judgeval 0.16.9__py3-none-any.whl → 0.22.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of judgeval might be problematic. Click here for more details.

Files changed (37) hide show
  1. judgeval/__init__.py +32 -2
  2. judgeval/api/__init__.py +108 -0
  3. judgeval/api/api_types.py +76 -15
  4. judgeval/cli.py +16 -1
  5. judgeval/data/judgment_types.py +76 -20
  6. judgeval/dataset/__init__.py +11 -2
  7. judgeval/env.py +2 -11
  8. judgeval/evaluation/__init__.py +4 -0
  9. judgeval/prompt/__init__.py +330 -0
  10. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +1 -13
  11. judgeval/tracer/__init__.py +371 -257
  12. judgeval/tracer/constants.py +1 -1
  13. judgeval/tracer/exporters/store.py +32 -16
  14. judgeval/tracer/keys.py +11 -9
  15. judgeval/tracer/llm/llm_anthropic/messages.py +38 -26
  16. judgeval/tracer/llm/llm_anthropic/messages_stream.py +14 -14
  17. judgeval/tracer/llm/llm_google/generate_content.py +9 -7
  18. judgeval/tracer/llm/llm_openai/beta_chat_completions.py +38 -14
  19. judgeval/tracer/llm/llm_openai/chat_completions.py +90 -26
  20. judgeval/tracer/llm/llm_openai/responses.py +88 -26
  21. judgeval/tracer/llm/llm_openai/utils.py +42 -0
  22. judgeval/tracer/llm/llm_together/chat_completions.py +26 -18
  23. judgeval/tracer/managers.py +4 -0
  24. judgeval/trainer/__init__.py +10 -1
  25. judgeval/trainer/base_trainer.py +122 -0
  26. judgeval/trainer/config.py +1 -1
  27. judgeval/trainer/fireworks_trainer.py +396 -0
  28. judgeval/trainer/trainer.py +52 -387
  29. judgeval/utils/guards.py +9 -5
  30. judgeval/utils/project.py +15 -0
  31. judgeval/utils/serialize.py +2 -2
  32. judgeval/version.py +1 -1
  33. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/METADATA +2 -3
  34. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/RECORD +37 -32
  35. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/WHEEL +0 -0
  36. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/entry_points.txt +0 -0
  37. {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,405 +1,70 @@
1
- import asyncio
2
- import json
3
- import time
4
- from typing import Optional, Callable, Any, List, Union, Dict
5
- from fireworks import Dataset # type: ignore[import-not-found]
6
- from .config import TrainerConfig, ModelConfig
1
+ from typing import Optional
2
+ from .config import TrainerConfig
3
+ from .base_trainer import BaseTrainer
4
+ from .fireworks_trainer import FireworksTrainer
7
5
  from .trainable_model import TrainableModel
8
6
  from judgeval.tracer import Tracer
9
- from judgeval.tracer.exporters.store import SpanStore
10
- from judgeval.tracer.exporters import InMemorySpanExporter
11
- from judgeval.tracer.keys import AttributeKeys
12
- from judgeval import JudgmentClient
13
- from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
14
- from judgeval.data import Example
15
- from .console import _spinner_progress, _print_progress, _print_progress_update
16
7
  from judgeval.exceptions import JudgmentRuntimeError
17
8
 
18
9
 
19
- class JudgmentTrainer:
10
+ def JudgmentTrainer(
11
+ config: TrainerConfig,
12
+ trainable_model: TrainableModel,
13
+ tracer: Tracer,
14
+ project_name: Optional[str] = None,
15
+ ) -> BaseTrainer:
20
16
  """
21
- A reinforcement learning trainer for Judgment models using Fine-Tuning.
17
+ Factory function for creating reinforcement learning trainers.
22
18
 
23
- This class handles the iterative training process where models are improved
24
- through reinforcement learning fine-tuning based on generated rollouts and rewards.
25
- """
26
-
27
- def __init__(
28
- self,
29
- config: TrainerConfig,
30
- trainable_model: TrainableModel,
31
- tracer: Tracer,
32
- project_name: Optional[str] = None,
33
- ):
34
- """
35
- Initialize the JudgmentTrainer.
36
-
37
- Args:
38
- config: TrainerConfig instance with training parameters. If None, uses default config.
39
- tracer: Optional tracer for observability
40
- trainable_model: Optional trainable model instance
41
- project_name: Project name for organizing training runs and evaluations
42
- """
43
- try:
44
- self.config = config
45
- self.tracer = tracer
46
- self.project_name = project_name or "judgment_training"
47
- self.trainable_model = trainable_model
48
-
49
- self.judgment_client = JudgmentClient()
50
- self.span_store = SpanStore()
51
- self.span_exporter = InMemorySpanExporter(self.span_store)
52
- except Exception as e:
53
- raise JudgmentRuntimeError(
54
- f"Failed to initialize JudgmentTrainer: {str(e)}"
55
- ) from e
56
-
57
- def _extract_message_history_from_spans(self) -> List[Dict[str, str]]:
58
- """
59
- Extract message history from spans in the span store for training purposes.
60
-
61
- This method processes trace spans to reconstruct the conversation flow,
62
- extracting messages in chronological order from LLM, user, and tool spans.
63
-
64
- Returns:
65
- List of message dictionaries with 'role' and 'content' keys
66
- """
67
- spans = self.span_store.get_all()
68
- if not spans:
69
- return []
70
-
71
- messages = []
72
- first_found = False
73
-
74
- for span in sorted(spans, key=lambda s: getattr(s, "start_time", 0)):
75
- span_attributes = span.attributes or {}
76
- span_type = span_attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND, "span")
77
-
78
- if (
79
- not span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
80
- and span_type != "llm"
81
- ):
82
- continue
83
-
84
- if span_type == "llm":
85
- if not first_found and span_attributes.get(
86
- AttributeKeys.JUDGMENT_INPUT
87
- ):
88
- input_data: Any = span_attributes.get(
89
- AttributeKeys.JUDGMENT_INPUT, {}
90
- )
91
- if isinstance(input_data, dict) and "messages" in input_data:
92
- input_messages = input_data["messages"]
93
- if input_messages:
94
- first_found = True
95
- for msg in input_messages:
96
- if (
97
- isinstance(msg, dict)
98
- and "role" in msg
99
- and "content" in msg
100
- ):
101
- messages.append(
102
- {"role": msg["role"], "content": msg["content"]}
103
- )
104
-
105
- # Add assistant response from span output
106
- output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
107
- if output is not None:
108
- content = str(output)
109
- try:
110
- parsed = json.loads(content)
111
- if isinstance(parsed, dict) and "messages" in parsed:
112
- # Extract the actual assistant message content
113
- for msg in parsed["messages"]:
114
- if (
115
- isinstance(msg, dict)
116
- and msg.get("role") == "assistant"
117
- ):
118
- content = msg.get("content", content)
119
- break
120
- except (json.JSONDecodeError, KeyError):
121
- pass
122
- messages.append({"role": "assistant", "content": content})
123
-
124
- elif span_type == "user":
125
- output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
126
- if output is not None:
127
- content = str(output)
128
- try:
129
- parsed = json.loads(content)
130
- if isinstance(parsed, dict) and "messages" in parsed:
131
- for msg in parsed["messages"]:
132
- if isinstance(msg, dict) and msg.get("role") == "user":
133
- content = msg.get("content", content)
134
- break
135
- except (json.JSONDecodeError, KeyError):
136
- pass
137
- messages.append({"role": "user", "content": content})
19
+ This factory creates and returns provider-specific trainer implementations
20
+ (FireworksTrainer, VerifiersTrainer, etc.) based on the configured RFT provider.
138
21
 
139
- elif span_type == "tool":
140
- output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
141
- if output is not None:
142
- content = str(output)
143
- try:
144
- parsed = json.loads(content)
145
- if isinstance(parsed, dict) and "messages" in parsed:
146
- for msg in parsed["messages"]:
147
- if isinstance(msg, dict) and msg.get("role") == "user":
148
- content = msg.get("content", content)
149
- break
150
- except (json.JSONDecodeError, KeyError):
151
- pass
152
- messages.append({"role": "user", "content": content})
22
+ The factory pattern allows for easy extension to support multiple training
23
+ providers without changing the client-facing API.
153
24
 
154
- return messages
155
-
156
- async def generate_rollouts_and_rewards(
157
- self,
158
- agent_function: Callable[[Any], Any],
159
- scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
160
- prompts: List[Any],
161
- num_prompts_per_step: Optional[int] = None,
162
- num_generations_per_prompt: Optional[int] = None,
163
- concurrency: Optional[int] = None,
164
- ):
165
- """
166
- Generate rollouts and compute rewards using the current model snapshot.
167
- Each sample contains multiple generations for reinforcement learning optimization.
168
-
169
- Args:
170
- agent_function: Function/agent to call for generating responses
171
- scorers: List of scorer objects to evaluate responses
172
- prompts: List of prompts to use for training
173
- num_prompts_per_step: Number of prompts to use per step (defaults to config value, limited by prompts list length)
174
- num_generations_per_prompt: Generations per prompt (defaults to config value)
175
- concurrency: Concurrency limit (defaults to config value)
176
-
177
- Returns:
178
- List of dataset rows containing samples with messages and evaluations
179
- """
180
- num_prompts_per_step = min(
181
- num_prompts_per_step or self.config.num_prompts_per_step, len(prompts)
182
- )
183
- num_generations_per_prompt = (
184
- num_generations_per_prompt or self.config.num_generations_per_prompt
25
+ Example:
26
+ config = TrainerConfig(
27
+ deployment_id="my-deployment",
28
+ user_id="my-user",
29
+ model_id="my-model",
30
+ rft_provider="fireworks" # or "verifiers" in the future
185
31
  )
186
- concurrency = concurrency or self.config.concurrency
187
-
188
- semaphore = asyncio.Semaphore(concurrency)
189
-
190
- @self.tracer.observe(span_type="function")
191
- async def generate_single_response(prompt_id, generation_id):
192
- async with semaphore:
193
- prompt_input = prompts[prompt_id]
194
- response_data = await agent_function(**prompt_input)
195
- messages = response_data.get("messages", [])
196
-
197
- try:
198
- traced_messages = self._extract_message_history_from_spans()
199
- if traced_messages:
200
- messages = traced_messages
201
- except Exception as e:
202
- print(f"Warning: Failed to get message history from trace: {e}")
203
- pass
204
-
205
- finally:
206
- self.span_store.spans = []
207
-
208
- example = Example(
209
- input=prompt_input,
210
- messages=messages,
211
- actual_output=response_data,
212
- )
213
-
214
- scoring_results = self.judgment_client.run_evaluation(
215
- examples=[example],
216
- scorers=scorers,
217
- project_name=self.project_name,
218
- eval_run_name=f"training_step_{self.trainable_model.current_step}_prompt_{prompt_id}_gen_{generation_id}",
219
- )
220
-
221
- if scoring_results and scoring_results[0].scorers_data:
222
- scores = [
223
- scorer_data.score
224
- for scorer_data in scoring_results[0].scorers_data
225
- if scorer_data.score is not None
226
- ]
227
- reward = sum(scores) / len(scores) if scores else 0.0
228
- else:
229
- reward = 0.0
230
-
231
- return {
232
- "prompt_id": prompt_id,
233
- "generation_id": generation_id,
234
- "messages": messages,
235
- "evals": {"score": reward},
236
- }
237
-
238
- coros = []
239
- for prompt_id in range(num_prompts_per_step):
240
- for generation_id in range(num_generations_per_prompt):
241
- coro = generate_single_response(prompt_id, generation_id)
242
- coros.append(coro)
243
32
 
244
- with _spinner_progress(f"Generating {len(coros)} rollouts..."):
245
- num_completed = 0
246
- results = []
33
+ # User creates and configures the trainable model
34
+ trainable_model = TrainableModel(config)
35
+ tracer = Tracer()
247
36
 
248
- for coro in asyncio.as_completed(coros):
249
- result = await coro
250
- results.append(result)
251
- num_completed += 1
37
+ # JudgmentTrainer automatically creates the appropriate provider-specific trainer
38
+ trainer = JudgmentTrainer(config, trainable_model, tracer)
252
39
 
253
- _print_progress(f"Generated {len(results)} rollouts successfully")
254
-
255
- dataset_rows = []
256
- for prompt_id in range(num_prompts_per_step):
257
- prompt_generations = [r for r in results if r["prompt_id"] == prompt_id]
258
- sample_generations = [
259
- {"messages": gen["messages"], "evals": gen["evals"]}
260
- for gen in prompt_generations
261
- ]
262
- dataset_rows.append({"samples": sample_generations})
263
-
264
- return dataset_rows
265
-
266
- async def run_reinforcement_learning(
267
- self,
268
- agent_function: Callable[[Any], Any],
269
- scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
270
- prompts: List[Any],
271
- ) -> ModelConfig:
272
- """
273
- Run the iterative reinforcement learning fine-tuning loop.
274
-
275
- This method performs multiple steps of reinforcement learning, where each step:
276
- 1. Advances to the appropriate model snapshot
277
- 2. Generates rollouts and computes rewards using scorers
278
- 3. Trains a new model using reinforcement learning
279
- 4. Waits for training completion
40
+ # The returned trainer implements the BaseTrainer interface
41
+ model_config = await trainer.train(agent_function, scorers, prompts)
280
42
 
281
43
  Args:
282
- agent_function: Function/agent to call for generating responses
283
- scorers: List of scorer objects to evaluate responses
284
- prompts: List of prompts to use for training
44
+ config: TrainerConfig instance with training parameters including rft_provider
45
+ trainable_model: Provider-specific trainable model instance (e.g., TrainableModel for Fireworks)
46
+ tracer: Tracer for observability
47
+ project_name: Project name for organizing training runs and evaluations
285
48
 
286
49
  Returns:
287
- ModelConfig: Configuration of the trained model for inference and future training
288
- """
289
-
290
- _print_progress("Starting reinforcement learning training")
291
-
292
- training_params = {
293
- "num_steps": self.config.num_steps,
294
- "num_prompts_per_step": self.config.num_prompts_per_step,
295
- "num_generations_per_prompt": self.config.num_generations_per_prompt,
296
- "epochs": self.config.epochs,
297
- "learning_rate": self.config.learning_rate,
298
- "accelerator_count": self.config.accelerator_count,
299
- "accelerator_type": self.config.accelerator_type,
300
- "temperature": self.config.temperature,
301
- "max_tokens": self.config.max_tokens,
302
- }
303
-
304
- start_step = self.trainable_model.current_step
305
-
306
- for step in range(start_step, self.config.num_steps):
307
- step_num = step + 1
308
- _print_progress(
309
- f"Starting training step {step_num}", step_num, self.config.num_steps
310
- )
311
-
312
- self.trainable_model.advance_to_next_step(step)
313
-
314
- dataset_rows = await self.generate_rollouts_and_rewards(
315
- agent_function, scorers, prompts
316
- )
317
-
318
- with _spinner_progress(
319
- "Preparing training dataset", step_num, self.config.num_steps
320
- ):
321
- dataset = Dataset.from_list(dataset_rows)
322
- dataset.sync()
323
-
324
- _print_progress(
325
- "Starting reinforcement training", step_num, self.config.num_steps
326
- )
327
- job = self.trainable_model.perform_reinforcement_step(dataset, step)
328
-
329
- last_state = None
330
- with _spinner_progress(
331
- "Training job in progress", step_num, self.config.num_steps
332
- ):
333
- while not job.is_completed:
334
- job.raise_if_bad_state()
335
- current_state = job.state
336
-
337
- if current_state != last_state:
338
- if current_state in ["uploading", "validating"]:
339
- _print_progress_update(
340
- f"Training job: {current_state} data"
341
- )
342
- elif current_state == "training":
343
- _print_progress_update(
344
- "Training job: model training in progress"
345
- )
346
- else:
347
- _print_progress_update(f"Training job: {current_state}")
348
- last_state = current_state
349
-
350
- time.sleep(10)
351
- job = job.get()
352
- if job is None:
353
- raise JudgmentRuntimeError(
354
- "Training job was deleted while waiting for completion"
355
- )
50
+ Provider-specific trainer instance (FireworksTrainer, etc.) that implements
51
+ the BaseTrainer interface
356
52
 
357
- _print_progress(
358
- f"Training completed! New model: {job.output_model}",
359
- step_num,
360
- self.config.num_steps,
361
- )
362
-
363
- dataset.delete()
364
-
365
- _print_progress("All training steps completed!")
366
-
367
- with _spinner_progress("Deploying final trained model"):
368
- self.trainable_model.advance_to_next_step(self.config.num_steps)
369
-
370
- return self.trainable_model.get_model_config(training_params)
371
-
372
- async def train(
373
- self,
374
- agent_function: Callable[[Any], Any],
375
- scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
376
- prompts: List[Any],
377
- rft_provider: Optional[str] = None,
378
- ) -> ModelConfig:
379
- """
380
- Start the reinforcement learning fine-tuning process.
381
-
382
- This is the main entry point for running the reinforcement learning training.
383
-
384
- Args:
385
- agent_function: Function/agent to call for generating responses.
386
- scorers: List of scorer objects to evaluate responses
387
- prompts: List of prompts to use for training
388
- rft_provider: RFT provider to use for training. Currently only "fireworks" is supported.
389
- Support for other providers is planned for future releases.
390
-
391
- Returns:
392
- ModelConfig: Configuration of the trained model for future loading
393
- """
394
- try:
395
- if rft_provider is not None:
396
- self.config.rft_provider = rft_provider
397
-
398
- return await self.run_reinforcement_learning(
399
- agent_function, scorers, prompts
400
- )
401
- except JudgmentRuntimeError:
402
- # Re-raise JudgmentAPIError as-is
403
- raise
404
- except Exception as e:
405
- raise JudgmentRuntimeError(f"Training process failed: {str(e)}") from e
53
+ Raises:
54
+ JudgmentRuntimeError: If the specified provider is not supported
55
+ """
56
+ provider = config.rft_provider.lower()
57
+
58
+ if provider == "fireworks":
59
+ return FireworksTrainer(config, trainable_model, tracer, project_name)
60
+ elif provider == "verifiers":
61
+ # Placeholder for future implementation
62
+ raise JudgmentRuntimeError(
63
+ "Verifiers provider is not yet implemented. "
64
+ "Currently supported providers: 'fireworks'"
65
+ )
66
+ else:
67
+ raise JudgmentRuntimeError(
68
+ f"Unsupported RFT provider: '{config.rft_provider}'. "
69
+ f"Currently supported providers: 'fireworks'"
70
+ )
judgeval/utils/guards.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING
4
+ from judgeval.logger import judgeval_logger
4
5
 
5
6
  if TYPE_CHECKING:
6
7
  from typing import TypeVar
@@ -8,24 +9,27 @@ if TYPE_CHECKING:
8
9
  T = TypeVar("T")
9
10
 
10
11
 
11
- def expect_exists(value: T | None, message: str) -> T:
12
- if value is None:
13
- raise ValueError(message)
12
+ def expect_exists(value: T | None, message: str, default: T) -> T:
13
+ if not value:
14
+ judgeval_logger.error(message)
15
+ return default
14
16
 
15
17
  return value
16
18
 
17
19
 
18
- def expect_api_key(api_key: str | None) -> str:
20
+ def expect_api_key(api_key: str | None) -> str | None:
19
21
  return expect_exists(
20
22
  api_key,
21
23
  "API Key is not set, please set JUDGMENT_API_KEY in the environment variables or pass it as `api_key`",
24
+ default=None,
22
25
  )
23
26
 
24
27
 
25
- def expect_organization_id(organization_id: str | None) -> str:
28
+ def expect_organization_id(organization_id: str | None) -> str | None:
26
29
  return expect_exists(
27
30
  organization_id,
28
31
  "Organization ID is not set, please set JUDGMENT_ORG_ID in the environment variables or pass it as `organization_id`",
32
+ default=None,
29
33
  )
30
34
 
31
35
 
@@ -0,0 +1,15 @@
1
+ from judgeval.utils.decorators.dont_throw import dont_throw
2
+ import functools
3
+ from judgeval.api import JudgmentSyncClient
4
+
5
+
6
+ @dont_throw
7
+ @functools.lru_cache(maxsize=64)
8
+ def _resolve_project_id(project_name: str, api_key: str, organization_id: str) -> str:
9
+ """Resolve project_id from project_name using the API."""
10
+ client = JudgmentSyncClient(
11
+ api_key=api_key,
12
+ organization_id=organization_id,
13
+ )
14
+ response = client.projects_resolve({"project_name": project_name})
15
+ return response["project_id"]
@@ -247,7 +247,7 @@ encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
247
247
  # Seralize arbitrary object to a json string
248
248
  def safe_serialize(obj: Any) -> str:
249
249
  try:
250
- return orjson.dumps(json_encoder(obj)).decode()
250
+ return orjson.dumps(json_encoder(obj), option=orjson.OPT_NON_STR_KEYS).decode()
251
251
  except Exception as e:
252
252
  judgeval_logger.warning(f"Error serializing object: {e}")
253
- return orjson.dumps(repr(obj)).decode()
253
+ return repr(obj)
judgeval/version.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.16.9"
1
+ __version__ = "0.22.2"
2
2
 
3
3
 
4
4
  def get_version() -> str:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: judgeval
3
- Version: 0.16.9
3
+ Version: 0.22.2
4
4
  Summary: Judgeval Package
5
5
  Project-URL: Homepage, https://github.com/JudgmentLabs/judgeval
6
6
  Project-URL: Issues, https://github.com/JudgmentLabs/judgeval/issues
@@ -63,8 +63,7 @@ Judgeval's agent monitoring infra provides a simple harness for integrating GRPO
63
63
  await trainer.train(
64
64
  agent_function=your_agent_function, # entry point to your agent
65
65
  scorers=[RewardScorer()], # Custom scorer you define based on task criteria, acts as reward
66
- prompts=training_prompts, # Tasks
67
- rft_provider="fireworks"
66
+ prompts=training_prompts # Tasks
68
67
  )
69
68
  ```
70
69