langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,334 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import threading
|
17
|
+
import time
|
18
|
+
from typing import Any
|
19
|
+
import unittest
|
20
|
+
|
21
|
+
from langfun.core.eval.v2 import eval_test_helper
|
22
|
+
from langfun.core.eval.v2 import example as example_lib
|
23
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
24
|
+
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
|
25
|
+
|
26
|
+
import pyglove as pg
|
27
|
+
|
28
|
+
|
29
|
+
Runner = experiment_lib.Runner
|
30
|
+
Example = example_lib.Example
|
31
|
+
Experiment = experiment_lib.Experiment
|
32
|
+
Suite = experiment_lib.Suite
|
33
|
+
Plugin = experiment_lib.Plugin
|
34
|
+
|
35
|
+
|
36
|
+
class TestPlugin(Plugin):
|
37
|
+
started_experiments: list[Experiment] = []
|
38
|
+
completed_experiments: list[Experiment] = []
|
39
|
+
skipped_experiments: list[Experiment] = []
|
40
|
+
started_example_ids: list[int] = []
|
41
|
+
completed_example_ids: list[int] = []
|
42
|
+
skipped_example_ids: list[int] = []
|
43
|
+
start_time: float | None = None
|
44
|
+
complete_time: float | None = None
|
45
|
+
|
46
|
+
def _on_bound(self):
|
47
|
+
super()._on_bound()
|
48
|
+
self._lock = threading.Lock()
|
49
|
+
|
50
|
+
def on_run_start(self, runner: Runner, root: Experiment):
|
51
|
+
del root
|
52
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
53
|
+
self.start_time = time.time()
|
54
|
+
|
55
|
+
def on_run_complete(self, runner: Runner, root: Experiment):
|
56
|
+
del root
|
57
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
58
|
+
self.complete_time = time.time()
|
59
|
+
|
60
|
+
def on_experiment_start(self, runner: Runner, experiment: Experiment):
|
61
|
+
del runner
|
62
|
+
with pg.notify_on_change(False), self._lock:
|
63
|
+
self.started_experiments.append(pg.Ref(experiment))
|
64
|
+
|
65
|
+
def on_experiment_skipped(self, runner: Runner, experiment: Experiment):
|
66
|
+
del runner
|
67
|
+
with pg.notify_on_change(False), self._lock:
|
68
|
+
self.skipped_experiments.append(pg.Ref(experiment))
|
69
|
+
|
70
|
+
def on_experiment_complete(self, runner: Runner, experiment: Experiment):
|
71
|
+
del runner
|
72
|
+
with pg.notify_on_change(False), self._lock:
|
73
|
+
self.completed_experiments.append(pg.Ref(experiment))
|
74
|
+
|
75
|
+
def on_example_start(
|
76
|
+
self, runner: Runner, experiment: Experiment, example: Example):
|
77
|
+
del runner, experiment
|
78
|
+
with pg.notify_on_change(False), self._lock:
|
79
|
+
self.started_example_ids.append(example.id)
|
80
|
+
|
81
|
+
def on_example_skipped(
|
82
|
+
self, runner: Runner, experiment: Experiment, example: Example):
|
83
|
+
del runner, experiment
|
84
|
+
with pg.notify_on_change(False), self._lock:
|
85
|
+
self.skipped_example_ids.append(example.id)
|
86
|
+
|
87
|
+
def on_example_complete(
|
88
|
+
self, runner: Runner, experiment: Experiment, example: Example):
|
89
|
+
del runner, experiment
|
90
|
+
with pg.notify_on_change(False), self._lock:
|
91
|
+
self.completed_example_ids.append(example.id)
|
92
|
+
|
93
|
+
|
94
|
+
class RunnerTest(unittest.TestCase):
|
95
|
+
|
96
|
+
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
97
|
+
self.assertEqual(len(actual), len(expected))
|
98
|
+
for i, (x, y) in enumerate(zip(actual, expected)):
|
99
|
+
if x is not y:
|
100
|
+
print(i, pg.diff(x, y))
|
101
|
+
self.assertIs(x, y)
|
102
|
+
|
103
|
+
def test_basic(self):
|
104
|
+
plugin = TestPlugin()
|
105
|
+
exp = eval_test_helper.test_experiment()
|
106
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
|
107
|
+
run = exp.run(root_dir, runner='sequential', plugins=[plugin])
|
108
|
+
|
109
|
+
self.assertIsNotNone(plugin.start_time)
|
110
|
+
self.assertIsNotNone(plugin.complete_time)
|
111
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
112
|
+
|
113
|
+
self.assert_same_list(
|
114
|
+
plugin.started_experiments,
|
115
|
+
exp.nonleaf_nodes + exp.leaf_nodes
|
116
|
+
)
|
117
|
+
self.assert_same_list(
|
118
|
+
plugin.completed_experiments,
|
119
|
+
exp.leaf_nodes + list(reversed(exp.nonleaf_nodes))
|
120
|
+
)
|
121
|
+
self.assert_same_list(
|
122
|
+
plugin.started_example_ids, list(range(1, 11)) * 6
|
123
|
+
)
|
124
|
+
self.assert_same_list(
|
125
|
+
plugin.completed_example_ids, list(range(1, 11)) * 6
|
126
|
+
)
|
127
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
128
|
+
self.assert_same_list(plugin.skipped_example_ids, [])
|
129
|
+
self.assertTrue(
|
130
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
131
|
+
)
|
132
|
+
|
133
|
+
for node in exp.nodes:
|
134
|
+
self.assertTrue(node.progress.is_started)
|
135
|
+
self.assertTrue(node.progress.is_completed)
|
136
|
+
if node.is_leaf:
|
137
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
138
|
+
self.assertEqual(node.progress.num_completed, 10)
|
139
|
+
self.assertEqual(node.progress.num_failed, 1)
|
140
|
+
else:
|
141
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
142
|
+
self.assertEqual(node.progress.num_failed, 0)
|
143
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
144
|
+
|
145
|
+
def test_raise_if_has_error(self):
|
146
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
|
147
|
+
exp = eval_test_helper.TestEvaluation()
|
148
|
+
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
|
149
|
+
exp.run(
|
150
|
+
root_dir, runner='sequential', plugins=[], raise_if_has_error=True
|
151
|
+
)
|
152
|
+
|
153
|
+
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
|
154
|
+
exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
|
155
|
+
|
156
|
+
def test_example_ids(self):
|
157
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
|
158
|
+
exp = eval_test_helper.test_experiment()
|
159
|
+
plugin = TestPlugin()
|
160
|
+
_ = exp.run(
|
161
|
+
root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
|
162
|
+
)
|
163
|
+
self.assertEqual(plugin.started_example_ids, [5, 7, 9] * 6)
|
164
|
+
self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
|
165
|
+
|
166
|
+
def test_filter(self):
|
167
|
+
plugin = TestPlugin()
|
168
|
+
exp = eval_test_helper.test_experiment()
|
169
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')
|
170
|
+
|
171
|
+
_ = exp.run(
|
172
|
+
root_dir, runner='sequential', plugins=[plugin],
|
173
|
+
filter=lambda e: e.lm.offset != 0
|
174
|
+
)
|
175
|
+
self.assert_same_list(
|
176
|
+
plugin.started_experiments,
|
177
|
+
exp.nonleaf_nodes + exp.leaf_nodes[2:]
|
178
|
+
)
|
179
|
+
self.assert_same_list(
|
180
|
+
plugin.skipped_experiments, exp.leaf_nodes[:2]
|
181
|
+
)
|
182
|
+
self.assert_same_list(
|
183
|
+
plugin.completed_experiments,
|
184
|
+
exp.leaf_nodes[2:] + [exp.children[1], exp]
|
185
|
+
)
|
186
|
+
|
187
|
+
def test_use_cache(self):
|
188
|
+
@pg.functor()
|
189
|
+
def test_inputs(num_examples: int = 10):
|
190
|
+
return [
|
191
|
+
pg.Dict(
|
192
|
+
x=i // 2, y=(i // 2) ** 2,
|
193
|
+
groundtruth=(i // 2 + (i // 2) ** 2)
|
194
|
+
) for i in range(num_examples)
|
195
|
+
]
|
196
|
+
|
197
|
+
exp = eval_test_helper.TestEvaluation(
|
198
|
+
inputs=test_inputs(num_examples=pg.oneof([2, 4]))
|
199
|
+
)
|
200
|
+
# Global cache.
|
201
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'global_cache')
|
202
|
+
run = exp.run(
|
203
|
+
root_dir, 'new', runner='sequential', use_cache='global', plugins=[]
|
204
|
+
)
|
205
|
+
self.assertTrue(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
|
206
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 4)
|
207
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
|
208
|
+
|
209
|
+
# Per-dataset cache.
|
210
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_dataset')
|
211
|
+
run = exp.run(
|
212
|
+
root_dir, 'new', runner='sequential',
|
213
|
+
use_cache='per_dataset', plugins=[]
|
214
|
+
)
|
215
|
+
for leaf in exp.leaf_nodes:
|
216
|
+
self.assertTrue(
|
217
|
+
pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
|
218
|
+
)
|
219
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 3)
|
220
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
|
221
|
+
|
222
|
+
# No cache.
|
223
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'no')
|
224
|
+
run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
|
225
|
+
self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
|
226
|
+
for leaf in exp.leaf_nodes:
|
227
|
+
self.assertFalse(
|
228
|
+
pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
|
229
|
+
)
|
230
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 0)
|
231
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 6)
|
232
|
+
|
233
|
+
|
234
|
+
class ParallelRunnerTest(RunnerTest):
|
235
|
+
|
236
|
+
def test_parallel_runner(self):
|
237
|
+
plugin = TestPlugin()
|
238
|
+
exp = eval_test_helper.test_experiment()
|
239
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
|
240
|
+
run = exp.run(root_dir, runner='parallel', plugins=[plugin])
|
241
|
+
|
242
|
+
self.assertIsNotNone(plugin.start_time)
|
243
|
+
self.assertIsNotNone(plugin.complete_time)
|
244
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
245
|
+
|
246
|
+
self.assertEqual(
|
247
|
+
len(plugin.started_experiments), len(exp.nodes)
|
248
|
+
)
|
249
|
+
self.assertEqual(
|
250
|
+
len(plugin.completed_experiments), len(exp.nodes)
|
251
|
+
)
|
252
|
+
self.assertEqual(
|
253
|
+
len(plugin.started_example_ids), 6 * 10
|
254
|
+
)
|
255
|
+
self.assertEqual(
|
256
|
+
len(plugin.completed_example_ids), 6 * 10
|
257
|
+
)
|
258
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
259
|
+
self.assert_same_list(plugin.skipped_example_ids, [])
|
260
|
+
self.assertTrue(
|
261
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
262
|
+
)
|
263
|
+
|
264
|
+
for node in exp.nodes:
|
265
|
+
self.assertTrue(node.progress.is_started)
|
266
|
+
self.assertTrue(node.progress.is_completed)
|
267
|
+
if node.is_leaf:
|
268
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
269
|
+
self.assertEqual(node.progress.num_completed, 10)
|
270
|
+
self.assertEqual(node.progress.num_failed, 1)
|
271
|
+
else:
|
272
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
273
|
+
self.assertEqual(node.progress.num_failed, 0)
|
274
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
275
|
+
|
276
|
+
def test_concurrent_startup_delay(self):
|
277
|
+
plugin = TestPlugin()
|
278
|
+
exp = eval_test_helper.test_experiment()
|
279
|
+
root_dir = os.path.join(
|
280
|
+
tempfile.gettempdir(), 'test_concurrent_startup_delay'
|
281
|
+
)
|
282
|
+
_ = exp.run(
|
283
|
+
root_dir,
|
284
|
+
runner='parallel',
|
285
|
+
plugins=[plugin],
|
286
|
+
concurrent_startup_delay=(0, 5),
|
287
|
+
)
|
288
|
+
|
289
|
+
|
290
|
+
class DebugRunnerTest(RunnerTest):
|
291
|
+
|
292
|
+
def test_debug_runner(self):
|
293
|
+
plugin = TestPlugin()
|
294
|
+
exp = eval_test_helper.test_experiment()
|
295
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
|
296
|
+
run = exp.run(root_dir, runner='debug', plugins=[plugin])
|
297
|
+
|
298
|
+
self.assertIsNotNone(plugin.start_time)
|
299
|
+
self.assertIsNotNone(plugin.complete_time)
|
300
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
301
|
+
|
302
|
+
self.assertEqual(
|
303
|
+
len(plugin.started_experiments), len(exp.nodes)
|
304
|
+
)
|
305
|
+
self.assertEqual(
|
306
|
+
len(plugin.completed_experiments), len(exp.nodes)
|
307
|
+
)
|
308
|
+
self.assertEqual(
|
309
|
+
len(plugin.started_example_ids), 6 * 1
|
310
|
+
)
|
311
|
+
self.assertEqual(
|
312
|
+
len(plugin.completed_example_ids), 6 * 1
|
313
|
+
)
|
314
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
315
|
+
self.assert_same_list(plugin.skipped_example_ids, [])
|
316
|
+
self.assertFalse(
|
317
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
318
|
+
)
|
319
|
+
|
320
|
+
for node in exp.nodes:
|
321
|
+
self.assertTrue(node.progress.is_started)
|
322
|
+
self.assertTrue(node.progress.is_completed)
|
323
|
+
if node.is_leaf:
|
324
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
325
|
+
self.assertEqual(node.progress.num_completed, 1)
|
326
|
+
self.assertEqual(node.progress.num_failed, 0)
|
327
|
+
else:
|
328
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
329
|
+
self.assertEqual(node.progress.num_failed, 0)
|
330
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
331
|
+
|
332
|
+
|
333
|
+
if __name__ == '__main__':
|
334
|
+
unittest.main()
|
langfun/core/langfunc.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14
14
|
"""LangFunc: Language-based functions."""
|
15
15
|
|
16
16
|
import dataclasses
|
17
|
-
from typing import Annotated, Type
|
17
|
+
from typing import Annotated, Type
|
18
18
|
|
19
19
|
from langfun.core import component
|
20
20
|
from langfun.core import language_model
|
@@ -261,7 +261,6 @@ class LangFunc(
|
|
261
261
|
if lm_input is None:
|
262
262
|
lm_input = self.render(**kwargs)
|
263
263
|
|
264
|
-
lm_input.tag(message_lib.Message.TAG_LM_INPUT)
|
265
264
|
if skip_lm:
|
266
265
|
return lm_input
|
267
266
|
|
@@ -270,9 +269,8 @@ class LangFunc(
|
|
270
269
|
# Send rendered text to LM.
|
271
270
|
lm_output = self.lm(lm_input, cache_seed=cache_seed)
|
272
271
|
|
273
|
-
#
|
274
|
-
|
275
|
-
lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
|
272
|
+
# Attach cache seed.
|
273
|
+
lm_input.metadata.cache_seed = cache_seed
|
276
274
|
|
277
275
|
# Transform the output message.
|
278
276
|
lm_output = self.transform_output(lm_output)
|
@@ -333,22 +331,6 @@ class LangFunc(
|
|
333
331
|
"""Transforms the output message before returning from __call__."""
|
334
332
|
return lm_output
|
335
333
|
|
336
|
-
@classmethod
|
337
|
-
def from_value(
|
338
|
-
cls, value: Union[str, template_lib.Template], **kwargs
|
339
|
-
) -> 'LangFunc':
|
340
|
-
"""Create a LangFunc object from a string or template."""
|
341
|
-
if isinstance(value, LangFunc):
|
342
|
-
return value
|
343
|
-
if isinstance(value, template_lib.Template):
|
344
|
-
lfun = LangFunc(value.template_str, **kwargs)
|
345
|
-
# So lfun could acccess all attributes from value.
|
346
|
-
lfun.sym_setparent(value)
|
347
|
-
return lfun
|
348
|
-
if isinstance(value, str):
|
349
|
-
return LangFunc(template_str=value, **kwargs)
|
350
|
-
return LangFunc('{{input}}', input=value, **kwargs)
|
351
|
-
|
352
334
|
|
353
335
|
# Register converter from str to LangFunc, therefore we can always
|
354
336
|
# pass strs to attributes that accept LangFunc.
|
langfun/core/langfunc_test.py
CHANGED
@@ -57,6 +57,10 @@ class BasicTest(unittest.TestCase):
|
|
57
57
|
l2 = LangFunc.from_value(l1)
|
58
58
|
self.assertIs(l2, l1)
|
59
59
|
|
60
|
+
l3 = LangFunc.from_value(l1, x=1)
|
61
|
+
self.assertIsNot(l3, l1)
|
62
|
+
self.assertTrue(pg.eq(l3, LangFunc('Hello', x=1)))
|
63
|
+
|
60
64
|
c = template_lib.Template(
|
61
65
|
'{{x}} + {{l}}',
|
62
66
|
x=1,
|
@@ -82,21 +86,30 @@ class LangFuncCallTest(unittest.TestCase):
|
|
82
86
|
self.assertEqual(i.tags, ['rendered'])
|
83
87
|
|
84
88
|
r = l()
|
85
|
-
self.assertEqual(
|
89
|
+
self.assertEqual(
|
90
|
+
r,
|
91
|
+
message.AIMessage(
|
92
|
+
'Hello!!!', score=0.0, logprobs=None, is_cached=False,
|
93
|
+
usage=language_model.UsageNotAvailable()
|
94
|
+
)
|
95
|
+
)
|
86
96
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
87
|
-
self.assertEqual(
|
97
|
+
self.assertEqual(
|
98
|
+
r.source,
|
99
|
+
message.UserMessage('Hello', metadata=dict(cache_seed=0))
|
100
|
+
)
|
88
101
|
self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
|
89
102
|
|
90
103
|
self.assertEqual(str(l), 'Hello')
|
91
|
-
print(repr(l))
|
92
104
|
self.assertEqual(
|
93
105
|
repr(l),
|
94
106
|
"LangFunc(template_str='Hello', clean=True,"
|
95
|
-
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=
|
96
|
-
' max_tokens=
|
107
|
+
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
|
108
|
+
' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
|
97
109
|
' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
|
98
110
|
' max_concurrency=None, timeout=120.0, max_attempts=5,'
|
99
|
-
' retry_interval=(5, 60), exponential_backoff=True,
|
111
|
+
' retry_interval=(5, 60), exponential_backoff=True,'
|
112
|
+
' max_retry_interval=300, debug=False))',
|
100
113
|
)
|
101
114
|
|
102
115
|
l = LangFunc('Hello')
|
@@ -104,11 +117,16 @@ class LangFuncCallTest(unittest.TestCase):
|
|
104
117
|
self.assertEqual(l, 'Hello')
|
105
118
|
self.assertEqual(l.natural_language_format(), 'Hello')
|
106
119
|
self.assertEqual(l.render(), 'Hello')
|
107
|
-
r = l()
|
120
|
+
r = l(cache_seed=1)
|
108
121
|
self.assertEqual(
|
109
|
-
r,
|
122
|
+
r,
|
123
|
+
message.AIMessage(
|
124
|
+
'Hello!!!', score=0.0, logprobs=None, is_cached=False,
|
125
|
+
usage=language_model.UsageNotAvailable()
|
126
|
+
)
|
110
127
|
)
|
111
128
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
129
|
+
self.assertEqual(r.source.metadata.cache_seed, 1)
|
112
130
|
|
113
131
|
self.assertEqual(str(l), 'Hello')
|
114
132
|
|