langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,334 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import threading
17
+ import time
18
+ from typing import Any
19
+ import unittest
20
+
21
+ from langfun.core.eval.v2 import eval_test_helper
22
+ from langfun.core.eval.v2 import example as example_lib
23
+ from langfun.core.eval.v2 import experiment as experiment_lib
24
+ from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
25
+
26
+ import pyglove as pg
27
+
28
+
29
+ Runner = experiment_lib.Runner
30
+ Example = example_lib.Example
31
+ Experiment = experiment_lib.Experiment
32
+ Suite = experiment_lib.Suite
33
+ Plugin = experiment_lib.Plugin
34
+
35
+
36
+ class TestPlugin(Plugin):
37
+ started_experiments: list[Experiment] = []
38
+ completed_experiments: list[Experiment] = []
39
+ skipped_experiments: list[Experiment] = []
40
+ started_example_ids: list[int] = []
41
+ completed_example_ids: list[int] = []
42
+ skipped_example_ids: list[int] = []
43
+ start_time: float | None = None
44
+ complete_time: float | None = None
45
+
46
+ def _on_bound(self):
47
+ super()._on_bound()
48
+ self._lock = threading.Lock()
49
+
50
+ def on_run_start(self, runner: Runner, root: Experiment):
51
+ del root
52
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
53
+ self.start_time = time.time()
54
+
55
+ def on_run_complete(self, runner: Runner, root: Experiment):
56
+ del root
57
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
58
+ self.complete_time = time.time()
59
+
60
+ def on_experiment_start(self, runner: Runner, experiment: Experiment):
61
+ del runner
62
+ with pg.notify_on_change(False), self._lock:
63
+ self.started_experiments.append(pg.Ref(experiment))
64
+
65
+ def on_experiment_skipped(self, runner: Runner, experiment: Experiment):
66
+ del runner
67
+ with pg.notify_on_change(False), self._lock:
68
+ self.skipped_experiments.append(pg.Ref(experiment))
69
+
70
+ def on_experiment_complete(self, runner: Runner, experiment: Experiment):
71
+ del runner
72
+ with pg.notify_on_change(False), self._lock:
73
+ self.completed_experiments.append(pg.Ref(experiment))
74
+
75
+ def on_example_start(
76
+ self, runner: Runner, experiment: Experiment, example: Example):
77
+ del runner, experiment
78
+ with pg.notify_on_change(False), self._lock:
79
+ self.started_example_ids.append(example.id)
80
+
81
+ def on_example_skipped(
82
+ self, runner: Runner, experiment: Experiment, example: Example):
83
+ del runner, experiment
84
+ with pg.notify_on_change(False), self._lock:
85
+ self.skipped_example_ids.append(example.id)
86
+
87
+ def on_example_complete(
88
+ self, runner: Runner, experiment: Experiment, example: Example):
89
+ del runner, experiment
90
+ with pg.notify_on_change(False), self._lock:
91
+ self.completed_example_ids.append(example.id)
92
+
93
+
94
+ class RunnerTest(unittest.TestCase):
95
+
96
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
97
+ self.assertEqual(len(actual), len(expected))
98
+ for i, (x, y) in enumerate(zip(actual, expected)):
99
+ if x is not y:
100
+ print(i, pg.diff(x, y))
101
+ self.assertIs(x, y)
102
+
103
+ def test_basic(self):
104
+ plugin = TestPlugin()
105
+ exp = eval_test_helper.test_experiment()
106
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
107
+ run = exp.run(root_dir, runner='sequential', plugins=[plugin])
108
+
109
+ self.assertIsNotNone(plugin.start_time)
110
+ self.assertIsNotNone(plugin.complete_time)
111
+ self.assertGreater(plugin.complete_time, plugin.start_time)
112
+
113
+ self.assert_same_list(
114
+ plugin.started_experiments,
115
+ exp.nonleaf_nodes + exp.leaf_nodes
116
+ )
117
+ self.assert_same_list(
118
+ plugin.completed_experiments,
119
+ exp.leaf_nodes + list(reversed(exp.nonleaf_nodes))
120
+ )
121
+ self.assert_same_list(
122
+ plugin.started_example_ids, list(range(1, 11)) * 6
123
+ )
124
+ self.assert_same_list(
125
+ plugin.completed_example_ids, list(range(1, 11)) * 6
126
+ )
127
+ self.assert_same_list(plugin.skipped_experiments, [])
128
+ self.assert_same_list(plugin.skipped_example_ids, [])
129
+ self.assertTrue(
130
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
131
+ )
132
+
133
+ for node in exp.nodes:
134
+ self.assertTrue(node.progress.is_started)
135
+ self.assertTrue(node.progress.is_completed)
136
+ if node.is_leaf:
137
+ self.assertEqual(node.progress.num_skipped, 0)
138
+ self.assertEqual(node.progress.num_completed, 10)
139
+ self.assertEqual(node.progress.num_failed, 1)
140
+ else:
141
+ self.assertEqual(node.progress.num_skipped, 0)
142
+ self.assertEqual(node.progress.num_failed, 0)
143
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
144
+
145
+ def test_raise_if_has_error(self):
146
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
147
+ exp = eval_test_helper.TestEvaluation()
148
+ with self.assertRaisesRegex(ValueError, 'x should not be 5'):
149
+ exp.run(
150
+ root_dir, runner='sequential', plugins=[], raise_if_has_error=True
151
+ )
152
+
153
+ with self.assertRaisesRegex(ValueError, 'x should not be 5'):
154
+ exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
155
+
156
+ def test_example_ids(self):
157
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
158
+ exp = eval_test_helper.test_experiment()
159
+ plugin = TestPlugin()
160
+ _ = exp.run(
161
+ root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
162
+ )
163
+ self.assertEqual(plugin.started_example_ids, [5, 7, 9] * 6)
164
+ self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
165
+
166
+ def test_filter(self):
167
+ plugin = TestPlugin()
168
+ exp = eval_test_helper.test_experiment()
169
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')
170
+
171
+ _ = exp.run(
172
+ root_dir, runner='sequential', plugins=[plugin],
173
+ filter=lambda e: e.lm.offset != 0
174
+ )
175
+ self.assert_same_list(
176
+ plugin.started_experiments,
177
+ exp.nonleaf_nodes + exp.leaf_nodes[2:]
178
+ )
179
+ self.assert_same_list(
180
+ plugin.skipped_experiments, exp.leaf_nodes[:2]
181
+ )
182
+ self.assert_same_list(
183
+ plugin.completed_experiments,
184
+ exp.leaf_nodes[2:] + [exp.children[1], exp]
185
+ )
186
+
187
+ def test_use_cache(self):
188
+ @pg.functor()
189
+ def test_inputs(num_examples: int = 10):
190
+ return [
191
+ pg.Dict(
192
+ x=i // 2, y=(i // 2) ** 2,
193
+ groundtruth=(i // 2 + (i // 2) ** 2)
194
+ ) for i in range(num_examples)
195
+ ]
196
+
197
+ exp = eval_test_helper.TestEvaluation(
198
+ inputs=test_inputs(num_examples=pg.oneof([2, 4]))
199
+ )
200
+ # Global cache.
201
+ root_dir = os.path.join(tempfile.gettempdir(), 'global_cache')
202
+ run = exp.run(
203
+ root_dir, 'new', runner='sequential', use_cache='global', plugins=[]
204
+ )
205
+ self.assertTrue(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
206
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 4)
207
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
208
+
209
+ # Per-dataset cache.
210
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_dataset')
211
+ run = exp.run(
212
+ root_dir, 'new', runner='sequential',
213
+ use_cache='per_dataset', plugins=[]
214
+ )
215
+ for leaf in exp.leaf_nodes:
216
+ self.assertTrue(
217
+ pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
218
+ )
219
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 3)
220
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
221
+
222
+ # No cache.
223
+ root_dir = os.path.join(tempfile.gettempdir(), 'no')
224
+ run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
225
+ self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
226
+ for leaf in exp.leaf_nodes:
227
+ self.assertFalse(
228
+ pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
229
+ )
230
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 0)
231
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 6)
232
+
233
+
234
+ class ParallelRunnerTest(RunnerTest):
235
+
236
+ def test_parallel_runner(self):
237
+ plugin = TestPlugin()
238
+ exp = eval_test_helper.test_experiment()
239
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
240
+ run = exp.run(root_dir, runner='parallel', plugins=[plugin])
241
+
242
+ self.assertIsNotNone(plugin.start_time)
243
+ self.assertIsNotNone(plugin.complete_time)
244
+ self.assertGreater(plugin.complete_time, plugin.start_time)
245
+
246
+ self.assertEqual(
247
+ len(plugin.started_experiments), len(exp.nodes)
248
+ )
249
+ self.assertEqual(
250
+ len(plugin.completed_experiments), len(exp.nodes)
251
+ )
252
+ self.assertEqual(
253
+ len(plugin.started_example_ids), 6 * 10
254
+ )
255
+ self.assertEqual(
256
+ len(plugin.completed_example_ids), 6 * 10
257
+ )
258
+ self.assert_same_list(plugin.skipped_experiments, [])
259
+ self.assert_same_list(plugin.skipped_example_ids, [])
260
+ self.assertTrue(
261
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
262
+ )
263
+
264
+ for node in exp.nodes:
265
+ self.assertTrue(node.progress.is_started)
266
+ self.assertTrue(node.progress.is_completed)
267
+ if node.is_leaf:
268
+ self.assertEqual(node.progress.num_skipped, 0)
269
+ self.assertEqual(node.progress.num_completed, 10)
270
+ self.assertEqual(node.progress.num_failed, 1)
271
+ else:
272
+ self.assertEqual(node.progress.num_skipped, 0)
273
+ self.assertEqual(node.progress.num_failed, 0)
274
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
275
+
276
+ def test_concurrent_startup_delay(self):
277
+ plugin = TestPlugin()
278
+ exp = eval_test_helper.test_experiment()
279
+ root_dir = os.path.join(
280
+ tempfile.gettempdir(), 'test_concurrent_startup_delay'
281
+ )
282
+ _ = exp.run(
283
+ root_dir,
284
+ runner='parallel',
285
+ plugins=[plugin],
286
+ concurrent_startup_delay=(0, 5),
287
+ )
288
+
289
+
290
+ class DebugRunnerTest(RunnerTest):
291
+
292
+ def test_debug_runner(self):
293
+ plugin = TestPlugin()
294
+ exp = eval_test_helper.test_experiment()
295
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
296
+ run = exp.run(root_dir, runner='debug', plugins=[plugin])
297
+
298
+ self.assertIsNotNone(plugin.start_time)
299
+ self.assertIsNotNone(plugin.complete_time)
300
+ self.assertGreater(plugin.complete_time, plugin.start_time)
301
+
302
+ self.assertEqual(
303
+ len(plugin.started_experiments), len(exp.nodes)
304
+ )
305
+ self.assertEqual(
306
+ len(plugin.completed_experiments), len(exp.nodes)
307
+ )
308
+ self.assertEqual(
309
+ len(plugin.started_example_ids), 6 * 1
310
+ )
311
+ self.assertEqual(
312
+ len(plugin.completed_example_ids), 6 * 1
313
+ )
314
+ self.assert_same_list(plugin.skipped_experiments, [])
315
+ self.assert_same_list(plugin.skipped_example_ids, [])
316
+ self.assertFalse(
317
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
318
+ )
319
+
320
+ for node in exp.nodes:
321
+ self.assertTrue(node.progress.is_started)
322
+ self.assertTrue(node.progress.is_completed)
323
+ if node.is_leaf:
324
+ self.assertEqual(node.progress.num_skipped, 0)
325
+ self.assertEqual(node.progress.num_completed, 1)
326
+ self.assertEqual(node.progress.num_failed, 0)
327
+ else:
328
+ self.assertEqual(node.progress.num_skipped, 0)
329
+ self.assertEqual(node.progress.num_failed, 0)
330
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
331
+
332
+
333
+ if __name__ == '__main__':
334
+ unittest.main()
langfun/core/langfunc.py CHANGED
@@ -14,7 +14,7 @@
14
14
  """LangFunc: Language-based functions."""
15
15
 
16
16
  import dataclasses
17
- from typing import Annotated, Type, Union
17
+ from typing import Annotated, Type
18
18
 
19
19
  from langfun.core import component
20
20
  from langfun.core import language_model
@@ -269,6 +269,9 @@ class LangFunc(
269
269
  # Send rendered text to LM.
270
270
  lm_output = self.lm(lm_input, cache_seed=cache_seed)
271
271
 
272
+ # Attach cache seed.
273
+ lm_input.metadata.cache_seed = cache_seed
274
+
272
275
  # Transform the output message.
273
276
  lm_output = self.transform_output(lm_output)
274
277
  lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
@@ -328,22 +331,6 @@ class LangFunc(
328
331
  """Transforms the output message before returning from __call__."""
329
332
  return lm_output
330
333
 
331
- @classmethod
332
- def from_value(
333
- cls, value: Union[str, template_lib.Template], **kwargs
334
- ) -> 'LangFunc':
335
- """Create a LangFunc object from a string or template."""
336
- if isinstance(value, LangFunc):
337
- return value
338
- if isinstance(value, template_lib.Template):
339
- lfun = LangFunc(value.template_str, **kwargs)
340
- # So lfun could acccess all attributes from value.
341
- lfun.sym_setparent(value)
342
- return lfun
343
- if isinstance(value, str):
344
- return LangFunc(template_str=value, **kwargs)
345
- return LangFunc('{{input}}', input=value, **kwargs)
346
-
347
334
 
348
335
  # Register converter from str to LangFunc, therefore we can always
349
336
  # pass strs to attributes that accept LangFunc.
@@ -57,6 +57,10 @@ class BasicTest(unittest.TestCase):
57
57
  l2 = LangFunc.from_value(l1)
58
58
  self.assertIs(l2, l1)
59
59
 
60
+ l3 = LangFunc.from_value(l1, x=1)
61
+ self.assertIsNot(l3, l1)
62
+ self.assertTrue(pg.eq(l3, LangFunc('Hello', x=1)))
63
+
60
64
  c = template_lib.Template(
61
65
  '{{x}} + {{l}}',
62
66
  x=1,
@@ -83,14 +87,20 @@ class LangFuncCallTest(unittest.TestCase):
83
87
 
84
88
  r = l()
85
89
  self.assertEqual(
86
- r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
90
+ r,
91
+ message.AIMessage(
92
+ 'Hello!!!', score=0.0, logprobs=None, is_cached=False,
93
+ usage=language_model.UsageNotAvailable()
94
+ )
87
95
  )
88
96
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
89
- self.assertEqual(r.source, message.UserMessage('Hello'))
97
+ self.assertEqual(
98
+ r.source,
99
+ message.UserMessage('Hello', metadata=dict(cache_seed=0))
100
+ )
90
101
  self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
91
102
 
92
103
  self.assertEqual(str(l), 'Hello')
93
- print(repr(l))
94
104
  self.assertEqual(
95
105
  repr(l),
96
106
  "LangFunc(template_str='Hello', clean=True,"
@@ -98,7 +108,8 @@ class LangFuncCallTest(unittest.TestCase):
98
108
  ' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
99
109
  ' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
100
110
  ' max_concurrency=None, timeout=120.0, max_attempts=5,'
101
- ' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
111
+ ' retry_interval=(5, 60), exponential_backoff=True,'
112
+ ' max_retry_interval=300, debug=False))',
102
113
  )
103
114
 
104
115
  l = LangFunc('Hello')
@@ -106,11 +117,16 @@ class LangFuncCallTest(unittest.TestCase):
106
117
  self.assertEqual(l, 'Hello')
107
118
  self.assertEqual(l.natural_language_format(), 'Hello')
108
119
  self.assertEqual(l.render(), 'Hello')
109
- r = l()
120
+ r = l(cache_seed=1)
110
121
  self.assertEqual(
111
- r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
122
+ r,
123
+ message.AIMessage(
124
+ 'Hello!!!', score=0.0, logprobs=None, is_cached=False,
125
+ usage=language_model.UsageNotAvailable()
126
+ )
112
127
  )
113
128
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
129
+ self.assertEqual(r.source.metadata.cache_seed, 1)
114
130
 
115
131
  self.assertEqual(str(l), 'Hello')
116
132