langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,15 @@ import pyglove as pg
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class MetricValue(pg.Object):
|
|
23
|
-
"""Base class for metric values.
|
|
23
|
+
"""Base class for metric values.
|
|
24
|
+
|
|
25
|
+
`MetricValue` is the base class for representing aggregated metric values
|
|
26
|
+
in an evaluation. It accumulates data points from individual examples,
|
|
27
|
+
each consisting of a value and an optional weight, associated with an example
|
|
28
|
+
ID. Subclasses must implement `reduce` method to compute a single float value
|
|
29
|
+
from accumulated data points, and `scalar_repr` to provide a string
|
|
30
|
+
representation of the reduced value.
|
|
31
|
+
"""
|
|
24
32
|
|
|
25
33
|
class DataPoint(pg.Object):
|
|
26
34
|
"""A data point for a metric value."""
|
|
@@ -88,6 +96,14 @@ class MetricValue(pg.Object):
|
|
|
88
96
|
self.increment_total()
|
|
89
97
|
return self
|
|
90
98
|
|
|
99
|
+
def merge_from(self, other: 'MetricValue') -> 'MetricValue':
|
|
100
|
+
"""Merges the values from another metric value."""
|
|
101
|
+
self._weighted_sum += other._weighted_sum # pylint: disable=protected-access
|
|
102
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
|
103
|
+
self.data_points.extend(other.data_points)
|
|
104
|
+
self.increment_total(other.total)
|
|
105
|
+
return self
|
|
106
|
+
|
|
91
107
|
def __gt__(self, other: Union['MetricValue', float]) -> bool:
|
|
92
108
|
if isinstance(other, self.__class__):
|
|
93
109
|
return float(self) > float(other)
|
|
@@ -133,7 +149,13 @@ class MetricValue(pg.Object):
|
|
|
133
149
|
|
|
134
150
|
|
|
135
151
|
class Rate(MetricValue):
|
|
136
|
-
"""
|
|
152
|
+
"""Metric value representing a rate in range [0, 1].
|
|
153
|
+
|
|
154
|
+
`Rate` is used for metrics that compute a rate, such as accuracy or error
|
|
155
|
+
rate. The final value is computed as the weighted sum of accumulated values
|
|
156
|
+
divided by the total number of examples. It's displayed as a percentage
|
|
157
|
+
(e.g., 90.0%).
|
|
158
|
+
"""
|
|
137
159
|
|
|
138
160
|
def reduce(self) -> float:
|
|
139
161
|
return self._weighted_sum / self.total
|
|
@@ -145,7 +167,13 @@ class Rate(MetricValue):
|
|
|
145
167
|
|
|
146
168
|
|
|
147
169
|
class Average(MetricValue):
|
|
148
|
-
"""
|
|
170
|
+
"""Metric value representing an average of accumulated values.
|
|
171
|
+
|
|
172
|
+
`Average` is used for metrics that compute an average score across examples
|
|
173
|
+
(e.g., average quality score). The final value is computed as the weighted
|
|
174
|
+
sum of accumulated values divided by the number of data points.
|
|
175
|
+
It's displayed as a float with 3 decimal places (e.g., 4.750).
|
|
176
|
+
"""
|
|
149
177
|
|
|
150
178
|
def reduce(self) -> float:
|
|
151
179
|
if not self.data_points:
|
|
@@ -51,6 +51,22 @@ class RateTest(unittest.TestCase):
|
|
|
51
51
|
self.assertEqual(rate.total, 0)
|
|
52
52
|
self.assertTrue(math.isnan(float(rate)))
|
|
53
53
|
|
|
54
|
+
def test_merge_from(self):
|
|
55
|
+
rate1 = metric_values.Rate()
|
|
56
|
+
rate1.add(1, 1.0, 1.0, increment_total=True)
|
|
57
|
+
rate2 = metric_values.Rate()
|
|
58
|
+
rate2.add(2, 0.0, 1.0, increment_total=True)
|
|
59
|
+
rate1.merge_from(rate2)
|
|
60
|
+
self.assertEqual(rate1.total, 2)
|
|
61
|
+
self.assertEqual(float(rate1), 0.5)
|
|
62
|
+
self.assertEqual(
|
|
63
|
+
rate1.data_points,
|
|
64
|
+
[
|
|
65
|
+
metric_values.MetricValue.DataPoint(1, 1.0, 1.0),
|
|
66
|
+
metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
|
|
67
|
+
],
|
|
68
|
+
)
|
|
69
|
+
|
|
54
70
|
|
|
55
71
|
class AverageTest(unittest.TestCase):
|
|
56
72
|
|
|
@@ -75,6 +91,22 @@ class AverageTest(unittest.TestCase):
|
|
|
75
91
|
average.reset()
|
|
76
92
|
self.assertEqual(average.total, 0)
|
|
77
93
|
|
|
94
|
+
def test_merge_from(self):
|
|
95
|
+
avg1 = metric_values.Average()
|
|
96
|
+
avg1.add(1, 1.0, 0.5, increment_total=True)
|
|
97
|
+
avg2 = metric_values.Average()
|
|
98
|
+
avg2.add(2, 0.0, 1.0, increment_total=True)
|
|
99
|
+
avg1.merge_from(avg2)
|
|
100
|
+
self.assertEqual(avg1.total, 2)
|
|
101
|
+
self.assertEqual(float(avg1), 0.25)
|
|
102
|
+
self.assertEqual(
|
|
103
|
+
avg1.data_points,
|
|
104
|
+
[
|
|
105
|
+
metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
|
|
106
|
+
metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
|
|
107
|
+
],
|
|
108
|
+
)
|
|
109
|
+
|
|
78
110
|
|
|
79
111
|
if __name__ == '__main__':
|
|
80
112
|
unittest.main()
|
langfun/core/eval/v2/metrics.py
CHANGED
|
@@ -29,7 +29,15 @@ Average = metric_values.Average
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
32
|
-
"""Interface for an evaluation metric.
|
|
32
|
+
"""Interface for an evaluation metric.
|
|
33
|
+
|
|
34
|
+
A metric is used to evaluate the quality of the outputs produced by an
|
|
35
|
+
evaluation. It works by auditing each processed example via its `audit`
|
|
36
|
+
method, which in turn calls the user-overridable `_audit` method to perform
|
|
37
|
+
metric-specific logic and update metric values. Metrics can compute multiple
|
|
38
|
+
values (e.g., precision, recall, F1 score) which are exposed via the
|
|
39
|
+
`values` method.
|
|
40
|
+
"""
|
|
33
41
|
|
|
34
42
|
name: Annotated[
|
|
35
43
|
str,
|
|
@@ -44,24 +52,43 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
44
52
|
self._label_group = None
|
|
45
53
|
self._lock = threading.Lock()
|
|
46
54
|
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
with
|
|
53
|
-
for v in self.values():
|
|
54
|
-
v.increment_total()
|
|
55
|
+
def update(
|
|
56
|
+
self,
|
|
57
|
+
example: example_lib.Example,
|
|
58
|
+
force_recompute: bool = False
|
|
59
|
+
) -> dict[str, Any]:
|
|
60
|
+
"""Updates metric values with a processed example.
|
|
55
61
|
|
|
56
|
-
|
|
62
|
+
Args:
|
|
63
|
+
example: The processed example.
|
|
64
|
+
force_recompute: Whether to force recompute the metric metadata even if
|
|
65
|
+
they are already present.
|
|
57
66
|
|
|
58
|
-
|
|
59
|
-
|
|
67
|
+
Returns:
|
|
68
|
+
A dict of metric metadata.
|
|
69
|
+
"""
|
|
70
|
+
if (force_recompute
|
|
71
|
+
or example.metric_metadata is None
|
|
72
|
+
or self.name not in example.metric_metadata):
|
|
73
|
+
metadata = self.compute_metric_metadata(example)
|
|
74
|
+
else:
|
|
75
|
+
metadata = example.metric_metadata[self.name]
|
|
76
|
+
self.update_metric_values(example.id, metadata)
|
|
77
|
+
self._update_view()
|
|
78
|
+
return metadata
|
|
60
79
|
|
|
61
80
|
@abc.abstractmethod
|
|
62
|
-
def
|
|
81
|
+
def compute_metric_metadata(
|
|
82
|
+
self, example: example_lib.Example
|
|
83
|
+
) -> dict[str, Any]:
|
|
63
84
|
"""Subclasses should override this method to implement the metric logic."""
|
|
64
85
|
|
|
86
|
+
@abc.abstractmethod
|
|
87
|
+
def update_metric_values(
|
|
88
|
+
self, example_id: int, metric_metadata: dict[str, Any]
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Update metric values based on metric metadata."""
|
|
91
|
+
|
|
65
92
|
@abc.abstractmethod
|
|
66
93
|
def values(self) -> list[metric_values.MetricValue]:
|
|
67
94
|
"""Returns all the values computed by this metric."""
|
|
@@ -71,6 +98,12 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
71
98
|
for v in self.values():
|
|
72
99
|
v.reset()
|
|
73
100
|
|
|
101
|
+
def merge_from(self, other: 'Metric') -> 'Metric':
|
|
102
|
+
"""Merges the values from another metric."""
|
|
103
|
+
for v1, v2 in zip(self.values(), other.values()):
|
|
104
|
+
v1.merge_from(v2)
|
|
105
|
+
return self
|
|
106
|
+
|
|
74
107
|
def _update_view(self):
|
|
75
108
|
"""Refreshes the metric values."""
|
|
76
109
|
if self._label_group is None:
|
|
@@ -169,7 +202,15 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
169
202
|
|
|
170
203
|
|
|
171
204
|
class MetricBase(Metric):
|
|
172
|
-
"""Base class for common metrics.
|
|
205
|
+
"""Base class for common metrics.
|
|
206
|
+
|
|
207
|
+
`MetricBase` provides common functionalities for metrics, such as automatic
|
|
208
|
+
error counting based on whether an example has an error during evaluation.
|
|
209
|
+
It distinguishes between Object-Oriented Programming (OOP) errors
|
|
210
|
+
(e.g. `MappingError` during structured output generation) and other errors.
|
|
211
|
+
Subclasses should implement `_audit_processed` for metric computation on
|
|
212
|
+
successfully processed examples.
|
|
213
|
+
"""
|
|
173
214
|
|
|
174
215
|
oop_errors: Rate | None = Rate()
|
|
175
216
|
non_oop_errors: Rate | None = Rate()
|
|
@@ -183,27 +224,67 @@ class MetricBase(Metric):
|
|
|
183
224
|
super().reset()
|
|
184
225
|
self._error_breakdown = collections.defaultdict(list)
|
|
185
226
|
|
|
186
|
-
def
|
|
187
|
-
|
|
227
|
+
def compute_metric_metadata(
|
|
228
|
+
self, example: example_lib.Example
|
|
229
|
+
) -> dict[str, Any]:
|
|
230
|
+
"""Computes the metric metadata for the example."""
|
|
188
231
|
if example.error is None:
|
|
189
|
-
return self.
|
|
232
|
+
return self._compute_metric_metadata(example)
|
|
233
|
+
return self._compute_metric_metadata_with_processing_error(example)
|
|
234
|
+
|
|
235
|
+
def update_metric_values(
|
|
236
|
+
self,
|
|
237
|
+
example_id: int,
|
|
238
|
+
metric_metadata: dict[str, Any]
|
|
239
|
+
) -> None:
|
|
240
|
+
"""Collects the metric metadata."""
|
|
241
|
+
# NOTE(daiyip): the metric values are being updated concurrently, so we
|
|
242
|
+
# uses a lock to avoid race condition. We might consider relaxing the lock
|
|
243
|
+
# later if metric auditing becomes a bottleneck.
|
|
244
|
+
with self._lock:
|
|
245
|
+
for v in self.values():
|
|
246
|
+
v.increment_total()
|
|
247
|
+
|
|
248
|
+
if 'error' in metric_metadata:
|
|
249
|
+
self._update_metric_values_with_processing_error(
|
|
250
|
+
example_id, metric_metadata
|
|
251
|
+
)
|
|
190
252
|
else:
|
|
191
|
-
|
|
253
|
+
self._update_metric_values(example_id, metric_metadata)
|
|
254
|
+
|
|
255
|
+
@abc.abstractmethod
|
|
256
|
+
def _compute_metric_metadata(
|
|
257
|
+
self,
|
|
258
|
+
example: example_lib.Example
|
|
259
|
+
) -> dict[str, Any]:
|
|
260
|
+
"""Computes the metric metadata for the example."""
|
|
192
261
|
|
|
193
|
-
def
|
|
262
|
+
def _compute_metric_metadata_with_processing_error(
|
|
263
|
+
self,
|
|
264
|
+
example: example_lib.Example
|
|
265
|
+
) -> dict[str, Any]:
|
|
194
266
|
"""Audits the evaluation example after processing."""
|
|
195
267
|
assert example.error is not None
|
|
196
|
-
|
|
197
|
-
if tag.startswith('MappingError'):
|
|
198
|
-
self.oop_errors.add(example.id, 1)
|
|
199
|
-
else:
|
|
200
|
-
self.non_oop_errors.add(example.id, 1)
|
|
201
|
-
self._error_breakdown[tag].append(example.id)
|
|
202
|
-
return dict(error=tag)
|
|
268
|
+
return dict(error=example.error.tag)
|
|
203
269
|
|
|
204
270
|
@abc.abstractmethod
|
|
205
|
-
def
|
|
206
|
-
"""
|
|
271
|
+
def _update_metric_values(self, metadata: dict[str, Any]) -> None:
|
|
272
|
+
"""Update metric values based metric metadata."""
|
|
273
|
+
|
|
274
|
+
def _update_metric_values_with_processing_error(
|
|
275
|
+
self,
|
|
276
|
+
example_id: int,
|
|
277
|
+
metric_metadata: dict[str, Any]
|
|
278
|
+
) -> None:
|
|
279
|
+
"""Updates metric values with processing error."""
|
|
280
|
+
error_tag = metric_metadata.get('error')
|
|
281
|
+
assert error_tag is not None, (example_id, metric_metadata)
|
|
282
|
+
self._error_breakdown[error_tag].append(example_id)
|
|
283
|
+
if error_tag.startswith('MappingError'):
|
|
284
|
+
self.oop_errors.add(example_id, 1)
|
|
285
|
+
else:
|
|
286
|
+
self.non_oop_errors.add(example_id, 1)
|
|
287
|
+
self._error_breakdown[error_tag].append(example_id)
|
|
207
288
|
|
|
208
289
|
def _oop_errors_breakdown(self) -> str | None:
|
|
209
290
|
"""Returns the OOP error breakdown as a string."""
|
|
@@ -229,7 +310,13 @@ class MetricBase(Metric):
|
|
|
229
310
|
|
|
230
311
|
|
|
231
312
|
class Match(MetricBase):
|
|
232
|
-
"""Metric for matching outputs against
|
|
313
|
+
"""Metric for matching outputs against ground truth.
|
|
314
|
+
|
|
315
|
+
This metric computes match and mismatch rates by comparing the output of
|
|
316
|
+
an example with its ground truth. By default, it looks for a `groundtruth`
|
|
317
|
+
attribute in `example.input` for comparison. Users can customize this behavior
|
|
318
|
+
by subclassing `Match` and overriding the `match` method.
|
|
319
|
+
"""
|
|
233
320
|
|
|
234
321
|
name = 'match'
|
|
235
322
|
matches: Rate = Rate()
|
|
@@ -257,20 +344,30 @@ class Match(MetricBase):
|
|
|
257
344
|
)
|
|
258
345
|
return pg.eq(output, groundtruth)
|
|
259
346
|
|
|
260
|
-
def
|
|
261
|
-
|
|
347
|
+
def _compute_metric_metadata(
|
|
348
|
+
self, example: example_lib.Example
|
|
349
|
+
) -> dict[str, Any]:
|
|
350
|
+
"""Computes the metric metadata for the example."""
|
|
262
351
|
metadata = {}
|
|
263
|
-
|
|
264
|
-
if isinstance(
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
metadata['match'] = True
|
|
269
|
-
else:
|
|
270
|
-
self.mismatches.add(example.id, 1)
|
|
271
|
-
metadata['mismatch'] = True
|
|
352
|
+
is_correct = self.match(example.input, example.output)
|
|
353
|
+
if isinstance(is_correct, tuple):
|
|
354
|
+
is_correct, metadata = is_correct
|
|
355
|
+
|
|
356
|
+
metadata['is_correct'] = is_correct
|
|
272
357
|
return metadata
|
|
273
358
|
|
|
359
|
+
def _update_metric_values(
|
|
360
|
+
self, example_id: int, metadata: dict[str, Any]
|
|
361
|
+
) -> None:
|
|
362
|
+
"""Update metric values based metric metadata."""
|
|
363
|
+
is_correct = metadata.get('is_correct')
|
|
364
|
+
assert is_correct is not None, (example_id, metadata)
|
|
365
|
+
if is_correct:
|
|
366
|
+
self.matches.add(example_id, 1)
|
|
367
|
+
else:
|
|
368
|
+
assert not is_correct
|
|
369
|
+
self.mismatches.add(example_id, 1)
|
|
370
|
+
|
|
274
371
|
def values(self) -> list[metric_values.MetricValue]:
|
|
275
372
|
"""Returns all the values computed by this metric."""
|
|
276
373
|
return [
|
|
@@ -302,7 +399,14 @@ class Match(MetricBase):
|
|
|
302
399
|
|
|
303
400
|
|
|
304
401
|
class Score(MetricBase):
|
|
305
|
-
"""Base class for scoring.
|
|
402
|
+
"""Base class for scoring metrics.
|
|
403
|
+
|
|
404
|
+
`Score` is a base class for metrics that assign a numerical score to each
|
|
405
|
+
example's output (e.g., evaluating quality on a scale of 1-5).
|
|
406
|
+
It automatically computes the average score across all examples.
|
|
407
|
+
Subclasses must implement the `score` method to define how an example
|
|
408
|
+
should be scored.
|
|
409
|
+
"""
|
|
306
410
|
|
|
307
411
|
name = 'score'
|
|
308
412
|
average_score: Average = Average()
|
|
@@ -322,16 +426,25 @@ class Score(MetricBase):
|
|
|
322
426
|
A float score. Or a tuple of (score, metadata).
|
|
323
427
|
"""
|
|
324
428
|
|
|
325
|
-
def
|
|
326
|
-
|
|
429
|
+
def _compute_metric_metadata(
|
|
430
|
+
self, example: example_lib.Example
|
|
431
|
+
) -> dict[str, Any]:
|
|
432
|
+
"""Computes the metric metadata for the example."""
|
|
327
433
|
metadata = {}
|
|
328
434
|
score = self.score(example.input, example.output)
|
|
329
435
|
if isinstance(score, tuple):
|
|
330
436
|
score, metadata = score
|
|
331
|
-
self.average_score.add(example.id, score)
|
|
332
437
|
metadata['score'] = score
|
|
333
438
|
return metadata
|
|
334
439
|
|
|
440
|
+
def _update_metric_values(
|
|
441
|
+
self, example_id: int, metadata: dict[str, Any]
|
|
442
|
+
) -> None:
|
|
443
|
+
"""Update metric values based metric metadata."""
|
|
444
|
+
score = metadata.get('score')
|
|
445
|
+
assert score is not None, (example_id, metadata)
|
|
446
|
+
self.average_score.add(example_id, score)
|
|
447
|
+
|
|
335
448
|
def values(self) -> list[metric_values.MetricValue]:
|
|
336
449
|
"""Returns all the values computed by this metric."""
|
|
337
450
|
return [
|
|
@@ -25,15 +25,22 @@ class MatchTest(unittest.TestCase):
|
|
|
25
25
|
def test_basic(self):
|
|
26
26
|
m = metrics.Match() # pylint: disable=invalid-name
|
|
27
27
|
self.assertEqual(
|
|
28
|
-
m.
|
|
29
|
-
dict(
|
|
28
|
+
m.update(Example(id=1, input=pg.Dict(groundtruth=1), output=1)),
|
|
29
|
+
dict(is_correct=True)
|
|
30
30
|
)
|
|
31
31
|
self.assertEqual(
|
|
32
|
-
m.
|
|
33
|
-
|
|
32
|
+
m.update(
|
|
33
|
+
Example(
|
|
34
|
+
id=2,
|
|
35
|
+
input=pg.Dict(groundtruth=1),
|
|
36
|
+
output=2,
|
|
37
|
+
metric_metadata=dict(match=dict(is_correct=False, x=1))
|
|
38
|
+
)
|
|
39
|
+
),
|
|
40
|
+
dict(is_correct=False, x=1)
|
|
34
41
|
)
|
|
35
42
|
self.assertEqual(
|
|
36
|
-
m.
|
|
43
|
+
m.update(
|
|
37
44
|
Example(
|
|
38
45
|
id=3,
|
|
39
46
|
input=pg.Dict(groundtruth=1),
|
|
@@ -47,7 +54,7 @@ class MatchTest(unittest.TestCase):
|
|
|
47
54
|
dict(error='ValueError')
|
|
48
55
|
)
|
|
49
56
|
self.assertEqual(
|
|
50
|
-
m.
|
|
57
|
+
m.update(
|
|
51
58
|
Example(
|
|
52
59
|
id=3,
|
|
53
60
|
input=pg.Dict(groundtruth=1),
|
|
@@ -80,7 +87,7 @@ class MatchTest(unittest.TestCase):
|
|
|
80
87
|
def test_bad_case(self):
|
|
81
88
|
m = metrics.Match() # pylint: disable=invalid-name
|
|
82
89
|
with self.assertRaisesRegex(ValueError, '`groundtruth` is not present'):
|
|
83
|
-
m.
|
|
90
|
+
m.update(Example(id=1, input=pg.Dict(x=1), output=1))
|
|
84
91
|
|
|
85
92
|
def test_custom_metadata(self):
|
|
86
93
|
|
|
@@ -90,22 +97,36 @@ class MatchTest(unittest.TestCase):
|
|
|
90
97
|
|
|
91
98
|
m = MyMatch() # pylint: disable=invalid-name
|
|
92
99
|
self.assertEqual(
|
|
93
|
-
m.
|
|
94
|
-
dict(
|
|
100
|
+
m.update(Example(id=1, input=pg.Dict(x=1), output=1)),
|
|
101
|
+
dict(is_correct=True, x=1)
|
|
95
102
|
)
|
|
96
103
|
self.assertEqual(m.matches, 1.0)
|
|
97
104
|
|
|
98
105
|
def test_html_view(self):
|
|
99
106
|
m = metrics.Match() # pylint: disable=invalid-name
|
|
100
|
-
m.
|
|
107
|
+
m.update(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
|
|
101
108
|
self.assertIn(
|
|
102
109
|
'100.0%',
|
|
103
110
|
m.to_html().content,
|
|
104
111
|
)
|
|
105
112
|
with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
|
|
106
|
-
m.
|
|
113
|
+
m.update(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
|
|
107
114
|
self.assertEqual(len(scripts), 12)
|
|
108
115
|
|
|
116
|
+
def test_merge_from(self):
|
|
117
|
+
m1 = metrics.Match()
|
|
118
|
+
m1.update(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
|
|
119
|
+
m2 = metrics.Match()
|
|
120
|
+
m2.update(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
|
|
121
|
+
m1.merge_from(m2)
|
|
122
|
+
self.assertEqual(m1.matches, 0.5)
|
|
123
|
+
self.assertEqual(m1.mismatches, 0.5)
|
|
124
|
+
self.assertEqual(m1.oop_errors, 0.0)
|
|
125
|
+
self.assertEqual(m1.non_oop_errors, 0.0)
|
|
126
|
+
self.assertEqual(m1.matches.total, 2)
|
|
127
|
+
self.assertEqual(len(m1.matches.data_points), 1)
|
|
128
|
+
self.assertEqual(len(m1.mismatches.data_points), 1)
|
|
129
|
+
|
|
109
130
|
|
|
110
131
|
class ScoreTest(unittest.TestCase):
|
|
111
132
|
|
|
@@ -118,15 +139,15 @@ class ScoreTest(unittest.TestCase):
|
|
|
118
139
|
|
|
119
140
|
m = MyScore() # pylint: disable=invalid-name
|
|
120
141
|
self.assertEqual(
|
|
121
|
-
m.
|
|
142
|
+
m.update(Example(id=1, input=pg.Dict(x=1), output=1)),
|
|
122
143
|
dict(score=1 * 1)
|
|
123
144
|
)
|
|
124
145
|
self.assertEqual(
|
|
125
|
-
m.
|
|
146
|
+
m.update(Example(id=2, input=pg.Dict(x=2), output=2)),
|
|
126
147
|
dict(score=2 * 2)
|
|
127
148
|
)
|
|
128
149
|
self.assertEqual(
|
|
129
|
-
m.
|
|
150
|
+
m.update(
|
|
130
151
|
Example(
|
|
131
152
|
id=3,
|
|
132
153
|
input=pg.Dict(x=1),
|
|
@@ -140,7 +161,7 @@ class ScoreTest(unittest.TestCase):
|
|
|
140
161
|
dict(error='ValueError')
|
|
141
162
|
)
|
|
142
163
|
self.assertEqual(
|
|
143
|
-
m.
|
|
164
|
+
m.update(
|
|
144
165
|
Example(
|
|
145
166
|
id=3,
|
|
146
167
|
input=pg.Dict(x=1),
|
|
@@ -176,7 +197,7 @@ class ScoreTest(unittest.TestCase):
|
|
|
176
197
|
|
|
177
198
|
m = MyScore() # pylint: disable=invalid-name
|
|
178
199
|
self.assertEqual(
|
|
179
|
-
m.
|
|
200
|
+
m.update(Example(id=1, input=pg.Dict(x=1), output=1)),
|
|
180
201
|
dict(score=1 * 1, x=1)
|
|
181
202
|
)
|
|
182
203
|
self.assertEqual(m.average_score, 1.0)
|
|
@@ -189,13 +210,13 @@ class ScoreTest(unittest.TestCase):
|
|
|
189
210
|
return example_input.x * output
|
|
190
211
|
|
|
191
212
|
m = MyScore() # pylint: disable=invalid-name
|
|
192
|
-
m.
|
|
213
|
+
m.update(Example(id=1, input=pg.Dict(x=1), output=2))
|
|
193
214
|
self.assertIn(
|
|
194
215
|
'2.000',
|
|
195
216
|
m.to_html().content,
|
|
196
217
|
)
|
|
197
218
|
with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
|
|
198
|
-
m.
|
|
219
|
+
m.update(Example(id=2, input=pg.Dict(x=1), output=2))
|
|
199
220
|
self.assertEqual(len(scripts), 9)
|
|
200
221
|
|
|
201
222
|
|
langfun/core/eval/v2/progress.py
CHANGED
|
@@ -21,7 +21,15 @@ import pyglove as pg
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
24
|
-
"""
|
|
24
|
+
"""Represents and tracks the progress of an evaluation.
|
|
25
|
+
|
|
26
|
+
The `Progress` class maintains counts of processed, failed, and skipped
|
|
27
|
+
items in an evaluation, along with timing information (start time, stop time,
|
|
28
|
+
duration) and an execution summary. It provides properties to check the
|
|
29
|
+
status of the evaluation (e.g., `is_started`, `is_completed`) and methods
|
|
30
|
+
to update progress as items are evaluated.
|
|
31
|
+
It also supports HTML rendering as a progress bar for visualization.
|
|
32
|
+
"""
|
|
25
33
|
|
|
26
34
|
num_total: Annotated[
|
|
27
35
|
int | None,
|
|
@@ -84,6 +92,7 @@ class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
84
92
|
stop_time=None,
|
|
85
93
|
execution_summary=pg.object_utils.TimeIt.StatusSummary(),
|
|
86
94
|
)
|
|
95
|
+
self._progress_bar = None
|
|
87
96
|
|
|
88
97
|
@property
|
|
89
98
|
def num_completed(self) -> int:
|
|
@@ -216,6 +225,27 @@ class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
216
225
|
"""Overrides nondefault values so volatile values are not included."""
|
|
217
226
|
return dict()
|
|
218
227
|
|
|
228
|
+
def merge_from(self, other: 'Progress') -> None:
|
|
229
|
+
"""Merges the progress from another progress."""
|
|
230
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
|
231
|
+
if other.start_time is not None and (
|
|
232
|
+
self.start_time is None or self.start_time > other.start_time):
|
|
233
|
+
self.start_time = other.start_time
|
|
234
|
+
|
|
235
|
+
if other.stop_time is not None and (
|
|
236
|
+
self.stop_time is None or self.stop_time < other.stop_time):
|
|
237
|
+
self.stop_time = other.stop_time
|
|
238
|
+
|
|
239
|
+
if other.num_total is not None:
|
|
240
|
+
if self.num_total is None:
|
|
241
|
+
self.num_total = other.num_total
|
|
242
|
+
else:
|
|
243
|
+
assert self.num_total == other.num_total, (self, other)
|
|
244
|
+
self.num_processed += other.num_processed
|
|
245
|
+
self.num_failed += other.num_failed
|
|
246
|
+
self.num_skipped += other.num_skipped
|
|
247
|
+
self.execution_summary.aggregate(other.execution_summary.breakdown)
|
|
248
|
+
|
|
219
249
|
#
|
|
220
250
|
# HTML view.
|
|
221
251
|
#
|
|
@@ -77,6 +77,33 @@ class ProgressTest(unittest.TestCase):
|
|
|
77
77
|
self.assertTrue(p.is_stopped)
|
|
78
78
|
self.assertIsNotNone(p.stop_time_str)
|
|
79
79
|
|
|
80
|
+
def test_merge_from(self):
|
|
81
|
+
p1 = Progress()
|
|
82
|
+
p1.start(10)
|
|
83
|
+
p1.increment_processed()
|
|
84
|
+
p1.increment_failed()
|
|
85
|
+
p1.stop()
|
|
86
|
+
|
|
87
|
+
p2 = Progress()
|
|
88
|
+
p2.start(10)
|
|
89
|
+
p2.increment_skipped()
|
|
90
|
+
p2.stop()
|
|
91
|
+
|
|
92
|
+
with pg.allow_writable_accessors(True):
|
|
93
|
+
p1.start_time = 2.0
|
|
94
|
+
p1.stop_time = 4.0
|
|
95
|
+
p2.start_time = 1.0
|
|
96
|
+
p2.stop_time = 5.0
|
|
97
|
+
|
|
98
|
+
p1.merge_from(p2)
|
|
99
|
+
self.assertEqual(p1.num_total, 10)
|
|
100
|
+
self.assertEqual(p1.num_processed, 1)
|
|
101
|
+
self.assertEqual(p1.num_failed, 1)
|
|
102
|
+
self.assertEqual(p1.num_skipped, 1)
|
|
103
|
+
self.assertEqual(p1.num_completed, 3)
|
|
104
|
+
self.assertEqual(p1.start_time, 1.0)
|
|
105
|
+
self.assertEqual(p1.stop_time, 5.0)
|
|
106
|
+
|
|
80
107
|
|
|
81
108
|
if __name__ == '__main__':
|
|
82
109
|
unittest.main()
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
"""Tracking evaluation run progress."""
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
from typing import Literal
|
|
17
18
|
import langfun.core as lf
|
|
18
19
|
from langfun.core.eval.v2 import example as example_lib
|
|
19
20
|
from langfun.core.eval.v2 import experiment as experiment_lib
|
|
@@ -24,16 +25,24 @@ Experiment = experiment_lib.Experiment
|
|
|
24
25
|
Example = example_lib.Example
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
def progress_tracker(
|
|
28
|
+
def progress_tracker(
|
|
29
|
+
tracker_type: Literal['tqdm', 'html', 'auto'] = 'auto'
|
|
30
|
+
) -> experiment_lib.Plugin:
|
|
28
31
|
"""Creates a progress tracker as a plugin.
|
|
29
32
|
|
|
30
33
|
Args:
|
|
31
|
-
|
|
34
|
+
tracker_type: The type of progress tracker to use.
|
|
35
|
+
If `tqdm`, force using tqdm for progress update.
|
|
36
|
+
If `html`, force using html for progress update.
|
|
37
|
+
If `auto`, determine it automatically based on the running
|
|
38
|
+
environment (console vs. notebook)
|
|
32
39
|
|
|
33
40
|
Returns:
|
|
34
41
|
The progress tracker plugin.
|
|
35
42
|
"""
|
|
36
|
-
if tqdm or
|
|
43
|
+
if tracker_type == 'tqdm' or (
|
|
44
|
+
tracker_type == 'auto' and not lf.console.under_notebook()
|
|
45
|
+
):
|
|
37
46
|
return _TqdmProgressTracker()
|
|
38
47
|
else:
|
|
39
48
|
return _HtmlProgressTracker()
|
|
@@ -88,8 +97,7 @@ class _TqdmProgressTracker(experiment_lib.Plugin):
|
|
|
88
97
|
self._leaf_progresses = {
|
|
89
98
|
leaf.id: lf.concurrent.ProgressBar.install(
|
|
90
99
|
label=f'[#{i + 1} - {leaf.id}]',
|
|
91
|
-
total=
|
|
92
|
-
if runner.current_run.example_ids else leaf.num_examples),
|
|
100
|
+
total=len(runner.current_run.examples_to_evaluate(leaf)),
|
|
93
101
|
color='cyan',
|
|
94
102
|
status=None
|
|
95
103
|
)
|