langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +0 -4
- langfun/core/eval/matching.py +2 -2
- langfun/core/eval/scoring.py +6 -2
- langfun/core/eval/v2/checkpointing.py +106 -72
- langfun/core/eval/v2/checkpointing_test.py +108 -3
- langfun/core/eval/v2/eval_test_helper.py +56 -0
- langfun/core/eval/v2/evaluation.py +25 -4
- langfun/core/eval/v2/evaluation_test.py +11 -0
- langfun/core/eval/v2/example.py +11 -1
- langfun/core/eval/v2/example_test.py +16 -2
- langfun/core/eval/v2/experiment.py +83 -19
- langfun/core/eval/v2/experiment_test.py +121 -3
- langfun/core/eval/v2/reporting.py +67 -20
- langfun/core/eval/v2/reporting_test.py +119 -2
- langfun/core/eval/v2/runners.py +7 -4
- langfun/core/llms/__init__.py +23 -24
- langfun/core/llms/anthropic.py +12 -0
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -310
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +23 -37
- langfun/core/llms/vertexai.py +28 -348
- langfun/core/llms/vertexai_test.py +6 -166
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +31 -31
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
- langfun/core/repr_utils.py +0 -204
- langfun/core/repr_utils_test.py +0 -90
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py
CHANGED
@@ -123,10 +123,6 @@ from langfun.core.memory import Memory
|
|
123
123
|
# Utility for console output.
|
124
124
|
from langfun.core import console
|
125
125
|
|
126
|
-
# Helpers for implementing _repr_xxx_ methods.
|
127
|
-
from langfun.core import repr_utils
|
128
|
-
Html = repr_utils.Html
|
129
|
-
|
130
126
|
# Utility for event logging.
|
131
127
|
from langfun.core import logging
|
132
128
|
|
langfun/core/eval/matching.py
CHANGED
@@ -251,14 +251,14 @@ class Matching(base.Evaluation):
|
|
251
251
|
for i, (_, example, output, message) in enumerate(self.matches):
|
252
252
|
bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
|
253
253
|
s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
|
254
|
-
input_str =
|
254
|
+
input_str = pg.Html.escape(
|
255
255
|
pg.format(
|
256
256
|
example, verbose=False, max_bytes_len=32,
|
257
257
|
custom_format=_maybe_html
|
258
258
|
)
|
259
259
|
)
|
260
260
|
s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
|
261
|
-
output_str =
|
261
|
+
output_str = pg.Html.escape(
|
262
262
|
pg.format(
|
263
263
|
output, verbose=False, max_bytes_len=32,
|
264
264
|
custom_format=_maybe_html
|
langfun/core/eval/scoring.py
CHANGED
@@ -194,9 +194,13 @@ class Scoring(base.Evaluation):
|
|
194
194
|
for i, (example, output, score, message) in enumerate(self.scored):
|
195
195
|
bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
|
196
196
|
s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
|
197
|
-
input_str = pg.
|
197
|
+
input_str = pg.Html.escape(
|
198
|
+
pg.format(example, verbose=False, max_bytes_len=32)
|
199
|
+
)
|
198
200
|
s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
|
199
|
-
output_str = pg.
|
201
|
+
output_str = pg.Html.escape(
|
202
|
+
pg.format(output, verbose=False, max_bytes_len=32)
|
203
|
+
)
|
200
204
|
s.write(f'<td style="color:blue;white-space:pre-wrap">{output_str}</td>')
|
201
205
|
s.write(f'<td style="color:magenta;white-space:pre-wrap">{score}</td>')
|
202
206
|
s.write('<td>')
|
@@ -13,8 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Checkpointing evaluation runs."""
|
15
15
|
import abc
|
16
|
+
import re
|
16
17
|
import threading
|
17
18
|
import traceback
|
19
|
+
from typing import Annotated
|
18
20
|
|
19
21
|
import langfun.core as lf
|
20
22
|
from langfun.core.eval.v2 import example as example_lib
|
@@ -29,6 +31,11 @@ Runner = experiment_lib.Runner
|
|
29
31
|
class Checkpointer(experiment_lib.Plugin):
|
30
32
|
"""Base class for checkpointing evaluation examples."""
|
31
33
|
|
34
|
+
checkpoint_filename: Annotated[
|
35
|
+
str,
|
36
|
+
'Checkpoint file pattern.'
|
37
|
+
] = 'checkpoint.bagz'
|
38
|
+
|
32
39
|
def on_experiment_start(
|
33
40
|
self,
|
34
41
|
runner: Runner,
|
@@ -37,37 +44,35 @@ class Checkpointer(experiment_lib.Plugin):
|
|
37
44
|
if not experiment.is_leaf:
|
38
45
|
return
|
39
46
|
|
40
|
-
|
41
|
-
if not
|
42
|
-
if
|
47
|
+
current_run = runner.current_run
|
48
|
+
if current_run.reprocess is not True: # pylint: disable=g-bool-id-comparison
|
49
|
+
if current_run.input_root != current_run.output_root:
|
43
50
|
experiment.info(
|
44
|
-
f'Warm starting from directory: {
|
51
|
+
f'Warm starting from directory: {current_run.input_root}.'
|
45
52
|
)
|
46
53
|
self._load_experiment(runner, experiment)
|
47
54
|
|
55
|
+
example_ids_to_evaluate = current_run.examples_to_evaluate(experiment)
|
48
56
|
if experiment.state.evaluated_examples:
|
49
57
|
loaded_example_ids = list(
|
50
58
|
sorted(experiment.state.evaluated_examples.keys())
|
51
59
|
)
|
52
|
-
example_ids_to_evaluate = (
|
53
|
-
set(runner.current_run.example_ids) if runner.current_run.example_ids
|
54
|
-
else set(range(1, experiment.num_examples + 1))
|
55
|
-
)
|
56
60
|
example_ids_to_evaluate -= set(loaded_example_ids)
|
57
|
-
|
61
|
+
example_ids_to_evaluate = list(sorted(example_ids_to_evaluate))
|
58
62
|
experiment.info(
|
59
|
-
f'{len(experiment.state.evaluated_examples)} examples
|
63
|
+
f'{len(experiment.state.evaluated_examples)} examples '
|
60
64
|
'loaded from checkpoint files. Their outputs will be used '
|
61
|
-
f'for recomputing metrics. Example IDs: {loaded_example_ids}'
|
65
|
+
f'for recomputing metrics. Example IDs: {loaded_example_ids}.'
|
62
66
|
)
|
63
67
|
experiment.info(
|
64
68
|
f'{len(example_ids_to_evaluate)} examples will be processed from '
|
65
|
-
f'scratch. Example IDs: {
|
69
|
+
f'scratch. Example IDs: {example_ids_to_evaluate}.'
|
66
70
|
)
|
67
71
|
else:
|
68
72
|
experiment.info(
|
69
73
|
'No examples are loaded from checkpoint files. '
|
70
|
-
f'
|
74
|
+
f'{len(example_ids_to_evaluate)} examples will be processed from '
|
75
|
+
f'scratch. Example IDs: {example_ids_to_evaluate}.'
|
71
76
|
)
|
72
77
|
|
73
78
|
def on_example_complete(
|
@@ -81,60 +86,36 @@ class Checkpointer(experiment_lib.Plugin):
|
|
81
86
|
experiment.warning(
|
82
87
|
f'Example {example.id} has error. Skipping checkpointing.'
|
83
88
|
)
|
84
|
-
|
89
|
+
elif example.newly_processed:
|
85
90
|
self._save_example(runner, experiment, example)
|
86
91
|
|
87
|
-
@abc.abstractmethod
|
88
|
-
def _load_experiment(self, runner: Runner, experiment: Experiment) -> None:
|
89
|
-
"""Loads the experiment state from checkpoint files."""
|
90
|
-
|
91
|
-
@abc.abstractmethod
|
92
|
-
def _save_example(
|
93
|
-
self,
|
94
|
-
runner: Runner,
|
95
|
-
experiment: Experiment,
|
96
|
-
example: Example,
|
97
|
-
) -> None:
|
98
|
-
"""Saves an evaluated example."""
|
99
|
-
|
100
|
-
|
101
|
-
class PerExampleCheckpointer(Checkpointer):
|
102
|
-
"""Checkpointer that saves each example to a separate file."""
|
103
|
-
|
104
|
-
checkpoint_filename: str = 'checkpoint.bagz'
|
105
|
-
|
106
|
-
def _on_bound(self):
|
107
|
-
super()._on_bound()
|
108
|
-
prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename)
|
109
|
-
self._checkpoint_file_prefix = prefix
|
110
|
-
self._checkpoint_file_ext = ext
|
111
|
-
|
112
92
|
def _load_experiment(
|
113
93
|
self,
|
114
94
|
runner: Runner,
|
115
95
|
experiment: Experiment,
|
116
96
|
) -> None:
|
117
97
|
"""Creates the checkpoint file."""
|
118
|
-
|
119
|
-
if pg.io.path_exists(experiment_dir):
|
120
|
-
ckpt_files = [
|
121
|
-
runner.current_run.input_path_for(experiment, filename)
|
122
|
-
for filename in pg.io.listdir(experiment_dir)
|
123
|
-
if filename.startswith(self._checkpoint_file_prefix)
|
124
|
-
and filename.endswith(self._checkpoint_file_ext)
|
125
|
-
]
|
126
|
-
else:
|
127
|
-
ckpt_files = []
|
128
|
-
|
98
|
+
ckpt_files = self._list_checkpoint_filenames(runner, experiment)
|
129
99
|
experiment.info(f'Found {len(ckpt_files)} checkpoint files to load.')
|
130
100
|
|
131
101
|
# Load the checkpoint files in parallel.
|
102
|
+
current_run = runner.current_run
|
103
|
+
examples_to_load = current_run.examples_to_load(experiment)
|
104
|
+
examples_to_load_metadata = current_run.examples_to_load_metadata(
|
105
|
+
experiment
|
106
|
+
)
|
132
107
|
context = dict(counter=0, counter_lock=threading.Lock())
|
108
|
+
copy_ckpt = current_run.input_root != current_run.output_root
|
109
|
+
|
133
110
|
def _load_state(ckpt_file):
|
134
111
|
error = None
|
135
112
|
with pg.timeit() as t:
|
136
113
|
try:
|
137
|
-
experiment.load_state(
|
114
|
+
experiment.load_state(
|
115
|
+
current_run.input_path_for(experiment, ckpt_file),
|
116
|
+
filter=lambda x: x.id in examples_to_load,
|
117
|
+
load_example_metadata=lambda x: x.id in examples_to_load_metadata,
|
118
|
+
)
|
138
119
|
except BaseException as e: # pylint: disable=broad-except
|
139
120
|
error = e
|
140
121
|
finally:
|
@@ -144,21 +125,80 @@ class PerExampleCheckpointer(Checkpointer):
|
|
144
125
|
progress_str = f'{context["counter"]}/{len(ckpt_files)}'
|
145
126
|
if error is None:
|
146
127
|
experiment.info(
|
147
|
-
f'
|
128
|
+
f'Checkpoint file {ckpt_file!r} loaded in {t.elapse:.2f} '
|
148
129
|
f'seconds. ({progress_str})'
|
149
130
|
)
|
150
131
|
else:
|
151
132
|
experiment.warning(
|
152
|
-
f'Failed to load checkpoint file {ckpt_file}: {error}. '
|
133
|
+
f'Failed to load checkpoint file {ckpt_file!r}: {error}. '
|
153
134
|
f'Skipping the file. ({progress_str})'
|
154
135
|
)
|
155
136
|
|
137
|
+
if not copy_ckpt:
|
138
|
+
return
|
139
|
+
|
140
|
+
# Copy the checkpoint records to the output directory.
|
141
|
+
try:
|
142
|
+
with pg.io.open_sequence(
|
143
|
+
current_run.output_path_for(experiment, ckpt_file), 'w'
|
144
|
+
) as o, pg.io.open_sequence(
|
145
|
+
current_run.input_path_for(experiment, ckpt_file), 'r'
|
146
|
+
) as i:
|
147
|
+
for x in i:
|
148
|
+
o.add(x)
|
149
|
+
except BaseException as e: # pylint: disable=broad-except
|
150
|
+
experiment.warning(
|
151
|
+
f'Failed to copy checkpoint {ckpt_file!r}: {e}.'
|
152
|
+
)
|
153
|
+
|
156
154
|
_ = list(
|
157
155
|
lf.concurrent_map(
|
158
156
|
_load_state, ckpt_files, max_workers=16, silence_on_errors=None
|
159
157
|
)
|
160
158
|
)
|
161
159
|
|
160
|
+
@abc.abstractmethod
|
161
|
+
def _list_checkpoint_filenames(
|
162
|
+
self, runner: Runner, experiment: Experiment
|
163
|
+
) -> list[str]:
|
164
|
+
"""Lists the checkpoint filenames to restore."""
|
165
|
+
|
166
|
+
@abc.abstractmethod
|
167
|
+
def _save_example(
|
168
|
+
self,
|
169
|
+
runner: Runner,
|
170
|
+
experiment: Experiment,
|
171
|
+
example: Example,
|
172
|
+
) -> None:
|
173
|
+
"""Saves an evaluated example."""
|
174
|
+
|
175
|
+
|
176
|
+
class PerExampleCheckpointer(Checkpointer):
|
177
|
+
"""Checkpointer that saves each example to a separate file."""
|
178
|
+
|
179
|
+
def _on_bound(self):
|
180
|
+
super()._on_bound()
|
181
|
+
prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename)
|
182
|
+
self._checkpoint_file_prefix = prefix
|
183
|
+
self._checkpoint_file_ext = ext
|
184
|
+
|
185
|
+
def _list_checkpoint_filenames(
|
186
|
+
self, runner: Runner, experiment: Experiment
|
187
|
+
) -> list[str]:
|
188
|
+
experiment_dir = runner.current_run.input_dir(experiment)
|
189
|
+
filenames = []
|
190
|
+
examples_to_load = runner.current_run.examples_to_load(experiment)
|
191
|
+
if pg.io.path_exists(experiment_dir):
|
192
|
+
regex = re.compile(
|
193
|
+
f'{self._checkpoint_file_prefix}_(\\d+){self._checkpoint_file_ext}'
|
194
|
+
.replace('.', '\\.')
|
195
|
+
)
|
196
|
+
for filename in pg.io.listdir(experiment_dir):
|
197
|
+
match = regex.match(filename)
|
198
|
+
if match and int(match.group(1)) in examples_to_load:
|
199
|
+
filenames.append(filename)
|
200
|
+
return filenames
|
201
|
+
|
162
202
|
def _save_example(
|
163
203
|
self,
|
164
204
|
runner: Runner,
|
@@ -180,11 +220,11 @@ class PerExampleCheckpointer(Checkpointer):
|
|
180
220
|
writer.add(example)
|
181
221
|
writer.close()
|
182
222
|
experiment.info(
|
183
|
-
f'Example {example.id}
|
223
|
+
f'Example {example.id} checkpointed to {writer.path}.',
|
184
224
|
)
|
185
225
|
except BaseException as e: # pylint: disable=broad-except
|
186
226
|
experiment.error(
|
187
|
-
f'Failed to
|
227
|
+
f'Failed to checkpoint example {example.id} to {writer.path}. '
|
188
228
|
f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
|
189
229
|
)
|
190
230
|
raise e
|
@@ -201,8 +241,6 @@ class PerExampleCheckpointer(Checkpointer):
|
|
201
241
|
class BulkCheckpointer(Checkpointer):
|
202
242
|
"""Checkpointer that saves all examples to a single file."""
|
203
243
|
|
204
|
-
checkpoint_filename: str = 'checkpoint.bagz'
|
205
|
-
|
206
244
|
def _on_bound(self):
|
207
245
|
super()._on_bound()
|
208
246
|
self._lock = threading.Lock()
|
@@ -253,18 +291,14 @@ class BulkCheckpointer(Checkpointer):
|
|
253
291
|
if self._sequence_writer is not None:
|
254
292
|
self._sequence_writer[experiment.id] = sequence_writer
|
255
293
|
|
256
|
-
def
|
257
|
-
self,
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
experiment, self.checkpoint_filename
|
265
|
-
),
|
266
|
-
raise_if_not_exist=False
|
267
|
-
)
|
294
|
+
def _list_checkpoint_filenames(
|
295
|
+
self, runner: Runner, experiment: Experiment
|
296
|
+
) -> list[str]:
|
297
|
+
if pg.io.path_exists(
|
298
|
+
runner.current_run.input_path_for(experiment, self.checkpoint_filename)
|
299
|
+
):
|
300
|
+
return [self.checkpoint_filename]
|
301
|
+
return []
|
268
302
|
|
269
303
|
def on_experiment_complete(
|
270
304
|
self,
|
@@ -299,11 +333,11 @@ class BulkCheckpointer(Checkpointer):
|
|
299
333
|
try:
|
300
334
|
writer.add(example)
|
301
335
|
experiment.info(
|
302
|
-
f'Example {example.id}
|
336
|
+
f'Example {example.id} checkpointed to {writer.path}.',
|
303
337
|
)
|
304
338
|
except BaseException as e: # pylint: disable=broad-except
|
305
339
|
experiment.error(
|
306
|
-
f'Failed to
|
340
|
+
f'Failed to checkpoint example {example.id} to {writer.path}. '
|
307
341
|
f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
|
308
342
|
)
|
309
343
|
raise e
|
@@ -316,7 +350,7 @@ class SequenceWriter:
|
|
316
350
|
def __init__(self, path: str):
|
317
351
|
self._lock = threading.Lock()
|
318
352
|
self._path = path
|
319
|
-
self._sequence_writer = pg.io.open_sequence(path, '
|
353
|
+
self._sequence_writer = pg.io.open_sequence(path, 'a')
|
320
354
|
|
321
355
|
@property
|
322
356
|
def path(self) -> str:
|
@@ -52,10 +52,20 @@ class SequenceWriterTest(unittest.TestCase):
|
|
52
52
|
self.assertEqual(len(list(iter(f))), 1)
|
53
53
|
|
54
54
|
|
55
|
-
class
|
55
|
+
class CheckpointerTest(unittest.TestCase):
|
56
|
+
|
57
|
+
def assert_found_in_log(self, experiment, message):
|
58
|
+
found_error_log = False
|
59
|
+
for log_entry in experiment._log_entries:
|
60
|
+
if log_entry.message.startswith(message):
|
61
|
+
found_error_log = True
|
62
|
+
break
|
63
|
+
self.assertTrue(found_error_log)
|
64
|
+
|
65
|
+
|
66
|
+
class PerExampleCheckpointerTest(CheckpointerTest):
|
56
67
|
|
57
68
|
def test_checkpointing(self):
|
58
|
-
pg.defaults.loggers.use_stdout()
|
59
69
|
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
|
60
70
|
experiment = eval_test_helper.test_experiment()
|
61
71
|
checkpoint_filename = 'checkpoint.jsonl'
|
@@ -85,8 +95,90 @@ class PerExampleCheckpointerTest(unittest.TestCase):
|
|
85
95
|
for leaf in experiment.leaf_nodes:
|
86
96
|
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
87
97
|
|
98
|
+
# Test warm start without reprocess.
|
99
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer2')
|
100
|
+
experiment = eval_test_helper.test_experiment()
|
101
|
+
_ = experiment.run(
|
102
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
103
|
+
warm_start_from=run.output_root
|
104
|
+
)
|
105
|
+
for leaf in experiment.leaf_nodes:
|
106
|
+
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
107
|
+
|
108
|
+
# Test warm start with reprocess.
|
109
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer3')
|
110
|
+
experiment = eval_test_helper.test_experiment()
|
111
|
+
_ = experiment.run(
|
112
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
113
|
+
warm_start_from=run.output_root,
|
114
|
+
reprocess=True
|
115
|
+
)
|
116
|
+
for leaf in experiment.leaf_nodes:
|
117
|
+
self.assertEqual(leaf.progress.num_skipped, 0)
|
88
118
|
|
89
|
-
|
119
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer4')
|
120
|
+
experiment = eval_test_helper.test_experiment()
|
121
|
+
_ = experiment.run(
|
122
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
123
|
+
warm_start_from=run.output_root,
|
124
|
+
reprocess=[1, 2, 3]
|
125
|
+
)
|
126
|
+
for leaf in experiment.leaf_nodes:
|
127
|
+
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id] - 3)
|
128
|
+
|
129
|
+
def test_loading_corrupted_checkpoint(self):
|
130
|
+
root_dir = os.path.join(
|
131
|
+
tempfile.gettempdir(),
|
132
|
+
'per_example_checkpointer_with_corrupted_checkpoint'
|
133
|
+
)
|
134
|
+
experiment = eval_test_helper.TestEvaluation()
|
135
|
+
checkpoint_filename = 'checkpoint.jsonl'
|
136
|
+
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
|
137
|
+
run = experiment.run(
|
138
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer]
|
139
|
+
)
|
140
|
+
num_processed = {}
|
141
|
+
for i in range(experiment.num_examples):
|
142
|
+
example = experiment.state.get(i + 1)
|
143
|
+
ckpt = run.output_path_for(experiment, f'checkpoint_{example.id}.jsonl')
|
144
|
+
if not example.has_error:
|
145
|
+
self.assertTrue(pg.io.path_exists(ckpt))
|
146
|
+
with pg.io.open_sequence(ckpt) as f:
|
147
|
+
self.assertEqual(len(list(iter(f))), 1)
|
148
|
+
|
149
|
+
# Simulate corrupting the first checkpoint.
|
150
|
+
if i == 0:
|
151
|
+
pg.io.writefile(ckpt, 'bad file')
|
152
|
+
num_processed[example.id] = i + 1
|
153
|
+
|
154
|
+
root_dir = os.path.join(
|
155
|
+
tempfile.gettempdir(),
|
156
|
+
'per_example_checkpointer_with_corrupted_checkpoint_warm_start'
|
157
|
+
)
|
158
|
+
experiment = eval_test_helper.TestEvaluation()
|
159
|
+
_ = experiment.run(
|
160
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
161
|
+
warm_start_from=run.output_root,
|
162
|
+
)
|
163
|
+
for leaf in experiment.leaf_nodes:
|
164
|
+
self.assertEqual(leaf.progress.num_skipped, len(num_processed) - 1)
|
165
|
+
self.assert_found_in_log(experiment, 'Failed to load checkpoint')
|
166
|
+
|
167
|
+
def test_checkpointing_error(self):
|
168
|
+
root_dir = os.path.join(
|
169
|
+
tempfile.gettempdir(),
|
170
|
+
'per_example_checkpointer_with_checkpointing_error'
|
171
|
+
)
|
172
|
+
experiment = (eval_test_helper
|
173
|
+
.test_experiment_with_example_checkpointing_error())
|
174
|
+
checkpointer = checkpointing.PerExampleCheckpointer('checkpoint.jsonl')
|
175
|
+
_ = experiment.run(
|
176
|
+
root_dir, 'new', runner='parallel', plugins=[checkpointer]
|
177
|
+
)
|
178
|
+
self.assert_found_in_log(experiment, 'Failed to checkpoint')
|
179
|
+
|
180
|
+
|
181
|
+
class BulkCheckpointerTest(CheckpointerTest):
|
90
182
|
|
91
183
|
def test_checkpointing(self):
|
92
184
|
root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
|
@@ -118,6 +210,19 @@ class BulkCheckpointerTest(unittest.TestCase):
|
|
118
210
|
for leaf in experiment.leaf_nodes:
|
119
211
|
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
120
212
|
|
213
|
+
def test_checkpointing_error(self):
|
214
|
+
root_dir = os.path.join(
|
215
|
+
tempfile.gettempdir(),
|
216
|
+
'bulk_checkpointer_with_checkpointing_error'
|
217
|
+
)
|
218
|
+
experiment = (eval_test_helper
|
219
|
+
.test_experiment_with_example_checkpointing_error())
|
220
|
+
checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
|
221
|
+
_ = experiment.run(
|
222
|
+
root_dir, 'new', runner='parallel', plugins=[checkpointer]
|
223
|
+
)
|
224
|
+
self.assert_found_in_log(experiment, 'Failed to checkpoint')
|
225
|
+
|
121
226
|
|
122
227
|
if __name__ == '__main__':
|
123
228
|
unittest.main()
|
@@ -72,9 +72,65 @@ class TestEvaluation(Evaluation):
|
|
72
72
|
)
|
73
73
|
|
74
74
|
|
75
|
+
class BadJsonConvertible(pg.Object):
|
76
|
+
|
77
|
+
def to_json(self, *args, **kwargs):
|
78
|
+
raise ValueError('Cannot convert to JSON.')
|
79
|
+
|
80
|
+
|
81
|
+
class TestEvaluationWithExampleCheckpointingError(TestEvaluation):
|
82
|
+
"""Test evaluation class with bad example checkpointing."""
|
83
|
+
inputs = test_inputs()
|
84
|
+
metrics = [metrics_lib.Match()]
|
85
|
+
|
86
|
+
def process(self, v):
|
87
|
+
return 1, dict(
|
88
|
+
x=BadJsonConvertible()
|
89
|
+
)
|
90
|
+
|
91
|
+
|
92
|
+
class BadHtmlConvertible(pg.Object, pg.views.HtmlTreeView.Extension):
|
93
|
+
|
94
|
+
def _html_tree_view(self, *args, **kwargs):
|
95
|
+
raise ValueError('Cannot render HTML.')
|
96
|
+
|
97
|
+
|
98
|
+
class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
|
99
|
+
"""Test evaluation class with bad example HTML generation."""
|
100
|
+
inputs = test_inputs()
|
101
|
+
metrics = [metrics_lib.Match()]
|
102
|
+
|
103
|
+
def process(self, v):
|
104
|
+
return 1, dict(
|
105
|
+
x=BadHtmlConvertible()
|
106
|
+
)
|
107
|
+
|
108
|
+
|
109
|
+
class TestEvaluationWithIndexHtmlGenerationError(TestEvaluation):
|
110
|
+
"""Test evaluation class with bad index HTML generation."""
|
111
|
+
|
112
|
+
def _html_tree_view(self, *args, **kwargs):
|
113
|
+
raise ValueError('Cannot render HTML.')
|
114
|
+
|
115
|
+
|
75
116
|
def test_experiment():
|
76
117
|
"""Returns a test experiment."""
|
77
118
|
return Suite([
|
78
119
|
TestEvaluation(lm=TestLLM(offset=0)),
|
79
120
|
TestEvaluation(lm=TestLLM(offset=pg.oneof(range(5)))),
|
80
121
|
])
|
122
|
+
|
123
|
+
|
124
|
+
def test_experiment_with_example_checkpointing_error():
|
125
|
+
"""Returns a test experiment with example checkpointing error."""
|
126
|
+
return TestEvaluationWithExampleCheckpointingError()
|
127
|
+
|
128
|
+
|
129
|
+
def test_experiment_with_example_html_generation_error():
|
130
|
+
"""Returns a test experiment with bad example HTML."""
|
131
|
+
return TestEvaluationWithExampleHtmlGenerationError()
|
132
|
+
|
133
|
+
|
134
|
+
def test_experiment_with_index_html_generation_error():
|
135
|
+
"""Returns a test experiment with bad index HTML."""
|
136
|
+
return TestEvaluationWithIndexHtmlGenerationError()
|
@@ -264,11 +264,21 @@ class Evaluation(experiment_lib.Experiment):
|
|
264
264
|
return self._state
|
265
265
|
|
266
266
|
def load_state(
|
267
|
-
self,
|
267
|
+
self,
|
268
|
+
state_file: str,
|
269
|
+
*,
|
270
|
+
load_example_metadata: bool = True,
|
271
|
+
filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
|
272
|
+
raise_if_not_exist: bool = False
|
268
273
|
) -> None:
|
269
274
|
"""Loads saved state from a sequence IO file."""
|
270
275
|
if pg.io.path_exists(state_file):
|
271
|
-
self._state.load(
|
276
|
+
self._state.load(
|
277
|
+
state_file,
|
278
|
+
example_input_by_id=self.example_input_by_id,
|
279
|
+
load_example_metadata=load_example_metadata,
|
280
|
+
filter=filter,
|
281
|
+
)
|
272
282
|
elif raise_if_not_exist:
|
273
283
|
raise ValueError(f'State file {state_file} does not exist.')
|
274
284
|
|
@@ -680,14 +690,25 @@ class EvaluationState:
|
|
680
690
|
self._evaluated_examples: dict[int, example_lib.Example] = {}
|
681
691
|
|
682
692
|
def load(
|
683
|
-
self,
|
693
|
+
self,
|
694
|
+
state_file: str,
|
695
|
+
*,
|
696
|
+
example_input_by_id: Callable[[int], Any] | None = None,
|
697
|
+
load_example_metadata: bool | Callable[
|
698
|
+
[example_lib.Example], bool] = True,
|
699
|
+
filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
|
700
|
+
) -> None:
|
684
701
|
"""Loads the state from the example sequence file."""
|
685
702
|
with pg.io.sequence.open_sequence(state_file) as f:
|
686
703
|
for record in f:
|
687
704
|
example = pg.from_json_str(
|
688
|
-
record,
|
705
|
+
record,
|
706
|
+
example_input_by_id=example_input_by_id,
|
707
|
+
load_example_metadata=load_example_metadata
|
689
708
|
)
|
690
709
|
assert isinstance(example, example_lib.Example), example
|
710
|
+
if filter is not None and not filter(example):
|
711
|
+
continue
|
691
712
|
self._evaluated_examples[example.id] = example
|
692
713
|
|
693
714
|
@property
|
@@ -138,6 +138,17 @@ class EvaluationTest(unittest.TestCase):
|
|
138
138
|
self.assertEqual(example.usage_summary.uncached.total.total_tokens, 0)
|
139
139
|
self.assertEqual(example.usage_summary.uncached.total.num_requests, 0)
|
140
140
|
|
141
|
+
# Test load_state with filter.
|
142
|
+
exp.reset()
|
143
|
+
self.assertEqual(len(exp._state.evaluated_examples), 0)
|
144
|
+
exp.load_state(state_file, filter=lambda x: x.id == 3)
|
145
|
+
self.assertEqual(len(exp._state.evaluated_examples), 1)
|
146
|
+
|
147
|
+
exp.reset()
|
148
|
+
self.assertEqual(len(exp._state.evaluated_examples), 0)
|
149
|
+
exp.load_state(state_file, filter=lambda x: x.id == 1)
|
150
|
+
self.assertEqual(len(exp._state.evaluated_examples), 0)
|
151
|
+
|
141
152
|
def test_html_view(self):
|
142
153
|
exp = eval_test_helper.TestEvaluation()
|
143
154
|
exp.debug('debug message')
|
langfun/core/eval/v2/example.py
CHANGED
@@ -101,6 +101,7 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
101
101
|
json_value: dict[str, Any],
|
102
102
|
*,
|
103
103
|
example_input_by_id: Callable[[int], Any] | None = None,
|
104
|
+
load_example_metadata: bool | Callable[['Example'], bool] = False,
|
104
105
|
**kwargs
|
105
106
|
) -> 'Example':
|
106
107
|
"""Creates an example from the JSON representation."""
|
@@ -128,12 +129,21 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
128
129
|
pg.traverse(example, _visit)
|
129
130
|
return list(referred_types)
|
130
131
|
|
132
|
+
# We delay loading the metadata until the other parts of the example are
|
133
|
+
# loaded. So we could apply the filter to decide whether to load the
|
134
|
+
# metadata.
|
135
|
+
metadata_dict = json_value.pop('metadata', None)
|
131
136
|
with pg.JSONConvertible.load_types_for_deserialization(
|
132
137
|
*example_class_defs(example_input)
|
133
138
|
):
|
134
|
-
|
139
|
+
example = cls(
|
135
140
|
**{k: pg.from_json(v, **kwargs) for k, v in json_value.items()}
|
136
141
|
)
|
142
|
+
if callable(load_example_metadata):
|
143
|
+
load_example_metadata = load_example_metadata(example)
|
144
|
+
if load_example_metadata:
|
145
|
+
example.metadata = pg.from_json(metadata_dict, **kwargs)
|
146
|
+
return example
|
137
147
|
|
138
148
|
#
|
139
149
|
# HTML rendering.
|
@@ -70,17 +70,31 @@ class ExampleTest(unittest.TestCase):
|
|
70
70
|
self.assertEqual(
|
71
71
|
pg.from_json_str(
|
72
72
|
json_str,
|
73
|
-
example_input_by_id=lambda i: inputs[i - 1]
|
73
|
+
example_input_by_id=lambda i: inputs[i - 1],
|
74
|
+
load_example_metadata=True,
|
74
75
|
),
|
75
76
|
ex
|
76
77
|
)
|
78
|
+
self.assertEqual(
|
79
|
+
pg.from_json_str(
|
80
|
+
json_str,
|
81
|
+
example_input_by_id=lambda i: inputs[i - 1],
|
82
|
+
load_example_metadata=False,
|
83
|
+
),
|
84
|
+
Example(
|
85
|
+
id=1,
|
86
|
+
input=inputs[0],
|
87
|
+
output=inputs[0].a(1),
|
88
|
+
metadata={}
|
89
|
+
)
|
90
|
+
)
|
77
91
|
pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
|
78
92
|
inputs[0].a.__type_name__
|
79
93
|
)
|
80
94
|
pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
|
81
95
|
inputs[0].b.__type_name__
|
82
96
|
)
|
83
|
-
v = pg.from_json_str(json_str, auto_dict=True)
|
97
|
+
v = pg.from_json_str(json_str, auto_dict=True, load_example_metadata=True)
|
84
98
|
v.output.pop('type_name')
|
85
99
|
v.metadata.b.pop('type_name')
|
86
100
|
self.assertEqual(
|