langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501030804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/core/__init__.py CHANGED
@@ -123,10 +123,6 @@ from langfun.core.memory import Memory
123
123
  # Utility for console output.
124
124
  from langfun.core import console
125
125
 
126
- # Helpers for implementing _repr_xxx_ methods.
127
- from langfun.core import repr_utils
128
- Html = repr_utils.Html
129
-
130
126
  # Utility for event logging.
131
127
  from langfun.core import logging
132
128
 
@@ -251,14 +251,14 @@ class Matching(base.Evaluation):
251
251
  for i, (_, example, output, message) in enumerate(self.matches):
252
252
  bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
253
253
  s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
254
- input_str = lf.repr_utils.escape_quoted(
254
+ input_str = pg.Html.escape(
255
255
  pg.format(
256
256
  example, verbose=False, max_bytes_len=32,
257
257
  custom_format=_maybe_html
258
258
  )
259
259
  )
260
260
  s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
261
- output_str = lf.repr_utils.escape_quoted(
261
+ output_str = pg.Html.escape(
262
262
  pg.format(
263
263
  output, verbose=False, max_bytes_len=32,
264
264
  custom_format=_maybe_html
@@ -194,9 +194,13 @@ class Scoring(base.Evaluation):
194
194
  for i, (example, output, score, message) in enumerate(self.scored):
195
195
  bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
196
196
  s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
197
- input_str = pg.format(example, verbose=False, max_bytes_len=32)
197
+ input_str = pg.Html.escape(
198
+ pg.format(example, verbose=False, max_bytes_len=32)
199
+ )
198
200
  s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
199
- output_str = pg.format(output, verbose=False, max_bytes_len=32)
201
+ output_str = pg.Html.escape(
202
+ pg.format(output, verbose=False, max_bytes_len=32)
203
+ )
200
204
  s.write(f'<td style="color:blue;white-space:pre-wrap">{output_str}</td>')
201
205
  s.write(f'<td style="color:magenta;white-space:pre-wrap">{score}</td>')
202
206
  s.write('<td>')
@@ -13,8 +13,10 @@
13
13
  # limitations under the License.
14
14
  """Checkpointing evaluation runs."""
15
15
  import abc
16
+ import re
16
17
  import threading
17
18
  import traceback
19
+ from typing import Annotated
18
20
 
19
21
  import langfun.core as lf
20
22
  from langfun.core.eval.v2 import example as example_lib
@@ -29,6 +31,11 @@ Runner = experiment_lib.Runner
29
31
  class Checkpointer(experiment_lib.Plugin):
30
32
  """Base class for checkpointing evaluation examples."""
31
33
 
34
+ checkpoint_filename: Annotated[
35
+ str,
36
+ 'Checkpoint file pattern.'
37
+ ] = 'checkpoint.bagz'
38
+
32
39
  def on_experiment_start(
33
40
  self,
34
41
  runner: Runner,
@@ -37,37 +44,35 @@ class Checkpointer(experiment_lib.Plugin):
37
44
  if not experiment.is_leaf:
38
45
  return
39
46
 
40
- # For refresh runs, we don't want to load the previous state.
41
- if not runner.current_run.refresh:
42
- if runner.current_run.input_root != runner.current_run.output_root:
47
+ current_run = runner.current_run
48
+ if current_run.reprocess is not True: # pylint: disable=g-bool-id-comparison
49
+ if current_run.input_root != current_run.output_root:
43
50
  experiment.info(
44
- f'Warm starting from directory: {runner.current_run.input_root}.'
51
+ f'Warm starting from directory: {current_run.input_root}.'
45
52
  )
46
53
  self._load_experiment(runner, experiment)
47
54
 
55
+ example_ids_to_evaluate = current_run.examples_to_evaluate(experiment)
48
56
  if experiment.state.evaluated_examples:
49
57
  loaded_example_ids = list(
50
58
  sorted(experiment.state.evaluated_examples.keys())
51
59
  )
52
- example_ids_to_evaluate = (
53
- set(runner.current_run.example_ids) if runner.current_run.example_ids
54
- else set(range(1, experiment.num_examples + 1))
55
- )
56
60
  example_ids_to_evaluate -= set(loaded_example_ids)
57
-
61
+ example_ids_to_evaluate = list(sorted(example_ids_to_evaluate))
58
62
  experiment.info(
59
- f'{len(experiment.state.evaluated_examples)} examples have been '
63
+ f'{len(experiment.state.evaluated_examples)} examples '
60
64
  'loaded from checkpoint files. Their outputs will be used '
61
- f'for recomputing metrics. Example IDs: {loaded_example_ids}'
65
+ f'for recomputing metrics. Example IDs: {loaded_example_ids}.'
62
66
  )
63
67
  experiment.info(
64
68
  f'{len(example_ids_to_evaluate)} examples will be processed from '
65
- f'scratch. Example IDs: {list(sorted(example_ids_to_evaluate))}'
69
+ f'scratch. Example IDs: {example_ids_to_evaluate}.'
66
70
  )
67
71
  else:
68
72
  experiment.info(
69
73
  'No examples are loaded from checkpoint files. '
70
- f'Experiment {experiment.id} starts from scratch.'
74
+ f'{len(example_ids_to_evaluate)} examples will be processed from '
75
+ f'scratch. Example IDs: {example_ids_to_evaluate}.'
71
76
  )
72
77
 
73
78
  def on_example_complete(
@@ -81,60 +86,36 @@ class Checkpointer(experiment_lib.Plugin):
81
86
  experiment.warning(
82
87
  f'Example {example.id} has error. Skipping checkpointing.'
83
88
  )
84
- else:
89
+ elif example.newly_processed:
85
90
  self._save_example(runner, experiment, example)
86
91
 
87
- @abc.abstractmethod
88
- def _load_experiment(self, runner: Runner, experiment: Experiment) -> None:
89
- """Loads the experiment state from checkpoint files."""
90
-
91
- @abc.abstractmethod
92
- def _save_example(
93
- self,
94
- runner: Runner,
95
- experiment: Experiment,
96
- example: Example,
97
- ) -> None:
98
- """Saves an evaluated example."""
99
-
100
-
101
- class PerExampleCheckpointer(Checkpointer):
102
- """Checkpointer that saves each example to a separate file."""
103
-
104
- checkpoint_filename: str = 'checkpoint.bagz'
105
-
106
- def _on_bound(self):
107
- super()._on_bound()
108
- prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename)
109
- self._checkpoint_file_prefix = prefix
110
- self._checkpoint_file_ext = ext
111
-
112
92
  def _load_experiment(
113
93
  self,
114
94
  runner: Runner,
115
95
  experiment: Experiment,
116
96
  ) -> None:
117
97
  """Creates the checkpoint file."""
118
- experiment_dir = runner.current_run.input_dir(experiment)
119
- if pg.io.path_exists(experiment_dir):
120
- ckpt_files = [
121
- runner.current_run.input_path_for(experiment, filename)
122
- for filename in pg.io.listdir(experiment_dir)
123
- if filename.startswith(self._checkpoint_file_prefix)
124
- and filename.endswith(self._checkpoint_file_ext)
125
- ]
126
- else:
127
- ckpt_files = []
128
-
98
+ ckpt_files = self._list_checkpoint_filenames(runner, experiment)
129
99
  experiment.info(f'Found {len(ckpt_files)} checkpoint files to load.')
130
100
 
131
101
  # Load the checkpoint files in parallel.
102
+ current_run = runner.current_run
103
+ examples_to_load = current_run.examples_to_load(experiment)
104
+ examples_to_load_metadata = current_run.examples_to_load_metadata(
105
+ experiment
106
+ )
132
107
  context = dict(counter=0, counter_lock=threading.Lock())
108
+ copy_ckpt = current_run.input_root != current_run.output_root
109
+
133
110
  def _load_state(ckpt_file):
134
111
  error = None
135
112
  with pg.timeit() as t:
136
113
  try:
137
- experiment.load_state(ckpt_file)
114
+ experiment.load_state(
115
+ current_run.input_path_for(experiment, ckpt_file),
116
+ filter=lambda x: x.id in examples_to_load,
117
+ load_example_metadata=lambda x: x.id in examples_to_load_metadata,
118
+ )
138
119
  except BaseException as e: # pylint: disable=broad-except
139
120
  error = e
140
121
  finally:
@@ -144,21 +125,80 @@ class PerExampleCheckpointer(Checkpointer):
144
125
  progress_str = f'{context["counter"]}/{len(ckpt_files)}'
145
126
  if error is None:
146
127
  experiment.info(
147
- f'Loaded checkpoint file {ckpt_file} in {t.elapse:.2f} '
128
+ f'Checkpoint file {ckpt_file!r} loaded in {t.elapse:.2f} '
148
129
  f'seconds. ({progress_str})'
149
130
  )
150
131
  else:
151
132
  experiment.warning(
152
- f'Failed to load checkpoint file {ckpt_file}: {error}. '
133
+ f'Failed to load checkpoint file {ckpt_file!r}: {error}. '
153
134
  f'Skipping the file. ({progress_str})'
154
135
  )
155
136
 
137
+ if not copy_ckpt:
138
+ return
139
+
140
+ # Copy the checkpoint records to the output directory.
141
+ try:
142
+ with pg.io.open_sequence(
143
+ current_run.output_path_for(experiment, ckpt_file), 'w'
144
+ ) as o, pg.io.open_sequence(
145
+ current_run.input_path_for(experiment, ckpt_file), 'r'
146
+ ) as i:
147
+ for x in i:
148
+ o.add(x)
149
+ except BaseException as e: # pylint: disable=broad-except
150
+ experiment.warning(
151
+ f'Failed to copy checkpoint {ckpt_file!r}: {e}.'
152
+ )
153
+
156
154
  _ = list(
157
155
  lf.concurrent_map(
158
156
  _load_state, ckpt_files, max_workers=16, silence_on_errors=None
159
157
  )
160
158
  )
161
159
 
160
+ @abc.abstractmethod
161
+ def _list_checkpoint_filenames(
162
+ self, runner: Runner, experiment: Experiment
163
+ ) -> list[str]:
164
+ """Lists the checkpoint filenames to restore."""
165
+
166
+ @abc.abstractmethod
167
+ def _save_example(
168
+ self,
169
+ runner: Runner,
170
+ experiment: Experiment,
171
+ example: Example,
172
+ ) -> None:
173
+ """Saves an evaluated example."""
174
+
175
+
176
+ class PerExampleCheckpointer(Checkpointer):
177
+ """Checkpointer that saves each example to a separate file."""
178
+
179
+ def _on_bound(self):
180
+ super()._on_bound()
181
+ prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename)
182
+ self._checkpoint_file_prefix = prefix
183
+ self._checkpoint_file_ext = ext
184
+
185
+ def _list_checkpoint_filenames(
186
+ self, runner: Runner, experiment: Experiment
187
+ ) -> list[str]:
188
+ experiment_dir = runner.current_run.input_dir(experiment)
189
+ filenames = []
190
+ examples_to_load = runner.current_run.examples_to_load(experiment)
191
+ if pg.io.path_exists(experiment_dir):
192
+ regex = re.compile(
193
+ f'{self._checkpoint_file_prefix}_(\\d+){self._checkpoint_file_ext}'
194
+ .replace('.', '\\.')
195
+ )
196
+ for filename in pg.io.listdir(experiment_dir):
197
+ match = regex.match(filename)
198
+ if match and int(match.group(1)) in examples_to_load:
199
+ filenames.append(filename)
200
+ return filenames
201
+
162
202
  def _save_example(
163
203
  self,
164
204
  runner: Runner,
@@ -180,11 +220,11 @@ class PerExampleCheckpointer(Checkpointer):
180
220
  writer.add(example)
181
221
  writer.close()
182
222
  experiment.info(
183
- f'Example {example.id} saved to {writer.path}.',
223
+ f'Example {example.id} checkpointed to {writer.path}.',
184
224
  )
185
225
  except BaseException as e: # pylint: disable=broad-except
186
226
  experiment.error(
187
- f'Failed to save example {example.id} to {writer.path}. '
227
+ f'Failed to checkpoint example {example.id} to {writer.path}. '
188
228
  f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
189
229
  )
190
230
  raise e
@@ -201,8 +241,6 @@ class PerExampleCheckpointer(Checkpointer):
201
241
  class BulkCheckpointer(Checkpointer):
202
242
  """Checkpointer that saves all examples to a single file."""
203
243
 
204
- checkpoint_filename: str = 'checkpoint.bagz'
205
-
206
244
  def _on_bound(self):
207
245
  super()._on_bound()
208
246
  self._lock = threading.Lock()
@@ -253,18 +291,14 @@ class BulkCheckpointer(Checkpointer):
253
291
  if self._sequence_writer is not None:
254
292
  self._sequence_writer[experiment.id] = sequence_writer
255
293
 
256
- def _load_experiment(
257
- self,
258
- runner: Runner,
259
- experiment: Experiment,
260
- ) -> None:
261
- """Creates the checkpoint file."""
262
- experiment.load_state(
263
- runner.current_run.input_path_for(
264
- experiment, self.checkpoint_filename
265
- ),
266
- raise_if_not_exist=False
267
- )
294
+ def _list_checkpoint_filenames(
295
+ self, runner: Runner, experiment: Experiment
296
+ ) -> list[str]:
297
+ if pg.io.path_exists(
298
+ runner.current_run.input_path_for(experiment, self.checkpoint_filename)
299
+ ):
300
+ return [self.checkpoint_filename]
301
+ return []
268
302
 
269
303
  def on_experiment_complete(
270
304
  self,
@@ -299,11 +333,11 @@ class BulkCheckpointer(Checkpointer):
299
333
  try:
300
334
  writer.add(example)
301
335
  experiment.info(
302
- f'Example {example.id} added to {writer.path}.',
336
+ f'Example {example.id} checkpointed to {writer.path}.',
303
337
  )
304
338
  except BaseException as e: # pylint: disable=broad-except
305
339
  experiment.error(
306
- f'Failed to save example {example.id} to {writer.path}. '
340
+ f'Failed to checkpoint example {example.id} to {writer.path}. '
307
341
  f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
308
342
  )
309
343
  raise e
@@ -316,7 +350,7 @@ class SequenceWriter:
316
350
  def __init__(self, path: str):
317
351
  self._lock = threading.Lock()
318
352
  self._path = path
319
- self._sequence_writer = pg.io.open_sequence(path, 'w')
353
+ self._sequence_writer = pg.io.open_sequence(path, 'a')
320
354
 
321
355
  @property
322
356
  def path(self) -> str:
@@ -52,10 +52,20 @@ class SequenceWriterTest(unittest.TestCase):
52
52
  self.assertEqual(len(list(iter(f))), 1)
53
53
 
54
54
 
55
- class PerExampleCheckpointerTest(unittest.TestCase):
55
+ class CheckpointerTest(unittest.TestCase):
56
+
57
+ def assert_found_in_log(self, experiment, message):
58
+ found_error_log = False
59
+ for log_entry in experiment._log_entries:
60
+ if log_entry.message.startswith(message):
61
+ found_error_log = True
62
+ break
63
+ self.assertTrue(found_error_log)
64
+
65
+
66
+ class PerExampleCheckpointerTest(CheckpointerTest):
56
67
 
57
68
  def test_checkpointing(self):
58
- pg.defaults.loggers.use_stdout()
59
69
  root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
60
70
  experiment = eval_test_helper.test_experiment()
61
71
  checkpoint_filename = 'checkpoint.jsonl'
@@ -85,8 +95,90 @@ class PerExampleCheckpointerTest(unittest.TestCase):
85
95
  for leaf in experiment.leaf_nodes:
86
96
  self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
87
97
 
98
+ # Test warm start without reprocess.
99
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer2')
100
+ experiment = eval_test_helper.test_experiment()
101
+ _ = experiment.run(
102
+ root_dir, 'new', runner='sequential', plugins=[checkpointer],
103
+ warm_start_from=run.output_root
104
+ )
105
+ for leaf in experiment.leaf_nodes:
106
+ self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
107
+
108
+ # Test warm start with reprocess.
109
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer3')
110
+ experiment = eval_test_helper.test_experiment()
111
+ _ = experiment.run(
112
+ root_dir, 'new', runner='sequential', plugins=[checkpointer],
113
+ warm_start_from=run.output_root,
114
+ reprocess=True
115
+ )
116
+ for leaf in experiment.leaf_nodes:
117
+ self.assertEqual(leaf.progress.num_skipped, 0)
88
118
 
89
- class BulkCheckpointerTest(unittest.TestCase):
119
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer4')
120
+ experiment = eval_test_helper.test_experiment()
121
+ _ = experiment.run(
122
+ root_dir, 'new', runner='sequential', plugins=[checkpointer],
123
+ warm_start_from=run.output_root,
124
+ reprocess=[1, 2, 3]
125
+ )
126
+ for leaf in experiment.leaf_nodes:
127
+ self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id] - 3)
128
+
129
+ def test_loading_corrupted_checkpoint(self):
130
+ root_dir = os.path.join(
131
+ tempfile.gettempdir(),
132
+ 'per_example_checkpointer_with_corrupted_checkpoint'
133
+ )
134
+ experiment = eval_test_helper.TestEvaluation()
135
+ checkpoint_filename = 'checkpoint.jsonl'
136
+ checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
137
+ run = experiment.run(
138
+ root_dir, 'new', runner='sequential', plugins=[checkpointer]
139
+ )
140
+ num_processed = {}
141
+ for i in range(experiment.num_examples):
142
+ example = experiment.state.get(i + 1)
143
+ ckpt = run.output_path_for(experiment, f'checkpoint_{example.id}.jsonl')
144
+ if not example.has_error:
145
+ self.assertTrue(pg.io.path_exists(ckpt))
146
+ with pg.io.open_sequence(ckpt) as f:
147
+ self.assertEqual(len(list(iter(f))), 1)
148
+
149
+ # Simulate corrupting the first checkpoint.
150
+ if i == 0:
151
+ pg.io.writefile(ckpt, 'bad file')
152
+ num_processed[example.id] = i + 1
153
+
154
+ root_dir = os.path.join(
155
+ tempfile.gettempdir(),
156
+ 'per_example_checkpointer_with_corrupted_checkpoint_warm_start'
157
+ )
158
+ experiment = eval_test_helper.TestEvaluation()
159
+ _ = experiment.run(
160
+ root_dir, 'new', runner='sequential', plugins=[checkpointer],
161
+ warm_start_from=run.output_root,
162
+ )
163
+ for leaf in experiment.leaf_nodes:
164
+ self.assertEqual(leaf.progress.num_skipped, len(num_processed) - 1)
165
+ self.assert_found_in_log(experiment, 'Failed to load checkpoint')
166
+
167
+ def test_checkpointing_error(self):
168
+ root_dir = os.path.join(
169
+ tempfile.gettempdir(),
170
+ 'per_example_checkpointer_with_checkpointing_error'
171
+ )
172
+ experiment = (eval_test_helper
173
+ .test_experiment_with_example_checkpointing_error())
174
+ checkpointer = checkpointing.PerExampleCheckpointer('checkpoint.jsonl')
175
+ _ = experiment.run(
176
+ root_dir, 'new', runner='parallel', plugins=[checkpointer]
177
+ )
178
+ self.assert_found_in_log(experiment, 'Failed to checkpoint')
179
+
180
+
181
+ class BulkCheckpointerTest(CheckpointerTest):
90
182
 
91
183
  def test_checkpointing(self):
92
184
  root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
@@ -118,6 +210,19 @@ class BulkCheckpointerTest(unittest.TestCase):
118
210
  for leaf in experiment.leaf_nodes:
119
211
  self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
120
212
 
213
+ def test_checkpointing_error(self):
214
+ root_dir = os.path.join(
215
+ tempfile.gettempdir(),
216
+ 'bulk_checkpointer_with_checkpointing_error'
217
+ )
218
+ experiment = (eval_test_helper
219
+ .test_experiment_with_example_checkpointing_error())
220
+ checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
221
+ _ = experiment.run(
222
+ root_dir, 'new', runner='parallel', plugins=[checkpointer]
223
+ )
224
+ self.assert_found_in_log(experiment, 'Failed to checkpoint')
225
+
121
226
 
122
227
  if __name__ == '__main__':
123
228
  unittest.main()
@@ -72,9 +72,65 @@ class TestEvaluation(Evaluation):
72
72
  )
73
73
 
74
74
 
75
+ class BadJsonConvertible(pg.Object):
76
+
77
+ def to_json(self, *args, **kwargs):
78
+ raise ValueError('Cannot convert to JSON.')
79
+
80
+
81
+ class TestEvaluationWithExampleCheckpointingError(TestEvaluation):
82
+ """Test evaluation class with bad example checkpointing."""
83
+ inputs = test_inputs()
84
+ metrics = [metrics_lib.Match()]
85
+
86
+ def process(self, v):
87
+ return 1, dict(
88
+ x=BadJsonConvertible()
89
+ )
90
+
91
+
92
+ class BadHtmlConvertible(pg.Object, pg.views.HtmlTreeView.Extension):
93
+
94
+ def _html_tree_view(self, *args, **kwargs):
95
+ raise ValueError('Cannot render HTML.')
96
+
97
+
98
+ class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
99
+ """Test evaluation class with bad example HTML generation."""
100
+ inputs = test_inputs()
101
+ metrics = [metrics_lib.Match()]
102
+
103
+ def process(self, v):
104
+ return 1, dict(
105
+ x=BadHtmlConvertible()
106
+ )
107
+
108
+
109
+ class TestEvaluationWithIndexHtmlGenerationError(TestEvaluation):
110
+ """Test evaluation class with bad index HTML generation."""
111
+
112
+ def _html_tree_view(self, *args, **kwargs):
113
+ raise ValueError('Cannot render HTML.')
114
+
115
+
75
116
  def test_experiment():
76
117
  """Returns a test experiment."""
77
118
  return Suite([
78
119
  TestEvaluation(lm=TestLLM(offset=0)),
79
120
  TestEvaluation(lm=TestLLM(offset=pg.oneof(range(5)))),
80
121
  ])
122
+
123
+
124
+ def test_experiment_with_example_checkpointing_error():
125
+ """Returns a test experiment with example checkpointing error."""
126
+ return TestEvaluationWithExampleCheckpointingError()
127
+
128
+
129
+ def test_experiment_with_example_html_generation_error():
130
+ """Returns a test experiment with bad example HTML."""
131
+ return TestEvaluationWithExampleHtmlGenerationError()
132
+
133
+
134
+ def test_experiment_with_index_html_generation_error():
135
+ """Returns a test experiment with bad index HTML."""
136
+ return TestEvaluationWithIndexHtmlGenerationError()
@@ -264,11 +264,21 @@ class Evaluation(experiment_lib.Experiment):
264
264
  return self._state
265
265
 
266
266
  def load_state(
267
- self, state_file: str, raise_if_not_exist: bool = False
267
+ self,
268
+ state_file: str,
269
+ *,
270
+ load_example_metadata: bool = True,
271
+ filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
272
+ raise_if_not_exist: bool = False
268
273
  ) -> None:
269
274
  """Loads saved state from a sequence IO file."""
270
275
  if pg.io.path_exists(state_file):
271
- self._state.load(state_file, self.example_input_by_id)
276
+ self._state.load(
277
+ state_file,
278
+ example_input_by_id=self.example_input_by_id,
279
+ load_example_metadata=load_example_metadata,
280
+ filter=filter,
281
+ )
272
282
  elif raise_if_not_exist:
273
283
  raise ValueError(f'State file {state_file} does not exist.')
274
284
 
@@ -680,14 +690,25 @@ class EvaluationState:
680
690
  self._evaluated_examples: dict[int, example_lib.Example] = {}
681
691
 
682
692
  def load(
683
- self, state_file: str, example_input_by_id: Callable[[int], Any]) -> None:
693
+ self,
694
+ state_file: str,
695
+ *,
696
+ example_input_by_id: Callable[[int], Any] | None = None,
697
+ load_example_metadata: bool | Callable[
698
+ [example_lib.Example], bool] = True,
699
+ filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
700
+ ) -> None:
684
701
  """Loads the state from the example sequence file."""
685
702
  with pg.io.sequence.open_sequence(state_file) as f:
686
703
  for record in f:
687
704
  example = pg.from_json_str(
688
- record, example_input_by_id=example_input_by_id
705
+ record,
706
+ example_input_by_id=example_input_by_id,
707
+ load_example_metadata=load_example_metadata
689
708
  )
690
709
  assert isinstance(example, example_lib.Example), example
710
+ if filter is not None and not filter(example):
711
+ continue
691
712
  self._evaluated_examples[example.id] = example
692
713
 
693
714
  @property
@@ -138,6 +138,17 @@ class EvaluationTest(unittest.TestCase):
138
138
  self.assertEqual(example.usage_summary.uncached.total.total_tokens, 0)
139
139
  self.assertEqual(example.usage_summary.uncached.total.num_requests, 0)
140
140
 
141
+ # Test load_state with filter.
142
+ exp.reset()
143
+ self.assertEqual(len(exp._state.evaluated_examples), 0)
144
+ exp.load_state(state_file, filter=lambda x: x.id == 3)
145
+ self.assertEqual(len(exp._state.evaluated_examples), 1)
146
+
147
+ exp.reset()
148
+ self.assertEqual(len(exp._state.evaluated_examples), 0)
149
+ exp.load_state(state_file, filter=lambda x: x.id == 1)
150
+ self.assertEqual(len(exp._state.evaluated_examples), 0)
151
+
141
152
  def test_html_view(self):
142
153
  exp = eval_test_helper.TestEvaluation()
143
154
  exp.debug('debug message')
@@ -101,6 +101,7 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
101
101
  json_value: dict[str, Any],
102
102
  *,
103
103
  example_input_by_id: Callable[[int], Any] | None = None,
104
+ load_example_metadata: bool | Callable[['Example'], bool] = False,
104
105
  **kwargs
105
106
  ) -> 'Example':
106
107
  """Creates an example from the JSON representation."""
@@ -128,12 +129,21 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
128
129
  pg.traverse(example, _visit)
129
130
  return list(referred_types)
130
131
 
132
+ # We delay loading the metadata until the other parts of the example are
133
+ # loaded. So we could apply the filter to decide whether to load the
134
+ # metadata.
135
+ metadata_dict = json_value.pop('metadata', None)
131
136
  with pg.JSONConvertible.load_types_for_deserialization(
132
137
  *example_class_defs(example_input)
133
138
  ):
134
- return cls(
139
+ example = cls(
135
140
  **{k: pg.from_json(v, **kwargs) for k, v in json_value.items()}
136
141
  )
142
+ if callable(load_example_metadata):
143
+ load_example_metadata = load_example_metadata(example)
144
+ if load_example_metadata:
145
+ example.metadata = pg.from_json(metadata_dict, **kwargs)
146
+ return example
137
147
 
138
148
  #
139
149
  # HTML rendering.
@@ -70,17 +70,31 @@ class ExampleTest(unittest.TestCase):
70
70
  self.assertEqual(
71
71
  pg.from_json_str(
72
72
  json_str,
73
- example_input_by_id=lambda i: inputs[i - 1]
73
+ example_input_by_id=lambda i: inputs[i - 1],
74
+ load_example_metadata=True,
74
75
  ),
75
76
  ex
76
77
  )
78
+ self.assertEqual(
79
+ pg.from_json_str(
80
+ json_str,
81
+ example_input_by_id=lambda i: inputs[i - 1],
82
+ load_example_metadata=False,
83
+ ),
84
+ Example(
85
+ id=1,
86
+ input=inputs[0],
87
+ output=inputs[0].a(1),
88
+ metadata={}
89
+ )
90
+ )
77
91
  pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
78
92
  inputs[0].a.__type_name__
79
93
  )
80
94
  pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
81
95
  inputs[0].b.__type_name__
82
96
  )
83
- v = pg.from_json_str(json_str, auto_dict=True)
97
+ v = pg.from_json_str(json_str, auto_dict=True, load_example_metadata=True)
84
98
  v.output.pop('type_name')
85
99
  v.metadata.b.pop('type_name')
86
100
  self.assertEqual(