langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +17 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -11,16 +11,19 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for structured
|
14
|
+
"""Tests for structured query."""
|
15
15
|
|
16
16
|
import inspect
|
17
|
+
import math
|
18
|
+
from typing import Any
|
17
19
|
import unittest
|
18
20
|
|
19
21
|
import langfun.core as lf
|
20
22
|
from langfun.core import modalities
|
21
23
|
from langfun.core.llms import fake
|
24
|
+
from langfun.core.llms.cache import in_memory
|
22
25
|
from langfun.core.structured import mapping
|
23
|
-
from langfun.core.structured import
|
26
|
+
from langfun.core.structured import querying
|
24
27
|
import pyglove as pg
|
25
28
|
|
26
29
|
|
@@ -41,13 +44,17 @@ class QueryTest(unittest.TestCase):
|
|
41
44
|
self,
|
42
45
|
prompt,
|
43
46
|
schema,
|
47
|
+
examples: list[mapping.MappingExample] | None = None,
|
44
48
|
*,
|
45
49
|
expected_snippet: str,
|
46
50
|
exact_match: bool = False,
|
47
51
|
expected_modalities: int = 0,
|
48
52
|
**kwargs,
|
49
53
|
):
|
50
|
-
m =
|
54
|
+
m = querying.query(
|
55
|
+
prompt, schema=schema, examples=examples,
|
56
|
+
**kwargs, returns_message=True
|
57
|
+
)
|
51
58
|
self.assertIsNotNone(m.lm_input)
|
52
59
|
if exact_match:
|
53
60
|
self.assertEqual(expected_snippet, m.lm_input.text)
|
@@ -60,14 +67,14 @@ class QueryTest(unittest.TestCase):
|
|
60
67
|
|
61
68
|
def test_call(self):
|
62
69
|
lm = fake.StaticSequence(['1'])
|
63
|
-
self.assertEqual(
|
70
|
+
self.assertEqual(querying.query('what is 1 + 0', int, lm=lm), 1)
|
64
71
|
|
65
72
|
# Testing calling the same `lm` without copy.
|
66
73
|
with self.assertRaises(IndexError):
|
67
|
-
|
74
|
+
querying.query('what is 1 + 2', int, lm=lm)
|
68
75
|
|
69
76
|
self.assertEqual(
|
70
|
-
|
77
|
+
querying.query(
|
71
78
|
'what is 1 + 0', int, lm=lm.clone(), returns_message=True
|
72
79
|
),
|
73
80
|
lf.AIMessage(
|
@@ -75,22 +82,23 @@ class QueryTest(unittest.TestCase):
|
|
75
82
|
result=1,
|
76
83
|
score=1.0,
|
77
84
|
logprobs=None,
|
85
|
+
is_cached=False,
|
78
86
|
usage=lf.LMSamplingUsage(323, 1, 324),
|
79
87
|
tags=['lm-response', 'lm-output', 'transformed'],
|
80
88
|
),
|
81
89
|
)
|
82
90
|
self.assertEqual(
|
83
|
-
|
84
|
-
lf.Template('what is {{x}} + {{y}}'
|
91
|
+
querying.query(
|
92
|
+
lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
|
85
93
|
),
|
86
94
|
1,
|
87
95
|
)
|
88
96
|
self.assertEqual(
|
89
|
-
|
97
|
+
querying.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
|
90
98
|
1,
|
91
99
|
)
|
92
100
|
self.assertEqual(
|
93
|
-
|
101
|
+
querying.query(
|
94
102
|
'what is {{x}} + {{y}}',
|
95
103
|
x=1,
|
96
104
|
y=0,
|
@@ -99,7 +107,7 @@ class QueryTest(unittest.TestCase):
|
|
99
107
|
'The answer is one.',
|
100
108
|
)
|
101
109
|
self.assertEqual(
|
102
|
-
|
110
|
+
querying.query(
|
103
111
|
Activity.partial(),
|
104
112
|
lm=fake.StaticResponse('Activity(description="hello")'),
|
105
113
|
),
|
@@ -208,7 +216,7 @@ class QueryTest(unittest.TestCase):
|
|
208
216
|
modalities.Image.from_bytes(b'mock_image'),
|
209
217
|
int,
|
210
218
|
lm=lm,
|
211
|
-
expected_snippet='\n\nINPUT_OBJECT:\n
|
219
|
+
expected_snippet='\n\nINPUT_OBJECT:\n <<[[input]]>>\n\n',
|
212
220
|
expected_modalities=1,
|
213
221
|
)
|
214
222
|
|
@@ -218,7 +226,7 @@ class QueryTest(unittest.TestCase):
|
|
218
226
|
modalities.Image.from_bytes(b'mock_image'),
|
219
227
|
None,
|
220
228
|
lm=lm,
|
221
|
-
expected_snippet='
|
229
|
+
expected_snippet='<<[[input]]>>',
|
222
230
|
exact_match=True,
|
223
231
|
expected_modalities=1,
|
224
232
|
)
|
@@ -231,7 +239,9 @@ class QueryTest(unittest.TestCase):
|
|
231
239
|
this_image=modalities.Image.from_bytes(b'cat_image'),
|
232
240
|
that_image=modalities.Image.from_bytes(b'mouse_image'),
|
233
241
|
lm=lm,
|
234
|
-
expected_snippet=
|
242
|
+
expected_snippet=(
|
243
|
+
'What are these? <<[[this_image]]>> and <<[[that_image]]>>'
|
244
|
+
),
|
235
245
|
exact_match=True,
|
236
246
|
expected_modalities=2,
|
237
247
|
)
|
@@ -245,7 +255,7 @@ class QueryTest(unittest.TestCase):
|
|
245
255
|
],
|
246
256
|
None,
|
247
257
|
lm=lm,
|
248
|
-
expected_snippet='`[
|
258
|
+
expected_snippet='`[<<[[input[0]]]>>, <<[[input[1]]]>>]`',
|
249
259
|
exact_match=True,
|
250
260
|
expected_modalities=2,
|
251
261
|
)
|
@@ -263,33 +273,349 @@ class QueryTest(unittest.TestCase):
|
|
263
273
|
INPUT_OBJECT:
|
264
274
|
```python
|
265
275
|
[
|
266
|
-
|
267
|
-
|
268
|
-
),
|
269
|
-
ModalityRef(
|
270
|
-
name='input[1]'
|
271
|
-
)
|
276
|
+
<<[[input[0]]]>>,
|
277
|
+
<<[[input[1]]]>>
|
272
278
|
]
|
273
279
|
```
|
274
|
-
|
275
|
-
MODALITY_REFERENCES:
|
276
|
-
{
|
277
|
-
'input[0]': {{input[0]}},
|
278
|
-
'input[1]': {{input[1]}}
|
279
|
-
}
|
280
280
|
"""),
|
281
281
|
expected_modalities=2,
|
282
282
|
)
|
283
283
|
|
284
|
+
def test_structure_with_modality_and_examples_to_structure_render(self):
|
285
|
+
lm = fake.StaticResponse('["cat", "mouse"]')
|
286
|
+
self.assert_render(
|
287
|
+
[
|
288
|
+
modalities.Image.from_bytes(b'cat_image'),
|
289
|
+
modalities.Image.from_bytes(b'mouse_image'),
|
290
|
+
],
|
291
|
+
list[str],
|
292
|
+
examples=[
|
293
|
+
mapping.MappingExample(
|
294
|
+
input=[modalities.Image.from_bytes(b'dog_image')],
|
295
|
+
schema=list[str],
|
296
|
+
output=['dog'],
|
297
|
+
),
|
298
|
+
],
|
299
|
+
lm=lm,
|
300
|
+
expected_snippet=inspect.cleandoc("""
|
301
|
+
INPUT_OBJECT:
|
302
|
+
```python
|
303
|
+
[
|
304
|
+
<<[[examples[0].input[0]]]>>
|
305
|
+
]
|
306
|
+
```
|
307
|
+
|
308
|
+
OUTPUT_TYPE:
|
309
|
+
list[str]
|
310
|
+
|
311
|
+
OUTPUT_OBJECT:
|
312
|
+
```python
|
313
|
+
[
|
314
|
+
'dog'
|
315
|
+
]
|
316
|
+
```
|
317
|
+
|
318
|
+
|
319
|
+
INPUT_OBJECT:
|
320
|
+
```python
|
321
|
+
[
|
322
|
+
<<[[input[0]]]>>,
|
323
|
+
<<[[input[1]]]>>
|
324
|
+
]
|
325
|
+
```
|
326
|
+
"""),
|
327
|
+
expected_modalities=3,
|
328
|
+
)
|
329
|
+
|
330
|
+
def test_multiple_queries(self):
|
331
|
+
self.assertEqual(
|
332
|
+
querying.query(
|
333
|
+
'Compute 1 + 2',
|
334
|
+
int,
|
335
|
+
lm=[
|
336
|
+
fake.StaticResponse('1'),
|
337
|
+
fake.StaticResponse('2'),
|
338
|
+
],
|
339
|
+
num_samples=[1, 2],
|
340
|
+
),
|
341
|
+
[1, 2, 2]
|
342
|
+
)
|
343
|
+
self.assertEqual(
|
344
|
+
querying.query(
|
345
|
+
'Compute 1 + 2',
|
346
|
+
int,
|
347
|
+
lm=[
|
348
|
+
fake.StaticResponse('1'),
|
349
|
+
fake.StaticResponse('2'),
|
350
|
+
],
|
351
|
+
num_samples=2,
|
352
|
+
),
|
353
|
+
[1, 1, 2, 2]
|
354
|
+
)
|
355
|
+
self.assertEqual(
|
356
|
+
querying.query(
|
357
|
+
'Compute 1 + 2',
|
358
|
+
int,
|
359
|
+
lm=[
|
360
|
+
fake.StaticResponse('1'),
|
361
|
+
fake.StaticResponse('abc'),
|
362
|
+
],
|
363
|
+
num_samples=[1, 2],
|
364
|
+
),
|
365
|
+
[1]
|
366
|
+
)
|
367
|
+
self.assertEqual(
|
368
|
+
querying.query(
|
369
|
+
'Compute 1 + 2',
|
370
|
+
int,
|
371
|
+
default=0,
|
372
|
+
lm=[
|
373
|
+
fake.StaticResponse('1'),
|
374
|
+
fake.StaticResponse('abc'),
|
375
|
+
],
|
376
|
+
num_samples=[1, 2],
|
377
|
+
),
|
378
|
+
[1, 0, 0]
|
379
|
+
)
|
380
|
+
results = querying.query(
|
381
|
+
'Compute 1 + 2',
|
382
|
+
int,
|
383
|
+
default=0,
|
384
|
+
lm=[
|
385
|
+
fake.StaticResponse('1'),
|
386
|
+
fake.StaticResponse('abc'),
|
387
|
+
],
|
388
|
+
returns_message=True,
|
389
|
+
)
|
390
|
+
self.assertEqual([r.text for r in results], ['1', 'abc'])
|
391
|
+
self.assertEqual([r.result for r in results], [1, 0])
|
392
|
+
|
284
393
|
def test_bad_protocol(self):
|
285
394
|
with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
|
286
|
-
|
395
|
+
querying.query('what is 1 + 1', int, protocol='text')
|
396
|
+
|
397
|
+
def test_query_prompt(self):
|
398
|
+
self.assertEqual(
|
399
|
+
querying.query_prompt('what is this?', int),
|
400
|
+
inspect.cleandoc("""
|
401
|
+
Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE.
|
402
|
+
|
403
|
+
INPUT_OBJECT:
|
404
|
+
1 + 1 =
|
405
|
+
|
406
|
+
OUTPUT_TYPE:
|
407
|
+
Answer
|
408
|
+
|
409
|
+
```python
|
410
|
+
class Answer:
|
411
|
+
final_answer: int
|
412
|
+
```
|
413
|
+
|
414
|
+
OUTPUT_OBJECT:
|
415
|
+
```python
|
416
|
+
Answer(
|
417
|
+
final_answer=2
|
418
|
+
)
|
419
|
+
```
|
420
|
+
|
421
|
+
INPUT_OBJECT:
|
422
|
+
what is this?
|
423
|
+
|
424
|
+
OUTPUT_TYPE:
|
425
|
+
int
|
426
|
+
|
427
|
+
OUTPUT_OBJECT:
|
428
|
+
"""),
|
429
|
+
)
|
430
|
+
|
431
|
+
def test_query_prompt_with_metadata(self):
|
432
|
+
self.assertIn(
|
433
|
+
'x',
|
434
|
+
querying.query_prompt(
|
435
|
+
'what is this?',
|
436
|
+
metadata_x=1
|
437
|
+
).metadata
|
438
|
+
)
|
439
|
+
self.assertIn(
|
440
|
+
'x',
|
441
|
+
querying.query_prompt(
|
442
|
+
'what is this?',
|
443
|
+
int,
|
444
|
+
metadata_x=1
|
445
|
+
).metadata
|
446
|
+
)
|
447
|
+
|
448
|
+
def test_query_prompt_with_unrooted_template(self):
|
449
|
+
output = querying.query_prompt(
|
450
|
+
pg.Dict(
|
451
|
+
input=lf.Template(
|
452
|
+
'what is {{image}}',
|
453
|
+
image=modalities.Image.from_bytes(b'mock_image')
|
454
|
+
)
|
455
|
+
).input,
|
456
|
+
)
|
457
|
+
self.assertIsNotNone(output.get_modality('image'))
|
458
|
+
|
459
|
+
def test_query_and_reduce(self):
|
460
|
+
self.assertEqual(
|
461
|
+
querying.query_and_reduce(
|
462
|
+
'Compute 1 + 1',
|
463
|
+
int,
|
464
|
+
reduce=sum,
|
465
|
+
lm=[
|
466
|
+
fake.StaticResponse('1'),
|
467
|
+
fake.StaticResponse('2'),
|
468
|
+
],
|
469
|
+
num_samples=[1, 2],
|
470
|
+
),
|
471
|
+
5
|
472
|
+
)
|
473
|
+
self.assertEqual(
|
474
|
+
querying.query_and_reduce(
|
475
|
+
'Compute 1 + 1',
|
476
|
+
int,
|
477
|
+
reduce=sum,
|
478
|
+
lm=fake.StaticResponse('2'),
|
479
|
+
),
|
480
|
+
2
|
481
|
+
)
|
482
|
+
|
483
|
+
def test_query_output(self):
|
484
|
+
self.assertEqual(
|
485
|
+
querying.query_output(
|
486
|
+
lf.AIMessage('1'),
|
487
|
+
int,
|
488
|
+
),
|
489
|
+
1,
|
490
|
+
)
|
491
|
+
|
492
|
+
def test_query_reward(self):
|
493
|
+
|
494
|
+
class Answer(pg.Object):
|
495
|
+
final_answer: int
|
496
|
+
|
497
|
+
def __reward__(self, inputs: lf.Template) -> None:
|
498
|
+
diff = abs(self.final_answer - (inputs.x + inputs.y))
|
499
|
+
# Center screwed sigmoid scaled to [-1.0 and 1.0].
|
500
|
+
return 4 / (1 + math.exp(diff)) - 1.0
|
501
|
+
|
502
|
+
# Case 1: Reward function based on input and output.
|
503
|
+
self.assertEqual(
|
504
|
+
querying.query_reward(
|
505
|
+
mapping.MappingExample(
|
506
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
507
|
+
schema=Answer,
|
508
|
+
output=Answer(final_answer=2),
|
509
|
+
),
|
510
|
+
'Answer(2)'
|
511
|
+
),
|
512
|
+
1.0
|
513
|
+
)
|
514
|
+
self.assertEqual(
|
515
|
+
querying.query_reward(
|
516
|
+
mapping.MappingExample(
|
517
|
+
input=lf.Template('{{x}} + {{y}}', x=2, y=3),
|
518
|
+
output=Answer(final_answer=2),
|
519
|
+
).to_json_str(),
|
520
|
+
'Answer(5)'
|
521
|
+
),
|
522
|
+
1.0
|
523
|
+
)
|
524
|
+
|
525
|
+
# Case 2: Reward function based on input, result and expected output.
|
526
|
+
class Answer2(pg.Object):
|
527
|
+
final_answer: int
|
528
|
+
|
529
|
+
def __reward__(self, inputs: lf.Template, expected_output: 'Answer2'):
|
530
|
+
return (
|
531
|
+
1.0 if self.final_answer == expected_output.final_answer else -1.0
|
532
|
+
)
|
533
|
+
|
534
|
+
self.assertEqual(
|
535
|
+
querying.query_reward(
|
536
|
+
mapping.MappingExample(
|
537
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
538
|
+
output=Answer2(final_answer=2),
|
539
|
+
),
|
540
|
+
'Answer2(3)'
|
541
|
+
),
|
542
|
+
-1.0
|
543
|
+
)
|
544
|
+
|
545
|
+
# Case 3: Reward function based on input, result, expected output
|
546
|
+
# and metadata.
|
547
|
+
class Answer3(pg.Object):
|
548
|
+
final_answer: int
|
549
|
+
|
550
|
+
def __reward__(self,
|
551
|
+
inputs: lf.Template,
|
552
|
+
expected_output: 'Answer3',
|
553
|
+
metadata: dict[str, Any]):
|
554
|
+
del inputs
|
555
|
+
return (
|
556
|
+
1.0 if self.final_answer == expected_output.final_answer else -1.0
|
557
|
+
) * metadata['weight']
|
558
|
+
|
559
|
+
self.assertEqual(
|
560
|
+
querying.query_reward(
|
561
|
+
mapping.MappingExample(
|
562
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
563
|
+
output=Answer3(final_answer=2),
|
564
|
+
metadata=dict(weight=0.5)
|
565
|
+
),
|
566
|
+
'Answer3(3)'
|
567
|
+
),
|
568
|
+
-0.5
|
569
|
+
)
|
570
|
+
|
571
|
+
# Case 4: No reward function is provided.
|
572
|
+
class Answer4(pg.Object):
|
573
|
+
final_answer: int
|
574
|
+
|
575
|
+
self.assertIsNone(
|
576
|
+
querying.query_reward(
|
577
|
+
mapping.MappingExample(
|
578
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
579
|
+
output=Answer4(final_answer=2),
|
580
|
+
),
|
581
|
+
'Answer2(2)'
|
582
|
+
)
|
583
|
+
)
|
584
|
+
|
585
|
+
# Case 5: Not a structured output.
|
586
|
+
self.assertIsNone(
|
587
|
+
querying.query_reward(
|
588
|
+
mapping.MappingExample(
|
589
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
590
|
+
output='2',
|
591
|
+
),
|
592
|
+
'2'
|
593
|
+
)
|
594
|
+
)
|
595
|
+
|
596
|
+
# Case 6: Bad reward function.
|
597
|
+
class Answer5(pg.Object):
|
598
|
+
final_answer: int
|
599
|
+
|
600
|
+
def __reward__(self):
|
601
|
+
return 0.0
|
602
|
+
|
603
|
+
with self.assertRaisesRegex(
|
604
|
+
TypeError, '.*Answer5.__reward__` should have signature'
|
605
|
+
):
|
606
|
+
querying.query_reward(
|
607
|
+
mapping.MappingExample(
|
608
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
609
|
+
output=Answer5(final_answer=2),
|
610
|
+
),
|
611
|
+
'Answer5(2)'
|
612
|
+
)
|
287
613
|
|
288
614
|
|
289
615
|
class QueryStructurePythonTest(unittest.TestCase):
|
290
616
|
|
291
617
|
def test_render_no_examples(self):
|
292
|
-
l =
|
618
|
+
l = querying._QueryStructurePython(
|
293
619
|
input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
|
294
620
|
)
|
295
621
|
self.assertEqual(
|
@@ -326,7 +652,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
326
652
|
)
|
327
653
|
|
328
654
|
def test_render(self):
|
329
|
-
l =
|
655
|
+
l = querying._QueryStructurePython(
|
330
656
|
input=lf.AIMessage('Compute 12 / 6 + 2.'),
|
331
657
|
schema=int,
|
332
658
|
examples=[
|
@@ -436,7 +762,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
436
762
|
),
|
437
763
|
override_attrs=True,
|
438
764
|
):
|
439
|
-
l =
|
765
|
+
l = querying._QueryStructurePython(
|
440
766
|
input=lm_input,
|
441
767
|
schema=[Itinerary],
|
442
768
|
examples=[
|
@@ -473,7 +799,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
473
799
|
mapping.MappingError,
|
474
800
|
'name .* is not defined',
|
475
801
|
):
|
476
|
-
|
802
|
+
querying.query('Compute 1 + 2', int)
|
477
803
|
|
478
804
|
def test_autofix(self):
|
479
805
|
lm = fake.StaticSequence([
|
@@ -484,7 +810,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
484
810
|
)
|
485
811
|
"""),
|
486
812
|
])
|
487
|
-
self.assertEqual(
|
813
|
+
self.assertEqual(querying.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
|
488
814
|
|
489
815
|
def test_response_postprocess(self):
|
490
816
|
with lf.context(
|
@@ -492,12 +818,12 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
492
818
|
override_attrs=True,
|
493
819
|
):
|
494
820
|
self.assertEqual(
|
495
|
-
|
821
|
+
querying.query(
|
496
822
|
'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
|
497
823
|
'3'
|
498
824
|
)
|
499
825
|
self.assertEqual(
|
500
|
-
|
826
|
+
querying.query(
|
501
827
|
'Compute 1 + 2', int,
|
502
828
|
response_postprocess=lambda x: x.split('\n')[1]),
|
503
829
|
3
|
@@ -507,7 +833,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
507
833
|
class QueryStructureJsonTest(unittest.TestCase):
|
508
834
|
|
509
835
|
def test_render_no_examples(self):
|
510
|
-
l =
|
836
|
+
l = querying._QueryStructureJson(
|
511
837
|
input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
|
512
838
|
)
|
513
839
|
self.assertEqual(
|
@@ -523,10 +849,10 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
523
849
|
1 + 1 =
|
524
850
|
|
525
851
|
SCHEMA:
|
526
|
-
{"result": {"_type": "langfun.core.structured.
|
852
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
527
853
|
|
528
854
|
JSON:
|
529
|
-
{"result": {"_type": "langfun.core.structured.
|
855
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
530
856
|
|
531
857
|
INPUT_OBJECT:
|
532
858
|
Compute 12 / 6 + 2.
|
@@ -539,7 +865,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
539
865
|
)
|
540
866
|
|
541
867
|
def test_render(self):
|
542
|
-
l =
|
868
|
+
l = querying._QueryStructureJson(
|
543
869
|
input=lf.AIMessage('Compute 12 / 6 + 2.'),
|
544
870
|
schema=int,
|
545
871
|
examples=[
|
@@ -560,10 +886,10 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
560
886
|
1 + 1 =
|
561
887
|
|
562
888
|
SCHEMA:
|
563
|
-
{"result": {"_type": "langfun.core.structured.
|
889
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
564
890
|
|
565
891
|
JSON:
|
566
|
-
{"result": {"_type": "langfun.core.structured.
|
892
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
567
893
|
|
568
894
|
INPUT_OBJECT:
|
569
895
|
What is the answer of 1 plus 1?
|
@@ -674,7 +1000,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
674
1000
|
),
|
675
1001
|
override_attrs=True,
|
676
1002
|
):
|
677
|
-
l =
|
1003
|
+
l = querying._QueryStructureJson(
|
678
1004
|
input=lm_input,
|
679
1005
|
schema=[Itinerary],
|
680
1006
|
examples=[
|
@@ -703,22 +1029,114 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
703
1029
|
self.assertIsNone(r.result[0].hotel)
|
704
1030
|
|
705
1031
|
def test_bad_transform(self):
|
706
|
-
with
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
with self.assertRaisesRegex(
|
711
|
-
mapping.MappingError,
|
712
|
-
'No JSON dict in the output',
|
1032
|
+
with in_memory.lm_cache() as cache:
|
1033
|
+
with lf.context(
|
1034
|
+
lm=fake.StaticSequence(['3']),
|
1035
|
+
override_attrs=True,
|
713
1036
|
):
|
714
|
-
|
1037
|
+
with self.assertRaisesRegex(
|
1038
|
+
mapping.MappingError,
|
1039
|
+
'No JSON dict in the output',
|
1040
|
+
):
|
1041
|
+
querying.query('Compute 1 + 2', int, protocol='json', cache_seed=1)
|
1042
|
+
# Make sure bad mapping does not impact cache.
|
1043
|
+
self.assertEqual(len(cache), 0)
|
715
1044
|
|
716
1045
|
def test_query(self):
|
717
1046
|
lm = fake.StaticSequence(['{"result": 1}'])
|
718
1047
|
self.assertEqual(
|
719
|
-
|
1048
|
+
querying.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
|
720
1049
|
)
|
721
1050
|
|
722
1051
|
|
1052
|
+
class QueryInvocationTest(unittest.TestCase):
|
1053
|
+
|
1054
|
+
def test_basics(self):
|
1055
|
+
lm = fake.StaticSequence([
|
1056
|
+
'Activity(description="hi"',
|
1057
|
+
])
|
1058
|
+
with querying.track_queries() as queries:
|
1059
|
+
querying.query('foo', Activity, default=None, lm=lm)
|
1060
|
+
|
1061
|
+
self.assertTrue(queries[0].has_error)
|
1062
|
+
self.assertIsInstance(queries[0].output, mapping.MappingError)
|
1063
|
+
|
1064
|
+
def test_to_html(self):
|
1065
|
+
lm = fake.StaticSequence([
|
1066
|
+
'Activity(description="hi")',
|
1067
|
+
])
|
1068
|
+
with querying.track_queries() as queries:
|
1069
|
+
querying.query('foo', Activity, lm=lm)
|
1070
|
+
|
1071
|
+
self.assertIn('schema', queries[0].to_html_str())
|
1072
|
+
|
1073
|
+
|
1074
|
+
class TrackQueriesTest(unittest.TestCase):
|
1075
|
+
|
1076
|
+
def test_include_child_scopes(self):
|
1077
|
+
lm = fake.StaticSequence([
|
1078
|
+
'bar',
|
1079
|
+
'Activity(description="hi")',
|
1080
|
+
])
|
1081
|
+
with querying.track_queries() as queries:
|
1082
|
+
querying.query('foo', lm=lm)
|
1083
|
+
with querying.track_queries() as child_queries:
|
1084
|
+
querying.query('give me an activity', Activity, lm=lm)
|
1085
|
+
|
1086
|
+
self.assertEqual(len(queries), 2)
|
1087
|
+
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
1088
|
+
self.assertIsNone(queries[0].schema)
|
1089
|
+
self.assertEqual(queries[0].output, 'bar')
|
1090
|
+
self.assertIs(queries[0].lm, lm)
|
1091
|
+
|
1092
|
+
self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity')))
|
1093
|
+
self.assertEqual(queries[1].schema.spec.cls, Activity)
|
1094
|
+
self.assertTrue(pg.eq(queries[1].output, Activity(description='hi')))
|
1095
|
+
self.assertIs(queries[1].lm, lm)
|
1096
|
+
self.assertGreater(queries[0].elapse, 0)
|
1097
|
+
self.assertGreater(queries[0].usage_summary.total.total_tokens, 0)
|
1098
|
+
self.assertGreater(queries[1].usage_summary.total.total_tokens, 0)
|
1099
|
+
|
1100
|
+
self.assertEqual(len(child_queries), 1)
|
1101
|
+
self.assertIs(child_queries[0], queries[1])
|
1102
|
+
|
1103
|
+
def test_exclude_child_scopes(self):
|
1104
|
+
lm = fake.StaticSequence([
|
1105
|
+
'bar',
|
1106
|
+
'Activity(description="hi")',
|
1107
|
+
])
|
1108
|
+
with querying.track_queries(include_child_scopes=False) as queries:
|
1109
|
+
querying.query('foo', lm=lm)
|
1110
|
+
with querying.track_queries(include_child_scopes=False) as child_queries:
|
1111
|
+
querying.query('give me an activity', Activity, lm=lm)
|
1112
|
+
|
1113
|
+
self.assertEqual(len(queries), 1)
|
1114
|
+
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
1115
|
+
self.assertIsNone(queries[0].schema)
|
1116
|
+
self.assertEqual(queries[0].output, 'bar')
|
1117
|
+
self.assertIs(queries[0].lm, lm)
|
1118
|
+
|
1119
|
+
self.assertEqual(len(child_queries), 1)
|
1120
|
+
self.assertTrue(
|
1121
|
+
pg.eq(child_queries[0].input, lf.Template('give me an activity'))
|
1122
|
+
)
|
1123
|
+
self.assertEqual(child_queries[0].schema.spec.cls, Activity)
|
1124
|
+
self.assertTrue(pg.eq(child_queries[0].output, Activity(description='hi')))
|
1125
|
+
self.assertIs(child_queries[0].lm, lm)
|
1126
|
+
|
1127
|
+
def test_concurrent_map(self):
|
1128
|
+
|
1129
|
+
def make_query(prompt):
|
1130
|
+
_ = querying.query(prompt, lm=lm)
|
1131
|
+
|
1132
|
+
lm = fake.StaticSequence([
|
1133
|
+
'foo',
|
1134
|
+
'bar',
|
1135
|
+
])
|
1136
|
+
with querying.track_queries() as queries:
|
1137
|
+
list(lf.concurrent_map(make_query, ['a', 'b']))
|
1138
|
+
self.assertEqual(len(queries), 2)
|
1139
|
+
|
1140
|
+
|
723
1141
|
if __name__ == '__main__':
|
724
1142
|
unittest.main()
|