langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +2 -0
- langfun/core/eval/v2/checkpointing.py +76 -7
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +92 -17
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +84 -15
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +12 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +64 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +58 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +64 -3
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
|
@@ -139,10 +139,10 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
|
|
|
139
139
|
|
|
140
140
|
# Checkpointing
|
|
141
141
|
|
|
142
|
-
Experiments support checkpointing, which is enabled by default. It allows
|
|
142
|
+
Experiments support checkpointing, which is enabled by default. It allows
|
|
143
143
|
users to resume their experiments from a saved state. When an experiment runs,
|
|
144
|
-
it creates a new directory for that run and saves
|
|
145
|
-
|
|
144
|
+
it creates a new directory for that run and saves its progress to checkpoint
|
|
145
|
+
files. If the experiment is interrupted or fails, users can resume
|
|
146
146
|
it by specifying the 'id' or 'warm_start_from' argument (shown above) to
|
|
147
147
|
seamlessly continue from previously saved state without starting over.
|
|
148
148
|
|
|
@@ -169,7 +169,7 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
|
|
|
169
169
|
|
|
170
170
|
# Experiment Plugins
|
|
171
171
|
|
|
172
|
-
|
|
172
|
+
Experiments can be extended by plugins. Plugins can listen to the events of
|
|
173
173
|
experiment execution and produce additional outputs. For example, a plugin
|
|
174
174
|
can be added to an experiment to generate additional metrics or to save
|
|
175
175
|
additional data to a database. More details will be added in the future.
|
|
@@ -657,7 +657,30 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
|
|
|
657
657
|
|
|
658
658
|
@pg.use_init_args(['children'])
|
|
659
659
|
class Suite(Experiment):
|
|
660
|
-
"""A suite of evaluations.
|
|
660
|
+
"""A suite of evaluations.
|
|
661
|
+
|
|
662
|
+
`lf.eval.Suite` groups multiple `lf.eval.Evaluation` or other `Suite`
|
|
663
|
+
objects into a single experiment, allowing them to be run, managed, and
|
|
664
|
+
reported together.
|
|
665
|
+
|
|
666
|
+
**Example:**
|
|
667
|
+
|
|
668
|
+
```python
|
|
669
|
+
import langfun as lf
|
|
670
|
+
|
|
671
|
+
suite = lf.eval.Suite([
|
|
672
|
+
MyEval(lm=lf.llms.Gpt4()),
|
|
673
|
+
MyEval(lm=lf.llms.Gemini()),
|
|
674
|
+
lf.eval.Suite([
|
|
675
|
+
AnotherEval(lm=lf.llms.Gpt4()),
|
|
676
|
+
AnotherEval(lm=lf.llms.Gemini())
|
|
677
|
+
])
|
|
678
|
+
])
|
|
679
|
+
|
|
680
|
+
# Run all evaluations in the suite
|
|
681
|
+
run_info = suite.run('/path/to/my/suite_run')
|
|
682
|
+
```
|
|
683
|
+
"""
|
|
661
684
|
|
|
662
685
|
children: Annotated[
|
|
663
686
|
list[Experiment], 'A list of child experiments.'
|
|
@@ -791,7 +814,14 @@ class RunId(pg.Object):
|
|
|
791
814
|
|
|
792
815
|
|
|
793
816
|
class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
794
|
-
"""
|
|
817
|
+
"""Represents a single run of an experiment.
|
|
818
|
+
|
|
819
|
+
A `Run` object holds all the configurations for executing an experiment,
|
|
820
|
+
such as the experiment definition, input/output directories, and flags
|
|
821
|
+
controlling the execution behavior (e.g., error handling, checkpointing).
|
|
822
|
+
It also provides utility methods for accessing run-specific paths and
|
|
823
|
+
filtering examples for evaluation.
|
|
824
|
+
"""
|
|
795
825
|
|
|
796
826
|
root_dir: Annotated[
|
|
797
827
|
str,
|
|
@@ -818,10 +848,10 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
818
848
|
] = None
|
|
819
849
|
|
|
820
850
|
example_ids: Annotated[
|
|
821
|
-
list[int] | None,
|
|
851
|
+
list[int] | Callable[[Experiment], list[int]] | None,
|
|
822
852
|
(
|
|
823
|
-
'The example IDs to run.
|
|
824
|
-
'
|
|
853
|
+
'The example IDs to run. Or a callable for determining the examples '
|
|
854
|
+
'to run based on the experiment. If None, it will run all examples. '
|
|
825
855
|
)
|
|
826
856
|
] = None
|
|
827
857
|
|
|
@@ -937,10 +967,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
937
967
|
"""Returns the example IDs to evaluate."""
|
|
938
968
|
if not experiment.is_leaf:
|
|
939
969
|
return set()
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
970
|
+
if self.example_ids is None:
|
|
971
|
+
return set(range(1, experiment.num_examples + 1))
|
|
972
|
+
elif isinstance(self.example_ids, Callable):
|
|
973
|
+
return set(self.example_ids(experiment))
|
|
974
|
+
else:
|
|
975
|
+
assert isinstance(self.example_ids, list), self.example_ids
|
|
976
|
+
return set(self.example_ids)
|
|
944
977
|
|
|
945
978
|
def examples_to_reprocess(self, experiment: Experiment) -> set[int]:
|
|
946
979
|
"""Returns the example IDs to reprocess per request."""
|
|
@@ -971,7 +1004,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
971
1004
|
|
|
972
1005
|
|
|
973
1006
|
class Runner(pg.Object):
|
|
974
|
-
"""Interface for experiment runner.
|
|
1007
|
+
"""Interface for experiment runner.
|
|
1008
|
+
|
|
1009
|
+
A runner is responsible for executing the evaluations within an experiment
|
|
1010
|
+
based on the configuration specified in a `Run` object. Different runners
|
|
1011
|
+
can implement different execution strategies, such as sequential or parallel
|
|
1012
|
+
processing of examples and evaluations.
|
|
1013
|
+
"""
|
|
975
1014
|
|
|
976
1015
|
# Class-level variable for registering the runner.
|
|
977
1016
|
NAME = None
|
|
@@ -1010,7 +1049,37 @@ class Runner(pg.Object):
|
|
|
1010
1049
|
|
|
1011
1050
|
|
|
1012
1051
|
class Plugin(lf.Component):
|
|
1013
|
-
"""Base class for experiment plugins.
|
|
1052
|
+
"""Base class for experiment plugins.
|
|
1053
|
+
|
|
1054
|
+
Plugins provide a mechanism to extend the behavior of an experiment run
|
|
1055
|
+
by hooking into various events during the lifecycle of experiment and
|
|
1056
|
+
example execution, such as `on_run_start`, `on_experiment_complete`,
|
|
1057
|
+
`on_example_start`, etc. They can be used for custom logging, monitoring,
|
|
1058
|
+
or result processing.
|
|
1059
|
+
"""
|
|
1060
|
+
|
|
1061
|
+
@classmethod
|
|
1062
|
+
def is_per_example(cls) -> bool:
|
|
1063
|
+
"""Returns whether the plugin is per example only.
|
|
1064
|
+
|
|
1065
|
+
Per-example plugins can be installed on individual workers when examples
|
|
1066
|
+
are evaluated by multiple processes in parallel.
|
|
1067
|
+
"""
|
|
1068
|
+
|
|
1069
|
+
def same_code(method1, method2):
|
|
1070
|
+
return method1.__code__ == method2.__code__
|
|
1071
|
+
return all(
|
|
1072
|
+
same_code(method1, method2)
|
|
1073
|
+
for method1, method2 in [
|
|
1074
|
+
(Plugin.on_run_start, cls.on_run_start),
|
|
1075
|
+
(Plugin.on_run_complete, cls.on_run_complete),
|
|
1076
|
+
(Plugin.on_run_abort, cls.on_run_abort),
|
|
1077
|
+
(Plugin.on_experiment_start, cls.on_experiment_start),
|
|
1078
|
+
(Plugin.on_experiment_skipped, cls.on_experiment_skipped),
|
|
1079
|
+
(Plugin.on_experiment_complete, cls.on_experiment_complete),
|
|
1080
|
+
(Plugin.on_experiment_abort, cls.on_experiment_abort),
|
|
1081
|
+
]
|
|
1082
|
+
)
|
|
1014
1083
|
|
|
1015
1084
|
def on_run_start(
|
|
1016
1085
|
self,
|
|
@@ -433,5 +433,24 @@ class RunnerTest(unittest.TestCase):
|
|
|
433
433
|
pass
|
|
434
434
|
|
|
435
435
|
|
|
436
|
+
class PluginTest(unittest.TestCase):
|
|
437
|
+
|
|
438
|
+
def test_per_example_only(self):
|
|
439
|
+
|
|
440
|
+
class PerExamplePlugin(experiment_lib.Plugin):
|
|
441
|
+
|
|
442
|
+
def on_example_complete(self, runner, experiment, example):
|
|
443
|
+
print('on_example_complete')
|
|
444
|
+
|
|
445
|
+
self.assertTrue(PerExamplePlugin.is_per_example())
|
|
446
|
+
|
|
447
|
+
class NonPerExamplePlugin(experiment_lib.Plugin):
|
|
448
|
+
|
|
449
|
+
def on_experiment_complete(self, runner, experiment):
|
|
450
|
+
print('on_example_complete')
|
|
451
|
+
|
|
452
|
+
self.assertFalse(NonPerExamplePlugin.is_per_example())
|
|
453
|
+
|
|
454
|
+
|
|
436
455
|
if __name__ == '__main__':
|
|
437
456
|
unittest.main()
|
|
@@ -20,7 +20,15 @@ import pyglove as pg
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class MetricValue(pg.Object):
|
|
23
|
-
"""Base class for metric values.
|
|
23
|
+
"""Base class for metric values.
|
|
24
|
+
|
|
25
|
+
`MetricValue` is the base class for representing aggregated metric values
|
|
26
|
+
in an evaluation. It accumulates data points from individual examples,
|
|
27
|
+
each consisting of a value and an optional weight, associated with an example
|
|
28
|
+
ID. Subclasses must implement `reduce` method to compute a single float value
|
|
29
|
+
from accumulated data points, and `scalar_repr` to provide a string
|
|
30
|
+
representation of the reduced value.
|
|
31
|
+
"""
|
|
24
32
|
|
|
25
33
|
class DataPoint(pg.Object):
|
|
26
34
|
"""A data point for a metric value."""
|
|
@@ -88,6 +96,14 @@ class MetricValue(pg.Object):
|
|
|
88
96
|
self.increment_total()
|
|
89
97
|
return self
|
|
90
98
|
|
|
99
|
+
def merge_from(self, other: 'MetricValue') -> 'MetricValue':
|
|
100
|
+
"""Merges the values from another metric value."""
|
|
101
|
+
self._weighted_sum += other._weighted_sum # pylint: disable=protected-access
|
|
102
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
|
103
|
+
self.data_points.extend(other.data_points)
|
|
104
|
+
self.increment_total(other.total)
|
|
105
|
+
return self
|
|
106
|
+
|
|
91
107
|
def __gt__(self, other: Union['MetricValue', float]) -> bool:
|
|
92
108
|
if isinstance(other, self.__class__):
|
|
93
109
|
return float(self) > float(other)
|
|
@@ -133,7 +149,13 @@ class MetricValue(pg.Object):
|
|
|
133
149
|
|
|
134
150
|
|
|
135
151
|
class Rate(MetricValue):
|
|
136
|
-
"""
|
|
152
|
+
"""Metric value representing a rate in range [0, 1].
|
|
153
|
+
|
|
154
|
+
`Rate` is used for metrics that compute a rate, such as accuracy or error
|
|
155
|
+
rate. The final value is computed as the weighted sum of accumulated values
|
|
156
|
+
divided by the total number of examples. It's displayed as a percentage
|
|
157
|
+
(e.g., 90.0%).
|
|
158
|
+
"""
|
|
137
159
|
|
|
138
160
|
def reduce(self) -> float:
|
|
139
161
|
return self._weighted_sum / self.total
|
|
@@ -145,7 +167,13 @@ class Rate(MetricValue):
|
|
|
145
167
|
|
|
146
168
|
|
|
147
169
|
class Average(MetricValue):
|
|
148
|
-
"""
|
|
170
|
+
"""Metric value representing an average of accumulated values.
|
|
171
|
+
|
|
172
|
+
`Average` is used for metrics that compute an average score across examples
|
|
173
|
+
(e.g., average quality score). The final value is computed as the weighted
|
|
174
|
+
sum of accumulated values divided by the number of data points.
|
|
175
|
+
It's displayed as a float with 3 decimal places (e.g., 4.750).
|
|
176
|
+
"""
|
|
149
177
|
|
|
150
178
|
def reduce(self) -> float:
|
|
151
179
|
if not self.data_points:
|
|
@@ -51,6 +51,22 @@ class RateTest(unittest.TestCase):
|
|
|
51
51
|
self.assertEqual(rate.total, 0)
|
|
52
52
|
self.assertTrue(math.isnan(float(rate)))
|
|
53
53
|
|
|
54
|
+
def test_merge_from(self):
|
|
55
|
+
rate1 = metric_values.Rate()
|
|
56
|
+
rate1.add(1, 1.0, 1.0, increment_total=True)
|
|
57
|
+
rate2 = metric_values.Rate()
|
|
58
|
+
rate2.add(2, 0.0, 1.0, increment_total=True)
|
|
59
|
+
rate1.merge_from(rate2)
|
|
60
|
+
self.assertEqual(rate1.total, 2)
|
|
61
|
+
self.assertEqual(float(rate1), 0.5)
|
|
62
|
+
self.assertEqual(
|
|
63
|
+
rate1.data_points,
|
|
64
|
+
[
|
|
65
|
+
metric_values.MetricValue.DataPoint(1, 1.0, 1.0),
|
|
66
|
+
metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
|
|
67
|
+
],
|
|
68
|
+
)
|
|
69
|
+
|
|
54
70
|
|
|
55
71
|
class AverageTest(unittest.TestCase):
|
|
56
72
|
|
|
@@ -75,6 +91,22 @@ class AverageTest(unittest.TestCase):
|
|
|
75
91
|
average.reset()
|
|
76
92
|
self.assertEqual(average.total, 0)
|
|
77
93
|
|
|
94
|
+
def test_merge_from(self):
|
|
95
|
+
avg1 = metric_values.Average()
|
|
96
|
+
avg1.add(1, 1.0, 0.5, increment_total=True)
|
|
97
|
+
avg2 = metric_values.Average()
|
|
98
|
+
avg2.add(2, 0.0, 1.0, increment_total=True)
|
|
99
|
+
avg1.merge_from(avg2)
|
|
100
|
+
self.assertEqual(avg1.total, 2)
|
|
101
|
+
self.assertEqual(float(avg1), 0.25)
|
|
102
|
+
self.assertEqual(
|
|
103
|
+
avg1.data_points,
|
|
104
|
+
[
|
|
105
|
+
metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
|
|
106
|
+
metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
|
|
107
|
+
],
|
|
108
|
+
)
|
|
109
|
+
|
|
78
110
|
|
|
79
111
|
if __name__ == '__main__':
|
|
80
112
|
unittest.main()
|
langfun/core/eval/v2/metrics.py
CHANGED
|
@@ -29,7 +29,15 @@ Average = metric_values.Average
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
32
|
-
"""Interface for an evaluation metric.
|
|
32
|
+
"""Interface for an evaluation metric.
|
|
33
|
+
|
|
34
|
+
A metric is used to evaluate the quality of the outputs produced by an
|
|
35
|
+
evaluation. It works by auditing each processed example via its `audit`
|
|
36
|
+
method, which in turn calls the user-overridable `_audit` method to perform
|
|
37
|
+
metric-specific logic and update metric values. Metrics can compute multiple
|
|
38
|
+
values (e.g., precision, recall, F1 score) which are exposed via the
|
|
39
|
+
`values` method.
|
|
40
|
+
"""
|
|
33
41
|
|
|
34
42
|
name: Annotated[
|
|
35
43
|
str,
|
|
@@ -44,24 +52,43 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
44
52
|
self._label_group = None
|
|
45
53
|
self._lock = threading.Lock()
|
|
46
54
|
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
with
|
|
53
|
-
for v in self.values():
|
|
54
|
-
v.increment_total()
|
|
55
|
+
def update(
|
|
56
|
+
self,
|
|
57
|
+
example: example_lib.Example,
|
|
58
|
+
force_recompute: bool = False
|
|
59
|
+
) -> dict[str, Any]:
|
|
60
|
+
"""Updates metric values with a processed example.
|
|
55
61
|
|
|
56
|
-
|
|
62
|
+
Args:
|
|
63
|
+
example: The processed example.
|
|
64
|
+
force_recompute: Whether to force recompute the metric metadata even if
|
|
65
|
+
they are already present.
|
|
57
66
|
|
|
58
|
-
|
|
59
|
-
|
|
67
|
+
Returns:
|
|
68
|
+
A dict of metric metadata.
|
|
69
|
+
"""
|
|
70
|
+
if (force_recompute
|
|
71
|
+
or example.metric_metadata is None
|
|
72
|
+
or self.name not in example.metric_metadata):
|
|
73
|
+
metadata = self.compute_metric_metadata(example)
|
|
74
|
+
else:
|
|
75
|
+
metadata = example.metric_metadata[self.name]
|
|
76
|
+
self.update_metric_values(example.id, metadata)
|
|
77
|
+
self._update_view()
|
|
78
|
+
return metadata
|
|
60
79
|
|
|
61
80
|
@abc.abstractmethod
|
|
62
|
-
def
|
|
81
|
+
def compute_metric_metadata(
|
|
82
|
+
self, example: example_lib.Example
|
|
83
|
+
) -> dict[str, Any]:
|
|
63
84
|
"""Subclasses should override this method to implement the metric logic."""
|
|
64
85
|
|
|
86
|
+
@abc.abstractmethod
|
|
87
|
+
def update_metric_values(
|
|
88
|
+
self, example_id: int, metric_metadata: dict[str, Any]
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Update metric values based on metric metadata."""
|
|
91
|
+
|
|
65
92
|
@abc.abstractmethod
|
|
66
93
|
def values(self) -> list[metric_values.MetricValue]:
|
|
67
94
|
"""Returns all the values computed by this metric."""
|
|
@@ -71,6 +98,12 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
71
98
|
for v in self.values():
|
|
72
99
|
v.reset()
|
|
73
100
|
|
|
101
|
+
def merge_from(self, other: 'Metric') -> 'Metric':
|
|
102
|
+
"""Merges the values from another metric."""
|
|
103
|
+
for v1, v2 in zip(self.values(), other.values()):
|
|
104
|
+
v1.merge_from(v2)
|
|
105
|
+
return self
|
|
106
|
+
|
|
74
107
|
def _update_view(self):
|
|
75
108
|
"""Refreshes the metric values."""
|
|
76
109
|
if self._label_group is None:
|
|
@@ -169,7 +202,15 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
169
202
|
|
|
170
203
|
|
|
171
204
|
class MetricBase(Metric):
|
|
172
|
-
"""Base class for common metrics.
|
|
205
|
+
"""Base class for common metrics.
|
|
206
|
+
|
|
207
|
+
`MetricBase` provides common functionalities for metrics, such as automatic
|
|
208
|
+
error counting based on whether an example has an error during evaluation.
|
|
209
|
+
It distinguishes between Object-Oriented Programming (OOP) errors
|
|
210
|
+
(e.g. `MappingError` during structured output generation) and other errors.
|
|
211
|
+
Subclasses should implement `_audit_processed` for metric computation on
|
|
212
|
+
successfully processed examples.
|
|
213
|
+
"""
|
|
173
214
|
|
|
174
215
|
oop_errors: Rate | None = Rate()
|
|
175
216
|
non_oop_errors: Rate | None = Rate()
|
|
@@ -183,27 +224,67 @@ class MetricBase(Metric):
|
|
|
183
224
|
super().reset()
|
|
184
225
|
self._error_breakdown = collections.defaultdict(list)
|
|
185
226
|
|
|
186
|
-
def
|
|
187
|
-
|
|
227
|
+
def compute_metric_metadata(
|
|
228
|
+
self, example: example_lib.Example
|
|
229
|
+
) -> dict[str, Any]:
|
|
230
|
+
"""Computes the metric metadata for the example."""
|
|
188
231
|
if example.error is None:
|
|
189
|
-
return self.
|
|
232
|
+
return self._compute_metric_metadata(example)
|
|
233
|
+
return self._compute_metric_metadata_with_processing_error(example)
|
|
234
|
+
|
|
235
|
+
def update_metric_values(
|
|
236
|
+
self,
|
|
237
|
+
example_id: int,
|
|
238
|
+
metric_metadata: dict[str, Any]
|
|
239
|
+
) -> None:
|
|
240
|
+
"""Collects the metric metadata."""
|
|
241
|
+
# NOTE(daiyip): the metric values are being updated concurrently, so we
|
|
242
|
+
# uses a lock to avoid race condition. We might consider relaxing the lock
|
|
243
|
+
# later if metric auditing becomes a bottleneck.
|
|
244
|
+
with self._lock:
|
|
245
|
+
for v in self.values():
|
|
246
|
+
v.increment_total()
|
|
247
|
+
|
|
248
|
+
if 'error' in metric_metadata:
|
|
249
|
+
self._update_metric_values_with_processing_error(
|
|
250
|
+
example_id, metric_metadata
|
|
251
|
+
)
|
|
190
252
|
else:
|
|
191
|
-
|
|
253
|
+
self._update_metric_values(example_id, metric_metadata)
|
|
254
|
+
|
|
255
|
+
@abc.abstractmethod
|
|
256
|
+
def _compute_metric_metadata(
|
|
257
|
+
self,
|
|
258
|
+
example: example_lib.Example
|
|
259
|
+
) -> dict[str, Any]:
|
|
260
|
+
"""Computes the metric metadata for the example."""
|
|
192
261
|
|
|
193
|
-
def
|
|
262
|
+
def _compute_metric_metadata_with_processing_error(
|
|
263
|
+
self,
|
|
264
|
+
example: example_lib.Example
|
|
265
|
+
) -> dict[str, Any]:
|
|
194
266
|
"""Audits the evaluation example after processing."""
|
|
195
267
|
assert example.error is not None
|
|
196
|
-
|
|
197
|
-
if tag.startswith('MappingError'):
|
|
198
|
-
self.oop_errors.add(example.id, 1)
|
|
199
|
-
else:
|
|
200
|
-
self.non_oop_errors.add(example.id, 1)
|
|
201
|
-
self._error_breakdown[tag].append(example.id)
|
|
202
|
-
return dict(error=tag)
|
|
268
|
+
return dict(error=example.error.tag)
|
|
203
269
|
|
|
204
270
|
@abc.abstractmethod
|
|
205
|
-
def
|
|
206
|
-
"""
|
|
271
|
+
def _update_metric_values(self, metadata: dict[str, Any]) -> None:
|
|
272
|
+
"""Update metric values based metric metadata."""
|
|
273
|
+
|
|
274
|
+
def _update_metric_values_with_processing_error(
|
|
275
|
+
self,
|
|
276
|
+
example_id: int,
|
|
277
|
+
metric_metadata: dict[str, Any]
|
|
278
|
+
) -> None:
|
|
279
|
+
"""Updates metric values with processing error."""
|
|
280
|
+
error_tag = metric_metadata.get('error')
|
|
281
|
+
assert error_tag is not None, (example_id, metric_metadata)
|
|
282
|
+
self._error_breakdown[error_tag].append(example_id)
|
|
283
|
+
if error_tag.startswith('MappingError'):
|
|
284
|
+
self.oop_errors.add(example_id, 1)
|
|
285
|
+
else:
|
|
286
|
+
self.non_oop_errors.add(example_id, 1)
|
|
287
|
+
self._error_breakdown[error_tag].append(example_id)
|
|
207
288
|
|
|
208
289
|
def _oop_errors_breakdown(self) -> str | None:
|
|
209
290
|
"""Returns the OOP error breakdown as a string."""
|
|
@@ -229,7 +310,13 @@ class MetricBase(Metric):
|
|
|
229
310
|
|
|
230
311
|
|
|
231
312
|
class Match(MetricBase):
|
|
232
|
-
"""Metric for matching outputs against
|
|
313
|
+
"""Metric for matching outputs against ground truth.
|
|
314
|
+
|
|
315
|
+
This metric computes match and mismatch rates by comparing the output of
|
|
316
|
+
an example with its ground truth. By default, it looks for a `groundtruth`
|
|
317
|
+
attribute in `example.input` for comparison. Users can customize this behavior
|
|
318
|
+
by subclassing `Match` and overriding the `match` method.
|
|
319
|
+
"""
|
|
233
320
|
|
|
234
321
|
name = 'match'
|
|
235
322
|
matches: Rate = Rate()
|
|
@@ -257,20 +344,30 @@ class Match(MetricBase):
|
|
|
257
344
|
)
|
|
258
345
|
return pg.eq(output, groundtruth)
|
|
259
346
|
|
|
260
|
-
def
|
|
261
|
-
|
|
347
|
+
def _compute_metric_metadata(
|
|
348
|
+
self, example: example_lib.Example
|
|
349
|
+
) -> dict[str, Any]:
|
|
350
|
+
"""Computes the metric metadata for the example."""
|
|
262
351
|
metadata = {}
|
|
263
|
-
|
|
264
|
-
if isinstance(
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
metadata['match'] = True
|
|
269
|
-
else:
|
|
270
|
-
self.mismatches.add(example.id, 1)
|
|
271
|
-
metadata['mismatch'] = True
|
|
352
|
+
is_correct = self.match(example.input, example.output)
|
|
353
|
+
if isinstance(is_correct, tuple):
|
|
354
|
+
is_correct, metadata = is_correct
|
|
355
|
+
|
|
356
|
+
metadata['is_correct'] = is_correct
|
|
272
357
|
return metadata
|
|
273
358
|
|
|
359
|
+
def _update_metric_values(
|
|
360
|
+
self, example_id: int, metadata: dict[str, Any]
|
|
361
|
+
) -> None:
|
|
362
|
+
"""Update metric values based metric metadata."""
|
|
363
|
+
is_correct = metadata.get('is_correct')
|
|
364
|
+
assert is_correct is not None, (example_id, metadata)
|
|
365
|
+
if is_correct:
|
|
366
|
+
self.matches.add(example_id, 1)
|
|
367
|
+
else:
|
|
368
|
+
assert not is_correct
|
|
369
|
+
self.mismatches.add(example_id, 1)
|
|
370
|
+
|
|
274
371
|
def values(self) -> list[metric_values.MetricValue]:
|
|
275
372
|
"""Returns all the values computed by this metric."""
|
|
276
373
|
return [
|
|
@@ -302,7 +399,14 @@ class Match(MetricBase):
|
|
|
302
399
|
|
|
303
400
|
|
|
304
401
|
class Score(MetricBase):
|
|
305
|
-
"""Base class for scoring.
|
|
402
|
+
"""Base class for scoring metrics.
|
|
403
|
+
|
|
404
|
+
`Score` is a base class for metrics that assign a numerical score to each
|
|
405
|
+
example's output (e.g., evaluating quality on a scale of 1-5).
|
|
406
|
+
It automatically computes the average score across all examples.
|
|
407
|
+
Subclasses must implement the `score` method to define how an example
|
|
408
|
+
should be scored.
|
|
409
|
+
"""
|
|
306
410
|
|
|
307
411
|
name = 'score'
|
|
308
412
|
average_score: Average = Average()
|
|
@@ -322,16 +426,25 @@ class Score(MetricBase):
|
|
|
322
426
|
A float score. Or a tuple of (score, metadata).
|
|
323
427
|
"""
|
|
324
428
|
|
|
325
|
-
def
|
|
326
|
-
|
|
429
|
+
def _compute_metric_metadata(
|
|
430
|
+
self, example: example_lib.Example
|
|
431
|
+
) -> dict[str, Any]:
|
|
432
|
+
"""Computes the metric metadata for the example."""
|
|
327
433
|
metadata = {}
|
|
328
434
|
score = self.score(example.input, example.output)
|
|
329
435
|
if isinstance(score, tuple):
|
|
330
436
|
score, metadata = score
|
|
331
|
-
self.average_score.add(example.id, score)
|
|
332
437
|
metadata['score'] = score
|
|
333
438
|
return metadata
|
|
334
439
|
|
|
440
|
+
def _update_metric_values(
|
|
441
|
+
self, example_id: int, metadata: dict[str, Any]
|
|
442
|
+
) -> None:
|
|
443
|
+
"""Update metric values based metric metadata."""
|
|
444
|
+
score = metadata.get('score')
|
|
445
|
+
assert score is not None, (example_id, metadata)
|
|
446
|
+
self.average_score.add(example_id, score)
|
|
447
|
+
|
|
335
448
|
def values(self) -> list[metric_values.MetricValue]:
|
|
336
449
|
"""Returns all the values computed by this metric."""
|
|
337
450
|
return [
|