25
25
PatchTaskRequest ,
26
26
UpdateTaskRequest ,
27
27
)
28
+
29
+ # AIDEV-NOTE: Fix Pydantic forward reference issues
30
+ # Import all step types first
31
+ from agents_api .autogen .Tasks import (
32
+ EvaluateStep ,
33
+ ForeachStep ,
34
+ IfElseWorkflowStep ,
35
+ ParallelStep ,
36
+ PromptStep ,
37
+ SwitchStep ,
38
+ ToolCallStep ,
39
+ WaitForInputStep ,
40
+ YieldStep ,
41
+ )
28
42
from agents_api .clients .pg import create_db_pool
29
43
from agents_api .common .utils .memory import total_size
30
44
from agents_api .env import api_key , api_key_header_name , multi_tenant_mode
47
61
from agents_api .queries .tools .create_tools import create_tools
48
62
from agents_api .queries .users .create_user import create_user
49
63
from agents_api .web import app
50
- from aiobotocore .session import get_session
51
64
from fastapi .testclient import TestClient
52
65
from temporalio .client import WorkflowHandle
53
66
from uuid_extensions import uuid7
54
67
55
68
from .utils import (
56
- get_localstack ,
57
69
get_pg_dsn ,
58
70
make_vector_with_similarity ,
59
71
)
60
72
from .utils import (
61
73
patch_embed_acompletion as patch_embed_acompletion_ctx ,
62
74
)
63
75
64
- # AIDEV-NOTE: Fix Pydantic forward reference issues
65
- # Import all step types first
66
- from agents_api .autogen .Tasks import (
67
- EvaluateStep ,
68
- ErrorWorkflowStep ,
69
- ForeachStep ,
70
- GetStep ,
71
- IfElseWorkflowStep ,
72
- LogStep ,
73
- ParallelStep ,
74
- PromptStep ,
75
- ReturnStep ,
76
- SetStep ,
77
- SleepStep ,
78
- SwitchStep ,
79
- ToolCallStep ,
80
- WaitForInputStep ,
81
- YieldStep ,
82
- )
83
-
84
76
# Rebuild models to resolve forward references
85
77
try :
86
78
CreateTaskRequest .model_rebuild ()
@@ -220,13 +212,13 @@ async def test_doc(pg_dsn, test_developer, test_agent):
220
212
owner_id = test_agent .id ,
221
213
connection_pool = pool ,
222
214
)
223
-
215
+
224
216
# Explicitly Refresh Indices
225
217
await pool .execute ("REINDEX DATABASE" )
226
-
218
+
227
219
doc = await get_doc (developer_id = test_developer .id , doc_id = resp .id , connection_pool = pool )
228
220
yield doc
229
-
221
+
230
222
# TODO: Delete the doc
231
223
# await delete_doc(
232
224
# developer_id=test_developer.id,
@@ -245,7 +237,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
245
237
embedding_with_confidence_0_5 = make_vector_with_similarity (d = 0.5 )
246
238
embedding_with_confidence_neg_0_5 = make_vector_with_similarity (d = - 0.5 )
247
239
embedding_with_confidence_1_neg = make_vector_with_similarity (d = - 1.0 )
248
-
240
+
249
241
# Insert embedding with all 1.0s (similarity = 1.0)
250
242
await pool .execute (
251
243
"""
@@ -257,7 +249,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
257
249
test_doc .content [0 ] if isinstance (test_doc .content , list ) else test_doc .content ,
258
250
f"[{ ', ' .join ([str (x ) for x in [1.0 ] * 1024 ])} ]" ,
259
251
)
260
-
252
+
261
253
# Insert embedding with confidence 0
262
254
await pool .execute (
263
255
"""
@@ -269,7 +261,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
269
261
"Test content 1" ,
270
262
f"[{ ', ' .join ([str (x ) for x in embedding_with_confidence_0 ])} ]" ,
271
263
)
272
-
264
+
273
265
# Insert embedding with confidence 0.5
274
266
await pool .execute (
275
267
"""
@@ -281,7 +273,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
281
273
"Test content 2" ,
282
274
f"[{ ', ' .join ([str (x ) for x in embedding_with_confidence_0_5 ])} ]" ,
283
275
)
284
-
276
+
285
277
# Insert embedding with confidence -0.5
286
278
await pool .execute (
287
279
"""
@@ -293,7 +285,7 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
293
285
"Test content 3" ,
294
286
f"[{ ', ' .join ([str (x ) for x in embedding_with_confidence_neg_0_5 ])} ]" ,
295
287
)
296
-
288
+
297
289
# Insert embedding with confidence -1
298
290
await pool .execute (
299
291
"""
@@ -305,11 +297,13 @@ async def test_doc_with_embedding(pg_dsn, test_developer, test_doc):
305
297
"Test content 4" ,
306
298
f"[{ ', ' .join ([str (x ) for x in embedding_with_confidence_1_neg ])} ]" ,
307
299
)
308
-
300
+
309
301
# Explicitly Refresh Indices
310
302
await pool .execute ("REINDEX DATABASE" )
311
-
312
- yield await get_doc (developer_id = test_developer .id , doc_id = test_doc .id , connection_pool = pool )
303
+
304
+ yield await get_doc (
305
+ developer_id = test_developer .id , doc_id = test_doc .id , connection_pool = pool
306
+ )
313
307
314
308
315
309
@pytest .fixture
@@ -328,13 +322,13 @@ async def test_user_doc(pg_dsn, test_developer, test_user):
328
322
owner_id = test_user .id ,
329
323
connection_pool = pool ,
330
324
)
331
-
325
+
332
326
# Explicitly Refresh Indices
333
327
await pool .execute ("REINDEX DATABASE" )
334
-
328
+
335
329
doc = await get_doc (developer_id = test_developer .id , doc_id = resp .id , connection_pool = pool )
336
330
yield doc
337
-
331
+
338
332
# TODO: Delete the doc
339
333
340
334
@@ -376,7 +370,7 @@ async def test_new_developer(pg_dsn, random_email):
376
370
developer_id = dev_id ,
377
371
connection_pool = pool ,
378
372
)
379
-
373
+
380
374
return await get_developer (
381
375
developer_id = dev_id ,
382
376
connection_pool = pool ,
@@ -416,7 +410,7 @@ async def test_execution(
416
410
client = None ,
417
411
id = "blah" ,
418
412
)
419
-
413
+
420
414
execution = await create_execution (
421
415
developer_id = test_developer_id ,
422
416
task_id = test_task .id ,
@@ -450,7 +444,7 @@ async def test_execution_started(
450
444
client = None ,
451
445
id = "blah" ,
452
446
)
453
-
447
+
454
448
execution = await create_execution (
455
449
developer_id = test_developer_id ,
456
450
task_id = test_task .id ,
@@ -462,9 +456,9 @@ async def test_execution_started(
462
456
workflow_handle = workflow_handle ,
463
457
connection_pool = pool ,
464
458
)
465
-
459
+
466
460
actual_scope_id = custom_scope_id or uuid7 ()
467
-
461
+
468
462
# Start the execution
469
463
await create_execution_transition (
470
464
developer_id = test_developer_id ,
@@ -515,13 +509,13 @@ async def test_tool(
515
509
"description" : "A function that prints hello world" ,
516
510
"parameters" : {"type" : "object" , "properties" : {}},
517
511
}
518
-
512
+
519
513
tool_spec = {
520
514
"function" : function ,
521
515
"name" : "hello_world1" ,
522
516
"type" : "function" ,
523
517
}
524
-
518
+
525
519
[tool , * _ ] = await create_tools (
526
520
developer_id = test_developer_id ,
527
521
agent_id = test_agent .id ,
@@ -539,8 +533,15 @@ async def test_tool(
539
533
540
534
541
535
@pytest .fixture (scope = "session" )
542
- def client (pg_dsn ):
536
+ def client (pg_dsn , localstack_container ):
543
537
"""Test client fixture."""
538
+ import os
539
+
540
+ # Set S3 environment variables before creating TestClient
541
+ os .environ ["S3_ACCESS_KEY" ] = localstack_container .env ["AWS_ACCESS_KEY_ID" ]
542
+ os .environ ["S3_SECRET_KEY" ] = localstack_container .env ["AWS_SECRET_ACCESS_KEY" ]
543
+ os .environ ["S3_ENDPOINT" ] = localstack_container .get_url ()
544
+
544
545
with (
545
546
TestClient (app = app ) as test_client ,
546
547
patch (
@@ -550,63 +551,81 @@ def client(pg_dsn):
550
551
):
551
552
yield test_client
552
553
554
+ # Clean up env vars
555
+ for key in ["S3_ACCESS_KEY" , "S3_SECRET_KEY" , "S3_ENDPOINT" ]:
556
+ if key in os .environ :
557
+ del os .environ [key ]
558
+
553
559
554
560
@pytest .fixture
555
561
async def make_request (client , test_developer_id ):
556
562
"""Factory fixture for making authenticated requests."""
563
+
557
564
def _make_request (method , url , ** kwargs ):
558
565
headers = kwargs .pop ("headers" , {})
559
566
headers = {
560
567
** headers ,
561
568
api_key_header_name : api_key ,
562
569
}
563
-
570
+
564
571
if multi_tenant_mode :
565
572
headers ["X-Developer-Id" ] = str (test_developer_id )
566
-
573
+
567
574
headers ["Content-Length" ] = str (total_size (kwargs .get ("json" , {})))
568
-
575
+
569
576
return client .request (method , url , headers = headers , ** kwargs )
570
-
577
+
571
578
return _make_request
572
579
573
580
574
- @pytest_asyncio .fixture
575
- async def s3_client ():
576
- """S3 client fixture."""
577
- with get_localstack () as localstack :
578
- s3_endpoint = localstack .get_url ()
579
-
580
- from botocore .config import Config
581
-
582
- session = get_session ()
583
- s3 = await session .create_client (
584
- "s3" ,
585
- endpoint_url = s3_endpoint ,
586
- aws_access_key_id = localstack .env ["AWS_ACCESS_KEY_ID" ],
587
- aws_secret_access_key = localstack .env ["AWS_SECRET_ACCESS_KEY" ],
588
- config = Config (s3 = {'addressing_style' : 'path' })
589
- ).__aenter__ ()
590
-
591
- app .state .s3_client = s3
592
-
593
- # Create the bucket if it doesn't exist
594
- from agents_api .env import blob_store_bucket
595
- try :
596
- await s3 .head_bucket (Bucket = blob_store_bucket )
597
- except Exception :
598
- await s3 .create_bucket (Bucket = blob_store_bucket )
599
-
600
- try :
601
- yield s3
602
- finally :
603
- await s3 .close ()
604
- app .state .s3_client = None
581
+ @pytest .fixture (scope = "session" )
582
+ def localstack_container ():
583
+ """Session-scoped LocalStack container."""
584
+ from testcontainers .localstack import LocalStackContainer
585
+
586
+ localstack = LocalStackContainer (image = "localstack/localstack:s3-latest" ).with_services (
587
+ "s3"
588
+ )
589
+ localstack .start ()
590
+
591
+ try :
592
+ yield localstack
593
+ finally :
594
+ localstack .stop ()
595
+
596
+
597
+ @pytest .fixture (autouse = True , scope = "session" )
598
+ def disable_s3_cache ():
599
+ """Disable async_s3 cache during tests to avoid event loop issues."""
600
+ from agents_api .clients import async_s3
601
+
602
+ # Check if the functions are wrapped with alru_cache
603
+ if hasattr (async_s3 .setup , "__wrapped__" ):
604
+ # Save original functions
605
+ original_setup = async_s3 .setup .__wrapped__
606
+ original_exists = async_s3 .exists .__wrapped__
607
+ original_list_buckets = async_s3 .list_buckets .__wrapped__
608
+
609
+ # Replace cached functions with uncached versions
610
+ async_s3 .setup = original_setup
611
+ async_s3 .exists = original_exists
612
+ async_s3 .list_buckets = original_list_buckets
613
+
614
+ yield
615
+
616
+
617
+ @pytest .fixture
618
+ def s3_client ():
619
+ """S3 client fixture that works with TestClient's event loop."""
620
+ # The TestClient's lifespan will create the S3 client
621
+ # The disable_s3_cache fixture ensures we don't have event loop issues
622
+ yield
605
623
606
624
607
625
@pytest .fixture
608
626
async def clean_secrets (pg_dsn , test_developer_id ):
609
627
"""Fixture to clean up secrets before and after tests."""
628
+
610
629
async def purge () -> None :
611
630
pool = await create_db_pool (dsn = pg_dsn )
612
631
try :
@@ -623,7 +642,7 @@ async def purge() -> None:
623
642
finally :
624
643
# pool is closed in *the same* loop it was created in
625
644
await pool .close ()
626
-
645
+
627
646
await purge ()
628
647
yield
629
648
await purge ()
@@ -635,4 +654,4 @@ def pytest_configure(config):
635
654
config .addinivalue_line ("markers" , "slow: marks tests as slow" )
636
655
config .addinivalue_line ("markers" , "integration: marks tests as integration tests" )
637
656
config .addinivalue_line ("markers" , "unit: marks tests as unit tests" )
638
- config .addinivalue_line ("markers" , "workflow: marks tests as workflow tests" )
657
+ config .addinivalue_line ("markers" , "workflow: marks tests as workflow tests" )
0 commit comments