judgeval 0.1.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. judgeval/__init__.py +173 -10
  2. judgeval/api/__init__.py +523 -0
  3. judgeval/api/api_types.py +413 -0
  4. judgeval/cli.py +112 -0
  5. judgeval/constants.py +7 -30
  6. judgeval/data/__init__.py +1 -3
  7. judgeval/data/evaluation_run.py +125 -0
  8. judgeval/data/example.py +14 -40
  9. judgeval/data/judgment_types.py +396 -146
  10. judgeval/data/result.py +11 -18
  11. judgeval/data/scorer_data.py +3 -26
  12. judgeval/data/scripts/openapi_transform.py +5 -5
  13. judgeval/data/trace.py +115 -194
  14. judgeval/dataset/__init__.py +335 -0
  15. judgeval/env.py +55 -0
  16. judgeval/evaluation/__init__.py +346 -0
  17. judgeval/exceptions.py +28 -0
  18. judgeval/integrations/langgraph/__init__.py +13 -0
  19. judgeval/integrations/openlit/__init__.py +51 -0
  20. judgeval/judges/__init__.py +2 -2
  21. judgeval/judges/litellm_judge.py +77 -16
  22. judgeval/judges/together_judge.py +88 -17
  23. judgeval/judges/utils.py +7 -20
  24. judgeval/judgment_attribute_keys.py +55 -0
  25. judgeval/{common/logger.py → logger.py} +24 -8
  26. judgeval/prompt/__init__.py +330 -0
  27. judgeval/scorers/__init__.py +11 -11
  28. judgeval/scorers/agent_scorer.py +15 -19
  29. judgeval/scorers/api_scorer.py +21 -23
  30. judgeval/scorers/base_scorer.py +54 -36
  31. judgeval/scorers/example_scorer.py +1 -3
  32. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +2 -24
  33. judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +2 -10
  34. judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +2 -2
  35. judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +2 -10
  36. judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +2 -14
  37. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +171 -59
  38. judgeval/scorers/score.py +64 -47
  39. judgeval/scorers/utils.py +2 -107
  40. judgeval/tracer/__init__.py +1111 -2
  41. judgeval/tracer/constants.py +1 -0
  42. judgeval/tracer/exporters/__init__.py +40 -0
  43. judgeval/tracer/exporters/s3.py +119 -0
  44. judgeval/tracer/exporters/store.py +59 -0
  45. judgeval/tracer/exporters/utils.py +32 -0
  46. judgeval/tracer/keys.py +63 -0
  47. judgeval/tracer/llm/__init__.py +7 -0
  48. judgeval/tracer/llm/config.py +78 -0
  49. judgeval/tracer/llm/constants.py +9 -0
  50. judgeval/tracer/llm/llm_anthropic/__init__.py +3 -0
  51. judgeval/tracer/llm/llm_anthropic/config.py +6 -0
  52. judgeval/tracer/llm/llm_anthropic/messages.py +452 -0
  53. judgeval/tracer/llm/llm_anthropic/messages_stream.py +322 -0
  54. judgeval/tracer/llm/llm_anthropic/wrapper.py +59 -0
  55. judgeval/tracer/llm/llm_google/__init__.py +3 -0
  56. judgeval/tracer/llm/llm_google/config.py +6 -0
  57. judgeval/tracer/llm/llm_google/generate_content.py +127 -0
  58. judgeval/tracer/llm/llm_google/wrapper.py +30 -0
  59. judgeval/tracer/llm/llm_openai/__init__.py +3 -0
  60. judgeval/tracer/llm/llm_openai/beta_chat_completions.py +216 -0
  61. judgeval/tracer/llm/llm_openai/chat_completions.py +501 -0
  62. judgeval/tracer/llm/llm_openai/config.py +6 -0
  63. judgeval/tracer/llm/llm_openai/responses.py +506 -0
  64. judgeval/tracer/llm/llm_openai/utils.py +42 -0
  65. judgeval/tracer/llm/llm_openai/wrapper.py +63 -0
  66. judgeval/tracer/llm/llm_together/__init__.py +3 -0
  67. judgeval/tracer/llm/llm_together/chat_completions.py +406 -0
  68. judgeval/tracer/llm/llm_together/config.py +6 -0
  69. judgeval/tracer/llm/llm_together/wrapper.py +52 -0
  70. judgeval/tracer/llm/providers.py +19 -0
  71. judgeval/tracer/managers.py +167 -0
  72. judgeval/tracer/processors/__init__.py +220 -0
  73. judgeval/tracer/utils.py +19 -0
  74. judgeval/trainer/__init__.py +14 -0
  75. judgeval/trainer/base_trainer.py +122 -0
  76. judgeval/trainer/config.py +123 -0
  77. judgeval/trainer/console.py +144 -0
  78. judgeval/trainer/fireworks_trainer.py +392 -0
  79. judgeval/trainer/trainable_model.py +252 -0
  80. judgeval/trainer/trainer.py +70 -0
  81. judgeval/utils/async_utils.py +39 -0
  82. judgeval/utils/decorators/__init__.py +0 -0
  83. judgeval/utils/decorators/dont_throw.py +37 -0
  84. judgeval/utils/decorators/use_once.py +13 -0
  85. judgeval/utils/file_utils.py +74 -28
  86. judgeval/utils/guards.py +36 -0
  87. judgeval/utils/meta.py +27 -0
  88. judgeval/utils/project.py +15 -0
  89. judgeval/utils/serialize.py +253 -0
  90. judgeval/utils/testing.py +70 -0
  91. judgeval/utils/url.py +10 -0
  92. judgeval/{version_check.py → utils/version_check.py} +5 -3
  93. judgeval/utils/wrappers/README.md +3 -0
  94. judgeval/utils/wrappers/__init__.py +15 -0
  95. judgeval/utils/wrappers/immutable_wrap_async.py +74 -0
  96. judgeval/utils/wrappers/immutable_wrap_async_iterator.py +84 -0
  97. judgeval/utils/wrappers/immutable_wrap_sync.py +66 -0
  98. judgeval/utils/wrappers/immutable_wrap_sync_iterator.py +84 -0
  99. judgeval/utils/wrappers/mutable_wrap_async.py +67 -0
  100. judgeval/utils/wrappers/mutable_wrap_sync.py +67 -0
  101. judgeval/utils/wrappers/py.typed +0 -0
  102. judgeval/utils/wrappers/utils.py +35 -0
  103. judgeval/v1/__init__.py +88 -0
  104. judgeval/v1/data/__init__.py +7 -0
  105. judgeval/v1/data/example.py +44 -0
  106. judgeval/v1/data/scorer_data.py +42 -0
  107. judgeval/v1/data/scoring_result.py +44 -0
  108. judgeval/v1/datasets/__init__.py +6 -0
  109. judgeval/v1/datasets/dataset.py +214 -0
  110. judgeval/v1/datasets/dataset_factory.py +94 -0
  111. judgeval/v1/evaluation/__init__.py +6 -0
  112. judgeval/v1/evaluation/evaluation.py +182 -0
  113. judgeval/v1/evaluation/evaluation_factory.py +17 -0
  114. judgeval/v1/instrumentation/__init__.py +6 -0
  115. judgeval/v1/instrumentation/llm/__init__.py +7 -0
  116. judgeval/v1/instrumentation/llm/config.py +78 -0
  117. judgeval/v1/instrumentation/llm/constants.py +11 -0
  118. judgeval/v1/instrumentation/llm/llm_anthropic/__init__.py +5 -0
  119. judgeval/v1/instrumentation/llm/llm_anthropic/config.py +6 -0
  120. judgeval/v1/instrumentation/llm/llm_anthropic/messages.py +414 -0
  121. judgeval/v1/instrumentation/llm/llm_anthropic/messages_stream.py +307 -0
  122. judgeval/v1/instrumentation/llm/llm_anthropic/wrapper.py +61 -0
  123. judgeval/v1/instrumentation/llm/llm_google/__init__.py +5 -0
  124. judgeval/v1/instrumentation/llm/llm_google/config.py +6 -0
  125. judgeval/v1/instrumentation/llm/llm_google/generate_content.py +121 -0
  126. judgeval/v1/instrumentation/llm/llm_google/wrapper.py +30 -0
  127. judgeval/v1/instrumentation/llm/llm_openai/__init__.py +5 -0
  128. judgeval/v1/instrumentation/llm/llm_openai/beta_chat_completions.py +212 -0
  129. judgeval/v1/instrumentation/llm/llm_openai/chat_completions.py +477 -0
  130. judgeval/v1/instrumentation/llm/llm_openai/config.py +6 -0
  131. judgeval/v1/instrumentation/llm/llm_openai/responses.py +472 -0
  132. judgeval/v1/instrumentation/llm/llm_openai/utils.py +41 -0
  133. judgeval/v1/instrumentation/llm/llm_openai/wrapper.py +63 -0
  134. judgeval/v1/instrumentation/llm/llm_together/__init__.py +5 -0
  135. judgeval/v1/instrumentation/llm/llm_together/chat_completions.py +382 -0
  136. judgeval/v1/instrumentation/llm/llm_together/config.py +6 -0
  137. judgeval/v1/instrumentation/llm/llm_together/wrapper.py +57 -0
  138. judgeval/v1/instrumentation/llm/providers.py +19 -0
  139. judgeval/v1/integrations/claude_agent_sdk/__init__.py +119 -0
  140. judgeval/v1/integrations/claude_agent_sdk/wrapper.py +564 -0
  141. judgeval/v1/integrations/langgraph/__init__.py +13 -0
  142. judgeval/v1/integrations/openlit/__init__.py +47 -0
  143. judgeval/v1/internal/api/__init__.py +525 -0
  144. judgeval/v1/internal/api/api_types.py +413 -0
  145. judgeval/v1/prompts/__init__.py +6 -0
  146. judgeval/v1/prompts/prompt.py +29 -0
  147. judgeval/v1/prompts/prompt_factory.py +189 -0
  148. judgeval/v1/py.typed +0 -0
  149. judgeval/v1/scorers/__init__.py +6 -0
  150. judgeval/v1/scorers/api_scorer.py +82 -0
  151. judgeval/v1/scorers/base_scorer.py +17 -0
  152. judgeval/v1/scorers/built_in/__init__.py +17 -0
  153. judgeval/v1/scorers/built_in/answer_correctness.py +28 -0
  154. judgeval/v1/scorers/built_in/answer_relevancy.py +28 -0
  155. judgeval/v1/scorers/built_in/built_in_factory.py +26 -0
  156. judgeval/v1/scorers/built_in/faithfulness.py +28 -0
  157. judgeval/v1/scorers/built_in/instruction_adherence.py +28 -0
  158. judgeval/v1/scorers/custom_scorer/__init__.py +6 -0
  159. judgeval/v1/scorers/custom_scorer/custom_scorer.py +50 -0
  160. judgeval/v1/scorers/custom_scorer/custom_scorer_factory.py +16 -0
  161. judgeval/v1/scorers/prompt_scorer/__init__.py +6 -0
  162. judgeval/v1/scorers/prompt_scorer/prompt_scorer.py +86 -0
  163. judgeval/v1/scorers/prompt_scorer/prompt_scorer_factory.py +85 -0
  164. judgeval/v1/scorers/scorers_factory.py +49 -0
  165. judgeval/v1/tracer/__init__.py +7 -0
  166. judgeval/v1/tracer/base_tracer.py +520 -0
  167. judgeval/v1/tracer/exporters/__init__.py +14 -0
  168. judgeval/v1/tracer/exporters/in_memory_span_exporter.py +25 -0
  169. judgeval/v1/tracer/exporters/judgment_span_exporter.py +42 -0
  170. judgeval/v1/tracer/exporters/noop_span_exporter.py +19 -0
  171. judgeval/v1/tracer/exporters/span_store.py +50 -0
  172. judgeval/v1/tracer/judgment_tracer_provider.py +70 -0
  173. judgeval/v1/tracer/processors/__init__.py +6 -0
  174. judgeval/v1/tracer/processors/_lifecycles/__init__.py +28 -0
  175. judgeval/v1/tracer/processors/_lifecycles/agent_id_processor.py +53 -0
  176. judgeval/v1/tracer/processors/_lifecycles/context_keys.py +11 -0
  177. judgeval/v1/tracer/processors/_lifecycles/customer_id_processor.py +29 -0
  178. judgeval/v1/tracer/processors/_lifecycles/registry.py +18 -0
  179. judgeval/v1/tracer/processors/judgment_span_processor.py +165 -0
  180. judgeval/v1/tracer/processors/noop_span_processor.py +42 -0
  181. judgeval/v1/tracer/tracer.py +67 -0
  182. judgeval/v1/tracer/tracer_factory.py +38 -0
  183. judgeval/v1/trainers/__init__.py +5 -0
  184. judgeval/v1/trainers/base_trainer.py +62 -0
  185. judgeval/v1/trainers/config.py +123 -0
  186. judgeval/v1/trainers/console.py +144 -0
  187. judgeval/v1/trainers/fireworks_trainer.py +392 -0
  188. judgeval/v1/trainers/trainable_model.py +252 -0
  189. judgeval/v1/trainers/trainers_factory.py +37 -0
  190. judgeval/v1/utils.py +18 -0
  191. judgeval/version.py +5 -0
  192. judgeval/warnings.py +4 -0
  193. judgeval-0.23.0.dist-info/METADATA +266 -0
  194. judgeval-0.23.0.dist-info/RECORD +201 -0
  195. judgeval-0.23.0.dist-info/entry_points.txt +2 -0
  196. judgeval/clients.py +0 -34
  197. judgeval/common/__init__.py +0 -13
  198. judgeval/common/api/__init__.py +0 -3
  199. judgeval/common/api/api.py +0 -352
  200. judgeval/common/api/constants.py +0 -165
  201. judgeval/common/exceptions.py +0 -27
  202. judgeval/common/storage/__init__.py +0 -6
  203. judgeval/common/storage/s3_storage.py +0 -98
  204. judgeval/common/tracer/__init__.py +0 -31
  205. judgeval/common/tracer/constants.py +0 -22
  206. judgeval/common/tracer/core.py +0 -1916
  207. judgeval/common/tracer/otel_exporter.py +0 -108
  208. judgeval/common/tracer/otel_span_processor.py +0 -234
  209. judgeval/common/tracer/span_processor.py +0 -37
  210. judgeval/common/tracer/span_transformer.py +0 -211
  211. judgeval/common/tracer/trace_manager.py +0 -92
  212. judgeval/common/utils.py +0 -940
  213. judgeval/data/datasets/__init__.py +0 -4
  214. judgeval/data/datasets/dataset.py +0 -341
  215. judgeval/data/datasets/eval_dataset_client.py +0 -214
  216. judgeval/data/tool.py +0 -5
  217. judgeval/data/trace_run.py +0 -37
  218. judgeval/evaluation_run.py +0 -75
  219. judgeval/integrations/langgraph.py +0 -843
  220. judgeval/judges/mixture_of_judges.py +0 -286
  221. judgeval/judgment_client.py +0 -369
  222. judgeval/rules.py +0 -521
  223. judgeval/run_evaluation.py +0 -684
  224. judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +0 -14
  225. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
  226. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
  227. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +0 -20
  228. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +0 -27
  229. judgeval/utils/alerts.py +0 -93
  230. judgeval/utils/requests.py +0 -50
  231. judgeval-0.1.0.dist-info/METADATA +0 -202
  232. judgeval-0.1.0.dist-info/RECORD +0 -73
  233. {judgeval-0.1.0.dist-info → judgeval-0.23.0.dist-info}/WHEEL +0 -0
  234. {judgeval-0.1.0.dist-info → judgeval-0.23.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,123 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Dict, Any
5
+ import json
6
+
7
+
8
+ @dataclass
9
+ class TrainerConfig:
10
+ """Configuration class for JudgmentTrainer parameters."""
11
+
12
+ deployment_id: str
13
+ user_id: str
14
+ model_id: str
15
+ base_model_name: str = "qwen2p5-7b-instruct"
16
+ rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
17
+ num_steps: int = 5
18
+ num_generations_per_prompt: int = 4
19
+ num_prompts_per_step: int = 4
20
+ concurrency: int = 100
21
+ epochs: int = 1
22
+ learning_rate: float = 1e-5
23
+ temperature: float = 1.5
24
+ max_tokens: int = 50
25
+ enable_addons: bool = True
26
+
27
+
28
+ @dataclass
29
+ class ModelConfig:
30
+ """
31
+ Configuration class for storing and loading trained model state.
32
+
33
+ This class enables persistence of trained models so they can be loaded
34
+ and used later without retraining.
35
+
36
+ Example usage:
37
+ trainer = JudgmentTrainer(config)
38
+ model_config = trainer.train(agent_function, scorers, prompts)
39
+
40
+ # Save the trained model configuration
41
+ model_config.save_to_file("my_trained_model.json")
42
+
43
+ # Later, load and use the trained model
44
+ loaded_config = ModelConfig.load_from_file("my_trained_model.json")
45
+ trained_model = TrainableModel.from_model_config(loaded_config)
46
+
47
+ # Use the trained model for inference
48
+ response = trained_model.chat.completions.create(
49
+ model="current", # Uses the loaded trained model
50
+ messages=[{"role": "user", "content": "Hello!"}]
51
+ )
52
+ """
53
+
54
+ # Base model configuration
55
+ base_model_name: str
56
+ deployment_id: str
57
+ user_id: str
58
+ model_id: str
59
+ enable_addons: bool
60
+
61
+ # Training state
62
+ current_step: int
63
+ total_steps: int
64
+
65
+ # Current model information
66
+ current_model_name: Optional[str] = None
67
+ is_trained: bool = False
68
+
69
+ # Training parameters used (for reference)
70
+ training_params: Optional[Dict[str, Any]] = None
71
+
72
+ def to_dict(self) -> Dict[str, Any]:
73
+ """Convert ModelConfig to dictionary for serialization."""
74
+ return {
75
+ "base_model_name": self.base_model_name,
76
+ "deployment_id": self.deployment_id,
77
+ "user_id": self.user_id,
78
+ "model_id": self.model_id,
79
+ "enable_addons": self.enable_addons,
80
+ "current_step": self.current_step,
81
+ "total_steps": self.total_steps,
82
+ "current_model_name": self.current_model_name,
83
+ "is_trained": self.is_trained,
84
+ "training_params": self.training_params,
85
+ }
86
+
87
+ @classmethod
88
+ def from_dict(cls, data: Dict[str, Any]) -> ModelConfig:
89
+ """Create ModelConfig from dictionary."""
90
+ return cls(
91
+ base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
92
+ deployment_id=data.get("deployment_id", "my-base-deployment"),
93
+ user_id=data.get("user_id", ""),
94
+ model_id=data.get("model_id", ""),
95
+ enable_addons=data.get("enable_addons", True),
96
+ current_step=data.get("current_step", 0),
97
+ total_steps=data.get("total_steps", 0),
98
+ current_model_name=data.get("current_model_name"),
99
+ is_trained=data.get("is_trained", False),
100
+ training_params=data.get("training_params"),
101
+ )
102
+
103
+ def to_json(self) -> str:
104
+ """Convert ModelConfig to JSON string."""
105
+ return json.dumps(self.to_dict(), indent=2)
106
+
107
+ @classmethod
108
+ def from_json(cls, json_str: str) -> ModelConfig:
109
+ """Create ModelConfig from JSON string."""
110
+ data = json.loads(json_str)
111
+ return cls.from_dict(data)
112
+
113
+ def save_to_file(self, filepath: str):
114
+ """Save ModelConfig to a JSON file."""
115
+ with open(filepath, "w") as f:
116
+ f.write(self.to_json())
117
+
118
+ @classmethod
119
+ def load_from_file(cls, filepath: str) -> ModelConfig:
120
+ """Load ModelConfig from a JSON file."""
121
+ with open(filepath, "r") as f:
122
+ json_str = f.read()
123
+ return cls.from_json(json_str)
@@ -0,0 +1,144 @@
1
+ from contextlib import contextmanager
2
+ from typing import Optional
3
+ import sys
4
+ import os
5
+ from judgeval.utils.decorators.use_once import use_once
6
+
7
+
8
+ @use_once
9
+ def _is_jupyter_environment():
10
+ """Check if we're running in a Jupyter notebook or similar environment."""
11
+ try:
12
+ # Check for IPython kernel
13
+ if "ipykernel" in sys.modules or "IPython" in sys.modules:
14
+ return True
15
+ # Check for Jupyter environment variables
16
+ if "JPY_PARENT_PID" in os.environ:
17
+ return True
18
+ # Check if we're in Google Colab
19
+ if "google.colab" in sys.modules:
20
+ return True
21
+ return False
22
+ except Exception:
23
+ return False
24
+
25
+
26
+ IS_JUPYTER = _is_jupyter_environment()
27
+
28
+ if not IS_JUPYTER:
29
+ try:
30
+ from rich.console import Console
31
+ from rich.spinner import Spinner
32
+ from rich.live import Live
33
+ from rich.text import Text
34
+
35
+ shared_console = Console()
36
+ RICH_AVAILABLE = True
37
+ except ImportError:
38
+ RICH_AVAILABLE = False
39
+ else:
40
+ RICH_AVAILABLE = False
41
+
42
+
43
+ class SimpleSpinner:
44
+ def __init__(self, name, text):
45
+ self.text = text
46
+
47
+
48
+ class SimpleLive:
49
+ def __init__(self, spinner, console=None, refresh_per_second=None):
50
+ self.spinner = spinner
51
+
52
+ def __enter__(self):
53
+ print(f"🔄 {self.spinner.text}")
54
+ return self
55
+
56
+ def __exit__(self, *args):
57
+ pass
58
+
59
+ def update(self, spinner):
60
+ print(f"🔄 {spinner.text}")
61
+
62
+
63
+ def safe_print(message, style=None):
64
+ """Safe print function that works in all environments."""
65
+ if RICH_AVAILABLE and not IS_JUPYTER:
66
+ shared_console.print(message, style=style)
67
+ else:
68
+ if style == "green":
69
+ print(f"✅ {message}")
70
+ elif style == "yellow":
71
+ print(f"⚠️ {message}")
72
+ elif style == "blue":
73
+ print(f"🔵 {message}")
74
+ elif style == "cyan":
75
+ print(f"🔷 {message}")
76
+ else:
77
+ print(message)
78
+
79
+
80
+ @contextmanager
81
+ def _spinner_progress(
82
+ message: str, step: Optional[int] = None, total_steps: Optional[int] = None
83
+ ):
84
+ """Context manager for spinner-based progress display."""
85
+ if step is not None and total_steps is not None:
86
+ full_message = f"[Step {step}/{total_steps}] {message}"
87
+ else:
88
+ full_message = f"[Training] {message}"
89
+
90
+ if RICH_AVAILABLE and not IS_JUPYTER:
91
+ spinner = Spinner("dots", text=Text(full_message, style="cyan"))
92
+ with Live(spinner, console=shared_console, refresh_per_second=10):
93
+ yield
94
+ else:
95
+ print(f"🔄 {full_message}")
96
+ try:
97
+ yield
98
+ finally:
99
+ print(f"✅ {full_message} - Complete")
100
+
101
+
102
+ @contextmanager
103
+ def _model_spinner_progress(message: str):
104
+ """Context manager for model operation spinner-based progress display."""
105
+ if RICH_AVAILABLE and not IS_JUPYTER:
106
+ spinner = Spinner("dots", text=Text(f"[Model] {message}", style="blue"))
107
+ with Live(spinner, console=shared_console, refresh_per_second=10) as live:
108
+
109
+ def update_progress(progress_message: str):
110
+ """Update the spinner with a new progress message."""
111
+ new_text = f"[Model] {message}\n └─ {progress_message}"
112
+ spinner.text = Text(new_text, style="blue")
113
+ live.update(spinner)
114
+
115
+ yield update_progress
116
+ else:
117
+ print(f"🔵 [Model] {message}")
118
+
119
+ def update_progress(progress_message: str):
120
+ print(f" └─ {progress_message}")
121
+
122
+ yield update_progress
123
+
124
+
125
+ def _print_progress(
126
+ message: str, step: Optional[int] = None, total_steps: Optional[int] = None
127
+ ):
128
+ """Print progress message with consistent formatting."""
129
+ if step is not None and total_steps is not None:
130
+ safe_print(f"[Step {step}/{total_steps}] {message}", style="green")
131
+ else:
132
+ safe_print(f"[Training] {message}", style="green")
133
+
134
+
135
+ def _print_progress_update(
136
+ message: str, step: Optional[int] = None, total_steps: Optional[int] = None
137
+ ):
138
+ """Print progress update message (for status changes during long operations)."""
139
+ safe_print(f" └─ {message}", style="yellow")
140
+
141
+
142
+ def _print_model_progress(message: str):
143
+ """Print model progress message with consistent formatting."""
144
+ safe_print(f"[Model] {message}", style="blue")
@@ -0,0 +1,392 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
6
+
7
+ from fireworks import Dataset # type: ignore[import-not-found,import-untyped]
8
+
9
+ if TYPE_CHECKING:
10
+ from judgeval.v1.trainers.config import TrainerConfig, ModelConfig
11
+ from judgeval.v1.trainers.trainable_model import TrainableModel
12
+ from judgeval.v1.tracer.tracer import Tracer
13
+ from judgeval.v1.scorers.base_scorer import BaseScorer
14
+ from judgeval.v1.internal.api import JudgmentSyncClient
15
+
16
+ from judgeval.v1.trainers.base_trainer import BaseTrainer
17
+ from judgeval.v1.tracer.exporters import SpanStore, InMemorySpanExporter
18
+ from judgeval.judgment_attribute_keys import AttributeKeys
19
+ from judgeval.v1.data.example import Example
20
+ from judgeval.v1.data.scoring_result import ScoringResult
21
+ from judgeval.v1.internal.api.api_types import ExampleEvaluationRun
22
+ from judgeval.v1.trainers.console import (
23
+ _spinner_progress,
24
+ _print_progress,
25
+ _print_progress_update,
26
+ )
27
+ from judgeval.exceptions import JudgmentRuntimeError
28
+ from opentelemetry import trace
29
+
30
+
31
+ class FireworksTrainer(BaseTrainer):
32
+ __slots__ = ("_client", "span_store", "span_exporter")
33
+
34
+ def __init__(
35
+ self,
36
+ config: "TrainerConfig",
37
+ trainable_model: "TrainableModel",
38
+ tracer: "Tracer",
39
+ project_name: Optional[str] = None,
40
+ client: Optional["JudgmentSyncClient"] = None,
41
+ ):
42
+ super().__init__(config, trainable_model, tracer, project_name)
43
+ if client is None:
44
+ raise ValueError("client is required")
45
+ self._client = client
46
+ self.span_store = SpanStore()
47
+ self.span_exporter = InMemorySpanExporter(self.span_store)
48
+
49
+ def _extract_message_history_from_spans(
50
+ self, trace_id: str
51
+ ) -> List[Dict[str, str]]:
52
+ spans = self.span_store.get_by_trace_id(trace_id)
53
+ if not spans:
54
+ return []
55
+
56
+ messages = []
57
+ first_found = False
58
+
59
+ for span in sorted(spans, key=lambda s: getattr(s, "start_time", 0)):
60
+ span_attributes = span.attributes or {}
61
+ span_type = span_attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND, "span")
62
+
63
+ if (
64
+ not span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
65
+ and span_type != "llm"
66
+ ):
67
+ continue
68
+
69
+ if span_type == "llm":
70
+ if not first_found and span_attributes.get(
71
+ AttributeKeys.JUDGMENT_INPUT
72
+ ):
73
+ input_data: Any = span_attributes.get(
74
+ AttributeKeys.JUDGMENT_INPUT, {}
75
+ )
76
+ if isinstance(input_data, dict) and "messages" in input_data:
77
+ input_messages = input_data["messages"]
78
+ if input_messages:
79
+ first_found = True
80
+ for msg in input_messages:
81
+ if (
82
+ isinstance(msg, dict)
83
+ and "role" in msg
84
+ and "content" in msg
85
+ ):
86
+ messages.append(
87
+ {"role": msg["role"], "content": msg["content"]}
88
+ )
89
+
90
+ output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
91
+ if output is not None:
92
+ content = str(output)
93
+ try:
94
+ parsed = json.loads(content)
95
+ if isinstance(parsed, dict) and "messages" in parsed:
96
+ for msg in parsed["messages"]:
97
+ if (
98
+ isinstance(msg, dict)
99
+ and msg.get("role") == "assistant"
100
+ ):
101
+ content = msg.get("content", content)
102
+ break
103
+ except (json.JSONDecodeError, KeyError):
104
+ pass
105
+ messages.append({"role": "assistant", "content": content})
106
+
107
+ elif span_type in ("user", "tool"):
108
+ output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
109
+ if output is not None:
110
+ content = str(output)
111
+ try:
112
+ parsed = json.loads(content)
113
+ if isinstance(parsed, dict) and "messages" in parsed:
114
+ for msg in parsed["messages"]:
115
+ if isinstance(msg, dict) and msg.get("role") == "user":
116
+ content = msg.get("content", content)
117
+ break
118
+ except (json.JSONDecodeError, KeyError):
119
+ pass
120
+ messages.append({"role": "user", "content": content})
121
+
122
+ return messages
123
+
124
+ async def generate_rollouts_and_rewards(
125
+ self,
126
+ agent_function: Callable[..., Any],
127
+ scorers: List["BaseScorer"],
128
+ prompts: dict[int, dict[Any, Any]],
129
+ num_prompts_per_step: Optional[int] = None,
130
+ num_generations_per_prompt: Optional[int] = None,
131
+ concurrency: Optional[int] = None,
132
+ ):
133
+ num_prompts_per_step = min(
134
+ num_prompts_per_step or self.config.num_prompts_per_step, len(prompts)
135
+ )
136
+ num_generations_per_prompt = (
137
+ num_generations_per_prompt or self.config.num_generations_per_prompt
138
+ )
139
+ concurrency = concurrency or self.config.concurrency
140
+
141
+ semaphore = asyncio.Semaphore(concurrency)
142
+
143
+ @self.tracer.observe(span_type="function")
144
+ async def generate_single_response(prompt_id: int, generation_id: int):
145
+ async with semaphore:
146
+ prompt_input = prompts[prompt_id]
147
+ response_data = await agent_function(**prompt_input)
148
+ messages = response_data.get("messages", [])
149
+
150
+ current_span = trace.get_current_span()
151
+ trace_id = None
152
+ if current_span and current_span.is_recording():
153
+ trace_id = format(current_span.get_span_context().trace_id, "032x")
154
+
155
+ try:
156
+ if trace_id is not None:
157
+ traced_messages = self._extract_message_history_from_spans(
158
+ trace_id
159
+ )
160
+ if traced_messages:
161
+ messages = traced_messages
162
+ except Exception as e:
163
+ print(f"Warning: Failed to get message history from trace: {e}")
164
+ finally:
165
+ if trace_id is not None:
166
+ self.span_store.clear_trace(trace_id)
167
+
168
+ example = Example()
169
+ example.set_property("input", prompt_input)
170
+ example.set_property("messages", messages)
171
+ example.set_property("actual_output", response_data)
172
+
173
+ evaluation_run: ExampleEvaluationRun = {
174
+ "project_name": self.project_name,
175
+ "eval_name": f"training_step_{self.trainable_model.current_step}_prompt_{prompt_id}_gen_{generation_id}",
176
+ "examples": [example.to_dict()],
177
+ "judgment_scorers": [
178
+ scorer.get_scorer_config() for scorer in scorers
179
+ ],
180
+ }
181
+
182
+ response = self._client.add_to_run_eval_queue_examples(evaluation_run)
183
+ if not response.get("success", False):
184
+ raise JudgmentRuntimeError(
185
+ f"Failed to queue evaluation: {response.get('error', 'Unknown error')}"
186
+ )
187
+
188
+ results = await self._poll_evaluation_until_complete(
189
+ evaluation_run["eval_name"], len(scorers)
190
+ )
191
+
192
+ if results and results[0].scorers_data:
193
+ scores = [
194
+ scorer_data.score
195
+ for scorer_data in results[0].scorers_data
196
+ if scorer_data.score is not None
197
+ ]
198
+ reward = sum(scores) / len(scores) if scores else 0.0
199
+ else:
200
+ reward = 0.0
201
+
202
+ return {
203
+ "prompt_id": prompt_id,
204
+ "generation_id": generation_id,
205
+ "messages": messages,
206
+ "evals": {"score": reward},
207
+ }
208
+
209
+ coros = []
210
+ for prompt_id in range(num_prompts_per_step):
211
+ for generation_id in range(num_generations_per_prompt):
212
+ coro = generate_single_response(prompt_id, generation_id)
213
+ coros.append(coro)
214
+
215
+ with _spinner_progress(f"Generating {len(coros)} rollouts..."):
216
+ num_completed = 0
217
+ results = []
218
+
219
+ for future in asyncio.as_completed(coros):
220
+ result = await future
221
+ results.append(result)
222
+ num_completed += 1
223
+
224
+ _print_progress(f"Generated {len(results)} rollouts successfully")
225
+
226
+ dataset_rows = []
227
+ for prompt_id in range(num_prompts_per_step):
228
+ prompt_generations = [r for r in results if r["prompt_id"] == prompt_id]
229
+ sample_generations = [
230
+ {"messages": gen["messages"], "evals": gen["evals"]}
231
+ for gen in prompt_generations
232
+ ]
233
+ dataset_rows.append({"samples": sample_generations})
234
+
235
+ return dataset_rows
236
+
237
+ async def _poll_evaluation_until_complete(
238
+ self, eval_name: str, expected_scorers: int
239
+ ) -> List[ScoringResult]:
240
+ import time
241
+
242
+ max_wait_time = 300
243
+ poll_interval = 2
244
+ start_time = time.time()
245
+
246
+ while time.time() - start_time < max_wait_time:
247
+ await asyncio.sleep(poll_interval)
248
+
249
+ from judgeval.v1.internal.api.api_types import EvalResultsFetch
250
+
251
+ fetch_request: EvalResultsFetch = {
252
+ "experiment_run_id": eval_name,
253
+ "project_name": self.project_name,
254
+ }
255
+
256
+ try:
257
+ response = self._client.fetch_experiment_run(fetch_request)
258
+ if response and response.get("results"):
259
+ results_data = response.get("results")
260
+ if results_data is not None and len(results_data) > 0:
261
+ scoring_results = []
262
+ for result_data in results_data:
263
+ from judgeval.v1.data.scorer_data import ScorerData
264
+
265
+ scorers_data = []
266
+ for scorer_result in result_data.get("scorers_data", []):
267
+ scorers_data.append(
268
+ ScorerData(
269
+ name=scorer_result.get("name", ""),
270
+ threshold=scorer_result.get("threshold", 0.0),
271
+ success=scorer_result.get("success", False),
272
+ score=scorer_result.get("score"),
273
+ reason=scorer_result.get("reason"),
274
+ )
275
+ )
276
+
277
+ scoring_results.append(
278
+ ScoringResult(
279
+ success=result_data.get("success", False),
280
+ scorers_data=scorers_data,
281
+ name=result_data.get("name"),
282
+ trace_id=result_data.get("trace_id"),
283
+ run_duration=result_data.get("run_duration"),
284
+ evaluation_cost=result_data.get("evaluation_cost"),
285
+ )
286
+ )
287
+ return scoring_results
288
+ except Exception:
289
+ pass
290
+
291
+ raise JudgmentRuntimeError(
292
+ f"Evaluation {eval_name} did not complete within {max_wait_time} seconds"
293
+ )
294
+
295
+ async def run_reinforcement_learning(
296
+ self,
297
+ agent_function: Callable[[Any], Any],
298
+ scorers: List["BaseScorer"],
299
+ prompts: dict[int, dict[Any, Any]],
300
+ ) -> "ModelConfig":
301
+ _print_progress("Starting reinforcement learning training")
302
+
303
+ training_params = {
304
+ "num_steps": self.config.num_steps,
305
+ "num_prompts_per_step": self.config.num_prompts_per_step,
306
+ "num_generations_per_prompt": self.config.num_generations_per_prompt,
307
+ "epochs": self.config.epochs,
308
+ "learning_rate": self.config.learning_rate,
309
+ "temperature": self.config.temperature,
310
+ "max_tokens": self.config.max_tokens,
311
+ }
312
+
313
+ start_step = self.trainable_model.current_step
314
+
315
+ for step in range(start_step, self.config.num_steps):
316
+ step_num = step + 1
317
+ _print_progress(
318
+ f"Starting training step {step_num}", step_num, self.config.num_steps
319
+ )
320
+
321
+ self.trainable_model.advance_to_next_step(step)
322
+
323
+ dataset_rows = await self.generate_rollouts_and_rewards(
324
+ agent_function, scorers, prompts
325
+ )
326
+
327
+ with _spinner_progress(
328
+ "Preparing training dataset", step_num, self.config.num_steps
329
+ ):
330
+ dataset = Dataset.from_list(dataset_rows)
331
+ dataset.sync()
332
+
333
+ _print_progress(
334
+ "Starting reinforcement training", step_num, self.config.num_steps
335
+ )
336
+ job = self.trainable_model.perform_reinforcement_step(dataset, step)
337
+
338
+ last_state = None
339
+ with _spinner_progress(
340
+ "Training job in progress", step_num, self.config.num_steps
341
+ ):
342
+ while not job.is_completed:
343
+ job.raise_if_bad_state()
344
+ current_state = job.state
345
+
346
+ if current_state != last_state:
347
+ if current_state in ["uploading", "validating"]:
348
+ _print_progress_update(
349
+ f"Training job: {current_state} data"
350
+ )
351
+ elif current_state == "training":
352
+ _print_progress_update(
353
+ "Training job: model training in progress"
354
+ )
355
+ else:
356
+ _print_progress_update(f"Training job: {current_state}")
357
+ last_state = current_state
358
+
359
+ await asyncio.sleep(10)
360
+ job = job.get()
361
+ if job is None:
362
+ raise JudgmentRuntimeError(
363
+ "Training job was deleted while waiting for completion"
364
+ )
365
+
366
+ _print_progress(
367
+ f"Training completed! New model: {job.output_model}",
368
+ step_num,
369
+ self.config.num_steps,
370
+ )
371
+
372
+ _print_progress("All training steps completed!")
373
+
374
+ with _spinner_progress("Deploying final trained model"):
375
+ self.trainable_model.advance_to_next_step(self.config.num_steps)
376
+
377
+ return self.trainable_model.get_model_config(training_params)
378
+
379
+ async def train(
380
+ self,
381
+ agent_function: Callable[[Any], Any],
382
+ scorers: List["BaseScorer"],
383
+ prompts: dict[int, dict[Any, Any]],
384
+ ) -> "ModelConfig":
385
+ try:
386
+ return await self.run_reinforcement_learning(
387
+ agent_function, scorers, prompts
388
+ )
389
+ except JudgmentRuntimeError:
390
+ raise
391
+ except Exception as e:
392
+ raise JudgmentRuntimeError(f"Training process failed: {str(e)}") from e