judgeval 0.16.9__py3-none-any.whl → 0.22.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of judgeval might be problematic. Click here for more details.

Files changed (37) hide show
  1. judgeval/__init__.py +32 -2
  2. judgeval/api/__init__.py +108 -0
  3. judgeval/api/api_types.py +76 -15
  4. judgeval/cli.py +16 -1
  5. judgeval/data/judgment_types.py +76 -20
  6. judgeval/dataset/__init__.py +11 -2
  7. judgeval/env.py +2 -11
  8. judgeval/evaluation/__init__.py +4 -0
  9. judgeval/prompt/__init__.py +330 -0
  10. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +1 -13
  11. judgeval/tracer/__init__.py +371 -257
  12. judgeval/tracer/constants.py +1 -1
  13. judgeval/tracer/exporters/store.py +32 -16
  14. judgeval/tracer/keys.py +11 -9
  15. judgeval/tracer/llm/llm_anthropic/messages.py +38 -26
  16. judgeval/tracer/llm/llm_anthropic/messages_stream.py +14 -14
  17. judgeval/tracer/llm/llm_google/generate_content.py +9 -7
  18. judgeval/tracer/llm/llm_openai/beta_chat_completions.py +38 -14
  19. judgeval/tracer/llm/llm_openai/chat_completions.py +90 -26
  20. judgeval/tracer/llm/llm_openai/responses.py +88 -26
  21. judgeval/tracer/llm/llm_openai/utils.py +42 -0
  22. judgeval/tracer/llm/llm_together/chat_completions.py +26 -18
  23. judgeval/tracer/managers.py +4 -0
  24. judgeval/trainer/__init__.py +10 -1
  25. judgeval/trainer/base_trainer.py +122 -0
  26. judgeval/trainer/config.py +1 -1
  27. judgeval/trainer/fireworks_trainer.py +396 -0
  28. judgeval/trainer/trainer.py +52 -387
  29. judgeval/utils/guards.py +9 -5
  30. judgeval/utils/project.py +15 -0
  31. judgeval/utils/serialize.py +2 -2
  32. judgeval/version.py +1 -1
  33. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/METADATA +2 -3
  34. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/RECORD +37 -32
  35. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/WHEEL +0 -0
  36. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/entry_points.txt +0 -0
  37. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -63,7 +63,7 @@ def _wrap_non_streaming_sync(
63
63
  ctx["span"] = tracer.get_tracer().start_span(
64
64
  "TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
65
65
  )
66
- tracer.add_agent_attributes_to_span(ctx["span"])
66
+ tracer._inject_judgment_context(ctx["span"])
67
67
  set_span_attribute(
68
68
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
69
69
  )
@@ -73,7 +73,7 @@ def _wrap_non_streaming_sync(
73
73
  )
74
74
  ctx["model_name"] = prefixed_model_name
75
75
  set_span_attribute(
76
- ctx["span"], AttributeKeys.GEN_AI_REQUEST_MODEL, prefixed_model_name
76
+ ctx["span"], AttributeKeys.JUDGMENT_LLM_MODEL_NAME, prefixed_model_name
77
77
  )
78
78
 
79
79
  def post_hook(ctx: Dict[str, Any], result: ChatCompletionResponse) -> None:
@@ -90,10 +90,12 @@ def _wrap_non_streaming_sync(
90
90
  result.usage
91
91
  )
92
92
  set_span_attribute(
93
- span, AttributeKeys.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens
93
+ span,
94
+ AttributeKeys.JUDGMENT_USAGE_NON_CACHED_INPUT_TOKENS,
95
+ prompt_tokens,
94
96
  )
95
97
  set_span_attribute(
96
- span, AttributeKeys.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens
98
+ span, AttributeKeys.JUDGMENT_USAGE_OUTPUT_TOKENS, completion_tokens
97
99
  )
98
100
  set_span_attribute(
99
101
  span,
@@ -103,7 +105,7 @@ def _wrap_non_streaming_sync(
103
105
 
104
106
  set_span_attribute(
105
107
  span,
106
- AttributeKeys.GEN_AI_RESPONSE_MODEL,
108
+ AttributeKeys.JUDGMENT_LLM_MODEL_NAME,
107
109
  ctx["model_name"],
108
110
  )
109
111
 
@@ -133,7 +135,7 @@ def _wrap_streaming_sync(
133
135
  ctx["span"] = tracer.get_tracer().start_span(
134
136
  "TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
135
137
  )
136
- tracer.add_agent_attributes_to_span(ctx["span"])
138
+ tracer._inject_judgment_context(ctx["span"])
137
139
  set_span_attribute(
138
140
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
139
141
  )
@@ -143,7 +145,7 @@ def _wrap_streaming_sync(
143
145
  )
144
146
  ctx["model_name"] = prefixed_model_name
145
147
  set_span_attribute(
146
- ctx["span"], AttributeKeys.GEN_AI_REQUEST_MODEL, prefixed_model_name
148
+ ctx["span"], AttributeKeys.JUDGMENT_LLM_MODEL_NAME, prefixed_model_name
147
149
  )
148
150
  ctx["accumulated_content"] = ""
149
151
 
@@ -171,10 +173,12 @@ def _wrap_streaming_sync(
171
173
  chunk.usage
172
174
  )
173
175
  set_span_attribute(
174
- span, AttributeKeys.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens
176
+ span,
177
+ AttributeKeys.JUDGMENT_USAGE_NON_CACHED_INPUT_TOKENS,
178
+ prompt_tokens,
175
179
  )
176
180
  set_span_attribute(
177
- span, AttributeKeys.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens
181
+ span, AttributeKeys.JUDGMENT_USAGE_OUTPUT_TOKENS, completion_tokens
178
182
  )
179
183
  set_span_attribute(
180
184
  span,
@@ -239,7 +243,7 @@ def _wrap_non_streaming_async(
239
243
  ctx["span"] = tracer.get_tracer().start_span(
240
244
  "TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
241
245
  )
242
- tracer.add_agent_attributes_to_span(ctx["span"])
246
+ tracer._inject_judgment_context(ctx["span"])
243
247
  set_span_attribute(
244
248
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
245
249
  )
@@ -249,7 +253,7 @@ def _wrap_non_streaming_async(
249
253
  )
250
254
  ctx["model_name"] = prefixed_model_name
251
255
  set_span_attribute(
252
- ctx["span"], AttributeKeys.GEN_AI_REQUEST_MODEL, prefixed_model_name
256
+ ctx["span"], AttributeKeys.JUDGMENT_LLM_MODEL_NAME, prefixed_model_name
253
257
  )
254
258
 
255
259
  def post_hook(ctx: Dict[str, Any], result: ChatCompletionResponse) -> None:
@@ -266,10 +270,12 @@ def _wrap_non_streaming_async(
266
270
  result.usage
267
271
  )
268
272
  set_span_attribute(
269
- span, AttributeKeys.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens
273
+ span,
274
+ AttributeKeys.JUDGMENT_USAGE_NON_CACHED_INPUT_TOKENS,
275
+ prompt_tokens,
270
276
  )
271
277
  set_span_attribute(
272
- span, AttributeKeys.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens
278
+ span, AttributeKeys.JUDGMENT_USAGE_OUTPUT_TOKENS, completion_tokens
273
279
  )
274
280
  set_span_attribute(
275
281
  span,
@@ -279,7 +285,7 @@ def _wrap_non_streaming_async(
279
285
 
280
286
  set_span_attribute(
281
287
  span,
282
- AttributeKeys.GEN_AI_RESPONSE_MODEL,
288
+ AttributeKeys.JUDGMENT_LLM_MODEL_NAME,
283
289
  ctx["model_name"],
284
290
  )
285
291
 
@@ -310,7 +316,7 @@ def _wrap_streaming_async(
310
316
  ctx["span"] = tracer.get_tracer().start_span(
311
317
  "TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
312
318
  )
313
- tracer.add_agent_attributes_to_span(ctx["span"])
319
+ tracer._inject_judgment_context(ctx["span"])
314
320
  set_span_attribute(
315
321
  ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
316
322
  )
@@ -320,7 +326,7 @@ def _wrap_streaming_async(
320
326
  )
321
327
  ctx["model_name"] = prefixed_model_name
322
328
  set_span_attribute(
323
- ctx["span"], AttributeKeys.GEN_AI_REQUEST_MODEL, prefixed_model_name
329
+ ctx["span"], AttributeKeys.JUDGMENT_LLM_MODEL_NAME, prefixed_model_name
324
330
  )
325
331
  ctx["accumulated_content"] = ""
326
332
 
@@ -348,10 +354,12 @@ def _wrap_streaming_async(
348
354
  chunk.usage
349
355
  )
350
356
  set_span_attribute(
351
- span, AttributeKeys.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens
357
+ span,
358
+ AttributeKeys.JUDGMENT_USAGE_NON_CACHED_INPUT_TOKENS,
359
+ prompt_tokens,
352
360
  )
353
361
  set_span_attribute(
354
- span, AttributeKeys.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens
362
+ span, AttributeKeys.JUDGMENT_USAGE_OUTPUT_TOKENS, completion_tokens
355
363
  )
356
364
  set_span_attribute(
357
365
  span,
@@ -16,6 +16,7 @@ def sync_span_context(
16
16
  name: str,
17
17
  span_attributes: Optional[Dict[str, str]] = None,
18
18
  disable_partial_emit: bool = False,
19
+ end_on_exit: bool = False,
19
20
  ):
20
21
  if span_attributes is None:
21
22
  span_attributes = {}
@@ -23,6 +24,7 @@ def sync_span_context(
23
24
  with tracer.get_tracer().start_as_current_span(
24
25
  name=name,
25
26
  attributes=span_attributes,
27
+ end_on_exit=end_on_exit,
26
28
  ) as span:
27
29
  if disable_partial_emit:
28
30
  tracer.judgment_processor.set_internal_attribute(
@@ -39,6 +41,7 @@ async def async_span_context(
39
41
  name: str,
40
42
  span_attributes: Optional[Dict[str, str]] = None,
41
43
  disable_partial_emit: bool = False,
44
+ end_on_exit: bool = False,
42
45
  ):
43
46
  if span_attributes is None:
44
47
  span_attributes = {}
@@ -46,6 +49,7 @@ async def async_span_context(
46
49
  with tracer.get_tracer().start_as_current_span(
47
50
  name=name,
48
51
  attributes=span_attributes,
52
+ end_on_exit=end_on_exit,
49
53
  ) as span:
50
54
  if disable_partial_emit:
51
55
  tracer.judgment_processor.set_internal_attribute(
@@ -1,5 +1,14 @@
1
1
  from judgeval.trainer.trainer import JudgmentTrainer
2
2
  from judgeval.trainer.config import TrainerConfig, ModelConfig
3
3
  from judgeval.trainer.trainable_model import TrainableModel
4
+ from judgeval.trainer.base_trainer import BaseTrainer
5
+ from judgeval.trainer.fireworks_trainer import FireworksTrainer
4
6
 
5
- __all__ = ["JudgmentTrainer", "TrainerConfig", "ModelConfig", "TrainableModel"]
7
+ __all__ = [
8
+ "JudgmentTrainer",
9
+ "TrainerConfig",
10
+ "ModelConfig",
11
+ "TrainableModel",
12
+ "BaseTrainer",
13
+ "FireworksTrainer",
14
+ ]
@@ -0,0 +1,122 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Callable, List, Optional, Union, Dict, TYPE_CHECKING
3
+ from .config import TrainerConfig, ModelConfig
4
+ from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
5
+
6
+ if TYPE_CHECKING:
7
+ from judgeval.tracer import Tracer
8
+ from .trainable_model import TrainableModel
9
+
10
+
11
+ class BaseTrainer(ABC):
12
+ """
13
+ Abstract base class for training providers.
14
+
15
+ This class defines the interface that all training provider implementations
16
+ must follow. Each provider (Fireworks, Verifiers, etc.) will have its own
17
+ concrete implementation of this interface.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ config: TrainerConfig,
23
+ trainable_model: "TrainableModel",
24
+ tracer: "Tracer",
25
+ project_name: Optional[str] = None,
26
+ ):
27
+ """
28
+ Initialize the base trainer.
29
+
30
+ Args:
31
+ config: TrainerConfig instance with training parameters
32
+ trainable_model: TrainableModel instance to use for training
33
+ tracer: Tracer for observability
34
+ project_name: Project name for organizing training runs
35
+ """
36
+ self.config = config
37
+ self.trainable_model = trainable_model
38
+ self.tracer = tracer
39
+ self.project_name = project_name or "judgment_training"
40
+
41
+ @abstractmethod
42
+ async def generate_rollouts_and_rewards(
43
+ self,
44
+ agent_function: Callable[[Any], Any],
45
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
46
+ prompts: List[Any],
47
+ num_prompts_per_step: Optional[int] = None,
48
+ num_generations_per_prompt: Optional[int] = None,
49
+ concurrency: Optional[int] = None,
50
+ ) -> Any:
51
+ """
52
+ Generate rollouts and compute rewards using the current model snapshot.
53
+
54
+ Args:
55
+ agent_function: Function/agent to call for generating responses
56
+ scorers: List of scorer objects to evaluate responses
57
+ prompts: List of prompts to use for training
58
+ num_prompts_per_step: Number of prompts to use per step
59
+ num_generations_per_prompt: Generations per prompt
60
+ concurrency: Concurrency limit
61
+
62
+ Returns:
63
+ Provider-specific dataset format for training
64
+ """
65
+ pass
66
+
67
+ @abstractmethod
68
+ async def run_reinforcement_learning(
69
+ self,
70
+ agent_function: Callable[[Any], Any],
71
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
72
+ prompts: List[Any],
73
+ ) -> ModelConfig:
74
+ """
75
+ Run the iterative reinforcement learning fine-tuning loop.
76
+
77
+ Args:
78
+ agent_function: Function/agent to call for generating responses
79
+ scorers: List of scorer objects to evaluate responses
80
+ prompts: List of prompts to use for training
81
+
82
+ Returns:
83
+ ModelConfig: Configuration of the trained model
84
+ """
85
+ pass
86
+
87
+ @abstractmethod
88
+ async def train(
89
+ self,
90
+ agent_function: Callable[[Any], Any],
91
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
92
+ prompts: List[Any],
93
+ ) -> ModelConfig:
94
+ """
95
+ Start the reinforcement learning fine-tuning process.
96
+
97
+ This is the main entry point for running the training.
98
+
99
+ Args:
100
+ agent_function: Function/agent to call for generating responses
101
+ scorers: List of scorer objects to evaluate responses
102
+ prompts: List of prompts to use for training
103
+
104
+ Returns:
105
+ ModelConfig: Configuration of the trained model
106
+ """
107
+ pass
108
+
109
+ @abstractmethod
110
+ def _extract_message_history_from_spans(
111
+ self, trace_id: str
112
+ ) -> List[Dict[str, str]]:
113
+ """
114
+ Extract message history from spans for training purposes.
115
+
116
+ Args:
117
+ trace_id: The trace ID (32-char hex string) to extract message history from
118
+
119
+ Returns:
120
+ List of message dictionaries with 'role' and 'content' keys
121
+ """
122
+ pass
@@ -16,7 +16,7 @@ class TrainerConfig:
16
16
  user_id: str
17
17
  model_id: str
18
18
  base_model_name: str = "qwen2p5-7b-instruct"
19
- rft_provider: str = "fireworks"
19
+ rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
20
20
  num_steps: int = 5
21
21
  num_generations_per_prompt: int = 4
22
22
  num_prompts_per_step: int = 4