langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +17 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,380 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Checkpointing evaluation runs."""
|
15
|
+
import abc
|
16
|
+
import re
|
17
|
+
import threading
|
18
|
+
import traceback
|
19
|
+
from typing import Annotated
|
20
|
+
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core.eval.v2 import example as example_lib
|
23
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
24
|
+
import pyglove as pg
|
25
|
+
|
26
|
+
Example = example_lib.Example
|
27
|
+
Experiment = experiment_lib.Experiment
|
28
|
+
Runner = experiment_lib.Runner
|
29
|
+
|
30
|
+
|
31
|
+
class Checkpointer(experiment_lib.Plugin):
|
32
|
+
"""Base class for checkpointing evaluation examples."""
|
33
|
+
|
34
|
+
checkpoint_filename: Annotated[
|
35
|
+
str,
|
36
|
+
'Checkpoint file pattern.'
|
37
|
+
] = 'checkpoint.bagz'
|
38
|
+
|
39
|
+
def on_experiment_start(
|
40
|
+
self,
|
41
|
+
runner: Runner,
|
42
|
+
experiment: Experiment
|
43
|
+
) -> None:
|
44
|
+
if not experiment.is_leaf:
|
45
|
+
return
|
46
|
+
|
47
|
+
current_run = runner.current_run
|
48
|
+
if current_run.reprocess is not True: # pylint: disable=g-bool-id-comparison
|
49
|
+
if current_run.input_root != current_run.output_root:
|
50
|
+
experiment.info(
|
51
|
+
f'Warm starting from directory: {current_run.input_root}.'
|
52
|
+
)
|
53
|
+
self._load_experiment(runner, experiment)
|
54
|
+
|
55
|
+
example_ids_to_evaluate = current_run.examples_to_evaluate(experiment)
|
56
|
+
if experiment.state.evaluated_examples:
|
57
|
+
loaded_example_ids = list(
|
58
|
+
sorted(experiment.state.evaluated_examples.keys())
|
59
|
+
)
|
60
|
+
example_ids_to_evaluate -= set(loaded_example_ids)
|
61
|
+
example_ids_to_evaluate = list(sorted(example_ids_to_evaluate))
|
62
|
+
experiment.info(
|
63
|
+
f'{len(experiment.state.evaluated_examples)} examples '
|
64
|
+
'loaded from checkpoint files. Their outputs will be used '
|
65
|
+
f'for recomputing metrics. Example IDs: {loaded_example_ids}.'
|
66
|
+
)
|
67
|
+
experiment.info(
|
68
|
+
f'{len(example_ids_to_evaluate)} examples will be processed from '
|
69
|
+
f'scratch. Example IDs: {example_ids_to_evaluate}.'
|
70
|
+
)
|
71
|
+
else:
|
72
|
+
experiment.info(
|
73
|
+
'No examples are loaded from checkpoint files. '
|
74
|
+
f'{len(example_ids_to_evaluate)} examples will be processed from '
|
75
|
+
f'scratch. Example IDs: {example_ids_to_evaluate}.'
|
76
|
+
)
|
77
|
+
|
78
|
+
def on_example_complete(
|
79
|
+
self,
|
80
|
+
runner: Runner,
|
81
|
+
experiment: Experiment,
|
82
|
+
example: Example,
|
83
|
+
) -> None:
|
84
|
+
"""Saves the example to the checkpoint file."""
|
85
|
+
if example.has_error:
|
86
|
+
experiment.warning(
|
87
|
+
f'Example {example.id} has error. Skipping checkpointing.'
|
88
|
+
)
|
89
|
+
elif example.newly_processed:
|
90
|
+
self._save_example(runner, experiment, example)
|
91
|
+
|
92
|
+
def _load_experiment(
|
93
|
+
self,
|
94
|
+
runner: Runner,
|
95
|
+
experiment: Experiment,
|
96
|
+
) -> None:
|
97
|
+
"""Creates the checkpoint file."""
|
98
|
+
ckpt_files = self._list_checkpoint_filenames(runner, experiment)
|
99
|
+
experiment.info(f'Found {len(ckpt_files)} checkpoint files to load.')
|
100
|
+
|
101
|
+
# Load the checkpoint files in parallel.
|
102
|
+
current_run = runner.current_run
|
103
|
+
examples_to_load = current_run.examples_to_load(experiment)
|
104
|
+
examples_to_load_metadata = current_run.examples_to_load_metadata(
|
105
|
+
experiment
|
106
|
+
)
|
107
|
+
context = dict(counter=0, counter_lock=threading.Lock())
|
108
|
+
copy_ckpt = current_run.input_root != current_run.output_root
|
109
|
+
|
110
|
+
def _load_state(ckpt_file):
|
111
|
+
error = None
|
112
|
+
with pg.timeit() as t:
|
113
|
+
try:
|
114
|
+
experiment.load_state(
|
115
|
+
current_run.input_path_for(experiment, ckpt_file),
|
116
|
+
filter=lambda x: x.id in examples_to_load,
|
117
|
+
load_example_metadata=lambda x: x.id in examples_to_load_metadata,
|
118
|
+
)
|
119
|
+
except BaseException as e: # pylint: disable=broad-except
|
120
|
+
error = e
|
121
|
+
finally:
|
122
|
+
with context['counter_lock']:
|
123
|
+
context['counter'] += 1
|
124
|
+
|
125
|
+
progress_str = f'{context["counter"]}/{len(ckpt_files)}'
|
126
|
+
if error is None:
|
127
|
+
experiment.info(
|
128
|
+
f'Checkpoint file {ckpt_file!r} loaded in {t.elapse:.2f} '
|
129
|
+
f'seconds. ({progress_str})'
|
130
|
+
)
|
131
|
+
else:
|
132
|
+
experiment.warning(
|
133
|
+
f'Failed to load checkpoint file {ckpt_file!r}: {error}. '
|
134
|
+
f'Skipping the file. ({progress_str})'
|
135
|
+
)
|
136
|
+
|
137
|
+
if not copy_ckpt:
|
138
|
+
return
|
139
|
+
|
140
|
+
# Copy the checkpoint records to the output directory.
|
141
|
+
try:
|
142
|
+
with pg.io.open_sequence(
|
143
|
+
current_run.output_path_for(experiment, ckpt_file), 'w'
|
144
|
+
) as o, pg.io.open_sequence(
|
145
|
+
current_run.input_path_for(experiment, ckpt_file), 'r'
|
146
|
+
) as i:
|
147
|
+
for x in i:
|
148
|
+
o.add(x)
|
149
|
+
except BaseException as e: # pylint: disable=broad-except
|
150
|
+
experiment.warning(
|
151
|
+
f'Failed to copy checkpoint {ckpt_file!r}: {e}.'
|
152
|
+
)
|
153
|
+
|
154
|
+
_ = list(
|
155
|
+
lf.concurrent_map(
|
156
|
+
_load_state, ckpt_files, max_workers=16, silence_on_errors=None
|
157
|
+
)
|
158
|
+
)
|
159
|
+
|
160
|
+
@abc.abstractmethod
|
161
|
+
def _list_checkpoint_filenames(
|
162
|
+
self, runner: Runner, experiment: Experiment
|
163
|
+
) -> list[str]:
|
164
|
+
"""Lists the checkpoint filenames to restore."""
|
165
|
+
|
166
|
+
@abc.abstractmethod
|
167
|
+
def _save_example(
|
168
|
+
self,
|
169
|
+
runner: Runner,
|
170
|
+
experiment: Experiment,
|
171
|
+
example: Example,
|
172
|
+
) -> None:
|
173
|
+
"""Saves an evaluated example."""
|
174
|
+
|
175
|
+
|
176
|
+
class PerExampleCheckpointer(Checkpointer):
|
177
|
+
"""Checkpointer that saves each example to a separate file."""
|
178
|
+
|
179
|
+
def _on_bound(self):
|
180
|
+
super()._on_bound()
|
181
|
+
prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename)
|
182
|
+
self._checkpoint_file_prefix = prefix
|
183
|
+
self._checkpoint_file_ext = ext
|
184
|
+
|
185
|
+
def _list_checkpoint_filenames(
|
186
|
+
self, runner: Runner, experiment: Experiment
|
187
|
+
) -> list[str]:
|
188
|
+
experiment_dir = runner.current_run.input_dir(experiment)
|
189
|
+
filenames = []
|
190
|
+
examples_to_load = runner.current_run.examples_to_load(experiment)
|
191
|
+
if pg.io.path_exists(experiment_dir):
|
192
|
+
regex = re.compile(
|
193
|
+
f'{self._checkpoint_file_prefix}_(\\d+){self._checkpoint_file_ext}'
|
194
|
+
.replace('.', '\\.')
|
195
|
+
)
|
196
|
+
for filename in pg.io.listdir(experiment_dir):
|
197
|
+
match = regex.match(filename)
|
198
|
+
if match and int(match.group(1)) in examples_to_load:
|
199
|
+
filenames.append(filename)
|
200
|
+
return filenames
|
201
|
+
|
202
|
+
def _save_example(
|
203
|
+
self,
|
204
|
+
runner: Runner,
|
205
|
+
experiment: Experiment,
|
206
|
+
example: Example,
|
207
|
+
) -> None:
|
208
|
+
"""Saves the example to the checkpoint file."""
|
209
|
+
def save_state(example: Example):
|
210
|
+
writer = SequenceWriter(
|
211
|
+
runner.current_run.output_path_for(
|
212
|
+
experiment,
|
213
|
+
(
|
214
|
+
f'{self._checkpoint_file_prefix}_{example.id}'
|
215
|
+
f'{self._checkpoint_file_ext}'
|
216
|
+
)
|
217
|
+
)
|
218
|
+
)
|
219
|
+
try:
|
220
|
+
writer.add(example)
|
221
|
+
writer.close()
|
222
|
+
experiment.info(
|
223
|
+
f'Example {example.id} checkpointed to {writer.path}.',
|
224
|
+
)
|
225
|
+
except BaseException as e: # pylint: disable=broad-except
|
226
|
+
experiment.error(
|
227
|
+
f'Failed to checkpoint example {example.id} to {writer.path}. '
|
228
|
+
f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
|
229
|
+
)
|
230
|
+
raise e
|
231
|
+
runner.background_run(save_state, example)
|
232
|
+
|
233
|
+
def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]:
|
234
|
+
ext_index = filename.rfind('.')
|
235
|
+
if ext_index == -1:
|
236
|
+
return filename, ''
|
237
|
+
else:
|
238
|
+
return filename[:ext_index], filename[ext_index:]
|
239
|
+
|
240
|
+
|
241
|
+
class BulkCheckpointer(Checkpointer):
|
242
|
+
"""Checkpointer that saves all examples to a single file."""
|
243
|
+
|
244
|
+
def _on_bound(self):
|
245
|
+
super()._on_bound()
|
246
|
+
self._lock = threading.Lock()
|
247
|
+
self._sequence_writer = None
|
248
|
+
|
249
|
+
def on_run_start(
|
250
|
+
self,
|
251
|
+
runner: Runner,
|
252
|
+
root: Experiment,
|
253
|
+
) -> None:
|
254
|
+
self._sequence_writer = {}
|
255
|
+
|
256
|
+
def on_run_abort(
|
257
|
+
self,
|
258
|
+
runner: Runner,
|
259
|
+
root: Experiment,
|
260
|
+
error: BaseException
|
261
|
+
) -> None:
|
262
|
+
with self._lock:
|
263
|
+
if self._sequence_writer is not None:
|
264
|
+
for writer in self._sequence_writer.values():
|
265
|
+
writer.close()
|
266
|
+
self._sequence_writer.clear()
|
267
|
+
|
268
|
+
def on_run_complete(
|
269
|
+
self,
|
270
|
+
runner: Runner,
|
271
|
+
root: Experiment,
|
272
|
+
) -> None:
|
273
|
+
with self._lock:
|
274
|
+
assert self._sequence_writer is not None and not self._sequence_writer
|
275
|
+
|
276
|
+
def on_experiment_start(
|
277
|
+
self,
|
278
|
+
runner: Runner,
|
279
|
+
experiment: Experiment,
|
280
|
+
) -> None:
|
281
|
+
super().on_experiment_start(runner, experiment)
|
282
|
+
|
283
|
+
# Prepare the sequence writer for the experiment.
|
284
|
+
if experiment.is_leaf:
|
285
|
+
sequence_writer = SequenceWriter(
|
286
|
+
runner.current_run.output_path_for(
|
287
|
+
experiment, self.checkpoint_filename
|
288
|
+
)
|
289
|
+
)
|
290
|
+
with self._lock:
|
291
|
+
if self._sequence_writer is not None:
|
292
|
+
self._sequence_writer[experiment.id] = sequence_writer
|
293
|
+
|
294
|
+
def _list_checkpoint_filenames(
|
295
|
+
self, runner: Runner, experiment: Experiment
|
296
|
+
) -> list[str]:
|
297
|
+
if pg.io.path_exists(
|
298
|
+
runner.current_run.input_path_for(experiment, self.checkpoint_filename)
|
299
|
+
):
|
300
|
+
return [self.checkpoint_filename]
|
301
|
+
return []
|
302
|
+
|
303
|
+
def on_experiment_complete(
|
304
|
+
self,
|
305
|
+
runner: Runner,
|
306
|
+
experiment: Experiment,
|
307
|
+
) -> None:
|
308
|
+
"""Closes the checkpoint file."""
|
309
|
+
if not experiment.is_leaf:
|
310
|
+
return
|
311
|
+
assert experiment.id in self._sequence_writer
|
312
|
+
with self._lock:
|
313
|
+
if self._sequence_writer is not None:
|
314
|
+
# Make sure the writer is closed without delay so the file will be
|
315
|
+
# available immediately.
|
316
|
+
writer = self._sequence_writer.pop(experiment.id)
|
317
|
+
writer.close()
|
318
|
+
experiment.info(
|
319
|
+
f'{len(experiment.state.evaluated_examples)} examples are '
|
320
|
+
f'checkpointed to {writer.path}.'
|
321
|
+
)
|
322
|
+
|
323
|
+
def _save_example(
|
324
|
+
self,
|
325
|
+
runner: Runner,
|
326
|
+
experiment: Experiment,
|
327
|
+
example: Example,
|
328
|
+
) -> None:
|
329
|
+
"""Saves the example to the checkpoint file."""
|
330
|
+
assert experiment.id in self._sequence_writer
|
331
|
+
def _save_example(example: Example):
|
332
|
+
writer = self._sequence_writer[experiment.id]
|
333
|
+
try:
|
334
|
+
writer.add(example)
|
335
|
+
experiment.info(
|
336
|
+
f'Example {example.id} checkpointed to {writer.path}.',
|
337
|
+
)
|
338
|
+
except BaseException as e: # pylint: disable=broad-except
|
339
|
+
experiment.error(
|
340
|
+
f'Failed to checkpoint example {example.id} to {writer.path}. '
|
341
|
+
f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
|
342
|
+
)
|
343
|
+
raise e
|
344
|
+
runner.background_run(_save_example, example)
|
345
|
+
|
346
|
+
|
347
|
+
class SequenceWriter:
|
348
|
+
"""Thread safe sequence writer."""
|
349
|
+
|
350
|
+
def __init__(self, path: str):
|
351
|
+
self._lock = threading.Lock()
|
352
|
+
self._path = path
|
353
|
+
self._sequence_writer = pg.io.open_sequence(path, 'a')
|
354
|
+
|
355
|
+
@property
|
356
|
+
def path(self) -> str:
|
357
|
+
return self._path
|
358
|
+
|
359
|
+
def add(self, example: Example):
|
360
|
+
example_blob = pg.to_json_str(
|
361
|
+
example,
|
362
|
+
hide_default_values=True,
|
363
|
+
save_ref_value=True,
|
364
|
+
exclude_input=True
|
365
|
+
)
|
366
|
+
with self._lock:
|
367
|
+
if self._sequence_writer is None:
|
368
|
+
return
|
369
|
+
self._sequence_writer.add(example_blob)
|
370
|
+
|
371
|
+
def close(self):
|
372
|
+
# Make sure there is no write in progress.
|
373
|
+
with self._lock:
|
374
|
+
if self._sequence_writer is None:
|
375
|
+
return
|
376
|
+
self._sequence_writer.close()
|
377
|
+
self._sequence_writer = None
|
378
|
+
|
379
|
+
def __del__(self):
|
380
|
+
self.close()
|
@@ -0,0 +1,228 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
from langfun.core.eval.v2 import checkpointing
|
19
|
+
from langfun.core.eval.v2 import eval_test_helper
|
20
|
+
from langfun.core.eval.v2 import example as example_lib
|
21
|
+
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
Example = example_lib.Example
|
25
|
+
|
26
|
+
|
27
|
+
class SequenceWriterTest(unittest.TestCase):
|
28
|
+
|
29
|
+
def test_basic(self):
|
30
|
+
file = os.path.join(tempfile.gettempdir(), 'test.jsonl')
|
31
|
+
writer = checkpointing.SequenceWriter(file)
|
32
|
+
example = Example(id=1, input=pg.Dict(x=1), output=2)
|
33
|
+
writer.add(example)
|
34
|
+
del writer
|
35
|
+
self.assertTrue(pg.io.path_exists(file))
|
36
|
+
|
37
|
+
def test_error_handling(self):
|
38
|
+
file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl')
|
39
|
+
writer = checkpointing.SequenceWriter(file)
|
40
|
+
writer.add(Example(id=1, input=pg.Dict(x=1), output=2))
|
41
|
+
|
42
|
+
def f():
|
43
|
+
raise ValueError('Intentional error')
|
44
|
+
|
45
|
+
try:
|
46
|
+
writer.add(f())
|
47
|
+
except ValueError:
|
48
|
+
del writer
|
49
|
+
|
50
|
+
self.assertTrue(pg.io.path_exists(file))
|
51
|
+
with pg.io.open_sequence(file, 'r') as f:
|
52
|
+
self.assertEqual(len(list(iter(f))), 1)
|
53
|
+
|
54
|
+
|
55
|
+
class CheckpointerTest(unittest.TestCase):
|
56
|
+
|
57
|
+
def assert_found_in_log(self, experiment, message):
|
58
|
+
found_error_log = False
|
59
|
+
for log_entry in experiment._log_entries:
|
60
|
+
if log_entry.message.startswith(message):
|
61
|
+
found_error_log = True
|
62
|
+
break
|
63
|
+
self.assertTrue(found_error_log)
|
64
|
+
|
65
|
+
|
66
|
+
class PerExampleCheckpointerTest(CheckpointerTest):
|
67
|
+
|
68
|
+
def test_checkpointing(self):
|
69
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
|
70
|
+
experiment = eval_test_helper.test_experiment()
|
71
|
+
checkpoint_filename = 'checkpoint.jsonl'
|
72
|
+
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
|
73
|
+
run = experiment.run(
|
74
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer]
|
75
|
+
)
|
76
|
+
num_processed = {}
|
77
|
+
for leaf in experiment.leaf_nodes:
|
78
|
+
for i in range(leaf.num_examples):
|
79
|
+
example = leaf.state.get(i + 1)
|
80
|
+
ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
|
81
|
+
if example.has_error:
|
82
|
+
self.assertFalse(pg.io.path_exists(ckpt))
|
83
|
+
else:
|
84
|
+
self.assertTrue(pg.io.path_exists(ckpt))
|
85
|
+
with pg.io.open_sequence(ckpt) as f:
|
86
|
+
self.assertEqual(len(list(iter(f))), 1)
|
87
|
+
if leaf.id not in num_processed:
|
88
|
+
self.assertEqual(leaf.progress.num_skipped, 0)
|
89
|
+
num_processed[leaf.id] = leaf.progress.num_processed
|
90
|
+
|
91
|
+
# Run again, should skip existing.
|
92
|
+
_ = experiment.run(
|
93
|
+
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
|
94
|
+
)
|
95
|
+
for leaf in experiment.leaf_nodes:
|
96
|
+
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
97
|
+
|
98
|
+
# Test warm start without reprocess.
|
99
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer2')
|
100
|
+
experiment = eval_test_helper.test_experiment()
|
101
|
+
_ = experiment.run(
|
102
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
103
|
+
warm_start_from=run.output_root
|
104
|
+
)
|
105
|
+
for leaf in experiment.leaf_nodes:
|
106
|
+
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
107
|
+
|
108
|
+
# Test warm start with reprocess.
|
109
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer3')
|
110
|
+
experiment = eval_test_helper.test_experiment()
|
111
|
+
_ = experiment.run(
|
112
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
113
|
+
warm_start_from=run.output_root,
|
114
|
+
reprocess=True
|
115
|
+
)
|
116
|
+
for leaf in experiment.leaf_nodes:
|
117
|
+
self.assertEqual(leaf.progress.num_skipped, 0)
|
118
|
+
|
119
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer4')
|
120
|
+
experiment = eval_test_helper.test_experiment()
|
121
|
+
_ = experiment.run(
|
122
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
123
|
+
warm_start_from=run.output_root,
|
124
|
+
reprocess=[1, 2, 3]
|
125
|
+
)
|
126
|
+
for leaf in experiment.leaf_nodes:
|
127
|
+
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id] - 3)
|
128
|
+
|
129
|
+
def test_loading_corrupted_checkpoint(self):
|
130
|
+
root_dir = os.path.join(
|
131
|
+
tempfile.gettempdir(),
|
132
|
+
'per_example_checkpointer_with_corrupted_checkpoint'
|
133
|
+
)
|
134
|
+
experiment = eval_test_helper.TestEvaluation()
|
135
|
+
checkpoint_filename = 'checkpoint.jsonl'
|
136
|
+
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
|
137
|
+
run = experiment.run(
|
138
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer]
|
139
|
+
)
|
140
|
+
num_processed = {}
|
141
|
+
for i in range(experiment.num_examples):
|
142
|
+
example = experiment.state.get(i + 1)
|
143
|
+
ckpt = run.output_path_for(experiment, f'checkpoint_{example.id}.jsonl')
|
144
|
+
if not example.has_error:
|
145
|
+
self.assertTrue(pg.io.path_exists(ckpt))
|
146
|
+
with pg.io.open_sequence(ckpt) as f:
|
147
|
+
self.assertEqual(len(list(iter(f))), 1)
|
148
|
+
|
149
|
+
# Simulate corrupting the first checkpoint.
|
150
|
+
if i == 0:
|
151
|
+
pg.io.writefile(ckpt, 'bad file')
|
152
|
+
num_processed[example.id] = i + 1
|
153
|
+
|
154
|
+
root_dir = os.path.join(
|
155
|
+
tempfile.gettempdir(),
|
156
|
+
'per_example_checkpointer_with_corrupted_checkpoint_warm_start'
|
157
|
+
)
|
158
|
+
experiment = eval_test_helper.TestEvaluation()
|
159
|
+
_ = experiment.run(
|
160
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
161
|
+
warm_start_from=run.output_root,
|
162
|
+
)
|
163
|
+
for leaf in experiment.leaf_nodes:
|
164
|
+
self.assertEqual(leaf.progress.num_skipped, len(num_processed) - 1)
|
165
|
+
self.assert_found_in_log(experiment, 'Failed to load checkpoint')
|
166
|
+
|
167
|
+
def test_checkpointing_error(self):
|
168
|
+
root_dir = os.path.join(
|
169
|
+
tempfile.gettempdir(),
|
170
|
+
'per_example_checkpointer_with_checkpointing_error'
|
171
|
+
)
|
172
|
+
experiment = (eval_test_helper
|
173
|
+
.test_experiment_with_example_checkpointing_error())
|
174
|
+
checkpointer = checkpointing.PerExampleCheckpointer('checkpoint.jsonl')
|
175
|
+
_ = experiment.run(
|
176
|
+
root_dir, 'new', runner='parallel', plugins=[checkpointer]
|
177
|
+
)
|
178
|
+
self.assert_found_in_log(experiment, 'Failed to checkpoint')
|
179
|
+
|
180
|
+
|
181
|
+
class BulkCheckpointerTest(CheckpointerTest):
|
182
|
+
|
183
|
+
def test_checkpointing(self):
|
184
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
|
185
|
+
experiment = eval_test_helper.test_experiment()
|
186
|
+
checkpoint_filename = 'checkpoint.jsonl'
|
187
|
+
checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename)
|
188
|
+
run = experiment.run(
|
189
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer]
|
190
|
+
)
|
191
|
+
self.assertEqual(len(checkpointer._sequence_writer), 0)
|
192
|
+
num_processed = {}
|
193
|
+
for leaf in experiment.leaf_nodes:
|
194
|
+
ckpt = run.output_path_for(leaf, checkpoint_filename)
|
195
|
+
self.assertTrue(pg.io.path_exists(ckpt))
|
196
|
+
with pg.io.open_sequence(ckpt) as f:
|
197
|
+
self.assertEqual(
|
198
|
+
len(list(iter(f))),
|
199
|
+
leaf.progress.num_completed - leaf.progress.num_failed
|
200
|
+
)
|
201
|
+
if leaf.id not in num_processed:
|
202
|
+
self.assertEqual(leaf.progress.num_skipped, 0)
|
203
|
+
num_processed[leaf.id] = leaf.progress.num_processed
|
204
|
+
|
205
|
+
# Run again, should skip existing.
|
206
|
+
_ = experiment.run(
|
207
|
+
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
|
208
|
+
)
|
209
|
+
self.assertEqual(len(checkpointer._sequence_writer), 0)
|
210
|
+
for leaf in experiment.leaf_nodes:
|
211
|
+
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
212
|
+
|
213
|
+
def test_checkpointing_error(self):
|
214
|
+
root_dir = os.path.join(
|
215
|
+
tempfile.gettempdir(),
|
216
|
+
'bulk_checkpointer_with_checkpointing_error'
|
217
|
+
)
|
218
|
+
experiment = (eval_test_helper
|
219
|
+
.test_experiment_with_example_checkpointing_error())
|
220
|
+
checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
|
221
|
+
_ = experiment.run(
|
222
|
+
root_dir, 'new', runner='parallel', plugins=[checkpointer]
|
223
|
+
)
|
224
|
+
self.assert_found_in_log(experiment, 'Failed to checkpoint')
|
225
|
+
|
226
|
+
|
227
|
+
if __name__ == '__main__':
|
228
|
+
unittest.main()
|