braintrust 0.5.0__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/__init__.py +14 -0
- braintrust/_generated_types.py +56 -3
- braintrust/auto.py +179 -0
- braintrust/conftest.py +23 -4
- braintrust/db_fields.py +10 -0
- braintrust/framework.py +18 -5
- braintrust/generated_types.py +3 -1
- braintrust/logger.py +369 -134
- braintrust/merge_row_batch.py +49 -109
- braintrust/oai.py +51 -0
- braintrust/test_bt_json.py +0 -5
- braintrust/test_context.py +1264 -0
- braintrust/test_framework.py +37 -0
- braintrust/test_http.py +444 -0
- braintrust/test_logger.py +179 -5
- braintrust/test_merge_row_batch.py +160 -0
- braintrust/test_util.py +58 -1
- braintrust/util.py +20 -0
- braintrust/version.py +2 -2
- braintrust/wrappers/agno/__init__.py +2 -3
- braintrust/wrappers/anthropic.py +64 -0
- braintrust/wrappers/claude_agent_sdk/__init__.py +2 -3
- braintrust/wrappers/claude_agent_sdk/test_wrapper.py +9 -0
- braintrust/wrappers/dspy.py +52 -1
- braintrust/wrappers/google_genai/__init__.py +9 -6
- braintrust/wrappers/litellm.py +6 -43
- braintrust/wrappers/pydantic_ai.py +2 -3
- braintrust/wrappers/test_agno.py +9 -0
- braintrust/wrappers/test_anthropic.py +156 -0
- braintrust/wrappers/test_dspy.py +117 -0
- braintrust/wrappers/test_google_genai.py +9 -0
- braintrust/wrappers/test_litellm.py +57 -55
- braintrust/wrappers/test_openai.py +253 -1
- braintrust/wrappers/test_pydantic_ai_integration.py +9 -0
- braintrust/wrappers/test_utils.py +79 -0
- braintrust/wrappers/threads.py +114 -0
- {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/METADATA +1 -1
- {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/RECORD +41 -37
- {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/WHEEL +1 -1
- braintrust/graph_util.py +0 -147
- {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/entry_points.txt +0 -0
- {braintrust-0.5.0.dist-info → braintrust-0.5.3.dist-info}/top_level.txt +0 -0
braintrust/logger.py
CHANGED
|
@@ -50,6 +50,8 @@ from .db_fields import (
|
|
|
50
50
|
AUDIT_METADATA_FIELD,
|
|
51
51
|
AUDIT_SOURCE_FIELD,
|
|
52
52
|
IS_MERGE_FIELD,
|
|
53
|
+
OBJECT_DELETE_FIELD,
|
|
54
|
+
OBJECT_ID_KEYS,
|
|
53
55
|
TRANSACTION_ID_FIELD,
|
|
54
56
|
VALID_SOURCES,
|
|
55
57
|
)
|
|
@@ -87,6 +89,7 @@ from .util import (
|
|
|
87
89
|
get_caller_location,
|
|
88
90
|
mask_api_key,
|
|
89
91
|
merge_dicts,
|
|
92
|
+
parse_env_var_float,
|
|
90
93
|
response_raise_for_status,
|
|
91
94
|
)
|
|
92
95
|
|
|
@@ -97,6 +100,23 @@ from .xact_ids import prettify_xact
|
|
|
97
100
|
|
|
98
101
|
Metadata = dict[str, Any]
|
|
99
102
|
DATA_API_VERSION = 2
|
|
103
|
+
LOGS3_OVERFLOW_REFERENCE_TYPE = "logs3_overflow"
|
|
104
|
+
# 6 MB for the AWS lambda gateway (from our own testing).
|
|
105
|
+
DEFAULT_MAX_REQUEST_SIZE = 6 * 1024 * 1024
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclasses.dataclass
|
|
109
|
+
class Logs3OverflowInputRow:
|
|
110
|
+
object_ids: dict[str, Any]
|
|
111
|
+
has_comment: bool
|
|
112
|
+
is_delete: bool
|
|
113
|
+
byte_size: int
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclasses.dataclass
|
|
117
|
+
class LogItemWithMeta:
|
|
118
|
+
str_value: str
|
|
119
|
+
overflow_meta: Logs3OverflowInputRow
|
|
100
120
|
|
|
101
121
|
|
|
102
122
|
class DatasetRef(TypedDict, total=False):
|
|
@@ -349,9 +369,16 @@ class BraintrustState:
|
|
|
349
369
|
def __init__(self):
|
|
350
370
|
self.id = str(uuid.uuid4())
|
|
351
371
|
self.current_experiment: Experiment | None = None
|
|
352
|
-
|
|
372
|
+
# We use both a ContextVar and a plain attribute for the current logger:
|
|
373
|
+
# - _cv_logger (ContextVar): Provides async context isolation so different
|
|
374
|
+
# async tasks can have different loggers without affecting each other.
|
|
375
|
+
# - _local_logger (plain attribute): Fallback for threads, since ContextVars
|
|
376
|
+
# don't propagate to new threads. This way if users don't want to do
|
|
377
|
+
# anything specific they'll always have a "global logger"
|
|
378
|
+
self._cv_logger: contextvars.ContextVar[Logger | None] = contextvars.ContextVar(
|
|
353
379
|
"braintrust_current_logger", default=None
|
|
354
380
|
)
|
|
381
|
+
self._local_logger: Logger | None = None
|
|
355
382
|
self.current_parent: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
|
356
383
|
"braintrust_current_parent", default=None
|
|
357
384
|
)
|
|
@@ -425,7 +452,8 @@ class BraintrustState:
|
|
|
425
452
|
def reset_parent_state(self):
|
|
426
453
|
# reset possible parent state for tests
|
|
427
454
|
self.current_experiment = None
|
|
428
|
-
self.
|
|
455
|
+
self._cv_logger.set(None)
|
|
456
|
+
self._local_logger = None
|
|
429
457
|
self.current_parent.set(None)
|
|
430
458
|
self.current_span.set(NOOP_SPAN)
|
|
431
459
|
|
|
@@ -479,22 +507,25 @@ class BraintrustState:
|
|
|
479
507
|
|
|
480
508
|
def copy_state(self, other: "BraintrustState"):
|
|
481
509
|
"""Copy login information from another BraintrustState instance."""
|
|
482
|
-
self.__dict__.update(
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
510
|
+
self.__dict__.update(
|
|
511
|
+
{
|
|
512
|
+
k: v
|
|
513
|
+
for (k, v) in other.__dict__.items()
|
|
514
|
+
if k
|
|
515
|
+
not in (
|
|
516
|
+
"current_experiment",
|
|
517
|
+
"_cv_logger",
|
|
518
|
+
"_local_logger",
|
|
519
|
+
"current_parent",
|
|
520
|
+
"current_span",
|
|
521
|
+
"_global_bg_logger",
|
|
522
|
+
"_override_bg_logger",
|
|
523
|
+
"_context_manager",
|
|
524
|
+
"_last_otel_setting",
|
|
525
|
+
"_context_manager_lock",
|
|
526
|
+
)
|
|
527
|
+
}
|
|
528
|
+
)
|
|
498
529
|
|
|
499
530
|
def login(
|
|
500
531
|
self,
|
|
@@ -555,10 +586,6 @@ class BraintrustState:
|
|
|
555
586
|
self._user_info = self.api_conn().get_json("ping")
|
|
556
587
|
return self._user_info
|
|
557
588
|
|
|
558
|
-
def set_user_info_if_null(self, info: Mapping[str, Any]):
|
|
559
|
-
if not self._user_info:
|
|
560
|
-
self._user_info = info
|
|
561
|
-
|
|
562
589
|
def global_bg_logger(self) -> "_BackgroundLogger":
|
|
563
590
|
return getattr(self._override_bg_logger, "logger", None) or self._global_bg_logger.get()
|
|
564
591
|
|
|
@@ -620,14 +647,28 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
|
|
|
620
647
|
base_num_retries: Maximum number of retries before giving up and re-raising the exception.
|
|
621
648
|
backoff_factor: A multiplier used to determine the time to wait between retries.
|
|
622
649
|
The actual wait time is calculated as: backoff_factor * (2 ** retry_count).
|
|
650
|
+
default_timeout_secs: Default timeout in seconds for requests that don't specify one.
|
|
651
|
+
Prevents indefinite hangs on stale connections.
|
|
623
652
|
"""
|
|
624
653
|
|
|
625
|
-
def __init__(
|
|
654
|
+
def __init__(
|
|
655
|
+
self,
|
|
656
|
+
*args: Any,
|
|
657
|
+
base_num_retries: int = 0,
|
|
658
|
+
backoff_factor: float = 0.5,
|
|
659
|
+
default_timeout_secs: float = 60,
|
|
660
|
+
**kwargs: Any,
|
|
661
|
+
):
|
|
626
662
|
self.base_num_retries = base_num_retries
|
|
627
663
|
self.backoff_factor = backoff_factor
|
|
664
|
+
self.default_timeout_secs = default_timeout_secs
|
|
628
665
|
super().__init__(*args, **kwargs)
|
|
629
666
|
|
|
630
667
|
def send(self, *args, **kwargs):
|
|
668
|
+
# Apply default timeout if none provided to prevent indefinite hangs
|
|
669
|
+
if kwargs.get("timeout") is None:
|
|
670
|
+
kwargs["timeout"] = self.default_timeout_secs
|
|
671
|
+
|
|
631
672
|
num_prev_retries = 0
|
|
632
673
|
while True:
|
|
633
674
|
try:
|
|
@@ -639,6 +680,14 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
|
|
|
639
680
|
return response
|
|
640
681
|
except (urllib3.exceptions.HTTPError, requests.exceptions.RequestException) as e:
|
|
641
682
|
if num_prev_retries < self.base_num_retries:
|
|
683
|
+
if isinstance(e, requests.exceptions.ReadTimeout):
|
|
684
|
+
# Clear all connection pools to discard stale connections. This
|
|
685
|
+
# fixes hangs caused by NAT gateways silently dropping idle TCP
|
|
686
|
+
# connections (e.g., Azure's ~4 min timeout). close() calls
|
|
687
|
+
# PoolManager.clear() which is thread-safe: in-flight requests
|
|
688
|
+
# keep their checked-out connections, and new requests create
|
|
689
|
+
# fresh pools on demand.
|
|
690
|
+
self.close()
|
|
642
691
|
# Emulates the sleeping logic in the backoff_factor of urllib3 Retry
|
|
643
692
|
sleep_s = self.backoff_factor * (2**num_prev_retries)
|
|
644
693
|
print("Retrying request after error:", e, file=sys.stderr)
|
|
@@ -660,14 +709,16 @@ class HTTPConnection:
|
|
|
660
709
|
def ping(self) -> bool:
|
|
661
710
|
try:
|
|
662
711
|
resp = self.get("ping")
|
|
663
|
-
_state.set_user_info_if_null(resp.json())
|
|
664
712
|
return resp.ok
|
|
665
713
|
except requests.exceptions.ConnectionError:
|
|
666
714
|
return False
|
|
667
715
|
|
|
668
716
|
def make_long_lived(self) -> None:
|
|
669
717
|
if not self.adapter:
|
|
670
|
-
|
|
718
|
+
timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0)
|
|
719
|
+
self.adapter = RetryRequestExceptionsAdapter(
|
|
720
|
+
base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs
|
|
721
|
+
)
|
|
671
722
|
self._reset()
|
|
672
723
|
|
|
673
724
|
@staticmethod
|
|
@@ -711,18 +762,10 @@ class HTTPConnection:
|
|
|
711
762
|
def delete(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
|
|
712
763
|
return self.session.delete(_urljoin(self.base_url, path), *args, **kwargs)
|
|
713
764
|
|
|
714
|
-
def get_json(self, object_type: str, args: Mapping[str, Any] | None = None
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
if i < tries - 1 and not resp.ok:
|
|
719
|
-
_logger.warning(f"Retrying API request {object_type} {args} {resp.status_code} {resp.text}")
|
|
720
|
-
continue
|
|
721
|
-
response_raise_for_status(resp)
|
|
722
|
-
|
|
723
|
-
return resp.json()
|
|
724
|
-
# Needed for type checking.
|
|
725
|
-
raise Exception("unreachable")
|
|
765
|
+
def get_json(self, object_type: str, args: Mapping[str, Any] | None = None) -> Mapping[str, Any]:
|
|
766
|
+
resp = self.get(f"/{object_type}", params=args)
|
|
767
|
+
response_raise_for_status(resp)
|
|
768
|
+
return resp.json()
|
|
726
769
|
|
|
727
770
|
def post_json(self, object_type: str, args: Mapping[str, Any] | None = None) -> Any:
|
|
728
771
|
resp = self.post(f"/{object_type.lstrip('/')}", json=args)
|
|
@@ -760,11 +803,43 @@ def construct_json_array(items: Sequence[str]):
|
|
|
760
803
|
return "[" + ",".join(items) + "]"
|
|
761
804
|
|
|
762
805
|
|
|
763
|
-
def construct_logs3_data(items: Sequence[
|
|
764
|
-
rowsS = construct_json_array(items)
|
|
806
|
+
def construct_logs3_data(items: Sequence[LogItemWithMeta]):
|
|
807
|
+
rowsS = construct_json_array([item.str_value for item in items])
|
|
765
808
|
return '{"rows": ' + rowsS + ', "api_version": ' + str(DATA_API_VERSION) + "}"
|
|
766
809
|
|
|
767
810
|
|
|
811
|
+
def construct_logs3_overflow_request(key: str, size_bytes: int | None = None) -> dict[str, Any]:
|
|
812
|
+
rows: dict[str, Any] = {"type": LOGS3_OVERFLOW_REFERENCE_TYPE, "key": key}
|
|
813
|
+
if size_bytes is not None:
|
|
814
|
+
rows["size_bytes"] = size_bytes
|
|
815
|
+
return {"rows": rows, "api_version": DATA_API_VERSION}
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
def pick_logs3_overflow_object_ids(row: Mapping[str, Any]) -> dict[str, Any]:
|
|
819
|
+
object_ids: dict[str, Any] = {}
|
|
820
|
+
for key in OBJECT_ID_KEYS:
|
|
821
|
+
if key in row:
|
|
822
|
+
object_ids[key] = row[key]
|
|
823
|
+
return object_ids
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
def stringify_with_overflow_meta(item: dict[str, Any]) -> LogItemWithMeta:
|
|
827
|
+
str_value = bt_dumps(item)
|
|
828
|
+
return LogItemWithMeta(
|
|
829
|
+
str_value=str_value,
|
|
830
|
+
overflow_meta=Logs3OverflowInputRow(
|
|
831
|
+
object_ids=pick_logs3_overflow_object_ids(item),
|
|
832
|
+
has_comment="comment" in item,
|
|
833
|
+
is_delete=item.get(OBJECT_DELETE_FIELD) is True,
|
|
834
|
+
byte_size=utf8_byte_length(str_value),
|
|
835
|
+
),
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
def utf8_byte_length(value: str) -> int:
|
|
840
|
+
return len(value.encode("utf-8"))
|
|
841
|
+
|
|
842
|
+
|
|
768
843
|
class _MaskingError:
|
|
769
844
|
"""Internal class to signal masking errors that need special handling."""
|
|
770
845
|
|
|
@@ -854,15 +929,12 @@ class _MemoryBackgroundLogger(_BackgroundLogger):
|
|
|
854
929
|
|
|
855
930
|
# all the logs get merged before gettig sent to the server, so simulate that
|
|
856
931
|
# here
|
|
857
|
-
|
|
858
|
-
first = merged[0]
|
|
859
|
-
for other in merged[1:]:
|
|
860
|
-
first.extend(other)
|
|
932
|
+
batch = merge_row_batch(logs)
|
|
861
933
|
|
|
862
934
|
# Apply masking after merge, similar to HTTPBackgroundLogger
|
|
863
935
|
if self.masking_function:
|
|
864
|
-
for i in range(len(
|
|
865
|
-
item =
|
|
936
|
+
for i in range(len(batch)):
|
|
937
|
+
item = batch[i]
|
|
866
938
|
masked_item = item.copy()
|
|
867
939
|
|
|
868
940
|
# Only mask specific fields if they exist
|
|
@@ -880,9 +952,9 @@ class _MemoryBackgroundLogger(_BackgroundLogger):
|
|
|
880
952
|
else:
|
|
881
953
|
masked_item[field] = masked_value
|
|
882
954
|
|
|
883
|
-
|
|
955
|
+
batch[i] = masked_item
|
|
884
956
|
|
|
885
|
-
return
|
|
957
|
+
return batch
|
|
886
958
|
|
|
887
959
|
|
|
888
960
|
BACKGROUND_LOGGER_BASE_SLEEP_TIME_S = 1.0
|
|
@@ -898,6 +970,9 @@ class _HTTPBackgroundLogger:
|
|
|
898
970
|
self.masking_function: Callable[[Any], Any] | None = None
|
|
899
971
|
self.outfile = sys.stderr
|
|
900
972
|
self.flush_lock = threading.RLock()
|
|
973
|
+
self._max_request_size_override: int | None = None
|
|
974
|
+
self._max_request_size_result: dict[str, Any] | None = None
|
|
975
|
+
self._max_request_size_lock = threading.Lock()
|
|
901
976
|
|
|
902
977
|
try:
|
|
903
978
|
self.sync_flush = bool(int(os.environ["BRAINTRUST_SYNC_FLUSH"]))
|
|
@@ -905,10 +980,9 @@ class _HTTPBackgroundLogger:
|
|
|
905
980
|
self.sync_flush = False
|
|
906
981
|
|
|
907
982
|
try:
|
|
908
|
-
self.
|
|
983
|
+
self._max_request_size_override = int(os.environ["BRAINTRUST_MAX_REQUEST_SIZE"])
|
|
909
984
|
except:
|
|
910
|
-
|
|
911
|
-
self.max_request_size = 6 * 1024 * 1024
|
|
985
|
+
pass
|
|
912
986
|
|
|
913
987
|
try:
|
|
914
988
|
self.default_batch_size = int(os.environ["BRAINTRUST_DEFAULT_BATCH_SIZE"])
|
|
@@ -949,6 +1023,9 @@ class _HTTPBackgroundLogger:
|
|
|
949
1023
|
self.logger = logging.getLogger("braintrust")
|
|
950
1024
|
self.queue: "LogQueue[LazyValue[Dict[str, Any]]]" = LogQueue(maxsize=self.queue_maxsize)
|
|
951
1025
|
|
|
1026
|
+
# Counter for tracking overflow uploads (useful for testing)
|
|
1027
|
+
self._overflow_upload_count = 0
|
|
1028
|
+
|
|
952
1029
|
atexit.register(self._finalize)
|
|
953
1030
|
|
|
954
1031
|
def enforce_queue_size_limit(self, enforce: bool) -> None:
|
|
@@ -1005,6 +1082,38 @@ class _HTTPBackgroundLogger:
|
|
|
1005
1082
|
else:
|
|
1006
1083
|
raise
|
|
1007
1084
|
|
|
1085
|
+
def _get_max_request_size(self) -> dict[str, Any]:
|
|
1086
|
+
if self._max_request_size_result is not None:
|
|
1087
|
+
return self._max_request_size_result
|
|
1088
|
+
with self._max_request_size_lock:
|
|
1089
|
+
if self._max_request_size_result is not None:
|
|
1090
|
+
return self._max_request_size_result
|
|
1091
|
+
server_limit: int | None = None
|
|
1092
|
+
try:
|
|
1093
|
+
conn = self.api_conn.get()
|
|
1094
|
+
info = conn.get_json("version")
|
|
1095
|
+
limit = info.get("logs3_payload_max_bytes")
|
|
1096
|
+
if isinstance(limit, (int, float)) and int(limit) > 0:
|
|
1097
|
+
server_limit = int(limit)
|
|
1098
|
+
except Exception as e:
|
|
1099
|
+
print(f"Failed to fetch version info for payload limit: {e}", file=self.outfile)
|
|
1100
|
+
valid_server_limit = server_limit if server_limit is not None and server_limit > 0 else None
|
|
1101
|
+
can_use_overflow = valid_server_limit is not None
|
|
1102
|
+
max_request_size = DEFAULT_MAX_REQUEST_SIZE
|
|
1103
|
+
if self._max_request_size_override is not None:
|
|
1104
|
+
max_request_size = (
|
|
1105
|
+
min(self._max_request_size_override, valid_server_limit)
|
|
1106
|
+
if valid_server_limit is not None
|
|
1107
|
+
else self._max_request_size_override
|
|
1108
|
+
)
|
|
1109
|
+
elif valid_server_limit is not None:
|
|
1110
|
+
max_request_size = valid_server_limit
|
|
1111
|
+
self._max_request_size_result = {
|
|
1112
|
+
"max_request_size": max_request_size,
|
|
1113
|
+
"can_use_overflow": can_use_overflow,
|
|
1114
|
+
}
|
|
1115
|
+
return self._max_request_size_result
|
|
1116
|
+
|
|
1008
1117
|
def flush(self, batch_size: int | None = None):
|
|
1009
1118
|
if batch_size is None:
|
|
1010
1119
|
batch_size = self.default_batch_size
|
|
@@ -1019,30 +1128,35 @@ class _HTTPBackgroundLogger:
|
|
|
1019
1128
|
if len(all_items) == 0:
|
|
1020
1129
|
return
|
|
1021
1130
|
|
|
1022
|
-
# Construct batches of records to flush in parallel
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1131
|
+
# Construct batches of records to flush in parallel.
|
|
1132
|
+
all_items_with_meta = [stringify_with_overflow_meta(item) for item in all_items]
|
|
1133
|
+
max_request_size_result = self._get_max_request_size()
|
|
1134
|
+
batches = batch_items(
|
|
1135
|
+
items=all_items_with_meta,
|
|
1136
|
+
batch_max_num_items=batch_size,
|
|
1137
|
+
batch_max_num_bytes=max_request_size_result["max_request_size"] // 2,
|
|
1138
|
+
get_byte_size=lambda item: len(item.str_value),
|
|
1026
1139
|
)
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1140
|
+
|
|
1141
|
+
post_promises = []
|
|
1142
|
+
try:
|
|
1143
|
+
post_promises = [
|
|
1144
|
+
HTTP_REQUEST_THREAD_POOL.submit(self._submit_logs_request, batch, max_request_size_result)
|
|
1145
|
+
for batch in batches
|
|
1146
|
+
]
|
|
1147
|
+
except RuntimeError:
|
|
1148
|
+
# If the thread pool has shut down, e.g. because the process
|
|
1149
|
+
# is terminating, run the requests the old fashioned way.
|
|
1150
|
+
for batch in batches:
|
|
1151
|
+
self._submit_logs_request(batch, max_request_size_result)
|
|
1152
|
+
|
|
1153
|
+
concurrent.futures.wait(post_promises)
|
|
1154
|
+
# Raise any exceptions from the promises as one group.
|
|
1155
|
+
post_promise_exceptions = [e for e in (f.exception() for f in post_promises) if e is not None]
|
|
1156
|
+
if post_promise_exceptions:
|
|
1157
|
+
raise exceptiongroup.BaseExceptionGroup(
|
|
1158
|
+
f"Encountered the following errors while logging:", post_promise_exceptions
|
|
1159
|
+
)
|
|
1046
1160
|
|
|
1047
1161
|
attachment_errors: list[Exception] = []
|
|
1048
1162
|
for attachment in attachments:
|
|
@@ -1063,42 +1177,40 @@ class _HTTPBackgroundLogger:
|
|
|
1063
1177
|
|
|
1064
1178
|
def _unwrap_lazy_values(
|
|
1065
1179
|
self, wrapped_items: Sequence[LazyValue[dict[str, Any]]]
|
|
1066
|
-
) -> tuple[list[
|
|
1180
|
+
) -> tuple[list[dict[str, Any]], list["BaseAttachment"]]:
|
|
1067
1181
|
for i in range(self.num_tries):
|
|
1068
1182
|
try:
|
|
1069
1183
|
unwrapped_items = [item.get() for item in wrapped_items]
|
|
1070
|
-
|
|
1184
|
+
merged_items = merge_row_batch(unwrapped_items)
|
|
1071
1185
|
|
|
1072
1186
|
# Apply masking after merging but before sending to backend
|
|
1073
1187
|
if self.masking_function:
|
|
1074
|
-
for
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}"
|
|
1089
|
-
else:
|
|
1090
|
-
masked_item["error"] = masked_value.error_msg
|
|
1188
|
+
for item_idx in range(len(merged_items)):
|
|
1189
|
+
item = merged_items[item_idx]
|
|
1190
|
+
masked_item = item.copy()
|
|
1191
|
+
|
|
1192
|
+
# Only mask specific fields if they exist
|
|
1193
|
+
for field in REDACTION_FIELDS:
|
|
1194
|
+
if field in item:
|
|
1195
|
+
masked_value = _apply_masking_to_field(self.masking_function, item[field], field)
|
|
1196
|
+
if isinstance(masked_value, _MaskingError):
|
|
1197
|
+
# Drop the field and add error message
|
|
1198
|
+
if field in masked_item:
|
|
1199
|
+
del masked_item[field]
|
|
1200
|
+
if "error" in masked_item:
|
|
1201
|
+
masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}"
|
|
1091
1202
|
else:
|
|
1092
|
-
masked_item[
|
|
1203
|
+
masked_item["error"] = masked_value.error_msg
|
|
1204
|
+
else:
|
|
1205
|
+
masked_item[field] = masked_value
|
|
1093
1206
|
|
|
1094
|
-
|
|
1207
|
+
merged_items[item_idx] = masked_item
|
|
1095
1208
|
|
|
1096
1209
|
attachments: list["BaseAttachment"] = []
|
|
1097
|
-
for
|
|
1098
|
-
|
|
1099
|
-
_extract_attachments(item, attachments)
|
|
1210
|
+
for item in merged_items:
|
|
1211
|
+
_extract_attachments(item, attachments)
|
|
1100
1212
|
|
|
1101
|
-
return
|
|
1213
|
+
return merged_items, attachments
|
|
1102
1214
|
except Exception as e:
|
|
1103
1215
|
errmsg = "Encountered error when constructing records to flush"
|
|
1104
1216
|
is_retrying = i + 1 < self.num_tries
|
|
@@ -1121,21 +1233,120 @@ class _HTTPBackgroundLogger:
|
|
|
1121
1233
|
)
|
|
1122
1234
|
return [], []
|
|
1123
1235
|
|
|
1124
|
-
def
|
|
1236
|
+
def _request_logs3_overflow_upload(
|
|
1237
|
+
self, conn: HTTPConnection, payload_size_bytes: int, rows: list[dict[str, Any]]
|
|
1238
|
+
) -> dict[str, Any]:
|
|
1239
|
+
try:
|
|
1240
|
+
resp = conn.post(
|
|
1241
|
+
"/logs3/overflow",
|
|
1242
|
+
json={"content_type": "application/json", "size_bytes": payload_size_bytes, "rows": rows},
|
|
1243
|
+
)
|
|
1244
|
+
resp.raise_for_status()
|
|
1245
|
+
payload = resp.json()
|
|
1246
|
+
except Exception as e:
|
|
1247
|
+
raise RuntimeError(f"Failed to request logs3 overflow upload URL: {e}") from e
|
|
1248
|
+
|
|
1249
|
+
method = payload.get("method")
|
|
1250
|
+
if method not in ("PUT", "POST"):
|
|
1251
|
+
raise RuntimeError(f"Invalid response from API server (method must be PUT or POST): {payload}")
|
|
1252
|
+
signed_url = payload.get("signedUrl")
|
|
1253
|
+
headers = payload.get("headers")
|
|
1254
|
+
fields = payload.get("fields")
|
|
1255
|
+
key = payload.get("key")
|
|
1256
|
+
if not isinstance(signed_url, str) or not isinstance(key, str):
|
|
1257
|
+
raise RuntimeError(f"Invalid response from API server: {payload}")
|
|
1258
|
+
if method == "PUT" and not isinstance(headers, dict):
|
|
1259
|
+
raise RuntimeError(f"Invalid response from API server: {payload}")
|
|
1260
|
+
if method == "POST" and not isinstance(fields, dict):
|
|
1261
|
+
raise RuntimeError(f"Invalid response from API server: {payload}")
|
|
1262
|
+
|
|
1263
|
+
if method == "PUT":
|
|
1264
|
+
add_azure_blob_headers(headers, signed_url)
|
|
1265
|
+
|
|
1266
|
+
return {
|
|
1267
|
+
"method": method,
|
|
1268
|
+
"signed_url": signed_url,
|
|
1269
|
+
"headers": headers if isinstance(headers, dict) else {},
|
|
1270
|
+
"fields": fields if isinstance(fields, dict) else {},
|
|
1271
|
+
"key": key,
|
|
1272
|
+
}
|
|
1273
|
+
|
|
1274
|
+
def _upload_logs3_overflow_payload(self, upload: dict[str, Any], payload: str) -> None:
|
|
1275
|
+
obj_conn = HTTPConnection(base_url="", adapter=_http_adapter)
|
|
1276
|
+
method = upload["method"]
|
|
1277
|
+
if method == "POST":
|
|
1278
|
+
fields = upload.get("fields")
|
|
1279
|
+
if not isinstance(fields, dict):
|
|
1280
|
+
raise RuntimeError("Missing logs3 overflow upload fields")
|
|
1281
|
+
content_type = fields.get("Content-Type", "application/json")
|
|
1282
|
+
headers = {k: v for k, v in upload.get("headers", {}).items() if k.lower() != "content-type"}
|
|
1283
|
+
obj_response = obj_conn.post(
|
|
1284
|
+
upload["signed_url"],
|
|
1285
|
+
headers=headers,
|
|
1286
|
+
data=fields,
|
|
1287
|
+
files={"file": ("logs3.json", payload.encode("utf-8"), content_type)},
|
|
1288
|
+
)
|
|
1289
|
+
else:
|
|
1290
|
+
obj_response = obj_conn.put(
|
|
1291
|
+
upload["signed_url"],
|
|
1292
|
+
headers=upload["headers"],
|
|
1293
|
+
data=payload.encode("utf-8"),
|
|
1294
|
+
)
|
|
1295
|
+
obj_response.raise_for_status()
|
|
1296
|
+
|
|
1297
|
+
def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_size_result: dict[str, Any]):
|
|
1125
1298
|
conn = self.api_conn.get()
|
|
1126
1299
|
dataStr = construct_logs3_data(items)
|
|
1300
|
+
payload_bytes = utf8_byte_length(dataStr)
|
|
1301
|
+
max_request_size = max_request_size_result["max_request_size"]
|
|
1302
|
+
can_use_overflow = max_request_size_result["can_use_overflow"]
|
|
1303
|
+
use_overflow = can_use_overflow and payload_bytes > max_request_size
|
|
1127
1304
|
if self.all_publish_payloads_dir:
|
|
1128
1305
|
_HTTPBackgroundLogger._write_payload_to_dir(payload_dir=self.all_publish_payloads_dir, payload=dataStr)
|
|
1306
|
+
overflow_upload: dict[str, Any] | None = None
|
|
1307
|
+
overflow_rows = (
|
|
1308
|
+
[
|
|
1309
|
+
{
|
|
1310
|
+
"object_ids": item.overflow_meta.object_ids,
|
|
1311
|
+
"has_comment": item.overflow_meta.has_comment,
|
|
1312
|
+
"is_delete": item.overflow_meta.is_delete,
|
|
1313
|
+
"input_row": {"byte_size": item.overflow_meta.byte_size},
|
|
1314
|
+
}
|
|
1315
|
+
for item in items
|
|
1316
|
+
]
|
|
1317
|
+
if use_overflow
|
|
1318
|
+
else None
|
|
1319
|
+
)
|
|
1129
1320
|
for i in range(self.num_tries):
|
|
1130
1321
|
start_time = time.time()
|
|
1131
|
-
resp =
|
|
1132
|
-
|
|
1322
|
+
resp = None
|
|
1323
|
+
error = None
|
|
1324
|
+
try:
|
|
1325
|
+
if overflow_rows:
|
|
1326
|
+
if overflow_upload is None:
|
|
1327
|
+
current_upload = self._request_logs3_overflow_upload(conn, payload_bytes, overflow_rows)
|
|
1328
|
+
self._upload_logs3_overflow_payload(current_upload, dataStr)
|
|
1329
|
+
overflow_upload = current_upload
|
|
1330
|
+
resp = conn.post(
|
|
1331
|
+
"/logs3",
|
|
1332
|
+
json=construct_logs3_overflow_request(overflow_upload["key"], payload_bytes),
|
|
1333
|
+
)
|
|
1334
|
+
else:
|
|
1335
|
+
resp = conn.post("/logs3", data=dataStr.encode("utf-8"))
|
|
1336
|
+
except Exception as e:
|
|
1337
|
+
error = e
|
|
1338
|
+
if error is None and resp is not None and resp.ok:
|
|
1339
|
+
if overflow_rows:
|
|
1340
|
+
self._overflow_upload_count += 1
|
|
1133
1341
|
return
|
|
1134
|
-
|
|
1342
|
+
if error is None and resp is not None:
|
|
1343
|
+
resp_errmsg = f"{resp.status_code}: {resp.text}"
|
|
1344
|
+
else:
|
|
1345
|
+
resp_errmsg = str(error)
|
|
1135
1346
|
|
|
1136
1347
|
is_retrying = i + 1 < self.num_tries
|
|
1137
1348
|
retrying_text = "" if is_retrying else " Retrying"
|
|
1138
|
-
errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {
|
|
1349
|
+
errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {payload_bytes}.{retrying_text}\nError: {resp_errmsg}"
|
|
1139
1350
|
|
|
1140
1351
|
if not is_retrying and self.failed_publish_payloads_dir:
|
|
1141
1352
|
_HTTPBackgroundLogger._write_payload_to_dir(
|
|
@@ -1160,14 +1371,15 @@ class _HTTPBackgroundLogger:
|
|
|
1160
1371
|
return
|
|
1161
1372
|
try:
|
|
1162
1373
|
all_items, attachments = self._unwrap_lazy_values(wrapped_items)
|
|
1163
|
-
|
|
1374
|
+
items_with_meta = [stringify_with_overflow_meta(item) for item in all_items]
|
|
1375
|
+
dataStr = construct_logs3_data(items_with_meta)
|
|
1164
1376
|
attachment_str = bt_dumps([a.debug_info() for a in attachments])
|
|
1165
1377
|
payload = "{" + f""""data": {dataStr}, "attachments": {attachment_str}""" + "}"
|
|
1166
1378
|
for output_dir in publish_payloads_dir:
|
|
1167
1379
|
if not output_dir:
|
|
1168
1380
|
continue
|
|
1169
1381
|
_HTTPBackgroundLogger._write_payload_to_dir(payload_dir=output_dir, payload=payload)
|
|
1170
|
-
except Exception
|
|
1382
|
+
except Exception:
|
|
1171
1383
|
traceback.print_exc(file=self.outfile)
|
|
1172
1384
|
|
|
1173
1385
|
def _register_dropped_item_count(self, num_items):
|
|
@@ -1634,7 +1846,8 @@ def init_logger(
|
|
|
1634
1846
|
if set_current:
|
|
1635
1847
|
if _state is None:
|
|
1636
1848
|
raise RuntimeError("_state is None in init_logger. This should never happen.")
|
|
1637
|
-
_state.
|
|
1849
|
+
_state._cv_logger.set(ret)
|
|
1850
|
+
_state._local_logger = ret
|
|
1638
1851
|
return ret
|
|
1639
1852
|
|
|
1640
1853
|
|
|
@@ -1955,7 +2168,7 @@ def current_experiment() -> Optional["Experiment"]:
|
|
|
1955
2168
|
def current_logger() -> Optional["Logger"]:
|
|
1956
2169
|
"""Returns the currently-active logger (set by `braintrust.init_logger(...)`). Returns None if no current logger has been set."""
|
|
1957
2170
|
|
|
1958
|
-
return _state.
|
|
2171
|
+
return _state._cv_logger.get() or _state._local_logger
|
|
1959
2172
|
|
|
1960
2173
|
|
|
1961
2174
|
def current_span() -> Span:
|
|
@@ -3260,17 +3473,17 @@ def _start_span_parent_args(
|
|
|
3260
3473
|
if parent:
|
|
3261
3474
|
assert parent_span_ids is None, "Cannot specify both parent and parent_span_ids"
|
|
3262
3475
|
parent_components = SpanComponentsV4.from_str(parent)
|
|
3263
|
-
assert
|
|
3264
|
-
|
|
3265
|
-
)
|
|
3476
|
+
assert (
|
|
3477
|
+
parent_object_type == parent_components.object_type
|
|
3478
|
+
), f"Mismatch between expected span parent object type {parent_object_type} and provided type {parent_components.object_type}"
|
|
3266
3479
|
|
|
3267
3480
|
parent_components_object_id_lambda = _span_components_to_object_id_lambda(parent_components)
|
|
3268
3481
|
|
|
3269
3482
|
def compute_parent_object_id():
|
|
3270
3483
|
parent_components_object_id = parent_components_object_id_lambda()
|
|
3271
|
-
assert
|
|
3272
|
-
|
|
3273
|
-
)
|
|
3484
|
+
assert (
|
|
3485
|
+
parent_object_id.get() == parent_components_object_id
|
|
3486
|
+
), f"Mismatch between expected span parent object id {parent_object_id.get()} and provided id {parent_components_object_id}"
|
|
3274
3487
|
return parent_object_id.get()
|
|
3275
3488
|
|
|
3276
3489
|
arg_parent_object_id = LazyValue(compute_parent_object_id, use_mutex=False)
|
|
@@ -3587,7 +3800,6 @@ class Experiment(ObjectFetcher[ExperimentEvent], Exportable):
|
|
|
3587
3800
|
"experiment_id": self.id,
|
|
3588
3801
|
"base_experiment_id": comparison_experiment_id,
|
|
3589
3802
|
},
|
|
3590
|
-
retries=3,
|
|
3591
3803
|
)
|
|
3592
3804
|
except Exception as e:
|
|
3593
3805
|
_logger.warning(
|
|
@@ -3984,6 +4196,9 @@ class SpanImpl(Span):
|
|
|
3984
4196
|
use_v4 = os.getenv("BRAINTRUST_OTEL_COMPAT", "false").lower() == "true"
|
|
3985
4197
|
span_components_class = SpanComponentsV4 if use_v4 else SpanComponentsV3
|
|
3986
4198
|
|
|
4199
|
+
# Disable span cache since remote function spans won't be in the local cache
|
|
4200
|
+
self.state.span_cache.disable()
|
|
4201
|
+
|
|
3987
4202
|
return span_components_class(
|
|
3988
4203
|
object_type=self.parent_object_type,
|
|
3989
4204
|
object_id=object_id,
|
|
@@ -3997,7 +4212,7 @@ class SpanImpl(Span):
|
|
|
3997
4212
|
def link(self) -> str:
|
|
3998
4213
|
parent_type, info = self._get_parent_info()
|
|
3999
4214
|
if parent_type == SpanObjectTypeV3.PROJECT_LOGS:
|
|
4000
|
-
cur_logger = self.state.
|
|
4215
|
+
cur_logger = self.state._cv_logger.get() or self.state._local_logger
|
|
4001
4216
|
if not cur_logger:
|
|
4002
4217
|
return NOOP_SPAN_PERMALINK
|
|
4003
4218
|
base_url = cur_logger._get_link_base_url()
|
|
@@ -4050,10 +4265,20 @@ class SpanImpl(Span):
|
|
|
4050
4265
|
self._context_token = self.state.context_manager.set_current_span(self)
|
|
4051
4266
|
|
|
4052
4267
|
def unset_current(self):
|
|
4268
|
+
"""
|
|
4269
|
+
Unset current span context.
|
|
4270
|
+
|
|
4271
|
+
Note: self._context_token may be None if set_current() failed.
|
|
4272
|
+
This is safe - context_manager.unset_current_span() handles None.
|
|
4273
|
+
"""
|
|
4053
4274
|
if self.can_set_current:
|
|
4054
|
-
|
|
4055
|
-
|
|
4056
|
-
|
|
4275
|
+
try:
|
|
4276
|
+
self.state.context_manager.unset_current_span(self._context_token)
|
|
4277
|
+
except Exception as e:
|
|
4278
|
+
logging.debug(f"Failed to unset current span: {e}")
|
|
4279
|
+
finally:
|
|
4280
|
+
# Always clear the token reference
|
|
4281
|
+
self._context_token = None
|
|
4057
4282
|
|
|
4058
4283
|
def __enter__(self) -> Span:
|
|
4059
4284
|
self.set_current()
|
|
@@ -4064,8 +4289,15 @@ class SpanImpl(Span):
|
|
|
4064
4289
|
if exc_type is not None:
|
|
4065
4290
|
self.log_internal(dict(error=stringify_exception(exc_type, exc_value, tb)))
|
|
4066
4291
|
finally:
|
|
4067
|
-
|
|
4068
|
-
|
|
4292
|
+
try:
|
|
4293
|
+
self.unset_current()
|
|
4294
|
+
except Exception as e:
|
|
4295
|
+
logging.debug(f"Failed to unset current in __exit__: {e}")
|
|
4296
|
+
|
|
4297
|
+
try:
|
|
4298
|
+
self.end()
|
|
4299
|
+
except Exception as e:
|
|
4300
|
+
logging.warning(f"Error ending span: {e}")
|
|
4069
4301
|
|
|
4070
4302
|
def _get_parent_info(self):
|
|
4071
4303
|
if self.parent_object_type == SpanObjectTypeV3.PROJECT_LOGS:
|
|
@@ -4396,7 +4628,6 @@ class Dataset(ObjectFetcher[DatasetEvent]):
|
|
|
4396
4628
|
args={
|
|
4397
4629
|
"dataset_id": self.id,
|
|
4398
4630
|
},
|
|
4399
|
-
retries=3,
|
|
4400
4631
|
)
|
|
4401
4632
|
data_summary = DataSummary(new_records=self.new_records, **data_summary_d)
|
|
4402
4633
|
|
|
@@ -4448,20 +4679,24 @@ def render_message(render: Callable[[str], str], message: PromptMessage):
|
|
|
4448
4679
|
if c["type"] == "text":
|
|
4449
4680
|
rendered_content.append({**c, "text": render(c["text"])})
|
|
4450
4681
|
elif c["type"] == "image_url":
|
|
4451
|
-
rendered_content.append(
|
|
4452
|
-
|
|
4453
|
-
|
|
4454
|
-
|
|
4682
|
+
rendered_content.append(
|
|
4683
|
+
{
|
|
4684
|
+
**c,
|
|
4685
|
+
"image_url": {**c["image_url"], "url": render(c["image_url"]["url"])},
|
|
4686
|
+
}
|
|
4687
|
+
)
|
|
4455
4688
|
elif c["type"] == "file":
|
|
4456
|
-
rendered_content.append(
|
|
4457
|
-
|
|
4458
|
-
|
|
4459
|
-
|
|
4460
|
-
|
|
4461
|
-
|
|
4462
|
-
|
|
4463
|
-
|
|
4464
|
-
|
|
4689
|
+
rendered_content.append(
|
|
4690
|
+
{
|
|
4691
|
+
**c,
|
|
4692
|
+
"file": {
|
|
4693
|
+
**c["file"],
|
|
4694
|
+
"file_data": render(c["file"]["file_data"]),
|
|
4695
|
+
**({} if "file_id" not in c["file"] else {"file_id": render(c["file"]["file_id"])}),
|
|
4696
|
+
**({} if "filename" not in c["file"] else {"filename": render(c["file"]["filename"])}),
|
|
4697
|
+
},
|
|
4698
|
+
}
|
|
4699
|
+
)
|
|
4465
4700
|
else:
|
|
4466
4701
|
raise ValueError(f"Unknown content type: {c['type']}")
|
|
4467
4702
|
|