judgeval 0.0.11__py3-none-any.whl → 0.22.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of judgeval might be problematic. Click here for more details.
- judgeval/__init__.py +177 -12
- judgeval/api/__init__.py +519 -0
- judgeval/api/api_types.py +407 -0
- judgeval/cli.py +79 -0
- judgeval/constants.py +76 -47
- judgeval/data/__init__.py +3 -3
- judgeval/data/evaluation_run.py +125 -0
- judgeval/data/example.py +15 -56
- judgeval/data/judgment_types.py +450 -0
- judgeval/data/result.py +29 -73
- judgeval/data/scorer_data.py +29 -62
- judgeval/data/scripts/fix_default_factory.py +23 -0
- judgeval/data/scripts/openapi_transform.py +123 -0
- judgeval/data/trace.py +121 -0
- judgeval/dataset/__init__.py +264 -0
- judgeval/env.py +52 -0
- judgeval/evaluation/__init__.py +344 -0
- judgeval/exceptions.py +27 -0
- judgeval/integrations/langgraph/__init__.py +13 -0
- judgeval/integrations/openlit/__init__.py +50 -0
- judgeval/judges/__init__.py +2 -3
- judgeval/judges/base_judge.py +2 -3
- judgeval/judges/litellm_judge.py +100 -20
- judgeval/judges/together_judge.py +101 -20
- judgeval/judges/utils.py +20 -24
- judgeval/logger.py +62 -0
- judgeval/prompt/__init__.py +330 -0
- judgeval/scorers/__init__.py +18 -25
- judgeval/scorers/agent_scorer.py +17 -0
- judgeval/scorers/api_scorer.py +45 -41
- judgeval/scorers/base_scorer.py +83 -38
- judgeval/scorers/example_scorer.py +17 -0
- judgeval/scorers/exceptions.py +1 -0
- judgeval/scorers/judgeval_scorers/__init__.py +0 -148
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +19 -17
- judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +13 -19
- judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +12 -19
- judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +13 -19
- judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +15 -0
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +327 -0
- judgeval/scorers/score.py +77 -306
- judgeval/scorers/utils.py +4 -199
- judgeval/tracer/__init__.py +1122 -2
- judgeval/tracer/constants.py +1 -0
- judgeval/tracer/exporters/__init__.py +40 -0
- judgeval/tracer/exporters/s3.py +119 -0
- judgeval/tracer/exporters/store.py +59 -0
- judgeval/tracer/exporters/utils.py +32 -0
- judgeval/tracer/keys.py +63 -0
- judgeval/tracer/llm/__init__.py +7 -0
- judgeval/tracer/llm/config.py +78 -0
- judgeval/tracer/llm/constants.py +9 -0
- judgeval/tracer/llm/llm_anthropic/__init__.py +3 -0
- judgeval/tracer/llm/llm_anthropic/config.py +6 -0
- judgeval/tracer/llm/llm_anthropic/messages.py +452 -0
- judgeval/tracer/llm/llm_anthropic/messages_stream.py +322 -0
- judgeval/tracer/llm/llm_anthropic/wrapper.py +59 -0
- judgeval/tracer/llm/llm_google/__init__.py +3 -0
- judgeval/tracer/llm/llm_google/config.py +6 -0
- judgeval/tracer/llm/llm_google/generate_content.py +127 -0
- judgeval/tracer/llm/llm_google/wrapper.py +30 -0
- judgeval/tracer/llm/llm_openai/__init__.py +3 -0
- judgeval/tracer/llm/llm_openai/beta_chat_completions.py +216 -0
- judgeval/tracer/llm/llm_openai/chat_completions.py +501 -0
- judgeval/tracer/llm/llm_openai/config.py +6 -0
- judgeval/tracer/llm/llm_openai/responses.py +506 -0
- judgeval/tracer/llm/llm_openai/utils.py +42 -0
- judgeval/tracer/llm/llm_openai/wrapper.py +63 -0
- judgeval/tracer/llm/llm_together/__init__.py +3 -0
- judgeval/tracer/llm/llm_together/chat_completions.py +406 -0
- judgeval/tracer/llm/llm_together/config.py +6 -0
- judgeval/tracer/llm/llm_together/wrapper.py +52 -0
- judgeval/tracer/llm/providers.py +19 -0
- judgeval/tracer/managers.py +167 -0
- judgeval/tracer/processors/__init__.py +220 -0
- judgeval/tracer/utils.py +19 -0
- judgeval/trainer/__init__.py +14 -0
- judgeval/trainer/base_trainer.py +122 -0
- judgeval/trainer/config.py +128 -0
- judgeval/trainer/console.py +144 -0
- judgeval/trainer/fireworks_trainer.py +396 -0
- judgeval/trainer/trainable_model.py +243 -0
- judgeval/trainer/trainer.py +70 -0
- judgeval/utils/async_utils.py +39 -0
- judgeval/utils/decorators/__init__.py +0 -0
- judgeval/utils/decorators/dont_throw.py +37 -0
- judgeval/utils/decorators/use_once.py +13 -0
- judgeval/utils/file_utils.py +97 -0
- judgeval/utils/guards.py +36 -0
- judgeval/utils/meta.py +27 -0
- judgeval/utils/project.py +15 -0
- judgeval/utils/serialize.py +253 -0
- judgeval/utils/testing.py +70 -0
- judgeval/utils/url.py +10 -0
- judgeval/utils/version_check.py +28 -0
- judgeval/utils/wrappers/README.md +3 -0
- judgeval/utils/wrappers/__init__.py +15 -0
- judgeval/utils/wrappers/immutable_wrap_async.py +74 -0
- judgeval/utils/wrappers/immutable_wrap_async_iterator.py +84 -0
- judgeval/utils/wrappers/immutable_wrap_sync.py +66 -0
- judgeval/utils/wrappers/immutable_wrap_sync_iterator.py +84 -0
- judgeval/utils/wrappers/mutable_wrap_async.py +67 -0
- judgeval/utils/wrappers/mutable_wrap_sync.py +67 -0
- judgeval/utils/wrappers/py.typed +0 -0
- judgeval/utils/wrappers/utils.py +35 -0
- judgeval/version.py +5 -0
- judgeval/warnings.py +4 -0
- judgeval-0.22.2.dist-info/METADATA +265 -0
- judgeval-0.22.2.dist-info/RECORD +112 -0
- judgeval-0.22.2.dist-info/entry_points.txt +2 -0
- judgeval/clients.py +0 -39
- judgeval/common/__init__.py +0 -8
- judgeval/common/exceptions.py +0 -28
- judgeval/common/logger.py +0 -189
- judgeval/common/tracer.py +0 -798
- judgeval/common/utils.py +0 -763
- judgeval/data/api_example.py +0 -111
- judgeval/data/datasets/__init__.py +0 -5
- judgeval/data/datasets/dataset.py +0 -286
- judgeval/data/datasets/eval_dataset_client.py +0 -193
- judgeval/data/datasets/ground_truth.py +0 -54
- judgeval/data/datasets/utils.py +0 -74
- judgeval/evaluation_run.py +0 -132
- judgeval/judges/mixture_of_judges.py +0 -248
- judgeval/judgment_client.py +0 -354
- judgeval/run_evaluation.py +0 -439
- judgeval/scorers/judgeval_scorer.py +0 -140
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +0 -19
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +0 -19
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +0 -22
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -19
- judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +0 -32
- judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +0 -20
- judgeval/scorers/judgeval_scorers/api_scorers/tool_correctness.py +0 -19
- judgeval/scorers/judgeval_scorers/classifiers/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +0 -54
- judgeval/scorers/judgeval_scorers/local_implementations/__init__.py +0 -24
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/__init__.py +0 -4
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/answer_correctness_scorer.py +0 -277
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/prompts.py +0 -169
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/__init__.py +0 -4
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/answer_relevancy_scorer.py +0 -298
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/prompts.py +0 -174
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/contextual_precision_scorer.py +0 -264
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/prompts.py +0 -106
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/contextual_recall_scorer.py +0 -254
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/prompts.py +0 -142
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/contextual_relevancy_scorer.py +0 -245
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/prompts.py +0 -121
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/faithfulness_scorer.py +0 -325
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/prompts.py +0 -268
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/hallucination_scorer.py +0 -263
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/prompts.py +0 -104
- judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/__init__.py +0 -5
- judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/json_correctness_scorer.py +0 -134
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/prompts.py +0 -247
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/summarization_scorer.py +0 -550
- judgeval/scorers/judgeval_scorers/local_implementations/tool_correctness/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/tool_correctness/tool_correctness_scorer.py +0 -157
- judgeval/scorers/prompt_scorer.py +0 -439
- judgeval-0.0.11.dist-info/METADATA +0 -36
- judgeval-0.0.11.dist-info/RECORD +0 -84
- {judgeval-0.0.11.dist-info → judgeval-0.22.2.dist-info}/WHEEL +0 -0
- {judgeval-0.0.11.dist-info → judgeval-0.22.2.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Optional, TYPE_CHECKING, Any
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from opentelemetry.context import Context
|
|
5
|
+
from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor
|
|
6
|
+
from opentelemetry.trace.span import SpanContext
|
|
7
|
+
from opentelemetry.sdk.trace.export import (
|
|
8
|
+
BatchSpanProcessor,
|
|
9
|
+
)
|
|
10
|
+
from judgeval.tracer.exporters import JudgmentSpanExporter
|
|
11
|
+
from judgeval.tracer.keys import AttributeKeys, InternalAttributeKeys, ResourceKeys
|
|
12
|
+
from judgeval.utils.url import url_for
|
|
13
|
+
from judgeval.utils.decorators.dont_throw import dont_throw
|
|
14
|
+
from judgeval.version import get_version
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from judgeval.tracer import Tracer
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class NoOpSpanProcessor(SpanProcessor):
|
|
21
|
+
def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def on_end(self, span: ReadableSpan) -> None:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
def shutdown(self) -> None:
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
|
31
|
+
return True
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class JudgmentSpanProcessor(BatchSpanProcessor):
|
|
35
|
+
__slots__ = ("tracer", "resource_attributes", "_internal_attributes")
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
tracer: Tracer,
|
|
40
|
+
project_name: str,
|
|
41
|
+
project_id: str,
|
|
42
|
+
api_key: str,
|
|
43
|
+
organization_id: str,
|
|
44
|
+
/,
|
|
45
|
+
*,
|
|
46
|
+
max_queue_size: int | None = None,
|
|
47
|
+
schedule_delay_millis: float | None = None,
|
|
48
|
+
max_export_batch_size: int | None = None,
|
|
49
|
+
export_timeout_millis: float | None = None,
|
|
50
|
+
resource_attributes: Optional[dict[str, Any]] = None,
|
|
51
|
+
):
|
|
52
|
+
self.tracer = tracer
|
|
53
|
+
|
|
54
|
+
attrs = {
|
|
55
|
+
ResourceKeys.SERVICE_NAME: project_name,
|
|
56
|
+
ResourceKeys.TELEMETRY_SDK_NAME: "judgeval",
|
|
57
|
+
ResourceKeys.TELEMETRY_SDK_VERSION: get_version(),
|
|
58
|
+
ResourceKeys.JUDGMENT_PROJECT_ID: project_id,
|
|
59
|
+
**(resource_attributes or {}),
|
|
60
|
+
}
|
|
61
|
+
self.resource_attributes = attrs
|
|
62
|
+
|
|
63
|
+
super().__init__(
|
|
64
|
+
JudgmentSpanExporter(
|
|
65
|
+
endpoint=url_for("/otel/v1/traces"),
|
|
66
|
+
api_key=api_key,
|
|
67
|
+
organization_id=organization_id,
|
|
68
|
+
project_id=project_id,
|
|
69
|
+
),
|
|
70
|
+
max_queue_size=max_queue_size,
|
|
71
|
+
schedule_delay_millis=schedule_delay_millis,
|
|
72
|
+
max_export_batch_size=max_export_batch_size,
|
|
73
|
+
export_timeout_millis=export_timeout_millis,
|
|
74
|
+
)
|
|
75
|
+
self._internal_attributes: defaultdict[tuple[int, int], dict[str, Any]] = (
|
|
76
|
+
defaultdict(dict)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _get_span_key(self, span_context: SpanContext) -> tuple[int, int]:
|
|
80
|
+
return (span_context.trace_id, span_context.span_id)
|
|
81
|
+
|
|
82
|
+
def set_internal_attribute(
|
|
83
|
+
self, span_context: SpanContext, key: str, value: Any
|
|
84
|
+
) -> None:
|
|
85
|
+
span_key = self._get_span_key(span_context)
|
|
86
|
+
self._internal_attributes[span_key][key] = value
|
|
87
|
+
|
|
88
|
+
def get_internal_attribute(
|
|
89
|
+
self, span_context: SpanContext, key: str, default: Any = None
|
|
90
|
+
) -> Any:
|
|
91
|
+
span_key = self._get_span_key(span_context)
|
|
92
|
+
return self._internal_attributes[span_key].get(key, default)
|
|
93
|
+
|
|
94
|
+
def increment_update_id(self, span_context: SpanContext) -> int:
|
|
95
|
+
current_id = self.get_internal_attribute(
|
|
96
|
+
span_context=span_context, key=AttributeKeys.JUDGMENT_UPDATE_ID, default=0
|
|
97
|
+
)
|
|
98
|
+
new_id = current_id + 1
|
|
99
|
+
self.set_internal_attribute(
|
|
100
|
+
span_context=span_context,
|
|
101
|
+
key=AttributeKeys.JUDGMENT_UPDATE_ID,
|
|
102
|
+
value=new_id,
|
|
103
|
+
)
|
|
104
|
+
return current_id
|
|
105
|
+
|
|
106
|
+
def _cleanup_span_state(self, span_key: tuple[int, int]) -> None:
|
|
107
|
+
self._internal_attributes.pop(span_key, None)
|
|
108
|
+
|
|
109
|
+
@dont_throw
|
|
110
|
+
def emit_partial(self) -> None:
|
|
111
|
+
current_span = self.tracer.get_current_span()
|
|
112
|
+
if (
|
|
113
|
+
not current_span
|
|
114
|
+
or not current_span.is_recording()
|
|
115
|
+
or not isinstance(current_span, ReadableSpan)
|
|
116
|
+
):
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
span_context = current_span.get_span_context()
|
|
120
|
+
if self.get_internal_attribute(
|
|
121
|
+
span_context, InternalAttributeKeys.DISABLE_PARTIAL_EMIT, False
|
|
122
|
+
):
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
attributes = dict(current_span.attributes or {})
|
|
126
|
+
attributes[AttributeKeys.JUDGMENT_UPDATE_ID] = self.increment_update_id(
|
|
127
|
+
span_context
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
partial_span = ReadableSpan(
|
|
131
|
+
name=current_span.name,
|
|
132
|
+
context=span_context,
|
|
133
|
+
parent=current_span.parent,
|
|
134
|
+
resource=current_span.resource,
|
|
135
|
+
attributes=attributes,
|
|
136
|
+
events=current_span.events,
|
|
137
|
+
links=current_span.links,
|
|
138
|
+
status=current_span.status,
|
|
139
|
+
kind=current_span.kind,
|
|
140
|
+
start_time=current_span.start_time,
|
|
141
|
+
end_time=None,
|
|
142
|
+
instrumentation_scope=current_span.instrumentation_scope,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
super().on_end(partial_span)
|
|
146
|
+
|
|
147
|
+
def on_end(self, span: ReadableSpan) -> None:
|
|
148
|
+
if not span.context:
|
|
149
|
+
super().on_end(span)
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
span_key = self._get_span_key(span.context)
|
|
153
|
+
|
|
154
|
+
if self.get_internal_attribute(
|
|
155
|
+
span.context, InternalAttributeKeys.CANCELLED, False
|
|
156
|
+
):
|
|
157
|
+
self._cleanup_span_state(span_key)
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
if span.end_time is not None:
|
|
161
|
+
attributes = dict(span.attributes or {})
|
|
162
|
+
attributes[AttributeKeys.JUDGMENT_UPDATE_ID] = 20
|
|
163
|
+
|
|
164
|
+
final_span = ReadableSpan(
|
|
165
|
+
name=span.name,
|
|
166
|
+
context=span.context,
|
|
167
|
+
parent=span.parent,
|
|
168
|
+
resource=span.resource,
|
|
169
|
+
attributes=attributes,
|
|
170
|
+
events=span.events,
|
|
171
|
+
links=span.links,
|
|
172
|
+
status=span.status,
|
|
173
|
+
kind=span.kind,
|
|
174
|
+
start_time=span.start_time,
|
|
175
|
+
end_time=span.end_time,
|
|
176
|
+
instrumentation_scope=span.instrumentation_scope,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
self._cleanup_span_state(span_key)
|
|
180
|
+
super().on_end(final_span)
|
|
181
|
+
else:
|
|
182
|
+
super().on_end(span)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class NoOpJudgmentSpanProcessor(JudgmentSpanProcessor):
|
|
186
|
+
__slots__ = ("resource_attributes",)
|
|
187
|
+
|
|
188
|
+
def __init__(self):
|
|
189
|
+
self.resource_attributes = {}
|
|
190
|
+
|
|
191
|
+
def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None:
|
|
192
|
+
pass
|
|
193
|
+
|
|
194
|
+
def on_end(self, span: ReadableSpan) -> None:
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
def shutdown(self) -> None:
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
def force_flush(self, timeout_millis: int | None = 30000) -> bool:
|
|
201
|
+
return True
|
|
202
|
+
|
|
203
|
+
def emit_partial(self) -> None:
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
def set_internal_attribute(
|
|
207
|
+
self, span_context: SpanContext, key: str, value: Any
|
|
208
|
+
) -> None:
|
|
209
|
+
pass
|
|
210
|
+
|
|
211
|
+
def get_internal_attribute(
|
|
212
|
+
self, span_context: SpanContext, key: str, default: Any = None
|
|
213
|
+
) -> Any:
|
|
214
|
+
return default
|
|
215
|
+
|
|
216
|
+
def increment_update_id(self, span_context: SpanContext) -> int:
|
|
217
|
+
return 0
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
__all__ = ["NoOpSpanProcessor", "JudgmentSpanProcessor", "NoOpJudgmentSpanProcessor"]
|
judgeval/tracer/utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from opentelemetry.trace import Span
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from typing import Callable, Optional
|
|
5
|
+
from judgeval.scorers.api_scorer import TraceAPIScorerConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def set_span_attribute(span: Span, name: str, value: Any):
|
|
9
|
+
if value is None or value == "":
|
|
10
|
+
return
|
|
11
|
+
|
|
12
|
+
span.set_attribute(name, value)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TraceScorerConfig(BaseModel):
|
|
16
|
+
scorer: TraceAPIScorerConfig | None
|
|
17
|
+
model: Optional[str] = None
|
|
18
|
+
sampling_rate: float = 1.0
|
|
19
|
+
run_condition: Optional[Callable[..., bool]] = None
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from judgeval.trainer.trainer import JudgmentTrainer
|
|
2
|
+
from judgeval.trainer.config import TrainerConfig, ModelConfig
|
|
3
|
+
from judgeval.trainer.trainable_model import TrainableModel
|
|
4
|
+
from judgeval.trainer.base_trainer import BaseTrainer
|
|
5
|
+
from judgeval.trainer.fireworks_trainer import FireworksTrainer
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"JudgmentTrainer",
|
|
9
|
+
"TrainerConfig",
|
|
10
|
+
"ModelConfig",
|
|
11
|
+
"TrainableModel",
|
|
12
|
+
"BaseTrainer",
|
|
13
|
+
"FireworksTrainer",
|
|
14
|
+
]
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Callable, List, Optional, Union, Dict, TYPE_CHECKING
|
|
3
|
+
from .config import TrainerConfig, ModelConfig
|
|
4
|
+
from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from judgeval.tracer import Tracer
|
|
8
|
+
from .trainable_model import TrainableModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseTrainer(ABC):
|
|
12
|
+
"""
|
|
13
|
+
Abstract base class for training providers.
|
|
14
|
+
|
|
15
|
+
This class defines the interface that all training provider implementations
|
|
16
|
+
must follow. Each provider (Fireworks, Verifiers, etc.) will have its own
|
|
17
|
+
concrete implementation of this interface.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
config: TrainerConfig,
|
|
23
|
+
trainable_model: "TrainableModel",
|
|
24
|
+
tracer: "Tracer",
|
|
25
|
+
project_name: Optional[str] = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize the base trainer.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
config: TrainerConfig instance with training parameters
|
|
32
|
+
trainable_model: TrainableModel instance to use for training
|
|
33
|
+
tracer: Tracer for observability
|
|
34
|
+
project_name: Project name for organizing training runs
|
|
35
|
+
"""
|
|
36
|
+
self.config = config
|
|
37
|
+
self.trainable_model = trainable_model
|
|
38
|
+
self.tracer = tracer
|
|
39
|
+
self.project_name = project_name or "judgment_training"
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
async def generate_rollouts_and_rewards(
|
|
43
|
+
self,
|
|
44
|
+
agent_function: Callable[[Any], Any],
|
|
45
|
+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
|
|
46
|
+
prompts: List[Any],
|
|
47
|
+
num_prompts_per_step: Optional[int] = None,
|
|
48
|
+
num_generations_per_prompt: Optional[int] = None,
|
|
49
|
+
concurrency: Optional[int] = None,
|
|
50
|
+
) -> Any:
|
|
51
|
+
"""
|
|
52
|
+
Generate rollouts and compute rewards using the current model snapshot.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
agent_function: Function/agent to call for generating responses
|
|
56
|
+
scorers: List of scorer objects to evaluate responses
|
|
57
|
+
prompts: List of prompts to use for training
|
|
58
|
+
num_prompts_per_step: Number of prompts to use per step
|
|
59
|
+
num_generations_per_prompt: Generations per prompt
|
|
60
|
+
concurrency: Concurrency limit
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Provider-specific dataset format for training
|
|
64
|
+
"""
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
async def run_reinforcement_learning(
|
|
69
|
+
self,
|
|
70
|
+
agent_function: Callable[[Any], Any],
|
|
71
|
+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
|
|
72
|
+
prompts: List[Any],
|
|
73
|
+
) -> ModelConfig:
|
|
74
|
+
"""
|
|
75
|
+
Run the iterative reinforcement learning fine-tuning loop.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
agent_function: Function/agent to call for generating responses
|
|
79
|
+
scorers: List of scorer objects to evaluate responses
|
|
80
|
+
prompts: List of prompts to use for training
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
ModelConfig: Configuration of the trained model
|
|
84
|
+
"""
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
async def train(
|
|
89
|
+
self,
|
|
90
|
+
agent_function: Callable[[Any], Any],
|
|
91
|
+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
|
|
92
|
+
prompts: List[Any],
|
|
93
|
+
) -> ModelConfig:
|
|
94
|
+
"""
|
|
95
|
+
Start the reinforcement learning fine-tuning process.
|
|
96
|
+
|
|
97
|
+
This is the main entry point for running the training.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
agent_function: Function/agent to call for generating responses
|
|
101
|
+
scorers: List of scorer objects to evaluate responses
|
|
102
|
+
prompts: List of prompts to use for training
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
ModelConfig: Configuration of the trained model
|
|
106
|
+
"""
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def _extract_message_history_from_spans(
|
|
111
|
+
self, trace_id: str
|
|
112
|
+
) -> List[Dict[str, str]]:
|
|
113
|
+
"""
|
|
114
|
+
Extract message history from spans for training purposes.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
trace_id: The trace ID (32-char hex string) to extract message history from
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
List of message dictionaries with 'role' and 'content' keys
|
|
121
|
+
"""
|
|
122
|
+
pass
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional, Dict, Any, TYPE_CHECKING
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from fireworks.llm.llm_reinforcement_step import ReinforcementAcceleratorTypeLiteral # type: ignore[import-not-found]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class TrainerConfig:
|
|
13
|
+
"""Configuration class for JudgmentTrainer parameters."""
|
|
14
|
+
|
|
15
|
+
deployment_id: str
|
|
16
|
+
user_id: str
|
|
17
|
+
model_id: str
|
|
18
|
+
base_model_name: str = "qwen2p5-7b-instruct"
|
|
19
|
+
rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
|
|
20
|
+
num_steps: int = 5
|
|
21
|
+
num_generations_per_prompt: int = 4
|
|
22
|
+
num_prompts_per_step: int = 4
|
|
23
|
+
concurrency: int = 100
|
|
24
|
+
epochs: int = 1
|
|
25
|
+
learning_rate: float = 1e-5
|
|
26
|
+
accelerator_count: int = 1
|
|
27
|
+
accelerator_type: ReinforcementAcceleratorTypeLiteral = "NVIDIA_A100_80GB"
|
|
28
|
+
temperature: float = 1.5
|
|
29
|
+
max_tokens: int = 50
|
|
30
|
+
enable_addons: bool = True
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ModelConfig:
|
|
35
|
+
"""
|
|
36
|
+
Configuration class for storing and loading trained model state.
|
|
37
|
+
|
|
38
|
+
This class enables persistence of trained models so they can be loaded
|
|
39
|
+
and used later without retraining.
|
|
40
|
+
|
|
41
|
+
Example usage:
|
|
42
|
+
trainer = JudgmentTrainer(config)
|
|
43
|
+
model_config = trainer.train(agent_function, scorers, prompts)
|
|
44
|
+
|
|
45
|
+
# Save the trained model configuration
|
|
46
|
+
model_config.save_to_file("my_trained_model.json")
|
|
47
|
+
|
|
48
|
+
# Later, load and use the trained model
|
|
49
|
+
loaded_config = ModelConfig.load_from_file("my_trained_model.json")
|
|
50
|
+
trained_model = TrainableModel.from_model_config(loaded_config)
|
|
51
|
+
|
|
52
|
+
# Use the trained model for inference
|
|
53
|
+
response = trained_model.chat.completions.create(
|
|
54
|
+
model="current", # Uses the loaded trained model
|
|
55
|
+
messages=[{"role": "user", "content": "Hello!"}]
|
|
56
|
+
)
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
# Base model configuration
|
|
60
|
+
base_model_name: str
|
|
61
|
+
deployment_id: str
|
|
62
|
+
user_id: str
|
|
63
|
+
model_id: str
|
|
64
|
+
enable_addons: bool
|
|
65
|
+
|
|
66
|
+
# Training state
|
|
67
|
+
current_step: int
|
|
68
|
+
total_steps: int
|
|
69
|
+
|
|
70
|
+
# Current model information
|
|
71
|
+
current_model_name: Optional[str] = None
|
|
72
|
+
is_trained: bool = False
|
|
73
|
+
|
|
74
|
+
# Training parameters used (for reference)
|
|
75
|
+
training_params: Optional[Dict[str, Any]] = None
|
|
76
|
+
|
|
77
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
78
|
+
"""Convert ModelConfig to dictionary for serialization."""
|
|
79
|
+
return {
|
|
80
|
+
"base_model_name": self.base_model_name,
|
|
81
|
+
"deployment_id": self.deployment_id,
|
|
82
|
+
"user_id": self.user_id,
|
|
83
|
+
"model_id": self.model_id,
|
|
84
|
+
"enable_addons": self.enable_addons,
|
|
85
|
+
"current_step": self.current_step,
|
|
86
|
+
"total_steps": self.total_steps,
|
|
87
|
+
"current_model_name": self.current_model_name,
|
|
88
|
+
"is_trained": self.is_trained,
|
|
89
|
+
"training_params": self.training_params,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def from_dict(cls, data: Dict[str, Any]) -> ModelConfig:
|
|
94
|
+
"""Create ModelConfig from dictionary."""
|
|
95
|
+
return cls(
|
|
96
|
+
base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
|
|
97
|
+
deployment_id=data.get("deployment_id", "my-base-deployment"),
|
|
98
|
+
user_id=data.get("user_id", ""),
|
|
99
|
+
model_id=data.get("model_id", ""),
|
|
100
|
+
enable_addons=data.get("enable_addons", True),
|
|
101
|
+
current_step=data.get("current_step", 0),
|
|
102
|
+
total_steps=data.get("total_steps", 0),
|
|
103
|
+
current_model_name=data.get("current_model_name"),
|
|
104
|
+
is_trained=data.get("is_trained", False),
|
|
105
|
+
training_params=data.get("training_params"),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def to_json(self) -> str:
|
|
109
|
+
"""Convert ModelConfig to JSON string."""
|
|
110
|
+
return json.dumps(self.to_dict(), indent=2)
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def from_json(cls, json_str: str) -> ModelConfig:
|
|
114
|
+
"""Create ModelConfig from JSON string."""
|
|
115
|
+
data = json.loads(json_str)
|
|
116
|
+
return cls.from_dict(data)
|
|
117
|
+
|
|
118
|
+
def save_to_file(self, filepath: str):
|
|
119
|
+
"""Save ModelConfig to a JSON file."""
|
|
120
|
+
with open(filepath, "w") as f:
|
|
121
|
+
f.write(self.to_json())
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def load_from_file(cls, filepath: str) -> ModelConfig:
|
|
125
|
+
"""Load ModelConfig from a JSON file."""
|
|
126
|
+
with open(filepath, "r") as f:
|
|
127
|
+
json_str = f.read()
|
|
128
|
+
return cls.from_json(json_str)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import sys
|
|
4
|
+
import os
|
|
5
|
+
from judgeval.utils.decorators.use_once import use_once
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@use_once
|
|
9
|
+
def _is_jupyter_environment():
|
|
10
|
+
"""Check if we're running in a Jupyter notebook or similar environment."""
|
|
11
|
+
try:
|
|
12
|
+
# Check for IPython kernel
|
|
13
|
+
if "ipykernel" in sys.modules or "IPython" in sys.modules:
|
|
14
|
+
return True
|
|
15
|
+
# Check for Jupyter environment variables
|
|
16
|
+
if "JPY_PARENT_PID" in os.environ:
|
|
17
|
+
return True
|
|
18
|
+
# Check if we're in Google Colab
|
|
19
|
+
if "google.colab" in sys.modules:
|
|
20
|
+
return True
|
|
21
|
+
return False
|
|
22
|
+
except Exception:
|
|
23
|
+
return False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
IS_JUPYTER = _is_jupyter_environment()
|
|
27
|
+
|
|
28
|
+
if not IS_JUPYTER:
|
|
29
|
+
try:
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.spinner import Spinner
|
|
32
|
+
from rich.live import Live
|
|
33
|
+
from rich.text import Text
|
|
34
|
+
|
|
35
|
+
shared_console = Console()
|
|
36
|
+
RICH_AVAILABLE = True
|
|
37
|
+
except ImportError:
|
|
38
|
+
RICH_AVAILABLE = False
|
|
39
|
+
else:
|
|
40
|
+
RICH_AVAILABLE = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SimpleSpinner:
|
|
44
|
+
def __init__(self, name, text):
|
|
45
|
+
self.text = text
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SimpleLive:
|
|
49
|
+
def __init__(self, spinner, console=None, refresh_per_second=None):
|
|
50
|
+
self.spinner = spinner
|
|
51
|
+
|
|
52
|
+
def __enter__(self):
|
|
53
|
+
print(f"🔄 {self.spinner.text}")
|
|
54
|
+
return self
|
|
55
|
+
|
|
56
|
+
def __exit__(self, *args):
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def update(self, spinner):
|
|
60
|
+
print(f"🔄 {spinner.text}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def safe_print(message, style=None):
|
|
64
|
+
"""Safe print function that works in all environments."""
|
|
65
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
|
66
|
+
shared_console.print(message, style=style)
|
|
67
|
+
else:
|
|
68
|
+
if style == "green":
|
|
69
|
+
print(f"✅ {message}")
|
|
70
|
+
elif style == "yellow":
|
|
71
|
+
print(f"⚠️ {message}")
|
|
72
|
+
elif style == "blue":
|
|
73
|
+
print(f"🔵 {message}")
|
|
74
|
+
elif style == "cyan":
|
|
75
|
+
print(f"🔷 {message}")
|
|
76
|
+
else:
|
|
77
|
+
print(message)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@contextmanager
|
|
81
|
+
def _spinner_progress(
|
|
82
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
|
83
|
+
):
|
|
84
|
+
"""Context manager for spinner-based progress display."""
|
|
85
|
+
if step is not None and total_steps is not None:
|
|
86
|
+
full_message = f"[Step {step}/{total_steps}] {message}"
|
|
87
|
+
else:
|
|
88
|
+
full_message = f"[Training] {message}"
|
|
89
|
+
|
|
90
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
|
91
|
+
spinner = Spinner("dots", text=Text(full_message, style="cyan"))
|
|
92
|
+
with Live(spinner, console=shared_console, refresh_per_second=10):
|
|
93
|
+
yield
|
|
94
|
+
else:
|
|
95
|
+
print(f"🔄 {full_message}")
|
|
96
|
+
try:
|
|
97
|
+
yield
|
|
98
|
+
finally:
|
|
99
|
+
print(f"✅ {full_message} - Complete")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@contextmanager
|
|
103
|
+
def _model_spinner_progress(message: str):
|
|
104
|
+
"""Context manager for model operation spinner-based progress display."""
|
|
105
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
|
106
|
+
spinner = Spinner("dots", text=Text(f"[Model] {message}", style="blue"))
|
|
107
|
+
with Live(spinner, console=shared_console, refresh_per_second=10) as live:
|
|
108
|
+
|
|
109
|
+
def update_progress(progress_message: str):
|
|
110
|
+
"""Update the spinner with a new progress message."""
|
|
111
|
+
new_text = f"[Model] {message}\n └─ {progress_message}"
|
|
112
|
+
spinner.text = Text(new_text, style="blue")
|
|
113
|
+
live.update(spinner)
|
|
114
|
+
|
|
115
|
+
yield update_progress
|
|
116
|
+
else:
|
|
117
|
+
print(f"🔵 [Model] {message}")
|
|
118
|
+
|
|
119
|
+
def update_progress(progress_message: str):
|
|
120
|
+
print(f" └─ {progress_message}")
|
|
121
|
+
|
|
122
|
+
yield update_progress
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _print_progress(
|
|
126
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
|
127
|
+
):
|
|
128
|
+
"""Print progress message with consistent formatting."""
|
|
129
|
+
if step is not None and total_steps is not None:
|
|
130
|
+
safe_print(f"[Step {step}/{total_steps}] {message}", style="green")
|
|
131
|
+
else:
|
|
132
|
+
safe_print(f"[Training] {message}", style="green")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _print_progress_update(
|
|
136
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
|
137
|
+
):
|
|
138
|
+
"""Print progress update message (for status changes during long operations)."""
|
|
139
|
+
safe_print(f" └─ {message}", style="yellow")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _print_model_progress(message: str):
|
|
143
|
+
"""Print model progress message with consistent formatting."""
|
|
144
|
+
safe_print(f"[Model] {message}", style="blue")
|