langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (155) hide show
  1. langfun/core/__init__.py +2 -0
  2. langfun/core/agentic/__init__.py +4 -1
  3. langfun/core/agentic/action.py +447 -29
  4. langfun/core/agentic/action_eval.py +9 -2
  5. langfun/core/agentic/action_test.py +149 -21
  6. langfun/core/async_support.py +32 -3
  7. langfun/core/coding/python/correction.py +19 -9
  8. langfun/core/coding/python/execution.py +14 -12
  9. langfun/core/coding/python/generation.py +21 -16
  10. langfun/core/coding/python/sandboxing.py +23 -3
  11. langfun/core/component.py +42 -3
  12. langfun/core/concurrent.py +70 -6
  13. langfun/core/concurrent_test.py +1 -0
  14. langfun/core/console.py +1 -1
  15. langfun/core/data/conversion/anthropic.py +12 -3
  16. langfun/core/data/conversion/anthropic_test.py +8 -6
  17. langfun/core/data/conversion/gemini.py +9 -2
  18. langfun/core/data/conversion/gemini_test.py +12 -9
  19. langfun/core/data/conversion/openai.py +145 -31
  20. langfun/core/data/conversion/openai_test.py +161 -17
  21. langfun/core/eval/base.py +47 -43
  22. langfun/core/eval/base_test.py +5 -5
  23. langfun/core/eval/matching.py +5 -2
  24. langfun/core/eval/patching.py +3 -3
  25. langfun/core/eval/scoring.py +4 -3
  26. langfun/core/eval/v2/__init__.py +1 -0
  27. langfun/core/eval/v2/checkpointing.py +64 -6
  28. langfun/core/eval/v2/checkpointing_test.py +9 -2
  29. langfun/core/eval/v2/eval_test_helper.py +103 -2
  30. langfun/core/eval/v2/evaluation.py +91 -16
  31. langfun/core/eval/v2/evaluation_test.py +9 -3
  32. langfun/core/eval/v2/example.py +50 -40
  33. langfun/core/eval/v2/example_test.py +16 -8
  34. langfun/core/eval/v2/experiment.py +74 -8
  35. langfun/core/eval/v2/experiment_test.py +19 -0
  36. langfun/core/eval/v2/metric_values.py +31 -3
  37. langfun/core/eval/v2/metric_values_test.py +32 -0
  38. langfun/core/eval/v2/metrics.py +157 -44
  39. langfun/core/eval/v2/metrics_test.py +39 -18
  40. langfun/core/eval/v2/progress.py +30 -1
  41. langfun/core/eval/v2/progress_test.py +27 -0
  42. langfun/core/eval/v2/progress_tracking.py +12 -3
  43. langfun/core/eval/v2/progress_tracking_test.py +6 -1
  44. langfun/core/eval/v2/reporting.py +90 -71
  45. langfun/core/eval/v2/reporting_test.py +24 -6
  46. langfun/core/eval/v2/runners/__init__.py +30 -0
  47. langfun/core/eval/v2/{runners.py → runners/base.py} +59 -142
  48. langfun/core/eval/v2/runners/beam.py +341 -0
  49. langfun/core/eval/v2/runners/beam_test.py +131 -0
  50. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  51. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  52. langfun/core/eval/v2/runners/debug.py +40 -0
  53. langfun/core/eval/v2/runners/debug_test.py +76 -0
  54. langfun/core/eval/v2/runners/parallel.py +100 -0
  55. langfun/core/eval/v2/runners/parallel_test.py +95 -0
  56. langfun/core/eval/v2/runners/sequential.py +47 -0
  57. langfun/core/eval/v2/runners/sequential_test.py +172 -0
  58. langfun/core/langfunc.py +45 -130
  59. langfun/core/langfunc_test.py +7 -5
  60. langfun/core/language_model.py +141 -21
  61. langfun/core/language_model_test.py +54 -3
  62. langfun/core/llms/__init__.py +9 -1
  63. langfun/core/llms/anthropic.py +157 -2
  64. langfun/core/llms/azure_openai.py +29 -17
  65. langfun/core/llms/cache/base.py +25 -3
  66. langfun/core/llms/cache/in_memory.py +48 -7
  67. langfun/core/llms/cache/in_memory_test.py +14 -4
  68. langfun/core/llms/compositional.py +25 -1
  69. langfun/core/llms/deepseek.py +30 -2
  70. langfun/core/llms/fake.py +32 -1
  71. langfun/core/llms/gemini.py +55 -17
  72. langfun/core/llms/gemini_test.py +84 -0
  73. langfun/core/llms/google_genai.py +34 -1
  74. langfun/core/llms/groq.py +28 -3
  75. langfun/core/llms/llama_cpp.py +23 -4
  76. langfun/core/llms/openai.py +36 -3
  77. langfun/core/llms/openai_compatible.py +148 -27
  78. langfun/core/llms/openai_compatible_test.py +207 -20
  79. langfun/core/llms/openai_test.py +0 -2
  80. langfun/core/llms/rest.py +12 -1
  81. langfun/core/llms/vertexai.py +58 -8
  82. langfun/core/logging.py +1 -1
  83. langfun/core/mcp/client.py +77 -22
  84. langfun/core/mcp/client_test.py +8 -35
  85. langfun/core/mcp/session.py +94 -29
  86. langfun/core/mcp/session_test.py +54 -0
  87. langfun/core/mcp/tool.py +151 -22
  88. langfun/core/mcp/tool_test.py +197 -0
  89. langfun/core/memory.py +1 -0
  90. langfun/core/message.py +160 -55
  91. langfun/core/message_test.py +65 -81
  92. langfun/core/modalities/__init__.py +8 -0
  93. langfun/core/modalities/audio.py +21 -1
  94. langfun/core/modalities/image.py +19 -1
  95. langfun/core/modalities/mime.py +64 -3
  96. langfun/core/modalities/mime_test.py +11 -0
  97. langfun/core/modalities/pdf.py +19 -1
  98. langfun/core/modalities/video.py +21 -1
  99. langfun/core/modality.py +167 -29
  100. langfun/core/modality_test.py +42 -12
  101. langfun/core/natural_language.py +1 -1
  102. langfun/core/sampling.py +4 -4
  103. langfun/core/sampling_test.py +20 -4
  104. langfun/core/structured/__init__.py +2 -24
  105. langfun/core/structured/completion.py +34 -44
  106. langfun/core/structured/completion_test.py +23 -43
  107. langfun/core/structured/description.py +54 -50
  108. langfun/core/structured/function_generation.py +29 -12
  109. langfun/core/structured/mapping.py +81 -37
  110. langfun/core/structured/parsing.py +95 -79
  111. langfun/core/structured/parsing_test.py +0 -3
  112. langfun/core/structured/querying.py +215 -142
  113. langfun/core/structured/querying_test.py +65 -29
  114. langfun/core/structured/schema/__init__.py +49 -0
  115. langfun/core/structured/schema/base.py +664 -0
  116. langfun/core/structured/schema/base_test.py +531 -0
  117. langfun/core/structured/schema/json.py +174 -0
  118. langfun/core/structured/schema/json_test.py +121 -0
  119. langfun/core/structured/schema/python.py +316 -0
  120. langfun/core/structured/schema/python_test.py +410 -0
  121. langfun/core/structured/schema_generation.py +33 -14
  122. langfun/core/structured/scoring.py +47 -36
  123. langfun/core/structured/tokenization.py +26 -11
  124. langfun/core/subscription.py +2 -2
  125. langfun/core/template.py +174 -49
  126. langfun/core/template_test.py +123 -17
  127. langfun/env/__init__.py +8 -2
  128. langfun/env/base_environment.py +320 -128
  129. langfun/env/base_environment_test.py +473 -0
  130. langfun/env/base_feature.py +92 -15
  131. langfun/env/base_feature_test.py +228 -0
  132. langfun/env/base_sandbox.py +84 -361
  133. langfun/env/base_sandbox_test.py +1235 -0
  134. langfun/env/event_handlers/__init__.py +1 -1
  135. langfun/env/event_handlers/chain.py +233 -0
  136. langfun/env/event_handlers/chain_test.py +253 -0
  137. langfun/env/event_handlers/event_logger.py +95 -98
  138. langfun/env/event_handlers/event_logger_test.py +21 -21
  139. langfun/env/event_handlers/metric_writer.py +225 -140
  140. langfun/env/event_handlers/metric_writer_test.py +23 -6
  141. langfun/env/interface.py +854 -40
  142. langfun/env/interface_test.py +112 -2
  143. langfun/env/load_balancers_test.py +23 -2
  144. langfun/env/test_utils.py +126 -84
  145. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
  146. langfun-0.1.2.dev202511270805.dist-info/RECORD +215 -0
  147. langfun/core/eval/v2/runners_test.py +0 -343
  148. langfun/core/structured/schema.py +0 -987
  149. langfun/core/structured/schema_test.py +0 -982
  150. langfun/env/base_test.py +0 -1481
  151. langfun/env/event_handlers/base.py +0 -350
  152. langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
  153. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
  154. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
  155. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
@@ -22,19 +22,30 @@ import pyglove as pg
22
22
 
23
23
  @dataclasses.dataclass
24
24
  class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
25
- """An item for the evaluation.
25
+ """An example for evaluation.
26
+
27
+ An evaluation example contains the input and output of an evaluation task,
28
+ as well as metadata about the evaluation process, such as execution time,
29
+ LLM usage, and metric results.
26
30
 
27
31
  Attributes:
28
- id: The 1-based ID of the item in the evaluation set.
29
- input: An element returned from the `Evaluable.inputs` functor.
30
- output: The output of the `process` method. If `pg.MISSING_VALUE`, it has
31
- not been processed yet.
32
- metadata: The metadata of the item produced by the `process` method.
33
- metric_metadata: The dictionary returned from `Metric.audit`.
34
- start_time: The start time of the evaluation item.
35
- end_time: The end time of the evaluation item.
36
- usage_summary: The summary of LLM usages of the evaluation item.
37
- execution_status: The timeit status of the evaluation item.
32
+ id: The 1-based ID of the example in the evaluation set.
33
+ input: An element returned from the `Evaluable.inputs` functor, which serves
34
+ as the input for `lf.Evaluable.process`.
35
+ output: The output of `lf.Evaluable.process` method. If `pg.MISSING_VALUE`,
36
+ it indicates the example has not been processed yet.
37
+ error: The error raised from `lf.Evaluable.process`. If None, it
38
+ indicates the process was successful.
39
+ metadata: The metadata of the example produced by `lf.Evaluable.process`.
40
+ metric_metadata: The dictionary returned from `Metric.audit`, which contains
41
+ metadata about metric computation for this example.
42
+ newly_processed: Whether this example is processed in the current run. If
43
+ False, it indicates the example was loaded from a checkpoint from previous
44
+ runs.
45
+ start_time: The start time of processing this example.
46
+ end_time: The end time of processing this example.
47
+ usage_summary: The summary of LLM usages for processing this example.
48
+ execution_status: The timeit status of processing this example.
38
49
  """
39
50
  id: int
40
51
  input: Any = pg.MISSING_VALUE
@@ -49,14 +60,6 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
49
60
  usage_summary: lf.UsageSummary | None = None
50
61
  execution_status: dict[str, pg.utils.TimeIt.Status] | None = None
51
62
 
52
- def __post_init__(self):
53
- if self.execution_status is not None:
54
- for status in self.execution_status.values():
55
- if status.has_error:
56
- assert isinstance(status.error, pg.ErrorInfo)
57
- self.error = status.error
58
- break
59
-
60
63
  @property
61
64
  def is_processed(self) -> bool:
62
65
  """Returns whether the item has been processed."""
@@ -152,6 +155,8 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
152
155
  ckpt_file: str | list[str],
153
156
  example_input_by_id: Callable[[int], Any] | None = None,
154
157
  load_example_metadata: bool = True,
158
+ convert_unknown: bool = True,
159
+ **kwargs
155
160
  ) -> Iterator['Example']:
156
161
  """Iterates Examples from the checkpoint files."""
157
162
  ckpt_files = [ckpt_file] if isinstance(ckpt_file, str) else ckpt_file
@@ -161,7 +166,9 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
161
166
  example = pg.from_json_str(
162
167
  record,
163
168
  example_input_by_id=example_input_by_id,
164
- load_example_metadata=load_example_metadata
169
+ load_example_metadata=load_example_metadata,
170
+ convert_unknown=convert_unknown,
171
+ **kwargs
165
172
  )
166
173
  assert isinstance(example, cls), example
167
174
  yield example
@@ -182,15 +189,23 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
182
189
  extra_flags = extra_flags or {}
183
190
  num_examples = extra_flags.get('num_examples', None)
184
191
 
185
- def _metric_metadata_badge(key, value):
186
- if isinstance(value, bool) and bool:
187
- text = key
188
- else:
189
- text = f'{key}:{value}'
190
- return pg.views.html.controls.Badge(
191
- text,
192
- css_classes=[pg.utils.camel_to_snake(key, '-')],
193
- )
192
+ def _metric_label_group(metric_metadata: dict[str, Any] | None):
193
+ """Renders a label group for metric metadata."""
194
+ badges = []
195
+ if metric_metadata:
196
+ for metric_name, metadata in metric_metadata.items():
197
+ assert isinstance(metadata, dict), (metric_name, metadata)
198
+ for k, v in metadata.items():
199
+ css_class = k
200
+ if isinstance(v, bool):
201
+ css_class += '_true' if v else '_false'
202
+ badge = pg.views.html.controls.Badge(
203
+ f'{k}:{v}',
204
+ tooltip=f'{metric_name}: {k}',
205
+ css_classes=[css_class],
206
+ )
207
+ badges.append(badge)
208
+ return pg.views.html.controls.LabelGroup(badges)
194
209
 
195
210
  def _render_header():
196
211
  return pg.Html.element(
@@ -229,12 +244,7 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
229
244
  extra_flags=dict(as_badge=True)
230
245
  ) if self.usage_summary is not None else None,
231
246
  # Metric metadata.
232
- pg.views.html.controls.LabelGroup(
233
- [ # pylint: disable=g-long-ternary
234
- _metric_metadata_badge(k, v)
235
- for k, v in self.metric_metadata.items()
236
- ] if self.metric_metadata else []
237
- ),
247
+ _metric_label_group(self.metric_metadata)
238
248
  ],
239
249
  css_classes=['example-container'],
240
250
  )
@@ -305,18 +315,18 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
305
315
  color: black;
306
316
  }
307
317
  /* Badge styles. */
308
- .eval-example .badge.match {
318
+ .eval-example .badge.is_correct_true {
309
319
  color: green;
310
320
  background-color: #dcefbe;
311
321
  }
322
+ .eval-example .badge.is_correct_false {
323
+ color: orange;
324
+ background-color: #ffefc4;
325
+ }
312
326
  .eval-example .badge.error {
313
327
  color: red;
314
328
  background-color: #fdcccc;
315
329
  }
316
- .eval-example .badge.mismatch {
317
- color: orange;
318
- background-color: #ffefc4;
319
- }
320
330
  .eval-example .badge.score {
321
331
  color: blue;
322
332
  background-color: #c4dced;
@@ -32,9 +32,9 @@ class ExampleTest(unittest.TestCase):
32
32
  name='evaluation', elapse=1.0, error=error
33
33
  )
34
34
  })
35
- self.assertEqual(ex.error, error)
35
+ self.assertIsNone(ex.error)
36
36
  self.assertFalse(ex.is_processed)
37
- self.assertTrue(ex.has_error)
37
+ self.assertFalse(ex.has_error)
38
38
  self.assertEqual(ex.elapse, 1.0)
39
39
 
40
40
  ex = Example(id=2, output=1)
@@ -94,15 +94,23 @@ class ExampleTest(unittest.TestCase):
94
94
  pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
95
95
  inputs[0].b.__type_name__
96
96
  )
97
- v = pg.from_json_str(json_str, auto_dict=True, load_example_metadata=True)
98
- v.output.pop('type_name')
99
- v.metadata.b.pop('type_name')
97
+ v = pg.from_json_str(
98
+ json_str,
99
+ convert_unknown=True,
100
+ load_example_metadata=True
101
+ )
100
102
  self.assertEqual(
101
103
  v,
102
104
  Example(
103
105
  id=1,
104
- output=pg.Dict(x=1),
105
- metadata=dict(b=pg.Dict(x=1, y=2)),
106
+ output=pg.symbolic.UnknownTypedObject(
107
+ inputs[0].a.__type_name__, x=1
108
+ ),
109
+ metadata=dict(
110
+ b=pg.symbolic.UnknownTypedObject(
111
+ inputs[0].b.__type_name__, x=1, y=2
112
+ )
113
+ ),
106
114
  )
107
115
  )
108
116
  # Serialize with input.
@@ -116,7 +124,7 @@ class ExampleTest(unittest.TestCase):
116
124
  input=pg.Dict(a=1, b=2),
117
125
  output=3,
118
126
  metadata=dict(sum=3),
119
- metric_metadata=dict(match=True),
127
+ metric_metadata=dict(match=dict(match=True)),
120
128
  )
121
129
  self.assertNotIn(
122
130
  'next',
@@ -139,10 +139,10 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
139
139
 
140
140
  # Checkpointing
141
141
 
142
- Experiments support checkpointing, which is enabled by default. It allows
142
+ Experiments support checkpointing, which is enabled by default. It allows
143
143
  users to resume their experiments from a saved state. When an experiment runs,
144
- it creates a new directory for that run and saves the current state to a
145
- checkpoint file. If the experiment is interrupted or fails, users can resume
144
+ it creates a new directory for that run and saves its progress to checkpoint
145
+ files. If the experiment is interrupted or fails, users can resume
146
146
  it by specifying the 'id' or 'warm_start_from' argument (shown above) to
147
147
  seamlessly continue from previously saved state without starting over.
148
148
 
@@ -169,7 +169,7 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
169
169
 
170
170
  # Experiment Plugins
171
171
 
172
- Experiment can be extended by plugins. Plugins can listen to the events of
172
+ Experiments can be extended by plugins. Plugins can listen to the events of
173
173
  experiment execution and produce additional outputs. For example, a plugin
174
174
  can be added to an experiment to generate additional metrics or to save
175
175
  additional data to a database. More details will be added in the future.
@@ -657,7 +657,30 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
657
657
 
658
658
  @pg.use_init_args(['children'])
659
659
  class Suite(Experiment):
660
- """A suite of evaluations."""
660
+ """A suite of evaluations.
661
+
662
+ `lf.eval.Suite` groups multiple `lf.eval.Evaluation` or other `Suite`
663
+ objects into a single experiment, allowing them to be run, managed, and
664
+ reported together.
665
+
666
+ **Example:**
667
+
668
+ ```python
669
+ import langfun as lf
670
+
671
+ suite = lf.eval.Suite([
672
+ MyEval(lm=lf.llms.Gpt4()),
673
+ MyEval(lm=lf.llms.Gemini()),
674
+ lf.eval.Suite([
675
+ AnotherEval(lm=lf.llms.Gpt4()),
676
+ AnotherEval(lm=lf.llms.Gemini())
677
+ ])
678
+ ])
679
+
680
+ # Run all evaluations in the suite
681
+ run_info = suite.run('/path/to/my/suite_run')
682
+ ```
683
+ """
661
684
 
662
685
  children: Annotated[
663
686
  list[Experiment], 'A list of child experiments.'
@@ -791,7 +814,14 @@ class RunId(pg.Object):
791
814
 
792
815
 
793
816
  class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
794
- """A run of an experiment."""
817
+ """Represents a single run of an experiment.
818
+
819
+ A `Run` object holds all the configurations for executing an experiment,
820
+ such as the experiment definition, input/output directories, and flags
821
+ controlling the execution behavior (e.g., error handling, checkpointing).
822
+ It also provides utility methods for accessing run-specific paths and
823
+ filtering examples for evaluation.
824
+ """
795
825
 
796
826
  root_dir: Annotated[
797
827
  str,
@@ -971,7 +1001,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
971
1001
 
972
1002
 
973
1003
  class Runner(pg.Object):
974
- """Interface for experiment runner."""
1004
+ """Interface for experiment runner.
1005
+
1006
+ A runner is responsible for executing the evaluations within an experiment
1007
+ based on the configuration specified in a `Run` object. Different runners
1008
+ can implement different execution strategies, such as sequential or parallel
1009
+ processing of examples and evaluations.
1010
+ """
975
1011
 
976
1012
  # Class-level variable for registering the runner.
977
1013
  NAME = None
@@ -1010,7 +1046,37 @@ class Runner(pg.Object):
1010
1046
 
1011
1047
 
1012
1048
  class Plugin(lf.Component):
1013
- """Base class for experiment plugins."""
1049
+ """Base class for experiment plugins.
1050
+
1051
+ Plugins provide a mechanism to extend the behavior of an experiment run
1052
+ by hooking into various events during the lifecycle of experiment and
1053
+ example execution, such as `on_run_start`, `on_experiment_complete`,
1054
+ `on_example_start`, etc. They can be used for custom logging, monitoring,
1055
+ or result processing.
1056
+ """
1057
+
1058
+ @classmethod
1059
+ def is_per_example(cls) -> bool:
1060
+ """Returns whether the plugin is per example only.
1061
+
1062
+ Per-example plugins can be installed on individual workers when examples
1063
+ are evaluated by multiple processes in parallel.
1064
+ """
1065
+
1066
+ def same_code(method1, method2):
1067
+ return method1.__code__ == method2.__code__
1068
+ return all(
1069
+ same_code(method1, method2)
1070
+ for method1, method2 in [
1071
+ (Plugin.on_run_start, cls.on_run_start),
1072
+ (Plugin.on_run_complete, cls.on_run_complete),
1073
+ (Plugin.on_run_abort, cls.on_run_abort),
1074
+ (Plugin.on_experiment_start, cls.on_experiment_start),
1075
+ (Plugin.on_experiment_skipped, cls.on_experiment_skipped),
1076
+ (Plugin.on_experiment_complete, cls.on_experiment_complete),
1077
+ (Plugin.on_experiment_abort, cls.on_experiment_abort),
1078
+ ]
1079
+ )
1014
1080
 
1015
1081
  def on_run_start(
1016
1082
  self,
@@ -433,5 +433,24 @@ class RunnerTest(unittest.TestCase):
433
433
  pass
434
434
 
435
435
 
436
+ class PluginTest(unittest.TestCase):
437
+
438
+ def test_per_example_only(self):
439
+
440
+ class PerExamplePlugin(experiment_lib.Plugin):
441
+
442
+ def on_example_complete(self, runner, experiment, example):
443
+ print('on_example_complete')
444
+
445
+ self.assertTrue(PerExamplePlugin.is_per_example())
446
+
447
+ class NonPerExamplePlugin(experiment_lib.Plugin):
448
+
449
+ def on_experiment_complete(self, runner, experiment):
450
+ print('on_example_complete')
451
+
452
+ self.assertFalse(NonPerExamplePlugin.is_per_example())
453
+
454
+
436
455
  if __name__ == '__main__':
437
456
  unittest.main()
@@ -20,7 +20,15 @@ import pyglove as pg
20
20
 
21
21
 
22
22
  class MetricValue(pg.Object):
23
- """Base class for metric values."""
23
+ """Base class for metric values.
24
+
25
+ `MetricValue` is the base class for representing aggregated metric values
26
+ in an evaluation. It accumulates data points from individual examples,
27
+ each consisting of a value and an optional weight, associated with an example
28
+ ID. Subclasses must implement `reduce` method to compute a single float value
29
+ from accumulated data points, and `scalar_repr` to provide a string
30
+ representation of the reduced value.
31
+ """
24
32
 
25
33
  class DataPoint(pg.Object):
26
34
  """A data point for a metric value."""
@@ -88,6 +96,14 @@ class MetricValue(pg.Object):
88
96
  self.increment_total()
89
97
  return self
90
98
 
99
+ def merge_from(self, other: 'MetricValue') -> 'MetricValue':
100
+ """Merges the values from another metric value."""
101
+ self._weighted_sum += other._weighted_sum # pylint: disable=protected-access
102
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
103
+ self.data_points.extend(other.data_points)
104
+ self.increment_total(other.total)
105
+ return self
106
+
91
107
  def __gt__(self, other: Union['MetricValue', float]) -> bool:
92
108
  if isinstance(other, self.__class__):
93
109
  return float(self) > float(other)
@@ -133,7 +149,13 @@ class MetricValue(pg.Object):
133
149
 
134
150
 
135
151
  class Rate(MetricValue):
136
- """Representing a rate in range [0, 1]."""
152
+ """Metric value representing a rate in range [0, 1].
153
+
154
+ `Rate` is used for metrics that compute a rate, such as accuracy or error
155
+ rate. The final value is computed as the weighted sum of accumulated values
156
+ divided by the total number of examples. It's displayed as a percentage
157
+ (e.g., 90.0%).
158
+ """
137
159
 
138
160
  def reduce(self) -> float:
139
161
  return self._weighted_sum / self.total
@@ -145,7 +167,13 @@ class Rate(MetricValue):
145
167
 
146
168
 
147
169
  class Average(MetricValue):
148
- """Average of a aggregated values."""
170
+ """Metric value representing an average of accumulated values.
171
+
172
+ `Average` is used for metrics that compute an average score across examples
173
+ (e.g., average quality score). The final value is computed as the weighted
174
+ sum of accumulated values divided by the number of data points.
175
+ It's displayed as a float with 3 decimal places (e.g., 4.750).
176
+ """
149
177
 
150
178
  def reduce(self) -> float:
151
179
  if not self.data_points:
@@ -51,6 +51,22 @@ class RateTest(unittest.TestCase):
51
51
  self.assertEqual(rate.total, 0)
52
52
  self.assertTrue(math.isnan(float(rate)))
53
53
 
54
+ def test_merge_from(self):
55
+ rate1 = metric_values.Rate()
56
+ rate1.add(1, 1.0, 1.0, increment_total=True)
57
+ rate2 = metric_values.Rate()
58
+ rate2.add(2, 0.0, 1.0, increment_total=True)
59
+ rate1.merge_from(rate2)
60
+ self.assertEqual(rate1.total, 2)
61
+ self.assertEqual(float(rate1), 0.5)
62
+ self.assertEqual(
63
+ rate1.data_points,
64
+ [
65
+ metric_values.MetricValue.DataPoint(1, 1.0, 1.0),
66
+ metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
67
+ ],
68
+ )
69
+
54
70
 
55
71
  class AverageTest(unittest.TestCase):
56
72
 
@@ -75,6 +91,22 @@ class AverageTest(unittest.TestCase):
75
91
  average.reset()
76
92
  self.assertEqual(average.total, 0)
77
93
 
94
+ def test_merge_from(self):
95
+ avg1 = metric_values.Average()
96
+ avg1.add(1, 1.0, 0.5, increment_total=True)
97
+ avg2 = metric_values.Average()
98
+ avg2.add(2, 0.0, 1.0, increment_total=True)
99
+ avg1.merge_from(avg2)
100
+ self.assertEqual(avg1.total, 2)
101
+ self.assertEqual(float(avg1), 0.25)
102
+ self.assertEqual(
103
+ avg1.data_points,
104
+ [
105
+ metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
106
+ metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
107
+ ],
108
+ )
109
+
78
110
 
79
111
  if __name__ == '__main__':
80
112
  unittest.main()