judgeval 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,301 @@
1
+ import asyncio
2
+ import time
3
+ from typing import Optional, Callable, Any, List, Union
4
+ from fireworks import Dataset
5
+ from .config import TrainerConfig, ModelConfig
6
+ from .trainable_model import TrainableModel
7
+ from judgeval.tracer import Tracer
8
+ from judgeval.judgment_client import JudgmentClient
9
+ from judgeval.scorers import BaseScorer, APIScorerConfig
10
+ from judgeval.data import Example
11
+ from .console import _spinner_progress, _print_progress, _print_progress_update
12
+ from judgeval.common.exceptions import JudgmentAPIError
13
+
14
+
15
+ class JudgmentTrainer:
16
+ """
17
+ A reinforcement learning trainer for Judgment models using Fine-Tuning.
18
+
19
+ This class handles the iterative training process where models are improved
20
+ through reinforcement learning fine-tuning based on generated rollouts and rewards.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ config: TrainerConfig,
26
+ trainable_model: TrainableModel,
27
+ tracer: Tracer,
28
+ project_name: Optional[str] = None,
29
+ ):
30
+ """
31
+ Initialize the JudgmentTrainer.
32
+
33
+ Args:
34
+ config: TrainerConfig instance with training parameters. If None, uses default config.
35
+ tracer: Optional tracer for observability
36
+ trainable_model: Optional trainable model instance
37
+ project_name: Project name for organizing training runs and evaluations
38
+ """
39
+ try:
40
+ self.config = config
41
+ self.tracer = tracer
42
+ self.tracer.show_trace_urls = False
43
+ self.project_name = project_name or "judgment_training"
44
+
45
+ if trainable_model is None:
46
+ self.trainable_model = TrainableModel(self.config)
47
+ else:
48
+ self.trainable_model = trainable_model
49
+
50
+ self.judgment_client = JudgmentClient()
51
+ except Exception as e:
52
+ raise JudgmentAPIError(
53
+ f"Failed to initialize JudgmentTrainer: {str(e)}"
54
+ ) from e
55
+
56
+ async def generate_rollouts_and_rewards(
57
+ self,
58
+ agent_function: Callable[[Any], Any],
59
+ scorers: List[Union[APIScorerConfig, BaseScorer]],
60
+ prompts: List[Any],
61
+ num_prompts_per_step: Optional[int] = None,
62
+ num_generations_per_prompt: Optional[int] = None,
63
+ concurrency: Optional[int] = None,
64
+ ):
65
+ """
66
+ Generate rollouts and compute rewards using the current model snapshot.
67
+ Each sample contains multiple generations for reinforcement learning optimization.
68
+
69
+ Args:
70
+ agent_function: Function/agent to call for generating responses
71
+ scorers: List of scorer objects to evaluate responses
72
+ prompts: List of prompts to use for training
73
+ num_prompts_per_step: Number of prompts to use per step (defaults to config value, limited by prompts list length)
74
+ num_generations_per_prompt: Generations per prompt (defaults to config value)
75
+ concurrency: Concurrency limit (defaults to config value)
76
+
77
+ Returns:
78
+ List of dataset rows containing samples with messages and evaluations
79
+ """
80
+ num_prompts_per_step = min(
81
+ num_prompts_per_step or self.config.num_prompts_per_step, len(prompts)
82
+ )
83
+ num_generations_per_prompt = (
84
+ num_generations_per_prompt or self.config.num_generations_per_prompt
85
+ )
86
+ concurrency = concurrency or self.config.concurrency
87
+
88
+ semaphore = asyncio.Semaphore(concurrency)
89
+
90
+ @self.tracer.observe(span_type="function")
91
+ async def generate_single_response(prompt_id, generation_id):
92
+ async with semaphore:
93
+ prompt_input = prompts[prompt_id]
94
+ response_data = await agent_function(**prompt_input)
95
+ messages = response_data.get("messages", [])
96
+
97
+ try:
98
+ traced_messages = self.tracer.get_current_message_history()
99
+ if traced_messages:
100
+ messages = traced_messages
101
+ except Exception as e:
102
+ print(f"Warning: Failed to get message history from trace: {e}")
103
+ pass
104
+
105
+ example = Example(
106
+ input=prompt_input,
107
+ messages=messages,
108
+ actual_output=response_data,
109
+ )
110
+
111
+ scoring_results = self.judgment_client.run_evaluation(
112
+ examples=[example],
113
+ scorers=scorers,
114
+ project_name=self.project_name,
115
+ eval_run_name=f"training_step_{self.trainable_model.current_step}_prompt_{prompt_id}_gen_{generation_id}",
116
+ show_url=False,
117
+ )
118
+
119
+ if scoring_results and scoring_results[0].scorers_data:
120
+ reward = sum(
121
+ scorer_data.score
122
+ for scorer_data in scoring_results[0].scorers_data
123
+ ) / len(scoring_results[0].scorers_data)
124
+ else:
125
+ reward = 0.0
126
+
127
+ return {
128
+ "prompt_id": prompt_id,
129
+ "generation_id": generation_id,
130
+ "messages": messages,
131
+ "evals": {"score": reward},
132
+ }
133
+
134
+ coros = []
135
+ for prompt_id in range(num_prompts_per_step):
136
+ for generation_id in range(num_generations_per_prompt):
137
+ coro = generate_single_response(prompt_id, generation_id)
138
+ coros.append(coro)
139
+
140
+ with _spinner_progress(f"Generating {len(coros)} rollouts..."):
141
+ num_completed = 0
142
+ results = []
143
+
144
+ for coro in asyncio.as_completed(coros):
145
+ result = await coro
146
+ results.append(result)
147
+ num_completed += 1
148
+
149
+ _print_progress(f"Generated {len(results)} rollouts successfully")
150
+
151
+ dataset_rows = []
152
+ for prompt_id in range(num_prompts_per_step):
153
+ prompt_generations = [r for r in results if r["prompt_id"] == prompt_id]
154
+ sample_generations = [
155
+ {"messages": gen["messages"], "evals": gen["evals"]}
156
+ for gen in prompt_generations
157
+ ]
158
+ dataset_rows.append({"samples": sample_generations})
159
+
160
+ return dataset_rows
161
+
162
+ async def run_reinforcement_learning(
163
+ self,
164
+ agent_function: Callable[[Any], Any],
165
+ scorers: List[Union[APIScorerConfig, BaseScorer]],
166
+ prompts: List[Any],
167
+ ) -> ModelConfig:
168
+ """
169
+ Run the iterative reinforcement learning fine-tuning loop.
170
+
171
+ This method performs multiple steps of reinforcement learning, where each step:
172
+ 1. Advances to the appropriate model snapshot
173
+ 2. Generates rollouts and computes rewards using scorers
174
+ 3. Trains a new model using reinforcement learning
175
+ 4. Waits for training completion
176
+
177
+ Args:
178
+ agent_function: Function/agent to call for generating responses
179
+ scorers: List of scorer objects to evaluate responses
180
+ prompts: List of prompts to use for training
181
+
182
+ Returns:
183
+ ModelConfig: Configuration of the trained model for inference and future training
184
+ """
185
+
186
+ _print_progress("Starting reinforcement learning training")
187
+
188
+ training_params = {
189
+ "num_steps": self.config.num_steps,
190
+ "num_prompts_per_step": self.config.num_prompts_per_step,
191
+ "num_generations_per_prompt": self.config.num_generations_per_prompt,
192
+ "epochs": self.config.epochs,
193
+ "learning_rate": self.config.learning_rate,
194
+ "accelerator_count": self.config.accelerator_count,
195
+ "accelerator_type": self.config.accelerator_type,
196
+ "temperature": self.config.temperature,
197
+ "max_tokens": self.config.max_tokens,
198
+ }
199
+
200
+ start_step = self.trainable_model.current_step
201
+
202
+ for step in range(start_step, self.config.num_steps):
203
+ step_num = step + 1
204
+ _print_progress(
205
+ f"Starting training step {step_num}", step_num, self.config.num_steps
206
+ )
207
+
208
+ self.trainable_model.advance_to_next_step(step)
209
+
210
+ dataset_rows = await self.generate_rollouts_and_rewards(
211
+ agent_function, scorers, prompts
212
+ )
213
+
214
+ with _spinner_progress(
215
+ "Preparing training dataset", step_num, self.config.num_steps
216
+ ):
217
+ dataset = Dataset.from_list(dataset_rows)
218
+ dataset.sync()
219
+
220
+ _print_progress(
221
+ "Starting reinforcement training", step_num, self.config.num_steps
222
+ )
223
+ job = self.trainable_model.perform_reinforcement_step(dataset, step)
224
+
225
+ last_state = None
226
+ with _spinner_progress(
227
+ "Training job in progress", step_num, self.config.num_steps
228
+ ):
229
+ while not job.is_completed:
230
+ job.raise_if_bad_state()
231
+ current_state = job.state
232
+
233
+ if current_state != last_state:
234
+ if current_state in ["uploading", "validating"]:
235
+ _print_progress_update(
236
+ f"Training job: {current_state} data"
237
+ )
238
+ elif current_state == "training":
239
+ _print_progress_update(
240
+ "Training job: model training in progress"
241
+ )
242
+ else:
243
+ _print_progress_update(f"Training job: {current_state}")
244
+ last_state = current_state
245
+
246
+ time.sleep(10)
247
+ job = job.get()
248
+ if job is None:
249
+ raise JudgmentAPIError(
250
+ "Training job was deleted while waiting for completion"
251
+ )
252
+
253
+ _print_progress(
254
+ f"Training completed! New model: {job.output_model}",
255
+ step_num,
256
+ self.config.num_steps,
257
+ )
258
+
259
+ dataset.delete()
260
+
261
+ _print_progress("All training steps completed!")
262
+
263
+ with _spinner_progress("Deploying final trained model"):
264
+ self.trainable_model.advance_to_next_step(self.config.num_steps)
265
+
266
+ return self.trainable_model.get_model_config(training_params)
267
+
268
+ async def train(
269
+ self,
270
+ agent_function: Callable[[Any], Any],
271
+ scorers: List[Union[APIScorerConfig, BaseScorer]],
272
+ prompts: List[Any],
273
+ rft_provider: Optional[str] = None,
274
+ ) -> ModelConfig:
275
+ """
276
+ Start the reinforcement learning fine-tuning process.
277
+
278
+ This is the main entry point for running the reinforcement learning training.
279
+
280
+ Args:
281
+ agent_function: Function/agent to call for generating responses.
282
+ scorers: List of scorer objects to evaluate responses
283
+ prompts: List of prompts to use for training
284
+ rft_provider: RFT provider to use for training. Currently only "fireworks" is supported.
285
+ Support for other providers is planned for future releases.
286
+
287
+ Returns:
288
+ ModelConfig: Configuration of the trained model for future loading
289
+ """
290
+ try:
291
+ if rft_provider is not None:
292
+ self.config.rft_provider = rft_provider
293
+
294
+ return await self.run_reinforcement_learning(
295
+ agent_function, scorers, prompts
296
+ )
297
+ except JudgmentAPIError:
298
+ # Re-raise JudgmentAPIError as-is
299
+ raise
300
+ except Exception as e:
301
+ raise JudgmentAPIError(f"Training process failed: {str(e)}") from e
@@ -0,0 +1,104 @@
1
+ from typing import List, Optional, Union
2
+ from pydantic import field_validator, model_validator, Field
3
+ from datetime import datetime, timezone
4
+ import uuid
5
+
6
+ from judgeval.data import Example
7
+ from judgeval.scorers import BaseScorer, APIScorerConfig
8
+ from judgeval.constants import ACCEPTABLE_MODELS
9
+ from judgeval.data.judgment_types import EvaluationRunJudgmentType
10
+
11
+
12
+ class EvaluationRun(EvaluationRunJudgmentType):
13
+ """
14
+ Stores example and evaluation scorers together for running an eval task
15
+
16
+ Args:
17
+ project_name (str): The name of the project the evaluation results belong to
18
+ eval_name (str): A name for this evaluation run
19
+ examples (List[Example]): The examples to evaluate
20
+ scorers (List[Union[BaseScorer, APIScorerConfig]]): A list of scorers to use for evaluation
21
+ model (str): The model used as a judge when using LLM as a Judge
22
+ metadata (Optional[Dict[str, Any]]): Additional metadata to include for this evaluation run, e.g. comments, dataset name, purpose, etc.
23
+ """
24
+
25
+ id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4()))
26
+ created_at: Optional[str] = Field(
27
+ default_factory=lambda: datetime.now(timezone.utc).isoformat()
28
+ )
29
+ custom_scorers: Optional[List[BaseScorer]] = None
30
+ judgment_scorers: Optional[List[APIScorerConfig]] = None
31
+ organization_id: Optional[str] = None
32
+
33
+ def __init__(
34
+ self,
35
+ scorers: Optional[List[Union[BaseScorer, APIScorerConfig]]] = None,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Initialize EvaluationRun with automatic scorer classification.
40
+
41
+ Args:
42
+ scorers: List of scorers that will be automatically sorted into custom_scorers or judgment_scorers
43
+ **kwargs: Other initialization arguments
44
+ """
45
+ if scorers is not None:
46
+ # Automatically sort scorers into appropriate fields
47
+ custom_scorers = [s for s in scorers if isinstance(s, BaseScorer)]
48
+ judgment_scorers = [s for s in scorers if isinstance(s, APIScorerConfig)]
49
+
50
+ # Always set both fields as lists (even if empty) to satisfy validation
51
+ kwargs["custom_scorers"] = custom_scorers
52
+ kwargs["judgment_scorers"] = judgment_scorers
53
+
54
+ super().__init__(**kwargs)
55
+
56
+ def model_dump(self, **kwargs):
57
+ data = super().model_dump(**kwargs)
58
+ data["custom_scorers"] = [s.model_dump() for s in self.custom_scorers]
59
+ data["judgment_scorers"] = [s.model_dump() for s in self.judgment_scorers]
60
+ data["examples"] = [example.model_dump() for example in self.examples]
61
+
62
+ return data
63
+
64
+ @field_validator("examples")
65
+ def validate_examples(cls, v):
66
+ if not v:
67
+ raise ValueError("Examples cannot be empty.")
68
+ for item in v:
69
+ if not isinstance(item, Example):
70
+ raise ValueError(f"Item of type {type(item)} is not a Example")
71
+ return v
72
+
73
+ @model_validator(mode="after")
74
+ @classmethod
75
+ def validate_scorer_lists(cls, values):
76
+ custom_scorers = values.custom_scorers
77
+ judgment_scorers = values.judgment_scorers
78
+
79
+ # Check that both lists are not empty
80
+ if not custom_scorers and not judgment_scorers:
81
+ raise ValueError(
82
+ "At least one of custom_scorers or judgment_scorers must be provided."
83
+ )
84
+
85
+ # Check that only one list is filled
86
+ if custom_scorers and judgment_scorers:
87
+ raise ValueError(
88
+ "Only one of custom_scorers or judgment_scorers can be provided, not both."
89
+ )
90
+
91
+ return values
92
+
93
+ @field_validator("model")
94
+ def validate_model(cls, v, values):
95
+ if not v:
96
+ raise ValueError("Model cannot be empty.")
97
+
98
+ # Check if model is string or list of strings
99
+ if isinstance(v, str):
100
+ if v not in ACCEPTABLE_MODELS:
101
+ raise ValueError(
102
+ f"Model name {v} not recognized. Please select a valid model name.)"
103
+ )
104
+ return v
@@ -1,6 +1,6 @@
1
1
  # generated by datamodel-codegen:
2
2
  # filename: openapi_new.json
3
- # timestamp: 2025-08-01T22:19:19+00:00
3
+ # timestamp: 2025-08-08T18:50:51+00:00
4
4
 
5
5
  from __future__ import annotations
6
6
 
@@ -51,6 +51,31 @@ class ScorerConfigJudgmentType(BaseModel):
51
51
  kwargs: Annotated[Optional[Dict[str, Any]], Field(title="Kwargs")] = None
52
52
 
53
53
 
54
+ class BaseScorerJudgmentType(BaseModel):
55
+ score_type: Annotated[str, Field(title="Score Type")]
56
+ threshold: Annotated[Optional[float], Field(title="Threshold")] = 0.5
57
+ name: Annotated[Optional[str], Field(title="Name")] = None
58
+ class_name: Annotated[Optional[str], Field(title="Class Name")] = None
59
+ score: Annotated[Optional[float], Field(title="Score")] = None
60
+ score_breakdown: Annotated[
61
+ Optional[Dict[str, Any]], Field(title="Score Breakdown")
62
+ ] = None
63
+ reason: Annotated[Optional[str], Field(title="Reason")] = ""
64
+ using_native_model: Annotated[Optional[bool], Field(title="Using Native Model")] = (
65
+ None
66
+ )
67
+ success: Annotated[Optional[bool], Field(title="Success")] = None
68
+ model: Annotated[Optional[str], Field(title="Model")] = None
69
+ model_client: Annotated[Any, Field(title="Model Client")] = None
70
+ strict_mode: Annotated[Optional[bool], Field(title="Strict Mode")] = False
71
+ error: Annotated[Optional[str], Field(title="Error")] = None
72
+ additional_metadata: Annotated[
73
+ Optional[Dict[str, Any]], Field(title="Additional Metadata")
74
+ ] = None
75
+ user: Annotated[Optional[str], Field(title="User")] = None
76
+ server_hosted: Annotated[Optional[bool], Field(title="Server Hosted")] = False
77
+
78
+
54
79
  class TraceUsageJudgmentType(BaseModel):
55
80
  prompt_tokens: Annotated[Optional[int], Field(title="Prompt Tokens")] = None
56
81
  completion_tokens: Annotated[Optional[int], Field(title="Completion Tokens")] = None
@@ -90,16 +115,21 @@ class HTTPValidationErrorJudgmentType(BaseModel):
90
115
  ] = None
91
116
 
92
117
 
93
- class JudgmentEvalJudgmentType(BaseModel):
118
+ class EvaluationRunJudgmentType(BaseModel):
119
+ id: Annotated[Optional[str], Field(title="Id")] = None
94
120
  project_name: Annotated[Optional[str], Field(title="Project Name")] = None
95
121
  eval_name: Annotated[Optional[str], Field(title="Eval Name")] = None
96
122
  examples: Annotated[List[ExampleJudgmentType], Field(title="Examples")]
97
- scorers: Annotated[List[ScorerConfigJudgmentType], Field(title="Scorers")]
123
+ custom_scorers: Annotated[
124
+ Optional[List[BaseScorerJudgmentType]], Field(title="Custom Scorers")
125
+ ] = Field(default_factory=list)
126
+ judgment_scorers: Annotated[
127
+ Optional[List[ScorerConfigJudgmentType]], Field(title="Judgment Scorers")
128
+ ] = Field(default_factory=list)
98
129
  model: Annotated[str, Field(title="Model")]
99
- append: Annotated[Optional[bool], Field(title="Append")] = False
100
- override: Annotated[Optional[bool], Field(title="Override")] = False
101
130
  trace_span_id: Annotated[Optional[str], Field(title="Trace Span Id")] = None
102
131
  trace_id: Annotated[Optional[str], Field(title="Trace Id")] = None
132
+ created_at: Annotated[Optional[str], Field(title="Created At")] = None
103
133
 
104
134
 
105
135
  class TraceSpanJudgmentType(BaseModel):
@@ -123,6 +153,7 @@ class TraceSpanJudgmentType(BaseModel):
123
153
  ] = None
124
154
  has_evaluation: Annotated[Optional[bool], Field(title="Has Evaluation")] = False
125
155
  agent_name: Annotated[Optional[str], Field(title="Agent Name")] = None
156
+ class_name: Annotated[Optional[str], Field(title="Class Name")] = None
126
157
  state_before: Annotated[Optional[Dict[str, Any]], Field(title="State Before")] = (
127
158
  None
128
159
  )
@@ -172,8 +203,6 @@ class TraceRunJudgmentType(BaseModel):
172
203
  traces: Annotated[List[TraceJudgmentType], Field(title="Traces")]
173
204
  scorers: Annotated[List[ScorerConfigJudgmentType], Field(title="Scorers")]
174
205
  model: Annotated[str, Field(title="Model")]
175
- append: Annotated[Optional[bool], Field(title="Append")] = False
176
- override: Annotated[Optional[bool], Field(title="Override")] = False
177
206
  trace_span_id: Annotated[Optional[str], Field(title="Trace Span Id")] = None
178
207
  tools: Annotated[Optional[List[Dict[str, Any]]], Field(title="Tools")] = None
179
208
 
@@ -181,5 +210,5 @@ class TraceRunJudgmentType(BaseModel):
181
210
  class EvalResultsJudgmentType(BaseModel):
182
211
  results: Annotated[List[ScoringResultJudgmentType], Field(title="Results")]
183
212
  run: Annotated[
184
- Union[TraceRunJudgmentType, JudgmentEvalJudgmentType], Field(title="Run")
213
+ Union[TraceRunJudgmentType, EvaluationRunJudgmentType], Field(title="Run")
185
214
  ]
judgeval/data/trace.py CHANGED
@@ -32,6 +32,7 @@ class TraceSpan(TraceSpanJudgmentType):
32
32
  "usage": self.usage.model_dump() if self.usage else None,
33
33
  "has_evaluation": self.has_evaluation,
34
34
  "agent_name": self.agent_name,
35
+ "class_name": self.class_name,
35
36
  "state_before": self.state_before,
36
37
  "state_after": self.state_after,
37
38
  "additional_metadata": json_encoder(self.additional_metadata),
@@ -29,8 +29,6 @@ class TraceRun(BaseModel):
29
29
  scorers: List[Union[APIScorerConfig, BaseScorer]]
30
30
  model: Optional[str] = DEFAULT_GPT_MODEL
31
31
  trace_span_id: Optional[str] = None
32
- append: Optional[bool] = False
33
- override: Optional[bool] = False
34
32
  rules: Optional[List[Rule]] = None
35
33
  tools: Optional[List[Dict[str, Any]]] = None
36
34
 
@@ -133,7 +133,8 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
133
133
  inputs: Optional[Dict[str, Any]] = None,
134
134
  ) -> None:
135
135
  """Start tracking a span, ensuring trace client exists"""
136
-
136
+ if name.startswith("__") and name.endswith("__"):
137
+ return
137
138
  start_time = time.time()
138
139
  span_id = str(uuid.uuid4())
139
140
  parent_span_id: Optional[str] = None