langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Parallel runner."""
|
|
15
|
+
|
|
16
|
+
import collections
|
|
17
|
+
import math
|
|
18
|
+
import random
|
|
19
|
+
import threading
|
|
20
|
+
import time
|
|
21
|
+
|
|
22
|
+
from typing import Annotated, Iterator
|
|
23
|
+
import langfun.core as lf
|
|
24
|
+
from langfun.core.eval.v2 import checkpointing
|
|
25
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
|
26
|
+
from langfun.core.eval.v2.runners import base
|
|
27
|
+
from langfun.core.eval.v2.runners import ckpt_monitor
|
|
28
|
+
import pyglove as pg
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ParallelRunner(base.RunnerBase):
|
|
32
|
+
"""A runner that executes evaluations and examples in parallel.
|
|
33
|
+
|
|
34
|
+
The parallel runner groups evaluations by their required resources
|
|
35
|
+
(e.g., specific LLMs) and runs evaluations that do not share resources in
|
|
36
|
+
parallel. Within each evaluation, examples are also processed in parallel
|
|
37
|
+
using threads, up to `Evaluation.max_workers`.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
NAME = 'parallel'
|
|
41
|
+
|
|
42
|
+
timeout: Annotated[
|
|
43
|
+
int | None,
|
|
44
|
+
'Timeout for each evaluation example.'
|
|
45
|
+
] = None
|
|
46
|
+
|
|
47
|
+
concurrent_startup_delay: Annotated[
|
|
48
|
+
tuple[int, int] | None,
|
|
49
|
+
(
|
|
50
|
+
'A range of seconds to delay the initial evaluation of each thread '
|
|
51
|
+
'in the thread pool, helping to prevent a burst in LLM QPS at '
|
|
52
|
+
'startup. If set to None, no delay will be applied.'
|
|
53
|
+
)
|
|
54
|
+
] = None
|
|
55
|
+
|
|
56
|
+
def _run(self, evaluations: list[base.Evaluation]) -> None:
|
|
57
|
+
"""Runs the evaluations in parallel."""
|
|
58
|
+
def _run_group(evaluation_group: list[base.Evaluation]):
|
|
59
|
+
for e in evaluation_group:
|
|
60
|
+
self.run_evaluation(e)
|
|
61
|
+
|
|
62
|
+
# Run evaluations in parallel groupped by resource key.
|
|
63
|
+
groups: dict[str, list[base.Evaluation]] = collections.defaultdict(list)
|
|
64
|
+
for e in evaluations:
|
|
65
|
+
resource_ids = e.resource_ids()
|
|
66
|
+
if not resource_ids:
|
|
67
|
+
group_id = e.id
|
|
68
|
+
else:
|
|
69
|
+
# TODO(daiyip): support group that requires multiple resources.
|
|
70
|
+
group_id = resource_ids.pop()
|
|
71
|
+
groups[group_id].append(e)
|
|
72
|
+
|
|
73
|
+
for _, _, _ in lf.concurrent_map(
|
|
74
|
+
_run_group,
|
|
75
|
+
groups.values(),
|
|
76
|
+
max_workers=max(64, len(groups)),
|
|
77
|
+
timeout=self.timeout,
|
|
78
|
+
silence_on_errors=None,
|
|
79
|
+
):
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
def _evaluate_items(
|
|
83
|
+
self, evaluation: base.Evaluation, items: Iterator[base.Example]
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Override run items to run in parallel."""
|
|
86
|
+
if self.concurrent_startup_delay is not None:
|
|
87
|
+
thread_delayed = {}
|
|
88
|
+
def _evaluate_item(item: base.Example):
|
|
89
|
+
thread_id = threading.current_thread().ident
|
|
90
|
+
if thread_id not in thread_delayed:
|
|
91
|
+
thread_delayed[thread_id] = True
|
|
92
|
+
time.sleep(random.randint(*self.concurrent_startup_delay))
|
|
93
|
+
return self.evaluate_item(evaluation, item)
|
|
94
|
+
else:
|
|
95
|
+
def _evaluate_item(item: base.Example):
|
|
96
|
+
return self.evaluate_item(evaluation, item)
|
|
97
|
+
|
|
98
|
+
for _, _, _ in lf.concurrent_map(
|
|
99
|
+
_evaluate_item,
|
|
100
|
+
items,
|
|
101
|
+
max_workers=self._max_workers(evaluation),
|
|
102
|
+
timeout=self.timeout,
|
|
103
|
+
silence_on_errors=None,
|
|
104
|
+
):
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
def _max_workers(self, evaluation: base.Evaluation) -> int | None:
|
|
108
|
+
return evaluation.max_workers
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class _SingleSliceRunner(ParallelRunner):
|
|
112
|
+
"""A single slice runner."""
|
|
113
|
+
|
|
114
|
+
NAME = '__single_slice_runner__'
|
|
115
|
+
|
|
116
|
+
# Do not track progress in single slice runner.
|
|
117
|
+
progress_tracker = None
|
|
118
|
+
|
|
119
|
+
num_slices: Annotated[
|
|
120
|
+
int,
|
|
121
|
+
'The number of slices to run the evaluations in.'
|
|
122
|
+
] = 1
|
|
123
|
+
|
|
124
|
+
def _max_workers(self, evaluation: base.Evaluation) -> int | None:
|
|
125
|
+
max_workers = super()._max_workers(evaluation)
|
|
126
|
+
if max_workers is None:
|
|
127
|
+
return None
|
|
128
|
+
return max(1, math.ceil(max_workers / self.num_slices))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class MultiSliceParallelRunner(experiment_lib.Runner):
|
|
132
|
+
"""A sliced parallel runner.
|
|
133
|
+
|
|
134
|
+
An evaluation is split into `num_slices` slices. Each MultiSliceParallelRunner
|
|
135
|
+
instance is responsible for evaluating a single slice. The instance with
|
|
136
|
+
`slice_id` 0 will also aggregate checkpoints from all slices.
|
|
137
|
+
|
|
138
|
+
Sliced parallel runner allows running multiple instances across different
|
|
139
|
+
machines/hosts parallelly. This can be utilize for scaling the evaluation
|
|
140
|
+
jobs to run on multiple machines, or running evaluations in a fault-tolerant
|
|
141
|
+
way by splitting each evaluation into multiple slices.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
NAME = 'sliced-parallel'
|
|
145
|
+
|
|
146
|
+
slice_id: Annotated[
|
|
147
|
+
int,
|
|
148
|
+
(
|
|
149
|
+
'The slice ID of the runner. If 0, it will also run as the '
|
|
150
|
+
'aggregator for collecting results from other slices. '
|
|
151
|
+
)
|
|
152
|
+
] = 0
|
|
153
|
+
|
|
154
|
+
num_slices: Annotated[
|
|
155
|
+
int,
|
|
156
|
+
'The number of slices to run the evaluations in parallel.'
|
|
157
|
+
] = 1
|
|
158
|
+
|
|
159
|
+
timeout: Annotated[
|
|
160
|
+
int | None,
|
|
161
|
+
'Timeout for each evaluation example.'
|
|
162
|
+
] = None
|
|
163
|
+
|
|
164
|
+
concurrent_startup_delay: Annotated[
|
|
165
|
+
tuple[int, int] | None,
|
|
166
|
+
(
|
|
167
|
+
'A range of seconds to delay the initial evaluation of each thread '
|
|
168
|
+
'in the thread pool, helping to prevent a burst in LLM QPS at '
|
|
169
|
+
'startup. If set to None, no delay will be applied.'
|
|
170
|
+
)
|
|
171
|
+
] = None
|
|
172
|
+
|
|
173
|
+
ckpt_format: Annotated[
|
|
174
|
+
str,
|
|
175
|
+
'The file extension of the checkpoint files.'
|
|
176
|
+
] = 'bagz'
|
|
177
|
+
|
|
178
|
+
max_aggregation_threads: Annotated[
|
|
179
|
+
int,
|
|
180
|
+
'The maximum number of threads to aggregate checkpoints.'
|
|
181
|
+
] = 32
|
|
182
|
+
|
|
183
|
+
def _on_bound(self):
|
|
184
|
+
super()._on_bound()
|
|
185
|
+
if self.current_run.use_cache != 'no':
|
|
186
|
+
raise ValueError(
|
|
187
|
+
'Cache is not supported in MultiProcessParallelRunner. '
|
|
188
|
+
f'Encountered: {self.current_run.use_cache}'
|
|
189
|
+
)
|
|
190
|
+
monitor_plugins = []
|
|
191
|
+
worker_plugins = [
|
|
192
|
+
checkpointing.PerExampleCheckpointer(
|
|
193
|
+
checkpoint_filename=f'checkpoint.{self.ckpt_format}'
|
|
194
|
+
),
|
|
195
|
+
]
|
|
196
|
+
for plugin in self.plugins:
|
|
197
|
+
if isinstance(plugin, checkpointing.Checkpointer):
|
|
198
|
+
pg.logging.warning(
|
|
199
|
+
'Built-in checkpointing is enabled on MultiProcessParallelRunner. '
|
|
200
|
+
f'Ignoring checkpointer: {plugin!r}.'
|
|
201
|
+
)
|
|
202
|
+
elif plugin.is_per_example():
|
|
203
|
+
worker_plugins.append(pg.Ref(plugin))
|
|
204
|
+
else:
|
|
205
|
+
monitor_plugins.append(pg.Ref(plugin))
|
|
206
|
+
|
|
207
|
+
if self.slice_id == 0:
|
|
208
|
+
self._ckpt_monitor = ckpt_monitor.CheckpointMonitor(
|
|
209
|
+
pg.Ref(self.current_run),
|
|
210
|
+
plugins=monitor_plugins,
|
|
211
|
+
monitor_inprogress_files=True,
|
|
212
|
+
checkpoint_pattern=f'checkpoint_*.{self.ckpt_format}',
|
|
213
|
+
max_aggregation_threads=self.max_aggregation_threads,
|
|
214
|
+
)
|
|
215
|
+
else:
|
|
216
|
+
self._ckpt_monitor = None
|
|
217
|
+
|
|
218
|
+
self._slice_runner = _SingleSliceRunner(
|
|
219
|
+
current_run=self.current_run.clone(
|
|
220
|
+
override=dict(
|
|
221
|
+
# Clone the experiment to avoid updating the original one.
|
|
222
|
+
experiment=self.current_run.experiment.clone(),
|
|
223
|
+
example_ids=self._examples_to_evaluate,
|
|
224
|
+
)
|
|
225
|
+
),
|
|
226
|
+
plugins=worker_plugins,
|
|
227
|
+
timeout=self.timeout,
|
|
228
|
+
concurrent_startup_delay=self.concurrent_startup_delay,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def _examples_to_evaluate(
|
|
232
|
+
self,
|
|
233
|
+
experiment: experiment_lib.Experiment
|
|
234
|
+
) -> list[int]:
|
|
235
|
+
all_ids = self.current_run.examples_to_evaluate(experiment)
|
|
236
|
+
return [x for x in all_ids if x % self.num_slices == self.slice_id]
|
|
237
|
+
|
|
238
|
+
def run(self) -> None:
|
|
239
|
+
if self._ckpt_monitor is not None:
|
|
240
|
+
self._ckpt_monitor.start()
|
|
241
|
+
self._slice_runner.run()
|
|
242
|
+
if self._ckpt_monitor is not None:
|
|
243
|
+
self._ckpt_monitor.join()
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Tests for parallel runner."""
|
|
15
|
+
import os
|
|
16
|
+
import tempfile
|
|
17
|
+
import threading
|
|
18
|
+
from typing import Any
|
|
19
|
+
import unittest
|
|
20
|
+
|
|
21
|
+
from langfun.core.eval.v2 import checkpointing
|
|
22
|
+
from langfun.core.eval.v2 import eval_test_helper
|
|
23
|
+
from langfun.core.eval.v2 import reporting
|
|
24
|
+
from langfun.core.eval.v2.runners import parallel # pylint: disable=unused-import
|
|
25
|
+
|
|
26
|
+
import pyglove as pg
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ParallelRunnerTest(unittest.TestCase):
|
|
30
|
+
|
|
31
|
+
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
|
32
|
+
self.assertEqual(len(actual), len(expected))
|
|
33
|
+
for i, (x, y) in enumerate(zip(actual, expected)):
|
|
34
|
+
if x is not y:
|
|
35
|
+
print(i, pg.diff(x, y))
|
|
36
|
+
self.assertIs(x, y)
|
|
37
|
+
|
|
38
|
+
def test_parallel_runner(self):
|
|
39
|
+
plugin = eval_test_helper.TestPlugin()
|
|
40
|
+
exp = eval_test_helper.test_experiment()
|
|
41
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_parallel_runner')
|
|
42
|
+
_ = exp.run(root_dir, runner='parallel', plugins=[plugin])
|
|
43
|
+
|
|
44
|
+
self.assertIsNotNone(plugin.start_time)
|
|
45
|
+
self.assertIsNotNone(plugin.complete_time)
|
|
46
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
|
47
|
+
|
|
48
|
+
self.assertEqual(
|
|
49
|
+
len(plugin.started_experiments), len(exp.nodes)
|
|
50
|
+
)
|
|
51
|
+
self.assertEqual(
|
|
52
|
+
len(plugin.completed_experiments), len(exp.nodes)
|
|
53
|
+
)
|
|
54
|
+
self.assertEqual(
|
|
55
|
+
len(plugin.started_example_ids), 6 * 10
|
|
56
|
+
)
|
|
57
|
+
self.assertEqual(
|
|
58
|
+
len(plugin.completed_example_ids), 6 * 10
|
|
59
|
+
)
|
|
60
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
|
61
|
+
|
|
62
|
+
for node in exp.nodes:
|
|
63
|
+
self.assertTrue(node.progress.is_started)
|
|
64
|
+
self.assertTrue(node.progress.is_completed)
|
|
65
|
+
if node.is_leaf:
|
|
66
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
67
|
+
self.assertEqual(node.progress.num_completed, 10)
|
|
68
|
+
self.assertEqual(node.progress.num_failed, 1)
|
|
69
|
+
else:
|
|
70
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
71
|
+
self.assertEqual(node.progress.num_failed, 0)
|
|
72
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
|
73
|
+
|
|
74
|
+
def test_raise_if_has_error(self):
|
|
75
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_raise_if_has_error')
|
|
76
|
+
exp = eval_test_helper.TestEvaluation()
|
|
77
|
+
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
|
|
78
|
+
exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
|
|
79
|
+
|
|
80
|
+
def test_concurrent_startup_delay(self):
|
|
81
|
+
plugin = eval_test_helper.TestPlugin()
|
|
82
|
+
exp = eval_test_helper.test_experiment()
|
|
83
|
+
root_dir = os.path.join(
|
|
84
|
+
tempfile.mkdtemp(), 'test_concurrent_startup_delay'
|
|
85
|
+
)
|
|
86
|
+
_ = exp.run(
|
|
87
|
+
root_dir,
|
|
88
|
+
runner='parallel',
|
|
89
|
+
plugins=[plugin],
|
|
90
|
+
concurrent_startup_delay=(0, 5),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class MultiProcessParallelRunnerTest(unittest.TestCase):
|
|
95
|
+
|
|
96
|
+
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
|
97
|
+
self.assertEqual(len(actual), len(expected))
|
|
98
|
+
for i, (x, y) in enumerate(zip(actual, expected)):
|
|
99
|
+
if x is not y:
|
|
100
|
+
print(i, pg.diff(x, y))
|
|
101
|
+
self.assertIs(x, y)
|
|
102
|
+
|
|
103
|
+
def test_basic(self):
|
|
104
|
+
plugin = eval_test_helper.TestPlugin()
|
|
105
|
+
exp = eval_test_helper.test_experiment()
|
|
106
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_parallel_runner')
|
|
107
|
+
runs = [None, None]
|
|
108
|
+
|
|
109
|
+
def run_slice(slice_id: int):
|
|
110
|
+
runs[slice_id] = exp.run(
|
|
111
|
+
root_dir,
|
|
112
|
+
runner='sliced-parallel',
|
|
113
|
+
plugins=[
|
|
114
|
+
pg.Ref(plugin),
|
|
115
|
+
# Ignored as PerExampleCheckpointer is used.
|
|
116
|
+
checkpointing.BulkCheckpointer(),
|
|
117
|
+
reporting.ExampleHtmlGenerator(),
|
|
118
|
+
],
|
|
119
|
+
use_cache='no',
|
|
120
|
+
slice_id=slice_id,
|
|
121
|
+
num_slices=2,
|
|
122
|
+
ckpt_format='jsonl',
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# We simulate two slices running in parallel.
|
|
126
|
+
threads = [threading.Thread(target=run_slice, args=(i,)) for i in range(2)]
|
|
127
|
+
for t in threads:
|
|
128
|
+
t.start()
|
|
129
|
+
for t in threads:
|
|
130
|
+
t.join()
|
|
131
|
+
|
|
132
|
+
self.assertIsNotNone(plugin.start_time)
|
|
133
|
+
self.assertIsNotNone(plugin.complete_time)
|
|
134
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
|
135
|
+
|
|
136
|
+
self.assertEqual(
|
|
137
|
+
len(plugin.started_experiments), len(exp.nodes)
|
|
138
|
+
)
|
|
139
|
+
self.assertEqual(
|
|
140
|
+
len(plugin.completed_experiments), len(exp.nodes)
|
|
141
|
+
)
|
|
142
|
+
self.assertEqual(
|
|
143
|
+
len(plugin.started_example_ids), 6 * 10
|
|
144
|
+
)
|
|
145
|
+
self.assertEqual(
|
|
146
|
+
len(plugin.completed_example_ids), 6 * 10
|
|
147
|
+
)
|
|
148
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
|
149
|
+
|
|
150
|
+
for node in exp.nodes:
|
|
151
|
+
self.assertTrue(node.progress.is_started)
|
|
152
|
+
self.assertTrue(node.progress.is_completed)
|
|
153
|
+
if node.is_leaf:
|
|
154
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
155
|
+
self.assertEqual(node.progress.num_completed, 10)
|
|
156
|
+
self.assertEqual(node.progress.num_failed, 1)
|
|
157
|
+
for example_id in runs[0].examples_to_evaluate(node):
|
|
158
|
+
self.assertTrue(
|
|
159
|
+
pg.io.path_exists(
|
|
160
|
+
runs[0].output_path_for(node, f'{example_id}.html')
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
165
|
+
self.assertEqual(node.progress.num_failed, 0)
|
|
166
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
|
167
|
+
|
|
168
|
+
def test_parallel_mp_runner_does_not_support_cache(self):
|
|
169
|
+
exp = eval_test_helper.test_experiment()
|
|
170
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_parallel_mp_runner_cache')
|
|
171
|
+
with self.assertRaisesRegex(ValueError, 'Cache is not supported'):
|
|
172
|
+
exp.run(
|
|
173
|
+
root_dir,
|
|
174
|
+
runner='sliced-parallel',
|
|
175
|
+
use_cache='global',
|
|
176
|
+
slice_id=0,
|
|
177
|
+
num_slices=1,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
if __name__ == '__main__':
|
|
182
|
+
unittest.main()
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Sequential runner."""
|
|
15
|
+
|
|
16
|
+
from typing import Any, Callable, Iterator
|
|
17
|
+
from langfun.core.eval.v2.runners import base
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SequentialRunner(base.RunnerBase):
|
|
21
|
+
"""A runner that executes evaluations and examples sequentially.
|
|
22
|
+
|
|
23
|
+
The sequential runner executes all evaluations and their examples in the
|
|
24
|
+
calling thread. Background tasks are also run sequentially, which makes it
|
|
25
|
+
easier to debug as exceptions from background tasks will be raised
|
|
26
|
+
immediately.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
NAME = 'sequential'
|
|
30
|
+
|
|
31
|
+
def background_run(
|
|
32
|
+
self, func: Callable[..., Any], *args: Any, **kwargs: Any
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Runs the function with the IO pool."""
|
|
35
|
+
func(*args, **kwargs)
|
|
36
|
+
|
|
37
|
+
def _run(self, evaluations: list[base.Evaluation]) -> None:
|
|
38
|
+
"""Runs the experiment in sequence."""
|
|
39
|
+
for e in evaluations:
|
|
40
|
+
self.run_evaluation(e)
|
|
41
|
+
|
|
42
|
+
def _evaluate_items(
|
|
43
|
+
self, evaluation: base.Evaluation, items: Iterator[base.Example]
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Runs the evaluation items in sequence."""
|
|
46
|
+
for item in items:
|
|
47
|
+
self.evaluate_item(evaluation, item)
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Tests for sequential runner."""
|
|
15
|
+
import os
|
|
16
|
+
import tempfile
|
|
17
|
+
from typing import Any
|
|
18
|
+
import unittest
|
|
19
|
+
|
|
20
|
+
from langfun.core.eval.v2 import eval_test_helper
|
|
21
|
+
from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-import
|
|
22
|
+
|
|
23
|
+
import pyglove as pg
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SequentialRunnerTest(unittest.TestCase):
|
|
27
|
+
|
|
28
|
+
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
|
29
|
+
self.assertEqual(len(actual), len(expected))
|
|
30
|
+
for i, (x, y) in enumerate(zip(actual, expected)):
|
|
31
|
+
if x is not y:
|
|
32
|
+
print(i, pg.diff(x, y))
|
|
33
|
+
self.assertIs(x, y)
|
|
34
|
+
|
|
35
|
+
def test_basic(self):
|
|
36
|
+
plugin = eval_test_helper.TestPlugin()
|
|
37
|
+
exp = eval_test_helper.test_experiment()
|
|
38
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_sequential_runner')
|
|
39
|
+
_ = exp.run(root_dir, runner='sequential', plugins=[plugin])
|
|
40
|
+
|
|
41
|
+
self.assertIsNotNone(plugin.start_time)
|
|
42
|
+
self.assertIsNotNone(plugin.complete_time)
|
|
43
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
|
44
|
+
|
|
45
|
+
self.assert_same_list(
|
|
46
|
+
plugin.started_experiments,
|
|
47
|
+
exp.nonleaf_nodes + exp.leaf_nodes
|
|
48
|
+
)
|
|
49
|
+
self.assert_same_list(
|
|
50
|
+
plugin.completed_experiments,
|
|
51
|
+
exp.leaf_nodes + list(reversed(exp.nonleaf_nodes))
|
|
52
|
+
)
|
|
53
|
+
self.assert_same_list(
|
|
54
|
+
plugin.started_example_ids, list(range(1, 11)) * 6
|
|
55
|
+
)
|
|
56
|
+
self.assert_same_list(
|
|
57
|
+
plugin.completed_example_ids, list(range(1, 11)) * 6
|
|
58
|
+
)
|
|
59
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
|
60
|
+
|
|
61
|
+
for node in exp.nodes:
|
|
62
|
+
self.assertTrue(node.progress.is_started)
|
|
63
|
+
self.assertTrue(node.progress.is_completed)
|
|
64
|
+
if node.is_leaf:
|
|
65
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
66
|
+
self.assertEqual(node.progress.num_completed, 10)
|
|
67
|
+
self.assertEqual(node.progress.num_failed, 1)
|
|
68
|
+
else:
|
|
69
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
70
|
+
self.assertEqual(node.progress.num_failed, 0)
|
|
71
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
|
72
|
+
|
|
73
|
+
def test_raise_if_has_error(self):
|
|
74
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_raise_if_has_error')
|
|
75
|
+
exp = eval_test_helper.TestEvaluation()
|
|
76
|
+
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
|
|
77
|
+
exp.run(
|
|
78
|
+
root_dir, runner='sequential', plugins=[], raise_if_has_error=True
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def test_example_ids(self):
|
|
82
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_example_ids')
|
|
83
|
+
exp = eval_test_helper.test_experiment()
|
|
84
|
+
plugin = eval_test_helper.TestPlugin()
|
|
85
|
+
_ = exp.run(
|
|
86
|
+
root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
|
|
87
|
+
)
|
|
88
|
+
self.assertEqual(plugin.started_example_ids, [5, 7, 9] * 6)
|
|
89
|
+
self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
|
|
90
|
+
|
|
91
|
+
def test_shuffle_inputs(self):
|
|
92
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_shuffle_inputs')
|
|
93
|
+
exp = eval_test_helper.test_experiment()
|
|
94
|
+
plugin = eval_test_helper.TestPlugin()
|
|
95
|
+
run = exp.run(
|
|
96
|
+
root_dir, runner='sequential', plugins=[plugin], shuffle_inputs=True
|
|
97
|
+
)
|
|
98
|
+
self.assertTrue(run.shuffle_inputs)
|
|
99
|
+
|
|
100
|
+
def test_filter(self):
|
|
101
|
+
plugin = eval_test_helper.TestPlugin()
|
|
102
|
+
exp = eval_test_helper.test_experiment()
|
|
103
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_filter')
|
|
104
|
+
|
|
105
|
+
_ = exp.run(
|
|
106
|
+
root_dir, runner='sequential', plugins=[plugin],
|
|
107
|
+
filter=lambda e: e.lm.offset != 0
|
|
108
|
+
)
|
|
109
|
+
self.assert_same_list(
|
|
110
|
+
plugin.started_experiments,
|
|
111
|
+
exp.nonleaf_nodes + exp.leaf_nodes[2:]
|
|
112
|
+
)
|
|
113
|
+
self.assert_same_list(
|
|
114
|
+
plugin.skipped_experiments, exp.leaf_nodes[:2]
|
|
115
|
+
)
|
|
116
|
+
self.assert_same_list(
|
|
117
|
+
plugin.completed_experiments,
|
|
118
|
+
exp.leaf_nodes[2:] + [exp.children[1], exp]
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def test_use_cache(self):
|
|
122
|
+
@pg.functor()
|
|
123
|
+
def test_inputs(num_examples: int = 10):
|
|
124
|
+
return [
|
|
125
|
+
pg.Dict(
|
|
126
|
+
x=i // 2, y=(i // 2) ** 2,
|
|
127
|
+
groundtruth=(i // 2 + (i // 2) ** 2)
|
|
128
|
+
) for i in range(num_examples)
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
exp = eval_test_helper.TestEvaluation(
|
|
132
|
+
inputs=test_inputs(num_examples=pg.oneof([2, 4]))
|
|
133
|
+
)
|
|
134
|
+
# Global cache.
|
|
135
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'global_cache')
|
|
136
|
+
run = exp.run(
|
|
137
|
+
root_dir, 'new', runner='sequential', use_cache='global', plugins=[]
|
|
138
|
+
)
|
|
139
|
+
self.assertTrue(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
|
|
140
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 4)
|
|
141
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
|
|
142
|
+
|
|
143
|
+
# Per-dataset cache.
|
|
144
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'per_dataset')
|
|
145
|
+
run = exp.run(
|
|
146
|
+
root_dir, 'new', runner='sequential',
|
|
147
|
+
use_cache='per_dataset', plugins=[]
|
|
148
|
+
)
|
|
149
|
+
for leaf in exp.leaf_nodes:
|
|
150
|
+
self.assertTrue(
|
|
151
|
+
pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
|
|
152
|
+
)
|
|
153
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 3)
|
|
154
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
|
|
155
|
+
|
|
156
|
+
# No cache.
|
|
157
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'no')
|
|
158
|
+
run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
|
|
159
|
+
self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
|
|
160
|
+
for leaf in exp.leaf_nodes:
|
|
161
|
+
self.assertFalse(
|
|
162
|
+
pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
|
|
163
|
+
)
|
|
164
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 0)
|
|
165
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 6)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
if __name__ == '__main__':
|
|
169
|
+
unittest.main()
|