judgeval 0.1.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +173 -10
- judgeval/api/__init__.py +523 -0
- judgeval/api/api_types.py +413 -0
- judgeval/cli.py +112 -0
- judgeval/constants.py +7 -30
- judgeval/data/__init__.py +1 -3
- judgeval/data/evaluation_run.py +125 -0
- judgeval/data/example.py +14 -40
- judgeval/data/judgment_types.py +396 -146
- judgeval/data/result.py +11 -18
- judgeval/data/scorer_data.py +3 -26
- judgeval/data/scripts/openapi_transform.py +5 -5
- judgeval/data/trace.py +115 -194
- judgeval/dataset/__init__.py +335 -0
- judgeval/env.py +55 -0
- judgeval/evaluation/__init__.py +346 -0
- judgeval/exceptions.py +28 -0
- judgeval/integrations/langgraph/__init__.py +13 -0
- judgeval/integrations/openlit/__init__.py +51 -0
- judgeval/judges/__init__.py +2 -2
- judgeval/judges/litellm_judge.py +77 -16
- judgeval/judges/together_judge.py +88 -17
- judgeval/judges/utils.py +7 -20
- judgeval/judgment_attribute_keys.py +55 -0
- judgeval/{common/logger.py → logger.py} +24 -8
- judgeval/prompt/__init__.py +330 -0
- judgeval/scorers/__init__.py +11 -11
- judgeval/scorers/agent_scorer.py +15 -19
- judgeval/scorers/api_scorer.py +21 -23
- judgeval/scorers/base_scorer.py +54 -36
- judgeval/scorers/example_scorer.py +1 -3
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +2 -24
- judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +2 -10
- judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +2 -2
- judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +2 -10
- judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +2 -14
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +171 -59
- judgeval/scorers/score.py +64 -47
- judgeval/scorers/utils.py +2 -107
- judgeval/tracer/__init__.py +1111 -2
- judgeval/tracer/constants.py +1 -0
- judgeval/tracer/exporters/__init__.py +40 -0
- judgeval/tracer/exporters/s3.py +119 -0
- judgeval/tracer/exporters/store.py +59 -0
- judgeval/tracer/exporters/utils.py +32 -0
- judgeval/tracer/keys.py +63 -0
- judgeval/tracer/llm/__init__.py +7 -0
- judgeval/tracer/llm/config.py +78 -0
- judgeval/tracer/llm/constants.py +9 -0
- judgeval/tracer/llm/llm_anthropic/__init__.py +3 -0
- judgeval/tracer/llm/llm_anthropic/config.py +6 -0
- judgeval/tracer/llm/llm_anthropic/messages.py +452 -0
- judgeval/tracer/llm/llm_anthropic/messages_stream.py +322 -0
- judgeval/tracer/llm/llm_anthropic/wrapper.py +59 -0
- judgeval/tracer/llm/llm_google/__init__.py +3 -0
- judgeval/tracer/llm/llm_google/config.py +6 -0
- judgeval/tracer/llm/llm_google/generate_content.py +127 -0
- judgeval/tracer/llm/llm_google/wrapper.py +30 -0
- judgeval/tracer/llm/llm_openai/__init__.py +3 -0
- judgeval/tracer/llm/llm_openai/beta_chat_completions.py +216 -0
- judgeval/tracer/llm/llm_openai/chat_completions.py +501 -0
- judgeval/tracer/llm/llm_openai/config.py +6 -0
- judgeval/tracer/llm/llm_openai/responses.py +506 -0
- judgeval/tracer/llm/llm_openai/utils.py +42 -0
- judgeval/tracer/llm/llm_openai/wrapper.py +63 -0
- judgeval/tracer/llm/llm_together/__init__.py +3 -0
- judgeval/tracer/llm/llm_together/chat_completions.py +406 -0
- judgeval/tracer/llm/llm_together/config.py +6 -0
- judgeval/tracer/llm/llm_together/wrapper.py +52 -0
- judgeval/tracer/llm/providers.py +19 -0
- judgeval/tracer/managers.py +167 -0
- judgeval/tracer/processors/__init__.py +220 -0
- judgeval/tracer/utils.py +19 -0
- judgeval/trainer/__init__.py +14 -0
- judgeval/trainer/base_trainer.py +122 -0
- judgeval/trainer/config.py +123 -0
- judgeval/trainer/console.py +144 -0
- judgeval/trainer/fireworks_trainer.py +392 -0
- judgeval/trainer/trainable_model.py +252 -0
- judgeval/trainer/trainer.py +70 -0
- judgeval/utils/async_utils.py +39 -0
- judgeval/utils/decorators/__init__.py +0 -0
- judgeval/utils/decorators/dont_throw.py +37 -0
- judgeval/utils/decorators/use_once.py +13 -0
- judgeval/utils/file_utils.py +74 -28
- judgeval/utils/guards.py +36 -0
- judgeval/utils/meta.py +27 -0
- judgeval/utils/project.py +15 -0
- judgeval/utils/serialize.py +253 -0
- judgeval/utils/testing.py +70 -0
- judgeval/utils/url.py +10 -0
- judgeval/{version_check.py → utils/version_check.py} +5 -3
- judgeval/utils/wrappers/README.md +3 -0
- judgeval/utils/wrappers/__init__.py +15 -0
- judgeval/utils/wrappers/immutable_wrap_async.py +74 -0
- judgeval/utils/wrappers/immutable_wrap_async_iterator.py +84 -0
- judgeval/utils/wrappers/immutable_wrap_sync.py +66 -0
- judgeval/utils/wrappers/immutable_wrap_sync_iterator.py +84 -0
- judgeval/utils/wrappers/mutable_wrap_async.py +67 -0
- judgeval/utils/wrappers/mutable_wrap_sync.py +67 -0
- judgeval/utils/wrappers/py.typed +0 -0
- judgeval/utils/wrappers/utils.py +35 -0
- judgeval/v1/__init__.py +88 -0
- judgeval/v1/data/__init__.py +7 -0
- judgeval/v1/data/example.py +44 -0
- judgeval/v1/data/scorer_data.py +42 -0
- judgeval/v1/data/scoring_result.py +44 -0
- judgeval/v1/datasets/__init__.py +6 -0
- judgeval/v1/datasets/dataset.py +214 -0
- judgeval/v1/datasets/dataset_factory.py +94 -0
- judgeval/v1/evaluation/__init__.py +6 -0
- judgeval/v1/evaluation/evaluation.py +182 -0
- judgeval/v1/evaluation/evaluation_factory.py +17 -0
- judgeval/v1/instrumentation/__init__.py +6 -0
- judgeval/v1/instrumentation/llm/__init__.py +7 -0
- judgeval/v1/instrumentation/llm/config.py +78 -0
- judgeval/v1/instrumentation/llm/constants.py +11 -0
- judgeval/v1/instrumentation/llm/llm_anthropic/__init__.py +5 -0
- judgeval/v1/instrumentation/llm/llm_anthropic/config.py +6 -0
- judgeval/v1/instrumentation/llm/llm_anthropic/messages.py +414 -0
- judgeval/v1/instrumentation/llm/llm_anthropic/messages_stream.py +307 -0
- judgeval/v1/instrumentation/llm/llm_anthropic/wrapper.py +61 -0
- judgeval/v1/instrumentation/llm/llm_google/__init__.py +5 -0
- judgeval/v1/instrumentation/llm/llm_google/config.py +6 -0
- judgeval/v1/instrumentation/llm/llm_google/generate_content.py +121 -0
- judgeval/v1/instrumentation/llm/llm_google/wrapper.py +30 -0
- judgeval/v1/instrumentation/llm/llm_openai/__init__.py +5 -0
- judgeval/v1/instrumentation/llm/llm_openai/beta_chat_completions.py +212 -0
- judgeval/v1/instrumentation/llm/llm_openai/chat_completions.py +477 -0
- judgeval/v1/instrumentation/llm/llm_openai/config.py +6 -0
- judgeval/v1/instrumentation/llm/llm_openai/responses.py +472 -0
- judgeval/v1/instrumentation/llm/llm_openai/utils.py +41 -0
- judgeval/v1/instrumentation/llm/llm_openai/wrapper.py +63 -0
- judgeval/v1/instrumentation/llm/llm_together/__init__.py +5 -0
- judgeval/v1/instrumentation/llm/llm_together/chat_completions.py +382 -0
- judgeval/v1/instrumentation/llm/llm_together/config.py +6 -0
- judgeval/v1/instrumentation/llm/llm_together/wrapper.py +57 -0
- judgeval/v1/instrumentation/llm/providers.py +19 -0
- judgeval/v1/integrations/claude_agent_sdk/__init__.py +119 -0
- judgeval/v1/integrations/claude_agent_sdk/wrapper.py +564 -0
- judgeval/v1/integrations/langgraph/__init__.py +13 -0
- judgeval/v1/integrations/openlit/__init__.py +47 -0
- judgeval/v1/internal/api/__init__.py +525 -0
- judgeval/v1/internal/api/api_types.py +413 -0
- judgeval/v1/prompts/__init__.py +6 -0
- judgeval/v1/prompts/prompt.py +29 -0
- judgeval/v1/prompts/prompt_factory.py +189 -0
- judgeval/v1/py.typed +0 -0
- judgeval/v1/scorers/__init__.py +6 -0
- judgeval/v1/scorers/api_scorer.py +82 -0
- judgeval/v1/scorers/base_scorer.py +17 -0
- judgeval/v1/scorers/built_in/__init__.py +17 -0
- judgeval/v1/scorers/built_in/answer_correctness.py +28 -0
- judgeval/v1/scorers/built_in/answer_relevancy.py +28 -0
- judgeval/v1/scorers/built_in/built_in_factory.py +26 -0
- judgeval/v1/scorers/built_in/faithfulness.py +28 -0
- judgeval/v1/scorers/built_in/instruction_adherence.py +28 -0
- judgeval/v1/scorers/custom_scorer/__init__.py +6 -0
- judgeval/v1/scorers/custom_scorer/custom_scorer.py +50 -0
- judgeval/v1/scorers/custom_scorer/custom_scorer_factory.py +16 -0
- judgeval/v1/scorers/prompt_scorer/__init__.py +6 -0
- judgeval/v1/scorers/prompt_scorer/prompt_scorer.py +86 -0
- judgeval/v1/scorers/prompt_scorer/prompt_scorer_factory.py +85 -0
- judgeval/v1/scorers/scorers_factory.py +49 -0
- judgeval/v1/tracer/__init__.py +7 -0
- judgeval/v1/tracer/base_tracer.py +520 -0
- judgeval/v1/tracer/exporters/__init__.py +14 -0
- judgeval/v1/tracer/exporters/in_memory_span_exporter.py +25 -0
- judgeval/v1/tracer/exporters/judgment_span_exporter.py +42 -0
- judgeval/v1/tracer/exporters/noop_span_exporter.py +19 -0
- judgeval/v1/tracer/exporters/span_store.py +50 -0
- judgeval/v1/tracer/judgment_tracer_provider.py +70 -0
- judgeval/v1/tracer/processors/__init__.py +6 -0
- judgeval/v1/tracer/processors/_lifecycles/__init__.py +28 -0
- judgeval/v1/tracer/processors/_lifecycles/agent_id_processor.py +53 -0
- judgeval/v1/tracer/processors/_lifecycles/context_keys.py +11 -0
- judgeval/v1/tracer/processors/_lifecycles/customer_id_processor.py +29 -0
- judgeval/v1/tracer/processors/_lifecycles/registry.py +18 -0
- judgeval/v1/tracer/processors/judgment_span_processor.py +165 -0
- judgeval/v1/tracer/processors/noop_span_processor.py +42 -0
- judgeval/v1/tracer/tracer.py +67 -0
- judgeval/v1/tracer/tracer_factory.py +38 -0
- judgeval/v1/trainers/__init__.py +5 -0
- judgeval/v1/trainers/base_trainer.py +62 -0
- judgeval/v1/trainers/config.py +123 -0
- judgeval/v1/trainers/console.py +144 -0
- judgeval/v1/trainers/fireworks_trainer.py +392 -0
- judgeval/v1/trainers/trainable_model.py +252 -0
- judgeval/v1/trainers/trainers_factory.py +37 -0
- judgeval/v1/utils.py +18 -0
- judgeval/version.py +5 -0
- judgeval/warnings.py +4 -0
- judgeval-0.23.0.dist-info/METADATA +266 -0
- judgeval-0.23.0.dist-info/RECORD +201 -0
- judgeval-0.23.0.dist-info/entry_points.txt +2 -0
- judgeval/clients.py +0 -34
- judgeval/common/__init__.py +0 -13
- judgeval/common/api/__init__.py +0 -3
- judgeval/common/api/api.py +0 -352
- judgeval/common/api/constants.py +0 -165
- judgeval/common/exceptions.py +0 -27
- judgeval/common/storage/__init__.py +0 -6
- judgeval/common/storage/s3_storage.py +0 -98
- judgeval/common/tracer/__init__.py +0 -31
- judgeval/common/tracer/constants.py +0 -22
- judgeval/common/tracer/core.py +0 -1916
- judgeval/common/tracer/otel_exporter.py +0 -108
- judgeval/common/tracer/otel_span_processor.py +0 -234
- judgeval/common/tracer/span_processor.py +0 -37
- judgeval/common/tracer/span_transformer.py +0 -211
- judgeval/common/tracer/trace_manager.py +0 -92
- judgeval/common/utils.py +0 -940
- judgeval/data/datasets/__init__.py +0 -4
- judgeval/data/datasets/dataset.py +0 -341
- judgeval/data/datasets/eval_dataset_client.py +0 -214
- judgeval/data/tool.py +0 -5
- judgeval/data/trace_run.py +0 -37
- judgeval/evaluation_run.py +0 -75
- judgeval/integrations/langgraph.py +0 -843
- judgeval/judges/mixture_of_judges.py +0 -286
- judgeval/judgment_client.py +0 -369
- judgeval/rules.py +0 -521
- judgeval/run_evaluation.py +0 -684
- judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +0 -14
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +0 -20
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +0 -27
- judgeval/utils/alerts.py +0 -93
- judgeval/utils/requests.py +0 -50
- judgeval-0.1.0.dist-info/METADATA +0 -202
- judgeval-0.1.0.dist-info/RECORD +0 -73
- {judgeval-0.1.0.dist-info → judgeval-0.23.0.dist-info}/WHEEL +0 -0
- {judgeval-0.1.0.dist-info → judgeval-0.23.0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional, Dict, Any
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class TrainerConfig:
|
|
10
|
+
"""Configuration class for JudgmentTrainer parameters."""
|
|
11
|
+
|
|
12
|
+
deployment_id: str
|
|
13
|
+
user_id: str
|
|
14
|
+
model_id: str
|
|
15
|
+
base_model_name: str = "qwen2p5-7b-instruct"
|
|
16
|
+
rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
|
|
17
|
+
num_steps: int = 5
|
|
18
|
+
num_generations_per_prompt: int = 4
|
|
19
|
+
num_prompts_per_step: int = 4
|
|
20
|
+
concurrency: int = 100
|
|
21
|
+
epochs: int = 1
|
|
22
|
+
learning_rate: float = 1e-5
|
|
23
|
+
temperature: float = 1.5
|
|
24
|
+
max_tokens: int = 50
|
|
25
|
+
enable_addons: bool = True
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class ModelConfig:
|
|
30
|
+
"""
|
|
31
|
+
Configuration class for storing and loading trained model state.
|
|
32
|
+
|
|
33
|
+
This class enables persistence of trained models so they can be loaded
|
|
34
|
+
and used later without retraining.
|
|
35
|
+
|
|
36
|
+
Example usage:
|
|
37
|
+
trainer = JudgmentTrainer(config)
|
|
38
|
+
model_config = trainer.train(agent_function, scorers, prompts)
|
|
39
|
+
|
|
40
|
+
# Save the trained model configuration
|
|
41
|
+
model_config.save_to_file("my_trained_model.json")
|
|
42
|
+
|
|
43
|
+
# Later, load and use the trained model
|
|
44
|
+
loaded_config = ModelConfig.load_from_file("my_trained_model.json")
|
|
45
|
+
trained_model = TrainableModel.from_model_config(loaded_config)
|
|
46
|
+
|
|
47
|
+
# Use the trained model for inference
|
|
48
|
+
response = trained_model.chat.completions.create(
|
|
49
|
+
model="current", # Uses the loaded trained model
|
|
50
|
+
messages=[{"role": "user", "content": "Hello!"}]
|
|
51
|
+
)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
# Base model configuration
|
|
55
|
+
base_model_name: str
|
|
56
|
+
deployment_id: str
|
|
57
|
+
user_id: str
|
|
58
|
+
model_id: str
|
|
59
|
+
enable_addons: bool
|
|
60
|
+
|
|
61
|
+
# Training state
|
|
62
|
+
current_step: int
|
|
63
|
+
total_steps: int
|
|
64
|
+
|
|
65
|
+
# Current model information
|
|
66
|
+
current_model_name: Optional[str] = None
|
|
67
|
+
is_trained: bool = False
|
|
68
|
+
|
|
69
|
+
# Training parameters used (for reference)
|
|
70
|
+
training_params: Optional[Dict[str, Any]] = None
|
|
71
|
+
|
|
72
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
73
|
+
"""Convert ModelConfig to dictionary for serialization."""
|
|
74
|
+
return {
|
|
75
|
+
"base_model_name": self.base_model_name,
|
|
76
|
+
"deployment_id": self.deployment_id,
|
|
77
|
+
"user_id": self.user_id,
|
|
78
|
+
"model_id": self.model_id,
|
|
79
|
+
"enable_addons": self.enable_addons,
|
|
80
|
+
"current_step": self.current_step,
|
|
81
|
+
"total_steps": self.total_steps,
|
|
82
|
+
"current_model_name": self.current_model_name,
|
|
83
|
+
"is_trained": self.is_trained,
|
|
84
|
+
"training_params": self.training_params,
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def from_dict(cls, data: Dict[str, Any]) -> ModelConfig:
|
|
89
|
+
"""Create ModelConfig from dictionary."""
|
|
90
|
+
return cls(
|
|
91
|
+
base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
|
|
92
|
+
deployment_id=data.get("deployment_id", "my-base-deployment"),
|
|
93
|
+
user_id=data.get("user_id", ""),
|
|
94
|
+
model_id=data.get("model_id", ""),
|
|
95
|
+
enable_addons=data.get("enable_addons", True),
|
|
96
|
+
current_step=data.get("current_step", 0),
|
|
97
|
+
total_steps=data.get("total_steps", 0),
|
|
98
|
+
current_model_name=data.get("current_model_name"),
|
|
99
|
+
is_trained=data.get("is_trained", False),
|
|
100
|
+
training_params=data.get("training_params"),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def to_json(self) -> str:
|
|
104
|
+
"""Convert ModelConfig to JSON string."""
|
|
105
|
+
return json.dumps(self.to_dict(), indent=2)
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def from_json(cls, json_str: str) -> ModelConfig:
|
|
109
|
+
"""Create ModelConfig from JSON string."""
|
|
110
|
+
data = json.loads(json_str)
|
|
111
|
+
return cls.from_dict(data)
|
|
112
|
+
|
|
113
|
+
def save_to_file(self, filepath: str):
|
|
114
|
+
"""Save ModelConfig to a JSON file."""
|
|
115
|
+
with open(filepath, "w") as f:
|
|
116
|
+
f.write(self.to_json())
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def load_from_file(cls, filepath: str) -> ModelConfig:
|
|
120
|
+
"""Load ModelConfig from a JSON file."""
|
|
121
|
+
with open(filepath, "r") as f:
|
|
122
|
+
json_str = f.read()
|
|
123
|
+
return cls.from_json(json_str)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import sys
|
|
4
|
+
import os
|
|
5
|
+
from judgeval.utils.decorators.use_once import use_once
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@use_once
|
|
9
|
+
def _is_jupyter_environment():
|
|
10
|
+
"""Check if we're running in a Jupyter notebook or similar environment."""
|
|
11
|
+
try:
|
|
12
|
+
# Check for IPython kernel
|
|
13
|
+
if "ipykernel" in sys.modules or "IPython" in sys.modules:
|
|
14
|
+
return True
|
|
15
|
+
# Check for Jupyter environment variables
|
|
16
|
+
if "JPY_PARENT_PID" in os.environ:
|
|
17
|
+
return True
|
|
18
|
+
# Check if we're in Google Colab
|
|
19
|
+
if "google.colab" in sys.modules:
|
|
20
|
+
return True
|
|
21
|
+
return False
|
|
22
|
+
except Exception:
|
|
23
|
+
return False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
IS_JUPYTER = _is_jupyter_environment()
|
|
27
|
+
|
|
28
|
+
if not IS_JUPYTER:
|
|
29
|
+
try:
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.spinner import Spinner
|
|
32
|
+
from rich.live import Live
|
|
33
|
+
from rich.text import Text
|
|
34
|
+
|
|
35
|
+
shared_console = Console()
|
|
36
|
+
RICH_AVAILABLE = True
|
|
37
|
+
except ImportError:
|
|
38
|
+
RICH_AVAILABLE = False
|
|
39
|
+
else:
|
|
40
|
+
RICH_AVAILABLE = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SimpleSpinner:
|
|
44
|
+
def __init__(self, name, text):
|
|
45
|
+
self.text = text
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SimpleLive:
|
|
49
|
+
def __init__(self, spinner, console=None, refresh_per_second=None):
|
|
50
|
+
self.spinner = spinner
|
|
51
|
+
|
|
52
|
+
def __enter__(self):
|
|
53
|
+
print(f"🔄 {self.spinner.text}")
|
|
54
|
+
return self
|
|
55
|
+
|
|
56
|
+
def __exit__(self, *args):
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def update(self, spinner):
|
|
60
|
+
print(f"🔄 {spinner.text}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def safe_print(message, style=None):
|
|
64
|
+
"""Safe print function that works in all environments."""
|
|
65
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
|
66
|
+
shared_console.print(message, style=style)
|
|
67
|
+
else:
|
|
68
|
+
if style == "green":
|
|
69
|
+
print(f"✅ {message}")
|
|
70
|
+
elif style == "yellow":
|
|
71
|
+
print(f"⚠️ {message}")
|
|
72
|
+
elif style == "blue":
|
|
73
|
+
print(f"🔵 {message}")
|
|
74
|
+
elif style == "cyan":
|
|
75
|
+
print(f"🔷 {message}")
|
|
76
|
+
else:
|
|
77
|
+
print(message)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@contextmanager
|
|
81
|
+
def _spinner_progress(
|
|
82
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
|
83
|
+
):
|
|
84
|
+
"""Context manager for spinner-based progress display."""
|
|
85
|
+
if step is not None and total_steps is not None:
|
|
86
|
+
full_message = f"[Step {step}/{total_steps}] {message}"
|
|
87
|
+
else:
|
|
88
|
+
full_message = f"[Training] {message}"
|
|
89
|
+
|
|
90
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
|
91
|
+
spinner = Spinner("dots", text=Text(full_message, style="cyan"))
|
|
92
|
+
with Live(spinner, console=shared_console, refresh_per_second=10):
|
|
93
|
+
yield
|
|
94
|
+
else:
|
|
95
|
+
print(f"🔄 {full_message}")
|
|
96
|
+
try:
|
|
97
|
+
yield
|
|
98
|
+
finally:
|
|
99
|
+
print(f"✅ {full_message} - Complete")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@contextmanager
|
|
103
|
+
def _model_spinner_progress(message: str):
|
|
104
|
+
"""Context manager for model operation spinner-based progress display."""
|
|
105
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
|
106
|
+
spinner = Spinner("dots", text=Text(f"[Model] {message}", style="blue"))
|
|
107
|
+
with Live(spinner, console=shared_console, refresh_per_second=10) as live:
|
|
108
|
+
|
|
109
|
+
def update_progress(progress_message: str):
|
|
110
|
+
"""Update the spinner with a new progress message."""
|
|
111
|
+
new_text = f"[Model] {message}\n └─ {progress_message}"
|
|
112
|
+
spinner.text = Text(new_text, style="blue")
|
|
113
|
+
live.update(spinner)
|
|
114
|
+
|
|
115
|
+
yield update_progress
|
|
116
|
+
else:
|
|
117
|
+
print(f"🔵 [Model] {message}")
|
|
118
|
+
|
|
119
|
+
def update_progress(progress_message: str):
|
|
120
|
+
print(f" └─ {progress_message}")
|
|
121
|
+
|
|
122
|
+
yield update_progress
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _print_progress(
|
|
126
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
|
127
|
+
):
|
|
128
|
+
"""Print progress message with consistent formatting."""
|
|
129
|
+
if step is not None and total_steps is not None:
|
|
130
|
+
safe_print(f"[Step {step}/{total_steps}] {message}", style="green")
|
|
131
|
+
else:
|
|
132
|
+
safe_print(f"[Training] {message}", style="green")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _print_progress_update(
|
|
136
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
|
137
|
+
):
|
|
138
|
+
"""Print progress update message (for status changes during long operations)."""
|
|
139
|
+
safe_print(f" └─ {message}", style="yellow")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _print_model_progress(message: str):
|
|
143
|
+
"""Print model progress message with consistent formatting."""
|
|
144
|
+
safe_print(f"[Model] {message}", style="blue")
|
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from fireworks import Dataset # type: ignore[import-not-found,import-untyped]
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from judgeval.v1.trainers.config import TrainerConfig, ModelConfig
|
|
11
|
+
from judgeval.v1.trainers.trainable_model import TrainableModel
|
|
12
|
+
from judgeval.v1.tracer.tracer import Tracer
|
|
13
|
+
from judgeval.v1.scorers.base_scorer import BaseScorer
|
|
14
|
+
from judgeval.v1.internal.api import JudgmentSyncClient
|
|
15
|
+
|
|
16
|
+
from judgeval.v1.trainers.base_trainer import BaseTrainer
|
|
17
|
+
from judgeval.v1.tracer.exporters import SpanStore, InMemorySpanExporter
|
|
18
|
+
from judgeval.judgment_attribute_keys import AttributeKeys
|
|
19
|
+
from judgeval.v1.data.example import Example
|
|
20
|
+
from judgeval.v1.data.scoring_result import ScoringResult
|
|
21
|
+
from judgeval.v1.internal.api.api_types import ExampleEvaluationRun
|
|
22
|
+
from judgeval.v1.trainers.console import (
|
|
23
|
+
_spinner_progress,
|
|
24
|
+
_print_progress,
|
|
25
|
+
_print_progress_update,
|
|
26
|
+
)
|
|
27
|
+
from judgeval.exceptions import JudgmentRuntimeError
|
|
28
|
+
from opentelemetry import trace
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FireworksTrainer(BaseTrainer):
|
|
32
|
+
__slots__ = ("_client", "span_store", "span_exporter")
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
config: "TrainerConfig",
|
|
37
|
+
trainable_model: "TrainableModel",
|
|
38
|
+
tracer: "Tracer",
|
|
39
|
+
project_name: Optional[str] = None,
|
|
40
|
+
client: Optional["JudgmentSyncClient"] = None,
|
|
41
|
+
):
|
|
42
|
+
super().__init__(config, trainable_model, tracer, project_name)
|
|
43
|
+
if client is None:
|
|
44
|
+
raise ValueError("client is required")
|
|
45
|
+
self._client = client
|
|
46
|
+
self.span_store = SpanStore()
|
|
47
|
+
self.span_exporter = InMemorySpanExporter(self.span_store)
|
|
48
|
+
|
|
49
|
+
def _extract_message_history_from_spans(
|
|
50
|
+
self, trace_id: str
|
|
51
|
+
) -> List[Dict[str, str]]:
|
|
52
|
+
spans = self.span_store.get_by_trace_id(trace_id)
|
|
53
|
+
if not spans:
|
|
54
|
+
return []
|
|
55
|
+
|
|
56
|
+
messages = []
|
|
57
|
+
first_found = False
|
|
58
|
+
|
|
59
|
+
for span in sorted(spans, key=lambda s: getattr(s, "start_time", 0)):
|
|
60
|
+
span_attributes = span.attributes or {}
|
|
61
|
+
span_type = span_attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND, "span")
|
|
62
|
+
|
|
63
|
+
if (
|
|
64
|
+
not span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
|
|
65
|
+
and span_type != "llm"
|
|
66
|
+
):
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
if span_type == "llm":
|
|
70
|
+
if not first_found and span_attributes.get(
|
|
71
|
+
AttributeKeys.JUDGMENT_INPUT
|
|
72
|
+
):
|
|
73
|
+
input_data: Any = span_attributes.get(
|
|
74
|
+
AttributeKeys.JUDGMENT_INPUT, {}
|
|
75
|
+
)
|
|
76
|
+
if isinstance(input_data, dict) and "messages" in input_data:
|
|
77
|
+
input_messages = input_data["messages"]
|
|
78
|
+
if input_messages:
|
|
79
|
+
first_found = True
|
|
80
|
+
for msg in input_messages:
|
|
81
|
+
if (
|
|
82
|
+
isinstance(msg, dict)
|
|
83
|
+
and "role" in msg
|
|
84
|
+
and "content" in msg
|
|
85
|
+
):
|
|
86
|
+
messages.append(
|
|
87
|
+
{"role": msg["role"], "content": msg["content"]}
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
|
|
91
|
+
if output is not None:
|
|
92
|
+
content = str(output)
|
|
93
|
+
try:
|
|
94
|
+
parsed = json.loads(content)
|
|
95
|
+
if isinstance(parsed, dict) and "messages" in parsed:
|
|
96
|
+
for msg in parsed["messages"]:
|
|
97
|
+
if (
|
|
98
|
+
isinstance(msg, dict)
|
|
99
|
+
and msg.get("role") == "assistant"
|
|
100
|
+
):
|
|
101
|
+
content = msg.get("content", content)
|
|
102
|
+
break
|
|
103
|
+
except (json.JSONDecodeError, KeyError):
|
|
104
|
+
pass
|
|
105
|
+
messages.append({"role": "assistant", "content": content})
|
|
106
|
+
|
|
107
|
+
elif span_type in ("user", "tool"):
|
|
108
|
+
output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
|
|
109
|
+
if output is not None:
|
|
110
|
+
content = str(output)
|
|
111
|
+
try:
|
|
112
|
+
parsed = json.loads(content)
|
|
113
|
+
if isinstance(parsed, dict) and "messages" in parsed:
|
|
114
|
+
for msg in parsed["messages"]:
|
|
115
|
+
if isinstance(msg, dict) and msg.get("role") == "user":
|
|
116
|
+
content = msg.get("content", content)
|
|
117
|
+
break
|
|
118
|
+
except (json.JSONDecodeError, KeyError):
|
|
119
|
+
pass
|
|
120
|
+
messages.append({"role": "user", "content": content})
|
|
121
|
+
|
|
122
|
+
return messages
|
|
123
|
+
|
|
124
|
+
async def generate_rollouts_and_rewards(
|
|
125
|
+
self,
|
|
126
|
+
agent_function: Callable[..., Any],
|
|
127
|
+
scorers: List["BaseScorer"],
|
|
128
|
+
prompts: dict[int, dict[Any, Any]],
|
|
129
|
+
num_prompts_per_step: Optional[int] = None,
|
|
130
|
+
num_generations_per_prompt: Optional[int] = None,
|
|
131
|
+
concurrency: Optional[int] = None,
|
|
132
|
+
):
|
|
133
|
+
num_prompts_per_step = min(
|
|
134
|
+
num_prompts_per_step or self.config.num_prompts_per_step, len(prompts)
|
|
135
|
+
)
|
|
136
|
+
num_generations_per_prompt = (
|
|
137
|
+
num_generations_per_prompt or self.config.num_generations_per_prompt
|
|
138
|
+
)
|
|
139
|
+
concurrency = concurrency or self.config.concurrency
|
|
140
|
+
|
|
141
|
+
semaphore = asyncio.Semaphore(concurrency)
|
|
142
|
+
|
|
143
|
+
@self.tracer.observe(span_type="function")
|
|
144
|
+
async def generate_single_response(prompt_id: int, generation_id: int):
|
|
145
|
+
async with semaphore:
|
|
146
|
+
prompt_input = prompts[prompt_id]
|
|
147
|
+
response_data = await agent_function(**prompt_input)
|
|
148
|
+
messages = response_data.get("messages", [])
|
|
149
|
+
|
|
150
|
+
current_span = trace.get_current_span()
|
|
151
|
+
trace_id = None
|
|
152
|
+
if current_span and current_span.is_recording():
|
|
153
|
+
trace_id = format(current_span.get_span_context().trace_id, "032x")
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
if trace_id is not None:
|
|
157
|
+
traced_messages = self._extract_message_history_from_spans(
|
|
158
|
+
trace_id
|
|
159
|
+
)
|
|
160
|
+
if traced_messages:
|
|
161
|
+
messages = traced_messages
|
|
162
|
+
except Exception as e:
|
|
163
|
+
print(f"Warning: Failed to get message history from trace: {e}")
|
|
164
|
+
finally:
|
|
165
|
+
if trace_id is not None:
|
|
166
|
+
self.span_store.clear_trace(trace_id)
|
|
167
|
+
|
|
168
|
+
example = Example()
|
|
169
|
+
example.set_property("input", prompt_input)
|
|
170
|
+
example.set_property("messages", messages)
|
|
171
|
+
example.set_property("actual_output", response_data)
|
|
172
|
+
|
|
173
|
+
evaluation_run: ExampleEvaluationRun = {
|
|
174
|
+
"project_name": self.project_name,
|
|
175
|
+
"eval_name": f"training_step_{self.trainable_model.current_step}_prompt_{prompt_id}_gen_{generation_id}",
|
|
176
|
+
"examples": [example.to_dict()],
|
|
177
|
+
"judgment_scorers": [
|
|
178
|
+
scorer.get_scorer_config() for scorer in scorers
|
|
179
|
+
],
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
response = self._client.add_to_run_eval_queue_examples(evaluation_run)
|
|
183
|
+
if not response.get("success", False):
|
|
184
|
+
raise JudgmentRuntimeError(
|
|
185
|
+
f"Failed to queue evaluation: {response.get('error', 'Unknown error')}"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
results = await self._poll_evaluation_until_complete(
|
|
189
|
+
evaluation_run["eval_name"], len(scorers)
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if results and results[0].scorers_data:
|
|
193
|
+
scores = [
|
|
194
|
+
scorer_data.score
|
|
195
|
+
for scorer_data in results[0].scorers_data
|
|
196
|
+
if scorer_data.score is not None
|
|
197
|
+
]
|
|
198
|
+
reward = sum(scores) / len(scores) if scores else 0.0
|
|
199
|
+
else:
|
|
200
|
+
reward = 0.0
|
|
201
|
+
|
|
202
|
+
return {
|
|
203
|
+
"prompt_id": prompt_id,
|
|
204
|
+
"generation_id": generation_id,
|
|
205
|
+
"messages": messages,
|
|
206
|
+
"evals": {"score": reward},
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
coros = []
|
|
210
|
+
for prompt_id in range(num_prompts_per_step):
|
|
211
|
+
for generation_id in range(num_generations_per_prompt):
|
|
212
|
+
coro = generate_single_response(prompt_id, generation_id)
|
|
213
|
+
coros.append(coro)
|
|
214
|
+
|
|
215
|
+
with _spinner_progress(f"Generating {len(coros)} rollouts..."):
|
|
216
|
+
num_completed = 0
|
|
217
|
+
results = []
|
|
218
|
+
|
|
219
|
+
for future in asyncio.as_completed(coros):
|
|
220
|
+
result = await future
|
|
221
|
+
results.append(result)
|
|
222
|
+
num_completed += 1
|
|
223
|
+
|
|
224
|
+
_print_progress(f"Generated {len(results)} rollouts successfully")
|
|
225
|
+
|
|
226
|
+
dataset_rows = []
|
|
227
|
+
for prompt_id in range(num_prompts_per_step):
|
|
228
|
+
prompt_generations = [r for r in results if r["prompt_id"] == prompt_id]
|
|
229
|
+
sample_generations = [
|
|
230
|
+
{"messages": gen["messages"], "evals": gen["evals"]}
|
|
231
|
+
for gen in prompt_generations
|
|
232
|
+
]
|
|
233
|
+
dataset_rows.append({"samples": sample_generations})
|
|
234
|
+
|
|
235
|
+
return dataset_rows
|
|
236
|
+
|
|
237
|
+
async def _poll_evaluation_until_complete(
|
|
238
|
+
self, eval_name: str, expected_scorers: int
|
|
239
|
+
) -> List[ScoringResult]:
|
|
240
|
+
import time
|
|
241
|
+
|
|
242
|
+
max_wait_time = 300
|
|
243
|
+
poll_interval = 2
|
|
244
|
+
start_time = time.time()
|
|
245
|
+
|
|
246
|
+
while time.time() - start_time < max_wait_time:
|
|
247
|
+
await asyncio.sleep(poll_interval)
|
|
248
|
+
|
|
249
|
+
from judgeval.v1.internal.api.api_types import EvalResultsFetch
|
|
250
|
+
|
|
251
|
+
fetch_request: EvalResultsFetch = {
|
|
252
|
+
"experiment_run_id": eval_name,
|
|
253
|
+
"project_name": self.project_name,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
try:
|
|
257
|
+
response = self._client.fetch_experiment_run(fetch_request)
|
|
258
|
+
if response and response.get("results"):
|
|
259
|
+
results_data = response.get("results")
|
|
260
|
+
if results_data is not None and len(results_data) > 0:
|
|
261
|
+
scoring_results = []
|
|
262
|
+
for result_data in results_data:
|
|
263
|
+
from judgeval.v1.data.scorer_data import ScorerData
|
|
264
|
+
|
|
265
|
+
scorers_data = []
|
|
266
|
+
for scorer_result in result_data.get("scorers_data", []):
|
|
267
|
+
scorers_data.append(
|
|
268
|
+
ScorerData(
|
|
269
|
+
name=scorer_result.get("name", ""),
|
|
270
|
+
threshold=scorer_result.get("threshold", 0.0),
|
|
271
|
+
success=scorer_result.get("success", False),
|
|
272
|
+
score=scorer_result.get("score"),
|
|
273
|
+
reason=scorer_result.get("reason"),
|
|
274
|
+
)
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
scoring_results.append(
|
|
278
|
+
ScoringResult(
|
|
279
|
+
success=result_data.get("success", False),
|
|
280
|
+
scorers_data=scorers_data,
|
|
281
|
+
name=result_data.get("name"),
|
|
282
|
+
trace_id=result_data.get("trace_id"),
|
|
283
|
+
run_duration=result_data.get("run_duration"),
|
|
284
|
+
evaluation_cost=result_data.get("evaluation_cost"),
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
return scoring_results
|
|
288
|
+
except Exception:
|
|
289
|
+
pass
|
|
290
|
+
|
|
291
|
+
raise JudgmentRuntimeError(
|
|
292
|
+
f"Evaluation {eval_name} did not complete within {max_wait_time} seconds"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
async def run_reinforcement_learning(
|
|
296
|
+
self,
|
|
297
|
+
agent_function: Callable[[Any], Any],
|
|
298
|
+
scorers: List["BaseScorer"],
|
|
299
|
+
prompts: dict[int, dict[Any, Any]],
|
|
300
|
+
) -> "ModelConfig":
|
|
301
|
+
_print_progress("Starting reinforcement learning training")
|
|
302
|
+
|
|
303
|
+
training_params = {
|
|
304
|
+
"num_steps": self.config.num_steps,
|
|
305
|
+
"num_prompts_per_step": self.config.num_prompts_per_step,
|
|
306
|
+
"num_generations_per_prompt": self.config.num_generations_per_prompt,
|
|
307
|
+
"epochs": self.config.epochs,
|
|
308
|
+
"learning_rate": self.config.learning_rate,
|
|
309
|
+
"temperature": self.config.temperature,
|
|
310
|
+
"max_tokens": self.config.max_tokens,
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
start_step = self.trainable_model.current_step
|
|
314
|
+
|
|
315
|
+
for step in range(start_step, self.config.num_steps):
|
|
316
|
+
step_num = step + 1
|
|
317
|
+
_print_progress(
|
|
318
|
+
f"Starting training step {step_num}", step_num, self.config.num_steps
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
self.trainable_model.advance_to_next_step(step)
|
|
322
|
+
|
|
323
|
+
dataset_rows = await self.generate_rollouts_and_rewards(
|
|
324
|
+
agent_function, scorers, prompts
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
with _spinner_progress(
|
|
328
|
+
"Preparing training dataset", step_num, self.config.num_steps
|
|
329
|
+
):
|
|
330
|
+
dataset = Dataset.from_list(dataset_rows)
|
|
331
|
+
dataset.sync()
|
|
332
|
+
|
|
333
|
+
_print_progress(
|
|
334
|
+
"Starting reinforcement training", step_num, self.config.num_steps
|
|
335
|
+
)
|
|
336
|
+
job = self.trainable_model.perform_reinforcement_step(dataset, step)
|
|
337
|
+
|
|
338
|
+
last_state = None
|
|
339
|
+
with _spinner_progress(
|
|
340
|
+
"Training job in progress", step_num, self.config.num_steps
|
|
341
|
+
):
|
|
342
|
+
while not job.is_completed:
|
|
343
|
+
job.raise_if_bad_state()
|
|
344
|
+
current_state = job.state
|
|
345
|
+
|
|
346
|
+
if current_state != last_state:
|
|
347
|
+
if current_state in ["uploading", "validating"]:
|
|
348
|
+
_print_progress_update(
|
|
349
|
+
f"Training job: {current_state} data"
|
|
350
|
+
)
|
|
351
|
+
elif current_state == "training":
|
|
352
|
+
_print_progress_update(
|
|
353
|
+
"Training job: model training in progress"
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
_print_progress_update(f"Training job: {current_state}")
|
|
357
|
+
last_state = current_state
|
|
358
|
+
|
|
359
|
+
await asyncio.sleep(10)
|
|
360
|
+
job = job.get()
|
|
361
|
+
if job is None:
|
|
362
|
+
raise JudgmentRuntimeError(
|
|
363
|
+
"Training job was deleted while waiting for completion"
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
_print_progress(
|
|
367
|
+
f"Training completed! New model: {job.output_model}",
|
|
368
|
+
step_num,
|
|
369
|
+
self.config.num_steps,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
_print_progress("All training steps completed!")
|
|
373
|
+
|
|
374
|
+
with _spinner_progress("Deploying final trained model"):
|
|
375
|
+
self.trainable_model.advance_to_next_step(self.config.num_steps)
|
|
376
|
+
|
|
377
|
+
return self.trainable_model.get_model_config(training_params)
|
|
378
|
+
|
|
379
|
+
async def train(
|
|
380
|
+
self,
|
|
381
|
+
agent_function: Callable[[Any], Any],
|
|
382
|
+
scorers: List["BaseScorer"],
|
|
383
|
+
prompts: dict[int, dict[Any, Any]],
|
|
384
|
+
) -> "ModelConfig":
|
|
385
|
+
try:
|
|
386
|
+
return await self.run_reinforcement_learning(
|
|
387
|
+
agent_function, scorers, prompts
|
|
388
|
+
)
|
|
389
|
+
except JudgmentRuntimeError:
|
|
390
|
+
raise
|
|
391
|
+
except Exception as e:
|
|
392
|
+
raise JudgmentRuntimeError(f"Training process failed: {str(e)}") from e
|