langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Checkpoint aggregator for Langfun evaluations."""
|
|
15
|
+
|
|
16
|
+
import concurrent.futures
|
|
17
|
+
import dataclasses
|
|
18
|
+
import os
|
|
19
|
+
import threading
|
|
20
|
+
import time
|
|
21
|
+
from typing import Annotated, Iterator
|
|
22
|
+
|
|
23
|
+
from langfun.core.eval.v2 import evaluation as evaluation_lib
|
|
24
|
+
from langfun.core.eval.v2 import example as example_lib
|
|
25
|
+
from langfun.core.eval.v2 import reporting
|
|
26
|
+
from langfun.core.eval.v2.runners import base
|
|
27
|
+
|
|
28
|
+
import pyglove as pg
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CheckpointMonitor(base.RunnerBase):
|
|
32
|
+
"""Runner for monitoring checkpoing files generated by other runners.
|
|
33
|
+
|
|
34
|
+
Currently checkpoint monitor only supports aggregating per-example
|
|
35
|
+
checkpoint files.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
NAME = 'checkpoint_monitor'
|
|
39
|
+
|
|
40
|
+
plugins = [
|
|
41
|
+
reporting.HtmlReporter(),
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
checkpoint_pattern: Annotated[
|
|
45
|
+
str, 'The glob pattern of the checkpoint files to monitor.'
|
|
46
|
+
] = 'checkpoint_*.bagz'
|
|
47
|
+
|
|
48
|
+
monitor_inprogress_files: Annotated[
|
|
49
|
+
bool,
|
|
50
|
+
'If True, monitor in-progress files to aggregate.'
|
|
51
|
+
] = False
|
|
52
|
+
|
|
53
|
+
poll_interval: Annotated[
|
|
54
|
+
int,
|
|
55
|
+
'The interval in seconds to poll for new checkpoint files.'
|
|
56
|
+
] = 5
|
|
57
|
+
|
|
58
|
+
max_aggregation_threads: Annotated[
|
|
59
|
+
int,
|
|
60
|
+
'The maximum number of threads to aggregate checkpoints.'
|
|
61
|
+
] = 128
|
|
62
|
+
|
|
63
|
+
bypass_old_ckpt_files_with_non_oop_errors: Annotated[
|
|
64
|
+
bool,
|
|
65
|
+
'If True, ignore old checkpoint files with non-oop errors.'
|
|
66
|
+
] = True
|
|
67
|
+
|
|
68
|
+
ckpt_start_time: Annotated[
|
|
69
|
+
float | None,
|
|
70
|
+
(
|
|
71
|
+
'The timestamp to treat checkpoint files modified before this '
|
|
72
|
+
'time as old.'
|
|
73
|
+
)
|
|
74
|
+
] = None
|
|
75
|
+
|
|
76
|
+
@dataclasses.dataclass
|
|
77
|
+
class _AggregationEntry:
|
|
78
|
+
evaluation: evaluation_lib.Evaluation
|
|
79
|
+
output_dir: str
|
|
80
|
+
inprogress_file_pattern: str | None
|
|
81
|
+
ckpt_file_pattern: str
|
|
82
|
+
example_ids_inprogress: set[int]
|
|
83
|
+
example_ids_to_be_aggregated: set[int]
|
|
84
|
+
example_ids_being_aggregated: set[int]
|
|
85
|
+
completion_lock: threading.Lock
|
|
86
|
+
is_completed: bool = False
|
|
87
|
+
|
|
88
|
+
def _on_bound(self):
|
|
89
|
+
super()._on_bound()
|
|
90
|
+
self._monitor_thread = None
|
|
91
|
+
self._aggregation_entries = []
|
|
92
|
+
self._aggregator_pool = None
|
|
93
|
+
self._error = None
|
|
94
|
+
if self.ckpt_start_time is None:
|
|
95
|
+
self.rebind(ckpt_start_time=time.time(), skip_notification=True)
|
|
96
|
+
self._ckpt_bypass_timestamp: dict[str, int] = {}
|
|
97
|
+
|
|
98
|
+
def start(self):
|
|
99
|
+
# Reset the experiment state before getting started.
|
|
100
|
+
self.current_run.experiment.reset()
|
|
101
|
+
|
|
102
|
+
# Signal the start of the run.
|
|
103
|
+
self.on_run_start()
|
|
104
|
+
|
|
105
|
+
# Start the non-leaf nodes.
|
|
106
|
+
for node in self.current_run.experiment.nonleaf_nodes:
|
|
107
|
+
self.on_experiment_start(node)
|
|
108
|
+
|
|
109
|
+
for evaluation in self.current_run.experiment.leaf_nodes:
|
|
110
|
+
# This is not precise, but we at least notify example start.
|
|
111
|
+
if not self.current_run.filter or self.current_run.filter(evaluation):
|
|
112
|
+
self.on_experiment_start(evaluation)
|
|
113
|
+
|
|
114
|
+
# Signal the start of the examples if we are not monitoring in-progress
|
|
115
|
+
# files.
|
|
116
|
+
if not self.monitor_inprogress_files:
|
|
117
|
+
for example_id in self.current_run.examples_to_evaluate(evaluation):
|
|
118
|
+
self._mark_example_started(evaluation, example_id)
|
|
119
|
+
|
|
120
|
+
# Create the aggregation entries for polling.
|
|
121
|
+
output_dir = self.current_run.output_dir(evaluation)
|
|
122
|
+
self._aggregation_entries.append(
|
|
123
|
+
self._AggregationEntry(
|
|
124
|
+
evaluation=evaluation,
|
|
125
|
+
output_dir=output_dir,
|
|
126
|
+
ckpt_file_pattern=os.path.join(
|
|
127
|
+
output_dir, self.checkpoint_pattern
|
|
128
|
+
),
|
|
129
|
+
inprogress_file_pattern=os.path.join(
|
|
130
|
+
output_dir, '*.inprogress'
|
|
131
|
+
) if self.monitor_inprogress_files else None,
|
|
132
|
+
example_ids_to_be_aggregated=(
|
|
133
|
+
self.current_run.examples_to_evaluate(evaluation)
|
|
134
|
+
),
|
|
135
|
+
example_ids_inprogress=set(),
|
|
136
|
+
example_ids_being_aggregated=set(),
|
|
137
|
+
completion_lock=threading.Lock(),
|
|
138
|
+
is_completed=False,
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
self.on_experiment_skipped(evaluation)
|
|
143
|
+
|
|
144
|
+
self._aggregator_pool = concurrent.futures.ThreadPoolExecutor(
|
|
145
|
+
max_workers=self.max_aggregation_threads
|
|
146
|
+
)
|
|
147
|
+
self._monitor_thread = threading.Thread(target=self._monitor_loop)
|
|
148
|
+
self._monitor_thread.start()
|
|
149
|
+
|
|
150
|
+
def join(self):
|
|
151
|
+
if self._monitor_thread:
|
|
152
|
+
self._monitor_thread.join()
|
|
153
|
+
if self._error is not None:
|
|
154
|
+
raise self._error
|
|
155
|
+
|
|
156
|
+
def run(self):
|
|
157
|
+
self.start()
|
|
158
|
+
self.join()
|
|
159
|
+
|
|
160
|
+
def _monitor_loop(self):
|
|
161
|
+
while not self._error and any(
|
|
162
|
+
not e.is_completed for e in self._aggregation_entries
|
|
163
|
+
):
|
|
164
|
+
for entry in self._aggregation_entries:
|
|
165
|
+
if not entry.example_ids_to_be_aggregated:
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
# Signal example processing.
|
|
169
|
+
if self.monitor_inprogress_files:
|
|
170
|
+
inprogress_files = pg.io.glob(entry.inprogress_file_pattern)
|
|
171
|
+
for inprogress_file in inprogress_files:
|
|
172
|
+
example_id = int(
|
|
173
|
+
os.path.basename(inprogress_file).split('.')[0]
|
|
174
|
+
)
|
|
175
|
+
if example_id not in entry.example_ids_inprogress:
|
|
176
|
+
self._mark_example_started(entry.evaluation, example_id)
|
|
177
|
+
entry.example_ids_inprogress.add(example_id)
|
|
178
|
+
|
|
179
|
+
for filepath in pg.io.glob(entry.ckpt_file_pattern):
|
|
180
|
+
example_id = int(
|
|
181
|
+
os.path.basename(filepath).split('.')[0].split('_')[-1]
|
|
182
|
+
)
|
|
183
|
+
if example_id in entry.example_ids_to_be_aggregated:
|
|
184
|
+
last_modified_time = pg.io.getmtime(filepath)
|
|
185
|
+
bypass_timestamp = self._ckpt_bypass_timestamp.get(filepath)
|
|
186
|
+
if (
|
|
187
|
+
bypass_timestamp is not None
|
|
188
|
+
and last_modified_time <= bypass_timestamp
|
|
189
|
+
):
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
# Remove example ID from the set to avoid duplicate processing.
|
|
193
|
+
entry.example_ids_to_be_aggregated.remove(example_id)
|
|
194
|
+
entry.example_ids_being_aggregated.add(example_id)
|
|
195
|
+
|
|
196
|
+
# It could be that the example has been processed before, but the
|
|
197
|
+
# inprogress file was removed. In this case, we should signal the
|
|
198
|
+
# example has started before completing it.
|
|
199
|
+
if example_id not in entry.example_ids_inprogress:
|
|
200
|
+
self._mark_example_started(entry.evaluation, example_id)
|
|
201
|
+
entry.example_ids_inprogress.add(example_id)
|
|
202
|
+
|
|
203
|
+
self._aggregator_pool.submit(
|
|
204
|
+
self._aggregate, entry, filepath, example_id, last_modified_time
|
|
205
|
+
)
|
|
206
|
+
pg.logging.info(
|
|
207
|
+
'[%s] Aggregating example %d from %s...',
|
|
208
|
+
entry.evaluation.id,
|
|
209
|
+
example_id,
|
|
210
|
+
filepath,
|
|
211
|
+
)
|
|
212
|
+
time.sleep(self.poll_interval)
|
|
213
|
+
|
|
214
|
+
if self._error is None:
|
|
215
|
+
self.on_run_complete()
|
|
216
|
+
else:
|
|
217
|
+
self.on_run_abort(self._error)
|
|
218
|
+
|
|
219
|
+
def _aggregate(
|
|
220
|
+
self,
|
|
221
|
+
entry: _AggregationEntry,
|
|
222
|
+
ckpt_filepath: str,
|
|
223
|
+
example_id: int,
|
|
224
|
+
last_modified_time: float,
|
|
225
|
+
):
|
|
226
|
+
"""Aggregate an example from a checkpoint file."""
|
|
227
|
+
try:
|
|
228
|
+
loaded_examples = entry.evaluation.state.load(
|
|
229
|
+
ckpt_filepath,
|
|
230
|
+
example_input_by_id=entry.evaluation.example_input_by_id,
|
|
231
|
+
# Example metadata may be expensive to load, and is not used by
|
|
232
|
+
# metric aggregation. Thus we do not load example metadata.
|
|
233
|
+
load_example_metadata=False
|
|
234
|
+
)
|
|
235
|
+
assert len(loaded_examples) >= 1, loaded_examples
|
|
236
|
+
# Ocassionally the per-example checkpoint file may contain the same
|
|
237
|
+
# example processed multiple times. We only need to aggregate the last
|
|
238
|
+
# example.
|
|
239
|
+
example = loaded_examples[-1]
|
|
240
|
+
if (
|
|
241
|
+
self.bypass_old_ckpt_files_with_non_oop_errors
|
|
242
|
+
and last_modified_time < self.ckpt_start_time
|
|
243
|
+
and example.error is not None
|
|
244
|
+
and not example.error.tag.startswith('MappingError')
|
|
245
|
+
):
|
|
246
|
+
entry.example_ids_being_aggregated.remove(example_id)
|
|
247
|
+
entry.example_ids_to_be_aggregated.add(example_id)
|
|
248
|
+
self._ckpt_bypass_timestamp[ckpt_filepath] = last_modified_time
|
|
249
|
+
pg.logging.info(
|
|
250
|
+
'[%s] Bypassing old checkpoint file with non-oop errors (%s) '
|
|
251
|
+
'for example %d, last_modified_time: %s, ckpt_start_time: %s',
|
|
252
|
+
entry.evaluation.id,
|
|
253
|
+
ckpt_filepath,
|
|
254
|
+
example_id,
|
|
255
|
+
last_modified_time,
|
|
256
|
+
self.ckpt_start_time,
|
|
257
|
+
)
|
|
258
|
+
return
|
|
259
|
+
except BaseException as e: # pylint: disable=broad-except
|
|
260
|
+
error_info = pg.ErrorInfo.from_exception(e)
|
|
261
|
+
pg.logging.error(
|
|
262
|
+
'[%s] Failed to aggregate example %d: %s',
|
|
263
|
+
entry.evaluation.id,
|
|
264
|
+
example_id,
|
|
265
|
+
error_info
|
|
266
|
+
)
|
|
267
|
+
example = example_lib.Example(
|
|
268
|
+
id=example_id,
|
|
269
|
+
input=entry.evaluation.example_input_by_id(example_id),
|
|
270
|
+
error=error_info,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# This will skip processing but still allow metrics to be collected.
|
|
274
|
+
# `process` will never be called for evaluation, thus we do not
|
|
275
|
+
# need to setup/teardown evaluation.
|
|
276
|
+
try:
|
|
277
|
+
example = entry.evaluation.evaluate(
|
|
278
|
+
example, reevaluate_upon_previous_errors=False
|
|
279
|
+
)
|
|
280
|
+
except BaseException as e: # pylint: disable=broad-except
|
|
281
|
+
pg.logging.error(
|
|
282
|
+
'[%s] Unexpected error found during evaluating example %d from %s.',
|
|
283
|
+
entry.evaluation.id,
|
|
284
|
+
example_id,
|
|
285
|
+
ckpt_filepath,
|
|
286
|
+
)
|
|
287
|
+
self._error = e
|
|
288
|
+
entry.example_ids_being_aggregated.remove(example_id)
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
example.newly_processed = True
|
|
292
|
+
pg.logging.info(
|
|
293
|
+
'[%s] Successfully aggregated example %d from %s.',
|
|
294
|
+
entry.evaluation.id,
|
|
295
|
+
example_id,
|
|
296
|
+
ckpt_filepath,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
try:
|
|
300
|
+
self.on_example_complete(entry.evaluation, example)
|
|
301
|
+
except BaseException as e: # pylint: disable=broad-except
|
|
302
|
+
# Plugin failures should be raised to the user.
|
|
303
|
+
self._error = e
|
|
304
|
+
|
|
305
|
+
entry.example_ids_being_aggregated.remove(example_id)
|
|
306
|
+
|
|
307
|
+
# Remove the in-progress file to indicate that the example has been
|
|
308
|
+
# processed.
|
|
309
|
+
try:
|
|
310
|
+
pg.io.rm(os.path.join(entry.output_dir, f'{example_id}.inprogress'))
|
|
311
|
+
except FileNotFoundError:
|
|
312
|
+
pass
|
|
313
|
+
|
|
314
|
+
if (not self._error
|
|
315
|
+
and not entry.example_ids_to_be_aggregated
|
|
316
|
+
and not entry.example_ids_being_aggregated):
|
|
317
|
+
with entry.completion_lock:
|
|
318
|
+
if not entry.is_completed:
|
|
319
|
+
entry.is_completed = True
|
|
320
|
+
try:
|
|
321
|
+
self.on_experiment_complete(entry.evaluation)
|
|
322
|
+
except BaseException as e: # pylint: disable=broad-except
|
|
323
|
+
# Plugin failures should be raised to the user.
|
|
324
|
+
self._error = e
|
|
325
|
+
|
|
326
|
+
def _mark_example_started(
|
|
327
|
+
self,
|
|
328
|
+
evaluation: evaluation_lib.Evaluation,
|
|
329
|
+
example_id: int
|
|
330
|
+
) -> None:
|
|
331
|
+
"""Mark an example as started."""
|
|
332
|
+
example = example_lib.Example(
|
|
333
|
+
id=example_id, input=evaluation.example_input_by_id(example_id),
|
|
334
|
+
)
|
|
335
|
+
example.start_time = time.time()
|
|
336
|
+
self.on_example_start(evaluation, example)
|
|
337
|
+
|
|
338
|
+
# We update evaluation state with the inprogress status so the evaluation
|
|
339
|
+
# HTML could show remotely in-progress examples.
|
|
340
|
+
evaluation.state.update(example, in_progress=True)
|
|
341
|
+
|
|
342
|
+
def _run(self, evaluations: list[evaluation_lib.Evaluation]):
|
|
343
|
+
raise NotImplementedError('Not needed in checkpoint monitor.')
|
|
344
|
+
|
|
345
|
+
def _evaluate_items(
|
|
346
|
+
self,
|
|
347
|
+
evaluation: evaluation_lib.Evaluation,
|
|
348
|
+
items: Iterator[example_lib.Example]
|
|
349
|
+
) -> None:
|
|
350
|
+
raise NotImplementedError('Not needed in checkpoint monitor.')
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import os
|
|
15
|
+
import tempfile
|
|
16
|
+
import time
|
|
17
|
+
import unittest
|
|
18
|
+
|
|
19
|
+
import langfun.core as lf
|
|
20
|
+
from langfun.core.eval.v2 import checkpointing
|
|
21
|
+
from langfun.core.eval.v2 import eval_test_helper
|
|
22
|
+
from langfun.core.eval.v2 import example as example_lib
|
|
23
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
|
24
|
+
from langfun.core.eval.v2.runners import ckpt_monitor
|
|
25
|
+
from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-import
|
|
26
|
+
import pyglove as pg
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class CheckpointMonitorTest(unittest.TestCase):
|
|
30
|
+
|
|
31
|
+
def setUp(self):
|
|
32
|
+
super().setUp()
|
|
33
|
+
self.test_dir = tempfile.mkdtemp()
|
|
34
|
+
|
|
35
|
+
def test_aggregate(self):
|
|
36
|
+
exp = eval_test_helper.test_experiment()
|
|
37
|
+
root_dir = os.path.join(self.test_dir, 'test_aggregate')
|
|
38
|
+
ckpt_start_time = time.time()
|
|
39
|
+
run = exp.run(
|
|
40
|
+
root_dir,
|
|
41
|
+
runner='sequential',
|
|
42
|
+
progress_tracker=None,
|
|
43
|
+
plugins=[
|
|
44
|
+
checkpointing.PerExampleCheckpointer(
|
|
45
|
+
checkpoint_filename='checkpoint.jsonl'
|
|
46
|
+
)
|
|
47
|
+
],
|
|
48
|
+
use_cache='no',
|
|
49
|
+
)
|
|
50
|
+
# Try to corrupt one of the checkpoint files.
|
|
51
|
+
pg.io.writefile(
|
|
52
|
+
run.output_path_for(exp.leaf_nodes[0], 'checkpoint_1.jsonl'),
|
|
53
|
+
'bad ckpt'
|
|
54
|
+
)
|
|
55
|
+
plugin = eval_test_helper.TestPlugin()
|
|
56
|
+
monitor = ckpt_monitor.CheckpointMonitor(
|
|
57
|
+
run,
|
|
58
|
+
plugins=[plugin],
|
|
59
|
+
checkpoint_pattern='checkpoint_*.jsonl',
|
|
60
|
+
monitor_inprogress_files=True,
|
|
61
|
+
ckpt_start_time=ckpt_start_time,
|
|
62
|
+
)
|
|
63
|
+
monitor.run()
|
|
64
|
+
|
|
65
|
+
# Assert that the in-progress files are created and not removed.
|
|
66
|
+
for entry in monitor._aggregation_entries:
|
|
67
|
+
self.assertEqual(len(entry.example_ids_inprogress), 10)
|
|
68
|
+
|
|
69
|
+
# 6 leaf nodes + 1 suite + 1 hyper.
|
|
70
|
+
self.assertEqual(len(plugin.started_experiments), 6 + 2)
|
|
71
|
+
self.assertEqual(len(plugin.completed_experiments), 6 + 2)
|
|
72
|
+
self.assertEqual(len(plugin.started_example_ids), 10 * 6)
|
|
73
|
+
self.assertEqual(len(plugin.completed_example_ids), 10 * 6)
|
|
74
|
+
for e in exp.leaf_nodes:
|
|
75
|
+
self.assertEqual(e.progress.num_completed, 10)
|
|
76
|
+
|
|
77
|
+
def test_ignore_old_ckpt_files_with_non_oop_errors(self):
|
|
78
|
+
exp = eval_test_helper.test_evaluation()
|
|
79
|
+
root_dir = os.path.join(self.test_dir, 'test_ignore_old_ckpt_files')
|
|
80
|
+
run = exp.run(
|
|
81
|
+
root_dir,
|
|
82
|
+
runner='sequential',
|
|
83
|
+
progress_tracker=None,
|
|
84
|
+
plugins=[
|
|
85
|
+
checkpointing.PerExampleCheckpointer(
|
|
86
|
+
checkpoint_filename='checkpoint.jsonl'
|
|
87
|
+
)
|
|
88
|
+
],
|
|
89
|
+
use_cache='no',
|
|
90
|
+
)
|
|
91
|
+
monitor = ckpt_monitor.CheckpointMonitor(
|
|
92
|
+
run,
|
|
93
|
+
plugins=[],
|
|
94
|
+
checkpoint_pattern='checkpoint_*.jsonl',
|
|
95
|
+
monitor_inprogress_files=True
|
|
96
|
+
)
|
|
97
|
+
monitor.start()
|
|
98
|
+
time.sleep(2)
|
|
99
|
+
# Example 6 is a non-oop error, we simulate a re-evaluation.
|
|
100
|
+
ex = example_lib.Example(
|
|
101
|
+
id=6, output=1, metric_metadata={'match': {'is_correct': True}},
|
|
102
|
+
start_time=time.time() - 2, end_time=time.time(),
|
|
103
|
+
usage_summary=lf.UsageSummary(),
|
|
104
|
+
execution_status={
|
|
105
|
+
'evaluate': pg.utils.TimeIt.Status(name='evaluate', elapse=1)
|
|
106
|
+
}
|
|
107
|
+
)
|
|
108
|
+
with pg.io.open_sequence(
|
|
109
|
+
run.output_path_for(exp, 'checkpoint_6.jsonl'),
|
|
110
|
+
mode='w'
|
|
111
|
+
) as f:
|
|
112
|
+
f.add(pg.to_json_str(ex))
|
|
113
|
+
print(time.time(), pg.io.listdir(run.output_dir(exp)))
|
|
114
|
+
monitor.join()
|
|
115
|
+
self.assertEqual(exp.progress.num_processed, 10)
|
|
116
|
+
self.assertEqual(exp.progress.num_completed, 10)
|
|
117
|
+
self.assertEqual(exp.progress.num_failed, 0)
|
|
118
|
+
|
|
119
|
+
def test_aggregate_with_filter(self):
|
|
120
|
+
ckpt_start_time = time.time()
|
|
121
|
+
exp = eval_test_helper.test_experiment()
|
|
122
|
+
root_dir = os.path.join(self.test_dir, 'test_aggregate_with_filter')
|
|
123
|
+
|
|
124
|
+
node_to_skip = exp.leaf_nodes[2]
|
|
125
|
+
# Run experiment to generate checkpoint files for all examples.
|
|
126
|
+
run = exp.run(
|
|
127
|
+
root_dir,
|
|
128
|
+
runner='sequential',
|
|
129
|
+
filter=lambda e: e.id != node_to_skip.id,
|
|
130
|
+
progress_tracker=None,
|
|
131
|
+
plugins=[
|
|
132
|
+
checkpointing.PerExampleCheckpointer(
|
|
133
|
+
checkpoint_filename='checkpoint.jsonl'
|
|
134
|
+
)
|
|
135
|
+
],
|
|
136
|
+
use_cache='no',
|
|
137
|
+
)
|
|
138
|
+
plugin = eval_test_helper.TestPlugin()
|
|
139
|
+
monitor = ckpt_monitor.CheckpointMonitor(
|
|
140
|
+
run,
|
|
141
|
+
plugins=[plugin],
|
|
142
|
+
checkpoint_pattern='checkpoint_*.jsonl',
|
|
143
|
+
ckpt_start_time=ckpt_start_time,
|
|
144
|
+
)
|
|
145
|
+
monitor.run()
|
|
146
|
+
|
|
147
|
+
# Assert that on_experiment_skipped was called for the filtered node.
|
|
148
|
+
self.assertEqual(len(plugin.skipped_experiments), 1)
|
|
149
|
+
self.assertEqual(plugin.skipped_experiments[0].id, node_to_skip.id)
|
|
150
|
+
|
|
151
|
+
# Assert that the skipped node was not started.
|
|
152
|
+
started_ids = [e.id for e in plugin.started_experiments]
|
|
153
|
+
self.assertNotIn(node_to_skip.id, started_ids)
|
|
154
|
+
|
|
155
|
+
def test_plugin_raise(self):
|
|
156
|
+
|
|
157
|
+
class TestPlugin(eval_test_helper.TestPlugin):
|
|
158
|
+
simulate_raise_on_example_complete: bool = False
|
|
159
|
+
simulate_raise_on_experiment_complete: bool = False
|
|
160
|
+
|
|
161
|
+
def on_example_complete(
|
|
162
|
+
self,
|
|
163
|
+
runner: experiment_lib.Runner,
|
|
164
|
+
experiment: experiment_lib.Experiment,
|
|
165
|
+
example: example_lib.Example
|
|
166
|
+
):
|
|
167
|
+
if self.simulate_raise_on_example_complete:
|
|
168
|
+
raise ValueError('example complete error')
|
|
169
|
+
|
|
170
|
+
def on_experiment_complete(
|
|
171
|
+
self,
|
|
172
|
+
runner: experiment_lib.Runner,
|
|
173
|
+
experiment: experiment_lib.Experiment
|
|
174
|
+
):
|
|
175
|
+
if self.simulate_raise_on_experiment_complete:
|
|
176
|
+
raise ValueError('experiment complete error')
|
|
177
|
+
|
|
178
|
+
ckpt_start_time = time.time()
|
|
179
|
+
exp = eval_test_helper.test_evaluation()
|
|
180
|
+
root_dir = os.path.join(self.test_dir, 'test_plugin_raise')
|
|
181
|
+
|
|
182
|
+
# Run experiment to generate checkpoint files for all examples.
|
|
183
|
+
run = exp.run(
|
|
184
|
+
root_dir,
|
|
185
|
+
runner='sequential',
|
|
186
|
+
progress_tracker=None,
|
|
187
|
+
plugins=[
|
|
188
|
+
checkpointing.PerExampleCheckpointer(
|
|
189
|
+
checkpoint_filename='checkpoint.jsonl'
|
|
190
|
+
)
|
|
191
|
+
],
|
|
192
|
+
use_cache='no',
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
with self.assertRaisesRegex(ValueError, 'example complete error'):
|
|
196
|
+
ckpt_monitor.CheckpointMonitor(
|
|
197
|
+
run,
|
|
198
|
+
plugins=[TestPlugin(simulate_raise_on_example_complete=True)],
|
|
199
|
+
checkpoint_pattern='checkpoint_*.jsonl',
|
|
200
|
+
ckpt_start_time=ckpt_start_time,
|
|
201
|
+
).run()
|
|
202
|
+
|
|
203
|
+
with self.assertRaisesRegex(ValueError, 'experiment complete error'):
|
|
204
|
+
ckpt_monitor.CheckpointMonitor(
|
|
205
|
+
run,
|
|
206
|
+
plugins=[TestPlugin(simulate_raise_on_experiment_complete=True)],
|
|
207
|
+
checkpoint_pattern='checkpoint_*.jsonl',
|
|
208
|
+
ckpt_start_time=ckpt_start_time,
|
|
209
|
+
).run()
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
if __name__ == '__main__':
|
|
213
|
+
unittest.main()
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Debug runner."""
|
|
15
|
+
|
|
16
|
+
from langfun.core.eval.v2.runners import sequential
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DebugRunner(sequential.SequentialRunner):
|
|
20
|
+
"""A runner for debugging evaluations.
|
|
21
|
+
|
|
22
|
+
The debug runner is a sequential runner that only runs the first example
|
|
23
|
+
of each evaluation, with `raise_if_has_error` enabled. This is useful for
|
|
24
|
+
quickly identifying issues in evaluation logic during development.
|
|
25
|
+
Checkpointers are disabled for this runner.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
NAME = 'debug'
|
|
29
|
+
|
|
30
|
+
# Do not use the checkpointer for debug runner.
|
|
31
|
+
plugins = []
|
|
32
|
+
|
|
33
|
+
def _on_bound(self):
|
|
34
|
+
super()._on_bound()
|
|
35
|
+
if self.current_run.example_ids is None:
|
|
36
|
+
self.current_run.rebind(example_ids=[1], skip_notification=True)
|
|
37
|
+
self.current_run.rebind(raise_if_has_error=True, skip_notification=True)
|
|
38
|
+
|
|
39
|
+
def _save_run_manifest(self) -> None:
|
|
40
|
+
"""Do nothing to avoid overriden existing runs."""
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Tests for debug runner."""
|
|
15
|
+
import os
|
|
16
|
+
import tempfile
|
|
17
|
+
from typing import Any
|
|
18
|
+
import unittest
|
|
19
|
+
|
|
20
|
+
from langfun.core.eval.v2 import eval_test_helper
|
|
21
|
+
from langfun.core.eval.v2.runners import debug # pylint: disable=unused-import
|
|
22
|
+
|
|
23
|
+
import pyglove as pg
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DebugRunnerTest(unittest.TestCase):
|
|
27
|
+
|
|
28
|
+
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
|
29
|
+
self.assertEqual(len(actual), len(expected))
|
|
30
|
+
for i, (x, y) in enumerate(zip(actual, expected)):
|
|
31
|
+
if x is not y:
|
|
32
|
+
print(i, pg.diff(x, y))
|
|
33
|
+
self.assertIs(x, y)
|
|
34
|
+
|
|
35
|
+
def test_debug_runner(self):
|
|
36
|
+
plugin = eval_test_helper.TestPlugin()
|
|
37
|
+
exp = eval_test_helper.test_experiment()
|
|
38
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_debug_runner')
|
|
39
|
+
run = exp.run(root_dir, runner='debug', plugins=[plugin])
|
|
40
|
+
|
|
41
|
+
self.assertIsNotNone(plugin.start_time)
|
|
42
|
+
self.assertIsNotNone(plugin.complete_time)
|
|
43
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
|
44
|
+
|
|
45
|
+
self.assertEqual(
|
|
46
|
+
len(plugin.started_experiments), len(exp.nodes)
|
|
47
|
+
)
|
|
48
|
+
self.assertEqual(
|
|
49
|
+
len(plugin.completed_experiments), len(exp.nodes)
|
|
50
|
+
)
|
|
51
|
+
self.assertEqual(
|
|
52
|
+
len(plugin.started_example_ids), 6 * 1
|
|
53
|
+
)
|
|
54
|
+
self.assertEqual(
|
|
55
|
+
len(plugin.completed_example_ids), 6 * 1
|
|
56
|
+
)
|
|
57
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
|
58
|
+
self.assertFalse(
|
|
59
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
for node in exp.nodes:
|
|
63
|
+
self.assertTrue(node.progress.is_started)
|
|
64
|
+
self.assertTrue(node.progress.is_completed)
|
|
65
|
+
if node.is_leaf:
|
|
66
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
67
|
+
self.assertEqual(node.progress.num_completed, 1)
|
|
68
|
+
self.assertEqual(node.progress.num_failed, 0)
|
|
69
|
+
else:
|
|
70
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
|
71
|
+
self.assertEqual(node.progress.num_failed, 0)
|
|
72
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
if __name__ == '__main__':
|
|
76
|
+
unittest.main()
|