langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Checkpointing evaluation runs."""
|
|
15
15
|
import abc
|
|
16
|
+
import datetime
|
|
17
|
+
import os
|
|
16
18
|
import re
|
|
17
19
|
import threading
|
|
18
20
|
import traceback
|
|
@@ -29,12 +31,32 @@ Runner = experiment_lib.Runner
|
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
class Checkpointer(experiment_lib.Plugin):
|
|
32
|
-
"""Base class for checkpointing evaluation examples.
|
|
34
|
+
"""Base class for checkpointing evaluation examples.
|
|
35
|
+
|
|
36
|
+
`Checkpointer` is a plugin that saves the state of processed examples
|
|
37
|
+
incrementally during an experiment run, allowing the experiment to be resumed
|
|
38
|
+
later. When an experiment starts, the checkpointer loads any previously saved
|
|
39
|
+
examples from an earlier run (or a warm-start run) into `experiment.state`,
|
|
40
|
+
so the runner can skip processing them again.
|
|
41
|
+
Subclasses should implement `_list_checkpoint_files` to identify
|
|
42
|
+
checkpoint files to load, and `_save_example` to save a newly processed
|
|
43
|
+
example.
|
|
44
|
+
"""
|
|
33
45
|
|
|
34
46
|
checkpoint_filename: Annotated[
|
|
35
47
|
str,
|
|
36
48
|
'Checkpoint file pattern.'
|
|
37
|
-
] = 'checkpoint.
|
|
49
|
+
] = 'checkpoint.jsonl'
|
|
50
|
+
|
|
51
|
+
enable_inprogress_file: Annotated[
|
|
52
|
+
bool,
|
|
53
|
+
'If True, write file "<example_id>.inprogress" when example gets started.'
|
|
54
|
+
] = True
|
|
55
|
+
|
|
56
|
+
max_ckpt_loading_threads: Annotated[
|
|
57
|
+
int,
|
|
58
|
+
'Max number of workers for loading checkpoint files at startup.'
|
|
59
|
+
] = 128
|
|
38
60
|
|
|
39
61
|
def on_experiment_start(
|
|
40
62
|
self,
|
|
@@ -75,6 +97,24 @@ class Checkpointer(experiment_lib.Plugin):
|
|
|
75
97
|
f'scratch. Example IDs: {example_ids_to_evaluate}.'
|
|
76
98
|
)
|
|
77
99
|
|
|
100
|
+
def on_example_start(
|
|
101
|
+
self,
|
|
102
|
+
runner: Runner,
|
|
103
|
+
experiment: Experiment,
|
|
104
|
+
example: Example,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Saves the example to the checkpoint file."""
|
|
107
|
+
if self.enable_inprogress_file:
|
|
108
|
+
def _save_inprogress_file(example: Example):
|
|
109
|
+
inprogress_file = runner.current_run.output_path_for(
|
|
110
|
+
experiment, f'{example.id}.inprogress'
|
|
111
|
+
)
|
|
112
|
+
pg.io.writefile(
|
|
113
|
+
inprogress_file,
|
|
114
|
+
datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
115
|
+
)
|
|
116
|
+
runner.background_run(_save_inprogress_file, example)
|
|
117
|
+
|
|
78
118
|
def on_example_complete(
|
|
79
119
|
self,
|
|
80
120
|
runner: Runner,
|
|
@@ -91,7 +131,7 @@ class Checkpointer(experiment_lib.Plugin):
|
|
|
91
131
|
experiment: Experiment,
|
|
92
132
|
) -> None:
|
|
93
133
|
"""Creates the checkpoint file."""
|
|
94
|
-
ckpt_files = self.
|
|
134
|
+
ckpt_files = self._list_checkpoint_files(runner, experiment)
|
|
95
135
|
experiment.info(f'Found {len(ckpt_files)} checkpoint files to load.')
|
|
96
136
|
|
|
97
137
|
# Load the checkpoint files in parallel.
|
|
@@ -101,18 +141,18 @@ class Checkpointer(experiment_lib.Plugin):
|
|
|
101
141
|
experiment
|
|
102
142
|
)
|
|
103
143
|
context = dict(counter=0, counter_lock=threading.Lock())
|
|
104
|
-
copy_ckpt = current_run.input_root != current_run.output_root
|
|
105
144
|
|
|
106
145
|
def _load_state(ckpt_file):
|
|
107
146
|
error = None
|
|
108
147
|
with pg.timeit() as t:
|
|
109
148
|
try:
|
|
110
|
-
experiment.load_state(
|
|
111
|
-
|
|
149
|
+
loaded_examples = experiment.load_state(
|
|
150
|
+
ckpt_file,
|
|
112
151
|
filter=lambda x: x.id in examples_to_load,
|
|
113
152
|
load_example_metadata=lambda x: x.id in examples_to_load_metadata,
|
|
114
153
|
)
|
|
115
154
|
except BaseException as e: # pylint: disable=broad-except
|
|
155
|
+
loaded_examples = []
|
|
116
156
|
error = e
|
|
117
157
|
finally:
|
|
118
158
|
with context['counter_lock']:
|
|
@@ -130,34 +170,33 @@ class Checkpointer(experiment_lib.Plugin):
|
|
|
130
170
|
f'Skipping the file. ({progress_str})'
|
|
131
171
|
)
|
|
132
172
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
except BaseException as e: # pylint: disable=broad-except
|
|
146
|
-
experiment.warning(
|
|
147
|
-
f'Failed to copy checkpoint {ckpt_file!r}: {e}.'
|
|
148
|
-
)
|
|
173
|
+
output_ckpt_file = current_run.output_path_for(
|
|
174
|
+
experiment, os.path.basename(ckpt_file)
|
|
175
|
+
)
|
|
176
|
+
if ckpt_file != output_ckpt_file and any(
|
|
177
|
+
e for e in loaded_examples if not e.has_error
|
|
178
|
+
):
|
|
179
|
+
# Write the error-free warm-start examples to the output checkpoint
|
|
180
|
+
# file.
|
|
181
|
+
with SequenceWriter(output_ckpt_file) as writer:
|
|
182
|
+
for example in loaded_examples:
|
|
183
|
+
if not example.has_error:
|
|
184
|
+
writer.add(example)
|
|
149
185
|
|
|
150
186
|
_ = list(
|
|
151
187
|
lf.concurrent_map(
|
|
152
|
-
_load_state,
|
|
188
|
+
_load_state,
|
|
189
|
+
ckpt_files,
|
|
190
|
+
max_workers=self.max_ckpt_loading_threads,
|
|
191
|
+
silence_on_errors=None
|
|
153
192
|
)
|
|
154
193
|
)
|
|
155
194
|
|
|
156
195
|
@abc.abstractmethod
|
|
157
|
-
def
|
|
196
|
+
def _list_checkpoint_files(
|
|
158
197
|
self, runner: Runner, experiment: Experiment
|
|
159
198
|
) -> list[str]:
|
|
160
|
-
"""Lists the checkpoint
|
|
199
|
+
"""Lists the checkpoint file paths to restore."""
|
|
161
200
|
|
|
162
201
|
@abc.abstractmethod
|
|
163
202
|
def _save_example(
|
|
@@ -170,7 +209,12 @@ class Checkpointer(experiment_lib.Plugin):
|
|
|
170
209
|
|
|
171
210
|
|
|
172
211
|
class PerExampleCheckpointer(Checkpointer):
|
|
173
|
-
"""Checkpointer that saves each example to a separate file.
|
|
212
|
+
"""Checkpointer that saves each example to a separate file.
|
|
213
|
+
|
|
214
|
+
This checkpointer saves each processed example to its own checkpoint file,
|
|
215
|
+
named using the pattern `<checkpoint_filename_prefix>_<example_id>.<ext>`.
|
|
216
|
+
For example, `checkpoint_1.bagz`, `checkpoint_2.bagz`, etc.
|
|
217
|
+
"""
|
|
174
218
|
|
|
175
219
|
def _on_bound(self):
|
|
176
220
|
super()._on_bound()
|
|
@@ -178,22 +222,41 @@ class PerExampleCheckpointer(Checkpointer):
|
|
|
178
222
|
self._checkpoint_file_prefix = prefix
|
|
179
223
|
self._checkpoint_file_ext = ext
|
|
180
224
|
|
|
181
|
-
def
|
|
225
|
+
def _list_checkpoint_files(
|
|
182
226
|
self, runner: Runner, experiment: Experiment
|
|
183
227
|
) -> list[str]:
|
|
184
|
-
|
|
185
|
-
|
|
228
|
+
|
|
229
|
+
def _list_checkpoints_from(ckpt_dir: str, examples_to_load: set[int]):
|
|
230
|
+
ckpt_files = []
|
|
231
|
+
if pg.io.path_exists(ckpt_dir):
|
|
232
|
+
regex = re.compile(
|
|
233
|
+
f'{self._checkpoint_file_prefix}_(\\d+){self._checkpoint_file_ext}'
|
|
234
|
+
.replace('.', '\\.')
|
|
235
|
+
)
|
|
236
|
+
for filename in pg.io.listdir(ckpt_dir):
|
|
237
|
+
match = regex.match(filename)
|
|
238
|
+
if match and int(match.group(1)) in examples_to_load:
|
|
239
|
+
examples_to_load.remove(int(match.group(1)))
|
|
240
|
+
ckpt_files.append(os.path.join(ckpt_dir, filename))
|
|
241
|
+
return ckpt_files
|
|
242
|
+
|
|
186
243
|
examples_to_load = runner.current_run.examples_to_load(experiment)
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
244
|
+
|
|
245
|
+
# Take output directory as the first priority to checkpoints processed in
|
|
246
|
+
# this run.
|
|
247
|
+
ckpt_files = _list_checkpoints_from(
|
|
248
|
+
runner.current_run.output_dir(experiment), examples_to_load
|
|
249
|
+
)
|
|
250
|
+
# If the input and output directories are different, also load from the
|
|
251
|
+
# input directory.
|
|
252
|
+
if (examples_to_load
|
|
253
|
+
and runner.current_run.input_root != runner.current_run.output_root):
|
|
254
|
+
ckpt_files.extend(
|
|
255
|
+
_list_checkpoints_from(
|
|
256
|
+
runner.current_run.input_dir(experiment), examples_to_load
|
|
257
|
+
)
|
|
191
258
|
)
|
|
192
|
-
|
|
193
|
-
match = regex.match(filename)
|
|
194
|
-
if match and int(match.group(1)) in examples_to_load:
|
|
195
|
-
filenames.append(filename)
|
|
196
|
-
return filenames
|
|
259
|
+
return ckpt_files
|
|
197
260
|
|
|
198
261
|
def _save_example(
|
|
199
262
|
self,
|
|
@@ -235,7 +298,13 @@ class PerExampleCheckpointer(Checkpointer):
|
|
|
235
298
|
|
|
236
299
|
|
|
237
300
|
class BulkCheckpointer(Checkpointer):
|
|
238
|
-
"""Checkpointer that saves all examples to a single file.
|
|
301
|
+
"""Checkpointer that saves all examples of an evaluation to a single file.
|
|
302
|
+
|
|
303
|
+
This checkpointer appends newly processed examples of an evaluation to a
|
|
304
|
+
single sequence file (e.g., `checkpoint.bagz`). This is often more efficient
|
|
305
|
+
than `PerExampleCheckpointer` when dealing with a large number of examples
|
|
306
|
+
or when file system overhead is a concern.
|
|
307
|
+
"""
|
|
239
308
|
|
|
240
309
|
def _on_bound(self):
|
|
241
310
|
super()._on_bound()
|
|
@@ -287,13 +356,24 @@ class BulkCheckpointer(Checkpointer):
|
|
|
287
356
|
if self._sequence_writer is not None:
|
|
288
357
|
self._sequence_writer[experiment.id] = sequence_writer
|
|
289
358
|
|
|
290
|
-
def
|
|
359
|
+
def _list_checkpoint_files(
|
|
291
360
|
self, runner: Runner, experiment: Experiment
|
|
292
361
|
) -> list[str]:
|
|
293
|
-
if
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
362
|
+
# Always honor the output directory if it's present, as it contains both
|
|
363
|
+
# the warm-started examples and newly processed examples.
|
|
364
|
+
output_ckpt_file = runner.current_run.output_path_for(
|
|
365
|
+
experiment, self.checkpoint_filename
|
|
366
|
+
)
|
|
367
|
+
if pg.io.path_exists(output_ckpt_file):
|
|
368
|
+
return [output_ckpt_file]
|
|
369
|
+
|
|
370
|
+
if runner.current_run.input_root != runner.current_run.output_root:
|
|
371
|
+
input_ckpt_file = runner.current_run.input_path_for(
|
|
372
|
+
experiment, self.checkpoint_filename
|
|
373
|
+
)
|
|
374
|
+
if pg.io.path_exists(input_ckpt_file):
|
|
375
|
+
return [input_ckpt_file]
|
|
376
|
+
print('CCC', experiment.hash, [])
|
|
297
377
|
return []
|
|
298
378
|
|
|
299
379
|
def on_experiment_complete(
|
|
@@ -341,12 +421,26 @@ class BulkCheckpointer(Checkpointer):
|
|
|
341
421
|
|
|
342
422
|
|
|
343
423
|
class SequenceWriter:
|
|
344
|
-
"""
|
|
424
|
+
"""A thread-safe writer for sequence files (e.g., Bagz) with atomic write.
|
|
425
|
+
|
|
426
|
+
`SequenceWriter` wraps a `pg.io.SequenceWriter` to provide thread-safe
|
|
427
|
+
`add` and `close` operations, ensuring that examples can be written
|
|
428
|
+
concurrently from multiple threads without corrupting the sequence file.
|
|
429
|
+
It writes to a temporary file and renames it to target path on `close` to
|
|
430
|
+
achieve atomic write. If the target path exists, new examples are appended
|
|
431
|
+
to existing content.
|
|
432
|
+
"""
|
|
345
433
|
|
|
346
434
|
def __init__(self, path: str):
|
|
347
435
|
self._lock = threading.Lock()
|
|
348
436
|
self._path = path
|
|
349
|
-
|
|
437
|
+
basename = os.path.basename(path)
|
|
438
|
+
self._tmp_path = os.path.join(
|
|
439
|
+
os.path.dirname(path), f'tmp.{basename}'
|
|
440
|
+
)
|
|
441
|
+
if pg.io.path_exists(self._path):
|
|
442
|
+
pg.io.copy(self._path, self._tmp_path)
|
|
443
|
+
self._sequence_writer = pg.io.open_sequence(self._tmp_path, 'a')
|
|
350
444
|
|
|
351
445
|
@property
|
|
352
446
|
def path(self) -> str:
|
|
@@ -371,6 +465,14 @@ class SequenceWriter:
|
|
|
371
465
|
return
|
|
372
466
|
self._sequence_writer.close()
|
|
373
467
|
self._sequence_writer = None
|
|
468
|
+
pg.io.rename(self._tmp_path, self._path)
|
|
469
|
+
|
|
470
|
+
def __enter__(self):
|
|
471
|
+
return self
|
|
472
|
+
|
|
473
|
+
def __exit__(self, *args, **kwargs):
|
|
474
|
+
del args, kwargs
|
|
475
|
+
self.close()
|
|
374
476
|
|
|
375
477
|
def __del__(self):
|
|
376
478
|
self.close()
|
|
@@ -65,7 +65,7 @@ class ExampleCollector(experiment_lib.Plugin):
|
|
|
65
65
|
return self._examples
|
|
66
66
|
|
|
67
67
|
def on_example_complete(
|
|
68
|
-
self, runner:
|
|
68
|
+
self, runner: experiment_lib.Runner,
|
|
69
69
|
experiment: experiment_lib.Experiment,
|
|
70
70
|
example: example_lib.Example,
|
|
71
71
|
):
|
|
@@ -90,7 +90,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
|
90
90
|
root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer')
|
|
91
91
|
experiment = eval_test_helper.test_experiment()
|
|
92
92
|
checkpoint_filename = 'checkpoint.jsonl'
|
|
93
|
-
checkpointer = checkpointing.PerExampleCheckpointer(
|
|
93
|
+
checkpointer = checkpointing.PerExampleCheckpointer(
|
|
94
|
+
checkpoint_filename,
|
|
95
|
+
enable_inprogress_file=True
|
|
96
|
+
)
|
|
94
97
|
collector = ExampleCollector()
|
|
95
98
|
run = experiment.run(
|
|
96
99
|
root_dir, 'new', runner='sequential', plugins=[checkpointer, collector]
|
|
@@ -102,6 +105,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
|
102
105
|
example = collector.examples[i + 1]
|
|
103
106
|
ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
|
|
104
107
|
self.assertTrue(pg.io.path_exists(ckpt))
|
|
108
|
+
inprogress_file = run.output_path_for(
|
|
109
|
+
leaf, f'{example.id}.inprogress'
|
|
110
|
+
)
|
|
111
|
+
self.assertTrue(pg.io.path_exists(inprogress_file))
|
|
105
112
|
with pg.io.open_sequence(ckpt) as f:
|
|
106
113
|
examples_from_ckpt = list(iter(f))
|
|
107
114
|
# `eval_test_helper.test_experiment` has two TestEvaluation with
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Config saver plugins."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RunConfigSaver(experiment_lib.Plugin):
|
|
21
|
+
"""Saves the current run."""
|
|
22
|
+
|
|
23
|
+
def on_run_start(
|
|
24
|
+
self,
|
|
25
|
+
runner: experiment_lib.Runner,
|
|
26
|
+
root: experiment_lib.Experiment
|
|
27
|
+
) -> None:
|
|
28
|
+
del root # Unused.
|
|
29
|
+
self._save_run_config(runner)
|
|
30
|
+
|
|
31
|
+
def _save_run_config(self, runner: experiment_lib.Runner) -> None:
|
|
32
|
+
def _save():
|
|
33
|
+
runner.current_run.save(
|
|
34
|
+
os.path.join(runner.current_run.output_root, 'run.json'),
|
|
35
|
+
hide_default_values=True,
|
|
36
|
+
)
|
|
37
|
+
runner.background_run(_save)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Config saver test."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import tempfile
|
|
18
|
+
import unittest
|
|
19
|
+
from langfun.core.eval.v2 import config_saver
|
|
20
|
+
from langfun.core.eval.v2 import eval_test_helper
|
|
21
|
+
from langfun.core.eval.v2.runners import parallel # pylint: disable=unused-import
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RunConfigSaverTest(unittest.TestCase):
|
|
25
|
+
|
|
26
|
+
def test_save_run_config(self):
|
|
27
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_run_config_saver')
|
|
28
|
+
experiment = eval_test_helper.test_evaluation()
|
|
29
|
+
run = experiment.run(
|
|
30
|
+
root_dir, 'new', plugins=[config_saver.RunConfigSaver()]
|
|
31
|
+
)
|
|
32
|
+
self.assertTrue(os.path.exists(os.path.join(run.output_root, 'run.json')))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
if __name__ == '__main__':
|
|
36
|
+
unittest.main()
|
|
@@ -13,6 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Helper classes and functions for evaluation tests."""
|
|
15
15
|
|
|
16
|
+
import threading
|
|
17
|
+
import time
|
|
18
|
+
|
|
16
19
|
from langfun.core import language_model
|
|
17
20
|
from langfun.core import llms
|
|
18
21
|
from langfun.core import message as message_lib
|
|
@@ -47,6 +50,8 @@ class TestLLM(llms.Fake):
|
|
|
47
50
|
|
|
48
51
|
offset: int = 0
|
|
49
52
|
|
|
53
|
+
__test__ = False
|
|
54
|
+
|
|
50
55
|
def _response_from(self, prompt: message_lib.Message) -> message_lib.Message:
|
|
51
56
|
return message_lib.AIMessage(
|
|
52
57
|
str(prompt.metadata.x + prompt.metadata.y + self.offset)
|
|
@@ -63,6 +68,8 @@ class TestEvaluation(Evaluation):
|
|
|
63
68
|
metrics = [metrics_lib.Match()]
|
|
64
69
|
lm: language_model.LanguageModel = TestLLM()
|
|
65
70
|
|
|
71
|
+
__test__ = False
|
|
72
|
+
|
|
66
73
|
def process(self, example):
|
|
67
74
|
v = example.input
|
|
68
75
|
if v.x == 5:
|
|
@@ -75,7 +82,7 @@ class TestEvaluation(Evaluation):
|
|
|
75
82
|
|
|
76
83
|
class BadJsonConvertible(pg.Object):
|
|
77
84
|
|
|
78
|
-
def
|
|
85
|
+
def sym_jsonify(self, *args, **kwargs):
|
|
79
86
|
raise ValueError('Cannot convert to JSON.')
|
|
80
87
|
|
|
81
88
|
|
|
@@ -84,6 +91,8 @@ class TestEvaluationWithExampleCheckpointingError(TestEvaluation):
|
|
|
84
91
|
inputs = test_inputs()
|
|
85
92
|
metrics = [metrics_lib.Match()]
|
|
86
93
|
|
|
94
|
+
__test__ = False
|
|
95
|
+
|
|
87
96
|
def process(self, example):
|
|
88
97
|
return 1, dict(
|
|
89
98
|
x=BadJsonConvertible()
|
|
@@ -101,6 +110,8 @@ class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
|
|
|
101
110
|
inputs = test_inputs()
|
|
102
111
|
metrics = [metrics_lib.Match()]
|
|
103
112
|
|
|
113
|
+
__test__ = False
|
|
114
|
+
|
|
104
115
|
def process(self, example):
|
|
105
116
|
return 1, dict(
|
|
106
117
|
x=BadHtmlConvertible()
|
|
@@ -110,15 +121,22 @@ class TestEvaluationWithExampleHtmlGenerationError(Evaluation):
|
|
|
110
121
|
class TestEvaluationWithIndexHtmlGenerationError(TestEvaluation):
|
|
111
122
|
"""Test evaluation class with bad index HTML generation."""
|
|
112
123
|
|
|
124
|
+
__test__ = False
|
|
125
|
+
|
|
113
126
|
def _html_tree_view(self, *args, **kwargs):
|
|
114
127
|
raise ValueError('Cannot render HTML.')
|
|
115
128
|
|
|
116
129
|
|
|
130
|
+
def test_evaluation(offset: int | pg.hyper.OneOf = 0):
|
|
131
|
+
"""Returns a test evaluation."""
|
|
132
|
+
return TestEvaluation(lm=TestLLM(offset=offset))
|
|
133
|
+
|
|
134
|
+
|
|
117
135
|
def test_experiment():
|
|
118
136
|
"""Returns a test experiment."""
|
|
119
137
|
return Suite([
|
|
120
|
-
|
|
121
|
-
|
|
138
|
+
test_evaluation(),
|
|
139
|
+
test_evaluation(pg.oneof(range(5))),
|
|
122
140
|
])
|
|
123
141
|
|
|
124
142
|
|
|
@@ -135,3 +153,86 @@ def test_experiment_with_example_html_generation_error():
|
|
|
135
153
|
def test_experiment_with_index_html_generation_error():
|
|
136
154
|
"""Returns a test experiment with bad index HTML."""
|
|
137
155
|
return TestEvaluationWithIndexHtmlGenerationError()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class TestPlugin(experiment_lib.Plugin):
|
|
159
|
+
"""Plugin for testing."""
|
|
160
|
+
|
|
161
|
+
started_experiments: list[experiment_lib.Experiment] = []
|
|
162
|
+
completed_experiments: list[experiment_lib.Experiment] = []
|
|
163
|
+
skipped_experiments: list[experiment_lib.Experiment] = []
|
|
164
|
+
started_example_ids: list[int] = []
|
|
165
|
+
completed_example_ids: list[int] = []
|
|
166
|
+
start_time: float | None = None
|
|
167
|
+
complete_time: float | None = None
|
|
168
|
+
|
|
169
|
+
__test__ = False
|
|
170
|
+
|
|
171
|
+
def _on_bound(self):
|
|
172
|
+
super()._on_bound()
|
|
173
|
+
self._lock = threading.Lock()
|
|
174
|
+
|
|
175
|
+
def on_run_start(
|
|
176
|
+
self,
|
|
177
|
+
runner: experiment_lib.Runner,
|
|
178
|
+
root: experiment_lib.Experiment
|
|
179
|
+
) -> None:
|
|
180
|
+
del root
|
|
181
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
|
182
|
+
self.start_time = time.time()
|
|
183
|
+
|
|
184
|
+
def on_run_complete(
|
|
185
|
+
self,
|
|
186
|
+
runner: experiment_lib.Runner,
|
|
187
|
+
root: experiment_lib.Experiment
|
|
188
|
+
) -> None:
|
|
189
|
+
del root
|
|
190
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
|
191
|
+
self.complete_time = time.time()
|
|
192
|
+
|
|
193
|
+
def on_experiment_start(
|
|
194
|
+
self,
|
|
195
|
+
runner: experiment_lib.Runner,
|
|
196
|
+
experiment: experiment_lib.Experiment
|
|
197
|
+
) -> None:
|
|
198
|
+
del runner
|
|
199
|
+
with pg.notify_on_change(False), self._lock:
|
|
200
|
+
self.started_experiments.append(pg.Ref(experiment))
|
|
201
|
+
|
|
202
|
+
def on_experiment_skipped(
|
|
203
|
+
self,
|
|
204
|
+
runner: experiment_lib.Runner,
|
|
205
|
+
experiment: experiment_lib.Experiment
|
|
206
|
+
) -> None:
|
|
207
|
+
del runner
|
|
208
|
+
with pg.notify_on_change(False), self._lock:
|
|
209
|
+
self.skipped_experiments.append(pg.Ref(experiment))
|
|
210
|
+
|
|
211
|
+
def on_experiment_complete(
|
|
212
|
+
self,
|
|
213
|
+
runner: experiment_lib.Runner,
|
|
214
|
+
experiment: experiment_lib.Experiment
|
|
215
|
+
) -> None:
|
|
216
|
+
del runner
|
|
217
|
+
with pg.notify_on_change(False), self._lock:
|
|
218
|
+
self.completed_experiments.append(pg.Ref(experiment))
|
|
219
|
+
|
|
220
|
+
def on_example_start(
|
|
221
|
+
self,
|
|
222
|
+
runner: experiment_lib.Runner,
|
|
223
|
+
experiment: experiment_lib.Experiment,
|
|
224
|
+
example: Example
|
|
225
|
+
) -> None:
|
|
226
|
+
del runner, experiment
|
|
227
|
+
with pg.notify_on_change(False), self._lock:
|
|
228
|
+
self.started_example_ids.append(example.id)
|
|
229
|
+
|
|
230
|
+
def on_example_complete(
|
|
231
|
+
self,
|
|
232
|
+
runner: experiment_lib.Runner,
|
|
233
|
+
experiment: experiment_lib.Experiment,
|
|
234
|
+
example: Example
|
|
235
|
+
) -> None:
|
|
236
|
+
del runner, experiment
|
|
237
|
+
with pg.notify_on_change(False), self._lock:
|
|
238
|
+
self.completed_example_ids.append(example.id)
|