judgeval 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/cli.py +1 -1
- judgeval/common/api/constants.py +1 -1
- judgeval/common/tracer/core.py +171 -1
- judgeval/common/tracer/trace_manager.py +6 -1
- judgeval/common/trainer/__init__.py +5 -0
- judgeval/common/trainer/config.py +125 -0
- judgeval/common/trainer/console.py +151 -0
- judgeval/common/trainer/trainable_model.py +238 -0
- judgeval/common/trainer/trainer.py +301 -0
- judgeval/judgment_client.py +4 -104
- judgeval/run_evaluation.py +10 -107
- {judgeval-0.6.0.dist-info → judgeval-0.7.0.dist-info}/METADATA +8 -47
- {judgeval-0.6.0.dist-info → judgeval-0.7.0.dist-info}/RECORD +16 -11
- {judgeval-0.6.0.dist-info → judgeval-0.7.0.dist-info}/WHEEL +0 -0
- {judgeval-0.6.0.dist-info → judgeval-0.7.0.dist-info}/entry_points.txt +0 -0
- {judgeval-0.6.0.dist-info → judgeval-0.7.0.dist-info}/licenses/LICENSE.md +0 -0
judgeval/cli.py
CHANGED
judgeval/common/api/constants.py
CHANGED
@@ -51,7 +51,7 @@ JUDGMENT_ADD_TO_RUN_EVAL_QUEUE_API_URL = f"{ROOT_API}/add_to_run_eval_queue/"
|
|
51
51
|
JUDGMENT_GET_EVAL_STATUS_API_URL = f"{ROOT_API}/get_evaluation_status/"
|
52
52
|
|
53
53
|
# Custom Scorers API
|
54
|
-
JUDGMENT_CUSTOM_SCORER_UPLOAD_API_URL = f"{ROOT_API}/
|
54
|
+
JUDGMENT_CUSTOM_SCORER_UPLOAD_API_URL = f"{ROOT_API}/upload_scorer/"
|
55
55
|
|
56
56
|
|
57
57
|
# Evaluation API Payloads
|
judgeval/common/tracer/core.py
CHANGED
@@ -815,6 +815,8 @@ class Tracer:
|
|
815
815
|
== "true",
|
816
816
|
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
|
817
817
|
== "true",
|
818
|
+
show_trace_urls: bool = os.getenv("JUDGMENT_SHOW_TRACE_URLS", "true").lower()
|
819
|
+
== "true",
|
818
820
|
# S3 configuration
|
819
821
|
use_s3: bool = False,
|
820
822
|
s3_bucket_name: Optional[str] = None,
|
@@ -859,6 +861,7 @@ class Tracer:
|
|
859
861
|
self.traces: List[Trace] = []
|
860
862
|
self.enable_monitoring: bool = enable_monitoring
|
861
863
|
self.enable_evaluations: bool = enable_evaluations
|
864
|
+
self.show_trace_urls: bool = show_trace_urls
|
862
865
|
self.class_identifiers: Dict[
|
863
866
|
str, str
|
864
867
|
] = {} # Dictionary to store class identifiers
|
@@ -1731,6 +1734,93 @@ class Tracer:
|
|
1731
1734
|
f"Error during background service shutdown: {e}"
|
1732
1735
|
)
|
1733
1736
|
|
1737
|
+
def trace_to_message_history(
|
1738
|
+
self, trace: Union[Trace, TraceClient]
|
1739
|
+
) -> List[Dict[str, str]]:
|
1740
|
+
"""
|
1741
|
+
Extract message history from a trace for training purposes.
|
1742
|
+
|
1743
|
+
This method processes trace spans to reconstruct the conversation flow,
|
1744
|
+
extracting messages in chronological order from LLM, user, and tool spans.
|
1745
|
+
|
1746
|
+
Args:
|
1747
|
+
trace: Trace or TraceClient instance to extract messages from
|
1748
|
+
|
1749
|
+
Returns:
|
1750
|
+
List of message dictionaries with 'role' and 'content' keys
|
1751
|
+
|
1752
|
+
Raises:
|
1753
|
+
ValueError: If no trace is provided
|
1754
|
+
"""
|
1755
|
+
if not trace:
|
1756
|
+
raise ValueError("No trace provided")
|
1757
|
+
|
1758
|
+
# Handle both Trace and TraceClient objects
|
1759
|
+
if isinstance(trace, TraceClient):
|
1760
|
+
spans = trace.trace_spans
|
1761
|
+
else:
|
1762
|
+
spans = trace.trace_spans if hasattr(trace, "trace_spans") else []
|
1763
|
+
|
1764
|
+
messages = []
|
1765
|
+
first_found = False
|
1766
|
+
|
1767
|
+
# Process spans in chronological order
|
1768
|
+
for span in sorted(
|
1769
|
+
spans, key=lambda s: s.created_at if hasattr(s, "created_at") else 0
|
1770
|
+
):
|
1771
|
+
# Skip spans without output (except for first LLM span which may have input messages)
|
1772
|
+
if span.output is None and span.span_type != "llm":
|
1773
|
+
continue
|
1774
|
+
|
1775
|
+
if span.span_type == "llm":
|
1776
|
+
# For the first LLM span, extract input messages (system + user prompts)
|
1777
|
+
if not first_found and hasattr(span, "inputs") and span.inputs:
|
1778
|
+
input_messages = span.inputs.get("messages", [])
|
1779
|
+
if input_messages:
|
1780
|
+
first_found = True
|
1781
|
+
# Add input messages (typically system and user messages)
|
1782
|
+
for msg in input_messages:
|
1783
|
+
if (
|
1784
|
+
isinstance(msg, dict)
|
1785
|
+
and "role" in msg
|
1786
|
+
and "content" in msg
|
1787
|
+
):
|
1788
|
+
messages.append(
|
1789
|
+
{"role": msg["role"], "content": msg["content"]}
|
1790
|
+
)
|
1791
|
+
|
1792
|
+
# Add assistant response from span output
|
1793
|
+
if span.output is not None:
|
1794
|
+
messages.append({"role": "assistant", "content": str(span.output)})
|
1795
|
+
|
1796
|
+
elif span.span_type == "user":
|
1797
|
+
# Add user messages
|
1798
|
+
if span.output is not None:
|
1799
|
+
messages.append({"role": "user", "content": str(span.output)})
|
1800
|
+
|
1801
|
+
elif span.span_type == "tool":
|
1802
|
+
# Add tool responses as user messages (common pattern in training)
|
1803
|
+
if span.output is not None:
|
1804
|
+
messages.append({"role": "user", "content": str(span.output)})
|
1805
|
+
|
1806
|
+
return messages
|
1807
|
+
|
1808
|
+
def get_current_message_history(self) -> List[Dict[str, str]]:
|
1809
|
+
"""
|
1810
|
+
Get message history from the current trace.
|
1811
|
+
|
1812
|
+
Returns:
|
1813
|
+
List of message dictionaries from the current trace context
|
1814
|
+
|
1815
|
+
Raises:
|
1816
|
+
ValueError: If no current trace is found
|
1817
|
+
"""
|
1818
|
+
current_trace = self.get_current_trace()
|
1819
|
+
if not current_trace:
|
1820
|
+
raise ValueError("No current trace found")
|
1821
|
+
|
1822
|
+
return self.trace_to_message_history(current_trace)
|
1823
|
+
|
1734
1824
|
|
1735
1825
|
def _get_current_trace(
|
1736
1826
|
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
@@ -1746,7 +1836,7 @@ def wrap(
|
|
1746
1836
|
) -> Any:
|
1747
1837
|
"""
|
1748
1838
|
Wraps an API client to add tracing capabilities.
|
1749
|
-
Supports OpenAI, Together, Anthropic,
|
1839
|
+
Supports OpenAI, Together, Anthropic, Google GenAI clients, and TrainableModel.
|
1750
1840
|
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1751
1841
|
"""
|
1752
1842
|
(
|
@@ -1871,6 +1961,39 @@ def wrap(
|
|
1871
1961
|
setattr(client.chat.completions, "create", wrapped(original_create))
|
1872
1962
|
elif isinstance(client, (groq_AsyncGroq)):
|
1873
1963
|
setattr(client.chat.completions, "create", wrapped_async(original_create))
|
1964
|
+
|
1965
|
+
# Check for TrainableModel from judgeval.common.trainer
|
1966
|
+
try:
|
1967
|
+
from judgeval.common.trainer import TrainableModel
|
1968
|
+
|
1969
|
+
if isinstance(client, TrainableModel):
|
1970
|
+
# Define a wrapper function that can be reapplied to new model instances
|
1971
|
+
def wrap_model_instance(model_instance):
|
1972
|
+
"""Wrap a model instance with tracing functionality"""
|
1973
|
+
if hasattr(model_instance, "chat") and hasattr(
|
1974
|
+
model_instance.chat, "completions"
|
1975
|
+
):
|
1976
|
+
if hasattr(model_instance.chat.completions, "create"):
|
1977
|
+
setattr(
|
1978
|
+
model_instance.chat.completions,
|
1979
|
+
"create",
|
1980
|
+
wrapped(model_instance.chat.completions.create),
|
1981
|
+
)
|
1982
|
+
if hasattr(model_instance.chat.completions, "acreate"):
|
1983
|
+
setattr(
|
1984
|
+
model_instance.chat.completions,
|
1985
|
+
"acreate",
|
1986
|
+
wrapped_async(model_instance.chat.completions.acreate),
|
1987
|
+
)
|
1988
|
+
|
1989
|
+
# Register the wrapper function with the TrainableModel
|
1990
|
+
client._register_tracer_wrapper(wrap_model_instance)
|
1991
|
+
|
1992
|
+
# Apply wrapping to the current model
|
1993
|
+
wrap_model_instance(client._current_model)
|
1994
|
+
except ImportError:
|
1995
|
+
pass # TrainableModel not available
|
1996
|
+
|
1874
1997
|
return client
|
1875
1998
|
|
1876
1999
|
|
@@ -1977,6 +2100,22 @@ def _get_client_config(
|
|
1977
2100
|
return "GROQ_API_CALL", client.chat.completions.create, None, None, None
|
1978
2101
|
elif isinstance(client, (groq_AsyncGroq)):
|
1979
2102
|
return "GROQ_API_CALL", client.chat.completions.create, None, None, None
|
2103
|
+
|
2104
|
+
# Check for TrainableModel
|
2105
|
+
try:
|
2106
|
+
from judgeval.common.trainer import TrainableModel
|
2107
|
+
|
2108
|
+
if isinstance(client, TrainableModel):
|
2109
|
+
return (
|
2110
|
+
"FIREWORKS_TRAINABLE_MODEL_CALL",
|
2111
|
+
client._current_model.chat.completions.create,
|
2112
|
+
None,
|
2113
|
+
None,
|
2114
|
+
None,
|
2115
|
+
)
|
2116
|
+
except ImportError:
|
2117
|
+
pass # TrainableModel not available
|
2118
|
+
|
1980
2119
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
1981
2120
|
|
1982
2121
|
|
@@ -2155,6 +2294,37 @@ def _format_output_data(
|
|
2155
2294
|
cache_creation_input_tokens,
|
2156
2295
|
)
|
2157
2296
|
|
2297
|
+
# Check for TrainableModel
|
2298
|
+
try:
|
2299
|
+
from judgeval.common.trainer import TrainableModel
|
2300
|
+
|
2301
|
+
if isinstance(client, TrainableModel):
|
2302
|
+
# TrainableModel uses Fireworks LLM internally, so response format should be similar to OpenAI
|
2303
|
+
if (
|
2304
|
+
hasattr(response, "model")
|
2305
|
+
and hasattr(response, "usage")
|
2306
|
+
and hasattr(response, "choices")
|
2307
|
+
):
|
2308
|
+
model_name = response.model
|
2309
|
+
prompt_tokens = response.usage.prompt_tokens if response.usage else 0
|
2310
|
+
completion_tokens = (
|
2311
|
+
response.usage.completion_tokens if response.usage else 0
|
2312
|
+
)
|
2313
|
+
message_content = response.choices[0].message.content
|
2314
|
+
|
2315
|
+
# Use LiteLLM cost calculation with fireworks_ai prefix
|
2316
|
+
# LiteLLM supports Fireworks AI models for cost calculation when prefixed with "fireworks_ai/"
|
2317
|
+
fireworks_model_name = f"fireworks_ai/{model_name}"
|
2318
|
+
return message_content, _create_usage(
|
2319
|
+
fireworks_model_name,
|
2320
|
+
prompt_tokens,
|
2321
|
+
completion_tokens,
|
2322
|
+
cache_read_input_tokens,
|
2323
|
+
cache_creation_input_tokens,
|
2324
|
+
)
|
2325
|
+
except ImportError:
|
2326
|
+
pass # TrainableModel not available
|
2327
|
+
|
2158
2328
|
judgeval_logger.warning(f"Unsupported client type: {type(client)}")
|
2159
2329
|
return None, None
|
2160
2330
|
|
@@ -71,7 +71,12 @@ class TraceManagerClient:
|
|
71
71
|
|
72
72
|
server_response = self.api_client.upsert_trace(trace_data)
|
73
73
|
|
74
|
-
if
|
74
|
+
if (
|
75
|
+
not offline_mode
|
76
|
+
and show_link
|
77
|
+
and "ui_results_url" in server_response
|
78
|
+
and self.tracer.show_trace_urls
|
79
|
+
):
|
75
80
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
76
81
|
rprint(pretty_str)
|
77
82
|
|
@@ -0,0 +1,125 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Optional, Dict, Any
|
3
|
+
import json
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class TrainerConfig:
|
8
|
+
"""Configuration class for JudgmentTrainer parameters."""
|
9
|
+
|
10
|
+
deployment_id: str
|
11
|
+
user_id: str
|
12
|
+
model_id: str
|
13
|
+
base_model_name: str = "qwen2p5-7b-instruct"
|
14
|
+
rft_provider: str = "fireworks"
|
15
|
+
num_steps: int = 5
|
16
|
+
num_generations_per_prompt: int = (
|
17
|
+
4 # Number of rollouts/generations per input prompt
|
18
|
+
)
|
19
|
+
num_prompts_per_step: int = 4 # Number of input prompts to sample per training step
|
20
|
+
concurrency: int = 100
|
21
|
+
epochs: int = 1
|
22
|
+
learning_rate: float = 1e-5
|
23
|
+
accelerator_count: int = 1
|
24
|
+
accelerator_type: str = "NVIDIA_A100_80GB"
|
25
|
+
temperature: float = 1.5
|
26
|
+
max_tokens: int = 50
|
27
|
+
enable_addons: bool = True
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class ModelConfig:
|
32
|
+
"""
|
33
|
+
Configuration class for storing and loading trained model state.
|
34
|
+
|
35
|
+
This class enables persistence of trained models so they can be loaded
|
36
|
+
and used later without retraining.
|
37
|
+
|
38
|
+
Example usage:
|
39
|
+
trainer = JudgmentTrainer(config)
|
40
|
+
model_config = trainer.train(agent_function, scorers, prompts)
|
41
|
+
|
42
|
+
# Save the trained model configuration
|
43
|
+
model_config.save_to_file("my_trained_model.json")
|
44
|
+
|
45
|
+
# Later, load and use the trained model
|
46
|
+
loaded_config = ModelConfig.load_from_file("my_trained_model.json")
|
47
|
+
trained_model = TrainableModel.from_model_config(loaded_config)
|
48
|
+
|
49
|
+
# Use the trained model for inference
|
50
|
+
response = trained_model.chat.completions.create(
|
51
|
+
model="current", # Uses the loaded trained model
|
52
|
+
messages=[{"role": "user", "content": "Hello!"}]
|
53
|
+
)
|
54
|
+
"""
|
55
|
+
|
56
|
+
# Base model configuration
|
57
|
+
base_model_name: str
|
58
|
+
deployment_id: str
|
59
|
+
user_id: str
|
60
|
+
model_id: str
|
61
|
+
enable_addons: bool
|
62
|
+
|
63
|
+
# Training state
|
64
|
+
current_step: int
|
65
|
+
total_steps: int
|
66
|
+
|
67
|
+
# Current model information
|
68
|
+
current_model_name: Optional[str] = None
|
69
|
+
is_trained: bool = False
|
70
|
+
|
71
|
+
# Training parameters used (for reference)
|
72
|
+
training_params: Optional[Dict[str, Any]] = None
|
73
|
+
|
74
|
+
def to_dict(self) -> Dict[str, Any]:
|
75
|
+
"""Convert ModelConfig to dictionary for serialization."""
|
76
|
+
return {
|
77
|
+
"base_model_name": self.base_model_name,
|
78
|
+
"deployment_id": self.deployment_id,
|
79
|
+
"user_id": self.user_id,
|
80
|
+
"model_id": self.model_id,
|
81
|
+
"enable_addons": self.enable_addons,
|
82
|
+
"current_step": self.current_step,
|
83
|
+
"total_steps": self.total_steps,
|
84
|
+
"current_model_name": self.current_model_name,
|
85
|
+
"is_trained": self.is_trained,
|
86
|
+
"training_params": self.training_params,
|
87
|
+
}
|
88
|
+
|
89
|
+
@classmethod
|
90
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig":
|
91
|
+
"""Create ModelConfig from dictionary."""
|
92
|
+
return cls(
|
93
|
+
base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
|
94
|
+
deployment_id=data.get("deployment_id", "my-base-deployment"),
|
95
|
+
user_id=data.get("user_id", ""),
|
96
|
+
model_id=data.get("model_id", ""),
|
97
|
+
enable_addons=data.get("enable_addons", True),
|
98
|
+
current_step=data.get("current_step", 0),
|
99
|
+
total_steps=data.get("total_steps", 0),
|
100
|
+
current_model_name=data.get("current_model_name"),
|
101
|
+
is_trained=data.get("is_trained", False),
|
102
|
+
training_params=data.get("training_params"),
|
103
|
+
)
|
104
|
+
|
105
|
+
def to_json(self) -> str:
|
106
|
+
"""Convert ModelConfig to JSON string."""
|
107
|
+
return json.dumps(self.to_dict(), indent=2)
|
108
|
+
|
109
|
+
@classmethod
|
110
|
+
def from_json(cls, json_str: str) -> "ModelConfig":
|
111
|
+
"""Create ModelConfig from JSON string."""
|
112
|
+
data = json.loads(json_str)
|
113
|
+
return cls.from_dict(data)
|
114
|
+
|
115
|
+
def save_to_file(self, filepath: str):
|
116
|
+
"""Save ModelConfig to a JSON file."""
|
117
|
+
with open(filepath, "w") as f:
|
118
|
+
f.write(self.to_json())
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def load_from_file(cls, filepath: str) -> "ModelConfig":
|
122
|
+
"""Load ModelConfig from a JSON file."""
|
123
|
+
with open(filepath, "r") as f:
|
124
|
+
json_str = f.read()
|
125
|
+
return cls.from_json(json_str)
|
@@ -0,0 +1,151 @@
|
|
1
|
+
from contextlib import contextmanager
|
2
|
+
from typing import Optional
|
3
|
+
import sys
|
4
|
+
import os
|
5
|
+
|
6
|
+
|
7
|
+
# Detect if we're running in a Jupyter environment
|
8
|
+
def _is_jupyter_environment():
|
9
|
+
"""Check if we're running in a Jupyter notebook or similar environment."""
|
10
|
+
try:
|
11
|
+
# Check for IPython kernel
|
12
|
+
if "ipykernel" in sys.modules or "IPython" in sys.modules:
|
13
|
+
return True
|
14
|
+
# Check for Jupyter environment variables
|
15
|
+
if "JPY_PARENT_PID" in os.environ:
|
16
|
+
return True
|
17
|
+
# Check if we're in Google Colab
|
18
|
+
if "google.colab" in sys.modules:
|
19
|
+
return True
|
20
|
+
return False
|
21
|
+
except Exception:
|
22
|
+
return False
|
23
|
+
|
24
|
+
|
25
|
+
# Check environment once at import time
|
26
|
+
IS_JUPYTER = _is_jupyter_environment()
|
27
|
+
|
28
|
+
if not IS_JUPYTER:
|
29
|
+
# Safe to use Rich in non-Jupyter environments
|
30
|
+
try:
|
31
|
+
from rich.console import Console
|
32
|
+
from rich.spinner import Spinner
|
33
|
+
from rich.live import Live
|
34
|
+
from rich.text import Text
|
35
|
+
|
36
|
+
# Shared console instance for the trainer module to avoid conflicts
|
37
|
+
shared_console = Console()
|
38
|
+
RICH_AVAILABLE = True
|
39
|
+
except ImportError:
|
40
|
+
RICH_AVAILABLE = False
|
41
|
+
else:
|
42
|
+
# In Jupyter, avoid Rich to prevent recursion issues
|
43
|
+
RICH_AVAILABLE = False
|
44
|
+
|
45
|
+
|
46
|
+
# Fallback implementations for when Rich is not available or safe
|
47
|
+
class SimpleSpinner:
|
48
|
+
def __init__(self, name, text):
|
49
|
+
self.text = text
|
50
|
+
|
51
|
+
|
52
|
+
class SimpleLive:
|
53
|
+
def __init__(self, spinner, console=None, refresh_per_second=None):
|
54
|
+
self.spinner = spinner
|
55
|
+
|
56
|
+
def __enter__(self):
|
57
|
+
print(f"🔄 {self.spinner.text}")
|
58
|
+
return self
|
59
|
+
|
60
|
+
def __exit__(self, *args):
|
61
|
+
pass
|
62
|
+
|
63
|
+
def update(self, spinner):
|
64
|
+
print(f"🔄 {spinner.text}")
|
65
|
+
|
66
|
+
|
67
|
+
def safe_print(message, style=None):
|
68
|
+
"""Safe print function that works in all environments."""
|
69
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
70
|
+
shared_console.print(message, style=style)
|
71
|
+
else:
|
72
|
+
# Use simple print with emoji indicators for different styles
|
73
|
+
if style == "green":
|
74
|
+
print(f"✅ {message}")
|
75
|
+
elif style == "yellow":
|
76
|
+
print(f"⚠️ {message}")
|
77
|
+
elif style == "blue":
|
78
|
+
print(f"🔵 {message}")
|
79
|
+
elif style == "cyan":
|
80
|
+
print(f"🔷 {message}")
|
81
|
+
else:
|
82
|
+
print(message)
|
83
|
+
|
84
|
+
|
85
|
+
@contextmanager
|
86
|
+
def _spinner_progress(
|
87
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
88
|
+
):
|
89
|
+
"""Context manager for spinner-based progress display."""
|
90
|
+
if step is not None and total_steps is not None:
|
91
|
+
full_message = f"[Step {step}/{total_steps}] {message}"
|
92
|
+
else:
|
93
|
+
full_message = f"[Training] {message}"
|
94
|
+
|
95
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
96
|
+
spinner = Spinner("dots", text=Text(full_message, style="cyan"))
|
97
|
+
with Live(spinner, console=shared_console, refresh_per_second=10):
|
98
|
+
yield
|
99
|
+
else:
|
100
|
+
# Fallback for Jupyter or when Rich is not available
|
101
|
+
print(f"🔄 {full_message}")
|
102
|
+
try:
|
103
|
+
yield
|
104
|
+
finally:
|
105
|
+
print(f"✅ {full_message} - Complete")
|
106
|
+
|
107
|
+
|
108
|
+
@contextmanager
|
109
|
+
def _model_spinner_progress(message: str):
|
110
|
+
"""Context manager for model operation spinner-based progress display."""
|
111
|
+
if RICH_AVAILABLE and not IS_JUPYTER:
|
112
|
+
spinner = Spinner("dots", text=Text(f"[Model] {message}", style="blue"))
|
113
|
+
with Live(spinner, console=shared_console, refresh_per_second=10) as live:
|
114
|
+
|
115
|
+
def update_progress(progress_message: str):
|
116
|
+
"""Update the spinner with a new progress message."""
|
117
|
+
new_text = f"[Model] {message}\n └─ {progress_message}"
|
118
|
+
spinner.text = Text(new_text, style="blue")
|
119
|
+
live.update(spinner)
|
120
|
+
|
121
|
+
yield update_progress
|
122
|
+
else:
|
123
|
+
# Fallback for Jupyter or when Rich is not available
|
124
|
+
print(f"🔵 [Model] {message}")
|
125
|
+
|
126
|
+
def update_progress(progress_message: str):
|
127
|
+
print(f" └─ {progress_message}")
|
128
|
+
|
129
|
+
yield update_progress
|
130
|
+
|
131
|
+
|
132
|
+
def _print_progress(
|
133
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
134
|
+
):
|
135
|
+
"""Print progress message with consistent formatting."""
|
136
|
+
if step is not None and total_steps is not None:
|
137
|
+
safe_print(f"[Step {step}/{total_steps}] {message}", style="green")
|
138
|
+
else:
|
139
|
+
safe_print(f"[Training] {message}", style="green")
|
140
|
+
|
141
|
+
|
142
|
+
def _print_progress_update(
|
143
|
+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
|
144
|
+
):
|
145
|
+
"""Print progress update message (for status changes during long operations)."""
|
146
|
+
safe_print(f" └─ {message}", style="yellow")
|
147
|
+
|
148
|
+
|
149
|
+
def _print_model_progress(message: str):
|
150
|
+
"""Print model progress message with consistent formatting."""
|
151
|
+
safe_print(f"[Model] {message}", style="blue")
|