langfun 0.1.2.dev202511160804__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/__init__.py +1 -0
- langfun/core/agentic/__init__.py +4 -1
- langfun/core/agentic/action.py +340 -17
- langfun/core/agentic/action_test.py +124 -21
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/v2/checkpointing.py +25 -1
- langfun/core/eval/v2/checkpointing_test.py +8 -1
- langfun/core/eval/v2/eval_test_helper.py +7 -2
- langfun/core/eval/v2/evaluation.py +4 -1
- langfun/core/eval/v2/example.py +5 -1
- langfun/core/eval/v2/example_test.py +13 -5
- langfun/core/eval/v2/experiment.py +23 -0
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/progress_tracking.py +12 -3
- langfun/core/eval/v2/progress_tracking_test.py +3 -1
- langfun/core/eval/v2/reporting_test.py +4 -0
- langfun/core/eval/v2/runners/__init__.py +4 -0
- langfun/core/eval/v2/runners/base.py +40 -21
- langfun/core/eval/v2/runners/beam.py +341 -0
- langfun/core/eval/v2/runners/beam_test.py +131 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug_test.py +1 -4
- langfun/core/eval/v2/runners/parallel_test.py +1 -4
- langfun/core/eval/v2/runners/sequential_test.py +1 -4
- langfun/core/langfunc_test.py +3 -3
- langfun/core/language_model.py +38 -5
- langfun/core/language_model_test.py +45 -0
- langfun/core/llms/__init__.py +2 -0
- langfun/core/llms/gemini.py +41 -8
- langfun/core/llms/gemini_test.py +84 -0
- langfun/core/llms/google_genai.py +5 -0
- langfun/core/llms/vertexai.py +7 -0
- langfun/core/modalities/mime.py +2 -0
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/structured/schema/__init__.py +1 -0
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/RECORD +41 -37
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Checkpoint aggregator for Langfun evaluations."""
|
|
15
|
+
|
|
16
|
+
import concurrent.futures
|
|
17
|
+
import dataclasses
|
|
18
|
+
import os
|
|
19
|
+
import threading
|
|
20
|
+
import time
|
|
21
|
+
from typing import Annotated, Iterator
|
|
22
|
+
|
|
23
|
+
from langfun.core.eval.v2 import evaluation as evaluation_lib
|
|
24
|
+
from langfun.core.eval.v2 import example as example_lib
|
|
25
|
+
from langfun.core.eval.v2 import reporting
|
|
26
|
+
from langfun.core.eval.v2.runners import base
|
|
27
|
+
|
|
28
|
+
import pyglove as pg
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CheckpointMonitor(base.RunnerBase):
|
|
32
|
+
"""Runner for monitoring checkpoing files generated by other runners.
|
|
33
|
+
|
|
34
|
+
Currently checkpoint monitor only supports aggregating per-example
|
|
35
|
+
checkpoint files.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
NAME = 'checkpoint_monitor'
|
|
39
|
+
|
|
40
|
+
plugins = [
|
|
41
|
+
reporting.HtmlReporter(),
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
checkpoint_pattern: Annotated[
|
|
45
|
+
str, 'The glob pattern of the checkpoint files to monitor.'
|
|
46
|
+
] = 'checkpoint_*.bagz'
|
|
47
|
+
|
|
48
|
+
monitor_inprogress_files: Annotated[
|
|
49
|
+
bool,
|
|
50
|
+
'If True, monitor in-progress files to aggregate.'
|
|
51
|
+
] = False
|
|
52
|
+
|
|
53
|
+
poll_interval: Annotated[
|
|
54
|
+
int,
|
|
55
|
+
'The interval in seconds to poll for new checkpoint files.'
|
|
56
|
+
] = 5
|
|
57
|
+
|
|
58
|
+
max_aggregation_threads: Annotated[
|
|
59
|
+
int,
|
|
60
|
+
'The maximum number of threads to aggregate checkpoints.'
|
|
61
|
+
] = 128
|
|
62
|
+
|
|
63
|
+
@dataclasses.dataclass
|
|
64
|
+
class _AggregationEntry:
|
|
65
|
+
evaluation: evaluation_lib.Evaluation
|
|
66
|
+
output_dir: str
|
|
67
|
+
inprogress_file_pattern: str | None
|
|
68
|
+
ckpt_file_pattern: str
|
|
69
|
+
example_ids_inprogress: set[int]
|
|
70
|
+
example_ids_to_be_aggregated: set[int]
|
|
71
|
+
example_ids_being_aggregated: set[int]
|
|
72
|
+
completion_lock: threading.Lock
|
|
73
|
+
is_completed: bool = False
|
|
74
|
+
|
|
75
|
+
def _on_bound(self):
|
|
76
|
+
super()._on_bound()
|
|
77
|
+
self._monitor_thread = None
|
|
78
|
+
self._aggregation_entries = []
|
|
79
|
+
self._aggregator_pool = None
|
|
80
|
+
self._error = None
|
|
81
|
+
|
|
82
|
+
def start(self):
|
|
83
|
+
# Reset the experiment state before getting started.
|
|
84
|
+
self.current_run.experiment.reset()
|
|
85
|
+
|
|
86
|
+
# Signal the start of the run.
|
|
87
|
+
self.on_run_start()
|
|
88
|
+
|
|
89
|
+
# Start the non-leaf nodes.
|
|
90
|
+
for node in self.current_run.experiment.nonleaf_nodes:
|
|
91
|
+
self.on_experiment_start(node)
|
|
92
|
+
|
|
93
|
+
for evaluation in self.current_run.experiment.leaf_nodes:
|
|
94
|
+
# This is not precise, but we at least notify example start.
|
|
95
|
+
if not self.current_run.filter or self.current_run.filter(evaluation):
|
|
96
|
+
self.on_experiment_start(evaluation)
|
|
97
|
+
|
|
98
|
+
# Signal the start of the examples if we are not monitoring in-progress
|
|
99
|
+
# files.
|
|
100
|
+
if not self.monitor_inprogress_files:
|
|
101
|
+
for example_id in self.current_run.examples_to_evaluate(evaluation):
|
|
102
|
+
self._mark_example_started(evaluation, example_id)
|
|
103
|
+
|
|
104
|
+
# Create the aggregation entries for polling.
|
|
105
|
+
output_dir = self.current_run.output_dir(evaluation)
|
|
106
|
+
self._aggregation_entries.append(
|
|
107
|
+
self._AggregationEntry(
|
|
108
|
+
evaluation=evaluation,
|
|
109
|
+
output_dir=output_dir,
|
|
110
|
+
ckpt_file_pattern=os.path.join(
|
|
111
|
+
output_dir, self.checkpoint_pattern
|
|
112
|
+
),
|
|
113
|
+
inprogress_file_pattern=os.path.join(
|
|
114
|
+
output_dir, '*.inprogress'
|
|
115
|
+
) if self.monitor_inprogress_files else None,
|
|
116
|
+
example_ids_to_be_aggregated=(
|
|
117
|
+
self.current_run.examples_to_evaluate(evaluation)
|
|
118
|
+
),
|
|
119
|
+
example_ids_inprogress=set(),
|
|
120
|
+
example_ids_being_aggregated=set(),
|
|
121
|
+
completion_lock=threading.Lock(),
|
|
122
|
+
is_completed=False,
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
self.on_experiment_skipped(evaluation)
|
|
127
|
+
|
|
128
|
+
self._aggregator_pool = concurrent.futures.ThreadPoolExecutor(
|
|
129
|
+
max_workers=self.max_aggregation_threads
|
|
130
|
+
)
|
|
131
|
+
self._monitor_thread = threading.Thread(target=self._monitor_loop)
|
|
132
|
+
self._monitor_thread.start()
|
|
133
|
+
|
|
134
|
+
def join(self):
|
|
135
|
+
if self._monitor_thread:
|
|
136
|
+
self._monitor_thread.join()
|
|
137
|
+
if self._error is not None:
|
|
138
|
+
raise self._error
|
|
139
|
+
|
|
140
|
+
def run(self):
|
|
141
|
+
self.start()
|
|
142
|
+
self.join()
|
|
143
|
+
|
|
144
|
+
def _monitor_loop(self):
|
|
145
|
+
while not self._error and any(
|
|
146
|
+
not e.is_completed for e in self._aggregation_entries
|
|
147
|
+
):
|
|
148
|
+
for entry in self._aggregation_entries:
|
|
149
|
+
if not entry.example_ids_to_be_aggregated:
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
# Signal example processing.
|
|
153
|
+
if self.monitor_inprogress_files:
|
|
154
|
+
inprogress_files = pg.io.glob(entry.inprogress_file_pattern)
|
|
155
|
+
for inprogress_file in inprogress_files:
|
|
156
|
+
example_id = int(
|
|
157
|
+
os.path.basename(inprogress_file).split('.')[0]
|
|
158
|
+
)
|
|
159
|
+
if example_id not in entry.example_ids_inprogress:
|
|
160
|
+
self._mark_example_started(entry.evaluation, example_id)
|
|
161
|
+
entry.example_ids_inprogress.add(example_id)
|
|
162
|
+
|
|
163
|
+
for filepath in pg.io.glob(entry.ckpt_file_pattern):
|
|
164
|
+
example_id = int(
|
|
165
|
+
os.path.basename(filepath).split('.')[0].split('_')[-1]
|
|
166
|
+
)
|
|
167
|
+
if example_id in entry.example_ids_to_be_aggregated:
|
|
168
|
+
# Remove example ID from the set to avoid duplicate processing.
|
|
169
|
+
entry.example_ids_to_be_aggregated.remove(example_id)
|
|
170
|
+
entry.example_ids_being_aggregated.add(example_id)
|
|
171
|
+
|
|
172
|
+
# It could be that the example has been processed before, but the
|
|
173
|
+
# inprogress file was removed. In this case, we should signal the
|
|
174
|
+
# example has started before completing it.
|
|
175
|
+
if example_id not in entry.example_ids_inprogress:
|
|
176
|
+
self._mark_example_started(entry.evaluation, example_id)
|
|
177
|
+
entry.example_ids_inprogress.add(example_id)
|
|
178
|
+
|
|
179
|
+
self._aggregator_pool.submit(
|
|
180
|
+
self._aggregate, entry, filepath, example_id
|
|
181
|
+
)
|
|
182
|
+
pg.logging.info(
|
|
183
|
+
'[%s] Aggregating example %d from %s...',
|
|
184
|
+
entry.evaluation.id,
|
|
185
|
+
example_id,
|
|
186
|
+
filepath,
|
|
187
|
+
)
|
|
188
|
+
time.sleep(self.poll_interval)
|
|
189
|
+
|
|
190
|
+
if self._error is None:
|
|
191
|
+
self.on_run_complete()
|
|
192
|
+
else:
|
|
193
|
+
self.on_run_abort(self._error)
|
|
194
|
+
|
|
195
|
+
def _aggregate(
|
|
196
|
+
self,
|
|
197
|
+
entry: _AggregationEntry,
|
|
198
|
+
ckpt_filepath: str,
|
|
199
|
+
example_id: int
|
|
200
|
+
):
|
|
201
|
+
"""Aggregate an example from a checkpoint file."""
|
|
202
|
+
try:
|
|
203
|
+
loaded_examples = entry.evaluation.state.load(
|
|
204
|
+
ckpt_filepath,
|
|
205
|
+
example_input_by_id=entry.evaluation.example_input_by_id,
|
|
206
|
+
# Example metadata may be expensive to load, and is not used by
|
|
207
|
+
# metric aggregation. Thus we do not load example metadata.
|
|
208
|
+
load_example_metadata=False
|
|
209
|
+
)
|
|
210
|
+
assert len(loaded_examples) > 1, loaded_examples
|
|
211
|
+
# Ocassionally the per-example checkpoint file may contain the same
|
|
212
|
+
# example processed multiple times. We only need to aggregate the last
|
|
213
|
+
# example.
|
|
214
|
+
example = loaded_examples[-1]
|
|
215
|
+
except BaseException as e: # pylint: disable=broad-except
|
|
216
|
+
error_info = pg.ErrorInfo.from_exception(e)
|
|
217
|
+
pg.logging.error(
|
|
218
|
+
'[%s] Failed to aggregate example %d: %s',
|
|
219
|
+
entry.evaluation.id,
|
|
220
|
+
example_id,
|
|
221
|
+
error_info
|
|
222
|
+
)
|
|
223
|
+
example = example_lib.Example(
|
|
224
|
+
id=example_id,
|
|
225
|
+
input=entry.evaluation.example_input_by_id(example_id),
|
|
226
|
+
error=error_info,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# This will skip processing but still allow metrics to be collected.
|
|
230
|
+
# `process` will never be called for evaluation, thus we do not
|
|
231
|
+
# need to setup/teardown evaluation.
|
|
232
|
+
example = entry.evaluation.evaluate(
|
|
233
|
+
example, reevaluate_upon_previous_errors=False
|
|
234
|
+
)
|
|
235
|
+
example.newly_processed = True
|
|
236
|
+
pg.logging.info(
|
|
237
|
+
'[%s] Successfully aggregated example %d from %s.',
|
|
238
|
+
entry.evaluation.id,
|
|
239
|
+
example_id,
|
|
240
|
+
ckpt_filepath,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
self.on_example_complete(entry.evaluation, example)
|
|
245
|
+
except BaseException as e: # pylint: disable=broad-except
|
|
246
|
+
# Plugin failures should be raised to the user.
|
|
247
|
+
self._error = e
|
|
248
|
+
|
|
249
|
+
entry.example_ids_being_aggregated.remove(example_id)
|
|
250
|
+
|
|
251
|
+
# Remove the in-progress file to indicate that the example has been
|
|
252
|
+
# processed.
|
|
253
|
+
try:
|
|
254
|
+
pg.io.rm(os.path.join(entry.output_dir, f'{example_id}.inprogress'))
|
|
255
|
+
except FileNotFoundError:
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
if (not self._error
|
|
259
|
+
and not entry.example_ids_to_be_aggregated
|
|
260
|
+
and not entry.example_ids_being_aggregated):
|
|
261
|
+
with entry.completion_lock:
|
|
262
|
+
if not entry.is_completed:
|
|
263
|
+
entry.is_completed = True
|
|
264
|
+
try:
|
|
265
|
+
self.on_experiment_complete(entry.evaluation)
|
|
266
|
+
except BaseException as e: # pylint: disable=broad-except
|
|
267
|
+
# Plugin failures should be raised to the user.
|
|
268
|
+
self._error = e
|
|
269
|
+
|
|
270
|
+
def _mark_example_started(
|
|
271
|
+
self,
|
|
272
|
+
evaluation: evaluation_lib.Evaluation,
|
|
273
|
+
example_id: int
|
|
274
|
+
) -> None:
|
|
275
|
+
"""Mark an example as started."""
|
|
276
|
+
example = example_lib.Example(
|
|
277
|
+
id=example_id, input=evaluation.example_input_by_id(example_id),
|
|
278
|
+
)
|
|
279
|
+
example.start_time = time.time()
|
|
280
|
+
self.on_example_start(evaluation, example)
|
|
281
|
+
|
|
282
|
+
# We update evaluation state with the inprogress status so the evaluation
|
|
283
|
+
# HTML could show remotely in-progress examples.
|
|
284
|
+
evaluation.state.update(example, in_progress=True)
|
|
285
|
+
|
|
286
|
+
def _run(self, evaluations: list[evaluation_lib.Evaluation]):
|
|
287
|
+
raise NotImplementedError('Not needed in checkpoint monitor.')
|
|
288
|
+
|
|
289
|
+
def _evaluate_items(
|
|
290
|
+
self,
|
|
291
|
+
evaluation: evaluation_lib.Evaluation,
|
|
292
|
+
items: Iterator[example_lib.Example]
|
|
293
|
+
) -> None:
|
|
294
|
+
raise NotImplementedError('Not needed in checkpoint monitor.')
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import os
|
|
15
|
+
import tempfile
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
from langfun.core.eval.v2 import checkpointing
|
|
19
|
+
from langfun.core.eval.v2 import eval_test_helper
|
|
20
|
+
from langfun.core.eval.v2 import example as example_lib
|
|
21
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
|
22
|
+
from langfun.core.eval.v2.runners import ckpt_monitor
|
|
23
|
+
from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-import
|
|
24
|
+
import pyglove as pg
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CheckpointMonitorTest(unittest.TestCase):
|
|
28
|
+
|
|
29
|
+
def setUp(self):
|
|
30
|
+
super().setUp()
|
|
31
|
+
self.test_dir = tempfile.mkdtemp()
|
|
32
|
+
|
|
33
|
+
def test_aggregate(self):
|
|
34
|
+
exp = eval_test_helper.test_experiment()
|
|
35
|
+
root_dir = os.path.join(self.test_dir, 'test_aggregate')
|
|
36
|
+
run = exp.run(
|
|
37
|
+
root_dir,
|
|
38
|
+
runner='sequential',
|
|
39
|
+
progress_tracker=None,
|
|
40
|
+
plugins=[
|
|
41
|
+
checkpointing.PerExampleCheckpointer(
|
|
42
|
+
checkpoint_filename='checkpoint.jsonl'
|
|
43
|
+
)
|
|
44
|
+
],
|
|
45
|
+
use_cache='no',
|
|
46
|
+
)
|
|
47
|
+
# Try to corrupt one of the checkpoint files.
|
|
48
|
+
pg.io.writefile(
|
|
49
|
+
run.output_path_for(exp.leaf_nodes[0], 'checkpoint_1.jsonl'),
|
|
50
|
+
'bad ckpt'
|
|
51
|
+
)
|
|
52
|
+
plugin = eval_test_helper.TestPlugin()
|
|
53
|
+
monitor = ckpt_monitor.CheckpointMonitor(
|
|
54
|
+
run,
|
|
55
|
+
plugins=[plugin],
|
|
56
|
+
checkpoint_pattern='checkpoint_*.jsonl',
|
|
57
|
+
monitor_inprogress_files=True,
|
|
58
|
+
)
|
|
59
|
+
monitor.run()
|
|
60
|
+
|
|
61
|
+
# Assert that the in-progress files are created and not removed.
|
|
62
|
+
for entry in monitor._aggregation_entries:
|
|
63
|
+
self.assertEqual(len(entry.example_ids_inprogress), 10)
|
|
64
|
+
|
|
65
|
+
# 6 leaf nodes + 1 suite + 1 hyper.
|
|
66
|
+
self.assertEqual(len(plugin.started_experiments), 6 + 2)
|
|
67
|
+
self.assertEqual(len(plugin.completed_experiments), 6 + 2)
|
|
68
|
+
self.assertEqual(len(plugin.started_example_ids), 10 * 6)
|
|
69
|
+
self.assertEqual(len(plugin.completed_example_ids), 10 * 6)
|
|
70
|
+
for e in exp.leaf_nodes:
|
|
71
|
+
self.assertEqual(e.progress.num_completed, 10)
|
|
72
|
+
|
|
73
|
+
def test_aggregate_with_filter(self):
|
|
74
|
+
exp = eval_test_helper.test_experiment()
|
|
75
|
+
root_dir = os.path.join(self.test_dir, 'test_aggregate_with_filter')
|
|
76
|
+
|
|
77
|
+
node_to_skip = exp.leaf_nodes[2]
|
|
78
|
+
# Run experiment to generate checkpoint files for all examples.
|
|
79
|
+
run = exp.run(
|
|
80
|
+
root_dir,
|
|
81
|
+
runner='sequential',
|
|
82
|
+
filter=lambda e: e.id != node_to_skip.id,
|
|
83
|
+
progress_tracker=None,
|
|
84
|
+
plugins=[
|
|
85
|
+
checkpointing.PerExampleCheckpointer(
|
|
86
|
+
checkpoint_filename='checkpoint.jsonl'
|
|
87
|
+
)
|
|
88
|
+
],
|
|
89
|
+
use_cache='no',
|
|
90
|
+
)
|
|
91
|
+
plugin = eval_test_helper.TestPlugin()
|
|
92
|
+
monitor = ckpt_monitor.CheckpointMonitor(
|
|
93
|
+
run,
|
|
94
|
+
plugins=[plugin],
|
|
95
|
+
checkpoint_pattern='checkpoint_*.jsonl',
|
|
96
|
+
)
|
|
97
|
+
monitor.run()
|
|
98
|
+
|
|
99
|
+
# Assert that on_experiment_skipped was called for the filtered node.
|
|
100
|
+
self.assertEqual(len(plugin.skipped_experiments), 1)
|
|
101
|
+
self.assertEqual(plugin.skipped_experiments[0].id, node_to_skip.id)
|
|
102
|
+
|
|
103
|
+
# Assert that the skipped node was not started.
|
|
104
|
+
started_ids = [e.id for e in plugin.started_experiments]
|
|
105
|
+
self.assertNotIn(node_to_skip.id, started_ids)
|
|
106
|
+
|
|
107
|
+
def test_plugin_raise(self):
|
|
108
|
+
|
|
109
|
+
class TestPlugin(eval_test_helper.TestPlugin):
|
|
110
|
+
simulate_raise_on_example_complete: bool = False
|
|
111
|
+
simulate_raise_on_experiment_complete: bool = False
|
|
112
|
+
|
|
113
|
+
def on_example_complete(
|
|
114
|
+
self,
|
|
115
|
+
runner: experiment_lib.Runner,
|
|
116
|
+
experiment: experiment_lib.Experiment,
|
|
117
|
+
example: example_lib.Example
|
|
118
|
+
):
|
|
119
|
+
if self.simulate_raise_on_example_complete:
|
|
120
|
+
raise ValueError('example complete error')
|
|
121
|
+
|
|
122
|
+
def on_experiment_complete(
|
|
123
|
+
self,
|
|
124
|
+
runner: experiment_lib.Runner,
|
|
125
|
+
experiment: experiment_lib.Experiment
|
|
126
|
+
):
|
|
127
|
+
if self.simulate_raise_on_experiment_complete:
|
|
128
|
+
raise ValueError('experiment complete error')
|
|
129
|
+
|
|
130
|
+
exp = eval_test_helper.test_evaluation()
|
|
131
|
+
root_dir = os.path.join(self.test_dir, 'test_plugin_raise')
|
|
132
|
+
|
|
133
|
+
# Run experiment to generate checkpoint files for all examples.
|
|
134
|
+
run = exp.run(
|
|
135
|
+
root_dir,
|
|
136
|
+
runner='sequential',
|
|
137
|
+
progress_tracker=None,
|
|
138
|
+
plugins=[
|
|
139
|
+
checkpointing.PerExampleCheckpointer(
|
|
140
|
+
checkpoint_filename='checkpoint.jsonl'
|
|
141
|
+
)
|
|
142
|
+
],
|
|
143
|
+
use_cache='no',
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
with self.assertRaisesRegex(ValueError, 'example complete error'):
|
|
147
|
+
ckpt_monitor.CheckpointMonitor(
|
|
148
|
+
run,
|
|
149
|
+
plugins=[TestPlugin(simulate_raise_on_example_complete=True)],
|
|
150
|
+
checkpoint_pattern='checkpoint_*.jsonl',
|
|
151
|
+
).run()
|
|
152
|
+
|
|
153
|
+
with self.assertRaisesRegex(ValueError, 'experiment complete error'):
|
|
154
|
+
ckpt_monitor.CheckpointMonitor(
|
|
155
|
+
run,
|
|
156
|
+
plugins=[TestPlugin(simulate_raise_on_experiment_complete=True)],
|
|
157
|
+
checkpoint_pattern='checkpoint_*.jsonl',
|
|
158
|
+
).run()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == '__main__':
|
|
162
|
+
unittest.main()
|
|
@@ -23,7 +23,7 @@ from langfun.core.eval.v2.runners import debug # pylint: disable=unused-import
|
|
|
23
23
|
import pyglove as pg
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class
|
|
26
|
+
class DebugRunnerTest(unittest.TestCase):
|
|
27
27
|
|
|
28
28
|
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
|
29
29
|
self.assertEqual(len(actual), len(expected))
|
|
@@ -32,9 +32,6 @@ class RunnerTest(unittest.TestCase):
|
|
|
32
32
|
print(i, pg.diff(x, y))
|
|
33
33
|
self.assertIs(x, y)
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
class DebugRunnerTest(RunnerTest):
|
|
37
|
-
|
|
38
35
|
def test_debug_runner(self):
|
|
39
36
|
plugin = eval_test_helper.TestPlugin()
|
|
40
37
|
exp = eval_test_helper.test_experiment()
|
|
@@ -23,7 +23,7 @@ from langfun.core.eval.v2.runners import parallel # pylint: disable=unused-impo
|
|
|
23
23
|
import pyglove as pg
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class
|
|
26
|
+
class ParallelRunnerTest(unittest.TestCase):
|
|
27
27
|
|
|
28
28
|
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
|
29
29
|
self.assertEqual(len(actual), len(expected))
|
|
@@ -32,9 +32,6 @@ class RunnerTest(unittest.TestCase):
|
|
|
32
32
|
print(i, pg.diff(x, y))
|
|
33
33
|
self.assertIs(x, y)
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
class ParallelRunnerTest(RunnerTest):
|
|
37
|
-
|
|
38
35
|
def test_parallel_runner(self):
|
|
39
36
|
plugin = eval_test_helper.TestPlugin()
|
|
40
37
|
exp = eval_test_helper.test_experiment()
|
|
@@ -23,7 +23,7 @@ from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-im
|
|
|
23
23
|
import pyglove as pg
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class
|
|
26
|
+
class SequentialRunnerTest(unittest.TestCase):
|
|
27
27
|
|
|
28
28
|
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
|
29
29
|
self.assertEqual(len(actual), len(expected))
|
|
@@ -32,9 +32,6 @@ class RunnerTest(unittest.TestCase):
|
|
|
32
32
|
print(i, pg.diff(x, y))
|
|
33
33
|
self.assertIs(x, y)
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
class SequentialRunnerTest(RunnerTest):
|
|
37
|
-
|
|
38
35
|
def test_basic(self):
|
|
39
36
|
plugin = eval_test_helper.TestPlugin()
|
|
40
37
|
exp = eval_test_helper.test_experiment()
|
langfun/core/langfunc_test.py
CHANGED
|
@@ -109,9 +109,9 @@ class LangFuncCallTest(unittest.TestCase):
|
|
|
109
109
|
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
|
|
110
110
|
' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
|
|
111
111
|
' random_seed=None, logprobs=False, top_logprobs=None,'
|
|
112
|
-
' max_thinking_tokens=None,
|
|
113
|
-
' cache=None, max_concurrency=None, timeout=120.0,
|
|
114
|
-
' retry_interval=(5, 60), exponential_backoff=True,'
|
|
112
|
+
' max_thinking_tokens=None, thinking_level=None, reasoning_effort=None,'
|
|
113
|
+
' extras={}), cache=None, max_concurrency=None, timeout=120.0,'
|
|
114
|
+
' max_attempts=5, retry_interval=(5, 60), exponential_backoff=True,'
|
|
115
115
|
' max_retry_interval=300, debug=False))',
|
|
116
116
|
)
|
|
117
117
|
|
langfun/core/language_model.py
CHANGED
|
@@ -53,6 +53,10 @@ class RetryableLMError(LMError):
|
|
|
53
53
|
"""Base class for LLM errors that can be solved by retrying."""
|
|
54
54
|
|
|
55
55
|
|
|
56
|
+
class EmptyGenerationError(RetryableLMError):
|
|
57
|
+
"""Error for empty generaition."""
|
|
58
|
+
|
|
59
|
+
|
|
56
60
|
class RateLimitError(RetryableLMError):
|
|
57
61
|
"""Error for rate limit reached."""
|
|
58
62
|
|
|
@@ -575,6 +579,14 @@ class LMSamplingOptions(component.Component):
|
|
|
575
579
|
int | None, 'Number of max thinking tokens.'
|
|
576
580
|
] = None
|
|
577
581
|
|
|
582
|
+
thinking_level: Annotated[
|
|
583
|
+
Literal['low', 'high'] | None,
|
|
584
|
+
(
|
|
585
|
+
'Thinking level for Gemini models. High is for complex tasks, '
|
|
586
|
+
'while low is for faster responses.'
|
|
587
|
+
),
|
|
588
|
+
] = None
|
|
589
|
+
|
|
578
590
|
reasoning_effort: Annotated[
|
|
579
591
|
Literal['low', 'medium', 'high'] | None,
|
|
580
592
|
(
|
|
@@ -1076,10 +1088,32 @@ class LanguageModel(component.Component):
|
|
|
1076
1088
|
prompts = [message_lib.UserMessage.from_value(p) for p in prompts]
|
|
1077
1089
|
|
|
1078
1090
|
with component.context(override_attrs=True, **kwargs):
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1091
|
+
|
|
1092
|
+
def _sample_with_retry():
|
|
1093
|
+
if self.cache is None:
|
|
1094
|
+
results = self._sample(prompts)
|
|
1095
|
+
else:
|
|
1096
|
+
results = self._sample_with_cache_lookup(prompts, cache_seed)
|
|
1097
|
+
|
|
1098
|
+
for i, result in enumerate(results):
|
|
1099
|
+
for sample in result.samples:
|
|
1100
|
+
if not sample.response.text:
|
|
1101
|
+
if self.cache is not None:
|
|
1102
|
+
self.cache.delete(self, prompts[i], seed=cache_seed)
|
|
1103
|
+
raise EmptyGenerationError(
|
|
1104
|
+
f'Empty generation encountered from model {self.model_id}.'
|
|
1105
|
+
)
|
|
1106
|
+
return results
|
|
1107
|
+
|
|
1108
|
+
retry_fn = concurrent.with_retry(
|
|
1109
|
+
_sample_with_retry,
|
|
1110
|
+
retry_on_errors=EmptyGenerationError,
|
|
1111
|
+
max_attempts=self.max_attempts,
|
|
1112
|
+
retry_interval=self.retry_interval,
|
|
1113
|
+
exponential_backoff=self.exponential_backoff,
|
|
1114
|
+
max_retry_interval=self.max_retry_interval,
|
|
1115
|
+
)
|
|
1116
|
+
results = retry_fn()
|
|
1083
1117
|
|
|
1084
1118
|
for prompt, result in zip(prompts, results):
|
|
1085
1119
|
|
|
@@ -1088,7 +1122,6 @@ class LanguageModel(component.Component):
|
|
|
1088
1122
|
|
|
1089
1123
|
for sample in result.samples:
|
|
1090
1124
|
# Update metadata for response message.
|
|
1091
|
-
|
|
1092
1125
|
response = sample.response
|
|
1093
1126
|
response.metadata.score = sample.score
|
|
1094
1127
|
response.metadata.logprobs = sample.logprobs
|
|
@@ -591,6 +591,51 @@ class LanguageModelTest(unittest.TestCase):
|
|
|
591
591
|
lm = MockModel(cache=cache, top_k=1)
|
|
592
592
|
self.assertEqual(lm('a'), 'a')
|
|
593
593
|
|
|
594
|
+
def test_empty_generation_error(self):
|
|
595
|
+
class MockModelWithEmptyResponse(MockModel):
|
|
596
|
+
def _sample(self,
|
|
597
|
+
prompts: list[message_lib.Message]
|
|
598
|
+
) -> list[lm_lib.LMSamplingResult]:
|
|
599
|
+
return [lm_lib.LMSamplingResult(
|
|
600
|
+
[lm_lib.LMSample(response='')],
|
|
601
|
+
usage=lm_lib.LMSamplingUsage(100, 0, 100, 1, 1.0)
|
|
602
|
+
)]
|
|
603
|
+
lm = MockModelWithEmptyResponse(max_attempts=1, retry_interval=0)
|
|
604
|
+
with self.assertRaisesRegex(
|
|
605
|
+
concurrent.RetryError, 'Empty generation encountered'
|
|
606
|
+
):
|
|
607
|
+
lm('a')
|
|
608
|
+
|
|
609
|
+
def test_empty_generation_retry(self):
|
|
610
|
+
class MockModelWithEmptyThenValid(MockModel):
|
|
611
|
+
attempt_count: int = 0
|
|
612
|
+
|
|
613
|
+
def _sample(
|
|
614
|
+
self, prompts: list[message_lib.Message]
|
|
615
|
+
) -> list[lm_lib.LMSamplingResult]:
|
|
616
|
+
self.rebind(attempt_count=self.attempt_count + 1)
|
|
617
|
+
if self.attempt_count == 1:
|
|
618
|
+
# First attempt returns empty
|
|
619
|
+
return [
|
|
620
|
+
lm_lib.LMSamplingResult(
|
|
621
|
+
[lm_lib.LMSample(response='')],
|
|
622
|
+
usage=lm_lib.LMSamplingUsage(100, 0, 100, 1, 1.0),
|
|
623
|
+
)
|
|
624
|
+
]
|
|
625
|
+
else:
|
|
626
|
+
# Subsequent attempts return valid response
|
|
627
|
+
return [
|
|
628
|
+
lm_lib.LMSamplingResult(
|
|
629
|
+
[lm_lib.LMSample(response='valid response')],
|
|
630
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
|
631
|
+
)
|
|
632
|
+
]
|
|
633
|
+
|
|
634
|
+
lm = MockModelWithEmptyThenValid(max_attempts=3, retry_interval=0)
|
|
635
|
+
result = lm('a')
|
|
636
|
+
self.assertEqual(result.text, 'valid response')
|
|
637
|
+
self.assertEqual(lm.attempt_count, 2)
|
|
638
|
+
|
|
594
639
|
def test_estimate_max_concurrency(self):
|
|
595
640
|
self.assertIsNone(lm_lib.LanguageModel.estimate_max_concurrency(None, None))
|
|
596
641
|
self.assertEqual(
|
langfun/core/llms/__init__.py
CHANGED
|
@@ -42,6 +42,7 @@ from langfun.core.llms.azure_openai import AzureOpenAI
|
|
|
42
42
|
|
|
43
43
|
# Gemini models.
|
|
44
44
|
from langfun.core.llms.google_genai import GenAI
|
|
45
|
+
from langfun.core.llms.google_genai import Gemini3ProPreview
|
|
45
46
|
from langfun.core.llms.google_genai import Gemini25Pro
|
|
46
47
|
from langfun.core.llms.google_genai import Gemini25Flash
|
|
47
48
|
from langfun.core.llms.google_genai import Gemini25ProPreview_20250605
|
|
@@ -90,6 +91,7 @@ from langfun.core.llms.vertexai import VertexAIGemini25ProPreview_20250605
|
|
|
90
91
|
from langfun.core.llms.vertexai import VertexAIGemini25Pro
|
|
91
92
|
from langfun.core.llms.vertexai import VertexAIGemini25Flash
|
|
92
93
|
from langfun.core.llms.vertexai import VertexAIGemini25FlashImagePreview
|
|
94
|
+
from langfun.core.llms.vertexai import VertexAIGemini3ProPreview
|
|
93
95
|
|
|
94
96
|
# For backward compatibility.
|
|
95
97
|
GeminiPro1_5 = Gemini15Pro
|