braintrust 0.5.0__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. braintrust/__init__.py +14 -0
  2. braintrust/_generated_types.py +56 -3
  3. braintrust/auto.py +179 -0
  4. braintrust/conftest.py +23 -4
  5. braintrust/db_fields.py +10 -0
  6. braintrust/framework.py +18 -5
  7. braintrust/generated_types.py +3 -1
  8. braintrust/logger.py +369 -134
  9. braintrust/merge_row_batch.py +49 -109
  10. braintrust/oai.py +51 -0
  11. braintrust/test_bt_json.py +0 -5
  12. braintrust/test_context.py +1264 -0
  13. braintrust/test_framework.py +37 -0
  14. braintrust/test_http.py +444 -0
  15. braintrust/test_logger.py +179 -5
  16. braintrust/test_merge_row_batch.py +160 -0
  17. braintrust/test_util.py +58 -1
  18. braintrust/util.py +20 -0
  19. braintrust/version.py +2 -2
  20. braintrust/wrappers/agno/__init__.py +2 -3
  21. braintrust/wrappers/anthropic.py +64 -0
  22. braintrust/wrappers/claude_agent_sdk/__init__.py +2 -3
  23. braintrust/wrappers/claude_agent_sdk/test_wrapper.py +9 -0
  24. braintrust/wrappers/dspy.py +52 -1
  25. braintrust/wrappers/google_genai/__init__.py +9 -6
  26. braintrust/wrappers/litellm.py +6 -43
  27. braintrust/wrappers/pydantic_ai.py +2 -3
  28. braintrust/wrappers/test_agno.py +9 -0
  29. braintrust/wrappers/test_anthropic.py +156 -0
  30. braintrust/wrappers/test_dspy.py +117 -0
  31. braintrust/wrappers/test_google_genai.py +9 -0
  32. braintrust/wrappers/test_litellm.py +57 -55
  33. braintrust/wrappers/test_openai.py +253 -1
  34. braintrust/wrappers/test_pydantic_ai_integration.py +9 -0
  35. braintrust/wrappers/test_utils.py +79 -0
  36. braintrust/wrappers/threads.py +114 -0
  37. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/METADATA +1 -1
  38. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/RECORD +41 -37
  39. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/WHEEL +1 -1
  40. braintrust/graph_util.py +0 -147
  41. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/entry_points.txt +0 -0
  42. {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/top_level.txt +0 -0
braintrust/logger.py CHANGED
@@ -50,6 +50,8 @@ from .db_fields import (
50
50
  AUDIT_METADATA_FIELD,
51
51
  AUDIT_SOURCE_FIELD,
52
52
  IS_MERGE_FIELD,
53
+ OBJECT_DELETE_FIELD,
54
+ OBJECT_ID_KEYS,
53
55
  TRANSACTION_ID_FIELD,
54
56
  VALID_SOURCES,
55
57
  )
@@ -87,6 +89,7 @@ from .util import (
87
89
  get_caller_location,
88
90
  mask_api_key,
89
91
  merge_dicts,
92
+ parse_env_var_float,
90
93
  response_raise_for_status,
91
94
  )
92
95
 
@@ -97,6 +100,23 @@ from .xact_ids import prettify_xact
97
100
 
98
101
  Metadata = dict[str, Any]
99
102
  DATA_API_VERSION = 2
103
+ LOGS3_OVERFLOW_REFERENCE_TYPE = "logs3_overflow"
104
+ # 6 MB for the AWS lambda gateway (from our own testing).
105
+ DEFAULT_MAX_REQUEST_SIZE = 6 * 1024 * 1024
106
+
107
+
108
+ @dataclasses.dataclass
109
+ class Logs3OverflowInputRow:
110
+ object_ids: dict[str, Any]
111
+ has_comment: bool
112
+ is_delete: bool
113
+ byte_size: int
114
+
115
+
116
+ @dataclasses.dataclass
117
+ class LogItemWithMeta:
118
+ str_value: str
119
+ overflow_meta: Logs3OverflowInputRow
100
120
 
101
121
 
102
122
  class DatasetRef(TypedDict, total=False):
@@ -349,9 +369,16 @@ class BraintrustState:
349
369
  def __init__(self):
350
370
  self.id = str(uuid.uuid4())
351
371
  self.current_experiment: Experiment | None = None
352
- self.current_logger: contextvars.ContextVar[Logger | None] = contextvars.ContextVar(
372
+ # We use both a ContextVar and a plain attribute for the current logger:
373
+ # - _cv_logger (ContextVar): Provides async context isolation so different
374
+ # async tasks can have different loggers without affecting each other.
375
+ # - _local_logger (plain attribute): Fallback for threads, since ContextVars
376
+ # don't propagate to new threads. This way if users don't want to do
377
+ # anything specific they'll always have a "global logger"
378
+ self._cv_logger: contextvars.ContextVar[Logger | None] = contextvars.ContextVar(
353
379
  "braintrust_current_logger", default=None
354
380
  )
381
+ self._local_logger: Logger | None = None
355
382
  self.current_parent: contextvars.ContextVar[str | None] = contextvars.ContextVar(
356
383
  "braintrust_current_parent", default=None
357
384
  )
@@ -425,7 +452,8 @@ class BraintrustState:
425
452
  def reset_parent_state(self):
426
453
  # reset possible parent state for tests
427
454
  self.current_experiment = None
428
- self.current_logger.set(None)
455
+ self._cv_logger.set(None)
456
+ self._local_logger = None
429
457
  self.current_parent.set(None)
430
458
  self.current_span.set(NOOP_SPAN)
431
459
 
@@ -479,22 +507,25 @@ class BraintrustState:
479
507
 
480
508
  def copy_state(self, other: "BraintrustState"):
481
509
  """Copy login information from another BraintrustState instance."""
482
- self.__dict__.update({
483
- k: v
484
- for (k, v) in other.__dict__.items()
485
- if k
486
- not in (
487
- "current_experiment",
488
- "current_logger",
489
- "current_parent",
490
- "current_span",
491
- "_global_bg_logger",
492
- "_override_bg_logger",
493
- "_context_manager",
494
- "_last_otel_setting",
495
- "_context_manager_lock",
496
- )
497
- })
510
+ self.__dict__.update(
511
+ {
512
+ k: v
513
+ for (k, v) in other.__dict__.items()
514
+ if k
515
+ not in (
516
+ "current_experiment",
517
+ "_cv_logger",
518
+ "_local_logger",
519
+ "current_parent",
520
+ "current_span",
521
+ "_global_bg_logger",
522
+ "_override_bg_logger",
523
+ "_context_manager",
524
+ "_last_otel_setting",
525
+ "_context_manager_lock",
526
+ )
527
+ }
528
+ )
498
529
 
499
530
  def login(
500
531
  self,
@@ -555,10 +586,6 @@ class BraintrustState:
555
586
  self._user_info = self.api_conn().get_json("ping")
556
587
  return self._user_info
557
588
 
558
- def set_user_info_if_null(self, info: Mapping[str, Any]):
559
- if not self._user_info:
560
- self._user_info = info
561
-
562
589
  def global_bg_logger(self) -> "_BackgroundLogger":
563
590
  return getattr(self._override_bg_logger, "logger", None) or self._global_bg_logger.get()
564
591
 
@@ -620,14 +647,28 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
620
647
  base_num_retries: Maximum number of retries before giving up and re-raising the exception.
621
648
  backoff_factor: A multiplier used to determine the time to wait between retries.
622
649
  The actual wait time is calculated as: backoff_factor * (2 ** retry_count).
650
+ default_timeout_secs: Default timeout in seconds for requests that don't specify one.
651
+ Prevents indefinite hangs on stale connections.
623
652
  """
624
653
 
625
- def __init__(self, *args: Any, base_num_retries: int = 0, backoff_factor: float = 0.5, **kwargs: Any):
654
+ def __init__(
655
+ self,
656
+ *args: Any,
657
+ base_num_retries: int = 0,
658
+ backoff_factor: float = 0.5,
659
+ default_timeout_secs: float = 60,
660
+ **kwargs: Any,
661
+ ):
626
662
  self.base_num_retries = base_num_retries
627
663
  self.backoff_factor = backoff_factor
664
+ self.default_timeout_secs = default_timeout_secs
628
665
  super().__init__(*args, **kwargs)
629
666
 
630
667
  def send(self, *args, **kwargs):
668
+ # Apply default timeout if none provided to prevent indefinite hangs
669
+ if kwargs.get("timeout") is None:
670
+ kwargs["timeout"] = self.default_timeout_secs
671
+
631
672
  num_prev_retries = 0
632
673
  while True:
633
674
  try:
@@ -639,6 +680,14 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
639
680
  return response
640
681
  except (urllib3.exceptions.HTTPError, requests.exceptions.RequestException) as e:
641
682
  if num_prev_retries < self.base_num_retries:
683
+ if isinstance(e, requests.exceptions.ReadTimeout):
684
+ # Clear all connection pools to discard stale connections. This
685
+ # fixes hangs caused by NAT gateways silently dropping idle TCP
686
+ # connections (e.g., Azure's ~4 min timeout). close() calls
687
+ # PoolManager.clear() which is thread-safe: in-flight requests
688
+ # keep their checked-out connections, and new requests create
689
+ # fresh pools on demand.
690
+ self.close()
642
691
  # Emulates the sleeping logic in the backoff_factor of urllib3 Retry
643
692
  sleep_s = self.backoff_factor * (2**num_prev_retries)
644
693
  print("Retrying request after error:", e, file=sys.stderr)
@@ -660,14 +709,16 @@ class HTTPConnection:
660
709
  def ping(self) -> bool:
661
710
  try:
662
711
  resp = self.get("ping")
663
- _state.set_user_info_if_null(resp.json())
664
712
  return resp.ok
665
713
  except requests.exceptions.ConnectionError:
666
714
  return False
667
715
 
668
716
  def make_long_lived(self) -> None:
669
717
  if not self.adapter:
670
- self.adapter = RetryRequestExceptionsAdapter(base_num_retries=10, backoff_factor=0.5)
718
+ timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0)
719
+ self.adapter = RetryRequestExceptionsAdapter(
720
+ base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs
721
+ )
671
722
  self._reset()
672
723
 
673
724
  @staticmethod
@@ -711,18 +762,10 @@ class HTTPConnection:
711
762
  def delete(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
712
763
  return self.session.delete(_urljoin(self.base_url, path), *args, **kwargs)
713
764
 
714
- def get_json(self, object_type: str, args: Mapping[str, Any] | None = None, retries: int = 0) -> Mapping[str, Any]:
715
- tries = retries + 1
716
- for i in range(tries):
717
- resp = self.get(f"/{object_type}", params=args)
718
- if i < tries - 1 and not resp.ok:
719
- _logger.warning(f"Retrying API request {object_type} {args} {resp.status_code} {resp.text}")
720
- continue
721
- response_raise_for_status(resp)
722
-
723
- return resp.json()
724
- # Needed for type checking.
725
- raise Exception("unreachable")
765
+ def get_json(self, object_type: str, args: Mapping[str, Any] | None = None) -> Mapping[str, Any]:
766
+ resp = self.get(f"/{object_type}", params=args)
767
+ response_raise_for_status(resp)
768
+ return resp.json()
726
769
 
727
770
  def post_json(self, object_type: str, args: Mapping[str, Any] | None = None) -> Any:
728
771
  resp = self.post(f"/{object_type.lstrip('/')}", json=args)
@@ -760,11 +803,43 @@ def construct_json_array(items: Sequence[str]):
760
803
  return "[" + ",".join(items) + "]"
761
804
 
762
805
 
763
- def construct_logs3_data(items: Sequence[str]):
764
- rowsS = construct_json_array(items)
806
+ def construct_logs3_data(items: Sequence[LogItemWithMeta]):
807
+ rowsS = construct_json_array([item.str_value for item in items])
765
808
  return '{"rows": ' + rowsS + ', "api_version": ' + str(DATA_API_VERSION) + "}"
766
809
 
767
810
 
811
+ def construct_logs3_overflow_request(key: str, size_bytes: int | None = None) -> dict[str, Any]:
812
+ rows: dict[str, Any] = {"type": LOGS3_OVERFLOW_REFERENCE_TYPE, "key": key}
813
+ if size_bytes is not None:
814
+ rows["size_bytes"] = size_bytes
815
+ return {"rows": rows, "api_version": DATA_API_VERSION}
816
+
817
+
818
+ def pick_logs3_overflow_object_ids(row: Mapping[str, Any]) -> dict[str, Any]:
819
+ object_ids: dict[str, Any] = {}
820
+ for key in OBJECT_ID_KEYS:
821
+ if key in row:
822
+ object_ids[key] = row[key]
823
+ return object_ids
824
+
825
+
826
+ def stringify_with_overflow_meta(item: dict[str, Any]) -> LogItemWithMeta:
827
+ str_value = bt_dumps(item)
828
+ return LogItemWithMeta(
829
+ str_value=str_value,
830
+ overflow_meta=Logs3OverflowInputRow(
831
+ object_ids=pick_logs3_overflow_object_ids(item),
832
+ has_comment="comment" in item,
833
+ is_delete=item.get(OBJECT_DELETE_FIELD) is True,
834
+ byte_size=utf8_byte_length(str_value),
835
+ ),
836
+ )
837
+
838
+
839
+ def utf8_byte_length(value: str) -> int:
840
+ return len(value.encode("utf-8"))
841
+
842
+
768
843
  class _MaskingError:
769
844
  """Internal class to signal masking errors that need special handling."""
770
845
 
@@ -854,15 +929,12 @@ class _MemoryBackgroundLogger(_BackgroundLogger):
854
929
 
855
930
  # all the logs get merged before gettig sent to the server, so simulate that
856
931
  # here
857
- merged = merge_row_batch(logs)
858
- first = merged[0]
859
- for other in merged[1:]:
860
- first.extend(other)
932
+ batch = merge_row_batch(logs)
861
933
 
862
934
  # Apply masking after merge, similar to HTTPBackgroundLogger
863
935
  if self.masking_function:
864
- for i in range(len(first)):
865
- item = first[i]
936
+ for i in range(len(batch)):
937
+ item = batch[i]
866
938
  masked_item = item.copy()
867
939
 
868
940
  # Only mask specific fields if they exist
@@ -880,9 +952,9 @@ class _MemoryBackgroundLogger(_BackgroundLogger):
880
952
  else:
881
953
  masked_item[field] = masked_value
882
954
 
883
- first[i] = masked_item
955
+ batch[i] = masked_item
884
956
 
885
- return first
957
+ return batch
886
958
 
887
959
 
888
960
  BACKGROUND_LOGGER_BASE_SLEEP_TIME_S = 1.0
@@ -898,6 +970,9 @@ class _HTTPBackgroundLogger:
898
970
  self.masking_function: Callable[[Any], Any] | None = None
899
971
  self.outfile = sys.stderr
900
972
  self.flush_lock = threading.RLock()
973
+ self._max_request_size_override: int | None = None
974
+ self._max_request_size_result: dict[str, Any] | None = None
975
+ self._max_request_size_lock = threading.Lock()
901
976
 
902
977
  try:
903
978
  self.sync_flush = bool(int(os.environ["BRAINTRUST_SYNC_FLUSH"]))
@@ -905,10 +980,9 @@ class _HTTPBackgroundLogger:
905
980
  self.sync_flush = False
906
981
 
907
982
  try:
908
- self.max_request_size = int(os.environ["BRAINTRUST_MAX_REQUEST_SIZE"])
983
+ self._max_request_size_override = int(os.environ["BRAINTRUST_MAX_REQUEST_SIZE"])
909
984
  except:
910
- # 6 MB for the AWS lambda gateway (from our own testing).
911
- self.max_request_size = 6 * 1024 * 1024
985
+ pass
912
986
 
913
987
  try:
914
988
  self.default_batch_size = int(os.environ["BRAINTRUST_DEFAULT_BATCH_SIZE"])
@@ -949,6 +1023,9 @@ class _HTTPBackgroundLogger:
949
1023
  self.logger = logging.getLogger("braintrust")
950
1024
  self.queue: "LogQueue[LazyValue[Dict[str, Any]]]" = LogQueue(maxsize=self.queue_maxsize)
951
1025
 
1026
+ # Counter for tracking overflow uploads (useful for testing)
1027
+ self._overflow_upload_count = 0
1028
+
952
1029
  atexit.register(self._finalize)
953
1030
 
954
1031
  def enforce_queue_size_limit(self, enforce: bool) -> None:
@@ -1005,6 +1082,38 @@ class _HTTPBackgroundLogger:
1005
1082
  else:
1006
1083
  raise
1007
1084
 
1085
+ def _get_max_request_size(self) -> dict[str, Any]:
1086
+ if self._max_request_size_result is not None:
1087
+ return self._max_request_size_result
1088
+ with self._max_request_size_lock:
1089
+ if self._max_request_size_result is not None:
1090
+ return self._max_request_size_result
1091
+ server_limit: int | None = None
1092
+ try:
1093
+ conn = self.api_conn.get()
1094
+ info = conn.get_json("version")
1095
+ limit = info.get("logs3_payload_max_bytes")
1096
+ if isinstance(limit, (int, float)) and int(limit) > 0:
1097
+ server_limit = int(limit)
1098
+ except Exception as e:
1099
+ print(f"Failed to fetch version info for payload limit: {e}", file=self.outfile)
1100
+ valid_server_limit = server_limit if server_limit is not None and server_limit > 0 else None
1101
+ can_use_overflow = valid_server_limit is not None
1102
+ max_request_size = DEFAULT_MAX_REQUEST_SIZE
1103
+ if self._max_request_size_override is not None:
1104
+ max_request_size = (
1105
+ min(self._max_request_size_override, valid_server_limit)
1106
+ if valid_server_limit is not None
1107
+ else self._max_request_size_override
1108
+ )
1109
+ elif valid_server_limit is not None:
1110
+ max_request_size = valid_server_limit
1111
+ self._max_request_size_result = {
1112
+ "max_request_size": max_request_size,
1113
+ "can_use_overflow": can_use_overflow,
1114
+ }
1115
+ return self._max_request_size_result
1116
+
1008
1117
  def flush(self, batch_size: int | None = None):
1009
1118
  if batch_size is None:
1010
1119
  batch_size = self.default_batch_size
@@ -1019,30 +1128,35 @@ class _HTTPBackgroundLogger:
1019
1128
  if len(all_items) == 0:
1020
1129
  return
1021
1130
 
1022
- # Construct batches of records to flush in parallel and in sequence.
1023
- all_items_str = [[bt_dumps(item) for item in bucket] for bucket in all_items]
1024
- batch_sets = batch_items(
1025
- items=all_items_str, batch_max_num_items=batch_size, batch_max_num_bytes=self.max_request_size // 2
1131
+ # Construct batches of records to flush in parallel.
1132
+ all_items_with_meta = [stringify_with_overflow_meta(item) for item in all_items]
1133
+ max_request_size_result = self._get_max_request_size()
1134
+ batches = batch_items(
1135
+ items=all_items_with_meta,
1136
+ batch_max_num_items=batch_size,
1137
+ batch_max_num_bytes=max_request_size_result["max_request_size"] // 2,
1138
+ get_byte_size=lambda item: len(item.str_value),
1026
1139
  )
1027
- for batch_set in batch_sets:
1028
- post_promises = []
1029
- try:
1030
- post_promises = [
1031
- HTTP_REQUEST_THREAD_POOL.submit(self._submit_logs_request, batch) for batch in batch_set
1032
- ]
1033
- except RuntimeError:
1034
- # If the thread pool has shut down, e.g. because the process
1035
- # is terminating, run the requests the old fashioned way.
1036
- for batch in batch_set:
1037
- self._submit_logs_request(batch)
1038
-
1039
- concurrent.futures.wait(post_promises)
1040
- # Raise any exceptions from the promises as one group.
1041
- post_promise_exceptions = [e for e in (f.exception() for f in post_promises) if e is not None]
1042
- if post_promise_exceptions:
1043
- raise exceptiongroup.BaseExceptionGroup(
1044
- f"Encountered the following errors while logging:", post_promise_exceptions
1045
- )
1140
+
1141
+ post_promises = []
1142
+ try:
1143
+ post_promises = [
1144
+ HTTP_REQUEST_THREAD_POOL.submit(self._submit_logs_request, batch, max_request_size_result)
1145
+ for batch in batches
1146
+ ]
1147
+ except RuntimeError:
1148
+ # If the thread pool has shut down, e.g. because the process
1149
+ # is terminating, run the requests the old fashioned way.
1150
+ for batch in batches:
1151
+ self._submit_logs_request(batch, max_request_size_result)
1152
+
1153
+ concurrent.futures.wait(post_promises)
1154
+ # Raise any exceptions from the promises as one group.
1155
+ post_promise_exceptions = [e for e in (f.exception() for f in post_promises) if e is not None]
1156
+ if post_promise_exceptions:
1157
+ raise exceptiongroup.BaseExceptionGroup(
1158
+ f"Encountered the following errors while logging:", post_promise_exceptions
1159
+ )
1046
1160
 
1047
1161
  attachment_errors: list[Exception] = []
1048
1162
  for attachment in attachments:
@@ -1063,42 +1177,40 @@ class _HTTPBackgroundLogger:
1063
1177
 
1064
1178
  def _unwrap_lazy_values(
1065
1179
  self, wrapped_items: Sequence[LazyValue[dict[str, Any]]]
1066
- ) -> tuple[list[list[dict[str, Any]]], list["BaseAttachment"]]:
1180
+ ) -> tuple[list[dict[str, Any]], list["BaseAttachment"]]:
1067
1181
  for i in range(self.num_tries):
1068
1182
  try:
1069
1183
  unwrapped_items = [item.get() for item in wrapped_items]
1070
- batched_items = merge_row_batch(unwrapped_items)
1184
+ merged_items = merge_row_batch(unwrapped_items)
1071
1185
 
1072
1186
  # Apply masking after merging but before sending to backend
1073
1187
  if self.masking_function:
1074
- for batch_idx in range(len(batched_items)):
1075
- for item_idx in range(len(batched_items[batch_idx])):
1076
- item = batched_items[batch_idx][item_idx]
1077
- masked_item = item.copy()
1078
-
1079
- # Only mask specific fields if they exist
1080
- for field in REDACTION_FIELDS:
1081
- if field in item:
1082
- masked_value = _apply_masking_to_field(self.masking_function, item[field], field)
1083
- if isinstance(masked_value, _MaskingError):
1084
- # Drop the field and add error message
1085
- if field in masked_item:
1086
- del masked_item[field]
1087
- if "error" in masked_item:
1088
- masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}"
1089
- else:
1090
- masked_item["error"] = masked_value.error_msg
1188
+ for item_idx in range(len(merged_items)):
1189
+ item = merged_items[item_idx]
1190
+ masked_item = item.copy()
1191
+
1192
+ # Only mask specific fields if they exist
1193
+ for field in REDACTION_FIELDS:
1194
+ if field in item:
1195
+ masked_value = _apply_masking_to_field(self.masking_function, item[field], field)
1196
+ if isinstance(masked_value, _MaskingError):
1197
+ # Drop the field and add error message
1198
+ if field in masked_item:
1199
+ del masked_item[field]
1200
+ if "error" in masked_item:
1201
+ masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}"
1091
1202
  else:
1092
- masked_item[field] = masked_value
1203
+ masked_item["error"] = masked_value.error_msg
1204
+ else:
1205
+ masked_item[field] = masked_value
1093
1206
 
1094
- batched_items[batch_idx][item_idx] = masked_item
1207
+ merged_items[item_idx] = masked_item
1095
1208
 
1096
1209
  attachments: list["BaseAttachment"] = []
1097
- for batch in batched_items:
1098
- for item in batch:
1099
- _extract_attachments(item, attachments)
1210
+ for item in merged_items:
1211
+ _extract_attachments(item, attachments)
1100
1212
 
1101
- return batched_items, attachments
1213
+ return merged_items, attachments
1102
1214
  except Exception as e:
1103
1215
  errmsg = "Encountered error when constructing records to flush"
1104
1216
  is_retrying = i + 1 < self.num_tries
@@ -1121,21 +1233,120 @@ class _HTTPBackgroundLogger:
1121
1233
  )
1122
1234
  return [], []
1123
1235
 
1124
- def _submit_logs_request(self, items: Sequence[str]):
1236
+ def _request_logs3_overflow_upload(
1237
+ self, conn: HTTPConnection, payload_size_bytes: int, rows: list[dict[str, Any]]
1238
+ ) -> dict[str, Any]:
1239
+ try:
1240
+ resp = conn.post(
1241
+ "/logs3/overflow",
1242
+ json={"content_type": "application/json", "size_bytes": payload_size_bytes, "rows": rows},
1243
+ )
1244
+ resp.raise_for_status()
1245
+ payload = resp.json()
1246
+ except Exception as e:
1247
+ raise RuntimeError(f"Failed to request logs3 overflow upload URL: {e}") from e
1248
+
1249
+ method = payload.get("method")
1250
+ if method not in ("PUT", "POST"):
1251
+ raise RuntimeError(f"Invalid response from API server (method must be PUT or POST): {payload}")
1252
+ signed_url = payload.get("signedUrl")
1253
+ headers = payload.get("headers")
1254
+ fields = payload.get("fields")
1255
+ key = payload.get("key")
1256
+ if not isinstance(signed_url, str) or not isinstance(key, str):
1257
+ raise RuntimeError(f"Invalid response from API server: {payload}")
1258
+ if method == "PUT" and not isinstance(headers, dict):
1259
+ raise RuntimeError(f"Invalid response from API server: {payload}")
1260
+ if method == "POST" and not isinstance(fields, dict):
1261
+ raise RuntimeError(f"Invalid response from API server: {payload}")
1262
+
1263
+ if method == "PUT":
1264
+ add_azure_blob_headers(headers, signed_url)
1265
+
1266
+ return {
1267
+ "method": method,
1268
+ "signed_url": signed_url,
1269
+ "headers": headers if isinstance(headers, dict) else {},
1270
+ "fields": fields if isinstance(fields, dict) else {},
1271
+ "key": key,
1272
+ }
1273
+
1274
+ def _upload_logs3_overflow_payload(self, upload: dict[str, Any], payload: str) -> None:
1275
+ obj_conn = HTTPConnection(base_url="", adapter=_http_adapter)
1276
+ method = upload["method"]
1277
+ if method == "POST":
1278
+ fields = upload.get("fields")
1279
+ if not isinstance(fields, dict):
1280
+ raise RuntimeError("Missing logs3 overflow upload fields")
1281
+ content_type = fields.get("Content-Type", "application/json")
1282
+ headers = {k: v for k, v in upload.get("headers", {}).items() if k.lower() != "content-type"}
1283
+ obj_response = obj_conn.post(
1284
+ upload["signed_url"],
1285
+ headers=headers,
1286
+ data=fields,
1287
+ files={"file": ("logs3.json", payload.encode("utf-8"), content_type)},
1288
+ )
1289
+ else:
1290
+ obj_response = obj_conn.put(
1291
+ upload["signed_url"],
1292
+ headers=upload["headers"],
1293
+ data=payload.encode("utf-8"),
1294
+ )
1295
+ obj_response.raise_for_status()
1296
+
1297
+ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_size_result: dict[str, Any]):
1125
1298
  conn = self.api_conn.get()
1126
1299
  dataStr = construct_logs3_data(items)
1300
+ payload_bytes = utf8_byte_length(dataStr)
1301
+ max_request_size = max_request_size_result["max_request_size"]
1302
+ can_use_overflow = max_request_size_result["can_use_overflow"]
1303
+ use_overflow = can_use_overflow and payload_bytes > max_request_size
1127
1304
  if self.all_publish_payloads_dir:
1128
1305
  _HTTPBackgroundLogger._write_payload_to_dir(payload_dir=self.all_publish_payloads_dir, payload=dataStr)
1306
+ overflow_upload: dict[str, Any] | None = None
1307
+ overflow_rows = (
1308
+ [
1309
+ {
1310
+ "object_ids": item.overflow_meta.object_ids,
1311
+ "has_comment": item.overflow_meta.has_comment,
1312
+ "is_delete": item.overflow_meta.is_delete,
1313
+ "input_row": {"byte_size": item.overflow_meta.byte_size},
1314
+ }
1315
+ for item in items
1316
+ ]
1317
+ if use_overflow
1318
+ else None
1319
+ )
1129
1320
  for i in range(self.num_tries):
1130
1321
  start_time = time.time()
1131
- resp = conn.post("/logs3", data=dataStr.encode("utf-8"))
1132
- if resp.ok:
1322
+ resp = None
1323
+ error = None
1324
+ try:
1325
+ if overflow_rows:
1326
+ if overflow_upload is None:
1327
+ current_upload = self._request_logs3_overflow_upload(conn, payload_bytes, overflow_rows)
1328
+ self._upload_logs3_overflow_payload(current_upload, dataStr)
1329
+ overflow_upload = current_upload
1330
+ resp = conn.post(
1331
+ "/logs3",
1332
+ json=construct_logs3_overflow_request(overflow_upload["key"], payload_bytes),
1333
+ )
1334
+ else:
1335
+ resp = conn.post("/logs3", data=dataStr.encode("utf-8"))
1336
+ except Exception as e:
1337
+ error = e
1338
+ if error is None and resp is not None and resp.ok:
1339
+ if overflow_rows:
1340
+ self._overflow_upload_count += 1
1133
1341
  return
1134
- resp_errmsg = f"{resp.status_code}: {resp.text}"
1342
+ if error is None and resp is not None:
1343
+ resp_errmsg = f"{resp.status_code}: {resp.text}"
1344
+ else:
1345
+ resp_errmsg = str(error)
1135
1346
 
1136
1347
  is_retrying = i + 1 < self.num_tries
1137
1348
  retrying_text = "" if is_retrying else " Retrying"
1138
- errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {len(dataStr)}.{retrying_text}\nError: {resp_errmsg}"
1349
+ errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {payload_bytes}.{retrying_text}\nError: {resp_errmsg}"
1139
1350
 
1140
1351
  if not is_retrying and self.failed_publish_payloads_dir:
1141
1352
  _HTTPBackgroundLogger._write_payload_to_dir(
@@ -1160,14 +1371,15 @@ class _HTTPBackgroundLogger:
1160
1371
  return
1161
1372
  try:
1162
1373
  all_items, attachments = self._unwrap_lazy_values(wrapped_items)
1163
- dataStr = construct_logs3_data([bt_dumps(item) for item in all_items])
1374
+ items_with_meta = [stringify_with_overflow_meta(item) for item in all_items]
1375
+ dataStr = construct_logs3_data(items_with_meta)
1164
1376
  attachment_str = bt_dumps([a.debug_info() for a in attachments])
1165
1377
  payload = "{" + f""""data": {dataStr}, "attachments": {attachment_str}""" + "}"
1166
1378
  for output_dir in publish_payloads_dir:
1167
1379
  if not output_dir:
1168
1380
  continue
1169
1381
  _HTTPBackgroundLogger._write_payload_to_dir(payload_dir=output_dir, payload=payload)
1170
- except Exception as e:
1382
+ except Exception:
1171
1383
  traceback.print_exc(file=self.outfile)
1172
1384
 
1173
1385
  def _register_dropped_item_count(self, num_items):
@@ -1634,7 +1846,8 @@ def init_logger(
1634
1846
  if set_current:
1635
1847
  if _state is None:
1636
1848
  raise RuntimeError("_state is None in init_logger. This should never happen.")
1637
- _state.current_logger.set(ret)
1849
+ _state._cv_logger.set(ret)
1850
+ _state._local_logger = ret
1638
1851
  return ret
1639
1852
 
1640
1853
 
@@ -1955,7 +2168,7 @@ def current_experiment() -> Optional["Experiment"]:
1955
2168
  def current_logger() -> Optional["Logger"]:
1956
2169
  """Returns the currently-active logger (set by `braintrust.init_logger(...)`). Returns None if no current logger has been set."""
1957
2170
 
1958
- return _state.current_logger.get()
2171
+ return _state._cv_logger.get() or _state._local_logger
1959
2172
 
1960
2173
 
1961
2174
  def current_span() -> Span:
@@ -3260,17 +3473,17 @@ def _start_span_parent_args(
3260
3473
  if parent:
3261
3474
  assert parent_span_ids is None, "Cannot specify both parent and parent_span_ids"
3262
3475
  parent_components = SpanComponentsV4.from_str(parent)
3263
- assert parent_object_type == parent_components.object_type, (
3264
- f"Mismatch between expected span parent object type {parent_object_type} and provided type {parent_components.object_type}"
3265
- )
3476
+ assert (
3477
+ parent_object_type == parent_components.object_type
3478
+ ), f"Mismatch between expected span parent object type {parent_object_type} and provided type {parent_components.object_type}"
3266
3479
 
3267
3480
  parent_components_object_id_lambda = _span_components_to_object_id_lambda(parent_components)
3268
3481
 
3269
3482
  def compute_parent_object_id():
3270
3483
  parent_components_object_id = parent_components_object_id_lambda()
3271
- assert parent_object_id.get() == parent_components_object_id, (
3272
- f"Mismatch between expected span parent object id {parent_object_id.get()} and provided id {parent_components_object_id}"
3273
- )
3484
+ assert (
3485
+ parent_object_id.get() == parent_components_object_id
3486
+ ), f"Mismatch between expected span parent object id {parent_object_id.get()} and provided id {parent_components_object_id}"
3274
3487
  return parent_object_id.get()
3275
3488
 
3276
3489
  arg_parent_object_id = LazyValue(compute_parent_object_id, use_mutex=False)
@@ -3587,7 +3800,6 @@ class Experiment(ObjectFetcher[ExperimentEvent], Exportable):
3587
3800
  "experiment_id": self.id,
3588
3801
  "base_experiment_id": comparison_experiment_id,
3589
3802
  },
3590
- retries=3,
3591
3803
  )
3592
3804
  except Exception as e:
3593
3805
  _logger.warning(
@@ -3984,6 +4196,9 @@ class SpanImpl(Span):
3984
4196
  use_v4 = os.getenv("BRAINTRUST_OTEL_COMPAT", "false").lower() == "true"
3985
4197
  span_components_class = SpanComponentsV4 if use_v4 else SpanComponentsV3
3986
4198
 
4199
+ # Disable span cache since remote function spans won't be in the local cache
4200
+ self.state.span_cache.disable()
4201
+
3987
4202
  return span_components_class(
3988
4203
  object_type=self.parent_object_type,
3989
4204
  object_id=object_id,
@@ -3997,7 +4212,7 @@ class SpanImpl(Span):
3997
4212
  def link(self) -> str:
3998
4213
  parent_type, info = self._get_parent_info()
3999
4214
  if parent_type == SpanObjectTypeV3.PROJECT_LOGS:
4000
- cur_logger = self.state.current_logger.get()
4215
+ cur_logger = self.state._cv_logger.get() or self.state._local_logger
4001
4216
  if not cur_logger:
4002
4217
  return NOOP_SPAN_PERMALINK
4003
4218
  base_url = cur_logger._get_link_base_url()
@@ -4050,10 +4265,20 @@ class SpanImpl(Span):
4050
4265
  self._context_token = self.state.context_manager.set_current_span(self)
4051
4266
 
4052
4267
  def unset_current(self):
4268
+ """
4269
+ Unset current span context.
4270
+
4271
+ Note: self._context_token may be None if set_current() failed.
4272
+ This is safe - context_manager.unset_current_span() handles None.
4273
+ """
4053
4274
  if self.can_set_current:
4054
- # Pass the stored token to context manager for cleanup
4055
- self.state.context_manager.unset_current_span(self._context_token)
4056
- self._context_token = None
4275
+ try:
4276
+ self.state.context_manager.unset_current_span(self._context_token)
4277
+ except Exception as e:
4278
+ logging.debug(f"Failed to unset current span: {e}")
4279
+ finally:
4280
+ # Always clear the token reference
4281
+ self._context_token = None
4057
4282
 
4058
4283
  def __enter__(self) -> Span:
4059
4284
  self.set_current()
@@ -4064,8 +4289,15 @@ class SpanImpl(Span):
4064
4289
  if exc_type is not None:
4065
4290
  self.log_internal(dict(error=stringify_exception(exc_type, exc_value, tb)))
4066
4291
  finally:
4067
- self.unset_current()
4068
- self.end()
4292
+ try:
4293
+ self.unset_current()
4294
+ except Exception as e:
4295
+ logging.debug(f"Failed to unset current in __exit__: {e}")
4296
+
4297
+ try:
4298
+ self.end()
4299
+ except Exception as e:
4300
+ logging.warning(f"Error ending span: {e}")
4069
4301
 
4070
4302
  def _get_parent_info(self):
4071
4303
  if self.parent_object_type == SpanObjectTypeV3.PROJECT_LOGS:
@@ -4396,7 +4628,6 @@ class Dataset(ObjectFetcher[DatasetEvent]):
4396
4628
  args={
4397
4629
  "dataset_id": self.id,
4398
4630
  },
4399
- retries=3,
4400
4631
  )
4401
4632
  data_summary = DataSummary(new_records=self.new_records, **data_summary_d)
4402
4633
 
@@ -4448,20 +4679,24 @@ def render_message(render: Callable[[str], str], message: PromptMessage):
4448
4679
  if c["type"] == "text":
4449
4680
  rendered_content.append({**c, "text": render(c["text"])})
4450
4681
  elif c["type"] == "image_url":
4451
- rendered_content.append({
4452
- **c,
4453
- "image_url": {**c["image_url"], "url": render(c["image_url"]["url"])},
4454
- })
4682
+ rendered_content.append(
4683
+ {
4684
+ **c,
4685
+ "image_url": {**c["image_url"], "url": render(c["image_url"]["url"])},
4686
+ }
4687
+ )
4455
4688
  elif c["type"] == "file":
4456
- rendered_content.append({
4457
- **c,
4458
- "file": {
4459
- **c["file"],
4460
- "file_data": render(c["file"]["file_data"]),
4461
- **({} if "file_id" not in c["file"] else {"file_id": render(c["file"]["file_id"])}),
4462
- **({} if "filename" not in c["file"] else {"filename": render(c["file"]["filename"])}),
4463
- },
4464
- })
4689
+ rendered_content.append(
4690
+ {
4691
+ **c,
4692
+ "file": {
4693
+ **c["file"],
4694
+ "file_data": render(c["file"]["file_data"]),
4695
+ **({} if "file_id" not in c["file"] else {"file_id": render(c["file"]["file_id"])}),
4696
+ **({} if "filename" not in c["file"] else {"filename": render(c["file"]["filename"])}),
4697
+ },
4698
+ }
4699
+ )
4465
4700
  else:
4466
4701
  raise ValueError(f"Unknown content type: {c['type']}")
4467
4702