judgeval 0.16.9__py3-none-any.whl → 0.22.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of judgeval might be problematic. Click here for more details.
- judgeval/__init__.py +32 -2
- judgeval/api/__init__.py +108 -0
- judgeval/api/api_types.py +76 -15
- judgeval/cli.py +16 -1
- judgeval/data/judgment_types.py +76 -20
- judgeval/dataset/__init__.py +11 -2
- judgeval/env.py +2 -11
- judgeval/evaluation/__init__.py +4 -0
- judgeval/prompt/__init__.py +330 -0
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +1 -13
- judgeval/tracer/__init__.py +371 -257
- judgeval/tracer/constants.py +1 -1
- judgeval/tracer/exporters/store.py +32 -16
- judgeval/tracer/keys.py +11 -9
- judgeval/tracer/llm/llm_anthropic/messages.py +38 -26
- judgeval/tracer/llm/llm_anthropic/messages_stream.py +14 -14
- judgeval/tracer/llm/llm_google/generate_content.py +9 -7
- judgeval/tracer/llm/llm_openai/beta_chat_completions.py +38 -14
- judgeval/tracer/llm/llm_openai/chat_completions.py +90 -26
- judgeval/tracer/llm/llm_openai/responses.py +88 -26
- judgeval/tracer/llm/llm_openai/utils.py +42 -0
- judgeval/tracer/llm/llm_together/chat_completions.py +26 -18
- judgeval/tracer/managers.py +4 -0
- judgeval/trainer/__init__.py +10 -1
- judgeval/trainer/base_trainer.py +122 -0
- judgeval/trainer/config.py +1 -1
- judgeval/trainer/fireworks_trainer.py +396 -0
- judgeval/trainer/trainer.py +52 -387
- judgeval/utils/guards.py +9 -5
- judgeval/utils/project.py +15 -0
- judgeval/utils/serialize.py +2 -2
- judgeval/version.py +1 -1
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/METADATA +2 -3
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/RECORD +37 -32
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/WHEEL +0 -0
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/entry_points.txt +0 -0
- {judgeval-0.16.9.dist-info → judgeval-0.22.2.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -63,7 +63,7 @@ def _wrap_non_streaming_sync(
|
|
|
63
63
|
ctx["span"] = tracer.get_tracer().start_span(
|
|
64
64
|
"TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
|
|
65
65
|
)
|
|
66
|
-
tracer.
|
|
66
|
+
tracer._inject_judgment_context(ctx["span"])
|
|
67
67
|
set_span_attribute(
|
|
68
68
|
ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
|
|
69
69
|
)
|
|
@@ -73,7 +73,7 @@ def _wrap_non_streaming_sync(
|
|
|
73
73
|
)
|
|
74
74
|
ctx["model_name"] = prefixed_model_name
|
|
75
75
|
set_span_attribute(
|
|
76
|
-
ctx["span"], AttributeKeys.
|
|
76
|
+
ctx["span"], AttributeKeys.JUDGMENT_LLM_MODEL_NAME, prefixed_model_name
|
|
77
77
|
)
|
|
78
78
|
|
|
79
79
|
def post_hook(ctx: Dict[str, Any], result: ChatCompletionResponse) -> None:
|
|
@@ -90,10 +90,12 @@ def _wrap_non_streaming_sync(
|
|
|
90
90
|
result.usage
|
|
91
91
|
)
|
|
92
92
|
set_span_attribute(
|
|
93
|
-
span,
|
|
93
|
+
span,
|
|
94
|
+
AttributeKeys.JUDGMENT_USAGE_NON_CACHED_INPUT_TOKENS,
|
|
95
|
+
prompt_tokens,
|
|
94
96
|
)
|
|
95
97
|
set_span_attribute(
|
|
96
|
-
span, AttributeKeys.
|
|
98
|
+
span, AttributeKeys.JUDGMENT_USAGE_OUTPUT_TOKENS, completion_tokens
|
|
97
99
|
)
|
|
98
100
|
set_span_attribute(
|
|
99
101
|
span,
|
|
@@ -103,7 +105,7 @@ def _wrap_non_streaming_sync(
|
|
|
103
105
|
|
|
104
106
|
set_span_attribute(
|
|
105
107
|
span,
|
|
106
|
-
AttributeKeys.
|
|
108
|
+
AttributeKeys.JUDGMENT_LLM_MODEL_NAME,
|
|
107
109
|
ctx["model_name"],
|
|
108
110
|
)
|
|
109
111
|
|
|
@@ -133,7 +135,7 @@ def _wrap_streaming_sync(
|
|
|
133
135
|
ctx["span"] = tracer.get_tracer().start_span(
|
|
134
136
|
"TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
|
|
135
137
|
)
|
|
136
|
-
tracer.
|
|
138
|
+
tracer._inject_judgment_context(ctx["span"])
|
|
137
139
|
set_span_attribute(
|
|
138
140
|
ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
|
|
139
141
|
)
|
|
@@ -143,7 +145,7 @@ def _wrap_streaming_sync(
|
|
|
143
145
|
)
|
|
144
146
|
ctx["model_name"] = prefixed_model_name
|
|
145
147
|
set_span_attribute(
|
|
146
|
-
ctx["span"], AttributeKeys.
|
|
148
|
+
ctx["span"], AttributeKeys.JUDGMENT_LLM_MODEL_NAME, prefixed_model_name
|
|
147
149
|
)
|
|
148
150
|
ctx["accumulated_content"] = ""
|
|
149
151
|
|
|
@@ -171,10 +173,12 @@ def _wrap_streaming_sync(
|
|
|
171
173
|
chunk.usage
|
|
172
174
|
)
|
|
173
175
|
set_span_attribute(
|
|
174
|
-
span,
|
|
176
|
+
span,
|
|
177
|
+
AttributeKeys.JUDGMENT_USAGE_NON_CACHED_INPUT_TOKENS,
|
|
178
|
+
prompt_tokens,
|
|
175
179
|
)
|
|
176
180
|
set_span_attribute(
|
|
177
|
-
span, AttributeKeys.
|
|
181
|
+
span, AttributeKeys.JUDGMENT_USAGE_OUTPUT_TOKENS, completion_tokens
|
|
178
182
|
)
|
|
179
183
|
set_span_attribute(
|
|
180
184
|
span,
|
|
@@ -239,7 +243,7 @@ def _wrap_non_streaming_async(
|
|
|
239
243
|
ctx["span"] = tracer.get_tracer().start_span(
|
|
240
244
|
"TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
|
|
241
245
|
)
|
|
242
|
-
tracer.
|
|
246
|
+
tracer._inject_judgment_context(ctx["span"])
|
|
243
247
|
set_span_attribute(
|
|
244
248
|
ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
|
|
245
249
|
)
|
|
@@ -249,7 +253,7 @@ def _wrap_non_streaming_async(
|
|
|
249
253
|
)
|
|
250
254
|
ctx["model_name"] = prefixed_model_name
|
|
251
255
|
set_span_attribute(
|
|
252
|
-
ctx["span"], AttributeKeys.
|
|
256
|
+
ctx["span"], AttributeKeys.JUDGMENT_LLM_MODEL_NAME, prefixed_model_name
|
|
253
257
|
)
|
|
254
258
|
|
|
255
259
|
def post_hook(ctx: Dict[str, Any], result: ChatCompletionResponse) -> None:
|
|
@@ -266,10 +270,12 @@ def _wrap_non_streaming_async(
|
|
|
266
270
|
result.usage
|
|
267
271
|
)
|
|
268
272
|
set_span_attribute(
|
|
269
|
-
span,
|
|
273
|
+
span,
|
|
274
|
+
AttributeKeys.JUDGMENT_USAGE_NON_CACHED_INPUT_TOKENS,
|
|
275
|
+
prompt_tokens,
|
|
270
276
|
)
|
|
271
277
|
set_span_attribute(
|
|
272
|
-
span, AttributeKeys.
|
|
278
|
+
span, AttributeKeys.JUDGMENT_USAGE_OUTPUT_TOKENS, completion_tokens
|
|
273
279
|
)
|
|
274
280
|
set_span_attribute(
|
|
275
281
|
span,
|
|
@@ -279,7 +285,7 @@ def _wrap_non_streaming_async(
|
|
|
279
285
|
|
|
280
286
|
set_span_attribute(
|
|
281
287
|
span,
|
|
282
|
-
AttributeKeys.
|
|
288
|
+
AttributeKeys.JUDGMENT_LLM_MODEL_NAME,
|
|
283
289
|
ctx["model_name"],
|
|
284
290
|
)
|
|
285
291
|
|
|
@@ -310,7 +316,7 @@ def _wrap_streaming_async(
|
|
|
310
316
|
ctx["span"] = tracer.get_tracer().start_span(
|
|
311
317
|
"TOGETHER_API_CALL", attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "llm"}
|
|
312
318
|
)
|
|
313
|
-
tracer.
|
|
319
|
+
tracer._inject_judgment_context(ctx["span"])
|
|
314
320
|
set_span_attribute(
|
|
315
321
|
ctx["span"], AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs)
|
|
316
322
|
)
|
|
@@ -320,7 +326,7 @@ def _wrap_streaming_async(
|
|
|
320
326
|
)
|
|
321
327
|
ctx["model_name"] = prefixed_model_name
|
|
322
328
|
set_span_attribute(
|
|
323
|
-
ctx["span"], AttributeKeys.
|
|
329
|
+
ctx["span"], AttributeKeys.JUDGMENT_LLM_MODEL_NAME, prefixed_model_name
|
|
324
330
|
)
|
|
325
331
|
ctx["accumulated_content"] = ""
|
|
326
332
|
|
|
@@ -348,10 +354,12 @@ def _wrap_streaming_async(
|
|
|
348
354
|
chunk.usage
|
|
349
355
|
)
|
|
350
356
|
set_span_attribute(
|
|
351
|
-
span,
|
|
357
|
+
span,
|
|
358
|
+
AttributeKeys.JUDGMENT_USAGE_NON_CACHED_INPUT_TOKENS,
|
|
359
|
+
prompt_tokens,
|
|
352
360
|
)
|
|
353
361
|
set_span_attribute(
|
|
354
|
-
span, AttributeKeys.
|
|
362
|
+
span, AttributeKeys.JUDGMENT_USAGE_OUTPUT_TOKENS, completion_tokens
|
|
355
363
|
)
|
|
356
364
|
set_span_attribute(
|
|
357
365
|
span,
|
judgeval/tracer/managers.py
CHANGED
|
@@ -16,6 +16,7 @@ def sync_span_context(
|
|
|
16
16
|
name: str,
|
|
17
17
|
span_attributes: Optional[Dict[str, str]] = None,
|
|
18
18
|
disable_partial_emit: bool = False,
|
|
19
|
+
end_on_exit: bool = False,
|
|
19
20
|
):
|
|
20
21
|
if span_attributes is None:
|
|
21
22
|
span_attributes = {}
|
|
@@ -23,6 +24,7 @@ def sync_span_context(
|
|
|
23
24
|
with tracer.get_tracer().start_as_current_span(
|
|
24
25
|
name=name,
|
|
25
26
|
attributes=span_attributes,
|
|
27
|
+
end_on_exit=end_on_exit,
|
|
26
28
|
) as span:
|
|
27
29
|
if disable_partial_emit:
|
|
28
30
|
tracer.judgment_processor.set_internal_attribute(
|
|
@@ -39,6 +41,7 @@ async def async_span_context(
|
|
|
39
41
|
name: str,
|
|
40
42
|
span_attributes: Optional[Dict[str, str]] = None,
|
|
41
43
|
disable_partial_emit: bool = False,
|
|
44
|
+
end_on_exit: bool = False,
|
|
42
45
|
):
|
|
43
46
|
if span_attributes is None:
|
|
44
47
|
span_attributes = {}
|
|
@@ -46,6 +49,7 @@ async def async_span_context(
|
|
|
46
49
|
with tracer.get_tracer().start_as_current_span(
|
|
47
50
|
name=name,
|
|
48
51
|
attributes=span_attributes,
|
|
52
|
+
end_on_exit=end_on_exit,
|
|
49
53
|
) as span:
|
|
50
54
|
if disable_partial_emit:
|
|
51
55
|
tracer.judgment_processor.set_internal_attribute(
|
judgeval/trainer/__init__.py
CHANGED
|
@@ -1,5 +1,14 @@
|
|
|
1
1
|
from judgeval.trainer.trainer import JudgmentTrainer
|
|
2
2
|
from judgeval.trainer.config import TrainerConfig, ModelConfig
|
|
3
3
|
from judgeval.trainer.trainable_model import TrainableModel
|
|
4
|
+
from judgeval.trainer.base_trainer import BaseTrainer
|
|
5
|
+
from judgeval.trainer.fireworks_trainer import FireworksTrainer
|
|
4
6
|
|
|
5
|
-
__all__ = [
|
|
7
|
+
__all__ = [
|
|
8
|
+
"JudgmentTrainer",
|
|
9
|
+
"TrainerConfig",
|
|
10
|
+
"ModelConfig",
|
|
11
|
+
"TrainableModel",
|
|
12
|
+
"BaseTrainer",
|
|
13
|
+
"FireworksTrainer",
|
|
14
|
+
]
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Callable, List, Optional, Union, Dict, TYPE_CHECKING
|
|
3
|
+
from .config import TrainerConfig, ModelConfig
|
|
4
|
+
from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from judgeval.tracer import Tracer
|
|
8
|
+
from .trainable_model import TrainableModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseTrainer(ABC):
|
|
12
|
+
"""
|
|
13
|
+
Abstract base class for training providers.
|
|
14
|
+
|
|
15
|
+
This class defines the interface that all training provider implementations
|
|
16
|
+
must follow. Each provider (Fireworks, Verifiers, etc.) will have its own
|
|
17
|
+
concrete implementation of this interface.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
config: TrainerConfig,
|
|
23
|
+
trainable_model: "TrainableModel",
|
|
24
|
+
tracer: "Tracer",
|
|
25
|
+
project_name: Optional[str] = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize the base trainer.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
config: TrainerConfig instance with training parameters
|
|
32
|
+
trainable_model: TrainableModel instance to use for training
|
|
33
|
+
tracer: Tracer for observability
|
|
34
|
+
project_name: Project name for organizing training runs
|
|
35
|
+
"""
|
|
36
|
+
self.config = config
|
|
37
|
+
self.trainable_model = trainable_model
|
|
38
|
+
self.tracer = tracer
|
|
39
|
+
self.project_name = project_name or "judgment_training"
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
async def generate_rollouts_and_rewards(
|
|
43
|
+
self,
|
|
44
|
+
agent_function: Callable[[Any], Any],
|
|
45
|
+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
|
|
46
|
+
prompts: List[Any],
|
|
47
|
+
num_prompts_per_step: Optional[int] = None,
|
|
48
|
+
num_generations_per_prompt: Optional[int] = None,
|
|
49
|
+
concurrency: Optional[int] = None,
|
|
50
|
+
) -> Any:
|
|
51
|
+
"""
|
|
52
|
+
Generate rollouts and compute rewards using the current model snapshot.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
agent_function: Function/agent to call for generating responses
|
|
56
|
+
scorers: List of scorer objects to evaluate responses
|
|
57
|
+
prompts: List of prompts to use for training
|
|
58
|
+
num_prompts_per_step: Number of prompts to use per step
|
|
59
|
+
num_generations_per_prompt: Generations per prompt
|
|
60
|
+
concurrency: Concurrency limit
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Provider-specific dataset format for training
|
|
64
|
+
"""
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
async def run_reinforcement_learning(
|
|
69
|
+
self,
|
|
70
|
+
agent_function: Callable[[Any], Any],
|
|
71
|
+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
|
|
72
|
+
prompts: List[Any],
|
|
73
|
+
) -> ModelConfig:
|
|
74
|
+
"""
|
|
75
|
+
Run the iterative reinforcement learning fine-tuning loop.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
agent_function: Function/agent to call for generating responses
|
|
79
|
+
scorers: List of scorer objects to evaluate responses
|
|
80
|
+
prompts: List of prompts to use for training
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
ModelConfig: Configuration of the trained model
|
|
84
|
+
"""
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
async def train(
|
|
89
|
+
self,
|
|
90
|
+
agent_function: Callable[[Any], Any],
|
|
91
|
+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
|
|
92
|
+
prompts: List[Any],
|
|
93
|
+
) -> ModelConfig:
|
|
94
|
+
"""
|
|
95
|
+
Start the reinforcement learning fine-tuning process.
|
|
96
|
+
|
|
97
|
+
This is the main entry point for running the training.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
agent_function: Function/agent to call for generating responses
|
|
101
|
+
scorers: List of scorer objects to evaluate responses
|
|
102
|
+
prompts: List of prompts to use for training
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
ModelConfig: Configuration of the trained model
|
|
106
|
+
"""
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def _extract_message_history_from_spans(
|
|
111
|
+
self, trace_id: str
|
|
112
|
+
) -> List[Dict[str, str]]:
|
|
113
|
+
"""
|
|
114
|
+
Extract message history from spans for training purposes.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
trace_id: The trace ID (32-char hex string) to extract message history from
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
List of message dictionaries with 'role' and 'content' keys
|
|
121
|
+
"""
|
|
122
|
+
pass
|
judgeval/trainer/config.py
CHANGED
|
@@ -16,7 +16,7 @@ class TrainerConfig:
|
|
|
16
16
|
user_id: str
|
|
17
17
|
model_id: str
|
|
18
18
|
base_model_name: str = "qwen2p5-7b-instruct"
|
|
19
|
-
rft_provider: str = "fireworks"
|
|
19
|
+
rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
|
|
20
20
|
num_steps: int = 5
|
|
21
21
|
num_generations_per_prompt: int = 4
|
|
22
22
|
num_prompts_per_step: int = 4
|