judgeval 0.16.9__py3-none-any.whl → 0.22.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of judgeval might be problematic. Click here for more details.

Files changed (37) hide show
  1. judgeval/__init__.py +32 -2
  2. judgeval/api/__init__.py +108 -0
  3. judgeval/api/api_types.py +76 -15
  4. judgeval/cli.py +16 -1
  5. judgeval/data/judgment_types.py +76 -20
  6. judgeval/dataset/__init__.py +11 -2
  7. judgeval/env.py +2 -11
  8. judgeval/evaluation/__init__.py +4 -0
  9. judgeval/prompt/__init__.py +330 -0
  10. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +1 -13
  11. judgeval/tracer/__init__.py +371 -257
  12. judgeval/tracer/constants.py +1 -1
  13. judgeval/tracer/exporters/store.py +32 -16
  14. judgeval/tracer/keys.py +11 -9
  15. judgeval/tracer/llm/llm_anthropic/messages.py +38 -26
  16. judgeval/tracer/llm/llm_anthropic/messages_stream.py +14 -14
  17. judgeval/tracer/llm/llm_google/generate_content.py +9 -7
  18. judgeval/tracer/llm/llm_openai/beta_chat_completions.py +38 -14
  19. judgeval/tracer/llm/llm_openai/chat_completions.py +90 -26
  20. judgeval/tracer/llm/llm_openai/responses.py +88 -26
  21. judgeval/tracer/llm/llm_openai/utils.py +42 -0
  22. judgeval/tracer/llm/llm_together/chat_completions.py +26 -18
  23. judgeval/tracer/managers.py +4 -0
  24. judgeval/trainer/__init__.py +10 -1
  25. judgeval/trainer/base_trainer.py +122 -0
  26. judgeval/trainer/config.py +1 -1
  27. judgeval/trainer/fireworks_trainer.py +396 -0
  28. judgeval/trainer/trainer.py +52 -387
  29. judgeval/utils/guards.py +9 -5
  30. judgeval/utils/project.py +15 -0
  31. judgeval/utils/serialize.py +2 -2
  32. judgeval/version.py +1 -1
  33. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/METADATA +2 -3
  34. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/RECORD +37 -32
  35. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/WHEEL +0 -0
  36. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/entry_points.txt +0 -0
  37. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -17,9 +17,12 @@ from typing import (
17
17
  overload,
18
18
  Literal,
19
19
  TypedDict,
20
- Iterator,
21
- AsyncIterator,
20
+ Generator,
21
+ AsyncGenerator,
22
+ Iterable,
22
23
  )
24
+ import contextvars
25
+ import asyncio
23
26
  from functools import partial
24
27
  from warnings import warn
25
28
 
@@ -71,6 +74,8 @@ from judgeval.tracer.processors import (
71
74
  NoOpJudgmentSpanProcessor,
72
75
  )
73
76
  from judgeval.tracer.utils import set_span_attribute, TraceScorerConfig
77
+ from judgeval.utils.project import _resolve_project_id
78
+ from opentelemetry.trace import use_span
74
79
 
75
80
  C = TypeVar("C", bound=Callable)
76
81
  Cls = TypeVar("Cls", bound=Type)
@@ -101,11 +106,12 @@ class Tracer(metaclass=SingletonMeta):
101
106
  "judgment_processor",
102
107
  "tracer",
103
108
  "agent_context",
109
+ "customer_id",
104
110
  "_initialized",
105
111
  )
106
112
 
107
- api_key: str
108
- organization_id: str
113
+ api_key: str | None
114
+ organization_id: str | None
109
115
  project_name: str
110
116
  enable_monitoring: bool
111
117
  enable_evaluation: bool
@@ -114,6 +120,7 @@ class Tracer(metaclass=SingletonMeta):
114
120
  judgment_processor: JudgmentSpanProcessor
115
121
  tracer: ABCTracer
116
122
  agent_context: ContextVar[Optional[AgentContext]]
123
+ customer_id: ContextVar[Optional[str]]
117
124
  _initialized: bool
118
125
 
119
126
  def __init__(
@@ -121,8 +128,8 @@ class Tracer(metaclass=SingletonMeta):
121
128
  /,
122
129
  *,
123
130
  project_name: str,
124
- api_key: Optional[str] = None,
125
- organization_id: Optional[str] = None,
131
+ api_key: str | None = None,
132
+ organization_id: str | None = None,
126
133
  enable_monitoring: bool = JUDGMENT_ENABLE_MONITORING.lower() == "true",
127
134
  enable_evaluation: bool = JUDGMENT_ENABLE_EVALUATIONS.lower() == "true",
128
135
  resource_attributes: Optional[Dict[str, Any]] = None,
@@ -131,6 +138,7 @@ class Tracer(metaclass=SingletonMeta):
131
138
  if not hasattr(self, "_initialized"):
132
139
  self._initialized = False
133
140
  self.agent_context = ContextVar("current_agent_context", default=None)
141
+ self.customer_id = ContextVar("current_customer_id", default=None)
134
142
 
135
143
  self.project_name = project_name
136
144
  self.api_key = expect_api_key(api_key or JUDGMENT_API_KEY)
@@ -141,10 +149,14 @@ class Tracer(metaclass=SingletonMeta):
141
149
  self.enable_evaluation = enable_evaluation
142
150
  self.resource_attributes = resource_attributes
143
151
 
144
- self.api_client = JudgmentSyncClient(
145
- api_key=self.api_key,
146
- organization_id=self.organization_id,
147
- )
152
+ if self.api_key and self.organization_id:
153
+ self.api_client = JudgmentSyncClient(
154
+ api_key=self.api_key, organization_id=self.organization_id
155
+ )
156
+ else:
157
+ judgeval_logger.error(
158
+ "API Key or Organization ID is not set. You must set them in the environment variables to use the tracer."
159
+ )
148
160
 
149
161
  if initialize:
150
162
  self.initialize()
@@ -155,10 +167,10 @@ class Tracer(metaclass=SingletonMeta):
155
167
 
156
168
  self.judgment_processor = NoOpJudgmentSpanProcessor()
157
169
  if self.enable_monitoring:
158
- project_id = Tracer._resolve_project_id(
170
+ project_id = _resolve_project_id(
159
171
  self.project_name, self.api_key, self.organization_id
160
172
  )
161
- if project_id:
173
+ if self.api_key and self.organization_id and project_id:
162
174
  self.judgment_processor = self.get_processor(
163
175
  tracer=self,
164
176
  project_name=self.project_name,
@@ -173,9 +185,10 @@ class Tracer(metaclass=SingletonMeta):
173
185
  provider.add_span_processor(self.judgment_processor)
174
186
  set_tracer_provider(provider)
175
187
  else:
176
- judgeval_logger.error(
177
- f"Failed to resolve or autocreate project {self.project_name}, please create it first at https://app.judgmentlabs.ai/org/{self.organization_id}/projects. Skipping Judgment export."
178
- )
188
+ if self.api_key and self.organization_id:
189
+ judgeval_logger.error(
190
+ f"Failed to resolve or autocreate project {self.project_name}, please create it first at https://app.judgmentlabs.ai/org/{self.organization_id}/projects. Skipping Judgment export."
191
+ )
179
192
 
180
193
  self.tracer = get_tracer_provider().get_tracer(
181
194
  JUDGEVAL_TRACER_INSTRUMENTING_MODULE_NAME,
@@ -194,10 +207,19 @@ class Tracer(metaclass=SingletonMeta):
194
207
  ):
195
208
  from judgeval.tracer.exporters import JudgmentSpanExporter
196
209
 
210
+ api_key = api_key or JUDGMENT_API_KEY
211
+ organization_id = organization_id or JUDGMENT_ORG_ID
212
+
213
+ if not api_key or not organization_id:
214
+ judgeval_logger.error(
215
+ "API Key or Organization ID is not set. You must set them in the environment variables to use the tracer."
216
+ )
217
+ return None
218
+
197
219
  return JudgmentSpanExporter(
198
220
  endpoint=url_for("/otel/v1/traces"),
199
- api_key=api_key or JUDGMENT_API_KEY,
200
- organization_id=organization_id or JUDGMENT_ORG_ID,
221
+ api_key=api_key,
222
+ organization_id=organization_id,
201
223
  project_id=project_id,
202
224
  )
203
225
 
@@ -213,31 +235,24 @@ class Tracer(metaclass=SingletonMeta):
213
235
  resource_attributes: Optional[Dict[str, Any]] = None,
214
236
  ) -> JudgmentSpanProcessor:
215
237
  """Create a JudgmentSpanProcessor using the correct constructor."""
238
+ api_key = api_key or JUDGMENT_API_KEY
239
+ organization_id = organization_id or JUDGMENT_ORG_ID
240
+ if not api_key or not organization_id:
241
+ judgeval_logger.error(
242
+ "API Key or Organization ID is not set. You must set them in the environment variables to use the tracer."
243
+ )
244
+ return NoOpJudgmentSpanProcessor()
216
245
  return JudgmentSpanProcessor(
217
246
  tracer,
218
247
  project_name,
219
248
  project_id,
220
- api_key or JUDGMENT_API_KEY,
221
- organization_id or JUDGMENT_ORG_ID,
249
+ api_key,
250
+ organization_id,
222
251
  max_queue_size=max_queue_size,
223
252
  export_timeout_millis=export_timeout_millis,
224
253
  resource_attributes=resource_attributes,
225
254
  )
226
255
 
227
- @dont_throw
228
- @functools.lru_cache(maxsize=64)
229
- @staticmethod
230
- def _resolve_project_id(
231
- project_name: str, api_key: str, organization_id: str
232
- ) -> str:
233
- """Resolve project_id from project_name using the API."""
234
- client = JudgmentSyncClient(
235
- api_key=api_key,
236
- organization_id=organization_id,
237
- )
238
- response = client.projects_resolve({"project_name": project_name})
239
- return response["project_id"]
240
-
241
256
  def get_current_span(self):
242
257
  return get_current_span()
243
258
 
@@ -247,17 +262,51 @@ class Tracer(metaclass=SingletonMeta):
247
262
  def get_current_agent_context(self):
248
263
  return self.agent_context
249
264
 
265
+ def get_current_customer_context(self):
266
+ return self.customer_id
267
+
250
268
  def get_span_processor(self) -> JudgmentSpanProcessor:
251
269
  """Get the internal span processor of this tracer instance."""
252
270
  return self.judgment_processor
253
271
 
272
+ @dont_throw
254
273
  def set_customer_id(self, customer_id: str) -> None:
274
+ if not customer_id:
275
+ judgeval_logger.warning("Customer ID is empty, skipping.")
276
+ return
277
+
255
278
  span = self.get_current_span()
279
+
280
+ if not span or not span.is_recording():
281
+ judgeval_logger.warning(
282
+ "No active span found. Customer ID will not be set."
283
+ )
284
+ return
285
+
286
+ if self.get_current_customer_context().get():
287
+ judgeval_logger.warning("Customer ID is already set, skipping.")
288
+ return
289
+
256
290
  if span and span.is_recording():
257
291
  set_span_attribute(span, AttributeKeys.JUDGMENT_CUSTOMER_ID, customer_id)
292
+ self.get_current_customer_context().set(customer_id)
293
+
294
+ self.get_span_processor().set_internal_attribute(
295
+ span_context=span.get_span_context(),
296
+ key=InternalAttributeKeys.IS_CUSTOMER_CONTEXT_OWNER,
297
+ value=True,
298
+ )
299
+
300
+ def _maybe_clear_customer_context(self, span: Span) -> None:
301
+ if self.get_span_processor().get_internal_attribute(
302
+ span_context=span.get_span_context(),
303
+ key=InternalAttributeKeys.IS_CUSTOMER_CONTEXT_OWNER,
304
+ default=False,
305
+ ):
306
+ self.get_current_customer_context().set(None)
258
307
 
259
308
  @dont_throw
260
- def add_agent_attributes_to_span(self, span):
309
+ def _add_agent_attributes_to_span(self, span):
261
310
  """Add agent ID, class name, and instance name to span if they exist in context"""
262
311
  current_agent_context = self.agent_context.get()
263
312
  if not current_agent_context:
@@ -289,7 +338,7 @@ class Tracer(metaclass=SingletonMeta):
289
338
  current_agent_context["is_agent_entry_point"] = False
290
339
 
291
340
  @dont_throw
292
- def record_instance_state(self, record_point: Literal["before", "after"], span):
341
+ def _record_instance_state(self, record_point: Literal["before", "after"], span):
293
342
  current_agent_context = self.agent_context.get()
294
343
 
295
344
  if current_agent_context and current_agent_context.get("track_state"):
@@ -318,6 +367,17 @@ class Tracer(metaclass=SingletonMeta):
318
367
  safe_serialize(attributes),
319
368
  )
320
369
 
370
+ @dont_throw
371
+ def _add_customer_id_to_span(self, span):
372
+ customer_id = self.get_current_customer_context().get()
373
+ if customer_id:
374
+ set_span_attribute(span, AttributeKeys.JUDGMENT_CUSTOMER_ID, customer_id)
375
+
376
+ @dont_throw
377
+ def _inject_judgment_context(self, span):
378
+ self._add_agent_attributes_to_span(span)
379
+ self._add_customer_id_to_span(span)
380
+
321
381
  def _set_pending_trace_eval(
322
382
  self,
323
383
  span: Span,
@@ -381,91 +441,37 @@ class Tracer(metaclass=SingletonMeta):
381
441
 
382
442
  def _create_traced_sync_generator(
383
443
  self,
384
- generator: Iterator[Any],
444
+ generator: Generator,
385
445
  main_span: Span,
386
- base_name: str,
387
- attributes: Optional[Dict[str, Any]],
446
+ disable_generator_yield_span: bool = False,
388
447
  ):
389
448
  """Create a traced synchronous generator that wraps each yield in a span."""
390
- try:
391
- while True:
392
- yield_span_name = f"{base_name}_yield"
393
- yield_attributes = {
394
- AttributeKeys.JUDGMENT_SPAN_KIND: "generator_yield",
395
- **(attributes or {}),
396
- }
397
-
398
- with sync_span_context(
399
- self, yield_span_name, yield_attributes, disable_partial_emit=True
400
- ) as yield_span:
401
- self.add_agent_attributes_to_span(yield_span)
402
-
403
- try:
404
- value = next(generator)
405
- except StopIteration:
406
- # Mark span as cancelled so it won't be exported
407
- self.judgment_processor.set_internal_attribute(
408
- span_context=yield_span.get_span_context(),
409
- key=InternalAttributeKeys.CANCELLED,
410
- value=True,
411
- )
412
- break
413
-
414
- set_span_attribute(
415
- yield_span,
416
- AttributeKeys.JUDGMENT_OUTPUT,
417
- safe_serialize(value),
418
- )
419
-
420
- yield value
421
- except Exception as e:
422
- main_span.record_exception(e)
423
- main_span.set_status(Status(StatusCode.ERROR, str(e)))
424
- raise
449
+ preserved_context = contextvars.copy_context()
450
+ return _ContextPreservedSyncGeneratorWrapper(
451
+ self,
452
+ generator,
453
+ preserved_context,
454
+ main_span,
455
+ None,
456
+ disable_generator_yield_span,
457
+ )
425
458
 
426
- async def _create_traced_async_generator(
459
+ def _create_traced_async_generator(
427
460
  self,
428
- async_generator: AsyncIterator[Any],
461
+ async_generator: AsyncGenerator,
429
462
  main_span: Span,
430
- base_name: str,
431
- attributes: Optional[Dict[str, Any]],
463
+ disable_generator_yield_span: bool = False,
432
464
  ):
433
465
  """Create a traced asynchronous generator that wraps each yield in a span."""
434
- try:
435
- while True:
436
- yield_span_name = f"{base_name}_yield"
437
- yield_attributes = {
438
- AttributeKeys.JUDGMENT_SPAN_KIND: "async_generator_yield",
439
- **(attributes or {}),
440
- }
441
-
442
- async with async_span_context(
443
- self, yield_span_name, yield_attributes, disable_partial_emit=True
444
- ) as yield_span:
445
- self.add_agent_attributes_to_span(yield_span)
446
-
447
- try:
448
- value = await async_generator.__anext__()
449
- except StopAsyncIteration:
450
- # Mark span as cancelled so it won't be exported
451
- self.judgment_processor.set_internal_attribute(
452
- span_context=yield_span.get_span_context(),
453
- key=InternalAttributeKeys.CANCELLED,
454
- value=True,
455
- )
456
- break
457
-
458
- set_span_attribute(
459
- yield_span,
460
- AttributeKeys.JUDGMENT_OUTPUT,
461
- safe_serialize(value),
462
- )
463
-
464
- yield value
465
- except Exception as e:
466
- main_span.record_exception(e)
467
- main_span.set_status(Status(StatusCode.ERROR, str(e)))
468
- raise
466
+ preserved_context = contextvars.copy_context()
467
+ return _ContextPreservedAsyncGeneratorWrapper(
468
+ self,
469
+ async_generator,
470
+ preserved_context,
471
+ main_span,
472
+ None,
473
+ disable_generator_yield_span,
474
+ )
469
475
 
470
476
  def _wrap_sync(
471
477
  self,
@@ -473,19 +479,16 @@ class Tracer(metaclass=SingletonMeta):
473
479
  name: Optional[str],
474
480
  attributes: Optional[Dict[str, Any]],
475
481
  scorer_config: TraceScorerConfig | None = None,
482
+ disable_generator_yield_span: bool = False,
476
483
  ):
477
- # Check if this is a generator function - if so, wrap it specially
478
- if inspect.isgeneratorfunction(f):
479
- return self._wrap_sync_generator_function(
480
- f, name, attributes, scorer_config
481
- )
482
-
483
484
  @functools.wraps(f)
484
485
  def wrapper(*args, **kwargs):
485
486
  n = name or f.__qualname__
486
487
  with sync_span_context(self, n, attributes) as span:
487
- self.add_agent_attributes_to_span(span)
488
- self.record_instance_state("before", span)
488
+ is_return_type_generator = False
489
+
490
+ self._inject_judgment_context(span)
491
+ self._record_instance_state("before", span)
489
492
  try:
490
493
  set_span_attribute(
491
494
  span,
@@ -499,73 +502,40 @@ class Tracer(metaclass=SingletonMeta):
499
502
  self._set_pending_trace_eval(span, scorer_config, args, kwargs)
500
503
 
501
504
  result = f(*args, **kwargs)
502
- except Exception as user_exc:
503
- span.record_exception(user_exc)
504
- span.set_status(Status(StatusCode.ERROR, str(user_exc)))
505
- raise
506
-
507
- if inspect.isgenerator(result):
508
- set_span_attribute(
509
- span, AttributeKeys.JUDGMENT_OUTPUT, "<generator>"
510
- )
511
- self.record_instance_state("after", span)
512
- return self._create_traced_sync_generator(
513
- result, span, n, attributes
514
- )
515
- else:
516
- set_span_attribute(
517
- span, AttributeKeys.JUDGMENT_OUTPUT, safe_serialize(result)
518
- )
519
- self.record_instance_state("after", span)
520
- return result
521
-
522
- return wrapper
523
-
524
- def _wrap_sync_generator_function(
525
- self,
526
- f: Callable,
527
- name: Optional[str],
528
- attributes: Optional[Dict[str, Any]],
529
- scorer_config: TraceScorerConfig | None = None,
530
- ):
531
- """Wrap a generator function to trace nested function calls within each yield."""
532
-
533
- @functools.wraps(f)
534
- def wrapper(*args, **kwargs):
535
- n = name or f.__qualname__
536
505
 
537
- with sync_span_context(self, n, attributes) as main_span:
538
- self.add_agent_attributes_to_span(main_span)
539
- self.record_instance_state("before", main_span)
540
-
541
- try:
542
- set_span_attribute(
543
- main_span,
544
- AttributeKeys.JUDGMENT_INPUT,
545
- safe_serialize(format_inputs(f, args, kwargs)),
546
- )
547
-
548
- self.judgment_processor.emit_partial()
549
-
550
- if scorer_config:
551
- self._set_pending_trace_eval(
552
- main_span, scorer_config, args, kwargs
506
+ if inspect.isgenerator(result):
507
+ is_return_type_generator = True
508
+ set_span_attribute(
509
+ span, AttributeKeys.JUDGMENT_OUTPUT, "<generator>"
553
510
  )
554
-
555
- generator = f(*args, **kwargs)
556
- set_span_attribute(
557
- main_span, AttributeKeys.JUDGMENT_OUTPUT, "<generator>"
558
- )
559
- self.record_instance_state("after", main_span)
560
-
561
- return self._create_traced_sync_generator(
562
- generator, main_span, n, attributes
563
- )
564
-
511
+ self._record_instance_state("after", span)
512
+ return self._create_traced_sync_generator(
513
+ result, span, disable_generator_yield_span
514
+ )
515
+ elif inspect.isasyncgen(result):
516
+ is_return_type_generator = True
517
+ set_span_attribute(
518
+ span, AttributeKeys.JUDGMENT_OUTPUT, "<async_generator>"
519
+ )
520
+ self._record_instance_state("after", span)
521
+ return self._create_traced_async_generator(
522
+ result, span, disable_generator_yield_span
523
+ )
524
+ else:
525
+ set_span_attribute(
526
+ span, AttributeKeys.JUDGMENT_OUTPUT, safe_serialize(result)
527
+ )
528
+ self._record_instance_state("after", span)
529
+ self._maybe_clear_customer_context(span)
530
+ return result
565
531
  except Exception as user_exc:
566
- main_span.record_exception(user_exc)
567
- main_span.set_status(Status(StatusCode.ERROR, str(user_exc)))
532
+ span.record_exception(user_exc)
533
+ span.set_status(Status(StatusCode.ERROR, str(user_exc)))
534
+ self._maybe_clear_customer_context(span)
568
535
  raise
536
+ finally:
537
+ if not is_return_type_generator:
538
+ span.end()
569
539
 
570
540
  return wrapper
571
541
 
@@ -575,19 +545,15 @@ class Tracer(metaclass=SingletonMeta):
575
545
  name: Optional[str],
576
546
  attributes: Optional[Dict[str, Any]],
577
547
  scorer_config: TraceScorerConfig | None = None,
548
+ disable_generator_yield_span: bool = False,
578
549
  ):
579
- # Check if this is an async generator function - if so, wrap it specially
580
- if inspect.isasyncgenfunction(f):
581
- return self._wrap_async_generator_function(
582
- f, name, attributes, scorer_config
583
- )
584
-
585
550
  @functools.wraps(f)
586
551
  async def wrapper(*args, **kwargs):
587
552
  n = name or f.__qualname__
588
553
  async with async_span_context(self, n, attributes) as span:
589
- self.add_agent_attributes_to_span(span)
590
- self.record_instance_state("before", span)
554
+ is_return_type_generator = False
555
+ self._inject_judgment_context(span)
556
+ self._record_instance_state("before", span)
591
557
  try:
592
558
  set_span_attribute(
593
559
  span,
@@ -601,73 +567,39 @@ class Tracer(metaclass=SingletonMeta):
601
567
  self._set_pending_trace_eval(span, scorer_config, args, kwargs)
602
568
 
603
569
  result = await f(*args, **kwargs)
570
+ if inspect.isasyncgen(result):
571
+ is_return_type_generator = True
572
+ set_span_attribute(
573
+ span, AttributeKeys.JUDGMENT_OUTPUT, "<async_generator>"
574
+ )
575
+ self._record_instance_state("after", span)
576
+ return self._create_traced_async_generator(
577
+ result, span, disable_generator_yield_span
578
+ )
579
+ elif inspect.isgenerator(result):
580
+ is_return_type_generator = True
581
+ set_span_attribute(
582
+ span, AttributeKeys.JUDGMENT_OUTPUT, "<generator>"
583
+ )
584
+ self._record_instance_state("after", span)
585
+ return self._create_traced_sync_generator(
586
+ result, span, disable_generator_yield_span
587
+ )
588
+ else:
589
+ set_span_attribute(
590
+ span, AttributeKeys.JUDGMENT_OUTPUT, safe_serialize(result)
591
+ )
592
+ self._record_instance_state("after", span)
593
+ self._maybe_clear_customer_context(span)
594
+ return result
604
595
  except Exception as user_exc:
605
596
  span.record_exception(user_exc)
606
597
  span.set_status(Status(StatusCode.ERROR, str(user_exc)))
598
+ self._maybe_clear_customer_context(span)
607
599
  raise
608
-
609
- if inspect.isasyncgen(result):
610
- set_span_attribute(
611
- span, AttributeKeys.JUDGMENT_OUTPUT, "<async_generator>"
612
- )
613
- self.record_instance_state("after", span)
614
- return self._create_traced_async_generator(
615
- result, span, n, attributes
616
- )
617
- else:
618
- set_span_attribute(
619
- span, AttributeKeys.JUDGMENT_OUTPUT, safe_serialize(result)
620
- )
621
- self.record_instance_state("after", span)
622
- return result
623
-
624
- return wrapper
625
-
626
- def _wrap_async_generator_function(
627
- self,
628
- f: Callable,
629
- name: Optional[str],
630
- attributes: Optional[Dict[str, Any]],
631
- scorer_config: TraceScorerConfig | None = None,
632
- ):
633
- """Wrap an async generator function to trace nested function calls within each yield."""
634
-
635
- @functools.wraps(f)
636
- def wrapper(*args, **kwargs):
637
- n = name or f.__qualname__
638
-
639
- with sync_span_context(self, n, attributes) as main_span:
640
- self.add_agent_attributes_to_span(main_span)
641
- self.record_instance_state("before", main_span)
642
-
643
- try:
644
- set_span_attribute(
645
- main_span,
646
- AttributeKeys.JUDGMENT_INPUT,
647
- safe_serialize(format_inputs(f, args, kwargs)),
648
- )
649
-
650
- self.judgment_processor.emit_partial()
651
-
652
- if scorer_config:
653
- self._set_pending_trace_eval(
654
- main_span, scorer_config, args, kwargs
655
- )
656
-
657
- async_generator = f(*args, **kwargs)
658
- set_span_attribute(
659
- main_span, AttributeKeys.JUDGMENT_OUTPUT, "<async_generator>"
660
- )
661
- self.record_instance_state("after", main_span)
662
-
663
- return self._create_traced_async_generator(
664
- async_generator, main_span, n, attributes
665
- )
666
-
667
- except Exception as user_exc:
668
- main_span.record_exception(user_exc)
669
- main_span.set_status(Status(StatusCode.ERROR, str(user_exc)))
670
- raise
600
+ finally:
601
+ if not is_return_type_generator:
602
+ span.end()
671
603
 
672
604
  return wrapper
673
605
 
@@ -704,6 +636,7 @@ class Tracer(metaclass=SingletonMeta):
704
636
  span_name: str | None = None,
705
637
  attributes: Optional[Dict[str, Any]] = None,
706
638
  scorer_config: TraceScorerConfig | None = None,
639
+ disable_generator_yield_span: bool = False,
707
640
  ) -> Callable | None:
708
641
  if func is None:
709
642
  return partial(
@@ -712,6 +645,7 @@ class Tracer(metaclass=SingletonMeta):
712
645
  span_name=span_name,
713
646
  attributes=attributes,
714
647
  scorer_config=scorer_config,
648
+ disable_generator_yield_span=disable_generator_yield_span,
715
649
  )
716
650
 
717
651
  if not self.enable_monitoring:
@@ -724,10 +658,14 @@ class Tracer(metaclass=SingletonMeta):
724
658
  **(attributes or {}),
725
659
  }
726
660
 
727
- if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
728
- return self._wrap_async(func, name, func_attributes, scorer_config)
661
+ if inspect.iscoroutinefunction(func):
662
+ return self._wrap_async(
663
+ func, name, func_attributes, scorer_config, disable_generator_yield_span
664
+ )
729
665
  else:
730
- return self._wrap_sync(func, name, func_attributes, scorer_config)
666
+ return self._wrap_sync(
667
+ func, name, func_attributes, scorer_config, disable_generator_yield_span
668
+ )
731
669
 
732
670
  @overload
733
671
  def agent(
@@ -1003,6 +941,182 @@ def format_inputs(
1003
941
  return {}
1004
942
 
1005
943
 
944
+ class _ContextPreservedSyncGeneratorWrapper:
945
+ """Sync generator wrapper that ensures each iteration runs in preserved context."""
946
+
947
+ def __init__(
948
+ self,
949
+ tracer: Tracer,
950
+ generator: Generator,
951
+ context: contextvars.Context,
952
+ span: Span,
953
+ transform_fn: Optional[Callable[[Iterable], str]],
954
+ disable_generator_yield_span: bool = False,
955
+ ) -> None:
956
+ self.tracer = tracer
957
+ self.generator = generator
958
+ self.context = context
959
+ self.span = span
960
+ self.transform_fn = transform_fn
961
+ self._finished = False
962
+ self.disable_generator_yield_span = disable_generator_yield_span
963
+
964
+ def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
965
+ return self
966
+
967
+ def __next__(self) -> Any:
968
+ try:
969
+ # Run the generator's __next__ in the preserved context
970
+ item = self.context.run(next, self.generator)
971
+
972
+ if not self.disable_generator_yield_span:
973
+ with use_span(self.span):
974
+ span_name = (
975
+ str(self.span.name)
976
+ if hasattr(self.span, "name")
977
+ else "generator_item"
978
+ ) # type: ignore[attr-defined]
979
+ with self.tracer.get_tracer().start_as_current_span(
980
+ span_name,
981
+ attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "generator_item"},
982
+ end_on_exit=True,
983
+ ) as child_span:
984
+ set_span_attribute(
985
+ child_span,
986
+ AttributeKeys.JUDGMENT_OUTPUT,
987
+ safe_serialize(item),
988
+ )
989
+
990
+ return item
991
+
992
+ except StopIteration:
993
+ # Handle output and span cleanup when generator is exhausted
994
+ if not self._finished:
995
+ set_span_attribute(
996
+ self.span, AttributeKeys.JUDGMENT_SPAN_KIND, "generator"
997
+ )
998
+ self.span.end()
999
+ self._finished = True
1000
+
1001
+ raise # Re-raise StopIteration
1002
+
1003
+ except Exception as e:
1004
+ if not self._finished:
1005
+ self.span.record_exception(e)
1006
+ self.span.set_status(
1007
+ Status(StatusCode.ERROR, str(e) or type(e).__name__)
1008
+ )
1009
+ self.tracer._maybe_clear_customer_context(self.span)
1010
+ self.span.end()
1011
+ self._finished = True
1012
+
1013
+ raise
1014
+
1015
+ def close(self) -> None:
1016
+ """Close the generator (minimal implementation)."""
1017
+ try:
1018
+ self.generator.close()
1019
+ finally:
1020
+ if not self._finished:
1021
+ set_span_attribute(
1022
+ self.span, AttributeKeys.JUDGMENT_SPAN_KIND, "generator"
1023
+ )
1024
+ self.tracer._maybe_clear_customer_context(self.span)
1025
+ self.span.end()
1026
+ self._finished = True
1027
+
1028
+
1029
+ class _ContextPreservedAsyncGeneratorWrapper:
1030
+ """Async generator wrapper that ensures each iteration runs in preserved context."""
1031
+
1032
+ def __init__(
1033
+ self,
1034
+ tracer: Tracer,
1035
+ generator: AsyncGenerator,
1036
+ context: contextvars.Context,
1037
+ span: Span,
1038
+ transform_fn: Optional[Callable[[Iterable], str]],
1039
+ disable_generator_yield_span: bool = False,
1040
+ ) -> None:
1041
+ self.tracer = tracer
1042
+ self.generator = generator
1043
+ self.context = context
1044
+ self.span = span
1045
+ self.transform_fn = transform_fn
1046
+ self._finished = False
1047
+ self.disable_generator_yield_span = disable_generator_yield_span
1048
+
1049
+ def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
1050
+ return self
1051
+
1052
+ async def __anext__(self) -> Any:
1053
+ try:
1054
+ # Run the generator's __anext__ in the preserved context
1055
+ try:
1056
+ # Python 3.10+ approach with context parameter
1057
+ item = await asyncio.create_task(
1058
+ self.generator.__anext__(), # type: ignore
1059
+ context=self.context,
1060
+ ) # type: ignore
1061
+ except TypeError:
1062
+ # Python < 3.10 fallback - context parameter not supported
1063
+ item = await self.generator.__anext__()
1064
+
1065
+ if not self.disable_generator_yield_span:
1066
+ with use_span(self.span):
1067
+ span_name = (
1068
+ str(self.span.name)
1069
+ if hasattr(self.span, "name")
1070
+ else "generator_item"
1071
+ ) # type: ignore[attr-defined]
1072
+ with self.tracer.get_tracer().start_as_current_span(
1073
+ span_name,
1074
+ attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "generator_item"},
1075
+ end_on_exit=True,
1076
+ ) as child_span:
1077
+ set_span_attribute(
1078
+ child_span,
1079
+ AttributeKeys.JUDGMENT_OUTPUT,
1080
+ safe_serialize(item),
1081
+ )
1082
+
1083
+ return item
1084
+
1085
+ except StopAsyncIteration:
1086
+ # Handle output and span cleanup when generator is exhausted
1087
+ if not self._finished:
1088
+ set_span_attribute(
1089
+ self.span, AttributeKeys.JUDGMENT_SPAN_KIND, "generator"
1090
+ )
1091
+ self.span.end()
1092
+ self._finished = True
1093
+ raise # Re-raise StopAsyncIteration
1094
+ except Exception as e:
1095
+ if not self._finished:
1096
+ self.span.record_exception(e)
1097
+ self.span.set_status(
1098
+ Status(StatusCode.ERROR, str(e) or type(e).__name__)
1099
+ )
1100
+ self.tracer._maybe_clear_customer_context(self.span)
1101
+ self.span.end()
1102
+ self._finished = True
1103
+
1104
+ raise
1105
+
1106
+ async def aclose(self) -> None:
1107
+ """Close the async generator (minimal implementation)."""
1108
+ try:
1109
+ await self.generator.aclose()
1110
+ finally:
1111
+ if not self._finished:
1112
+ set_span_attribute(
1113
+ self.span, AttributeKeys.JUDGMENT_SPAN_KIND, "generator"
1114
+ )
1115
+ self.tracer._maybe_clear_customer_context(self.span)
1116
+ self.span.end()
1117
+ self._finished = True
1118
+
1119
+
1006
1120
  __all__ = [
1007
1121
  "Tracer",
1008
1122
  "wrap",