judgeval 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/cli.py CHANGED
@@ -38,7 +38,7 @@ def upload_scorer(
38
38
  try:
39
39
  client = JudgmentClient()
40
40
 
41
- result = client.save_custom_scorer(
41
+ result = client.upload_custom_scorer(
42
42
  scorer_file_path=scorer_file_path,
43
43
  requirements_file_path=requirements_file_path,
44
44
  unique_name=unique_name,
@@ -51,7 +51,7 @@ JUDGMENT_ADD_TO_RUN_EVAL_QUEUE_API_URL = f"{ROOT_API}/add_to_run_eval_queue/"
51
51
  JUDGMENT_GET_EVAL_STATUS_API_URL = f"{ROOT_API}/get_evaluation_status/"
52
52
 
53
53
  # Custom Scorers API
54
- JUDGMENT_CUSTOM_SCORER_UPLOAD_API_URL = f"{ROOT_API}/build_sandbox_template/"
54
+ JUDGMENT_CUSTOM_SCORER_UPLOAD_API_URL = f"{ROOT_API}/upload_scorer/"
55
55
 
56
56
 
57
57
  # Evaluation API Payloads
@@ -815,6 +815,8 @@ class Tracer:
815
815
  == "true",
816
816
  enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
817
817
  == "true",
818
+ show_trace_urls: bool = os.getenv("JUDGMENT_SHOW_TRACE_URLS", "true").lower()
819
+ == "true",
818
820
  # S3 configuration
819
821
  use_s3: bool = False,
820
822
  s3_bucket_name: Optional[str] = None,
@@ -859,6 +861,7 @@ class Tracer:
859
861
  self.traces: List[Trace] = []
860
862
  self.enable_monitoring: bool = enable_monitoring
861
863
  self.enable_evaluations: bool = enable_evaluations
864
+ self.show_trace_urls: bool = show_trace_urls
862
865
  self.class_identifiers: Dict[
863
866
  str, str
864
867
  ] = {} # Dictionary to store class identifiers
@@ -1731,6 +1734,93 @@ class Tracer:
1731
1734
  f"Error during background service shutdown: {e}"
1732
1735
  )
1733
1736
 
1737
+ def trace_to_message_history(
1738
+ self, trace: Union[Trace, TraceClient]
1739
+ ) -> List[Dict[str, str]]:
1740
+ """
1741
+ Extract message history from a trace for training purposes.
1742
+
1743
+ This method processes trace spans to reconstruct the conversation flow,
1744
+ extracting messages in chronological order from LLM, user, and tool spans.
1745
+
1746
+ Args:
1747
+ trace: Trace or TraceClient instance to extract messages from
1748
+
1749
+ Returns:
1750
+ List of message dictionaries with 'role' and 'content' keys
1751
+
1752
+ Raises:
1753
+ ValueError: If no trace is provided
1754
+ """
1755
+ if not trace:
1756
+ raise ValueError("No trace provided")
1757
+
1758
+ # Handle both Trace and TraceClient objects
1759
+ if isinstance(trace, TraceClient):
1760
+ spans = trace.trace_spans
1761
+ else:
1762
+ spans = trace.trace_spans if hasattr(trace, "trace_spans") else []
1763
+
1764
+ messages = []
1765
+ first_found = False
1766
+
1767
+ # Process spans in chronological order
1768
+ for span in sorted(
1769
+ spans, key=lambda s: s.created_at if hasattr(s, "created_at") else 0
1770
+ ):
1771
+ # Skip spans without output (except for first LLM span which may have input messages)
1772
+ if span.output is None and span.span_type != "llm":
1773
+ continue
1774
+
1775
+ if span.span_type == "llm":
1776
+ # For the first LLM span, extract input messages (system + user prompts)
1777
+ if not first_found and hasattr(span, "inputs") and span.inputs:
1778
+ input_messages = span.inputs.get("messages", [])
1779
+ if input_messages:
1780
+ first_found = True
1781
+ # Add input messages (typically system and user messages)
1782
+ for msg in input_messages:
1783
+ if (
1784
+ isinstance(msg, dict)
1785
+ and "role" in msg
1786
+ and "content" in msg
1787
+ ):
1788
+ messages.append(
1789
+ {"role": msg["role"], "content": msg["content"]}
1790
+ )
1791
+
1792
+ # Add assistant response from span output
1793
+ if span.output is not None:
1794
+ messages.append({"role": "assistant", "content": str(span.output)})
1795
+
1796
+ elif span.span_type == "user":
1797
+ # Add user messages
1798
+ if span.output is not None:
1799
+ messages.append({"role": "user", "content": str(span.output)})
1800
+
1801
+ elif span.span_type == "tool":
1802
+ # Add tool responses as user messages (common pattern in training)
1803
+ if span.output is not None:
1804
+ messages.append({"role": "user", "content": str(span.output)})
1805
+
1806
+ return messages
1807
+
1808
+ def get_current_message_history(self) -> List[Dict[str, str]]:
1809
+ """
1810
+ Get message history from the current trace.
1811
+
1812
+ Returns:
1813
+ List of message dictionaries from the current trace context
1814
+
1815
+ Raises:
1816
+ ValueError: If no current trace is found
1817
+ """
1818
+ current_trace = self.get_current_trace()
1819
+ if not current_trace:
1820
+ raise ValueError("No current trace found")
1821
+
1822
+ return self.trace_to_message_history(current_trace)
1823
+
1734
1824
 
1735
1825
  def _get_current_trace(
1736
1826
  trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
@@ -1746,7 +1836,7 @@ def wrap(
1746
1836
  ) -> Any:
1747
1837
  """
1748
1838
  Wraps an API client to add tracing capabilities.
1749
- Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1839
+ Supports OpenAI, Together, Anthropic, Google GenAI clients, and TrainableModel.
1750
1840
  Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1751
1841
  """
1752
1842
  (
@@ -1871,6 +1961,39 @@ def wrap(
1871
1961
  setattr(client.chat.completions, "create", wrapped(original_create))
1872
1962
  elif isinstance(client, (groq_AsyncGroq)):
1873
1963
  setattr(client.chat.completions, "create", wrapped_async(original_create))
1964
+
1965
+ # Check for TrainableModel from judgeval.common.trainer
1966
+ try:
1967
+ from judgeval.common.trainer import TrainableModel
1968
+
1969
+ if isinstance(client, TrainableModel):
1970
+ # Define a wrapper function that can be reapplied to new model instances
1971
+ def wrap_model_instance(model_instance):
1972
+ """Wrap a model instance with tracing functionality"""
1973
+ if hasattr(model_instance, "chat") and hasattr(
1974
+ model_instance.chat, "completions"
1975
+ ):
1976
+ if hasattr(model_instance.chat.completions, "create"):
1977
+ setattr(
1978
+ model_instance.chat.completions,
1979
+ "create",
1980
+ wrapped(model_instance.chat.completions.create),
1981
+ )
1982
+ if hasattr(model_instance.chat.completions, "acreate"):
1983
+ setattr(
1984
+ model_instance.chat.completions,
1985
+ "acreate",
1986
+ wrapped_async(model_instance.chat.completions.acreate),
1987
+ )
1988
+
1989
+ # Register the wrapper function with the TrainableModel
1990
+ client._register_tracer_wrapper(wrap_model_instance)
1991
+
1992
+ # Apply wrapping to the current model
1993
+ wrap_model_instance(client._current_model)
1994
+ except ImportError:
1995
+ pass # TrainableModel not available
1996
+
1874
1997
  return client
1875
1998
 
1876
1999
 
@@ -1977,6 +2100,22 @@ def _get_client_config(
1977
2100
  return "GROQ_API_CALL", client.chat.completions.create, None, None, None
1978
2101
  elif isinstance(client, (groq_AsyncGroq)):
1979
2102
  return "GROQ_API_CALL", client.chat.completions.create, None, None, None
2103
+
2104
+ # Check for TrainableModel
2105
+ try:
2106
+ from judgeval.common.trainer import TrainableModel
2107
+
2108
+ if isinstance(client, TrainableModel):
2109
+ return (
2110
+ "FIREWORKS_TRAINABLE_MODEL_CALL",
2111
+ client._current_model.chat.completions.create,
2112
+ None,
2113
+ None,
2114
+ None,
2115
+ )
2116
+ except ImportError:
2117
+ pass # TrainableModel not available
2118
+
1980
2119
  raise ValueError(f"Unsupported client type: {type(client)}")
1981
2120
 
1982
2121
 
@@ -2155,6 +2294,37 @@ def _format_output_data(
2155
2294
  cache_creation_input_tokens,
2156
2295
  )
2157
2296
 
2297
+ # Check for TrainableModel
2298
+ try:
2299
+ from judgeval.common.trainer import TrainableModel
2300
+
2301
+ if isinstance(client, TrainableModel):
2302
+ # TrainableModel uses Fireworks LLM internally, so response format should be similar to OpenAI
2303
+ if (
2304
+ hasattr(response, "model")
2305
+ and hasattr(response, "usage")
2306
+ and hasattr(response, "choices")
2307
+ ):
2308
+ model_name = response.model
2309
+ prompt_tokens = response.usage.prompt_tokens if response.usage else 0
2310
+ completion_tokens = (
2311
+ response.usage.completion_tokens if response.usage else 0
2312
+ )
2313
+ message_content = response.choices[0].message.content
2314
+
2315
+ # Use LiteLLM cost calculation with fireworks_ai prefix
2316
+ # LiteLLM supports Fireworks AI models for cost calculation when prefixed with "fireworks_ai/"
2317
+ fireworks_model_name = f"fireworks_ai/{model_name}"
2318
+ return message_content, _create_usage(
2319
+ fireworks_model_name,
2320
+ prompt_tokens,
2321
+ completion_tokens,
2322
+ cache_read_input_tokens,
2323
+ cache_creation_input_tokens,
2324
+ )
2325
+ except ImportError:
2326
+ pass # TrainableModel not available
2327
+
2158
2328
  judgeval_logger.warning(f"Unsupported client type: {type(client)}")
2159
2329
  return None, None
2160
2330
 
@@ -71,7 +71,12 @@ class TraceManagerClient:
71
71
 
72
72
  server_response = self.api_client.upsert_trace(trace_data)
73
73
 
74
- if not offline_mode and show_link and "ui_results_url" in server_response:
74
+ if (
75
+ not offline_mode
76
+ and show_link
77
+ and "ui_results_url" in server_response
78
+ and self.tracer.show_trace_urls
79
+ ):
75
80
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
76
81
  rprint(pretty_str)
77
82
 
@@ -0,0 +1,5 @@
1
+ from .trainer import JudgmentTrainer
2
+ from .config import TrainerConfig, ModelConfig
3
+ from .trainable_model import TrainableModel
4
+
5
+ __all__ = ["JudgmentTrainer", "TrainerConfig", "ModelConfig", "TrainableModel"]
@@ -0,0 +1,125 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Dict, Any
3
+ import json
4
+
5
+
6
+ @dataclass
7
+ class TrainerConfig:
8
+ """Configuration class for JudgmentTrainer parameters."""
9
+
10
+ deployment_id: str
11
+ user_id: str
12
+ model_id: str
13
+ base_model_name: str = "qwen2p5-7b-instruct"
14
+ rft_provider: str = "fireworks"
15
+ num_steps: int = 5
16
+ num_generations_per_prompt: int = (
17
+ 4 # Number of rollouts/generations per input prompt
18
+ )
19
+ num_prompts_per_step: int = 4 # Number of input prompts to sample per training step
20
+ concurrency: int = 100
21
+ epochs: int = 1
22
+ learning_rate: float = 1e-5
23
+ accelerator_count: int = 1
24
+ accelerator_type: str = "NVIDIA_A100_80GB"
25
+ temperature: float = 1.5
26
+ max_tokens: int = 50
27
+ enable_addons: bool = True
28
+
29
+
30
+ @dataclass
31
+ class ModelConfig:
32
+ """
33
+ Configuration class for storing and loading trained model state.
34
+
35
+ This class enables persistence of trained models so they can be loaded
36
+ and used later without retraining.
37
+
38
+ Example usage:
39
+ trainer = JudgmentTrainer(config)
40
+ model_config = trainer.train(agent_function, scorers, prompts)
41
+
42
+ # Save the trained model configuration
43
+ model_config.save_to_file("my_trained_model.json")
44
+
45
+ # Later, load and use the trained model
46
+ loaded_config = ModelConfig.load_from_file("my_trained_model.json")
47
+ trained_model = TrainableModel.from_model_config(loaded_config)
48
+
49
+ # Use the trained model for inference
50
+ response = trained_model.chat.completions.create(
51
+ model="current", # Uses the loaded trained model
52
+ messages=[{"role": "user", "content": "Hello!"}]
53
+ )
54
+ """
55
+
56
+ # Base model configuration
57
+ base_model_name: str
58
+ deployment_id: str
59
+ user_id: str
60
+ model_id: str
61
+ enable_addons: bool
62
+
63
+ # Training state
64
+ current_step: int
65
+ total_steps: int
66
+
67
+ # Current model information
68
+ current_model_name: Optional[str] = None
69
+ is_trained: bool = False
70
+
71
+ # Training parameters used (for reference)
72
+ training_params: Optional[Dict[str, Any]] = None
73
+
74
+ def to_dict(self) -> Dict[str, Any]:
75
+ """Convert ModelConfig to dictionary for serialization."""
76
+ return {
77
+ "base_model_name": self.base_model_name,
78
+ "deployment_id": self.deployment_id,
79
+ "user_id": self.user_id,
80
+ "model_id": self.model_id,
81
+ "enable_addons": self.enable_addons,
82
+ "current_step": self.current_step,
83
+ "total_steps": self.total_steps,
84
+ "current_model_name": self.current_model_name,
85
+ "is_trained": self.is_trained,
86
+ "training_params": self.training_params,
87
+ }
88
+
89
+ @classmethod
90
+ def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig":
91
+ """Create ModelConfig from dictionary."""
92
+ return cls(
93
+ base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
94
+ deployment_id=data.get("deployment_id", "my-base-deployment"),
95
+ user_id=data.get("user_id", ""),
96
+ model_id=data.get("model_id", ""),
97
+ enable_addons=data.get("enable_addons", True),
98
+ current_step=data.get("current_step", 0),
99
+ total_steps=data.get("total_steps", 0),
100
+ current_model_name=data.get("current_model_name"),
101
+ is_trained=data.get("is_trained", False),
102
+ training_params=data.get("training_params"),
103
+ )
104
+
105
+ def to_json(self) -> str:
106
+ """Convert ModelConfig to JSON string."""
107
+ return json.dumps(self.to_dict(), indent=2)
108
+
109
+ @classmethod
110
+ def from_json(cls, json_str: str) -> "ModelConfig":
111
+ """Create ModelConfig from JSON string."""
112
+ data = json.loads(json_str)
113
+ return cls.from_dict(data)
114
+
115
+ def save_to_file(self, filepath: str):
116
+ """Save ModelConfig to a JSON file."""
117
+ with open(filepath, "w") as f:
118
+ f.write(self.to_json())
119
+
120
+ @classmethod
121
+ def load_from_file(cls, filepath: str) -> "ModelConfig":
122
+ """Load ModelConfig from a JSON file."""
123
+ with open(filepath, "r") as f:
124
+ json_str = f.read()
125
+ return cls.from_json(json_str)
@@ -0,0 +1,151 @@
1
+ from contextlib import contextmanager
2
+ from typing import Optional
3
+ import sys
4
+ import os
5
+
6
+
7
+ # Detect if we're running in a Jupyter environment
8
+ def _is_jupyter_environment():
9
+ """Check if we're running in a Jupyter notebook or similar environment."""
10
+ try:
11
+ # Check for IPython kernel
12
+ if "ipykernel" in sys.modules or "IPython" in sys.modules:
13
+ return True
14
+ # Check for Jupyter environment variables
15
+ if "JPY_PARENT_PID" in os.environ:
16
+ return True
17
+ # Check if we're in Google Colab
18
+ if "google.colab" in sys.modules:
19
+ return True
20
+ return False
21
+ except Exception:
22
+ return False
23
+
24
+
25
+ # Check environment once at import time
26
+ IS_JUPYTER = _is_jupyter_environment()
27
+
28
+ if not IS_JUPYTER:
29
+ # Safe to use Rich in non-Jupyter environments
30
+ try:
31
+ from rich.console import Console
32
+ from rich.spinner import Spinner
33
+ from rich.live import Live
34
+ from rich.text import Text
35
+
36
+ # Shared console instance for the trainer module to avoid conflicts
37
+ shared_console = Console()
38
+ RICH_AVAILABLE = True
39
+ except ImportError:
40
+ RICH_AVAILABLE = False
41
+ else:
42
+ # In Jupyter, avoid Rich to prevent recursion issues
43
+ RICH_AVAILABLE = False
44
+
45
+
46
+ # Fallback implementations for when Rich is not available or safe
47
+ class SimpleSpinner:
48
+ def __init__(self, name, text):
49
+ self.text = text
50
+
51
+
52
+ class SimpleLive:
53
+ def __init__(self, spinner, console=None, refresh_per_second=None):
54
+ self.spinner = spinner
55
+
56
+ def __enter__(self):
57
+ print(f"🔄 {self.spinner.text}")
58
+ return self
59
+
60
+ def __exit__(self, *args):
61
+ pass
62
+
63
+ def update(self, spinner):
64
+ print(f"🔄 {spinner.text}")
65
+
66
+
67
+ def safe_print(message, style=None):
68
+ """Safe print function that works in all environments."""
69
+ if RICH_AVAILABLE and not IS_JUPYTER:
70
+ shared_console.print(message, style=style)
71
+ else:
72
+ # Use simple print with emoji indicators for different styles
73
+ if style == "green":
74
+ print(f"✅ {message}")
75
+ elif style == "yellow":
76
+ print(f"⚠️ {message}")
77
+ elif style == "blue":
78
+ print(f"🔵 {message}")
79
+ elif style == "cyan":
80
+ print(f"🔷 {message}")
81
+ else:
82
+ print(message)
83
+
84
+
85
+ @contextmanager
86
+ def _spinner_progress(
87
+ message: str, step: Optional[int] = None, total_steps: Optional[int] = None
88
+ ):
89
+ """Context manager for spinner-based progress display."""
90
+ if step is not None and total_steps is not None:
91
+ full_message = f"[Step {step}/{total_steps}] {message}"
92
+ else:
93
+ full_message = f"[Training] {message}"
94
+
95
+ if RICH_AVAILABLE and not IS_JUPYTER:
96
+ spinner = Spinner("dots", text=Text(full_message, style="cyan"))
97
+ with Live(spinner, console=shared_console, refresh_per_second=10):
98
+ yield
99
+ else:
100
+ # Fallback for Jupyter or when Rich is not available
101
+ print(f"🔄 {full_message}")
102
+ try:
103
+ yield
104
+ finally:
105
+ print(f"✅ {full_message} - Complete")
106
+
107
+
108
+ @contextmanager
109
+ def _model_spinner_progress(message: str):
110
+ """Context manager for model operation spinner-based progress display."""
111
+ if RICH_AVAILABLE and not IS_JUPYTER:
112
+ spinner = Spinner("dots", text=Text(f"[Model] {message}", style="blue"))
113
+ with Live(spinner, console=shared_console, refresh_per_second=10) as live:
114
+
115
+ def update_progress(progress_message: str):
116
+ """Update the spinner with a new progress message."""
117
+ new_text = f"[Model] {message}\n └─ {progress_message}"
118
+ spinner.text = Text(new_text, style="blue")
119
+ live.update(spinner)
120
+
121
+ yield update_progress
122
+ else:
123
+ # Fallback for Jupyter or when Rich is not available
124
+ print(f"🔵 [Model] {message}")
125
+
126
+ def update_progress(progress_message: str):
127
+ print(f" └─ {progress_message}")
128
+
129
+ yield update_progress
130
+
131
+
132
+ def _print_progress(
133
+ message: str, step: Optional[int] = None, total_steps: Optional[int] = None
134
+ ):
135
+ """Print progress message with consistent formatting."""
136
+ if step is not None and total_steps is not None:
137
+ safe_print(f"[Step {step}/{total_steps}] {message}", style="green")
138
+ else:
139
+ safe_print(f"[Training] {message}", style="green")
140
+
141
+
142
+ def _print_progress_update(
143
+ message: str, step: Optional[int] = None, total_steps: Optional[int] = None
144
+ ):
145
+ """Print progress update message (for status changes during long operations)."""
146
+ safe_print(f" └─ {message}", style="yellow")
147
+
148
+
149
+ def _print_model_progress(message: str):
150
+ """Print model progress message with consistent formatting."""
151
+ safe_print(f"[Model] {message}", style="blue")