judgeval 0.16.9__py3-none-any.whl → 0.22.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of judgeval might be problematic. Click here for more details.
- judgeval/__init__.py +32 -2
- judgeval/api/__init__.py +108 -0
- judgeval/api/api_types.py +76 -15
- judgeval/cli.py +16 -1
- judgeval/data/judgment_types.py +76 -20
- judgeval/dataset/__init__.py +11 -2
- judgeval/env.py +2 -11
- judgeval/evaluation/__init__.py +4 -0
- judgeval/prompt/__init__.py +330 -0
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +1 -13
- judgeval/tracer/__init__.py +371 -257
- judgeval/tracer/constants.py +1 -1
- judgeval/tracer/exporters/store.py +32 -16
- judgeval/tracer/keys.py +11 -9
- judgeval/tracer/llm/llm_anthropic/messages.py +38 -26
- judgeval/tracer/llm/llm_anthropic/messages_stream.py +14 -14
- judgeval/tracer/llm/llm_google/generate_content.py +9 -7
- judgeval/tracer/llm/llm_openai/beta_chat_completions.py +38 -14
- judgeval/tracer/llm/llm_openai/chat_completions.py +90 -26
- judgeval/tracer/llm/llm_openai/responses.py +88 -26
- judgeval/tracer/llm/llm_openai/utils.py +42 -0
- judgeval/tracer/llm/llm_together/chat_completions.py +26 -18
- judgeval/tracer/managers.py +4 -0
- judgeval/trainer/__init__.py +10 -1
- judgeval/trainer/base_trainer.py +122 -0
- judgeval/trainer/config.py +1 -1
- judgeval/trainer/fireworks_trainer.py +396 -0
- judgeval/trainer/trainer.py +52 -387
- judgeval/utils/guards.py +9 -5
- judgeval/utils/project.py +15 -0
- judgeval/utils/serialize.py +2 -2
- judgeval/version.py +1 -1
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/METADATA +2 -3
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/RECORD +37 -32
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/WHEEL +0 -0
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/entry_points.txt +0 -0
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from typing import Optional, Callable, Any, List, Union, Dict
|
|
4
|
+
from fireworks import Dataset # type: ignore[import-not-found]
|
|
5
|
+
from .config import TrainerConfig, ModelConfig
|
|
6
|
+
from .base_trainer import BaseTrainer
|
|
7
|
+
from .trainable_model import TrainableModel
|
|
8
|
+
from judgeval.tracer import Tracer
|
|
9
|
+
from judgeval.tracer.exporters.store import SpanStore
|
|
10
|
+
from judgeval.tracer.exporters import InMemorySpanExporter
|
|
11
|
+
from judgeval.tracer.keys import AttributeKeys
|
|
12
|
+
from judgeval import JudgmentClient
|
|
13
|
+
from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
|
|
14
|
+
from judgeval.data import Example
|
|
15
|
+
from .console import _spinner_progress, _print_progress, _print_progress_update
|
|
16
|
+
from judgeval.exceptions import JudgmentRuntimeError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FireworksTrainer(BaseTrainer):
|
|
20
|
+
"""
|
|
21
|
+
Fireworks AI implementation of the training provider.
|
|
22
|
+
|
|
23
|
+
This trainer uses Fireworks AI's infrastructure for reinforcement learning
|
|
24
|
+
fine-tuning (RFT) of language models.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
config: TrainerConfig,
|
|
30
|
+
trainable_model: TrainableModel,
|
|
31
|
+
tracer: Tracer,
|
|
32
|
+
project_name: Optional[str] = None,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Initialize the FireworksTrainer.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
config: TrainerConfig instance with training parameters
|
|
39
|
+
trainable_model: TrainableModel instance for Fireworks training
|
|
40
|
+
tracer: Tracer for observability
|
|
41
|
+
project_name: Project name for organizing training runs and evaluations
|
|
42
|
+
"""
|
|
43
|
+
try:
|
|
44
|
+
super().__init__(config, trainable_model, tracer, project_name)
|
|
45
|
+
|
|
46
|
+
self.judgment_client = JudgmentClient()
|
|
47
|
+
self.span_store = SpanStore()
|
|
48
|
+
self.span_exporter = InMemorySpanExporter(self.span_store)
|
|
49
|
+
except Exception as e:
|
|
50
|
+
raise JudgmentRuntimeError(
|
|
51
|
+
f"Failed to initialize FireworksTrainer: {str(e)}"
|
|
52
|
+
) from e
|
|
53
|
+
|
|
54
|
+
def _extract_message_history_from_spans(
|
|
55
|
+
self, trace_id: str
|
|
56
|
+
) -> List[Dict[str, str]]:
|
|
57
|
+
"""
|
|
58
|
+
Extract message history from spans in the span store for training purposes.
|
|
59
|
+
|
|
60
|
+
This method processes trace spans to reconstruct the conversation flow,
|
|
61
|
+
extracting messages in chronological order from LLM, user, and tool spans.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
trace_id: The trace ID (32-char hex string) to extract message history from
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
List of message dictionaries with 'role' and 'content' keys
|
|
68
|
+
"""
|
|
69
|
+
spans = self.span_store.get_by_trace_id(trace_id)
|
|
70
|
+
if not spans:
|
|
71
|
+
return []
|
|
72
|
+
|
|
73
|
+
messages = []
|
|
74
|
+
first_found = False
|
|
75
|
+
|
|
76
|
+
for span in sorted(spans, key=lambda s: getattr(s, "start_time", 0)):
|
|
77
|
+
span_attributes = span.attributes or {}
|
|
78
|
+
span_type = span_attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND, "span")
|
|
79
|
+
|
|
80
|
+
if (
|
|
81
|
+
not span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
|
|
82
|
+
and span_type != "llm"
|
|
83
|
+
):
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
if span_type == "llm":
|
|
87
|
+
if not first_found and span_attributes.get(
|
|
88
|
+
AttributeKeys.JUDGMENT_INPUT
|
|
89
|
+
):
|
|
90
|
+
input_data: Any = span_attributes.get(
|
|
91
|
+
AttributeKeys.JUDGMENT_INPUT, {}
|
|
92
|
+
)
|
|
93
|
+
if isinstance(input_data, dict) and "messages" in input_data:
|
|
94
|
+
input_messages = input_data["messages"]
|
|
95
|
+
if input_messages:
|
|
96
|
+
first_found = True
|
|
97
|
+
for msg in input_messages:
|
|
98
|
+
if (
|
|
99
|
+
isinstance(msg, dict)
|
|
100
|
+
and "role" in msg
|
|
101
|
+
and "content" in msg
|
|
102
|
+
):
|
|
103
|
+
messages.append(
|
|
104
|
+
{"role": msg["role"], "content": msg["content"]}
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Add assistant response from span output
|
|
108
|
+
output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
|
|
109
|
+
if output is not None:
|
|
110
|
+
content = str(output)
|
|
111
|
+
try:
|
|
112
|
+
parsed = json.loads(content)
|
|
113
|
+
if isinstance(parsed, dict) and "messages" in parsed:
|
|
114
|
+
# Extract the actual assistant message content
|
|
115
|
+
for msg in parsed["messages"]:
|
|
116
|
+
if (
|
|
117
|
+
isinstance(msg, dict)
|
|
118
|
+
and msg.get("role") == "assistant"
|
|
119
|
+
):
|
|
120
|
+
content = msg.get("content", content)
|
|
121
|
+
break
|
|
122
|
+
except (json.JSONDecodeError, KeyError):
|
|
123
|
+
pass
|
|
124
|
+
messages.append({"role": "assistant", "content": content})
|
|
125
|
+
|
|
126
|
+
elif span_type in ("user", "tool"):
|
|
127
|
+
output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
|
|
128
|
+
if output is not None:
|
|
129
|
+
content = str(output)
|
|
130
|
+
try:
|
|
131
|
+
parsed = json.loads(content)
|
|
132
|
+
if isinstance(parsed, dict) and "messages" in parsed:
|
|
133
|
+
for msg in parsed["messages"]:
|
|
134
|
+
if isinstance(msg, dict) and msg.get("role") == "user":
|
|
135
|
+
content = msg.get("content", content)
|
|
136
|
+
break
|
|
137
|
+
except (json.JSONDecodeError, KeyError):
|
|
138
|
+
pass
|
|
139
|
+
messages.append({"role": "user", "content": content})
|
|
140
|
+
|
|
141
|
+
return messages
|
|
142
|
+
|
|
143
|
+
async def generate_rollouts_and_rewards(
|
|
144
|
+
self,
|
|
145
|
+
agent_function: Callable[[Any], Any],
|
|
146
|
+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
|
|
147
|
+
prompts: List[Any],
|
|
148
|
+
num_prompts_per_step: Optional[int] = None,
|
|
149
|
+
num_generations_per_prompt: Optional[int] = None,
|
|
150
|
+
concurrency: Optional[int] = None,
|
|
151
|
+
):
|
|
152
|
+
"""
|
|
153
|
+
Generate rollouts and compute rewards using the current model snapshot.
|
|
154
|
+
Each sample contains multiple generations for reinforcement learning optimization.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
agent_function: Function/agent to call for generating responses
|
|
158
|
+
scorers: List of scorer objects to evaluate responses
|
|
159
|
+
prompts: List of prompts to use for training
|
|
160
|
+
num_prompts_per_step: Number of prompts to use per step (defaults to config value, limited by prompts list length)
|
|
161
|
+
num_generations_per_prompt: Generations per prompt (defaults to config value)
|
|
162
|
+
concurrency: Concurrency limit (defaults to config value)
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
List of dataset rows containing samples with messages and evaluations
|
|
166
|
+
"""
|
|
167
|
+
num_prompts_per_step = min(
|
|
168
|
+
num_prompts_per_step or self.config.num_prompts_per_step, len(prompts)
|
|
169
|
+
)
|
|
170
|
+
num_generations_per_prompt = (
|
|
171
|
+
num_generations_per_prompt or self.config.num_generations_per_prompt
|
|
172
|
+
)
|
|
173
|
+
concurrency = concurrency or self.config.concurrency
|
|
174
|
+
|
|
175
|
+
semaphore = asyncio.Semaphore(concurrency)
|
|
176
|
+
|
|
177
|
+
@self.tracer.observe(span_type="function")
|
|
178
|
+
async def generate_single_response(prompt_id, generation_id):
|
|
179
|
+
async with semaphore:
|
|
180
|
+
prompt_input = prompts[prompt_id]
|
|
181
|
+
response_data = await agent_function(**prompt_input)
|
|
182
|
+
messages = response_data.get("messages", [])
|
|
183
|
+
|
|
184
|
+
current_span = self.tracer.get_current_span()
|
|
185
|
+
trace_id = None
|
|
186
|
+
if current_span and current_span.is_recording():
|
|
187
|
+
# Convert trace_id to hex string per OTEL spec
|
|
188
|
+
trace_id = format(current_span.get_span_context().trace_id, "032x")
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
if trace_id is not None:
|
|
192
|
+
traced_messages = self._extract_message_history_from_spans(
|
|
193
|
+
trace_id
|
|
194
|
+
)
|
|
195
|
+
if traced_messages:
|
|
196
|
+
messages = traced_messages
|
|
197
|
+
except Exception as e:
|
|
198
|
+
print(f"Warning: Failed to get message history from trace: {e}")
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
finally:
|
|
202
|
+
if trace_id is not None:
|
|
203
|
+
self.span_store.clear_trace(trace_id)
|
|
204
|
+
|
|
205
|
+
example = Example(
|
|
206
|
+
input=prompt_input,
|
|
207
|
+
messages=messages,
|
|
208
|
+
actual_output=response_data,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
scoring_results = self.judgment_client.run_evaluation(
|
|
212
|
+
examples=[example],
|
|
213
|
+
scorers=scorers,
|
|
214
|
+
project_name=self.project_name,
|
|
215
|
+
eval_run_name=f"training_step_{self.trainable_model.current_step}_prompt_{prompt_id}_gen_{generation_id}",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if scoring_results and scoring_results[0].scorers_data:
|
|
219
|
+
scores = [
|
|
220
|
+
scorer_data.score
|
|
221
|
+
for scorer_data in scoring_results[0].scorers_data
|
|
222
|
+
if scorer_data.score is not None
|
|
223
|
+
]
|
|
224
|
+
reward = sum(scores) / len(scores) if scores else 0.0
|
|
225
|
+
else:
|
|
226
|
+
reward = 0.0
|
|
227
|
+
|
|
228
|
+
return {
|
|
229
|
+
"prompt_id": prompt_id,
|
|
230
|
+
"generation_id": generation_id,
|
|
231
|
+
"messages": messages,
|
|
232
|
+
"evals": {"score": reward},
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
coros = []
|
|
236
|
+
for prompt_id in range(num_prompts_per_step):
|
|
237
|
+
for generation_id in range(num_generations_per_prompt):
|
|
238
|
+
coro = generate_single_response(prompt_id, generation_id)
|
|
239
|
+
coros.append(coro)
|
|
240
|
+
|
|
241
|
+
with _spinner_progress(f"Generating {len(coros)} rollouts..."):
|
|
242
|
+
num_completed = 0
|
|
243
|
+
results = []
|
|
244
|
+
|
|
245
|
+
for coro in asyncio.as_completed(coros):
|
|
246
|
+
result = await coro
|
|
247
|
+
results.append(result)
|
|
248
|
+
num_completed += 1
|
|
249
|
+
|
|
250
|
+
_print_progress(f"Generated {len(results)} rollouts successfully")
|
|
251
|
+
|
|
252
|
+
dataset_rows = []
|
|
253
|
+
for prompt_id in range(num_prompts_per_step):
|
|
254
|
+
prompt_generations = [r for r in results if r["prompt_id"] == prompt_id]
|
|
255
|
+
sample_generations = [
|
|
256
|
+
{"messages": gen["messages"], "evals": gen["evals"]}
|
|
257
|
+
for gen in prompt_generations
|
|
258
|
+
]
|
|
259
|
+
dataset_rows.append({"samples": sample_generations})
|
|
260
|
+
|
|
261
|
+
return dataset_rows
|
|
262
|
+
|
|
263
|
+
async def run_reinforcement_learning(
|
|
264
|
+
self,
|
|
265
|
+
agent_function: Callable[[Any], Any],
|
|
266
|
+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
|
|
267
|
+
prompts: List[Any],
|
|
268
|
+
) -> ModelConfig:
|
|
269
|
+
"""
|
|
270
|
+
Run the iterative reinforcement learning fine-tuning loop.
|
|
271
|
+
|
|
272
|
+
This method performs multiple steps of reinforcement learning, where each step:
|
|
273
|
+
1. Advances to the appropriate model snapshot
|
|
274
|
+
2. Generates rollouts and computes rewards using scorers
|
|
275
|
+
3. Trains a new model using reinforcement learning
|
|
276
|
+
4. Waits for training completion
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
agent_function: Function/agent to call for generating responses
|
|
280
|
+
scorers: List of scorer objects to evaluate responses
|
|
281
|
+
prompts: List of prompts to use for training
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
ModelConfig: Configuration of the trained model for inference and future training
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
_print_progress("Starting reinforcement learning training")
|
|
288
|
+
|
|
289
|
+
training_params = {
|
|
290
|
+
"num_steps": self.config.num_steps,
|
|
291
|
+
"num_prompts_per_step": self.config.num_prompts_per_step,
|
|
292
|
+
"num_generations_per_prompt": self.config.num_generations_per_prompt,
|
|
293
|
+
"epochs": self.config.epochs,
|
|
294
|
+
"learning_rate": self.config.learning_rate,
|
|
295
|
+
"accelerator_count": self.config.accelerator_count,
|
|
296
|
+
"accelerator_type": self.config.accelerator_type,
|
|
297
|
+
"temperature": self.config.temperature,
|
|
298
|
+
"max_tokens": self.config.max_tokens,
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
start_step = self.trainable_model.current_step
|
|
302
|
+
|
|
303
|
+
for step in range(start_step, self.config.num_steps):
|
|
304
|
+
step_num = step + 1
|
|
305
|
+
_print_progress(
|
|
306
|
+
f"Starting training step {step_num}", step_num, self.config.num_steps
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
self.trainable_model.advance_to_next_step(step)
|
|
310
|
+
|
|
311
|
+
dataset_rows = await self.generate_rollouts_and_rewards(
|
|
312
|
+
agent_function, scorers, prompts
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
with _spinner_progress(
|
|
316
|
+
"Preparing training dataset", step_num, self.config.num_steps
|
|
317
|
+
):
|
|
318
|
+
dataset = Dataset.from_list(dataset_rows)
|
|
319
|
+
dataset.sync()
|
|
320
|
+
|
|
321
|
+
_print_progress(
|
|
322
|
+
"Starting reinforcement training", step_num, self.config.num_steps
|
|
323
|
+
)
|
|
324
|
+
job = self.trainable_model.perform_reinforcement_step(dataset, step)
|
|
325
|
+
|
|
326
|
+
last_state = None
|
|
327
|
+
with _spinner_progress(
|
|
328
|
+
"Training job in progress", step_num, self.config.num_steps
|
|
329
|
+
):
|
|
330
|
+
while not job.is_completed:
|
|
331
|
+
job.raise_if_bad_state()
|
|
332
|
+
current_state = job.state
|
|
333
|
+
|
|
334
|
+
if current_state != last_state:
|
|
335
|
+
if current_state in ["uploading", "validating"]:
|
|
336
|
+
_print_progress_update(
|
|
337
|
+
f"Training job: {current_state} data"
|
|
338
|
+
)
|
|
339
|
+
elif current_state == "training":
|
|
340
|
+
_print_progress_update(
|
|
341
|
+
"Training job: model training in progress"
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
_print_progress_update(f"Training job: {current_state}")
|
|
345
|
+
last_state = current_state
|
|
346
|
+
|
|
347
|
+
await asyncio.sleep(10)
|
|
348
|
+
job = job.get()
|
|
349
|
+
if job is None:
|
|
350
|
+
raise JudgmentRuntimeError(
|
|
351
|
+
"Training job was deleted while waiting for completion"
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
_print_progress(
|
|
355
|
+
f"Training completed! New model: {job.output_model}",
|
|
356
|
+
step_num,
|
|
357
|
+
self.config.num_steps,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
dataset.delete()
|
|
361
|
+
|
|
362
|
+
_print_progress("All training steps completed!")
|
|
363
|
+
|
|
364
|
+
with _spinner_progress("Deploying final trained model"):
|
|
365
|
+
self.trainable_model.advance_to_next_step(self.config.num_steps)
|
|
366
|
+
|
|
367
|
+
return self.trainable_model.get_model_config(training_params)
|
|
368
|
+
|
|
369
|
+
async def train(
|
|
370
|
+
self,
|
|
371
|
+
agent_function: Callable[[Any], Any],
|
|
372
|
+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
|
|
373
|
+
prompts: List[Any],
|
|
374
|
+
) -> ModelConfig:
|
|
375
|
+
"""
|
|
376
|
+
Start the reinforcement learning fine-tuning process.
|
|
377
|
+
|
|
378
|
+
This is the main entry point for running the reinforcement learning training.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
agent_function: Function/agent to call for generating responses.
|
|
382
|
+
scorers: List of scorer objects to evaluate responses
|
|
383
|
+
prompts: List of prompts to use for training
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
ModelConfig: Configuration of the trained model for future loading
|
|
387
|
+
"""
|
|
388
|
+
try:
|
|
389
|
+
return await self.run_reinforcement_learning(
|
|
390
|
+
agent_function, scorers, prompts
|
|
391
|
+
)
|
|
392
|
+
except JudgmentRuntimeError:
|
|
393
|
+
# Re-raise JudgmentRuntimeError as-is
|
|
394
|
+
raise
|
|
395
|
+
except Exception as e:
|
|
396
|
+
raise JudgmentRuntimeError(f"Training process failed: {str(e)}") from e
|