langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -139,10 +139,10 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
139
139
 
140
140
  # Checkpointing
141
141
 
142
- Experiments support checkpointing, which is enabled by default. It allows
142
+ Experiments support checkpointing, which is enabled by default. It allows
143
143
  users to resume their experiments from a saved state. When an experiment runs,
144
- it creates a new directory for that run and saves the current state to a
145
- checkpoint file. If the experiment is interrupted or fails, users can resume
144
+ it creates a new directory for that run and saves its progress to checkpoint
145
+ files. If the experiment is interrupted or fails, users can resume
146
146
  it by specifying the 'id' or 'warm_start_from' argument (shown above) to
147
147
  seamlessly continue from previously saved state without starting over.
148
148
 
@@ -169,7 +169,7 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
169
169
 
170
170
  # Experiment Plugins
171
171
 
172
- Experiment can be extended by plugins. Plugins can listen to the events of
172
+ Experiments can be extended by plugins. Plugins can listen to the events of
173
173
  experiment execution and produce additional outputs. For example, a plugin
174
174
  can be added to an experiment to generate additional metrics or to save
175
175
  additional data to a database. More details will be added in the future.
@@ -657,7 +657,30 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
657
657
 
658
658
  @pg.use_init_args(['children'])
659
659
  class Suite(Experiment):
660
- """A suite of evaluations."""
660
+ """A suite of evaluations.
661
+
662
+ `lf.eval.Suite` groups multiple `lf.eval.Evaluation` or other `Suite`
663
+ objects into a single experiment, allowing them to be run, managed, and
664
+ reported together.
665
+
666
+ **Example:**
667
+
668
+ ```python
669
+ import langfun as lf
670
+
671
+ suite = lf.eval.Suite([
672
+ MyEval(lm=lf.llms.Gpt4()),
673
+ MyEval(lm=lf.llms.Gemini()),
674
+ lf.eval.Suite([
675
+ AnotherEval(lm=lf.llms.Gpt4()),
676
+ AnotherEval(lm=lf.llms.Gemini())
677
+ ])
678
+ ])
679
+
680
+ # Run all evaluations in the suite
681
+ run_info = suite.run('/path/to/my/suite_run')
682
+ ```
683
+ """
661
684
 
662
685
  children: Annotated[
663
686
  list[Experiment], 'A list of child experiments.'
@@ -791,7 +814,14 @@ class RunId(pg.Object):
791
814
 
792
815
 
793
816
  class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
794
- """A run of an experiment."""
817
+ """Represents a single run of an experiment.
818
+
819
+ A `Run` object holds all the configurations for executing an experiment,
820
+ such as the experiment definition, input/output directories, and flags
821
+ controlling the execution behavior (e.g., error handling, checkpointing).
822
+ It also provides utility methods for accessing run-specific paths and
823
+ filtering examples for evaluation.
824
+ """
795
825
 
796
826
  root_dir: Annotated[
797
827
  str,
@@ -818,10 +848,10 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
818
848
  ] = None
819
849
 
820
850
  example_ids: Annotated[
821
- list[int] | None,
851
+ list[int] | Callable[[Experiment], list[int]] | None,
822
852
  (
823
- 'The example IDs to run. If None, it will run all examples. '
824
- 'Though '
853
+ 'The example IDs to run. Or a callable for determining the examples '
854
+ 'to run based on the experiment. If None, it will run all examples. '
825
855
  )
826
856
  ] = None
827
857
 
@@ -937,10 +967,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
937
967
  """Returns the example IDs to evaluate."""
938
968
  if not experiment.is_leaf:
939
969
  return set()
940
- return set(
941
- self.example_ids if self.example_ids else
942
- range(1, experiment.num_examples + 1)
943
- )
970
+ if self.example_ids is None:
971
+ return set(range(1, experiment.num_examples + 1))
972
+ elif isinstance(self.example_ids, Callable):
973
+ return set(self.example_ids(experiment))
974
+ else:
975
+ assert isinstance(self.example_ids, list), self.example_ids
976
+ return set(self.example_ids)
944
977
 
945
978
  def examples_to_reprocess(self, experiment: Experiment) -> set[int]:
946
979
  """Returns the example IDs to reprocess per request."""
@@ -971,7 +1004,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
971
1004
 
972
1005
 
973
1006
  class Runner(pg.Object):
974
- """Interface for experiment runner."""
1007
+ """Interface for experiment runner.
1008
+
1009
+ A runner is responsible for executing the evaluations within an experiment
1010
+ based on the configuration specified in a `Run` object. Different runners
1011
+ can implement different execution strategies, such as sequential or parallel
1012
+ processing of examples and evaluations.
1013
+ """
975
1014
 
976
1015
  # Class-level variable for registering the runner.
977
1016
  NAME = None
@@ -1010,7 +1049,37 @@ class Runner(pg.Object):
1010
1049
 
1011
1050
 
1012
1051
  class Plugin(lf.Component):
1013
- """Base class for experiment plugins."""
1052
+ """Base class for experiment plugins.
1053
+
1054
+ Plugins provide a mechanism to extend the behavior of an experiment run
1055
+ by hooking into various events during the lifecycle of experiment and
1056
+ example execution, such as `on_run_start`, `on_experiment_complete`,
1057
+ `on_example_start`, etc. They can be used for custom logging, monitoring,
1058
+ or result processing.
1059
+ """
1060
+
1061
+ @classmethod
1062
+ def is_per_example(cls) -> bool:
1063
+ """Returns whether the plugin is per example only.
1064
+
1065
+ Per-example plugins can be installed on individual workers when examples
1066
+ are evaluated by multiple processes in parallel.
1067
+ """
1068
+
1069
+ def same_code(method1, method2):
1070
+ return method1.__code__ == method2.__code__
1071
+ return all(
1072
+ same_code(method1, method2)
1073
+ for method1, method2 in [
1074
+ (Plugin.on_run_start, cls.on_run_start),
1075
+ (Plugin.on_run_complete, cls.on_run_complete),
1076
+ (Plugin.on_run_abort, cls.on_run_abort),
1077
+ (Plugin.on_experiment_start, cls.on_experiment_start),
1078
+ (Plugin.on_experiment_skipped, cls.on_experiment_skipped),
1079
+ (Plugin.on_experiment_complete, cls.on_experiment_complete),
1080
+ (Plugin.on_experiment_abort, cls.on_experiment_abort),
1081
+ ]
1082
+ )
1014
1083
 
1015
1084
  def on_run_start(
1016
1085
  self,
@@ -433,5 +433,24 @@ class RunnerTest(unittest.TestCase):
433
433
  pass
434
434
 
435
435
 
436
+ class PluginTest(unittest.TestCase):
437
+
438
+ def test_per_example_only(self):
439
+
440
+ class PerExamplePlugin(experiment_lib.Plugin):
441
+
442
+ def on_example_complete(self, runner, experiment, example):
443
+ print('on_example_complete')
444
+
445
+ self.assertTrue(PerExamplePlugin.is_per_example())
446
+
447
+ class NonPerExamplePlugin(experiment_lib.Plugin):
448
+
449
+ def on_experiment_complete(self, runner, experiment):
450
+ print('on_example_complete')
451
+
452
+ self.assertFalse(NonPerExamplePlugin.is_per_example())
453
+
454
+
436
455
  if __name__ == '__main__':
437
456
  unittest.main()
@@ -20,7 +20,15 @@ import pyglove as pg
20
20
 
21
21
 
22
22
  class MetricValue(pg.Object):
23
- """Base class for metric values."""
23
+ """Base class for metric values.
24
+
25
+ `MetricValue` is the base class for representing aggregated metric values
26
+ in an evaluation. It accumulates data points from individual examples,
27
+ each consisting of a value and an optional weight, associated with an example
28
+ ID. Subclasses must implement `reduce` method to compute a single float value
29
+ from accumulated data points, and `scalar_repr` to provide a string
30
+ representation of the reduced value.
31
+ """
24
32
 
25
33
  class DataPoint(pg.Object):
26
34
  """A data point for a metric value."""
@@ -88,6 +96,14 @@ class MetricValue(pg.Object):
88
96
  self.increment_total()
89
97
  return self
90
98
 
99
+ def merge_from(self, other: 'MetricValue') -> 'MetricValue':
100
+ """Merges the values from another metric value."""
101
+ self._weighted_sum += other._weighted_sum # pylint: disable=protected-access
102
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
103
+ self.data_points.extend(other.data_points)
104
+ self.increment_total(other.total)
105
+ return self
106
+
91
107
  def __gt__(self, other: Union['MetricValue', float]) -> bool:
92
108
  if isinstance(other, self.__class__):
93
109
  return float(self) > float(other)
@@ -133,7 +149,13 @@ class MetricValue(pg.Object):
133
149
 
134
150
 
135
151
  class Rate(MetricValue):
136
- """Representing a rate in range [0, 1]."""
152
+ """Metric value representing a rate in range [0, 1].
153
+
154
+ `Rate` is used for metrics that compute a rate, such as accuracy or error
155
+ rate. The final value is computed as the weighted sum of accumulated values
156
+ divided by the total number of examples. It's displayed as a percentage
157
+ (e.g., 90.0%).
158
+ """
137
159
 
138
160
  def reduce(self) -> float:
139
161
  return self._weighted_sum / self.total
@@ -145,7 +167,13 @@ class Rate(MetricValue):
145
167
 
146
168
 
147
169
  class Average(MetricValue):
148
- """Average of a aggregated values."""
170
+ """Metric value representing an average of accumulated values.
171
+
172
+ `Average` is used for metrics that compute an average score across examples
173
+ (e.g., average quality score). The final value is computed as the weighted
174
+ sum of accumulated values divided by the number of data points.
175
+ It's displayed as a float with 3 decimal places (e.g., 4.750).
176
+ """
149
177
 
150
178
  def reduce(self) -> float:
151
179
  if not self.data_points:
@@ -51,6 +51,22 @@ class RateTest(unittest.TestCase):
51
51
  self.assertEqual(rate.total, 0)
52
52
  self.assertTrue(math.isnan(float(rate)))
53
53
 
54
+ def test_merge_from(self):
55
+ rate1 = metric_values.Rate()
56
+ rate1.add(1, 1.0, 1.0, increment_total=True)
57
+ rate2 = metric_values.Rate()
58
+ rate2.add(2, 0.0, 1.0, increment_total=True)
59
+ rate1.merge_from(rate2)
60
+ self.assertEqual(rate1.total, 2)
61
+ self.assertEqual(float(rate1), 0.5)
62
+ self.assertEqual(
63
+ rate1.data_points,
64
+ [
65
+ metric_values.MetricValue.DataPoint(1, 1.0, 1.0),
66
+ metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
67
+ ],
68
+ )
69
+
54
70
 
55
71
  class AverageTest(unittest.TestCase):
56
72
 
@@ -75,6 +91,22 @@ class AverageTest(unittest.TestCase):
75
91
  average.reset()
76
92
  self.assertEqual(average.total, 0)
77
93
 
94
+ def test_merge_from(self):
95
+ avg1 = metric_values.Average()
96
+ avg1.add(1, 1.0, 0.5, increment_total=True)
97
+ avg2 = metric_values.Average()
98
+ avg2.add(2, 0.0, 1.0, increment_total=True)
99
+ avg1.merge_from(avg2)
100
+ self.assertEqual(avg1.total, 2)
101
+ self.assertEqual(float(avg1), 0.25)
102
+ self.assertEqual(
103
+ avg1.data_points,
104
+ [
105
+ metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
106
+ metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
107
+ ],
108
+ )
109
+
78
110
 
79
111
  if __name__ == '__main__':
80
112
  unittest.main()
@@ -29,7 +29,15 @@ Average = metric_values.Average
29
29
 
30
30
 
31
31
  class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
32
- """Interface for an evaluation metric."""
32
+ """Interface for an evaluation metric.
33
+
34
+ A metric is used to evaluate the quality of the outputs produced by an
35
+ evaluation. It works by auditing each processed example via its `audit`
36
+ method, which in turn calls the user-overridable `_audit` method to perform
37
+ metric-specific logic and update metric values. Metrics can compute multiple
38
+ values (e.g., precision, recall, F1 score) which are exposed via the
39
+ `values` method.
40
+ """
33
41
 
34
42
  name: Annotated[
35
43
  str,
@@ -44,24 +52,43 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
44
52
  self._label_group = None
45
53
  self._lock = threading.Lock()
46
54
 
47
- def audit(self, example: example_lib.Example) -> dict[str, Any]:
48
- """Audits a processed example and returns metric metadata for it."""
49
- # NOTE(daiyip): the metric values are being updated concurrently, so we
50
- # uses a lock to avoid race condition. We might consider relaxing the lock
51
- # later if metric auditing becomes a bottleneck.
52
- with self._lock:
53
- for v in self.values():
54
- v.increment_total()
55
+ def update(
56
+ self,
57
+ example: example_lib.Example,
58
+ force_recompute: bool = False
59
+ ) -> dict[str, Any]:
60
+ """Updates metric values with a processed example.
55
61
 
56
- metadata = self._audit(example)
62
+ Args:
63
+ example: The processed example.
64
+ force_recompute: Whether to force recompute the metric metadata even if
65
+ they are already present.
57
66
 
58
- self._update_view()
59
- return metadata
67
+ Returns:
68
+ A dict of metric metadata.
69
+ """
70
+ if (force_recompute
71
+ or example.metric_metadata is None
72
+ or self.name not in example.metric_metadata):
73
+ metadata = self.compute_metric_metadata(example)
74
+ else:
75
+ metadata = example.metric_metadata[self.name]
76
+ self.update_metric_values(example.id, metadata)
77
+ self._update_view()
78
+ return metadata
60
79
 
61
80
  @abc.abstractmethod
62
- def _audit(self, example: example_lib.Example) -> dict[str, Any]:
81
+ def compute_metric_metadata(
82
+ self, example: example_lib.Example
83
+ ) -> dict[str, Any]:
63
84
  """Subclasses should override this method to implement the metric logic."""
64
85
 
86
+ @abc.abstractmethod
87
+ def update_metric_values(
88
+ self, example_id: int, metric_metadata: dict[str, Any]
89
+ ) -> None:
90
+ """Update metric values based on metric metadata."""
91
+
65
92
  @abc.abstractmethod
66
93
  def values(self) -> list[metric_values.MetricValue]:
67
94
  """Returns all the values computed by this metric."""
@@ -71,6 +98,12 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
71
98
  for v in self.values():
72
99
  v.reset()
73
100
 
101
+ def merge_from(self, other: 'Metric') -> 'Metric':
102
+ """Merges the values from another metric."""
103
+ for v1, v2 in zip(self.values(), other.values()):
104
+ v1.merge_from(v2)
105
+ return self
106
+
74
107
  def _update_view(self):
75
108
  """Refreshes the metric values."""
76
109
  if self._label_group is None:
@@ -169,7 +202,15 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
169
202
 
170
203
 
171
204
  class MetricBase(Metric):
172
- """Base class for common metrics."""
205
+ """Base class for common metrics.
206
+
207
+ `MetricBase` provides common functionalities for metrics, such as automatic
208
+ error counting based on whether an example has an error during evaluation.
209
+ It distinguishes between Object-Oriented Programming (OOP) errors
210
+ (e.g. `MappingError` during structured output generation) and other errors.
211
+ Subclasses should implement `_audit_processed` for metric computation on
212
+ successfully processed examples.
213
+ """
173
214
 
174
215
  oop_errors: Rate | None = Rate()
175
216
  non_oop_errors: Rate | None = Rate()
@@ -183,27 +224,67 @@ class MetricBase(Metric):
183
224
  super().reset()
184
225
  self._error_breakdown = collections.defaultdict(list)
185
226
 
186
- def _audit(self, example: example_lib.Example) -> dict[str, Any]:
187
- """Audits the evaluation example after processing."""
227
+ def compute_metric_metadata(
228
+ self, example: example_lib.Example
229
+ ) -> dict[str, Any]:
230
+ """Computes the metric metadata for the example."""
188
231
  if example.error is None:
189
- return self._audit_processed(example)
232
+ return self._compute_metric_metadata(example)
233
+ return self._compute_metric_metadata_with_processing_error(example)
234
+
235
+ def update_metric_values(
236
+ self,
237
+ example_id: int,
238
+ metric_metadata: dict[str, Any]
239
+ ) -> None:
240
+ """Collects the metric metadata."""
241
+ # NOTE(daiyip): the metric values are being updated concurrently, so we
242
+ # uses a lock to avoid race condition. We might consider relaxing the lock
243
+ # later if metric auditing becomes a bottleneck.
244
+ with self._lock:
245
+ for v in self.values():
246
+ v.increment_total()
247
+
248
+ if 'error' in metric_metadata:
249
+ self._update_metric_values_with_processing_error(
250
+ example_id, metric_metadata
251
+ )
190
252
  else:
191
- return self._audit_error(example)
253
+ self._update_metric_values(example_id, metric_metadata)
254
+
255
+ @abc.abstractmethod
256
+ def _compute_metric_metadata(
257
+ self,
258
+ example: example_lib.Example
259
+ ) -> dict[str, Any]:
260
+ """Computes the metric metadata for the example."""
192
261
 
193
- def _audit_error(self, example: example_lib.Example) -> dict[str, Any]:
262
+ def _compute_metric_metadata_with_processing_error(
263
+ self,
264
+ example: example_lib.Example
265
+ ) -> dict[str, Any]:
194
266
  """Audits the evaluation example after processing."""
195
267
  assert example.error is not None
196
- tag = example.error.tag
197
- if tag.startswith('MappingError'):
198
- self.oop_errors.add(example.id, 1)
199
- else:
200
- self.non_oop_errors.add(example.id, 1)
201
- self._error_breakdown[tag].append(example.id)
202
- return dict(error=tag)
268
+ return dict(error=example.error.tag)
203
269
 
204
270
  @abc.abstractmethod
205
- def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
206
- """Audits the evaluation example after processing."""
271
+ def _update_metric_values(self, metadata: dict[str, Any]) -> None:
272
+ """Update metric values based metric metadata."""
273
+
274
+ def _update_metric_values_with_processing_error(
275
+ self,
276
+ example_id: int,
277
+ metric_metadata: dict[str, Any]
278
+ ) -> None:
279
+ """Updates metric values with processing error."""
280
+ error_tag = metric_metadata.get('error')
281
+ assert error_tag is not None, (example_id, metric_metadata)
282
+ self._error_breakdown[error_tag].append(example_id)
283
+ if error_tag.startswith('MappingError'):
284
+ self.oop_errors.add(example_id, 1)
285
+ else:
286
+ self.non_oop_errors.add(example_id, 1)
287
+ self._error_breakdown[error_tag].append(example_id)
207
288
 
208
289
  def _oop_errors_breakdown(self) -> str | None:
209
290
  """Returns the OOP error breakdown as a string."""
@@ -229,7 +310,13 @@ class MetricBase(Metric):
229
310
 
230
311
 
231
312
  class Match(MetricBase):
232
- """Metric for matching outputs against groundtruth."""
313
+ """Metric for matching outputs against ground truth.
314
+
315
+ This metric computes match and mismatch rates by comparing the output of
316
+ an example with its ground truth. By default, it looks for a `groundtruth`
317
+ attribute in `example.input` for comparison. Users can customize this behavior
318
+ by subclassing `Match` and overriding the `match` method.
319
+ """
233
320
 
234
321
  name = 'match'
235
322
  matches: Rate = Rate()
@@ -257,20 +344,30 @@ class Match(MetricBase):
257
344
  )
258
345
  return pg.eq(output, groundtruth)
259
346
 
260
- def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
261
- """Audits the evaluation example after processing."""
347
+ def _compute_metric_metadata(
348
+ self, example: example_lib.Example
349
+ ) -> dict[str, Any]:
350
+ """Computes the metric metadata for the example."""
262
351
  metadata = {}
263
- is_match = self.match(example.input, example.output)
264
- if isinstance(is_match, tuple):
265
- is_match, metadata = is_match
266
- if is_match:
267
- self.matches.add(example.id, 1)
268
- metadata['match'] = True
269
- else:
270
- self.mismatches.add(example.id, 1)
271
- metadata['mismatch'] = True
352
+ is_correct = self.match(example.input, example.output)
353
+ if isinstance(is_correct, tuple):
354
+ is_correct, metadata = is_correct
355
+
356
+ metadata['is_correct'] = is_correct
272
357
  return metadata
273
358
 
359
+ def _update_metric_values(
360
+ self, example_id: int, metadata: dict[str, Any]
361
+ ) -> None:
362
+ """Update metric values based metric metadata."""
363
+ is_correct = metadata.get('is_correct')
364
+ assert is_correct is not None, (example_id, metadata)
365
+ if is_correct:
366
+ self.matches.add(example_id, 1)
367
+ else:
368
+ assert not is_correct
369
+ self.mismatches.add(example_id, 1)
370
+
274
371
  def values(self) -> list[metric_values.MetricValue]:
275
372
  """Returns all the values computed by this metric."""
276
373
  return [
@@ -302,7 +399,14 @@ class Match(MetricBase):
302
399
 
303
400
 
304
401
  class Score(MetricBase):
305
- """Base class for scoring."""
402
+ """Base class for scoring metrics.
403
+
404
+ `Score` is a base class for metrics that assign a numerical score to each
405
+ example's output (e.g., evaluating quality on a scale of 1-5).
406
+ It automatically computes the average score across all examples.
407
+ Subclasses must implement the `score` method to define how an example
408
+ should be scored.
409
+ """
306
410
 
307
411
  name = 'score'
308
412
  average_score: Average = Average()
@@ -322,16 +426,25 @@ class Score(MetricBase):
322
426
  A float score. Or a tuple of (score, metadata).
323
427
  """
324
428
 
325
- def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
326
- """Audits the evaluation example after processing."""
429
+ def _compute_metric_metadata(
430
+ self, example: example_lib.Example
431
+ ) -> dict[str, Any]:
432
+ """Computes the metric metadata for the example."""
327
433
  metadata = {}
328
434
  score = self.score(example.input, example.output)
329
435
  if isinstance(score, tuple):
330
436
  score, metadata = score
331
- self.average_score.add(example.id, score)
332
437
  metadata['score'] = score
333
438
  return metadata
334
439
 
440
+ def _update_metric_values(
441
+ self, example_id: int, metadata: dict[str, Any]
442
+ ) -> None:
443
+ """Update metric values based metric metadata."""
444
+ score = metadata.get('score')
445
+ assert score is not None, (example_id, metadata)
446
+ self.average_score.add(example_id, score)
447
+
335
448
  def values(self) -> list[metric_values.MetricValue]:
336
449
  """Returns all the values computed by this metric."""
337
450
  return [