langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Beam-based evaluation runner.
|
|
15
|
+
|
|
16
|
+
BeamRunner is a runner that uses Apache Beam to run evaluations in parallel.
|
|
17
|
+
It is useful for running evaluations with a large number of examples and/or when
|
|
18
|
+
each example is costly to evaluate and can be parallelized.
|
|
19
|
+
|
|
20
|
+
BeamRunner supports plugins as all other runners do, with the following caveats:
|
|
21
|
+
|
|
22
|
+
1. Checkpointer plugins are ignored, as BeamRunner performs its own per-example
|
|
23
|
+
checkpointing.
|
|
24
|
+
|
|
25
|
+
2. Per-example plugins are executed in the Beam worker to maximize throughput,
|
|
26
|
+
while all non-per-example plugins are executed in the main process, which
|
|
27
|
+
collects the results from the workers. Since it might be expensive to
|
|
28
|
+
deserialize `Example.metadata` for complex evaluations, the main process
|
|
29
|
+
does not deserialize `Example.metadata` from the workers. If you need to
|
|
30
|
+
to access `Example.metadata` in your plugin, consider make it a per-example
|
|
31
|
+
plugin (which only implements `on_example_start` and/or
|
|
32
|
+
`on_example_complete`)
|
|
33
|
+
|
|
34
|
+
To use it, simply create a `lf.eval.Suite` or `lf.eval.Evaluation`
|
|
35
|
+
and run it with `lf.eval.run(runner='beam')` and passing in an additional
|
|
36
|
+
`beam_runner` argument.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
import datetime
|
|
40
|
+
import hashlib
|
|
41
|
+
import os
|
|
42
|
+
import random
|
|
43
|
+
import time
|
|
44
|
+
from typing import Annotated, Any, Iterator
|
|
45
|
+
|
|
46
|
+
from langfun.core.eval.v2 import checkpointing
|
|
47
|
+
from langfun.core.eval.v2 import evaluation as evaluation_lib
|
|
48
|
+
from langfun.core.eval.v2 import example as example_lib
|
|
49
|
+
from langfun.core.eval.v2.runners import base
|
|
50
|
+
from langfun.core.eval.v2.runners import ckpt_monitor
|
|
51
|
+
|
|
52
|
+
import pyglove as pg
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
# pylint: disable=g-import-not-at-top
|
|
56
|
+
# pytype: disable=import-error
|
|
57
|
+
import apache_beam as beam
|
|
58
|
+
from apache_beam.options import pipeline_options
|
|
59
|
+
# pytype: enable=import-error
|
|
60
|
+
# pylint: enable=g-import-not-at-top
|
|
61
|
+
except ImportError:
|
|
62
|
+
beam = None
|
|
63
|
+
pipeline_options = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
if beam is not None:
|
|
67
|
+
class _EvaluateFn(beam.DoFn):
|
|
68
|
+
"""Beam DoFn for evaluating examples."""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
runner_str: str,
|
|
73
|
+
ckpt_format: str,
|
|
74
|
+
concurrent_startup_delay: tuple[int, int] | None = None,
|
|
75
|
+
):
|
|
76
|
+
self._runner_str = runner_str
|
|
77
|
+
self._ckpt_format = ckpt_format
|
|
78
|
+
self._concurrent_startup_delay = concurrent_startup_delay
|
|
79
|
+
|
|
80
|
+
def setup(self):
|
|
81
|
+
if self._concurrent_startup_delay is not None:
|
|
82
|
+
time.sleep(random.randint(*self._concurrent_startup_delay))
|
|
83
|
+
self._runner = pg.from_json_str(self._runner_str)
|
|
84
|
+
assert isinstance(self._runner, LeafNodeRunner)
|
|
85
|
+
self._runner.setup()
|
|
86
|
+
self._input_dir = self._runner.current_run.input_dir(
|
|
87
|
+
self._runner.current_run.experiment
|
|
88
|
+
)
|
|
89
|
+
self._output_dir = self._runner.current_run.output_dir(
|
|
90
|
+
self._runner.current_run.experiment
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def teardown(self):
|
|
94
|
+
assert self._runner is not None
|
|
95
|
+
self._runner.teardown()
|
|
96
|
+
|
|
97
|
+
def process(self, example: tuple[int, str]) -> Iterator[str]:
|
|
98
|
+
"""Evaluates an example and writes the checkpoint file.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
example: A tuple of (example_id, example_json).
|
|
102
|
+
|
|
103
|
+
Yields:
|
|
104
|
+
The path to the checkpoint file.
|
|
105
|
+
"""
|
|
106
|
+
example_id, example_json = example
|
|
107
|
+
ckpt_file = os.path.join(
|
|
108
|
+
self._output_dir, f'checkpoint_{example_id}.{self._ckpt_format}'
|
|
109
|
+
)
|
|
110
|
+
if pg.io.path_exists(ckpt_file):
|
|
111
|
+
yield ckpt_file
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
if self._input_dir != self._output_dir:
|
|
115
|
+
warmup_ckpt_file = os.path.join(
|
|
116
|
+
self._input_dir, f'checkpoint_{example_id}.{self._ckpt_format}'
|
|
117
|
+
)
|
|
118
|
+
if pg.io.path_exists(warmup_ckpt_file):
|
|
119
|
+
pg.io.copy(warmup_ckpt_file, ckpt_file)
|
|
120
|
+
yield ckpt_file
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
# Write the in-progress file to indicate that the example is being
|
|
124
|
+
# processed.
|
|
125
|
+
in_progress_file = os.path.join(
|
|
126
|
+
self._output_dir, f'{example_id}.inprogress'
|
|
127
|
+
)
|
|
128
|
+
pg.io.writefile(
|
|
129
|
+
in_progress_file,
|
|
130
|
+
datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Process one example.
|
|
134
|
+
example = self._runner.process(pg.from_json_str(example_json))
|
|
135
|
+
|
|
136
|
+
# Perform atomic checkpointing.
|
|
137
|
+
tmp_ckpt_file = os.path.join(
|
|
138
|
+
self._output_dir, f'tmp_checkpoint_{example_id}.{self._ckpt_format}'
|
|
139
|
+
)
|
|
140
|
+
example_json_str = pg.to_json_str(example)
|
|
141
|
+
with pg.io.open_sequence(tmp_ckpt_file, 'w') as f:
|
|
142
|
+
f.add(example_json_str)
|
|
143
|
+
pg.io.rename(tmp_ckpt_file, ckpt_file)
|
|
144
|
+
|
|
145
|
+
# Write the MD5 digest of the example so we know if the example has been
|
|
146
|
+
# processed multiple times.
|
|
147
|
+
digest = hashlib.md5(example_json_str.encode()).hexdigest()[:8]
|
|
148
|
+
pg.io.writefile(
|
|
149
|
+
os.path.join(self._output_dir, f'{example_id}.{digest}.md5'),
|
|
150
|
+
digest
|
|
151
|
+
)
|
|
152
|
+
yield ckpt_file
|
|
153
|
+
|
|
154
|
+
else:
|
|
155
|
+
_EvaluateFn = None # pylint: disable=invalid-name
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class LeafNodeRunner(base.RunnerBase):
|
|
159
|
+
"""A runner that runs in a DoFn worker."""
|
|
160
|
+
|
|
161
|
+
NAME = '__beam_leaf_node_runner__'
|
|
162
|
+
progress_tracker = None
|
|
163
|
+
max_background_threads = 0
|
|
164
|
+
|
|
165
|
+
def _on_bound(self):
|
|
166
|
+
super()._on_bound()
|
|
167
|
+
for plugin in self.plugins:
|
|
168
|
+
if not plugin.is_per_example():
|
|
169
|
+
raise ValueError(
|
|
170
|
+
'Only per-example plugins are supported in LeafNodeRunner. '
|
|
171
|
+
f'Encountered: {plugin!r}'
|
|
172
|
+
)
|
|
173
|
+
if not isinstance(self.current_run.experiment, evaluation_lib.Evaluation):
|
|
174
|
+
raise ValueError(
|
|
175
|
+
'The experiment must be a leaf evaluation in LeafNodeRunner. '
|
|
176
|
+
f'Encountered: {self.current_run.experiment!r}'
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def setup(self):
|
|
180
|
+
self.current_run.experiment.setup()
|
|
181
|
+
|
|
182
|
+
def teardown(self):
|
|
183
|
+
self.current_run.experiment.teardown()
|
|
184
|
+
|
|
185
|
+
def process(self, example: example_lib.Example) -> example_lib.Example:
|
|
186
|
+
"""Processes one example."""
|
|
187
|
+
for plugin in self.plugins:
|
|
188
|
+
plugin.on_example_start(self, self.current_run.experiment, example)
|
|
189
|
+
example = self.current_run.experiment.evaluate(example)
|
|
190
|
+
for plugin in self.plugins:
|
|
191
|
+
plugin.on_example_complete(self, self.current_run.experiment, example)
|
|
192
|
+
return example
|
|
193
|
+
|
|
194
|
+
def _run(self, evaluations: list[evaluation_lib.Evaluation]) -> None:
|
|
195
|
+
"""Runs the experiment in sequence."""
|
|
196
|
+
raise NotImplementedError('Not needed in leaf node runner.')
|
|
197
|
+
|
|
198
|
+
def _evaluate_items(
|
|
199
|
+
self,
|
|
200
|
+
evaluation: evaluation_lib.Evaluation,
|
|
201
|
+
items: Iterator[example_lib.Example]
|
|
202
|
+
) -> None:
|
|
203
|
+
"""Evaluates the items of an evaluation."""
|
|
204
|
+
raise NotImplementedError('Not needed in leaf node runner.')
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class BeamRunner(base.RunnerBase):
|
|
208
|
+
"""Beam runner for Langfun evaluations.
|
|
209
|
+
|
|
210
|
+
NOTE: This runner depends on Apache Beam, which needs to be installed
|
|
211
|
+
separately.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
NAME = 'beam'
|
|
215
|
+
|
|
216
|
+
beam_runner: Annotated[
|
|
217
|
+
Any | None,
|
|
218
|
+
'The beam runner to use. If None, the direct runner will be used.'
|
|
219
|
+
] = None
|
|
220
|
+
|
|
221
|
+
beam_pipeline_options: Annotated[
|
|
222
|
+
dict[str, Any],
|
|
223
|
+
'Beam pipeline options.'
|
|
224
|
+
] = {}
|
|
225
|
+
|
|
226
|
+
ckpt_format: Annotated[
|
|
227
|
+
str,
|
|
228
|
+
'The file extension of the checkpoint files.'
|
|
229
|
+
] = 'jsonl'
|
|
230
|
+
|
|
231
|
+
max_aggregation_threads: Annotated[
|
|
232
|
+
int,
|
|
233
|
+
'The maximum number of threads to aggregate checkpoints.'
|
|
234
|
+
] = 128
|
|
235
|
+
|
|
236
|
+
concurrent_startup_delay: Annotated[
|
|
237
|
+
tuple[int, int] | None,
|
|
238
|
+
(
|
|
239
|
+
'A range of seconds to delay the initial evaluation of each thread '
|
|
240
|
+
'in the thread pool, helping to prevent a burst in LLM QPS at '
|
|
241
|
+
'startup. If set to None, no delay will be applied.'
|
|
242
|
+
)
|
|
243
|
+
] = None
|
|
244
|
+
|
|
245
|
+
def _on_bound(self):
|
|
246
|
+
if beam is None:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
'Apache Beam is not installed. '
|
|
249
|
+
'Please run `pip install apache-beam` to install beam.'
|
|
250
|
+
)
|
|
251
|
+
if self.current_run.use_cache != 'no':
|
|
252
|
+
raise ValueError(
|
|
253
|
+
'Cache is not supported in BeamRunner. '
|
|
254
|
+
f'Encountered: {self.current_run.use_cache}'
|
|
255
|
+
)
|
|
256
|
+
host_plugins = []
|
|
257
|
+
per_example_plugins = []
|
|
258
|
+
for plugin in self.plugins:
|
|
259
|
+
if isinstance(plugin, checkpointing.Checkpointer):
|
|
260
|
+
pg.logging.warning(
|
|
261
|
+
'Built-in checkpointing is enabled on BeamRunner. '
|
|
262
|
+
f'Ignoring checkpointer: {plugin!r}.'
|
|
263
|
+
)
|
|
264
|
+
elif plugin.is_per_example():
|
|
265
|
+
per_example_plugins.append(pg.Ref(plugin))
|
|
266
|
+
else:
|
|
267
|
+
host_plugins.append(pg.Ref(plugin))
|
|
268
|
+
|
|
269
|
+
self.rebind(
|
|
270
|
+
plugins=host_plugins,
|
|
271
|
+
skip_notification=True,
|
|
272
|
+
raise_on_no_change=False
|
|
273
|
+
)
|
|
274
|
+
self._per_example_plugins = per_example_plugins
|
|
275
|
+
super()._on_bound()
|
|
276
|
+
|
|
277
|
+
def run(self) -> None:
|
|
278
|
+
"""Run evaluations using Beam."""
|
|
279
|
+
assert beam is not None
|
|
280
|
+
assert pipeline_options is not None
|
|
281
|
+
|
|
282
|
+
with beam.Pipeline(
|
|
283
|
+
runner=self.beam_runner or beam.runners.DirectRunner(),
|
|
284
|
+
options=pipeline_options.PipelineOptions(**self.beam_pipeline_options)
|
|
285
|
+
) as pipeline:
|
|
286
|
+
evaluation_ids = set()
|
|
287
|
+
for evaluation in self.current_run.experiment.leaf_nodes:
|
|
288
|
+
if evaluation.id in evaluation_ids or (
|
|
289
|
+
self.current_run.filter is not None
|
|
290
|
+
and not self.current_run.filter(evaluation)
|
|
291
|
+
):
|
|
292
|
+
continue
|
|
293
|
+
|
|
294
|
+
# There could be suites with duplicate evaluations, but we only want
|
|
295
|
+
# to run each evaluation once.
|
|
296
|
+
evaluation_ids.add(evaluation.id)
|
|
297
|
+
|
|
298
|
+
example_ids = self.current_run.example_ids
|
|
299
|
+
if example_ids is None:
|
|
300
|
+
example_ids = range(1, evaluation.num_examples + 1)
|
|
301
|
+
inputs = [
|
|
302
|
+
example_lib.Example(id=i, input=evaluation.example_input_by_id(i))
|
|
303
|
+
for i in example_ids
|
|
304
|
+
]
|
|
305
|
+
if self.current_run.shuffle_inputs:
|
|
306
|
+
random.shuffle(inputs)
|
|
307
|
+
|
|
308
|
+
leaf_node_runner = LeafNodeRunner(
|
|
309
|
+
current_run=self.current_run.clone(
|
|
310
|
+
override=dict(
|
|
311
|
+
experiment=evaluation,
|
|
312
|
+
raise_if_has_error=False,
|
|
313
|
+
)
|
|
314
|
+
),
|
|
315
|
+
plugins=self._per_example_plugins,
|
|
316
|
+
)
|
|
317
|
+
_ = (
|
|
318
|
+
pipeline
|
|
319
|
+
| f'Input-{evaluation.id}' >> beam.Create(
|
|
320
|
+
[(x.id, pg.to_json_str(x)) for x in inputs]
|
|
321
|
+
)
|
|
322
|
+
| f'Evaluate-{evaluation.id}'
|
|
323
|
+
>> beam.ParDo(
|
|
324
|
+
_EvaluateFn(
|
|
325
|
+
pg.to_json_str(leaf_node_runner),
|
|
326
|
+
ckpt_format=self.ckpt_format,
|
|
327
|
+
concurrent_startup_delay=self.concurrent_startup_delay,
|
|
328
|
+
)
|
|
329
|
+
)
|
|
330
|
+
)
|
|
331
|
+
monitor = ckpt_monitor.CheckpointMonitor(
|
|
332
|
+
pg.Ref(self.current_run),
|
|
333
|
+
plugins=pg.Ref(self.plugins),
|
|
334
|
+
# No need to add progress tracker as it is already added by the
|
|
335
|
+
# Beam runner.
|
|
336
|
+
progress_tracker=None,
|
|
337
|
+
monitor_inprogress_files=True,
|
|
338
|
+
checkpoint_pattern=f'checkpoint_*.{self.ckpt_format}',
|
|
339
|
+
max_aggregation_threads=self.max_aggregation_threads,
|
|
340
|
+
)
|
|
341
|
+
monitor.start()
|
|
342
|
+
monitor.join()
|
|
343
|
+
|
|
344
|
+
def _run(self, evaluations: list[evaluation_lib.Evaluation]) -> None:
|
|
345
|
+
"""Runs the experiment in sequence."""
|
|
346
|
+
raise NotImplementedError('Not needed in beam runner.')
|
|
347
|
+
|
|
348
|
+
def _evaluate_items(
|
|
349
|
+
self,
|
|
350
|
+
evaluation: evaluation_lib.Evaluation,
|
|
351
|
+
items: Iterator[example_lib.Example]
|
|
352
|
+
) -> None:
|
|
353
|
+
"""Evaluates the items of an evaluation."""
|
|
354
|
+
raise NotImplementedError('Not needed in beam runner.')
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Tests for beam runner."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import tempfile
|
|
18
|
+
from typing import Any
|
|
19
|
+
import unittest
|
|
20
|
+
|
|
21
|
+
from langfun.core.eval.v2 import checkpointing # pylint: disable=unused-import
|
|
22
|
+
from langfun.core.eval.v2 import eval_test_helper
|
|
23
|
+
from langfun.core.eval.v2 import reporting # pylint: disable=unused-import
|
|
24
|
+
from langfun.core.eval.v2.runners import beam # pylint: disable=unused-import
|
|
25
|
+
import pyglove as pg
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@unittest.skip(
|
|
29
|
+
'These tests are flaky due to writing ckpt files with standard IO.'
|
|
30
|
+
'We will move to `beam.io` and re-enable these tests later.'
|
|
31
|
+
)
|
|
32
|
+
class BeamRunnerTest(unittest.TestCase):
|
|
33
|
+
|
|
34
|
+
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
|
35
|
+
self.assertEqual(len(actual), len(expected))
|
|
36
|
+
for i, (x, y) in enumerate(zip(actual, expected)):
|
|
37
|
+
if x is not y:
|
|
38
|
+
print(i, pg.diff(x, y))
|
|
39
|
+
self.assertIs(x, y)
|
|
40
|
+
|
|
41
|
+
def setUp(self):
|
|
42
|
+
super().setUp()
|
|
43
|
+
self.test_dir = os.path.join(tempfile.mkdtemp(), 'test_dir')
|
|
44
|
+
|
|
45
|
+
def test_basic(self):
|
|
46
|
+
plugin = eval_test_helper.TestPlugin()
|
|
47
|
+
exp = eval_test_helper.test_experiment()
|
|
48
|
+
root_dir = os.path.join(self.test_dir, 'test_beam_runner')
|
|
49
|
+
run = exp.run(
|
|
50
|
+
root_dir,
|
|
51
|
+
runner='beam',
|
|
52
|
+
plugins=[
|
|
53
|
+
plugin,
|
|
54
|
+
reporting.ExampleHtmlGenerator(),
|
|
55
|
+
checkpointing.PerExampleCheckpointer(
|
|
56
|
+
checkpoint_filename='checkpoint.jsonl'
|
|
57
|
+
),
|
|
58
|
+
],
|
|
59
|
+
concurrent_startup_delay=(1, 2),
|
|
60
|
+
use_cache='no',
|
|
61
|
+
ckpt_format='jsonl',
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.assertIsNotNone(plugin.start_time)
|
|
65
|
+
self.assertIsNotNone(plugin.complete_time)
|
|
66
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
|
67
|
+
|
|
68
|
+
self.assertEqual(len(plugin.started_experiments), len(exp.nodes))
|
|
69
|
+
self.assertEqual(len(plugin.completed_experiments), len(exp.nodes))
|
|
70
|
+
self.assertEqual(len(plugin.started_example_ids), 6 * 10)
|
|
71
|
+
self.assertEqual(len(plugin.completed_example_ids), 6 * 10)
|
|
72
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
|
73
|
+
|
|
74
|
+
for node in exp.leaf_nodes:
|
|
75
|
+
for i in range(node.num_examples):
|
|
76
|
+
self.assertTrue(
|
|
77
|
+
pg.io.path_exists(
|
|
78
|
+
run.output_path_for(node, f'{i + 1}.html')
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
for node in exp.nodes:
|
|
83
|
+
if node.is_leaf:
|
|
84
|
+
self.assertTrue(node.progress.is_started)
|
|
85
|
+
self.assertTrue(node.progress.is_completed)
|
|
86
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
87
|
+
self.assertEqual(node.progress.num_completed, 10)
|
|
88
|
+
self.assertEqual(node.progress.num_failed, 1)
|
|
89
|
+
else:
|
|
90
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
91
|
+
self.assertEqual(node.progress.num_failed, 0)
|
|
92
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
|
93
|
+
|
|
94
|
+
# Test warm start.
|
|
95
|
+
root_dir2 = os.path.join(self.test_dir, 'test_warm_start')
|
|
96
|
+
exp = eval_test_helper.test_experiment()
|
|
97
|
+
plugin = eval_test_helper.TestPlugin()
|
|
98
|
+
run2 = exp.run(
|
|
99
|
+
root_dir2,
|
|
100
|
+
warm_start_from=run.output_root,
|
|
101
|
+
runner='beam',
|
|
102
|
+
plugins=[plugin],
|
|
103
|
+
use_cache='no',
|
|
104
|
+
ckpt_format='jsonl',
|
|
105
|
+
)
|
|
106
|
+
for node in run2.experiment.nodes:
|
|
107
|
+
if node.is_leaf:
|
|
108
|
+
self.assertTrue(node.progress.is_started)
|
|
109
|
+
self.assertTrue(node.progress.is_completed)
|
|
110
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
111
|
+
self.assertEqual(node.progress.num_completed, 10)
|
|
112
|
+
self.assertEqual(node.progress.num_failed, 1)
|
|
113
|
+
else:
|
|
114
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
115
|
+
self.assertEqual(node.progress.num_failed, 0)
|
|
116
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
|
117
|
+
|
|
118
|
+
def test_shuffle_inputs(self):
|
|
119
|
+
root_dir = os.path.join(self.test_dir, 'test_shuffle_inputs')
|
|
120
|
+
exp = eval_test_helper.test_experiment()
|
|
121
|
+
plugin = eval_test_helper.TestPlugin()
|
|
122
|
+
run = exp.run(
|
|
123
|
+
root_dir,
|
|
124
|
+
runner='beam',
|
|
125
|
+
plugins=[plugin],
|
|
126
|
+
shuffle_inputs=True,
|
|
127
|
+
use_cache='no',
|
|
128
|
+
ckpt_format='jsonl',
|
|
129
|
+
)
|
|
130
|
+
self.assertTrue(run.shuffle_inputs)
|
|
131
|
+
|
|
132
|
+
def test_beam_runner_does_not_support_cache(self):
|
|
133
|
+
exp = eval_test_helper.test_experiment()
|
|
134
|
+
root_dir = os.path.join(self.test_dir, 'test_beam_runner_cache')
|
|
135
|
+
with self.assertRaisesRegex(ValueError, 'Cache is not supported'):
|
|
136
|
+
exp.run(
|
|
137
|
+
root_dir,
|
|
138
|
+
runner='beam',
|
|
139
|
+
use_cache='global',
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def test_no_beam(self):
|
|
143
|
+
orig_beam = beam.beam
|
|
144
|
+
beam.beam = None
|
|
145
|
+
with self.assertRaisesRegex(ValueError, 'Beam is not installed'):
|
|
146
|
+
exp = eval_test_helper.TestEvaluation()
|
|
147
|
+
root_dir = os.path.join(self.test_dir, 'test_no_beam')
|
|
148
|
+
exp.run(root_dir, runner='beam')
|
|
149
|
+
beam.beam = orig_beam
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
if __name__ == '__main__':
|
|
153
|
+
unittest.main()
|