langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,354 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Beam-based evaluation runner.
15
+
16
+ BeamRunner is a runner that uses Apache Beam to run evaluations in parallel.
17
+ It is useful for running evaluations with a large number of examples and/or when
18
+ each example is costly to evaluate and can be parallelized.
19
+
20
+ BeamRunner supports plugins as all other runners do, with the following caveats:
21
+
22
+ 1. Checkpointer plugins are ignored, as BeamRunner performs its own per-example
23
+ checkpointing.
24
+
25
+ 2. Per-example plugins are executed in the Beam worker to maximize throughput,
26
+ while all non-per-example plugins are executed in the main process, which
27
+ collects the results from the workers. Since it might be expensive to
28
+ deserialize `Example.metadata` for complex evaluations, the main process
29
+ does not deserialize `Example.metadata` from the workers. If you need to
30
+ to access `Example.metadata` in your plugin, consider make it a per-example
31
+ plugin (which only implements `on_example_start` and/or
32
+ `on_example_complete`)
33
+
34
+ To use it, simply create a `lf.eval.Suite` or `lf.eval.Evaluation`
35
+ and run it with `lf.eval.run(runner='beam')` and passing in an additional
36
+ `beam_runner` argument.
37
+ """
38
+
39
+ import datetime
40
+ import hashlib
41
+ import os
42
+ import random
43
+ import time
44
+ from typing import Annotated, Any, Iterator
45
+
46
+ from langfun.core.eval.v2 import checkpointing
47
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
48
+ from langfun.core.eval.v2 import example as example_lib
49
+ from langfun.core.eval.v2.runners import base
50
+ from langfun.core.eval.v2.runners import ckpt_monitor
51
+
52
+ import pyglove as pg
53
+
54
+ try:
55
+ # pylint: disable=g-import-not-at-top
56
+ # pytype: disable=import-error
57
+ import apache_beam as beam
58
+ from apache_beam.options import pipeline_options
59
+ # pytype: enable=import-error
60
+ # pylint: enable=g-import-not-at-top
61
+ except ImportError:
62
+ beam = None
63
+ pipeline_options = None
64
+
65
+
66
+ if beam is not None:
67
+ class _EvaluateFn(beam.DoFn):
68
+ """Beam DoFn for evaluating examples."""
69
+
70
+ def __init__(
71
+ self,
72
+ runner_str: str,
73
+ ckpt_format: str,
74
+ concurrent_startup_delay: tuple[int, int] | None = None,
75
+ ):
76
+ self._runner_str = runner_str
77
+ self._ckpt_format = ckpt_format
78
+ self._concurrent_startup_delay = concurrent_startup_delay
79
+
80
+ def setup(self):
81
+ if self._concurrent_startup_delay is not None:
82
+ time.sleep(random.randint(*self._concurrent_startup_delay))
83
+ self._runner = pg.from_json_str(self._runner_str)
84
+ assert isinstance(self._runner, LeafNodeRunner)
85
+ self._runner.setup()
86
+ self._input_dir = self._runner.current_run.input_dir(
87
+ self._runner.current_run.experiment
88
+ )
89
+ self._output_dir = self._runner.current_run.output_dir(
90
+ self._runner.current_run.experiment
91
+ )
92
+
93
+ def teardown(self):
94
+ assert self._runner is not None
95
+ self._runner.teardown()
96
+
97
+ def process(self, example: tuple[int, str]) -> Iterator[str]:
98
+ """Evaluates an example and writes the checkpoint file.
99
+
100
+ Args:
101
+ example: A tuple of (example_id, example_json).
102
+
103
+ Yields:
104
+ The path to the checkpoint file.
105
+ """
106
+ example_id, example_json = example
107
+ ckpt_file = os.path.join(
108
+ self._output_dir, f'checkpoint_{example_id}.{self._ckpt_format}'
109
+ )
110
+ if pg.io.path_exists(ckpt_file):
111
+ yield ckpt_file
112
+ return
113
+
114
+ if self._input_dir != self._output_dir:
115
+ warmup_ckpt_file = os.path.join(
116
+ self._input_dir, f'checkpoint_{example_id}.{self._ckpt_format}'
117
+ )
118
+ if pg.io.path_exists(warmup_ckpt_file):
119
+ pg.io.copy(warmup_ckpt_file, ckpt_file)
120
+ yield ckpt_file
121
+ return
122
+
123
+ # Write the in-progress file to indicate that the example is being
124
+ # processed.
125
+ in_progress_file = os.path.join(
126
+ self._output_dir, f'{example_id}.inprogress'
127
+ )
128
+ pg.io.writefile(
129
+ in_progress_file,
130
+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
131
+ )
132
+
133
+ # Process one example.
134
+ example = self._runner.process(pg.from_json_str(example_json))
135
+
136
+ # Perform atomic checkpointing.
137
+ tmp_ckpt_file = os.path.join(
138
+ self._output_dir, f'tmp_checkpoint_{example_id}.{self._ckpt_format}'
139
+ )
140
+ example_json_str = pg.to_json_str(example)
141
+ with pg.io.open_sequence(tmp_ckpt_file, 'w') as f:
142
+ f.add(example_json_str)
143
+ pg.io.rename(tmp_ckpt_file, ckpt_file)
144
+
145
+ # Write the MD5 digest of the example so we know if the example has been
146
+ # processed multiple times.
147
+ digest = hashlib.md5(example_json_str.encode()).hexdigest()[:8]
148
+ pg.io.writefile(
149
+ os.path.join(self._output_dir, f'{example_id}.{digest}.md5'),
150
+ digest
151
+ )
152
+ yield ckpt_file
153
+
154
+ else:
155
+ _EvaluateFn = None # pylint: disable=invalid-name
156
+
157
+
158
+ class LeafNodeRunner(base.RunnerBase):
159
+ """A runner that runs in a DoFn worker."""
160
+
161
+ NAME = '__beam_leaf_node_runner__'
162
+ progress_tracker = None
163
+ max_background_threads = 0
164
+
165
+ def _on_bound(self):
166
+ super()._on_bound()
167
+ for plugin in self.plugins:
168
+ if not plugin.is_per_example():
169
+ raise ValueError(
170
+ 'Only per-example plugins are supported in LeafNodeRunner. '
171
+ f'Encountered: {plugin!r}'
172
+ )
173
+ if not isinstance(self.current_run.experiment, evaluation_lib.Evaluation):
174
+ raise ValueError(
175
+ 'The experiment must be a leaf evaluation in LeafNodeRunner. '
176
+ f'Encountered: {self.current_run.experiment!r}'
177
+ )
178
+
179
+ def setup(self):
180
+ self.current_run.experiment.setup()
181
+
182
+ def teardown(self):
183
+ self.current_run.experiment.teardown()
184
+
185
+ def process(self, example: example_lib.Example) -> example_lib.Example:
186
+ """Processes one example."""
187
+ for plugin in self.plugins:
188
+ plugin.on_example_start(self, self.current_run.experiment, example)
189
+ example = self.current_run.experiment.evaluate(example)
190
+ for plugin in self.plugins:
191
+ plugin.on_example_complete(self, self.current_run.experiment, example)
192
+ return example
193
+
194
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]) -> None:
195
+ """Runs the experiment in sequence."""
196
+ raise NotImplementedError('Not needed in leaf node runner.')
197
+
198
+ def _evaluate_items(
199
+ self,
200
+ evaluation: evaluation_lib.Evaluation,
201
+ items: Iterator[example_lib.Example]
202
+ ) -> None:
203
+ """Evaluates the items of an evaluation."""
204
+ raise NotImplementedError('Not needed in leaf node runner.')
205
+
206
+
207
+ class BeamRunner(base.RunnerBase):
208
+ """Beam runner for Langfun evaluations.
209
+
210
+ NOTE: This runner depends on Apache Beam, which needs to be installed
211
+ separately.
212
+ """
213
+
214
+ NAME = 'beam'
215
+
216
+ beam_runner: Annotated[
217
+ Any | None,
218
+ 'The beam runner to use. If None, the direct runner will be used.'
219
+ ] = None
220
+
221
+ beam_pipeline_options: Annotated[
222
+ dict[str, Any],
223
+ 'Beam pipeline options.'
224
+ ] = {}
225
+
226
+ ckpt_format: Annotated[
227
+ str,
228
+ 'The file extension of the checkpoint files.'
229
+ ] = 'jsonl'
230
+
231
+ max_aggregation_threads: Annotated[
232
+ int,
233
+ 'The maximum number of threads to aggregate checkpoints.'
234
+ ] = 128
235
+
236
+ concurrent_startup_delay: Annotated[
237
+ tuple[int, int] | None,
238
+ (
239
+ 'A range of seconds to delay the initial evaluation of each thread '
240
+ 'in the thread pool, helping to prevent a burst in LLM QPS at '
241
+ 'startup. If set to None, no delay will be applied.'
242
+ )
243
+ ] = None
244
+
245
+ def _on_bound(self):
246
+ if beam is None:
247
+ raise ValueError(
248
+ 'Apache Beam is not installed. '
249
+ 'Please run `pip install apache-beam` to install beam.'
250
+ )
251
+ if self.current_run.use_cache != 'no':
252
+ raise ValueError(
253
+ 'Cache is not supported in BeamRunner. '
254
+ f'Encountered: {self.current_run.use_cache}'
255
+ )
256
+ host_plugins = []
257
+ per_example_plugins = []
258
+ for plugin in self.plugins:
259
+ if isinstance(plugin, checkpointing.Checkpointer):
260
+ pg.logging.warning(
261
+ 'Built-in checkpointing is enabled on BeamRunner. '
262
+ f'Ignoring checkpointer: {plugin!r}.'
263
+ )
264
+ elif plugin.is_per_example():
265
+ per_example_plugins.append(pg.Ref(plugin))
266
+ else:
267
+ host_plugins.append(pg.Ref(plugin))
268
+
269
+ self.rebind(
270
+ plugins=host_plugins,
271
+ skip_notification=True,
272
+ raise_on_no_change=False
273
+ )
274
+ self._per_example_plugins = per_example_plugins
275
+ super()._on_bound()
276
+
277
+ def run(self) -> None:
278
+ """Run evaluations using Beam."""
279
+ assert beam is not None
280
+ assert pipeline_options is not None
281
+
282
+ with beam.Pipeline(
283
+ runner=self.beam_runner or beam.runners.DirectRunner(),
284
+ options=pipeline_options.PipelineOptions(**self.beam_pipeline_options)
285
+ ) as pipeline:
286
+ evaluation_ids = set()
287
+ for evaluation in self.current_run.experiment.leaf_nodes:
288
+ if evaluation.id in evaluation_ids or (
289
+ self.current_run.filter is not None
290
+ and not self.current_run.filter(evaluation)
291
+ ):
292
+ continue
293
+
294
+ # There could be suites with duplicate evaluations, but we only want
295
+ # to run each evaluation once.
296
+ evaluation_ids.add(evaluation.id)
297
+
298
+ example_ids = self.current_run.example_ids
299
+ if example_ids is None:
300
+ example_ids = range(1, evaluation.num_examples + 1)
301
+ inputs = [
302
+ example_lib.Example(id=i, input=evaluation.example_input_by_id(i))
303
+ for i in example_ids
304
+ ]
305
+ if self.current_run.shuffle_inputs:
306
+ random.shuffle(inputs)
307
+
308
+ leaf_node_runner = LeafNodeRunner(
309
+ current_run=self.current_run.clone(
310
+ override=dict(
311
+ experiment=evaluation,
312
+ raise_if_has_error=False,
313
+ )
314
+ ),
315
+ plugins=self._per_example_plugins,
316
+ )
317
+ _ = (
318
+ pipeline
319
+ | f'Input-{evaluation.id}' >> beam.Create(
320
+ [(x.id, pg.to_json_str(x)) for x in inputs]
321
+ )
322
+ | f'Evaluate-{evaluation.id}'
323
+ >> beam.ParDo(
324
+ _EvaluateFn(
325
+ pg.to_json_str(leaf_node_runner),
326
+ ckpt_format=self.ckpt_format,
327
+ concurrent_startup_delay=self.concurrent_startup_delay,
328
+ )
329
+ )
330
+ )
331
+ monitor = ckpt_monitor.CheckpointMonitor(
332
+ pg.Ref(self.current_run),
333
+ plugins=pg.Ref(self.plugins),
334
+ # No need to add progress tracker as it is already added by the
335
+ # Beam runner.
336
+ progress_tracker=None,
337
+ monitor_inprogress_files=True,
338
+ checkpoint_pattern=f'checkpoint_*.{self.ckpt_format}',
339
+ max_aggregation_threads=self.max_aggregation_threads,
340
+ )
341
+ monitor.start()
342
+ monitor.join()
343
+
344
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]) -> None:
345
+ """Runs the experiment in sequence."""
346
+ raise NotImplementedError('Not needed in beam runner.')
347
+
348
+ def _evaluate_items(
349
+ self,
350
+ evaluation: evaluation_lib.Evaluation,
351
+ items: Iterator[example_lib.Example]
352
+ ) -> None:
353
+ """Evaluates the items of an evaluation."""
354
+ raise NotImplementedError('Not needed in beam runner.')
@@ -0,0 +1,153 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for beam runner."""
15
+
16
+ import os
17
+ import tempfile
18
+ from typing import Any
19
+ import unittest
20
+
21
+ from langfun.core.eval.v2 import checkpointing # pylint: disable=unused-import
22
+ from langfun.core.eval.v2 import eval_test_helper
23
+ from langfun.core.eval.v2 import reporting # pylint: disable=unused-import
24
+ from langfun.core.eval.v2.runners import beam # pylint: disable=unused-import
25
+ import pyglove as pg
26
+
27
+
28
+ @unittest.skip(
29
+ 'These tests are flaky due to writing ckpt files with standard IO.'
30
+ 'We will move to `beam.io` and re-enable these tests later.'
31
+ )
32
+ class BeamRunnerTest(unittest.TestCase):
33
+
34
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
35
+ self.assertEqual(len(actual), len(expected))
36
+ for i, (x, y) in enumerate(zip(actual, expected)):
37
+ if x is not y:
38
+ print(i, pg.diff(x, y))
39
+ self.assertIs(x, y)
40
+
41
+ def setUp(self):
42
+ super().setUp()
43
+ self.test_dir = os.path.join(tempfile.mkdtemp(), 'test_dir')
44
+
45
+ def test_basic(self):
46
+ plugin = eval_test_helper.TestPlugin()
47
+ exp = eval_test_helper.test_experiment()
48
+ root_dir = os.path.join(self.test_dir, 'test_beam_runner')
49
+ run = exp.run(
50
+ root_dir,
51
+ runner='beam',
52
+ plugins=[
53
+ plugin,
54
+ reporting.ExampleHtmlGenerator(),
55
+ checkpointing.PerExampleCheckpointer(
56
+ checkpoint_filename='checkpoint.jsonl'
57
+ ),
58
+ ],
59
+ concurrent_startup_delay=(1, 2),
60
+ use_cache='no',
61
+ ckpt_format='jsonl',
62
+ )
63
+
64
+ self.assertIsNotNone(plugin.start_time)
65
+ self.assertIsNotNone(plugin.complete_time)
66
+ self.assertGreater(plugin.complete_time, plugin.start_time)
67
+
68
+ self.assertEqual(len(plugin.started_experiments), len(exp.nodes))
69
+ self.assertEqual(len(plugin.completed_experiments), len(exp.nodes))
70
+ self.assertEqual(len(plugin.started_example_ids), 6 * 10)
71
+ self.assertEqual(len(plugin.completed_example_ids), 6 * 10)
72
+ self.assert_same_list(plugin.skipped_experiments, [])
73
+
74
+ for node in exp.leaf_nodes:
75
+ for i in range(node.num_examples):
76
+ self.assertTrue(
77
+ pg.io.path_exists(
78
+ run.output_path_for(node, f'{i + 1}.html')
79
+ )
80
+ )
81
+
82
+ for node in exp.nodes:
83
+ if node.is_leaf:
84
+ self.assertTrue(node.progress.is_started)
85
+ self.assertTrue(node.progress.is_completed)
86
+ self.assertEqual(node.progress.num_skipped, 0)
87
+ self.assertEqual(node.progress.num_completed, 10)
88
+ self.assertEqual(node.progress.num_failed, 1)
89
+ else:
90
+ self.assertEqual(node.progress.num_skipped, 0)
91
+ self.assertEqual(node.progress.num_failed, 0)
92
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
93
+
94
+ # Test warm start.
95
+ root_dir2 = os.path.join(self.test_dir, 'test_warm_start')
96
+ exp = eval_test_helper.test_experiment()
97
+ plugin = eval_test_helper.TestPlugin()
98
+ run2 = exp.run(
99
+ root_dir2,
100
+ warm_start_from=run.output_root,
101
+ runner='beam',
102
+ plugins=[plugin],
103
+ use_cache='no',
104
+ ckpt_format='jsonl',
105
+ )
106
+ for node in run2.experiment.nodes:
107
+ if node.is_leaf:
108
+ self.assertTrue(node.progress.is_started)
109
+ self.assertTrue(node.progress.is_completed)
110
+ self.assertEqual(node.progress.num_skipped, 0)
111
+ self.assertEqual(node.progress.num_completed, 10)
112
+ self.assertEqual(node.progress.num_failed, 1)
113
+ else:
114
+ self.assertEqual(node.progress.num_skipped, 0)
115
+ self.assertEqual(node.progress.num_failed, 0)
116
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
117
+
118
+ def test_shuffle_inputs(self):
119
+ root_dir = os.path.join(self.test_dir, 'test_shuffle_inputs')
120
+ exp = eval_test_helper.test_experiment()
121
+ plugin = eval_test_helper.TestPlugin()
122
+ run = exp.run(
123
+ root_dir,
124
+ runner='beam',
125
+ plugins=[plugin],
126
+ shuffle_inputs=True,
127
+ use_cache='no',
128
+ ckpt_format='jsonl',
129
+ )
130
+ self.assertTrue(run.shuffle_inputs)
131
+
132
+ def test_beam_runner_does_not_support_cache(self):
133
+ exp = eval_test_helper.test_experiment()
134
+ root_dir = os.path.join(self.test_dir, 'test_beam_runner_cache')
135
+ with self.assertRaisesRegex(ValueError, 'Cache is not supported'):
136
+ exp.run(
137
+ root_dir,
138
+ runner='beam',
139
+ use_cache='global',
140
+ )
141
+
142
+ def test_no_beam(self):
143
+ orig_beam = beam.beam
144
+ beam.beam = None
145
+ with self.assertRaisesRegex(ValueError, 'Beam is not installed'):
146
+ exp = eval_test_helper.TestEvaluation()
147
+ root_dir = os.path.join(self.test_dir, 'test_no_beam')
148
+ exp.run(root_dir, runner='beam')
149
+ beam.beam = orig_beam
150
+
151
+
152
+ if __name__ == '__main__':
153
+ unittest.main()