langfun 0.1.2.dev202511160804__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (41) hide show
  1. langfun/core/__init__.py +1 -0
  2. langfun/core/agentic/__init__.py +4 -1
  3. langfun/core/agentic/action.py +340 -17
  4. langfun/core/agentic/action_test.py +124 -21
  5. langfun/core/eval/base_test.py +5 -5
  6. langfun/core/eval/v2/checkpointing.py +25 -1
  7. langfun/core/eval/v2/checkpointing_test.py +8 -1
  8. langfun/core/eval/v2/eval_test_helper.py +7 -2
  9. langfun/core/eval/v2/evaluation.py +4 -1
  10. langfun/core/eval/v2/example.py +5 -1
  11. langfun/core/eval/v2/example_test.py +13 -5
  12. langfun/core/eval/v2/experiment.py +23 -0
  13. langfun/core/eval/v2/experiment_test.py +19 -0
  14. langfun/core/eval/v2/progress_tracking.py +12 -3
  15. langfun/core/eval/v2/progress_tracking_test.py +3 -1
  16. langfun/core/eval/v2/reporting_test.py +4 -0
  17. langfun/core/eval/v2/runners/__init__.py +4 -0
  18. langfun/core/eval/v2/runners/base.py +40 -21
  19. langfun/core/eval/v2/runners/beam.py +341 -0
  20. langfun/core/eval/v2/runners/beam_test.py +131 -0
  21. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  22. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  23. langfun/core/eval/v2/runners/debug_test.py +1 -4
  24. langfun/core/eval/v2/runners/parallel_test.py +1 -4
  25. langfun/core/eval/v2/runners/sequential_test.py +1 -4
  26. langfun/core/langfunc_test.py +3 -3
  27. langfun/core/language_model.py +38 -5
  28. langfun/core/language_model_test.py +45 -0
  29. langfun/core/llms/__init__.py +2 -0
  30. langfun/core/llms/gemini.py +41 -8
  31. langfun/core/llms/gemini_test.py +84 -0
  32. langfun/core/llms/google_genai.py +5 -0
  33. langfun/core/llms/vertexai.py +7 -0
  34. langfun/core/modalities/mime.py +2 -0
  35. langfun/core/modalities/mime_test.py +11 -0
  36. langfun/core/structured/schema/__init__.py +1 -0
  37. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
  38. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/RECORD +41 -37
  39. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
  40. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
  41. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Checkpoint aggregator for Langfun evaluations."""
15
+
16
+ import concurrent.futures
17
+ import dataclasses
18
+ import os
19
+ import threading
20
+ import time
21
+ from typing import Annotated, Iterator
22
+
23
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
24
+ from langfun.core.eval.v2 import example as example_lib
25
+ from langfun.core.eval.v2 import reporting
26
+ from langfun.core.eval.v2.runners import base
27
+
28
+ import pyglove as pg
29
+
30
+
31
+ class CheckpointMonitor(base.RunnerBase):
32
+ """Runner for monitoring checkpoing files generated by other runners.
33
+
34
+ Currently checkpoint monitor only supports aggregating per-example
35
+ checkpoint files.
36
+ """
37
+
38
+ NAME = 'checkpoint_monitor'
39
+
40
+ plugins = [
41
+ reporting.HtmlReporter(),
42
+ ]
43
+
44
+ checkpoint_pattern: Annotated[
45
+ str, 'The glob pattern of the checkpoint files to monitor.'
46
+ ] = 'checkpoint_*.bagz'
47
+
48
+ monitor_inprogress_files: Annotated[
49
+ bool,
50
+ 'If True, monitor in-progress files to aggregate.'
51
+ ] = False
52
+
53
+ poll_interval: Annotated[
54
+ int,
55
+ 'The interval in seconds to poll for new checkpoint files.'
56
+ ] = 5
57
+
58
+ max_aggregation_threads: Annotated[
59
+ int,
60
+ 'The maximum number of threads to aggregate checkpoints.'
61
+ ] = 128
62
+
63
+ @dataclasses.dataclass
64
+ class _AggregationEntry:
65
+ evaluation: evaluation_lib.Evaluation
66
+ output_dir: str
67
+ inprogress_file_pattern: str | None
68
+ ckpt_file_pattern: str
69
+ example_ids_inprogress: set[int]
70
+ example_ids_to_be_aggregated: set[int]
71
+ example_ids_being_aggregated: set[int]
72
+ completion_lock: threading.Lock
73
+ is_completed: bool = False
74
+
75
+ def _on_bound(self):
76
+ super()._on_bound()
77
+ self._monitor_thread = None
78
+ self._aggregation_entries = []
79
+ self._aggregator_pool = None
80
+ self._error = None
81
+
82
+ def start(self):
83
+ # Reset the experiment state before getting started.
84
+ self.current_run.experiment.reset()
85
+
86
+ # Signal the start of the run.
87
+ self.on_run_start()
88
+
89
+ # Start the non-leaf nodes.
90
+ for node in self.current_run.experiment.nonleaf_nodes:
91
+ self.on_experiment_start(node)
92
+
93
+ for evaluation in self.current_run.experiment.leaf_nodes:
94
+ # This is not precise, but we at least notify example start.
95
+ if not self.current_run.filter or self.current_run.filter(evaluation):
96
+ self.on_experiment_start(evaluation)
97
+
98
+ # Signal the start of the examples if we are not monitoring in-progress
99
+ # files.
100
+ if not self.monitor_inprogress_files:
101
+ for example_id in self.current_run.examples_to_evaluate(evaluation):
102
+ self._mark_example_started(evaluation, example_id)
103
+
104
+ # Create the aggregation entries for polling.
105
+ output_dir = self.current_run.output_dir(evaluation)
106
+ self._aggregation_entries.append(
107
+ self._AggregationEntry(
108
+ evaluation=evaluation,
109
+ output_dir=output_dir,
110
+ ckpt_file_pattern=os.path.join(
111
+ output_dir, self.checkpoint_pattern
112
+ ),
113
+ inprogress_file_pattern=os.path.join(
114
+ output_dir, '*.inprogress'
115
+ ) if self.monitor_inprogress_files else None,
116
+ example_ids_to_be_aggregated=(
117
+ self.current_run.examples_to_evaluate(evaluation)
118
+ ),
119
+ example_ids_inprogress=set(),
120
+ example_ids_being_aggregated=set(),
121
+ completion_lock=threading.Lock(),
122
+ is_completed=False,
123
+ )
124
+ )
125
+ else:
126
+ self.on_experiment_skipped(evaluation)
127
+
128
+ self._aggregator_pool = concurrent.futures.ThreadPoolExecutor(
129
+ max_workers=self.max_aggregation_threads
130
+ )
131
+ self._monitor_thread = threading.Thread(target=self._monitor_loop)
132
+ self._monitor_thread.start()
133
+
134
+ def join(self):
135
+ if self._monitor_thread:
136
+ self._monitor_thread.join()
137
+ if self._error is not None:
138
+ raise self._error
139
+
140
+ def run(self):
141
+ self.start()
142
+ self.join()
143
+
144
+ def _monitor_loop(self):
145
+ while not self._error and any(
146
+ not e.is_completed for e in self._aggregation_entries
147
+ ):
148
+ for entry in self._aggregation_entries:
149
+ if not entry.example_ids_to_be_aggregated:
150
+ continue
151
+
152
+ # Signal example processing.
153
+ if self.monitor_inprogress_files:
154
+ inprogress_files = pg.io.glob(entry.inprogress_file_pattern)
155
+ for inprogress_file in inprogress_files:
156
+ example_id = int(
157
+ os.path.basename(inprogress_file).split('.')[0]
158
+ )
159
+ if example_id not in entry.example_ids_inprogress:
160
+ self._mark_example_started(entry.evaluation, example_id)
161
+ entry.example_ids_inprogress.add(example_id)
162
+
163
+ for filepath in pg.io.glob(entry.ckpt_file_pattern):
164
+ example_id = int(
165
+ os.path.basename(filepath).split('.')[0].split('_')[-1]
166
+ )
167
+ if example_id in entry.example_ids_to_be_aggregated:
168
+ # Remove example ID from the set to avoid duplicate processing.
169
+ entry.example_ids_to_be_aggregated.remove(example_id)
170
+ entry.example_ids_being_aggregated.add(example_id)
171
+
172
+ # It could be that the example has been processed before, but the
173
+ # inprogress file was removed. In this case, we should signal the
174
+ # example has started before completing it.
175
+ if example_id not in entry.example_ids_inprogress:
176
+ self._mark_example_started(entry.evaluation, example_id)
177
+ entry.example_ids_inprogress.add(example_id)
178
+
179
+ self._aggregator_pool.submit(
180
+ self._aggregate, entry, filepath, example_id
181
+ )
182
+ pg.logging.info(
183
+ '[%s] Aggregating example %d from %s...',
184
+ entry.evaluation.id,
185
+ example_id,
186
+ filepath,
187
+ )
188
+ time.sleep(self.poll_interval)
189
+
190
+ if self._error is None:
191
+ self.on_run_complete()
192
+ else:
193
+ self.on_run_abort(self._error)
194
+
195
+ def _aggregate(
196
+ self,
197
+ entry: _AggregationEntry,
198
+ ckpt_filepath: str,
199
+ example_id: int
200
+ ):
201
+ """Aggregate an example from a checkpoint file."""
202
+ try:
203
+ loaded_examples = entry.evaluation.state.load(
204
+ ckpt_filepath,
205
+ example_input_by_id=entry.evaluation.example_input_by_id,
206
+ # Example metadata may be expensive to load, and is not used by
207
+ # metric aggregation. Thus we do not load example metadata.
208
+ load_example_metadata=False
209
+ )
210
+ assert len(loaded_examples) > 1, loaded_examples
211
+ # Ocassionally the per-example checkpoint file may contain the same
212
+ # example processed multiple times. We only need to aggregate the last
213
+ # example.
214
+ example = loaded_examples[-1]
215
+ except BaseException as e: # pylint: disable=broad-except
216
+ error_info = pg.ErrorInfo.from_exception(e)
217
+ pg.logging.error(
218
+ '[%s] Failed to aggregate example %d: %s',
219
+ entry.evaluation.id,
220
+ example_id,
221
+ error_info
222
+ )
223
+ example = example_lib.Example(
224
+ id=example_id,
225
+ input=entry.evaluation.example_input_by_id(example_id),
226
+ error=error_info,
227
+ )
228
+
229
+ # This will skip processing but still allow metrics to be collected.
230
+ # `process` will never be called for evaluation, thus we do not
231
+ # need to setup/teardown evaluation.
232
+ example = entry.evaluation.evaluate(
233
+ example, reevaluate_upon_previous_errors=False
234
+ )
235
+ example.newly_processed = True
236
+ pg.logging.info(
237
+ '[%s] Successfully aggregated example %d from %s.',
238
+ entry.evaluation.id,
239
+ example_id,
240
+ ckpt_filepath,
241
+ )
242
+
243
+ try:
244
+ self.on_example_complete(entry.evaluation, example)
245
+ except BaseException as e: # pylint: disable=broad-except
246
+ # Plugin failures should be raised to the user.
247
+ self._error = e
248
+
249
+ entry.example_ids_being_aggregated.remove(example_id)
250
+
251
+ # Remove the in-progress file to indicate that the example has been
252
+ # processed.
253
+ try:
254
+ pg.io.rm(os.path.join(entry.output_dir, f'{example_id}.inprogress'))
255
+ except FileNotFoundError:
256
+ pass
257
+
258
+ if (not self._error
259
+ and not entry.example_ids_to_be_aggregated
260
+ and not entry.example_ids_being_aggregated):
261
+ with entry.completion_lock:
262
+ if not entry.is_completed:
263
+ entry.is_completed = True
264
+ try:
265
+ self.on_experiment_complete(entry.evaluation)
266
+ except BaseException as e: # pylint: disable=broad-except
267
+ # Plugin failures should be raised to the user.
268
+ self._error = e
269
+
270
+ def _mark_example_started(
271
+ self,
272
+ evaluation: evaluation_lib.Evaluation,
273
+ example_id: int
274
+ ) -> None:
275
+ """Mark an example as started."""
276
+ example = example_lib.Example(
277
+ id=example_id, input=evaluation.example_input_by_id(example_id),
278
+ )
279
+ example.start_time = time.time()
280
+ self.on_example_start(evaluation, example)
281
+
282
+ # We update evaluation state with the inprogress status so the evaluation
283
+ # HTML could show remotely in-progress examples.
284
+ evaluation.state.update(example, in_progress=True)
285
+
286
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]):
287
+ raise NotImplementedError('Not needed in checkpoint monitor.')
288
+
289
+ def _evaluate_items(
290
+ self,
291
+ evaluation: evaluation_lib.Evaluation,
292
+ items: Iterator[example_lib.Example]
293
+ ) -> None:
294
+ raise NotImplementedError('Not needed in checkpoint monitor.')
@@ -0,0 +1,162 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import unittest
17
+
18
+ from langfun.core.eval.v2 import checkpointing
19
+ from langfun.core.eval.v2 import eval_test_helper
20
+ from langfun.core.eval.v2 import example as example_lib
21
+ from langfun.core.eval.v2 import experiment as experiment_lib
22
+ from langfun.core.eval.v2.runners import ckpt_monitor
23
+ from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-import
24
+ import pyglove as pg
25
+
26
+
27
+ class CheckpointMonitorTest(unittest.TestCase):
28
+
29
+ def setUp(self):
30
+ super().setUp()
31
+ self.test_dir = tempfile.mkdtemp()
32
+
33
+ def test_aggregate(self):
34
+ exp = eval_test_helper.test_experiment()
35
+ root_dir = os.path.join(self.test_dir, 'test_aggregate')
36
+ run = exp.run(
37
+ root_dir,
38
+ runner='sequential',
39
+ progress_tracker=None,
40
+ plugins=[
41
+ checkpointing.PerExampleCheckpointer(
42
+ checkpoint_filename='checkpoint.jsonl'
43
+ )
44
+ ],
45
+ use_cache='no',
46
+ )
47
+ # Try to corrupt one of the checkpoint files.
48
+ pg.io.writefile(
49
+ run.output_path_for(exp.leaf_nodes[0], 'checkpoint_1.jsonl'),
50
+ 'bad ckpt'
51
+ )
52
+ plugin = eval_test_helper.TestPlugin()
53
+ monitor = ckpt_monitor.CheckpointMonitor(
54
+ run,
55
+ plugins=[plugin],
56
+ checkpoint_pattern='checkpoint_*.jsonl',
57
+ monitor_inprogress_files=True,
58
+ )
59
+ monitor.run()
60
+
61
+ # Assert that the in-progress files are created and not removed.
62
+ for entry in monitor._aggregation_entries:
63
+ self.assertEqual(len(entry.example_ids_inprogress), 10)
64
+
65
+ # 6 leaf nodes + 1 suite + 1 hyper.
66
+ self.assertEqual(len(plugin.started_experiments), 6 + 2)
67
+ self.assertEqual(len(plugin.completed_experiments), 6 + 2)
68
+ self.assertEqual(len(plugin.started_example_ids), 10 * 6)
69
+ self.assertEqual(len(plugin.completed_example_ids), 10 * 6)
70
+ for e in exp.leaf_nodes:
71
+ self.assertEqual(e.progress.num_completed, 10)
72
+
73
+ def test_aggregate_with_filter(self):
74
+ exp = eval_test_helper.test_experiment()
75
+ root_dir = os.path.join(self.test_dir, 'test_aggregate_with_filter')
76
+
77
+ node_to_skip = exp.leaf_nodes[2]
78
+ # Run experiment to generate checkpoint files for all examples.
79
+ run = exp.run(
80
+ root_dir,
81
+ runner='sequential',
82
+ filter=lambda e: e.id != node_to_skip.id,
83
+ progress_tracker=None,
84
+ plugins=[
85
+ checkpointing.PerExampleCheckpointer(
86
+ checkpoint_filename='checkpoint.jsonl'
87
+ )
88
+ ],
89
+ use_cache='no',
90
+ )
91
+ plugin = eval_test_helper.TestPlugin()
92
+ monitor = ckpt_monitor.CheckpointMonitor(
93
+ run,
94
+ plugins=[plugin],
95
+ checkpoint_pattern='checkpoint_*.jsonl',
96
+ )
97
+ monitor.run()
98
+
99
+ # Assert that on_experiment_skipped was called for the filtered node.
100
+ self.assertEqual(len(plugin.skipped_experiments), 1)
101
+ self.assertEqual(plugin.skipped_experiments[0].id, node_to_skip.id)
102
+
103
+ # Assert that the skipped node was not started.
104
+ started_ids = [e.id for e in plugin.started_experiments]
105
+ self.assertNotIn(node_to_skip.id, started_ids)
106
+
107
+ def test_plugin_raise(self):
108
+
109
+ class TestPlugin(eval_test_helper.TestPlugin):
110
+ simulate_raise_on_example_complete: bool = False
111
+ simulate_raise_on_experiment_complete: bool = False
112
+
113
+ def on_example_complete(
114
+ self,
115
+ runner: experiment_lib.Runner,
116
+ experiment: experiment_lib.Experiment,
117
+ example: example_lib.Example
118
+ ):
119
+ if self.simulate_raise_on_example_complete:
120
+ raise ValueError('example complete error')
121
+
122
+ def on_experiment_complete(
123
+ self,
124
+ runner: experiment_lib.Runner,
125
+ experiment: experiment_lib.Experiment
126
+ ):
127
+ if self.simulate_raise_on_experiment_complete:
128
+ raise ValueError('experiment complete error')
129
+
130
+ exp = eval_test_helper.test_evaluation()
131
+ root_dir = os.path.join(self.test_dir, 'test_plugin_raise')
132
+
133
+ # Run experiment to generate checkpoint files for all examples.
134
+ run = exp.run(
135
+ root_dir,
136
+ runner='sequential',
137
+ progress_tracker=None,
138
+ plugins=[
139
+ checkpointing.PerExampleCheckpointer(
140
+ checkpoint_filename='checkpoint.jsonl'
141
+ )
142
+ ],
143
+ use_cache='no',
144
+ )
145
+
146
+ with self.assertRaisesRegex(ValueError, 'example complete error'):
147
+ ckpt_monitor.CheckpointMonitor(
148
+ run,
149
+ plugins=[TestPlugin(simulate_raise_on_example_complete=True)],
150
+ checkpoint_pattern='checkpoint_*.jsonl',
151
+ ).run()
152
+
153
+ with self.assertRaisesRegex(ValueError, 'experiment complete error'):
154
+ ckpt_monitor.CheckpointMonitor(
155
+ run,
156
+ plugins=[TestPlugin(simulate_raise_on_experiment_complete=True)],
157
+ checkpoint_pattern='checkpoint_*.jsonl',
158
+ ).run()
159
+
160
+
161
+ if __name__ == '__main__':
162
+ unittest.main()
@@ -23,7 +23,7 @@ from langfun.core.eval.v2.runners import debug # pylint: disable=unused-import
23
23
  import pyglove as pg
24
24
 
25
25
 
26
- class RunnerTest(unittest.TestCase):
26
+ class DebugRunnerTest(unittest.TestCase):
27
27
 
28
28
  def assert_same_list(self, actual: list[Any], expected: list[Any]):
29
29
  self.assertEqual(len(actual), len(expected))
@@ -32,9 +32,6 @@ class RunnerTest(unittest.TestCase):
32
32
  print(i, pg.diff(x, y))
33
33
  self.assertIs(x, y)
34
34
 
35
-
36
- class DebugRunnerTest(RunnerTest):
37
-
38
35
  def test_debug_runner(self):
39
36
  plugin = eval_test_helper.TestPlugin()
40
37
  exp = eval_test_helper.test_experiment()
@@ -23,7 +23,7 @@ from langfun.core.eval.v2.runners import parallel # pylint: disable=unused-impo
23
23
  import pyglove as pg
24
24
 
25
25
 
26
- class RunnerTest(unittest.TestCase):
26
+ class ParallelRunnerTest(unittest.TestCase):
27
27
 
28
28
  def assert_same_list(self, actual: list[Any], expected: list[Any]):
29
29
  self.assertEqual(len(actual), len(expected))
@@ -32,9 +32,6 @@ class RunnerTest(unittest.TestCase):
32
32
  print(i, pg.diff(x, y))
33
33
  self.assertIs(x, y)
34
34
 
35
-
36
- class ParallelRunnerTest(RunnerTest):
37
-
38
35
  def test_parallel_runner(self):
39
36
  plugin = eval_test_helper.TestPlugin()
40
37
  exp = eval_test_helper.test_experiment()
@@ -23,7 +23,7 @@ from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-im
23
23
  import pyglove as pg
24
24
 
25
25
 
26
- class RunnerTest(unittest.TestCase):
26
+ class SequentialRunnerTest(unittest.TestCase):
27
27
 
28
28
  def assert_same_list(self, actual: list[Any], expected: list[Any]):
29
29
  self.assertEqual(len(actual), len(expected))
@@ -32,9 +32,6 @@ class RunnerTest(unittest.TestCase):
32
32
  print(i, pg.diff(x, y))
33
33
  self.assertIs(x, y)
34
34
 
35
-
36
- class SequentialRunnerTest(RunnerTest):
37
-
38
35
  def test_basic(self):
39
36
  plugin = eval_test_helper.TestPlugin()
40
37
  exp = eval_test_helper.test_experiment()
@@ -109,9 +109,9 @@ class LangFuncCallTest(unittest.TestCase):
109
109
  ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
110
110
  ' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
111
111
  ' random_seed=None, logprobs=False, top_logprobs=None,'
112
- ' max_thinking_tokens=None, reasoning_effort=None, extras={}),'
113
- ' cache=None, max_concurrency=None, timeout=120.0, max_attempts=5,'
114
- ' retry_interval=(5, 60), exponential_backoff=True,'
112
+ ' max_thinking_tokens=None, thinking_level=None, reasoning_effort=None,'
113
+ ' extras={}), cache=None, max_concurrency=None, timeout=120.0,'
114
+ ' max_attempts=5, retry_interval=(5, 60), exponential_backoff=True,'
115
115
  ' max_retry_interval=300, debug=False))',
116
116
  )
117
117
 
@@ -53,6 +53,10 @@ class RetryableLMError(LMError):
53
53
  """Base class for LLM errors that can be solved by retrying."""
54
54
 
55
55
 
56
+ class EmptyGenerationError(RetryableLMError):
57
+ """Error for empty generaition."""
58
+
59
+
56
60
  class RateLimitError(RetryableLMError):
57
61
  """Error for rate limit reached."""
58
62
 
@@ -575,6 +579,14 @@ class LMSamplingOptions(component.Component):
575
579
  int | None, 'Number of max thinking tokens.'
576
580
  ] = None
577
581
 
582
+ thinking_level: Annotated[
583
+ Literal['low', 'high'] | None,
584
+ (
585
+ 'Thinking level for Gemini models. High is for complex tasks, '
586
+ 'while low is for faster responses.'
587
+ ),
588
+ ] = None
589
+
578
590
  reasoning_effort: Annotated[
579
591
  Literal['low', 'medium', 'high'] | None,
580
592
  (
@@ -1076,10 +1088,32 @@ class LanguageModel(component.Component):
1076
1088
  prompts = [message_lib.UserMessage.from_value(p) for p in prompts]
1077
1089
 
1078
1090
  with component.context(override_attrs=True, **kwargs):
1079
- if self.cache is None:
1080
- results = self._sample(prompts)
1081
- else:
1082
- results = self._sample_with_cache_lookup(prompts, cache_seed)
1091
+
1092
+ def _sample_with_retry():
1093
+ if self.cache is None:
1094
+ results = self._sample(prompts)
1095
+ else:
1096
+ results = self._sample_with_cache_lookup(prompts, cache_seed)
1097
+
1098
+ for i, result in enumerate(results):
1099
+ for sample in result.samples:
1100
+ if not sample.response.text:
1101
+ if self.cache is not None:
1102
+ self.cache.delete(self, prompts[i], seed=cache_seed)
1103
+ raise EmptyGenerationError(
1104
+ f'Empty generation encountered from model {self.model_id}.'
1105
+ )
1106
+ return results
1107
+
1108
+ retry_fn = concurrent.with_retry(
1109
+ _sample_with_retry,
1110
+ retry_on_errors=EmptyGenerationError,
1111
+ max_attempts=self.max_attempts,
1112
+ retry_interval=self.retry_interval,
1113
+ exponential_backoff=self.exponential_backoff,
1114
+ max_retry_interval=self.max_retry_interval,
1115
+ )
1116
+ results = retry_fn()
1083
1117
 
1084
1118
  for prompt, result in zip(prompts, results):
1085
1119
 
@@ -1088,7 +1122,6 @@ class LanguageModel(component.Component):
1088
1122
 
1089
1123
  for sample in result.samples:
1090
1124
  # Update metadata for response message.
1091
-
1092
1125
  response = sample.response
1093
1126
  response.metadata.score = sample.score
1094
1127
  response.metadata.logprobs = sample.logprobs
@@ -591,6 +591,51 @@ class LanguageModelTest(unittest.TestCase):
591
591
  lm = MockModel(cache=cache, top_k=1)
592
592
  self.assertEqual(lm('a'), 'a')
593
593
 
594
+ def test_empty_generation_error(self):
595
+ class MockModelWithEmptyResponse(MockModel):
596
+ def _sample(self,
597
+ prompts: list[message_lib.Message]
598
+ ) -> list[lm_lib.LMSamplingResult]:
599
+ return [lm_lib.LMSamplingResult(
600
+ [lm_lib.LMSample(response='')],
601
+ usage=lm_lib.LMSamplingUsage(100, 0, 100, 1, 1.0)
602
+ )]
603
+ lm = MockModelWithEmptyResponse(max_attempts=1, retry_interval=0)
604
+ with self.assertRaisesRegex(
605
+ concurrent.RetryError, 'Empty generation encountered'
606
+ ):
607
+ lm('a')
608
+
609
+ def test_empty_generation_retry(self):
610
+ class MockModelWithEmptyThenValid(MockModel):
611
+ attempt_count: int = 0
612
+
613
+ def _sample(
614
+ self, prompts: list[message_lib.Message]
615
+ ) -> list[lm_lib.LMSamplingResult]:
616
+ self.rebind(attempt_count=self.attempt_count + 1)
617
+ if self.attempt_count == 1:
618
+ # First attempt returns empty
619
+ return [
620
+ lm_lib.LMSamplingResult(
621
+ [lm_lib.LMSample(response='')],
622
+ usage=lm_lib.LMSamplingUsage(100, 0, 100, 1, 1.0),
623
+ )
624
+ ]
625
+ else:
626
+ # Subsequent attempts return valid response
627
+ return [
628
+ lm_lib.LMSamplingResult(
629
+ [lm_lib.LMSample(response='valid response')],
630
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
631
+ )
632
+ ]
633
+
634
+ lm = MockModelWithEmptyThenValid(max_attempts=3, retry_interval=0)
635
+ result = lm('a')
636
+ self.assertEqual(result.text, 'valid response')
637
+ self.assertEqual(lm.attempt_count, 2)
638
+
594
639
  def test_estimate_max_concurrency(self):
595
640
  self.assertIsNone(lm_lib.LanguageModel.estimate_max_concurrency(None, None))
596
641
  self.assertEqual(
@@ -42,6 +42,7 @@ from langfun.core.llms.azure_openai import AzureOpenAI
42
42
 
43
43
  # Gemini models.
44
44
  from langfun.core.llms.google_genai import GenAI
45
+ from langfun.core.llms.google_genai import Gemini3ProPreview
45
46
  from langfun.core.llms.google_genai import Gemini25Pro
46
47
  from langfun.core.llms.google_genai import Gemini25Flash
47
48
  from langfun.core.llms.google_genai import Gemini25ProPreview_20250605
@@ -90,6 +91,7 @@ from langfun.core.llms.vertexai import VertexAIGemini25ProPreview_20250605
90
91
  from langfun.core.llms.vertexai import VertexAIGemini25Pro
91
92
  from langfun.core.llms.vertexai import VertexAIGemini25Flash
92
93
  from langfun.core.llms.vertexai import VertexAIGemini25FlashImagePreview
94
+ from langfun.core.llms.vertexai import VertexAIGemini3ProPreview
93
95
 
94
96
  # For backward compatibility.
95
97
  GeminiPro1_5 = Gemini15Pro