langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,350 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Checkpoint aggregator for Langfun evaluations."""
15
+
16
+ import concurrent.futures
17
+ import dataclasses
18
+ import os
19
+ import threading
20
+ import time
21
+ from typing import Annotated, Iterator
22
+
23
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
24
+ from langfun.core.eval.v2 import example as example_lib
25
+ from langfun.core.eval.v2 import reporting
26
+ from langfun.core.eval.v2.runners import base
27
+
28
+ import pyglove as pg
29
+
30
+
31
+ class CheckpointMonitor(base.RunnerBase):
32
+ """Runner for monitoring checkpoing files generated by other runners.
33
+
34
+ Currently checkpoint monitor only supports aggregating per-example
35
+ checkpoint files.
36
+ """
37
+
38
+ NAME = 'checkpoint_monitor'
39
+
40
+ plugins = [
41
+ reporting.HtmlReporter(),
42
+ ]
43
+
44
+ checkpoint_pattern: Annotated[
45
+ str, 'The glob pattern of the checkpoint files to monitor.'
46
+ ] = 'checkpoint_*.bagz'
47
+
48
+ monitor_inprogress_files: Annotated[
49
+ bool,
50
+ 'If True, monitor in-progress files to aggregate.'
51
+ ] = False
52
+
53
+ poll_interval: Annotated[
54
+ int,
55
+ 'The interval in seconds to poll for new checkpoint files.'
56
+ ] = 5
57
+
58
+ max_aggregation_threads: Annotated[
59
+ int,
60
+ 'The maximum number of threads to aggregate checkpoints.'
61
+ ] = 128
62
+
63
+ bypass_old_ckpt_files_with_non_oop_errors: Annotated[
64
+ bool,
65
+ 'If True, ignore old checkpoint files with non-oop errors.'
66
+ ] = True
67
+
68
+ ckpt_start_time: Annotated[
69
+ float | None,
70
+ (
71
+ 'The timestamp to treat checkpoint files modified before this '
72
+ 'time as old.'
73
+ )
74
+ ] = None
75
+
76
+ @dataclasses.dataclass
77
+ class _AggregationEntry:
78
+ evaluation: evaluation_lib.Evaluation
79
+ output_dir: str
80
+ inprogress_file_pattern: str | None
81
+ ckpt_file_pattern: str
82
+ example_ids_inprogress: set[int]
83
+ example_ids_to_be_aggregated: set[int]
84
+ example_ids_being_aggregated: set[int]
85
+ completion_lock: threading.Lock
86
+ is_completed: bool = False
87
+
88
+ def _on_bound(self):
89
+ super()._on_bound()
90
+ self._monitor_thread = None
91
+ self._aggregation_entries = []
92
+ self._aggregator_pool = None
93
+ self._error = None
94
+ if self.ckpt_start_time is None:
95
+ self.rebind(ckpt_start_time=time.time(), skip_notification=True)
96
+ self._ckpt_bypass_timestamp: dict[str, int] = {}
97
+
98
+ def start(self):
99
+ # Reset the experiment state before getting started.
100
+ self.current_run.experiment.reset()
101
+
102
+ # Signal the start of the run.
103
+ self.on_run_start()
104
+
105
+ # Start the non-leaf nodes.
106
+ for node in self.current_run.experiment.nonleaf_nodes:
107
+ self.on_experiment_start(node)
108
+
109
+ for evaluation in self.current_run.experiment.leaf_nodes:
110
+ # This is not precise, but we at least notify example start.
111
+ if not self.current_run.filter or self.current_run.filter(evaluation):
112
+ self.on_experiment_start(evaluation)
113
+
114
+ # Signal the start of the examples if we are not monitoring in-progress
115
+ # files.
116
+ if not self.monitor_inprogress_files:
117
+ for example_id in self.current_run.examples_to_evaluate(evaluation):
118
+ self._mark_example_started(evaluation, example_id)
119
+
120
+ # Create the aggregation entries for polling.
121
+ output_dir = self.current_run.output_dir(evaluation)
122
+ self._aggregation_entries.append(
123
+ self._AggregationEntry(
124
+ evaluation=evaluation,
125
+ output_dir=output_dir,
126
+ ckpt_file_pattern=os.path.join(
127
+ output_dir, self.checkpoint_pattern
128
+ ),
129
+ inprogress_file_pattern=os.path.join(
130
+ output_dir, '*.inprogress'
131
+ ) if self.monitor_inprogress_files else None,
132
+ example_ids_to_be_aggregated=(
133
+ self.current_run.examples_to_evaluate(evaluation)
134
+ ),
135
+ example_ids_inprogress=set(),
136
+ example_ids_being_aggregated=set(),
137
+ completion_lock=threading.Lock(),
138
+ is_completed=False,
139
+ )
140
+ )
141
+ else:
142
+ self.on_experiment_skipped(evaluation)
143
+
144
+ self._aggregator_pool = concurrent.futures.ThreadPoolExecutor(
145
+ max_workers=self.max_aggregation_threads
146
+ )
147
+ self._monitor_thread = threading.Thread(target=self._monitor_loop)
148
+ self._monitor_thread.start()
149
+
150
+ def join(self):
151
+ if self._monitor_thread:
152
+ self._monitor_thread.join()
153
+ if self._error is not None:
154
+ raise self._error
155
+
156
+ def run(self):
157
+ self.start()
158
+ self.join()
159
+
160
+ def _monitor_loop(self):
161
+ while not self._error and any(
162
+ not e.is_completed for e in self._aggregation_entries
163
+ ):
164
+ for entry in self._aggregation_entries:
165
+ if not entry.example_ids_to_be_aggregated:
166
+ continue
167
+
168
+ # Signal example processing.
169
+ if self.monitor_inprogress_files:
170
+ inprogress_files = pg.io.glob(entry.inprogress_file_pattern)
171
+ for inprogress_file in inprogress_files:
172
+ example_id = int(
173
+ os.path.basename(inprogress_file).split('.')[0]
174
+ )
175
+ if example_id not in entry.example_ids_inprogress:
176
+ self._mark_example_started(entry.evaluation, example_id)
177
+ entry.example_ids_inprogress.add(example_id)
178
+
179
+ for filepath in pg.io.glob(entry.ckpt_file_pattern):
180
+ example_id = int(
181
+ os.path.basename(filepath).split('.')[0].split('_')[-1]
182
+ )
183
+ if example_id in entry.example_ids_to_be_aggregated:
184
+ last_modified_time = pg.io.getmtime(filepath)
185
+ bypass_timestamp = self._ckpt_bypass_timestamp.get(filepath)
186
+ if (
187
+ bypass_timestamp is not None
188
+ and last_modified_time <= bypass_timestamp
189
+ ):
190
+ continue
191
+
192
+ # Remove example ID from the set to avoid duplicate processing.
193
+ entry.example_ids_to_be_aggregated.remove(example_id)
194
+ entry.example_ids_being_aggregated.add(example_id)
195
+
196
+ # It could be that the example has been processed before, but the
197
+ # inprogress file was removed. In this case, we should signal the
198
+ # example has started before completing it.
199
+ if example_id not in entry.example_ids_inprogress:
200
+ self._mark_example_started(entry.evaluation, example_id)
201
+ entry.example_ids_inprogress.add(example_id)
202
+
203
+ self._aggregator_pool.submit(
204
+ self._aggregate, entry, filepath, example_id, last_modified_time
205
+ )
206
+ pg.logging.info(
207
+ '[%s] Aggregating example %d from %s...',
208
+ entry.evaluation.id,
209
+ example_id,
210
+ filepath,
211
+ )
212
+ time.sleep(self.poll_interval)
213
+
214
+ if self._error is None:
215
+ self.on_run_complete()
216
+ else:
217
+ self.on_run_abort(self._error)
218
+
219
+ def _aggregate(
220
+ self,
221
+ entry: _AggregationEntry,
222
+ ckpt_filepath: str,
223
+ example_id: int,
224
+ last_modified_time: float,
225
+ ):
226
+ """Aggregate an example from a checkpoint file."""
227
+ try:
228
+ loaded_examples = entry.evaluation.state.load(
229
+ ckpt_filepath,
230
+ example_input_by_id=entry.evaluation.example_input_by_id,
231
+ # Example metadata may be expensive to load, and is not used by
232
+ # metric aggregation. Thus we do not load example metadata.
233
+ load_example_metadata=False
234
+ )
235
+ assert len(loaded_examples) >= 1, loaded_examples
236
+ # Ocassionally the per-example checkpoint file may contain the same
237
+ # example processed multiple times. We only need to aggregate the last
238
+ # example.
239
+ example = loaded_examples[-1]
240
+ if (
241
+ self.bypass_old_ckpt_files_with_non_oop_errors
242
+ and last_modified_time < self.ckpt_start_time
243
+ and example.error is not None
244
+ and not example.error.tag.startswith('MappingError')
245
+ ):
246
+ entry.example_ids_being_aggregated.remove(example_id)
247
+ entry.example_ids_to_be_aggregated.add(example_id)
248
+ self._ckpt_bypass_timestamp[ckpt_filepath] = last_modified_time
249
+ pg.logging.info(
250
+ '[%s] Bypassing old checkpoint file with non-oop errors (%s) '
251
+ 'for example %d, last_modified_time: %s, ckpt_start_time: %s',
252
+ entry.evaluation.id,
253
+ ckpt_filepath,
254
+ example_id,
255
+ last_modified_time,
256
+ self.ckpt_start_time,
257
+ )
258
+ return
259
+ except BaseException as e: # pylint: disable=broad-except
260
+ error_info = pg.ErrorInfo.from_exception(e)
261
+ pg.logging.error(
262
+ '[%s] Failed to aggregate example %d: %s',
263
+ entry.evaluation.id,
264
+ example_id,
265
+ error_info
266
+ )
267
+ example = example_lib.Example(
268
+ id=example_id,
269
+ input=entry.evaluation.example_input_by_id(example_id),
270
+ error=error_info,
271
+ )
272
+
273
+ # This will skip processing but still allow metrics to be collected.
274
+ # `process` will never be called for evaluation, thus we do not
275
+ # need to setup/teardown evaluation.
276
+ try:
277
+ example = entry.evaluation.evaluate(
278
+ example, reevaluate_upon_previous_errors=False
279
+ )
280
+ except BaseException as e: # pylint: disable=broad-except
281
+ pg.logging.error(
282
+ '[%s] Unexpected error found during evaluating example %d from %s.',
283
+ entry.evaluation.id,
284
+ example_id,
285
+ ckpt_filepath,
286
+ )
287
+ self._error = e
288
+ entry.example_ids_being_aggregated.remove(example_id)
289
+ return
290
+
291
+ example.newly_processed = True
292
+ pg.logging.info(
293
+ '[%s] Successfully aggregated example %d from %s.',
294
+ entry.evaluation.id,
295
+ example_id,
296
+ ckpt_filepath,
297
+ )
298
+
299
+ try:
300
+ self.on_example_complete(entry.evaluation, example)
301
+ except BaseException as e: # pylint: disable=broad-except
302
+ # Plugin failures should be raised to the user.
303
+ self._error = e
304
+
305
+ entry.example_ids_being_aggregated.remove(example_id)
306
+
307
+ # Remove the in-progress file to indicate that the example has been
308
+ # processed.
309
+ try:
310
+ pg.io.rm(os.path.join(entry.output_dir, f'{example_id}.inprogress'))
311
+ except FileNotFoundError:
312
+ pass
313
+
314
+ if (not self._error
315
+ and not entry.example_ids_to_be_aggregated
316
+ and not entry.example_ids_being_aggregated):
317
+ with entry.completion_lock:
318
+ if not entry.is_completed:
319
+ entry.is_completed = True
320
+ try:
321
+ self.on_experiment_complete(entry.evaluation)
322
+ except BaseException as e: # pylint: disable=broad-except
323
+ # Plugin failures should be raised to the user.
324
+ self._error = e
325
+
326
+ def _mark_example_started(
327
+ self,
328
+ evaluation: evaluation_lib.Evaluation,
329
+ example_id: int
330
+ ) -> None:
331
+ """Mark an example as started."""
332
+ example = example_lib.Example(
333
+ id=example_id, input=evaluation.example_input_by_id(example_id),
334
+ )
335
+ example.start_time = time.time()
336
+ self.on_example_start(evaluation, example)
337
+
338
+ # We update evaluation state with the inprogress status so the evaluation
339
+ # HTML could show remotely in-progress examples.
340
+ evaluation.state.update(example, in_progress=True)
341
+
342
+ def _run(self, evaluations: list[evaluation_lib.Evaluation]):
343
+ raise NotImplementedError('Not needed in checkpoint monitor.')
344
+
345
+ def _evaluate_items(
346
+ self,
347
+ evaluation: evaluation_lib.Evaluation,
348
+ items: Iterator[example_lib.Example]
349
+ ) -> None:
350
+ raise NotImplementedError('Not needed in checkpoint monitor.')
@@ -0,0 +1,213 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import time
17
+ import unittest
18
+
19
+ import langfun.core as lf
20
+ from langfun.core.eval.v2 import checkpointing
21
+ from langfun.core.eval.v2 import eval_test_helper
22
+ from langfun.core.eval.v2 import example as example_lib
23
+ from langfun.core.eval.v2 import experiment as experiment_lib
24
+ from langfun.core.eval.v2.runners import ckpt_monitor
25
+ from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-import
26
+ import pyglove as pg
27
+
28
+
29
+ class CheckpointMonitorTest(unittest.TestCase):
30
+
31
+ def setUp(self):
32
+ super().setUp()
33
+ self.test_dir = tempfile.mkdtemp()
34
+
35
+ def test_aggregate(self):
36
+ exp = eval_test_helper.test_experiment()
37
+ root_dir = os.path.join(self.test_dir, 'test_aggregate')
38
+ ckpt_start_time = time.time()
39
+ run = exp.run(
40
+ root_dir,
41
+ runner='sequential',
42
+ progress_tracker=None,
43
+ plugins=[
44
+ checkpointing.PerExampleCheckpointer(
45
+ checkpoint_filename='checkpoint.jsonl'
46
+ )
47
+ ],
48
+ use_cache='no',
49
+ )
50
+ # Try to corrupt one of the checkpoint files.
51
+ pg.io.writefile(
52
+ run.output_path_for(exp.leaf_nodes[0], 'checkpoint_1.jsonl'),
53
+ 'bad ckpt'
54
+ )
55
+ plugin = eval_test_helper.TestPlugin()
56
+ monitor = ckpt_monitor.CheckpointMonitor(
57
+ run,
58
+ plugins=[plugin],
59
+ checkpoint_pattern='checkpoint_*.jsonl',
60
+ monitor_inprogress_files=True,
61
+ ckpt_start_time=ckpt_start_time,
62
+ )
63
+ monitor.run()
64
+
65
+ # Assert that the in-progress files are created and not removed.
66
+ for entry in monitor._aggregation_entries:
67
+ self.assertEqual(len(entry.example_ids_inprogress), 10)
68
+
69
+ # 6 leaf nodes + 1 suite + 1 hyper.
70
+ self.assertEqual(len(plugin.started_experiments), 6 + 2)
71
+ self.assertEqual(len(plugin.completed_experiments), 6 + 2)
72
+ self.assertEqual(len(plugin.started_example_ids), 10 * 6)
73
+ self.assertEqual(len(plugin.completed_example_ids), 10 * 6)
74
+ for e in exp.leaf_nodes:
75
+ self.assertEqual(e.progress.num_completed, 10)
76
+
77
+ def test_ignore_old_ckpt_files_with_non_oop_errors(self):
78
+ exp = eval_test_helper.test_evaluation()
79
+ root_dir = os.path.join(self.test_dir, 'test_ignore_old_ckpt_files')
80
+ run = exp.run(
81
+ root_dir,
82
+ runner='sequential',
83
+ progress_tracker=None,
84
+ plugins=[
85
+ checkpointing.PerExampleCheckpointer(
86
+ checkpoint_filename='checkpoint.jsonl'
87
+ )
88
+ ],
89
+ use_cache='no',
90
+ )
91
+ monitor = ckpt_monitor.CheckpointMonitor(
92
+ run,
93
+ plugins=[],
94
+ checkpoint_pattern='checkpoint_*.jsonl',
95
+ monitor_inprogress_files=True
96
+ )
97
+ monitor.start()
98
+ time.sleep(2)
99
+ # Example 6 is a non-oop error, we simulate a re-evaluation.
100
+ ex = example_lib.Example(
101
+ id=6, output=1, metric_metadata={'match': {'is_correct': True}},
102
+ start_time=time.time() - 2, end_time=time.time(),
103
+ usage_summary=lf.UsageSummary(),
104
+ execution_status={
105
+ 'evaluate': pg.utils.TimeIt.Status(name='evaluate', elapse=1)
106
+ }
107
+ )
108
+ with pg.io.open_sequence(
109
+ run.output_path_for(exp, 'checkpoint_6.jsonl'),
110
+ mode='w'
111
+ ) as f:
112
+ f.add(pg.to_json_str(ex))
113
+ print(time.time(), pg.io.listdir(run.output_dir(exp)))
114
+ monitor.join()
115
+ self.assertEqual(exp.progress.num_processed, 10)
116
+ self.assertEqual(exp.progress.num_completed, 10)
117
+ self.assertEqual(exp.progress.num_failed, 0)
118
+
119
+ def test_aggregate_with_filter(self):
120
+ ckpt_start_time = time.time()
121
+ exp = eval_test_helper.test_experiment()
122
+ root_dir = os.path.join(self.test_dir, 'test_aggregate_with_filter')
123
+
124
+ node_to_skip = exp.leaf_nodes[2]
125
+ # Run experiment to generate checkpoint files for all examples.
126
+ run = exp.run(
127
+ root_dir,
128
+ runner='sequential',
129
+ filter=lambda e: e.id != node_to_skip.id,
130
+ progress_tracker=None,
131
+ plugins=[
132
+ checkpointing.PerExampleCheckpointer(
133
+ checkpoint_filename='checkpoint.jsonl'
134
+ )
135
+ ],
136
+ use_cache='no',
137
+ )
138
+ plugin = eval_test_helper.TestPlugin()
139
+ monitor = ckpt_monitor.CheckpointMonitor(
140
+ run,
141
+ plugins=[plugin],
142
+ checkpoint_pattern='checkpoint_*.jsonl',
143
+ ckpt_start_time=ckpt_start_time,
144
+ )
145
+ monitor.run()
146
+
147
+ # Assert that on_experiment_skipped was called for the filtered node.
148
+ self.assertEqual(len(plugin.skipped_experiments), 1)
149
+ self.assertEqual(plugin.skipped_experiments[0].id, node_to_skip.id)
150
+
151
+ # Assert that the skipped node was not started.
152
+ started_ids = [e.id for e in plugin.started_experiments]
153
+ self.assertNotIn(node_to_skip.id, started_ids)
154
+
155
+ def test_plugin_raise(self):
156
+
157
+ class TestPlugin(eval_test_helper.TestPlugin):
158
+ simulate_raise_on_example_complete: bool = False
159
+ simulate_raise_on_experiment_complete: bool = False
160
+
161
+ def on_example_complete(
162
+ self,
163
+ runner: experiment_lib.Runner,
164
+ experiment: experiment_lib.Experiment,
165
+ example: example_lib.Example
166
+ ):
167
+ if self.simulate_raise_on_example_complete:
168
+ raise ValueError('example complete error')
169
+
170
+ def on_experiment_complete(
171
+ self,
172
+ runner: experiment_lib.Runner,
173
+ experiment: experiment_lib.Experiment
174
+ ):
175
+ if self.simulate_raise_on_experiment_complete:
176
+ raise ValueError('experiment complete error')
177
+
178
+ ckpt_start_time = time.time()
179
+ exp = eval_test_helper.test_evaluation()
180
+ root_dir = os.path.join(self.test_dir, 'test_plugin_raise')
181
+
182
+ # Run experiment to generate checkpoint files for all examples.
183
+ run = exp.run(
184
+ root_dir,
185
+ runner='sequential',
186
+ progress_tracker=None,
187
+ plugins=[
188
+ checkpointing.PerExampleCheckpointer(
189
+ checkpoint_filename='checkpoint.jsonl'
190
+ )
191
+ ],
192
+ use_cache='no',
193
+ )
194
+
195
+ with self.assertRaisesRegex(ValueError, 'example complete error'):
196
+ ckpt_monitor.CheckpointMonitor(
197
+ run,
198
+ plugins=[TestPlugin(simulate_raise_on_example_complete=True)],
199
+ checkpoint_pattern='checkpoint_*.jsonl',
200
+ ckpt_start_time=ckpt_start_time,
201
+ ).run()
202
+
203
+ with self.assertRaisesRegex(ValueError, 'experiment complete error'):
204
+ ckpt_monitor.CheckpointMonitor(
205
+ run,
206
+ plugins=[TestPlugin(simulate_raise_on_experiment_complete=True)],
207
+ checkpoint_pattern='checkpoint_*.jsonl',
208
+ ckpt_start_time=ckpt_start_time,
209
+ ).run()
210
+
211
+
212
+ if __name__ == '__main__':
213
+ unittest.main()
@@ -0,0 +1,40 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Debug runner."""
15
+
16
+ from langfun.core.eval.v2.runners import sequential
17
+
18
+
19
+ class DebugRunner(sequential.SequentialRunner):
20
+ """A runner for debugging evaluations.
21
+
22
+ The debug runner is a sequential runner that only runs the first example
23
+ of each evaluation, with `raise_if_has_error` enabled. This is useful for
24
+ quickly identifying issues in evaluation logic during development.
25
+ Checkpointers are disabled for this runner.
26
+ """
27
+
28
+ NAME = 'debug'
29
+
30
+ # Do not use the checkpointer for debug runner.
31
+ plugins = []
32
+
33
+ def _on_bound(self):
34
+ super()._on_bound()
35
+ if self.current_run.example_ids is None:
36
+ self.current_run.rebind(example_ids=[1], skip_notification=True)
37
+ self.current_run.rebind(raise_if_has_error=True, skip_notification=True)
38
+
39
+ def _save_run_manifest(self) -> None:
40
+ """Do nothing to avoid overriden existing runs."""
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for debug runner."""
15
+ import os
16
+ import tempfile
17
+ from typing import Any
18
+ import unittest
19
+
20
+ from langfun.core.eval.v2 import eval_test_helper
21
+ from langfun.core.eval.v2.runners import debug # pylint: disable=unused-import
22
+
23
+ import pyglove as pg
24
+
25
+
26
+ class DebugRunnerTest(unittest.TestCase):
27
+
28
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
29
+ self.assertEqual(len(actual), len(expected))
30
+ for i, (x, y) in enumerate(zip(actual, expected)):
31
+ if x is not y:
32
+ print(i, pg.diff(x, y))
33
+ self.assertIs(x, y)
34
+
35
+ def test_debug_runner(self):
36
+ plugin = eval_test_helper.TestPlugin()
37
+ exp = eval_test_helper.test_experiment()
38
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_debug_runner')
39
+ run = exp.run(root_dir, runner='debug', plugins=[plugin])
40
+
41
+ self.assertIsNotNone(plugin.start_time)
42
+ self.assertIsNotNone(plugin.complete_time)
43
+ self.assertGreater(plugin.complete_time, plugin.start_time)
44
+
45
+ self.assertEqual(
46
+ len(plugin.started_experiments), len(exp.nodes)
47
+ )
48
+ self.assertEqual(
49
+ len(plugin.completed_experiments), len(exp.nodes)
50
+ )
51
+ self.assertEqual(
52
+ len(plugin.started_example_ids), 6 * 1
53
+ )
54
+ self.assertEqual(
55
+ len(plugin.completed_example_ids), 6 * 1
56
+ )
57
+ self.assert_same_list(plugin.skipped_experiments, [])
58
+ self.assertFalse(
59
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
60
+ )
61
+
62
+ for node in exp.nodes:
63
+ self.assertTrue(node.progress.is_started)
64
+ self.assertTrue(node.progress.is_completed)
65
+ if node.is_leaf:
66
+ self.assertEqual(node.progress.num_skipped, 0)
67
+ self.assertEqual(node.progress.num_completed, 1)
68
+ self.assertEqual(node.progress.num_failed, 0)
69
+ else:
70
+ self.assertEqual(node.progress.num_skipped, 0)
71
+ self.assertEqual(node.progress.num_failed, 0)
72
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
73
+
74
+
75
+ if __name__ == '__main__':
76
+ unittest.main()