judgeval 0.0.11__py3-none-any.whl → 0.22.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of judgeval might be problematic. Click here for more details.

Files changed (171) hide show
  1. judgeval/__init__.py +177 -12
  2. judgeval/api/__init__.py +519 -0
  3. judgeval/api/api_types.py +407 -0
  4. judgeval/cli.py +79 -0
  5. judgeval/constants.py +76 -47
  6. judgeval/data/__init__.py +3 -3
  7. judgeval/data/evaluation_run.py +125 -0
  8. judgeval/data/example.py +15 -56
  9. judgeval/data/judgment_types.py +450 -0
  10. judgeval/data/result.py +29 -73
  11. judgeval/data/scorer_data.py +29 -62
  12. judgeval/data/scripts/fix_default_factory.py +23 -0
  13. judgeval/data/scripts/openapi_transform.py +123 -0
  14. judgeval/data/trace.py +121 -0
  15. judgeval/dataset/__init__.py +264 -0
  16. judgeval/env.py +52 -0
  17. judgeval/evaluation/__init__.py +344 -0
  18. judgeval/exceptions.py +27 -0
  19. judgeval/integrations/langgraph/__init__.py +13 -0
  20. judgeval/integrations/openlit/__init__.py +50 -0
  21. judgeval/judges/__init__.py +2 -3
  22. judgeval/judges/base_judge.py +2 -3
  23. judgeval/judges/litellm_judge.py +100 -20
  24. judgeval/judges/together_judge.py +101 -20
  25. judgeval/judges/utils.py +20 -24
  26. judgeval/logger.py +62 -0
  27. judgeval/prompt/__init__.py +330 -0
  28. judgeval/scorers/__init__.py +18 -25
  29. judgeval/scorers/agent_scorer.py +17 -0
  30. judgeval/scorers/api_scorer.py +45 -41
  31. judgeval/scorers/base_scorer.py +83 -38
  32. judgeval/scorers/example_scorer.py +17 -0
  33. judgeval/scorers/exceptions.py +1 -0
  34. judgeval/scorers/judgeval_scorers/__init__.py +0 -148
  35. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +19 -17
  36. judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +13 -19
  37. judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +12 -19
  38. judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +13 -19
  39. judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +15 -0
  40. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +327 -0
  41. judgeval/scorers/score.py +77 -306
  42. judgeval/scorers/utils.py +4 -199
  43. judgeval/tracer/__init__.py +1122 -2
  44. judgeval/tracer/constants.py +1 -0
  45. judgeval/tracer/exporters/__init__.py +40 -0
  46. judgeval/tracer/exporters/s3.py +119 -0
  47. judgeval/tracer/exporters/store.py +59 -0
  48. judgeval/tracer/exporters/utils.py +32 -0
  49. judgeval/tracer/keys.py +63 -0
  50. judgeval/tracer/llm/__init__.py +7 -0
  51. judgeval/tracer/llm/config.py +78 -0
  52. judgeval/tracer/llm/constants.py +9 -0
  53. judgeval/tracer/llm/llm_anthropic/__init__.py +3 -0
  54. judgeval/tracer/llm/llm_anthropic/config.py +6 -0
  55. judgeval/tracer/llm/llm_anthropic/messages.py +452 -0
  56. judgeval/tracer/llm/llm_anthropic/messages_stream.py +322 -0
  57. judgeval/tracer/llm/llm_anthropic/wrapper.py +59 -0
  58. judgeval/tracer/llm/llm_google/__init__.py +3 -0
  59. judgeval/tracer/llm/llm_google/config.py +6 -0
  60. judgeval/tracer/llm/llm_google/generate_content.py +127 -0
  61. judgeval/tracer/llm/llm_google/wrapper.py +30 -0
  62. judgeval/tracer/llm/llm_openai/__init__.py +3 -0
  63. judgeval/tracer/llm/llm_openai/beta_chat_completions.py +216 -0
  64. judgeval/tracer/llm/llm_openai/chat_completions.py +501 -0
  65. judgeval/tracer/llm/llm_openai/config.py +6 -0
  66. judgeval/tracer/llm/llm_openai/responses.py +506 -0
  67. judgeval/tracer/llm/llm_openai/utils.py +42 -0
  68. judgeval/tracer/llm/llm_openai/wrapper.py +63 -0
  69. judgeval/tracer/llm/llm_together/__init__.py +3 -0
  70. judgeval/tracer/llm/llm_together/chat_completions.py +406 -0
  71. judgeval/tracer/llm/llm_together/config.py +6 -0
  72. judgeval/tracer/llm/llm_together/wrapper.py +52 -0
  73. judgeval/tracer/llm/providers.py +19 -0
  74. judgeval/tracer/managers.py +167 -0
  75. judgeval/tracer/processors/__init__.py +220 -0
  76. judgeval/tracer/utils.py +19 -0
  77. judgeval/trainer/__init__.py +14 -0
  78. judgeval/trainer/base_trainer.py +122 -0
  79. judgeval/trainer/config.py +128 -0
  80. judgeval/trainer/console.py +144 -0
  81. judgeval/trainer/fireworks_trainer.py +396 -0
  82. judgeval/trainer/trainable_model.py +243 -0
  83. judgeval/trainer/trainer.py +70 -0
  84. judgeval/utils/async_utils.py +39 -0
  85. judgeval/utils/decorators/__init__.py +0 -0
  86. judgeval/utils/decorators/dont_throw.py +37 -0
  87. judgeval/utils/decorators/use_once.py +13 -0
  88. judgeval/utils/file_utils.py +97 -0
  89. judgeval/utils/guards.py +36 -0
  90. judgeval/utils/meta.py +27 -0
  91. judgeval/utils/project.py +15 -0
  92. judgeval/utils/serialize.py +253 -0
  93. judgeval/utils/testing.py +70 -0
  94. judgeval/utils/url.py +10 -0
  95. judgeval/utils/version_check.py +28 -0
  96. judgeval/utils/wrappers/README.md +3 -0
  97. judgeval/utils/wrappers/__init__.py +15 -0
  98. judgeval/utils/wrappers/immutable_wrap_async.py +74 -0
  99. judgeval/utils/wrappers/immutable_wrap_async_iterator.py +84 -0
  100. judgeval/utils/wrappers/immutable_wrap_sync.py +66 -0
  101. judgeval/utils/wrappers/immutable_wrap_sync_iterator.py +84 -0
  102. judgeval/utils/wrappers/mutable_wrap_async.py +67 -0
  103. judgeval/utils/wrappers/mutable_wrap_sync.py +67 -0
  104. judgeval/utils/wrappers/py.typed +0 -0
  105. judgeval/utils/wrappers/utils.py +35 -0
  106. judgeval/version.py +5 -0
  107. judgeval/warnings.py +4 -0
  108. judgeval-0.22.2.dist-info/METADATA +265 -0
  109. judgeval-0.22.2.dist-info/RECORD +112 -0
  110. judgeval-0.22.2.dist-info/entry_points.txt +2 -0
  111. judgeval/clients.py +0 -39
  112. judgeval/common/__init__.py +0 -8
  113. judgeval/common/exceptions.py +0 -28
  114. judgeval/common/logger.py +0 -189
  115. judgeval/common/tracer.py +0 -798
  116. judgeval/common/utils.py +0 -763
  117. judgeval/data/api_example.py +0 -111
  118. judgeval/data/datasets/__init__.py +0 -5
  119. judgeval/data/datasets/dataset.py +0 -286
  120. judgeval/data/datasets/eval_dataset_client.py +0 -193
  121. judgeval/data/datasets/ground_truth.py +0 -54
  122. judgeval/data/datasets/utils.py +0 -74
  123. judgeval/evaluation_run.py +0 -132
  124. judgeval/judges/mixture_of_judges.py +0 -248
  125. judgeval/judgment_client.py +0 -354
  126. judgeval/run_evaluation.py +0 -439
  127. judgeval/scorers/judgeval_scorer.py +0 -140
  128. judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +0 -19
  129. judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +0 -19
  130. judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +0 -22
  131. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -19
  132. judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +0 -32
  133. judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +0 -20
  134. judgeval/scorers/judgeval_scorers/api_scorers/tool_correctness.py +0 -19
  135. judgeval/scorers/judgeval_scorers/classifiers/__init__.py +0 -3
  136. judgeval/scorers/judgeval_scorers/classifiers/text2sql/__init__.py +0 -3
  137. judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +0 -54
  138. judgeval/scorers/judgeval_scorers/local_implementations/__init__.py +0 -24
  139. judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/__init__.py +0 -4
  140. judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/answer_correctness_scorer.py +0 -277
  141. judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/prompts.py +0 -169
  142. judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/__init__.py +0 -4
  143. judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/answer_relevancy_scorer.py +0 -298
  144. judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/prompts.py +0 -174
  145. judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/__init__.py +0 -3
  146. judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/contextual_precision_scorer.py +0 -264
  147. judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/prompts.py +0 -106
  148. judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/__init__.py +0 -3
  149. judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/contextual_recall_scorer.py +0 -254
  150. judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/prompts.py +0 -142
  151. judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/__init__.py +0 -3
  152. judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/contextual_relevancy_scorer.py +0 -245
  153. judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/prompts.py +0 -121
  154. judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/__init__.py +0 -3
  155. judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/faithfulness_scorer.py +0 -325
  156. judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/prompts.py +0 -268
  157. judgeval/scorers/judgeval_scorers/local_implementations/hallucination/__init__.py +0 -3
  158. judgeval/scorers/judgeval_scorers/local_implementations/hallucination/hallucination_scorer.py +0 -263
  159. judgeval/scorers/judgeval_scorers/local_implementations/hallucination/prompts.py +0 -104
  160. judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/__init__.py +0 -5
  161. judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/json_correctness_scorer.py +0 -134
  162. judgeval/scorers/judgeval_scorers/local_implementations/summarization/__init__.py +0 -3
  163. judgeval/scorers/judgeval_scorers/local_implementations/summarization/prompts.py +0 -247
  164. judgeval/scorers/judgeval_scorers/local_implementations/summarization/summarization_scorer.py +0 -550
  165. judgeval/scorers/judgeval_scorers/local_implementations/tool_correctness/__init__.py +0 -3
  166. judgeval/scorers/judgeval_scorers/local_implementations/tool_correctness/tool_correctness_scorer.py +0 -157
  167. judgeval/scorers/prompt_scorer.py +0 -439
  168. judgeval-0.0.11.dist-info/METADATA +0 -36
  169. judgeval-0.0.11.dist-info/RECORD +0 -84
  170. {judgeval-0.0.11.dist-info → judgeval-0.22.2.dist-info}/WHEEL +0 -0
  171. {judgeval-0.0.11.dist-info → judgeval-0.22.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,220 @@
1
+ from __future__ import annotations
2
+ from typing import Optional, TYPE_CHECKING, Any
3
+ from collections import defaultdict
4
+ from opentelemetry.context import Context
5
+ from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor
6
+ from opentelemetry.trace.span import SpanContext
7
+ from opentelemetry.sdk.trace.export import (
8
+ BatchSpanProcessor,
9
+ )
10
+ from judgeval.tracer.exporters import JudgmentSpanExporter
11
+ from judgeval.tracer.keys import AttributeKeys, InternalAttributeKeys, ResourceKeys
12
+ from judgeval.utils.url import url_for
13
+ from judgeval.utils.decorators.dont_throw import dont_throw
14
+ from judgeval.version import get_version
15
+
16
+ if TYPE_CHECKING:
17
+ from judgeval.tracer import Tracer
18
+
19
+
20
+ class NoOpSpanProcessor(SpanProcessor):
21
+ def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None:
22
+ pass
23
+
24
+ def on_end(self, span: ReadableSpan) -> None:
25
+ pass
26
+
27
+ def shutdown(self) -> None:
28
+ pass
29
+
30
+ def force_flush(self, timeout_millis: int = 30000) -> bool:
31
+ return True
32
+
33
+
34
+ class JudgmentSpanProcessor(BatchSpanProcessor):
35
+ __slots__ = ("tracer", "resource_attributes", "_internal_attributes")
36
+
37
+ def __init__(
38
+ self,
39
+ tracer: Tracer,
40
+ project_name: str,
41
+ project_id: str,
42
+ api_key: str,
43
+ organization_id: str,
44
+ /,
45
+ *,
46
+ max_queue_size: int | None = None,
47
+ schedule_delay_millis: float | None = None,
48
+ max_export_batch_size: int | None = None,
49
+ export_timeout_millis: float | None = None,
50
+ resource_attributes: Optional[dict[str, Any]] = None,
51
+ ):
52
+ self.tracer = tracer
53
+
54
+ attrs = {
55
+ ResourceKeys.SERVICE_NAME: project_name,
56
+ ResourceKeys.TELEMETRY_SDK_NAME: "judgeval",
57
+ ResourceKeys.TELEMETRY_SDK_VERSION: get_version(),
58
+ ResourceKeys.JUDGMENT_PROJECT_ID: project_id,
59
+ **(resource_attributes or {}),
60
+ }
61
+ self.resource_attributes = attrs
62
+
63
+ super().__init__(
64
+ JudgmentSpanExporter(
65
+ endpoint=url_for("/otel/v1/traces"),
66
+ api_key=api_key,
67
+ organization_id=organization_id,
68
+ project_id=project_id,
69
+ ),
70
+ max_queue_size=max_queue_size,
71
+ schedule_delay_millis=schedule_delay_millis,
72
+ max_export_batch_size=max_export_batch_size,
73
+ export_timeout_millis=export_timeout_millis,
74
+ )
75
+ self._internal_attributes: defaultdict[tuple[int, int], dict[str, Any]] = (
76
+ defaultdict(dict)
77
+ )
78
+
79
+ def _get_span_key(self, span_context: SpanContext) -> tuple[int, int]:
80
+ return (span_context.trace_id, span_context.span_id)
81
+
82
+ def set_internal_attribute(
83
+ self, span_context: SpanContext, key: str, value: Any
84
+ ) -> None:
85
+ span_key = self._get_span_key(span_context)
86
+ self._internal_attributes[span_key][key] = value
87
+
88
+ def get_internal_attribute(
89
+ self, span_context: SpanContext, key: str, default: Any = None
90
+ ) -> Any:
91
+ span_key = self._get_span_key(span_context)
92
+ return self._internal_attributes[span_key].get(key, default)
93
+
94
+ def increment_update_id(self, span_context: SpanContext) -> int:
95
+ current_id = self.get_internal_attribute(
96
+ span_context=span_context, key=AttributeKeys.JUDGMENT_UPDATE_ID, default=0
97
+ )
98
+ new_id = current_id + 1
99
+ self.set_internal_attribute(
100
+ span_context=span_context,
101
+ key=AttributeKeys.JUDGMENT_UPDATE_ID,
102
+ value=new_id,
103
+ )
104
+ return current_id
105
+
106
+ def _cleanup_span_state(self, span_key: tuple[int, int]) -> None:
107
+ self._internal_attributes.pop(span_key, None)
108
+
109
+ @dont_throw
110
+ def emit_partial(self) -> None:
111
+ current_span = self.tracer.get_current_span()
112
+ if (
113
+ not current_span
114
+ or not current_span.is_recording()
115
+ or not isinstance(current_span, ReadableSpan)
116
+ ):
117
+ return
118
+
119
+ span_context = current_span.get_span_context()
120
+ if self.get_internal_attribute(
121
+ span_context, InternalAttributeKeys.DISABLE_PARTIAL_EMIT, False
122
+ ):
123
+ return
124
+
125
+ attributes = dict(current_span.attributes or {})
126
+ attributes[AttributeKeys.JUDGMENT_UPDATE_ID] = self.increment_update_id(
127
+ span_context
128
+ )
129
+
130
+ partial_span = ReadableSpan(
131
+ name=current_span.name,
132
+ context=span_context,
133
+ parent=current_span.parent,
134
+ resource=current_span.resource,
135
+ attributes=attributes,
136
+ events=current_span.events,
137
+ links=current_span.links,
138
+ status=current_span.status,
139
+ kind=current_span.kind,
140
+ start_time=current_span.start_time,
141
+ end_time=None,
142
+ instrumentation_scope=current_span.instrumentation_scope,
143
+ )
144
+
145
+ super().on_end(partial_span)
146
+
147
+ def on_end(self, span: ReadableSpan) -> None:
148
+ if not span.context:
149
+ super().on_end(span)
150
+ return
151
+
152
+ span_key = self._get_span_key(span.context)
153
+
154
+ if self.get_internal_attribute(
155
+ span.context, InternalAttributeKeys.CANCELLED, False
156
+ ):
157
+ self._cleanup_span_state(span_key)
158
+ return
159
+
160
+ if span.end_time is not None:
161
+ attributes = dict(span.attributes or {})
162
+ attributes[AttributeKeys.JUDGMENT_UPDATE_ID] = 20
163
+
164
+ final_span = ReadableSpan(
165
+ name=span.name,
166
+ context=span.context,
167
+ parent=span.parent,
168
+ resource=span.resource,
169
+ attributes=attributes,
170
+ events=span.events,
171
+ links=span.links,
172
+ status=span.status,
173
+ kind=span.kind,
174
+ start_time=span.start_time,
175
+ end_time=span.end_time,
176
+ instrumentation_scope=span.instrumentation_scope,
177
+ )
178
+
179
+ self._cleanup_span_state(span_key)
180
+ super().on_end(final_span)
181
+ else:
182
+ super().on_end(span)
183
+
184
+
185
+ class NoOpJudgmentSpanProcessor(JudgmentSpanProcessor):
186
+ __slots__ = ("resource_attributes",)
187
+
188
+ def __init__(self):
189
+ self.resource_attributes = {}
190
+
191
+ def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None:
192
+ pass
193
+
194
+ def on_end(self, span: ReadableSpan) -> None:
195
+ pass
196
+
197
+ def shutdown(self) -> None:
198
+ pass
199
+
200
+ def force_flush(self, timeout_millis: int | None = 30000) -> bool:
201
+ return True
202
+
203
+ def emit_partial(self) -> None:
204
+ pass
205
+
206
+ def set_internal_attribute(
207
+ self, span_context: SpanContext, key: str, value: Any
208
+ ) -> None:
209
+ pass
210
+
211
+ def get_internal_attribute(
212
+ self, span_context: SpanContext, key: str, default: Any = None
213
+ ) -> Any:
214
+ return default
215
+
216
+ def increment_update_id(self, span_context: SpanContext) -> int:
217
+ return 0
218
+
219
+
220
+ __all__ = ["NoOpSpanProcessor", "JudgmentSpanProcessor", "NoOpJudgmentSpanProcessor"]
@@ -0,0 +1,19 @@
1
+ from typing import Any
2
+ from opentelemetry.trace import Span
3
+ from pydantic import BaseModel
4
+ from typing import Callable, Optional
5
+ from judgeval.scorers.api_scorer import TraceAPIScorerConfig
6
+
7
+
8
+ def set_span_attribute(span: Span, name: str, value: Any):
9
+ if value is None or value == "":
10
+ return
11
+
12
+ span.set_attribute(name, value)
13
+
14
+
15
+ class TraceScorerConfig(BaseModel):
16
+ scorer: TraceAPIScorerConfig | None
17
+ model: Optional[str] = None
18
+ sampling_rate: float = 1.0
19
+ run_condition: Optional[Callable[..., bool]] = None
@@ -0,0 +1,14 @@
1
+ from judgeval.trainer.trainer import JudgmentTrainer
2
+ from judgeval.trainer.config import TrainerConfig, ModelConfig
3
+ from judgeval.trainer.trainable_model import TrainableModel
4
+ from judgeval.trainer.base_trainer import BaseTrainer
5
+ from judgeval.trainer.fireworks_trainer import FireworksTrainer
6
+
7
+ __all__ = [
8
+ "JudgmentTrainer",
9
+ "TrainerConfig",
10
+ "ModelConfig",
11
+ "TrainableModel",
12
+ "BaseTrainer",
13
+ "FireworksTrainer",
14
+ ]
@@ -0,0 +1,122 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Callable, List, Optional, Union, Dict, TYPE_CHECKING
3
+ from .config import TrainerConfig, ModelConfig
4
+ from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
5
+
6
+ if TYPE_CHECKING:
7
+ from judgeval.tracer import Tracer
8
+ from .trainable_model import TrainableModel
9
+
10
+
11
+ class BaseTrainer(ABC):
12
+ """
13
+ Abstract base class for training providers.
14
+
15
+ This class defines the interface that all training provider implementations
16
+ must follow. Each provider (Fireworks, Verifiers, etc.) will have its own
17
+ concrete implementation of this interface.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ config: TrainerConfig,
23
+ trainable_model: "TrainableModel",
24
+ tracer: "Tracer",
25
+ project_name: Optional[str] = None,
26
+ ):
27
+ """
28
+ Initialize the base trainer.
29
+
30
+ Args:
31
+ config: TrainerConfig instance with training parameters
32
+ trainable_model: TrainableModel instance to use for training
33
+ tracer: Tracer for observability
34
+ project_name: Project name for organizing training runs
35
+ """
36
+ self.config = config
37
+ self.trainable_model = trainable_model
38
+ self.tracer = tracer
39
+ self.project_name = project_name or "judgment_training"
40
+
41
+ @abstractmethod
42
+ async def generate_rollouts_and_rewards(
43
+ self,
44
+ agent_function: Callable[[Any], Any],
45
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
46
+ prompts: List[Any],
47
+ num_prompts_per_step: Optional[int] = None,
48
+ num_generations_per_prompt: Optional[int] = None,
49
+ concurrency: Optional[int] = None,
50
+ ) -> Any:
51
+ """
52
+ Generate rollouts and compute rewards using the current model snapshot.
53
+
54
+ Args:
55
+ agent_function: Function/agent to call for generating responses
56
+ scorers: List of scorer objects to evaluate responses
57
+ prompts: List of prompts to use for training
58
+ num_prompts_per_step: Number of prompts to use per step
59
+ num_generations_per_prompt: Generations per prompt
60
+ concurrency: Concurrency limit
61
+
62
+ Returns:
63
+ Provider-specific dataset format for training
64
+ """
65
+ pass
66
+
67
+ @abstractmethod
68
+ async def run_reinforcement_learning(
69
+ self,
70
+ agent_function: Callable[[Any], Any],
71
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
72
+ prompts: List[Any],
73
+ ) -> ModelConfig:
74
+ """
75
+ Run the iterative reinforcement learning fine-tuning loop.
76
+
77
+ Args:
78
+ agent_function: Function/agent to call for generating responses
79
+ scorers: List of scorer objects to evaluate responses
80
+ prompts: List of prompts to use for training
81
+
82
+ Returns:
83
+ ModelConfig: Configuration of the trained model
84
+ """
85
+ pass
86
+
87
+ @abstractmethod
88
+ async def train(
89
+ self,
90
+ agent_function: Callable[[Any], Any],
91
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
92
+ prompts: List[Any],
93
+ ) -> ModelConfig:
94
+ """
95
+ Start the reinforcement learning fine-tuning process.
96
+
97
+ This is the main entry point for running the training.
98
+
99
+ Args:
100
+ agent_function: Function/agent to call for generating responses
101
+ scorers: List of scorer objects to evaluate responses
102
+ prompts: List of prompts to use for training
103
+
104
+ Returns:
105
+ ModelConfig: Configuration of the trained model
106
+ """
107
+ pass
108
+
109
+ @abstractmethod
110
+ def _extract_message_history_from_spans(
111
+ self, trace_id: str
112
+ ) -> List[Dict[str, str]]:
113
+ """
114
+ Extract message history from spans for training purposes.
115
+
116
+ Args:
117
+ trace_id: The trace ID (32-char hex string) to extract message history from
118
+
119
+ Returns:
120
+ List of message dictionaries with 'role' and 'content' keys
121
+ """
122
+ pass
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Dict, Any, TYPE_CHECKING
5
+ import json
6
+
7
+ if TYPE_CHECKING:
8
+ from fireworks.llm.llm_reinforcement_step import ReinforcementAcceleratorTypeLiteral # type: ignore[import-not-found]
9
+
10
+
11
+ @dataclass
12
+ class TrainerConfig:
13
+ """Configuration class for JudgmentTrainer parameters."""
14
+
15
+ deployment_id: str
16
+ user_id: str
17
+ model_id: str
18
+ base_model_name: str = "qwen2p5-7b-instruct"
19
+ rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
20
+ num_steps: int = 5
21
+ num_generations_per_prompt: int = 4
22
+ num_prompts_per_step: int = 4
23
+ concurrency: int = 100
24
+ epochs: int = 1
25
+ learning_rate: float = 1e-5
26
+ accelerator_count: int = 1
27
+ accelerator_type: ReinforcementAcceleratorTypeLiteral = "NVIDIA_A100_80GB"
28
+ temperature: float = 1.5
29
+ max_tokens: int = 50
30
+ enable_addons: bool = True
31
+
32
+
33
+ @dataclass
34
+ class ModelConfig:
35
+ """
36
+ Configuration class for storing and loading trained model state.
37
+
38
+ This class enables persistence of trained models so they can be loaded
39
+ and used later without retraining.
40
+
41
+ Example usage:
42
+ trainer = JudgmentTrainer(config)
43
+ model_config = trainer.train(agent_function, scorers, prompts)
44
+
45
+ # Save the trained model configuration
46
+ model_config.save_to_file("my_trained_model.json")
47
+
48
+ # Later, load and use the trained model
49
+ loaded_config = ModelConfig.load_from_file("my_trained_model.json")
50
+ trained_model = TrainableModel.from_model_config(loaded_config)
51
+
52
+ # Use the trained model for inference
53
+ response = trained_model.chat.completions.create(
54
+ model="current", # Uses the loaded trained model
55
+ messages=[{"role": "user", "content": "Hello!"}]
56
+ )
57
+ """
58
+
59
+ # Base model configuration
60
+ base_model_name: str
61
+ deployment_id: str
62
+ user_id: str
63
+ model_id: str
64
+ enable_addons: bool
65
+
66
+ # Training state
67
+ current_step: int
68
+ total_steps: int
69
+
70
+ # Current model information
71
+ current_model_name: Optional[str] = None
72
+ is_trained: bool = False
73
+
74
+ # Training parameters used (for reference)
75
+ training_params: Optional[Dict[str, Any]] = None
76
+
77
+ def to_dict(self) -> Dict[str, Any]:
78
+ """Convert ModelConfig to dictionary for serialization."""
79
+ return {
80
+ "base_model_name": self.base_model_name,
81
+ "deployment_id": self.deployment_id,
82
+ "user_id": self.user_id,
83
+ "model_id": self.model_id,
84
+ "enable_addons": self.enable_addons,
85
+ "current_step": self.current_step,
86
+ "total_steps": self.total_steps,
87
+ "current_model_name": self.current_model_name,
88
+ "is_trained": self.is_trained,
89
+ "training_params": self.training_params,
90
+ }
91
+
92
+ @classmethod
93
+ def from_dict(cls, data: Dict[str, Any]) -> ModelConfig:
94
+ """Create ModelConfig from dictionary."""
95
+ return cls(
96
+ base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
97
+ deployment_id=data.get("deployment_id", "my-base-deployment"),
98
+ user_id=data.get("user_id", ""),
99
+ model_id=data.get("model_id", ""),
100
+ enable_addons=data.get("enable_addons", True),
101
+ current_step=data.get("current_step", 0),
102
+ total_steps=data.get("total_steps", 0),
103
+ current_model_name=data.get("current_model_name"),
104
+ is_trained=data.get("is_trained", False),
105
+ training_params=data.get("training_params"),
106
+ )
107
+
108
+ def to_json(self) -> str:
109
+ """Convert ModelConfig to JSON string."""
110
+ return json.dumps(self.to_dict(), indent=2)
111
+
112
+ @classmethod
113
+ def from_json(cls, json_str: str) -> ModelConfig:
114
+ """Create ModelConfig from JSON string."""
115
+ data = json.loads(json_str)
116
+ return cls.from_dict(data)
117
+
118
+ def save_to_file(self, filepath: str):
119
+ """Save ModelConfig to a JSON file."""
120
+ with open(filepath, "w") as f:
121
+ f.write(self.to_json())
122
+
123
+ @classmethod
124
+ def load_from_file(cls, filepath: str) -> ModelConfig:
125
+ """Load ModelConfig from a JSON file."""
126
+ with open(filepath, "r") as f:
127
+ json_str = f.read()
128
+ return cls.from_json(json_str)
@@ -0,0 +1,144 @@
1
+ from contextlib import contextmanager
2
+ from typing import Optional
3
+ import sys
4
+ import os
5
+ from judgeval.utils.decorators.use_once import use_once
6
+
7
+
8
+ @use_once
9
+ def _is_jupyter_environment():
10
+ """Check if we're running in a Jupyter notebook or similar environment."""
11
+ try:
12
+ # Check for IPython kernel
13
+ if "ipykernel" in sys.modules or "IPython" in sys.modules:
14
+ return True
15
+ # Check for Jupyter environment variables
16
+ if "JPY_PARENT_PID" in os.environ:
17
+ return True
18
+ # Check if we're in Google Colab
19
+ if "google.colab" in sys.modules:
20
+ return True
21
+ return False
22
+ except Exception:
23
+ return False
24
+
25
+
26
+ IS_JUPYTER = _is_jupyter_environment()
27
+
28
+ if not IS_JUPYTER:
29
+ try:
30
+ from rich.console import Console
31
+ from rich.spinner import Spinner
32
+ from rich.live import Live
33
+ from rich.text import Text
34
+
35
+ shared_console = Console()
36
+ RICH_AVAILABLE = True
37
+ except ImportError:
38
+ RICH_AVAILABLE = False
39
+ else:
40
+ RICH_AVAILABLE = False
41
+
42
+
43
+ class SimpleSpinner:
44
+ def __init__(self, name, text):
45
+ self.text = text
46
+
47
+
48
+ class SimpleLive:
49
+ def __init__(self, spinner, console=None, refresh_per_second=None):
50
+ self.spinner = spinner
51
+
52
+ def __enter__(self):
53
+ print(f"🔄 {self.spinner.text}")
54
+ return self
55
+
56
+ def __exit__(self, *args):
57
+ pass
58
+
59
+ def update(self, spinner):
60
+ print(f"🔄 {spinner.text}")
61
+
62
+
63
+ def safe_print(message, style=None):
64
+ """Safe print function that works in all environments."""
65
+ if RICH_AVAILABLE and not IS_JUPYTER:
66
+ shared_console.print(message, style=style)
67
+ else:
68
+ if style == "green":
69
+ print(f"✅ {message}")
70
+ elif style == "yellow":
71
+ print(f"⚠️ {message}")
72
+ elif style == "blue":
73
+ print(f"🔵 {message}")
74
+ elif style == "cyan":
75
+ print(f"🔷 {message}")
76
+ else:
77
+ print(message)
78
+
79
+
80
+ @contextmanager
81
+ def _spinner_progress(
82
+ message: str, step: Optional[int] = None, total_steps: Optional[int] = None
83
+ ):
84
+ """Context manager for spinner-based progress display."""
85
+ if step is not None and total_steps is not None:
86
+ full_message = f"[Step {step}/{total_steps}] {message}"
87
+ else:
88
+ full_message = f"[Training] {message}"
89
+
90
+ if RICH_AVAILABLE and not IS_JUPYTER:
91
+ spinner = Spinner("dots", text=Text(full_message, style="cyan"))
92
+ with Live(spinner, console=shared_console, refresh_per_second=10):
93
+ yield
94
+ else:
95
+ print(f"🔄 {full_message}")
96
+ try:
97
+ yield
98
+ finally:
99
+ print(f"✅ {full_message} - Complete")
100
+
101
+
102
+ @contextmanager
103
+ def _model_spinner_progress(message: str):
104
+ """Context manager for model operation spinner-based progress display."""
105
+ if RICH_AVAILABLE and not IS_JUPYTER:
106
+ spinner = Spinner("dots", text=Text(f"[Model] {message}", style="blue"))
107
+ with Live(spinner, console=shared_console, refresh_per_second=10) as live:
108
+
109
+ def update_progress(progress_message: str):
110
+ """Update the spinner with a new progress message."""
111
+ new_text = f"[Model] {message}\n └─ {progress_message}"
112
+ spinner.text = Text(new_text, style="blue")
113
+ live.update(spinner)
114
+
115
+ yield update_progress
116
+ else:
117
+ print(f"🔵 [Model] {message}")
118
+
119
+ def update_progress(progress_message: str):
120
+ print(f" └─ {progress_message}")
121
+
122
+ yield update_progress
123
+
124
+
125
+ def _print_progress(
126
+ message: str, step: Optional[int] = None, total_steps: Optional[int] = None
127
+ ):
128
+ """Print progress message with consistent formatting."""
129
+ if step is not None and total_steps is not None:
130
+ safe_print(f"[Step {step}/{total_steps}] {message}", style="green")
131
+ else:
132
+ safe_print(f"[Training] {message}", style="green")
133
+
134
+
135
+ def _print_progress_update(
136
+ message: str, step: Optional[int] = None, total_steps: Optional[int] = None
137
+ ):
138
+ """Print progress update message (for status changes during long operations)."""
139
+ safe_print(f" └─ {message}", style="yellow")
140
+
141
+
142
+ def _print_model_progress(message: str):
143
+ """Print model progress message with consistent formatting."""
144
+ safe_print(f"[Model] {message}", style="blue")