judgeval 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +139 -12
- judgeval/api/__init__.py +501 -0
- judgeval/api/api_types.py +344 -0
- judgeval/cli.py +2 -4
- judgeval/constants.py +10 -26
- judgeval/data/evaluation_run.py +49 -26
- judgeval/data/example.py +2 -2
- judgeval/data/judgment_types.py +266 -82
- judgeval/data/result.py +4 -5
- judgeval/data/scorer_data.py +4 -2
- judgeval/data/tool.py +2 -2
- judgeval/data/trace.py +7 -50
- judgeval/data/trace_run.py +7 -4
- judgeval/{dataset.py → dataset/__init__.py} +43 -28
- judgeval/env.py +67 -0
- judgeval/{run_evaluation.py → evaluation/__init__.py} +29 -95
- judgeval/exceptions.py +27 -0
- judgeval/integrations/langgraph/__init__.py +788 -0
- judgeval/judges/__init__.py +2 -2
- judgeval/judges/litellm_judge.py +75 -15
- judgeval/judges/together_judge.py +86 -18
- judgeval/judges/utils.py +7 -21
- judgeval/{common/logger.py → logger.py} +8 -6
- judgeval/scorers/__init__.py +0 -4
- judgeval/scorers/agent_scorer.py +3 -7
- judgeval/scorers/api_scorer.py +8 -13
- judgeval/scorers/base_scorer.py +52 -32
- judgeval/scorers/example_scorer.py +1 -3
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -14
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +45 -20
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +2 -2
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +3 -3
- judgeval/scorers/score.py +21 -31
- judgeval/scorers/trace_api_scorer.py +5 -0
- judgeval/scorers/utils.py +1 -103
- judgeval/tracer/__init__.py +1075 -2
- judgeval/tracer/constants.py +1 -0
- judgeval/tracer/exporters/__init__.py +37 -0
- judgeval/tracer/exporters/s3.py +119 -0
- judgeval/tracer/exporters/store.py +43 -0
- judgeval/tracer/exporters/utils.py +32 -0
- judgeval/tracer/keys.py +67 -0
- judgeval/tracer/llm/__init__.py +1233 -0
- judgeval/{common/tracer → tracer/llm}/providers.py +5 -10
- judgeval/{local_eval_queue.py → tracer/local_eval_queue.py} +15 -10
- judgeval/tracer/managers.py +188 -0
- judgeval/tracer/processors/__init__.py +181 -0
- judgeval/tracer/utils.py +20 -0
- judgeval/trainer/__init__.py +5 -0
- judgeval/{common/trainer → trainer}/config.py +12 -9
- judgeval/{common/trainer → trainer}/console.py +2 -9
- judgeval/{common/trainer → trainer}/trainable_model.py +12 -7
- judgeval/{common/trainer → trainer}/trainer.py +119 -17
- judgeval/utils/async_utils.py +2 -3
- judgeval/utils/decorators.py +24 -0
- judgeval/utils/file_utils.py +37 -4
- judgeval/utils/guards.py +32 -0
- judgeval/utils/meta.py +14 -0
- judgeval/{common/api/json_encoder.py → utils/serialize.py} +7 -1
- judgeval/utils/testing.py +88 -0
- judgeval/utils/url.py +10 -0
- judgeval/{version_check.py → utils/version_check.py} +3 -3
- judgeval/version.py +5 -0
- judgeval/warnings.py +4 -0
- {judgeval-0.8.0.dist-info → judgeval-0.9.0.dist-info}/METADATA +12 -14
- judgeval-0.9.0.dist-info/RECORD +80 -0
- judgeval/clients.py +0 -35
- judgeval/common/__init__.py +0 -13
- judgeval/common/api/__init__.py +0 -3
- judgeval/common/api/api.py +0 -375
- judgeval/common/api/constants.py +0 -186
- judgeval/common/exceptions.py +0 -27
- judgeval/common/storage/__init__.py +0 -6
- judgeval/common/storage/s3_storage.py +0 -97
- judgeval/common/tracer/__init__.py +0 -31
- judgeval/common/tracer/constants.py +0 -22
- judgeval/common/tracer/core.py +0 -2427
- judgeval/common/tracer/otel_exporter.py +0 -108
- judgeval/common/tracer/otel_span_processor.py +0 -188
- judgeval/common/tracer/span_processor.py +0 -37
- judgeval/common/tracer/span_transformer.py +0 -207
- judgeval/common/tracer/trace_manager.py +0 -101
- judgeval/common/trainer/__init__.py +0 -5
- judgeval/common/utils.py +0 -948
- judgeval/integrations/langgraph.py +0 -844
- judgeval/judges/mixture_of_judges.py +0 -287
- judgeval/judgment_client.py +0 -267
- judgeval/rules.py +0 -521
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
- judgeval/utils/alerts.py +0 -93
- judgeval/utils/requests.py +0 -50
- judgeval-0.8.0.dist-info/RECORD +0 -82
- {judgeval-0.8.0.dist-info → judgeval-0.9.0.dist-info}/WHEEL +0 -0
- {judgeval-0.8.0.dist-info → judgeval-0.9.0.dist-info}/entry_points.txt +0 -0
- {judgeval-0.8.0.dist-info → judgeval-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,12 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import logging
|
3
2
|
from typing import Any, TypeAlias
|
4
3
|
|
5
4
|
|
6
|
-
logger = logging.getLogger(__name__)
|
7
|
-
# TODO: Have functions that assert and return the relevant exports when the client is installed.
|
8
|
-
# The method should raise if the user tries to access client information that doesnt exist.
|
9
|
-
|
10
5
|
HAS_OPENAI = False
|
11
6
|
openai_OpenAI = None
|
12
7
|
openai_AsyncOpenAI = None
|
@@ -35,7 +30,7 @@ together_Together = None
|
|
35
30
|
together_AsyncTogether = None
|
36
31
|
|
37
32
|
try:
|
38
|
-
from together import Together, AsyncTogether
|
33
|
+
from together import Together, AsyncTogether # type: ignore[import-untyped]
|
39
34
|
|
40
35
|
together_Together = Together
|
41
36
|
together_AsyncTogether = AsyncTogether
|
@@ -49,7 +44,7 @@ anthropic_Anthropic = None
|
|
49
44
|
anthropic_AsyncAnthropic = None
|
50
45
|
|
51
46
|
try:
|
52
|
-
from anthropic import Anthropic, AsyncAnthropic
|
47
|
+
from anthropic import Anthropic, AsyncAnthropic # type: ignore[import-untyped]
|
53
48
|
|
54
49
|
anthropic_Anthropic = Anthropic
|
55
50
|
anthropic_AsyncAnthropic = AsyncAnthropic
|
@@ -63,8 +58,8 @@ google_genai_Client = None
|
|
63
58
|
google_genai_cleint_AsyncClient = None
|
64
59
|
|
65
60
|
try:
|
66
|
-
from google.genai import Client
|
67
|
-
from google.genai.client import AsyncClient
|
61
|
+
from google.genai import Client # type: ignore[import-untyped]
|
62
|
+
from google.genai.client import AsyncClient # type: ignore[import-untyped]
|
68
63
|
|
69
64
|
google_genai_Client = Client
|
70
65
|
google_genai_AsyncClient = AsyncClient
|
@@ -78,7 +73,7 @@ groq_Groq = None
|
|
78
73
|
groq_AsyncGroq = None
|
79
74
|
|
80
75
|
try:
|
81
|
-
from groq import Groq, AsyncGroq
|
76
|
+
from groq import Groq, AsyncGroq # type: ignore[import-untyped]
|
82
77
|
|
83
78
|
groq_Groq = Groq
|
84
79
|
groq_AsyncGroq = AsyncGroq
|
@@ -10,12 +10,14 @@ import threading
|
|
10
10
|
from typing import Callable, List, Optional
|
11
11
|
import time
|
12
12
|
|
13
|
-
from judgeval.
|
14
|
-
from judgeval.
|
13
|
+
from judgeval.logger import judgeval_logger
|
14
|
+
from judgeval.env import JUDGMENT_MAX_CONCURRENT_EVALUATIONS
|
15
15
|
from judgeval.data import ScoringResult
|
16
16
|
from judgeval.data.evaluation_run import EvaluationRun
|
17
17
|
from judgeval.utils.async_utils import safe_run_async
|
18
18
|
from judgeval.scorers.score import a_execute_scoring
|
19
|
+
from judgeval.api import JudgmentSyncClient
|
20
|
+
from judgeval.env import JUDGMENT_API_KEY, JUDGMENT_ORG_ID
|
19
21
|
|
20
22
|
|
21
23
|
class LocalEvaluationQueue:
|
@@ -26,7 +28,9 @@ class LocalEvaluationQueue:
|
|
26
28
|
"""
|
27
29
|
|
28
30
|
def __init__(
|
29
|
-
self,
|
31
|
+
self,
|
32
|
+
max_concurrent: int = JUDGMENT_MAX_CONCURRENT_EVALUATIONS,
|
33
|
+
num_workers: int = 4,
|
30
34
|
):
|
31
35
|
if num_workers <= 0:
|
32
36
|
raise ValueError("num_workers must be a positive integer.")
|
@@ -35,6 +39,10 @@ class LocalEvaluationQueue:
|
|
35
39
|
self._num_workers = num_workers # Number of worker threads
|
36
40
|
self._worker_threads: List[threading.Thread] = []
|
37
41
|
self._shutdown_event = threading.Event()
|
42
|
+
self._api_client = JudgmentSyncClient(
|
43
|
+
api_key=JUDGMENT_API_KEY,
|
44
|
+
organization_id=JUDGMENT_ORG_ID,
|
45
|
+
)
|
38
46
|
|
39
47
|
def enqueue(self, evaluation_run: EvaluationRun) -> None:
|
40
48
|
"""Add evaluation run to the queue."""
|
@@ -81,13 +89,8 @@ class LocalEvaluationQueue:
|
|
81
89
|
|
82
90
|
def start_workers(
|
83
91
|
self,
|
84
|
-
callback: Optional[Callable[[EvaluationRun, List[ScoringResult]], None]] = None,
|
85
92
|
) -> List[threading.Thread]:
|
86
93
|
"""Start multiple background threads to process runs in parallel.
|
87
|
-
|
88
|
-
Args:
|
89
|
-
callback: Optional function called after each run with (run, results).
|
90
|
-
|
91
94
|
Returns:
|
92
95
|
List of started worker threads.
|
93
96
|
"""
|
@@ -105,8 +108,10 @@ class LocalEvaluationQueue:
|
|
105
108
|
|
106
109
|
try:
|
107
110
|
results = self._process_run(run)
|
108
|
-
|
109
|
-
|
111
|
+
results_dict = [result.model_dump() for result in results]
|
112
|
+
self._api_client.log_eval_results(
|
113
|
+
payload={"results": results_dict, "run": run.model_dump()}
|
114
|
+
)
|
110
115
|
except Exception as exc:
|
111
116
|
judgeval_logger.error(
|
112
117
|
f"Worker {worker_id} error processing {run.eval_name}: {exc}"
|
@@ -0,0 +1,188 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from contextlib import asynccontextmanager, contextmanager
|
4
|
+
from typing import TYPE_CHECKING, Dict, Optional, List, Any
|
5
|
+
from judgeval.tracer.keys import AttributeKeys, InternalAttributeKeys
|
6
|
+
import uuid
|
7
|
+
from judgeval.exceptions import JudgmentRuntimeError
|
8
|
+
from judgeval.tracer.utils import set_span_attribute
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from judgeval.tracer import Tracer
|
12
|
+
|
13
|
+
|
14
|
+
@contextmanager
|
15
|
+
def sync_span_context(
|
16
|
+
tracer: Tracer,
|
17
|
+
name: str,
|
18
|
+
span_attributes: Optional[Dict[str, str]] = None,
|
19
|
+
disable_partial_emit: bool = False,
|
20
|
+
):
|
21
|
+
if span_attributes is None:
|
22
|
+
span_attributes = {}
|
23
|
+
|
24
|
+
current_cost_context = tracer.get_current_cost_context()
|
25
|
+
|
26
|
+
cost_context = {"cumulative_cost": 0.0}
|
27
|
+
|
28
|
+
cost_token = current_cost_context.set(cost_context)
|
29
|
+
|
30
|
+
try:
|
31
|
+
with tracer.get_tracer().start_as_current_span(
|
32
|
+
name=name,
|
33
|
+
attributes=span_attributes,
|
34
|
+
) as span:
|
35
|
+
set_span_attribute(span, AttributeKeys.JUDGMENT_CUMULATIVE_LLM_COST, 0.0)
|
36
|
+
if disable_partial_emit:
|
37
|
+
tracer.judgment_processor.set_internal_attribute(
|
38
|
+
span_context=span.get_span_context(),
|
39
|
+
key=InternalAttributeKeys.DISABLE_PARTIAL_EMIT,
|
40
|
+
value=True,
|
41
|
+
)
|
42
|
+
yield span
|
43
|
+
finally:
|
44
|
+
current_cost_context.reset(cost_token)
|
45
|
+
child_cost = float(cost_context.get("cumulative_cost", 0.0))
|
46
|
+
tracer.add_cost_to_current_context(child_cost)
|
47
|
+
|
48
|
+
|
49
|
+
@asynccontextmanager
|
50
|
+
async def async_span_context(
|
51
|
+
tracer: Tracer,
|
52
|
+
name: str,
|
53
|
+
span_attributes: Optional[Dict[str, str]] = None,
|
54
|
+
disable_partial_emit: bool = False,
|
55
|
+
):
|
56
|
+
if span_attributes is None:
|
57
|
+
span_attributes = {}
|
58
|
+
|
59
|
+
current_cost_context = tracer.get_current_cost_context()
|
60
|
+
|
61
|
+
cost_context = {"cumulative_cost": 0.0}
|
62
|
+
|
63
|
+
cost_token = current_cost_context.set(cost_context)
|
64
|
+
|
65
|
+
try:
|
66
|
+
with tracer.get_tracer().start_as_current_span(
|
67
|
+
name=name,
|
68
|
+
attributes=span_attributes,
|
69
|
+
) as span:
|
70
|
+
set_span_attribute(span, AttributeKeys.JUDGMENT_CUMULATIVE_LLM_COST, 0.0)
|
71
|
+
if disable_partial_emit:
|
72
|
+
tracer.judgment_processor.set_internal_attribute(
|
73
|
+
span_context=span.get_span_context(),
|
74
|
+
key=InternalAttributeKeys.DISABLE_PARTIAL_EMIT,
|
75
|
+
value=True,
|
76
|
+
)
|
77
|
+
yield span
|
78
|
+
finally:
|
79
|
+
current_cost_context.reset(cost_token)
|
80
|
+
child_cost = float(cost_context.get("cumulative_cost", 0.0))
|
81
|
+
tracer.add_cost_to_current_context(child_cost)
|
82
|
+
|
83
|
+
|
84
|
+
def create_agent_context(
|
85
|
+
tracer: Tracer,
|
86
|
+
args: tuple,
|
87
|
+
class_name: Optional[str] = None,
|
88
|
+
identifier: Optional[str] = None,
|
89
|
+
track_state: bool = False,
|
90
|
+
track_attributes: Optional[List[str]] = None,
|
91
|
+
field_mappings: Optional[Dict[str, str]] = None,
|
92
|
+
):
|
93
|
+
"""Create agent context and return token for cleanup"""
|
94
|
+
agent_id = str(uuid.uuid4())
|
95
|
+
agent_context: Dict[str, Any] = {"agent_id": agent_id}
|
96
|
+
|
97
|
+
if class_name:
|
98
|
+
agent_context["class_name"] = class_name
|
99
|
+
else:
|
100
|
+
agent_context["class_name"] = None
|
101
|
+
|
102
|
+
agent_context["track_state"] = track_state
|
103
|
+
agent_context["track_attributes"] = track_attributes or []
|
104
|
+
agent_context["field_mappings"] = field_mappings or {}
|
105
|
+
|
106
|
+
instance = args[0] if args else None
|
107
|
+
agent_context["instance"] = instance
|
108
|
+
|
109
|
+
if identifier:
|
110
|
+
if not class_name or not instance or not isinstance(instance, object):
|
111
|
+
raise JudgmentRuntimeError(
|
112
|
+
"'identifier' is set but no class name or instance is available. 'identifier' can only be specified when using the agent() decorator on a class method."
|
113
|
+
)
|
114
|
+
if (
|
115
|
+
instance
|
116
|
+
and hasattr(instance, identifier)
|
117
|
+
and not callable(getattr(instance, identifier))
|
118
|
+
):
|
119
|
+
instance_name = str(getattr(instance, identifier))
|
120
|
+
agent_context["instance_name"] = instance_name
|
121
|
+
else:
|
122
|
+
raise JudgmentRuntimeError(
|
123
|
+
f"Attribute {identifier} does not exist for {class_name}. Check your agent() decorator."
|
124
|
+
)
|
125
|
+
else:
|
126
|
+
agent_context["instance_name"] = None
|
127
|
+
|
128
|
+
current_agent_context = tracer.get_current_agent_context().get()
|
129
|
+
if current_agent_context and "agent_id" in current_agent_context:
|
130
|
+
agent_context["parent_agent_id"] = current_agent_context["agent_id"]
|
131
|
+
else:
|
132
|
+
agent_context["parent_agent_id"] = None
|
133
|
+
|
134
|
+
agent_context["is_agent_entry_point"] = True
|
135
|
+
token = tracer.get_current_agent_context().set(agent_context) # type: ignore
|
136
|
+
return token
|
137
|
+
|
138
|
+
|
139
|
+
@contextmanager
|
140
|
+
def sync_agent_context(
|
141
|
+
tracer: Tracer,
|
142
|
+
args: tuple,
|
143
|
+
class_name: Optional[str] = None,
|
144
|
+
identifier: Optional[str] = None,
|
145
|
+
track_state: bool = False,
|
146
|
+
track_attributes: Optional[List[str]] = None,
|
147
|
+
field_mappings: Optional[Dict[str, str]] = None,
|
148
|
+
):
|
149
|
+
"""Context manager for synchronous agent context"""
|
150
|
+
token = create_agent_context(
|
151
|
+
tracer=tracer,
|
152
|
+
args=args,
|
153
|
+
class_name=class_name,
|
154
|
+
identifier=identifier,
|
155
|
+
track_state=track_state,
|
156
|
+
track_attributes=track_attributes,
|
157
|
+
field_mappings=field_mappings,
|
158
|
+
)
|
159
|
+
try:
|
160
|
+
yield
|
161
|
+
finally:
|
162
|
+
tracer.get_current_agent_context().reset(token)
|
163
|
+
|
164
|
+
|
165
|
+
@asynccontextmanager
|
166
|
+
async def async_agent_context(
|
167
|
+
tracer: Tracer,
|
168
|
+
args: tuple,
|
169
|
+
class_name: Optional[str] = None,
|
170
|
+
identifier: Optional[str] = None,
|
171
|
+
track_state: bool = False,
|
172
|
+
track_attributes: Optional[List[str]] = None,
|
173
|
+
field_mappings: Optional[Dict[str, str]] = None,
|
174
|
+
):
|
175
|
+
"""Context manager for asynchronous agent context"""
|
176
|
+
token = create_agent_context(
|
177
|
+
tracer=tracer,
|
178
|
+
args=args,
|
179
|
+
class_name=class_name,
|
180
|
+
identifier=identifier,
|
181
|
+
track_state=track_state,
|
182
|
+
track_attributes=track_attributes,
|
183
|
+
field_mappings=field_mappings,
|
184
|
+
)
|
185
|
+
try:
|
186
|
+
yield
|
187
|
+
finally:
|
188
|
+
tracer.get_current_agent_context().reset(token)
|
@@ -0,0 +1,181 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Optional, TYPE_CHECKING, Any
|
3
|
+
from collections import defaultdict
|
4
|
+
from opentelemetry.context import Context
|
5
|
+
from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor, SpanContext
|
6
|
+
from opentelemetry.sdk.trace.export import (
|
7
|
+
BatchSpanProcessor,
|
8
|
+
)
|
9
|
+
from judgeval.tracer.exporters import JudgmentSpanExporter
|
10
|
+
from judgeval.tracer.keys import AttributeKeys, InternalAttributeKeys
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from judgeval.tracer import Tracer
|
14
|
+
|
15
|
+
|
16
|
+
class NoOpSpanProcessor(SpanProcessor):
|
17
|
+
def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None:
|
18
|
+
pass
|
19
|
+
|
20
|
+
def on_end(self, span: ReadableSpan) -> None:
|
21
|
+
pass
|
22
|
+
|
23
|
+
def shutdown(self) -> None:
|
24
|
+
pass
|
25
|
+
|
26
|
+
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
27
|
+
return True
|
28
|
+
|
29
|
+
|
30
|
+
class JudgmentSpanProcessor(BatchSpanProcessor):
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
tracer: Tracer,
|
34
|
+
endpoint: str,
|
35
|
+
api_key: str,
|
36
|
+
organization_id: str,
|
37
|
+
/,
|
38
|
+
*,
|
39
|
+
max_queue_size: int = 2**18,
|
40
|
+
export_timeout_millis: int = 30000,
|
41
|
+
):
|
42
|
+
self.tracer = tracer
|
43
|
+
super().__init__(
|
44
|
+
JudgmentSpanExporter(
|
45
|
+
endpoint=endpoint,
|
46
|
+
api_key=api_key,
|
47
|
+
organization_id=organization_id,
|
48
|
+
),
|
49
|
+
max_queue_size=max_queue_size,
|
50
|
+
export_timeout_millis=export_timeout_millis,
|
51
|
+
)
|
52
|
+
self._internal_attributes: defaultdict[tuple[int, int], dict[str, Any]] = (
|
53
|
+
defaultdict(dict)
|
54
|
+
)
|
55
|
+
|
56
|
+
def _get_span_key(self, span_context: SpanContext) -> tuple[int, int]:
|
57
|
+
return (span_context.trace_id, span_context.span_id)
|
58
|
+
|
59
|
+
def set_internal_attribute(
|
60
|
+
self, span_context: SpanContext, key: str, value: Any
|
61
|
+
) -> None:
|
62
|
+
span_key = self._get_span_key(span_context)
|
63
|
+
self._internal_attributes[span_key][key] = value
|
64
|
+
|
65
|
+
def get_internal_attribute(
|
66
|
+
self, span_context: SpanContext, key: str, default: Any = None
|
67
|
+
) -> Any:
|
68
|
+
span_key = self._get_span_key(span_context)
|
69
|
+
return self._internal_attributes[span_key].get(key, default)
|
70
|
+
|
71
|
+
def increment_update_id(self, span_context: SpanContext) -> int:
|
72
|
+
current_id = self.get_internal_attribute(
|
73
|
+
span_context=span_context, key=AttributeKeys.JUDGMENT_UPDATE_ID, default=0
|
74
|
+
)
|
75
|
+
new_id = current_id + 1
|
76
|
+
self.set_internal_attribute(
|
77
|
+
span_context=span_context,
|
78
|
+
key=AttributeKeys.JUDGMENT_UPDATE_ID,
|
79
|
+
value=new_id,
|
80
|
+
)
|
81
|
+
return current_id
|
82
|
+
|
83
|
+
def _cleanup_span_state(self, span_key: tuple[int, int]) -> None:
|
84
|
+
self._internal_attributes.pop(span_key, None)
|
85
|
+
|
86
|
+
def emit_partial(self) -> None:
|
87
|
+
current_span = self.tracer.get_current_span()
|
88
|
+
if not current_span or not current_span.is_recording():
|
89
|
+
return
|
90
|
+
|
91
|
+
if not isinstance(current_span, ReadableSpan):
|
92
|
+
return
|
93
|
+
|
94
|
+
span_context = current_span.get_span_context()
|
95
|
+
if self.get_internal_attribute(
|
96
|
+
span_context=span_context,
|
97
|
+
key=InternalAttributeKeys.DISABLE_PARTIAL_EMIT,
|
98
|
+
default=False,
|
99
|
+
):
|
100
|
+
return
|
101
|
+
|
102
|
+
current_update_id = self.increment_update_id(span_context=span_context)
|
103
|
+
|
104
|
+
attributes = dict(current_span.attributes or {})
|
105
|
+
attributes[AttributeKeys.JUDGMENT_UPDATE_ID] = current_update_id
|
106
|
+
partial_span = ReadableSpan(
|
107
|
+
name=current_span.name,
|
108
|
+
context=span_context,
|
109
|
+
parent=current_span.parent,
|
110
|
+
resource=current_span.resource,
|
111
|
+
attributes=attributes,
|
112
|
+
events=current_span.events,
|
113
|
+
links=current_span.links,
|
114
|
+
status=current_span.status,
|
115
|
+
kind=current_span.kind,
|
116
|
+
start_time=current_span.start_time,
|
117
|
+
end_time=None,
|
118
|
+
instrumentation_scope=current_span.instrumentation_scope,
|
119
|
+
)
|
120
|
+
|
121
|
+
super().on_end(partial_span)
|
122
|
+
|
123
|
+
def on_end(self, span: ReadableSpan) -> None:
|
124
|
+
if not span.context:
|
125
|
+
super().on_end(span)
|
126
|
+
return
|
127
|
+
|
128
|
+
span_key = self._get_span_key(span.context)
|
129
|
+
|
130
|
+
if self.get_internal_attribute(
|
131
|
+
span.context, InternalAttributeKeys.CANCELLED, False
|
132
|
+
):
|
133
|
+
self._cleanup_span_state(span_key)
|
134
|
+
return
|
135
|
+
|
136
|
+
if span.end_time is not None:
|
137
|
+
attributes = dict(span.attributes or {})
|
138
|
+
attributes[AttributeKeys.JUDGMENT_UPDATE_ID] = 20
|
139
|
+
|
140
|
+
final_span = ReadableSpan(
|
141
|
+
name=span.name,
|
142
|
+
context=span.context,
|
143
|
+
parent=span.parent,
|
144
|
+
resource=span.resource,
|
145
|
+
attributes=attributes,
|
146
|
+
events=span.events,
|
147
|
+
links=span.links,
|
148
|
+
status=span.status,
|
149
|
+
kind=span.kind,
|
150
|
+
start_time=span.start_time,
|
151
|
+
end_time=span.end_time,
|
152
|
+
instrumentation_scope=span.instrumentation_scope,
|
153
|
+
)
|
154
|
+
|
155
|
+
self._cleanup_span_state(span_key)
|
156
|
+
super().on_end(final_span)
|
157
|
+
else:
|
158
|
+
super().on_end(span)
|
159
|
+
|
160
|
+
|
161
|
+
class NoOpJudgmentSpanProcessor(JudgmentSpanProcessor):
|
162
|
+
def __init__(self):
|
163
|
+
super().__init__(None, "", "", "") # type: ignore[arg-type]
|
164
|
+
|
165
|
+
def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None:
|
166
|
+
pass
|
167
|
+
|
168
|
+
def on_end(self, span: ReadableSpan) -> None:
|
169
|
+
pass
|
170
|
+
|
171
|
+
def shutdown(self) -> None:
|
172
|
+
pass
|
173
|
+
|
174
|
+
def force_flush(self, timeout_millis: int | None = 30000) -> bool:
|
175
|
+
return True
|
176
|
+
|
177
|
+
def emit_partial(self) -> None:
|
178
|
+
pass
|
179
|
+
|
180
|
+
|
181
|
+
__all__ = ("NoOpSpanProcessor", "JudgmentSpanProcessor", "NoOpJudgmentSpanProcessor")
|
judgeval/tracer/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import Any
|
2
|
+
from opentelemetry.trace import Span
|
3
|
+
from pydantic import BaseModel
|
4
|
+
from typing import Callable, Optional
|
5
|
+
from judgeval.scorers.trace_api_scorer import TraceAPIScorerConfig
|
6
|
+
from judgeval.env import JUDGMENT_DEFAULT_GPT_MODEL
|
7
|
+
|
8
|
+
|
9
|
+
def set_span_attribute(span: Span, name: str, value: Any):
|
10
|
+
if value is None or value == "":
|
11
|
+
return
|
12
|
+
|
13
|
+
span.set_attribute(name, value)
|
14
|
+
|
15
|
+
|
16
|
+
class TraceScorerConfig(BaseModel):
|
17
|
+
scorer: TraceAPIScorerConfig
|
18
|
+
model: str = JUDGMENT_DEFAULT_GPT_MODEL
|
19
|
+
sampling_rate: float = 1.0
|
20
|
+
run_condition: Optional[Callable[..., bool]] = None
|
@@ -1,7 +1,12 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from dataclasses import dataclass
|
2
|
-
from typing import Optional, Dict, Any
|
4
|
+
from typing import Optional, Dict, Any, TYPE_CHECKING
|
3
5
|
import json
|
4
6
|
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from fireworks.llm.llm_reinforcement_step import ReinforcementAcceleratorTypeLiteral
|
9
|
+
|
5
10
|
|
6
11
|
@dataclass
|
7
12
|
class TrainerConfig:
|
@@ -13,15 +18,13 @@ class TrainerConfig:
|
|
13
18
|
base_model_name: str = "qwen2p5-7b-instruct"
|
14
19
|
rft_provider: str = "fireworks"
|
15
20
|
num_steps: int = 5
|
16
|
-
num_generations_per_prompt: int =
|
17
|
-
|
18
|
-
)
|
19
|
-
num_prompts_per_step: int = 4 # Number of input prompts to sample per training step
|
21
|
+
num_generations_per_prompt: int = 4
|
22
|
+
num_prompts_per_step: int = 4
|
20
23
|
concurrency: int = 100
|
21
24
|
epochs: int = 1
|
22
25
|
learning_rate: float = 1e-5
|
23
26
|
accelerator_count: int = 1
|
24
|
-
accelerator_type:
|
27
|
+
accelerator_type: ReinforcementAcceleratorTypeLiteral = "NVIDIA_A100_80GB"
|
25
28
|
temperature: float = 1.5
|
26
29
|
max_tokens: int = 50
|
27
30
|
enable_addons: bool = True
|
@@ -87,7 +90,7 @@ class ModelConfig:
|
|
87
90
|
}
|
88
91
|
|
89
92
|
@classmethod
|
90
|
-
def from_dict(cls, data: Dict[str, Any]) ->
|
93
|
+
def from_dict(cls, data: Dict[str, Any]) -> ModelConfig:
|
91
94
|
"""Create ModelConfig from dictionary."""
|
92
95
|
return cls(
|
93
96
|
base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
|
@@ -107,7 +110,7 @@ class ModelConfig:
|
|
107
110
|
return json.dumps(self.to_dict(), indent=2)
|
108
111
|
|
109
112
|
@classmethod
|
110
|
-
def from_json(cls, json_str: str) ->
|
113
|
+
def from_json(cls, json_str: str) -> ModelConfig:
|
111
114
|
"""Create ModelConfig from JSON string."""
|
112
115
|
data = json.loads(json_str)
|
113
116
|
return cls.from_dict(data)
|
@@ -118,7 +121,7 @@ class ModelConfig:
|
|
118
121
|
f.write(self.to_json())
|
119
122
|
|
120
123
|
@classmethod
|
121
|
-
def load_from_file(cls, filepath: str) ->
|
124
|
+
def load_from_file(cls, filepath: str) -> ModelConfig:
|
122
125
|
"""Load ModelConfig from a JSON file."""
|
123
126
|
with open(filepath, "r") as f:
|
124
127
|
json_str = f.read()
|
@@ -2,9 +2,10 @@ from contextlib import contextmanager
|
|
2
2
|
from typing import Optional
|
3
3
|
import sys
|
4
4
|
import os
|
5
|
+
from judgeval.utils.decorators import use_once
|
5
6
|
|
6
7
|
|
7
|
-
|
8
|
+
@use_once
|
8
9
|
def _is_jupyter_environment():
|
9
10
|
"""Check if we're running in a Jupyter notebook or similar environment."""
|
10
11
|
try:
|
@@ -22,28 +23,23 @@ def _is_jupyter_environment():
|
|
22
23
|
return False
|
23
24
|
|
24
25
|
|
25
|
-
# Check environment once at import time
|
26
26
|
IS_JUPYTER = _is_jupyter_environment()
|
27
27
|
|
28
28
|
if not IS_JUPYTER:
|
29
|
-
# Safe to use Rich in non-Jupyter environments
|
30
29
|
try:
|
31
30
|
from rich.console import Console
|
32
31
|
from rich.spinner import Spinner
|
33
32
|
from rich.live import Live
|
34
33
|
from rich.text import Text
|
35
34
|
|
36
|
-
# Shared console instance for the trainer module to avoid conflicts
|
37
35
|
shared_console = Console()
|
38
36
|
RICH_AVAILABLE = True
|
39
37
|
except ImportError:
|
40
38
|
RICH_AVAILABLE = False
|
41
39
|
else:
|
42
|
-
# In Jupyter, avoid Rich to prevent recursion issues
|
43
40
|
RICH_AVAILABLE = False
|
44
41
|
|
45
42
|
|
46
|
-
# Fallback implementations for when Rich is not available or safe
|
47
43
|
class SimpleSpinner:
|
48
44
|
def __init__(self, name, text):
|
49
45
|
self.text = text
|
@@ -69,7 +65,6 @@ def safe_print(message, style=None):
|
|
69
65
|
if RICH_AVAILABLE and not IS_JUPYTER:
|
70
66
|
shared_console.print(message, style=style)
|
71
67
|
else:
|
72
|
-
# Use simple print with emoji indicators for different styles
|
73
68
|
if style == "green":
|
74
69
|
print(f"✅ {message}")
|
75
70
|
elif style == "yellow":
|
@@ -97,7 +92,6 @@ def _spinner_progress(
|
|
97
92
|
with Live(spinner, console=shared_console, refresh_per_second=10):
|
98
93
|
yield
|
99
94
|
else:
|
100
|
-
# Fallback for Jupyter or when Rich is not available
|
101
95
|
print(f"🔄 {full_message}")
|
102
96
|
try:
|
103
97
|
yield
|
@@ -120,7 +114,6 @@ def _model_spinner_progress(message: str):
|
|
120
114
|
|
121
115
|
yield update_progress
|
122
116
|
else:
|
123
|
-
# Fallback for Jupyter or when Rich is not available
|
124
117
|
print(f"🔵 [Model] {message}")
|
125
118
|
|
126
119
|
def update_progress(progress_message: str):
|