langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (155) hide show
  1. langfun/core/__init__.py +2 -0
  2. langfun/core/agentic/__init__.py +4 -1
  3. langfun/core/agentic/action.py +447 -29
  4. langfun/core/agentic/action_eval.py +9 -2
  5. langfun/core/agentic/action_test.py +149 -21
  6. langfun/core/async_support.py +32 -3
  7. langfun/core/coding/python/correction.py +19 -9
  8. langfun/core/coding/python/execution.py +14 -12
  9. langfun/core/coding/python/generation.py +21 -16
  10. langfun/core/coding/python/sandboxing.py +23 -3
  11. langfun/core/component.py +42 -3
  12. langfun/core/concurrent.py +70 -6
  13. langfun/core/concurrent_test.py +1 -0
  14. langfun/core/console.py +1 -1
  15. langfun/core/data/conversion/anthropic.py +12 -3
  16. langfun/core/data/conversion/anthropic_test.py +8 -6
  17. langfun/core/data/conversion/gemini.py +9 -2
  18. langfun/core/data/conversion/gemini_test.py +12 -9
  19. langfun/core/data/conversion/openai.py +145 -31
  20. langfun/core/data/conversion/openai_test.py +161 -17
  21. langfun/core/eval/base.py +47 -43
  22. langfun/core/eval/base_test.py +5 -5
  23. langfun/core/eval/matching.py +5 -2
  24. langfun/core/eval/patching.py +3 -3
  25. langfun/core/eval/scoring.py +4 -3
  26. langfun/core/eval/v2/__init__.py +1 -0
  27. langfun/core/eval/v2/checkpointing.py +64 -6
  28. langfun/core/eval/v2/checkpointing_test.py +9 -2
  29. langfun/core/eval/v2/eval_test_helper.py +103 -2
  30. langfun/core/eval/v2/evaluation.py +91 -16
  31. langfun/core/eval/v2/evaluation_test.py +9 -3
  32. langfun/core/eval/v2/example.py +50 -40
  33. langfun/core/eval/v2/example_test.py +16 -8
  34. langfun/core/eval/v2/experiment.py +74 -8
  35. langfun/core/eval/v2/experiment_test.py +19 -0
  36. langfun/core/eval/v2/metric_values.py +31 -3
  37. langfun/core/eval/v2/metric_values_test.py +32 -0
  38. langfun/core/eval/v2/metrics.py +157 -44
  39. langfun/core/eval/v2/metrics_test.py +39 -18
  40. langfun/core/eval/v2/progress.py +30 -1
  41. langfun/core/eval/v2/progress_test.py +27 -0
  42. langfun/core/eval/v2/progress_tracking.py +12 -3
  43. langfun/core/eval/v2/progress_tracking_test.py +6 -1
  44. langfun/core/eval/v2/reporting.py +90 -71
  45. langfun/core/eval/v2/reporting_test.py +24 -6
  46. langfun/core/eval/v2/runners/__init__.py +30 -0
  47. langfun/core/eval/v2/{runners.py → runners/base.py} +59 -142
  48. langfun/core/eval/v2/runners/beam.py +341 -0
  49. langfun/core/eval/v2/runners/beam_test.py +131 -0
  50. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  51. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  52. langfun/core/eval/v2/runners/debug.py +40 -0
  53. langfun/core/eval/v2/runners/debug_test.py +76 -0
  54. langfun/core/eval/v2/runners/parallel.py +100 -0
  55. langfun/core/eval/v2/runners/parallel_test.py +95 -0
  56. langfun/core/eval/v2/runners/sequential.py +47 -0
  57. langfun/core/eval/v2/runners/sequential_test.py +172 -0
  58. langfun/core/langfunc.py +45 -130
  59. langfun/core/langfunc_test.py +7 -5
  60. langfun/core/language_model.py +141 -21
  61. langfun/core/language_model_test.py +54 -3
  62. langfun/core/llms/__init__.py +9 -1
  63. langfun/core/llms/anthropic.py +157 -2
  64. langfun/core/llms/azure_openai.py +29 -17
  65. langfun/core/llms/cache/base.py +25 -3
  66. langfun/core/llms/cache/in_memory.py +48 -7
  67. langfun/core/llms/cache/in_memory_test.py +14 -4
  68. langfun/core/llms/compositional.py +25 -1
  69. langfun/core/llms/deepseek.py +30 -2
  70. langfun/core/llms/fake.py +32 -1
  71. langfun/core/llms/gemini.py +55 -17
  72. langfun/core/llms/gemini_test.py +84 -0
  73. langfun/core/llms/google_genai.py +34 -1
  74. langfun/core/llms/groq.py +28 -3
  75. langfun/core/llms/llama_cpp.py +23 -4
  76. langfun/core/llms/openai.py +36 -3
  77. langfun/core/llms/openai_compatible.py +148 -27
  78. langfun/core/llms/openai_compatible_test.py +207 -20
  79. langfun/core/llms/openai_test.py +0 -2
  80. langfun/core/llms/rest.py +12 -1
  81. langfun/core/llms/vertexai.py +58 -8
  82. langfun/core/logging.py +1 -1
  83. langfun/core/mcp/client.py +77 -22
  84. langfun/core/mcp/client_test.py +8 -35
  85. langfun/core/mcp/session.py +94 -29
  86. langfun/core/mcp/session_test.py +54 -0
  87. langfun/core/mcp/tool.py +151 -22
  88. langfun/core/mcp/tool_test.py +197 -0
  89. langfun/core/memory.py +1 -0
  90. langfun/core/message.py +160 -55
  91. langfun/core/message_test.py +65 -81
  92. langfun/core/modalities/__init__.py +8 -0
  93. langfun/core/modalities/audio.py +21 -1
  94. langfun/core/modalities/image.py +19 -1
  95. langfun/core/modalities/mime.py +64 -3
  96. langfun/core/modalities/mime_test.py +11 -0
  97. langfun/core/modalities/pdf.py +19 -1
  98. langfun/core/modalities/video.py +21 -1
  99. langfun/core/modality.py +167 -29
  100. langfun/core/modality_test.py +42 -12
  101. langfun/core/natural_language.py +1 -1
  102. langfun/core/sampling.py +4 -4
  103. langfun/core/sampling_test.py +20 -4
  104. langfun/core/structured/__init__.py +2 -24
  105. langfun/core/structured/completion.py +34 -44
  106. langfun/core/structured/completion_test.py +23 -43
  107. langfun/core/structured/description.py +54 -50
  108. langfun/core/structured/function_generation.py +29 -12
  109. langfun/core/structured/mapping.py +81 -37
  110. langfun/core/structured/parsing.py +95 -79
  111. langfun/core/structured/parsing_test.py +0 -3
  112. langfun/core/structured/querying.py +215 -142
  113. langfun/core/structured/querying_test.py +65 -29
  114. langfun/core/structured/schema/__init__.py +49 -0
  115. langfun/core/structured/schema/base.py +664 -0
  116. langfun/core/structured/schema/base_test.py +531 -0
  117. langfun/core/structured/schema/json.py +174 -0
  118. langfun/core/structured/schema/json_test.py +121 -0
  119. langfun/core/structured/schema/python.py +316 -0
  120. langfun/core/structured/schema/python_test.py +410 -0
  121. langfun/core/structured/schema_generation.py +33 -14
  122. langfun/core/structured/scoring.py +47 -36
  123. langfun/core/structured/tokenization.py +26 -11
  124. langfun/core/subscription.py +2 -2
  125. langfun/core/template.py +174 -49
  126. langfun/core/template_test.py +123 -17
  127. langfun/env/__init__.py +8 -2
  128. langfun/env/base_environment.py +320 -128
  129. langfun/env/base_environment_test.py +473 -0
  130. langfun/env/base_feature.py +92 -15
  131. langfun/env/base_feature_test.py +228 -0
  132. langfun/env/base_sandbox.py +84 -361
  133. langfun/env/base_sandbox_test.py +1235 -0
  134. langfun/env/event_handlers/__init__.py +1 -1
  135. langfun/env/event_handlers/chain.py +233 -0
  136. langfun/env/event_handlers/chain_test.py +253 -0
  137. langfun/env/event_handlers/event_logger.py +95 -98
  138. langfun/env/event_handlers/event_logger_test.py +21 -21
  139. langfun/env/event_handlers/metric_writer.py +225 -140
  140. langfun/env/event_handlers/metric_writer_test.py +23 -6
  141. langfun/env/interface.py +854 -40
  142. langfun/env/interface_test.py +112 -2
  143. langfun/env/load_balancers_test.py +23 -2
  144. langfun/env/test_utils.py +126 -84
  145. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
  146. langfun-0.1.2.dev202511270805.dist-info/RECORD +215 -0
  147. langfun/core/eval/v2/runners_test.py +0 -343
  148. langfun/core/structured/schema.py +0 -987
  149. langfun/core/structured/schema_test.py +0 -982
  150. langfun/env/base_test.py +0 -1481
  151. langfun/env/event_handlers/base.py +0 -350
  152. langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
  153. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
  154. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
  155. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
@@ -41,18 +41,19 @@ class Scoring(base.Evaluation):
41
41
 
42
42
  @property
43
43
  def score_rate(self) -> float:
44
- """Returns the score rate."""
44
+ """Returns the rate of scored examples among the completed ones."""
45
45
  if self.num_completed == 0:
46
46
  return 0.0
47
47
  return self.num_scored / self.num_completed
48
48
 
49
49
  @property
50
50
  def scored_link(self) -> str:
51
- """Returns the matches page."""
51
+ """Returns the scored examples page."""
52
52
  return self.link(os.path.join(self.dir, Scoring.SCORED_HTML))
53
53
 
54
54
  @property
55
55
  def avg_score(self) -> float:
56
+ """Returns the average score of scored examples."""
56
57
  if self.num_scored == 0:
57
58
  return 0
58
59
  return sum([i[2] for i in self._scored]) / self.num_scored
@@ -181,7 +182,7 @@ class Scoring(base.Evaluation):
181
182
  super()._render_summary_metrics(s)
182
183
 
183
184
  def _render_scored(self, s: io.StringIO) -> None:
184
- """Formats the matched cases into html."""
185
+ """Formats the scored cases into html."""
185
186
  s.write('<h2> Scored </h2>')
186
187
  s.write('<div style="white-space:pre">\n')
187
188
  s.write(
@@ -38,6 +38,7 @@ from langfun.core.eval.v2 import runners
38
38
  from langfun.core.eval.v2.checkpointing import BulkCheckpointer
39
39
  from langfun.core.eval.v2.checkpointing import PerExampleCheckpointer
40
40
  from langfun.core.eval.v2.reporting import HtmlReporter
41
+ from langfun.core.eval.v2.reporting import ExampleHtmlGenerator
41
42
 
42
43
 
43
44
  # pylint: enable=g-bad-import-order
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Checkpointing evaluation runs."""
15
15
  import abc
16
+ import datetime
16
17
  import re
17
18
  import threading
18
19
  import traceback
@@ -29,12 +30,32 @@ Runner = experiment_lib.Runner
29
30
 
30
31
 
31
32
  class Checkpointer(experiment_lib.Plugin):
32
- """Base class for checkpointing evaluation examples."""
33
+ """Base class for checkpointing evaluation examples.
34
+
35
+ `Checkpointer` is a plugin that saves the state of processed examples
36
+ incrementally during an experiment run, allowing the experiment to be resumed
37
+ later. When an experiment starts, the checkpointer loads any previously saved
38
+ examples from an earlier run (or a warm-start run) into `experiment.state`,
39
+ so the runner can skip processing them again.
40
+ Subclasses should implement `_list_checkpoint_filenames` to identify
41
+ checkpoint files to load, and `_save_example` to save a newly processed
42
+ example.
43
+ """
33
44
 
34
45
  checkpoint_filename: Annotated[
35
46
  str,
36
47
  'Checkpoint file pattern.'
37
- ] = 'checkpoint.bagz'
48
+ ] = 'checkpoint.jsonl'
49
+
50
+ enable_inprogress_file: Annotated[
51
+ bool,
52
+ 'If True, write file "<example_id>.inprogress" when example gets started.'
53
+ ] = True
54
+
55
+ max_ckpt_loading_threads: Annotated[
56
+ int,
57
+ 'Max number of workers for loading checkpoint files at startup.'
58
+ ] = 128
38
59
 
39
60
  def on_experiment_start(
40
61
  self,
@@ -75,6 +96,24 @@ class Checkpointer(experiment_lib.Plugin):
75
96
  f'scratch. Example IDs: {example_ids_to_evaluate}.'
76
97
  )
77
98
 
99
+ def on_example_start(
100
+ self,
101
+ runner: Runner,
102
+ experiment: Experiment,
103
+ example: Example,
104
+ ) -> None:
105
+ """Saves the example to the checkpoint file."""
106
+ if self.enable_inprogress_file:
107
+ def _save_inprogress_file(example: Example):
108
+ inprogress_file = runner.current_run.output_path_for(
109
+ experiment, f'{example.id}.inprogress'
110
+ )
111
+ pg.io.writefile(
112
+ inprogress_file,
113
+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
114
+ )
115
+ runner.background_run(_save_inprogress_file, example)
116
+
78
117
  def on_example_complete(
79
118
  self,
80
119
  runner: Runner,
@@ -149,7 +188,10 @@ class Checkpointer(experiment_lib.Plugin):
149
188
 
150
189
  _ = list(
151
190
  lf.concurrent_map(
152
- _load_state, ckpt_files, max_workers=16, silence_on_errors=None
191
+ _load_state,
192
+ ckpt_files,
193
+ max_workers=self.max_ckpt_loading_threads,
194
+ silence_on_errors=None
153
195
  )
154
196
  )
155
197
 
@@ -170,7 +212,12 @@ class Checkpointer(experiment_lib.Plugin):
170
212
 
171
213
 
172
214
  class PerExampleCheckpointer(Checkpointer):
173
- """Checkpointer that saves each example to a separate file."""
215
+ """Checkpointer that saves each example to a separate file.
216
+
217
+ This checkpointer saves each processed example to its own checkpoint file,
218
+ named using the pattern `<checkpoint_filename_prefix>_<example_id>.<ext>`.
219
+ For example, `checkpoint_1.bagz`, `checkpoint_2.bagz`, etc.
220
+ """
174
221
 
175
222
  def _on_bound(self):
176
223
  super()._on_bound()
@@ -235,7 +282,13 @@ class PerExampleCheckpointer(Checkpointer):
235
282
 
236
283
 
237
284
  class BulkCheckpointer(Checkpointer):
238
- """Checkpointer that saves all examples to a single file."""
285
+ """Checkpointer that saves all examples of an evaluation to a single file.
286
+
287
+ This checkpointer appends newly processed examples of an evaluation to a
288
+ single sequence file (e.g., `checkpoint.bagz`). This is often more efficient
289
+ than `PerExampleCheckpointer` when dealing with a large number of examples
290
+ or when file system overhead is a concern.
291
+ """
239
292
 
240
293
  def _on_bound(self):
241
294
  super()._on_bound()
@@ -341,7 +394,12 @@ class BulkCheckpointer(Checkpointer):
341
394
 
342
395
 
343
396
  class SequenceWriter:
344
- """Thread safe sequence writer."""
397
+ """A thread-safe writer for sequence files (e.g., Bagz).
398
+
399
+ `SequenceWriter` wraps a `pg.io.SequenceWriter` to provide thread-safe
400
+ `add` and `close` operations, ensuring that examples can be written
401
+ concurrently from multiple threads without corrupting the sequence file.
402
+ """
345
403
 
346
404
  def __init__(self, path: str):
347
405
  self._lock = threading.Lock()
@@ -65,7 +65,7 @@ class ExampleCollector(experiment_lib.Plugin):
65
65
  return self._examples
66
66
 
67
67
  def on_example_complete(
68
- self, runner: runners_lib.Runner,
68
+ self, runner: experiment_lib.Runner,
69
69
  experiment: experiment_lib.Experiment,
70
70
  example: example_lib.Example,
71
71
  ):
@@ -90,7 +90,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
90
90
  root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer')
91
91
  experiment = eval_test_helper.test_experiment()
92
92
  checkpoint_filename = 'checkpoint.jsonl'
93
- checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
93
+ checkpointer = checkpointing.PerExampleCheckpointer(
94
+ checkpoint_filename,
95
+ enable_inprogress_file=True
96
+ )
94
97
  collector = ExampleCollector()
95
98
  run = experiment.run(
96
99
  root_dir, 'new', runner='sequential', plugins=[checkpointer, collector]
@@ -102,6 +105,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
102
105
  example = collector.examples[i + 1]
103
106
  ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
104
107
  self.assertTrue(pg.io.path_exists(ckpt))
108
+ inprogress_file = run.output_path_for(
109
+ leaf, f'{example.id}.inprogress'
110
+ )
111
+ self.assertTrue(pg.io.path_exists(inprogress_file))
105
112
  with pg.io.open_sequence(ckpt) as f:
106
113
  examples_from_ckpt = list(iter(f))
107
114
  # `eval_test_helper.test_experiment` has two TestEvaluation with
@@ -13,6 +13,9 @@
13
13
  # limitations under the License.
14
14
  """Helper classes and functions for evaluation tests."""
15
15
 
16
+ import threading
17
+ import time
18
+
16
19
  from langfun.core import language_model
17
20
  from langfun.core import llms
18
21
  from langfun.core import message as message_lib
@@ -47,6 +50,8 @@ class TestLLM(llms.Fake):
47
50
 
48
51
  offset: int = 0
49
52
 
53
+ __test__ = False
54
+
50
55
  def _response_from(self, prompt: message_lib.Message) -> message_lib.Message:
51
56
  return message_lib.AIMessage(
52
57
  str(prompt.metadata.x + prompt.metadata.y + self.offset)
@@ -63,6 +68,8 @@ class TestEvaluation(Evaluation):
63
68
  metrics = [metrics_lib.Match()]
64
69
  lm: language_model.LanguageModel = TestLLM()
65
70
 
71
+ __test__ = False
72
+
66
73
  def process(self, example):
67
74
  v = example.input
68
75
  if v.x == 5:
@@ -84,6 +91,8 @@ class TestEvaluationWithExampleCheckpointingError(TestEvaluation):
84
91
  inputs = test_inputs()
85
92
  metrics = [metrics_lib.Match()]
86
93
 
94
+ __test__ = False
95
+
87
96
  def process(self, example):
88
97
  return 1, dict(
89
98
  x=BadJsonConvertible()
@@ -101,6 +110,8 @@ class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
101
110
  inputs = test_inputs()
102
111
  metrics = [metrics_lib.Match()]
103
112
 
113
+ __test__ = False
114
+
104
115
  def process(self, example):
105
116
  return 1, dict(
106
117
  x=BadHtmlConvertible()
@@ -110,15 +121,22 @@ class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
110
121
  class TestEvaluationWithIndexHtmlGenerationError(TestEvaluation):
111
122
  """Test evaluation class with bad index HTML generation."""
112
123
 
124
+ __test__ = False
125
+
113
126
  def _html_tree_view(self, *args, **kwargs):
114
127
  raise ValueError('Cannot render HTML.')
115
128
 
116
129
 
130
+ def test_evaluation(offset: int | pg.hyper.OneOf = 0):
131
+ """Returns a test evaluation."""
132
+ return TestEvaluation(lm=TestLLM(offset=offset))
133
+
134
+
117
135
  def test_experiment():
118
136
  """Returns a test experiment."""
119
137
  return Suite([
120
- TestEvaluation(lm=TestLLM(offset=0)),
121
- TestEvaluation(lm=TestLLM(offset=pg.oneof(range(5)))),
138
+ test_evaluation(),
139
+ test_evaluation(pg.oneof(range(5))),
122
140
  ])
123
141
 
124
142
 
@@ -135,3 +153,86 @@ def test_experiment_with_example_html_generation_error():
135
153
  def test_experiment_with_index_html_generation_error():
136
154
  """Returns a test experiment with bad index HTML."""
137
155
  return TestEvaluationWithIndexHtmlGenerationError()
156
+
157
+
158
+ class TestPlugin(experiment_lib.Plugin):
159
+ """Plugin for testing."""
160
+
161
+ started_experiments: list[experiment_lib.Experiment] = []
162
+ completed_experiments: list[experiment_lib.Experiment] = []
163
+ skipped_experiments: list[experiment_lib.Experiment] = []
164
+ started_example_ids: list[int] = []
165
+ completed_example_ids: list[int] = []
166
+ start_time: float | None = None
167
+ complete_time: float | None = None
168
+
169
+ __test__ = False
170
+
171
+ def _on_bound(self):
172
+ super()._on_bound()
173
+ self._lock = threading.Lock()
174
+
175
+ def on_run_start(
176
+ self,
177
+ runner: experiment_lib.Runner,
178
+ root: experiment_lib.Experiment
179
+ ) -> None:
180
+ del root
181
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
182
+ self.start_time = time.time()
183
+
184
+ def on_run_complete(
185
+ self,
186
+ runner: experiment_lib.Runner,
187
+ root: experiment_lib.Experiment
188
+ ) -> None:
189
+ del root
190
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
191
+ self.complete_time = time.time()
192
+
193
+ def on_experiment_start(
194
+ self,
195
+ runner: experiment_lib.Runner,
196
+ experiment: experiment_lib.Experiment
197
+ ) -> None:
198
+ del runner
199
+ with pg.notify_on_change(False), self._lock:
200
+ self.started_experiments.append(pg.Ref(experiment))
201
+
202
+ def on_experiment_skipped(
203
+ self,
204
+ runner: experiment_lib.Runner,
205
+ experiment: experiment_lib.Experiment
206
+ ) -> None:
207
+ del runner
208
+ with pg.notify_on_change(False), self._lock:
209
+ self.skipped_experiments.append(pg.Ref(experiment))
210
+
211
+ def on_experiment_complete(
212
+ self,
213
+ runner: experiment_lib.Runner,
214
+ experiment: experiment_lib.Experiment
215
+ ) -> None:
216
+ del runner
217
+ with pg.notify_on_change(False), self._lock:
218
+ self.completed_experiments.append(pg.Ref(experiment))
219
+
220
+ def on_example_start(
221
+ self,
222
+ runner: experiment_lib.Runner,
223
+ experiment: experiment_lib.Experiment,
224
+ example: Example
225
+ ) -> None:
226
+ del runner, experiment
227
+ with pg.notify_on_change(False), self._lock:
228
+ self.started_example_ids.append(example.id)
229
+
230
+ def on_example_complete(
231
+ self,
232
+ runner: experiment_lib.Runner,
233
+ experiment: experiment_lib.Experiment,
234
+ example: Example
235
+ ) -> None:
236
+ del runner, experiment
237
+ with pg.notify_on_change(False), self._lock:
238
+ self.completed_example_ids.append(example.id)
@@ -32,17 +32,63 @@ import pyglove as pg
32
32
 
33
33
 
34
34
  class Evaluation(experiment_lib.Experiment):
35
- """Evaluation.
36
-
37
- An evaluation can be a leaf node or a container of other evaluations,
38
- depending on whether the current evaluation object is configured with
39
- any `pg.oneof`.
40
-
41
- For example, `MyEval(lm=pg.oneof([lf.llms.Gpt4(), lf.llms.Gemini1_5Pro()]))`
42
- is a container of two sub-experiments, one for each LLM. In such case, the
43
- evaluation object with `pg.oneof` is called a hyper evaluation, which
44
- represents a search space of evaluations, and each sub-evaluation is called
45
- a leaf evaluation, which will perform the actual evaluation.
35
+ """Base class for Langfun evaluations.
36
+
37
+ `lf.eval.Evaluation` is the base class for defining evaluation tasks in
38
+ Langfun. Users typically subclass it to implement custom evaluation logic by
39
+ overriding `inputs` and `process` methods.
40
+
41
+ An `Evaluation` object encapsulates:
42
+
43
+ * **`inputs`**: A callable that returns an iterable of input examples to be
44
+ processed. This is usually provided by implementing an `inputs(self)`
45
+ method in the subclass, which yields input items for evaluation one by
46
+ one.
47
+ * **`process(self, example)`**: An abstract method that processes one
48
+ example and returns the output, or a tuple of (output, metadata).
49
+ The output will be used for computing metrics.
50
+ * **`metrics`**: A list of metrics (e.g., `lf.metrics.Accuracy`) to compute
51
+ based on the outputs from `process`. Some metrics may require users to
52
+ implement a `ground_truth(self, example)` method in the subclass to
53
+ compute metrics against ground truth.
54
+ * **Hyperparameters**: Any other attributes of the class serve as
55
+ hyperparameters for the evaluation (e.g., the language model to use).
56
+
57
+ **Running Evaluations:**
58
+
59
+ Evaluations are executed via `lf.eval.Suite` or by calling the `.run()`
60
+ method on an `Evaluation` instance, which returns a `Run` object
61
+ containing the evaluation run information and results. If an evaluation
62
+ contains sweeable parameters (using `pg.oneof`), `.run()` will expand it
63
+ into multiple evaluation sub-tasks -- one for each combination of
64
+ hyperparameters -- all managed within the same `Run`.
65
+
66
+ **Example:**
67
+
68
+ ```python
69
+ import langfun as lf
70
+ import pyglove as pg
71
+
72
+ class MyEval(lf.eval.Evaluation):
73
+ lm: lf.LanguageModel
74
+ prompt: str = '1 + 1 = '
75
+
76
+ def inputs(self):
77
+ yield 2
78
+
79
+ def process(self, example: lf.eval.Example):
80
+ return int(lf.query(self.prompt, lm=self.lm))
81
+
82
+ def ground_truth(self, example: lf.eval.Example) -> int:
83
+ return example.input
84
+
85
+ # Run evaluation using two different LMs
86
+ evaluation = MyEval(
87
+ lm=pg.oneof([lf.llms.Gpt4(), lf.llms.Gemini()]),
88
+ metrics=[lf.metrics.Accuracy()]
89
+ )
90
+ run_info = evaluation.run()
91
+ ```
46
92
  """
47
93
 
48
94
  inputs: Annotated[
@@ -126,6 +172,20 @@ class Evaluation(experiment_lib.Experiment):
126
172
  # Evaluation logics.
127
173
  #
128
174
 
175
+ def setup(self) -> None:
176
+ """Sets up resources required by the evaluation.
177
+
178
+ Subclasses should always call the super().setup() method to ensure the
179
+ proper initialization of the evaluation.
180
+ """
181
+
182
+ def teardown(self) -> None:
183
+ """Tears down resources used by the evaluation.
184
+
185
+ Subclasses should always call the super().teardown() method to ensure the
186
+ proper cleanup of the evaluation.
187
+ """
188
+
129
189
  @abc.abstractmethod
130
190
  def process(
131
191
  self,
@@ -137,7 +197,7 @@ class Evaluation(experiment_lib.Experiment):
137
197
 
138
198
  Args:
139
199
  example: An example object to process. `example.input` is an object
140
- returned from `Evaluable.inputs`.
200
+ yielded from `inputs()` method.
141
201
 
142
202
  Returns:
143
203
  A processed output. Or a tuple of (output, metadata).
@@ -150,6 +210,7 @@ class Evaluation(experiment_lib.Experiment):
150
210
  example: example_lib.Example | int,
151
211
  raise_if_has_error: bool = False,
152
212
  reevaluate_upon_previous_errors: bool = True,
213
+ force_recompute_metrics: bool = False
153
214
  ) -> example_lib.Example:
154
215
  """Evaluates a single example input.
155
216
 
@@ -158,6 +219,8 @@ class Evaluation(experiment_lib.Experiment):
158
219
  raise_if_has_error: Whether to raise an error if the example has error.
159
220
  reevaluate_upon_previous_errors: Whether to reevaluate the example if
160
221
  the previous checkpointed run has error.
222
+ force_recompute_metrics: If True, force recompute the metrics even if
223
+ metric metadata is already present from previous checkpoint.
161
224
 
162
225
  Returns:
163
226
  The evaluated example with the output and metric metadata populated.
@@ -206,6 +269,7 @@ class Evaluation(experiment_lib.Experiment):
206
269
  # Use the output and metadata obtained from the previous processing.
207
270
  example.output = checkpointed.output
208
271
  example.metadata = checkpointed.metadata
272
+ example.metric_metadata = checkpointed.metric_metadata
209
273
  example.error = checkpointed.error
210
274
  example.newly_processed = False
211
275
  example.execution_status = checkpointed.execution_status
@@ -225,8 +289,16 @@ class Evaluation(experiment_lib.Experiment):
225
289
  self.info(f'Starting metric computation for example {example.id}.')
226
290
  metric_metadata = {}
227
291
  for metric in self.metrics:
228
- metric_metadata.update(metric.audit(example))
229
- example.metric_metadata = metric_metadata
292
+ metric_metadata[metric.name] = metric.update(
293
+ example, force_recompute=force_recompute_metrics
294
+ )
295
+
296
+ if example.metric_metadata is None:
297
+ example.metric_metadata = metric_metadata
298
+ else:
299
+ # Accumulate the metric metadata as there might be existing metadata
300
+ # from previous metric computation runs.
301
+ example.metric_metadata.update(metric_metadata)
230
302
  self.info(f'Completed metric computation for example {example.id}.')
231
303
 
232
304
  # For previously processed examples, we keep the execution status for the
@@ -760,7 +832,7 @@ class Evaluation(experiment_lib.Experiment):
760
832
 
761
833
 
762
834
  class EvaluationState:
763
- """Evaluation state."""
835
+ """In-memory state of an evaluation."""
764
836
 
765
837
  class ExampleStatus(pg.Object):
766
838
  """Example state."""
@@ -808,8 +880,9 @@ class EvaluationState:
808
880
  load_example_metadata: bool | Callable[
809
881
  [example_lib.Example], bool] = True,
810
882
  filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
811
- ) -> None:
883
+ ) -> list[example_lib.Example]:
812
884
  """Loads the state from the example sequence file."""
885
+ examples = []
813
886
  for example in example_lib.Example.iter_ckpts(
814
887
  state_file,
815
888
  example_input_by_id=example_input_by_id,
@@ -819,6 +892,8 @@ class EvaluationState:
819
892
  continue
820
893
  example.newly_processed = False
821
894
  self._ckpt_examples[example.id] = example
895
+ examples.append(example)
896
+ return examples
822
897
 
823
898
  @property
824
899
  def evaluation_status(self) -> dict[int, ExampleStatus]:
@@ -88,7 +88,7 @@ class EvaluationTest(unittest.TestCase):
88
88
  self.assertEqual(example.output, 6)
89
89
  self.assertIsNone(example.error)
90
90
  self.assertEqual(example.metadata, {})
91
- self.assertEqual(example.metric_metadata, dict(match=True))
91
+ self.assertEqual(example.metric_metadata, dict(match=dict(is_correct=True)))
92
92
  self.assertIsNotNone(example.usage_summary)
93
93
  self.assertGreater(example.usage_summary.total.total_tokens, 0)
94
94
  self.assertEqual(example.usage_summary.total.num_requests, 1)
@@ -103,7 +103,10 @@ class EvaluationTest(unittest.TestCase):
103
103
  self.assertEqual(example.output, 7)
104
104
  self.assertIsNone(example.error)
105
105
  self.assertEqual(example.metadata, {})
106
- self.assertEqual(example.metric_metadata, dict(mismatch=True))
106
+ self.assertEqual(
107
+ example.metric_metadata,
108
+ dict(match=dict(is_correct=False))
109
+ )
107
110
 
108
111
  with self.assertRaisesRegex(ValueError, 'x should not be 5'):
109
112
  _ = exp.evaluate(6, raise_if_has_error=True)
@@ -113,7 +116,10 @@ class EvaluationTest(unittest.TestCase):
113
116
  self.assertEqual(pg.MISSING_VALUE, example.output)
114
117
  self.assertEqual(example.error.tag, 'ValueError')
115
118
  self.assertEqual(example.metadata, {})
116
- self.assertEqual(example.metric_metadata, dict(error='ValueError'))
119
+ self.assertEqual(
120
+ example.metric_metadata,
121
+ dict(match=dict(error='ValueError'))
122
+ )
117
123
 
118
124
  def test_evaluate_withstate(self):
119
125
  eval_dir = os.path.join(tempfile.mkdtemp(), 'test_eval')