langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,334 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import threading
17
+ import time
18
+ from typing import Any
19
+ import unittest
20
+
21
+ from langfun.core.eval.v2 import eval_test_helper
22
+ from langfun.core.eval.v2 import example as example_lib
23
+ from langfun.core.eval.v2 import experiment as experiment_lib
24
+ from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
25
+
26
+ import pyglove as pg
27
+
28
+
29
+ Runner = experiment_lib.Runner
30
+ Example = example_lib.Example
31
+ Experiment = experiment_lib.Experiment
32
+ Suite = experiment_lib.Suite
33
+ Plugin = experiment_lib.Plugin
34
+
35
+
36
+ class TestPlugin(Plugin):
37
+ started_experiments: list[Experiment] = []
38
+ completed_experiments: list[Experiment] = []
39
+ skipped_experiments: list[Experiment] = []
40
+ started_example_ids: list[int] = []
41
+ completed_example_ids: list[int] = []
42
+ skipped_example_ids: list[int] = []
43
+ start_time: float | None = None
44
+ complete_time: float | None = None
45
+
46
+ def _on_bound(self):
47
+ super()._on_bound()
48
+ self._lock = threading.Lock()
49
+
50
+ def on_run_start(self, runner: Runner, root: Experiment):
51
+ del root
52
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
53
+ self.start_time = time.time()
54
+
55
+ def on_run_complete(self, runner: Runner, root: Experiment):
56
+ del root
57
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
58
+ self.complete_time = time.time()
59
+
60
+ def on_experiment_start(self, runner: Runner, experiment: Experiment):
61
+ del runner
62
+ with pg.notify_on_change(False), self._lock:
63
+ self.started_experiments.append(pg.Ref(experiment))
64
+
65
+ def on_experiment_skipped(self, runner: Runner, experiment: Experiment):
66
+ del runner
67
+ with pg.notify_on_change(False), self._lock:
68
+ self.skipped_experiments.append(pg.Ref(experiment))
69
+
70
+ def on_experiment_complete(self, runner: Runner, experiment: Experiment):
71
+ del runner
72
+ with pg.notify_on_change(False), self._lock:
73
+ self.completed_experiments.append(pg.Ref(experiment))
74
+
75
+ def on_example_start(
76
+ self, runner: Runner, experiment: Experiment, example: Example):
77
+ del runner, experiment
78
+ with pg.notify_on_change(False), self._lock:
79
+ self.started_example_ids.append(example.id)
80
+
81
+ def on_example_skipped(
82
+ self, runner: Runner, experiment: Experiment, example: Example):
83
+ del runner, experiment
84
+ with pg.notify_on_change(False), self._lock:
85
+ self.skipped_example_ids.append(example.id)
86
+
87
+ def on_example_complete(
88
+ self, runner: Runner, experiment: Experiment, example: Example):
89
+ del runner, experiment
90
+ with pg.notify_on_change(False), self._lock:
91
+ self.completed_example_ids.append(example.id)
92
+
93
+
94
+ class RunnerTest(unittest.TestCase):
95
+
96
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
97
+ self.assertEqual(len(actual), len(expected))
98
+ for i, (x, y) in enumerate(zip(actual, expected)):
99
+ if x is not y:
100
+ print(i, pg.diff(x, y))
101
+ self.assertIs(x, y)
102
+
103
+ def test_basic(self):
104
+ plugin = TestPlugin()
105
+ exp = eval_test_helper.test_experiment()
106
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
107
+ run = exp.run(root_dir, runner='sequential', plugins=[plugin])
108
+
109
+ self.assertIsNotNone(plugin.start_time)
110
+ self.assertIsNotNone(plugin.complete_time)
111
+ self.assertGreater(plugin.complete_time, plugin.start_time)
112
+
113
+ self.assert_same_list(
114
+ plugin.started_experiments,
115
+ exp.nonleaf_nodes + exp.leaf_nodes
116
+ )
117
+ self.assert_same_list(
118
+ plugin.completed_experiments,
119
+ exp.leaf_nodes + list(reversed(exp.nonleaf_nodes))
120
+ )
121
+ self.assert_same_list(
122
+ plugin.started_example_ids, list(range(1, 11)) * 6
123
+ )
124
+ self.assert_same_list(
125
+ plugin.completed_example_ids, list(range(1, 11)) * 6
126
+ )
127
+ self.assert_same_list(plugin.skipped_experiments, [])
128
+ self.assert_same_list(plugin.skipped_example_ids, [])
129
+ self.assertTrue(
130
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
131
+ )
132
+
133
+ for node in exp.nodes:
134
+ self.assertTrue(node.progress.is_started)
135
+ self.assertTrue(node.progress.is_completed)
136
+ if node.is_leaf:
137
+ self.assertEqual(node.progress.num_skipped, 0)
138
+ self.assertEqual(node.progress.num_completed, 10)
139
+ self.assertEqual(node.progress.num_failed, 1)
140
+ else:
141
+ self.assertEqual(node.progress.num_skipped, 0)
142
+ self.assertEqual(node.progress.num_failed, 0)
143
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
144
+
145
+ def test_raise_if_has_error(self):
146
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
147
+ exp = eval_test_helper.TestEvaluation()
148
+ with self.assertRaisesRegex(ValueError, 'x should not be 5'):
149
+ exp.run(
150
+ root_dir, runner='sequential', plugins=[], raise_if_has_error=True
151
+ )
152
+
153
+ with self.assertRaisesRegex(ValueError, 'x should not be 5'):
154
+ exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
155
+
156
+ def test_example_ids(self):
157
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
158
+ exp = eval_test_helper.test_experiment()
159
+ plugin = TestPlugin()
160
+ _ = exp.run(
161
+ root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
162
+ )
163
+ self.assertEqual(plugin.started_example_ids, [5, 7, 9] * 6)
164
+ self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
165
+
166
+ def test_filter(self):
167
+ plugin = TestPlugin()
168
+ exp = eval_test_helper.test_experiment()
169
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')
170
+
171
+ _ = exp.run(
172
+ root_dir, runner='sequential', plugins=[plugin],
173
+ filter=lambda e: e.lm.offset != 0
174
+ )
175
+ self.assert_same_list(
176
+ plugin.started_experiments,
177
+ exp.nonleaf_nodes + exp.leaf_nodes[2:]
178
+ )
179
+ self.assert_same_list(
180
+ plugin.skipped_experiments, exp.leaf_nodes[:2]
181
+ )
182
+ self.assert_same_list(
183
+ plugin.completed_experiments,
184
+ exp.leaf_nodes[2:] + [exp.children[1], exp]
185
+ )
186
+
187
+ def test_use_cache(self):
188
+ @pg.functor()
189
+ def test_inputs(num_examples: int = 10):
190
+ return [
191
+ pg.Dict(
192
+ x=i // 2, y=(i // 2) ** 2,
193
+ groundtruth=(i // 2 + (i // 2) ** 2)
194
+ ) for i in range(num_examples)
195
+ ]
196
+
197
+ exp = eval_test_helper.TestEvaluation(
198
+ inputs=test_inputs(num_examples=pg.oneof([2, 4]))
199
+ )
200
+ # Global cache.
201
+ root_dir = os.path.join(tempfile.gettempdir(), 'global_cache')
202
+ run = exp.run(
203
+ root_dir, 'new', runner='sequential', use_cache='global', plugins=[]
204
+ )
205
+ self.assertTrue(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
206
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 4)
207
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
208
+
209
+ # Per-dataset cache.
210
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_dataset')
211
+ run = exp.run(
212
+ root_dir, 'new', runner='sequential',
213
+ use_cache='per_dataset', plugins=[]
214
+ )
215
+ for leaf in exp.leaf_nodes:
216
+ self.assertTrue(
217
+ pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
218
+ )
219
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 3)
220
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
221
+
222
+ # No cache.
223
+ root_dir = os.path.join(tempfile.gettempdir(), 'no')
224
+ run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
225
+ self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
226
+ for leaf in exp.leaf_nodes:
227
+ self.assertFalse(
228
+ pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
229
+ )
230
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 0)
231
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 6)
232
+
233
+
234
+ class ParallelRunnerTest(RunnerTest):
235
+
236
+ def test_parallel_runner(self):
237
+ plugin = TestPlugin()
238
+ exp = eval_test_helper.test_experiment()
239
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
240
+ run = exp.run(root_dir, runner='parallel', plugins=[plugin])
241
+
242
+ self.assertIsNotNone(plugin.start_time)
243
+ self.assertIsNotNone(plugin.complete_time)
244
+ self.assertGreater(plugin.complete_time, plugin.start_time)
245
+
246
+ self.assertEqual(
247
+ len(plugin.started_experiments), len(exp.nodes)
248
+ )
249
+ self.assertEqual(
250
+ len(plugin.completed_experiments), len(exp.nodes)
251
+ )
252
+ self.assertEqual(
253
+ len(plugin.started_example_ids), 6 * 10
254
+ )
255
+ self.assertEqual(
256
+ len(plugin.completed_example_ids), 6 * 10
257
+ )
258
+ self.assert_same_list(plugin.skipped_experiments, [])
259
+ self.assert_same_list(plugin.skipped_example_ids, [])
260
+ self.assertTrue(
261
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
262
+ )
263
+
264
+ for node in exp.nodes:
265
+ self.assertTrue(node.progress.is_started)
266
+ self.assertTrue(node.progress.is_completed)
267
+ if node.is_leaf:
268
+ self.assertEqual(node.progress.num_skipped, 0)
269
+ self.assertEqual(node.progress.num_completed, 10)
270
+ self.assertEqual(node.progress.num_failed, 1)
271
+ else:
272
+ self.assertEqual(node.progress.num_skipped, 0)
273
+ self.assertEqual(node.progress.num_failed, 0)
274
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
275
+
276
+ def test_concurrent_startup_delay(self):
277
+ plugin = TestPlugin()
278
+ exp = eval_test_helper.test_experiment()
279
+ root_dir = os.path.join(
280
+ tempfile.gettempdir(), 'test_concurrent_startup_delay'
281
+ )
282
+ _ = exp.run(
283
+ root_dir,
284
+ runner='parallel',
285
+ plugins=[plugin],
286
+ concurrent_startup_delay=(0, 5),
287
+ )
288
+
289
+
290
+ class DebugRunnerTest(RunnerTest):
291
+
292
+ def test_debug_runner(self):
293
+ plugin = TestPlugin()
294
+ exp = eval_test_helper.test_experiment()
295
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
296
+ run = exp.run(root_dir, runner='debug', plugins=[plugin])
297
+
298
+ self.assertIsNotNone(plugin.start_time)
299
+ self.assertIsNotNone(plugin.complete_time)
300
+ self.assertGreater(plugin.complete_time, plugin.start_time)
301
+
302
+ self.assertEqual(
303
+ len(plugin.started_experiments), len(exp.nodes)
304
+ )
305
+ self.assertEqual(
306
+ len(plugin.completed_experiments), len(exp.nodes)
307
+ )
308
+ self.assertEqual(
309
+ len(plugin.started_example_ids), 6 * 1
310
+ )
311
+ self.assertEqual(
312
+ len(plugin.completed_example_ids), 6 * 1
313
+ )
314
+ self.assert_same_list(plugin.skipped_experiments, [])
315
+ self.assert_same_list(plugin.skipped_example_ids, [])
316
+ self.assertFalse(
317
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
318
+ )
319
+
320
+ for node in exp.nodes:
321
+ self.assertTrue(node.progress.is_started)
322
+ self.assertTrue(node.progress.is_completed)
323
+ if node.is_leaf:
324
+ self.assertEqual(node.progress.num_skipped, 0)
325
+ self.assertEqual(node.progress.num_completed, 1)
326
+ self.assertEqual(node.progress.num_failed, 0)
327
+ else:
328
+ self.assertEqual(node.progress.num_skipped, 0)
329
+ self.assertEqual(node.progress.num_failed, 0)
330
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
331
+
332
+
333
+ if __name__ == '__main__':
334
+ unittest.main()
langfun/core/langfunc.py CHANGED
@@ -14,7 +14,7 @@
14
14
  """LangFunc: Language-based functions."""
15
15
 
16
16
  import dataclasses
17
- from typing import Annotated, Type, Union
17
+ from typing import Annotated, Type
18
18
 
19
19
  from langfun.core import component
20
20
  from langfun.core import language_model
@@ -261,7 +261,6 @@ class LangFunc(
261
261
  if lm_input is None:
262
262
  lm_input = self.render(**kwargs)
263
263
 
264
- lm_input.tag(message_lib.Message.TAG_LM_INPUT)
265
264
  if skip_lm:
266
265
  return lm_input
267
266
 
@@ -270,9 +269,8 @@ class LangFunc(
270
269
  # Send rendered text to LM.
271
270
  lm_output = self.lm(lm_input, cache_seed=cache_seed)
272
271
 
273
- # Track the input as the source of the output.
274
- lm_output.source = lm_input
275
- lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
272
+ # Attach cache seed.
273
+ lm_input.metadata.cache_seed = cache_seed
276
274
 
277
275
  # Transform the output message.
278
276
  lm_output = self.transform_output(lm_output)
@@ -333,22 +331,6 @@ class LangFunc(
333
331
  """Transforms the output message before returning from __call__."""
334
332
  return lm_output
335
333
 
336
- @classmethod
337
- def from_value(
338
- cls, value: Union[str, template_lib.Template], **kwargs
339
- ) -> 'LangFunc':
340
- """Create a LangFunc object from a string or template."""
341
- if isinstance(value, LangFunc):
342
- return value
343
- if isinstance(value, template_lib.Template):
344
- lfun = LangFunc(value.template_str, **kwargs)
345
- # So lfun could acccess all attributes from value.
346
- lfun.sym_setparent(value)
347
- return lfun
348
- if isinstance(value, str):
349
- return LangFunc(template_str=value, **kwargs)
350
- return LangFunc('{{input}}', input=value, **kwargs)
351
-
352
334
 
353
335
  # Register converter from str to LangFunc, therefore we can always
354
336
  # pass strs to attributes that accept LangFunc.
@@ -57,6 +57,10 @@ class BasicTest(unittest.TestCase):
57
57
  l2 = LangFunc.from_value(l1)
58
58
  self.assertIs(l2, l1)
59
59
 
60
+ l3 = LangFunc.from_value(l1, x=1)
61
+ self.assertIsNot(l3, l1)
62
+ self.assertTrue(pg.eq(l3, LangFunc('Hello', x=1)))
63
+
60
64
  c = template_lib.Template(
61
65
  '{{x}} + {{l}}',
62
66
  x=1,
@@ -82,21 +86,30 @@ class LangFuncCallTest(unittest.TestCase):
82
86
  self.assertEqual(i.tags, ['rendered'])
83
87
 
84
88
  r = l()
85
- self.assertEqual(r, message.AIMessage('Hello!!!', score=0.0, logprobs=None))
89
+ self.assertEqual(
90
+ r,
91
+ message.AIMessage(
92
+ 'Hello!!!', score=0.0, logprobs=None, is_cached=False,
93
+ usage=language_model.UsageNotAvailable()
94
+ )
95
+ )
86
96
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
87
- self.assertEqual(r.source, message.UserMessage('Hello'))
97
+ self.assertEqual(
98
+ r.source,
99
+ message.UserMessage('Hello', metadata=dict(cache_seed=0))
100
+ )
88
101
  self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
89
102
 
90
103
  self.assertEqual(str(l), 'Hello')
91
- print(repr(l))
92
104
  self.assertEqual(
93
105
  repr(l),
94
106
  "LangFunc(template_str='Hello', clean=True,"
95
- ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0,'
96
- ' max_tokens=1024, n=1, top_k=40, top_p=None, stop=None,'
107
+ ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
108
+ ' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
97
109
  ' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
98
110
  ' max_concurrency=None, timeout=120.0, max_attempts=5,'
99
- ' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
111
+ ' retry_interval=(5, 60), exponential_backoff=True,'
112
+ ' max_retry_interval=300, debug=False))',
100
113
  )
101
114
 
102
115
  l = LangFunc('Hello')
@@ -104,11 +117,16 @@ class LangFuncCallTest(unittest.TestCase):
104
117
  self.assertEqual(l, 'Hello')
105
118
  self.assertEqual(l.natural_language_format(), 'Hello')
106
119
  self.assertEqual(l.render(), 'Hello')
107
- r = l()
120
+ r = l(cache_seed=1)
108
121
  self.assertEqual(
109
- r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
122
+ r,
123
+ message.AIMessage(
124
+ 'Hello!!!', score=0.0, logprobs=None, is_cached=False,
125
+ usage=language_model.UsageNotAvailable()
126
+ )
110
127
  )
111
128
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
129
+ self.assertEqual(r.source.metadata.cache_seed, 1)
112
130
 
113
131
  self.assertEqual(str(l), 'Hello')
114
132