langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -11,18 +11,19 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for structured
|
14
|
+
"""Tests for structured query."""
|
15
15
|
|
16
16
|
import inspect
|
17
|
+
import math
|
18
|
+
from typing import Any
|
17
19
|
import unittest
|
18
20
|
|
19
21
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
22
|
from langfun.core import modalities
|
22
23
|
from langfun.core.llms import fake
|
24
|
+
from langfun.core.llms.cache import in_memory
|
23
25
|
from langfun.core.structured import mapping
|
24
|
-
from langfun.core.structured import
|
25
|
-
from langfun.core.structured import schema as schema_lib
|
26
|
+
from langfun.core.structured import querying
|
26
27
|
import pyglove as pg
|
27
28
|
|
28
29
|
|
@@ -43,13 +44,17 @@ class QueryTest(unittest.TestCase):
|
|
43
44
|
self,
|
44
45
|
prompt,
|
45
46
|
schema,
|
47
|
+
examples: list[mapping.MappingExample] | None = None,
|
46
48
|
*,
|
47
49
|
expected_snippet: str,
|
48
50
|
exact_match: bool = False,
|
49
51
|
expected_modalities: int = 0,
|
50
52
|
**kwargs,
|
51
53
|
):
|
52
|
-
m =
|
54
|
+
m = querying.query(
|
55
|
+
prompt, schema=schema, examples=examples,
|
56
|
+
**kwargs, returns_message=True
|
57
|
+
)
|
53
58
|
self.assertIsNotNone(m.lm_input)
|
54
59
|
if exact_match:
|
55
60
|
self.assertEqual(expected_snippet, m.lm_input.text)
|
@@ -62,14 +67,14 @@ class QueryTest(unittest.TestCase):
|
|
62
67
|
|
63
68
|
def test_call(self):
|
64
69
|
lm = fake.StaticSequence(['1'])
|
65
|
-
self.assertEqual(
|
70
|
+
self.assertEqual(querying.query('what is 1 + 0', int, lm=lm), 1)
|
66
71
|
|
67
72
|
# Testing calling the same `lm` without copy.
|
68
73
|
with self.assertRaises(IndexError):
|
69
|
-
|
74
|
+
querying.query('what is 1 + 2', int, lm=lm)
|
70
75
|
|
71
76
|
self.assertEqual(
|
72
|
-
|
77
|
+
querying.query(
|
73
78
|
'what is 1 + 0', int, lm=lm.clone(), returns_message=True
|
74
79
|
),
|
75
80
|
lf.AIMessage(
|
@@ -77,21 +82,23 @@ class QueryTest(unittest.TestCase):
|
|
77
82
|
result=1,
|
78
83
|
score=1.0,
|
79
84
|
logprobs=None,
|
85
|
+
is_cached=False,
|
86
|
+
usage=lf.LMSamplingUsage(323, 1, 324),
|
80
87
|
tags=['lm-response', 'lm-output', 'transformed'],
|
81
88
|
),
|
82
89
|
)
|
83
90
|
self.assertEqual(
|
84
|
-
|
85
|
-
lf.Template('what is {{x}} + {{y}}'
|
91
|
+
querying.query(
|
92
|
+
lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
|
86
93
|
),
|
87
94
|
1,
|
88
95
|
)
|
89
96
|
self.assertEqual(
|
90
|
-
|
97
|
+
querying.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
|
91
98
|
1,
|
92
99
|
)
|
93
100
|
self.assertEqual(
|
94
|
-
|
101
|
+
querying.query(
|
95
102
|
'what is {{x}} + {{y}}',
|
96
103
|
x=1,
|
97
104
|
y=0,
|
@@ -100,7 +107,7 @@ class QueryTest(unittest.TestCase):
|
|
100
107
|
'The answer is one.',
|
101
108
|
)
|
102
109
|
self.assertEqual(
|
103
|
-
|
110
|
+
querying.query(
|
104
111
|
Activity.partial(),
|
105
112
|
lm=fake.StaticResponse('Activity(description="hello")'),
|
106
113
|
),
|
@@ -116,12 +123,59 @@ class QueryTest(unittest.TestCase):
|
|
116
123
|
y=2,
|
117
124
|
lm=lm.clone(),
|
118
125
|
expected_snippet=(
|
119
|
-
'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT'
|
120
|
-
'
|
121
|
-
'
|
122
|
-
'
|
123
|
-
'
|
124
|
-
'
|
126
|
+
'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
|
127
|
+
'according to OUTPUT_TYPE.\n\n'
|
128
|
+
'INPUT_OBJECT:\n 1 + 1 =\n\n'
|
129
|
+
'OUTPUT_TYPE:\n'
|
130
|
+
' Answer\n\n'
|
131
|
+
' ```python\n'
|
132
|
+
' class Answer:\n'
|
133
|
+
' final_answer: int\n'
|
134
|
+
' ```\n\n'
|
135
|
+
'OUTPUT_OBJECT:\n'
|
136
|
+
' ```python\n'
|
137
|
+
' Answer(\n'
|
138
|
+
' final_answer=2\n'
|
139
|
+
' )\n'
|
140
|
+
' ```\n\n'
|
141
|
+
'INPUT_OBJECT:\n'
|
142
|
+
' What is 1 + 2?\n\n'
|
143
|
+
'OUTPUT_TYPE:\n'
|
144
|
+
' int\n\n'
|
145
|
+
'OUTPUT_OBJECT:'
|
146
|
+
),
|
147
|
+
)
|
148
|
+
|
149
|
+
def test_str_to_structure_render_custom_template(self):
|
150
|
+
lm = fake.StaticResponse('1')
|
151
|
+
self.assert_render(
|
152
|
+
'What is {{x}} + {{y}}?',
|
153
|
+
int,
|
154
|
+
x=1,
|
155
|
+
y=2,
|
156
|
+
lm=lm.clone(),
|
157
|
+
template_str='!!{{ DEFAULT }}!!',
|
158
|
+
expected_snippet=(
|
159
|
+
'!!Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
|
160
|
+
'according to OUTPUT_TYPE.\n\n'
|
161
|
+
'INPUT_OBJECT:\n 1 + 1 =\n\n'
|
162
|
+
'OUTPUT_TYPE:\n'
|
163
|
+
' Answer\n\n'
|
164
|
+
' ```python\n'
|
165
|
+
' class Answer:\n'
|
166
|
+
' final_answer: int\n'
|
167
|
+
' ```\n\n'
|
168
|
+
'OUTPUT_OBJECT:\n'
|
169
|
+
' ```python\n'
|
170
|
+
' Answer(\n'
|
171
|
+
' final_answer=2\n'
|
172
|
+
' )\n'
|
173
|
+
' ```\n\n'
|
174
|
+
'INPUT_OBJECT:\n'
|
175
|
+
' What is 1 + 2?\n\n'
|
176
|
+
'OUTPUT_TYPE:\n'
|
177
|
+
' int\n\n'
|
178
|
+
'OUTPUT_OBJECT:!!'
|
125
179
|
),
|
126
180
|
)
|
127
181
|
|
@@ -162,7 +216,7 @@ class QueryTest(unittest.TestCase):
|
|
162
216
|
modalities.Image.from_bytes(b'mock_image'),
|
163
217
|
int,
|
164
218
|
lm=lm,
|
165
|
-
expected_snippet='\n\nINPUT_OBJECT:\n
|
219
|
+
expected_snippet='\n\nINPUT_OBJECT:\n <<[[input]]>>\n\n',
|
166
220
|
expected_modalities=1,
|
167
221
|
)
|
168
222
|
|
@@ -172,7 +226,7 @@ class QueryTest(unittest.TestCase):
|
|
172
226
|
modalities.Image.from_bytes(b'mock_image'),
|
173
227
|
None,
|
174
228
|
lm=lm,
|
175
|
-
expected_snippet='
|
229
|
+
expected_snippet='<<[[input]]>>',
|
176
230
|
exact_match=True,
|
177
231
|
expected_modalities=1,
|
178
232
|
)
|
@@ -185,7 +239,9 @@ class QueryTest(unittest.TestCase):
|
|
185
239
|
this_image=modalities.Image.from_bytes(b'cat_image'),
|
186
240
|
that_image=modalities.Image.from_bytes(b'mouse_image'),
|
187
241
|
lm=lm,
|
188
|
-
expected_snippet=
|
242
|
+
expected_snippet=(
|
243
|
+
'What are these? <<[[this_image]]>> and <<[[that_image]]>>'
|
244
|
+
),
|
189
245
|
exact_match=True,
|
190
246
|
expected_modalities=2,
|
191
247
|
)
|
@@ -199,7 +255,7 @@ class QueryTest(unittest.TestCase):
|
|
199
255
|
],
|
200
256
|
None,
|
201
257
|
lm=lm,
|
202
|
-
expected_snippet='`[
|
258
|
+
expected_snippet='`[<<[[input[0]]]>>, <<[[input[1]]]>>]`',
|
203
259
|
exact_match=True,
|
204
260
|
expected_modalities=2,
|
205
261
|
)
|
@@ -217,33 +273,349 @@ class QueryTest(unittest.TestCase):
|
|
217
273
|
INPUT_OBJECT:
|
218
274
|
```python
|
219
275
|
[
|
220
|
-
|
221
|
-
|
222
|
-
),
|
223
|
-
ModalityRef(
|
224
|
-
name='input[1]'
|
225
|
-
)
|
276
|
+
<<[[input[0]]]>>,
|
277
|
+
<<[[input[1]]]>>
|
226
278
|
]
|
227
279
|
```
|
228
|
-
|
229
|
-
MODALITY_REFERENCES:
|
230
|
-
{
|
231
|
-
'input[0]': {{input[0]}},
|
232
|
-
'input[1]': {{input[1]}}
|
233
|
-
}
|
234
280
|
"""),
|
235
281
|
expected_modalities=2,
|
236
282
|
)
|
237
283
|
|
284
|
+
def test_structure_with_modality_and_examples_to_structure_render(self):
|
285
|
+
lm = fake.StaticResponse('["cat", "mouse"]')
|
286
|
+
self.assert_render(
|
287
|
+
[
|
288
|
+
modalities.Image.from_bytes(b'cat_image'),
|
289
|
+
modalities.Image.from_bytes(b'mouse_image'),
|
290
|
+
],
|
291
|
+
list[str],
|
292
|
+
examples=[
|
293
|
+
mapping.MappingExample(
|
294
|
+
input=[modalities.Image.from_bytes(b'dog_image')],
|
295
|
+
schema=list[str],
|
296
|
+
output=['dog'],
|
297
|
+
),
|
298
|
+
],
|
299
|
+
lm=lm,
|
300
|
+
expected_snippet=inspect.cleandoc("""
|
301
|
+
INPUT_OBJECT:
|
302
|
+
```python
|
303
|
+
[
|
304
|
+
<<[[examples[0].input[0]]]>>
|
305
|
+
]
|
306
|
+
```
|
307
|
+
|
308
|
+
OUTPUT_TYPE:
|
309
|
+
list[str]
|
310
|
+
|
311
|
+
OUTPUT_OBJECT:
|
312
|
+
```python
|
313
|
+
[
|
314
|
+
'dog'
|
315
|
+
]
|
316
|
+
```
|
317
|
+
|
318
|
+
|
319
|
+
INPUT_OBJECT:
|
320
|
+
```python
|
321
|
+
[
|
322
|
+
<<[[input[0]]]>>,
|
323
|
+
<<[[input[1]]]>>
|
324
|
+
]
|
325
|
+
```
|
326
|
+
"""),
|
327
|
+
expected_modalities=3,
|
328
|
+
)
|
329
|
+
|
330
|
+
def test_multiple_queries(self):
|
331
|
+
self.assertEqual(
|
332
|
+
querying.query(
|
333
|
+
'Compute 1 + 2',
|
334
|
+
int,
|
335
|
+
lm=[
|
336
|
+
fake.StaticResponse('1'),
|
337
|
+
fake.StaticResponse('2'),
|
338
|
+
],
|
339
|
+
num_samples=[1, 2],
|
340
|
+
),
|
341
|
+
[1, 2, 2]
|
342
|
+
)
|
343
|
+
self.assertEqual(
|
344
|
+
querying.query(
|
345
|
+
'Compute 1 + 2',
|
346
|
+
int,
|
347
|
+
lm=[
|
348
|
+
fake.StaticResponse('1'),
|
349
|
+
fake.StaticResponse('2'),
|
350
|
+
],
|
351
|
+
num_samples=2,
|
352
|
+
),
|
353
|
+
[1, 1, 2, 2]
|
354
|
+
)
|
355
|
+
self.assertEqual(
|
356
|
+
querying.query(
|
357
|
+
'Compute 1 + 2',
|
358
|
+
int,
|
359
|
+
lm=[
|
360
|
+
fake.StaticResponse('1'),
|
361
|
+
fake.StaticResponse('abc'),
|
362
|
+
],
|
363
|
+
num_samples=[1, 2],
|
364
|
+
),
|
365
|
+
[1]
|
366
|
+
)
|
367
|
+
self.assertEqual(
|
368
|
+
querying.query(
|
369
|
+
'Compute 1 + 2',
|
370
|
+
int,
|
371
|
+
default=0,
|
372
|
+
lm=[
|
373
|
+
fake.StaticResponse('1'),
|
374
|
+
fake.StaticResponse('abc'),
|
375
|
+
],
|
376
|
+
num_samples=[1, 2],
|
377
|
+
),
|
378
|
+
[1, 0, 0]
|
379
|
+
)
|
380
|
+
results = querying.query(
|
381
|
+
'Compute 1 + 2',
|
382
|
+
int,
|
383
|
+
default=0,
|
384
|
+
lm=[
|
385
|
+
fake.StaticResponse('1'),
|
386
|
+
fake.StaticResponse('abc'),
|
387
|
+
],
|
388
|
+
returns_message=True,
|
389
|
+
)
|
390
|
+
self.assertEqual([r.text for r in results], ['1', 'abc'])
|
391
|
+
self.assertEqual([r.result for r in results], [1, 0])
|
392
|
+
|
238
393
|
def test_bad_protocol(self):
|
239
394
|
with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
|
240
|
-
|
395
|
+
querying.query('what is 1 + 1', int, protocol='text')
|
396
|
+
|
397
|
+
def test_query_prompt(self):
|
398
|
+
self.assertEqual(
|
399
|
+
querying.query_prompt('what is this?', int),
|
400
|
+
inspect.cleandoc("""
|
401
|
+
Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE.
|
402
|
+
|
403
|
+
INPUT_OBJECT:
|
404
|
+
1 + 1 =
|
405
|
+
|
406
|
+
OUTPUT_TYPE:
|
407
|
+
Answer
|
408
|
+
|
409
|
+
```python
|
410
|
+
class Answer:
|
411
|
+
final_answer: int
|
412
|
+
```
|
413
|
+
|
414
|
+
OUTPUT_OBJECT:
|
415
|
+
```python
|
416
|
+
Answer(
|
417
|
+
final_answer=2
|
418
|
+
)
|
419
|
+
```
|
420
|
+
|
421
|
+
INPUT_OBJECT:
|
422
|
+
what is this?
|
423
|
+
|
424
|
+
OUTPUT_TYPE:
|
425
|
+
int
|
426
|
+
|
427
|
+
OUTPUT_OBJECT:
|
428
|
+
"""),
|
429
|
+
)
|
430
|
+
|
431
|
+
def test_query_prompt_with_metadata(self):
|
432
|
+
self.assertIn(
|
433
|
+
'x',
|
434
|
+
querying.query_prompt(
|
435
|
+
'what is this?',
|
436
|
+
metadata_x=1
|
437
|
+
).metadata
|
438
|
+
)
|
439
|
+
self.assertIn(
|
440
|
+
'x',
|
441
|
+
querying.query_prompt(
|
442
|
+
'what is this?',
|
443
|
+
int,
|
444
|
+
metadata_x=1
|
445
|
+
).metadata
|
446
|
+
)
|
447
|
+
|
448
|
+
def test_query_prompt_with_unrooted_template(self):
|
449
|
+
output = querying.query_prompt(
|
450
|
+
pg.Dict(
|
451
|
+
input=lf.Template(
|
452
|
+
'what is {{image}}',
|
453
|
+
image=modalities.Image.from_bytes(b'mock_image')
|
454
|
+
)
|
455
|
+
).input,
|
456
|
+
)
|
457
|
+
self.assertIsNotNone(output.get_modality('image'))
|
458
|
+
|
459
|
+
def test_query_and_reduce(self):
|
460
|
+
self.assertEqual(
|
461
|
+
querying.query_and_reduce(
|
462
|
+
'Compute 1 + 1',
|
463
|
+
int,
|
464
|
+
reduce=sum,
|
465
|
+
lm=[
|
466
|
+
fake.StaticResponse('1'),
|
467
|
+
fake.StaticResponse('2'),
|
468
|
+
],
|
469
|
+
num_samples=[1, 2],
|
470
|
+
),
|
471
|
+
5
|
472
|
+
)
|
473
|
+
self.assertEqual(
|
474
|
+
querying.query_and_reduce(
|
475
|
+
'Compute 1 + 1',
|
476
|
+
int,
|
477
|
+
reduce=sum,
|
478
|
+
lm=fake.StaticResponse('2'),
|
479
|
+
),
|
480
|
+
2
|
481
|
+
)
|
482
|
+
|
483
|
+
def test_query_output(self):
|
484
|
+
self.assertEqual(
|
485
|
+
querying.query_output(
|
486
|
+
lf.AIMessage('1'),
|
487
|
+
int,
|
488
|
+
),
|
489
|
+
1,
|
490
|
+
)
|
491
|
+
|
492
|
+
def test_query_reward(self):
|
493
|
+
|
494
|
+
class Answer(pg.Object):
|
495
|
+
final_answer: int
|
496
|
+
|
497
|
+
def __reward__(self, inputs: lf.Template) -> None:
|
498
|
+
diff = abs(self.final_answer - (inputs.x + inputs.y))
|
499
|
+
# Center screwed sigmoid scaled to [-1.0 and 1.0].
|
500
|
+
return 4 / (1 + math.exp(diff)) - 1.0
|
501
|
+
|
502
|
+
# Case 1: Reward function based on input and output.
|
503
|
+
self.assertEqual(
|
504
|
+
querying.query_reward(
|
505
|
+
mapping.MappingExample(
|
506
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
507
|
+
schema=Answer,
|
508
|
+
output=Answer(final_answer=2),
|
509
|
+
),
|
510
|
+
'Answer(2)'
|
511
|
+
),
|
512
|
+
1.0
|
513
|
+
)
|
514
|
+
self.assertEqual(
|
515
|
+
querying.query_reward(
|
516
|
+
mapping.MappingExample(
|
517
|
+
input=lf.Template('{{x}} + {{y}}', x=2, y=3),
|
518
|
+
output=Answer(final_answer=2),
|
519
|
+
).to_json_str(),
|
520
|
+
'Answer(5)'
|
521
|
+
),
|
522
|
+
1.0
|
523
|
+
)
|
524
|
+
|
525
|
+
# Case 2: Reward function based on input, result and expected output.
|
526
|
+
class Answer2(pg.Object):
|
527
|
+
final_answer: int
|
528
|
+
|
529
|
+
def __reward__(self, inputs: lf.Template, expected_output: 'Answer2'):
|
530
|
+
return (
|
531
|
+
1.0 if self.final_answer == expected_output.final_answer else -1.0
|
532
|
+
)
|
533
|
+
|
534
|
+
self.assertEqual(
|
535
|
+
querying.query_reward(
|
536
|
+
mapping.MappingExample(
|
537
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
538
|
+
output=Answer2(final_answer=2),
|
539
|
+
),
|
540
|
+
'Answer2(3)'
|
541
|
+
),
|
542
|
+
-1.0
|
543
|
+
)
|
544
|
+
|
545
|
+
# Case 3: Reward function based on input, result, expected output
|
546
|
+
# and metadata.
|
547
|
+
class Answer3(pg.Object):
|
548
|
+
final_answer: int
|
549
|
+
|
550
|
+
def __reward__(self,
|
551
|
+
inputs: lf.Template,
|
552
|
+
expected_output: 'Answer3',
|
553
|
+
metadata: dict[str, Any]):
|
554
|
+
del inputs
|
555
|
+
return (
|
556
|
+
1.0 if self.final_answer == expected_output.final_answer else -1.0
|
557
|
+
) * metadata['weight']
|
558
|
+
|
559
|
+
self.assertEqual(
|
560
|
+
querying.query_reward(
|
561
|
+
mapping.MappingExample(
|
562
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
563
|
+
output=Answer3(final_answer=2),
|
564
|
+
metadata=dict(weight=0.5)
|
565
|
+
),
|
566
|
+
'Answer3(3)'
|
567
|
+
),
|
568
|
+
-0.5
|
569
|
+
)
|
570
|
+
|
571
|
+
# Case 4: No reward function is provided.
|
572
|
+
class Answer4(pg.Object):
|
573
|
+
final_answer: int
|
574
|
+
|
575
|
+
self.assertIsNone(
|
576
|
+
querying.query_reward(
|
577
|
+
mapping.MappingExample(
|
578
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
579
|
+
output=Answer4(final_answer=2),
|
580
|
+
),
|
581
|
+
'Answer2(2)'
|
582
|
+
)
|
583
|
+
)
|
584
|
+
|
585
|
+
# Case 5: Not a structured output.
|
586
|
+
self.assertIsNone(
|
587
|
+
querying.query_reward(
|
588
|
+
mapping.MappingExample(
|
589
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
590
|
+
output='2',
|
591
|
+
),
|
592
|
+
'2'
|
593
|
+
)
|
594
|
+
)
|
595
|
+
|
596
|
+
# Case 6: Bad reward function.
|
597
|
+
class Answer5(pg.Object):
|
598
|
+
final_answer: int
|
599
|
+
|
600
|
+
def __reward__(self):
|
601
|
+
return 0.0
|
602
|
+
|
603
|
+
with self.assertRaisesRegex(
|
604
|
+
TypeError, '.*Answer5.__reward__` should have signature'
|
605
|
+
):
|
606
|
+
querying.query_reward(
|
607
|
+
mapping.MappingExample(
|
608
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
609
|
+
output=Answer5(final_answer=2),
|
610
|
+
),
|
611
|
+
'Answer5(2)'
|
612
|
+
)
|
241
613
|
|
242
614
|
|
243
615
|
class QueryStructurePythonTest(unittest.TestCase):
|
244
616
|
|
245
617
|
def test_render_no_examples(self):
|
246
|
-
l =
|
618
|
+
l = querying._QueryStructurePython(
|
247
619
|
input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
|
248
620
|
)
|
249
621
|
self.assertEqual(
|
@@ -264,7 +636,9 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
264
636
|
|
265
637
|
OUTPUT_OBJECT:
|
266
638
|
```python
|
267
|
-
Answer(
|
639
|
+
Answer(
|
640
|
+
final_answer=2
|
641
|
+
)
|
268
642
|
```
|
269
643
|
|
270
644
|
INPUT_OBJECT:
|
@@ -278,7 +652,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
278
652
|
)
|
279
653
|
|
280
654
|
def test_render(self):
|
281
|
-
l =
|
655
|
+
l = querying._QueryStructurePython(
|
282
656
|
input=lf.AIMessage('Compute 12 / 6 + 2.'),
|
283
657
|
schema=int,
|
284
658
|
examples=[
|
@@ -308,7 +682,9 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
308
682
|
|
309
683
|
OUTPUT_OBJECT:
|
310
684
|
```python
|
311
|
-
Answer(
|
685
|
+
Answer(
|
686
|
+
final_answer=2
|
687
|
+
)
|
312
688
|
```
|
313
689
|
|
314
690
|
INPUT_OBJECT:
|
@@ -386,7 +762,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
386
762
|
),
|
387
763
|
override_attrs=True,
|
388
764
|
):
|
389
|
-
l =
|
765
|
+
l = querying._QueryStructurePython(
|
390
766
|
input=lm_input,
|
391
767
|
schema=[Itinerary],
|
392
768
|
examples=[
|
@@ -420,10 +796,10 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
420
796
|
override_attrs=True,
|
421
797
|
):
|
422
798
|
with self.assertRaisesRegex(
|
423
|
-
|
799
|
+
mapping.MappingError,
|
424
800
|
'name .* is not defined',
|
425
801
|
):
|
426
|
-
|
802
|
+
querying.query('Compute 1 + 2', int)
|
427
803
|
|
428
804
|
def test_autofix(self):
|
429
805
|
lm = fake.StaticSequence([
|
@@ -434,13 +810,30 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
434
810
|
)
|
435
811
|
"""),
|
436
812
|
])
|
437
|
-
self.assertEqual(
|
813
|
+
self.assertEqual(querying.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
|
814
|
+
|
815
|
+
def test_response_postprocess(self):
|
816
|
+
with lf.context(
|
817
|
+
lm=fake.StaticResponse('<!-- some comment-->\n3'),
|
818
|
+
override_attrs=True,
|
819
|
+
):
|
820
|
+
self.assertEqual(
|
821
|
+
querying.query(
|
822
|
+
'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
|
823
|
+
'3'
|
824
|
+
)
|
825
|
+
self.assertEqual(
|
826
|
+
querying.query(
|
827
|
+
'Compute 1 + 2', int,
|
828
|
+
response_postprocess=lambda x: x.split('\n')[1]),
|
829
|
+
3
|
830
|
+
)
|
438
831
|
|
439
832
|
|
440
833
|
class QueryStructureJsonTest(unittest.TestCase):
|
441
834
|
|
442
835
|
def test_render_no_examples(self):
|
443
|
-
l =
|
836
|
+
l = querying._QueryStructureJson(
|
444
837
|
input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
|
445
838
|
)
|
446
839
|
self.assertEqual(
|
@@ -456,10 +849,10 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
456
849
|
1 + 1 =
|
457
850
|
|
458
851
|
SCHEMA:
|
459
|
-
{"result": {"_type": "langfun.core.structured.
|
852
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
460
853
|
|
461
854
|
JSON:
|
462
|
-
{"result": {"_type": "langfun.core.structured.
|
855
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
463
856
|
|
464
857
|
INPUT_OBJECT:
|
465
858
|
Compute 12 / 6 + 2.
|
@@ -472,7 +865,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
472
865
|
)
|
473
866
|
|
474
867
|
def test_render(self):
|
475
|
-
l =
|
868
|
+
l = querying._QueryStructureJson(
|
476
869
|
input=lf.AIMessage('Compute 12 / 6 + 2.'),
|
477
870
|
schema=int,
|
478
871
|
examples=[
|
@@ -493,10 +886,10 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
493
886
|
1 + 1 =
|
494
887
|
|
495
888
|
SCHEMA:
|
496
|
-
{"result": {"_type": "langfun.core.structured.
|
889
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
497
890
|
|
498
891
|
JSON:
|
499
|
-
{"result": {"_type": "langfun.core.structured.
|
892
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
500
893
|
|
501
894
|
INPUT_OBJECT:
|
502
895
|
What is the answer of 1 plus 1?
|
@@ -607,7 +1000,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
607
1000
|
),
|
608
1001
|
override_attrs=True,
|
609
1002
|
):
|
610
|
-
l =
|
1003
|
+
l = querying._QueryStructureJson(
|
611
1004
|
input=lm_input,
|
612
1005
|
schema=[Itinerary],
|
613
1006
|
examples=[
|
@@ -636,22 +1029,114 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
636
1029
|
self.assertIsNone(r.result[0].hotel)
|
637
1030
|
|
638
1031
|
def test_bad_transform(self):
|
639
|
-
with
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
with self.assertRaisesRegex(
|
644
|
-
schema_lib.JsonError,
|
645
|
-
'No JSON dict in the output',
|
1032
|
+
with in_memory.lm_cache() as cache:
|
1033
|
+
with lf.context(
|
1034
|
+
lm=fake.StaticSequence(['3']),
|
1035
|
+
override_attrs=True,
|
646
1036
|
):
|
647
|
-
|
1037
|
+
with self.assertRaisesRegex(
|
1038
|
+
mapping.MappingError,
|
1039
|
+
'No JSON dict in the output',
|
1040
|
+
):
|
1041
|
+
querying.query('Compute 1 + 2', int, protocol='json', cache_seed=1)
|
1042
|
+
# Make sure bad mapping does not impact cache.
|
1043
|
+
self.assertEqual(len(cache), 0)
|
648
1044
|
|
649
1045
|
def test_query(self):
|
650
1046
|
lm = fake.StaticSequence(['{"result": 1}'])
|
651
1047
|
self.assertEqual(
|
652
|
-
|
1048
|
+
querying.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
|
653
1049
|
)
|
654
1050
|
|
655
1051
|
|
1052
|
+
class QueryInvocationTest(unittest.TestCase):
|
1053
|
+
|
1054
|
+
def test_basics(self):
|
1055
|
+
lm = fake.StaticSequence([
|
1056
|
+
'Activity(description="hi"',
|
1057
|
+
])
|
1058
|
+
with querying.track_queries() as queries:
|
1059
|
+
querying.query('foo', Activity, default=None, lm=lm)
|
1060
|
+
|
1061
|
+
self.assertTrue(queries[0].has_error)
|
1062
|
+
self.assertIsInstance(queries[0].output, mapping.MappingError)
|
1063
|
+
|
1064
|
+
def test_to_html(self):
|
1065
|
+
lm = fake.StaticSequence([
|
1066
|
+
'Activity(description="hi")',
|
1067
|
+
])
|
1068
|
+
with querying.track_queries() as queries:
|
1069
|
+
querying.query('foo', Activity, lm=lm)
|
1070
|
+
|
1071
|
+
self.assertIn('schema', queries[0].to_html_str())
|
1072
|
+
|
1073
|
+
|
1074
|
+
class TrackQueriesTest(unittest.TestCase):
|
1075
|
+
|
1076
|
+
def test_include_child_scopes(self):
|
1077
|
+
lm = fake.StaticSequence([
|
1078
|
+
'bar',
|
1079
|
+
'Activity(description="hi")',
|
1080
|
+
])
|
1081
|
+
with querying.track_queries() as queries:
|
1082
|
+
querying.query('foo', lm=lm)
|
1083
|
+
with querying.track_queries() as child_queries:
|
1084
|
+
querying.query('give me an activity', Activity, lm=lm)
|
1085
|
+
|
1086
|
+
self.assertEqual(len(queries), 2)
|
1087
|
+
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
1088
|
+
self.assertIsNone(queries[0].schema)
|
1089
|
+
self.assertEqual(queries[0].output, 'bar')
|
1090
|
+
self.assertIs(queries[0].lm, lm)
|
1091
|
+
|
1092
|
+
self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity')))
|
1093
|
+
self.assertEqual(queries[1].schema.spec.cls, Activity)
|
1094
|
+
self.assertTrue(pg.eq(queries[1].output, Activity(description='hi')))
|
1095
|
+
self.assertIs(queries[1].lm, lm)
|
1096
|
+
self.assertGreater(queries[0].elapse, 0)
|
1097
|
+
self.assertGreater(queries[0].usage_summary.total.total_tokens, 0)
|
1098
|
+
self.assertGreater(queries[1].usage_summary.total.total_tokens, 0)
|
1099
|
+
|
1100
|
+
self.assertEqual(len(child_queries), 1)
|
1101
|
+
self.assertIs(child_queries[0], queries[1])
|
1102
|
+
|
1103
|
+
def test_exclude_child_scopes(self):
|
1104
|
+
lm = fake.StaticSequence([
|
1105
|
+
'bar',
|
1106
|
+
'Activity(description="hi")',
|
1107
|
+
])
|
1108
|
+
with querying.track_queries(include_child_scopes=False) as queries:
|
1109
|
+
querying.query('foo', lm=lm)
|
1110
|
+
with querying.track_queries(include_child_scopes=False) as child_queries:
|
1111
|
+
querying.query('give me an activity', Activity, lm=lm)
|
1112
|
+
|
1113
|
+
self.assertEqual(len(queries), 1)
|
1114
|
+
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
1115
|
+
self.assertIsNone(queries[0].schema)
|
1116
|
+
self.assertEqual(queries[0].output, 'bar')
|
1117
|
+
self.assertIs(queries[0].lm, lm)
|
1118
|
+
|
1119
|
+
self.assertEqual(len(child_queries), 1)
|
1120
|
+
self.assertTrue(
|
1121
|
+
pg.eq(child_queries[0].input, lf.Template('give me an activity'))
|
1122
|
+
)
|
1123
|
+
self.assertEqual(child_queries[0].schema.spec.cls, Activity)
|
1124
|
+
self.assertTrue(pg.eq(child_queries[0].output, Activity(description='hi')))
|
1125
|
+
self.assertIs(child_queries[0].lm, lm)
|
1126
|
+
|
1127
|
+
def test_concurrent_map(self):
|
1128
|
+
|
1129
|
+
def make_query(prompt):
|
1130
|
+
_ = querying.query(prompt, lm=lm)
|
1131
|
+
|
1132
|
+
lm = fake.StaticSequence([
|
1133
|
+
'foo',
|
1134
|
+
'bar',
|
1135
|
+
])
|
1136
|
+
with querying.track_queries() as queries:
|
1137
|
+
list(lf.concurrent_map(make_query, ['a', 'b']))
|
1138
|
+
self.assertEqual(len(queries), 2)
|
1139
|
+
|
1140
|
+
|
656
1141
|
if __name__ == '__main__':
|
657
1142
|
unittest.main()
|