braintrust 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1264 @@
1
+ """
2
+ Context Propagation Tests for Braintrust SDK
3
+
4
+ This test suite validates context propagation behavior across various concurrency patterns.
5
+
6
+ TEST ISOLATION STRATEGY:
7
+ - Tests use pytest-forked to run each test in an isolated process
8
+ - This ensures setup_threads() patches don't leak between tests
9
+ - Use unpatched(scenario) for xfail tests (documents context loss)
10
+ - Use patched(scenario) for tests that prove setup_threads() fixes it
11
+
12
+ Example:
13
+ def _threadpool_scenario(test_logger, with_memory_logger):
14
+ # test logic...
15
+
16
+ test_threadpool_loses_context = unpatched(_threadpool_scenario)
17
+ test_threadpool_with_patch = patched(_threadpool_scenario)
18
+
19
+ Run with: pytest --forked src/braintrust/test_context.py
20
+ """
21
+
22
+ import asyncio
23
+ import concurrent.futures
24
+ import functools
25
+ import sys
26
+ import threading
27
+ from typing import AsyncGenerator, Callable, Generator, TypeVar
28
+
29
+ import braintrust
30
+ import pytest
31
+ from braintrust import current_span, start_span
32
+ from braintrust.test_helpers import init_test_logger, with_memory_logger # noqa: F401
33
+ from braintrust.wrappers.threads import setup_threads
34
+
35
+ F = TypeVar("F", bound=Callable)
36
+
37
+
38
+ def isolate(instrument: bool) -> Callable[[F], F]:
39
+ """
40
+ Decorator for isolated context propagation tests.
41
+
42
+ - Always runs in forked process (pytest-forked)
43
+ - If instrument=True: calls setup_threads() before test
44
+ - If instrument=False: marks test as xfail (context loss expected)
45
+ """
46
+
47
+ def decorator(fn: F) -> F:
48
+ if asyncio.iscoroutinefunction(fn):
49
+
50
+ @functools.wraps(fn)
51
+ async def async_wrapper(*args, **kwargs):
52
+ if instrument:
53
+ setup_threads()
54
+ return await fn(*args, **kwargs)
55
+
56
+ wrapped = pytest.mark.forked(async_wrapper)
57
+ else:
58
+
59
+ @functools.wraps(fn)
60
+ def wrapper(*args, **kwargs):
61
+ if instrument:
62
+ setup_threads()
63
+ return fn(*args, **kwargs)
64
+
65
+ wrapped = pytest.mark.forked(wrapper)
66
+
67
+ if not instrument:
68
+ wrapped = pytest.mark.xfail(reason="context lost without patch")(wrapped)
69
+ return wrapped # type: ignore
70
+
71
+ return decorator
72
+
73
+
74
+ patched = isolate(instrument=True)
75
+ unpatched = isolate(instrument=False)
76
+
77
+
78
+ @pytest.fixture
79
+ def test_logger(with_memory_logger):
80
+ """Provide a test logger for each test with memory logger."""
81
+ logger = init_test_logger("test-context-project")
82
+ yield logger
83
+
84
+
85
+ # ============================================================================
86
+ # CONTEXT MANAGER PATTERN: with start_span(...)
87
+ # ============================================================================
88
+
89
+
90
+ def _threadpool_scenario(test_logger, with_memory_logger):
91
+ """ThreadPoolExecutor context propagation."""
92
+ parent_seen_by_worker = None
93
+
94
+ def worker_task():
95
+ nonlocal parent_seen_by_worker
96
+ parent_seen_by_worker = current_span()
97
+
98
+ with start_span(name="parent") as parent_span:
99
+ parent_id = parent_span.id
100
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
101
+ future = executor.submit(worker_task)
102
+ future.result()
103
+
104
+ assert parent_seen_by_worker is not None
105
+ assert parent_seen_by_worker.id == parent_id
106
+
107
+
108
+ test_threadpool_loses_context = unpatched(_threadpool_scenario)
109
+ test_threadpool_with_patch = patched(_threadpool_scenario)
110
+
111
+
112
+ def _thread_scenario(test_logger, with_memory_logger):
113
+ """threading.Thread context propagation."""
114
+ parent_seen_by_worker = None
115
+
116
+ def worker_task():
117
+ nonlocal parent_seen_by_worker
118
+ parent_seen_by_worker = current_span()
119
+
120
+ with start_span(name="parent") as parent_span:
121
+ parent_id = parent_span.id
122
+ thread = threading.Thread(target=worker_task)
123
+ thread.start()
124
+ thread.join()
125
+
126
+ assert parent_seen_by_worker is not None
127
+ assert parent_seen_by_worker.id == parent_id
128
+
129
+
130
+ test_thread_loses_context = unpatched(_thread_scenario)
131
+ test_thread_with_patch = patched(_thread_scenario)
132
+
133
+
134
+ def _nested_threadpool_scenario(test_logger, with_memory_logger):
135
+ """Nested ThreadPoolExecutor context propagation."""
136
+ root_seen_by_level1 = None
137
+ level1_seen_by_level2 = None
138
+
139
+ def level2_task():
140
+ nonlocal level1_seen_by_level2
141
+ level1_seen_by_level2 = current_span()
142
+
143
+ def level1_task():
144
+ nonlocal root_seen_by_level1
145
+ root_seen_by_level1 = current_span()
146
+
147
+ with start_span(name="level1") as level1_span:
148
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
149
+ future = executor.submit(level2_task)
150
+ future.result()
151
+ return level1_span.id
152
+
153
+ with start_span(name="root") as root_span:
154
+ root_id = root_span.id
155
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
156
+ future = executor.submit(level1_task)
157
+ level1_id = future.result()
158
+
159
+ assert root_seen_by_level1 is not None
160
+ assert root_seen_by_level1.id == root_id
161
+ assert level1_seen_by_level2 is not None
162
+ assert level1_seen_by_level2.id == level1_id
163
+
164
+
165
+ test_nested_threadpool_loses_context = unpatched(_nested_threadpool_scenario)
166
+ test_nested_threadpool_with_patch = patched(_nested_threadpool_scenario)
167
+
168
+
169
+ @pytest.mark.asyncio
170
+ async def _run_in_executor_scenario(test_logger, with_memory_logger):
171
+ """loop.run_in_executor context propagation."""
172
+ parent_seen_by_worker = None
173
+
174
+ def blocking_work():
175
+ nonlocal parent_seen_by_worker
176
+ parent_seen_by_worker = current_span()
177
+
178
+ with start_span(name="parent") as parent_span:
179
+ parent_id = parent_span.id
180
+ loop = asyncio.get_running_loop()
181
+ await loop.run_in_executor(None, blocking_work)
182
+
183
+ assert parent_seen_by_worker is not None
184
+ assert parent_seen_by_worker.id == parent_id
185
+
186
+
187
+ test_run_in_executor_loses_context = unpatched(_run_in_executor_scenario)
188
+ test_run_in_executor_with_patch = patched(_run_in_executor_scenario)
189
+
190
+
191
+ # ============================================================================
192
+ # ASYNCIO PATTERNS (Should Work)
193
+ # ============================================================================
194
+
195
+
196
+ @pytest.mark.asyncio
197
+ async def test_asyncio_create_task_preserves_context(test_logger, with_memory_logger):
198
+ """
199
+ WORKS: asyncio.create_task() DOES preserve Braintrust context.
200
+ """
201
+
202
+ async def async_worker():
203
+ span = current_span()
204
+ worker_span = start_span(name="async_worker")
205
+ await asyncio.sleep(0.001)
206
+ worker_span.end()
207
+ return span
208
+
209
+ # Create parent span
210
+ with start_span(name="parent") as parent_span:
211
+ parent_id = parent_span.id
212
+
213
+ # Create async task
214
+ task = asyncio.create_task(async_worker())
215
+ result_span = await task
216
+
217
+ # Task SHOULD see the parent span
218
+ assert result_span.id == parent_id, "create_task() should preserve context"
219
+
220
+ test_logger.flush()
221
+ logs = with_memory_logger.pop()
222
+ assert len(logs) == 2
223
+
224
+ parent_log = next(l for l in logs if l["span_attributes"]["name"] == "parent")
225
+ worker_log = next(l for l in logs if l["span_attributes"]["name"] == "async_worker")
226
+
227
+ # Worker should have parent as its parent (same trace)
228
+ assert worker_log["root_span_id"] == parent_log["root_span_id"], "Should be in same trace"
229
+ assert parent_log["span_id"] in worker_log.get("span_parents", []), "Worker should have parent as parent"
230
+
231
+
232
+ @pytest.mark.skipif(sys.version_info < (3, 9), reason="to_thread requires Python 3.9+")
233
+ @pytest.mark.asyncio
234
+ async def test_to_thread_preserves_context(test_logger, with_memory_logger):
235
+ """
236
+ WORKS: asyncio.to_thread() DOES preserve Braintrust context.
237
+ """
238
+
239
+ def blocking_work():
240
+ span = current_span()
241
+ worker_span = start_span(name="to_thread_worker")
242
+ worker_span.end()
243
+ return span
244
+
245
+ # Create parent span
246
+ with start_span(name="parent") as parent_span:
247
+ parent_id = parent_span.id
248
+
249
+ # Use to_thread
250
+ result_span = await asyncio.to_thread(blocking_work)
251
+
252
+ # to_thread SHOULD preserve context
253
+ assert result_span.id == parent_id, "to_thread() should preserve context"
254
+
255
+ test_logger.flush()
256
+ logs = with_memory_logger.pop()
257
+
258
+ # SURPRISING: Even to_thread() loses logger context (logger is a ContextVar too!)
259
+ # Only parent span is logged
260
+ # However, to_thread() DOES preserve span parent context
261
+ assert len(logs) >= 1
262
+
263
+ # If both spans logged (logger context preserved), verify parent chain
264
+ if len(logs) == 2:
265
+ parent_log = next(l for l in logs if l["span_attributes"]["name"] == "parent")
266
+ worker_log = next(l for l in logs if l["span_attributes"]["name"] == "to_thread_worker")
267
+ assert worker_log["root_span_id"] == parent_log["root_span_id"]
268
+ assert parent_log["span_id"] in worker_log.get("span_parents", [])
269
+
270
+
271
+ # ============================================================================
272
+ # DECORATOR PATTERN: @traced
273
+ # ============================================================================
274
+
275
+
276
+ def _traced_decorator_scenario(test_logger, with_memory_logger):
277
+ """@traced with ThreadPoolExecutor context propagation."""
278
+ parent_seen_by_worker = None
279
+
280
+ def worker_function():
281
+ nonlocal parent_seen_by_worker
282
+ parent_seen_by_worker = current_span()
283
+
284
+ with start_span(name="parent") as parent_span:
285
+ parent_id = parent_span.id
286
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
287
+ future = executor.submit(worker_function)
288
+ future.result()
289
+
290
+ assert parent_seen_by_worker is not None
291
+ assert parent_seen_by_worker.id == parent_id
292
+
293
+
294
+ test_traced_decorator_loses_context = unpatched(_traced_decorator_scenario)
295
+ test_traced_decorator_with_patch = patched(_traced_decorator_scenario)
296
+
297
+
298
+ @pytest.mark.asyncio
299
+ async def test_traced_decorator_with_async(test_logger, with_memory_logger):
300
+ """@traced decorator works with async functions (no patching needed)."""
301
+
302
+ @braintrust.traced
303
+ async def child_function():
304
+ await asyncio.sleep(0.01)
305
+ return "child_result"
306
+
307
+ @braintrust.traced
308
+ async def parent_function():
309
+ return await child_function()
310
+
311
+ await parent_function()
312
+
313
+ test_logger.flush()
314
+ logs = with_memory_logger.pop()
315
+
316
+ assert len(logs) == 2
317
+ parent_log = next(l for l in logs if l["span_attributes"]["name"] == "parent_function")
318
+ child_log = next(l for l in logs if l["span_attributes"]["name"] == "child_function")
319
+ assert child_log["root_span_id"] == parent_log["root_span_id"]
320
+ assert parent_log["span_id"] in child_log.get("span_parents", [])
321
+
322
+
323
+ # ============================================================================
324
+ # MANUAL PATTERN: start_span() + .end()
325
+ # ============================================================================
326
+
327
+
328
+ def _manual_span_scenario(test_logger, with_memory_logger):
329
+ """Manual span with ThreadPoolExecutor context propagation."""
330
+ parent_seen_by_worker = None
331
+
332
+ def worker_task():
333
+ nonlocal parent_seen_by_worker
334
+ parent_seen_by_worker = current_span()
335
+
336
+ parent_span = start_span(name="parent", set_current=True)
337
+ parent_span.set_current()
338
+ try:
339
+ parent_id = parent_span.id
340
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
341
+ future = executor.submit(worker_task)
342
+ future.result()
343
+ finally:
344
+ parent_span.unset_current()
345
+ parent_span.end()
346
+
347
+ assert parent_seen_by_worker is not None
348
+ assert parent_seen_by_worker.id == parent_id
349
+
350
+
351
+ test_manual_span_loses_context = unpatched(_manual_span_scenario)
352
+ test_manual_span_with_patch = patched(_manual_span_scenario)
353
+
354
+
355
+ @pytest.mark.asyncio
356
+ async def test_manual_span_with_async(test_logger, with_memory_logger):
357
+ """
358
+ Manual span management with explicit set_current()/unset_current() calls.
359
+
360
+ ⚠️ IMPORTANT: This pattern is MORE VERBOSE and ERROR-PRONE than context managers.
361
+
362
+ Incorrect pattern (DOES NOT WORK):
363
+ parent_span = start_span("parent", set_current=True) # ❌ Just creates span
364
+ # parent is NOT current yet!
365
+
366
+ Correct pattern (WORKS but verbose):
367
+ parent_span = start_span("parent", set_current=True)
368
+ parent_span.set_current() # ✅ Actually set as current
369
+ try:
370
+ await child()
371
+ finally:
372
+ parent_span.unset_current() # ✅ Clean up
373
+ parent_span.end()
374
+
375
+ Recommended pattern (BEST):
376
+ with start_span("parent"): # ✅ Automatic set/unset
377
+ await child()
378
+ """
379
+
380
+ async def child_work():
381
+ child_span = start_span(name="child", set_current=True)
382
+ child_span.set_current() # ✅ Must call explicitly!
383
+ try:
384
+ await asyncio.sleep(0.01)
385
+ return "result"
386
+ finally:
387
+ child_span.unset_current() # ✅ Must clean up!
388
+ child_span.end()
389
+
390
+ parent_span = start_span(name="parent", set_current=True)
391
+ parent_span.set_current() # ✅ Must call explicitly!
392
+ parent_id = parent_span.id
393
+ try:
394
+ result = await child_work()
395
+ finally:
396
+ parent_span.unset_current() # ✅ Must clean up!
397
+ parent_span.end()
398
+
399
+ test_logger.flush()
400
+ logs = with_memory_logger.pop()
401
+
402
+ # Expected: Both spans should be logged
403
+ assert len(logs) == 2, f"Expected 2 spans, got {len(logs)}"
404
+
405
+ parent_log = next(l for l in logs if l["span_attributes"]["name"] == "parent")
406
+ child_log = next(l for l in logs if l["span_attributes"]["name"] == "child")
407
+
408
+ # Child should be child of parent
409
+ assert child_log["root_span_id"] == parent_log["root_span_id"], (
410
+ f"Child root {child_log['root_span_id']} != parent root {parent_log['root_span_id']}"
411
+ )
412
+ assert parent_log["span_id"] in child_log.get("span_parents", []), (
413
+ f"Parent {parent_log['span_id']} not in child parents {child_log.get('span_parents', [])}"
414
+ )
415
+
416
+
417
+ # ============================================================================
418
+ # INTEGRATION PATTERNS (Based on Real SDK Integrations)
419
+ # ============================================================================
420
+
421
+
422
+ @pytest.mark.asyncio
423
+ async def test_async_generator_wrapper_pattern(test_logger, with_memory_logger):
424
+ """
425
+ Expected: Async generators wrapping spans should maintain parent relationships.
426
+
427
+ Real-world pattern: Wrapping SDK streams in async generators (common in pydantic-ai, etc.)
428
+
429
+ Pattern:
430
+ with start_span("consumer"):
431
+ async def stream_wrapper():
432
+ with start_span("stream_source"):
433
+ async for item in source():
434
+ yield item
435
+
436
+ async for item in stream_wrapper():
437
+ process(item)
438
+
439
+ Expected trace:
440
+ consumer
441
+ └─ stream_source
442
+ └─ processing spans
443
+ """
444
+
445
+ async def simulated_stream():
446
+ """Simulates an async stream source."""
447
+ for i in range(3):
448
+ await asyncio.sleep(0.001)
449
+ yield f"item_{i}"
450
+
451
+ async def stream_wrapper():
452
+ """Wraps stream in async generator (common customer pattern)."""
453
+ with start_span(name="stream_source") as source_span:
454
+ async for item in simulated_stream():
455
+ yield item
456
+
457
+ with start_span(name="consumer") as consumer_span:
458
+ async for item in stream_wrapper():
459
+ # Process each item
460
+ item_span = start_span(name=f"process_{item}")
461
+ await asyncio.sleep(0.001)
462
+ item_span.end()
463
+
464
+ test_logger.flush()
465
+ logs = with_memory_logger.pop()
466
+
467
+ # Expected: consumer + stream_source + 3 process spans = 5
468
+ assert len(logs) == 5, f"Expected 5 spans, got {len(logs)}"
469
+
470
+ consumer_log = next(l for l in logs if l["span_attributes"]["name"] == "consumer")
471
+ stream_log = next(l for l in logs if l["span_attributes"]["name"] == "stream_source")
472
+ process_logs = [l for l in logs if l["span_attributes"]["name"].startswith("process_")]
473
+
474
+ # All should share same root
475
+ assert stream_log["root_span_id"] == consumer_log["root_span_id"]
476
+ for p in process_logs:
477
+ assert p["root_span_id"] == consumer_log["root_span_id"]
478
+
479
+ # stream_source should be child of consumer
480
+ assert consumer_log["span_id"] in stream_log.get("span_parents", [])
481
+
482
+
483
+ def test_library_doing_context_right(test_logger, with_memory_logger):
484
+ """
485
+ Test: Well-behaved library (like LangChain) that properly propagates context.
486
+
487
+ This test works WITHOUT auto-instrumentation because the library correctly
488
+ captures context at call time using copy_context().
489
+
490
+ Real-world example - LangChain-style pattern:
491
+ class WellBehavedSDK:
492
+ def run_async(self, fn):
493
+ ctx = contextvars.copy_context() # Captured at call time!
494
+ return self._pool.submit(lambda: ctx.run(fn))
495
+ """
496
+ import contextvars
497
+
498
+ class WellBehavedSDK:
499
+ def __init__(self):
500
+ self._pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
501
+
502
+ def run_async(self, fn):
503
+ ctx = contextvars.copy_context()
504
+ return self._pool.submit(lambda: ctx.run(fn))
505
+
506
+ def shutdown(self):
507
+ self._pool.shutdown(wait=True)
508
+
509
+ sdk = WellBehavedSDK()
510
+
511
+ parent_seen_by_worker = None
512
+
513
+ def worker_function():
514
+ nonlocal parent_seen_by_worker
515
+ parent_seen_by_worker = current_span()
516
+
517
+ try:
518
+ with start_span(name="user_parent") as parent_span:
519
+ parent_id = parent_span.id
520
+ future = sdk.run_async(worker_function)
521
+ future.result()
522
+ finally:
523
+ sdk.shutdown()
524
+
525
+ assert parent_seen_by_worker is not None
526
+ assert parent_seen_by_worker.id == parent_id, "Well-behaved library preserves context"
527
+
528
+
529
+ def _integration_forgot_context_scenario(test_logger, with_memory_logger):
530
+ """Integration without context propagation."""
531
+ parent_seen_by_worker = None
532
+
533
+ class NaiveIntegration:
534
+ def __init__(self):
535
+ self._pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
536
+
537
+ def process(self, fn):
538
+ return self._pool.submit(fn)
539
+
540
+ def shutdown(self):
541
+ self._pool.shutdown(wait=True)
542
+
543
+ integration = NaiveIntegration()
544
+
545
+ def worker_function():
546
+ nonlocal parent_seen_by_worker
547
+ parent_seen_by_worker = current_span()
548
+
549
+ try:
550
+ with start_span(name="user_parent") as parent_span:
551
+ parent_id = parent_span.id
552
+ future = integration.process(worker_function)
553
+ future.result()
554
+ finally:
555
+ integration.shutdown()
556
+
557
+ assert parent_seen_by_worker is not None
558
+ assert parent_seen_by_worker.id == parent_id
559
+
560
+
561
+ test_integration_forgot_context_loses = unpatched(_integration_forgot_context_scenario)
562
+ test_integration_forgot_context_with_patch = patched(_integration_forgot_context_scenario)
563
+
564
+
565
+ def test_integration_early_context_not_fixable(test_logger, with_memory_logger):
566
+ """
567
+ Documents: Integration that captured context too early CANNOT be fixed by auto-instrumentation.
568
+
569
+ This pattern explicitly switches to a stale context using self._ctx.run(fn),
570
+ which overrides our auto-instrumentation. The integration's explicit context
571
+ switch happens AFTER our wrapper, so the stale context wins.
572
+
573
+ Pattern:
574
+ class EagerContextIntegration:
575
+ def __init__(self):
576
+ self._ctx = copy_context() # Stale context captured here
577
+
578
+ def process(self, fn):
579
+ return self._pool.submit(lambda: self._ctx.run(fn)) # Explicit switch to stale
580
+
581
+ Auto-instrumentation wraps submit(), but the lambda then switches to stale context.
582
+
583
+ This is NOT fixable by auto-instrumentation - the integration must be fixed
584
+ to capture context at call time, not at __init__ time.
585
+ """
586
+ import contextvars
587
+
588
+ parent_seen_by_worker = None
589
+
590
+ class EagerContextIntegration:
591
+ def __init__(self):
592
+ self._pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
593
+ self._ctx = contextvars.copy_context()
594
+
595
+ def process(self, fn):
596
+ return self._pool.submit(lambda: self._ctx.run(fn))
597
+
598
+ def shutdown(self):
599
+ self._pool.shutdown(wait=True)
600
+
601
+ integration = EagerContextIntegration()
602
+
603
+ def worker_function():
604
+ nonlocal parent_seen_by_worker
605
+ parent_seen_by_worker = current_span()
606
+
607
+ try:
608
+ with start_span(name="user_parent") as parent_span:
609
+ parent_id = parent_span.id
610
+ future = integration.process(worker_function)
611
+ future.result()
612
+ finally:
613
+ integration.shutdown()
614
+
615
+ assert parent_seen_by_worker is not None, "Worker runs"
616
+ assert parent_seen_by_worker.id != parent_id, "Worker sees STALE context, not parent (not fixable)"
617
+
618
+
619
+ def _integration_thread_scenario(test_logger, with_memory_logger):
620
+ """Integration using Thread directly."""
621
+ parent_seen_by_worker = None
622
+
623
+ class ThreadIntegration:
624
+ def process(self, fn):
625
+ thread = threading.Thread(target=fn)
626
+ thread.start()
627
+ return thread
628
+
629
+ def wait(self, thread):
630
+ thread.join()
631
+
632
+ integration = ThreadIntegration()
633
+
634
+ def worker_function():
635
+ nonlocal parent_seen_by_worker
636
+ parent_seen_by_worker = current_span()
637
+
638
+ with start_span(name="user_parent") as parent_span:
639
+ parent_id = parent_span.id
640
+ thread = integration.process(worker_function)
641
+ integration.wait(thread)
642
+
643
+ assert parent_seen_by_worker is not None
644
+ assert parent_seen_by_worker.id == parent_id
645
+
646
+
647
+ test_integration_thread_loses_context = unpatched(_integration_thread_scenario)
648
+ test_integration_thread_with_patch = patched(_integration_thread_scenario)
649
+
650
+
651
+ def _integration_decorator_scenario(test_logger, with_memory_logger):
652
+ """Decorator pattern loses context."""
653
+ parent_seen_by_worker = None
654
+
655
+ def async_retry_decorator(fn):
656
+ pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
657
+
658
+ def wrapper(*args, **kwargs):
659
+ future = pool.submit(fn, *args, **kwargs)
660
+ return future.result()
661
+
662
+ return wrapper
663
+
664
+ @async_retry_decorator
665
+ def user_function():
666
+ nonlocal parent_seen_by_worker
667
+ parent_seen_by_worker = current_span()
668
+
669
+ with start_span(name="user_parent") as parent_span:
670
+ parent_id = parent_span.id
671
+ user_function()
672
+
673
+ assert parent_seen_by_worker is not None
674
+ assert parent_seen_by_worker.id == parent_id
675
+
676
+
677
+ test_integration_decorator_loses_context = unpatched(_integration_decorator_scenario)
678
+ test_integration_decorator_with_patch = patched(_integration_decorator_scenario)
679
+
680
+
681
+ @pytest.mark.asyncio
682
+ async def test_copy_context_token_error_across_async_tasks(test_logger, with_memory_logger):
683
+ """
684
+ Expected: Span lifecycle should work even when started in one async context
685
+ and ended in another (copied) context.
686
+
687
+ Real-world pattern: LangChain creates parallel async tasks using asyncio.create_task(),
688
+ which gives each task a COPY of the context. If a span is started in the main
689
+ context but ended in a task context, we get:
690
+ "ValueError: Token was created in a different Context"
691
+
692
+ This is what LangChain's Braintrust integration silently handles!
693
+
694
+ Pattern:
695
+ async with start_span("parent"):
696
+ # Span sets ContextVar token in context A
697
+
698
+ async def task_work():
699
+ # Task runs in context B (copy of A)
700
+ # Try to end parent span
701
+ # ValueError: Token from context A can't be reset in context B
702
+
703
+ task = asyncio.create_task(task_work()) # Context copy
704
+ await task
705
+
706
+ Expected: Should work without errors (or handle them gracefully)
707
+ Actual: May raise ValueError (which integrations must handle)
708
+ """
709
+ import asyncio
710
+
711
+ with start_span(name="parent"):
712
+ parent_log = with_memory_logger.pop()[0]
713
+ parent_span = current_span()
714
+
715
+ # Simulate what happens in LangChain:
716
+ # Span is started in main context, but callback happens in task context
717
+
718
+ async def task_work():
719
+ # This runs in a COPIED context
720
+ # If we try to manipulate parent_span here, we might hit token errors
721
+
722
+ # This is what LangChain callbacks do:
723
+ # 1. Create child span (works - parent_span accessible)
724
+ with start_span(name="child"):
725
+ await asyncio.sleep(0.01)
726
+
727
+ # 2. Try to unset current (might fail with token error)
728
+ try:
729
+ parent_span.unset_current()
730
+ token_error = None
731
+ except ValueError as e:
732
+ token_error = str(e)
733
+
734
+ return token_error
735
+
736
+ # Create task - this copies the context
737
+ task = asyncio.create_task(task_work())
738
+ error = await task
739
+
740
+ # We might see token error here
741
+ if error and "was created in a different Context" in error:
742
+ # This is the error LangChain's integration silently handles!
743
+ # It's not a bug, it's an expected consequence of context copies
744
+ pass # Expected in async contexts
745
+
746
+ # Child span should still be logged correctly despite token error
747
+ child_log = with_memory_logger.pop()[0]
748
+
749
+ # The child span should maintain parent relationship
750
+ # (Braintrust SDK handles this correctly even across context boundaries)
751
+ assert child_log["span_parents"] == [parent_log["span_id"]], (
752
+ f"Child span should have parent relationship despite context copy. Got: {child_log.get('span_parents')}"
753
+ )
754
+
755
+
756
+ @pytest.mark.asyncio
757
+ async def test_async_generator_early_break_context_token(test_logger, with_memory_logger):
758
+ """
759
+ Expected: Early breaks from async generators shouldn't cause context token errors.
760
+
761
+ Real-world issue: Breaking early from async generators causes cleanup in different
762
+ async context, leading to "Token was created in a different Context" errors.
763
+
764
+ Pattern (from pydantic-ai integration):
765
+ async def stream_wrapper():
766
+ with start_span("stream"):
767
+ async for chunk in source():
768
+ yield chunk
769
+ if condition:
770
+ break # Early break triggers cleanup in different context
771
+
772
+ async for chunk in stream_wrapper():
773
+ process(chunk)
774
+ if done:
775
+ break # Consumer breaks early
776
+
777
+ Expected: Spans logged correctly, no context token errors
778
+ """
779
+
780
+ async def simulated_long_stream():
781
+ """Simulates a long stream."""
782
+ for i in range(100):
783
+ await asyncio.sleep(0.001)
784
+ yield f"chunk_{i}"
785
+
786
+ async def stream_wrapper():
787
+ """Wraps stream, may break early (triggers cleanup in different context)."""
788
+ with start_span(name="wrapped_stream") as stream_span:
789
+ count = 0
790
+ async for chunk in simulated_long_stream():
791
+ yield chunk
792
+ count += 1
793
+ if count >= 3:
794
+ # Break early - this triggers cleanup in different context
795
+ break
796
+
797
+ with start_span(name="consumer") as consumer_span:
798
+ chunk_count = 0
799
+
800
+ # Consumer breaks early too
801
+ async for chunk in stream_wrapper():
802
+ chunk_count += 1
803
+ if chunk_count >= 2:
804
+ break
805
+
806
+ # Should not raise ValueError about "Token was created in a different Context"
807
+ test_logger.flush()
808
+ logs = with_memory_logger.pop()
809
+
810
+ # Expected: At least consumer and wrapped_stream spans
811
+ assert len(logs) >= 2, f"Expected at least 2 spans, got {len(logs)}"
812
+
813
+ consumer_log = next((l for l in logs if l["span_attributes"]["name"] == "consumer"), None)
814
+ stream_log = next((l for l in logs if l["span_attributes"]["name"] == "wrapped_stream"), None)
815
+
816
+ assert consumer_log is not None, "Consumer span should be logged"
817
+ assert stream_log is not None, "Wrapped stream span should be logged despite early break"
818
+
819
+ # wrapped_stream should be child of consumer
820
+ if stream_log:
821
+ assert stream_log["root_span_id"] == consumer_log["root_span_id"]
822
+ assert consumer_log["span_id"] in stream_log.get("span_parents", [])
823
+
824
+
825
+ # ============================================================================
826
+ # ASYNC GENERATOR TESTS
827
+ # ============================================================================
828
+
829
+
830
+ @pytest.mark.asyncio
831
+ async def test_async_generator_context_behavior(test_logger, with_memory_logger):
832
+ """
833
+ Test how Braintrust spans behave with async generators.
834
+ """
835
+
836
+ async def my_async_gen() -> AsyncGenerator[int, None]:
837
+ gen_span = start_span(name="generator_span")
838
+
839
+ try:
840
+ for i in range(3):
841
+ yield i
842
+ await asyncio.sleep(0.001)
843
+ finally:
844
+ gen_span.end()
845
+
846
+ # Consumer with parent span
847
+ with start_span(name="consumer") as consumer_span:
848
+ results = []
849
+ async for value in my_async_gen():
850
+ results.append(value)
851
+ # Consumer does work between iterations
852
+ item_span = start_span(name=f"process_{value}")
853
+ await asyncio.sleep(0.001)
854
+ item_span.end()
855
+
856
+ assert results == [0, 1, 2]
857
+
858
+ test_logger.flush()
859
+ logs = with_memory_logger.pop()
860
+ # Should have consumer + generator_span + 3 process spans
861
+ assert len(logs) == 5
862
+
863
+
864
+ @pytest.mark.asyncio
865
+ async def test_async_generator_finalization(test_logger, with_memory_logger):
866
+ """
867
+ Test context during async generator cleanup.
868
+ """
869
+
870
+ async def generator_with_finally() -> AsyncGenerator[int, None]:
871
+ gen_span = start_span(name="gen_with_finally")
872
+
873
+ try:
874
+ yield 1
875
+ yield 2
876
+ finally:
877
+ # What context do we have during cleanup?
878
+ cleanup_span = current_span()
879
+ gen_span.end()
880
+
881
+ # Consumer
882
+ with start_span(name="consumer") as consumer_span:
883
+ gen = generator_with_finally()
884
+ await gen.__anext__() # Get first value only
885
+
886
+ # Explicitly close generator
887
+ await gen.aclose()
888
+
889
+ test_logger.flush()
890
+ logs = with_memory_logger.pop()
891
+ assert len(logs) == 2 # consumer + gen_with_finally
892
+
893
+
894
+ # ============================================================================
895
+ # TEST CATEGORY 4: Sync Generator Context
896
+ # ============================================================================
897
+
898
+
899
+ def test_sync_generator_context_sharing(test_logger, with_memory_logger):
900
+ """
901
+ Sync generators share caller's context - changes are visible.
902
+ """
903
+
904
+ def sync_gen() -> Generator[int, None, None]:
905
+ for i in range(3):
906
+ # Check current span at each iteration
907
+ span = current_span()
908
+ yield i
909
+
910
+ # Create parent span
911
+ with start_span(name="parent") as parent_span:
912
+ gen = sync_gen()
913
+
914
+ for i, value in enumerate(gen):
915
+ # Create new span for each iteration
916
+ item_span = start_span(name=f"item_{i}")
917
+ item_span.end()
918
+
919
+ test_logger.flush()
920
+ logs = with_memory_logger.pop()
921
+ assert len(logs) == 4 # parent + 3 items
922
+
923
+
924
+ # ============================================================================
925
+ # REAL-WORLD PATTERN TESTS
926
+ # ============================================================================
927
+
928
+
929
+ def _thread_wrapped_async_scenario(test_logger, with_memory_logger):
930
+ """Thread-wrapped async (Google ADK, Pydantic AI pattern)."""
931
+ import queue as queue_module
932
+
933
+ event_queue = queue_module.Queue()
934
+ parent_seen_in_thread = None
935
+
936
+ async def _invoke_async():
937
+ nonlocal parent_seen_in_thread
938
+ parent_seen_in_thread = current_span()
939
+ event_queue.put("done")
940
+
941
+ def _thread_main():
942
+ asyncio.run(_invoke_async())
943
+ event_queue.put(None)
944
+
945
+ with start_span(name="parent") as parent_span:
946
+ parent_id = parent_span.id
947
+ thread = threading.Thread(target=_thread_main)
948
+ thread.start()
949
+ while True:
950
+ event = event_queue.get()
951
+ if event is None:
952
+ break
953
+ thread.join()
954
+
955
+ assert parent_seen_in_thread is not None
956
+ assert parent_seen_in_thread.id == parent_id
957
+
958
+
959
+ test_thread_wrapped_async_loses_context = unpatched(_thread_wrapped_async_scenario)
960
+ test_thread_wrapped_async_with_patch = patched(_thread_wrapped_async_scenario)
961
+
962
+
963
+ async def _fastapi_background_scenario(test_logger, with_memory_logger):
964
+ """FastAPI background tasks (run_in_executor)."""
965
+ parent_seen_by_background = None
966
+
967
+ def background_work():
968
+ nonlocal parent_seen_by_background
969
+ parent_seen_by_background = current_span()
970
+
971
+ with start_span(name="http_request") as request_span:
972
+ request_id = request_span.id
973
+ loop = asyncio.get_running_loop()
974
+ await loop.run_in_executor(None, background_work)
975
+
976
+ assert parent_seen_by_background is not None
977
+ assert parent_seen_by_background.id == request_id
978
+
979
+
980
+ test_fastapi_background_loses_context = unpatched(pytest.mark.asyncio(_fastapi_background_scenario))
981
+ test_fastapi_background_with_patch = patched(pytest.mark.asyncio(_fastapi_background_scenario))
982
+
983
+
984
+ def _data_pipeline_scenario(test_logger, with_memory_logger):
985
+ """Data pipeline with parallel ThreadPoolExecutor."""
986
+ parents_seen = []
987
+
988
+ def process_item(item: int):
989
+ parent = current_span()
990
+ parents_seen.append(parent)
991
+ return item
992
+
993
+ with start_span(name="pipeline") as pipeline_span:
994
+ pipeline_id = pipeline_span.id
995
+ data = list(range(3))
996
+
997
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
998
+ futures = [executor.submit(process_item, item) for item in data]
999
+ [f.result() for f in futures]
1000
+
1001
+ assert len(parents_seen) == 3
1002
+ for i, parent in enumerate(parents_seen):
1003
+ assert parent is not None
1004
+ assert parent.id == pipeline_id
1005
+
1006
+
1007
+ test_data_pipeline_loses_context = unpatched(_data_pipeline_scenario)
1008
+ test_data_pipeline_with_patch = patched(_data_pipeline_scenario)
1009
+
1010
+
1011
+ @pytest.mark.asyncio
1012
+ async def test_streaming_llm_pattern(test_logger, with_memory_logger):
1013
+ """
1014
+ Simulates streaming LLM responses with async generator.
1015
+ """
1016
+
1017
+ async def llm_stream_generator() -> AsyncGenerator[str, None]:
1018
+ llm_span = start_span(name="llm_generation")
1019
+
1020
+ try:
1021
+ for i in range(3):
1022
+ yield f"chunk_{i}"
1023
+ await asyncio.sleep(0.001)
1024
+ finally:
1025
+ llm_span.end()
1026
+
1027
+ # Consumer
1028
+ with start_span(name="http_request") as request_span:
1029
+ async for chunk in llm_stream_generator():
1030
+ # Process each chunk
1031
+ chunk_span = start_span(name=f"process_{chunk}")
1032
+ await asyncio.sleep(0.001)
1033
+ chunk_span.end()
1034
+
1035
+ test_logger.flush()
1036
+ logs = with_memory_logger.pop()
1037
+ assert len(logs) == 5 # request + llm_generation + 3 process chunks
1038
+
1039
+
1040
+ # ============================================================================
1041
+ # TEST CATEGORY 6: Context Isolation Tests
1042
+ # ============================================================================
1043
+
1044
+
1045
+ @pytest.mark.asyncio
1046
+ async def test_parallel_tasks_context_isolation(test_logger, with_memory_logger):
1047
+ """
1048
+ Test that concurrent asyncio tasks have isolated contexts.
1049
+ """
1050
+ parent_ids = []
1051
+
1052
+ async def task_work(task_id: int):
1053
+ # Each task should see the root span as parent
1054
+ parent = current_span()
1055
+ parent_ids.append(parent.id)
1056
+
1057
+ task_span = start_span(name=f"task_{task_id}")
1058
+
1059
+ await asyncio.sleep(0.01)
1060
+ task_span.end()
1061
+
1062
+ # Root span
1063
+ with start_span(name="root") as root_span:
1064
+ root_id = root_span.id
1065
+
1066
+ # Spawn multiple concurrent tasks
1067
+ tasks = [asyncio.create_task(task_work(i)) for i in range(5)]
1068
+ await asyncio.gather(*tasks)
1069
+
1070
+ # All tasks should have seen root as parent
1071
+ assert all(pid == root_id for pid in parent_ids), "Tasks should see root as parent"
1072
+
1073
+ test_logger.flush()
1074
+ logs = with_memory_logger.pop()
1075
+ assert len(logs) == 6 # root + 5 tasks
1076
+
1077
+ root_log = next(l for l in logs if l["span_attributes"]["name"] == "root")
1078
+ task_logs = [l for l in logs if l["span_attributes"]["name"].startswith("task_")]
1079
+
1080
+ # All tasks should have root as parent
1081
+ for task_log in task_logs:
1082
+ assert task_log["root_span_id"] == root_log["root_span_id"]
1083
+ assert root_log["span_id"] in task_log.get("span_parents", [])
1084
+
1085
+
1086
+ @pytest.mark.skipif(sys.version_info < (3, 11), reason="TaskGroup requires Python 3.11+")
1087
+ @pytest.mark.asyncio
1088
+ async def test_taskgroup_context_propagation(test_logger, with_memory_logger):
1089
+ """
1090
+ Test that TaskGroup properly propagates context (Python 3.11+).
1091
+ """
1092
+
1093
+ async def child_task(task_id: int):
1094
+ child_span = start_span(name=f"child_{task_id}")
1095
+ await asyncio.sleep(0.001)
1096
+ child_span.end()
1097
+
1098
+ # Root span
1099
+ with start_span(name="root") as root_span:
1100
+ async with asyncio.TaskGroup() as tg: # pylint: disable=no-member
1101
+ for i in range(3):
1102
+ tg.create_task(child_task(i))
1103
+
1104
+ test_logger.flush()
1105
+ logs = with_memory_logger.pop()
1106
+ assert len(logs) == 4 # root + 3 children
1107
+
1108
+ root_log = next(l for l in logs if l["span_attributes"]["name"] == "root")
1109
+ child_logs = [l for l in logs if l["span_attributes"]["name"].startswith("child_")]
1110
+
1111
+ # All children should have root as parent
1112
+ for child_log in child_logs:
1113
+ assert child_log["root_span_id"] == root_log["root_span_id"]
1114
+ assert root_log["span_id"] in child_log.get("span_parents", [])
1115
+
1116
+
1117
+ # ============================================================================
1118
+ # TEST CATEGORY 7: Nested Context Tests
1119
+ # ============================================================================
1120
+
1121
+
1122
+ def test_nested_spans_same_thread(test_logger, with_memory_logger):
1123
+ """
1124
+ Test that nested spans work correctly in the same thread.
1125
+ """
1126
+ # Root span
1127
+ with start_span(name="root") as root_span:
1128
+ # Verify root is current
1129
+ assert current_span().id == root_span.id
1130
+
1131
+ # Child span
1132
+ with start_span(name="child") as child_span:
1133
+ child_id = child_span.id
1134
+
1135
+ # Verify child is now current
1136
+ assert current_span().id == child_span.id
1137
+
1138
+ # Grandchild span
1139
+ with start_span(name="grandchild") as grandchild_span:
1140
+ grandchild_id = grandchild_span.id
1141
+ assert current_span().id == grandchild_span.id
1142
+
1143
+ # After grandchild closes, child should be current
1144
+ assert current_span().id == child_span.id
1145
+
1146
+ # After child closes, root should be current
1147
+ assert current_span().id == root_span.id
1148
+
1149
+ test_logger.flush()
1150
+ logs = with_memory_logger.pop()
1151
+ assert len(logs) == 3
1152
+
1153
+ root_log = next(l for l in logs if l["span_attributes"]["name"] == "root")
1154
+ child_log = next(l for l in logs if l["span_attributes"]["name"] == "child")
1155
+ grandchild_log = next(l for l in logs if l["span_attributes"]["name"] == "grandchild")
1156
+
1157
+ # Verify parent chain
1158
+ assert root_log["span_id"] == root_log["root_span_id"], "Root is root"
1159
+ assert child_log["root_span_id"] == root_log["root_span_id"], "Child same root"
1160
+ assert grandchild_log["root_span_id"] == root_log["root_span_id"], "Grandchild same root"
1161
+ assert root_log["span_id"] in child_log.get("span_parents", []), "Child parent is root"
1162
+ assert child_log["span_id"] in grandchild_log.get("span_parents", []), "Grandchild parent is child"
1163
+
1164
+
1165
+ @pytest.mark.asyncio
1166
+ async def test_deeply_nested_async_context(test_logger, with_memory_logger):
1167
+ """
1168
+ Test deeply nested spans to ensure no corruption.
1169
+ """
1170
+
1171
+ async def nested_span(depth: int):
1172
+ span = start_span(name=f"depth_{depth}")
1173
+
1174
+ if depth > 0:
1175
+ await nested_span(depth - 1)
1176
+
1177
+ span.end()
1178
+
1179
+ with start_span(name="root") as root_span:
1180
+ root_id = root_span.id
1181
+ await nested_span(10) # 10 levels deep
1182
+
1183
+ test_logger.flush()
1184
+ logs = with_memory_logger.pop()
1185
+
1186
+ # Should be 11 spans: root + 10 nested
1187
+ assert len(logs) >= 11 # Allow for timing variations
1188
+
1189
+ # Get the actual root (first span created)
1190
+ root_log = next((l for l in logs if l["span_attributes"]["name"] == "root"), None)
1191
+ assert root_log is not None
1192
+ actual_root_id = root_log["root_span_id"]
1193
+
1194
+ # All should share same root
1195
+ for log in logs:
1196
+ assert log["root_span_id"] == actual_root_id
1197
+
1198
+
1199
+ # ============================================================================
1200
+ # TEST CATEGORY 8: Exception Handling
1201
+ # ============================================================================
1202
+
1203
+
1204
+ def test_context_with_exception_propagation(test_logger, with_memory_logger):
1205
+ """
1206
+ Test that context is properly maintained during exception propagation.
1207
+ """
1208
+ fail_span_id = None
1209
+
1210
+ def failing_function():
1211
+ nonlocal fail_span_id
1212
+ # Use context manager for proper span lifecycle
1213
+ with start_span(name="failing_span") as fail_span:
1214
+ fail_span_id = fail_span.id
1215
+ # During this context, fail_span should be current
1216
+ assert current_span().id == fail_span.id
1217
+ raise ValueError("Expected error")
1218
+
1219
+ with start_span(name="parent") as parent_span:
1220
+ parent_id = parent_span.id
1221
+
1222
+ try:
1223
+ failing_function()
1224
+ except ValueError:
1225
+ pass
1226
+
1227
+ # After exception, parent should be restored as current
1228
+ assert current_span().id == parent_id
1229
+
1230
+ test_logger.flush()
1231
+ logs = with_memory_logger.pop()
1232
+ assert len(logs) == 2
1233
+
1234
+ parent_log = next(l for l in logs if l["span_attributes"]["name"] == "parent")
1235
+ fail_log = next(l for l in logs if l["span_attributes"]["name"] == "failing_span")
1236
+
1237
+ # Verify parent chain
1238
+ assert fail_log["root_span_id"] == parent_log["root_span_id"]
1239
+ assert parent_log["span_id"] in fail_log.get("span_parents", [])
1240
+
1241
+
1242
+ # ============================================================================
1243
+ # AUTO-INSTRUMENTATION SPECIFIC TESTS
1244
+ # ============================================================================
1245
+
1246
+
1247
+ @pytest.mark.forked
1248
+ def test_setup_threads_returns_true():
1249
+ """setup_threads() returns True on success."""
1250
+ result = setup_threads()
1251
+ assert result is True
1252
+
1253
+
1254
+ @pytest.mark.forked
1255
+ def test_setup_threads_idempotent():
1256
+ """Calling setup_threads() multiple times is safe."""
1257
+ result1 = setup_threads()
1258
+ result2 = setup_threads()
1259
+ assert result1 is True
1260
+ assert result2 is True
1261
+
1262
+
1263
+ if __name__ == "__main__":
1264
+ pytest.main([__file__, "-v", "-s"])