judgeval 0.16.9__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of judgeval might be problematic. Click here for more details.

@@ -40,12 +40,6 @@ def push_prompt_scorer(
40
40
  }
41
41
  )
42
42
  except JudgmentAPIError as e:
43
- if e.status_code == 500:
44
- raise JudgmentAPIError(
45
- status_code=e.status_code,
46
- detail=f"The server is temporarily unavailable. Please try your request again in a few moments. Error details: {e.detail}",
47
- response=e.response,
48
- )
49
43
  raise JudgmentAPIError(
50
44
  status_code=e.status_code,
51
45
  detail=f"Failed to save prompt scorer: {e.detail}",
@@ -75,12 +69,6 @@ def fetch_prompt_scorer(
75
69
  scorer_config.pop("updated_at")
76
70
  return scorer_config
77
71
  except JudgmentAPIError as e:
78
- if e.status_code == 500:
79
- raise JudgmentAPIError(
80
- status_code=e.status_code,
81
- detail=f"The server is temporarily unavailable. Please try your request again in a few moments. Error details: {e.detail}",
82
- response=e.response,
83
- )
84
72
  raise JudgmentAPIError(
85
73
  status_code=e.status_code,
86
74
  detail=f"Failed to fetch prompt scorer '{name}': {e.detail}",
@@ -71,6 +71,7 @@ from judgeval.tracer.processors import (
71
71
  NoOpJudgmentSpanProcessor,
72
72
  )
73
73
  from judgeval.tracer.utils import set_span_attribute, TraceScorerConfig
74
+ from judgeval.utils.project import _resolve_project_id
74
75
 
75
76
  C = TypeVar("C", bound=Callable)
76
77
  Cls = TypeVar("Cls", bound=Type)
@@ -101,6 +102,7 @@ class Tracer(metaclass=SingletonMeta):
101
102
  "judgment_processor",
102
103
  "tracer",
103
104
  "agent_context",
105
+ "customer_id",
104
106
  "_initialized",
105
107
  )
106
108
 
@@ -114,6 +116,7 @@ class Tracer(metaclass=SingletonMeta):
114
116
  judgment_processor: JudgmentSpanProcessor
115
117
  tracer: ABCTracer
116
118
  agent_context: ContextVar[Optional[AgentContext]]
119
+ customer_id: ContextVar[Optional[str]]
117
120
  _initialized: bool
118
121
 
119
122
  def __init__(
@@ -131,6 +134,7 @@ class Tracer(metaclass=SingletonMeta):
131
134
  if not hasattr(self, "_initialized"):
132
135
  self._initialized = False
133
136
  self.agent_context = ContextVar("current_agent_context", default=None)
137
+ self.customer_id = ContextVar("current_customer_id", default=None)
134
138
 
135
139
  self.project_name = project_name
136
140
  self.api_key = expect_api_key(api_key or JUDGMENT_API_KEY)
@@ -155,7 +159,7 @@ class Tracer(metaclass=SingletonMeta):
155
159
 
156
160
  self.judgment_processor = NoOpJudgmentSpanProcessor()
157
161
  if self.enable_monitoring:
158
- project_id = Tracer._resolve_project_id(
162
+ project_id = _resolve_project_id(
159
163
  self.project_name, self.api_key, self.organization_id
160
164
  )
161
165
  if project_id:
@@ -224,20 +228,6 @@ class Tracer(metaclass=SingletonMeta):
224
228
  resource_attributes=resource_attributes,
225
229
  )
226
230
 
227
- @dont_throw
228
- @functools.lru_cache(maxsize=64)
229
- @staticmethod
230
- def _resolve_project_id(
231
- project_name: str, api_key: str, organization_id: str
232
- ) -> str:
233
- """Resolve project_id from project_name using the API."""
234
- client = JudgmentSyncClient(
235
- api_key=api_key,
236
- organization_id=organization_id,
237
- )
238
- response = client.projects_resolve({"project_name": project_name})
239
- return response["project_id"]
240
-
241
231
  def get_current_span(self):
242
232
  return get_current_span()
243
233
 
@@ -247,17 +237,50 @@ class Tracer(metaclass=SingletonMeta):
247
237
  def get_current_agent_context(self):
248
238
  return self.agent_context
249
239
 
240
+ def get_current_customer_context(self):
241
+ return self.customer_id
242
+
250
243
  def get_span_processor(self) -> JudgmentSpanProcessor:
251
244
  """Get the internal span processor of this tracer instance."""
252
245
  return self.judgment_processor
253
246
 
254
247
  def set_customer_id(self, customer_id: str) -> None:
248
+ if not customer_id:
249
+ judgeval_logger.warning("Customer ID is empty, skipping.")
250
+ return
251
+
255
252
  span = self.get_current_span()
253
+
254
+ if not span or not span.is_recording():
255
+ judgeval_logger.warning(
256
+ "No active span found. Customer ID will not be set."
257
+ )
258
+ return
259
+
260
+ if self.get_current_customer_context().get():
261
+ judgeval_logger.warning("Customer ID is already set, skipping.")
262
+ return
263
+
256
264
  if span and span.is_recording():
257
265
  set_span_attribute(span, AttributeKeys.JUDGMENT_CUSTOMER_ID, customer_id)
266
+ self.get_current_customer_context().set(customer_id)
267
+
268
+ self.get_span_processor().set_internal_attribute(
269
+ span_context=span.get_span_context(),
270
+ key=InternalAttributeKeys.IS_CUSTOMER_CONTEXT_OWNER,
271
+ value=True,
272
+ )
273
+
274
+ def _maybe_clear_customer_context(self, span: Span) -> None:
275
+ if self.get_span_processor().get_internal_attribute(
276
+ span_context=span.get_span_context(),
277
+ key=InternalAttributeKeys.IS_CUSTOMER_CONTEXT_OWNER,
278
+ default=False,
279
+ ):
280
+ self.get_current_customer_context().set(None)
258
281
 
259
282
  @dont_throw
260
- def add_agent_attributes_to_span(self, span):
283
+ def _add_agent_attributes_to_span(self, span):
261
284
  """Add agent ID, class name, and instance name to span if they exist in context"""
262
285
  current_agent_context = self.agent_context.get()
263
286
  if not current_agent_context:
@@ -289,7 +312,7 @@ class Tracer(metaclass=SingletonMeta):
289
312
  current_agent_context["is_agent_entry_point"] = False
290
313
 
291
314
  @dont_throw
292
- def record_instance_state(self, record_point: Literal["before", "after"], span):
315
+ def _record_instance_state(self, record_point: Literal["before", "after"], span):
293
316
  current_agent_context = self.agent_context.get()
294
317
 
295
318
  if current_agent_context and current_agent_context.get("track_state"):
@@ -318,6 +341,17 @@ class Tracer(metaclass=SingletonMeta):
318
341
  safe_serialize(attributes),
319
342
  )
320
343
 
344
+ @dont_throw
345
+ def _add_customer_id_to_span(self, span):
346
+ customer_id = self.get_current_customer_context().get()
347
+ if customer_id:
348
+ set_span_attribute(span, AttributeKeys.JUDGMENT_CUSTOMER_ID, customer_id)
349
+
350
+ @dont_throw
351
+ def _inject_judgment_context(self, span):
352
+ self._add_agent_attributes_to_span(span)
353
+ self._add_customer_id_to_span(span)
354
+
321
355
  def _set_pending_trace_eval(
322
356
  self,
323
357
  span: Span,
@@ -398,7 +432,7 @@ class Tracer(metaclass=SingletonMeta):
398
432
  with sync_span_context(
399
433
  self, yield_span_name, yield_attributes, disable_partial_emit=True
400
434
  ) as yield_span:
401
- self.add_agent_attributes_to_span(yield_span)
435
+ self._inject_judgment_context(yield_span)
402
436
 
403
437
  try:
404
438
  value = next(generator)
@@ -442,7 +476,7 @@ class Tracer(metaclass=SingletonMeta):
442
476
  async with async_span_context(
443
477
  self, yield_span_name, yield_attributes, disable_partial_emit=True
444
478
  ) as yield_span:
445
- self.add_agent_attributes_to_span(yield_span)
479
+ self._inject_judgment_context(yield_span)
446
480
 
447
481
  try:
448
482
  value = await async_generator.__anext__()
@@ -484,8 +518,8 @@ class Tracer(metaclass=SingletonMeta):
484
518
  def wrapper(*args, **kwargs):
485
519
  n = name or f.__qualname__
486
520
  with sync_span_context(self, n, attributes) as span:
487
- self.add_agent_attributes_to_span(span)
488
- self.record_instance_state("before", span)
521
+ self._inject_judgment_context(span)
522
+ self._record_instance_state("before", span)
489
523
  try:
490
524
  set_span_attribute(
491
525
  span,
@@ -502,13 +536,14 @@ class Tracer(metaclass=SingletonMeta):
502
536
  except Exception as user_exc:
503
537
  span.record_exception(user_exc)
504
538
  span.set_status(Status(StatusCode.ERROR, str(user_exc)))
539
+ self._maybe_clear_customer_context(span)
505
540
  raise
506
541
 
507
542
  if inspect.isgenerator(result):
508
543
  set_span_attribute(
509
544
  span, AttributeKeys.JUDGMENT_OUTPUT, "<generator>"
510
545
  )
511
- self.record_instance_state("after", span)
546
+ self._record_instance_state("after", span)
512
547
  return self._create_traced_sync_generator(
513
548
  result, span, n, attributes
514
549
  )
@@ -516,7 +551,8 @@ class Tracer(metaclass=SingletonMeta):
516
551
  set_span_attribute(
517
552
  span, AttributeKeys.JUDGMENT_OUTPUT, safe_serialize(result)
518
553
  )
519
- self.record_instance_state("after", span)
554
+ self._record_instance_state("after", span)
555
+ self._maybe_clear_customer_context(span)
520
556
  return result
521
557
 
522
558
  return wrapper
@@ -535,8 +571,8 @@ class Tracer(metaclass=SingletonMeta):
535
571
  n = name or f.__qualname__
536
572
 
537
573
  with sync_span_context(self, n, attributes) as main_span:
538
- self.add_agent_attributes_to_span(main_span)
539
- self.record_instance_state("before", main_span)
574
+ self._inject_judgment_context(main_span)
575
+ self._record_instance_state("before", main_span)
540
576
 
541
577
  try:
542
578
  set_span_attribute(
@@ -556,7 +592,7 @@ class Tracer(metaclass=SingletonMeta):
556
592
  set_span_attribute(
557
593
  main_span, AttributeKeys.JUDGMENT_OUTPUT, "<generator>"
558
594
  )
559
- self.record_instance_state("after", main_span)
595
+ self._record_instance_state("after", main_span)
560
596
 
561
597
  return self._create_traced_sync_generator(
562
598
  generator, main_span, n, attributes
@@ -586,8 +622,8 @@ class Tracer(metaclass=SingletonMeta):
586
622
  async def wrapper(*args, **kwargs):
587
623
  n = name or f.__qualname__
588
624
  async with async_span_context(self, n, attributes) as span:
589
- self.add_agent_attributes_to_span(span)
590
- self.record_instance_state("before", span)
625
+ self._inject_judgment_context(span)
626
+ self._record_instance_state("before", span)
591
627
  try:
592
628
  set_span_attribute(
593
629
  span,
@@ -604,13 +640,14 @@ class Tracer(metaclass=SingletonMeta):
604
640
  except Exception as user_exc:
605
641
  span.record_exception(user_exc)
606
642
  span.set_status(Status(StatusCode.ERROR, str(user_exc)))
643
+ self._maybe_clear_customer_context(span)
607
644
  raise
608
645
 
609
646
  if inspect.isasyncgen(result):
610
647
  set_span_attribute(
611
648
  span, AttributeKeys.JUDGMENT_OUTPUT, "<async_generator>"
612
649
  )
613
- self.record_instance_state("after", span)
650
+ self._record_instance_state("after", span)
614
651
  return self._create_traced_async_generator(
615
652
  result, span, n, attributes
616
653
  )
@@ -618,7 +655,8 @@ class Tracer(metaclass=SingletonMeta):
618
655
  set_span_attribute(
619
656
  span, AttributeKeys.JUDGMENT_OUTPUT, safe_serialize(result)
620
657
  )
621
- self.record_instance_state("after", span)
658
+ self._record_instance_state("after", span)
659
+ self._maybe_clear_customer_context(span)
622
660
  return result
623
661
 
624
662
  return wrapper
@@ -637,8 +675,8 @@ class Tracer(metaclass=SingletonMeta):
637
675
  n = name or f.__qualname__
638
676
 
639
677
  with sync_span_context(self, n, attributes) as main_span:
640
- self.add_agent_attributes_to_span(main_span)
641
- self.record_instance_state("before", main_span)
678
+ self._inject_judgment_context(main_span)
679
+ self._record_instance_state("before", main_span)
642
680
 
643
681
  try:
644
682
  set_span_attribute(
@@ -658,7 +696,7 @@ class Tracer(metaclass=SingletonMeta):
658
696
  set_span_attribute(
659
697
  main_span, AttributeKeys.JUDGMENT_OUTPUT, "<async_generator>"
660
698
  )
661
- self.record_instance_state("after", main_span)
699
+ self._record_instance_state("after", main_span)
662
700
 
663
701
  return self._create_traced_async_generator(
664
702
  async_generator, main_span, n, attributes
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import List
2
+ from typing import List, Dict
3
3
 
4
4
  from opentelemetry.sdk.trace import ReadableSpan
5
5
 
@@ -9,35 +9,51 @@ class ABCSpanStore(ABC):
9
9
  def add(self, *spans: ReadableSpan): ...
10
10
 
11
11
  @abstractmethod
12
- def get(self, id: str) -> ReadableSpan: ...
12
+ def get_all(self) -> List[ReadableSpan]: ...
13
13
 
14
14
  @abstractmethod
15
- def get_all(self) -> List[ReadableSpan]: ...
15
+ def get_by_trace_id(self, trace_id: str) -> List[ReadableSpan]: ...
16
+
17
+ @abstractmethod
18
+ def clear_trace(self, trace_id: str): ...
16
19
 
17
20
 
18
21
  class SpanStore(ABCSpanStore):
19
- __slots__ = ("spans",)
22
+ __slots__ = ("_spans_by_trace",)
20
23
 
21
- spans: List[ReadableSpan]
24
+ _spans_by_trace: Dict[str, List[ReadableSpan]]
22
25
 
23
26
  def __init__(self):
24
- self.spans = []
27
+ self._spans_by_trace = {}
25
28
 
26
29
  def add(self, *spans: ReadableSpan):
27
- self.spans.extend(spans)
28
-
29
- def get(self, id: str) -> ReadableSpan:
30
- for span in self.spans:
30
+ for span in spans:
31
31
  context = span.get_span_context()
32
32
  if context is None:
33
33
  continue
34
- if context.span_id == id:
35
- return span
36
-
37
- raise ValueError(f"Span with id {id} not found")
34
+ # Convert trace_id to hex string per OTEL spec
35
+ trace_id = format(context.trace_id, "032x")
36
+ if trace_id not in self._spans_by_trace:
37
+ self._spans_by_trace[trace_id] = []
38
+ self._spans_by_trace[trace_id].append(span)
38
39
 
39
40
  def get_all(self) -> List[ReadableSpan]:
40
- return self.spans
41
+ all_spans = []
42
+ for spans in self._spans_by_trace.values():
43
+ all_spans.extend(spans)
44
+ return all_spans
45
+
46
+ def get_by_trace_id(self, trace_id: str) -> List[ReadableSpan]:
47
+ """Get all spans for a specific trace ID (32-char hex string)."""
48
+ return self._spans_by_trace.get(trace_id, [])
49
+
50
+ def clear_trace(self, trace_id: str):
51
+ """Clear all spans for a specific trace ID (32-char hex string)."""
52
+ if trace_id in self._spans_by_trace:
53
+ del self._spans_by_trace[trace_id]
41
54
 
42
55
  def __repr__(self) -> str:
43
- return f"SpanStore(spans={self.spans})"
56
+ total_spans = sum(len(spans) for spans in self._spans_by_trace.values())
57
+ return (
58
+ f"SpanStore(traces={len(self._spans_by_trace)}, total_spans={total_spans})"
59
+ )
judgeval/tracer/keys.py CHANGED
@@ -51,6 +51,7 @@ class InternalAttributeKeys(str, Enum):
51
51
 
52
52
  DISABLE_PARTIAL_EMIT = "disable_partial_emit"
53
53
  CANCELLED = "cancelled"
54
+ IS_CUSTOMER_CONTEXT_OWNER = "is_customer_context_owner"
54
55
 
55
56
 
56
57
  class ResourceKeys(str, Enum):
@@ -89,7 +89,7 @@ def _wrap_non_streaming_sync(
89
89
  ctx["span"] = tracer.get_tracer().start_span(
90
90
  "ANTHROPIC_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
91
91
  )
92
- tracer.add_agent_attributes_to_span(ctx["span"])
92
+ tracer._inject_judgment_context(ctx["span"])
93
93
  set_span_attribute(
94
94
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
95
95
  )
@@ -163,7 +163,7 @@ def _wrap_streaming_sync(
163
163
  ctx["span"] = tracer.get_tracer().start_span(
164
164
  "ANTHROPIC_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
165
165
  )
166
- tracer.add_agent_attributes_to_span(ctx["span"])
166
+ tracer._inject_judgment_context(ctx["span"])
167
167
  set_span_attribute(
168
168
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
169
169
  )
@@ -273,7 +273,7 @@ def _wrap_non_streaming_async(
273
273
  ctx["span"] = tracer.get_tracer().start_span(
274
274
  "ANTHROPIC_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
275
275
  )
276
- tracer.add_agent_attributes_to_span(ctx["span"])
276
+ tracer._inject_judgment_context(ctx["span"])
277
277
  set_span_attribute(
278
278
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
279
279
  )
@@ -348,7 +348,7 @@ def _wrap_streaming_async(
348
348
  ctx["span"] = tracer.get_tracer().start_span(
349
349
  "ANTHROPIC_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
350
350
  )
351
- tracer.add_agent_attributes_to_span(ctx["span"])
351
+ tracer._inject_judgment_context(ctx["span"])
352
352
  set_span_attribute(
353
353
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
354
354
  )
@@ -37,7 +37,7 @@ def wrap_messages_stream_sync(tracer: Tracer, client: Anthropic) -> None:
37
37
  ctx["span"] = tracer.get_tracer().start_span(
38
38
  "ANTHROPIC_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
39
39
  )
40
- tracer.add_agent_attributes_to_span(ctx["span"])
40
+ tracer._inject_judgment_context(ctx["span"])
41
41
  set_span_attribute(
42
42
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
43
43
  )
@@ -183,7 +183,7 @@ def wrap_messages_stream_async(tracer: Tracer, client: AsyncAnthropic) -> None:
183
183
  ctx["span"] = tracer.get_tracer().start_span(
184
184
  "ANTHROPIC_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
185
185
  )
186
- tracer.add_agent_attributes_to_span(ctx["span"])
186
+ tracer._inject_judgment_context(ctx["span"])
187
187
  set_span_attribute(
188
188
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
189
189
  )
@@ -57,7 +57,7 @@ def wrap_generate_content_sync(tracer: Tracer, client: Client) -> None:
57
57
  ctx["span"] = tracer.get_tracer().start_span(
58
58
  "GOOGLE_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
59
59
  )
60
- tracer.add_agent_attributes_to_span(ctx["span"])
60
+ tracer._inject_judgment_context(ctx["span"])
61
61
  set_span_attribute(
62
62
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
63
63
  )
@@ -39,7 +39,7 @@ def _wrap_beta_non_streaming_sync(
39
39
  ctx["span"] = tracer.get_tracer().start_span(
40
40
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
41
41
  )
42
- tracer.add_agent_attributes_to_span(ctx["span"])
42
+ tracer._inject_judgment_context(ctx["span"])
43
43
  set_span_attribute(
44
44
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
45
45
  )
@@ -122,7 +122,7 @@ def _wrap_beta_non_streaming_async(
122
122
  ctx["span"] = tracer.get_tracer().start_span(
123
123
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
124
124
  )
125
- tracer.add_agent_attributes_to_span(ctx["span"])
125
+ tracer._inject_judgment_context(ctx["span"])
126
126
  set_span_attribute(
127
127
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
128
128
  )
@@ -62,7 +62,7 @@ def _wrap_non_streaming_sync(
62
62
  ctx["span"] = tracer.get_tracer().start_span(
63
63
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
64
64
  )
65
- tracer.add_agent_attributes_to_span(ctx["span"])
65
+ tracer._inject_judgment_context(ctx["span"])
66
66
  set_span_attribute(
67
67
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
68
68
  )
@@ -139,7 +139,7 @@ def _wrap_streaming_sync(
139
139
  ctx["span"] = tracer.get_tracer().start_span(
140
140
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
141
141
  )
142
- tracer.add_agent_attributes_to_span(ctx["span"])
142
+ tracer._inject_judgment_context(ctx["span"])
143
143
  set_span_attribute(
144
144
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
145
145
  )
@@ -258,7 +258,7 @@ def _wrap_non_streaming_async(
258
258
  ctx["span"] = tracer.get_tracer().start_span(
259
259
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
260
260
  )
261
- tracer.add_agent_attributes_to_span(ctx["span"])
261
+ tracer._inject_judgment_context(ctx["span"])
262
262
  set_span_attribute(
263
263
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
264
264
  )
@@ -336,7 +336,7 @@ def _wrap_streaming_async(
336
336
  ctx["span"] = tracer.get_tracer().start_span(
337
337
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
338
338
  )
339
- tracer.add_agent_attributes_to_span(ctx["span"])
339
+ tracer._inject_judgment_context(ctx["span"])
340
340
  set_span_attribute(
341
341
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
342
342
  )
@@ -56,7 +56,7 @@ def _wrap_responses_non_streaming_sync(
56
56
  ctx["span"] = tracer.get_tracer().start_span(
57
57
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
58
58
  )
59
- tracer.add_agent_attributes_to_span(ctx["span"])
59
+ tracer._inject_judgment_context(ctx["span"])
60
60
  set_span_attribute(
61
61
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
62
62
  )
@@ -131,7 +131,7 @@ def _wrap_responses_streaming_sync(
131
131
  ctx["span"] = tracer.get_tracer().start_span(
132
132
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
133
133
  )
134
- tracer.add_agent_attributes_to_span(ctx["span"])
134
+ tracer._inject_judgment_context(ctx["span"])
135
135
  set_span_attribute(
136
136
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
137
137
  )
@@ -260,7 +260,7 @@ def _wrap_responses_non_streaming_async(
260
260
  ctx["span"] = tracer.get_tracer().start_span(
261
261
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
262
262
  )
263
- tracer.add_agent_attributes_to_span(ctx["span"])
263
+ tracer._inject_judgment_context(ctx["span"])
264
264
  set_span_attribute(
265
265
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
266
266
  )
@@ -335,7 +335,7 @@ def _wrap_responses_streaming_async(
335
335
  ctx["span"] = tracer.get_tracer().start_span(
336
336
  "OPENAI_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
337
337
  )
338
- tracer.add_agent_attributes_to_span(ctx["span"])
338
+ tracer._inject_judgment_context(ctx["span"])
339
339
  set_span_attribute(
340
340
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
341
341
  )
@@ -63,7 +63,7 @@ def _wrap_non_streaming_sync(
63
63
  ctx["span"] = tracer.get_tracer().start_span(
64
64
  "TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
65
65
  )
66
- tracer.add_agent_attributes_to_span(ctx["span"])
66
+ tracer._inject_judgment_context(ctx["span"])
67
67
  set_span_attribute(
68
68
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
69
69
  )
@@ -133,7 +133,7 @@ def _wrap_streaming_sync(
133
133
  ctx["span"] = tracer.get_tracer().start_span(
134
134
  "TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
135
135
  )
136
- tracer.add_agent_attributes_to_span(ctx["span"])
136
+ tracer._inject_judgment_context(ctx["span"])
137
137
  set_span_attribute(
138
138
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
139
139
  )
@@ -239,7 +239,7 @@ def _wrap_non_streaming_async(
239
239
  ctx["span"] = tracer.get_tracer().start_span(
240
240
  "TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
241
241
  )
242
- tracer.add_agent_attributes_to_span(ctx["span"])
242
+ tracer._inject_judgment_context(ctx["span"])
243
243
  set_span_attribute(
244
244
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
245
245
  )
@@ -310,7 +310,7 @@ def _wrap_streaming_async(
310
310
  ctx["span"] = tracer.get_tracer().start_span(
311
311
  "TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
312
312
  )
313
- tracer.add_agent_attributes_to_span(ctx["span"])
313
+ tracer._inject_judgment_context(ctx["span"])
314
314
  set_span_attribute(
315
315
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
316
316
  )
@@ -1,5 +1,14 @@
1
1
  from judgeval.trainer.trainer import JudgmentTrainer
2
2
  from judgeval.trainer.config import TrainerConfig, ModelConfig
3
3
  from judgeval.trainer.trainable_model import TrainableModel
4
+ from judgeval.trainer.base_trainer import BaseTrainer
5
+ from judgeval.trainer.fireworks_trainer import FireworksTrainer
4
6
 
5
- __all__ = ["JudgmentTrainer", "TrainerConfig", "ModelConfig", "TrainableModel"]
7
+ __all__ = [
8
+ "JudgmentTrainer",
9
+ "TrainerConfig",
10
+ "ModelConfig",
11
+ "TrainableModel",
12
+ "BaseTrainer",
13
+ "FireworksTrainer",
14
+ ]