langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,334 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import threading
|
17
|
+
import time
|
18
|
+
from typing import Any
|
19
|
+
import unittest
|
20
|
+
|
21
|
+
from langfun.core.eval.v2 import eval_test_helper
|
22
|
+
from langfun.core.eval.v2 import example as example_lib
|
23
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
24
|
+
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
|
25
|
+
|
26
|
+
import pyglove as pg
|
27
|
+
|
28
|
+
|
29
|
+
Runner = experiment_lib.Runner
|
30
|
+
Example = example_lib.Example
|
31
|
+
Experiment = experiment_lib.Experiment
|
32
|
+
Suite = experiment_lib.Suite
|
33
|
+
Plugin = experiment_lib.Plugin
|
34
|
+
|
35
|
+
|
36
|
+
class TestPlugin(Plugin):
|
37
|
+
started_experiments: list[Experiment] = []
|
38
|
+
completed_experiments: list[Experiment] = []
|
39
|
+
skipped_experiments: list[Experiment] = []
|
40
|
+
started_example_ids: list[int] = []
|
41
|
+
completed_example_ids: list[int] = []
|
42
|
+
skipped_example_ids: list[int] = []
|
43
|
+
start_time: float | None = None
|
44
|
+
complete_time: float | None = None
|
45
|
+
|
46
|
+
def _on_bound(self):
|
47
|
+
super()._on_bound()
|
48
|
+
self._lock = threading.Lock()
|
49
|
+
|
50
|
+
def on_run_start(self, runner: Runner, root: Experiment):
|
51
|
+
del root
|
52
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
53
|
+
self.start_time = time.time()
|
54
|
+
|
55
|
+
def on_run_complete(self, runner: Runner, root: Experiment):
|
56
|
+
del root
|
57
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
58
|
+
self.complete_time = time.time()
|
59
|
+
|
60
|
+
def on_experiment_start(self, runner: Runner, experiment: Experiment):
|
61
|
+
del runner
|
62
|
+
with pg.notify_on_change(False), self._lock:
|
63
|
+
self.started_experiments.append(pg.Ref(experiment))
|
64
|
+
|
65
|
+
def on_experiment_skipped(self, runner: Runner, experiment: Experiment):
|
66
|
+
del runner
|
67
|
+
with pg.notify_on_change(False), self._lock:
|
68
|
+
self.skipped_experiments.append(pg.Ref(experiment))
|
69
|
+
|
70
|
+
def on_experiment_complete(self, runner: Runner, experiment: Experiment):
|
71
|
+
del runner
|
72
|
+
with pg.notify_on_change(False), self._lock:
|
73
|
+
self.completed_experiments.append(pg.Ref(experiment))
|
74
|
+
|
75
|
+
def on_example_start(
|
76
|
+
self, runner: Runner, experiment: Experiment, example: Example):
|
77
|
+
del runner, experiment
|
78
|
+
with pg.notify_on_change(False), self._lock:
|
79
|
+
self.started_example_ids.append(example.id)
|
80
|
+
|
81
|
+
def on_example_skipped(
|
82
|
+
self, runner: Runner, experiment: Experiment, example: Example):
|
83
|
+
del runner, experiment
|
84
|
+
with pg.notify_on_change(False), self._lock:
|
85
|
+
self.skipped_example_ids.append(example.id)
|
86
|
+
|
87
|
+
def on_example_complete(
|
88
|
+
self, runner: Runner, experiment: Experiment, example: Example):
|
89
|
+
del runner, experiment
|
90
|
+
with pg.notify_on_change(False), self._lock:
|
91
|
+
self.completed_example_ids.append(example.id)
|
92
|
+
|
93
|
+
|
94
|
+
class RunnerTest(unittest.TestCase):
|
95
|
+
|
96
|
+
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
97
|
+
self.assertEqual(len(actual), len(expected))
|
98
|
+
for i, (x, y) in enumerate(zip(actual, expected)):
|
99
|
+
if x is not y:
|
100
|
+
print(i, pg.diff(x, y))
|
101
|
+
self.assertIs(x, y)
|
102
|
+
|
103
|
+
def test_basic(self):
|
104
|
+
plugin = TestPlugin()
|
105
|
+
exp = eval_test_helper.test_experiment()
|
106
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
|
107
|
+
run = exp.run(root_dir, runner='sequential', plugins=[plugin])
|
108
|
+
|
109
|
+
self.assertIsNotNone(plugin.start_time)
|
110
|
+
self.assertIsNotNone(plugin.complete_time)
|
111
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
112
|
+
|
113
|
+
self.assert_same_list(
|
114
|
+
plugin.started_experiments,
|
115
|
+
exp.nonleaf_nodes + exp.leaf_nodes
|
116
|
+
)
|
117
|
+
self.assert_same_list(
|
118
|
+
plugin.completed_experiments,
|
119
|
+
exp.leaf_nodes + list(reversed(exp.nonleaf_nodes))
|
120
|
+
)
|
121
|
+
self.assert_same_list(
|
122
|
+
plugin.started_example_ids, list(range(1, 11)) * 6
|
123
|
+
)
|
124
|
+
self.assert_same_list(
|
125
|
+
plugin.completed_example_ids, list(range(1, 11)) * 6
|
126
|
+
)
|
127
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
128
|
+
self.assert_same_list(plugin.skipped_example_ids, [])
|
129
|
+
self.assertTrue(
|
130
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
131
|
+
)
|
132
|
+
|
133
|
+
for node in exp.nodes:
|
134
|
+
self.assertTrue(node.progress.is_started)
|
135
|
+
self.assertTrue(node.progress.is_completed)
|
136
|
+
if node.is_leaf:
|
137
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
138
|
+
self.assertEqual(node.progress.num_completed, 10)
|
139
|
+
self.assertEqual(node.progress.num_failed, 1)
|
140
|
+
else:
|
141
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
142
|
+
self.assertEqual(node.progress.num_failed, 0)
|
143
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
144
|
+
|
145
|
+
def test_raise_if_has_error(self):
|
146
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
|
147
|
+
exp = eval_test_helper.TestEvaluation()
|
148
|
+
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
|
149
|
+
exp.run(
|
150
|
+
root_dir, runner='sequential', plugins=[], raise_if_has_error=True
|
151
|
+
)
|
152
|
+
|
153
|
+
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
|
154
|
+
exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
|
155
|
+
|
156
|
+
def test_example_ids(self):
|
157
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
|
158
|
+
exp = eval_test_helper.test_experiment()
|
159
|
+
plugin = TestPlugin()
|
160
|
+
_ = exp.run(
|
161
|
+
root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
|
162
|
+
)
|
163
|
+
self.assertEqual(plugin.started_example_ids, [5, 7, 9] * 6)
|
164
|
+
self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
|
165
|
+
|
166
|
+
def test_filter(self):
|
167
|
+
plugin = TestPlugin()
|
168
|
+
exp = eval_test_helper.test_experiment()
|
169
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')
|
170
|
+
|
171
|
+
_ = exp.run(
|
172
|
+
root_dir, runner='sequential', plugins=[plugin],
|
173
|
+
filter=lambda e: e.lm.offset != 0
|
174
|
+
)
|
175
|
+
self.assert_same_list(
|
176
|
+
plugin.started_experiments,
|
177
|
+
exp.nonleaf_nodes + exp.leaf_nodes[2:]
|
178
|
+
)
|
179
|
+
self.assert_same_list(
|
180
|
+
plugin.skipped_experiments, exp.leaf_nodes[:2]
|
181
|
+
)
|
182
|
+
self.assert_same_list(
|
183
|
+
plugin.completed_experiments,
|
184
|
+
exp.leaf_nodes[2:] + [exp.children[1], exp]
|
185
|
+
)
|
186
|
+
|
187
|
+
def test_use_cache(self):
|
188
|
+
@pg.functor()
|
189
|
+
def test_inputs(num_examples: int = 10):
|
190
|
+
return [
|
191
|
+
pg.Dict(
|
192
|
+
x=i // 2, y=(i // 2) ** 2,
|
193
|
+
groundtruth=(i // 2 + (i // 2) ** 2)
|
194
|
+
) for i in range(num_examples)
|
195
|
+
]
|
196
|
+
|
197
|
+
exp = eval_test_helper.TestEvaluation(
|
198
|
+
inputs=test_inputs(num_examples=pg.oneof([2, 4]))
|
199
|
+
)
|
200
|
+
# Global cache.
|
201
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'global_cache')
|
202
|
+
run = exp.run(
|
203
|
+
root_dir, 'new', runner='sequential', use_cache='global', plugins=[]
|
204
|
+
)
|
205
|
+
self.assertTrue(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
|
206
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 4)
|
207
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
|
208
|
+
|
209
|
+
# Per-dataset cache.
|
210
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_dataset')
|
211
|
+
run = exp.run(
|
212
|
+
root_dir, 'new', runner='sequential',
|
213
|
+
use_cache='per_dataset', plugins=[]
|
214
|
+
)
|
215
|
+
for leaf in exp.leaf_nodes:
|
216
|
+
self.assertTrue(
|
217
|
+
pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
|
218
|
+
)
|
219
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 3)
|
220
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
|
221
|
+
|
222
|
+
# No cache.
|
223
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'no')
|
224
|
+
run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
|
225
|
+
self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
|
226
|
+
for leaf in exp.leaf_nodes:
|
227
|
+
self.assertFalse(
|
228
|
+
pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
|
229
|
+
)
|
230
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 0)
|
231
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 6)
|
232
|
+
|
233
|
+
|
234
|
+
class ParallelRunnerTest(RunnerTest):
|
235
|
+
|
236
|
+
def test_parallel_runner(self):
|
237
|
+
plugin = TestPlugin()
|
238
|
+
exp = eval_test_helper.test_experiment()
|
239
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
|
240
|
+
run = exp.run(root_dir, runner='parallel', plugins=[plugin])
|
241
|
+
|
242
|
+
self.assertIsNotNone(plugin.start_time)
|
243
|
+
self.assertIsNotNone(plugin.complete_time)
|
244
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
245
|
+
|
246
|
+
self.assertEqual(
|
247
|
+
len(plugin.started_experiments), len(exp.nodes)
|
248
|
+
)
|
249
|
+
self.assertEqual(
|
250
|
+
len(plugin.completed_experiments), len(exp.nodes)
|
251
|
+
)
|
252
|
+
self.assertEqual(
|
253
|
+
len(plugin.started_example_ids), 6 * 10
|
254
|
+
)
|
255
|
+
self.assertEqual(
|
256
|
+
len(plugin.completed_example_ids), 6 * 10
|
257
|
+
)
|
258
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
259
|
+
self.assert_same_list(plugin.skipped_example_ids, [])
|
260
|
+
self.assertTrue(
|
261
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
262
|
+
)
|
263
|
+
|
264
|
+
for node in exp.nodes:
|
265
|
+
self.assertTrue(node.progress.is_started)
|
266
|
+
self.assertTrue(node.progress.is_completed)
|
267
|
+
if node.is_leaf:
|
268
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
269
|
+
self.assertEqual(node.progress.num_completed, 10)
|
270
|
+
self.assertEqual(node.progress.num_failed, 1)
|
271
|
+
else:
|
272
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
273
|
+
self.assertEqual(node.progress.num_failed, 0)
|
274
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
275
|
+
|
276
|
+
def test_concurrent_startup_delay(self):
|
277
|
+
plugin = TestPlugin()
|
278
|
+
exp = eval_test_helper.test_experiment()
|
279
|
+
root_dir = os.path.join(
|
280
|
+
tempfile.gettempdir(), 'test_concurrent_startup_delay'
|
281
|
+
)
|
282
|
+
_ = exp.run(
|
283
|
+
root_dir,
|
284
|
+
runner='parallel',
|
285
|
+
plugins=[plugin],
|
286
|
+
concurrent_startup_delay=(0, 5),
|
287
|
+
)
|
288
|
+
|
289
|
+
|
290
|
+
class DebugRunnerTest(RunnerTest):
|
291
|
+
|
292
|
+
def test_debug_runner(self):
|
293
|
+
plugin = TestPlugin()
|
294
|
+
exp = eval_test_helper.test_experiment()
|
295
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
|
296
|
+
run = exp.run(root_dir, runner='debug', plugins=[plugin])
|
297
|
+
|
298
|
+
self.assertIsNotNone(plugin.start_time)
|
299
|
+
self.assertIsNotNone(plugin.complete_time)
|
300
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
301
|
+
|
302
|
+
self.assertEqual(
|
303
|
+
len(plugin.started_experiments), len(exp.nodes)
|
304
|
+
)
|
305
|
+
self.assertEqual(
|
306
|
+
len(plugin.completed_experiments), len(exp.nodes)
|
307
|
+
)
|
308
|
+
self.assertEqual(
|
309
|
+
len(plugin.started_example_ids), 6 * 1
|
310
|
+
)
|
311
|
+
self.assertEqual(
|
312
|
+
len(plugin.completed_example_ids), 6 * 1
|
313
|
+
)
|
314
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
315
|
+
self.assert_same_list(plugin.skipped_example_ids, [])
|
316
|
+
self.assertFalse(
|
317
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
318
|
+
)
|
319
|
+
|
320
|
+
for node in exp.nodes:
|
321
|
+
self.assertTrue(node.progress.is_started)
|
322
|
+
self.assertTrue(node.progress.is_completed)
|
323
|
+
if node.is_leaf:
|
324
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
325
|
+
self.assertEqual(node.progress.num_completed, 1)
|
326
|
+
self.assertEqual(node.progress.num_failed, 0)
|
327
|
+
else:
|
328
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
329
|
+
self.assertEqual(node.progress.num_failed, 0)
|
330
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
331
|
+
|
332
|
+
|
333
|
+
if __name__ == '__main__':
|
334
|
+
unittest.main()
|
langfun/core/langfunc.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14
14
|
"""LangFunc: Language-based functions."""
|
15
15
|
|
16
16
|
import dataclasses
|
17
|
-
from typing import Annotated, Type
|
17
|
+
from typing import Annotated, Type
|
18
18
|
|
19
19
|
from langfun.core import component
|
20
20
|
from langfun.core import language_model
|
@@ -269,6 +269,9 @@ class LangFunc(
|
|
269
269
|
# Send rendered text to LM.
|
270
270
|
lm_output = self.lm(lm_input, cache_seed=cache_seed)
|
271
271
|
|
272
|
+
# Attach cache seed.
|
273
|
+
lm_input.metadata.cache_seed = cache_seed
|
274
|
+
|
272
275
|
# Transform the output message.
|
273
276
|
lm_output = self.transform_output(lm_output)
|
274
277
|
lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
|
@@ -328,22 +331,6 @@ class LangFunc(
|
|
328
331
|
"""Transforms the output message before returning from __call__."""
|
329
332
|
return lm_output
|
330
333
|
|
331
|
-
@classmethod
|
332
|
-
def from_value(
|
333
|
-
cls, value: Union[str, template_lib.Template], **kwargs
|
334
|
-
) -> 'LangFunc':
|
335
|
-
"""Create a LangFunc object from a string or template."""
|
336
|
-
if isinstance(value, LangFunc):
|
337
|
-
return value
|
338
|
-
if isinstance(value, template_lib.Template):
|
339
|
-
lfun = LangFunc(value.template_str, **kwargs)
|
340
|
-
# So lfun could acccess all attributes from value.
|
341
|
-
lfun.sym_setparent(value)
|
342
|
-
return lfun
|
343
|
-
if isinstance(value, str):
|
344
|
-
return LangFunc(template_str=value, **kwargs)
|
345
|
-
return LangFunc('{{input}}', input=value, **kwargs)
|
346
|
-
|
347
334
|
|
348
335
|
# Register converter from str to LangFunc, therefore we can always
|
349
336
|
# pass strs to attributes that accept LangFunc.
|
langfun/core/langfunc_test.py
CHANGED
@@ -57,6 +57,10 @@ class BasicTest(unittest.TestCase):
|
|
57
57
|
l2 = LangFunc.from_value(l1)
|
58
58
|
self.assertIs(l2, l1)
|
59
59
|
|
60
|
+
l3 = LangFunc.from_value(l1, x=1)
|
61
|
+
self.assertIsNot(l3, l1)
|
62
|
+
self.assertTrue(pg.eq(l3, LangFunc('Hello', x=1)))
|
63
|
+
|
60
64
|
c = template_lib.Template(
|
61
65
|
'{{x}} + {{l}}',
|
62
66
|
x=1,
|
@@ -83,14 +87,20 @@ class LangFuncCallTest(unittest.TestCase):
|
|
83
87
|
|
84
88
|
r = l()
|
85
89
|
self.assertEqual(
|
86
|
-
r,
|
90
|
+
r,
|
91
|
+
message.AIMessage(
|
92
|
+
'Hello!!!', score=0.0, logprobs=None, is_cached=False,
|
93
|
+
usage=language_model.UsageNotAvailable()
|
94
|
+
)
|
87
95
|
)
|
88
96
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
89
|
-
self.assertEqual(
|
97
|
+
self.assertEqual(
|
98
|
+
r.source,
|
99
|
+
message.UserMessage('Hello', metadata=dict(cache_seed=0))
|
100
|
+
)
|
90
101
|
self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
|
91
102
|
|
92
103
|
self.assertEqual(str(l), 'Hello')
|
93
|
-
print(repr(l))
|
94
104
|
self.assertEqual(
|
95
105
|
repr(l),
|
96
106
|
"LangFunc(template_str='Hello', clean=True,"
|
@@ -98,7 +108,8 @@ class LangFuncCallTest(unittest.TestCase):
|
|
98
108
|
' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
|
99
109
|
' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
|
100
110
|
' max_concurrency=None, timeout=120.0, max_attempts=5,'
|
101
|
-
' retry_interval=(5, 60), exponential_backoff=True,
|
111
|
+
' retry_interval=(5, 60), exponential_backoff=True,'
|
112
|
+
' max_retry_interval=300, debug=False))',
|
102
113
|
)
|
103
114
|
|
104
115
|
l = LangFunc('Hello')
|
@@ -106,11 +117,16 @@ class LangFuncCallTest(unittest.TestCase):
|
|
106
117
|
self.assertEqual(l, 'Hello')
|
107
118
|
self.assertEqual(l.natural_language_format(), 'Hello')
|
108
119
|
self.assertEqual(l.render(), 'Hello')
|
109
|
-
r = l()
|
120
|
+
r = l(cache_seed=1)
|
110
121
|
self.assertEqual(
|
111
|
-
r,
|
122
|
+
r,
|
123
|
+
message.AIMessage(
|
124
|
+
'Hello!!!', score=0.0, logprobs=None, is_cached=False,
|
125
|
+
usage=language_model.UsageNotAvailable()
|
126
|
+
)
|
112
127
|
)
|
113
128
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
129
|
+
self.assertEqual(r.source.metadata.cache_seed, 1)
|
114
130
|
|
115
131
|
self.assertEqual(str(l), 'Hello')
|
116
132
|
|