langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
  """Checkpointing evaluation runs."""
15
15
  import abc
16
+ import datetime
17
+ import os
16
18
  import re
17
19
  import threading
18
20
  import traceback
@@ -29,12 +31,32 @@ Runner = experiment_lib.Runner
29
31
 
30
32
 
31
33
  class Checkpointer(experiment_lib.Plugin):
32
- """Base class for checkpointing evaluation examples."""
34
+ """Base class for checkpointing evaluation examples.
35
+
36
+ `Checkpointer` is a plugin that saves the state of processed examples
37
+ incrementally during an experiment run, allowing the experiment to be resumed
38
+ later. When an experiment starts, the checkpointer loads any previously saved
39
+ examples from an earlier run (or a warm-start run) into `experiment.state`,
40
+ so the runner can skip processing them again.
41
+ Subclasses should implement `_list_checkpoint_files` to identify
42
+ checkpoint files to load, and `_save_example` to save a newly processed
43
+ example.
44
+ """
33
45
 
34
46
  checkpoint_filename: Annotated[
35
47
  str,
36
48
  'Checkpoint file pattern.'
37
- ] = 'checkpoint.bagz'
49
+ ] = 'checkpoint.jsonl'
50
+
51
+ enable_inprogress_file: Annotated[
52
+ bool,
53
+ 'If True, write file "<example_id>.inprogress" when example gets started.'
54
+ ] = True
55
+
56
+ max_ckpt_loading_threads: Annotated[
57
+ int,
58
+ 'Max number of workers for loading checkpoint files at startup.'
59
+ ] = 128
38
60
 
39
61
  def on_experiment_start(
40
62
  self,
@@ -75,6 +97,24 @@ class Checkpointer(experiment_lib.Plugin):
75
97
  f'scratch. Example IDs: {example_ids_to_evaluate}.'
76
98
  )
77
99
 
100
+ def on_example_start(
101
+ self,
102
+ runner: Runner,
103
+ experiment: Experiment,
104
+ example: Example,
105
+ ) -> None:
106
+ """Saves the example to the checkpoint file."""
107
+ if self.enable_inprogress_file:
108
+ def _save_inprogress_file(example: Example):
109
+ inprogress_file = runner.current_run.output_path_for(
110
+ experiment, f'{example.id}.inprogress'
111
+ )
112
+ pg.io.writefile(
113
+ inprogress_file,
114
+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
115
+ )
116
+ runner.background_run(_save_inprogress_file, example)
117
+
78
118
  def on_example_complete(
79
119
  self,
80
120
  runner: Runner,
@@ -91,7 +131,7 @@ class Checkpointer(experiment_lib.Plugin):
91
131
  experiment: Experiment,
92
132
  ) -> None:
93
133
  """Creates the checkpoint file."""
94
- ckpt_files = self._list_checkpoint_filenames(runner, experiment)
134
+ ckpt_files = self._list_checkpoint_files(runner, experiment)
95
135
  experiment.info(f'Found {len(ckpt_files)} checkpoint files to load.')
96
136
 
97
137
  # Load the checkpoint files in parallel.
@@ -101,18 +141,18 @@ class Checkpointer(experiment_lib.Plugin):
101
141
  experiment
102
142
  )
103
143
  context = dict(counter=0, counter_lock=threading.Lock())
104
- copy_ckpt = current_run.input_root != current_run.output_root
105
144
 
106
145
  def _load_state(ckpt_file):
107
146
  error = None
108
147
  with pg.timeit() as t:
109
148
  try:
110
- experiment.load_state(
111
- current_run.input_path_for(experiment, ckpt_file),
149
+ loaded_examples = experiment.load_state(
150
+ ckpt_file,
112
151
  filter=lambda x: x.id in examples_to_load,
113
152
  load_example_metadata=lambda x: x.id in examples_to_load_metadata,
114
153
  )
115
154
  except BaseException as e: # pylint: disable=broad-except
155
+ loaded_examples = []
116
156
  error = e
117
157
  finally:
118
158
  with context['counter_lock']:
@@ -130,34 +170,33 @@ class Checkpointer(experiment_lib.Plugin):
130
170
  f'Skipping the file. ({progress_str})'
131
171
  )
132
172
 
133
- if not copy_ckpt:
134
- return
135
-
136
- # Copy the checkpoint records to the output directory.
137
- try:
138
- with pg.io.open_sequence(
139
- current_run.output_path_for(experiment, ckpt_file), 'w'
140
- ) as o, pg.io.open_sequence(
141
- current_run.input_path_for(experiment, ckpt_file), 'r'
142
- ) as i:
143
- for x in i:
144
- o.add(x)
145
- except BaseException as e: # pylint: disable=broad-except
146
- experiment.warning(
147
- f'Failed to copy checkpoint {ckpt_file!r}: {e}.'
148
- )
173
+ output_ckpt_file = current_run.output_path_for(
174
+ experiment, os.path.basename(ckpt_file)
175
+ )
176
+ if ckpt_file != output_ckpt_file and any(
177
+ e for e in loaded_examples if not e.has_error
178
+ ):
179
+ # Write the error-free warm-start examples to the output checkpoint
180
+ # file.
181
+ with SequenceWriter(output_ckpt_file) as writer:
182
+ for example in loaded_examples:
183
+ if not example.has_error:
184
+ writer.add(example)
149
185
 
150
186
  _ = list(
151
187
  lf.concurrent_map(
152
- _load_state, ckpt_files, max_workers=16, silence_on_errors=None
188
+ _load_state,
189
+ ckpt_files,
190
+ max_workers=self.max_ckpt_loading_threads,
191
+ silence_on_errors=None
153
192
  )
154
193
  )
155
194
 
156
195
  @abc.abstractmethod
157
- def _list_checkpoint_filenames(
196
+ def _list_checkpoint_files(
158
197
  self, runner: Runner, experiment: Experiment
159
198
  ) -> list[str]:
160
- """Lists the checkpoint filenames to restore."""
199
+ """Lists the checkpoint file paths to restore."""
161
200
 
162
201
  @abc.abstractmethod
163
202
  def _save_example(
@@ -170,7 +209,12 @@ class Checkpointer(experiment_lib.Plugin):
170
209
 
171
210
 
172
211
  class PerExampleCheckpointer(Checkpointer):
173
- """Checkpointer that saves each example to a separate file."""
212
+ """Checkpointer that saves each example to a separate file.
213
+
214
+ This checkpointer saves each processed example to its own checkpoint file,
215
+ named using the pattern `<checkpoint_filename_prefix>_<example_id>.<ext>`.
216
+ For example, `checkpoint_1.bagz`, `checkpoint_2.bagz`, etc.
217
+ """
174
218
 
175
219
  def _on_bound(self):
176
220
  super()._on_bound()
@@ -178,22 +222,41 @@ class PerExampleCheckpointer(Checkpointer):
178
222
  self._checkpoint_file_prefix = prefix
179
223
  self._checkpoint_file_ext = ext
180
224
 
181
- def _list_checkpoint_filenames(
225
+ def _list_checkpoint_files(
182
226
  self, runner: Runner, experiment: Experiment
183
227
  ) -> list[str]:
184
- experiment_dir = runner.current_run.input_dir(experiment)
185
- filenames = []
228
+
229
+ def _list_checkpoints_from(ckpt_dir: str, examples_to_load: set[int]):
230
+ ckpt_files = []
231
+ if pg.io.path_exists(ckpt_dir):
232
+ regex = re.compile(
233
+ f'{self._checkpoint_file_prefix}_(\\d+){self._checkpoint_file_ext}'
234
+ .replace('.', '\\.')
235
+ )
236
+ for filename in pg.io.listdir(ckpt_dir):
237
+ match = regex.match(filename)
238
+ if match and int(match.group(1)) in examples_to_load:
239
+ examples_to_load.remove(int(match.group(1)))
240
+ ckpt_files.append(os.path.join(ckpt_dir, filename))
241
+ return ckpt_files
242
+
186
243
  examples_to_load = runner.current_run.examples_to_load(experiment)
187
- if pg.io.path_exists(experiment_dir):
188
- regex = re.compile(
189
- f'{self._checkpoint_file_prefix}_(\\d+){self._checkpoint_file_ext}'
190
- .replace('.', '\\.')
244
+
245
+ # Take output directory as the first priority to checkpoints processed in
246
+ # this run.
247
+ ckpt_files = _list_checkpoints_from(
248
+ runner.current_run.output_dir(experiment), examples_to_load
249
+ )
250
+ # If the input and output directories are different, also load from the
251
+ # input directory.
252
+ if (examples_to_load
253
+ and runner.current_run.input_root != runner.current_run.output_root):
254
+ ckpt_files.extend(
255
+ _list_checkpoints_from(
256
+ runner.current_run.input_dir(experiment), examples_to_load
257
+ )
191
258
  )
192
- for filename in pg.io.listdir(experiment_dir):
193
- match = regex.match(filename)
194
- if match and int(match.group(1)) in examples_to_load:
195
- filenames.append(filename)
196
- return filenames
259
+ return ckpt_files
197
260
 
198
261
  def _save_example(
199
262
  self,
@@ -235,7 +298,13 @@ class PerExampleCheckpointer(Checkpointer):
235
298
 
236
299
 
237
300
  class BulkCheckpointer(Checkpointer):
238
- """Checkpointer that saves all examples to a single file."""
301
+ """Checkpointer that saves all examples of an evaluation to a single file.
302
+
303
+ This checkpointer appends newly processed examples of an evaluation to a
304
+ single sequence file (e.g., `checkpoint.bagz`). This is often more efficient
305
+ than `PerExampleCheckpointer` when dealing with a large number of examples
306
+ or when file system overhead is a concern.
307
+ """
239
308
 
240
309
  def _on_bound(self):
241
310
  super()._on_bound()
@@ -287,13 +356,24 @@ class BulkCheckpointer(Checkpointer):
287
356
  if self._sequence_writer is not None:
288
357
  self._sequence_writer[experiment.id] = sequence_writer
289
358
 
290
- def _list_checkpoint_filenames(
359
+ def _list_checkpoint_files(
291
360
  self, runner: Runner, experiment: Experiment
292
361
  ) -> list[str]:
293
- if pg.io.path_exists(
294
- runner.current_run.input_path_for(experiment, self.checkpoint_filename)
295
- ):
296
- return [self.checkpoint_filename]
362
+ # Always honor the output directory if it's present, as it contains both
363
+ # the warm-started examples and newly processed examples.
364
+ output_ckpt_file = runner.current_run.output_path_for(
365
+ experiment, self.checkpoint_filename
366
+ )
367
+ if pg.io.path_exists(output_ckpt_file):
368
+ return [output_ckpt_file]
369
+
370
+ if runner.current_run.input_root != runner.current_run.output_root:
371
+ input_ckpt_file = runner.current_run.input_path_for(
372
+ experiment, self.checkpoint_filename
373
+ )
374
+ if pg.io.path_exists(input_ckpt_file):
375
+ return [input_ckpt_file]
376
+ print('CCC', experiment.hash, [])
297
377
  return []
298
378
 
299
379
  def on_experiment_complete(
@@ -341,12 +421,26 @@ class BulkCheckpointer(Checkpointer):
341
421
 
342
422
 
343
423
  class SequenceWriter:
344
- """Thread safe sequence writer."""
424
+ """A thread-safe writer for sequence files (e.g., Bagz) with atomic write.
425
+
426
+ `SequenceWriter` wraps a `pg.io.SequenceWriter` to provide thread-safe
427
+ `add` and `close` operations, ensuring that examples can be written
428
+ concurrently from multiple threads without corrupting the sequence file.
429
+ It writes to a temporary file and renames it to target path on `close` to
430
+ achieve atomic write. If the target path exists, new examples are appended
431
+ to existing content.
432
+ """
345
433
 
346
434
  def __init__(self, path: str):
347
435
  self._lock = threading.Lock()
348
436
  self._path = path
349
- self._sequence_writer = pg.io.open_sequence(path, 'a')
437
+ basename = os.path.basename(path)
438
+ self._tmp_path = os.path.join(
439
+ os.path.dirname(path), f'tmp.{basename}'
440
+ )
441
+ if pg.io.path_exists(self._path):
442
+ pg.io.copy(self._path, self._tmp_path)
443
+ self._sequence_writer = pg.io.open_sequence(self._tmp_path, 'a')
350
444
 
351
445
  @property
352
446
  def path(self) -> str:
@@ -371,6 +465,14 @@ class SequenceWriter:
371
465
  return
372
466
  self._sequence_writer.close()
373
467
  self._sequence_writer = None
468
+ pg.io.rename(self._tmp_path, self._path)
469
+
470
+ def __enter__(self):
471
+ return self
472
+
473
+ def __exit__(self, *args, **kwargs):
474
+ del args, kwargs
475
+ self.close()
374
476
 
375
477
  def __del__(self):
376
478
  self.close()
@@ -65,7 +65,7 @@ class ExampleCollector(experiment_lib.Plugin):
65
65
  return self._examples
66
66
 
67
67
  def on_example_complete(
68
- self, runner: runners_lib.Runner,
68
+ self, runner: experiment_lib.Runner,
69
69
  experiment: experiment_lib.Experiment,
70
70
  example: example_lib.Example,
71
71
  ):
@@ -90,7 +90,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
90
90
  root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer')
91
91
  experiment = eval_test_helper.test_experiment()
92
92
  checkpoint_filename = 'checkpoint.jsonl'
93
- checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
93
+ checkpointer = checkpointing.PerExampleCheckpointer(
94
+ checkpoint_filename,
95
+ enable_inprogress_file=True
96
+ )
94
97
  collector = ExampleCollector()
95
98
  run = experiment.run(
96
99
  root_dir, 'new', runner='sequential', plugins=[checkpointer, collector]
@@ -102,6 +105,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
102
105
  example = collector.examples[i + 1]
103
106
  ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
104
107
  self.assertTrue(pg.io.path_exists(ckpt))
108
+ inprogress_file = run.output_path_for(
109
+ leaf, f'{example.id}.inprogress'
110
+ )
111
+ self.assertTrue(pg.io.path_exists(inprogress_file))
105
112
  with pg.io.open_sequence(ckpt) as f:
106
113
  examples_from_ckpt = list(iter(f))
107
114
  # `eval_test_helper.test_experiment` has two TestEvaluation with
@@ -0,0 +1,37 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Config saver plugins."""
15
+
16
+ import os
17
+ from langfun.core.eval.v2 import experiment as experiment_lib
18
+
19
+
20
+ class RunConfigSaver(experiment_lib.Plugin):
21
+ """Saves the current run."""
22
+
23
+ def on_run_start(
24
+ self,
25
+ runner: experiment_lib.Runner,
26
+ root: experiment_lib.Experiment
27
+ ) -> None:
28
+ del root # Unused.
29
+ self._save_run_config(runner)
30
+
31
+ def _save_run_config(self, runner: experiment_lib.Runner) -> None:
32
+ def _save():
33
+ runner.current_run.save(
34
+ os.path.join(runner.current_run.output_root, 'run.json'),
35
+ hide_default_values=True,
36
+ )
37
+ runner.background_run(_save)
@@ -0,0 +1,36 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Config saver test."""
15
+
16
+ import os
17
+ import tempfile
18
+ import unittest
19
+ from langfun.core.eval.v2 import config_saver
20
+ from langfun.core.eval.v2 import eval_test_helper
21
+ from langfun.core.eval.v2.runners import parallel # pylint: disable=unused-import
22
+
23
+
24
+ class RunConfigSaverTest(unittest.TestCase):
25
+
26
+ def test_save_run_config(self):
27
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_run_config_saver')
28
+ experiment = eval_test_helper.test_evaluation()
29
+ run = experiment.run(
30
+ root_dir, 'new', plugins=[config_saver.RunConfigSaver()]
31
+ )
32
+ self.assertTrue(os.path.exists(os.path.join(run.output_root, 'run.json')))
33
+
34
+
35
+ if __name__ == '__main__':
36
+ unittest.main()
@@ -13,6 +13,9 @@
13
13
  # limitations under the License.
14
14
  """Helper classes and functions for evaluation tests."""
15
15
 
16
+ import threading
17
+ import time
18
+
16
19
  from langfun.core import language_model
17
20
  from langfun.core import llms
18
21
  from langfun.core import message as message_lib
@@ -47,6 +50,8 @@ class TestLLM(llms.Fake):
47
50
 
48
51
  offset: int = 0
49
52
 
53
+ __test__ = False
54
+
50
55
  def _response_from(self, prompt: message_lib.Message) -> message_lib.Message:
51
56
  return message_lib.AIMessage(
52
57
  str(prompt.metadata.x + prompt.metadata.y + self.offset)
@@ -63,6 +68,8 @@ class TestEvaluation(Evaluation):
63
68
  metrics = [metrics_lib.Match()]
64
69
  lm: language_model.LanguageModel = TestLLM()
65
70
 
71
+ __test__ = False
72
+
66
73
  def process(self, example):
67
74
  v = example.input
68
75
  if v.x == 5:
@@ -75,7 +82,7 @@ class TestEvaluation(Evaluation):
75
82
 
76
83
  class BadJsonConvertible(pg.Object):
77
84
 
78
- def to_json(self, *args, **kwargs):
85
+ def sym_jsonify(self, *args, **kwargs):
79
86
  raise ValueError('Cannot convert to JSON.')
80
87
 
81
88
 
@@ -84,6 +91,8 @@ class TestEvaluationWithExampleCheckpointingError(TestEvaluation):
84
91
  inputs = test_inputs()
85
92
  metrics = [metrics_lib.Match()]
86
93
 
94
+ __test__ = False
95
+
87
96
  def process(self, example):
88
97
  return 1, dict(
89
98
  x=BadJsonConvertible()
@@ -101,6 +110,8 @@ class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
101
110
  inputs = test_inputs()
102
111
  metrics = [metrics_lib.Match()]
103
112
 
113
+ __test__ = False
114
+
104
115
  def process(self, example):
105
116
  return 1, dict(
106
117
  x=BadHtmlConvertible()
@@ -110,15 +121,22 @@ class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
110
121
  class TestEvaluationWithIndexHtmlGenerationError(TestEvaluation):
111
122
  """Test evaluation class with bad index HTML generation."""
112
123
 
124
+ __test__ = False
125
+
113
126
  def _html_tree_view(self, *args, **kwargs):
114
127
  raise ValueError('Cannot render HTML.')
115
128
 
116
129
 
130
+ def test_evaluation(offset: int | pg.hyper.OneOf = 0):
131
+ """Returns a test evaluation."""
132
+ return TestEvaluation(lm=TestLLM(offset=offset))
133
+
134
+
117
135
  def test_experiment():
118
136
  """Returns a test experiment."""
119
137
  return Suite([
120
- TestEvaluation(lm=TestLLM(offset=0)),
121
- TestEvaluation(lm=TestLLM(offset=pg.oneof(range(5)))),
138
+ test_evaluation(),
139
+ test_evaluation(pg.oneof(range(5))),
122
140
  ])
123
141
 
124
142
 
@@ -135,3 +153,86 @@ def test_experiment_with_example_html_generation_error():
135
153
  def test_experiment_with_index_html_generation_error():
136
154
  """Returns a test experiment with bad index HTML."""
137
155
  return TestEvaluationWithIndexHtmlGenerationError()
156
+
157
+
158
+ class TestPlugin(experiment_lib.Plugin):
159
+ """Plugin for testing."""
160
+
161
+ started_experiments: list[experiment_lib.Experiment] = []
162
+ completed_experiments: list[experiment_lib.Experiment] = []
163
+ skipped_experiments: list[experiment_lib.Experiment] = []
164
+ started_example_ids: list[int] = []
165
+ completed_example_ids: list[int] = []
166
+ start_time: float | None = None
167
+ complete_time: float | None = None
168
+
169
+ __test__ = False
170
+
171
+ def _on_bound(self):
172
+ super()._on_bound()
173
+ self._lock = threading.Lock()
174
+
175
+ def on_run_start(
176
+ self,
177
+ runner: experiment_lib.Runner,
178
+ root: experiment_lib.Experiment
179
+ ) -> None:
180
+ del root
181
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
182
+ self.start_time = time.time()
183
+
184
+ def on_run_complete(
185
+ self,
186
+ runner: experiment_lib.Runner,
187
+ root: experiment_lib.Experiment
188
+ ) -> None:
189
+ del root
190
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
191
+ self.complete_time = time.time()
192
+
193
+ def on_experiment_start(
194
+ self,
195
+ runner: experiment_lib.Runner,
196
+ experiment: experiment_lib.Experiment
197
+ ) -> None:
198
+ del runner
199
+ with pg.notify_on_change(False), self._lock:
200
+ self.started_experiments.append(pg.Ref(experiment))
201
+
202
+ def on_experiment_skipped(
203
+ self,
204
+ runner: experiment_lib.Runner,
205
+ experiment: experiment_lib.Experiment
206
+ ) -> None:
207
+ del runner
208
+ with pg.notify_on_change(False), self._lock:
209
+ self.skipped_experiments.append(pg.Ref(experiment))
210
+
211
+ def on_experiment_complete(
212
+ self,
213
+ runner: experiment_lib.Runner,
214
+ experiment: experiment_lib.Experiment
215
+ ) -> None:
216
+ del runner
217
+ with pg.notify_on_change(False), self._lock:
218
+ self.completed_experiments.append(pg.Ref(experiment))
219
+
220
+ def on_example_start(
221
+ self,
222
+ runner: experiment_lib.Runner,
223
+ experiment: experiment_lib.Experiment,
224
+ example: Example
225
+ ) -> None:
226
+ del runner, experiment
227
+ with pg.notify_on_change(False), self._lock:
228
+ self.started_example_ids.append(example.id)
229
+
230
+ def on_example_complete(
231
+ self,
232
+ runner: experiment_lib.Runner,
233
+ experiment: experiment_lib.Experiment,
234
+ example: Example
235
+ ) -> None:
236
+ del runner, experiment
237
+ with pg.notify_on_change(False), self._lock:
238
+ self.completed_example_ids.append(example.id)