judgeval 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/cli.py +65 -0
- judgeval/common/api/api.py +44 -38
- judgeval/common/api/constants.py +18 -5
- judgeval/common/api/json_encoder.py +8 -9
- judgeval/common/tracer/core.py +448 -256
- judgeval/common/tracer/otel_span_processor.py +1 -1
- judgeval/common/tracer/span_processor.py +1 -1
- judgeval/common/tracer/span_transformer.py +2 -1
- judgeval/common/tracer/trace_manager.py +6 -1
- judgeval/common/trainer/__init__.py +5 -0
- judgeval/common/trainer/config.py +125 -0
- judgeval/common/trainer/console.py +151 -0
- judgeval/common/trainer/trainable_model.py +238 -0
- judgeval/common/trainer/trainer.py +301 -0
- judgeval/data/evaluation_run.py +104 -0
- judgeval/data/judgment_types.py +37 -8
- judgeval/data/trace.py +1 -0
- judgeval/data/trace_run.py +0 -2
- judgeval/integrations/langgraph.py +2 -1
- judgeval/judgment_client.py +90 -135
- judgeval/local_eval_queue.py +3 -5
- judgeval/run_evaluation.py +43 -299
- judgeval/scorers/base_scorer.py +9 -10
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +17 -3
- {judgeval-0.5.0.dist-info → judgeval-0.7.0.dist-info}/METADATA +10 -47
- {judgeval-0.5.0.dist-info → judgeval-0.7.0.dist-info}/RECORD +29 -22
- judgeval-0.7.0.dist-info/entry_points.txt +2 -0
- judgeval/evaluation_run.py +0 -80
- {judgeval-0.5.0.dist-info → judgeval-0.7.0.dist-info}/WHEEL +0 -0
- {judgeval-0.5.0.dist-info → judgeval-0.7.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -21,7 +21,7 @@ from judgeval.common.tracer.otel_exporter import JudgmentAPISpanExporter
|
|
21
21
|
from judgeval.common.tracer.span_processor import SpanProcessorBase
|
22
22
|
from judgeval.common.tracer.span_transformer import SpanTransformer
|
23
23
|
from judgeval.data import TraceSpan
|
24
|
-
from judgeval.evaluation_run import EvaluationRun
|
24
|
+
from judgeval.data.evaluation_run import EvaluationRun
|
25
25
|
|
26
26
|
|
27
27
|
class SimpleReadableSpan(ReadableSpan):
|
@@ -7,7 +7,7 @@ When monitoring is enabled, we use JudgmentSpanProcessor which overrides the met
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from judgeval.data import TraceSpan
|
10
|
-
from judgeval.evaluation_run import EvaluationRun
|
10
|
+
from judgeval.data.evaluation_run import EvaluationRun
|
11
11
|
|
12
12
|
|
13
13
|
class SpanProcessorBase:
|
@@ -11,7 +11,7 @@ from pydantic import BaseModel
|
|
11
11
|
|
12
12
|
from judgeval.common.api.json_encoder import json_encoder
|
13
13
|
from judgeval.data import TraceSpan
|
14
|
-
from judgeval.evaluation_run import EvaluationRun
|
14
|
+
from judgeval.data.evaluation_run import EvaluationRun
|
15
15
|
|
16
16
|
|
17
17
|
class SpanTransformer:
|
@@ -150,6 +150,7 @@ class SpanTransformer:
|
|
150
150
|
"additional_metadata": judgment_data.get("additional_metadata"),
|
151
151
|
"has_evaluation": judgment_data.get("has_evaluation", False),
|
152
152
|
"agent_name": judgment_data.get("agent_name"),
|
153
|
+
"class_name": judgment_data.get("class_name"),
|
153
154
|
"state_before": judgment_data.get("state_before"),
|
154
155
|
"state_after": judgment_data.get("state_after"),
|
155
156
|
"update_id": judgment_data.get("update_id", 1),
|
@@ -71,7 +71,12 @@ class TraceManagerClient:
|
|
71
71
|
|
72
72
|
server_response = self.api_client.upsert_trace(trace_data)
|
73
73
|
|
74
|
-
if
|
74
|
+
if (
|
75
|
+
not offline_mode
|
76
|
+
and show_link
|
77
|
+
and "ui_results_url" in server_response
|
78
|
+
and self.tracer.show_trace_urls
|
79
|
+
):
|
75
80
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
76
81
|
rprint(pretty_str)
|
77
82
|
|
@@ -0,0 +1,125 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Optional, Dict, Any
|
3
|
+
import json
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class TrainerConfig:
|
8
|
+
"""Configuration class for JudgmentTrainer parameters."""
|
9
|
+
|
10
|
+
deployment_id: str
|
11
|
+
user_id: str
|
12
|
+
model_id: str
|
13
|
+
base_model_name: str = "qwen2p5-7b-instruct"
|
14
|
+
rft_provider: str = "fireworks"
|
15
|
+
num_steps: int = 5
|
16
|
+
num_generations_per_prompt: int = (
|
17
|
+
4 # Number of rollouts/generations per input prompt
|
18
|
+
)
|
19
|
+
num_prompts_per_step: int = 4 # Number of input prompts to sample per training step
|
20
|
+
concurrency: int = 100
|
21
|
+
epochs: int = 1
|
22
|
+
learning_rate: float = 1e-5
|
23
|
+
accelerator_count: int = 1
|
24
|
+
accelerator_type: str = "NVIDIA_A100_80GB"
|
25
|
+
temperature: float = 1.5
|
26
|
+
max_tokens: int = 50
|
27
|
+
enable_addons: bool = True
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class ModelConfig:
|
32
|
+
"""
|
33
|
+
Configuration class for storing and loading trained model state.
|
34
|
+
|
35
|
+
This class enables persistence of trained models so they can be loaded
|
36
|
+
and used later without retraining.
|
37
|
+
|
38
|
+
Example usage:
|
39
|
+
trainer = JudgmentTrainer(config)
|
40
|
+
model_config = trainer.train(agent_function, scorers, prompts)
|
41
|
+
|
42
|
+
# Save the trained model configuration
|
43
|
+
model_config.save_to_file("my_trained_model.json")
|
44
|
+
|
45
|
+
# Later, load and use the trained model
|
46
|
+
loaded_config = ModelConfig.load_from_file("my_trained_model.json")
|
47
|
+
trained_model = TrainableModel.from_model_config(loaded_config)
|
48
|
+
|
49
|
+
# Use the trained model for inference
|
50
|
+
response = trained_model.chat.completions.create(
|
51
|
+
model="current", # Uses the loaded trained model
|
52
|
+
messages=[{"role": "user", "content": "Hello!"}]
|
53
|
+
)
|
54
|
+
"""
|
55
|
+
|
56
|
+
# Base model configuration
|
57
|
+
base_model_name: str
|
58
|
+
deployment_id: str
|
59
|
+
user_id: str
|
60
|
+
model_id: str
|
61
|
+
enable_addons: bool
|
62
|
+
|
63
|
+
# Training state
|
64
|
+
current_step: int
|
65
|
+
total_steps: int
|
66
|
+
|
67
|
+
# Current model information
|
68
|
+
current_model_name: Optional[str] = None
|
69
|
+
is_trained: bool = False
|
70
|
+
|
71
|
+
# Training parameters used (for reference)
|
72
|
+
training_params: Optional[Dict[str, Any]] = None
|
73
|
+
|
74
|
+
def to_dict(self) -> Dict[str, Any]:
|
75
|
+
"""Convert ModelConfig to dictionary for serialization."""
|
76
|
+
return {
|
77
|
+
"base_model_name": self.base_model_name,
|
78
|
+
"deployment_id": self.deployment_id,
|
79
|
+
"user_id": self.user_id,
|
80
|
+
"model_id": self.model_id,
|
81
|
+
"enable_addons": self.enable_addons,
|
82
|
+
"current_step": self.current_step,
|
83
|
+
"total_steps": self.total_steps,
|
84
|
+
"current_model_name": self.current_model_name,
|
85
|
+
"is_trained": self.is_trained,
|
86
|
+
"training_params": self.training_params,
|
87
|
+
}
|
88
|
+
|
89
|
+
@classmethod
|
90
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig":
|
91
|
+
"""Create ModelConfig from dictionary."""
|
92
|
+
return cls(
|
93
|
+
base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
|
94
|
+
deployment_id=data.get("deployment_id", "my-base-deployment"),
|
95
|
+
user_id=data.get("user_id", ""),
|
96
|
+
model_id=data.get("model_id", ""),
|
97
|
+
enable_addons=data.get("enable_addons", True),
|
98
|
+
current_step=data.get("current_step", 0),
|
99
|
+
total_steps=data.get("total_steps", 0),
|
100
|
+
current_model_name=data.get("current_model_name"),
|
101
|
+
is_trained=data.get("is_trained", False),
|
102
|
+
training_params=data.get("training_params"),
|
103
|
+
)
|
104
|
+
|
105
|
+
def to_json(self) -> str:
|
106
|
+
"""Convert ModelConfig to JSON string."""
|
107
|
+
return json.dumps(self.to_dict(), indent=2)
|
108
|
+
|
109
|
+
@classmethod
|
110
|
+
def from_json(cls, json_str: str) -> "ModelConfig":
|
111
|
+
"""Create ModelConfig from JSON string."""
|
112
|
+
data = json.loads(json_str)
|
113
|
+
return cls.from_dict(data)
|
114
|
+
|
115
|
+
def save_to_file(self, filepath: str):
|
116
|
+
"""Save ModelConfig to a JSON file."""
|
117
|
+
with open(filepath, "w") as f:
|
118
|
+
f.write(self.to_json())
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def load_from_file(cls, filepath: str) -> "ModelConfig":
|
122
|
+
"""Load ModelConfig from a JSON file."""
|
123
|
+
with open(filepath, "r") as f:
|
124
|
+
json_str = f.read()
|
125
|
+
return cls.from_json(json_str)
|
@@ -0,0 +1,151 @@
|
|
1
|
+
from contextlib import contextmanager
|
2
|
+
from typing import Optional
|
3
|
+
import sys
|
4
|
+
import os
|
5
|
+
|
6
|
+
|
7
|
+
# Detect if we're running in a Jupyter environment
|
8
|
+
def _is_jupyter_environment():
|
9
|
+
"""Check if we're running in a Jupyter notebook or similar environment."""
|
10
|
+
try:
|
11
|
+
# Check for IPython kernel
|
12
|
+
if "ipykernel" in sys.modules or "IPython" in sys.modules:
|
13
|
+
return True
|
14
|
+
# Check for Jupyter environment variables
|
15
|
+
if "JPY_PARENT_PID" in os.environ:
|
16
|
+
return True
|
17
|
+
# Check if we're in Google Colab
|
18
|
+
if "google.colab" in sys.modules:
|
19
|
+
return True
|
20
|
+
return False
|
21
|
+
except Exception:
|
22
|
+
return False
|
23
|
+
|
24
|
+
|
25
|
+
# Check environment once at import time
|
26
|
+
IS_JUPYTER = _is_jupyter_environment()
|
27
|
+
|
28
|
+
if not IS_JUPYTER:
|
29
|
+
# Safe to use Rich in non-Jupyter environments
|
30
|
+
try:
|
31
|
+
from rich.console import Console
|
32
|
+
from rich.spinner import Spinner
|
33
|
+
from rich.live import Live
|
34
|
+
from rich.text import Text
|
35
|
+
|
36
|
+
# Shared console instance for the trainer module to avoid conflicts
|
37
|
+
shared_console = Console()
|
38
|
+
RICH_AVAILABLE = True
|
39
|
+
except ImportError:
|
40
|
+
RICH_AVAILABLE = False
|
41
|
+
else:
|
42
|
+
# In Jupyter, avoid Rich to prevent recursion issues
|
43
|
+
RICH_AVAILABLE = False
|
44
|
+
|
45
|
+
|
46
|
+
# Fallback implementations for when Rich is not available or safe
|
47
|
+
class SimpleSpinner:
|
48
|
+
def __init__(self, name, text):
|
49
|
+
self.text = text
|
50
|
+
|
51
|
+
|
52
|
+
class SimpleLive:
|
53
|
+
def __init__(self, spinner, console=None, refresh_per_second=None):
|
54
|
+
self.spinner = spinner
|
55
|
+
|
56
|
+
def __enter__(self):
|
57
|
+
print(f"🔄 {self.spinner.text}")
|
58
|
+
return self
|
59
|
+
|
60
|
+
def __exit__(self, *args):
|
61
|
+
pass
|
62
|
+
|
63
|
+
def update(self, spinner):
|
64
|
+
print(f"🔄 {spinner.text}")
|
65
|
+
|
66
|
+
|
67
|
+
def safe_print(message, style=None):
|
68
|
+
"""Safe print function that works in all environments."""
|
69
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
70
|
+
shared_console.print(message, style=style)
|
71
|
+
else:
|
72
|
+
# Use simple print with emoji indicators for different styles
|
73
|
+
if style == "green":
|
74
|
+
print(f"✅ {message}")
|
75
|
+
elif style == "yellow":
|
76
|
+
print(f"⚠️ {message}")
|
77
|
+
elif style == "blue":
|
78
|
+
print(f"🔵 {message}")
|
79
|
+
elif style == "cyan":
|
80
|
+
print(f"🔷 {message}")
|
81
|
+
else:
|
82
|
+
print(message)
|
83
|
+
|
84
|
+
|
85
|
+
@contextmanager
|
86
|
+
def _spinner_progress(
|
87
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
88
|
+
):
|
89
|
+
"""Context manager for spinner-based progress display."""
|
90
|
+
if step is not None and total_steps is not None:
|
91
|
+
full_message = f"[Step {step}/{total_steps}] {message}"
|
92
|
+
else:
|
93
|
+
full_message = f"[Training] {message}"
|
94
|
+
|
95
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
96
|
+
spinner = Spinner("dots", text=Text(full_message, style="cyan"))
|
97
|
+
with Live(spinner, console=shared_console, refresh_per_second=10):
|
98
|
+
yield
|
99
|
+
else:
|
100
|
+
# Fallback for Jupyter or when Rich is not available
|
101
|
+
print(f"🔄 {full_message}")
|
102
|
+
try:
|
103
|
+
yield
|
104
|
+
finally:
|
105
|
+
print(f"✅ {full_message} - Complete")
|
106
|
+
|
107
|
+
|
108
|
+
@contextmanager
|
109
|
+
def _model_spinner_progress(message: str):
|
110
|
+
"""Context manager for model operation spinner-based progress display."""
|
111
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
112
|
+
spinner = Spinner("dots", text=Text(f"[Model] {message}", style="blue"))
|
113
|
+
with Live(spinner, console=shared_console, refresh_per_second=10) as live:
|
114
|
+
|
115
|
+
def update_progress(progress_message: str):
|
116
|
+
"""Update the spinner with a new progress message."""
|
117
|
+
new_text = f"[Model] {message}\n └─ {progress_message}"
|
118
|
+
spinner.text = Text(new_text, style="blue")
|
119
|
+
live.update(spinner)
|
120
|
+
|
121
|
+
yield update_progress
|
122
|
+
else:
|
123
|
+
# Fallback for Jupyter or when Rich is not available
|
124
|
+
print(f"🔵 [Model] {message}")
|
125
|
+
|
126
|
+
def update_progress(progress_message: str):
|
127
|
+
print(f" └─ {progress_message}")
|
128
|
+
|
129
|
+
yield update_progress
|
130
|
+
|
131
|
+
|
132
|
+
def _print_progress(
|
133
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
134
|
+
):
|
135
|
+
"""Print progress message with consistent formatting."""
|
136
|
+
if step is not None and total_steps is not None:
|
137
|
+
safe_print(f"[Step {step}/{total_steps}] {message}", style="green")
|
138
|
+
else:
|
139
|
+
safe_print(f"[Training] {message}", style="green")
|
140
|
+
|
141
|
+
|
142
|
+
def _print_progress_update(
|
143
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
144
|
+
):
|
145
|
+
"""Print progress update message (for status changes during long operations)."""
|
146
|
+
safe_print(f" └─ {message}", style="yellow")
|
147
|
+
|
148
|
+
|
149
|
+
def _print_model_progress(message: str):
|
150
|
+
"""Print model progress message with consistent formatting."""
|
151
|
+
safe_print(f"[Model] {message}", style="blue")
|
@@ -0,0 +1,238 @@
|
|
1
|
+
from fireworks import LLM
|
2
|
+
from .config import TrainerConfig, ModelConfig
|
3
|
+
from typing import Optional, Dict, Any, Callable
|
4
|
+
from .console import _model_spinner_progress, _print_model_progress
|
5
|
+
from judgeval.common.exceptions import JudgmentAPIError
|
6
|
+
|
7
|
+
|
8
|
+
class TrainableModel:
|
9
|
+
"""
|
10
|
+
A wrapper class for managing model snapshots during training.
|
11
|
+
|
12
|
+
This class automatically handles model snapshot creation and management
|
13
|
+
during the RFT (Reinforcement Fine-Tuning) process,
|
14
|
+
abstracting away manual snapshot management from users.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, config: TrainerConfig):
|
18
|
+
"""
|
19
|
+
Initialize the TrainableModel.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
config: TrainerConfig instance with model configuration
|
23
|
+
"""
|
24
|
+
try:
|
25
|
+
self.config = config
|
26
|
+
self.current_step = 0
|
27
|
+
self._current_model = None
|
28
|
+
self._tracer_wrapper_func = None
|
29
|
+
|
30
|
+
self._base_model = self._create_base_model()
|
31
|
+
self._current_model = self._base_model
|
32
|
+
except Exception as e:
|
33
|
+
raise JudgmentAPIError(
|
34
|
+
f"Failed to initialize TrainableModel: {str(e)}"
|
35
|
+
) from e
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def from_model_config(cls, model_config: ModelConfig) -> "TrainableModel":
|
39
|
+
"""
|
40
|
+
Create a TrainableModel from a saved ModelConfig.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
model_config: ModelConfig instance with saved model state
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
TrainableModel instance configured to use the saved model
|
47
|
+
"""
|
48
|
+
# Create a TrainerConfig from the ModelConfig
|
49
|
+
trainer_config = TrainerConfig(
|
50
|
+
base_model_name=model_config.base_model_name,
|
51
|
+
deployment_id=model_config.deployment_id,
|
52
|
+
user_id=model_config.user_id,
|
53
|
+
model_id=model_config.model_id,
|
54
|
+
enable_addons=model_config.enable_addons,
|
55
|
+
)
|
56
|
+
|
57
|
+
instance = cls(trainer_config)
|
58
|
+
instance.current_step = model_config.current_step
|
59
|
+
|
60
|
+
if model_config.is_trained and model_config.current_model_name:
|
61
|
+
instance._load_trained_model(model_config.current_model_name)
|
62
|
+
|
63
|
+
return instance
|
64
|
+
|
65
|
+
def _create_base_model(self):
|
66
|
+
"""Create and configure the base model."""
|
67
|
+
try:
|
68
|
+
with _model_spinner_progress(
|
69
|
+
"Creating and deploying base model..."
|
70
|
+
) as update_progress:
|
71
|
+
update_progress("Creating base model instance...")
|
72
|
+
base_model = LLM(
|
73
|
+
model=self.config.base_model_name,
|
74
|
+
deployment_type="on-demand",
|
75
|
+
id=self.config.deployment_id,
|
76
|
+
enable_addons=self.config.enable_addons,
|
77
|
+
)
|
78
|
+
update_progress("Applying deployment configuration...")
|
79
|
+
base_model.apply()
|
80
|
+
_print_model_progress("Base model deployment ready")
|
81
|
+
return base_model
|
82
|
+
except Exception as e:
|
83
|
+
raise JudgmentAPIError(
|
84
|
+
f"Failed to create and deploy base model '{self.config.base_model_name}': {str(e)}"
|
85
|
+
) from e
|
86
|
+
|
87
|
+
def _load_trained_model(self, model_name: str):
|
88
|
+
"""Load a trained model by name."""
|
89
|
+
try:
|
90
|
+
with _model_spinner_progress(
|
91
|
+
f"Loading and deploying trained model: {model_name}"
|
92
|
+
) as update_progress:
|
93
|
+
update_progress("Creating trained model instance...")
|
94
|
+
self._current_model = LLM(
|
95
|
+
model=model_name,
|
96
|
+
deployment_type="on-demand-lora",
|
97
|
+
base_id=self.config.deployment_id,
|
98
|
+
)
|
99
|
+
update_progress("Applying deployment configuration...")
|
100
|
+
self._current_model.apply()
|
101
|
+
_print_model_progress("Trained model deployment ready")
|
102
|
+
|
103
|
+
if self._tracer_wrapper_func:
|
104
|
+
self._tracer_wrapper_func(self._current_model)
|
105
|
+
except Exception as e:
|
106
|
+
raise JudgmentAPIError(
|
107
|
+
f"Failed to load and deploy trained model '{model_name}': {str(e)}"
|
108
|
+
) from e
|
109
|
+
|
110
|
+
def get_current_model(self):
|
111
|
+
return self._current_model
|
112
|
+
|
113
|
+
@property
|
114
|
+
def chat(self):
|
115
|
+
"""OpenAI-compatible chat interface."""
|
116
|
+
return self._current_model.chat
|
117
|
+
|
118
|
+
@property
|
119
|
+
def completions(self):
|
120
|
+
"""OpenAI-compatible completions interface."""
|
121
|
+
return self._current_model.completions
|
122
|
+
|
123
|
+
def advance_to_next_step(self, step: int):
|
124
|
+
"""
|
125
|
+
Advance to the next training step and update the current model snapshot.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
step: The current training step number
|
129
|
+
"""
|
130
|
+
try:
|
131
|
+
self.current_step = step
|
132
|
+
|
133
|
+
if step == 0:
|
134
|
+
self._current_model = self._base_model
|
135
|
+
else:
|
136
|
+
model_name = f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{step}"
|
137
|
+
with _model_spinner_progress(
|
138
|
+
f"Creating and deploying model snapshot: {model_name}"
|
139
|
+
) as update_progress:
|
140
|
+
update_progress("Creating model snapshot instance...")
|
141
|
+
self._current_model = LLM(
|
142
|
+
model=model_name,
|
143
|
+
deployment_type="on-demand-lora",
|
144
|
+
base_id=self.config.deployment_id,
|
145
|
+
)
|
146
|
+
update_progress("Applying deployment configuration...")
|
147
|
+
self._current_model.apply()
|
148
|
+
_print_model_progress("Model snapshot deployment ready")
|
149
|
+
|
150
|
+
if self._tracer_wrapper_func:
|
151
|
+
self._tracer_wrapper_func(self._current_model)
|
152
|
+
except Exception as e:
|
153
|
+
raise JudgmentAPIError(
|
154
|
+
f"Failed to advance to training step {step}: {str(e)}"
|
155
|
+
) from e
|
156
|
+
|
157
|
+
def perform_reinforcement_step(self, dataset, step: int):
|
158
|
+
"""
|
159
|
+
Perform a reinforcement learning step using the current model.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
dataset: Training dataset for the reinforcement step
|
163
|
+
step: Current step number for output model naming
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
Training job object
|
167
|
+
"""
|
168
|
+
try:
|
169
|
+
model_name = f"{self.config.model_id}-v{step + 1}"
|
170
|
+
return self._current_model.reinforcement_step(
|
171
|
+
dataset=dataset,
|
172
|
+
output_model=model_name,
|
173
|
+
epochs=self.config.epochs,
|
174
|
+
learning_rate=self.config.learning_rate,
|
175
|
+
accelerator_count=self.config.accelerator_count,
|
176
|
+
accelerator_type=self.config.accelerator_type,
|
177
|
+
)
|
178
|
+
except Exception as e:
|
179
|
+
raise JudgmentAPIError(
|
180
|
+
f"Failed to start reinforcement learning step {step + 1}: {str(e)}"
|
181
|
+
) from e
|
182
|
+
|
183
|
+
def get_model_config(
|
184
|
+
self, training_params: Optional[Dict[str, Any]] = None
|
185
|
+
) -> ModelConfig:
|
186
|
+
"""
|
187
|
+
Get the current model configuration for persistence.
|
188
|
+
|
189
|
+
Args:
|
190
|
+
training_params: Optional training parameters to include in config
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
ModelConfig instance with current model state
|
194
|
+
"""
|
195
|
+
current_model_name = None
|
196
|
+
is_trained = False
|
197
|
+
|
198
|
+
if self.current_step > 0:
|
199
|
+
current_model_name = f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{self.current_step}"
|
200
|
+
is_trained = True
|
201
|
+
|
202
|
+
return ModelConfig(
|
203
|
+
base_model_name=self.config.base_model_name,
|
204
|
+
deployment_id=self.config.deployment_id,
|
205
|
+
user_id=self.config.user_id,
|
206
|
+
model_id=self.config.model_id,
|
207
|
+
enable_addons=self.config.enable_addons,
|
208
|
+
current_step=self.current_step,
|
209
|
+
total_steps=self.config.num_steps,
|
210
|
+
current_model_name=current_model_name,
|
211
|
+
is_trained=is_trained,
|
212
|
+
training_params=training_params,
|
213
|
+
)
|
214
|
+
|
215
|
+
def save_model_config(
|
216
|
+
self, filepath: str, training_params: Optional[Dict[str, Any]] = None
|
217
|
+
):
|
218
|
+
"""
|
219
|
+
Save the current model configuration to a file.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
filepath: Path to save the configuration file
|
223
|
+
training_params: Optional training parameters to include in config
|
224
|
+
"""
|
225
|
+
model_config = self.get_model_config(training_params)
|
226
|
+
model_config.save_to_file(filepath)
|
227
|
+
|
228
|
+
def _register_tracer_wrapper(self, wrapper_func: Callable):
|
229
|
+
"""
|
230
|
+
Register a tracer wrapper function to be reapplied when models change.
|
231
|
+
|
232
|
+
This is called internally by the tracer's wrap() function to ensure
|
233
|
+
that new model instances created during training are automatically wrapped.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
wrapper_func: Function that wraps a model instance with tracing
|
237
|
+
"""
|
238
|
+
self._tracer_wrapper_func = wrapper_func
|