langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,243 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Parallel runner."""
15
+
16
+ import collections
17
+ import math
18
+ import random
19
+ import threading
20
+ import time
21
+
22
+ from typing import Annotated, Iterator
23
+ import langfun.core as lf
24
+ from langfun.core.eval.v2 import checkpointing
25
+ from langfun.core.eval.v2 import experiment as experiment_lib
26
+ from langfun.core.eval.v2.runners import base
27
+ from langfun.core.eval.v2.runners import ckpt_monitor
28
+ import pyglove as pg
29
+
30
+
31
+ class ParallelRunner(base.RunnerBase):
32
+ """A runner that executes evaluations and examples in parallel.
33
+
34
+ The parallel runner groups evaluations by their required resources
35
+ (e.g., specific LLMs) and runs evaluations that do not share resources in
36
+ parallel. Within each evaluation, examples are also processed in parallel
37
+ using threads, up to `Evaluation.max_workers`.
38
+ """
39
+
40
+ NAME = 'parallel'
41
+
42
+ timeout: Annotated[
43
+ int | None,
44
+ 'Timeout for each evaluation example.'
45
+ ] = None
46
+
47
+ concurrent_startup_delay: Annotated[
48
+ tuple[int, int] | None,
49
+ (
50
+ 'A range of seconds to delay the initial evaluation of each thread '
51
+ 'in the thread pool, helping to prevent a burst in LLM QPS at '
52
+ 'startup. If set to None, no delay will be applied.'
53
+ )
54
+ ] = None
55
+
56
+ def _run(self, evaluations: list[base.Evaluation]) -> None:
57
+ """Runs the evaluations in parallel."""
58
+ def _run_group(evaluation_group: list[base.Evaluation]):
59
+ for e in evaluation_group:
60
+ self.run_evaluation(e)
61
+
62
+ # Run evaluations in parallel groupped by resource key.
63
+ groups: dict[str, list[base.Evaluation]] = collections.defaultdict(list)
64
+ for e in evaluations:
65
+ resource_ids = e.resource_ids()
66
+ if not resource_ids:
67
+ group_id = e.id
68
+ else:
69
+ # TODO(daiyip): support group that requires multiple resources.
70
+ group_id = resource_ids.pop()
71
+ groups[group_id].append(e)
72
+
73
+ for _, _, _ in lf.concurrent_map(
74
+ _run_group,
75
+ groups.values(),
76
+ max_workers=max(64, len(groups)),
77
+ timeout=self.timeout,
78
+ silence_on_errors=None,
79
+ ):
80
+ pass
81
+
82
+ def _evaluate_items(
83
+ self, evaluation: base.Evaluation, items: Iterator[base.Example]
84
+ ) -> None:
85
+ """Override run items to run in parallel."""
86
+ if self.concurrent_startup_delay is not None:
87
+ thread_delayed = {}
88
+ def _evaluate_item(item: base.Example):
89
+ thread_id = threading.current_thread().ident
90
+ if thread_id not in thread_delayed:
91
+ thread_delayed[thread_id] = True
92
+ time.sleep(random.randint(*self.concurrent_startup_delay))
93
+ return self.evaluate_item(evaluation, item)
94
+ else:
95
+ def _evaluate_item(item: base.Example):
96
+ return self.evaluate_item(evaluation, item)
97
+
98
+ for _, _, _ in lf.concurrent_map(
99
+ _evaluate_item,
100
+ items,
101
+ max_workers=self._max_workers(evaluation),
102
+ timeout=self.timeout,
103
+ silence_on_errors=None,
104
+ ):
105
+ pass
106
+
107
+ def _max_workers(self, evaluation: base.Evaluation) -> int | None:
108
+ return evaluation.max_workers
109
+
110
+
111
+ class _SingleSliceRunner(ParallelRunner):
112
+ """A single slice runner."""
113
+
114
+ NAME = '__single_slice_runner__'
115
+
116
+ # Do not track progress in single slice runner.
117
+ progress_tracker = None
118
+
119
+ num_slices: Annotated[
120
+ int,
121
+ 'The number of slices to run the evaluations in.'
122
+ ] = 1
123
+
124
+ def _max_workers(self, evaluation: base.Evaluation) -> int | None:
125
+ max_workers = super()._max_workers(evaluation)
126
+ if max_workers is None:
127
+ return None
128
+ return max(1, math.ceil(max_workers / self.num_slices))
129
+
130
+
131
+ class MultiSliceParallelRunner(experiment_lib.Runner):
132
+ """A sliced parallel runner.
133
+
134
+ An evaluation is split into `num_slices` slices. Each MultiSliceParallelRunner
135
+ instance is responsible for evaluating a single slice. The instance with
136
+ `slice_id` 0 will also aggregate checkpoints from all slices.
137
+
138
+ Sliced parallel runner allows running multiple instances across different
139
+ machines/hosts parallelly. This can be utilize for scaling the evaluation
140
+ jobs to run on multiple machines, or running evaluations in a fault-tolerant
141
+ way by splitting each evaluation into multiple slices.
142
+ """
143
+
144
+ NAME = 'sliced-parallel'
145
+
146
+ slice_id: Annotated[
147
+ int,
148
+ (
149
+ 'The slice ID of the runner. If 0, it will also run as the '
150
+ 'aggregator for collecting results from other slices. '
151
+ )
152
+ ] = 0
153
+
154
+ num_slices: Annotated[
155
+ int,
156
+ 'The number of slices to run the evaluations in parallel.'
157
+ ] = 1
158
+
159
+ timeout: Annotated[
160
+ int | None,
161
+ 'Timeout for each evaluation example.'
162
+ ] = None
163
+
164
+ concurrent_startup_delay: Annotated[
165
+ tuple[int, int] | None,
166
+ (
167
+ 'A range of seconds to delay the initial evaluation of each thread '
168
+ 'in the thread pool, helping to prevent a burst in LLM QPS at '
169
+ 'startup. If set to None, no delay will be applied.'
170
+ )
171
+ ] = None
172
+
173
+ ckpt_format: Annotated[
174
+ str,
175
+ 'The file extension of the checkpoint files.'
176
+ ] = 'bagz'
177
+
178
+ max_aggregation_threads: Annotated[
179
+ int,
180
+ 'The maximum number of threads to aggregate checkpoints.'
181
+ ] = 32
182
+
183
+ def _on_bound(self):
184
+ super()._on_bound()
185
+ if self.current_run.use_cache != 'no':
186
+ raise ValueError(
187
+ 'Cache is not supported in MultiProcessParallelRunner. '
188
+ f'Encountered: {self.current_run.use_cache}'
189
+ )
190
+ monitor_plugins = []
191
+ worker_plugins = [
192
+ checkpointing.PerExampleCheckpointer(
193
+ checkpoint_filename=f'checkpoint.{self.ckpt_format}'
194
+ ),
195
+ ]
196
+ for plugin in self.plugins:
197
+ if isinstance(plugin, checkpointing.Checkpointer):
198
+ pg.logging.warning(
199
+ 'Built-in checkpointing is enabled on MultiProcessParallelRunner. '
200
+ f'Ignoring checkpointer: {plugin!r}.'
201
+ )
202
+ elif plugin.is_per_example():
203
+ worker_plugins.append(pg.Ref(plugin))
204
+ else:
205
+ monitor_plugins.append(pg.Ref(plugin))
206
+
207
+ if self.slice_id == 0:
208
+ self._ckpt_monitor = ckpt_monitor.CheckpointMonitor(
209
+ pg.Ref(self.current_run),
210
+ plugins=monitor_plugins,
211
+ monitor_inprogress_files=True,
212
+ checkpoint_pattern=f'checkpoint_*.{self.ckpt_format}',
213
+ max_aggregation_threads=self.max_aggregation_threads,
214
+ )
215
+ else:
216
+ self._ckpt_monitor = None
217
+
218
+ self._slice_runner = _SingleSliceRunner(
219
+ current_run=self.current_run.clone(
220
+ override=dict(
221
+ # Clone the experiment to avoid updating the original one.
222
+ experiment=self.current_run.experiment.clone(),
223
+ example_ids=self._examples_to_evaluate,
224
+ )
225
+ ),
226
+ plugins=worker_plugins,
227
+ timeout=self.timeout,
228
+ concurrent_startup_delay=self.concurrent_startup_delay,
229
+ )
230
+
231
+ def _examples_to_evaluate(
232
+ self,
233
+ experiment: experiment_lib.Experiment
234
+ ) -> list[int]:
235
+ all_ids = self.current_run.examples_to_evaluate(experiment)
236
+ return [x for x in all_ids if x % self.num_slices == self.slice_id]
237
+
238
+ def run(self) -> None:
239
+ if self._ckpt_monitor is not None:
240
+ self._ckpt_monitor.start()
241
+ self._slice_runner.run()
242
+ if self._ckpt_monitor is not None:
243
+ self._ckpt_monitor.join()
@@ -0,0 +1,182 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for parallel runner."""
15
+ import os
16
+ import tempfile
17
+ import threading
18
+ from typing import Any
19
+ import unittest
20
+
21
+ from langfun.core.eval.v2 import checkpointing
22
+ from langfun.core.eval.v2 import eval_test_helper
23
+ from langfun.core.eval.v2 import reporting
24
+ from langfun.core.eval.v2.runners import parallel # pylint: disable=unused-import
25
+
26
+ import pyglove as pg
27
+
28
+
29
+ class ParallelRunnerTest(unittest.TestCase):
30
+
31
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
32
+ self.assertEqual(len(actual), len(expected))
33
+ for i, (x, y) in enumerate(zip(actual, expected)):
34
+ if x is not y:
35
+ print(i, pg.diff(x, y))
36
+ self.assertIs(x, y)
37
+
38
+ def test_parallel_runner(self):
39
+ plugin = eval_test_helper.TestPlugin()
40
+ exp = eval_test_helper.test_experiment()
41
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_parallel_runner')
42
+ _ = exp.run(root_dir, runner='parallel', plugins=[plugin])
43
+
44
+ self.assertIsNotNone(plugin.start_time)
45
+ self.assertIsNotNone(plugin.complete_time)
46
+ self.assertGreater(plugin.complete_time, plugin.start_time)
47
+
48
+ self.assertEqual(
49
+ len(plugin.started_experiments), len(exp.nodes)
50
+ )
51
+ self.assertEqual(
52
+ len(plugin.completed_experiments), len(exp.nodes)
53
+ )
54
+ self.assertEqual(
55
+ len(plugin.started_example_ids), 6 * 10
56
+ )
57
+ self.assertEqual(
58
+ len(plugin.completed_example_ids), 6 * 10
59
+ )
60
+ self.assert_same_list(plugin.skipped_experiments, [])
61
+
62
+ for node in exp.nodes:
63
+ self.assertTrue(node.progress.is_started)
64
+ self.assertTrue(node.progress.is_completed)
65
+ if node.is_leaf:
66
+ self.assertEqual(node.progress.num_skipped, 0)
67
+ self.assertEqual(node.progress.num_completed, 10)
68
+ self.assertEqual(node.progress.num_failed, 1)
69
+ else:
70
+ self.assertEqual(node.progress.num_skipped, 0)
71
+ self.assertEqual(node.progress.num_failed, 0)
72
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
73
+
74
+ def test_raise_if_has_error(self):
75
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_raise_if_has_error')
76
+ exp = eval_test_helper.TestEvaluation()
77
+ with self.assertRaisesRegex(ValueError, 'x should not be 5'):
78
+ exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
79
+
80
+ def test_concurrent_startup_delay(self):
81
+ plugin = eval_test_helper.TestPlugin()
82
+ exp = eval_test_helper.test_experiment()
83
+ root_dir = os.path.join(
84
+ tempfile.mkdtemp(), 'test_concurrent_startup_delay'
85
+ )
86
+ _ = exp.run(
87
+ root_dir,
88
+ runner='parallel',
89
+ plugins=[plugin],
90
+ concurrent_startup_delay=(0, 5),
91
+ )
92
+
93
+
94
+ class MultiProcessParallelRunnerTest(unittest.TestCase):
95
+
96
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
97
+ self.assertEqual(len(actual), len(expected))
98
+ for i, (x, y) in enumerate(zip(actual, expected)):
99
+ if x is not y:
100
+ print(i, pg.diff(x, y))
101
+ self.assertIs(x, y)
102
+
103
+ def test_basic(self):
104
+ plugin = eval_test_helper.TestPlugin()
105
+ exp = eval_test_helper.test_experiment()
106
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_parallel_runner')
107
+ runs = [None, None]
108
+
109
+ def run_slice(slice_id: int):
110
+ runs[slice_id] = exp.run(
111
+ root_dir,
112
+ runner='sliced-parallel',
113
+ plugins=[
114
+ pg.Ref(plugin),
115
+ # Ignored as PerExampleCheckpointer is used.
116
+ checkpointing.BulkCheckpointer(),
117
+ reporting.ExampleHtmlGenerator(),
118
+ ],
119
+ use_cache='no',
120
+ slice_id=slice_id,
121
+ num_slices=2,
122
+ ckpt_format='jsonl',
123
+ )
124
+
125
+ # We simulate two slices running in parallel.
126
+ threads = [threading.Thread(target=run_slice, args=(i,)) for i in range(2)]
127
+ for t in threads:
128
+ t.start()
129
+ for t in threads:
130
+ t.join()
131
+
132
+ self.assertIsNotNone(plugin.start_time)
133
+ self.assertIsNotNone(plugin.complete_time)
134
+ self.assertGreater(plugin.complete_time, plugin.start_time)
135
+
136
+ self.assertEqual(
137
+ len(plugin.started_experiments), len(exp.nodes)
138
+ )
139
+ self.assertEqual(
140
+ len(plugin.completed_experiments), len(exp.nodes)
141
+ )
142
+ self.assertEqual(
143
+ len(plugin.started_example_ids), 6 * 10
144
+ )
145
+ self.assertEqual(
146
+ len(plugin.completed_example_ids), 6 * 10
147
+ )
148
+ self.assert_same_list(plugin.skipped_experiments, [])
149
+
150
+ for node in exp.nodes:
151
+ self.assertTrue(node.progress.is_started)
152
+ self.assertTrue(node.progress.is_completed)
153
+ if node.is_leaf:
154
+ self.assertEqual(node.progress.num_skipped, 0)
155
+ self.assertEqual(node.progress.num_completed, 10)
156
+ self.assertEqual(node.progress.num_failed, 1)
157
+ for example_id in runs[0].examples_to_evaluate(node):
158
+ self.assertTrue(
159
+ pg.io.path_exists(
160
+ runs[0].output_path_for(node, f'{example_id}.html')
161
+ )
162
+ )
163
+ else:
164
+ self.assertEqual(node.progress.num_skipped, 0)
165
+ self.assertEqual(node.progress.num_failed, 0)
166
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
167
+
168
+ def test_parallel_mp_runner_does_not_support_cache(self):
169
+ exp = eval_test_helper.test_experiment()
170
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_parallel_mp_runner_cache')
171
+ with self.assertRaisesRegex(ValueError, 'Cache is not supported'):
172
+ exp.run(
173
+ root_dir,
174
+ runner='sliced-parallel',
175
+ use_cache='global',
176
+ slice_id=0,
177
+ num_slices=1,
178
+ )
179
+
180
+
181
+ if __name__ == '__main__':
182
+ unittest.main()
@@ -0,0 +1,47 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Sequential runner."""
15
+
16
+ from typing import Any, Callable, Iterator
17
+ from langfun.core.eval.v2.runners import base
18
+
19
+
20
+ class SequentialRunner(base.RunnerBase):
21
+ """A runner that executes evaluations and examples sequentially.
22
+
23
+ The sequential runner executes all evaluations and their examples in the
24
+ calling thread. Background tasks are also run sequentially, which makes it
25
+ easier to debug as exceptions from background tasks will be raised
26
+ immediately.
27
+ """
28
+
29
+ NAME = 'sequential'
30
+
31
+ def background_run(
32
+ self, func: Callable[..., Any], *args: Any, **kwargs: Any
33
+ ) -> None:
34
+ """Runs the function with the IO pool."""
35
+ func(*args, **kwargs)
36
+
37
+ def _run(self, evaluations: list[base.Evaluation]) -> None:
38
+ """Runs the experiment in sequence."""
39
+ for e in evaluations:
40
+ self.run_evaluation(e)
41
+
42
+ def _evaluate_items(
43
+ self, evaluation: base.Evaluation, items: Iterator[base.Example]
44
+ ) -> None:
45
+ """Runs the evaluation items in sequence."""
46
+ for item in items:
47
+ self.evaluate_item(evaluation, item)
@@ -0,0 +1,169 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for sequential runner."""
15
+ import os
16
+ import tempfile
17
+ from typing import Any
18
+ import unittest
19
+
20
+ from langfun.core.eval.v2 import eval_test_helper
21
+ from langfun.core.eval.v2.runners import sequential # pylint: disable=unused-import
22
+
23
+ import pyglove as pg
24
+
25
+
26
+ class SequentialRunnerTest(unittest.TestCase):
27
+
28
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
29
+ self.assertEqual(len(actual), len(expected))
30
+ for i, (x, y) in enumerate(zip(actual, expected)):
31
+ if x is not y:
32
+ print(i, pg.diff(x, y))
33
+ self.assertIs(x, y)
34
+
35
+ def test_basic(self):
36
+ plugin = eval_test_helper.TestPlugin()
37
+ exp = eval_test_helper.test_experiment()
38
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_sequential_runner')
39
+ _ = exp.run(root_dir, runner='sequential', plugins=[plugin])
40
+
41
+ self.assertIsNotNone(plugin.start_time)
42
+ self.assertIsNotNone(plugin.complete_time)
43
+ self.assertGreater(plugin.complete_time, plugin.start_time)
44
+
45
+ self.assert_same_list(
46
+ plugin.started_experiments,
47
+ exp.nonleaf_nodes + exp.leaf_nodes
48
+ )
49
+ self.assert_same_list(
50
+ plugin.completed_experiments,
51
+ exp.leaf_nodes + list(reversed(exp.nonleaf_nodes))
52
+ )
53
+ self.assert_same_list(
54
+ plugin.started_example_ids, list(range(1, 11)) * 6
55
+ )
56
+ self.assert_same_list(
57
+ plugin.completed_example_ids, list(range(1, 11)) * 6
58
+ )
59
+ self.assert_same_list(plugin.skipped_experiments, [])
60
+
61
+ for node in exp.nodes:
62
+ self.assertTrue(node.progress.is_started)
63
+ self.assertTrue(node.progress.is_completed)
64
+ if node.is_leaf:
65
+ self.assertEqual(node.progress.num_skipped, 0)
66
+ self.assertEqual(node.progress.num_completed, 10)
67
+ self.assertEqual(node.progress.num_failed, 1)
68
+ else:
69
+ self.assertEqual(node.progress.num_skipped, 0)
70
+ self.assertEqual(node.progress.num_failed, 0)
71
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
72
+
73
+ def test_raise_if_has_error(self):
74
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_raise_if_has_error')
75
+ exp = eval_test_helper.TestEvaluation()
76
+ with self.assertRaisesRegex(ValueError, 'x should not be 5'):
77
+ exp.run(
78
+ root_dir, runner='sequential', plugins=[], raise_if_has_error=True
79
+ )
80
+
81
+ def test_example_ids(self):
82
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_example_ids')
83
+ exp = eval_test_helper.test_experiment()
84
+ plugin = eval_test_helper.TestPlugin()
85
+ _ = exp.run(
86
+ root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
87
+ )
88
+ self.assertEqual(plugin.started_example_ids, [5, 7, 9] * 6)
89
+ self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
90
+
91
+ def test_shuffle_inputs(self):
92
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_shuffle_inputs')
93
+ exp = eval_test_helper.test_experiment()
94
+ plugin = eval_test_helper.TestPlugin()
95
+ run = exp.run(
96
+ root_dir, runner='sequential', plugins=[plugin], shuffle_inputs=True
97
+ )
98
+ self.assertTrue(run.shuffle_inputs)
99
+
100
+ def test_filter(self):
101
+ plugin = eval_test_helper.TestPlugin()
102
+ exp = eval_test_helper.test_experiment()
103
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_filter')
104
+
105
+ _ = exp.run(
106
+ root_dir, runner='sequential', plugins=[plugin],
107
+ filter=lambda e: e.lm.offset != 0
108
+ )
109
+ self.assert_same_list(
110
+ plugin.started_experiments,
111
+ exp.nonleaf_nodes + exp.leaf_nodes[2:]
112
+ )
113
+ self.assert_same_list(
114
+ plugin.skipped_experiments, exp.leaf_nodes[:2]
115
+ )
116
+ self.assert_same_list(
117
+ plugin.completed_experiments,
118
+ exp.leaf_nodes[2:] + [exp.children[1], exp]
119
+ )
120
+
121
+ def test_use_cache(self):
122
+ @pg.functor()
123
+ def test_inputs(num_examples: int = 10):
124
+ return [
125
+ pg.Dict(
126
+ x=i // 2, y=(i // 2) ** 2,
127
+ groundtruth=(i // 2 + (i // 2) ** 2)
128
+ ) for i in range(num_examples)
129
+ ]
130
+
131
+ exp = eval_test_helper.TestEvaluation(
132
+ inputs=test_inputs(num_examples=pg.oneof([2, 4]))
133
+ )
134
+ # Global cache.
135
+ root_dir = os.path.join(tempfile.mkdtemp(), 'global_cache')
136
+ run = exp.run(
137
+ root_dir, 'new', runner='sequential', use_cache='global', plugins=[]
138
+ )
139
+ self.assertTrue(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
140
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 4)
141
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
142
+
143
+ # Per-dataset cache.
144
+ root_dir = os.path.join(tempfile.mkdtemp(), 'per_dataset')
145
+ run = exp.run(
146
+ root_dir, 'new', runner='sequential',
147
+ use_cache='per_dataset', plugins=[]
148
+ )
149
+ for leaf in exp.leaf_nodes:
150
+ self.assertTrue(
151
+ pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
152
+ )
153
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 3)
154
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
155
+
156
+ # No cache.
157
+ root_dir = os.path.join(tempfile.mkdtemp(), 'no')
158
+ run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
159
+ self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
160
+ for leaf in exp.leaf_nodes:
161
+ self.assertFalse(
162
+ pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
163
+ )
164
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 0)
165
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 6)
166
+
167
+
168
+ if __name__ == '__main__':
169
+ unittest.main()