langfun 0.1.2.dev202511160804__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (41) hide show
  1. langfun/core/__init__.py +1 -0
  2. langfun/core/agentic/__init__.py +4 -1
  3. langfun/core/agentic/action.py +340 -17
  4. langfun/core/agentic/action_test.py +124 -21
  5. langfun/core/eval/base_test.py +5 -5
  6. langfun/core/eval/v2/checkpointing.py +25 -1
  7. langfun/core/eval/v2/checkpointing_test.py +8 -1
  8. langfun/core/eval/v2/eval_test_helper.py +7 -2
  9. langfun/core/eval/v2/evaluation.py +4 -1
  10. langfun/core/eval/v2/example.py +5 -1
  11. langfun/core/eval/v2/example_test.py +13 -5
  12. langfun/core/eval/v2/experiment.py +23 -0
  13. langfun/core/eval/v2/experiment_test.py +19 -0
  14. langfun/core/eval/v2/progress_tracking.py +12 -3
  15. langfun/core/eval/v2/progress_tracking_test.py +3 -1
  16. langfun/core/eval/v2/reporting_test.py +4 -0
  17. langfun/core/eval/v2/runners/__init__.py +4 -0
  18. langfun/core/eval/v2/runners/base.py +40 -21
  19. langfun/core/eval/v2/runners/beam.py +341 -0
  20. langfun/core/eval/v2/runners/beam_test.py +131 -0
  21. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  22. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  23. langfun/core/eval/v2/runners/debug_test.py +1 -4
  24. langfun/core/eval/v2/runners/parallel_test.py +1 -4
  25. langfun/core/eval/v2/runners/sequential_test.py +1 -4
  26. langfun/core/langfunc_test.py +3 -3
  27. langfun/core/language_model.py +38 -5
  28. langfun/core/language_model_test.py +45 -0
  29. langfun/core/llms/__init__.py +2 -0
  30. langfun/core/llms/gemini.py +41 -8
  31. langfun/core/llms/gemini_test.py +84 -0
  32. langfun/core/llms/google_genai.py +5 -0
  33. langfun/core/llms/vertexai.py +7 -0
  34. langfun/core/modalities/mime.py +2 -0
  35. langfun/core/modalities/mime_test.py +11 -0
  36. langfun/core/structured/schema/__init__.py +1 -0
  37. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
  38. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/RECORD +41 -37
  39. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
  40. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
  41. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import concurrent.futures
18
18
  import random
19
19
  import threading
20
20
  import traceback
21
- from typing import Any, Annotated, Callable, Iterator
21
+ from typing import Any, Annotated, Callable, Iterator, Literal
22
22
 
23
23
  from langfun import core as lf
24
24
  from langfun.core.eval.v2 import checkpointing
@@ -50,14 +50,16 @@ class RunnerBase(Runner):
50
50
  execution strategies.
51
51
  """
52
52
 
53
- tqdm: Annotated[
54
- bool,
53
+ progress_tracker: Annotated[
54
+ Literal['tqdm', 'html', 'auto', None],
55
55
  (
56
- 'If True, force using tqdm for progress update. Otherwise, determine '
57
- 'it automatically based on the running environment (console vs. '
58
- 'notebook)'
56
+ 'If `tqdm`, force using tqdm for progress update. '
57
+ 'If `html`, force using html for progress update. '
58
+ 'If `auto`, determine it automatically based on the running '
59
+ 'environment (console vs. notebook)'
60
+ 'If `none`, disable progress update.'
59
61
  )
60
- ] = False
62
+ ] = 'auto'
61
63
 
62
64
  plugins = [
63
65
  checkpointing.BulkCheckpointer(),
@@ -73,13 +75,21 @@ class RunnerBase(Runner):
73
75
  super()._on_bound()
74
76
 
75
77
  # Install the tqdm plugin if needed.
76
- with pg.notify_on_change(False):
77
- self.plugins.append(progress_tracking.progress_tracker(self.tqdm))
78
+ if self.progress_tracker is not None:
79
+ with pg.notify_on_change(False):
80
+ self.plugins.append(
81
+ progress_tracking.progress_tracker(self.progress_tracker)
82
+ )
83
+
84
+ if self.max_background_threads > 0:
85
+ self._io_pool_lock = threading.Lock()
86
+ self._io_pool = concurrent.futures.ThreadPoolExecutor(
87
+ max_workers=self.max_background_threads
88
+ )
89
+ else:
90
+ self._io_pool_lock = None
91
+ self._io_pool = None
78
92
 
79
- self._io_pool_lock = threading.Lock()
80
- self._io_pool = concurrent.futures.ThreadPoolExecutor(
81
- max_workers=self.max_background_threads
82
- )
83
93
  # TODO(daiyip): render background errors.
84
94
  self._background_last_error = None
85
95
 
@@ -91,9 +101,12 @@ class RunnerBase(Runner):
91
101
  except Exception as e: # pylint: disable=broad-except
92
102
  self._background_last_error = e
93
103
 
94
- with self._io_pool_lock:
95
- if self._io_pool is not None:
96
- self._io_pool.submit(_background_run, *args, **kwargs)
104
+ if self.max_background_threads > 0:
105
+ with self._io_pool_lock:
106
+ if self._io_pool is not None:
107
+ self._io_pool.submit(_background_run, *args, **kwargs)
108
+ else:
109
+ _background_run(*args, **kwargs)
97
110
 
98
111
  def _all_plugins(self, experiment: Experiment) -> Iterator[Plugin]:
99
112
  """Returns all plugins for the experiment."""
@@ -152,6 +165,7 @@ class RunnerBase(Runner):
152
165
  plugin.on_experiment_start(self, experiment)
153
166
 
154
167
  if experiment.is_leaf:
168
+ pg.io.mkdirs(self.current_run.output_dir(experiment))
155
169
  experiment.info(
156
170
  f'Starting evaluation {experiment.id!r} with '
157
171
  f'{num_examples_to_evaluate} examples to evaluate.'
@@ -248,6 +262,8 @@ class RunnerBase(Runner):
248
262
  example: Example
249
263
  ) -> None:
250
264
  """Called when an evaluation example is started."""
265
+ assert isinstance(experiment, Evaluation), experiment
266
+ experiment.state.update(example, in_progress=True)
251
267
  for plugin in self._all_plugins(experiment):
252
268
  plugin.on_example_start(self, experiment, example)
253
269
  experiment.info(f'Starting to evaluate example {example.id}.')
@@ -258,6 +274,8 @@ class RunnerBase(Runner):
258
274
  example: Example
259
275
  ) -> None:
260
276
  """Called when an evaluation example is complete."""
277
+ assert isinstance(experiment, Evaluation), experiment
278
+ experiment.state.update(example, in_progress=False)
261
279
  if example.newly_processed:
262
280
  if example.error is None:
263
281
  experiment.progress.increment_processed()
@@ -269,7 +287,7 @@ class RunnerBase(Runner):
269
287
  experiment.progress.increment_failed()
270
288
  experiment.error(
271
289
  (
272
- f'Failed to evaluate example {example.id} in'
290
+ f'Failed to evaluate example {example.id} in '
273
291
  f'{example.elapse:.2f} seconds.'
274
292
  ),
275
293
  error=example.error
@@ -329,7 +347,7 @@ class RunnerBase(Runner):
329
347
  self._run(targets)
330
348
 
331
349
  self.on_run_complete()
332
- except Exception as e: # pylint: disable=broad-except
350
+ except BaseException as e: # pylint: disable=broad-except
333
351
  self.on_run_abort(e)
334
352
  raise e
335
353
  finally:
@@ -337,9 +355,10 @@ class RunnerBase(Runner):
337
355
  self.background_run(cache.save)
338
356
 
339
357
  # Wait for the background tasks to finish.
340
- with self._io_pool_lock:
341
- self._io_pool, io_pool = None, self._io_pool
342
- io_pool.shutdown(wait=True)
358
+ if self.max_background_threads > 0:
359
+ with self._io_pool_lock:
360
+ self._io_pool, io_pool = None, self._io_pool
361
+ io_pool.shutdown(wait=True)
343
362
 
344
363
  @abc.abstractmethod
345
364
  def _run(self, evaluations: list[Evaluation]) -> None:
@@ -0,0 +1,341 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Beam-based evaluation runner.
15
+
16
+ BeamRunner is a runner that uses Apache Beam to run evaluations in parallel.
17
+ It is useful for running evaluations with a large number of examples and/or when
18
+ each example is costly to evaluate and can be parallelized.
19
+
20
+ BeamRunner supports plugins as all other runners do, with the following caveats:
21
+
22
+ 1. Checkpointer plugins are ignored, as BeamRunner performs its own per-example
23
+ checkpointing.
24
+
25
+ 2. Per-example plugins are executed in the Beam worker to maximize throughput,
26
+ while all non-per-example plugins are executed in the main process, which
27
+ collects the results from the workers. Since it might be expensive to
28
+ deserialize `Example.metadata` for complex evaluations, the main process
29
+ does not deserialize `Example.metadata` from the workers. If you need to
30
+ to access `Example.metadata` in your plugin, consider make it a per-example
31
+ plugin (which only implements `on_example_start` and/or
32
+ `on_example_complete`)
33
+
34
+ To use it, simply create a `lf.eval.Suite` or `lf.eval.Evaluation`
35
+ and run it with `lf.eval.run(runner='beam')` and passing in an additional
36
+ `beam_runner` argument.
37
+ """
38
+
39
+ import datetime
40
+ import hashlib
41
+ import os
42
+ import random
43
+ import time
44
+ from typing import Annotated, Any, Iterator
45
+
46
+ from langfun.core.eval.v2 import checkpointing
47
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
48
+ from langfun.core.eval.v2 import example as example_lib
49
+ from langfun.core.eval.v2.runners import base
50
+ from langfun.core.eval.v2.runners import ckpt_monitor
51
+
52
+ import pyglove as pg
53
+
54
+ try:
55
+ # pylint: disable=g-import-not-at-top
56
+ # pytype: disable=import-error
57
+ import apache_beam as beam
58
+ from apache_beam.options import pipeline_options
59
+ # pytype: enable=import-error
60
+ # pylint: enable=g-import-not-at-top
61
+ except ImportError:
62
+ beam = None
63
+ pipeline_options = None
64
+
65
+
66
+ if beam is not None:
67
+ class _EvaluateFn(beam.DoFn):
68
+ """Beam DoFn for evaluating examples."""
69
+
70
+ def __init__(
71
+ self,
72
+ runner_str: str,
73
+ ckpt_format: str,
74
+ concurrent_startup_delay: tuple[int, int] | None = None,
75
+ ):
76
+ self._runner_str = runner_str
77
+ self._ckpt_format = ckpt_format
78
+ self._concurrent_startup_delay = concurrent_startup_delay
79
+
80
+ def setup(self):
81
+ if self._concurrent_startup_delay is not None:
82
+ time.sleep(random.randint(*self._concurrent_startup_delay))
83
+ self._runner = pg.from_json_str(self._runner_str)
84
+ assert isinstance(self._runner, LeafNodeRunner)
85
+ self._runner.setup()
86
+ self._output_dir = self._runner.current_run.output_dir(
87
+ self._runner.current_run.experiment
88
+ )
89
+
90
+ def teardown(self):
91
+ assert self._runner is not None
92
+ self._runner.teardown()
93
+
94
+ def process(self, example: tuple[int, str]) -> Iterator[str]:
95
+ """Evaluates an example and writes the checkpoint file.
96
+
97
+ Args:
98
+ example: A tuple of (example_id, example_json).
99
+
100
+ Yields:
101
+ The path to the checkpoint file.
102
+ """
103
+ example_id, example_json = example
104
+ ckpt_file = os.path.join(
105
+ self._output_dir, f'checkpoint_{example_id}.{self._ckpt_format}'
106
+ )
107
+ if pg.io.path_exists(ckpt_file):
108
+ yield ckpt_file
109
+
110
+ # Write the in-progress file to indicate that the example is being
111
+ # processed.
112
+ in_progress_file = os.path.join(
113
+ self._output_dir, f'{example_id}.inprogress'
114
+ )
115
+ pg.io.writefile(
116
+ in_progress_file,
117
+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
118
+ )
119
+
120
+ # Process one example.
121
+ example = self._runner.process(pg.from_json_str(example_json))
122
+
123
+ # Perform atomic checkpointing.
124
+ tmp_ckpt_file = os.path.join(
125
+ self._output_dir, f'tmp_checkpoint_{example_id}.{self._ckpt_format}'
126
+ )
127
+ example_json_str = pg.to_json_str(example)
128
+ with pg.io.open_sequence(tmp_ckpt_file, 'w') as f:
129
+ f.add(example_json_str)
130
+ pg.io.rename(tmp_ckpt_file, ckpt_file)
131
+
132
+ # Write the MD5 digest of the example so we know if the example has been
133
+ # processed multiple times.
134
+ digest = hashlib.md5(example_json_str.encode()).hexdigest()[:8]
135
+ pg.io.writefile(
136
+ os.path.join(self._output_dir, f'{example_id}.{digest}.md5'),
137
+ digest
138
+ )
139
+ yield ckpt_file
140
+
141
+ else:
142
+ _EvaluateFn = None # pylint: disable=invalid-name
143
+
144
+
145
+ class LeafNodeRunner(base.RunnerBase):
146
+ """A runner that runs in a DoFn worker."""
147
+
148
+ NAME = '__beam_leaf_node_runner__'
149
+ progress_tracker = None
150
+ max_background_threads = 0
151
+
152
+ def _on_bound(self):
153
+ super()._on_bound()
154
+ for plugin in self.plugins:
155
+ if not plugin.is_per_example():
156
+ raise ValueError(
157
+ 'Only per-example plugins are supported in LeafNodeRunner. '
158
+ f'Encountered: {plugin!r}'
159
+ )
160
+ if not isinstance(self.current_run.experiment, evaluation_lib.Evaluation):
161
+ raise ValueError(
162
+ 'The experiment must be a leaf evaluation in LeafNodeRunner. '
163
+ f'Encountered: {self.current_run.experiment!r}'
164
+ )
165
+
166
+ def setup(self):
167
+ self.current_run.experiment.setup()
168
+
169
+ def teardown(self):
170
+ self.current_run.experiment.teardown()
171
+
172
+ def process(self, example: example_lib.Example) -> example_lib.Example:
173
+ """Processes one example."""
174
+ for plugin in self.plugins:
175
+ plugin.on_example_start(self, self.current_run.experiment, example)
176
+ example = self.current_run.experiment.evaluate(example)
177
+ for plugin in self.plugins:
178
+ plugin.on_example_complete(self, self.current_run.experiment, example)
179
+ return example
180
+
181
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]) -> None:
182
+ """Runs the experiment in sequence."""
183
+ raise NotImplementedError('Not needed in leaf node runner.')
184
+
185
+ def _evaluate_items(
186
+ self,
187
+ evaluation: evaluation_lib.Evaluation,
188
+ items: Iterator[example_lib.Example]
189
+ ) -> None:
190
+ """Evaluates the items of an evaluation."""
191
+ raise NotImplementedError('Not needed in leaf node runner.')
192
+
193
+
194
+ class BeamRunner(base.RunnerBase):
195
+ """Beam runner for Langfun evaluations.
196
+
197
+ NOTE: This runner depends on Apache Beam, which needs to be installed
198
+ separately.
199
+ """
200
+
201
+ NAME = 'beam'
202
+
203
+ beam_runner: Annotated[
204
+ Any | None,
205
+ 'The beam runner to use. If None, the direct runner will be used.'
206
+ ] = None
207
+
208
+ beam_pipeline_options: Annotated[
209
+ dict[str, Any],
210
+ 'Beam pipeline options.'
211
+ ] = {}
212
+
213
+ ckpt_format: Annotated[
214
+ str,
215
+ 'The file extension of the checkpoint files.'
216
+ ] = 'jsonl'
217
+
218
+ max_aggregation_threads: Annotated[
219
+ int,
220
+ 'The maximum number of threads to aggregate checkpoints.'
221
+ ] = 128
222
+
223
+ concurrent_startup_delay: Annotated[
224
+ tuple[int, int] | None,
225
+ (
226
+ 'A range of seconds to delay the initial evaluation of each thread '
227
+ 'in the thread pool, helping to prevent a burst in LLM QPS at '
228
+ 'startup. If set to None, no delay will be applied.'
229
+ )
230
+ ] = None
231
+
232
+ def _on_bound(self):
233
+ if beam is None:
234
+ raise ValueError(
235
+ 'Apache Beam is not installed. '
236
+ 'Please run `pip install apache-beam` to install beam.'
237
+ )
238
+ if self.current_run.use_cache != 'no':
239
+ raise ValueError(
240
+ 'Cache is not supported in BeamRunner. '
241
+ f'Encountered: {self.current_run.use_cache}'
242
+ )
243
+ host_plugins = []
244
+ per_example_plugins = []
245
+ for plugin in self.plugins:
246
+ if isinstance(plugin, checkpointing.Checkpointer):
247
+ pg.logging.warning(
248
+ 'Built-in checkpointing is enabled on BeamRunner. '
249
+ f'Ignoring checkpointer: {plugin!r}.'
250
+ )
251
+ elif plugin.is_per_example():
252
+ per_example_plugins.append(pg.Ref(plugin))
253
+ else:
254
+ host_plugins.append(pg.Ref(plugin))
255
+
256
+ self.rebind(
257
+ plugins=host_plugins,
258
+ skip_notification=True,
259
+ raise_on_no_change=False
260
+ )
261
+ self._per_example_plugins = per_example_plugins
262
+ super()._on_bound()
263
+
264
+ def run(self) -> None:
265
+ """Run evaluations using Beam."""
266
+ assert beam is not None
267
+ assert pipeline_options is not None
268
+
269
+ with beam.Pipeline(
270
+ runner=self.beam_runner or beam.runners.DirectRunner(),
271
+ options=pipeline_options.PipelineOptions(**self.beam_pipeline_options)
272
+ ) as pipeline:
273
+ evaluation_ids = set()
274
+ for evaluation in self.current_run.experiment.leaf_nodes:
275
+ if evaluation.id in evaluation_ids or (
276
+ self.current_run.filter is not None
277
+ and not self.current_run.filter(evaluation)
278
+ ):
279
+ continue
280
+
281
+ # There could be suites with duplicate evaluations, but we only want
282
+ # to run each evaluation once.
283
+ evaluation_ids.add(evaluation.id)
284
+
285
+ example_ids = self.current_run.example_ids
286
+ if example_ids is None:
287
+ example_ids = range(1, evaluation.num_examples + 1)
288
+ inputs = [
289
+ example_lib.Example(id=i, input=evaluation.example_input_by_id(i))
290
+ for i in example_ids
291
+ ]
292
+ if self.current_run.shuffle_inputs:
293
+ random.shuffle(inputs)
294
+
295
+ leaf_node_runner = LeafNodeRunner(
296
+ current_run=self.current_run.clone(
297
+ override=dict(
298
+ experiment=evaluation,
299
+ raise_if_has_error=False,
300
+ )
301
+ ),
302
+ plugins=self._per_example_plugins,
303
+ )
304
+ _ = (
305
+ pipeline
306
+ | f'Input-{evaluation.id}' >> beam.Create(
307
+ [(x.id, pg.to_json_str(x)) for x in inputs]
308
+ )
309
+ | f'Evaluate-{evaluation.id}'
310
+ >> beam.ParDo(
311
+ _EvaluateFn(
312
+ pg.to_json_str(leaf_node_runner),
313
+ ckpt_format=self.ckpt_format,
314
+ concurrent_startup_delay=self.concurrent_startup_delay,
315
+ )
316
+ )
317
+ )
318
+ monitor = ckpt_monitor.CheckpointMonitor(
319
+ pg.Ref(self.current_run),
320
+ plugins=pg.Ref(self.plugins),
321
+ # No need to add progress tracker as it is already added by the
322
+ # Beam runner.
323
+ progress_tracker=None,
324
+ monitor_inprogress_files=True,
325
+ checkpoint_pattern=f'checkpoint_*.{self.ckpt_format}',
326
+ max_aggregation_threads=self.max_aggregation_threads,
327
+ )
328
+ monitor.start()
329
+ monitor.join()
330
+
331
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]) -> None:
332
+ """Runs the experiment in sequence."""
333
+ raise NotImplementedError('Not needed in beam runner.')
334
+
335
+ def _evaluate_items(
336
+ self,
337
+ evaluation: evaluation_lib.Evaluation,
338
+ items: Iterator[example_lib.Example]
339
+ ) -> None:
340
+ """Evaluates the items of an evaluation."""
341
+ raise NotImplementedError('Not needed in beam runner.')
@@ -0,0 +1,131 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for beam runner."""
15
+
16
+ import os
17
+ import tempfile
18
+ from typing import Any
19
+ import unittest
20
+
21
+ from langfun.core.eval.v2 import checkpointing # pylint: disable=unused-import
22
+ from langfun.core.eval.v2 import eval_test_helper
23
+ from langfun.core.eval.v2 import reporting # pylint: disable=unused-import
24
+ from langfun.core.eval.v2.runners import beam # pylint: disable=unused-import
25
+ import pyglove as pg
26
+
27
+
28
+ @unittest.skip(
29
+ 'These tests are flaky due to writing ckpt files with standard IO.'
30
+ 'We will move to `beam.io` and re-enable these tests later.'
31
+ )
32
+ class BeamRunnerTest(unittest.TestCase):
33
+
34
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
35
+ self.assertEqual(len(actual), len(expected))
36
+ for i, (x, y) in enumerate(zip(actual, expected)):
37
+ if x is not y:
38
+ print(i, pg.diff(x, y))
39
+ self.assertIs(x, y)
40
+
41
+ def setUp(self):
42
+ super().setUp()
43
+ self.test_dir = os.path.join(tempfile.mkdtemp(), 'test_dir')
44
+
45
+ def test_basic(self):
46
+ plugin = eval_test_helper.TestPlugin()
47
+ exp = eval_test_helper.test_experiment()
48
+ root_dir = os.path.join(self.test_dir, 'test_beam_runner')
49
+ run = exp.run(
50
+ root_dir,
51
+ runner='beam',
52
+ plugins=[
53
+ plugin,
54
+ reporting.ExampleHtmlGenerator(),
55
+ checkpointing.PerExampleCheckpointer(
56
+ checkpoint_filename='checkpoint.jsonl'
57
+ ),
58
+ ],
59
+ concurrent_startup_delay=(1, 2),
60
+ use_cache='no',
61
+ ckpt_format='jsonl',
62
+ )
63
+
64
+ self.assertIsNotNone(plugin.start_time)
65
+ self.assertIsNotNone(plugin.complete_time)
66
+ self.assertGreater(plugin.complete_time, plugin.start_time)
67
+
68
+ self.assertEqual(len(plugin.started_experiments), len(exp.nodes))
69
+ self.assertEqual(len(plugin.completed_experiments), len(exp.nodes))
70
+ self.assertEqual(len(plugin.started_example_ids), 6 * 10)
71
+ self.assertEqual(len(plugin.completed_example_ids), 6 * 10)
72
+ self.assert_same_list(plugin.skipped_experiments, [])
73
+ self.assertTrue(
74
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
75
+ )
76
+ for node in exp.leaf_nodes:
77
+ for i in range(node.num_examples):
78
+ self.assertTrue(
79
+ pg.io.path_exists(
80
+ run.output_path_for(node, f'{i + 1}.html')
81
+ )
82
+ )
83
+
84
+ for node in exp.nodes:
85
+ if node.is_leaf:
86
+ self.assertTrue(node.progress.is_started)
87
+ self.assertTrue(node.progress.is_completed)
88
+ self.assertEqual(node.progress.num_skipped, 0)
89
+ self.assertEqual(node.progress.num_completed, 10)
90
+ self.assertEqual(node.progress.num_failed, 1)
91
+ else:
92
+ self.assertEqual(node.progress.num_skipped, 0)
93
+ self.assertEqual(node.progress.num_failed, 0)
94
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
95
+
96
+ def test_shuffle_inputs(self):
97
+ root_dir = os.path.join(self.test_dir, 'test_shuffle_inputs')
98
+ exp = eval_test_helper.test_experiment()
99
+ plugin = eval_test_helper.TestPlugin()
100
+ run = exp.run(
101
+ root_dir,
102
+ runner='beam',
103
+ plugins=[plugin],
104
+ shuffle_inputs=True,
105
+ use_cache='no',
106
+ ckpt_format='jsonl',
107
+ )
108
+ self.assertTrue(run.shuffle_inputs)
109
+
110
+ def test_beam_runner_does_not_support_cache(self):
111
+ exp = eval_test_helper.test_experiment()
112
+ root_dir = os.path.join(self.test_dir, 'test_beam_runner_cache')
113
+ with self.assertRaisesRegex(ValueError, 'Cache is not supported'):
114
+ exp.run(
115
+ root_dir,
116
+ runner='beam',
117
+ use_cache='global',
118
+ )
119
+
120
+ def test_no_beam(self):
121
+ orig_beam = beam.beam
122
+ beam.beam = None
123
+ with self.assertRaisesRegex(ValueError, 'Beam is not installed'):
124
+ exp = eval_test_helper.TestEvaluation()
125
+ root_dir = os.path.join(self.test_dir, 'test_no_beam')
126
+ exp.run(root_dir, runner='beam')
127
+ beam.beam = orig_beam
128
+
129
+
130
+ if __name__ == '__main__':
131
+ unittest.main()