langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Checkpoint aggregator for Langfun evaluations."""
15
+
16
+ import concurrent.futures
17
+ import dataclasses
18
+ import os
19
+ import threading
20
+ import time
21
+ from typing import Annotated, Iterator
22
+
23
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
24
+ from langfun.core.eval.v2 import example as example_lib
25
+ from langfun.core.eval.v2 import reporting
26
+ from langfun.core.eval.v2.runners import base
27
+
28
+ import pyglove as pg
29
+
30
+
31
+ class CheckpointMonitor(base.RunnerBase):
32
+ """Runner for monitoring checkpoing files generated by other runners.
33
+
34
+ Currently checkpoint monitor only supports aggregating per-example
35
+ checkpoint files.
36
+ """
37
+
38
+ NAME = 'checkpoint_monitor'
39
+
40
+ plugins = [
41
+ reporting.HtmlReporter(),
42
+ ]
43
+
44
+ checkpoint_pattern: Annotated[
45
+ str, 'The glob pattern of the checkpoint files to monitor.'
46
+ ] = 'checkpoint_*.bagz'
47
+
48
+ monitor_inprogress_files: Annotated[
49
+ bool,
50
+ 'If True, monitor in-progress files to aggregate.'
51
+ ] = False
52
+
53
+ poll_interval: Annotated[
54
+ int,
55
+ 'The interval in seconds to poll for new checkpoint files.'
56
+ ] = 5
57
+
58
+ max_aggregation_threads: Annotated[
59
+ int,
60
+ 'The maximum number of threads to aggregate checkpoints.'
61
+ ] = 128
62
+
63
+ @dataclasses.dataclass
64
+ class _AggregationEntry:
65
+ evaluation: evaluation_lib.Evaluation
66
+ output_dir: str
67
+ inprogress_file_pattern: str | None
68
+ ckpt_file_pattern: str
69
+ example_ids_inprogress: set[int]
70
+ example_ids_to_be_aggregated: set[int]
71
+ example_ids_being_aggregated: set[int]
72
+ completion_lock: threading.Lock
73
+ is_completed: bool = False
74
+
75
+ def _on_bound(self):
76
+ super()._on_bound()
77
+ self._monitor_thread = None
78
+ self._aggregation_entries = []
79
+ self._aggregator_pool = None
80
+ self._error = None
81
+
82
+ def start(self):
83
+ # Reset the experiment state before getting started.
84
+ self.current_run.experiment.reset()
85
+
86
+ # Signal the start of the run.
87
+ self.on_run_start()
88
+
89
+ # Start the non-leaf nodes.
90
+ for node in self.current_run.experiment.nonleaf_nodes:
91
+ self.on_experiment_start(node)
92
+
93
+ for evaluation in self.current_run.experiment.leaf_nodes:
94
+ # This is not precise, but we at least notify example start.
95
+ if not self.current_run.filter or self.current_run.filter(evaluation):
96
+ self.on_experiment_start(evaluation)
97
+
98
+ # Signal the start of the examples if we are not monitoring in-progress
99
+ # files.
100
+ if not self.monitor_inprogress_files:
101
+ for example_id in self.current_run.examples_to_evaluate(evaluation):
102
+ self._mark_example_started(evaluation, example_id)
103
+
104
+ # Create the aggregation entries for polling.
105
+ output_dir = self.current_run.output_dir(evaluation)
106
+ self._aggregation_entries.append(
107
+ self._AggregationEntry(
108
+ evaluation=evaluation,
109
+ output_dir=output_dir,
110
+ ckpt_file_pattern=os.path.join(
111
+ output_dir, self.checkpoint_pattern
112
+ ),
113
+ inprogress_file_pattern=os.path.join(
114
+ output_dir, '*.inprogress'
115
+ ) if self.monitor_inprogress_files else None,
116
+ example_ids_to_be_aggregated=(
117
+ self.current_run.examples_to_evaluate(evaluation)
118
+ ),
119
+ example_ids_inprogress=set(),
120
+ example_ids_being_aggregated=set(),
121
+ completion_lock=threading.Lock(),
122
+ is_completed=False,
123
+ )
124
+ )
125
+ else:
126
+ self.on_experiment_skipped(evaluation)
127
+
128
+ self._aggregator_pool = concurrent.futures.ThreadPoolExecutor(
129
+ max_workers=self.max_aggregation_threads
130
+ )
131
+ self._monitor_thread = threading.Thread(target=self._monitor_loop)
132
+ self._monitor_thread.start()
133
+
134
+ def join(self):
135
+ if self._monitor_thread:
136
+ self._monitor_thread.join()
137
+ if self._error is not None:
138
+ raise self._error
139
+
140
+ def run(self):
141
+ self.start()
142
+ self.join()
143
+
144
+ def _monitor_loop(self):
145
+ while not self._error and any(
146
+ not e.is_completed for e in self._aggregation_entries
147
+ ):
148
+ for entry in self._aggregation_entries:
149
+ if not entry.example_ids_to_be_aggregated:
150
+ continue
151
+
152
+ # Signal example processing.
153
+ if self.monitor_inprogress_files:
154
+ inprogress_files = pg.io.glob(entry.inprogress_file_pattern)
155
+ for inprogress_file in inprogress_files:
156
+ example_id = int(
157
+ os.path.basename(inprogress_file).split('.')[0]
158
+ )
159
+ if example_id not in entry.example_ids_inprogress:
160
+ self._mark_example_started(entry.evaluation, example_id)
161
+ entry.example_ids_inprogress.add(example_id)
162
+
163
+ for filepath in pg.io.glob(entry.ckpt_file_pattern):
164
+ example_id = int(
165
+ os.path.basename(filepath).split('.')[0].split('_')[-1]
166
+ )
167
+ if example_id in entry.example_ids_to_be_aggregated:
168
+ # Remove example ID from the set to avoid duplicate processing.
169
+ entry.example_ids_to_be_aggregated.remove(example_id)
170
+ entry.example_ids_being_aggregated.add(example_id)
171
+
172
+ # It could be that the example has been processed before, but the
173
+ # inprogress file was removed. In this case, we should signal the
174
+ # example has started before completing it.
175
+ if example_id not in entry.example_ids_inprogress:
176
+ self._mark_example_started(entry.evaluation, example_id)
177
+ entry.example_ids_inprogress.add(example_id)
178
+
179
+ self._aggregator_pool.submit(
180
+ self._aggregate, entry, filepath, example_id
181
+ )
182
+ pg.logging.info(
183
+ '[%s] Aggregating example %d from %s...',
184
+ entry.evaluation.id,
185
+ example_id,
186
+ filepath,
187
+ )
188
+ time.sleep(self.poll_interval)
189
+
190
+ if self._error is None:
191
+ self.on_run_complete()
192
+ else:
193
+ self.on_run_abort(self._error)
194
+
195
+ def _aggregate(
196
+ self,
197
+ entry: _AggregationEntry,
198
+ ckpt_filepath: str,
199
+ example_id: int
200
+ ):
201
+ """Aggregate an example from a checkpoint file."""
202
+ try:
203
+ loaded_examples = entry.evaluation.state.load(
204
+ ckpt_filepath,
205
+ example_input_by_id=entry.evaluation.example_input_by_id,
206
+ # Example metadata may be expensive to load, and is not used by
207
+ # metric aggregation. Thus we do not load example metadata.
208
+ load_example_metadata=False
209
+ )
210
+ assert len(loaded_examples) >= 1, loaded_examples
211
+ # Ocassionally the per-example checkpoint file may contain the same
212
+ # example processed multiple times. We only need to aggregate the last
213
+ # example.
214
+ example = loaded_examples[-1]
215
+ except BaseException as e: # pylint: disable=broad-except
216
+ error_info = pg.ErrorInfo.from_exception(e)
217
+ pg.logging.error(
218
+ '[%s] Failed to aggregate example %d: %s',
219
+ entry.evaluation.id,
220
+ example_id,
221
+ error_info
222
+ )
223
+ example = example_lib.Example(
224
+ id=example_id,
225
+ input=entry.evaluation.example_input_by_id(example_id),
226
+ error=error_info,
227
+ )
228
+
229
+ # This will skip processing but still allow metrics to be collected.
230
+ # `process` will never be called for evaluation, thus we do not
231
+ # need to setup/teardown evaluation.
232
+ example = entry.evaluation.evaluate(
233
+ example, reevaluate_upon_previous_errors=False
234
+ )
235
+ example.newly_processed = True
236
+ pg.logging.info(
237
+ '[%s] Successfully aggregated example %d from %s.',
238
+ entry.evaluation.id,
239
+ example_id,
240
+ ckpt_filepath,
241
+ )
242
+
243
+ try:
244
+ self.on_example_complete(entry.evaluation, example)
245
+ except BaseException as e: # pylint: disable=broad-except
246
+ # Plugin failures should be raised to the user.
247
+ self._error = e
248
+
249
+ entry.example_ids_being_aggregated.remove(example_id)
250
+
251
+ # Remove the in-progress file to indicate that the example has been
252
+ # processed.
253
+ try:
254
+ pg.io.rm(os.path.join(entry.output_dir, f'{example_id}.inprogress'))
255
+ except FileNotFoundError:
256
+ pass
257
+
258
+ if (not self._error
259
+ and not entry.example_ids_to_be_aggregated
260
+ and not entry.example_ids_being_aggregated):
261
+ with entry.completion_lock:
262
+ if not entry.is_completed:
263
+ entry.is_completed = True
264
+ try:
265
+ self.on_experiment_complete(entry.evaluation)
266
+ except BaseException as e: # pylint: disable=broad-except
267
+ # Plugin failures should be raised to the user.
268
+ self._error = e
269
+
270
+ def _mark_example_started(
271
+ self,
272
+ evaluation: evaluation_lib.Evaluation,
273
+ example_id: int
274
+ ) -> None:
275
+ """Mark an example as started."""
276
+ example = example_lib.Example(
277
+ id=example_id, input=evaluation.example_input_by_id(example_id),
278
+ )
279
+ example.start_time = time.time()
280
+ self.on_example_start(evaluation, example)
281
+
282
+ # We update evaluation state with the inprogress status so the evaluation
283
+ # HTML could show remotely in-progress examples.
284
+ evaluation.state.update(example, in_progress=True)
285
+
286
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]):
287
+ raise NotImplementedError('Not needed in checkpoint monitor.')
288
+
289
+ def _evaluate_items(
290
+ self,
291
+ evaluation: evaluation_lib.Evaluation,
292
+ items: Iterator[example_lib.Example]
293
+ ) -> None:
294
+ raise NotImplementedError('Not needed in checkpoint monitor.')
@@ -0,0 +1,162 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import unittest
17
+
18
+ from langfun.core.eval.v2 import checkpointing
19
+ from langfun.core.eval.v2 import eval_test_helper
20
+ from langfun.core.eval.v2 import example as example_lib
21
+ from langfun.core.eval.v2 import experiment as experiment_lib
22
+ from langfun.core.eval.v2.runners import ckpt_monitor
23
+ from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-import
24
+ import pyglove as pg
25
+
26
+
27
+ class CheckpointMonitorTest(unittest.TestCase):
28
+
29
+ def setUp(self):
30
+ super().setUp()
31
+ self.test_dir = tempfile.mkdtemp()
32
+
33
+ def test_aggregate(self):
34
+ exp = eval_test_helper.test_experiment()
35
+ root_dir = os.path.join(self.test_dir, 'test_aggregate')
36
+ run = exp.run(
37
+ root_dir,
38
+ runner='sequential',
39
+ progress_tracker=None,
40
+ plugins=[
41
+ checkpointing.PerExampleCheckpointer(
42
+ checkpoint_filename='checkpoint.jsonl'
43
+ )
44
+ ],
45
+ use_cache='no',
46
+ )
47
+ # Try to corrupt one of the checkpoint files.
48
+ pg.io.writefile(
49
+ run.output_path_for(exp.leaf_nodes[0], 'checkpoint_1.jsonl'),
50
+ 'bad ckpt'
51
+ )
52
+ plugin = eval_test_helper.TestPlugin()
53
+ monitor = ckpt_monitor.CheckpointMonitor(
54
+ run,
55
+ plugins=[plugin],
56
+ checkpoint_pattern='checkpoint_*.jsonl',
57
+ monitor_inprogress_files=True,
58
+ )
59
+ monitor.run()
60
+
61
+ # Assert that the in-progress files are created and not removed.
62
+ for entry in monitor._aggregation_entries:
63
+ self.assertEqual(len(entry.example_ids_inprogress), 10)
64
+
65
+ # 6 leaf nodes + 1 suite + 1 hyper.
66
+ self.assertEqual(len(plugin.started_experiments), 6 + 2)
67
+ self.assertEqual(len(plugin.completed_experiments), 6 + 2)
68
+ self.assertEqual(len(plugin.started_example_ids), 10 * 6)
69
+ self.assertEqual(len(plugin.completed_example_ids), 10 * 6)
70
+ for e in exp.leaf_nodes:
71
+ self.assertEqual(e.progress.num_completed, 10)
72
+
73
+ def test_aggregate_with_filter(self):
74
+ exp = eval_test_helper.test_experiment()
75
+ root_dir = os.path.join(self.test_dir, 'test_aggregate_with_filter')
76
+
77
+ node_to_skip = exp.leaf_nodes[2]
78
+ # Run experiment to generate checkpoint files for all examples.
79
+ run = exp.run(
80
+ root_dir,
81
+ runner='sequential',
82
+ filter=lambda e: e.id != node_to_skip.id,
83
+ progress_tracker=None,
84
+ plugins=[
85
+ checkpointing.PerExampleCheckpointer(
86
+ checkpoint_filename='checkpoint.jsonl'
87
+ )
88
+ ],
89
+ use_cache='no',
90
+ )
91
+ plugin = eval_test_helper.TestPlugin()
92
+ monitor = ckpt_monitor.CheckpointMonitor(
93
+ run,
94
+ plugins=[plugin],
95
+ checkpoint_pattern='checkpoint_*.jsonl',
96
+ )
97
+ monitor.run()
98
+
99
+ # Assert that on_experiment_skipped was called for the filtered node.
100
+ self.assertEqual(len(plugin.skipped_experiments), 1)
101
+ self.assertEqual(plugin.skipped_experiments[0].id, node_to_skip.id)
102
+
103
+ # Assert that the skipped node was not started.
104
+ started_ids = [e.id for e in plugin.started_experiments]
105
+ self.assertNotIn(node_to_skip.id, started_ids)
106
+
107
+ def test_plugin_raise(self):
108
+
109
+ class TestPlugin(eval_test_helper.TestPlugin):
110
+ simulate_raise_on_example_complete: bool = False
111
+ simulate_raise_on_experiment_complete: bool = False
112
+
113
+ def on_example_complete(
114
+ self,
115
+ runner: experiment_lib.Runner,
116
+ experiment: experiment_lib.Experiment,
117
+ example: example_lib.Example
118
+ ):
119
+ if self.simulate_raise_on_example_complete:
120
+ raise ValueError('example complete error')
121
+
122
+ def on_experiment_complete(
123
+ self,
124
+ runner: experiment_lib.Runner,
125
+ experiment: experiment_lib.Experiment
126
+ ):
127
+ if self.simulate_raise_on_experiment_complete:
128
+ raise ValueError('experiment complete error')
129
+
130
+ exp = eval_test_helper.test_evaluation()
131
+ root_dir = os.path.join(self.test_dir, 'test_plugin_raise')
132
+
133
+ # Run experiment to generate checkpoint files for all examples.
134
+ run = exp.run(
135
+ root_dir,
136
+ runner='sequential',
137
+ progress_tracker=None,
138
+ plugins=[
139
+ checkpointing.PerExampleCheckpointer(
140
+ checkpoint_filename='checkpoint.jsonl'
141
+ )
142
+ ],
143
+ use_cache='no',
144
+ )
145
+
146
+ with self.assertRaisesRegex(ValueError, 'example complete error'):
147
+ ckpt_monitor.CheckpointMonitor(
148
+ run,
149
+ plugins=[TestPlugin(simulate_raise_on_example_complete=True)],
150
+ checkpoint_pattern='checkpoint_*.jsonl',
151
+ ).run()
152
+
153
+ with self.assertRaisesRegex(ValueError, 'experiment complete error'):
154
+ ckpt_monitor.CheckpointMonitor(
155
+ run,
156
+ plugins=[TestPlugin(simulate_raise_on_experiment_complete=True)],
157
+ checkpoint_pattern='checkpoint_*.jsonl',
158
+ ).run()
159
+
160
+
161
+ if __name__ == '__main__':
162
+ unittest.main()
@@ -0,0 +1,40 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Debug runner."""
15
+
16
+ from langfun.core.eval.v2.runners import sequential
17
+
18
+
19
+ class DebugRunner(sequential.SequentialRunner):
20
+ """A runner for debugging evaluations.
21
+
22
+ The debug runner is a sequential runner that only runs the first example
23
+ of each evaluation, with `raise_if_has_error` enabled. This is useful for
24
+ quickly identifying issues in evaluation logic during development.
25
+ Checkpointers are disabled for this runner.
26
+ """
27
+
28
+ NAME = 'debug'
29
+
30
+ # Do not use the checkpointer for debug runner.
31
+ plugins = []
32
+
33
+ def _on_bound(self):
34
+ super()._on_bound()
35
+ if self.current_run.example_ids is None:
36
+ self.current_run.rebind(example_ids=[1], skip_notification=True)
37
+ self.current_run.rebind(raise_if_has_error=True, skip_notification=True)
38
+
39
+ def _save_run_manifest(self) -> None:
40
+ """Do nothing to avoid overriden existing runs."""
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for debug runner."""
15
+ import os
16
+ import tempfile
17
+ from typing import Any
18
+ import unittest
19
+
20
+ from langfun.core.eval.v2 import eval_test_helper
21
+ from langfun.core.eval.v2.runners import debug # pylint: disable=unused-import
22
+
23
+ import pyglove as pg
24
+
25
+
26
+ class DebugRunnerTest(unittest.TestCase):
27
+
28
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
29
+ self.assertEqual(len(actual), len(expected))
30
+ for i, (x, y) in enumerate(zip(actual, expected)):
31
+ if x is not y:
32
+ print(i, pg.diff(x, y))
33
+ self.assertIs(x, y)
34
+
35
+ def test_debug_runner(self):
36
+ plugin = eval_test_helper.TestPlugin()
37
+ exp = eval_test_helper.test_experiment()
38
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_debug_runner')
39
+ run = exp.run(root_dir, runner='debug', plugins=[plugin])
40
+
41
+ self.assertIsNotNone(plugin.start_time)
42
+ self.assertIsNotNone(plugin.complete_time)
43
+ self.assertGreater(plugin.complete_time, plugin.start_time)
44
+
45
+ self.assertEqual(
46
+ len(plugin.started_experiments), len(exp.nodes)
47
+ )
48
+ self.assertEqual(
49
+ len(plugin.completed_experiments), len(exp.nodes)
50
+ )
51
+ self.assertEqual(
52
+ len(plugin.started_example_ids), 6 * 1
53
+ )
54
+ self.assertEqual(
55
+ len(plugin.completed_example_ids), 6 * 1
56
+ )
57
+ self.assert_same_list(plugin.skipped_experiments, [])
58
+ self.assertFalse(
59
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
60
+ )
61
+
62
+ for node in exp.nodes:
63
+ self.assertTrue(node.progress.is_started)
64
+ self.assertTrue(node.progress.is_completed)
65
+ if node.is_leaf:
66
+ self.assertEqual(node.progress.num_skipped, 0)
67
+ self.assertEqual(node.progress.num_completed, 1)
68
+ self.assertEqual(node.progress.num_failed, 0)
69
+ else:
70
+ self.assertEqual(node.progress.num_skipped, 0)
71
+ self.assertEqual(node.progress.num_failed, 0)
72
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
73
+
74
+
75
+ if __name__ == '__main__':
76
+ unittest.main()