langfun 0.1.2.dev202509020804__py3-none-any.whl → 0.1.2.dev202511110805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (133) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +6 -1
  3. langfun/core/agentic/__init__.py +4 -0
  4. langfun/core/agentic/action.py +412 -103
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +68 -6
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +9 -2
  20. langfun/core/data/conversion/gemini_test.py +12 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +47 -43
  24. langfun/core/eval/base_test.py +4 -4
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +1 -0
  29. langfun/core/eval/v2/checkpointing.py +30 -4
  30. langfun/core/eval/v2/eval_test_helper.py +1 -1
  31. langfun/core/eval/v2/evaluation.py +60 -14
  32. langfun/core/eval/v2/example.py +22 -11
  33. langfun/core/eval/v2/experiment.py +51 -8
  34. langfun/core/eval/v2/metric_values.py +31 -3
  35. langfun/core/eval/v2/metric_values_test.py +32 -0
  36. langfun/core/eval/v2/metrics.py +39 -4
  37. langfun/core/eval/v2/metrics_test.py +14 -0
  38. langfun/core/eval/v2/progress.py +30 -1
  39. langfun/core/eval/v2/progress_test.py +27 -0
  40. langfun/core/eval/v2/progress_tracking_test.py +6 -0
  41. langfun/core/eval/v2/reporting.py +90 -71
  42. langfun/core/eval/v2/reporting_test.py +20 -6
  43. langfun/core/eval/v2/runners.py +27 -7
  44. langfun/core/eval/v2/runners_test.py +3 -0
  45. langfun/core/langfunc.py +45 -130
  46. langfun/core/langfunc_test.py +6 -4
  47. langfun/core/language_model.py +151 -31
  48. langfun/core/language_model_test.py +9 -3
  49. langfun/core/llms/__init__.py +12 -1
  50. langfun/core/llms/anthropic.py +157 -2
  51. langfun/core/llms/azure_openai.py +29 -17
  52. langfun/core/llms/cache/base.py +25 -3
  53. langfun/core/llms/cache/in_memory.py +48 -7
  54. langfun/core/llms/cache/in_memory_test.py +14 -4
  55. langfun/core/llms/compositional.py +25 -1
  56. langfun/core/llms/deepseek.py +30 -2
  57. langfun/core/llms/fake.py +39 -1
  58. langfun/core/llms/fake_test.py +9 -0
  59. langfun/core/llms/gemini.py +43 -7
  60. langfun/core/llms/google_genai.py +34 -1
  61. langfun/core/llms/groq.py +28 -3
  62. langfun/core/llms/llama_cpp.py +23 -4
  63. langfun/core/llms/openai.py +93 -3
  64. langfun/core/llms/openai_compatible.py +148 -27
  65. langfun/core/llms/openai_compatible_test.py +207 -20
  66. langfun/core/llms/openai_test.py +0 -2
  67. langfun/core/llms/rest.py +16 -1
  68. langfun/core/llms/vertexai.py +59 -8
  69. langfun/core/logging.py +1 -1
  70. langfun/core/mcp/__init__.py +10 -0
  71. langfun/core/mcp/client.py +177 -0
  72. langfun/core/mcp/client_test.py +71 -0
  73. langfun/core/mcp/session.py +241 -0
  74. langfun/core/mcp/session_test.py +54 -0
  75. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  76. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  77. langfun/core/mcp/tool.py +256 -0
  78. langfun/core/mcp/tool_test.py +197 -0
  79. langfun/core/memory.py +1 -0
  80. langfun/core/message.py +160 -55
  81. langfun/core/message_test.py +65 -81
  82. langfun/core/modalities/__init__.py +8 -0
  83. langfun/core/modalities/audio.py +21 -1
  84. langfun/core/modalities/image.py +19 -1
  85. langfun/core/modalities/mime.py +62 -3
  86. langfun/core/modalities/pdf.py +19 -1
  87. langfun/core/modalities/video.py +21 -1
  88. langfun/core/modality.py +167 -29
  89. langfun/core/modality_test.py +42 -12
  90. langfun/core/natural_language.py +1 -1
  91. langfun/core/sampling.py +4 -4
  92. langfun/core/sampling_test.py +20 -4
  93. langfun/core/structured/completion.py +34 -44
  94. langfun/core/structured/completion_test.py +23 -43
  95. langfun/core/structured/description.py +54 -50
  96. langfun/core/structured/function_generation.py +29 -12
  97. langfun/core/structured/mapping.py +74 -28
  98. langfun/core/structured/parsing.py +90 -74
  99. langfun/core/structured/parsing_test.py +0 -3
  100. langfun/core/structured/querying.py +242 -156
  101. langfun/core/structured/querying_test.py +95 -64
  102. langfun/core/structured/schema.py +70 -10
  103. langfun/core/structured/schema_generation.py +33 -14
  104. langfun/core/structured/scoring.py +45 -34
  105. langfun/core/structured/tokenization.py +24 -9
  106. langfun/core/subscription.py +2 -2
  107. langfun/core/template.py +175 -50
  108. langfun/core/template_test.py +123 -17
  109. langfun/env/__init__.py +43 -0
  110. langfun/env/base_environment.py +827 -0
  111. langfun/env/base_environment_test.py +473 -0
  112. langfun/env/base_feature.py +304 -0
  113. langfun/env/base_feature_test.py +228 -0
  114. langfun/env/base_sandbox.py +842 -0
  115. langfun/env/base_sandbox_test.py +1235 -0
  116. langfun/env/event_handlers/__init__.py +14 -0
  117. langfun/env/event_handlers/chain.py +233 -0
  118. langfun/env/event_handlers/chain_test.py +253 -0
  119. langfun/env/event_handlers/event_logger.py +472 -0
  120. langfun/env/event_handlers/event_logger_test.py +304 -0
  121. langfun/env/event_handlers/metric_writer.py +726 -0
  122. langfun/env/event_handlers/metric_writer_test.py +214 -0
  123. langfun/env/interface.py +1640 -0
  124. langfun/env/interface_test.py +151 -0
  125. langfun/env/load_balancers.py +59 -0
  126. langfun/env/load_balancers_test.py +139 -0
  127. langfun/env/test_utils.py +497 -0
  128. {langfun-0.1.2.dev202509020804.dist-info → langfun-0.1.2.dev202511110805.dist-info}/METADATA +7 -3
  129. langfun-0.1.2.dev202511110805.dist-info/RECORD +200 -0
  130. langfun-0.1.2.dev202509020804.dist-info/RECORD +0 -172
  131. {langfun-0.1.2.dev202509020804.dist-info → langfun-0.1.2.dev202511110805.dist-info}/WHEEL +0 -0
  132. {langfun-0.1.2.dev202509020804.dist-info → langfun-0.1.2.dev202511110805.dist-info}/licenses/LICENSE +0 -0
  133. {langfun-0.1.2.dev202509020804.dist-info → langfun-0.1.2.dev202511110805.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ from langfun.core.eval.v2 import runners
38
38
  from langfun.core.eval.v2.checkpointing import BulkCheckpointer
39
39
  from langfun.core.eval.v2.checkpointing import PerExampleCheckpointer
40
40
  from langfun.core.eval.v2.reporting import HtmlReporter
41
+ from langfun.core.eval.v2.reporting import ExampleHtmlGenerator
41
42
 
42
43
 
43
44
  # pylint: enable=g-bad-import-order
@@ -29,7 +29,17 @@ Runner = experiment_lib.Runner
29
29
 
30
30
 
31
31
  class Checkpointer(experiment_lib.Plugin):
32
- """Base class for checkpointing evaluation examples."""
32
+ """Base class for checkpointing evaluation examples.
33
+
34
+ `Checkpointer` is a plugin that saves the state of processed examples
35
+ incrementally during an experiment run, allowing the experiment to be resumed
36
+ later. When an experiment starts, the checkpointer loads any previously saved
37
+ examples from an earlier run (or a warm-start run) into `experiment.state`,
38
+ so the runner can skip processing them again.
39
+ Subclasses should implement `_list_checkpoint_filenames` to identify
40
+ checkpoint files to load, and `_save_example` to save a newly processed
41
+ example.
42
+ """
33
43
 
34
44
  checkpoint_filename: Annotated[
35
45
  str,
@@ -170,7 +180,12 @@ class Checkpointer(experiment_lib.Plugin):
170
180
 
171
181
 
172
182
  class PerExampleCheckpointer(Checkpointer):
173
- """Checkpointer that saves each example to a separate file."""
183
+ """Checkpointer that saves each example to a separate file.
184
+
185
+ This checkpointer saves each processed example to its own checkpoint file,
186
+ named using the pattern `<checkpoint_filename_prefix>_<example_id>.<ext>`.
187
+ For example, `checkpoint_1.bagz`, `checkpoint_2.bagz`, etc.
188
+ """
174
189
 
175
190
  def _on_bound(self):
176
191
  super()._on_bound()
@@ -235,7 +250,13 @@ class PerExampleCheckpointer(Checkpointer):
235
250
 
236
251
 
237
252
  class BulkCheckpointer(Checkpointer):
238
- """Checkpointer that saves all examples to a single file."""
253
+ """Checkpointer that saves all examples of an evaluation to a single file.
254
+
255
+ This checkpointer appends newly processed examples of an evaluation to a
256
+ single sequence file (e.g., `checkpoint.bagz`). This is often more efficient
257
+ than `PerExampleCheckpointer` when dealing with a large number of examples
258
+ or when file system overhead is a concern.
259
+ """
239
260
 
240
261
  def _on_bound(self):
241
262
  super()._on_bound()
@@ -341,7 +362,12 @@ class BulkCheckpointer(Checkpointer):
341
362
 
342
363
 
343
364
  class SequenceWriter:
344
- """Thread safe sequence writer."""
365
+ """A thread-safe writer for sequence files (e.g., Bagz).
366
+
367
+ `SequenceWriter` wraps a `pg.io.SequenceWriter` to provide thread-safe
368
+ `add` and `close` operations, ensuring that examples can be written
369
+ concurrently from multiple threads without corrupting the sequence file.
370
+ """
345
371
 
346
372
  def __init__(self, path: str):
347
373
  self._lock = threading.Lock()
@@ -75,7 +75,7 @@ class TestEvaluation(Evaluation):
75
75
 
76
76
  class BadJsonConvertible(pg.Object):
77
77
 
78
- def to_json(self, *args, **kwargs):
78
+ def sym_jsonify(self, *args, **kwargs):
79
79
  raise ValueError('Cannot convert to JSON.')
80
80
 
81
81
 
@@ -32,17 +32,63 @@ import pyglove as pg
32
32
 
33
33
 
34
34
  class Evaluation(experiment_lib.Experiment):
35
- """Evaluation.
36
-
37
- An evaluation can be a leaf node or a container of other evaluations,
38
- depending on whether the current evaluation object is configured with
39
- any `pg.oneof`.
40
-
41
- For example, `MyEval(lm=pg.oneof([lf.llms.Gpt4(), lf.llms.Gemini1_5Pro()]))`
42
- is a container of two sub-experiments, one for each LLM. In such case, the
43
- evaluation object with `pg.oneof` is called a hyper evaluation, which
44
- represents a search space of evaluations, and each sub-evaluation is called
45
- a leaf evaluation, which will perform the actual evaluation.
35
+ """Base class for Langfun evaluations.
36
+
37
+ `lf.eval.Evaluation` is the base class for defining evaluation tasks in
38
+ Langfun. Users typically subclass it to implement custom evaluation logic by
39
+ overriding `inputs` and `process` methods.
40
+
41
+ An `Evaluation` object encapsulates:
42
+
43
+ * **`inputs`**: A callable that returns an iterable of input examples to be
44
+ processed. This is usually provided by implementing an `inputs(self)`
45
+ method in the subclass, which yields input items for evaluation one by
46
+ one.
47
+ * **`process(self, example)`**: An abstract method that processes one
48
+ example and returns the output, or a tuple of (output, metadata).
49
+ The output will be used for computing metrics.
50
+ * **`metrics`**: A list of metrics (e.g., `lf.metrics.Accuracy`) to compute
51
+ based on the outputs from `process`. Some metrics may require users to
52
+ implement a `ground_truth(self, example)` method in the subclass to
53
+ compute metrics against ground truth.
54
+ * **Hyperparameters**: Any other attributes of the class serve as
55
+ hyperparameters for the evaluation (e.g., the language model to use).
56
+
57
+ **Running Evaluations:**
58
+
59
+ Evaluations are executed via `lf.eval.Suite` or by calling the `.run()`
60
+ method on an `Evaluation` instance, which returns a `Run` object
61
+ containing the evaluation run information and results. If an evaluation
62
+ contains sweeable parameters (using `pg.oneof`), `.run()` will expand it
63
+ into multiple evaluation sub-tasks -- one for each combination of
64
+ hyperparameters -- all managed within the same `Run`.
65
+
66
+ **Example:**
67
+
68
+ ```python
69
+ import langfun as lf
70
+ import pyglove as pg
71
+
72
+ class MyEval(lf.eval.Evaluation):
73
+ lm: lf.LanguageModel
74
+ prompt: str = '1 + 1 = '
75
+
76
+ def inputs(self):
77
+ yield 2
78
+
79
+ def process(self, example: lf.eval.Example):
80
+ return int(lf.query(self.prompt, lm=self.lm))
81
+
82
+ def ground_truth(self, example: lf.eval.Example) -> int:
83
+ return example.input
84
+
85
+ # Run evaluation using two different LMs
86
+ evaluation = MyEval(
87
+ lm=pg.oneof([lf.llms.Gpt4(), lf.llms.Gemini()]),
88
+ metrics=[lf.metrics.Accuracy()]
89
+ )
90
+ run_info = evaluation.run()
91
+ ```
46
92
  """
47
93
 
48
94
  inputs: Annotated[
@@ -137,7 +183,7 @@ class Evaluation(experiment_lib.Experiment):
137
183
 
138
184
  Args:
139
185
  example: An example object to process. `example.input` is an object
140
- returned from `Evaluable.inputs`.
186
+ yielded from `inputs()` method.
141
187
 
142
188
  Returns:
143
189
  A processed output. Or a tuple of (output, metadata).
@@ -287,7 +333,7 @@ class Evaluation(experiment_lib.Experiment):
287
333
  A unique string representing the resource required.
288
334
  """
289
335
  return {
290
- v.resource_id for _, v in self.sym_init_args.items()
336
+ v.resource_id for _, v in self.sym_init_args.sym_items()
291
337
  if isinstance(v, lf.LanguageModel)
292
338
  }
293
339
 
@@ -760,7 +806,7 @@ class Evaluation(experiment_lib.Experiment):
760
806
 
761
807
 
762
808
  class EvaluationState:
763
- """Evaluation state."""
809
+ """In-memory state of an evaluation."""
764
810
 
765
811
  class ExampleStatus(pg.Object):
766
812
  """Example state."""
@@ -22,19 +22,30 @@ import pyglove as pg
22
22
 
23
23
  @dataclasses.dataclass
24
24
  class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
25
- """An item for the evaluation.
25
+ """An example for evaluation.
26
+
27
+ An evaluation example contains the input and output of an evaluation task,
28
+ as well as metadata about the evaluation process, such as execution time,
29
+ LLM usage, and metric results.
26
30
 
27
31
  Attributes:
28
- id: The 1-based ID of the item in the evaluation set.
29
- input: An element returned from the `Evaluable.inputs` functor.
30
- output: The output of the `process` method. If `pg.MISSING_VALUE`, it has
31
- not been processed yet.
32
- metadata: The metadata of the item produced by the `process` method.
33
- metric_metadata: The dictionary returned from `Metric.audit`.
34
- start_time: The start time of the evaluation item.
35
- end_time: The end time of the evaluation item.
36
- usage_summary: The summary of LLM usages of the evaluation item.
37
- execution_status: The timeit status of the evaluation item.
32
+ id: The 1-based ID of the example in the evaluation set.
33
+ input: An element returned from the `Evaluable.inputs` functor, which serves
34
+ as the input for `lf.Evaluable.process`.
35
+ output: The output of `lf.Evaluable.process` method. If `pg.MISSING_VALUE`,
36
+ it indicates the example has not been processed yet.
37
+ error: The error encountered during `lf.Evaluable.process`. If None, it
38
+ indicates the process was successful.
39
+ metadata: The metadata of the example produced by `lf.Evaluable.process`.
40
+ metric_metadata: The dictionary returned from `Metric.audit`, which contains
41
+ metadata about metric computation for this example.
42
+ newly_processed: Whether this example is processed in the current run. If
43
+ False, it indicates the example was loaded from a checkpoint from previous
44
+ runs.
45
+ start_time: The start time of processing this example.
46
+ end_time: The end time of processing this example.
47
+ usage_summary: The summary of LLM usages for processing this example.
48
+ execution_status: The timeit status of processing this example.
38
49
  """
39
50
  id: int
40
51
  input: Any = pg.MISSING_VALUE
@@ -139,10 +139,10 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
139
139
 
140
140
  # Checkpointing
141
141
 
142
- Experiments support checkpointing, which is enabled by default. It allows
142
+ Experiments support checkpointing, which is enabled by default. It allows
143
143
  users to resume their experiments from a saved state. When an experiment runs,
144
- it creates a new directory for that run and saves the current state to a
145
- checkpoint file. If the experiment is interrupted or fails, users can resume
144
+ it creates a new directory for that run and saves its progress to checkpoint
145
+ files. If the experiment is interrupted or fails, users can resume
146
146
  it by specifying the 'id' or 'warm_start_from' argument (shown above) to
147
147
  seamlessly continue from previously saved state without starting over.
148
148
 
@@ -169,7 +169,7 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
169
169
 
170
170
  # Experiment Plugins
171
171
 
172
- Experiment can be extended by plugins. Plugins can listen to the events of
172
+ Experiments can be extended by plugins. Plugins can listen to the events of
173
173
  experiment execution and produce additional outputs. For example, a plugin
174
174
  can be added to an experiment to generate additional metrics or to save
175
175
  additional data to a database. More details will be added in the future.
@@ -657,7 +657,30 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
657
657
 
658
658
  @pg.use_init_args(['children'])
659
659
  class Suite(Experiment):
660
- """A suite of evaluations."""
660
+ """A suite of evaluations.
661
+
662
+ `lf.eval.Suite` groups multiple `lf.eval.Evaluation` or other `Suite`
663
+ objects into a single experiment, allowing them to be run, managed, and
664
+ reported together.
665
+
666
+ **Example:**
667
+
668
+ ```python
669
+ import langfun as lf
670
+
671
+ suite = lf.eval.Suite([
672
+ MyEval(lm=lf.llms.Gpt4()),
673
+ MyEval(lm=lf.llms.Gemini()),
674
+ lf.eval.Suite([
675
+ AnotherEval(lm=lf.llms.Gpt4()),
676
+ AnotherEval(lm=lf.llms.Gemini())
677
+ ])
678
+ ])
679
+
680
+ # Run all evaluations in the suite
681
+ run_info = suite.run('/path/to/my/suite_run')
682
+ ```
683
+ """
661
684
 
662
685
  children: Annotated[
663
686
  list[Experiment], 'A list of child experiments.'
@@ -791,7 +814,14 @@ class RunId(pg.Object):
791
814
 
792
815
 
793
816
  class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
794
- """A run of an experiment."""
817
+ """Represents a single run of an experiment.
818
+
819
+ A `Run` object holds all the configurations for executing an experiment,
820
+ such as the experiment definition, input/output directories, and flags
821
+ controlling the execution behavior (e.g., error handling, checkpointing).
822
+ It also provides utility methods for accessing run-specific paths and
823
+ filtering examples for evaluation.
824
+ """
795
825
 
796
826
  root_dir: Annotated[
797
827
  str,
@@ -971,7 +1001,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
971
1001
 
972
1002
 
973
1003
  class Runner(pg.Object):
974
- """Interface for experiment runner."""
1004
+ """Interface for experiment runner.
1005
+
1006
+ A runner is responsible for executing the evaluations within an experiment
1007
+ based on the configuration specified in a `Run` object. Different runners
1008
+ can implement different execution strategies, such as sequential or parallel
1009
+ processing of examples and evaluations.
1010
+ """
975
1011
 
976
1012
  # Class-level variable for registering the runner.
977
1013
  NAME = None
@@ -1010,7 +1046,14 @@ class Runner(pg.Object):
1010
1046
 
1011
1047
 
1012
1048
  class Plugin(lf.Component):
1013
- """Base class for experiment plugins."""
1049
+ """Base class for experiment plugins.
1050
+
1051
+ Plugins provide a mechanism to extend the behavior of an experiment run
1052
+ by hooking into various events during the lifecycle of experiment and
1053
+ example execution, such as `on_run_start`, `on_experiment_complete`,
1054
+ `on_example_start`, etc. They can be used for custom logging, monitoring,
1055
+ or result processing.
1056
+ """
1014
1057
 
1015
1058
  def on_run_start(
1016
1059
  self,
@@ -20,7 +20,15 @@ import pyglove as pg
20
20
 
21
21
 
22
22
  class MetricValue(pg.Object):
23
- """Base class for metric values."""
23
+ """Base class for metric values.
24
+
25
+ `MetricValue` is the base class for representing aggregated metric values
26
+ in an evaluation. It accumulates data points from individual examples,
27
+ each consisting of a value and an optional weight, associated with an example
28
+ ID. Subclasses must implement `reduce` method to compute a single float value
29
+ from accumulated data points, and `scalar_repr` to provide a string
30
+ representation of the reduced value.
31
+ """
24
32
 
25
33
  class DataPoint(pg.Object):
26
34
  """A data point for a metric value."""
@@ -88,6 +96,14 @@ class MetricValue(pg.Object):
88
96
  self.increment_total()
89
97
  return self
90
98
 
99
+ def merge_from(self, other: 'MetricValue') -> 'MetricValue':
100
+ """Merges the values from another metric value."""
101
+ self._weighted_sum += other._weighted_sum # pylint: disable=protected-access
102
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
103
+ self.data_points.extend(other.data_points)
104
+ self.increment_total(other.total)
105
+ return self
106
+
91
107
  def __gt__(self, other: Union['MetricValue', float]) -> bool:
92
108
  if isinstance(other, self.__class__):
93
109
  return float(self) > float(other)
@@ -133,7 +149,13 @@ class MetricValue(pg.Object):
133
149
 
134
150
 
135
151
  class Rate(MetricValue):
136
- """Representing a rate in range [0, 1]."""
152
+ """Metric value representing a rate in range [0, 1].
153
+
154
+ `Rate` is used for metrics that compute a rate, such as accuracy or error
155
+ rate. The final value is computed as the weighted sum of accumulated values
156
+ divided by the total number of examples. It's displayed as a percentage
157
+ (e.g., 90.0%).
158
+ """
137
159
 
138
160
  def reduce(self) -> float:
139
161
  return self._weighted_sum / self.total
@@ -145,7 +167,13 @@ class Rate(MetricValue):
145
167
 
146
168
 
147
169
  class Average(MetricValue):
148
- """Average of a aggregated values."""
170
+ """Metric value representing an average of accumulated values.
171
+
172
+ `Average` is used for metrics that compute an average score across examples
173
+ (e.g., average quality score). The final value is computed as the weighted
174
+ sum of accumulated values divided by the number of data points.
175
+ It's displayed as a float with 3 decimal places (e.g., 4.750).
176
+ """
149
177
 
150
178
  def reduce(self) -> float:
151
179
  if not self.data_points:
@@ -51,6 +51,22 @@ class RateTest(unittest.TestCase):
51
51
  self.assertEqual(rate.total, 0)
52
52
  self.assertTrue(math.isnan(float(rate)))
53
53
 
54
+ def test_merge_from(self):
55
+ rate1 = metric_values.Rate()
56
+ rate1.add(1, 1.0, 1.0, increment_total=True)
57
+ rate2 = metric_values.Rate()
58
+ rate2.add(2, 0.0, 1.0, increment_total=True)
59
+ rate1.merge_from(rate2)
60
+ self.assertEqual(rate1.total, 2)
61
+ self.assertEqual(float(rate1), 0.5)
62
+ self.assertEqual(
63
+ rate1.data_points,
64
+ [
65
+ metric_values.MetricValue.DataPoint(1, 1.0, 1.0),
66
+ metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
67
+ ],
68
+ )
69
+
54
70
 
55
71
  class AverageTest(unittest.TestCase):
56
72
 
@@ -75,6 +91,22 @@ class AverageTest(unittest.TestCase):
75
91
  average.reset()
76
92
  self.assertEqual(average.total, 0)
77
93
 
94
+ def test_merge_from(self):
95
+ avg1 = metric_values.Average()
96
+ avg1.add(1, 1.0, 0.5, increment_total=True)
97
+ avg2 = metric_values.Average()
98
+ avg2.add(2, 0.0, 1.0, increment_total=True)
99
+ avg1.merge_from(avg2)
100
+ self.assertEqual(avg1.total, 2)
101
+ self.assertEqual(float(avg1), 0.25)
102
+ self.assertEqual(
103
+ avg1.data_points,
104
+ [
105
+ metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
106
+ metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
107
+ ],
108
+ )
109
+
78
110
 
79
111
  if __name__ == '__main__':
80
112
  unittest.main()
@@ -29,7 +29,15 @@ Average = metric_values.Average
29
29
 
30
30
 
31
31
  class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
32
- """Interface for an evaluation metric."""
32
+ """Interface for an evaluation metric.
33
+
34
+ A metric is used to evaluate the quality of the outputs produced by an
35
+ evaluation. It works by auditing each processed example via its `audit`
36
+ method, which in turn calls the user-overridable `_audit` method to perform
37
+ metric-specific logic and update metric values. Metrics can compute multiple
38
+ values (e.g., precision, recall, F1 score) which are exposed via the
39
+ `values` method.
40
+ """
33
41
 
34
42
  name: Annotated[
35
43
  str,
@@ -71,6 +79,12 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
71
79
  for v in self.values():
72
80
  v.reset()
73
81
 
82
+ def merge_from(self, other: 'Metric') -> 'Metric':
83
+ """Merges the values from another metric."""
84
+ for v1, v2 in zip(self.values(), other.values()):
85
+ v1.merge_from(v2)
86
+ return self
87
+
74
88
  def _update_view(self):
75
89
  """Refreshes the metric values."""
76
90
  if self._label_group is None:
@@ -169,7 +183,15 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
169
183
 
170
184
 
171
185
  class MetricBase(Metric):
172
- """Base class for common metrics."""
186
+ """Base class for common metrics.
187
+
188
+ `MetricBase` provides common functionalities for metrics, such as automatic
189
+ error counting based on whether an example has an error during evaluation.
190
+ It distinguishes between Object-Oriented Programming (OOP) errors
191
+ (e.g. `MappingError` during structured output generation) and other errors.
192
+ Subclasses should implement `_audit_processed` for metric computation on
193
+ successfully processed examples.
194
+ """
173
195
 
174
196
  oop_errors: Rate | None = Rate()
175
197
  non_oop_errors: Rate | None = Rate()
@@ -229,7 +251,13 @@ class MetricBase(Metric):
229
251
 
230
252
 
231
253
  class Match(MetricBase):
232
- """Metric for matching outputs against groundtruth."""
254
+ """Metric for matching outputs against ground truth.
255
+
256
+ This metric computes match and mismatch rates by comparing the output of
257
+ an example with its ground truth. By default, it looks for a `groundtruth`
258
+ attribute in `example.input` for comparison. Users can customize this behavior
259
+ by subclassing `Match` and overriding the `match` method.
260
+ """
233
261
 
234
262
  name = 'match'
235
263
  matches: Rate = Rate()
@@ -302,7 +330,14 @@ class Match(MetricBase):
302
330
 
303
331
 
304
332
  class Score(MetricBase):
305
- """Base class for scoring."""
333
+ """Base class for scoring metrics.
334
+
335
+ `Score` is a base class for metrics that assign a numerical score to each
336
+ example's output (e.g., evaluating quality on a scale of 1-5).
337
+ It automatically computes the average score across all examples.
338
+ Subclasses must implement the `score` method to define how an example
339
+ should be scored.
340
+ """
306
341
 
307
342
  name = 'score'
308
343
  average_score: Average = Average()
@@ -106,6 +106,20 @@ class MatchTest(unittest.TestCase):
106
106
  m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
107
107
  self.assertEqual(len(scripts), 12)
108
108
 
109
+ def test_merge_from(self):
110
+ m1 = metrics.Match()
111
+ m1.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
112
+ m2 = metrics.Match()
113
+ m2.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
114
+ m1.merge_from(m2)
115
+ self.assertEqual(m1.matches, 0.5)
116
+ self.assertEqual(m1.mismatches, 0.5)
117
+ self.assertEqual(m1.oop_errors, 0.0)
118
+ self.assertEqual(m1.non_oop_errors, 0.0)
119
+ self.assertEqual(m1.matches.total, 2)
120
+ self.assertEqual(len(m1.matches.data_points), 1)
121
+ self.assertEqual(len(m1.mismatches.data_points), 1)
122
+
109
123
 
110
124
  class ScoreTest(unittest.TestCase):
111
125
 
@@ -21,7 +21,15 @@ import pyglove as pg
21
21
 
22
22
 
23
23
  class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
24
- """Evaluation progress."""
24
+ """Represents and tracks the progress of an evaluation.
25
+
26
+ The `Progress` class maintains counts of processed, failed, and skipped
27
+ items in an evaluation, along with timing information (start time, stop time,
28
+ duration) and an execution summary. It provides properties to check the
29
+ status of the evaluation (e.g., `is_started`, `is_completed`) and methods
30
+ to update progress as items are evaluated.
31
+ It also supports HTML rendering as a progress bar for visualization.
32
+ """
25
33
 
26
34
  num_total: Annotated[
27
35
  int | None,
@@ -216,6 +224,27 @@ class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
216
224
  """Overrides nondefault values so volatile values are not included."""
217
225
  return dict()
218
226
 
227
+ def merge_from(self, other: 'Progress') -> None:
228
+ """Merges the progress from another progress."""
229
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
230
+ if other.start_time is not None and (
231
+ self.start_time is None or self.start_time > other.start_time):
232
+ self.start_time = other.start_time
233
+
234
+ if other.stop_time is not None and (
235
+ self.stop_time is None or self.stop_time < other.stop_time):
236
+ self.stop_time = other.stop_time
237
+
238
+ if other.num_total is not None:
239
+ if self.num_total is None:
240
+ self.num_total = other.num_total
241
+ else:
242
+ assert self.num_total == other.num_total, (self, other)
243
+ self.num_processed += other.num_processed
244
+ self.num_failed += other.num_failed
245
+ self.num_skipped += other.num_skipped
246
+ self.execution_summary.aggregate(other.execution_summary.breakdown)
247
+
219
248
  #
220
249
  # HTML view.
221
250
  #
@@ -77,6 +77,33 @@ class ProgressTest(unittest.TestCase):
77
77
  self.assertTrue(p.is_stopped)
78
78
  self.assertIsNotNone(p.stop_time_str)
79
79
 
80
+ def test_merge_from(self):
81
+ p1 = Progress()
82
+ p1.start(10)
83
+ p1.increment_processed()
84
+ p1.increment_failed()
85
+ p1.stop()
86
+
87
+ p2 = Progress()
88
+ p2.start(10)
89
+ p2.increment_skipped()
90
+ p2.stop()
91
+
92
+ with pg.allow_writable_accessors(True):
93
+ p1.start_time = 2.0
94
+ p1.stop_time = 4.0
95
+ p2.start_time = 1.0
96
+ p2.stop_time = 5.0
97
+
98
+ p1.merge_from(p2)
99
+ self.assertEqual(p1.num_total, 10)
100
+ self.assertEqual(p1.num_processed, 1)
101
+ self.assertEqual(p1.num_failed, 1)
102
+ self.assertEqual(p1.num_skipped, 1)
103
+ self.assertEqual(p1.num_completed, 3)
104
+ self.assertEqual(p1.start_time, 1.0)
105
+ self.assertEqual(p1.stop_time, 5.0)
106
+
80
107
 
81
108
  if __name__ == '__main__':
82
109
  unittest.main()
@@ -14,9 +14,11 @@
14
14
  import contextlib
15
15
  import io
16
16
  import os
17
+ import sys
17
18
  import tempfile
18
19
  import unittest
19
20
 
21
+ from langfun.core import concurrent as lf_concurrent
20
22
  from langfun.core import console as lf_console
21
23
  from langfun.core.eval.v2 import eval_test_helper
22
24
  from langfun.core.eval.v2 import progress_tracking # pylint: disable=unused-import
@@ -49,6 +51,8 @@ class TqdmProgressTrackerTest(unittest.TestCase):
49
51
  string_io = io.StringIO()
50
52
  with contextlib.redirect_stderr(string_io):
51
53
  _ = experiment.run(root_dir, 'new', plugins=[])
54
+ sys.stderr.flush()
55
+ lf_concurrent.ProgressBar.refresh()
52
56
  self.assertIn('All: 100%', string_io.getvalue())
53
57
 
54
58
  def test_with_example_ids(self):
@@ -59,6 +63,8 @@ class TqdmProgressTrackerTest(unittest.TestCase):
59
63
  string_io = io.StringIO()
60
64
  with contextlib.redirect_stderr(string_io):
61
65
  _ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[])
66
+ sys.stderr.flush()
67
+ lf_concurrent.ProgressBar.refresh()
62
68
  self.assertIn('All: 100%', string_io.getvalue())
63
69
 
64
70