braintrust 0.5.0__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. braintrust/__init__.py +14 -0
  2. braintrust/_generated_types.py +56 -3
  3. braintrust/auto.py +179 -0
  4. braintrust/conftest.py +23 -4
  5. braintrust/db_fields.py +10 -0
  6. braintrust/framework.py +18 -5
  7. braintrust/generated_types.py +3 -1
  8. braintrust/logger.py +369 -134
  9. braintrust/merge_row_batch.py +49 -109
  10. braintrust/oai.py +51 -0
  11. braintrust/test_bt_json.py +0 -5
  12. braintrust/test_context.py +1264 -0
  13. braintrust/test_framework.py +37 -0
  14. braintrust/test_http.py +444 -0
  15. braintrust/test_logger.py +179 -5
  16. braintrust/test_merge_row_batch.py +160 -0
  17. braintrust/test_util.py +58 -1
  18. braintrust/util.py +20 -0
  19. braintrust/version.py +2 -2
  20. braintrust/wrappers/agno/__init__.py +2 -3
  21. braintrust/wrappers/anthropic.py +64 -0
  22. braintrust/wrappers/claude_agent_sdk/__init__.py +2 -3
  23. braintrust/wrappers/claude_agent_sdk/test_wrapper.py +9 -0
  24. braintrust/wrappers/dspy.py +52 -1
  25. braintrust/wrappers/google_genai/__init__.py +9 -6
  26. braintrust/wrappers/litellm.py +6 -43
  27. braintrust/wrappers/pydantic_ai.py +2 -3
  28. braintrust/wrappers/test_agno.py +9 -0
  29. braintrust/wrappers/test_anthropic.py +156 -0
  30. braintrust/wrappers/test_dspy.py +117 -0
  31. braintrust/wrappers/test_google_genai.py +9 -0
  32. braintrust/wrappers/test_litellm.py +57 -55
  33. braintrust/wrappers/test_openai.py +253 -1
  34. braintrust/wrappers/test_pydantic_ai_integration.py +9 -0
  35. braintrust/wrappers/test_utils.py +79 -0
  36. braintrust/wrappers/threads.py +114 -0
  37. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/METADATA +1 -1
  38. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/RECORD +41 -37
  39. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/WHEEL +1 -1
  40. braintrust/graph_util.py +0 -147
  41. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/entry_points.txt +0 -0
  42. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
1
  from typing import List
2
+ from unittest.mock import MagicMock
2
3
 
3
4
  import pytest
5
+ from braintrust.logger import BraintrustState
4
6
 
5
7
  from .framework import (
6
8
  Eval,
@@ -241,6 +243,7 @@ async def test_hooks_trial_index_multiple_inputs():
241
243
  assert sorted(input_2_trials) == [0, 1]
242
244
 
243
245
 
246
+ @pytest.mark.vcr
244
247
  @pytest.mark.asyncio
245
248
  async def test_scorer_spans_have_purpose_attribute(with_memory_logger, with_simulate_login):
246
249
  """Test that scorer spans have span_attributes.purpose='scorer' and propagate to subspans."""
@@ -527,3 +530,37 @@ async def test_hooks_without_setting_tags(with_memory_logger, with_simulate_logi
527
530
  root_span = [log for log in logs if not log["span_parents"]]
528
531
  assert len(root_span) == 1
529
532
  assert root_span[0].get("tags") == None
533
+
534
+ @pytest.mark.asyncio
535
+ async def test_eval_enable_cache():
536
+ state = BraintrustState()
537
+ state.span_cache = MagicMock()
538
+
539
+ # Test enable_cache=False
540
+ await Eval(
541
+ "test-enable-cache-false",
542
+ data=[EvalCase(input=1, expected=1)],
543
+ task=lambda x: x,
544
+ scores=[],
545
+ state=state,
546
+ no_send_logs=True,
547
+ enable_cache=False,
548
+ )
549
+ state.span_cache.start.assert_not_called()
550
+ state.span_cache.stop.assert_not_called()
551
+
552
+ # Test enable_cache=True (default)
553
+ state.span_cache.start.reset_mock()
554
+ state.span_cache.stop.reset_mock()
555
+
556
+ await Eval(
557
+ "test-enable-cache-true",
558
+ data=[EvalCase(input=1, expected=1)],
559
+ task=lambda x: x,
560
+ scores=[],
561
+ state=state,
562
+ no_send_logs=True,
563
+ # enable_cache defaults to True
564
+ )
565
+ state.span_cache.start.assert_called()
566
+ state.span_cache.stop.assert_called()
@@ -0,0 +1,444 @@
1
+ """Tests for HTTP connection handling, retries, and timeouts."""
2
+
3
+ import http.server
4
+ import os
5
+ import socketserver
6
+ import threading
7
+ import time
8
+
9
+ import pytest
10
+ import requests
11
+ from braintrust.logger import HTTPConnection, RetryRequestExceptionsAdapter
12
+
13
+
14
+ class HangingConnectionHandler(http.server.BaseHTTPRequestHandler):
15
+ """HTTP handler that simulates stale connections by HANGING (not responding).
16
+
17
+ This simulates what happens when a NAT gateway silently drops packets:
18
+ - The TCP connection appears open
19
+ - Packets are sent but never acknowledged
20
+ - The client waits forever for a response
21
+ """
22
+
23
+ request_count = 0
24
+ hang_count = 1
25
+
26
+ def log_message(self, format, *args):
27
+ pass
28
+
29
+ def do_POST(self):
30
+ HangingConnectionHandler.request_count += 1
31
+
32
+ if HangingConnectionHandler.request_count <= HangingConnectionHandler.hang_count:
33
+ # Simulate stale connection: hang long enough for client to timeout
34
+ for _ in range(100): # 10 seconds total, interruptible
35
+ time.sleep(0.1)
36
+ return
37
+
38
+ self.send_response(200)
39
+ self.send_header("Content-Type", "application/json")
40
+ self.end_headers()
41
+ self.wfile.write(b'{"status": "ok"}')
42
+
43
+ def do_GET(self):
44
+ self.do_POST()
45
+
46
+
47
+ class CloseConnectionHandler(http.server.BaseHTTPRequestHandler):
48
+ """HTTP handler that closes connection immediately (triggers ConnectionError)."""
49
+
50
+ request_count = 0
51
+ fail_count = 1
52
+
53
+ def log_message(self, format, *args):
54
+ pass
55
+
56
+ def do_POST(self):
57
+ CloseConnectionHandler.request_count += 1
58
+
59
+ if CloseConnectionHandler.request_count <= CloseConnectionHandler.fail_count:
60
+ self.connection.close()
61
+ return
62
+
63
+ self.send_response(200)
64
+ self.send_header("Content-Type", "application/json")
65
+ self.end_headers()
66
+ self.wfile.write(b'{"status": "ok"}')
67
+
68
+ def do_GET(self):
69
+ self.do_POST()
70
+
71
+
72
+ @pytest.fixture
73
+ def hanging_server():
74
+ """Fixture that creates a server that HANGS on first request (simulates stale NAT)."""
75
+ HangingConnectionHandler.request_count = 0
76
+ HangingConnectionHandler.hang_count = 1
77
+
78
+ server = socketserver.ThreadingTCPServer(("127.0.0.1", 0), HangingConnectionHandler)
79
+ server.daemon_threads = True
80
+ port = server.server_address[1]
81
+
82
+ thread = threading.Thread(target=server.serve_forever)
83
+ thread.daemon = True
84
+ thread.start()
85
+
86
+ yield f"http://127.0.0.1:{port}"
87
+
88
+ server.shutdown()
89
+ server.server_close()
90
+
91
+
92
+ @pytest.fixture
93
+ def closing_server():
94
+ """Fixture that creates a server that CLOSES connection on first request."""
95
+ CloseConnectionHandler.request_count = 0
96
+ CloseConnectionHandler.fail_count = 1
97
+
98
+ server = socketserver.ThreadingTCPServer(("127.0.0.1", 0), CloseConnectionHandler)
99
+ server.daemon_threads = True
100
+ port = server.server_address[1]
101
+
102
+ thread = threading.Thread(target=server.serve_forever)
103
+ thread.daemon = True
104
+ thread.start()
105
+
106
+ yield f"http://127.0.0.1:{port}"
107
+
108
+ server.shutdown()
109
+ server.server_close()
110
+
111
+
112
+ class TestRetryRequestExceptionsAdapter:
113
+ """Tests for RetryRequestExceptionsAdapter timeout and retry behavior."""
114
+
115
+ def test_adapter_has_default_timeout(self):
116
+ """Adapter should have a default_timeout_secs attribute."""
117
+ adapter = RetryRequestExceptionsAdapter(base_num_retries=3, backoff_factor=0.1)
118
+
119
+ assert hasattr(adapter, "default_timeout_secs")
120
+ assert adapter.default_timeout_secs == 60
121
+
122
+ def test_adapter_applies_default_timeout_to_requests(self, hanging_server):
123
+ """Requests without explicit timeout should use default_timeout_secs."""
124
+ adapter = RetryRequestExceptionsAdapter(
125
+ base_num_retries=3,
126
+ backoff_factor=0.05,
127
+ default_timeout_secs=0.2,
128
+ )
129
+ session = requests.Session()
130
+ session.mount("http://", adapter)
131
+
132
+ start = time.time()
133
+ resp = session.post(f"{hanging_server}/test", json={"hello": "world"})
134
+ elapsed = time.time() - start
135
+
136
+ assert resp.status_code == 200
137
+ assert elapsed < 2.0, f"Should complete within 2s, took {elapsed:.2f}s"
138
+ assert HangingConnectionHandler.request_count >= 2
139
+
140
+ def test_adapter_retries_on_connection_close(self, closing_server):
141
+ """Adapter retries on connection close errors."""
142
+ adapter = RetryRequestExceptionsAdapter(base_num_retries=5, backoff_factor=0.05)
143
+ session = requests.Session()
144
+ session.mount("http://", adapter)
145
+
146
+ start = time.time()
147
+ resp = session.post(f"{closing_server}/test", json={"hello": "world"})
148
+ elapsed = time.time() - start
149
+
150
+ assert resp.status_code == 200
151
+ assert elapsed < 5.0
152
+ assert CloseConnectionHandler.request_count >= 2
153
+
154
+ def test_adapter_resets_pool_on_timeout(self, hanging_server):
155
+ """Adapter resets connection pool on timeout errors via self.close().
156
+
157
+ This is the key fix for stale NAT connections: when a request times out,
158
+ we reset the connection pool to ensure the next retry uses a fresh connection.
159
+ """
160
+ adapter = RetryRequestExceptionsAdapter(
161
+ base_num_retries=10,
162
+ backoff_factor=0.05,
163
+ default_timeout_secs=0.2,
164
+ )
165
+ session = requests.Session()
166
+ session.mount("http://", adapter)
167
+
168
+ start = time.time()
169
+ resp = session.post(f"{hanging_server}/test", json={"hello": "world"})
170
+ elapsed = time.time() - start
171
+
172
+ assert resp.status_code == 200
173
+ assert elapsed < 10.0, f"Request took too long: {elapsed:.2f}s"
174
+ assert HangingConnectionHandler.request_count >= 2
175
+
176
+
177
+ class TestHTTPConnection:
178
+ """Tests for HTTPConnection timeout configuration."""
179
+
180
+ def test_make_long_lived_uses_default_timeout(self, hanging_server):
181
+ """HTTPConnection.make_long_lived() should use default_timeout_secs.
182
+
183
+ This tests the exact scenario from the stale connection bug:
184
+ - Long eval run (15+ minutes)
185
+ - app_conn() becomes stale due to NAT gateway idle timeout
186
+ - summarize() calls fetch_base_experiment()
187
+ - Request hangs forever because no timeout
188
+
189
+ With the fix, make_long_lived() uses default_timeout_secs (60s by default).
190
+ """
191
+ os.environ["BRAINTRUST_HTTP_TIMEOUT"] = "0.2"
192
+ try:
193
+ conn = HTTPConnection(hanging_server)
194
+ conn.make_long_lived()
195
+
196
+ assert hasattr(conn.adapter, "default_timeout_secs")
197
+ assert conn.adapter.default_timeout_secs == 0.2
198
+
199
+ start = time.time()
200
+ resp = conn.post("/test", json={"hello": "world"})
201
+ elapsed = time.time() - start
202
+
203
+ assert resp.status_code == 200
204
+ # Allow more time due to backoff_factor=0.5 in make_long_lived()
205
+ assert elapsed < 15.0, f"Should complete within 15s, took {elapsed:.2f}s"
206
+ finally:
207
+ del os.environ["BRAINTRUST_HTTP_TIMEOUT"]
208
+
209
+ def test_env_var_configures_timeout(self):
210
+ """BRAINTRUST_HTTP_TIMEOUT env var configures timeout via make_long_lived()."""
211
+ os.environ["BRAINTRUST_HTTP_TIMEOUT"] = "30"
212
+ try:
213
+ conn = HTTPConnection("http://localhost:8080")
214
+ conn.make_long_lived()
215
+
216
+ assert hasattr(conn.adapter, "default_timeout_secs")
217
+ assert conn.adapter.default_timeout_secs == 30
218
+ finally:
219
+ del os.environ["BRAINTRUST_HTTP_TIMEOUT"]
220
+
221
+
222
+ class TestAdapterCloseAndReuse:
223
+ """Tests verifying that adapter.close() allows subsequent requests.
224
+
225
+ This addresses the review concern about whether calling self.close()
226
+ (which calls PoolManager.clear()) breaks subsequent request handling.
227
+ """
228
+
229
+ @pytest.fixture
230
+ def simple_server(self):
231
+ """Fixture that creates a server that always succeeds."""
232
+
233
+ class SimpleHandler(http.server.BaseHTTPRequestHandler):
234
+ request_count = 0
235
+
236
+ def log_message(self, format, *args):
237
+ pass
238
+
239
+ def do_GET(self):
240
+ SimpleHandler.request_count += 1
241
+ self.send_response(200)
242
+ self.send_header("Content-Type", "application/json")
243
+ self.end_headers()
244
+ self.wfile.write(b'{"status": "ok"}')
245
+
246
+ SimpleHandler.request_count = 0
247
+ server = socketserver.ThreadingTCPServer(("127.0.0.1", 0), SimpleHandler)
248
+ server.daemon_threads = True
249
+ port = server.server_address[1]
250
+
251
+ thread = threading.Thread(target=server.serve_forever)
252
+ thread.daemon = True
253
+ thread.start()
254
+
255
+ yield f"http://127.0.0.1:{port}", SimpleHandler
256
+
257
+ server.shutdown()
258
+ server.server_close()
259
+
260
+ def test_adapter_works_after_close(self, simple_server):
261
+ """Verify adapter.close() does not break subsequent requests.
262
+
263
+ This is the key test for the PR feedback: after calling close(),
264
+ the PoolManager should create new connection pools on demand.
265
+ """
266
+ url, handler = simple_server
267
+
268
+ adapter = RetryRequestExceptionsAdapter(base_num_retries=3, backoff_factor=0.1)
269
+ session = requests.Session()
270
+ session.mount("http://", adapter)
271
+
272
+ # First request works
273
+ resp1 = session.get(f"{url}/test1")
274
+ assert resp1.status_code == 200
275
+ assert handler.request_count == 1
276
+
277
+ # Explicitly close the adapter (simulates what happens on timeout)
278
+ adapter.close()
279
+
280
+ # Second request should still work after close()
281
+ resp2 = session.get(f"{url}/test2")
282
+ assert resp2.status_code == 200
283
+ assert handler.request_count == 2
284
+
285
+ def test_adapter_works_after_multiple_closes(self, simple_server):
286
+ """Verify adapter works even after multiple close() calls."""
287
+ url, handler = simple_server
288
+
289
+ adapter = RetryRequestExceptionsAdapter(base_num_retries=3, backoff_factor=0.1)
290
+ session = requests.Session()
291
+ session.mount("http://", adapter)
292
+
293
+ for i in range(3):
294
+ resp = session.get(f"{url}/test{i}")
295
+ assert resp.status_code == 200
296
+ adapter.close()
297
+
298
+ assert handler.request_count == 3
299
+
300
+ def test_concurrent_requests_with_close(self):
301
+ """Test thread safety: close() called while requests are in-flight.
302
+
303
+ This tests a potential race condition where one thread calls close()
304
+ while another thread is mid-request. Requests are staggered to ensure
305
+ close() happens while some requests are in-flight.
306
+ """
307
+ import concurrent.futures
308
+
309
+ class SlowHandler(http.server.BaseHTTPRequestHandler):
310
+ request_count = 0
311
+
312
+ def log_message(self, format, *args):
313
+ pass
314
+
315
+ def do_GET(self):
316
+ SlowHandler.request_count += 1
317
+ # Simulate slow response
318
+ time.sleep(0.1)
319
+ self.send_response(200)
320
+ self.send_header("Content-Type", "application/json")
321
+ self.end_headers()
322
+ self.wfile.write(b'{"status": "ok"}')
323
+
324
+ SlowHandler.request_count = 0
325
+ server = socketserver.ThreadingTCPServer(("127.0.0.1", 0), SlowHandler)
326
+ server.daemon_threads = True
327
+ port = server.server_address[1]
328
+ url = f"http://127.0.0.1:{port}"
329
+
330
+ server_thread = threading.Thread(target=server.serve_forever)
331
+ server_thread.daemon = True
332
+ server_thread.start()
333
+
334
+ try:
335
+ adapter = RetryRequestExceptionsAdapter(base_num_retries=3, backoff_factor=0.1)
336
+ session = requests.Session()
337
+ session.mount("http://", adapter)
338
+
339
+ errors = []
340
+
341
+ def make_request(i):
342
+ try:
343
+ time.sleep(i * 0.02) # Stagger requests
344
+ resp = session.get(f"{url}/test{i}")
345
+ return resp.status_code
346
+ except Exception as e:
347
+ errors.append(e)
348
+ return None
349
+
350
+ def close_adapter():
351
+ time.sleep(0.05) # Close while requests are in-flight
352
+ adapter.close()
353
+
354
+ # Launch concurrent requests and a close() call
355
+ with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
356
+ # Start several requests (staggered)
357
+ request_futures = [executor.submit(make_request, i) for i in range(5)]
358
+ # Start close() call mid-flight
359
+ close_future = executor.submit(close_adapter)
360
+
361
+ close_future.result()
362
+ results = [f.result() for f in request_futures]
363
+
364
+ # All requests should succeed (retry on failure)
365
+ assert all(r == 200 for r in results), f"Some requests failed: {results}, errors: {errors}"
366
+
367
+ finally:
368
+ server.shutdown()
369
+ server.server_close()
370
+
371
+ def test_stress_concurrent_close_and_requests(self):
372
+ """Stress test: many close() calls interleaved with requests.
373
+
374
+ Requests are staggered to ensure close() calls happen during requests.
375
+ """
376
+ import concurrent.futures
377
+
378
+ class FastHandler(http.server.BaseHTTPRequestHandler):
379
+ request_count = 0
380
+
381
+ def log_message(self, format, *args):
382
+ pass
383
+
384
+ def do_GET(self):
385
+ FastHandler.request_count += 1
386
+ self.send_response(200)
387
+ self.send_header("Content-Type", "application/json")
388
+ self.end_headers()
389
+ self.wfile.write(b'{"status": "ok"}')
390
+
391
+ FastHandler.request_count = 0
392
+ server = socketserver.ThreadingTCPServer(("127.0.0.1", 0), FastHandler)
393
+ server.daemon_threads = True
394
+ port = server.server_address[1]
395
+ url = f"http://127.0.0.1:{port}"
396
+
397
+ server_thread = threading.Thread(target=server.serve_forever)
398
+ server_thread.daemon = True
399
+ server_thread.start()
400
+
401
+ try:
402
+ adapter = RetryRequestExceptionsAdapter(base_num_retries=5, backoff_factor=0.01)
403
+ session = requests.Session()
404
+ session.mount("http://", adapter)
405
+
406
+ errors = []
407
+ success_count = 0
408
+ lock = threading.Lock()
409
+
410
+ def make_request(i):
411
+ nonlocal success_count
412
+ try:
413
+ time.sleep(i * 0.005) # Stagger requests
414
+ resp = session.get(f"{url}/test{i}")
415
+ if resp.status_code == 200:
416
+ with lock:
417
+ success_count += 1
418
+ return resp.status_code
419
+ except Exception as e:
420
+ with lock:
421
+ errors.append(str(e))
422
+ return None
423
+
424
+ def close_repeatedly():
425
+ for _ in range(20):
426
+ time.sleep(0.01) # Close throughout the request window
427
+ adapter.close()
428
+
429
+ # Launch many concurrent requests while repeatedly closing
430
+ with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
431
+ request_futures = [executor.submit(make_request, i) for i in range(50)]
432
+ close_futures = [executor.submit(close_repeatedly) for _ in range(3)]
433
+
434
+ # Wait for all
435
+ for f in close_futures:
436
+ f.result()
437
+ results = [f.result() for f in request_futures]
438
+
439
+ failed = [r for r in results if r != 200]
440
+ assert len(failed) == 0, f"Failed requests: {len(failed)}, errors: {errors[:5]}"
441
+
442
+ finally:
443
+ server.shutdown()
444
+ server.server_close()
braintrust/test_logger.py CHANGED
@@ -833,6 +833,16 @@ def test_span_project_id_logged_in(with_memory_logger, with_simulate_login):
833
833
  )
834
834
 
835
835
 
836
+ def test_span_export_disables_cache(with_memory_logger):
837
+ """Test that span.export() disables the span cache."""
838
+ logger = init_test_logger(__name__)
839
+
840
+ with logger.start_span(name="test_span") as span:
841
+ # Exporting should disable the span cache
842
+ span.export()
843
+ assert logger.state.span_cache.disabled
844
+
845
+
836
846
  def test_span_project_name_logged_in(with_simulate_login, with_memory_logger):
837
847
  init_logger(project="test-project")
838
848
  span = logger.start_span(name="test-span")
@@ -929,11 +939,7 @@ def test_permalink_with_valid_span_logged_in(with_simulate_login, with_memory_lo
929
939
 
930
940
  @pytest.mark.asyncio
931
941
  async def test_span_link_in_async_context(with_simulate_login, with_memory_logger):
932
- """Test that span.link() works correctly when called from within an async function.
933
-
934
- This tests the bug where current_logger was a plain attribute instead of a ContextVar,
935
- causing span.link() to return a noop link in async contexts even though the span was valid.
936
- """
942
+ """Test that span.link() works correctly when called from within an async function."""
937
943
  import asyncio
938
944
 
939
945
  logger = init_logger(
@@ -966,6 +972,174 @@ async def test_span_link_in_async_context(with_simulate_login, with_memory_logge
966
972
  assert "test-project-id" in link
967
973
 
968
974
 
975
+ @pytest.mark.asyncio
976
+ async def test_current_logger_after_multiple_awaits(with_simulate_login, with_memory_logger):
977
+ """Test that current_logger() works after multiple await points."""
978
+ import asyncio
979
+
980
+ logger = init_logger(project="test-project", project_id="test-project-id")
981
+
982
+ async def check_logger_after_awaits():
983
+ assert braintrust.current_logger() is logger
984
+ await asyncio.sleep(0.01)
985
+ assert braintrust.current_logger() is logger
986
+ await asyncio.sleep(0.01)
987
+ assert braintrust.current_logger() is logger
988
+ return braintrust.current_logger()
989
+
990
+ result = await check_logger_after_awaits()
991
+ assert result is logger
992
+
993
+
994
+ @pytest.mark.asyncio
995
+ async def test_current_logger_in_async_generator(with_simulate_login, with_memory_logger):
996
+ """Test that current_logger() works within an async generator (yield)."""
997
+ import asyncio
998
+
999
+ logger = init_logger(project="test-project", project_id="test-project-id")
1000
+
1001
+ async def logger_generator():
1002
+ for i in range(3):
1003
+ await asyncio.sleep(0.01)
1004
+ yield braintrust.current_logger()
1005
+
1006
+ results = []
1007
+ async for log in logger_generator():
1008
+ results.append(log)
1009
+
1010
+ assert len(results) == 3
1011
+ assert all(r is logger for r in results)
1012
+
1013
+
1014
+ @pytest.mark.asyncio
1015
+ async def test_current_logger_in_separate_task(with_simulate_login, with_memory_logger):
1016
+ """Test that current_logger() works in a separately created asyncio task."""
1017
+ import asyncio
1018
+
1019
+ logger = init_logger(project="test-project", project_id="test-project-id")
1020
+
1021
+ async def get_logger_in_task():
1022
+ await asyncio.sleep(0.01)
1023
+ return braintrust.current_logger()
1024
+
1025
+ # Create a separate task
1026
+ task = asyncio.create_task(get_logger_in_task())
1027
+ result = await task
1028
+
1029
+ assert result is logger
1030
+
1031
+
1032
+ @pytest.mark.asyncio
1033
+ async def test_span_link_in_nested_async(with_simulate_login, with_memory_logger):
1034
+ """Test that span.link() works in deeply nested async calls."""
1035
+ import asyncio
1036
+
1037
+ logger = init_logger(project="test-project", project_id="test-project-id")
1038
+ span = logger.start_span(name="test-span")
1039
+
1040
+ async def level3():
1041
+ await asyncio.sleep(0.01)
1042
+ return span.link()
1043
+
1044
+ async def level2():
1045
+ await asyncio.sleep(0.01)
1046
+ return await level3()
1047
+
1048
+ async def level1():
1049
+ await asyncio.sleep(0.01)
1050
+ return await level2()
1051
+
1052
+ link = await level1()
1053
+ span.end()
1054
+
1055
+ assert link != "https://www.braintrust.dev/noop-span"
1056
+ assert span._id in link
1057
+
1058
+
1059
+ def test_current_logger_in_thread(with_simulate_login, with_memory_logger):
1060
+ """Test that current_logger() works correctly when called from a new thread.
1061
+
1062
+ Regression test: ContextVar values don't propagate to new threads,
1063
+ so current_logger must be a plain attribute for thread access.
1064
+ """
1065
+ import threading
1066
+
1067
+ logger = init_logger(project="test-project", project_id="test-project-id")
1068
+ assert braintrust.current_logger() is logger
1069
+
1070
+ thread_result = {}
1071
+
1072
+ def check_logger_in_thread():
1073
+ thread_result["logger"] = braintrust.current_logger()
1074
+
1075
+ thread = threading.Thread(target=check_logger_in_thread)
1076
+ thread.start()
1077
+ thread.join()
1078
+
1079
+ assert thread_result["logger"] is logger
1080
+
1081
+
1082
+ def test_span_link_in_thread(with_simulate_login, with_memory_logger):
1083
+ """Test that span.link() works correctly when called from a new thread.
1084
+
1085
+ The span should be able to generate a valid link even when link() is called
1086
+ from a different thread than where the span was created.
1087
+ """
1088
+ import threading
1089
+
1090
+ logger = init_logger(project="test-project", project_id="test-project-id")
1091
+ span = logger.start_span(name="test-span")
1092
+
1093
+ thread_result = {}
1094
+
1095
+ def get_link_in_thread():
1096
+ # Call link() on the span directly (not via current_span() which uses ContextVar)
1097
+ thread_result["link"] = span.link()
1098
+
1099
+ thread = threading.Thread(target=get_link_in_thread)
1100
+ thread.start()
1101
+ thread.join()
1102
+ span.end()
1103
+
1104
+ # The link should NOT be the noop link
1105
+ assert thread_result["link"] != "https://www.braintrust.dev/noop-span"
1106
+ # The link should contain the span ID
1107
+ assert span._id in thread_result["link"]
1108
+
1109
+
1110
+ @pytest.mark.asyncio
1111
+ async def test_current_logger_async_context_isolation(with_simulate_login, with_memory_logger):
1112
+ """Test that different async contexts can have different loggers.
1113
+
1114
+ When a child task sets its own logger, it should not affect the parent context.
1115
+ This ensures async context isolation via ContextVar.
1116
+ """
1117
+ import asyncio
1118
+
1119
+ parent_logger = init_logger(project="parent-project", project_id="parent-project-id")
1120
+ assert braintrust.current_logger() is parent_logger
1121
+
1122
+ child_result = {}
1123
+
1124
+ async def child_task():
1125
+ # Child initially inherits parent's logger
1126
+ assert braintrust.current_logger() is parent_logger
1127
+
1128
+ # Child sets its own logger
1129
+ child_logger = init_logger(project="child-project", project_id="child-project-id")
1130
+ child_result["logger"] = braintrust.current_logger()
1131
+ return child_logger
1132
+
1133
+ # Run child task
1134
+ child_logger = await asyncio.create_task(child_task())
1135
+
1136
+ # Child should have seen its own logger
1137
+ assert child_result["logger"] is child_logger
1138
+
1139
+ # Parent should still see parent logger (not affected by child)
1140
+ assert braintrust.current_logger() is parent_logger
1141
+
1142
+
969
1143
  def test_span_set_current(with_memory_logger):
970
1144
  """Test that span.set_current() makes the span accessible via current_span()."""
971
1145
  init_test_logger(__name__)