langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (155) hide show
  1. langfun/core/__init__.py +2 -0
  2. langfun/core/agentic/__init__.py +4 -1
  3. langfun/core/agentic/action.py +447 -29
  4. langfun/core/agentic/action_eval.py +9 -2
  5. langfun/core/agentic/action_test.py +149 -21
  6. langfun/core/async_support.py +32 -3
  7. langfun/core/coding/python/correction.py +19 -9
  8. langfun/core/coding/python/execution.py +14 -12
  9. langfun/core/coding/python/generation.py +21 -16
  10. langfun/core/coding/python/sandboxing.py +23 -3
  11. langfun/core/component.py +42 -3
  12. langfun/core/concurrent.py +70 -6
  13. langfun/core/concurrent_test.py +1 -0
  14. langfun/core/console.py +1 -1
  15. langfun/core/data/conversion/anthropic.py +12 -3
  16. langfun/core/data/conversion/anthropic_test.py +8 -6
  17. langfun/core/data/conversion/gemini.py +9 -2
  18. langfun/core/data/conversion/gemini_test.py +12 -9
  19. langfun/core/data/conversion/openai.py +145 -31
  20. langfun/core/data/conversion/openai_test.py +161 -17
  21. langfun/core/eval/base.py +47 -43
  22. langfun/core/eval/base_test.py +5 -5
  23. langfun/core/eval/matching.py +5 -2
  24. langfun/core/eval/patching.py +3 -3
  25. langfun/core/eval/scoring.py +4 -3
  26. langfun/core/eval/v2/__init__.py +1 -0
  27. langfun/core/eval/v2/checkpointing.py +64 -6
  28. langfun/core/eval/v2/checkpointing_test.py +9 -2
  29. langfun/core/eval/v2/eval_test_helper.py +103 -2
  30. langfun/core/eval/v2/evaluation.py +91 -16
  31. langfun/core/eval/v2/evaluation_test.py +9 -3
  32. langfun/core/eval/v2/example.py +50 -40
  33. langfun/core/eval/v2/example_test.py +16 -8
  34. langfun/core/eval/v2/experiment.py +74 -8
  35. langfun/core/eval/v2/experiment_test.py +19 -0
  36. langfun/core/eval/v2/metric_values.py +31 -3
  37. langfun/core/eval/v2/metric_values_test.py +32 -0
  38. langfun/core/eval/v2/metrics.py +157 -44
  39. langfun/core/eval/v2/metrics_test.py +39 -18
  40. langfun/core/eval/v2/progress.py +30 -1
  41. langfun/core/eval/v2/progress_test.py +27 -0
  42. langfun/core/eval/v2/progress_tracking.py +12 -3
  43. langfun/core/eval/v2/progress_tracking_test.py +6 -1
  44. langfun/core/eval/v2/reporting.py +90 -71
  45. langfun/core/eval/v2/reporting_test.py +24 -6
  46. langfun/core/eval/v2/runners/__init__.py +30 -0
  47. langfun/core/eval/v2/{runners.py → runners/base.py} +59 -142
  48. langfun/core/eval/v2/runners/beam.py +341 -0
  49. langfun/core/eval/v2/runners/beam_test.py +131 -0
  50. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  51. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  52. langfun/core/eval/v2/runners/debug.py +40 -0
  53. langfun/core/eval/v2/runners/debug_test.py +76 -0
  54. langfun/core/eval/v2/runners/parallel.py +100 -0
  55. langfun/core/eval/v2/runners/parallel_test.py +95 -0
  56. langfun/core/eval/v2/runners/sequential.py +47 -0
  57. langfun/core/eval/v2/runners/sequential_test.py +172 -0
  58. langfun/core/langfunc.py +45 -130
  59. langfun/core/langfunc_test.py +7 -5
  60. langfun/core/language_model.py +141 -21
  61. langfun/core/language_model_test.py +54 -3
  62. langfun/core/llms/__init__.py +9 -1
  63. langfun/core/llms/anthropic.py +157 -2
  64. langfun/core/llms/azure_openai.py +29 -17
  65. langfun/core/llms/cache/base.py +25 -3
  66. langfun/core/llms/cache/in_memory.py +48 -7
  67. langfun/core/llms/cache/in_memory_test.py +14 -4
  68. langfun/core/llms/compositional.py +25 -1
  69. langfun/core/llms/deepseek.py +30 -2
  70. langfun/core/llms/fake.py +32 -1
  71. langfun/core/llms/gemini.py +55 -17
  72. langfun/core/llms/gemini_test.py +84 -0
  73. langfun/core/llms/google_genai.py +34 -1
  74. langfun/core/llms/groq.py +28 -3
  75. langfun/core/llms/llama_cpp.py +23 -4
  76. langfun/core/llms/openai.py +36 -3
  77. langfun/core/llms/openai_compatible.py +148 -27
  78. langfun/core/llms/openai_compatible_test.py +207 -20
  79. langfun/core/llms/openai_test.py +0 -2
  80. langfun/core/llms/rest.py +12 -1
  81. langfun/core/llms/vertexai.py +58 -8
  82. langfun/core/logging.py +1 -1
  83. langfun/core/mcp/client.py +77 -22
  84. langfun/core/mcp/client_test.py +8 -35
  85. langfun/core/mcp/session.py +94 -29
  86. langfun/core/mcp/session_test.py +54 -0
  87. langfun/core/mcp/tool.py +151 -22
  88. langfun/core/mcp/tool_test.py +197 -0
  89. langfun/core/memory.py +1 -0
  90. langfun/core/message.py +160 -55
  91. langfun/core/message_test.py +65 -81
  92. langfun/core/modalities/__init__.py +8 -0
  93. langfun/core/modalities/audio.py +21 -1
  94. langfun/core/modalities/image.py +19 -1
  95. langfun/core/modalities/mime.py +64 -3
  96. langfun/core/modalities/mime_test.py +11 -0
  97. langfun/core/modalities/pdf.py +19 -1
  98. langfun/core/modalities/video.py +21 -1
  99. langfun/core/modality.py +167 -29
  100. langfun/core/modality_test.py +42 -12
  101. langfun/core/natural_language.py +1 -1
  102. langfun/core/sampling.py +4 -4
  103. langfun/core/sampling_test.py +20 -4
  104. langfun/core/structured/__init__.py +2 -24
  105. langfun/core/structured/completion.py +34 -44
  106. langfun/core/structured/completion_test.py +23 -43
  107. langfun/core/structured/description.py +54 -50
  108. langfun/core/structured/function_generation.py +29 -12
  109. langfun/core/structured/mapping.py +81 -37
  110. langfun/core/structured/parsing.py +95 -79
  111. langfun/core/structured/parsing_test.py +0 -3
  112. langfun/core/structured/querying.py +215 -142
  113. langfun/core/structured/querying_test.py +65 -29
  114. langfun/core/structured/schema/__init__.py +49 -0
  115. langfun/core/structured/schema/base.py +664 -0
  116. langfun/core/structured/schema/base_test.py +531 -0
  117. langfun/core/structured/schema/json.py +174 -0
  118. langfun/core/structured/schema/json_test.py +121 -0
  119. langfun/core/structured/schema/python.py +316 -0
  120. langfun/core/structured/schema/python_test.py +410 -0
  121. langfun/core/structured/schema_generation.py +33 -14
  122. langfun/core/structured/scoring.py +47 -36
  123. langfun/core/structured/tokenization.py +26 -11
  124. langfun/core/subscription.py +2 -2
  125. langfun/core/template.py +174 -49
  126. langfun/core/template_test.py +123 -17
  127. langfun/env/__init__.py +8 -2
  128. langfun/env/base_environment.py +320 -128
  129. langfun/env/base_environment_test.py +473 -0
  130. langfun/env/base_feature.py +92 -15
  131. langfun/env/base_feature_test.py +228 -0
  132. langfun/env/base_sandbox.py +84 -361
  133. langfun/env/base_sandbox_test.py +1235 -0
  134. langfun/env/event_handlers/__init__.py +1 -1
  135. langfun/env/event_handlers/chain.py +233 -0
  136. langfun/env/event_handlers/chain_test.py +253 -0
  137. langfun/env/event_handlers/event_logger.py +95 -98
  138. langfun/env/event_handlers/event_logger_test.py +21 -21
  139. langfun/env/event_handlers/metric_writer.py +225 -140
  140. langfun/env/event_handlers/metric_writer_test.py +23 -6
  141. langfun/env/interface.py +854 -40
  142. langfun/env/interface_test.py +112 -2
  143. langfun/env/load_balancers_test.py +23 -2
  144. langfun/env/test_utils.py +126 -84
  145. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
  146. langfun-0.1.2.dev202511270805.dist-info/RECORD +215 -0
  147. langfun/core/eval/v2/runners_test.py +0 -343
  148. langfun/core/structured/schema.py +0 -987
  149. langfun/core/structured/schema_test.py +0 -982
  150. langfun/env/base_test.py +0 -1481
  151. langfun/env/event_handlers/base.py +0 -350
  152. langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
  153. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
  154. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
  155. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,341 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Beam-based evaluation runner.
15
+
16
+ BeamRunner is a runner that uses Apache Beam to run evaluations in parallel.
17
+ It is useful for running evaluations with a large number of examples and/or when
18
+ each example is costly to evaluate and can be parallelized.
19
+
20
+ BeamRunner supports plugins as all other runners do, with the following caveats:
21
+
22
+ 1. Checkpointer plugins are ignored, as BeamRunner performs its own per-example
23
+ checkpointing.
24
+
25
+ 2. Per-example plugins are executed in the Beam worker to maximize throughput,
26
+ while all non-per-example plugins are executed in the main process, which
27
+ collects the results from the workers. Since it might be expensive to
28
+ deserialize `Example.metadata` for complex evaluations, the main process
29
+ does not deserialize `Example.metadata` from the workers. If you need to
30
+ to access `Example.metadata` in your plugin, consider make it a per-example
31
+ plugin (which only implements `on_example_start` and/or
32
+ `on_example_complete`)
33
+
34
+ To use it, simply create a `lf.eval.Suite` or `lf.eval.Evaluation`
35
+ and run it with `lf.eval.run(runner='beam')` and passing in an additional
36
+ `beam_runner` argument.
37
+ """
38
+
39
+ import datetime
40
+ import hashlib
41
+ import os
42
+ import random
43
+ import time
44
+ from typing import Annotated, Any, Iterator
45
+
46
+ from langfun.core.eval.v2 import checkpointing
47
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
48
+ from langfun.core.eval.v2 import example as example_lib
49
+ from langfun.core.eval.v2.runners import base
50
+ from langfun.core.eval.v2.runners import ckpt_monitor
51
+
52
+ import pyglove as pg
53
+
54
+ try:
55
+ # pylint: disable=g-import-not-at-top
56
+ # pytype: disable=import-error
57
+ import apache_beam as beam
58
+ from apache_beam.options import pipeline_options
59
+ # pytype: enable=import-error
60
+ # pylint: enable=g-import-not-at-top
61
+ except ImportError:
62
+ beam = None
63
+ pipeline_options = None
64
+
65
+
66
+ if beam is not None:
67
+ class _EvaluateFn(beam.DoFn):
68
+ """Beam DoFn for evaluating examples."""
69
+
70
+ def __init__(
71
+ self,
72
+ runner_str: str,
73
+ ckpt_format: str,
74
+ concurrent_startup_delay: tuple[int, int] | None = None,
75
+ ):
76
+ self._runner_str = runner_str
77
+ self._ckpt_format = ckpt_format
78
+ self._concurrent_startup_delay = concurrent_startup_delay
79
+
80
+ def setup(self):
81
+ if self._concurrent_startup_delay is not None:
82
+ time.sleep(random.randint(*self._concurrent_startup_delay))
83
+ self._runner = pg.from_json_str(self._runner_str)
84
+ assert isinstance(self._runner, LeafNodeRunner)
85
+ self._runner.setup()
86
+ self._output_dir = self._runner.current_run.output_dir(
87
+ self._runner.current_run.experiment
88
+ )
89
+
90
+ def teardown(self):
91
+ assert self._runner is not None
92
+ self._runner.teardown()
93
+
94
+ def process(self, example: tuple[int, str]) -> Iterator[str]:
95
+ """Evaluates an example and writes the checkpoint file.
96
+
97
+ Args:
98
+ example: A tuple of (example_id, example_json).
99
+
100
+ Yields:
101
+ The path to the checkpoint file.
102
+ """
103
+ example_id, example_json = example
104
+ ckpt_file = os.path.join(
105
+ self._output_dir, f'checkpoint_{example_id}.{self._ckpt_format}'
106
+ )
107
+ if pg.io.path_exists(ckpt_file):
108
+ yield ckpt_file
109
+
110
+ # Write the in-progress file to indicate that the example is being
111
+ # processed.
112
+ in_progress_file = os.path.join(
113
+ self._output_dir, f'{example_id}.inprogress'
114
+ )
115
+ pg.io.writefile(
116
+ in_progress_file,
117
+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
118
+ )
119
+
120
+ # Process one example.
121
+ example = self._runner.process(pg.from_json_str(example_json))
122
+
123
+ # Perform atomic checkpointing.
124
+ tmp_ckpt_file = os.path.join(
125
+ self._output_dir, f'tmp_checkpoint_{example_id}.{self._ckpt_format}'
126
+ )
127
+ example_json_str = pg.to_json_str(example)
128
+ with pg.io.open_sequence(tmp_ckpt_file, 'w') as f:
129
+ f.add(example_json_str)
130
+ pg.io.rename(tmp_ckpt_file, ckpt_file)
131
+
132
+ # Write the MD5 digest of the example so we know if the example has been
133
+ # processed multiple times.
134
+ digest = hashlib.md5(example_json_str.encode()).hexdigest()[:8]
135
+ pg.io.writefile(
136
+ os.path.join(self._output_dir, f'{example_id}.{digest}.md5'),
137
+ digest
138
+ )
139
+ yield ckpt_file
140
+
141
+ else:
142
+ _EvaluateFn = None # pylint: disable=invalid-name
143
+
144
+
145
+ class LeafNodeRunner(base.RunnerBase):
146
+ """A runner that runs in a DoFn worker."""
147
+
148
+ NAME = '__beam_leaf_node_runner__'
149
+ progress_tracker = None
150
+ max_background_threads = 0
151
+
152
+ def _on_bound(self):
153
+ super()._on_bound()
154
+ for plugin in self.plugins:
155
+ if not plugin.is_per_example():
156
+ raise ValueError(
157
+ 'Only per-example plugins are supported in LeafNodeRunner. '
158
+ f'Encountered: {plugin!r}'
159
+ )
160
+ if not isinstance(self.current_run.experiment, evaluation_lib.Evaluation):
161
+ raise ValueError(
162
+ 'The experiment must be a leaf evaluation in LeafNodeRunner. '
163
+ f'Encountered: {self.current_run.experiment!r}'
164
+ )
165
+
166
+ def setup(self):
167
+ self.current_run.experiment.setup()
168
+
169
+ def teardown(self):
170
+ self.current_run.experiment.teardown()
171
+
172
+ def process(self, example: example_lib.Example) -> example_lib.Example:
173
+ """Processes one example."""
174
+ for plugin in self.plugins:
175
+ plugin.on_example_start(self, self.current_run.experiment, example)
176
+ example = self.current_run.experiment.evaluate(example)
177
+ for plugin in self.plugins:
178
+ plugin.on_example_complete(self, self.current_run.experiment, example)
179
+ return example
180
+
181
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]) -> None:
182
+ """Runs the experiment in sequence."""
183
+ raise NotImplementedError('Not needed in leaf node runner.')
184
+
185
+ def _evaluate_items(
186
+ self,
187
+ evaluation: evaluation_lib.Evaluation,
188
+ items: Iterator[example_lib.Example]
189
+ ) -> None:
190
+ """Evaluates the items of an evaluation."""
191
+ raise NotImplementedError('Not needed in leaf node runner.')
192
+
193
+
194
+ class BeamRunner(base.RunnerBase):
195
+ """Beam runner for Langfun evaluations.
196
+
197
+ NOTE: This runner depends on Apache Beam, which needs to be installed
198
+ separately.
199
+ """
200
+
201
+ NAME = 'beam'
202
+
203
+ beam_runner: Annotated[
204
+ Any | None,
205
+ 'The beam runner to use. If None, the direct runner will be used.'
206
+ ] = None
207
+
208
+ beam_pipeline_options: Annotated[
209
+ dict[str, Any],
210
+ 'Beam pipeline options.'
211
+ ] = {}
212
+
213
+ ckpt_format: Annotated[
214
+ str,
215
+ 'The file extension of the checkpoint files.'
216
+ ] = 'jsonl'
217
+
218
+ max_aggregation_threads: Annotated[
219
+ int,
220
+ 'The maximum number of threads to aggregate checkpoints.'
221
+ ] = 128
222
+
223
+ concurrent_startup_delay: Annotated[
224
+ tuple[int, int] | None,
225
+ (
226
+ 'A range of seconds to delay the initial evaluation of each thread '
227
+ 'in the thread pool, helping to prevent a burst in LLM QPS at '
228
+ 'startup. If set to None, no delay will be applied.'
229
+ )
230
+ ] = None
231
+
232
+ def _on_bound(self):
233
+ if beam is None:
234
+ raise ValueError(
235
+ 'Apache Beam is not installed. '
236
+ 'Please run `pip install apache-beam` to install beam.'
237
+ )
238
+ if self.current_run.use_cache != 'no':
239
+ raise ValueError(
240
+ 'Cache is not supported in BeamRunner. '
241
+ f'Encountered: {self.current_run.use_cache}'
242
+ )
243
+ host_plugins = []
244
+ per_example_plugins = []
245
+ for plugin in self.plugins:
246
+ if isinstance(plugin, checkpointing.Checkpointer):
247
+ pg.logging.warning(
248
+ 'Built-in checkpointing is enabled on BeamRunner. '
249
+ f'Ignoring checkpointer: {plugin!r}.'
250
+ )
251
+ elif plugin.is_per_example():
252
+ per_example_plugins.append(pg.Ref(plugin))
253
+ else:
254
+ host_plugins.append(pg.Ref(plugin))
255
+
256
+ self.rebind(
257
+ plugins=host_plugins,
258
+ skip_notification=True,
259
+ raise_on_no_change=False
260
+ )
261
+ self._per_example_plugins = per_example_plugins
262
+ super()._on_bound()
263
+
264
+ def run(self) -> None:
265
+ """Run evaluations using Beam."""
266
+ assert beam is not None
267
+ assert pipeline_options is not None
268
+
269
+ with beam.Pipeline(
270
+ runner=self.beam_runner or beam.runners.DirectRunner(),
271
+ options=pipeline_options.PipelineOptions(**self.beam_pipeline_options)
272
+ ) as pipeline:
273
+ evaluation_ids = set()
274
+ for evaluation in self.current_run.experiment.leaf_nodes:
275
+ if evaluation.id in evaluation_ids or (
276
+ self.current_run.filter is not None
277
+ and not self.current_run.filter(evaluation)
278
+ ):
279
+ continue
280
+
281
+ # There could be suites with duplicate evaluations, but we only want
282
+ # to run each evaluation once.
283
+ evaluation_ids.add(evaluation.id)
284
+
285
+ example_ids = self.current_run.example_ids
286
+ if example_ids is None:
287
+ example_ids = range(1, evaluation.num_examples + 1)
288
+ inputs = [
289
+ example_lib.Example(id=i, input=evaluation.example_input_by_id(i))
290
+ for i in example_ids
291
+ ]
292
+ if self.current_run.shuffle_inputs:
293
+ random.shuffle(inputs)
294
+
295
+ leaf_node_runner = LeafNodeRunner(
296
+ current_run=self.current_run.clone(
297
+ override=dict(
298
+ experiment=evaluation,
299
+ raise_if_has_error=False,
300
+ )
301
+ ),
302
+ plugins=self._per_example_plugins,
303
+ )
304
+ _ = (
305
+ pipeline
306
+ | f'Input-{evaluation.id}' >> beam.Create(
307
+ [(x.id, pg.to_json_str(x)) for x in inputs]
308
+ )
309
+ | f'Evaluate-{evaluation.id}'
310
+ >> beam.ParDo(
311
+ _EvaluateFn(
312
+ pg.to_json_str(leaf_node_runner),
313
+ ckpt_format=self.ckpt_format,
314
+ concurrent_startup_delay=self.concurrent_startup_delay,
315
+ )
316
+ )
317
+ )
318
+ monitor = ckpt_monitor.CheckpointMonitor(
319
+ pg.Ref(self.current_run),
320
+ plugins=pg.Ref(self.plugins),
321
+ # No need to add progress tracker as it is already added by the
322
+ # Beam runner.
323
+ progress_tracker=None,
324
+ monitor_inprogress_files=True,
325
+ checkpoint_pattern=f'checkpoint_*.{self.ckpt_format}',
326
+ max_aggregation_threads=self.max_aggregation_threads,
327
+ )
328
+ monitor.start()
329
+ monitor.join()
330
+
331
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]) -> None:
332
+ """Runs the experiment in sequence."""
333
+ raise NotImplementedError('Not needed in beam runner.')
334
+
335
+ def _evaluate_items(
336
+ self,
337
+ evaluation: evaluation_lib.Evaluation,
338
+ items: Iterator[example_lib.Example]
339
+ ) -> None:
340
+ """Evaluates the items of an evaluation."""
341
+ raise NotImplementedError('Not needed in beam runner.')
@@ -0,0 +1,131 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for beam runner."""
15
+
16
+ import os
17
+ import tempfile
18
+ from typing import Any
19
+ import unittest
20
+
21
+ from langfun.core.eval.v2 import checkpointing # pylint: disable=unused-import
22
+ from langfun.core.eval.v2 import eval_test_helper
23
+ from langfun.core.eval.v2 import reporting # pylint: disable=unused-import
24
+ from langfun.core.eval.v2.runners import beam # pylint: disable=unused-import
25
+ import pyglove as pg
26
+
27
+
28
+ @unittest.skip(
29
+ 'These tests are flaky due to writing ckpt files with standard IO.'
30
+ 'We will move to `beam.io` and re-enable these tests later.'
31
+ )
32
+ class BeamRunnerTest(unittest.TestCase):
33
+
34
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
35
+ self.assertEqual(len(actual), len(expected))
36
+ for i, (x, y) in enumerate(zip(actual, expected)):
37
+ if x is not y:
38
+ print(i, pg.diff(x, y))
39
+ self.assertIs(x, y)
40
+
41
+ def setUp(self):
42
+ super().setUp()
43
+ self.test_dir = os.path.join(tempfile.mkdtemp(), 'test_dir')
44
+
45
+ def test_basic(self):
46
+ plugin = eval_test_helper.TestPlugin()
47
+ exp = eval_test_helper.test_experiment()
48
+ root_dir = os.path.join(self.test_dir, 'test_beam_runner')
49
+ run = exp.run(
50
+ root_dir,
51
+ runner='beam',
52
+ plugins=[
53
+ plugin,
54
+ reporting.ExampleHtmlGenerator(),
55
+ checkpointing.PerExampleCheckpointer(
56
+ checkpoint_filename='checkpoint.jsonl'
57
+ ),
58
+ ],
59
+ concurrent_startup_delay=(1, 2),
60
+ use_cache='no',
61
+ ckpt_format='jsonl',
62
+ )
63
+
64
+ self.assertIsNotNone(plugin.start_time)
65
+ self.assertIsNotNone(plugin.complete_time)
66
+ self.assertGreater(plugin.complete_time, plugin.start_time)
67
+
68
+ self.assertEqual(len(plugin.started_experiments), len(exp.nodes))
69
+ self.assertEqual(len(plugin.completed_experiments), len(exp.nodes))
70
+ self.assertEqual(len(plugin.started_example_ids), 6 * 10)
71
+ self.assertEqual(len(plugin.completed_example_ids), 6 * 10)
72
+ self.assert_same_list(plugin.skipped_experiments, [])
73
+ self.assertTrue(
74
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
75
+ )
76
+ for node in exp.leaf_nodes:
77
+ for i in range(node.num_examples):
78
+ self.assertTrue(
79
+ pg.io.path_exists(
80
+ run.output_path_for(node, f'{i + 1}.html')
81
+ )
82
+ )
83
+
84
+ for node in exp.nodes:
85
+ if node.is_leaf:
86
+ self.assertTrue(node.progress.is_started)
87
+ self.assertTrue(node.progress.is_completed)
88
+ self.assertEqual(node.progress.num_skipped, 0)
89
+ self.assertEqual(node.progress.num_completed, 10)
90
+ self.assertEqual(node.progress.num_failed, 1)
91
+ else:
92
+ self.assertEqual(node.progress.num_skipped, 0)
93
+ self.assertEqual(node.progress.num_failed, 0)
94
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
95
+
96
+ def test_shuffle_inputs(self):
97
+ root_dir = os.path.join(self.test_dir, 'test_shuffle_inputs')
98
+ exp = eval_test_helper.test_experiment()
99
+ plugin = eval_test_helper.TestPlugin()
100
+ run = exp.run(
101
+ root_dir,
102
+ runner='beam',
103
+ plugins=[plugin],
104
+ shuffle_inputs=True,
105
+ use_cache='no',
106
+ ckpt_format='jsonl',
107
+ )
108
+ self.assertTrue(run.shuffle_inputs)
109
+
110
+ def test_beam_runner_does_not_support_cache(self):
111
+ exp = eval_test_helper.test_experiment()
112
+ root_dir = os.path.join(self.test_dir, 'test_beam_runner_cache')
113
+ with self.assertRaisesRegex(ValueError, 'Cache is not supported'):
114
+ exp.run(
115
+ root_dir,
116
+ runner='beam',
117
+ use_cache='global',
118
+ )
119
+
120
+ def test_no_beam(self):
121
+ orig_beam = beam.beam
122
+ beam.beam = None
123
+ with self.assertRaisesRegex(ValueError, 'Beam is not installed'):
124
+ exp = eval_test_helper.TestEvaluation()
125
+ root_dir = os.path.join(self.test_dir, 'test_no_beam')
126
+ exp.run(root_dir, runner='beam')
127
+ beam.beam = orig_beam
128
+
129
+
130
+ if __name__ == '__main__':
131
+ unittest.main()