langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,380 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Checkpointing evaluation runs."""
15
+ import abc
16
+ import re
17
+ import threading
18
+ import traceback
19
+ from typing import Annotated
20
+
21
+ import langfun.core as lf
22
+ from langfun.core.eval.v2 import example as example_lib
23
+ from langfun.core.eval.v2 import experiment as experiment_lib
24
+ import pyglove as pg
25
+
26
+ Example = example_lib.Example
27
+ Experiment = experiment_lib.Experiment
28
+ Runner = experiment_lib.Runner
29
+
30
+
31
+ class Checkpointer(experiment_lib.Plugin):
32
+ """Base class for checkpointing evaluation examples."""
33
+
34
+ checkpoint_filename: Annotated[
35
+ str,
36
+ 'Checkpoint file pattern.'
37
+ ] = 'checkpoint.bagz'
38
+
39
+ def on_experiment_start(
40
+ self,
41
+ runner: Runner,
42
+ experiment: Experiment
43
+ ) -> None:
44
+ if not experiment.is_leaf:
45
+ return
46
+
47
+ current_run = runner.current_run
48
+ if current_run.reprocess is not True: # pylint: disable=g-bool-id-comparison
49
+ if current_run.input_root != current_run.output_root:
50
+ experiment.info(
51
+ f'Warm starting from directory: {current_run.input_root}.'
52
+ )
53
+ self._load_experiment(runner, experiment)
54
+
55
+ example_ids_to_evaluate = current_run.examples_to_evaluate(experiment)
56
+ if experiment.state.evaluated_examples:
57
+ loaded_example_ids = list(
58
+ sorted(experiment.state.evaluated_examples.keys())
59
+ )
60
+ example_ids_to_evaluate -= set(loaded_example_ids)
61
+ example_ids_to_evaluate = list(sorted(example_ids_to_evaluate))
62
+ experiment.info(
63
+ f'{len(experiment.state.evaluated_examples)} examples '
64
+ 'loaded from checkpoint files. Their outputs will be used '
65
+ f'for recomputing metrics. Example IDs: {loaded_example_ids}.'
66
+ )
67
+ experiment.info(
68
+ f'{len(example_ids_to_evaluate)} examples will be processed from '
69
+ f'scratch. Example IDs: {example_ids_to_evaluate}.'
70
+ )
71
+ else:
72
+ experiment.info(
73
+ 'No examples are loaded from checkpoint files. '
74
+ f'{len(example_ids_to_evaluate)} examples will be processed from '
75
+ f'scratch. Example IDs: {example_ids_to_evaluate}.'
76
+ )
77
+
78
+ def on_example_complete(
79
+ self,
80
+ runner: Runner,
81
+ experiment: Experiment,
82
+ example: Example,
83
+ ) -> None:
84
+ """Saves the example to the checkpoint file."""
85
+ if example.has_error:
86
+ experiment.warning(
87
+ f'Example {example.id} has error. Skipping checkpointing.'
88
+ )
89
+ elif example.newly_processed:
90
+ self._save_example(runner, experiment, example)
91
+
92
+ def _load_experiment(
93
+ self,
94
+ runner: Runner,
95
+ experiment: Experiment,
96
+ ) -> None:
97
+ """Creates the checkpoint file."""
98
+ ckpt_files = self._list_checkpoint_filenames(runner, experiment)
99
+ experiment.info(f'Found {len(ckpt_files)} checkpoint files to load.')
100
+
101
+ # Load the checkpoint files in parallel.
102
+ current_run = runner.current_run
103
+ examples_to_load = current_run.examples_to_load(experiment)
104
+ examples_to_load_metadata = current_run.examples_to_load_metadata(
105
+ experiment
106
+ )
107
+ context = dict(counter=0, counter_lock=threading.Lock())
108
+ copy_ckpt = current_run.input_root != current_run.output_root
109
+
110
+ def _load_state(ckpt_file):
111
+ error = None
112
+ with pg.timeit() as t:
113
+ try:
114
+ experiment.load_state(
115
+ current_run.input_path_for(experiment, ckpt_file),
116
+ filter=lambda x: x.id in examples_to_load,
117
+ load_example_metadata=lambda x: x.id in examples_to_load_metadata,
118
+ )
119
+ except BaseException as e: # pylint: disable=broad-except
120
+ error = e
121
+ finally:
122
+ with context['counter_lock']:
123
+ context['counter'] += 1
124
+
125
+ progress_str = f'{context["counter"]}/{len(ckpt_files)}'
126
+ if error is None:
127
+ experiment.info(
128
+ f'Checkpoint file {ckpt_file!r} loaded in {t.elapse:.2f} '
129
+ f'seconds. ({progress_str})'
130
+ )
131
+ else:
132
+ experiment.warning(
133
+ f'Failed to load checkpoint file {ckpt_file!r}: {error}. '
134
+ f'Skipping the file. ({progress_str})'
135
+ )
136
+
137
+ if not copy_ckpt:
138
+ return
139
+
140
+ # Copy the checkpoint records to the output directory.
141
+ try:
142
+ with pg.io.open_sequence(
143
+ current_run.output_path_for(experiment, ckpt_file), 'w'
144
+ ) as o, pg.io.open_sequence(
145
+ current_run.input_path_for(experiment, ckpt_file), 'r'
146
+ ) as i:
147
+ for x in i:
148
+ o.add(x)
149
+ except BaseException as e: # pylint: disable=broad-except
150
+ experiment.warning(
151
+ f'Failed to copy checkpoint {ckpt_file!r}: {e}.'
152
+ )
153
+
154
+ _ = list(
155
+ lf.concurrent_map(
156
+ _load_state, ckpt_files, max_workers=16, silence_on_errors=None
157
+ )
158
+ )
159
+
160
+ @abc.abstractmethod
161
+ def _list_checkpoint_filenames(
162
+ self, runner: Runner, experiment: Experiment
163
+ ) -> list[str]:
164
+ """Lists the checkpoint filenames to restore."""
165
+
166
+ @abc.abstractmethod
167
+ def _save_example(
168
+ self,
169
+ runner: Runner,
170
+ experiment: Experiment,
171
+ example: Example,
172
+ ) -> None:
173
+ """Saves an evaluated example."""
174
+
175
+
176
+ class PerExampleCheckpointer(Checkpointer):
177
+ """Checkpointer that saves each example to a separate file."""
178
+
179
+ def _on_bound(self):
180
+ super()._on_bound()
181
+ prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename)
182
+ self._checkpoint_file_prefix = prefix
183
+ self._checkpoint_file_ext = ext
184
+
185
+ def _list_checkpoint_filenames(
186
+ self, runner: Runner, experiment: Experiment
187
+ ) -> list[str]:
188
+ experiment_dir = runner.current_run.input_dir(experiment)
189
+ filenames = []
190
+ examples_to_load = runner.current_run.examples_to_load(experiment)
191
+ if pg.io.path_exists(experiment_dir):
192
+ regex = re.compile(
193
+ f'{self._checkpoint_file_prefix}_(\\d+){self._checkpoint_file_ext}'
194
+ .replace('.', '\\.')
195
+ )
196
+ for filename in pg.io.listdir(experiment_dir):
197
+ match = regex.match(filename)
198
+ if match and int(match.group(1)) in examples_to_load:
199
+ filenames.append(filename)
200
+ return filenames
201
+
202
+ def _save_example(
203
+ self,
204
+ runner: Runner,
205
+ experiment: Experiment,
206
+ example: Example,
207
+ ) -> None:
208
+ """Saves the example to the checkpoint file."""
209
+ def save_state(example: Example):
210
+ writer = SequenceWriter(
211
+ runner.current_run.output_path_for(
212
+ experiment,
213
+ (
214
+ f'{self._checkpoint_file_prefix}_{example.id}'
215
+ f'{self._checkpoint_file_ext}'
216
+ )
217
+ )
218
+ )
219
+ try:
220
+ writer.add(example)
221
+ writer.close()
222
+ experiment.info(
223
+ f'Example {example.id} checkpointed to {writer.path}.',
224
+ )
225
+ except BaseException as e: # pylint: disable=broad-except
226
+ experiment.error(
227
+ f'Failed to checkpoint example {example.id} to {writer.path}. '
228
+ f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
229
+ )
230
+ raise e
231
+ runner.background_run(save_state, example)
232
+
233
+ def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]:
234
+ ext_index = filename.rfind('.')
235
+ if ext_index == -1:
236
+ return filename, ''
237
+ else:
238
+ return filename[:ext_index], filename[ext_index:]
239
+
240
+
241
+ class BulkCheckpointer(Checkpointer):
242
+ """Checkpointer that saves all examples to a single file."""
243
+
244
+ def _on_bound(self):
245
+ super()._on_bound()
246
+ self._lock = threading.Lock()
247
+ self._sequence_writer = None
248
+
249
+ def on_run_start(
250
+ self,
251
+ runner: Runner,
252
+ root: Experiment,
253
+ ) -> None:
254
+ self._sequence_writer = {}
255
+
256
+ def on_run_abort(
257
+ self,
258
+ runner: Runner,
259
+ root: Experiment,
260
+ error: BaseException
261
+ ) -> None:
262
+ with self._lock:
263
+ if self._sequence_writer is not None:
264
+ for writer in self._sequence_writer.values():
265
+ writer.close()
266
+ self._sequence_writer.clear()
267
+
268
+ def on_run_complete(
269
+ self,
270
+ runner: Runner,
271
+ root: Experiment,
272
+ ) -> None:
273
+ with self._lock:
274
+ assert self._sequence_writer is not None and not self._sequence_writer
275
+
276
+ def on_experiment_start(
277
+ self,
278
+ runner: Runner,
279
+ experiment: Experiment,
280
+ ) -> None:
281
+ super().on_experiment_start(runner, experiment)
282
+
283
+ # Prepare the sequence writer for the experiment.
284
+ if experiment.is_leaf:
285
+ sequence_writer = SequenceWriter(
286
+ runner.current_run.output_path_for(
287
+ experiment, self.checkpoint_filename
288
+ )
289
+ )
290
+ with self._lock:
291
+ if self._sequence_writer is not None:
292
+ self._sequence_writer[experiment.id] = sequence_writer
293
+
294
+ def _list_checkpoint_filenames(
295
+ self, runner: Runner, experiment: Experiment
296
+ ) -> list[str]:
297
+ if pg.io.path_exists(
298
+ runner.current_run.input_path_for(experiment, self.checkpoint_filename)
299
+ ):
300
+ return [self.checkpoint_filename]
301
+ return []
302
+
303
+ def on_experiment_complete(
304
+ self,
305
+ runner: Runner,
306
+ experiment: Experiment,
307
+ ) -> None:
308
+ """Closes the checkpoint file."""
309
+ if not experiment.is_leaf:
310
+ return
311
+ assert experiment.id in self._sequence_writer
312
+ with self._lock:
313
+ if self._sequence_writer is not None:
314
+ # Make sure the writer is closed without delay so the file will be
315
+ # available immediately.
316
+ writer = self._sequence_writer.pop(experiment.id)
317
+ writer.close()
318
+ experiment.info(
319
+ f'{len(experiment.state.evaluated_examples)} examples are '
320
+ f'checkpointed to {writer.path}.'
321
+ )
322
+
323
+ def _save_example(
324
+ self,
325
+ runner: Runner,
326
+ experiment: Experiment,
327
+ example: Example,
328
+ ) -> None:
329
+ """Saves the example to the checkpoint file."""
330
+ assert experiment.id in self._sequence_writer
331
+ def _save_example(example: Example):
332
+ writer = self._sequence_writer[experiment.id]
333
+ try:
334
+ writer.add(example)
335
+ experiment.info(
336
+ f'Example {example.id} checkpointed to {writer.path}.',
337
+ )
338
+ except BaseException as e: # pylint: disable=broad-except
339
+ experiment.error(
340
+ f'Failed to checkpoint example {example.id} to {writer.path}. '
341
+ f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
342
+ )
343
+ raise e
344
+ runner.background_run(_save_example, example)
345
+
346
+
347
+ class SequenceWriter:
348
+ """Thread safe sequence writer."""
349
+
350
+ def __init__(self, path: str):
351
+ self._lock = threading.Lock()
352
+ self._path = path
353
+ self._sequence_writer = pg.io.open_sequence(path, 'a')
354
+
355
+ @property
356
+ def path(self) -> str:
357
+ return self._path
358
+
359
+ def add(self, example: Example):
360
+ example_blob = pg.to_json_str(
361
+ example,
362
+ hide_default_values=True,
363
+ save_ref_value=True,
364
+ exclude_input=True
365
+ )
366
+ with self._lock:
367
+ if self._sequence_writer is None:
368
+ return
369
+ self._sequence_writer.add(example_blob)
370
+
371
+ def close(self):
372
+ # Make sure there is no write in progress.
373
+ with self._lock:
374
+ if self._sequence_writer is None:
375
+ return
376
+ self._sequence_writer.close()
377
+ self._sequence_writer = None
378
+
379
+ def __del__(self):
380
+ self.close()
@@ -0,0 +1,228 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import unittest
17
+
18
+ from langfun.core.eval.v2 import checkpointing
19
+ from langfun.core.eval.v2 import eval_test_helper
20
+ from langfun.core.eval.v2 import example as example_lib
21
+ from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
22
+ import pyglove as pg
23
+
24
+ Example = example_lib.Example
25
+
26
+
27
+ class SequenceWriterTest(unittest.TestCase):
28
+
29
+ def test_basic(self):
30
+ file = os.path.join(tempfile.gettempdir(), 'test.jsonl')
31
+ writer = checkpointing.SequenceWriter(file)
32
+ example = Example(id=1, input=pg.Dict(x=1), output=2)
33
+ writer.add(example)
34
+ del writer
35
+ self.assertTrue(pg.io.path_exists(file))
36
+
37
+ def test_error_handling(self):
38
+ file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl')
39
+ writer = checkpointing.SequenceWriter(file)
40
+ writer.add(Example(id=1, input=pg.Dict(x=1), output=2))
41
+
42
+ def f():
43
+ raise ValueError('Intentional error')
44
+
45
+ try:
46
+ writer.add(f())
47
+ except ValueError:
48
+ del writer
49
+
50
+ self.assertTrue(pg.io.path_exists(file))
51
+ with pg.io.open_sequence(file, 'r') as f:
52
+ self.assertEqual(len(list(iter(f))), 1)
53
+
54
+
55
+ class CheckpointerTest(unittest.TestCase):
56
+
57
+ def assert_found_in_log(self, experiment, message):
58
+ found_error_log = False
59
+ for log_entry in experiment._log_entries:
60
+ if log_entry.message.startswith(message):
61
+ found_error_log = True
62
+ break
63
+ self.assertTrue(found_error_log)
64
+
65
+
66
+ class PerExampleCheckpointerTest(CheckpointerTest):
67
+
68
+ def test_checkpointing(self):
69
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
70
+ experiment = eval_test_helper.test_experiment()
71
+ checkpoint_filename = 'checkpoint.jsonl'
72
+ checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
73
+ run = experiment.run(
74
+ root_dir, 'new', runner='sequential', plugins=[checkpointer]
75
+ )
76
+ num_processed = {}
77
+ for leaf in experiment.leaf_nodes:
78
+ for i in range(leaf.num_examples):
79
+ example = leaf.state.get(i + 1)
80
+ ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
81
+ if example.has_error:
82
+ self.assertFalse(pg.io.path_exists(ckpt))
83
+ else:
84
+ self.assertTrue(pg.io.path_exists(ckpt))
85
+ with pg.io.open_sequence(ckpt) as f:
86
+ self.assertEqual(len(list(iter(f))), 1)
87
+ if leaf.id not in num_processed:
88
+ self.assertEqual(leaf.progress.num_skipped, 0)
89
+ num_processed[leaf.id] = leaf.progress.num_processed
90
+
91
+ # Run again, should skip existing.
92
+ _ = experiment.run(
93
+ root_dir, 'latest', runner='sequential', plugins=[checkpointer]
94
+ )
95
+ for leaf in experiment.leaf_nodes:
96
+ self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
97
+
98
+ # Test warm start without reprocess.
99
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer2')
100
+ experiment = eval_test_helper.test_experiment()
101
+ _ = experiment.run(
102
+ root_dir, 'new', runner='sequential', plugins=[checkpointer],
103
+ warm_start_from=run.output_root
104
+ )
105
+ for leaf in experiment.leaf_nodes:
106
+ self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
107
+
108
+ # Test warm start with reprocess.
109
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer3')
110
+ experiment = eval_test_helper.test_experiment()
111
+ _ = experiment.run(
112
+ root_dir, 'new', runner='sequential', plugins=[checkpointer],
113
+ warm_start_from=run.output_root,
114
+ reprocess=True
115
+ )
116
+ for leaf in experiment.leaf_nodes:
117
+ self.assertEqual(leaf.progress.num_skipped, 0)
118
+
119
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer4')
120
+ experiment = eval_test_helper.test_experiment()
121
+ _ = experiment.run(
122
+ root_dir, 'new', runner='sequential', plugins=[checkpointer],
123
+ warm_start_from=run.output_root,
124
+ reprocess=[1, 2, 3]
125
+ )
126
+ for leaf in experiment.leaf_nodes:
127
+ self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id] - 3)
128
+
129
+ def test_loading_corrupted_checkpoint(self):
130
+ root_dir = os.path.join(
131
+ tempfile.gettempdir(),
132
+ 'per_example_checkpointer_with_corrupted_checkpoint'
133
+ )
134
+ experiment = eval_test_helper.TestEvaluation()
135
+ checkpoint_filename = 'checkpoint.jsonl'
136
+ checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
137
+ run = experiment.run(
138
+ root_dir, 'new', runner='sequential', plugins=[checkpointer]
139
+ )
140
+ num_processed = {}
141
+ for i in range(experiment.num_examples):
142
+ example = experiment.state.get(i + 1)
143
+ ckpt = run.output_path_for(experiment, f'checkpoint_{example.id}.jsonl')
144
+ if not example.has_error:
145
+ self.assertTrue(pg.io.path_exists(ckpt))
146
+ with pg.io.open_sequence(ckpt) as f:
147
+ self.assertEqual(len(list(iter(f))), 1)
148
+
149
+ # Simulate corrupting the first checkpoint.
150
+ if i == 0:
151
+ pg.io.writefile(ckpt, 'bad file')
152
+ num_processed[example.id] = i + 1
153
+
154
+ root_dir = os.path.join(
155
+ tempfile.gettempdir(),
156
+ 'per_example_checkpointer_with_corrupted_checkpoint_warm_start'
157
+ )
158
+ experiment = eval_test_helper.TestEvaluation()
159
+ _ = experiment.run(
160
+ root_dir, 'new', runner='sequential', plugins=[checkpointer],
161
+ warm_start_from=run.output_root,
162
+ )
163
+ for leaf in experiment.leaf_nodes:
164
+ self.assertEqual(leaf.progress.num_skipped, len(num_processed) - 1)
165
+ self.assert_found_in_log(experiment, 'Failed to load checkpoint')
166
+
167
+ def test_checkpointing_error(self):
168
+ root_dir = os.path.join(
169
+ tempfile.gettempdir(),
170
+ 'per_example_checkpointer_with_checkpointing_error'
171
+ )
172
+ experiment = (eval_test_helper
173
+ .test_experiment_with_example_checkpointing_error())
174
+ checkpointer = checkpointing.PerExampleCheckpointer('checkpoint.jsonl')
175
+ _ = experiment.run(
176
+ root_dir, 'new', runner='parallel', plugins=[checkpointer]
177
+ )
178
+ self.assert_found_in_log(experiment, 'Failed to checkpoint')
179
+
180
+
181
+ class BulkCheckpointerTest(CheckpointerTest):
182
+
183
+ def test_checkpointing(self):
184
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
185
+ experiment = eval_test_helper.test_experiment()
186
+ checkpoint_filename = 'checkpoint.jsonl'
187
+ checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename)
188
+ run = experiment.run(
189
+ root_dir, 'new', runner='sequential', plugins=[checkpointer]
190
+ )
191
+ self.assertEqual(len(checkpointer._sequence_writer), 0)
192
+ num_processed = {}
193
+ for leaf in experiment.leaf_nodes:
194
+ ckpt = run.output_path_for(leaf, checkpoint_filename)
195
+ self.assertTrue(pg.io.path_exists(ckpt))
196
+ with pg.io.open_sequence(ckpt) as f:
197
+ self.assertEqual(
198
+ len(list(iter(f))),
199
+ leaf.progress.num_completed - leaf.progress.num_failed
200
+ )
201
+ if leaf.id not in num_processed:
202
+ self.assertEqual(leaf.progress.num_skipped, 0)
203
+ num_processed[leaf.id] = leaf.progress.num_processed
204
+
205
+ # Run again, should skip existing.
206
+ _ = experiment.run(
207
+ root_dir, 'latest', runner='sequential', plugins=[checkpointer]
208
+ )
209
+ self.assertEqual(len(checkpointer._sequence_writer), 0)
210
+ for leaf in experiment.leaf_nodes:
211
+ self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
212
+
213
+ def test_checkpointing_error(self):
214
+ root_dir = os.path.join(
215
+ tempfile.gettempdir(),
216
+ 'bulk_checkpointer_with_checkpointing_error'
217
+ )
218
+ experiment = (eval_test_helper
219
+ .test_experiment_with_example_checkpointing_error())
220
+ checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
221
+ _ = experiment.run(
222
+ root_dir, 'new', runner='parallel', plugins=[checkpointer]
223
+ )
224
+ self.assert_found_in_log(experiment, 'Failed to checkpoint')
225
+
226
+
227
+ if __name__ == '__main__':
228
+ unittest.main()