langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -11,16 +11,19 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for structured prompting."""
14
+ """Tests for structured query."""
15
15
 
16
16
  import inspect
17
+ import math
18
+ from typing import Any
17
19
  import unittest
18
20
 
19
21
  import langfun.core as lf
20
22
  from langfun.core import modalities
21
23
  from langfun.core.llms import fake
24
+ from langfun.core.llms.cache import in_memory
22
25
  from langfun.core.structured import mapping
23
- from langfun.core.structured import prompting
26
+ from langfun.core.structured import querying
24
27
  import pyglove as pg
25
28
 
26
29
 
@@ -41,13 +44,17 @@ class QueryTest(unittest.TestCase):
41
44
  self,
42
45
  prompt,
43
46
  schema,
47
+ examples: list[mapping.MappingExample] | None = None,
44
48
  *,
45
49
  expected_snippet: str,
46
50
  exact_match: bool = False,
47
51
  expected_modalities: int = 0,
48
52
  **kwargs,
49
53
  ):
50
- m = prompting.query(prompt, schema=schema, **kwargs, returns_message=True)
54
+ m = querying.query(
55
+ prompt, schema=schema, examples=examples,
56
+ **kwargs, returns_message=True
57
+ )
51
58
  self.assertIsNotNone(m.lm_input)
52
59
  if exact_match:
53
60
  self.assertEqual(expected_snippet, m.lm_input.text)
@@ -60,14 +67,14 @@ class QueryTest(unittest.TestCase):
60
67
 
61
68
  def test_call(self):
62
69
  lm = fake.StaticSequence(['1'])
63
- self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm), 1)
70
+ self.assertEqual(querying.query('what is 1 + 0', int, lm=lm), 1)
64
71
 
65
72
  # Testing calling the same `lm` without copy.
66
73
  with self.assertRaises(IndexError):
67
- prompting.query('what is 1 + 2', int, lm=lm)
74
+ querying.query('what is 1 + 2', int, lm=lm)
68
75
 
69
76
  self.assertEqual(
70
- prompting.query(
77
+ querying.query(
71
78
  'what is 1 + 0', int, lm=lm.clone(), returns_message=True
72
79
  ),
73
80
  lf.AIMessage(
@@ -75,22 +82,23 @@ class QueryTest(unittest.TestCase):
75
82
  result=1,
76
83
  score=1.0,
77
84
  logprobs=None,
85
+ is_cached=False,
78
86
  usage=lf.LMSamplingUsage(323, 1, 324),
79
87
  tags=['lm-response', 'lm-output', 'transformed'],
80
88
  ),
81
89
  )
82
90
  self.assertEqual(
83
- prompting.query(
84
- lf.Template('what is {{x}} + {{y}}'), int, x=1, y=0, lm=lm.clone()
91
+ querying.query(
92
+ lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
85
93
  ),
86
94
  1,
87
95
  )
88
96
  self.assertEqual(
89
- prompting.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
97
+ querying.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
90
98
  1,
91
99
  )
92
100
  self.assertEqual(
93
- prompting.query(
101
+ querying.query(
94
102
  'what is {{x}} + {{y}}',
95
103
  x=1,
96
104
  y=0,
@@ -99,7 +107,7 @@ class QueryTest(unittest.TestCase):
99
107
  'The answer is one.',
100
108
  )
101
109
  self.assertEqual(
102
- prompting.query(
110
+ querying.query(
103
111
  Activity.partial(),
104
112
  lm=fake.StaticResponse('Activity(description="hello")'),
105
113
  ),
@@ -208,7 +216,7 @@ class QueryTest(unittest.TestCase):
208
216
  modalities.Image.from_bytes(b'mock_image'),
209
217
  int,
210
218
  lm=lm,
211
- expected_snippet='\n\nINPUT_OBJECT:\n {{input}}\n\n',
219
+ expected_snippet='\n\nINPUT_OBJECT:\n <<[[input]]>>\n\n',
212
220
  expected_modalities=1,
213
221
  )
214
222
 
@@ -218,7 +226,7 @@ class QueryTest(unittest.TestCase):
218
226
  modalities.Image.from_bytes(b'mock_image'),
219
227
  None,
220
228
  lm=lm,
221
- expected_snippet='{{input}}',
229
+ expected_snippet='<<[[input]]>>',
222
230
  exact_match=True,
223
231
  expected_modalities=1,
224
232
  )
@@ -231,7 +239,9 @@ class QueryTest(unittest.TestCase):
231
239
  this_image=modalities.Image.from_bytes(b'cat_image'),
232
240
  that_image=modalities.Image.from_bytes(b'mouse_image'),
233
241
  lm=lm,
234
- expected_snippet='What are these? {{this_image}} and {{that_image}}',
242
+ expected_snippet=(
243
+ 'What are these? <<[[this_image]]>> and <<[[that_image]]>>'
244
+ ),
235
245
  exact_match=True,
236
246
  expected_modalities=2,
237
247
  )
@@ -245,7 +255,7 @@ class QueryTest(unittest.TestCase):
245
255
  ],
246
256
  None,
247
257
  lm=lm,
248
- expected_snippet='`[{{input[0]}}, {{input[1]}}]`',
258
+ expected_snippet='`[<<[[input[0]]]>>, <<[[input[1]]]>>]`',
249
259
  exact_match=True,
250
260
  expected_modalities=2,
251
261
  )
@@ -263,33 +273,349 @@ class QueryTest(unittest.TestCase):
263
273
  INPUT_OBJECT:
264
274
  ```python
265
275
  [
266
- ModalityRef(
267
- name='input[0]'
268
- ),
269
- ModalityRef(
270
- name='input[1]'
271
- )
276
+ <<[[input[0]]]>>,
277
+ <<[[input[1]]]>>
272
278
  ]
273
279
  ```
274
-
275
- MODALITY_REFERENCES:
276
- {
277
- 'input[0]': {{input[0]}},
278
- 'input[1]': {{input[1]}}
279
- }
280
280
  """),
281
281
  expected_modalities=2,
282
282
  )
283
283
 
284
+ def test_structure_with_modality_and_examples_to_structure_render(self):
285
+ lm = fake.StaticResponse('["cat", "mouse"]')
286
+ self.assert_render(
287
+ [
288
+ modalities.Image.from_bytes(b'cat_image'),
289
+ modalities.Image.from_bytes(b'mouse_image'),
290
+ ],
291
+ list[str],
292
+ examples=[
293
+ mapping.MappingExample(
294
+ input=[modalities.Image.from_bytes(b'dog_image')],
295
+ schema=list[str],
296
+ output=['dog'],
297
+ ),
298
+ ],
299
+ lm=lm,
300
+ expected_snippet=inspect.cleandoc("""
301
+ INPUT_OBJECT:
302
+ ```python
303
+ [
304
+ <<[[examples[0].input[0]]]>>
305
+ ]
306
+ ```
307
+
308
+ OUTPUT_TYPE:
309
+ list[str]
310
+
311
+ OUTPUT_OBJECT:
312
+ ```python
313
+ [
314
+ 'dog'
315
+ ]
316
+ ```
317
+
318
+
319
+ INPUT_OBJECT:
320
+ ```python
321
+ [
322
+ <<[[input[0]]]>>,
323
+ <<[[input[1]]]>>
324
+ ]
325
+ ```
326
+ """),
327
+ expected_modalities=3,
328
+ )
329
+
330
+ def test_multiple_queries(self):
331
+ self.assertEqual(
332
+ querying.query(
333
+ 'Compute 1 + 2',
334
+ int,
335
+ lm=[
336
+ fake.StaticResponse('1'),
337
+ fake.StaticResponse('2'),
338
+ ],
339
+ num_samples=[1, 2],
340
+ ),
341
+ [1, 2, 2]
342
+ )
343
+ self.assertEqual(
344
+ querying.query(
345
+ 'Compute 1 + 2',
346
+ int,
347
+ lm=[
348
+ fake.StaticResponse('1'),
349
+ fake.StaticResponse('2'),
350
+ ],
351
+ num_samples=2,
352
+ ),
353
+ [1, 1, 2, 2]
354
+ )
355
+ self.assertEqual(
356
+ querying.query(
357
+ 'Compute 1 + 2',
358
+ int,
359
+ lm=[
360
+ fake.StaticResponse('1'),
361
+ fake.StaticResponse('abc'),
362
+ ],
363
+ num_samples=[1, 2],
364
+ ),
365
+ [1]
366
+ )
367
+ self.assertEqual(
368
+ querying.query(
369
+ 'Compute 1 + 2',
370
+ int,
371
+ default=0,
372
+ lm=[
373
+ fake.StaticResponse('1'),
374
+ fake.StaticResponse('abc'),
375
+ ],
376
+ num_samples=[1, 2],
377
+ ),
378
+ [1, 0, 0]
379
+ )
380
+ results = querying.query(
381
+ 'Compute 1 + 2',
382
+ int,
383
+ default=0,
384
+ lm=[
385
+ fake.StaticResponse('1'),
386
+ fake.StaticResponse('abc'),
387
+ ],
388
+ returns_message=True,
389
+ )
390
+ self.assertEqual([r.text for r in results], ['1', 'abc'])
391
+ self.assertEqual([r.result for r in results], [1, 0])
392
+
284
393
  def test_bad_protocol(self):
285
394
  with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
286
- prompting.query('what is 1 + 1', int, protocol='text')
395
+ querying.query('what is 1 + 1', int, protocol='text')
396
+
397
+ def test_query_prompt(self):
398
+ self.assertEqual(
399
+ querying.query_prompt('what is this?', int),
400
+ inspect.cleandoc("""
401
+ Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE.
402
+
403
+ INPUT_OBJECT:
404
+ 1 + 1 =
405
+
406
+ OUTPUT_TYPE:
407
+ Answer
408
+
409
+ ```python
410
+ class Answer:
411
+ final_answer: int
412
+ ```
413
+
414
+ OUTPUT_OBJECT:
415
+ ```python
416
+ Answer(
417
+ final_answer=2
418
+ )
419
+ ```
420
+
421
+ INPUT_OBJECT:
422
+ what is this?
423
+
424
+ OUTPUT_TYPE:
425
+ int
426
+
427
+ OUTPUT_OBJECT:
428
+ """),
429
+ )
430
+
431
+ def test_query_prompt_with_metadata(self):
432
+ self.assertIn(
433
+ 'x',
434
+ querying.query_prompt(
435
+ 'what is this?',
436
+ metadata_x=1
437
+ ).metadata
438
+ )
439
+ self.assertIn(
440
+ 'x',
441
+ querying.query_prompt(
442
+ 'what is this?',
443
+ int,
444
+ metadata_x=1
445
+ ).metadata
446
+ )
447
+
448
+ def test_query_prompt_with_unrooted_template(self):
449
+ output = querying.query_prompt(
450
+ pg.Dict(
451
+ input=lf.Template(
452
+ 'what is {{image}}',
453
+ image=modalities.Image.from_bytes(b'mock_image')
454
+ )
455
+ ).input,
456
+ )
457
+ self.assertIsNotNone(output.get_modality('image'))
458
+
459
+ def test_query_and_reduce(self):
460
+ self.assertEqual(
461
+ querying.query_and_reduce(
462
+ 'Compute 1 + 1',
463
+ int,
464
+ reduce=sum,
465
+ lm=[
466
+ fake.StaticResponse('1'),
467
+ fake.StaticResponse('2'),
468
+ ],
469
+ num_samples=[1, 2],
470
+ ),
471
+ 5
472
+ )
473
+ self.assertEqual(
474
+ querying.query_and_reduce(
475
+ 'Compute 1 + 1',
476
+ int,
477
+ reduce=sum,
478
+ lm=fake.StaticResponse('2'),
479
+ ),
480
+ 2
481
+ )
482
+
483
+ def test_query_output(self):
484
+ self.assertEqual(
485
+ querying.query_output(
486
+ lf.AIMessage('1'),
487
+ int,
488
+ ),
489
+ 1,
490
+ )
491
+
492
+ def test_query_reward(self):
493
+
494
+ class Answer(pg.Object):
495
+ final_answer: int
496
+
497
+ def __reward__(self, inputs: lf.Template) -> None:
498
+ diff = abs(self.final_answer - (inputs.x + inputs.y))
499
+ # Center screwed sigmoid scaled to [-1.0 and 1.0].
500
+ return 4 / (1 + math.exp(diff)) - 1.0
501
+
502
+ # Case 1: Reward function based on input and output.
503
+ self.assertEqual(
504
+ querying.query_reward(
505
+ mapping.MappingExample(
506
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
507
+ schema=Answer,
508
+ output=Answer(final_answer=2),
509
+ ),
510
+ 'Answer(2)'
511
+ ),
512
+ 1.0
513
+ )
514
+ self.assertEqual(
515
+ querying.query_reward(
516
+ mapping.MappingExample(
517
+ input=lf.Template('{{x}} + {{y}}', x=2, y=3),
518
+ output=Answer(final_answer=2),
519
+ ).to_json_str(),
520
+ 'Answer(5)'
521
+ ),
522
+ 1.0
523
+ )
524
+
525
+ # Case 2: Reward function based on input, result and expected output.
526
+ class Answer2(pg.Object):
527
+ final_answer: int
528
+
529
+ def __reward__(self, inputs: lf.Template, expected_output: 'Answer2'):
530
+ return (
531
+ 1.0 if self.final_answer == expected_output.final_answer else -1.0
532
+ )
533
+
534
+ self.assertEqual(
535
+ querying.query_reward(
536
+ mapping.MappingExample(
537
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
538
+ output=Answer2(final_answer=2),
539
+ ),
540
+ 'Answer2(3)'
541
+ ),
542
+ -1.0
543
+ )
544
+
545
+ # Case 3: Reward function based on input, result, expected output
546
+ # and metadata.
547
+ class Answer3(pg.Object):
548
+ final_answer: int
549
+
550
+ def __reward__(self,
551
+ inputs: lf.Template,
552
+ expected_output: 'Answer3',
553
+ metadata: dict[str, Any]):
554
+ del inputs
555
+ return (
556
+ 1.0 if self.final_answer == expected_output.final_answer else -1.0
557
+ ) * metadata['weight']
558
+
559
+ self.assertEqual(
560
+ querying.query_reward(
561
+ mapping.MappingExample(
562
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
563
+ output=Answer3(final_answer=2),
564
+ metadata=dict(weight=0.5)
565
+ ),
566
+ 'Answer3(3)'
567
+ ),
568
+ -0.5
569
+ )
570
+
571
+ # Case 4: No reward function is provided.
572
+ class Answer4(pg.Object):
573
+ final_answer: int
574
+
575
+ self.assertIsNone(
576
+ querying.query_reward(
577
+ mapping.MappingExample(
578
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
579
+ output=Answer4(final_answer=2),
580
+ ),
581
+ 'Answer2(2)'
582
+ )
583
+ )
584
+
585
+ # Case 5: Not a structured output.
586
+ self.assertIsNone(
587
+ querying.query_reward(
588
+ mapping.MappingExample(
589
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
590
+ output='2',
591
+ ),
592
+ '2'
593
+ )
594
+ )
595
+
596
+ # Case 6: Bad reward function.
597
+ class Answer5(pg.Object):
598
+ final_answer: int
599
+
600
+ def __reward__(self):
601
+ return 0.0
602
+
603
+ with self.assertRaisesRegex(
604
+ TypeError, '.*Answer5.__reward__` should have signature'
605
+ ):
606
+ querying.query_reward(
607
+ mapping.MappingExample(
608
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
609
+ output=Answer5(final_answer=2),
610
+ ),
611
+ 'Answer5(2)'
612
+ )
287
613
 
288
614
 
289
615
  class QueryStructurePythonTest(unittest.TestCase):
290
616
 
291
617
  def test_render_no_examples(self):
292
- l = prompting.QueryStructurePython(
618
+ l = querying._QueryStructurePython(
293
619
  input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
294
620
  )
295
621
  self.assertEqual(
@@ -326,7 +652,7 @@ class QueryStructurePythonTest(unittest.TestCase):
326
652
  )
327
653
 
328
654
  def test_render(self):
329
- l = prompting.QueryStructurePython(
655
+ l = querying._QueryStructurePython(
330
656
  input=lf.AIMessage('Compute 12 / 6 + 2.'),
331
657
  schema=int,
332
658
  examples=[
@@ -436,7 +762,7 @@ class QueryStructurePythonTest(unittest.TestCase):
436
762
  ),
437
763
  override_attrs=True,
438
764
  ):
439
- l = prompting.QueryStructurePython(
765
+ l = querying._QueryStructurePython(
440
766
  input=lm_input,
441
767
  schema=[Itinerary],
442
768
  examples=[
@@ -473,7 +799,7 @@ class QueryStructurePythonTest(unittest.TestCase):
473
799
  mapping.MappingError,
474
800
  'name .* is not defined',
475
801
  ):
476
- prompting.query('Compute 1 + 2', int)
802
+ querying.query('Compute 1 + 2', int)
477
803
 
478
804
  def test_autofix(self):
479
805
  lm = fake.StaticSequence([
@@ -484,7 +810,7 @@ class QueryStructurePythonTest(unittest.TestCase):
484
810
  )
485
811
  """),
486
812
  ])
487
- self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
813
+ self.assertEqual(querying.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
488
814
 
489
815
  def test_response_postprocess(self):
490
816
  with lf.context(
@@ -492,12 +818,12 @@ class QueryStructurePythonTest(unittest.TestCase):
492
818
  override_attrs=True,
493
819
  ):
494
820
  self.assertEqual(
495
- prompting.query(
821
+ querying.query(
496
822
  'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
497
823
  '3'
498
824
  )
499
825
  self.assertEqual(
500
- prompting.query(
826
+ querying.query(
501
827
  'Compute 1 + 2', int,
502
828
  response_postprocess=lambda x: x.split('\n')[1]),
503
829
  3
@@ -507,7 +833,7 @@ class QueryStructurePythonTest(unittest.TestCase):
507
833
  class QueryStructureJsonTest(unittest.TestCase):
508
834
 
509
835
  def test_render_no_examples(self):
510
- l = prompting.QueryStructureJson(
836
+ l = querying._QueryStructureJson(
511
837
  input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
512
838
  )
513
839
  self.assertEqual(
@@ -523,10 +849,10 @@ class QueryStructureJsonTest(unittest.TestCase):
523
849
  1 + 1 =
524
850
 
525
851
  SCHEMA:
526
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}}
852
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
527
853
 
528
854
  JSON:
529
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}}
855
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
530
856
 
531
857
  INPUT_OBJECT:
532
858
  Compute 12 / 6 + 2.
@@ -539,7 +865,7 @@ class QueryStructureJsonTest(unittest.TestCase):
539
865
  )
540
866
 
541
867
  def test_render(self):
542
- l = prompting.QueryStructureJson(
868
+ l = querying._QueryStructureJson(
543
869
  input=lf.AIMessage('Compute 12 / 6 + 2.'),
544
870
  schema=int,
545
871
  examples=[
@@ -560,10 +886,10 @@ class QueryStructureJsonTest(unittest.TestCase):
560
886
  1 + 1 =
561
887
 
562
888
  SCHEMA:
563
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}}
889
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
564
890
 
565
891
  JSON:
566
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}}
892
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
567
893
 
568
894
  INPUT_OBJECT:
569
895
  What is the answer of 1 plus 1?
@@ -674,7 +1000,7 @@ class QueryStructureJsonTest(unittest.TestCase):
674
1000
  ),
675
1001
  override_attrs=True,
676
1002
  ):
677
- l = prompting.QueryStructureJson(
1003
+ l = querying._QueryStructureJson(
678
1004
  input=lm_input,
679
1005
  schema=[Itinerary],
680
1006
  examples=[
@@ -703,22 +1029,114 @@ class QueryStructureJsonTest(unittest.TestCase):
703
1029
  self.assertIsNone(r.result[0].hotel)
704
1030
 
705
1031
  def test_bad_transform(self):
706
- with lf.context(
707
- lm=fake.StaticSequence(['3']),
708
- override_attrs=True,
709
- ):
710
- with self.assertRaisesRegex(
711
- mapping.MappingError,
712
- 'No JSON dict in the output',
1032
+ with in_memory.lm_cache() as cache:
1033
+ with lf.context(
1034
+ lm=fake.StaticSequence(['3']),
1035
+ override_attrs=True,
713
1036
  ):
714
- prompting.query('Compute 1 + 2', int, protocol='json')
1037
+ with self.assertRaisesRegex(
1038
+ mapping.MappingError,
1039
+ 'No JSON dict in the output',
1040
+ ):
1041
+ querying.query('Compute 1 + 2', int, protocol='json', cache_seed=1)
1042
+ # Make sure bad mapping does not impact cache.
1043
+ self.assertEqual(len(cache), 0)
715
1044
 
716
1045
  def test_query(self):
717
1046
  lm = fake.StaticSequence(['{"result": 1}'])
718
1047
  self.assertEqual(
719
- prompting.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
1048
+ querying.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
720
1049
  )
721
1050
 
722
1051
 
1052
+ class QueryInvocationTest(unittest.TestCase):
1053
+
1054
+ def test_basics(self):
1055
+ lm = fake.StaticSequence([
1056
+ 'Activity(description="hi"',
1057
+ ])
1058
+ with querying.track_queries() as queries:
1059
+ querying.query('foo', Activity, default=None, lm=lm)
1060
+
1061
+ self.assertTrue(queries[0].has_error)
1062
+ self.assertIsInstance(queries[0].output, mapping.MappingError)
1063
+
1064
+ def test_to_html(self):
1065
+ lm = fake.StaticSequence([
1066
+ 'Activity(description="hi")',
1067
+ ])
1068
+ with querying.track_queries() as queries:
1069
+ querying.query('foo', Activity, lm=lm)
1070
+
1071
+ self.assertIn('schema', queries[0].to_html_str())
1072
+
1073
+
1074
+ class TrackQueriesTest(unittest.TestCase):
1075
+
1076
+ def test_include_child_scopes(self):
1077
+ lm = fake.StaticSequence([
1078
+ 'bar',
1079
+ 'Activity(description="hi")',
1080
+ ])
1081
+ with querying.track_queries() as queries:
1082
+ querying.query('foo', lm=lm)
1083
+ with querying.track_queries() as child_queries:
1084
+ querying.query('give me an activity', Activity, lm=lm)
1085
+
1086
+ self.assertEqual(len(queries), 2)
1087
+ self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
1088
+ self.assertIsNone(queries[0].schema)
1089
+ self.assertEqual(queries[0].output, 'bar')
1090
+ self.assertIs(queries[0].lm, lm)
1091
+
1092
+ self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity')))
1093
+ self.assertEqual(queries[1].schema.spec.cls, Activity)
1094
+ self.assertTrue(pg.eq(queries[1].output, Activity(description='hi')))
1095
+ self.assertIs(queries[1].lm, lm)
1096
+ self.assertGreater(queries[0].elapse, 0)
1097
+ self.assertGreater(queries[0].usage_summary.total.total_tokens, 0)
1098
+ self.assertGreater(queries[1].usage_summary.total.total_tokens, 0)
1099
+
1100
+ self.assertEqual(len(child_queries), 1)
1101
+ self.assertIs(child_queries[0], queries[1])
1102
+
1103
+ def test_exclude_child_scopes(self):
1104
+ lm = fake.StaticSequence([
1105
+ 'bar',
1106
+ 'Activity(description="hi")',
1107
+ ])
1108
+ with querying.track_queries(include_child_scopes=False) as queries:
1109
+ querying.query('foo', lm=lm)
1110
+ with querying.track_queries(include_child_scopes=False) as child_queries:
1111
+ querying.query('give me an activity', Activity, lm=lm)
1112
+
1113
+ self.assertEqual(len(queries), 1)
1114
+ self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
1115
+ self.assertIsNone(queries[0].schema)
1116
+ self.assertEqual(queries[0].output, 'bar')
1117
+ self.assertIs(queries[0].lm, lm)
1118
+
1119
+ self.assertEqual(len(child_queries), 1)
1120
+ self.assertTrue(
1121
+ pg.eq(child_queries[0].input, lf.Template('give me an activity'))
1122
+ )
1123
+ self.assertEqual(child_queries[0].schema.spec.cls, Activity)
1124
+ self.assertTrue(pg.eq(child_queries[0].output, Activity(description='hi')))
1125
+ self.assertIs(child_queries[0].lm, lm)
1126
+
1127
+ def test_concurrent_map(self):
1128
+
1129
+ def make_query(prompt):
1130
+ _ = querying.query(prompt, lm=lm)
1131
+
1132
+ lm = fake.StaticSequence([
1133
+ 'foo',
1134
+ 'bar',
1135
+ ])
1136
+ with querying.track_queries() as queries:
1137
+ list(lf.concurrent_map(make_query, ['a', 'b']))
1138
+ self.assertEqual(len(queries), 2)
1139
+
1140
+
723
1141
  if __name__ == '__main__':
724
1142
  unittest.main()