braintrust 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
braintrust/logger.py CHANGED
@@ -50,6 +50,8 @@ from .db_fields import (
50
50
  AUDIT_METADATA_FIELD,
51
51
  AUDIT_SOURCE_FIELD,
52
52
  IS_MERGE_FIELD,
53
+ OBJECT_DELETE_FIELD,
54
+ OBJECT_ID_KEYS,
53
55
  TRANSACTION_ID_FIELD,
54
56
  VALID_SOURCES,
55
57
  )
@@ -98,6 +100,23 @@ from .xact_ids import prettify_xact
98
100
 
99
101
  Metadata = dict[str, Any]
100
102
  DATA_API_VERSION = 2
103
+ LOGS3_OVERFLOW_REFERENCE_TYPE = "logs3_overflow"
104
+ # 6 MB for the AWS lambda gateway (from our own testing).
105
+ DEFAULT_MAX_REQUEST_SIZE = 6 * 1024 * 1024
106
+
107
+
108
+ @dataclasses.dataclass
109
+ class Logs3OverflowInputRow:
110
+ object_ids: dict[str, Any]
111
+ has_comment: bool
112
+ is_delete: bool
113
+ byte_size: int
114
+
115
+
116
+ @dataclasses.dataclass
117
+ class LogItemWithMeta:
118
+ str_value: str
119
+ overflow_meta: Logs3OverflowInputRow
101
120
 
102
121
 
103
122
  class DatasetRef(TypedDict, total=False):
@@ -488,23 +507,25 @@ class BraintrustState:
488
507
 
489
508
  def copy_state(self, other: "BraintrustState"):
490
509
  """Copy login information from another BraintrustState instance."""
491
- self.__dict__.update({
492
- k: v
493
- for (k, v) in other.__dict__.items()
494
- if k
495
- not in (
496
- "current_experiment",
497
- "_cv_logger",
498
- "_local_logger",
499
- "current_parent",
500
- "current_span",
501
- "_global_bg_logger",
502
- "_override_bg_logger",
503
- "_context_manager",
504
- "_last_otel_setting",
505
- "_context_manager_lock",
506
- )
507
- })
510
+ self.__dict__.update(
511
+ {
512
+ k: v
513
+ for (k, v) in other.__dict__.items()
514
+ if k
515
+ not in (
516
+ "current_experiment",
517
+ "_cv_logger",
518
+ "_local_logger",
519
+ "current_parent",
520
+ "current_span",
521
+ "_global_bg_logger",
522
+ "_override_bg_logger",
523
+ "_context_manager",
524
+ "_last_otel_setting",
525
+ "_context_manager_lock",
526
+ )
527
+ }
528
+ )
508
529
 
509
530
  def login(
510
531
  self,
@@ -741,20 +762,10 @@ class HTTPConnection:
741
762
  def delete(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
742
763
  return self.session.delete(_urljoin(self.base_url, path), *args, **kwargs)
743
764
 
744
- def get_json(self, object_type: str, args: Mapping[str, Any] | None = None, retries: int = 0) -> Mapping[str, Any]:
745
- # FIXME[matt]: the retry logic seems to be unused and could be n*2 because of the the retry logic
746
- # in the RetryRequestExceptionsAdapter. We should probably remove this.
747
- tries = retries + 1
748
- for i in range(tries):
749
- resp = self.get(f"/{object_type}", params=args)
750
- if i < tries - 1 and not resp.ok:
751
- _logger.warning(f"Retrying API request {object_type} {args} {resp.status_code} {resp.text}")
752
- continue
753
- response_raise_for_status(resp)
754
-
755
- return resp.json()
756
- # Needed for type checking.
757
- raise Exception("unreachable")
765
+ def get_json(self, object_type: str, args: Mapping[str, Any] | None = None) -> Mapping[str, Any]:
766
+ resp = self.get(f"/{object_type}", params=args)
767
+ response_raise_for_status(resp)
768
+ return resp.json()
758
769
 
759
770
  def post_json(self, object_type: str, args: Mapping[str, Any] | None = None) -> Any:
760
771
  resp = self.post(f"/{object_type.lstrip('/')}", json=args)
@@ -792,11 +803,43 @@ def construct_json_array(items: Sequence[str]):
792
803
  return "[" + ",".join(items) + "]"
793
804
 
794
805
 
795
- def construct_logs3_data(items: Sequence[str]):
796
- rowsS = construct_json_array(items)
806
+ def construct_logs3_data(items: Sequence[LogItemWithMeta]):
807
+ rowsS = construct_json_array([item.str_value for item in items])
797
808
  return '{"rows": ' + rowsS + ', "api_version": ' + str(DATA_API_VERSION) + "}"
798
809
 
799
810
 
811
+ def construct_logs3_overflow_request(key: str, size_bytes: int | None = None) -> dict[str, Any]:
812
+ rows: dict[str, Any] = {"type": LOGS3_OVERFLOW_REFERENCE_TYPE, "key": key}
813
+ if size_bytes is not None:
814
+ rows["size_bytes"] = size_bytes
815
+ return {"rows": rows, "api_version": DATA_API_VERSION}
816
+
817
+
818
+ def pick_logs3_overflow_object_ids(row: Mapping[str, Any]) -> dict[str, Any]:
819
+ object_ids: dict[str, Any] = {}
820
+ for key in OBJECT_ID_KEYS:
821
+ if key in row:
822
+ object_ids[key] = row[key]
823
+ return object_ids
824
+
825
+
826
+ def stringify_with_overflow_meta(item: dict[str, Any]) -> LogItemWithMeta:
827
+ str_value = bt_dumps(item)
828
+ return LogItemWithMeta(
829
+ str_value=str_value,
830
+ overflow_meta=Logs3OverflowInputRow(
831
+ object_ids=pick_logs3_overflow_object_ids(item),
832
+ has_comment="comment" in item,
833
+ is_delete=item.get(OBJECT_DELETE_FIELD) is True,
834
+ byte_size=utf8_byte_length(str_value),
835
+ ),
836
+ )
837
+
838
+
839
+ def utf8_byte_length(value: str) -> int:
840
+ return len(value.encode("utf-8"))
841
+
842
+
800
843
  class _MaskingError:
801
844
  """Internal class to signal masking errors that need special handling."""
802
845
 
@@ -886,15 +929,12 @@ class _MemoryBackgroundLogger(_BackgroundLogger):
886
929
 
887
930
  # all the logs get merged before gettig sent to the server, so simulate that
888
931
  # here
889
- merged = merge_row_batch(logs)
890
- first = merged[0]
891
- for other in merged[1:]:
892
- first.extend(other)
932
+ batch = merge_row_batch(logs)
893
933
 
894
934
  # Apply masking after merge, similar to HTTPBackgroundLogger
895
935
  if self.masking_function:
896
- for i in range(len(first)):
897
- item = first[i]
936
+ for i in range(len(batch)):
937
+ item = batch[i]
898
938
  masked_item = item.copy()
899
939
 
900
940
  # Only mask specific fields if they exist
@@ -912,9 +952,9 @@ class _MemoryBackgroundLogger(_BackgroundLogger):
912
952
  else:
913
953
  masked_item[field] = masked_value
914
954
 
915
- first[i] = masked_item
955
+ batch[i] = masked_item
916
956
 
917
- return first
957
+ return batch
918
958
 
919
959
 
920
960
  BACKGROUND_LOGGER_BASE_SLEEP_TIME_S = 1.0
@@ -930,6 +970,9 @@ class _HTTPBackgroundLogger:
930
970
  self.masking_function: Callable[[Any], Any] | None = None
931
971
  self.outfile = sys.stderr
932
972
  self.flush_lock = threading.RLock()
973
+ self._max_request_size_override: int | None = None
974
+ self._max_request_size_result: dict[str, Any] | None = None
975
+ self._max_request_size_lock = threading.Lock()
933
976
 
934
977
  try:
935
978
  self.sync_flush = bool(int(os.environ["BRAINTRUST_SYNC_FLUSH"]))
@@ -937,10 +980,9 @@ class _HTTPBackgroundLogger:
937
980
  self.sync_flush = False
938
981
 
939
982
  try:
940
- self.max_request_size = int(os.environ["BRAINTRUST_MAX_REQUEST_SIZE"])
983
+ self._max_request_size_override = int(os.environ["BRAINTRUST_MAX_REQUEST_SIZE"])
941
984
  except:
942
- # 6 MB for the AWS lambda gateway (from our own testing).
943
- self.max_request_size = 6 * 1024 * 1024
985
+ pass
944
986
 
945
987
  try:
946
988
  self.default_batch_size = int(os.environ["BRAINTRUST_DEFAULT_BATCH_SIZE"])
@@ -981,6 +1023,9 @@ class _HTTPBackgroundLogger:
981
1023
  self.logger = logging.getLogger("braintrust")
982
1024
  self.queue: "LogQueue[LazyValue[Dict[str, Any]]]" = LogQueue(maxsize=self.queue_maxsize)
983
1025
 
1026
+ # Counter for tracking overflow uploads (useful for testing)
1027
+ self._overflow_upload_count = 0
1028
+
984
1029
  atexit.register(self._finalize)
985
1030
 
986
1031
  def enforce_queue_size_limit(self, enforce: bool) -> None:
@@ -1037,6 +1082,38 @@ class _HTTPBackgroundLogger:
1037
1082
  else:
1038
1083
  raise
1039
1084
 
1085
+ def _get_max_request_size(self) -> dict[str, Any]:
1086
+ if self._max_request_size_result is not None:
1087
+ return self._max_request_size_result
1088
+ with self._max_request_size_lock:
1089
+ if self._max_request_size_result is not None:
1090
+ return self._max_request_size_result
1091
+ server_limit: int | None = None
1092
+ try:
1093
+ conn = self.api_conn.get()
1094
+ info = conn.get_json("version")
1095
+ limit = info.get("logs3_payload_max_bytes")
1096
+ if isinstance(limit, (int, float)) and int(limit) > 0:
1097
+ server_limit = int(limit)
1098
+ except Exception as e:
1099
+ print(f"Failed to fetch version info for payload limit: {e}", file=self.outfile)
1100
+ valid_server_limit = server_limit if server_limit is not None and server_limit > 0 else None
1101
+ can_use_overflow = valid_server_limit is not None
1102
+ max_request_size = DEFAULT_MAX_REQUEST_SIZE
1103
+ if self._max_request_size_override is not None:
1104
+ max_request_size = (
1105
+ min(self._max_request_size_override, valid_server_limit)
1106
+ if valid_server_limit is not None
1107
+ else self._max_request_size_override
1108
+ )
1109
+ elif valid_server_limit is not None:
1110
+ max_request_size = valid_server_limit
1111
+ self._max_request_size_result = {
1112
+ "max_request_size": max_request_size,
1113
+ "can_use_overflow": can_use_overflow,
1114
+ }
1115
+ return self._max_request_size_result
1116
+
1040
1117
  def flush(self, batch_size: int | None = None):
1041
1118
  if batch_size is None:
1042
1119
  batch_size = self.default_batch_size
@@ -1051,30 +1128,35 @@ class _HTTPBackgroundLogger:
1051
1128
  if len(all_items) == 0:
1052
1129
  return
1053
1130
 
1054
- # Construct batches of records to flush in parallel and in sequence.
1055
- all_items_str = [[bt_dumps(item) for item in bucket] for bucket in all_items]
1056
- batch_sets = batch_items(
1057
- items=all_items_str, batch_max_num_items=batch_size, batch_max_num_bytes=self.max_request_size // 2
1131
+ # Construct batches of records to flush in parallel.
1132
+ all_items_with_meta = [stringify_with_overflow_meta(item) for item in all_items]
1133
+ max_request_size_result = self._get_max_request_size()
1134
+ batches = batch_items(
1135
+ items=all_items_with_meta,
1136
+ batch_max_num_items=batch_size,
1137
+ batch_max_num_bytes=max_request_size_result["max_request_size"] // 2,
1138
+ get_byte_size=lambda item: len(item.str_value),
1058
1139
  )
1059
- for batch_set in batch_sets:
1060
- post_promises = []
1061
- try:
1062
- post_promises = [
1063
- HTTP_REQUEST_THREAD_POOL.submit(self._submit_logs_request, batch) for batch in batch_set
1064
- ]
1065
- except RuntimeError:
1066
- # If the thread pool has shut down, e.g. because the process
1067
- # is terminating, run the requests the old fashioned way.
1068
- for batch in batch_set:
1069
- self._submit_logs_request(batch)
1070
-
1071
- concurrent.futures.wait(post_promises)
1072
- # Raise any exceptions from the promises as one group.
1073
- post_promise_exceptions = [e for e in (f.exception() for f in post_promises) if e is not None]
1074
- if post_promise_exceptions:
1075
- raise exceptiongroup.BaseExceptionGroup(
1076
- f"Encountered the following errors while logging:", post_promise_exceptions
1077
- )
1140
+
1141
+ post_promises = []
1142
+ try:
1143
+ post_promises = [
1144
+ HTTP_REQUEST_THREAD_POOL.submit(self._submit_logs_request, batch, max_request_size_result)
1145
+ for batch in batches
1146
+ ]
1147
+ except RuntimeError:
1148
+ # If the thread pool has shut down, e.g. because the process
1149
+ # is terminating, run the requests the old fashioned way.
1150
+ for batch in batches:
1151
+ self._submit_logs_request(batch, max_request_size_result)
1152
+
1153
+ concurrent.futures.wait(post_promises)
1154
+ # Raise any exceptions from the promises as one group.
1155
+ post_promise_exceptions = [e for e in (f.exception() for f in post_promises) if e is not None]
1156
+ if post_promise_exceptions:
1157
+ raise exceptiongroup.BaseExceptionGroup(
1158
+ f"Encountered the following errors while logging:", post_promise_exceptions
1159
+ )
1078
1160
 
1079
1161
  attachment_errors: list[Exception] = []
1080
1162
  for attachment in attachments:
@@ -1095,42 +1177,40 @@ class _HTTPBackgroundLogger:
1095
1177
 
1096
1178
  def _unwrap_lazy_values(
1097
1179
  self, wrapped_items: Sequence[LazyValue[dict[str, Any]]]
1098
- ) -> tuple[list[list[dict[str, Any]]], list["BaseAttachment"]]:
1180
+ ) -> tuple[list[dict[str, Any]], list["BaseAttachment"]]:
1099
1181
  for i in range(self.num_tries):
1100
1182
  try:
1101
1183
  unwrapped_items = [item.get() for item in wrapped_items]
1102
- batched_items = merge_row_batch(unwrapped_items)
1184
+ merged_items = merge_row_batch(unwrapped_items)
1103
1185
 
1104
1186
  # Apply masking after merging but before sending to backend
1105
1187
  if self.masking_function:
1106
- for batch_idx in range(len(batched_items)):
1107
- for item_idx in range(len(batched_items[batch_idx])):
1108
- item = batched_items[batch_idx][item_idx]
1109
- masked_item = item.copy()
1110
-
1111
- # Only mask specific fields if they exist
1112
- for field in REDACTION_FIELDS:
1113
- if field in item:
1114
- masked_value = _apply_masking_to_field(self.masking_function, item[field], field)
1115
- if isinstance(masked_value, _MaskingError):
1116
- # Drop the field and add error message
1117
- if field in masked_item:
1118
- del masked_item[field]
1119
- if "error" in masked_item:
1120
- masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}"
1121
- else:
1122
- masked_item["error"] = masked_value.error_msg
1188
+ for item_idx in range(len(merged_items)):
1189
+ item = merged_items[item_idx]
1190
+ masked_item = item.copy()
1191
+
1192
+ # Only mask specific fields if they exist
1193
+ for field in REDACTION_FIELDS:
1194
+ if field in item:
1195
+ masked_value = _apply_masking_to_field(self.masking_function, item[field], field)
1196
+ if isinstance(masked_value, _MaskingError):
1197
+ # Drop the field and add error message
1198
+ if field in masked_item:
1199
+ del masked_item[field]
1200
+ if "error" in masked_item:
1201
+ masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}"
1123
1202
  else:
1124
- masked_item[field] = masked_value
1203
+ masked_item["error"] = masked_value.error_msg
1204
+ else:
1205
+ masked_item[field] = masked_value
1125
1206
 
1126
- batched_items[batch_idx][item_idx] = masked_item
1207
+ merged_items[item_idx] = masked_item
1127
1208
 
1128
1209
  attachments: list["BaseAttachment"] = []
1129
- for batch in batched_items:
1130
- for item in batch:
1131
- _extract_attachments(item, attachments)
1210
+ for item in merged_items:
1211
+ _extract_attachments(item, attachments)
1132
1212
 
1133
- return batched_items, attachments
1213
+ return merged_items, attachments
1134
1214
  except Exception as e:
1135
1215
  errmsg = "Encountered error when constructing records to flush"
1136
1216
  is_retrying = i + 1 < self.num_tries
@@ -1153,21 +1233,120 @@ class _HTTPBackgroundLogger:
1153
1233
  )
1154
1234
  return [], []
1155
1235
 
1156
- def _submit_logs_request(self, items: Sequence[str]):
1236
+ def _request_logs3_overflow_upload(
1237
+ self, conn: HTTPConnection, payload_size_bytes: int, rows: list[dict[str, Any]]
1238
+ ) -> dict[str, Any]:
1239
+ try:
1240
+ resp = conn.post(
1241
+ "/logs3/overflow",
1242
+ json={"content_type": "application/json", "size_bytes": payload_size_bytes, "rows": rows},
1243
+ )
1244
+ resp.raise_for_status()
1245
+ payload = resp.json()
1246
+ except Exception as e:
1247
+ raise RuntimeError(f"Failed to request logs3 overflow upload URL: {e}") from e
1248
+
1249
+ method = payload.get("method")
1250
+ if method not in ("PUT", "POST"):
1251
+ raise RuntimeError(f"Invalid response from API server (method must be PUT or POST): {payload}")
1252
+ signed_url = payload.get("signedUrl")
1253
+ headers = payload.get("headers")
1254
+ fields = payload.get("fields")
1255
+ key = payload.get("key")
1256
+ if not isinstance(signed_url, str) or not isinstance(key, str):
1257
+ raise RuntimeError(f"Invalid response from API server: {payload}")
1258
+ if method == "PUT" and not isinstance(headers, dict):
1259
+ raise RuntimeError(f"Invalid response from API server: {payload}")
1260
+ if method == "POST" and not isinstance(fields, dict):
1261
+ raise RuntimeError(f"Invalid response from API server: {payload}")
1262
+
1263
+ if method == "PUT":
1264
+ add_azure_blob_headers(headers, signed_url)
1265
+
1266
+ return {
1267
+ "method": method,
1268
+ "signed_url": signed_url,
1269
+ "headers": headers if isinstance(headers, dict) else {},
1270
+ "fields": fields if isinstance(fields, dict) else {},
1271
+ "key": key,
1272
+ }
1273
+
1274
+ def _upload_logs3_overflow_payload(self, upload: dict[str, Any], payload: str) -> None:
1275
+ obj_conn = HTTPConnection(base_url="", adapter=_http_adapter)
1276
+ method = upload["method"]
1277
+ if method == "POST":
1278
+ fields = upload.get("fields")
1279
+ if not isinstance(fields, dict):
1280
+ raise RuntimeError("Missing logs3 overflow upload fields")
1281
+ content_type = fields.get("Content-Type", "application/json")
1282
+ headers = {k: v for k, v in upload.get("headers", {}).items() if k.lower() != "content-type"}
1283
+ obj_response = obj_conn.post(
1284
+ upload["signed_url"],
1285
+ headers=headers,
1286
+ data=fields,
1287
+ files={"file": ("logs3.json", payload.encode("utf-8"), content_type)},
1288
+ )
1289
+ else:
1290
+ obj_response = obj_conn.put(
1291
+ upload["signed_url"],
1292
+ headers=upload["headers"],
1293
+ data=payload.encode("utf-8"),
1294
+ )
1295
+ obj_response.raise_for_status()
1296
+
1297
+ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_size_result: dict[str, Any]):
1157
1298
  conn = self.api_conn.get()
1158
1299
  dataStr = construct_logs3_data(items)
1300
+ payload_bytes = utf8_byte_length(dataStr)
1301
+ max_request_size = max_request_size_result["max_request_size"]
1302
+ can_use_overflow = max_request_size_result["can_use_overflow"]
1303
+ use_overflow = can_use_overflow and payload_bytes > max_request_size
1159
1304
  if self.all_publish_payloads_dir:
1160
1305
  _HTTPBackgroundLogger._write_payload_to_dir(payload_dir=self.all_publish_payloads_dir, payload=dataStr)
1306
+ overflow_upload: dict[str, Any] | None = None
1307
+ overflow_rows = (
1308
+ [
1309
+ {
1310
+ "object_ids": item.overflow_meta.object_ids,
1311
+ "has_comment": item.overflow_meta.has_comment,
1312
+ "is_delete": item.overflow_meta.is_delete,
1313
+ "input_row": {"byte_size": item.overflow_meta.byte_size},
1314
+ }
1315
+ for item in items
1316
+ ]
1317
+ if use_overflow
1318
+ else None
1319
+ )
1161
1320
  for i in range(self.num_tries):
1162
1321
  start_time = time.time()
1163
- resp = conn.post("/logs3", data=dataStr.encode("utf-8"))
1164
- if resp.ok:
1322
+ resp = None
1323
+ error = None
1324
+ try:
1325
+ if overflow_rows:
1326
+ if overflow_upload is None:
1327
+ current_upload = self._request_logs3_overflow_upload(conn, payload_bytes, overflow_rows)
1328
+ self._upload_logs3_overflow_payload(current_upload, dataStr)
1329
+ overflow_upload = current_upload
1330
+ resp = conn.post(
1331
+ "/logs3",
1332
+ json=construct_logs3_overflow_request(overflow_upload["key"], payload_bytes),
1333
+ )
1334
+ else:
1335
+ resp = conn.post("/logs3", data=dataStr.encode("utf-8"))
1336
+ except Exception as e:
1337
+ error = e
1338
+ if error is None and resp is not None and resp.ok:
1339
+ if overflow_rows:
1340
+ self._overflow_upload_count += 1
1165
1341
  return
1166
- resp_errmsg = f"{resp.status_code}: {resp.text}"
1342
+ if error is None and resp is not None:
1343
+ resp_errmsg = f"{resp.status_code}: {resp.text}"
1344
+ else:
1345
+ resp_errmsg = str(error)
1167
1346
 
1168
1347
  is_retrying = i + 1 < self.num_tries
1169
1348
  retrying_text = "" if is_retrying else " Retrying"
1170
- errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {len(dataStr)}.{retrying_text}\nError: {resp_errmsg}"
1349
+ errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {payload_bytes}.{retrying_text}\nError: {resp_errmsg}"
1171
1350
 
1172
1351
  if not is_retrying and self.failed_publish_payloads_dir:
1173
1352
  _HTTPBackgroundLogger._write_payload_to_dir(
@@ -1192,14 +1371,15 @@ class _HTTPBackgroundLogger:
1192
1371
  return
1193
1372
  try:
1194
1373
  all_items, attachments = self._unwrap_lazy_values(wrapped_items)
1195
- dataStr = construct_logs3_data([bt_dumps(item) for item in all_items])
1374
+ items_with_meta = [stringify_with_overflow_meta(item) for item in all_items]
1375
+ dataStr = construct_logs3_data(items_with_meta)
1196
1376
  attachment_str = bt_dumps([a.debug_info() for a in attachments])
1197
1377
  payload = "{" + f""""data": {dataStr}, "attachments": {attachment_str}""" + "}"
1198
1378
  for output_dir in publish_payloads_dir:
1199
1379
  if not output_dir:
1200
1380
  continue
1201
1381
  _HTTPBackgroundLogger._write_payload_to_dir(payload_dir=output_dir, payload=payload)
1202
- except Exception as e:
1382
+ except Exception:
1203
1383
  traceback.print_exc(file=self.outfile)
1204
1384
 
1205
1385
  def _register_dropped_item_count(self, num_items):
@@ -3293,17 +3473,17 @@ def _start_span_parent_args(
3293
3473
  if parent:
3294
3474
  assert parent_span_ids is None, "Cannot specify both parent and parent_span_ids"
3295
3475
  parent_components = SpanComponentsV4.from_str(parent)
3296
- assert parent_object_type == parent_components.object_type, (
3297
- f"Mismatch between expected span parent object type {parent_object_type} and provided type {parent_components.object_type}"
3298
- )
3476
+ assert (
3477
+ parent_object_type == parent_components.object_type
3478
+ ), f"Mismatch between expected span parent object type {parent_object_type} and provided type {parent_components.object_type}"
3299
3479
 
3300
3480
  parent_components_object_id_lambda = _span_components_to_object_id_lambda(parent_components)
3301
3481
 
3302
3482
  def compute_parent_object_id():
3303
3483
  parent_components_object_id = parent_components_object_id_lambda()
3304
- assert parent_object_id.get() == parent_components_object_id, (
3305
- f"Mismatch between expected span parent object id {parent_object_id.get()} and provided id {parent_components_object_id}"
3306
- )
3484
+ assert (
3485
+ parent_object_id.get() == parent_components_object_id
3486
+ ), f"Mismatch between expected span parent object id {parent_object_id.get()} and provided id {parent_components_object_id}"
3307
3487
  return parent_object_id.get()
3308
3488
 
3309
3489
  arg_parent_object_id = LazyValue(compute_parent_object_id, use_mutex=False)
@@ -3620,7 +3800,6 @@ class Experiment(ObjectFetcher[ExperimentEvent], Exportable):
3620
3800
  "experiment_id": self.id,
3621
3801
  "base_experiment_id": comparison_experiment_id,
3622
3802
  },
3623
- retries=3,
3624
3803
  )
3625
3804
  except Exception as e:
3626
3805
  _logger.warning(
@@ -4086,10 +4265,20 @@ class SpanImpl(Span):
4086
4265
  self._context_token = self.state.context_manager.set_current_span(self)
4087
4266
 
4088
4267
  def unset_current(self):
4268
+ """
4269
+ Unset current span context.
4270
+
4271
+ Note: self._context_token may be None if set_current() failed.
4272
+ This is safe - context_manager.unset_current_span() handles None.
4273
+ """
4089
4274
  if self.can_set_current:
4090
- # Pass the stored token to context manager for cleanup
4091
- self.state.context_manager.unset_current_span(self._context_token)
4092
- self._context_token = None
4275
+ try:
4276
+ self.state.context_manager.unset_current_span(self._context_token)
4277
+ except Exception as e:
4278
+ logging.debug(f"Failed to unset current span: {e}")
4279
+ finally:
4280
+ # Always clear the token reference
4281
+ self._context_token = None
4093
4282
 
4094
4283
  def __enter__(self) -> Span:
4095
4284
  self.set_current()
@@ -4100,8 +4289,15 @@ class SpanImpl(Span):
4100
4289
  if exc_type is not None:
4101
4290
  self.log_internal(dict(error=stringify_exception(exc_type, exc_value, tb)))
4102
4291
  finally:
4103
- self.unset_current()
4104
- self.end()
4292
+ try:
4293
+ self.unset_current()
4294
+ except Exception as e:
4295
+ logging.debug(f"Failed to unset current in __exit__: {e}")
4296
+
4297
+ try:
4298
+ self.end()
4299
+ except Exception as e:
4300
+ logging.warning(f"Error ending span: {e}")
4105
4301
 
4106
4302
  def _get_parent_info(self):
4107
4303
  if self.parent_object_type == SpanObjectTypeV3.PROJECT_LOGS:
@@ -4432,7 +4628,6 @@ class Dataset(ObjectFetcher[DatasetEvent]):
4432
4628
  args={
4433
4629
  "dataset_id": self.id,
4434
4630
  },
4435
- retries=3,
4436
4631
  )
4437
4632
  data_summary = DataSummary(new_records=self.new_records, **data_summary_d)
4438
4633
 
@@ -4484,20 +4679,24 @@ def render_message(render: Callable[[str], str], message: PromptMessage):
4484
4679
  if c["type"] == "text":
4485
4680
  rendered_content.append({**c, "text": render(c["text"])})
4486
4681
  elif c["type"] == "image_url":
4487
- rendered_content.append({
4488
- **c,
4489
- "image_url": {**c["image_url"], "url": render(c["image_url"]["url"])},
4490
- })
4682
+ rendered_content.append(
4683
+ {
4684
+ **c,
4685
+ "image_url": {**c["image_url"], "url": render(c["image_url"]["url"])},
4686
+ }
4687
+ )
4491
4688
  elif c["type"] == "file":
4492
- rendered_content.append({
4493
- **c,
4494
- "file": {
4495
- **c["file"],
4496
- "file_data": render(c["file"]["file_data"]),
4497
- **({} if "file_id" not in c["file"] else {"file_id": render(c["file"]["file_id"])}),
4498
- **({} if "filename" not in c["file"] else {"filename": render(c["file"]["filename"])}),
4499
- },
4500
- })
4689
+ rendered_content.append(
4690
+ {
4691
+ **c,
4692
+ "file": {
4693
+ **c["file"],
4694
+ "file_data": render(c["file"]["file_data"]),
4695
+ **({} if "file_id" not in c["file"] else {"file_id": render(c["file"]["file_id"])}),
4696
+ **({} if "filename" not in c["file"] else {"filename": render(c["file"]["filename"])}),
4697
+ },
4698
+ }
4699
+ )
4501
4700
  else:
4502
4701
  raise ValueError(f"Unknown content type: {c['type']}")
4503
4702