judgeval 0.16.9__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of judgeval might be problematic. Click here for more details.

@@ -1,5 +1,14 @@
1
1
  from judgeval.trainer.trainer import JudgmentTrainer
2
2
  from judgeval.trainer.config import TrainerConfig, ModelConfig
3
3
  from judgeval.trainer.trainable_model import TrainableModel
4
+ from judgeval.trainer.base_trainer import BaseTrainer
5
+ from judgeval.trainer.fireworks_trainer import FireworksTrainer
4
6
 
5
- __all__ = ["JudgmentTrainer", "TrainerConfig", "ModelConfig", "TrainableModel"]
7
+ __all__ = [
8
+ "JudgmentTrainer",
9
+ "TrainerConfig",
10
+ "ModelConfig",
11
+ "TrainableModel",
12
+ "BaseTrainer",
13
+ "FireworksTrainer",
14
+ ]
@@ -0,0 +1,117 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Callable, List, Optional, Union, Dict, TYPE_CHECKING
3
+ from .config import TrainerConfig, ModelConfig
4
+ from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
5
+
6
+ if TYPE_CHECKING:
7
+ from judgeval.tracer import Tracer
8
+ from .trainable_model import TrainableModel
9
+
10
+
11
+ class BaseTrainer(ABC):
12
+ """
13
+ Abstract base class for training providers.
14
+
15
+ This class defines the interface that all training provider implementations
16
+ must follow. Each provider (Fireworks, Verifiers, etc.) will have its own
17
+ concrete implementation of this interface.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ config: TrainerConfig,
23
+ trainable_model: "TrainableModel",
24
+ tracer: "Tracer",
25
+ project_name: Optional[str] = None,
26
+ ):
27
+ """
28
+ Initialize the base trainer.
29
+
30
+ Args:
31
+ config: TrainerConfig instance with training parameters
32
+ trainable_model: TrainableModel instance to use for training
33
+ tracer: Tracer for observability
34
+ project_name: Project name for organizing training runs
35
+ """
36
+ self.config = config
37
+ self.trainable_model = trainable_model
38
+ self.tracer = tracer
39
+ self.project_name = project_name or "judgment_training"
40
+
41
+ @abstractmethod
42
+ async def generate_rollouts_and_rewards(
43
+ self,
44
+ agent_function: Callable[[Any], Any],
45
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
46
+ prompts: List[Any],
47
+ num_prompts_per_step: Optional[int] = None,
48
+ num_generations_per_prompt: Optional[int] = None,
49
+ concurrency: Optional[int] = None,
50
+ ) -> Any:
51
+ """
52
+ Generate rollouts and compute rewards using the current model snapshot.
53
+
54
+ Args:
55
+ agent_function: Function/agent to call for generating responses
56
+ scorers: List of scorer objects to evaluate responses
57
+ prompts: List of prompts to use for training
58
+ num_prompts_per_step: Number of prompts to use per step
59
+ num_generations_per_prompt: Generations per prompt
60
+ concurrency: Concurrency limit
61
+
62
+ Returns:
63
+ Provider-specific dataset format for training
64
+ """
65
+ pass
66
+
67
+ @abstractmethod
68
+ async def run_reinforcement_learning(
69
+ self,
70
+ agent_function: Callable[[Any], Any],
71
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
72
+ prompts: List[Any],
73
+ ) -> ModelConfig:
74
+ """
75
+ Run the iterative reinforcement learning fine-tuning loop.
76
+
77
+ Args:
78
+ agent_function: Function/agent to call for generating responses
79
+ scorers: List of scorer objects to evaluate responses
80
+ prompts: List of prompts to use for training
81
+
82
+ Returns:
83
+ ModelConfig: Configuration of the trained model
84
+ """
85
+ pass
86
+
87
+ @abstractmethod
88
+ async def train(
89
+ self,
90
+ agent_function: Callable[[Any], Any],
91
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
92
+ prompts: List[Any],
93
+ ) -> ModelConfig:
94
+ """
95
+ Start the reinforcement learning fine-tuning process.
96
+
97
+ This is the main entry point for running the training.
98
+
99
+ Args:
100
+ agent_function: Function/agent to call for generating responses
101
+ scorers: List of scorer objects to evaluate responses
102
+ prompts: List of prompts to use for training
103
+
104
+ Returns:
105
+ ModelConfig: Configuration of the trained model
106
+ """
107
+ pass
108
+
109
+ @abstractmethod
110
+ def _extract_message_history_from_spans(self) -> List[Dict[str, str]]:
111
+ """
112
+ Extract message history from spans for training purposes.
113
+
114
+ Returns:
115
+ List of message dictionaries with 'role' and 'content' keys
116
+ """
117
+ pass
@@ -16,7 +16,7 @@ class TrainerConfig:
16
16
  user_id: str
17
17
  model_id: str
18
18
  base_model_name: str = "qwen2p5-7b-instruct"
19
- rft_provider: str = "fireworks"
19
+ rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
20
20
  num_steps: int = 5
21
21
  num_generations_per_prompt: int = 4
22
22
  num_prompts_per_step: int = 4
@@ -0,0 +1,381 @@
1
+ import asyncio
2
+ import json
3
+ from typing import Optional, Callable, Any, List, Union, Dict
4
+ from fireworks import Dataset # type: ignore[import-not-found]
5
+ from .config import TrainerConfig, ModelConfig
6
+ from .base_trainer import BaseTrainer
7
+ from .trainable_model import TrainableModel
8
+ from judgeval.tracer import Tracer
9
+ from judgeval.tracer.exporters.store import SpanStore
10
+ from judgeval.tracer.exporters import InMemorySpanExporter
11
+ from judgeval.tracer.keys import AttributeKeys
12
+ from judgeval import JudgmentClient
13
+ from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
14
+ from judgeval.data import Example
15
+ from .console import _spinner_progress, _print_progress, _print_progress_update
16
+ from judgeval.exceptions import JudgmentRuntimeError
17
+
18
+
19
+ class FireworksTrainer(BaseTrainer):
20
+ """
21
+ Fireworks AI implementation of the training provider.
22
+
23
+ This trainer uses Fireworks AI's infrastructure for reinforcement learning
24
+ fine-tuning (RFT) of language models.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ config: TrainerConfig,
30
+ trainable_model: TrainableModel,
31
+ tracer: Tracer,
32
+ project_name: Optional[str] = None,
33
+ ):
34
+ """
35
+ Initialize the FireworksTrainer.
36
+
37
+ Args:
38
+ config: TrainerConfig instance with training parameters
39
+ trainable_model: TrainableModel instance for Fireworks training
40
+ tracer: Tracer for observability
41
+ project_name: Project name for organizing training runs and evaluations
42
+ """
43
+ try:
44
+ super().__init__(config, trainable_model, tracer, project_name)
45
+
46
+ self.judgment_client = JudgmentClient()
47
+ self.span_store = SpanStore()
48
+ self.span_exporter = InMemorySpanExporter(self.span_store)
49
+ except Exception as e:
50
+ raise JudgmentRuntimeError(
51
+ f"Failed to initialize FireworksTrainer: {str(e)}"
52
+ ) from e
53
+
54
+ def _extract_message_history_from_spans(self) -> List[Dict[str, str]]:
55
+ """
56
+ Extract message history from spans in the span store for training purposes.
57
+
58
+ This method processes trace spans to reconstruct the conversation flow,
59
+ extracting messages in chronological order from LLM, user, and tool spans.
60
+
61
+ Returns:
62
+ List of message dictionaries with 'role' and 'content' keys
63
+ """
64
+ spans = self.span_store.get_all()
65
+ if not spans:
66
+ return []
67
+
68
+ messages = []
69
+ first_found = False
70
+
71
+ for span in sorted(spans, key=lambda s: getattr(s, "start_time", 0)):
72
+ span_attributes = span.attributes or {}
73
+ span_type = span_attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND, "span")
74
+
75
+ if (
76
+ not span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
77
+ and span_type != "llm"
78
+ ):
79
+ continue
80
+
81
+ if span_type == "llm":
82
+ if not first_found and span_attributes.get(
83
+ AttributeKeys.JUDGMENT_INPUT
84
+ ):
85
+ input_data: Any = span_attributes.get(
86
+ AttributeKeys.JUDGMENT_INPUT, {}
87
+ )
88
+ if isinstance(input_data, dict) and "messages" in input_data:
89
+ input_messages = input_data["messages"]
90
+ if input_messages:
91
+ first_found = True
92
+ for msg in input_messages:
93
+ if (
94
+ isinstance(msg, dict)
95
+ and "role" in msg
96
+ and "content" in msg
97
+ ):
98
+ messages.append(
99
+ {"role": msg["role"], "content": msg["content"]}
100
+ )
101
+
102
+ # Add assistant response from span output
103
+ output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
104
+ if output is not None:
105
+ content = str(output)
106
+ try:
107
+ parsed = json.loads(content)
108
+ if isinstance(parsed, dict) and "messages" in parsed:
109
+ # Extract the actual assistant message content
110
+ for msg in parsed["messages"]:
111
+ if (
112
+ isinstance(msg, dict)
113
+ and msg.get("role") == "assistant"
114
+ ):
115
+ content = msg.get("content", content)
116
+ break
117
+ except (json.JSONDecodeError, KeyError):
118
+ pass
119
+ messages.append({"role": "assistant", "content": content})
120
+
121
+ elif span_type in ("user", "tool"):
122
+ output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
123
+ if output is not None:
124
+ content = str(output)
125
+ try:
126
+ parsed = json.loads(content)
127
+ if isinstance(parsed, dict) and "messages" in parsed:
128
+ for msg in parsed["messages"]:
129
+ if isinstance(msg, dict) and msg.get("role") == "user":
130
+ content = msg.get("content", content)
131
+ break
132
+ except (json.JSONDecodeError, KeyError):
133
+ pass
134
+ messages.append({"role": "user", "content": content})
135
+
136
+ return messages
137
+
138
+ async def generate_rollouts_and_rewards(
139
+ self,
140
+ agent_function: Callable[[Any], Any],
141
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
142
+ prompts: List[Any],
143
+ num_prompts_per_step: Optional[int] = None,
144
+ num_generations_per_prompt: Optional[int] = None,
145
+ concurrency: Optional[int] = None,
146
+ ):
147
+ """
148
+ Generate rollouts and compute rewards using the current model snapshot.
149
+ Each sample contains multiple generations for reinforcement learning optimization.
150
+
151
+ Args:
152
+ agent_function: Function/agent to call for generating responses
153
+ scorers: List of scorer objects to evaluate responses
154
+ prompts: List of prompts to use for training
155
+ num_prompts_per_step: Number of prompts to use per step (defaults to config value, limited by prompts list length)
156
+ num_generations_per_prompt: Generations per prompt (defaults to config value)
157
+ concurrency: Concurrency limit (defaults to config value)
158
+
159
+ Returns:
160
+ List of dataset rows containing samples with messages and evaluations
161
+ """
162
+ num_prompts_per_step = min(
163
+ num_prompts_per_step or self.config.num_prompts_per_step, len(prompts)
164
+ )
165
+ num_generations_per_prompt = (
166
+ num_generations_per_prompt or self.config.num_generations_per_prompt
167
+ )
168
+ concurrency = concurrency or self.config.concurrency
169
+
170
+ semaphore = asyncio.Semaphore(concurrency)
171
+
172
+ @self.tracer.observe(span_type="function")
173
+ async def generate_single_response(prompt_id, generation_id):
174
+ async with semaphore:
175
+ prompt_input = prompts[prompt_id]
176
+ response_data = await agent_function(**prompt_input)
177
+ messages = response_data.get("messages", [])
178
+
179
+ try:
180
+ traced_messages = self._extract_message_history_from_spans()
181
+ if traced_messages:
182
+ messages = traced_messages
183
+ except Exception as e:
184
+ print(f"Warning: Failed to get message history from trace: {e}")
185
+ pass
186
+
187
+ finally:
188
+ self.span_store.spans = []
189
+
190
+ example = Example(
191
+ input=prompt_input,
192
+ messages=messages,
193
+ actual_output=response_data,
194
+ )
195
+
196
+ scoring_results = self.judgment_client.run_evaluation(
197
+ examples=[example],
198
+ scorers=scorers,
199
+ project_name=self.project_name,
200
+ eval_run_name=f"training_step_{self.trainable_model.current_step}_prompt_{prompt_id}_gen_{generation_id}",
201
+ )
202
+
203
+ if scoring_results and scoring_results[0].scorers_data:
204
+ scores = [
205
+ scorer_data.score
206
+ for scorer_data in scoring_results[0].scorers_data
207
+ if scorer_data.score is not None
208
+ ]
209
+ reward = sum(scores) / len(scores) if scores else 0.0
210
+ else:
211
+ reward = 0.0
212
+
213
+ return {
214
+ "prompt_id": prompt_id,
215
+ "generation_id": generation_id,
216
+ "messages": messages,
217
+ "evals": {"score": reward},
218
+ }
219
+
220
+ coros = []
221
+ for prompt_id in range(num_prompts_per_step):
222
+ for generation_id in range(num_generations_per_prompt):
223
+ coro = generate_single_response(prompt_id, generation_id)
224
+ coros.append(coro)
225
+
226
+ with _spinner_progress(f"Generating {len(coros)} rollouts..."):
227
+ num_completed = 0
228
+ results = []
229
+
230
+ for coro in asyncio.as_completed(coros):
231
+ result = await coro
232
+ results.append(result)
233
+ num_completed += 1
234
+
235
+ _print_progress(f"Generated {len(results)} rollouts successfully")
236
+
237
+ dataset_rows = []
238
+ for prompt_id in range(num_prompts_per_step):
239
+ prompt_generations = [r for r in results if r["prompt_id"] == prompt_id]
240
+ sample_generations = [
241
+ {"messages": gen["messages"], "evals": gen["evals"]}
242
+ for gen in prompt_generations
243
+ ]
244
+ dataset_rows.append({"samples": sample_generations})
245
+
246
+ return dataset_rows
247
+
248
+ async def run_reinforcement_learning(
249
+ self,
250
+ agent_function: Callable[[Any], Any],
251
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
252
+ prompts: List[Any],
253
+ ) -> ModelConfig:
254
+ """
255
+ Run the iterative reinforcement learning fine-tuning loop.
256
+
257
+ This method performs multiple steps of reinforcement learning, where each step:
258
+ 1. Advances to the appropriate model snapshot
259
+ 2. Generates rollouts and computes rewards using scorers
260
+ 3. Trains a new model using reinforcement learning
261
+ 4. Waits for training completion
262
+
263
+ Args:
264
+ agent_function: Function/agent to call for generating responses
265
+ scorers: List of scorer objects to evaluate responses
266
+ prompts: List of prompts to use for training
267
+
268
+ Returns:
269
+ ModelConfig: Configuration of the trained model for inference and future training
270
+ """
271
+
272
+ _print_progress("Starting reinforcement learning training")
273
+
274
+ training_params = {
275
+ "num_steps": self.config.num_steps,
276
+ "num_prompts_per_step": self.config.num_prompts_per_step,
277
+ "num_generations_per_prompt": self.config.num_generations_per_prompt,
278
+ "epochs": self.config.epochs,
279
+ "learning_rate": self.config.learning_rate,
280
+ "accelerator_count": self.config.accelerator_count,
281
+ "accelerator_type": self.config.accelerator_type,
282
+ "temperature": self.config.temperature,
283
+ "max_tokens": self.config.max_tokens,
284
+ }
285
+
286
+ start_step = self.trainable_model.current_step
287
+
288
+ for step in range(start_step, self.config.num_steps):
289
+ step_num = step + 1
290
+ _print_progress(
291
+ f"Starting training step {step_num}", step_num, self.config.num_steps
292
+ )
293
+
294
+ self.trainable_model.advance_to_next_step(step)
295
+
296
+ dataset_rows = await self.generate_rollouts_and_rewards(
297
+ agent_function, scorers, prompts
298
+ )
299
+
300
+ with _spinner_progress(
301
+ "Preparing training dataset", step_num, self.config.num_steps
302
+ ):
303
+ dataset = Dataset.from_list(dataset_rows)
304
+ dataset.sync()
305
+
306
+ _print_progress(
307
+ "Starting reinforcement training", step_num, self.config.num_steps
308
+ )
309
+ job = self.trainable_model.perform_reinforcement_step(dataset, step)
310
+
311
+ last_state = None
312
+ with _spinner_progress(
313
+ "Training job in progress", step_num, self.config.num_steps
314
+ ):
315
+ while not job.is_completed:
316
+ job.raise_if_bad_state()
317
+ current_state = job.state
318
+
319
+ if current_state != last_state:
320
+ if current_state in ["uploading", "validating"]:
321
+ _print_progress_update(
322
+ f"Training job: {current_state} data"
323
+ )
324
+ elif current_state == "training":
325
+ _print_progress_update(
326
+ "Training job: model training in progress"
327
+ )
328
+ else:
329
+ _print_progress_update(f"Training job: {current_state}")
330
+ last_state = current_state
331
+
332
+ await asyncio.sleep(10)
333
+ job = job.get()
334
+ if job is None:
335
+ raise JudgmentRuntimeError(
336
+ "Training job was deleted while waiting for completion"
337
+ )
338
+
339
+ _print_progress(
340
+ f"Training completed! New model: {job.output_model}",
341
+ step_num,
342
+ self.config.num_steps,
343
+ )
344
+
345
+ dataset.delete()
346
+
347
+ _print_progress("All training steps completed!")
348
+
349
+ with _spinner_progress("Deploying final trained model"):
350
+ self.trainable_model.advance_to_next_step(self.config.num_steps)
351
+
352
+ return self.trainable_model.get_model_config(training_params)
353
+
354
+ async def train(
355
+ self,
356
+ agent_function: Callable[[Any], Any],
357
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
358
+ prompts: List[Any],
359
+ ) -> ModelConfig:
360
+ """
361
+ Start the reinforcement learning fine-tuning process.
362
+
363
+ This is the main entry point for running the reinforcement learning training.
364
+
365
+ Args:
366
+ agent_function: Function/agent to call for generating responses.
367
+ scorers: List of scorer objects to evaluate responses
368
+ prompts: List of prompts to use for training
369
+
370
+ Returns:
371
+ ModelConfig: Configuration of the trained model for future loading
372
+ """
373
+ try:
374
+ return await self.run_reinforcement_learning(
375
+ agent_function, scorers, prompts
376
+ )
377
+ except JudgmentRuntimeError:
378
+ # Re-raise JudgmentRuntimeError as-is
379
+ raise
380
+ except Exception as e:
381
+ raise JudgmentRuntimeError(f"Training process failed: {str(e)}") from e
@@ -1,405 +1,70 @@
1
- import asyncio
2
- import json
3
- import time
4
- from typing import Optional, Callable, Any, List, Union, Dict
5
- from fireworks import Dataset # type: ignore[import-not-found]
6
- from .config import TrainerConfig, ModelConfig
1
+ from typing import Optional
2
+ from .config import TrainerConfig
3
+ from .base_trainer import BaseTrainer
4
+ from .fireworks_trainer import FireworksTrainer
7
5
  from .trainable_model import TrainableModel
8
6
  from judgeval.tracer import Tracer
9
- from judgeval.tracer.exporters.store import SpanStore
10
- from judgeval.tracer.exporters import InMemorySpanExporter
11
- from judgeval.tracer.keys import AttributeKeys
12
- from judgeval import JudgmentClient
13
- from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
14
- from judgeval.data import Example
15
- from .console import _spinner_progress, _print_progress, _print_progress_update
16
7
  from judgeval.exceptions import JudgmentRuntimeError
17
8
 
18
9
 
19
- class JudgmentTrainer:
10
+ def JudgmentTrainer(
11
+ config: TrainerConfig,
12
+ trainable_model: TrainableModel,
13
+ tracer: Tracer,
14
+ project_name: Optional[str] = None,
15
+ ) -> BaseTrainer:
20
16
  """
21
- A reinforcement learning trainer for Judgment models using Fine-Tuning.
17
+ Factory function for creating reinforcement learning trainers.
22
18
 
23
- This class handles the iterative training process where models are improved
24
- through reinforcement learning fine-tuning based on generated rollouts and rewards.
25
- """
26
-
27
- def __init__(
28
- self,
29
- config: TrainerConfig,
30
- trainable_model: TrainableModel,
31
- tracer: Tracer,
32
- project_name: Optional[str] = None,
33
- ):
34
- """
35
- Initialize the JudgmentTrainer.
36
-
37
- Args:
38
- config: TrainerConfig instance with training parameters. If None, uses default config.
39
- tracer: Optional tracer for observability
40
- trainable_model: Optional trainable model instance
41
- project_name: Project name for organizing training runs and evaluations
42
- """
43
- try:
44
- self.config = config
45
- self.tracer = tracer
46
- self.project_name = project_name or "judgment_training"
47
- self.trainable_model = trainable_model
48
-
49
- self.judgment_client = JudgmentClient()
50
- self.span_store = SpanStore()
51
- self.span_exporter = InMemorySpanExporter(self.span_store)
52
- except Exception as e:
53
- raise JudgmentRuntimeError(
54
- f"Failed to initialize JudgmentTrainer: {str(e)}"
55
- ) from e
56
-
57
- def _extract_message_history_from_spans(self) -> List[Dict[str, str]]:
58
- """
59
- Extract message history from spans in the span store for training purposes.
60
-
61
- This method processes trace spans to reconstruct the conversation flow,
62
- extracting messages in chronological order from LLM, user, and tool spans.
63
-
64
- Returns:
65
- List of message dictionaries with 'role' and 'content' keys
66
- """
67
- spans = self.span_store.get_all()
68
- if not spans:
69
- return []
70
-
71
- messages = []
72
- first_found = False
73
-
74
- for span in sorted(spans, key=lambda s: getattr(s, "start_time", 0)):
75
- span_attributes = span.attributes or {}
76
- span_type = span_attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND, "span")
77
-
78
- if (
79
- not span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
80
- and span_type != "llm"
81
- ):
82
- continue
83
-
84
- if span_type == "llm":
85
- if not first_found and span_attributes.get(
86
- AttributeKeys.JUDGMENT_INPUT
87
- ):
88
- input_data: Any = span_attributes.get(
89
- AttributeKeys.JUDGMENT_INPUT, {}
90
- )
91
- if isinstance(input_data, dict) and "messages" in input_data:
92
- input_messages = input_data["messages"]
93
- if input_messages:
94
- first_found = True
95
- for msg in input_messages:
96
- if (
97
- isinstance(msg, dict)
98
- and "role" in msg
99
- and "content" in msg
100
- ):
101
- messages.append(
102
- {"role": msg["role"], "content": msg["content"]}
103
- )
104
-
105
- # Add assistant response from span output
106
- output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
107
- if output is not None:
108
- content = str(output)
109
- try:
110
- parsed = json.loads(content)
111
- if isinstance(parsed, dict) and "messages" in parsed:
112
- # Extract the actual assistant message content
113
- for msg in parsed["messages"]:
114
- if (
115
- isinstance(msg, dict)
116
- and msg.get("role") == "assistant"
117
- ):
118
- content = msg.get("content", content)
119
- break
120
- except (json.JSONDecodeError, KeyError):
121
- pass
122
- messages.append({"role": "assistant", "content": content})
123
-
124
- elif span_type == "user":
125
- output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
126
- if output is not None:
127
- content = str(output)
128
- try:
129
- parsed = json.loads(content)
130
- if isinstance(parsed, dict) and "messages" in parsed:
131
- for msg in parsed["messages"]:
132
- if isinstance(msg, dict) and msg.get("role") == "user":
133
- content = msg.get("content", content)
134
- break
135
- except (json.JSONDecodeError, KeyError):
136
- pass
137
- messages.append({"role": "user", "content": content})
19
+ This factory creates and returns provider-specific trainer implementations
20
+ (FireworksTrainer, VerifiersTrainer, etc.) based on the configured RFT provider.
138
21
 
139
- elif span_type == "tool":
140
- output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
141
- if output is not None:
142
- content = str(output)
143
- try:
144
- parsed = json.loads(content)
145
- if isinstance(parsed, dict) and "messages" in parsed:
146
- for msg in parsed["messages"]:
147
- if isinstance(msg, dict) and msg.get("role") == "user":
148
- content = msg.get("content", content)
149
- break
150
- except (json.JSONDecodeError, KeyError):
151
- pass
152
- messages.append({"role": "user", "content": content})
22
+ The factory pattern allows for easy extension to support multiple training
23
+ providers without changing the client-facing API.
153
24
 
154
- return messages
155
-
156
- async def generate_rollouts_and_rewards(
157
- self,
158
- agent_function: Callable[[Any], Any],
159
- scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
160
- prompts: List[Any],
161
- num_prompts_per_step: Optional[int] = None,
162
- num_generations_per_prompt: Optional[int] = None,
163
- concurrency: Optional[int] = None,
164
- ):
165
- """
166
- Generate rollouts and compute rewards using the current model snapshot.
167
- Each sample contains multiple generations for reinforcement learning optimization.
168
-
169
- Args:
170
- agent_function: Function/agent to call for generating responses
171
- scorers: List of scorer objects to evaluate responses
172
- prompts: List of prompts to use for training
173
- num_prompts_per_step: Number of prompts to use per step (defaults to config value, limited by prompts list length)
174
- num_generations_per_prompt: Generations per prompt (defaults to config value)
175
- concurrency: Concurrency limit (defaults to config value)
176
-
177
- Returns:
178
- List of dataset rows containing samples with messages and evaluations
179
- """
180
- num_prompts_per_step = min(
181
- num_prompts_per_step or self.config.num_prompts_per_step, len(prompts)
182
- )
183
- num_generations_per_prompt = (
184
- num_generations_per_prompt or self.config.num_generations_per_prompt
25
+ Example:
26
+ config = TrainerConfig(
27
+ deployment_id="my-deployment",
28
+ user_id="my-user",
29
+ model_id="my-model",
30
+ rft_provider="fireworks" # or "verifiers" in the future
185
31
  )
186
- concurrency = concurrency or self.config.concurrency
187
-
188
- semaphore = asyncio.Semaphore(concurrency)
189
-
190
- @self.tracer.observe(span_type="function")
191
- async def generate_single_response(prompt_id, generation_id):
192
- async with semaphore:
193
- prompt_input = prompts[prompt_id]
194
- response_data = await agent_function(**prompt_input)
195
- messages = response_data.get("messages", [])
196
-
197
- try:
198
- traced_messages = self._extract_message_history_from_spans()
199
- if traced_messages:
200
- messages = traced_messages
201
- except Exception as e:
202
- print(f"Warning: Failed to get message history from trace: {e}")
203
- pass
204
-
205
- finally:
206
- self.span_store.spans = []
207
-
208
- example = Example(
209
- input=prompt_input,
210
- messages=messages,
211
- actual_output=response_data,
212
- )
213
-
214
- scoring_results = self.judgment_client.run_evaluation(
215
- examples=[example],
216
- scorers=scorers,
217
- project_name=self.project_name,
218
- eval_run_name=f"training_step_{self.trainable_model.current_step}_prompt_{prompt_id}_gen_{generation_id}",
219
- )
220
-
221
- if scoring_results and scoring_results[0].scorers_data:
222
- scores = [
223
- scorer_data.score
224
- for scorer_data in scoring_results[0].scorers_data
225
- if scorer_data.score is not None
226
- ]
227
- reward = sum(scores) / len(scores) if scores else 0.0
228
- else:
229
- reward = 0.0
230
-
231
- return {
232
- "prompt_id": prompt_id,
233
- "generation_id": generation_id,
234
- "messages": messages,
235
- "evals": {"score": reward},
236
- }
237
-
238
- coros = []
239
- for prompt_id in range(num_prompts_per_step):
240
- for generation_id in range(num_generations_per_prompt):
241
- coro = generate_single_response(prompt_id, generation_id)
242
- coros.append(coro)
243
32
 
244
- with _spinner_progress(f"Generating {len(coros)} rollouts..."):
245
- num_completed = 0
246
- results = []
33
+ # User creates and configures the trainable model
34
+ trainable_model = TrainableModel(config)
35
+ tracer = Tracer()
247
36
 
248
- for coro in asyncio.as_completed(coros):
249
- result = await coro
250
- results.append(result)
251
- num_completed += 1
37
+ # JudgmentTrainer automatically creates the appropriate provider-specific trainer
38
+ trainer = JudgmentTrainer(config, trainable_model, tracer)
252
39
 
253
- _print_progress(f"Generated {len(results)} rollouts successfully")
254
-
255
- dataset_rows = []
256
- for prompt_id in range(num_prompts_per_step):
257
- prompt_generations = [r for r in results if r["prompt_id"] == prompt_id]
258
- sample_generations = [
259
- {"messages": gen["messages"], "evals": gen["evals"]}
260
- for gen in prompt_generations
261
- ]
262
- dataset_rows.append({"samples": sample_generations})
263
-
264
- return dataset_rows
265
-
266
- async def run_reinforcement_learning(
267
- self,
268
- agent_function: Callable[[Any], Any],
269
- scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
270
- prompts: List[Any],
271
- ) -> ModelConfig:
272
- """
273
- Run the iterative reinforcement learning fine-tuning loop.
274
-
275
- This method performs multiple steps of reinforcement learning, where each step:
276
- 1. Advances to the appropriate model snapshot
277
- 2. Generates rollouts and computes rewards using scorers
278
- 3. Trains a new model using reinforcement learning
279
- 4. Waits for training completion
40
+ # The returned trainer implements the BaseTrainer interface
41
+ model_config = await trainer.train(agent_function, scorers, prompts)
280
42
 
281
43
  Args:
282
- agent_function: Function/agent to call for generating responses
283
- scorers: List of scorer objects to evaluate responses
284
- prompts: List of prompts to use for training
44
+ config: TrainerConfig instance with training parameters including rft_provider
45
+ trainable_model: Provider-specific trainable model instance (e.g., TrainableModel for Fireworks)
46
+ tracer: Tracer for observability
47
+ project_name: Project name for organizing training runs and evaluations
285
48
 
286
49
  Returns:
287
- ModelConfig: Configuration of the trained model for inference and future training
288
- """
289
-
290
- _print_progress("Starting reinforcement learning training")
291
-
292
- training_params = {
293
- "num_steps": self.config.num_steps,
294
- "num_prompts_per_step": self.config.num_prompts_per_step,
295
- "num_generations_per_prompt": self.config.num_generations_per_prompt,
296
- "epochs": self.config.epochs,
297
- "learning_rate": self.config.learning_rate,
298
- "accelerator_count": self.config.accelerator_count,
299
- "accelerator_type": self.config.accelerator_type,
300
- "temperature": self.config.temperature,
301
- "max_tokens": self.config.max_tokens,
302
- }
303
-
304
- start_step = self.trainable_model.current_step
305
-
306
- for step in range(start_step, self.config.num_steps):
307
- step_num = step + 1
308
- _print_progress(
309
- f"Starting training step {step_num}", step_num, self.config.num_steps
310
- )
311
-
312
- self.trainable_model.advance_to_next_step(step)
313
-
314
- dataset_rows = await self.generate_rollouts_and_rewards(
315
- agent_function, scorers, prompts
316
- )
317
-
318
- with _spinner_progress(
319
- "Preparing training dataset", step_num, self.config.num_steps
320
- ):
321
- dataset = Dataset.from_list(dataset_rows)
322
- dataset.sync()
323
-
324
- _print_progress(
325
- "Starting reinforcement training", step_num, self.config.num_steps
326
- )
327
- job = self.trainable_model.perform_reinforcement_step(dataset, step)
328
-
329
- last_state = None
330
- with _spinner_progress(
331
- "Training job in progress", step_num, self.config.num_steps
332
- ):
333
- while not job.is_completed:
334
- job.raise_if_bad_state()
335
- current_state = job.state
336
-
337
- if current_state != last_state:
338
- if current_state in ["uploading", "validating"]:
339
- _print_progress_update(
340
- f"Training job: {current_state} data"
341
- )
342
- elif current_state == "training":
343
- _print_progress_update(
344
- "Training job: model training in progress"
345
- )
346
- else:
347
- _print_progress_update(f"Training job: {current_state}")
348
- last_state = current_state
349
-
350
- time.sleep(10)
351
- job = job.get()
352
- if job is None:
353
- raise JudgmentRuntimeError(
354
- "Training job was deleted while waiting for completion"
355
- )
50
+ Provider-specific trainer instance (FireworksTrainer, etc.) that implements
51
+ the BaseTrainer interface
356
52
 
357
- _print_progress(
358
- f"Training completed! New model: {job.output_model}",
359
- step_num,
360
- self.config.num_steps,
361
- )
362
-
363
- dataset.delete()
364
-
365
- _print_progress("All training steps completed!")
366
-
367
- with _spinner_progress("Deploying final trained model"):
368
- self.trainable_model.advance_to_next_step(self.config.num_steps)
369
-
370
- return self.trainable_model.get_model_config(training_params)
371
-
372
- async def train(
373
- self,
374
- agent_function: Callable[[Any], Any],
375
- scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
376
- prompts: List[Any],
377
- rft_provider: Optional[str] = None,
378
- ) -> ModelConfig:
379
- """
380
- Start the reinforcement learning fine-tuning process.
381
-
382
- This is the main entry point for running the reinforcement learning training.
383
-
384
- Args:
385
- agent_function: Function/agent to call for generating responses.
386
- scorers: List of scorer objects to evaluate responses
387
- prompts: List of prompts to use for training
388
- rft_provider: RFT provider to use for training. Currently only "fireworks" is supported.
389
- Support for other providers is planned for future releases.
390
-
391
- Returns:
392
- ModelConfig: Configuration of the trained model for future loading
393
- """
394
- try:
395
- if rft_provider is not None:
396
- self.config.rft_provider = rft_provider
397
-
398
- return await self.run_reinforcement_learning(
399
- agent_function, scorers, prompts
400
- )
401
- except JudgmentRuntimeError:
402
- # Re-raise JudgmentAPIError as-is
403
- raise
404
- except Exception as e:
405
- raise JudgmentRuntimeError(f"Training process failed: {str(e)}") from e
53
+ Raises:
54
+ JudgmentRuntimeError: If the specified provider is not supported
55
+ """
56
+ provider = config.rft_provider.lower()
57
+
58
+ if provider == "fireworks":
59
+ return FireworksTrainer(config, trainable_model, tracer, project_name)
60
+ elif provider == "verifiers":
61
+ # Placeholder for future implementation
62
+ raise JudgmentRuntimeError(
63
+ "Verifiers provider is not yet implemented. "
64
+ "Currently supported providers: 'fireworks'"
65
+ )
66
+ else:
67
+ raise JudgmentRuntimeError(
68
+ f"Unsupported RFT provider: '{config.rft_provider}'. "
69
+ f"Currently supported providers: 'fireworks'"
70
+ )
judgeval/version.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.16.9"
1
+ __version__ = "0.17.0"
2
2
 
3
3
 
4
4
  def get_version() -> str:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: judgeval
3
- Version: 0.16.9
3
+ Version: 0.17.0
4
4
  Summary: Judgeval Package
5
5
  Project-URL: Homepage, https://github.com/JudgmentLabs/judgeval
6
6
  Project-URL: Issues, https://github.com/JudgmentLabs/judgeval/issues
@@ -63,8 +63,7 @@ Judgeval's agent monitoring infra provides a simple harness for integrating GRPO
63
63
  await trainer.train(
64
64
  agent_function=your_agent_function, # entry point to your agent
65
65
  scorers=[RewardScorer()], # Custom scorer you define based on task criteria, acts as reward
66
- prompts=training_prompts, # Tasks
67
- rft_provider="fireworks"
66
+ prompts=training_prompts # Tasks
68
67
  )
69
68
  ```
70
69
 
@@ -4,7 +4,7 @@ judgeval/constants.py,sha256=JZZJ1MqzZZDVk-5PRPRbmLnM8mXI-RDL5vxa1JFuscs,3408
4
4
  judgeval/env.py,sha256=37Mn4g0OkpFxXCZGlO_CLqKJnyX-jx_R24tC28XJzig,2112
5
5
  judgeval/exceptions.py,sha256=tTbfe4yoOtPXmn22UQz9-6a-5PT9uOko85xaRRwr0Sw,621
6
6
  judgeval/logger.py,sha256=VP5blbsJ53mvJbNHfBf5p2KrARUrkrErpPkB-__Hh3U,1562
7
- judgeval/version.py,sha256=na4SICn1_ldveglTM2Suf3pZLRnw2qbMJMUmIhGkh0Q,74
7
+ judgeval/version.py,sha256=vPcSY2o-MH6v7gn4Fzt6yeb_jPUs2hu117IC_EWy33g,74
8
8
  judgeval/warnings.py,sha256=LbGte14ppiFjrkp-JJYueZ40NWFvMkWRvPXr6r-fUWw,73
9
9
  judgeval/api/__init__.py,sha256=ho8L4wC9y-STYEpk5zHwc2mZJhC4ezW8jiGgOIERBVY,12058
10
10
  judgeval/api/api_types.py,sha256=xOHcgK8NTHMuBr1HBHlCvoSYldVOtG8DQsXeo23-YQk,8874
@@ -74,11 +74,13 @@ judgeval/tracer/llm/llm_together/chat_completions.py,sha256=YxVL1zqG7Tjoss0BH3hm
74
74
  judgeval/tracer/llm/llm_together/config.py,sha256=jCJY0KQcHJZZJk2vq038GKIDUMusqgvRjQ0B6OV5uEc,150
75
75
  judgeval/tracer/llm/llm_together/wrapper.py,sha256=HFqy_MabQeSq8oj2diZhEuk1SDt_hDfk5MFdPn9MFhg,1733
76
76
  judgeval/tracer/processors/__init__.py,sha256=BdOOPOD1RfMI5YHW76DNPKR07EAev-JxoolZ3KaXNNU,7100
77
- judgeval/trainer/__init__.py,sha256=h_DDVV7HFF7HUPAJFpt2d9wjqgnmEVcHxqZyB1k7pPQ,257
78
- judgeval/trainer/config.py,sha256=sAAVBgeoFDJWYjGIgOvoQoiO0gtqNAOI6MHncwdN_mk,4292
77
+ judgeval/trainer/__init__.py,sha256=nJo913vFdss3E_PR-M1OUjznS0SYgNZ-MP-Y_6Mj5PA,437
78
+ judgeval/trainer/base_trainer.py,sha256=21adIMmYyn7XKbiI1Dc6N5thPbuH5wK7vVfrtoFX6Ys,3886
79
+ judgeval/trainer/config.py,sha256=7ZSwr6p7vq0MRadh9axm6XB-RAotdWqULZ5yDl0xGbQ,4340
79
80
  judgeval/trainer/console.py,sha256=SvokkFEU-K1vLV4Rd1m6YJJ7HyYwTr4Azdzwx_JPZUY,4351
81
+ judgeval/trainer/fireworks_trainer.py,sha256=FqGoS1OzmxzyT0134e_EW3pgzFNO04GpKST4NcjYSyU,15432
80
82
  judgeval/trainer/trainable_model.py,sha256=T-Sioi_sXtfYlcu3lE0cd60PHs8DrYaZ-Kxb4h1nU04,8993
81
- judgeval/trainer/trainer.py,sha256=FBhHq2YPooKADDCC_IEKex81L6a5quCmAMyl9mn3QLk,16675
83
+ judgeval/trainer/trainer.py,sha256=twLEHNaomelTg6ZYG6veI9OpB3wzhPCtPVQMTnDZWx4,2626
82
84
  judgeval/utils/async_utils.py,sha256=AF1xdu8Ao5GyhFvfaLOaKJHn1RISyXZ4U70UZe9zfBA,1083
83
85
  judgeval/utils/file_utils.py,sha256=vq-n5WZEZjVbZ5S9QTkW8nSH6Pvw-Jx0ttsQ1t0wnPQ,3140
84
86
  judgeval/utils/guards.py,sha256=QBb6m6KElxdvt2bskLZCKh_zGHbBcqV-VfGzT63o3hY,807
@@ -100,8 +102,8 @@ judgeval/utils/wrappers/mutable_wrap_async.py,sha256=stHISOUCGFUJXY8seXmxUo4ZpMF
100
102
  judgeval/utils/wrappers/mutable_wrap_sync.py,sha256=t5jygAQ1vqhy8s1GfiLeYygYgaLTgfoYASN47U5JiPs,2888
101
103
  judgeval/utils/wrappers/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
102
104
  judgeval/utils/wrappers/utils.py,sha256=j18vaa6JWDw2s3nQy1z5PfV_9Xxio-bVARaHG_0XyL0,1228
103
- judgeval-0.16.9.dist-info/METADATA,sha256=OiLnf6tEWwnFyLkEjqBbqORUSfcTgjJSyK9nFr6dxHo,11513
104
- judgeval-0.16.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
105
- judgeval-0.16.9.dist-info/entry_points.txt,sha256=-eoeD-oDLn4A7MSgeBS9Akwanf3_0r0cgEleBcIOjg0,46
106
- judgeval-0.16.9.dist-info/licenses/LICENSE.md,sha256=tKmCg7k5QOmxPK19XMfzim04QiQJPmgIm0pAn55IJwk,11352
107
- judgeval-0.16.9.dist-info/RECORD,,
105
+ judgeval-0.17.0.dist-info/METADATA,sha256=0A2L0alaZoA7KR-b43_IZlD9IolcBSwyVJj8Db-DC20,11483
106
+ judgeval-0.17.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
107
+ judgeval-0.17.0.dist-info/entry_points.txt,sha256=-eoeD-oDLn4A7MSgeBS9Akwanf3_0r0cgEleBcIOjg0,46
108
+ judgeval-0.17.0.dist-info/licenses/LICENSE.md,sha256=tKmCg7k5QOmxPK19XMfzim04QiQJPmgIm0pAn55IJwk,11352
109
+ judgeval-0.17.0.dist-info/RECORD,,