langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,36 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Config saver test."""
15
+
16
+ import os
17
+ import tempfile
18
+ import unittest
19
+ from langfun.core.eval.v2 import config_saver
20
+ from langfun.core.eval.v2 import eval_test_helper
21
+ from langfun.core.eval.v2.runners import parallel # pylint: disable=unused-import
22
+
23
+
24
+ class RunConfigSaverTest(unittest.TestCase):
25
+
26
+ def test_save_run_config(self):
27
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_run_config_saver')
28
+ experiment = eval_test_helper.test_evaluation()
29
+ run = experiment.run(
30
+ root_dir, 'new', plugins=[config_saver.RunConfigSaver()]
31
+ )
32
+ self.assertTrue(os.path.exists(os.path.join(run.output_root, 'run.json')))
33
+
34
+
35
+ if __name__ == '__main__':
36
+ unittest.main()
@@ -13,6 +13,9 @@
13
13
  # limitations under the License.
14
14
  """Helper classes and functions for evaluation tests."""
15
15
 
16
+ import threading
17
+ import time
18
+
16
19
  from langfun.core import language_model
17
20
  from langfun.core import llms
18
21
  from langfun.core import message as message_lib
@@ -47,6 +50,8 @@ class TestLLM(llms.Fake):
47
50
 
48
51
  offset: int = 0
49
52
 
53
+ __test__ = False
54
+
50
55
  def _response_from(self, prompt: message_lib.Message) -> message_lib.Message:
51
56
  return message_lib.AIMessage(
52
57
  str(prompt.metadata.x + prompt.metadata.y + self.offset)
@@ -63,6 +68,8 @@ class TestEvaluation(Evaluation):
63
68
  metrics = [metrics_lib.Match()]
64
69
  lm: language_model.LanguageModel = TestLLM()
65
70
 
71
+ __test__ = False
72
+
66
73
  def process(self, example):
67
74
  v = example.input
68
75
  if v.x == 5:
@@ -75,7 +82,7 @@ class TestEvaluation(Evaluation):
75
82
 
76
83
  class BadJsonConvertible(pg.Object):
77
84
 
78
- def to_json(self, *args, **kwargs):
85
+ def sym_jsonify(self, *args, **kwargs):
79
86
  raise ValueError('Cannot convert to JSON.')
80
87
 
81
88
 
@@ -84,6 +91,8 @@ class TestEvaluationWithExampleCheckpointingError(TestEvaluation):
84
91
  inputs = test_inputs()
85
92
  metrics = [metrics_lib.Match()]
86
93
 
94
+ __test__ = False
95
+
87
96
  def process(self, example):
88
97
  return 1, dict(
89
98
  x=BadJsonConvertible()
@@ -101,6 +110,8 @@ class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
101
110
  inputs = test_inputs()
102
111
  metrics = [metrics_lib.Match()]
103
112
 
113
+ __test__ = False
114
+
104
115
  def process(self, example):
105
116
  return 1, dict(
106
117
  x=BadHtmlConvertible()
@@ -110,15 +121,22 @@ class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
110
121
  class TestEvaluationWithIndexHtmlGenerationError(TestEvaluation):
111
122
  """Test evaluation class with bad index HTML generation."""
112
123
 
124
+ __test__ = False
125
+
113
126
  def _html_tree_view(self, *args, **kwargs):
114
127
  raise ValueError('Cannot render HTML.')
115
128
 
116
129
 
130
+ def test_evaluation(offset: int | pg.hyper.OneOf = 0):
131
+ """Returns a test evaluation."""
132
+ return TestEvaluation(lm=TestLLM(offset=offset))
133
+
134
+
117
135
  def test_experiment():
118
136
  """Returns a test experiment."""
119
137
  return Suite([
120
- TestEvaluation(lm=TestLLM(offset=0)),
121
- TestEvaluation(lm=TestLLM(offset=pg.oneof(range(5)))),
138
+ test_evaluation(),
139
+ test_evaluation(pg.oneof(range(5))),
122
140
  ])
123
141
 
124
142
 
@@ -135,3 +153,86 @@ def test_experiment_with_example_html_generation_error():
135
153
  def test_experiment_with_index_html_generation_error():
136
154
  """Returns a test experiment with bad index HTML."""
137
155
  return TestEvaluationWithIndexHtmlGenerationError()
156
+
157
+
158
+ class TestPlugin(experiment_lib.Plugin):
159
+ """Plugin for testing."""
160
+
161
+ started_experiments: list[experiment_lib.Experiment] = []
162
+ completed_experiments: list[experiment_lib.Experiment] = []
163
+ skipped_experiments: list[experiment_lib.Experiment] = []
164
+ started_example_ids: list[int] = []
165
+ completed_example_ids: list[int] = []
166
+ start_time: float | None = None
167
+ complete_time: float | None = None
168
+
169
+ __test__ = False
170
+
171
+ def _on_bound(self):
172
+ super()._on_bound()
173
+ self._lock = threading.Lock()
174
+
175
+ def on_run_start(
176
+ self,
177
+ runner: experiment_lib.Runner,
178
+ root: experiment_lib.Experiment
179
+ ) -> None:
180
+ del root
181
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
182
+ self.start_time = time.time()
183
+
184
+ def on_run_complete(
185
+ self,
186
+ runner: experiment_lib.Runner,
187
+ root: experiment_lib.Experiment
188
+ ) -> None:
189
+ del root
190
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
191
+ self.complete_time = time.time()
192
+
193
+ def on_experiment_start(
194
+ self,
195
+ runner: experiment_lib.Runner,
196
+ experiment: experiment_lib.Experiment
197
+ ) -> None:
198
+ del runner
199
+ with pg.notify_on_change(False), self._lock:
200
+ self.started_experiments.append(pg.Ref(experiment))
201
+
202
+ def on_experiment_skipped(
203
+ self,
204
+ runner: experiment_lib.Runner,
205
+ experiment: experiment_lib.Experiment
206
+ ) -> None:
207
+ del runner
208
+ with pg.notify_on_change(False), self._lock:
209
+ self.skipped_experiments.append(pg.Ref(experiment))
210
+
211
+ def on_experiment_complete(
212
+ self,
213
+ runner: experiment_lib.Runner,
214
+ experiment: experiment_lib.Experiment
215
+ ) -> None:
216
+ del runner
217
+ with pg.notify_on_change(False), self._lock:
218
+ self.completed_experiments.append(pg.Ref(experiment))
219
+
220
+ def on_example_start(
221
+ self,
222
+ runner: experiment_lib.Runner,
223
+ experiment: experiment_lib.Experiment,
224
+ example: Example
225
+ ) -> None:
226
+ del runner, experiment
227
+ with pg.notify_on_change(False), self._lock:
228
+ self.started_example_ids.append(example.id)
229
+
230
+ def on_example_complete(
231
+ self,
232
+ runner: experiment_lib.Runner,
233
+ experiment: experiment_lib.Experiment,
234
+ example: Example
235
+ ) -> None:
236
+ del runner, experiment
237
+ with pg.notify_on_change(False), self._lock:
238
+ self.completed_example_ids.append(example.id)
@@ -32,17 +32,63 @@ import pyglove as pg
32
32
 
33
33
 
34
34
  class Evaluation(experiment_lib.Experiment):
35
- """Evaluation.
36
-
37
- An evaluation can be a leaf node or a container of other evaluations,
38
- depending on whether the current evaluation object is configured with
39
- any `pg.oneof`.
40
-
41
- For example, `MyEval(lm=pg.oneof([lf.llms.Gpt4(), lf.llms.Gemini1_5Pro()]))`
42
- is a container of two sub-experiments, one for each LLM. In such case, the
43
- evaluation object with `pg.oneof` is called a hyper evaluation, which
44
- represents a search space of evaluations, and each sub-evaluation is called
45
- a leaf evaluation, which will perform the actual evaluation.
35
+ """Base class for Langfun evaluations.
36
+
37
+ `lf.eval.Evaluation` is the base class for defining evaluation tasks in
38
+ Langfun. Users typically subclass it to implement custom evaluation logic by
39
+ overriding `inputs` and `process` methods.
40
+
41
+ An `Evaluation` object encapsulates:
42
+
43
+ * **`inputs`**: A callable that returns an iterable of input examples to be
44
+ processed. This is usually provided by implementing an `inputs(self)`
45
+ method in the subclass, which yields input items for evaluation one by
46
+ one.
47
+ * **`process(self, example)`**: An abstract method that processes one
48
+ example and returns the output, or a tuple of (output, metadata).
49
+ The output will be used for computing metrics.
50
+ * **`metrics`**: A list of metrics (e.g., `lf.metrics.Accuracy`) to compute
51
+ based on the outputs from `process`. Some metrics may require users to
52
+ implement a `ground_truth(self, example)` method in the subclass to
53
+ compute metrics against ground truth.
54
+ * **Hyperparameters**: Any other attributes of the class serve as
55
+ hyperparameters for the evaluation (e.g., the language model to use).
56
+
57
+ **Running Evaluations:**
58
+
59
+ Evaluations are executed via `lf.eval.Suite` or by calling the `.run()`
60
+ method on an `Evaluation` instance, which returns a `Run` object
61
+ containing the evaluation run information and results. If an evaluation
62
+ contains sweeable parameters (using `pg.oneof`), `.run()` will expand it
63
+ into multiple evaluation sub-tasks -- one for each combination of
64
+ hyperparameters -- all managed within the same `Run`.
65
+
66
+ **Example:**
67
+
68
+ ```python
69
+ import langfun as lf
70
+ import pyglove as pg
71
+
72
+ class MyEval(lf.eval.Evaluation):
73
+ lm: lf.LanguageModel
74
+ prompt: str = '1 + 1 = '
75
+
76
+ def inputs(self):
77
+ yield 2
78
+
79
+ def process(self, example: lf.eval.Example):
80
+ return int(lf.query(self.prompt, lm=self.lm))
81
+
82
+ def ground_truth(self, example: lf.eval.Example) -> int:
83
+ return example.input
84
+
85
+ # Run evaluation using two different LMs
86
+ evaluation = MyEval(
87
+ lm=pg.oneof([lf.llms.Gpt4(), lf.llms.Gemini()]),
88
+ metrics=[lf.metrics.Accuracy()]
89
+ )
90
+ run_info = evaluation.run()
91
+ ```
46
92
  """
47
93
 
48
94
  inputs: Annotated[
@@ -126,6 +172,20 @@ class Evaluation(experiment_lib.Experiment):
126
172
  # Evaluation logics.
127
173
  #
128
174
 
175
+ def setup(self) -> None:
176
+ """Sets up resources required by the evaluation.
177
+
178
+ Subclasses should always call the super().setup() method to ensure the
179
+ proper initialization of the evaluation.
180
+ """
181
+
182
+ def teardown(self) -> None:
183
+ """Tears down resources used by the evaluation.
184
+
185
+ Subclasses should always call the super().teardown() method to ensure the
186
+ proper cleanup of the evaluation.
187
+ """
188
+
129
189
  @abc.abstractmethod
130
190
  def process(
131
191
  self,
@@ -137,7 +197,7 @@ class Evaluation(experiment_lib.Experiment):
137
197
 
138
198
  Args:
139
199
  example: An example object to process. `example.input` is an object
140
- returned from `Evaluable.inputs`.
200
+ yielded from `inputs()` method.
141
201
 
142
202
  Returns:
143
203
  A processed output. Or a tuple of (output, metadata).
@@ -150,6 +210,7 @@ class Evaluation(experiment_lib.Experiment):
150
210
  example: example_lib.Example | int,
151
211
  raise_if_has_error: bool = False,
152
212
  reevaluate_upon_previous_errors: bool = True,
213
+ force_recompute_metrics: bool = False
153
214
  ) -> example_lib.Example:
154
215
  """Evaluates a single example input.
155
216
 
@@ -158,6 +219,8 @@ class Evaluation(experiment_lib.Experiment):
158
219
  raise_if_has_error: Whether to raise an error if the example has error.
159
220
  reevaluate_upon_previous_errors: Whether to reevaluate the example if
160
221
  the previous checkpointed run has error.
222
+ force_recompute_metrics: If True, force recompute the metrics even if
223
+ metric metadata is already present from previous checkpoint.
161
224
 
162
225
  Returns:
163
226
  The evaluated example with the output and metric metadata populated.
@@ -206,6 +269,7 @@ class Evaluation(experiment_lib.Experiment):
206
269
  # Use the output and metadata obtained from the previous processing.
207
270
  example.output = checkpointed.output
208
271
  example.metadata = checkpointed.metadata
272
+ example.metric_metadata = checkpointed.metric_metadata
209
273
  example.error = checkpointed.error
210
274
  example.newly_processed = False
211
275
  example.execution_status = checkpointed.execution_status
@@ -225,8 +289,16 @@ class Evaluation(experiment_lib.Experiment):
225
289
  self.info(f'Starting metric computation for example {example.id}.')
226
290
  metric_metadata = {}
227
291
  for metric in self.metrics:
228
- metric_metadata.update(metric.audit(example))
229
- example.metric_metadata = metric_metadata
292
+ metric_metadata[metric.name] = metric.update(
293
+ example, force_recompute=force_recompute_metrics
294
+ )
295
+
296
+ if example.metric_metadata is None:
297
+ example.metric_metadata = metric_metadata
298
+ else:
299
+ # Accumulate the metric metadata as there might be existing metadata
300
+ # from previous metric computation runs.
301
+ example.metric_metadata.update(metric_metadata)
230
302
  self.info(f'Completed metric computation for example {example.id}.')
231
303
 
232
304
  # For previously processed examples, we keep the execution status for the
@@ -287,7 +359,7 @@ class Evaluation(experiment_lib.Experiment):
287
359
  A unique string representing the resource required.
288
360
  """
289
361
  return {
290
- v.resource_id for _, v in self.sym_init_args.items()
362
+ v.resource_id for _, v in self.sym_init_args.sym_items()
291
363
  if isinstance(v, lf.LanguageModel)
292
364
  }
293
365
 
@@ -760,7 +832,7 @@ class Evaluation(experiment_lib.Experiment):
760
832
 
761
833
 
762
834
  class EvaluationState:
763
- """Evaluation state."""
835
+ """In-memory state of an evaluation."""
764
836
 
765
837
  class ExampleStatus(pg.Object):
766
838
  """Example state."""
@@ -808,8 +880,9 @@ class EvaluationState:
808
880
  load_example_metadata: bool | Callable[
809
881
  [example_lib.Example], bool] = True,
810
882
  filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
811
- ) -> None:
883
+ ) -> list[example_lib.Example]:
812
884
  """Loads the state from the example sequence file."""
885
+ examples = []
813
886
  for example in example_lib.Example.iter_ckpts(
814
887
  state_file,
815
888
  example_input_by_id=example_input_by_id,
@@ -819,6 +892,8 @@ class EvaluationState:
819
892
  continue
820
893
  example.newly_processed = False
821
894
  self._ckpt_examples[example.id] = example
895
+ examples.append(example)
896
+ return examples
822
897
 
823
898
  @property
824
899
  def evaluation_status(self) -> dict[int, ExampleStatus]:
@@ -88,7 +88,7 @@ class EvaluationTest(unittest.TestCase):
88
88
  self.assertEqual(example.output, 6)
89
89
  self.assertIsNone(example.error)
90
90
  self.assertEqual(example.metadata, {})
91
- self.assertEqual(example.metric_metadata, dict(match=True))
91
+ self.assertEqual(example.metric_metadata, dict(match=dict(is_correct=True)))
92
92
  self.assertIsNotNone(example.usage_summary)
93
93
  self.assertGreater(example.usage_summary.total.total_tokens, 0)
94
94
  self.assertEqual(example.usage_summary.total.num_requests, 1)
@@ -103,7 +103,10 @@ class EvaluationTest(unittest.TestCase):
103
103
  self.assertEqual(example.output, 7)
104
104
  self.assertIsNone(example.error)
105
105
  self.assertEqual(example.metadata, {})
106
- self.assertEqual(example.metric_metadata, dict(mismatch=True))
106
+ self.assertEqual(
107
+ example.metric_metadata,
108
+ dict(match=dict(is_correct=False))
109
+ )
107
110
 
108
111
  with self.assertRaisesRegex(ValueError, 'x should not be 5'):
109
112
  _ = exp.evaluate(6, raise_if_has_error=True)
@@ -113,7 +116,10 @@ class EvaluationTest(unittest.TestCase):
113
116
  self.assertEqual(pg.MISSING_VALUE, example.output)
114
117
  self.assertEqual(example.error.tag, 'ValueError')
115
118
  self.assertEqual(example.metadata, {})
116
- self.assertEqual(example.metric_metadata, dict(error='ValueError'))
119
+ self.assertEqual(
120
+ example.metric_metadata,
121
+ dict(match=dict(error='ValueError'))
122
+ )
117
123
 
118
124
  def test_evaluate_withstate(self):
119
125
  eval_dir = os.path.join(tempfile.mkdtemp(), 'test_eval')
@@ -22,19 +22,30 @@ import pyglove as pg
22
22
 
23
23
  @dataclasses.dataclass
24
24
  class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
25
- """An item for the evaluation.
25
+ """An example for evaluation.
26
+
27
+ An evaluation example contains the input and output of an evaluation task,
28
+ as well as metadata about the evaluation process, such as execution time,
29
+ LLM usage, and metric results.
26
30
 
27
31
  Attributes:
28
- id: The 1-based ID of the item in the evaluation set.
29
- input: An element returned from the `Evaluable.inputs` functor.
30
- output: The output of the `process` method. If `pg.MISSING_VALUE`, it has
31
- not been processed yet.
32
- metadata: The metadata of the item produced by the `process` method.
33
- metric_metadata: The dictionary returned from `Metric.audit`.
34
- start_time: The start time of the evaluation item.
35
- end_time: The end time of the evaluation item.
36
- usage_summary: The summary of LLM usages of the evaluation item.
37
- execution_status: The timeit status of the evaluation item.
32
+ id: The 1-based ID of the example in the evaluation set.
33
+ input: An element returned from the `Evaluable.inputs` functor, which serves
34
+ as the input for `lf.Evaluable.process`.
35
+ output: The output of `lf.Evaluable.process` method. If `pg.MISSING_VALUE`,
36
+ it indicates the example has not been processed yet.
37
+ error: The error raised from `lf.Evaluable.process`. If None, it
38
+ indicates the process was successful.
39
+ metadata: The metadata of the example produced by `lf.Evaluable.process`.
40
+ metric_metadata: The dictionary returned from `Metric.audit`, which contains
41
+ metadata about metric computation for this example.
42
+ newly_processed: Whether this example is processed in the current run. If
43
+ False, it indicates the example was loaded from a checkpoint from previous
44
+ runs.
45
+ start_time: The start time of processing this example.
46
+ end_time: The end time of processing this example.
47
+ usage_summary: The summary of LLM usages for processing this example.
48
+ execution_status: The timeit status of processing this example.
38
49
  """
39
50
  id: int
40
51
  input: Any = pg.MISSING_VALUE
@@ -49,14 +60,6 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
49
60
  usage_summary: lf.UsageSummary | None = None
50
61
  execution_status: dict[str, pg.utils.TimeIt.Status] | None = None
51
62
 
52
- def __post_init__(self):
53
- if self.execution_status is not None:
54
- for status in self.execution_status.values():
55
- if status.has_error:
56
- assert isinstance(status.error, pg.ErrorInfo)
57
- self.error = status.error
58
- break
59
-
60
63
  @property
61
64
  def is_processed(self) -> bool:
62
65
  """Returns whether the item has been processed."""
@@ -152,6 +155,8 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
152
155
  ckpt_file: str | list[str],
153
156
  example_input_by_id: Callable[[int], Any] | None = None,
154
157
  load_example_metadata: bool = True,
158
+ convert_unknown: bool = True,
159
+ **kwargs
155
160
  ) -> Iterator['Example']:
156
161
  """Iterates Examples from the checkpoint files."""
157
162
  ckpt_files = [ckpt_file] if isinstance(ckpt_file, str) else ckpt_file
@@ -161,7 +166,9 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
161
166
  example = pg.from_json_str(
162
167
  record,
163
168
  example_input_by_id=example_input_by_id,
164
- load_example_metadata=load_example_metadata
169
+ load_example_metadata=load_example_metadata,
170
+ convert_unknown=convert_unknown,
171
+ **kwargs
165
172
  )
166
173
  assert isinstance(example, cls), example
167
174
  yield example
@@ -182,15 +189,23 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
182
189
  extra_flags = extra_flags or {}
183
190
  num_examples = extra_flags.get('num_examples', None)
184
191
 
185
- def _metric_metadata_badge(key, value):
186
- if isinstance(value, bool) and bool:
187
- text = key
188
- else:
189
- text = f'{key}:{value}'
190
- return pg.views.html.controls.Badge(
191
- text,
192
- css_classes=[pg.utils.camel_to_snake(key, '-')],
193
- )
192
+ def _metric_label_group(metric_metadata: dict[str, Any] | None):
193
+ """Renders a label group for metric metadata."""
194
+ badges = []
195
+ if metric_metadata:
196
+ for metric_name, metadata in metric_metadata.items():
197
+ assert isinstance(metadata, dict), (metric_name, metadata)
198
+ for k, v in metadata.items():
199
+ css_class = k
200
+ if isinstance(v, bool):
201
+ css_class += '_true' if v else '_false'
202
+ badge = pg.views.html.controls.Badge(
203
+ f'{k}:{v}',
204
+ tooltip=f'{metric_name}: {k}',
205
+ css_classes=[css_class],
206
+ )
207
+ badges.append(badge)
208
+ return pg.views.html.controls.LabelGroup(badges)
194
209
 
195
210
  def _render_header():
196
211
  return pg.Html.element(
@@ -229,12 +244,7 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
229
244
  extra_flags=dict(as_badge=True)
230
245
  ) if self.usage_summary is not None else None,
231
246
  # Metric metadata.
232
- pg.views.html.controls.LabelGroup(
233
- [ # pylint: disable=g-long-ternary
234
- _metric_metadata_badge(k, v)
235
- for k, v in self.metric_metadata.items()
236
- ] if self.metric_metadata else []
237
- ),
247
+ _metric_label_group(self.metric_metadata)
238
248
  ],
239
249
  css_classes=['example-container'],
240
250
  )
@@ -305,18 +315,18 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
305
315
  color: black;
306
316
  }
307
317
  /* Badge styles. */
308
- .eval-example .badge.match {
318
+ .eval-example .badge.is_correct_true {
309
319
  color: green;
310
320
  background-color: #dcefbe;
311
321
  }
322
+ .eval-example .badge.is_correct_false {
323
+ color: orange;
324
+ background-color: #ffefc4;
325
+ }
312
326
  .eval-example .badge.error {
313
327
  color: red;
314
328
  background-color: #fdcccc;
315
329
  }
316
- .eval-example .badge.mismatch {
317
- color: orange;
318
- background-color: #ffefc4;
319
- }
320
330
  .eval-example .badge.score {
321
331
  color: blue;
322
332
  background-color: #c4dced;
@@ -32,9 +32,9 @@ class ExampleTest(unittest.TestCase):
32
32
  name='evaluation', elapse=1.0, error=error
33
33
  )
34
34
  })
35
- self.assertEqual(ex.error, error)
35
+ self.assertIsNone(ex.error)
36
36
  self.assertFalse(ex.is_processed)
37
- self.assertTrue(ex.has_error)
37
+ self.assertFalse(ex.has_error)
38
38
  self.assertEqual(ex.elapse, 1.0)
39
39
 
40
40
  ex = Example(id=2, output=1)
@@ -94,15 +94,23 @@ class ExampleTest(unittest.TestCase):
94
94
  pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
95
95
  inputs[0].b.__type_name__
96
96
  )
97
- v = pg.from_json_str(json_str, auto_dict=True, load_example_metadata=True)
98
- v.output.pop('type_name')
99
- v.metadata.b.pop('type_name')
97
+ v = pg.from_json_str(
98
+ json_str,
99
+ convert_unknown=True,
100
+ load_example_metadata=True
101
+ )
100
102
  self.assertEqual(
101
103
  v,
102
104
  Example(
103
105
  id=1,
104
- output=pg.Dict(x=1),
105
- metadata=dict(b=pg.Dict(x=1, y=2)),
106
+ output=pg.symbolic.UnknownTypedObject(
107
+ inputs[0].a.__type_name__, x=1
108
+ ),
109
+ metadata=dict(
110
+ b=pg.symbolic.UnknownTypedObject(
111
+ inputs[0].b.__type_name__, x=1, y=2
112
+ )
113
+ ),
106
114
  )
107
115
  )
108
116
  # Serialize with input.
@@ -116,7 +124,7 @@ class ExampleTest(unittest.TestCase):
116
124
  input=pg.Dict(a=1, b=2),
117
125
  output=3,
118
126
  metadata=dict(sum=3),
119
- metric_metadata=dict(match=True),
127
+ metric_metadata=dict(match=dict(match=True)),
120
128
  )
121
129
  self.assertNotIn(
122
130
  'next',