judgeval 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,238 @@
1
+ from fireworks import LLM
2
+ from .config import TrainerConfig, ModelConfig
3
+ from typing import Optional, Dict, Any, Callable
4
+ from .console import _model_spinner_progress, _print_model_progress
5
+ from judgeval.common.exceptions import JudgmentAPIError
6
+
7
+
8
+ class TrainableModel:
9
+ """
10
+ A wrapper class for managing model snapshots during training.
11
+
12
+ This class automatically handles model snapshot creation and management
13
+ during the RFT (Reinforcement Fine-Tuning) process,
14
+ abstracting away manual snapshot management from users.
15
+ """
16
+
17
+ def __init__(self, config: TrainerConfig):
18
+ """
19
+ Initialize the TrainableModel.
20
+
21
+ Args:
22
+ config: TrainerConfig instance with model configuration
23
+ """
24
+ try:
25
+ self.config = config
26
+ self.current_step = 0
27
+ self._current_model = None
28
+ self._tracer_wrapper_func = None
29
+
30
+ self._base_model = self._create_base_model()
31
+ self._current_model = self._base_model
32
+ except Exception as e:
33
+ raise JudgmentAPIError(
34
+ f"Failed to initialize TrainableModel: {str(e)}"
35
+ ) from e
36
+
37
+ @classmethod
38
+ def from_model_config(cls, model_config: ModelConfig) -> "TrainableModel":
39
+ """
40
+ Create a TrainableModel from a saved ModelConfig.
41
+
42
+ Args:
43
+ model_config: ModelConfig instance with saved model state
44
+
45
+ Returns:
46
+ TrainableModel instance configured to use the saved model
47
+ """
48
+ # Create a TrainerConfig from the ModelConfig
49
+ trainer_config = TrainerConfig(
50
+ base_model_name=model_config.base_model_name,
51
+ deployment_id=model_config.deployment_id,
52
+ user_id=model_config.user_id,
53
+ model_id=model_config.model_id,
54
+ enable_addons=model_config.enable_addons,
55
+ )
56
+
57
+ instance = cls(trainer_config)
58
+ instance.current_step = model_config.current_step
59
+
60
+ if model_config.is_trained and model_config.current_model_name:
61
+ instance._load_trained_model(model_config.current_model_name)
62
+
63
+ return instance
64
+
65
+ def _create_base_model(self):
66
+ """Create and configure the base model."""
67
+ try:
68
+ with _model_spinner_progress(
69
+ "Creating and deploying base model..."
70
+ ) as update_progress:
71
+ update_progress("Creating base model instance...")
72
+ base_model = LLM(
73
+ model=self.config.base_model_name,
74
+ deployment_type="on-demand",
75
+ id=self.config.deployment_id,
76
+ enable_addons=self.config.enable_addons,
77
+ )
78
+ update_progress("Applying deployment configuration...")
79
+ base_model.apply()
80
+ _print_model_progress("Base model deployment ready")
81
+ return base_model
82
+ except Exception as e:
83
+ raise JudgmentAPIError(
84
+ f"Failed to create and deploy base model '{self.config.base_model_name}': {str(e)}"
85
+ ) from e
86
+
87
+ def _load_trained_model(self, model_name: str):
88
+ """Load a trained model by name."""
89
+ try:
90
+ with _model_spinner_progress(
91
+ f"Loading and deploying trained model: {model_name}"
92
+ ) as update_progress:
93
+ update_progress("Creating trained model instance...")
94
+ self._current_model = LLM(
95
+ model=model_name,
96
+ deployment_type="on-demand-lora",
97
+ base_id=self.config.deployment_id,
98
+ )
99
+ update_progress("Applying deployment configuration...")
100
+ self._current_model.apply()
101
+ _print_model_progress("Trained model deployment ready")
102
+
103
+ if self._tracer_wrapper_func:
104
+ self._tracer_wrapper_func(self._current_model)
105
+ except Exception as e:
106
+ raise JudgmentAPIError(
107
+ f"Failed to load and deploy trained model '{model_name}': {str(e)}"
108
+ ) from e
109
+
110
+ def get_current_model(self):
111
+ return self._current_model
112
+
113
+ @property
114
+ def chat(self):
115
+ """OpenAI-compatible chat interface."""
116
+ return self._current_model.chat
117
+
118
+ @property
119
+ def completions(self):
120
+ """OpenAI-compatible completions interface."""
121
+ return self._current_model.completions
122
+
123
+ def advance_to_next_step(self, step: int):
124
+ """
125
+ Advance to the next training step and update the current model snapshot.
126
+
127
+ Args:
128
+ step: The current training step number
129
+ """
130
+ try:
131
+ self.current_step = step
132
+
133
+ if step == 0:
134
+ self._current_model = self._base_model
135
+ else:
136
+ model_name = f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{step}"
137
+ with _model_spinner_progress(
138
+ f"Creating and deploying model snapshot: {model_name}"
139
+ ) as update_progress:
140
+ update_progress("Creating model snapshot instance...")
141
+ self._current_model = LLM(
142
+ model=model_name,
143
+ deployment_type="on-demand-lora",
144
+ base_id=self.config.deployment_id,
145
+ )
146
+ update_progress("Applying deployment configuration...")
147
+ self._current_model.apply()
148
+ _print_model_progress("Model snapshot deployment ready")
149
+
150
+ if self._tracer_wrapper_func:
151
+ self._tracer_wrapper_func(self._current_model)
152
+ except Exception as e:
153
+ raise JudgmentAPIError(
154
+ f"Failed to advance to training step {step}: {str(e)}"
155
+ ) from e
156
+
157
+ def perform_reinforcement_step(self, dataset, step: int):
158
+ """
159
+ Perform a reinforcement learning step using the current model.
160
+
161
+ Args:
162
+ dataset: Training dataset for the reinforcement step
163
+ step: Current step number for output model naming
164
+
165
+ Returns:
166
+ Training job object
167
+ """
168
+ try:
169
+ model_name = f"{self.config.model_id}-v{step + 1}"
170
+ return self._current_model.reinforcement_step(
171
+ dataset=dataset,
172
+ output_model=model_name,
173
+ epochs=self.config.epochs,
174
+ learning_rate=self.config.learning_rate,
175
+ accelerator_count=self.config.accelerator_count,
176
+ accelerator_type=self.config.accelerator_type,
177
+ )
178
+ except Exception as e:
179
+ raise JudgmentAPIError(
180
+ f"Failed to start reinforcement learning step {step + 1}: {str(e)}"
181
+ ) from e
182
+
183
+ def get_model_config(
184
+ self, training_params: Optional[Dict[str, Any]] = None
185
+ ) -> ModelConfig:
186
+ """
187
+ Get the current model configuration for persistence.
188
+
189
+ Args:
190
+ training_params: Optional training parameters to include in config
191
+
192
+ Returns:
193
+ ModelConfig instance with current model state
194
+ """
195
+ current_model_name = None
196
+ is_trained = False
197
+
198
+ if self.current_step > 0:
199
+ current_model_name = f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{self.current_step}"
200
+ is_trained = True
201
+
202
+ return ModelConfig(
203
+ base_model_name=self.config.base_model_name,
204
+ deployment_id=self.config.deployment_id,
205
+ user_id=self.config.user_id,
206
+ model_id=self.config.model_id,
207
+ enable_addons=self.config.enable_addons,
208
+ current_step=self.current_step,
209
+ total_steps=self.config.num_steps,
210
+ current_model_name=current_model_name,
211
+ is_trained=is_trained,
212
+ training_params=training_params,
213
+ )
214
+
215
+ def save_model_config(
216
+ self, filepath: str, training_params: Optional[Dict[str, Any]] = None
217
+ ):
218
+ """
219
+ Save the current model configuration to a file.
220
+
221
+ Args:
222
+ filepath: Path to save the configuration file
223
+ training_params: Optional training parameters to include in config
224
+ """
225
+ model_config = self.get_model_config(training_params)
226
+ model_config.save_to_file(filepath)
227
+
228
+ def _register_tracer_wrapper(self, wrapper_func: Callable):
229
+ """
230
+ Register a tracer wrapper function to be reapplied when models change.
231
+
232
+ This is called internally by the tracer's wrap() function to ensure
233
+ that new model instances created during training are automatically wrapped.
234
+
235
+ Args:
236
+ wrapper_func: Function that wraps a model instance with tracing
237
+ """
238
+ self._tracer_wrapper_func = wrapper_func
@@ -0,0 +1,301 @@
1
+ import asyncio
2
+ import time
3
+ from typing import Optional, Callable, Any, List, Union
4
+ from fireworks import Dataset
5
+ from .config import TrainerConfig, ModelConfig
6
+ from .trainable_model import TrainableModel
7
+ from judgeval.tracer import Tracer
8
+ from judgeval.judgment_client import JudgmentClient
9
+ from judgeval.scorers import BaseScorer, APIScorerConfig
10
+ from judgeval.data import Example
11
+ from .console import _spinner_progress, _print_progress, _print_progress_update
12
+ from judgeval.common.exceptions import JudgmentAPIError
13
+
14
+
15
+ class JudgmentTrainer:
16
+ """
17
+ A reinforcement learning trainer for Judgment models using Fine-Tuning.
18
+
19
+ This class handles the iterative training process where models are improved
20
+ through reinforcement learning fine-tuning based on generated rollouts and rewards.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ config: TrainerConfig,
26
+ trainable_model: TrainableModel,
27
+ tracer: Tracer,
28
+ project_name: Optional[str] = None,
29
+ ):
30
+ """
31
+ Initialize the JudgmentTrainer.
32
+
33
+ Args:
34
+ config: TrainerConfig instance with training parameters. If None, uses default config.
35
+ tracer: Optional tracer for observability
36
+ trainable_model: Optional trainable model instance
37
+ project_name: Project name for organizing training runs and evaluations
38
+ """
39
+ try:
40
+ self.config = config
41
+ self.tracer = tracer
42
+ self.tracer.show_trace_urls = False
43
+ self.project_name = project_name or "judgment_training"
44
+
45
+ if trainable_model is None:
46
+ self.trainable_model = TrainableModel(self.config)
47
+ else:
48
+ self.trainable_model = trainable_model
49
+
50
+ self.judgment_client = JudgmentClient()
51
+ except Exception as e:
52
+ raise JudgmentAPIError(
53
+ f"Failed to initialize JudgmentTrainer: {str(e)}"
54
+ ) from e
55
+
56
+ async def generate_rollouts_and_rewards(
57
+ self,
58
+ agent_function: Callable[[Any], Any],
59
+ scorers: List[Union[APIScorerConfig, BaseScorer]],
60
+ prompts: List[Any],
61
+ num_prompts_per_step: Optional[int] = None,
62
+ num_generations_per_prompt: Optional[int] = None,
63
+ concurrency: Optional[int] = None,
64
+ ):
65
+ """
66
+ Generate rollouts and compute rewards using the current model snapshot.
67
+ Each sample contains multiple generations for reinforcement learning optimization.
68
+
69
+ Args:
70
+ agent_function: Function/agent to call for generating responses
71
+ scorers: List of scorer objects to evaluate responses
72
+ prompts: List of prompts to use for training
73
+ num_prompts_per_step: Number of prompts to use per step (defaults to config value, limited by prompts list length)
74
+ num_generations_per_prompt: Generations per prompt (defaults to config value)
75
+ concurrency: Concurrency limit (defaults to config value)
76
+
77
+ Returns:
78
+ List of dataset rows containing samples with messages and evaluations
79
+ """
80
+ num_prompts_per_step = min(
81
+ num_prompts_per_step or self.config.num_prompts_per_step, len(prompts)
82
+ )
83
+ num_generations_per_prompt = (
84
+ num_generations_per_prompt or self.config.num_generations_per_prompt
85
+ )
86
+ concurrency = concurrency or self.config.concurrency
87
+
88
+ semaphore = asyncio.Semaphore(concurrency)
89
+
90
+ @self.tracer.observe(span_type="function")
91
+ async def generate_single_response(prompt_id, generation_id):
92
+ async with semaphore:
93
+ prompt_input = prompts[prompt_id]
94
+ response_data = await agent_function(**prompt_input)
95
+ messages = response_data.get("messages", [])
96
+
97
+ try:
98
+ traced_messages = self.tracer.get_current_message_history()
99
+ if traced_messages:
100
+ messages = traced_messages
101
+ except Exception as e:
102
+ print(f"Warning: Failed to get message history from trace: {e}")
103
+ pass
104
+
105
+ example = Example(
106
+ input=prompt_input,
107
+ messages=messages,
108
+ actual_output=response_data,
109
+ )
110
+
111
+ scoring_results = self.judgment_client.run_evaluation(
112
+ examples=[example],
113
+ scorers=scorers,
114
+ project_name=self.project_name,
115
+ eval_run_name=f"training_step_{self.trainable_model.current_step}_prompt_{prompt_id}_gen_{generation_id}",
116
+ show_url=False,
117
+ )
118
+
119
+ if scoring_results and scoring_results[0].scorers_data:
120
+ reward = sum(
121
+ scorer_data.score
122
+ for scorer_data in scoring_results[0].scorers_data
123
+ ) / len(scoring_results[0].scorers_data)
124
+ else:
125
+ reward = 0.0
126
+
127
+ return {
128
+ "prompt_id": prompt_id,
129
+ "generation_id": generation_id,
130
+ "messages": messages,
131
+ "evals": {"score": reward},
132
+ }
133
+
134
+ coros = []
135
+ for prompt_id in range(num_prompts_per_step):
136
+ for generation_id in range(num_generations_per_prompt):
137
+ coro = generate_single_response(prompt_id, generation_id)
138
+ coros.append(coro)
139
+
140
+ with _spinner_progress(f"Generating {len(coros)} rollouts..."):
141
+ num_completed = 0
142
+ results = []
143
+
144
+ for coro in asyncio.as_completed(coros):
145
+ result = await coro
146
+ results.append(result)
147
+ num_completed += 1
148
+
149
+ _print_progress(f"Generated {len(results)} rollouts successfully")
150
+
151
+ dataset_rows = []
152
+ for prompt_id in range(num_prompts_per_step):
153
+ prompt_generations = [r for r in results if r["prompt_id"] == prompt_id]
154
+ sample_generations = [
155
+ {"messages": gen["messages"], "evals": gen["evals"]}
156
+ for gen in prompt_generations
157
+ ]
158
+ dataset_rows.append({"samples": sample_generations})
159
+
160
+ return dataset_rows
161
+
162
+ async def run_reinforcement_learning(
163
+ self,
164
+ agent_function: Callable[[Any], Any],
165
+ scorers: List[Union[APIScorerConfig, BaseScorer]],
166
+ prompts: List[Any],
167
+ ) -> ModelConfig:
168
+ """
169
+ Run the iterative reinforcement learning fine-tuning loop.
170
+
171
+ This method performs multiple steps of reinforcement learning, where each step:
172
+ 1. Advances to the appropriate model snapshot
173
+ 2. Generates rollouts and computes rewards using scorers
174
+ 3. Trains a new model using reinforcement learning
175
+ 4. Waits for training completion
176
+
177
+ Args:
178
+ agent_function: Function/agent to call for generating responses
179
+ scorers: List of scorer objects to evaluate responses
180
+ prompts: List of prompts to use for training
181
+
182
+ Returns:
183
+ ModelConfig: Configuration of the trained model for inference and future training
184
+ """
185
+
186
+ _print_progress("Starting reinforcement learning training")
187
+
188
+ training_params = {
189
+ "num_steps": self.config.num_steps,
190
+ "num_prompts_per_step": self.config.num_prompts_per_step,
191
+ "num_generations_per_prompt": self.config.num_generations_per_prompt,
192
+ "epochs": self.config.epochs,
193
+ "learning_rate": self.config.learning_rate,
194
+ "accelerator_count": self.config.accelerator_count,
195
+ "accelerator_type": self.config.accelerator_type,
196
+ "temperature": self.config.temperature,
197
+ "max_tokens": self.config.max_tokens,
198
+ }
199
+
200
+ start_step = self.trainable_model.current_step
201
+
202
+ for step in range(start_step, self.config.num_steps):
203
+ step_num = step + 1
204
+ _print_progress(
205
+ f"Starting training step {step_num}", step_num, self.config.num_steps
206
+ )
207
+
208
+ self.trainable_model.advance_to_next_step(step)
209
+
210
+ dataset_rows = await self.generate_rollouts_and_rewards(
211
+ agent_function, scorers, prompts
212
+ )
213
+
214
+ with _spinner_progress(
215
+ "Preparing training dataset", step_num, self.config.num_steps
216
+ ):
217
+ dataset = Dataset.from_list(dataset_rows)
218
+ dataset.sync()
219
+
220
+ _print_progress(
221
+ "Starting reinforcement training", step_num, self.config.num_steps
222
+ )
223
+ job = self.trainable_model.perform_reinforcement_step(dataset, step)
224
+
225
+ last_state = None
226
+ with _spinner_progress(
227
+ "Training job in progress", step_num, self.config.num_steps
228
+ ):
229
+ while not job.is_completed:
230
+ job.raise_if_bad_state()
231
+ current_state = job.state
232
+
233
+ if current_state != last_state:
234
+ if current_state in ["uploading", "validating"]:
235
+ _print_progress_update(
236
+ f"Training job: {current_state} data"
237
+ )
238
+ elif current_state == "training":
239
+ _print_progress_update(
240
+ "Training job: model training in progress"
241
+ )
242
+ else:
243
+ _print_progress_update(f"Training job: {current_state}")
244
+ last_state = current_state
245
+
246
+ time.sleep(10)
247
+ job = job.get()
248
+ if job is None:
249
+ raise JudgmentAPIError(
250
+ "Training job was deleted while waiting for completion"
251
+ )
252
+
253
+ _print_progress(
254
+ f"Training completed! New model: {job.output_model}",
255
+ step_num,
256
+ self.config.num_steps,
257
+ )
258
+
259
+ dataset.delete()
260
+
261
+ _print_progress("All training steps completed!")
262
+
263
+ with _spinner_progress("Deploying final trained model"):
264
+ self.trainable_model.advance_to_next_step(self.config.num_steps)
265
+
266
+ return self.trainable_model.get_model_config(training_params)
267
+
268
+ async def train(
269
+ self,
270
+ agent_function: Callable[[Any], Any],
271
+ scorers: List[Union[APIScorerConfig, BaseScorer]],
272
+ prompts: List[Any],
273
+ rft_provider: Optional[str] = None,
274
+ ) -> ModelConfig:
275
+ """
276
+ Start the reinforcement learning fine-tuning process.
277
+
278
+ This is the main entry point for running the reinforcement learning training.
279
+
280
+ Args:
281
+ agent_function: Function/agent to call for generating responses.
282
+ scorers: List of scorer objects to evaluate responses
283
+ prompts: List of prompts to use for training
284
+ rft_provider: RFT provider to use for training. Currently only "fireworks" is supported.
285
+ Support for other providers is planned for future releases.
286
+
287
+ Returns:
288
+ ModelConfig: Configuration of the trained model for future loading
289
+ """
290
+ try:
291
+ if rft_provider is not None:
292
+ self.config.rft_provider = rft_provider
293
+
294
+ return await self.run_reinforcement_learning(
295
+ agent_function, scorers, prompts
296
+ )
297
+ except JudgmentAPIError:
298
+ # Re-raise JudgmentAPIError as-is
299
+ raise
300
+ except Exception as e:
301
+ raise JudgmentAPIError(f"Training process failed: {str(e)}") from e