langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (155) hide show
  1. langfun/core/__init__.py +2 -0
  2. langfun/core/agentic/__init__.py +4 -1
  3. langfun/core/agentic/action.py +447 -29
  4. langfun/core/agentic/action_eval.py +9 -2
  5. langfun/core/agentic/action_test.py +149 -21
  6. langfun/core/async_support.py +32 -3
  7. langfun/core/coding/python/correction.py +19 -9
  8. langfun/core/coding/python/execution.py +14 -12
  9. langfun/core/coding/python/generation.py +21 -16
  10. langfun/core/coding/python/sandboxing.py +23 -3
  11. langfun/core/component.py +42 -3
  12. langfun/core/concurrent.py +70 -6
  13. langfun/core/concurrent_test.py +1 -0
  14. langfun/core/console.py +1 -1
  15. langfun/core/data/conversion/anthropic.py +12 -3
  16. langfun/core/data/conversion/anthropic_test.py +8 -6
  17. langfun/core/data/conversion/gemini.py +9 -2
  18. langfun/core/data/conversion/gemini_test.py +12 -9
  19. langfun/core/data/conversion/openai.py +145 -31
  20. langfun/core/data/conversion/openai_test.py +161 -17
  21. langfun/core/eval/base.py +47 -43
  22. langfun/core/eval/base_test.py +5 -5
  23. langfun/core/eval/matching.py +5 -2
  24. langfun/core/eval/patching.py +3 -3
  25. langfun/core/eval/scoring.py +4 -3
  26. langfun/core/eval/v2/__init__.py +1 -0
  27. langfun/core/eval/v2/checkpointing.py +64 -6
  28. langfun/core/eval/v2/checkpointing_test.py +9 -2
  29. langfun/core/eval/v2/eval_test_helper.py +103 -2
  30. langfun/core/eval/v2/evaluation.py +91 -16
  31. langfun/core/eval/v2/evaluation_test.py +9 -3
  32. langfun/core/eval/v2/example.py +50 -40
  33. langfun/core/eval/v2/example_test.py +16 -8
  34. langfun/core/eval/v2/experiment.py +74 -8
  35. langfun/core/eval/v2/experiment_test.py +19 -0
  36. langfun/core/eval/v2/metric_values.py +31 -3
  37. langfun/core/eval/v2/metric_values_test.py +32 -0
  38. langfun/core/eval/v2/metrics.py +157 -44
  39. langfun/core/eval/v2/metrics_test.py +39 -18
  40. langfun/core/eval/v2/progress.py +30 -1
  41. langfun/core/eval/v2/progress_test.py +27 -0
  42. langfun/core/eval/v2/progress_tracking.py +12 -3
  43. langfun/core/eval/v2/progress_tracking_test.py +6 -1
  44. langfun/core/eval/v2/reporting.py +90 -71
  45. langfun/core/eval/v2/reporting_test.py +24 -6
  46. langfun/core/eval/v2/runners/__init__.py +30 -0
  47. langfun/core/eval/v2/{runners.py → runners/base.py} +59 -142
  48. langfun/core/eval/v2/runners/beam.py +341 -0
  49. langfun/core/eval/v2/runners/beam_test.py +131 -0
  50. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  51. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  52. langfun/core/eval/v2/runners/debug.py +40 -0
  53. langfun/core/eval/v2/runners/debug_test.py +76 -0
  54. langfun/core/eval/v2/runners/parallel.py +100 -0
  55. langfun/core/eval/v2/runners/parallel_test.py +95 -0
  56. langfun/core/eval/v2/runners/sequential.py +47 -0
  57. langfun/core/eval/v2/runners/sequential_test.py +172 -0
  58. langfun/core/langfunc.py +45 -130
  59. langfun/core/langfunc_test.py +7 -5
  60. langfun/core/language_model.py +141 -21
  61. langfun/core/language_model_test.py +54 -3
  62. langfun/core/llms/__init__.py +9 -1
  63. langfun/core/llms/anthropic.py +157 -2
  64. langfun/core/llms/azure_openai.py +29 -17
  65. langfun/core/llms/cache/base.py +25 -3
  66. langfun/core/llms/cache/in_memory.py +48 -7
  67. langfun/core/llms/cache/in_memory_test.py +14 -4
  68. langfun/core/llms/compositional.py +25 -1
  69. langfun/core/llms/deepseek.py +30 -2
  70. langfun/core/llms/fake.py +32 -1
  71. langfun/core/llms/gemini.py +55 -17
  72. langfun/core/llms/gemini_test.py +84 -0
  73. langfun/core/llms/google_genai.py +34 -1
  74. langfun/core/llms/groq.py +28 -3
  75. langfun/core/llms/llama_cpp.py +23 -4
  76. langfun/core/llms/openai.py +36 -3
  77. langfun/core/llms/openai_compatible.py +148 -27
  78. langfun/core/llms/openai_compatible_test.py +207 -20
  79. langfun/core/llms/openai_test.py +0 -2
  80. langfun/core/llms/rest.py +12 -1
  81. langfun/core/llms/vertexai.py +58 -8
  82. langfun/core/logging.py +1 -1
  83. langfun/core/mcp/client.py +77 -22
  84. langfun/core/mcp/client_test.py +8 -35
  85. langfun/core/mcp/session.py +94 -29
  86. langfun/core/mcp/session_test.py +54 -0
  87. langfun/core/mcp/tool.py +151 -22
  88. langfun/core/mcp/tool_test.py +197 -0
  89. langfun/core/memory.py +1 -0
  90. langfun/core/message.py +160 -55
  91. langfun/core/message_test.py +65 -81
  92. langfun/core/modalities/__init__.py +8 -0
  93. langfun/core/modalities/audio.py +21 -1
  94. langfun/core/modalities/image.py +19 -1
  95. langfun/core/modalities/mime.py +64 -3
  96. langfun/core/modalities/mime_test.py +11 -0
  97. langfun/core/modalities/pdf.py +19 -1
  98. langfun/core/modalities/video.py +21 -1
  99. langfun/core/modality.py +167 -29
  100. langfun/core/modality_test.py +42 -12
  101. langfun/core/natural_language.py +1 -1
  102. langfun/core/sampling.py +4 -4
  103. langfun/core/sampling_test.py +20 -4
  104. langfun/core/structured/__init__.py +2 -24
  105. langfun/core/structured/completion.py +34 -44
  106. langfun/core/structured/completion_test.py +23 -43
  107. langfun/core/structured/description.py +54 -50
  108. langfun/core/structured/function_generation.py +29 -12
  109. langfun/core/structured/mapping.py +81 -37
  110. langfun/core/structured/parsing.py +95 -79
  111. langfun/core/structured/parsing_test.py +0 -3
  112. langfun/core/structured/querying.py +215 -142
  113. langfun/core/structured/querying_test.py +65 -29
  114. langfun/core/structured/schema/__init__.py +49 -0
  115. langfun/core/structured/schema/base.py +664 -0
  116. langfun/core/structured/schema/base_test.py +531 -0
  117. langfun/core/structured/schema/json.py +174 -0
  118. langfun/core/structured/schema/json_test.py +121 -0
  119. langfun/core/structured/schema/python.py +316 -0
  120. langfun/core/structured/schema/python_test.py +410 -0
  121. langfun/core/structured/schema_generation.py +33 -14
  122. langfun/core/structured/scoring.py +47 -36
  123. langfun/core/structured/tokenization.py +26 -11
  124. langfun/core/subscription.py +2 -2
  125. langfun/core/template.py +174 -49
  126. langfun/core/template_test.py +123 -17
  127. langfun/env/__init__.py +8 -2
  128. langfun/env/base_environment.py +320 -128
  129. langfun/env/base_environment_test.py +473 -0
  130. langfun/env/base_feature.py +92 -15
  131. langfun/env/base_feature_test.py +228 -0
  132. langfun/env/base_sandbox.py +84 -361
  133. langfun/env/base_sandbox_test.py +1235 -0
  134. langfun/env/event_handlers/__init__.py +1 -1
  135. langfun/env/event_handlers/chain.py +233 -0
  136. langfun/env/event_handlers/chain_test.py +253 -0
  137. langfun/env/event_handlers/event_logger.py +95 -98
  138. langfun/env/event_handlers/event_logger_test.py +21 -21
  139. langfun/env/event_handlers/metric_writer.py +225 -140
  140. langfun/env/event_handlers/metric_writer_test.py +23 -6
  141. langfun/env/interface.py +854 -40
  142. langfun/env/interface_test.py +112 -2
  143. langfun/env/load_balancers_test.py +23 -2
  144. langfun/env/test_utils.py +126 -84
  145. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
  146. langfun-0.1.2.dev202511270805.dist-info/RECORD +215 -0
  147. langfun/core/eval/v2/runners_test.py +0 -343
  148. langfun/core/structured/schema.py +0 -987
  149. langfun/core/structured/schema_test.py +0 -982
  150. langfun/env/base_test.py +0 -1481
  151. langfun/env/event_handlers/base.py +0 -350
  152. langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
  153. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
  154. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
  155. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Checkpoint aggregator for Langfun evaluations."""
15
+
16
+ import concurrent.futures
17
+ import dataclasses
18
+ import os
19
+ import threading
20
+ import time
21
+ from typing import Annotated, Iterator
22
+
23
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
24
+ from langfun.core.eval.v2 import example as example_lib
25
+ from langfun.core.eval.v2 import reporting
26
+ from langfun.core.eval.v2.runners import base
27
+
28
+ import pyglove as pg
29
+
30
+
31
+ class CheckpointMonitor(base.RunnerBase):
32
+ """Runner for monitoring checkpoing files generated by other runners.
33
+
34
+ Currently checkpoint monitor only supports aggregating per-example
35
+ checkpoint files.
36
+ """
37
+
38
+ NAME = 'checkpoint_monitor'
39
+
40
+ plugins = [
41
+ reporting.HtmlReporter(),
42
+ ]
43
+
44
+ checkpoint_pattern: Annotated[
45
+ str, 'The glob pattern of the checkpoint files to monitor.'
46
+ ] = 'checkpoint_*.bagz'
47
+
48
+ monitor_inprogress_files: Annotated[
49
+ bool,
50
+ 'If True, monitor in-progress files to aggregate.'
51
+ ] = False
52
+
53
+ poll_interval: Annotated[
54
+ int,
55
+ 'The interval in seconds to poll for new checkpoint files.'
56
+ ] = 5
57
+
58
+ max_aggregation_threads: Annotated[
59
+ int,
60
+ 'The maximum number of threads to aggregate checkpoints.'
61
+ ] = 128
62
+
63
+ @dataclasses.dataclass
64
+ class _AggregationEntry:
65
+ evaluation: evaluation_lib.Evaluation
66
+ output_dir: str
67
+ inprogress_file_pattern: str | None
68
+ ckpt_file_pattern: str
69
+ example_ids_inprogress: set[int]
70
+ example_ids_to_be_aggregated: set[int]
71
+ example_ids_being_aggregated: set[int]
72
+ completion_lock: threading.Lock
73
+ is_completed: bool = False
74
+
75
+ def _on_bound(self):
76
+ super()._on_bound()
77
+ self._monitor_thread = None
78
+ self._aggregation_entries = []
79
+ self._aggregator_pool = None
80
+ self._error = None
81
+
82
+ def start(self):
83
+ # Reset the experiment state before getting started.
84
+ self.current_run.experiment.reset()
85
+
86
+ # Signal the start of the run.
87
+ self.on_run_start()
88
+
89
+ # Start the non-leaf nodes.
90
+ for node in self.current_run.experiment.nonleaf_nodes:
91
+ self.on_experiment_start(node)
92
+
93
+ for evaluation in self.current_run.experiment.leaf_nodes:
94
+ # This is not precise, but we at least notify example start.
95
+ if not self.current_run.filter or self.current_run.filter(evaluation):
96
+ self.on_experiment_start(evaluation)
97
+
98
+ # Signal the start of the examples if we are not monitoring in-progress
99
+ # files.
100
+ if not self.monitor_inprogress_files:
101
+ for example_id in self.current_run.examples_to_evaluate(evaluation):
102
+ self._mark_example_started(evaluation, example_id)
103
+
104
+ # Create the aggregation entries for polling.
105
+ output_dir = self.current_run.output_dir(evaluation)
106
+ self._aggregation_entries.append(
107
+ self._AggregationEntry(
108
+ evaluation=evaluation,
109
+ output_dir=output_dir,
110
+ ckpt_file_pattern=os.path.join(
111
+ output_dir, self.checkpoint_pattern
112
+ ),
113
+ inprogress_file_pattern=os.path.join(
114
+ output_dir, '*.inprogress'
115
+ ) if self.monitor_inprogress_files else None,
116
+ example_ids_to_be_aggregated=(
117
+ self.current_run.examples_to_evaluate(evaluation)
118
+ ),
119
+ example_ids_inprogress=set(),
120
+ example_ids_being_aggregated=set(),
121
+ completion_lock=threading.Lock(),
122
+ is_completed=False,
123
+ )
124
+ )
125
+ else:
126
+ self.on_experiment_skipped(evaluation)
127
+
128
+ self._aggregator_pool = concurrent.futures.ThreadPoolExecutor(
129
+ max_workers=self.max_aggregation_threads
130
+ )
131
+ self._monitor_thread = threading.Thread(target=self._monitor_loop)
132
+ self._monitor_thread.start()
133
+
134
+ def join(self):
135
+ if self._monitor_thread:
136
+ self._monitor_thread.join()
137
+ if self._error is not None:
138
+ raise self._error
139
+
140
+ def run(self):
141
+ self.start()
142
+ self.join()
143
+
144
+ def _monitor_loop(self):
145
+ while not self._error and any(
146
+ not e.is_completed for e in self._aggregation_entries
147
+ ):
148
+ for entry in self._aggregation_entries:
149
+ if not entry.example_ids_to_be_aggregated:
150
+ continue
151
+
152
+ # Signal example processing.
153
+ if self.monitor_inprogress_files:
154
+ inprogress_files = pg.io.glob(entry.inprogress_file_pattern)
155
+ for inprogress_file in inprogress_files:
156
+ example_id = int(
157
+ os.path.basename(inprogress_file).split('.')[0]
158
+ )
159
+ if example_id not in entry.example_ids_inprogress:
160
+ self._mark_example_started(entry.evaluation, example_id)
161
+ entry.example_ids_inprogress.add(example_id)
162
+
163
+ for filepath in pg.io.glob(entry.ckpt_file_pattern):
164
+ example_id = int(
165
+ os.path.basename(filepath).split('.')[0].split('_')[-1]
166
+ )
167
+ if example_id in entry.example_ids_to_be_aggregated:
168
+ # Remove example ID from the set to avoid duplicate processing.
169
+ entry.example_ids_to_be_aggregated.remove(example_id)
170
+ entry.example_ids_being_aggregated.add(example_id)
171
+
172
+ # It could be that the example has been processed before, but the
173
+ # inprogress file was removed. In this case, we should signal the
174
+ # example has started before completing it.
175
+ if example_id not in entry.example_ids_inprogress:
176
+ self._mark_example_started(entry.evaluation, example_id)
177
+ entry.example_ids_inprogress.add(example_id)
178
+
179
+ self._aggregator_pool.submit(
180
+ self._aggregate, entry, filepath, example_id
181
+ )
182
+ pg.logging.info(
183
+ '[%s] Aggregating example %d from %s...',
184
+ entry.evaluation.id,
185
+ example_id,
186
+ filepath,
187
+ )
188
+ time.sleep(self.poll_interval)
189
+
190
+ if self._error is None:
191
+ self.on_run_complete()
192
+ else:
193
+ self.on_run_abort(self._error)
194
+
195
+ def _aggregate(
196
+ self,
197
+ entry: _AggregationEntry,
198
+ ckpt_filepath: str,
199
+ example_id: int
200
+ ):
201
+ """Aggregate an example from a checkpoint file."""
202
+ try:
203
+ loaded_examples = entry.evaluation.state.load(
204
+ ckpt_filepath,
205
+ example_input_by_id=entry.evaluation.example_input_by_id,
206
+ # Example metadata may be expensive to load, and is not used by
207
+ # metric aggregation. Thus we do not load example metadata.
208
+ load_example_metadata=False
209
+ )
210
+ assert len(loaded_examples) > 1, loaded_examples
211
+ # Ocassionally the per-example checkpoint file may contain the same
212
+ # example processed multiple times. We only need to aggregate the last
213
+ # example.
214
+ example = loaded_examples[-1]
215
+ except BaseException as e: # pylint: disable=broad-except
216
+ error_info = pg.ErrorInfo.from_exception(e)
217
+ pg.logging.error(
218
+ '[%s] Failed to aggregate example %d: %s',
219
+ entry.evaluation.id,
220
+ example_id,
221
+ error_info
222
+ )
223
+ example = example_lib.Example(
224
+ id=example_id,
225
+ input=entry.evaluation.example_input_by_id(example_id),
226
+ error=error_info,
227
+ )
228
+
229
+ # This will skip processing but still allow metrics to be collected.
230
+ # `process` will never be called for evaluation, thus we do not
231
+ # need to setup/teardown evaluation.
232
+ example = entry.evaluation.evaluate(
233
+ example, reevaluate_upon_previous_errors=False
234
+ )
235
+ example.newly_processed = True
236
+ pg.logging.info(
237
+ '[%s] Successfully aggregated example %d from %s.',
238
+ entry.evaluation.id,
239
+ example_id,
240
+ ckpt_filepath,
241
+ )
242
+
243
+ try:
244
+ self.on_example_complete(entry.evaluation, example)
245
+ except BaseException as e: # pylint: disable=broad-except
246
+ # Plugin failures should be raised to the user.
247
+ self._error = e
248
+
249
+ entry.example_ids_being_aggregated.remove(example_id)
250
+
251
+ # Remove the in-progress file to indicate that the example has been
252
+ # processed.
253
+ try:
254
+ pg.io.rm(os.path.join(entry.output_dir, f'{example_id}.inprogress'))
255
+ except FileNotFoundError:
256
+ pass
257
+
258
+ if (not self._error
259
+ and not entry.example_ids_to_be_aggregated
260
+ and not entry.example_ids_being_aggregated):
261
+ with entry.completion_lock:
262
+ if not entry.is_completed:
263
+ entry.is_completed = True
264
+ try:
265
+ self.on_experiment_complete(entry.evaluation)
266
+ except BaseException as e: # pylint: disable=broad-except
267
+ # Plugin failures should be raised to the user.
268
+ self._error = e
269
+
270
+ def _mark_example_started(
271
+ self,
272
+ evaluation: evaluation_lib.Evaluation,
273
+ example_id: int
274
+ ) -> None:
275
+ """Mark an example as started."""
276
+ example = example_lib.Example(
277
+ id=example_id, input=evaluation.example_input_by_id(example_id),
278
+ )
279
+ example.start_time = time.time()
280
+ self.on_example_start(evaluation, example)
281
+
282
+ # We update evaluation state with the inprogress status so the evaluation
283
+ # HTML could show remotely in-progress examples.
284
+ evaluation.state.update(example, in_progress=True)
285
+
286
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]):
287
+ raise NotImplementedError('Not needed in checkpoint monitor.')
288
+
289
+ def _evaluate_items(
290
+ self,
291
+ evaluation: evaluation_lib.Evaluation,
292
+ items: Iterator[example_lib.Example]
293
+ ) -> None:
294
+ raise NotImplementedError('Not needed in checkpoint monitor.')
@@ -0,0 +1,162 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import unittest
17
+
18
+ from langfun.core.eval.v2 import checkpointing
19
+ from langfun.core.eval.v2 import eval_test_helper
20
+ from langfun.core.eval.v2 import example as example_lib
21
+ from langfun.core.eval.v2 import experiment as experiment_lib
22
+ from langfun.core.eval.v2.runners import ckpt_monitor
23
+ from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-import
24
+ import pyglove as pg
25
+
26
+
27
+ class CheckpointMonitorTest(unittest.TestCase):
28
+
29
+ def setUp(self):
30
+ super().setUp()
31
+ self.test_dir = tempfile.mkdtemp()
32
+
33
+ def test_aggregate(self):
34
+ exp = eval_test_helper.test_experiment()
35
+ root_dir = os.path.join(self.test_dir, 'test_aggregate')
36
+ run = exp.run(
37
+ root_dir,
38
+ runner='sequential',
39
+ progress_tracker=None,
40
+ plugins=[
41
+ checkpointing.PerExampleCheckpointer(
42
+ checkpoint_filename='checkpoint.jsonl'
43
+ )
44
+ ],
45
+ use_cache='no',
46
+ )
47
+ # Try to corrupt one of the checkpoint files.
48
+ pg.io.writefile(
49
+ run.output_path_for(exp.leaf_nodes[0], 'checkpoint_1.jsonl'),
50
+ 'bad ckpt'
51
+ )
52
+ plugin = eval_test_helper.TestPlugin()
53
+ monitor = ckpt_monitor.CheckpointMonitor(
54
+ run,
55
+ plugins=[plugin],
56
+ checkpoint_pattern='checkpoint_*.jsonl',
57
+ monitor_inprogress_files=True,
58
+ )
59
+ monitor.run()
60
+
61
+ # Assert that the in-progress files are created and not removed.
62
+ for entry in monitor._aggregation_entries:
63
+ self.assertEqual(len(entry.example_ids_inprogress), 10)
64
+
65
+ # 6 leaf nodes + 1 suite + 1 hyper.
66
+ self.assertEqual(len(plugin.started_experiments), 6 + 2)
67
+ self.assertEqual(len(plugin.completed_experiments), 6 + 2)
68
+ self.assertEqual(len(plugin.started_example_ids), 10 * 6)
69
+ self.assertEqual(len(plugin.completed_example_ids), 10 * 6)
70
+ for e in exp.leaf_nodes:
71
+ self.assertEqual(e.progress.num_completed, 10)
72
+
73
+ def test_aggregate_with_filter(self):
74
+ exp = eval_test_helper.test_experiment()
75
+ root_dir = os.path.join(self.test_dir, 'test_aggregate_with_filter')
76
+
77
+ node_to_skip = exp.leaf_nodes[2]
78
+ # Run experiment to generate checkpoint files for all examples.
79
+ run = exp.run(
80
+ root_dir,
81
+ runner='sequential',
82
+ filter=lambda e: e.id != node_to_skip.id,
83
+ progress_tracker=None,
84
+ plugins=[
85
+ checkpointing.PerExampleCheckpointer(
86
+ checkpoint_filename='checkpoint.jsonl'
87
+ )
88
+ ],
89
+ use_cache='no',
90
+ )
91
+ plugin = eval_test_helper.TestPlugin()
92
+ monitor = ckpt_monitor.CheckpointMonitor(
93
+ run,
94
+ plugins=[plugin],
95
+ checkpoint_pattern='checkpoint_*.jsonl',
96
+ )
97
+ monitor.run()
98
+
99
+ # Assert that on_experiment_skipped was called for the filtered node.
100
+ self.assertEqual(len(plugin.skipped_experiments), 1)
101
+ self.assertEqual(plugin.skipped_experiments[0].id, node_to_skip.id)
102
+
103
+ # Assert that the skipped node was not started.
104
+ started_ids = [e.id for e in plugin.started_experiments]
105
+ self.assertNotIn(node_to_skip.id, started_ids)
106
+
107
+ def test_plugin_raise(self):
108
+
109
+ class TestPlugin(eval_test_helper.TestPlugin):
110
+ simulate_raise_on_example_complete: bool = False
111
+ simulate_raise_on_experiment_complete: bool = False
112
+
113
+ def on_example_complete(
114
+ self,
115
+ runner: experiment_lib.Runner,
116
+ experiment: experiment_lib.Experiment,
117
+ example: example_lib.Example
118
+ ):
119
+ if self.simulate_raise_on_example_complete:
120
+ raise ValueError('example complete error')
121
+
122
+ def on_experiment_complete(
123
+ self,
124
+ runner: experiment_lib.Runner,
125
+ experiment: experiment_lib.Experiment
126
+ ):
127
+ if self.simulate_raise_on_experiment_complete:
128
+ raise ValueError('experiment complete error')
129
+
130
+ exp = eval_test_helper.test_evaluation()
131
+ root_dir = os.path.join(self.test_dir, 'test_plugin_raise')
132
+
133
+ # Run experiment to generate checkpoint files for all examples.
134
+ run = exp.run(
135
+ root_dir,
136
+ runner='sequential',
137
+ progress_tracker=None,
138
+ plugins=[
139
+ checkpointing.PerExampleCheckpointer(
140
+ checkpoint_filename='checkpoint.jsonl'
141
+ )
142
+ ],
143
+ use_cache='no',
144
+ )
145
+
146
+ with self.assertRaisesRegex(ValueError, 'example complete error'):
147
+ ckpt_monitor.CheckpointMonitor(
148
+ run,
149
+ plugins=[TestPlugin(simulate_raise_on_example_complete=True)],
150
+ checkpoint_pattern='checkpoint_*.jsonl',
151
+ ).run()
152
+
153
+ with self.assertRaisesRegex(ValueError, 'experiment complete error'):
154
+ ckpt_monitor.CheckpointMonitor(
155
+ run,
156
+ plugins=[TestPlugin(simulate_raise_on_experiment_complete=True)],
157
+ checkpoint_pattern='checkpoint_*.jsonl',
158
+ ).run()
159
+
160
+
161
+ if __name__ == '__main__':
162
+ unittest.main()
@@ -0,0 +1,40 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Debug runner."""
15
+
16
+ from langfun.core.eval.v2.runners import sequential
17
+
18
+
19
+ class DebugRunner(sequential.SequentialRunner):
20
+ """A runner for debugging evaluations.
21
+
22
+ The debug runner is a sequential runner that only runs the first example
23
+ of each evaluation, with `raise_if_has_error` enabled. This is useful for
24
+ quickly identifying issues in evaluation logic during development.
25
+ Checkpointers are disabled for this runner.
26
+ """
27
+
28
+ NAME = 'debug'
29
+
30
+ # Do not use the checkpointer for debug runner.
31
+ plugins = []
32
+
33
+ def _on_bound(self):
34
+ super()._on_bound()
35
+ if self.current_run.example_ids is None:
36
+ self.current_run.rebind(example_ids=[1], skip_notification=True)
37
+ self.current_run.rebind(raise_if_has_error=True, skip_notification=True)
38
+
39
+ def _save_run_manifest(self) -> None:
40
+ """Do nothing to avoid overriden existing runs."""
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for debug runner."""
15
+ import os
16
+ import tempfile
17
+ from typing import Any
18
+ import unittest
19
+
20
+ from langfun.core.eval.v2 import eval_test_helper
21
+ from langfun.core.eval.v2.runners import debug # pylint: disable=unused-import
22
+
23
+ import pyglove as pg
24
+
25
+
26
+ class DebugRunnerTest(unittest.TestCase):
27
+
28
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
29
+ self.assertEqual(len(actual), len(expected))
30
+ for i, (x, y) in enumerate(zip(actual, expected)):
31
+ if x is not y:
32
+ print(i, pg.diff(x, y))
33
+ self.assertIs(x, y)
34
+
35
+ def test_debug_runner(self):
36
+ plugin = eval_test_helper.TestPlugin()
37
+ exp = eval_test_helper.test_experiment()
38
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_debug_runner')
39
+ run = exp.run(root_dir, runner='debug', plugins=[plugin])
40
+
41
+ self.assertIsNotNone(plugin.start_time)
42
+ self.assertIsNotNone(plugin.complete_time)
43
+ self.assertGreater(plugin.complete_time, plugin.start_time)
44
+
45
+ self.assertEqual(
46
+ len(plugin.started_experiments), len(exp.nodes)
47
+ )
48
+ self.assertEqual(
49
+ len(plugin.completed_experiments), len(exp.nodes)
50
+ )
51
+ self.assertEqual(
52
+ len(plugin.started_example_ids), 6 * 1
53
+ )
54
+ self.assertEqual(
55
+ len(plugin.completed_example_ids), 6 * 1
56
+ )
57
+ self.assert_same_list(plugin.skipped_experiments, [])
58
+ self.assertFalse(
59
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
60
+ )
61
+
62
+ for node in exp.nodes:
63
+ self.assertTrue(node.progress.is_started)
64
+ self.assertTrue(node.progress.is_completed)
65
+ if node.is_leaf:
66
+ self.assertEqual(node.progress.num_skipped, 0)
67
+ self.assertEqual(node.progress.num_completed, 1)
68
+ self.assertEqual(node.progress.num_failed, 0)
69
+ else:
70
+ self.assertEqual(node.progress.num_skipped, 0)
71
+ self.assertEqual(node.progress.num_failed, 0)
72
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
73
+
74
+
75
+ if __name__ == '__main__':
76
+ unittest.main()
@@ -0,0 +1,100 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Parallel runner."""
15
+
16
+ import collections
17
+ import random
18
+ import threading
19
+ import time
20
+
21
+ from typing import Annotated, Iterator
22
+ import langfun.core as lf
23
+ from langfun.core.eval.v2.runners import base
24
+
25
+
26
+ class ParallelRunner(base.RunnerBase):
27
+ """A runner that executes evaluations and examples in parallel.
28
+
29
+ The parallel runner groups evaluations by their required resources
30
+ (e.g., specific LLMs) and runs evaluations that do not share resources in
31
+ parallel. Within each evaluation, examples are also processed in parallel
32
+ using threads, up to `Evaluation.max_workers`.
33
+ """
34
+
35
+ NAME = 'parallel'
36
+
37
+ timeout: Annotated[
38
+ int | None,
39
+ 'Timeout for each evaluation example.'
40
+ ] = None
41
+
42
+ concurrent_startup_delay: Annotated[
43
+ tuple[int, int] | None,
44
+ (
45
+ 'A range of seconds to delay the initial evaluation of each thread '
46
+ 'in the thread pool, helping to prevent a burst in LLM QPS at '
47
+ 'startup. If set to None, no delay will be applied.'
48
+ )
49
+ ] = None
50
+
51
+ def _run(self, evaluations: list[base.Evaluation]) -> None:
52
+ """Runs the evaluations in parallel."""
53
+ def _run_group(evaluation_group: list[base.Evaluation]):
54
+ for e in evaluation_group:
55
+ self.run_evaluation(e)
56
+
57
+ # Run evaluations in parallel groupped by resource key.
58
+ groups: dict[str, list[base.Evaluation]] = collections.defaultdict(list)
59
+ for e in evaluations:
60
+ resource_ids = e.resource_ids()
61
+ if not resource_ids:
62
+ group_id = e.id
63
+ else:
64
+ # TODO(daiyip): support group that requires multiple resources.
65
+ group_id = resource_ids.pop()
66
+ groups[group_id].append(e)
67
+
68
+ for _, _, _ in lf.concurrent_map(
69
+ _run_group,
70
+ groups.values(),
71
+ max_workers=max(64, len(groups)),
72
+ timeout=self.timeout,
73
+ silence_on_errors=None,
74
+ ):
75
+ pass
76
+
77
+ def _evaluate_items(
78
+ self, evaluation: base.Evaluation, items: Iterator[base.Example]
79
+ ) -> None:
80
+ """Override run items to run in parallel."""
81
+ if self.concurrent_startup_delay is not None:
82
+ thread_delayed = {}
83
+ def _evaluate_item(item: base.Example):
84
+ thread_id = threading.current_thread().ident
85
+ if thread_id not in thread_delayed:
86
+ thread_delayed[thread_id] = True
87
+ time.sleep(random.randint(*self.concurrent_startup_delay))
88
+ return self.evaluate_item(evaluation, item)
89
+ else:
90
+ def _evaluate_item(item: base.Example):
91
+ return self.evaluate_item(evaluation, item)
92
+
93
+ for _, _, _ in lf.concurrent_map(
94
+ _evaluate_item,
95
+ items,
96
+ max_workers=evaluation.max_workers,
97
+ timeout=self.timeout,
98
+ silence_on_errors=None,
99
+ ):
100
+ pass