langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -32,17 +32,63 @@ import pyglove as pg
32
32
 
33
33
 
34
34
  class Evaluation(experiment_lib.Experiment):
35
- """Evaluation.
36
-
37
- An evaluation can be a leaf node or a container of other evaluations,
38
- depending on whether the current evaluation object is configured with
39
- any `pg.oneof`.
40
-
41
- For example, `MyEval(lm=pg.oneof([lf.llms.Gpt4(), lf.llms.Gemini1_5Pro()]))`
42
- is a container of two sub-experiments, one for each LLM. In such case, the
43
- evaluation object with `pg.oneof` is called a hyper evaluation, which
44
- represents a search space of evaluations, and each sub-evaluation is called
45
- a leaf evaluation, which will perform the actual evaluation.
35
+ """Base class for Langfun evaluations.
36
+
37
+ `lf.eval.Evaluation` is the base class for defining evaluation tasks in
38
+ Langfun. Users typically subclass it to implement custom evaluation logic by
39
+ overriding `inputs` and `process` methods.
40
+
41
+ An `Evaluation` object encapsulates:
42
+
43
+ * **`inputs`**: A callable that returns an iterable of input examples to be
44
+ processed. This is usually provided by implementing an `inputs(self)`
45
+ method in the subclass, which yields input items for evaluation one by
46
+ one.
47
+ * **`process(self, example)`**: An abstract method that processes one
48
+ example and returns the output, or a tuple of (output, metadata).
49
+ The output will be used for computing metrics.
50
+ * **`metrics`**: A list of metrics (e.g., `lf.metrics.Accuracy`) to compute
51
+ based on the outputs from `process`. Some metrics may require users to
52
+ implement a `ground_truth(self, example)` method in the subclass to
53
+ compute metrics against ground truth.
54
+ * **Hyperparameters**: Any other attributes of the class serve as
55
+ hyperparameters for the evaluation (e.g., the language model to use).
56
+
57
+ **Running Evaluations:**
58
+
59
+ Evaluations are executed via `lf.eval.Suite` or by calling the `.run()`
60
+ method on an `Evaluation` instance, which returns a `Run` object
61
+ containing the evaluation run information and results. If an evaluation
62
+ contains sweeable parameters (using `pg.oneof`), `.run()` will expand it
63
+ into multiple evaluation sub-tasks -- one for each combination of
64
+ hyperparameters -- all managed within the same `Run`.
65
+
66
+ **Example:**
67
+
68
+ ```python
69
+ import langfun as lf
70
+ import pyglove as pg
71
+
72
+ class MyEval(lf.eval.Evaluation):
73
+ lm: lf.LanguageModel
74
+ prompt: str = '1 + 1 = '
75
+
76
+ def inputs(self):
77
+ yield 2
78
+
79
+ def process(self, example: lf.eval.Example):
80
+ return int(lf.query(self.prompt, lm=self.lm))
81
+
82
+ def ground_truth(self, example: lf.eval.Example) -> int:
83
+ return example.input
84
+
85
+ # Run evaluation using two different LMs
86
+ evaluation = MyEval(
87
+ lm=pg.oneof([lf.llms.Gpt4(), lf.llms.Gemini()]),
88
+ metrics=[lf.metrics.Accuracy()]
89
+ )
90
+ run_info = evaluation.run()
91
+ ```
46
92
  """
47
93
 
48
94
  inputs: Annotated[
@@ -68,6 +114,13 @@ class Evaluation(experiment_lib.Experiment):
68
114
  self._log_entries = []
69
115
  self._log_lock = threading.Lock()
70
116
 
117
+ def _identity(self) -> str:
118
+ """Returns the definition of the evaluation."""
119
+ return self.format(
120
+ compact=True, hide_default_values=True, use_inferred=True,
121
+ exclude_keys=('plugins', 'progress', 'usage_summary')
122
+ )
123
+
71
124
  #
72
125
  # Handling evaluation hierarchy (materialized vs. hyper evaluations).
73
126
  #
@@ -126,6 +179,20 @@ class Evaluation(experiment_lib.Experiment):
126
179
  # Evaluation logics.
127
180
  #
128
181
 
182
+ def setup(self) -> None:
183
+ """Sets up resources required by the evaluation.
184
+
185
+ Subclasses should always call the super().setup() method to ensure the
186
+ proper initialization of the evaluation.
187
+ """
188
+
189
+ def teardown(self) -> None:
190
+ """Tears down resources used by the evaluation.
191
+
192
+ Subclasses should always call the super().teardown() method to ensure the
193
+ proper cleanup of the evaluation.
194
+ """
195
+
129
196
  @abc.abstractmethod
130
197
  def process(
131
198
  self,
@@ -137,7 +204,7 @@ class Evaluation(experiment_lib.Experiment):
137
204
 
138
205
  Args:
139
206
  example: An example object to process. `example.input` is an object
140
- returned from `Evaluable.inputs`.
207
+ yielded from `inputs()` method.
141
208
 
142
209
  Returns:
143
210
  A processed output. Or a tuple of (output, metadata).
@@ -150,6 +217,7 @@ class Evaluation(experiment_lib.Experiment):
150
217
  example: example_lib.Example | int,
151
218
  raise_if_has_error: bool = False,
152
219
  reevaluate_upon_previous_errors: bool = True,
220
+ force_recompute_metrics: bool = False
153
221
  ) -> example_lib.Example:
154
222
  """Evaluates a single example input.
155
223
 
@@ -158,6 +226,8 @@ class Evaluation(experiment_lib.Experiment):
158
226
  raise_if_has_error: Whether to raise an error if the example has error.
159
227
  reevaluate_upon_previous_errors: Whether to reevaluate the example if
160
228
  the previous checkpointed run has error.
229
+ force_recompute_metrics: If True, force recompute the metrics even if
230
+ metric metadata is already present from previous checkpoint.
161
231
 
162
232
  Returns:
163
233
  The evaluated example with the output and metric metadata populated.
@@ -206,6 +276,7 @@ class Evaluation(experiment_lib.Experiment):
206
276
  # Use the output and metadata obtained from the previous processing.
207
277
  example.output = checkpointed.output
208
278
  example.metadata = checkpointed.metadata
279
+ example.metric_metadata = checkpointed.metric_metadata
209
280
  example.error = checkpointed.error
210
281
  example.newly_processed = False
211
282
  example.execution_status = checkpointed.execution_status
@@ -225,8 +296,16 @@ class Evaluation(experiment_lib.Experiment):
225
296
  self.info(f'Starting metric computation for example {example.id}.')
226
297
  metric_metadata = {}
227
298
  for metric in self.metrics:
228
- metric_metadata.update(metric.audit(example))
229
- example.metric_metadata = metric_metadata
299
+ metric_metadata[metric.name] = metric.update(
300
+ example, force_recompute=force_recompute_metrics
301
+ )
302
+
303
+ if example.metric_metadata is None:
304
+ example.metric_metadata = metric_metadata
305
+ else:
306
+ # Accumulate the metric metadata as there might be existing metadata
307
+ # from previous metric computation runs.
308
+ example.metric_metadata.update(metric_metadata)
230
309
  self.info(f'Completed metric computation for example {example.id}.')
231
310
 
232
311
  # For previously processed examples, we keep the execution status for the
@@ -287,7 +366,7 @@ class Evaluation(experiment_lib.Experiment):
287
366
  A unique string representing the resource required.
288
367
  """
289
368
  return {
290
- v.resource_id for _, v in self.sym_init_args.items()
369
+ v.resource_id for _, v in self.sym_init_args.sym_items()
291
370
  if isinstance(v, lf.LanguageModel)
292
371
  }
293
372
 
@@ -307,10 +386,10 @@ class Evaluation(experiment_lib.Experiment):
307
386
  load_example_metadata: bool = True,
308
387
  filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
309
388
  raise_if_not_exist: bool = False
310
- ) -> None:
389
+ ) -> list[example_lib.Example]:
311
390
  """Loads saved state from a sequence IO file."""
312
391
  if pg.io.path_exists(state_file):
313
- self._state.load(
392
+ return self._state.load(
314
393
  state_file,
315
394
  example_input_by_id=self.example_input_by_id,
316
395
  load_example_metadata=load_example_metadata,
@@ -318,6 +397,7 @@ class Evaluation(experiment_lib.Experiment):
318
397
  )
319
398
  elif raise_if_not_exist:
320
399
  raise ValueError(f'State file {state_file} does not exist.')
400
+ return []
321
401
 
322
402
  def _reset(self) -> None:
323
403
  """Resets the state of the evaluation."""
@@ -760,7 +840,7 @@ class Evaluation(experiment_lib.Experiment):
760
840
 
761
841
 
762
842
  class EvaluationState:
763
- """Evaluation state."""
843
+ """In-memory state of an evaluation."""
764
844
 
765
845
  class ExampleStatus(pg.Object):
766
846
  """Example state."""
@@ -808,8 +888,9 @@ class EvaluationState:
808
888
  load_example_metadata: bool | Callable[
809
889
  [example_lib.Example], bool] = True,
810
890
  filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
811
- ) -> None:
891
+ ) -> list[example_lib.Example]:
812
892
  """Loads the state from the example sequence file."""
893
+ examples = []
813
894
  for example in example_lib.Example.iter_ckpts(
814
895
  state_file,
815
896
  example_input_by_id=example_input_by_id,
@@ -819,6 +900,8 @@ class EvaluationState:
819
900
  continue
820
901
  example.newly_processed = False
821
902
  self._ckpt_examples[example.id] = example
903
+ examples.append(example)
904
+ return examples
822
905
 
823
906
  @property
824
907
  def evaluation_status(self) -> dict[int, ExampleStatus]:
@@ -88,7 +88,7 @@ class EvaluationTest(unittest.TestCase):
88
88
  self.assertEqual(example.output, 6)
89
89
  self.assertIsNone(example.error)
90
90
  self.assertEqual(example.metadata, {})
91
- self.assertEqual(example.metric_metadata, dict(match=True))
91
+ self.assertEqual(example.metric_metadata, dict(match=dict(is_correct=True)))
92
92
  self.assertIsNotNone(example.usage_summary)
93
93
  self.assertGreater(example.usage_summary.total.total_tokens, 0)
94
94
  self.assertEqual(example.usage_summary.total.num_requests, 1)
@@ -103,7 +103,10 @@ class EvaluationTest(unittest.TestCase):
103
103
  self.assertEqual(example.output, 7)
104
104
  self.assertIsNone(example.error)
105
105
  self.assertEqual(example.metadata, {})
106
- self.assertEqual(example.metric_metadata, dict(mismatch=True))
106
+ self.assertEqual(
107
+ example.metric_metadata,
108
+ dict(match=dict(is_correct=False))
109
+ )
107
110
 
108
111
  with self.assertRaisesRegex(ValueError, 'x should not be 5'):
109
112
  _ = exp.evaluate(6, raise_if_has_error=True)
@@ -113,7 +116,10 @@ class EvaluationTest(unittest.TestCase):
113
116
  self.assertEqual(pg.MISSING_VALUE, example.output)
114
117
  self.assertEqual(example.error.tag, 'ValueError')
115
118
  self.assertEqual(example.metadata, {})
116
- self.assertEqual(example.metric_metadata, dict(error='ValueError'))
119
+ self.assertEqual(
120
+ example.metric_metadata,
121
+ dict(match=dict(error='ValueError'))
122
+ )
117
123
 
118
124
  def test_evaluate_withstate(self):
119
125
  eval_dir = os.path.join(tempfile.mkdtemp(), 'test_eval')
@@ -22,19 +22,30 @@ import pyglove as pg
22
22
 
23
23
  @dataclasses.dataclass
24
24
  class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
25
- """An item for the evaluation.
25
+ """An example for evaluation.
26
+
27
+ An evaluation example contains the input and output of an evaluation task,
28
+ as well as metadata about the evaluation process, such as execution time,
29
+ LLM usage, and metric results.
26
30
 
27
31
  Attributes:
28
- id: The 1-based ID of the item in the evaluation set.
29
- input: An element returned from the `Evaluable.inputs` functor.
30
- output: The output of the `process` method. If `pg.MISSING_VALUE`, it has
31
- not been processed yet.
32
- metadata: The metadata of the item produced by the `process` method.
33
- metric_metadata: The dictionary returned from `Metric.audit`.
34
- start_time: The start time of the evaluation item.
35
- end_time: The end time of the evaluation item.
36
- usage_summary: The summary of LLM usages of the evaluation item.
37
- execution_status: The timeit status of the evaluation item.
32
+ id: The 1-based ID of the example in the evaluation set.
33
+ input: An element returned from the `Evaluable.inputs` functor, which serves
34
+ as the input for `lf.Evaluable.process`.
35
+ output: The output of `lf.Evaluable.process` method. If `pg.MISSING_VALUE`,
36
+ it indicates the example has not been processed yet.
37
+ error: The error raised from `lf.Evaluable.process`. If None, it
38
+ indicates the process was successful.
39
+ metadata: The metadata of the example produced by `lf.Evaluable.process`.
40
+ metric_metadata: The dictionary returned from `Metric.audit`, which contains
41
+ metadata about metric computation for this example.
42
+ newly_processed: Whether this example is processed in the current run. If
43
+ False, it indicates the example was loaded from a checkpoint from previous
44
+ runs.
45
+ start_time: The start time of processing this example.
46
+ end_time: The end time of processing this example.
47
+ usage_summary: The summary of LLM usages for processing this example.
48
+ execution_status: The timeit status of processing this example.
38
49
  """
39
50
  id: int
40
51
  input: Any = pg.MISSING_VALUE
@@ -49,14 +60,6 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
49
60
  usage_summary: lf.UsageSummary | None = None
50
61
  execution_status: dict[str, pg.utils.TimeIt.Status] | None = None
51
62
 
52
- def __post_init__(self):
53
- if self.execution_status is not None:
54
- for status in self.execution_status.values():
55
- if status.has_error:
56
- assert isinstance(status.error, pg.ErrorInfo)
57
- self.error = status.error
58
- break
59
-
60
63
  @property
61
64
  def is_processed(self) -> bool:
62
65
  """Returns whether the item has been processed."""
@@ -152,6 +155,8 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
152
155
  ckpt_file: str | list[str],
153
156
  example_input_by_id: Callable[[int], Any] | None = None,
154
157
  load_example_metadata: bool = True,
158
+ convert_unknown: bool = True,
159
+ **kwargs
155
160
  ) -> Iterator['Example']:
156
161
  """Iterates Examples from the checkpoint files."""
157
162
  ckpt_files = [ckpt_file] if isinstance(ckpt_file, str) else ckpt_file
@@ -161,7 +166,9 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
161
166
  example = pg.from_json_str(
162
167
  record,
163
168
  example_input_by_id=example_input_by_id,
164
- load_example_metadata=load_example_metadata
169
+ load_example_metadata=load_example_metadata,
170
+ convert_unknown=convert_unknown,
171
+ **kwargs
165
172
  )
166
173
  assert isinstance(example, cls), example
167
174
  yield example
@@ -182,15 +189,23 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
182
189
  extra_flags = extra_flags or {}
183
190
  num_examples = extra_flags.get('num_examples', None)
184
191
 
185
- def _metric_metadata_badge(key, value):
186
- if isinstance(value, bool) and bool:
187
- text = key
188
- else:
189
- text = f'{key}:{value}'
190
- return pg.views.html.controls.Badge(
191
- text,
192
- css_classes=[pg.utils.camel_to_snake(key, '-')],
193
- )
192
+ def _metric_label_group(metric_metadata: dict[str, Any] | None):
193
+ """Renders a label group for metric metadata."""
194
+ badges = []
195
+ if metric_metadata:
196
+ for metric_name, metadata in metric_metadata.items():
197
+ assert isinstance(metadata, dict), (metric_name, metadata)
198
+ for k, v in metadata.items():
199
+ css_class = k
200
+ if isinstance(v, bool):
201
+ css_class += '_true' if v else '_false'
202
+ badge = pg.views.html.controls.Badge(
203
+ f'{k}:{v}',
204
+ tooltip=f'{metric_name}: {k}',
205
+ css_classes=[css_class],
206
+ )
207
+ badges.append(badge)
208
+ return pg.views.html.controls.LabelGroup(badges)
194
209
 
195
210
  def _render_header():
196
211
  return pg.Html.element(
@@ -229,12 +244,7 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
229
244
  extra_flags=dict(as_badge=True)
230
245
  ) if self.usage_summary is not None else None,
231
246
  # Metric metadata.
232
- pg.views.html.controls.LabelGroup(
233
- [ # pylint: disable=g-long-ternary
234
- _metric_metadata_badge(k, v)
235
- for k, v in self.metric_metadata.items()
236
- ] if self.metric_metadata else []
237
- ),
247
+ _metric_label_group(self.metric_metadata)
238
248
  ],
239
249
  css_classes=['example-container'],
240
250
  )
@@ -305,18 +315,18 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
305
315
  color: black;
306
316
  }
307
317
  /* Badge styles. */
308
- .eval-example .badge.match {
318
+ .eval-example .badge.is_correct_true {
309
319
  color: green;
310
320
  background-color: #dcefbe;
311
321
  }
322
+ .eval-example .badge.is_correct_false {
323
+ color: orange;
324
+ background-color: #ffefc4;
325
+ }
312
326
  .eval-example .badge.error {
313
327
  color: red;
314
328
  background-color: #fdcccc;
315
329
  }
316
- .eval-example .badge.mismatch {
317
- color: orange;
318
- background-color: #ffefc4;
319
- }
320
330
  .eval-example .badge.score {
321
331
  color: blue;
322
332
  background-color: #c4dced;
@@ -32,9 +32,9 @@ class ExampleTest(unittest.TestCase):
32
32
  name='evaluation', elapse=1.0, error=error
33
33
  )
34
34
  })
35
- self.assertEqual(ex.error, error)
35
+ self.assertIsNone(ex.error)
36
36
  self.assertFalse(ex.is_processed)
37
- self.assertTrue(ex.has_error)
37
+ self.assertFalse(ex.has_error)
38
38
  self.assertEqual(ex.elapse, 1.0)
39
39
 
40
40
  ex = Example(id=2, output=1)
@@ -94,15 +94,23 @@ class ExampleTest(unittest.TestCase):
94
94
  pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
95
95
  inputs[0].b.__type_name__
96
96
  )
97
- v = pg.from_json_str(json_str, auto_dict=True, load_example_metadata=True)
98
- v.output.pop('type_name')
99
- v.metadata.b.pop('type_name')
97
+ v = pg.from_json_str(
98
+ json_str,
99
+ convert_unknown=True,
100
+ load_example_metadata=True
101
+ )
100
102
  self.assertEqual(
101
103
  v,
102
104
  Example(
103
105
  id=1,
104
- output=pg.Dict(x=1),
105
- metadata=dict(b=pg.Dict(x=1, y=2)),
106
+ output=pg.symbolic.UnknownTypedObject(
107
+ inputs[0].a.__type_name__, x=1
108
+ ),
109
+ metadata=dict(
110
+ b=pg.symbolic.UnknownTypedObject(
111
+ inputs[0].b.__type_name__, x=1, y=2
112
+ )
113
+ ),
106
114
  )
107
115
  )
108
116
  # Serialize with input.
@@ -116,7 +124,7 @@ class ExampleTest(unittest.TestCase):
116
124
  input=pg.Dict(a=1, b=2),
117
125
  output=3,
118
126
  metadata=dict(sum=3),
119
- metric_metadata=dict(match=True),
127
+ metric_metadata=dict(match=dict(match=True)),
120
128
  )
121
129
  self.assertNotIn(
122
130
  'next',
@@ -139,10 +139,10 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
139
139
 
140
140
  # Checkpointing
141
141
 
142
- Experiments support checkpointing, which is enabled by default. It allows
142
+ Experiments support checkpointing, which is enabled by default. It allows
143
143
  users to resume their experiments from a saved state. When an experiment runs,
144
- it creates a new directory for that run and saves the current state to a
145
- checkpoint file. If the experiment is interrupted or fails, users can resume
144
+ it creates a new directory for that run and saves its progress to checkpoint
145
+ files. If the experiment is interrupted or fails, users can resume
146
146
  it by specifying the 'id' or 'warm_start_from' argument (shown above) to
147
147
  seamlessly continue from previously saved state without starting over.
148
148
 
@@ -169,7 +169,7 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
169
169
 
170
170
  # Experiment Plugins
171
171
 
172
- Experiment can be extended by plugins. Plugins can listen to the events of
172
+ Experiments can be extended by plugins. Plugins can listen to the events of
173
173
  experiment execution and produce additional outputs. For example, a plugin
174
174
  can be added to an experiment to generate additional metrics or to save
175
175
  additional data to a database. More details will be added in the future.
@@ -268,11 +268,11 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
268
268
  @functools.cached_property
269
269
  def hash(self) -> str:
270
270
  """A 8-byte MD5 hash computed from experiment identity."""
271
- identity = self.format(
272
- compact=True, hide_default_values=True, use_inferred=True,
273
- exclude_keys=('plugins', 'progress', 'usage_summary')
274
- )
275
- return hashlib.md5(identity.encode()).hexdigest()[:8]
271
+ return hashlib.md5(self._identity().encode()).hexdigest()[:8]
272
+
273
+ @abc.abstractmethod
274
+ def _identity(self) -> str:
275
+ """Returns the identity of the experiment."""
276
276
 
277
277
  @classmethod
278
278
  def link(cls, path: str) -> str:
@@ -657,7 +657,30 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
657
657
 
658
658
  @pg.use_init_args(['children'])
659
659
  class Suite(Experiment):
660
- """A suite of evaluations."""
660
+ """A suite of evaluations.
661
+
662
+ `lf.eval.Suite` groups multiple `lf.eval.Evaluation` or other `Suite`
663
+ objects into a single experiment, allowing them to be run, managed, and
664
+ reported together.
665
+
666
+ **Example:**
667
+
668
+ ```python
669
+ import langfun as lf
670
+
671
+ suite = lf.eval.Suite([
672
+ MyEval(lm=lf.llms.Gpt4()),
673
+ MyEval(lm=lf.llms.Gemini()),
674
+ lf.eval.Suite([
675
+ AnotherEval(lm=lf.llms.Gpt4()),
676
+ AnotherEval(lm=lf.llms.Gemini())
677
+ ])
678
+ ])
679
+
680
+ # Run all evaluations in the suite
681
+ run_info = suite.run('/path/to/my/suite_run')
682
+ ```
683
+ """
661
684
 
662
685
  children: Annotated[
663
686
  list[Experiment], 'A list of child experiments.'
@@ -668,6 +691,12 @@ class Suite(Experiment):
668
691
  """Returns whether the task is a leaf."""
669
692
  return False
670
693
 
694
+ def _identity(self) -> str:
695
+ """Returns the definition of the evaluation."""
696
+ return '[' + ', '.join(
697
+ [child._identity() for child in self.children] # pylint: disable=protected-access
698
+ ) + ']'
699
+
671
700
 
672
701
  class RunId(pg.Object):
673
702
  """Structured repreesentation a experiment run ID."""
@@ -791,7 +820,14 @@ class RunId(pg.Object):
791
820
 
792
821
 
793
822
  class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
794
- """A run of an experiment."""
823
+ """Represents a single run of an experiment.
824
+
825
+ A `Run` object holds all the configurations for executing an experiment,
826
+ such as the experiment definition, input/output directories, and flags
827
+ controlling the execution behavior (e.g., error handling, checkpointing).
828
+ It also provides utility methods for accessing run-specific paths and
829
+ filtering examples for evaluation.
830
+ """
795
831
 
796
832
  root_dir: Annotated[
797
833
  str,
@@ -818,10 +854,10 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
818
854
  ] = None
819
855
 
820
856
  example_ids: Annotated[
821
- list[int] | None,
857
+ list[int] | Callable[[Experiment], list[int]] | None,
822
858
  (
823
- 'The example IDs to run. If None, it will run all examples. '
824
- 'Though '
859
+ 'The example IDs to run. Or a callable for determining the examples '
860
+ 'to run based on the experiment. If None, it will run all examples. '
825
861
  )
826
862
  ] = None
827
863
 
@@ -937,10 +973,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
937
973
  """Returns the example IDs to evaluate."""
938
974
  if not experiment.is_leaf:
939
975
  return set()
940
- return set(
941
- self.example_ids if self.example_ids else
942
- range(1, experiment.num_examples + 1)
943
- )
976
+ if self.example_ids is None:
977
+ return set(range(1, experiment.num_examples + 1))
978
+ elif isinstance(self.example_ids, Callable):
979
+ return set(self.example_ids(experiment))
980
+ else:
981
+ assert isinstance(self.example_ids, list), self.example_ids
982
+ return set(self.example_ids)
944
983
 
945
984
  def examples_to_reprocess(self, experiment: Experiment) -> set[int]:
946
985
  """Returns the example IDs to reprocess per request."""
@@ -971,7 +1010,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
971
1010
 
972
1011
 
973
1012
  class Runner(pg.Object):
974
- """Interface for experiment runner."""
1013
+ """Interface for experiment runner.
1014
+
1015
+ A runner is responsible for executing the evaluations within an experiment
1016
+ based on the configuration specified in a `Run` object. Different runners
1017
+ can implement different execution strategies, such as sequential or parallel
1018
+ processing of examples and evaluations.
1019
+ """
975
1020
 
976
1021
  # Class-level variable for registering the runner.
977
1022
  NAME = None
@@ -1010,7 +1055,37 @@ class Runner(pg.Object):
1010
1055
 
1011
1056
 
1012
1057
  class Plugin(lf.Component):
1013
- """Base class for experiment plugins."""
1058
+ """Base class for experiment plugins.
1059
+
1060
+ Plugins provide a mechanism to extend the behavior of an experiment run
1061
+ by hooking into various events during the lifecycle of experiment and
1062
+ example execution, such as `on_run_start`, `on_experiment_complete`,
1063
+ `on_example_start`, etc. They can be used for custom logging, monitoring,
1064
+ or result processing.
1065
+ """
1066
+
1067
+ @classmethod
1068
+ def is_per_example(cls) -> bool:
1069
+ """Returns whether the plugin is per example only.
1070
+
1071
+ Per-example plugins can be installed on individual workers when examples
1072
+ are evaluated by multiple processes in parallel.
1073
+ """
1074
+
1075
+ def same_code(method1, method2):
1076
+ return method1.__code__ == method2.__code__
1077
+ return all(
1078
+ same_code(method1, method2)
1079
+ for method1, method2 in [
1080
+ (Plugin.on_run_start, cls.on_run_start),
1081
+ (Plugin.on_run_complete, cls.on_run_complete),
1082
+ (Plugin.on_run_abort, cls.on_run_abort),
1083
+ (Plugin.on_experiment_start, cls.on_experiment_start),
1084
+ (Plugin.on_experiment_skipped, cls.on_experiment_skipped),
1085
+ (Plugin.on_experiment_complete, cls.on_experiment_complete),
1086
+ (Plugin.on_experiment_abort, cls.on_experiment_abort),
1087
+ ]
1088
+ )
1014
1089
 
1015
1090
  def on_run_start(
1016
1091
  self,
@@ -433,5 +433,24 @@ class RunnerTest(unittest.TestCase):
433
433
  pass
434
434
 
435
435
 
436
+ class PluginTest(unittest.TestCase):
437
+
438
+ def test_per_example_only(self):
439
+
440
+ class PerExamplePlugin(experiment_lib.Plugin):
441
+
442
+ def on_example_complete(self, runner, experiment, example):
443
+ print('on_example_complete')
444
+
445
+ self.assertTrue(PerExamplePlugin.is_per_example())
446
+
447
+ class NonPerExamplePlugin(experiment_lib.Plugin):
448
+
449
+ def on_experiment_complete(self, runner, experiment):
450
+ print('on_example_complete')
451
+
452
+ self.assertFalse(NonPerExamplePlugin.is_per_example())
453
+
454
+
436
455
  if __name__ == '__main__':
437
456
  unittest.main()