langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,15 @@ import pyglove as pg
20
20
 
21
21
 
22
22
  class MetricValue(pg.Object):
23
- """Base class for metric values."""
23
+ """Base class for metric values.
24
+
25
+ `MetricValue` is the base class for representing aggregated metric values
26
+ in an evaluation. It accumulates data points from individual examples,
27
+ each consisting of a value and an optional weight, associated with an example
28
+ ID. Subclasses must implement `reduce` method to compute a single float value
29
+ from accumulated data points, and `scalar_repr` to provide a string
30
+ representation of the reduced value.
31
+ """
24
32
 
25
33
  class DataPoint(pg.Object):
26
34
  """A data point for a metric value."""
@@ -88,6 +96,14 @@ class MetricValue(pg.Object):
88
96
  self.increment_total()
89
97
  return self
90
98
 
99
+ def merge_from(self, other: 'MetricValue') -> 'MetricValue':
100
+ """Merges the values from another metric value."""
101
+ self._weighted_sum += other._weighted_sum # pylint: disable=protected-access
102
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
103
+ self.data_points.extend(other.data_points)
104
+ self.increment_total(other.total)
105
+ return self
106
+
91
107
  def __gt__(self, other: Union['MetricValue', float]) -> bool:
92
108
  if isinstance(other, self.__class__):
93
109
  return float(self) > float(other)
@@ -133,7 +149,13 @@ class MetricValue(pg.Object):
133
149
 
134
150
 
135
151
  class Rate(MetricValue):
136
- """Representing a rate in range [0, 1]."""
152
+ """Metric value representing a rate in range [0, 1].
153
+
154
+ `Rate` is used for metrics that compute a rate, such as accuracy or error
155
+ rate. The final value is computed as the weighted sum of accumulated values
156
+ divided by the total number of examples. It's displayed as a percentage
157
+ (e.g., 90.0%).
158
+ """
137
159
 
138
160
  def reduce(self) -> float:
139
161
  return self._weighted_sum / self.total
@@ -145,7 +167,13 @@ class Rate(MetricValue):
145
167
 
146
168
 
147
169
  class Average(MetricValue):
148
- """Average of a aggregated values."""
170
+ """Metric value representing an average of accumulated values.
171
+
172
+ `Average` is used for metrics that compute an average score across examples
173
+ (e.g., average quality score). The final value is computed as the weighted
174
+ sum of accumulated values divided by the number of data points.
175
+ It's displayed as a float with 3 decimal places (e.g., 4.750).
176
+ """
149
177
 
150
178
  def reduce(self) -> float:
151
179
  if not self.data_points:
@@ -51,6 +51,22 @@ class RateTest(unittest.TestCase):
51
51
  self.assertEqual(rate.total, 0)
52
52
  self.assertTrue(math.isnan(float(rate)))
53
53
 
54
+ def test_merge_from(self):
55
+ rate1 = metric_values.Rate()
56
+ rate1.add(1, 1.0, 1.0, increment_total=True)
57
+ rate2 = metric_values.Rate()
58
+ rate2.add(2, 0.0, 1.0, increment_total=True)
59
+ rate1.merge_from(rate2)
60
+ self.assertEqual(rate1.total, 2)
61
+ self.assertEqual(float(rate1), 0.5)
62
+ self.assertEqual(
63
+ rate1.data_points,
64
+ [
65
+ metric_values.MetricValue.DataPoint(1, 1.0, 1.0),
66
+ metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
67
+ ],
68
+ )
69
+
54
70
 
55
71
  class AverageTest(unittest.TestCase):
56
72
 
@@ -75,6 +91,22 @@ class AverageTest(unittest.TestCase):
75
91
  average.reset()
76
92
  self.assertEqual(average.total, 0)
77
93
 
94
+ def test_merge_from(self):
95
+ avg1 = metric_values.Average()
96
+ avg1.add(1, 1.0, 0.5, increment_total=True)
97
+ avg2 = metric_values.Average()
98
+ avg2.add(2, 0.0, 1.0, increment_total=True)
99
+ avg1.merge_from(avg2)
100
+ self.assertEqual(avg1.total, 2)
101
+ self.assertEqual(float(avg1), 0.25)
102
+ self.assertEqual(
103
+ avg1.data_points,
104
+ [
105
+ metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
106
+ metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
107
+ ],
108
+ )
109
+
78
110
 
79
111
  if __name__ == '__main__':
80
112
  unittest.main()
@@ -29,7 +29,15 @@ Average = metric_values.Average
29
29
 
30
30
 
31
31
  class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
32
- """Interface for an evaluation metric."""
32
+ """Interface for an evaluation metric.
33
+
34
+ A metric is used to evaluate the quality of the outputs produced by an
35
+ evaluation. It works by auditing each processed example via its `audit`
36
+ method, which in turn calls the user-overridable `_audit` method to perform
37
+ metric-specific logic and update metric values. Metrics can compute multiple
38
+ values (e.g., precision, recall, F1 score) which are exposed via the
39
+ `values` method.
40
+ """
33
41
 
34
42
  name: Annotated[
35
43
  str,
@@ -44,24 +52,43 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
44
52
  self._label_group = None
45
53
  self._lock = threading.Lock()
46
54
 
47
- def audit(self, example: example_lib.Example) -> dict[str, Any]:
48
- """Audits a processed example and returns metric metadata for it."""
49
- # NOTE(daiyip): the metric values are being updated concurrently, so we
50
- # uses a lock to avoid race condition. We might consider relaxing the lock
51
- # later if metric auditing becomes a bottleneck.
52
- with self._lock:
53
- for v in self.values():
54
- v.increment_total()
55
+ def update(
56
+ self,
57
+ example: example_lib.Example,
58
+ force_recompute: bool = False
59
+ ) -> dict[str, Any]:
60
+ """Updates metric values with a processed example.
55
61
 
56
- metadata = self._audit(example)
62
+ Args:
63
+ example: The processed example.
64
+ force_recompute: Whether to force recompute the metric metadata even if
65
+ they are already present.
57
66
 
58
- self._update_view()
59
- return metadata
67
+ Returns:
68
+ A dict of metric metadata.
69
+ """
70
+ if (force_recompute
71
+ or example.metric_metadata is None
72
+ or self.name not in example.metric_metadata):
73
+ metadata = self.compute_metric_metadata(example)
74
+ else:
75
+ metadata = example.metric_metadata[self.name]
76
+ self.update_metric_values(example.id, metadata)
77
+ self._update_view()
78
+ return metadata
60
79
 
61
80
  @abc.abstractmethod
62
- def _audit(self, example: example_lib.Example) -> dict[str, Any]:
81
+ def compute_metric_metadata(
82
+ self, example: example_lib.Example
83
+ ) -> dict[str, Any]:
63
84
  """Subclasses should override this method to implement the metric logic."""
64
85
 
86
+ @abc.abstractmethod
87
+ def update_metric_values(
88
+ self, example_id: int, metric_metadata: dict[str, Any]
89
+ ) -> None:
90
+ """Update metric values based on metric metadata."""
91
+
65
92
  @abc.abstractmethod
66
93
  def values(self) -> list[metric_values.MetricValue]:
67
94
  """Returns all the values computed by this metric."""
@@ -71,6 +98,12 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
71
98
  for v in self.values():
72
99
  v.reset()
73
100
 
101
+ def merge_from(self, other: 'Metric') -> 'Metric':
102
+ """Merges the values from another metric."""
103
+ for v1, v2 in zip(self.values(), other.values()):
104
+ v1.merge_from(v2)
105
+ return self
106
+
74
107
  def _update_view(self):
75
108
  """Refreshes the metric values."""
76
109
  if self._label_group is None:
@@ -169,7 +202,15 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
169
202
 
170
203
 
171
204
  class MetricBase(Metric):
172
- """Base class for common metrics."""
205
+ """Base class for common metrics.
206
+
207
+ `MetricBase` provides common functionalities for metrics, such as automatic
208
+ error counting based on whether an example has an error during evaluation.
209
+ It distinguishes between Object-Oriented Programming (OOP) errors
210
+ (e.g. `MappingError` during structured output generation) and other errors.
211
+ Subclasses should implement `_audit_processed` for metric computation on
212
+ successfully processed examples.
213
+ """
173
214
 
174
215
  oop_errors: Rate | None = Rate()
175
216
  non_oop_errors: Rate | None = Rate()
@@ -183,27 +224,67 @@ class MetricBase(Metric):
183
224
  super().reset()
184
225
  self._error_breakdown = collections.defaultdict(list)
185
226
 
186
- def _audit(self, example: example_lib.Example) -> dict[str, Any]:
187
- """Audits the evaluation example after processing."""
227
+ def compute_metric_metadata(
228
+ self, example: example_lib.Example
229
+ ) -> dict[str, Any]:
230
+ """Computes the metric metadata for the example."""
188
231
  if example.error is None:
189
- return self._audit_processed(example)
232
+ return self._compute_metric_metadata(example)
233
+ return self._compute_metric_metadata_with_processing_error(example)
234
+
235
+ def update_metric_values(
236
+ self,
237
+ example_id: int,
238
+ metric_metadata: dict[str, Any]
239
+ ) -> None:
240
+ """Collects the metric metadata."""
241
+ # NOTE(daiyip): the metric values are being updated concurrently, so we
242
+ # uses a lock to avoid race condition. We might consider relaxing the lock
243
+ # later if metric auditing becomes a bottleneck.
244
+ with self._lock:
245
+ for v in self.values():
246
+ v.increment_total()
247
+
248
+ if 'error' in metric_metadata:
249
+ self._update_metric_values_with_processing_error(
250
+ example_id, metric_metadata
251
+ )
190
252
  else:
191
- return self._audit_error(example)
253
+ self._update_metric_values(example_id, metric_metadata)
254
+
255
+ @abc.abstractmethod
256
+ def _compute_metric_metadata(
257
+ self,
258
+ example: example_lib.Example
259
+ ) -> dict[str, Any]:
260
+ """Computes the metric metadata for the example."""
192
261
 
193
- def _audit_error(self, example: example_lib.Example) -> dict[str, Any]:
262
+ def _compute_metric_metadata_with_processing_error(
263
+ self,
264
+ example: example_lib.Example
265
+ ) -> dict[str, Any]:
194
266
  """Audits the evaluation example after processing."""
195
267
  assert example.error is not None
196
- tag = example.error.tag
197
- if tag.startswith('MappingError'):
198
- self.oop_errors.add(example.id, 1)
199
- else:
200
- self.non_oop_errors.add(example.id, 1)
201
- self._error_breakdown[tag].append(example.id)
202
- return dict(error=tag)
268
+ return dict(error=example.error.tag)
203
269
 
204
270
  @abc.abstractmethod
205
- def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
206
- """Audits the evaluation example after processing."""
271
+ def _update_metric_values(self, metadata: dict[str, Any]) -> None:
272
+ """Update metric values based metric metadata."""
273
+
274
+ def _update_metric_values_with_processing_error(
275
+ self,
276
+ example_id: int,
277
+ metric_metadata: dict[str, Any]
278
+ ) -> None:
279
+ """Updates metric values with processing error."""
280
+ error_tag = metric_metadata.get('error')
281
+ assert error_tag is not None, (example_id, metric_metadata)
282
+ self._error_breakdown[error_tag].append(example_id)
283
+ if error_tag.startswith('MappingError'):
284
+ self.oop_errors.add(example_id, 1)
285
+ else:
286
+ self.non_oop_errors.add(example_id, 1)
287
+ self._error_breakdown[error_tag].append(example_id)
207
288
 
208
289
  def _oop_errors_breakdown(self) -> str | None:
209
290
  """Returns the OOP error breakdown as a string."""
@@ -229,7 +310,13 @@ class MetricBase(Metric):
229
310
 
230
311
 
231
312
  class Match(MetricBase):
232
- """Metric for matching outputs against groundtruth."""
313
+ """Metric for matching outputs against ground truth.
314
+
315
+ This metric computes match and mismatch rates by comparing the output of
316
+ an example with its ground truth. By default, it looks for a `groundtruth`
317
+ attribute in `example.input` for comparison. Users can customize this behavior
318
+ by subclassing `Match` and overriding the `match` method.
319
+ """
233
320
 
234
321
  name = 'match'
235
322
  matches: Rate = Rate()
@@ -257,20 +344,30 @@ class Match(MetricBase):
257
344
  )
258
345
  return pg.eq(output, groundtruth)
259
346
 
260
- def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
261
- """Audits the evaluation example after processing."""
347
+ def _compute_metric_metadata(
348
+ self, example: example_lib.Example
349
+ ) -> dict[str, Any]:
350
+ """Computes the metric metadata for the example."""
262
351
  metadata = {}
263
- is_match = self.match(example.input, example.output)
264
- if isinstance(is_match, tuple):
265
- is_match, metadata = is_match
266
- if is_match:
267
- self.matches.add(example.id, 1)
268
- metadata['match'] = True
269
- else:
270
- self.mismatches.add(example.id, 1)
271
- metadata['mismatch'] = True
352
+ is_correct = self.match(example.input, example.output)
353
+ if isinstance(is_correct, tuple):
354
+ is_correct, metadata = is_correct
355
+
356
+ metadata['is_correct'] = is_correct
272
357
  return metadata
273
358
 
359
+ def _update_metric_values(
360
+ self, example_id: int, metadata: dict[str, Any]
361
+ ) -> None:
362
+ """Update metric values based metric metadata."""
363
+ is_correct = metadata.get('is_correct')
364
+ assert is_correct is not None, (example_id, metadata)
365
+ if is_correct:
366
+ self.matches.add(example_id, 1)
367
+ else:
368
+ assert not is_correct
369
+ self.mismatches.add(example_id, 1)
370
+
274
371
  def values(self) -> list[metric_values.MetricValue]:
275
372
  """Returns all the values computed by this metric."""
276
373
  return [
@@ -302,7 +399,14 @@ class Match(MetricBase):
302
399
 
303
400
 
304
401
  class Score(MetricBase):
305
- """Base class for scoring."""
402
+ """Base class for scoring metrics.
403
+
404
+ `Score` is a base class for metrics that assign a numerical score to each
405
+ example's output (e.g., evaluating quality on a scale of 1-5).
406
+ It automatically computes the average score across all examples.
407
+ Subclasses must implement the `score` method to define how an example
408
+ should be scored.
409
+ """
306
410
 
307
411
  name = 'score'
308
412
  average_score: Average = Average()
@@ -322,16 +426,25 @@ class Score(MetricBase):
322
426
  A float score. Or a tuple of (score, metadata).
323
427
  """
324
428
 
325
- def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
326
- """Audits the evaluation example after processing."""
429
+ def _compute_metric_metadata(
430
+ self, example: example_lib.Example
431
+ ) -> dict[str, Any]:
432
+ """Computes the metric metadata for the example."""
327
433
  metadata = {}
328
434
  score = self.score(example.input, example.output)
329
435
  if isinstance(score, tuple):
330
436
  score, metadata = score
331
- self.average_score.add(example.id, score)
332
437
  metadata['score'] = score
333
438
  return metadata
334
439
 
440
+ def _update_metric_values(
441
+ self, example_id: int, metadata: dict[str, Any]
442
+ ) -> None:
443
+ """Update metric values based metric metadata."""
444
+ score = metadata.get('score')
445
+ assert score is not None, (example_id, metadata)
446
+ self.average_score.add(example_id, score)
447
+
335
448
  def values(self) -> list[metric_values.MetricValue]:
336
449
  """Returns all the values computed by this metric."""
337
450
  return [
@@ -25,15 +25,22 @@ class MatchTest(unittest.TestCase):
25
25
  def test_basic(self):
26
26
  m = metrics.Match() # pylint: disable=invalid-name
27
27
  self.assertEqual(
28
- m.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1)),
29
- dict(match=True)
28
+ m.update(Example(id=1, input=pg.Dict(groundtruth=1), output=1)),
29
+ dict(is_correct=True)
30
30
  )
31
31
  self.assertEqual(
32
- m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2)),
33
- dict(mismatch=True)
32
+ m.update(
33
+ Example(
34
+ id=2,
35
+ input=pg.Dict(groundtruth=1),
36
+ output=2,
37
+ metric_metadata=dict(match=dict(is_correct=False, x=1))
38
+ )
39
+ ),
40
+ dict(is_correct=False, x=1)
34
41
  )
35
42
  self.assertEqual(
36
- m.audit(
43
+ m.update(
37
44
  Example(
38
45
  id=3,
39
46
  input=pg.Dict(groundtruth=1),
@@ -47,7 +54,7 @@ class MatchTest(unittest.TestCase):
47
54
  dict(error='ValueError')
48
55
  )
49
56
  self.assertEqual(
50
- m.audit(
57
+ m.update(
51
58
  Example(
52
59
  id=3,
53
60
  input=pg.Dict(groundtruth=1),
@@ -80,7 +87,7 @@ class MatchTest(unittest.TestCase):
80
87
  def test_bad_case(self):
81
88
  m = metrics.Match() # pylint: disable=invalid-name
82
89
  with self.assertRaisesRegex(ValueError, '`groundtruth` is not present'):
83
- m.audit(Example(id=1, input=pg.Dict(x=1), output=1))
90
+ m.update(Example(id=1, input=pg.Dict(x=1), output=1))
84
91
 
85
92
  def test_custom_metadata(self):
86
93
 
@@ -90,22 +97,36 @@ class MatchTest(unittest.TestCase):
90
97
 
91
98
  m = MyMatch() # pylint: disable=invalid-name
92
99
  self.assertEqual(
93
- m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
94
- dict(match=True, x=1)
100
+ m.update(Example(id=1, input=pg.Dict(x=1), output=1)),
101
+ dict(is_correct=True, x=1)
95
102
  )
96
103
  self.assertEqual(m.matches, 1.0)
97
104
 
98
105
  def test_html_view(self):
99
106
  m = metrics.Match() # pylint: disable=invalid-name
100
- m.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
107
+ m.update(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
101
108
  self.assertIn(
102
109
  '100.0%',
103
110
  m.to_html().content,
104
111
  )
105
112
  with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
106
- m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
113
+ m.update(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
107
114
  self.assertEqual(len(scripts), 12)
108
115
 
116
+ def test_merge_from(self):
117
+ m1 = metrics.Match()
118
+ m1.update(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
119
+ m2 = metrics.Match()
120
+ m2.update(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
121
+ m1.merge_from(m2)
122
+ self.assertEqual(m1.matches, 0.5)
123
+ self.assertEqual(m1.mismatches, 0.5)
124
+ self.assertEqual(m1.oop_errors, 0.0)
125
+ self.assertEqual(m1.non_oop_errors, 0.0)
126
+ self.assertEqual(m1.matches.total, 2)
127
+ self.assertEqual(len(m1.matches.data_points), 1)
128
+ self.assertEqual(len(m1.mismatches.data_points), 1)
129
+
109
130
 
110
131
  class ScoreTest(unittest.TestCase):
111
132
 
@@ -118,15 +139,15 @@ class ScoreTest(unittest.TestCase):
118
139
 
119
140
  m = MyScore() # pylint: disable=invalid-name
120
141
  self.assertEqual(
121
- m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
142
+ m.update(Example(id=1, input=pg.Dict(x=1), output=1)),
122
143
  dict(score=1 * 1)
123
144
  )
124
145
  self.assertEqual(
125
- m.audit(Example(id=2, input=pg.Dict(x=2), output=2)),
146
+ m.update(Example(id=2, input=pg.Dict(x=2), output=2)),
126
147
  dict(score=2 * 2)
127
148
  )
128
149
  self.assertEqual(
129
- m.audit(
150
+ m.update(
130
151
  Example(
131
152
  id=3,
132
153
  input=pg.Dict(x=1),
@@ -140,7 +161,7 @@ class ScoreTest(unittest.TestCase):
140
161
  dict(error='ValueError')
141
162
  )
142
163
  self.assertEqual(
143
- m.audit(
164
+ m.update(
144
165
  Example(
145
166
  id=3,
146
167
  input=pg.Dict(x=1),
@@ -176,7 +197,7 @@ class ScoreTest(unittest.TestCase):
176
197
 
177
198
  m = MyScore() # pylint: disable=invalid-name
178
199
  self.assertEqual(
179
- m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
200
+ m.update(Example(id=1, input=pg.Dict(x=1), output=1)),
180
201
  dict(score=1 * 1, x=1)
181
202
  )
182
203
  self.assertEqual(m.average_score, 1.0)
@@ -189,13 +210,13 @@ class ScoreTest(unittest.TestCase):
189
210
  return example_input.x * output
190
211
 
191
212
  m = MyScore() # pylint: disable=invalid-name
192
- m.audit(Example(id=1, input=pg.Dict(x=1), output=2))
213
+ m.update(Example(id=1, input=pg.Dict(x=1), output=2))
193
214
  self.assertIn(
194
215
  '2.000',
195
216
  m.to_html().content,
196
217
  )
197
218
  with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
198
- m.audit(Example(id=2, input=pg.Dict(x=1), output=2))
219
+ m.update(Example(id=2, input=pg.Dict(x=1), output=2))
199
220
  self.assertEqual(len(scripts), 9)
200
221
 
201
222
 
@@ -21,7 +21,15 @@ import pyglove as pg
21
21
 
22
22
 
23
23
  class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
24
- """Evaluation progress."""
24
+ """Represents and tracks the progress of an evaluation.
25
+
26
+ The `Progress` class maintains counts of processed, failed, and skipped
27
+ items in an evaluation, along with timing information (start time, stop time,
28
+ duration) and an execution summary. It provides properties to check the
29
+ status of the evaluation (e.g., `is_started`, `is_completed`) and methods
30
+ to update progress as items are evaluated.
31
+ It also supports HTML rendering as a progress bar for visualization.
32
+ """
25
33
 
26
34
  num_total: Annotated[
27
35
  int | None,
@@ -84,6 +92,7 @@ class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
84
92
  stop_time=None,
85
93
  execution_summary=pg.object_utils.TimeIt.StatusSummary(),
86
94
  )
95
+ self._progress_bar = None
87
96
 
88
97
  @property
89
98
  def num_completed(self) -> int:
@@ -216,6 +225,27 @@ class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
216
225
  """Overrides nondefault values so volatile values are not included."""
217
226
  return dict()
218
227
 
228
+ def merge_from(self, other: 'Progress') -> None:
229
+ """Merges the progress from another progress."""
230
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
231
+ if other.start_time is not None and (
232
+ self.start_time is None or self.start_time > other.start_time):
233
+ self.start_time = other.start_time
234
+
235
+ if other.stop_time is not None and (
236
+ self.stop_time is None or self.stop_time < other.stop_time):
237
+ self.stop_time = other.stop_time
238
+
239
+ if other.num_total is not None:
240
+ if self.num_total is None:
241
+ self.num_total = other.num_total
242
+ else:
243
+ assert self.num_total == other.num_total, (self, other)
244
+ self.num_processed += other.num_processed
245
+ self.num_failed += other.num_failed
246
+ self.num_skipped += other.num_skipped
247
+ self.execution_summary.aggregate(other.execution_summary.breakdown)
248
+
219
249
  #
220
250
  # HTML view.
221
251
  #
@@ -77,6 +77,33 @@ class ProgressTest(unittest.TestCase):
77
77
  self.assertTrue(p.is_stopped)
78
78
  self.assertIsNotNone(p.stop_time_str)
79
79
 
80
+ def test_merge_from(self):
81
+ p1 = Progress()
82
+ p1.start(10)
83
+ p1.increment_processed()
84
+ p1.increment_failed()
85
+ p1.stop()
86
+
87
+ p2 = Progress()
88
+ p2.start(10)
89
+ p2.increment_skipped()
90
+ p2.stop()
91
+
92
+ with pg.allow_writable_accessors(True):
93
+ p1.start_time = 2.0
94
+ p1.stop_time = 4.0
95
+ p2.start_time = 1.0
96
+ p2.stop_time = 5.0
97
+
98
+ p1.merge_from(p2)
99
+ self.assertEqual(p1.num_total, 10)
100
+ self.assertEqual(p1.num_processed, 1)
101
+ self.assertEqual(p1.num_failed, 1)
102
+ self.assertEqual(p1.num_skipped, 1)
103
+ self.assertEqual(p1.num_completed, 3)
104
+ self.assertEqual(p1.start_time, 1.0)
105
+ self.assertEqual(p1.stop_time, 5.0)
106
+
80
107
 
81
108
  if __name__ == '__main__':
82
109
  unittest.main()
@@ -14,6 +14,7 @@
14
14
  """Tracking evaluation run progress."""
15
15
 
16
16
  import os
17
+ from typing import Literal
17
18
  import langfun.core as lf
18
19
  from langfun.core.eval.v2 import example as example_lib
19
20
  from langfun.core.eval.v2 import experiment as experiment_lib
@@ -24,16 +25,24 @@ Experiment = experiment_lib.Experiment
24
25
  Example = example_lib.Example
25
26
 
26
27
 
27
- def progress_tracker(tqdm: bool = False) -> experiment_lib.Plugin:
28
+ def progress_tracker(
29
+ tracker_type: Literal['tqdm', 'html', 'auto'] = 'auto'
30
+ ) -> experiment_lib.Plugin:
28
31
  """Creates a progress tracker as a plugin.
29
32
 
30
33
  Args:
31
- tqdm: If True, force using tqdm for progress update.
34
+ tracker_type: The type of progress tracker to use.
35
+ If `tqdm`, force using tqdm for progress update.
36
+ If `html`, force using html for progress update.
37
+ If `auto`, determine it automatically based on the running
38
+ environment (console vs. notebook)
32
39
 
33
40
  Returns:
34
41
  The progress tracker plugin.
35
42
  """
36
- if tqdm or not lf.console.under_notebook():
43
+ if tracker_type == 'tqdm' or (
44
+ tracker_type == 'auto' and not lf.console.under_notebook()
45
+ ):
37
46
  return _TqdmProgressTracker()
38
47
  else:
39
48
  return _HtmlProgressTracker()
@@ -88,8 +97,7 @@ class _TqdmProgressTracker(experiment_lib.Plugin):
88
97
  self._leaf_progresses = {
89
98
  leaf.id: lf.concurrent.ProgressBar.install(
90
99
  label=f'[#{i + 1} - {leaf.id}]',
91
- total=(len(runner.current_run.example_ids)
92
- if runner.current_run.example_ids else leaf.num_examples),
100
+ total=len(runner.current_run.examples_to_evaluate(leaf)),
93
101
  color='cyan',
94
102
  status=None
95
103
  )