judgeval 0.1.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. judgeval/__init__.py +173 -10
  2. judgeval/api/__init__.py +523 -0
  3. judgeval/api/api_types.py +413 -0
  4. judgeval/cli.py +112 -0
  5. judgeval/constants.py +7 -30
  6. judgeval/data/__init__.py +1 -3
  7. judgeval/data/evaluation_run.py +125 -0
  8. judgeval/data/example.py +14 -40
  9. judgeval/data/judgment_types.py +396 -146
  10. judgeval/data/result.py +11 -18
  11. judgeval/data/scorer_data.py +3 -26
  12. judgeval/data/scripts/openapi_transform.py +5 -5
  13. judgeval/data/trace.py +115 -194
  14. judgeval/dataset/__init__.py +335 -0
  15. judgeval/env.py +55 -0
  16. judgeval/evaluation/__init__.py +346 -0
  17. judgeval/exceptions.py +28 -0
  18. judgeval/integrations/langgraph/__init__.py +13 -0
  19. judgeval/integrations/openlit/__init__.py +51 -0
  20. judgeval/judges/__init__.py +2 -2
  21. judgeval/judges/litellm_judge.py +77 -16
  22. judgeval/judges/together_judge.py +88 -17
  23. judgeval/judges/utils.py +7 -20
  24. judgeval/judgment_attribute_keys.py +55 -0
  25. judgeval/{common/logger.py → logger.py} +24 -8
  26. judgeval/prompt/__init__.py +330 -0
  27. judgeval/scorers/__init__.py +11 -11
  28. judgeval/scorers/agent_scorer.py +15 -19
  29. judgeval/scorers/api_scorer.py +21 -23
  30. judgeval/scorers/base_scorer.py +54 -36
  31. judgeval/scorers/example_scorer.py +1 -3
  32. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +2 -24
  33. judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +2 -10
  34. judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +2 -2
  35. judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +2 -10
  36. judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +2 -14
  37. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +171 -59
  38. judgeval/scorers/score.py +64 -47
  39. judgeval/scorers/utils.py +2 -107
  40. judgeval/tracer/__init__.py +1111 -2
  41. judgeval/tracer/constants.py +1 -0
  42. judgeval/tracer/exporters/__init__.py +40 -0
  43. judgeval/tracer/exporters/s3.py +119 -0
  44. judgeval/tracer/exporters/store.py +59 -0
  45. judgeval/tracer/exporters/utils.py +32 -0
  46. judgeval/tracer/keys.py +63 -0
  47. judgeval/tracer/llm/__init__.py +7 -0
  48. judgeval/tracer/llm/config.py +78 -0
  49. judgeval/tracer/llm/constants.py +9 -0
  50. judgeval/tracer/llm/llm_anthropic/__init__.py +3 -0
  51. judgeval/tracer/llm/llm_anthropic/config.py +6 -0
  52. judgeval/tracer/llm/llm_anthropic/messages.py +452 -0
  53. judgeval/tracer/llm/llm_anthropic/messages_stream.py +322 -0
  54. judgeval/tracer/llm/llm_anthropic/wrapper.py +59 -0
  55. judgeval/tracer/llm/llm_google/__init__.py +3 -0
  56. judgeval/tracer/llm/llm_google/config.py +6 -0
  57. judgeval/tracer/llm/llm_google/generate_content.py +127 -0
  58. judgeval/tracer/llm/llm_google/wrapper.py +30 -0
  59. judgeval/tracer/llm/llm_openai/__init__.py +3 -0
  60. judgeval/tracer/llm/llm_openai/beta_chat_completions.py +216 -0
  61. judgeval/tracer/llm/llm_openai/chat_completions.py +501 -0
  62. judgeval/tracer/llm/llm_openai/config.py +6 -0
  63. judgeval/tracer/llm/llm_openai/responses.py +506 -0
  64. judgeval/tracer/llm/llm_openai/utils.py +42 -0
  65. judgeval/tracer/llm/llm_openai/wrapper.py +63 -0
  66. judgeval/tracer/llm/llm_together/__init__.py +3 -0
  67. judgeval/tracer/llm/llm_together/chat_completions.py +406 -0
  68. judgeval/tracer/llm/llm_together/config.py +6 -0
  69. judgeval/tracer/llm/llm_together/wrapper.py +52 -0
  70. judgeval/tracer/llm/providers.py +19 -0
  71. judgeval/tracer/managers.py +167 -0
  72. judgeval/tracer/processors/__init__.py +220 -0
  73. judgeval/tracer/utils.py +19 -0
  74. judgeval/trainer/__init__.py +14 -0
  75. judgeval/trainer/base_trainer.py +122 -0
  76. judgeval/trainer/config.py +123 -0
  77. judgeval/trainer/console.py +144 -0
  78. judgeval/trainer/fireworks_trainer.py +392 -0
  79. judgeval/trainer/trainable_model.py +252 -0
  80. judgeval/trainer/trainer.py +70 -0
  81. judgeval/utils/async_utils.py +39 -0
  82. judgeval/utils/decorators/__init__.py +0 -0
  83. judgeval/utils/decorators/dont_throw.py +37 -0
  84. judgeval/utils/decorators/use_once.py +13 -0
  85. judgeval/utils/file_utils.py +74 -28
  86. judgeval/utils/guards.py +36 -0
  87. judgeval/utils/meta.py +27 -0
  88. judgeval/utils/project.py +15 -0
  89. judgeval/utils/serialize.py +253 -0
  90. judgeval/utils/testing.py +70 -0
  91. judgeval/utils/url.py +10 -0
  92. judgeval/{version_check.py → utils/version_check.py} +5 -3
  93. judgeval/utils/wrappers/README.md +3 -0
  94. judgeval/utils/wrappers/__init__.py +15 -0
  95. judgeval/utils/wrappers/immutable_wrap_async.py +74 -0
  96. judgeval/utils/wrappers/immutable_wrap_async_iterator.py +84 -0
  97. judgeval/utils/wrappers/immutable_wrap_sync.py +66 -0
  98. judgeval/utils/wrappers/immutable_wrap_sync_iterator.py +84 -0
  99. judgeval/utils/wrappers/mutable_wrap_async.py +67 -0
  100. judgeval/utils/wrappers/mutable_wrap_sync.py +67 -0
  101. judgeval/utils/wrappers/py.typed +0 -0
  102. judgeval/utils/wrappers/utils.py +35 -0
  103. judgeval/v1/__init__.py +88 -0
  104. judgeval/v1/data/__init__.py +7 -0
  105. judgeval/v1/data/example.py +44 -0
  106. judgeval/v1/data/scorer_data.py +42 -0
  107. judgeval/v1/data/scoring_result.py +44 -0
  108. judgeval/v1/datasets/__init__.py +6 -0
  109. judgeval/v1/datasets/dataset.py +214 -0
  110. judgeval/v1/datasets/dataset_factory.py +94 -0
  111. judgeval/v1/evaluation/__init__.py +6 -0
  112. judgeval/v1/evaluation/evaluation.py +182 -0
  113. judgeval/v1/evaluation/evaluation_factory.py +17 -0
  114. judgeval/v1/instrumentation/__init__.py +6 -0
  115. judgeval/v1/instrumentation/llm/__init__.py +7 -0
  116. judgeval/v1/instrumentation/llm/config.py +78 -0
  117. judgeval/v1/instrumentation/llm/constants.py +11 -0
  118. judgeval/v1/instrumentation/llm/llm_anthropic/__init__.py +5 -0
  119. judgeval/v1/instrumentation/llm/llm_anthropic/config.py +6 -0
  120. judgeval/v1/instrumentation/llm/llm_anthropic/messages.py +414 -0
  121. judgeval/v1/instrumentation/llm/llm_anthropic/messages_stream.py +307 -0
  122. judgeval/v1/instrumentation/llm/llm_anthropic/wrapper.py +61 -0
  123. judgeval/v1/instrumentation/llm/llm_google/__init__.py +5 -0
  124. judgeval/v1/instrumentation/llm/llm_google/config.py +6 -0
  125. judgeval/v1/instrumentation/llm/llm_google/generate_content.py +121 -0
  126. judgeval/v1/instrumentation/llm/llm_google/wrapper.py +30 -0
  127. judgeval/v1/instrumentation/llm/llm_openai/__init__.py +5 -0
  128. judgeval/v1/instrumentation/llm/llm_openai/beta_chat_completions.py +212 -0
  129. judgeval/v1/instrumentation/llm/llm_openai/chat_completions.py +477 -0
  130. judgeval/v1/instrumentation/llm/llm_openai/config.py +6 -0
  131. judgeval/v1/instrumentation/llm/llm_openai/responses.py +472 -0
  132. judgeval/v1/instrumentation/llm/llm_openai/utils.py +41 -0
  133. judgeval/v1/instrumentation/llm/llm_openai/wrapper.py +63 -0
  134. judgeval/v1/instrumentation/llm/llm_together/__init__.py +5 -0
  135. judgeval/v1/instrumentation/llm/llm_together/chat_completions.py +382 -0
  136. judgeval/v1/instrumentation/llm/llm_together/config.py +6 -0
  137. judgeval/v1/instrumentation/llm/llm_together/wrapper.py +57 -0
  138. judgeval/v1/instrumentation/llm/providers.py +19 -0
  139. judgeval/v1/integrations/claude_agent_sdk/__init__.py +119 -0
  140. judgeval/v1/integrations/claude_agent_sdk/wrapper.py +564 -0
  141. judgeval/v1/integrations/langgraph/__init__.py +13 -0
  142. judgeval/v1/integrations/openlit/__init__.py +47 -0
  143. judgeval/v1/internal/api/__init__.py +525 -0
  144. judgeval/v1/internal/api/api_types.py +413 -0
  145. judgeval/v1/prompts/__init__.py +6 -0
  146. judgeval/v1/prompts/prompt.py +29 -0
  147. judgeval/v1/prompts/prompt_factory.py +189 -0
  148. judgeval/v1/py.typed +0 -0
  149. judgeval/v1/scorers/__init__.py +6 -0
  150. judgeval/v1/scorers/api_scorer.py +82 -0
  151. judgeval/v1/scorers/base_scorer.py +17 -0
  152. judgeval/v1/scorers/built_in/__init__.py +17 -0
  153. judgeval/v1/scorers/built_in/answer_correctness.py +28 -0
  154. judgeval/v1/scorers/built_in/answer_relevancy.py +28 -0
  155. judgeval/v1/scorers/built_in/built_in_factory.py +26 -0
  156. judgeval/v1/scorers/built_in/faithfulness.py +28 -0
  157. judgeval/v1/scorers/built_in/instruction_adherence.py +28 -0
  158. judgeval/v1/scorers/custom_scorer/__init__.py +6 -0
  159. judgeval/v1/scorers/custom_scorer/custom_scorer.py +50 -0
  160. judgeval/v1/scorers/custom_scorer/custom_scorer_factory.py +16 -0
  161. judgeval/v1/scorers/prompt_scorer/__init__.py +6 -0
  162. judgeval/v1/scorers/prompt_scorer/prompt_scorer.py +86 -0
  163. judgeval/v1/scorers/prompt_scorer/prompt_scorer_factory.py +85 -0
  164. judgeval/v1/scorers/scorers_factory.py +49 -0
  165. judgeval/v1/tracer/__init__.py +7 -0
  166. judgeval/v1/tracer/base_tracer.py +520 -0
  167. judgeval/v1/tracer/exporters/__init__.py +14 -0
  168. judgeval/v1/tracer/exporters/in_memory_span_exporter.py +25 -0
  169. judgeval/v1/tracer/exporters/judgment_span_exporter.py +42 -0
  170. judgeval/v1/tracer/exporters/noop_span_exporter.py +19 -0
  171. judgeval/v1/tracer/exporters/span_store.py +50 -0
  172. judgeval/v1/tracer/judgment_tracer_provider.py +70 -0
  173. judgeval/v1/tracer/processors/__init__.py +6 -0
  174. judgeval/v1/tracer/processors/_lifecycles/__init__.py +28 -0
  175. judgeval/v1/tracer/processors/_lifecycles/agent_id_processor.py +53 -0
  176. judgeval/v1/tracer/processors/_lifecycles/context_keys.py +11 -0
  177. judgeval/v1/tracer/processors/_lifecycles/customer_id_processor.py +29 -0
  178. judgeval/v1/tracer/processors/_lifecycles/registry.py +18 -0
  179. judgeval/v1/tracer/processors/judgment_span_processor.py +165 -0
  180. judgeval/v1/tracer/processors/noop_span_processor.py +42 -0
  181. judgeval/v1/tracer/tracer.py +67 -0
  182. judgeval/v1/tracer/tracer_factory.py +38 -0
  183. judgeval/v1/trainers/__init__.py +5 -0
  184. judgeval/v1/trainers/base_trainer.py +62 -0
  185. judgeval/v1/trainers/config.py +123 -0
  186. judgeval/v1/trainers/console.py +144 -0
  187. judgeval/v1/trainers/fireworks_trainer.py +392 -0
  188. judgeval/v1/trainers/trainable_model.py +252 -0
  189. judgeval/v1/trainers/trainers_factory.py +37 -0
  190. judgeval/v1/utils.py +18 -0
  191. judgeval/version.py +5 -0
  192. judgeval/warnings.py +4 -0
  193. judgeval-0.23.0.dist-info/METADATA +266 -0
  194. judgeval-0.23.0.dist-info/RECORD +201 -0
  195. judgeval-0.23.0.dist-info/entry_points.txt +2 -0
  196. judgeval/clients.py +0 -34
  197. judgeval/common/__init__.py +0 -13
  198. judgeval/common/api/__init__.py +0 -3
  199. judgeval/common/api/api.py +0 -352
  200. judgeval/common/api/constants.py +0 -165
  201. judgeval/common/exceptions.py +0 -27
  202. judgeval/common/storage/__init__.py +0 -6
  203. judgeval/common/storage/s3_storage.py +0 -98
  204. judgeval/common/tracer/__init__.py +0 -31
  205. judgeval/common/tracer/constants.py +0 -22
  206. judgeval/common/tracer/core.py +0 -1916
  207. judgeval/common/tracer/otel_exporter.py +0 -108
  208. judgeval/common/tracer/otel_span_processor.py +0 -234
  209. judgeval/common/tracer/span_processor.py +0 -37
  210. judgeval/common/tracer/span_transformer.py +0 -211
  211. judgeval/common/tracer/trace_manager.py +0 -92
  212. judgeval/common/utils.py +0 -940
  213. judgeval/data/datasets/__init__.py +0 -4
  214. judgeval/data/datasets/dataset.py +0 -341
  215. judgeval/data/datasets/eval_dataset_client.py +0 -214
  216. judgeval/data/tool.py +0 -5
  217. judgeval/data/trace_run.py +0 -37
  218. judgeval/evaluation_run.py +0 -75
  219. judgeval/integrations/langgraph.py +0 -843
  220. judgeval/judges/mixture_of_judges.py +0 -286
  221. judgeval/judgment_client.py +0 -369
  222. judgeval/rules.py +0 -521
  223. judgeval/run_evaluation.py +0 -684
  224. judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +0 -14
  225. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
  226. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
  227. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +0 -20
  228. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +0 -27
  229. judgeval/utils/alerts.py +0 -93
  230. judgeval/utils/requests.py +0 -50
  231. judgeval-0.1.0.dist-info/METADATA +0 -202
  232. judgeval-0.1.0.dist-info/RECORD +0 -73
  233. {judgeval-0.1.0.dist-info → judgeval-0.23.0.dist-info}/WHEEL +0 -0
  234. {judgeval-0.1.0.dist-info → judgeval-0.23.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,392 @@
1
+ import asyncio
2
+ import json
3
+ from typing import Optional, Callable, Any, List, Union, Dict
4
+ from fireworks import Dataset # type: ignore[import-not-found,import-untyped]
5
+ from .config import TrainerConfig, ModelConfig
6
+ from .base_trainer import BaseTrainer
7
+ from .trainable_model import TrainableModel
8
+ from judgeval.tracer import Tracer
9
+ from judgeval.tracer.exporters.store import SpanStore
10
+ from judgeval.tracer.exporters import InMemorySpanExporter
11
+ from judgeval.tracer.keys import AttributeKeys
12
+ from judgeval import JudgmentClient
13
+ from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
14
+ from judgeval.data import Example
15
+ from .console import _spinner_progress, _print_progress, _print_progress_update
16
+ from judgeval.exceptions import JudgmentRuntimeError
17
+
18
+
19
+ class FireworksTrainer(BaseTrainer):
20
+ """
21
+ Fireworks AI implementation of the training provider.
22
+
23
+ This trainer uses Fireworks AI's infrastructure for reinforcement learning
24
+ fine-tuning (RFT) of language models.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ config: TrainerConfig,
30
+ trainable_model: TrainableModel,
31
+ tracer: Tracer,
32
+ project_name: Optional[str] = None,
33
+ ):
34
+ """
35
+ Initialize the FireworksTrainer.
36
+
37
+ Args:
38
+ config: TrainerConfig instance with training parameters
39
+ trainable_model: TrainableModel instance for Fireworks training
40
+ tracer: Tracer for observability
41
+ project_name: Project name for organizing training runs and evaluations
42
+ """
43
+ try:
44
+ super().__init__(config, trainable_model, tracer, project_name)
45
+
46
+ self.judgment_client = JudgmentClient()
47
+ self.span_store = SpanStore()
48
+ self.span_exporter = InMemorySpanExporter(self.span_store)
49
+ except Exception as e:
50
+ raise JudgmentRuntimeError(
51
+ f"Failed to initialize FireworksTrainer: {str(e)}"
52
+ ) from e
53
+
54
+ def _extract_message_history_from_spans(
55
+ self, trace_id: str
56
+ ) -> List[Dict[str, str]]:
57
+ """
58
+ Extract message history from spans in the span store for training purposes.
59
+
60
+ This method processes trace spans to reconstruct the conversation flow,
61
+ extracting messages in chronological order from LLM, user, and tool spans.
62
+
63
+ Args:
64
+ trace_id: The trace ID (32-char hex string) to extract message history from
65
+
66
+ Returns:
67
+ List of message dictionaries with 'role' and 'content' keys
68
+ """
69
+ spans = self.span_store.get_by_trace_id(trace_id)
70
+ if not spans:
71
+ return []
72
+
73
+ messages = []
74
+ first_found = False
75
+
76
+ for span in sorted(spans, key=lambda s: getattr(s, "start_time", 0)):
77
+ span_attributes = span.attributes or {}
78
+ span_type = span_attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND, "span")
79
+
80
+ if (
81
+ not span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
82
+ and span_type != "llm"
83
+ ):
84
+ continue
85
+
86
+ if span_type == "llm":
87
+ if not first_found and span_attributes.get(
88
+ AttributeKeys.JUDGMENT_INPUT
89
+ ):
90
+ input_data: Any = span_attributes.get(
91
+ AttributeKeys.JUDGMENT_INPUT, {}
92
+ )
93
+ if isinstance(input_data, dict) and "messages" in input_data:
94
+ input_messages = input_data["messages"]
95
+ if input_messages:
96
+ first_found = True
97
+ for msg in input_messages:
98
+ if (
99
+ isinstance(msg, dict)
100
+ and "role" in msg
101
+ and "content" in msg
102
+ ):
103
+ messages.append(
104
+ {"role": msg["role"], "content": msg["content"]}
105
+ )
106
+
107
+ # Add assistant response from span output
108
+ output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
109
+ if output is not None:
110
+ content = str(output)
111
+ try:
112
+ parsed = json.loads(content)
113
+ if isinstance(parsed, dict) and "messages" in parsed:
114
+ # Extract the actual assistant message content
115
+ for msg in parsed["messages"]:
116
+ if (
117
+ isinstance(msg, dict)
118
+ and msg.get("role") == "assistant"
119
+ ):
120
+ content = msg.get("content", content)
121
+ break
122
+ except (json.JSONDecodeError, KeyError):
123
+ pass
124
+ messages.append({"role": "assistant", "content": content})
125
+
126
+ elif span_type in ("user", "tool"):
127
+ output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
128
+ if output is not None:
129
+ content = str(output)
130
+ try:
131
+ parsed = json.loads(content)
132
+ if isinstance(parsed, dict) and "messages" in parsed:
133
+ for msg in parsed["messages"]:
134
+ if isinstance(msg, dict) and msg.get("role") == "user":
135
+ content = msg.get("content", content)
136
+ break
137
+ except (json.JSONDecodeError, KeyError):
138
+ pass
139
+ messages.append({"role": "user", "content": content})
140
+
141
+ return messages
142
+
143
+ async def generate_rollouts_and_rewards(
144
+ self,
145
+ agent_function: Callable[[Any], Any],
146
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
147
+ prompts: List[Any],
148
+ num_prompts_per_step: Optional[int] = None,
149
+ num_generations_per_prompt: Optional[int] = None,
150
+ concurrency: Optional[int] = None,
151
+ ):
152
+ """
153
+ Generate rollouts and compute rewards using the current model snapshot.
154
+ Each sample contains multiple generations for reinforcement learning optimization.
155
+
156
+ Args:
157
+ agent_function: Function/agent to call for generating responses
158
+ scorers: List of scorer objects to evaluate responses
159
+ prompts: List of prompts to use for training
160
+ num_prompts_per_step: Number of prompts to use per step (defaults to config value, limited by prompts list length)
161
+ num_generations_per_prompt: Generations per prompt (defaults to config value)
162
+ concurrency: Concurrency limit (defaults to config value)
163
+
164
+ Returns:
165
+ List of dataset rows containing samples with messages and evaluations
166
+ """
167
+ num_prompts_per_step = min(
168
+ num_prompts_per_step or self.config.num_prompts_per_step, len(prompts)
169
+ )
170
+ num_generations_per_prompt = (
171
+ num_generations_per_prompt or self.config.num_generations_per_prompt
172
+ )
173
+ concurrency = concurrency or self.config.concurrency
174
+
175
+ semaphore = asyncio.Semaphore(concurrency)
176
+
177
+ @self.tracer.observe(span_type="function")
178
+ async def generate_single_response(prompt_id, generation_id):
179
+ async with semaphore:
180
+ prompt_input = prompts[prompt_id]
181
+ response_data = await agent_function(**prompt_input)
182
+ messages = response_data.get("messages", [])
183
+
184
+ current_span = self.tracer.get_current_span()
185
+ trace_id = None
186
+ if current_span and current_span.is_recording():
187
+ # Convert trace_id to hex string per OTEL spec
188
+ trace_id = format(current_span.get_span_context().trace_id, "032x")
189
+
190
+ try:
191
+ if trace_id is not None:
192
+ traced_messages = self._extract_message_history_from_spans(
193
+ trace_id
194
+ )
195
+ if traced_messages:
196
+ messages = traced_messages
197
+ except Exception as e:
198
+ print(f"Warning: Failed to get message history from trace: {e}")
199
+ pass
200
+
201
+ finally:
202
+ if trace_id is not None:
203
+ self.span_store.clear_trace(trace_id)
204
+
205
+ example = Example(
206
+ input=prompt_input,
207
+ messages=messages,
208
+ actual_output=response_data,
209
+ )
210
+
211
+ scoring_results = self.judgment_client.run_evaluation(
212
+ examples=[example],
213
+ scorers=scorers,
214
+ project_name=self.project_name,
215
+ eval_run_name=f"training_step_{self.trainable_model.current_step}_prompt_{prompt_id}_gen_{generation_id}",
216
+ )
217
+
218
+ if scoring_results and scoring_results[0].scorers_data:
219
+ scores = [
220
+ scorer_data.score
221
+ for scorer_data in scoring_results[0].scorers_data
222
+ if scorer_data.score is not None
223
+ ]
224
+ reward = sum(scores) / len(scores) if scores else 0.0
225
+ else:
226
+ reward = 0.0
227
+
228
+ return {
229
+ "prompt_id": prompt_id,
230
+ "generation_id": generation_id,
231
+ "messages": messages,
232
+ "evals": {"score": reward},
233
+ }
234
+
235
+ coros = []
236
+ for prompt_id in range(num_prompts_per_step):
237
+ for generation_id in range(num_generations_per_prompt):
238
+ coro = generate_single_response(prompt_id, generation_id)
239
+ coros.append(coro)
240
+
241
+ with _spinner_progress(f"Generating {len(coros)} rollouts..."):
242
+ num_completed = 0
243
+ results = []
244
+
245
+ for coro in asyncio.as_completed(coros):
246
+ result = await coro
247
+ results.append(result)
248
+ num_completed += 1
249
+
250
+ _print_progress(f"Generated {len(results)} rollouts successfully")
251
+
252
+ dataset_rows = []
253
+ for prompt_id in range(num_prompts_per_step):
254
+ prompt_generations = [r for r in results if r["prompt_id"] == prompt_id]
255
+ sample_generations = [
256
+ {"messages": gen["messages"], "evals": gen["evals"]}
257
+ for gen in prompt_generations
258
+ ]
259
+ dataset_rows.append({"samples": sample_generations})
260
+
261
+ return dataset_rows
262
+
263
+ async def run_reinforcement_learning(
264
+ self,
265
+ agent_function: Callable[[Any], Any],
266
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
267
+ prompts: List[Any],
268
+ ) -> ModelConfig:
269
+ """
270
+ Run the iterative reinforcement learning fine-tuning loop.
271
+
272
+ This method performs multiple steps of reinforcement learning, where each step:
273
+ 1. Advances to the appropriate model snapshot
274
+ 2. Generates rollouts and computes rewards using scorers
275
+ 3. Trains a new model using reinforcement learning
276
+ 4. Waits for training completion
277
+
278
+ Args:
279
+ agent_function: Function/agent to call for generating responses
280
+ scorers: List of scorer objects to evaluate responses
281
+ prompts: List of prompts to use for training
282
+
283
+ Returns:
284
+ ModelConfig: Configuration of the trained model for inference and future training
285
+ """
286
+
287
+ _print_progress("Starting reinforcement learning training")
288
+
289
+ training_params = {
290
+ "num_steps": self.config.num_steps,
291
+ "num_prompts_per_step": self.config.num_prompts_per_step,
292
+ "num_generations_per_prompt": self.config.num_generations_per_prompt,
293
+ "epochs": self.config.epochs,
294
+ "learning_rate": self.config.learning_rate,
295
+ "temperature": self.config.temperature,
296
+ "max_tokens": self.config.max_tokens,
297
+ }
298
+
299
+ start_step = self.trainable_model.current_step
300
+
301
+ for step in range(start_step, self.config.num_steps):
302
+ step_num = step + 1
303
+ _print_progress(
304
+ f"Starting training step {step_num}", step_num, self.config.num_steps
305
+ )
306
+
307
+ self.trainable_model.advance_to_next_step(step)
308
+
309
+ dataset_rows = await self.generate_rollouts_and_rewards(
310
+ agent_function, scorers, prompts
311
+ )
312
+
313
+ with _spinner_progress(
314
+ "Preparing training dataset", step_num, self.config.num_steps
315
+ ):
316
+ dataset = Dataset.from_list(dataset_rows)
317
+ dataset.sync()
318
+
319
+ _print_progress(
320
+ "Starting reinforcement training", step_num, self.config.num_steps
321
+ )
322
+ job = self.trainable_model.perform_reinforcement_step(dataset, step)
323
+
324
+ last_state = None
325
+ with _spinner_progress(
326
+ "Training job in progress", step_num, self.config.num_steps
327
+ ):
328
+ while not job.is_completed:
329
+ job.raise_if_bad_state()
330
+ current_state = job.state
331
+
332
+ if current_state != last_state:
333
+ if current_state in ["uploading", "validating"]:
334
+ _print_progress_update(
335
+ f"Training job: {current_state} data"
336
+ )
337
+ elif current_state == "training":
338
+ _print_progress_update(
339
+ "Training job: model training in progress"
340
+ )
341
+ else:
342
+ _print_progress_update(f"Training job: {current_state}")
343
+ last_state = current_state
344
+
345
+ await asyncio.sleep(10)
346
+ job = job.get()
347
+ if job is None:
348
+ raise JudgmentRuntimeError(
349
+ "Training job was deleted while waiting for completion"
350
+ )
351
+
352
+ _print_progress(
353
+ f"Training completed! New model: {job.output_model}",
354
+ step_num,
355
+ self.config.num_steps,
356
+ )
357
+
358
+ _print_progress("All training steps completed!")
359
+
360
+ with _spinner_progress("Deploying final trained model"):
361
+ self.trainable_model.advance_to_next_step(self.config.num_steps)
362
+
363
+ return self.trainable_model.get_model_config(training_params)
364
+
365
+ async def train(
366
+ self,
367
+ agent_function: Callable[[Any], Any],
368
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
369
+ prompts: List[Any],
370
+ ) -> ModelConfig:
371
+ """
372
+ Start the reinforcement learning fine-tuning process.
373
+
374
+ This is the main entry point for running the reinforcement learning training.
375
+
376
+ Args:
377
+ agent_function: Function/agent to call for generating responses.
378
+ scorers: List of scorer objects to evaluate responses
379
+ prompts: List of prompts to use for training
380
+
381
+ Returns:
382
+ ModelConfig: Configuration of the trained model for future loading
383
+ """
384
+ try:
385
+ return await self.run_reinforcement_learning(
386
+ agent_function, scorers, prompts
387
+ )
388
+ except JudgmentRuntimeError:
389
+ # Re-raise JudgmentRuntimeError as-is
390
+ raise
391
+ except Exception as e:
392
+ raise JudgmentRuntimeError(f"Training process failed: {str(e)}") from e
@@ -0,0 +1,252 @@
1
+ import time
2
+ from fireworks import LLM # type: ignore[import-not-found,import-untyped]
3
+ from .config import TrainerConfig, ModelConfig
4
+ from typing import Optional, Dict, Any, Callable
5
+ from .console import _model_spinner_progress, _print_model_progress
6
+ from judgeval.exceptions import JudgmentRuntimeError
7
+
8
+
9
+ class TrainableModel:
10
+ """
11
+ A wrapper class for managing model snapshots during training.
12
+
13
+ This class automatically handles model snapshot creation and management
14
+ during the RFT (Reinforcement Fine-Tuning) process,
15
+ abstracting away manual snapshot management from users.
16
+ """
17
+
18
+ config: TrainerConfig
19
+ current_step: int
20
+ _current_model: LLM
21
+ _tracer_wrapper_func: Optional[Callable]
22
+ _base_model: LLM
23
+
24
+ def __init__(self, config: TrainerConfig):
25
+ """
26
+ Initialize the TrainableModel.
27
+
28
+ Args:
29
+ config: TrainerConfig instance with model configuration
30
+ """
31
+ try:
32
+ self.config = config
33
+ self.current_step = 0
34
+ self._tracer_wrapper_func = None
35
+
36
+ self._base_model = self._create_base_model()
37
+ self._current_model = self._base_model
38
+ except Exception as e:
39
+ raise JudgmentRuntimeError(
40
+ f"Failed to initialize TrainableModel: {str(e)}"
41
+ ) from e
42
+
43
+ @classmethod
44
+ def from_model_config(cls, model_config: ModelConfig) -> "TrainableModel":
45
+ """
46
+ Create a TrainableModel from a saved ModelConfig.
47
+
48
+ Args:
49
+ model_config: ModelConfig instance with saved model state
50
+
51
+ Returns:
52
+ TrainableModel instance configured to use the saved model
53
+ """
54
+ # Create a TrainerConfig from the ModelConfig
55
+ trainer_config = TrainerConfig(
56
+ base_model_name=model_config.base_model_name,
57
+ deployment_id=model_config.deployment_id,
58
+ user_id=model_config.user_id,
59
+ model_id=model_config.model_id,
60
+ enable_addons=model_config.enable_addons,
61
+ )
62
+
63
+ instance = cls(trainer_config)
64
+ instance.current_step = model_config.current_step
65
+
66
+ if model_config.is_trained and model_config.current_model_name:
67
+ instance._load_trained_model(model_config.current_model_name)
68
+
69
+ return instance
70
+
71
+ def _create_base_model(self):
72
+ """Create and configure the base model."""
73
+ try:
74
+ with _model_spinner_progress(
75
+ "Creating and deploying base model..."
76
+ ) as update_progress:
77
+ update_progress("Creating base model instance...")
78
+ base_model = LLM(
79
+ model=self.config.base_model_name,
80
+ deployment_type="on-demand",
81
+ id=self.config.deployment_id,
82
+ enable_addons=self.config.enable_addons,
83
+ )
84
+ update_progress("Applying deployment configuration...")
85
+ base_model.apply()
86
+ _print_model_progress("Base model deployment ready")
87
+ return base_model
88
+ except Exception as e:
89
+ raise JudgmentRuntimeError(
90
+ f"Failed to create and deploy base model '{self.config.base_model_name}': {str(e)}"
91
+ ) from e
92
+
93
+ def _load_trained_model(self, model_name: str):
94
+ """Load a trained model by name."""
95
+ try:
96
+ with _model_spinner_progress(
97
+ f"Loading and deploying trained model: {model_name}"
98
+ ) as update_progress:
99
+ update_progress("Creating trained model instance...")
100
+ self._current_model = LLM(
101
+ model=model_name,
102
+ deployment_type="on-demand-lora",
103
+ base_id=self.config.deployment_id,
104
+ )
105
+ update_progress("Applying deployment configuration...")
106
+ self._current_model.apply()
107
+ _print_model_progress("Trained model deployment ready")
108
+
109
+ if self._tracer_wrapper_func:
110
+ self._tracer_wrapper_func(self._current_model)
111
+ except Exception as e:
112
+ raise JudgmentRuntimeError(
113
+ f"Failed to load and deploy trained model '{model_name}': {str(e)}"
114
+ ) from e
115
+
116
+ def get_current_model(self):
117
+ return self._current_model
118
+
119
+ @property
120
+ def chat(self):
121
+ """OpenAI-compatible chat interface."""
122
+ return self._current_model.chat
123
+
124
+ @property
125
+ def completions(self):
126
+ """OpenAI-compatible completions interface."""
127
+ return self._current_model.completions
128
+
129
+ def advance_to_next_step(self, step: int):
130
+ """
131
+ Advance to the next training step and update the current model snapshot.
132
+
133
+ Args:
134
+ step: The current training step number
135
+ """
136
+ try:
137
+ self.current_step = step
138
+
139
+ if step == 0:
140
+ self._current_model = self._base_model
141
+ else:
142
+ model_name = f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{step}"
143
+ with _model_spinner_progress(
144
+ f"Creating and deploying model snapshot: {model_name}"
145
+ ) as update_progress:
146
+ update_progress("Creating model snapshot instance...")
147
+ self._current_model = LLM(
148
+ model=model_name,
149
+ deployment_type="on-demand-lora",
150
+ base_id=self.config.deployment_id,
151
+ )
152
+ update_progress("Applying deployment configuration...")
153
+ self._current_model.apply()
154
+ _print_model_progress("Model snapshot deployment ready")
155
+
156
+ if self._tracer_wrapper_func:
157
+ self._tracer_wrapper_func(self._current_model)
158
+ except Exception as e:
159
+ raise JudgmentRuntimeError(
160
+ f"Failed to advance to training step {step}: {str(e)}"
161
+ ) from e
162
+
163
+ def perform_reinforcement_step(
164
+ self, dataset, step: int, max_retries: int = 3, initial_backoff: float = 1.0
165
+ ):
166
+ """
167
+ Perform a reinforcement learning step using the current model.
168
+
169
+ Args:
170
+ dataset: Training dataset for the reinforcement step
171
+ step: Current step number for output model naming
172
+ max_retries: Maximum number of retry attempts (default: 3)
173
+ initial_backoff: Initial backoff time in seconds for exponential backoff (default: 1.0)
174
+
175
+ Returns:
176
+ Training job object
177
+ """
178
+ model_name = f"{self.config.model_id}-v{step + 1}"
179
+
180
+ for attempt in range(max_retries):
181
+ try:
182
+ return self._current_model.reinforcement_step(
183
+ dataset=dataset,
184
+ output_model=model_name,
185
+ epochs=self.config.epochs,
186
+ learning_rate=self.config.learning_rate,
187
+ )
188
+ except Exception as e:
189
+ if attempt < max_retries - 1:
190
+ backoff_time = initial_backoff * (2**attempt)
191
+ time.sleep(backoff_time)
192
+ else:
193
+ raise JudgmentRuntimeError(
194
+ f"Failed to start reinforcement learning step {step + 1} after {max_retries} attempts: {str(e)}"
195
+ ) from e
196
+
197
+ def get_model_config(
198
+ self, training_params: Optional[Dict[str, Any]] = None
199
+ ) -> ModelConfig:
200
+ """
201
+ Get the current model configuration for persistence.
202
+
203
+ Args:
204
+ training_params: Optional training parameters to include in config
205
+
206
+ Returns:
207
+ ModelConfig instance with current model state
208
+ """
209
+ current_model_name = None
210
+ is_trained = False
211
+
212
+ if self.current_step > 0:
213
+ current_model_name = f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{self.current_step}"
214
+ is_trained = True
215
+
216
+ return ModelConfig(
217
+ base_model_name=self.config.base_model_name,
218
+ deployment_id=self.config.deployment_id,
219
+ user_id=self.config.user_id,
220
+ model_id=self.config.model_id,
221
+ enable_addons=self.config.enable_addons,
222
+ current_step=self.current_step,
223
+ total_steps=self.config.num_steps,
224
+ current_model_name=current_model_name,
225
+ is_trained=is_trained,
226
+ training_params=training_params,
227
+ )
228
+
229
+ def save_model_config(
230
+ self, filepath: str, training_params: Optional[Dict[str, Any]] = None
231
+ ):
232
+ """
233
+ Save the current model configuration to a file.
234
+
235
+ Args:
236
+ filepath: Path to save the configuration file
237
+ training_params: Optional training parameters to include in config
238
+ """
239
+ model_config = self.get_model_config(training_params)
240
+ model_config.save_to_file(filepath)
241
+
242
+ def _register_tracer_wrapper(self, wrapper_func: Callable):
243
+ """
244
+ Register a tracer wrapper function to be reapplied when models change.
245
+
246
+ This is called internally by the tracer's wrap() function to ensure
247
+ that new model instances created during training are automatically wrapped.
248
+
249
+ Args:
250
+ wrapper_func: Function that wraps a model instance with tracing
251
+ """
252
+ self._tracer_wrapper_func = wrapper_func