langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -11,18 +11,19 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for structured prompting."""
14
+ """Tests for structured query."""
15
15
 
16
16
  import inspect
17
+ import math
18
+ from typing import Any
17
19
  import unittest
18
20
 
19
21
  import langfun.core as lf
20
- from langfun.core import coding
21
22
  from langfun.core import modalities
22
23
  from langfun.core.llms import fake
24
+ from langfun.core.llms.cache import in_memory
23
25
  from langfun.core.structured import mapping
24
- from langfun.core.structured import prompting
25
- from langfun.core.structured import schema as schema_lib
26
+ from langfun.core.structured import querying
26
27
  import pyglove as pg
27
28
 
28
29
 
@@ -43,13 +44,17 @@ class QueryTest(unittest.TestCase):
43
44
  self,
44
45
  prompt,
45
46
  schema,
47
+ examples: list[mapping.MappingExample] | None = None,
46
48
  *,
47
49
  expected_snippet: str,
48
50
  exact_match: bool = False,
49
51
  expected_modalities: int = 0,
50
52
  **kwargs,
51
53
  ):
52
- m = prompting.query(prompt, schema=schema, **kwargs, returns_message=True)
54
+ m = querying.query(
55
+ prompt, schema=schema, examples=examples,
56
+ **kwargs, returns_message=True
57
+ )
53
58
  self.assertIsNotNone(m.lm_input)
54
59
  if exact_match:
55
60
  self.assertEqual(expected_snippet, m.lm_input.text)
@@ -62,14 +67,14 @@ class QueryTest(unittest.TestCase):
62
67
 
63
68
  def test_call(self):
64
69
  lm = fake.StaticSequence(['1'])
65
- self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm), 1)
70
+ self.assertEqual(querying.query('what is 1 + 0', int, lm=lm), 1)
66
71
 
67
72
  # Testing calling the same `lm` without copy.
68
73
  with self.assertRaises(IndexError):
69
- prompting.query('what is 1 + 2', int, lm=lm)
74
+ querying.query('what is 1 + 2', int, lm=lm)
70
75
 
71
76
  self.assertEqual(
72
- prompting.query(
77
+ querying.query(
73
78
  'what is 1 + 0', int, lm=lm.clone(), returns_message=True
74
79
  ),
75
80
  lf.AIMessage(
@@ -77,21 +82,23 @@ class QueryTest(unittest.TestCase):
77
82
  result=1,
78
83
  score=1.0,
79
84
  logprobs=None,
85
+ is_cached=False,
86
+ usage=lf.LMSamplingUsage(323, 1, 324),
80
87
  tags=['lm-response', 'lm-output', 'transformed'],
81
88
  ),
82
89
  )
83
90
  self.assertEqual(
84
- prompting.query(
85
- lf.Template('what is {{x}} + {{y}}'), int, x=1, y=0, lm=lm.clone()
91
+ querying.query(
92
+ lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
86
93
  ),
87
94
  1,
88
95
  )
89
96
  self.assertEqual(
90
- prompting.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
97
+ querying.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
91
98
  1,
92
99
  )
93
100
  self.assertEqual(
94
- prompting.query(
101
+ querying.query(
95
102
  'what is {{x}} + {{y}}',
96
103
  x=1,
97
104
  y=0,
@@ -100,7 +107,7 @@ class QueryTest(unittest.TestCase):
100
107
  'The answer is one.',
101
108
  )
102
109
  self.assertEqual(
103
- prompting.query(
110
+ querying.query(
104
111
  Activity.partial(),
105
112
  lm=fake.StaticResponse('Activity(description="hello")'),
106
113
  ),
@@ -116,12 +123,59 @@ class QueryTest(unittest.TestCase):
116
123
  y=2,
117
124
  lm=lm.clone(),
118
125
  expected_snippet=(
119
- 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT'
120
- ' according to OUTPUT_TYPE.\n\nINPUT_OBJECT:\n 1 + 1'
121
- ' =\n\nOUTPUT_TYPE:\n Answer\n\n ```python\n class Answer:\n '
122
- ' final_answer: int\n ```\n\nOUTPUT_OBJECT:\n ```python\n '
123
- ' Answer(final_answer=2)\n ```\n\nINPUT_OBJECT:\n What is 1 +'
124
- ' 2?\n\nOUTPUT_TYPE:\n int\n\nOUTPUT_OBJECT:'
126
+ 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
127
+ 'according to OUTPUT_TYPE.\n\n'
128
+ 'INPUT_OBJECT:\n 1 + 1 =\n\n'
129
+ 'OUTPUT_TYPE:\n'
130
+ ' Answer\n\n'
131
+ ' ```python\n'
132
+ ' class Answer:\n'
133
+ ' final_answer: int\n'
134
+ ' ```\n\n'
135
+ 'OUTPUT_OBJECT:\n'
136
+ ' ```python\n'
137
+ ' Answer(\n'
138
+ ' final_answer=2\n'
139
+ ' )\n'
140
+ ' ```\n\n'
141
+ 'INPUT_OBJECT:\n'
142
+ ' What is 1 + 2?\n\n'
143
+ 'OUTPUT_TYPE:\n'
144
+ ' int\n\n'
145
+ 'OUTPUT_OBJECT:'
146
+ ),
147
+ )
148
+
149
+ def test_str_to_structure_render_custom_template(self):
150
+ lm = fake.StaticResponse('1')
151
+ self.assert_render(
152
+ 'What is {{x}} + {{y}}?',
153
+ int,
154
+ x=1,
155
+ y=2,
156
+ lm=lm.clone(),
157
+ template_str='!!{{ DEFAULT }}!!',
158
+ expected_snippet=(
159
+ '!!Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
160
+ 'according to OUTPUT_TYPE.\n\n'
161
+ 'INPUT_OBJECT:\n 1 + 1 =\n\n'
162
+ 'OUTPUT_TYPE:\n'
163
+ ' Answer\n\n'
164
+ ' ```python\n'
165
+ ' class Answer:\n'
166
+ ' final_answer: int\n'
167
+ ' ```\n\n'
168
+ 'OUTPUT_OBJECT:\n'
169
+ ' ```python\n'
170
+ ' Answer(\n'
171
+ ' final_answer=2\n'
172
+ ' )\n'
173
+ ' ```\n\n'
174
+ 'INPUT_OBJECT:\n'
175
+ ' What is 1 + 2?\n\n'
176
+ 'OUTPUT_TYPE:\n'
177
+ ' int\n\n'
178
+ 'OUTPUT_OBJECT:!!'
125
179
  ),
126
180
  )
127
181
 
@@ -162,7 +216,7 @@ class QueryTest(unittest.TestCase):
162
216
  modalities.Image.from_bytes(b'mock_image'),
163
217
  int,
164
218
  lm=lm,
165
- expected_snippet='\n\nINPUT_OBJECT:\n {{input}}\n\n',
219
+ expected_snippet='\n\nINPUT_OBJECT:\n <<[[input]]>>\n\n',
166
220
  expected_modalities=1,
167
221
  )
168
222
 
@@ -172,7 +226,7 @@ class QueryTest(unittest.TestCase):
172
226
  modalities.Image.from_bytes(b'mock_image'),
173
227
  None,
174
228
  lm=lm,
175
- expected_snippet='{{input}}',
229
+ expected_snippet='<<[[input]]>>',
176
230
  exact_match=True,
177
231
  expected_modalities=1,
178
232
  )
@@ -185,7 +239,9 @@ class QueryTest(unittest.TestCase):
185
239
  this_image=modalities.Image.from_bytes(b'cat_image'),
186
240
  that_image=modalities.Image.from_bytes(b'mouse_image'),
187
241
  lm=lm,
188
- expected_snippet='What are these? {{this_image}} and {{that_image}}',
242
+ expected_snippet=(
243
+ 'What are these? <<[[this_image]]>> and <<[[that_image]]>>'
244
+ ),
189
245
  exact_match=True,
190
246
  expected_modalities=2,
191
247
  )
@@ -199,7 +255,7 @@ class QueryTest(unittest.TestCase):
199
255
  ],
200
256
  None,
201
257
  lm=lm,
202
- expected_snippet='`[{{input[0]}}, {{input[1]}}]`',
258
+ expected_snippet='`[<<[[input[0]]]>>, <<[[input[1]]]>>]`',
203
259
  exact_match=True,
204
260
  expected_modalities=2,
205
261
  )
@@ -217,33 +273,349 @@ class QueryTest(unittest.TestCase):
217
273
  INPUT_OBJECT:
218
274
  ```python
219
275
  [
220
- ModalityRef(
221
- name='input[0]'
222
- ),
223
- ModalityRef(
224
- name='input[1]'
225
- )
276
+ <<[[input[0]]]>>,
277
+ <<[[input[1]]]>>
226
278
  ]
227
279
  ```
228
-
229
- MODALITY_REFERENCES:
230
- {
231
- 'input[0]': {{input[0]}},
232
- 'input[1]': {{input[1]}}
233
- }
234
280
  """),
235
281
  expected_modalities=2,
236
282
  )
237
283
 
284
+ def test_structure_with_modality_and_examples_to_structure_render(self):
285
+ lm = fake.StaticResponse('["cat", "mouse"]')
286
+ self.assert_render(
287
+ [
288
+ modalities.Image.from_bytes(b'cat_image'),
289
+ modalities.Image.from_bytes(b'mouse_image'),
290
+ ],
291
+ list[str],
292
+ examples=[
293
+ mapping.MappingExample(
294
+ input=[modalities.Image.from_bytes(b'dog_image')],
295
+ schema=list[str],
296
+ output=['dog'],
297
+ ),
298
+ ],
299
+ lm=lm,
300
+ expected_snippet=inspect.cleandoc("""
301
+ INPUT_OBJECT:
302
+ ```python
303
+ [
304
+ <<[[examples[0].input[0]]]>>
305
+ ]
306
+ ```
307
+
308
+ OUTPUT_TYPE:
309
+ list[str]
310
+
311
+ OUTPUT_OBJECT:
312
+ ```python
313
+ [
314
+ 'dog'
315
+ ]
316
+ ```
317
+
318
+
319
+ INPUT_OBJECT:
320
+ ```python
321
+ [
322
+ <<[[input[0]]]>>,
323
+ <<[[input[1]]]>>
324
+ ]
325
+ ```
326
+ """),
327
+ expected_modalities=3,
328
+ )
329
+
330
+ def test_multiple_queries(self):
331
+ self.assertEqual(
332
+ querying.query(
333
+ 'Compute 1 + 2',
334
+ int,
335
+ lm=[
336
+ fake.StaticResponse('1'),
337
+ fake.StaticResponse('2'),
338
+ ],
339
+ num_samples=[1, 2],
340
+ ),
341
+ [1, 2, 2]
342
+ )
343
+ self.assertEqual(
344
+ querying.query(
345
+ 'Compute 1 + 2',
346
+ int,
347
+ lm=[
348
+ fake.StaticResponse('1'),
349
+ fake.StaticResponse('2'),
350
+ ],
351
+ num_samples=2,
352
+ ),
353
+ [1, 1, 2, 2]
354
+ )
355
+ self.assertEqual(
356
+ querying.query(
357
+ 'Compute 1 + 2',
358
+ int,
359
+ lm=[
360
+ fake.StaticResponse('1'),
361
+ fake.StaticResponse('abc'),
362
+ ],
363
+ num_samples=[1, 2],
364
+ ),
365
+ [1]
366
+ )
367
+ self.assertEqual(
368
+ querying.query(
369
+ 'Compute 1 + 2',
370
+ int,
371
+ default=0,
372
+ lm=[
373
+ fake.StaticResponse('1'),
374
+ fake.StaticResponse('abc'),
375
+ ],
376
+ num_samples=[1, 2],
377
+ ),
378
+ [1, 0, 0]
379
+ )
380
+ results = querying.query(
381
+ 'Compute 1 + 2',
382
+ int,
383
+ default=0,
384
+ lm=[
385
+ fake.StaticResponse('1'),
386
+ fake.StaticResponse('abc'),
387
+ ],
388
+ returns_message=True,
389
+ )
390
+ self.assertEqual([r.text for r in results], ['1', 'abc'])
391
+ self.assertEqual([r.result for r in results], [1, 0])
392
+
238
393
  def test_bad_protocol(self):
239
394
  with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
240
- prompting.query('what is 1 + 1', int, protocol='text')
395
+ querying.query('what is 1 + 1', int, protocol='text')
396
+
397
+ def test_query_prompt(self):
398
+ self.assertEqual(
399
+ querying.query_prompt('what is this?', int),
400
+ inspect.cleandoc("""
401
+ Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE.
402
+
403
+ INPUT_OBJECT:
404
+ 1 + 1 =
405
+
406
+ OUTPUT_TYPE:
407
+ Answer
408
+
409
+ ```python
410
+ class Answer:
411
+ final_answer: int
412
+ ```
413
+
414
+ OUTPUT_OBJECT:
415
+ ```python
416
+ Answer(
417
+ final_answer=2
418
+ )
419
+ ```
420
+
421
+ INPUT_OBJECT:
422
+ what is this?
423
+
424
+ OUTPUT_TYPE:
425
+ int
426
+
427
+ OUTPUT_OBJECT:
428
+ """),
429
+ )
430
+
431
+ def test_query_prompt_with_metadata(self):
432
+ self.assertIn(
433
+ 'x',
434
+ querying.query_prompt(
435
+ 'what is this?',
436
+ metadata_x=1
437
+ ).metadata
438
+ )
439
+ self.assertIn(
440
+ 'x',
441
+ querying.query_prompt(
442
+ 'what is this?',
443
+ int,
444
+ metadata_x=1
445
+ ).metadata
446
+ )
447
+
448
+ def test_query_prompt_with_unrooted_template(self):
449
+ output = querying.query_prompt(
450
+ pg.Dict(
451
+ input=lf.Template(
452
+ 'what is {{image}}',
453
+ image=modalities.Image.from_bytes(b'mock_image')
454
+ )
455
+ ).input,
456
+ )
457
+ self.assertIsNotNone(output.get_modality('image'))
458
+
459
+ def test_query_and_reduce(self):
460
+ self.assertEqual(
461
+ querying.query_and_reduce(
462
+ 'Compute 1 + 1',
463
+ int,
464
+ reduce=sum,
465
+ lm=[
466
+ fake.StaticResponse('1'),
467
+ fake.StaticResponse('2'),
468
+ ],
469
+ num_samples=[1, 2],
470
+ ),
471
+ 5
472
+ )
473
+ self.assertEqual(
474
+ querying.query_and_reduce(
475
+ 'Compute 1 + 1',
476
+ int,
477
+ reduce=sum,
478
+ lm=fake.StaticResponse('2'),
479
+ ),
480
+ 2
481
+ )
482
+
483
+ def test_query_output(self):
484
+ self.assertEqual(
485
+ querying.query_output(
486
+ lf.AIMessage('1'),
487
+ int,
488
+ ),
489
+ 1,
490
+ )
491
+
492
+ def test_query_reward(self):
493
+
494
+ class Answer(pg.Object):
495
+ final_answer: int
496
+
497
+ def __reward__(self, inputs: lf.Template) -> None:
498
+ diff = abs(self.final_answer - (inputs.x + inputs.y))
499
+ # Center screwed sigmoid scaled to [-1.0 and 1.0].
500
+ return 4 / (1 + math.exp(diff)) - 1.0
501
+
502
+ # Case 1: Reward function based on input and output.
503
+ self.assertEqual(
504
+ querying.query_reward(
505
+ mapping.MappingExample(
506
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
507
+ schema=Answer,
508
+ output=Answer(final_answer=2),
509
+ ),
510
+ 'Answer(2)'
511
+ ),
512
+ 1.0
513
+ )
514
+ self.assertEqual(
515
+ querying.query_reward(
516
+ mapping.MappingExample(
517
+ input=lf.Template('{{x}} + {{y}}', x=2, y=3),
518
+ output=Answer(final_answer=2),
519
+ ).to_json_str(),
520
+ 'Answer(5)'
521
+ ),
522
+ 1.0
523
+ )
524
+
525
+ # Case 2: Reward function based on input, result and expected output.
526
+ class Answer2(pg.Object):
527
+ final_answer: int
528
+
529
+ def __reward__(self, inputs: lf.Template, expected_output: 'Answer2'):
530
+ return (
531
+ 1.0 if self.final_answer == expected_output.final_answer else -1.0
532
+ )
533
+
534
+ self.assertEqual(
535
+ querying.query_reward(
536
+ mapping.MappingExample(
537
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
538
+ output=Answer2(final_answer=2),
539
+ ),
540
+ 'Answer2(3)'
541
+ ),
542
+ -1.0
543
+ )
544
+
545
+ # Case 3: Reward function based on input, result, expected output
546
+ # and metadata.
547
+ class Answer3(pg.Object):
548
+ final_answer: int
549
+
550
+ def __reward__(self,
551
+ inputs: lf.Template,
552
+ expected_output: 'Answer3',
553
+ metadata: dict[str, Any]):
554
+ del inputs
555
+ return (
556
+ 1.0 if self.final_answer == expected_output.final_answer else -1.0
557
+ ) * metadata['weight']
558
+
559
+ self.assertEqual(
560
+ querying.query_reward(
561
+ mapping.MappingExample(
562
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
563
+ output=Answer3(final_answer=2),
564
+ metadata=dict(weight=0.5)
565
+ ),
566
+ 'Answer3(3)'
567
+ ),
568
+ -0.5
569
+ )
570
+
571
+ # Case 4: No reward function is provided.
572
+ class Answer4(pg.Object):
573
+ final_answer: int
574
+
575
+ self.assertIsNone(
576
+ querying.query_reward(
577
+ mapping.MappingExample(
578
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
579
+ output=Answer4(final_answer=2),
580
+ ),
581
+ 'Answer2(2)'
582
+ )
583
+ )
584
+
585
+ # Case 5: Not a structured output.
586
+ self.assertIsNone(
587
+ querying.query_reward(
588
+ mapping.MappingExample(
589
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
590
+ output='2',
591
+ ),
592
+ '2'
593
+ )
594
+ )
595
+
596
+ # Case 6: Bad reward function.
597
+ class Answer5(pg.Object):
598
+ final_answer: int
599
+
600
+ def __reward__(self):
601
+ return 0.0
602
+
603
+ with self.assertRaisesRegex(
604
+ TypeError, '.*Answer5.__reward__` should have signature'
605
+ ):
606
+ querying.query_reward(
607
+ mapping.MappingExample(
608
+ input=lf.Template('{{x}} + {{y}}', x=1, y=1),
609
+ output=Answer5(final_answer=2),
610
+ ),
611
+ 'Answer5(2)'
612
+ )
241
613
 
242
614
 
243
615
  class QueryStructurePythonTest(unittest.TestCase):
244
616
 
245
617
  def test_render_no_examples(self):
246
- l = prompting.QueryStructurePython(
618
+ l = querying._QueryStructurePython(
247
619
  input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
248
620
  )
249
621
  self.assertEqual(
@@ -264,7 +636,9 @@ class QueryStructurePythonTest(unittest.TestCase):
264
636
 
265
637
  OUTPUT_OBJECT:
266
638
  ```python
267
- Answer(final_answer=2)
639
+ Answer(
640
+ final_answer=2
641
+ )
268
642
  ```
269
643
 
270
644
  INPUT_OBJECT:
@@ -278,7 +652,7 @@ class QueryStructurePythonTest(unittest.TestCase):
278
652
  )
279
653
 
280
654
  def test_render(self):
281
- l = prompting.QueryStructurePython(
655
+ l = querying._QueryStructurePython(
282
656
  input=lf.AIMessage('Compute 12 / 6 + 2.'),
283
657
  schema=int,
284
658
  examples=[
@@ -308,7 +682,9 @@ class QueryStructurePythonTest(unittest.TestCase):
308
682
 
309
683
  OUTPUT_OBJECT:
310
684
  ```python
311
- Answer(final_answer=2)
685
+ Answer(
686
+ final_answer=2
687
+ )
312
688
  ```
313
689
 
314
690
  INPUT_OBJECT:
@@ -386,7 +762,7 @@ class QueryStructurePythonTest(unittest.TestCase):
386
762
  ),
387
763
  override_attrs=True,
388
764
  ):
389
- l = prompting.QueryStructurePython(
765
+ l = querying._QueryStructurePython(
390
766
  input=lm_input,
391
767
  schema=[Itinerary],
392
768
  examples=[
@@ -420,10 +796,10 @@ class QueryStructurePythonTest(unittest.TestCase):
420
796
  override_attrs=True,
421
797
  ):
422
798
  with self.assertRaisesRegex(
423
- coding.CodeError,
799
+ mapping.MappingError,
424
800
  'name .* is not defined',
425
801
  ):
426
- prompting.query('Compute 1 + 2', int)
802
+ querying.query('Compute 1 + 2', int)
427
803
 
428
804
  def test_autofix(self):
429
805
  lm = fake.StaticSequence([
@@ -434,13 +810,30 @@ class QueryStructurePythonTest(unittest.TestCase):
434
810
  )
435
811
  """),
436
812
  ])
437
- self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
813
+ self.assertEqual(querying.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
814
+
815
+ def test_response_postprocess(self):
816
+ with lf.context(
817
+ lm=fake.StaticResponse('<!-- some comment-->\n3'),
818
+ override_attrs=True,
819
+ ):
820
+ self.assertEqual(
821
+ querying.query(
822
+ 'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
823
+ '3'
824
+ )
825
+ self.assertEqual(
826
+ querying.query(
827
+ 'Compute 1 + 2', int,
828
+ response_postprocess=lambda x: x.split('\n')[1]),
829
+ 3
830
+ )
438
831
 
439
832
 
440
833
  class QueryStructureJsonTest(unittest.TestCase):
441
834
 
442
835
  def test_render_no_examples(self):
443
- l = prompting.QueryStructureJson(
836
+ l = querying._QueryStructureJson(
444
837
  input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
445
838
  )
446
839
  self.assertEqual(
@@ -456,10 +849,10 @@ class QueryStructureJsonTest(unittest.TestCase):
456
849
  1 + 1 =
457
850
 
458
851
  SCHEMA:
459
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}}
852
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
460
853
 
461
854
  JSON:
462
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}}
855
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
463
856
 
464
857
  INPUT_OBJECT:
465
858
  Compute 12 / 6 + 2.
@@ -472,7 +865,7 @@ class QueryStructureJsonTest(unittest.TestCase):
472
865
  )
473
866
 
474
867
  def test_render(self):
475
- l = prompting.QueryStructureJson(
868
+ l = querying._QueryStructureJson(
476
869
  input=lf.AIMessage('Compute 12 / 6 + 2.'),
477
870
  schema=int,
478
871
  examples=[
@@ -493,10 +886,10 @@ class QueryStructureJsonTest(unittest.TestCase):
493
886
  1 + 1 =
494
887
 
495
888
  SCHEMA:
496
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}}
889
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
497
890
 
498
891
  JSON:
499
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}}
892
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
500
893
 
501
894
  INPUT_OBJECT:
502
895
  What is the answer of 1 plus 1?
@@ -607,7 +1000,7 @@ class QueryStructureJsonTest(unittest.TestCase):
607
1000
  ),
608
1001
  override_attrs=True,
609
1002
  ):
610
- l = prompting.QueryStructureJson(
1003
+ l = querying._QueryStructureJson(
611
1004
  input=lm_input,
612
1005
  schema=[Itinerary],
613
1006
  examples=[
@@ -636,22 +1029,114 @@ class QueryStructureJsonTest(unittest.TestCase):
636
1029
  self.assertIsNone(r.result[0].hotel)
637
1030
 
638
1031
  def test_bad_transform(self):
639
- with lf.context(
640
- lm=fake.StaticSequence(['3']),
641
- override_attrs=True,
642
- ):
643
- with self.assertRaisesRegex(
644
- schema_lib.JsonError,
645
- 'No JSON dict in the output',
1032
+ with in_memory.lm_cache() as cache:
1033
+ with lf.context(
1034
+ lm=fake.StaticSequence(['3']),
1035
+ override_attrs=True,
646
1036
  ):
647
- prompting.query('Compute 1 + 2', int, protocol='json')
1037
+ with self.assertRaisesRegex(
1038
+ mapping.MappingError,
1039
+ 'No JSON dict in the output',
1040
+ ):
1041
+ querying.query('Compute 1 + 2', int, protocol='json', cache_seed=1)
1042
+ # Make sure bad mapping does not impact cache.
1043
+ self.assertEqual(len(cache), 0)
648
1044
 
649
1045
  def test_query(self):
650
1046
  lm = fake.StaticSequence(['{"result": 1}'])
651
1047
  self.assertEqual(
652
- prompting.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
1048
+ querying.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
653
1049
  )
654
1050
 
655
1051
 
1052
+ class QueryInvocationTest(unittest.TestCase):
1053
+
1054
+ def test_basics(self):
1055
+ lm = fake.StaticSequence([
1056
+ 'Activity(description="hi"',
1057
+ ])
1058
+ with querying.track_queries() as queries:
1059
+ querying.query('foo', Activity, default=None, lm=lm)
1060
+
1061
+ self.assertTrue(queries[0].has_error)
1062
+ self.assertIsInstance(queries[0].output, mapping.MappingError)
1063
+
1064
+ def test_to_html(self):
1065
+ lm = fake.StaticSequence([
1066
+ 'Activity(description="hi")',
1067
+ ])
1068
+ with querying.track_queries() as queries:
1069
+ querying.query('foo', Activity, lm=lm)
1070
+
1071
+ self.assertIn('schema', queries[0].to_html_str())
1072
+
1073
+
1074
+ class TrackQueriesTest(unittest.TestCase):
1075
+
1076
+ def test_include_child_scopes(self):
1077
+ lm = fake.StaticSequence([
1078
+ 'bar',
1079
+ 'Activity(description="hi")',
1080
+ ])
1081
+ with querying.track_queries() as queries:
1082
+ querying.query('foo', lm=lm)
1083
+ with querying.track_queries() as child_queries:
1084
+ querying.query('give me an activity', Activity, lm=lm)
1085
+
1086
+ self.assertEqual(len(queries), 2)
1087
+ self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
1088
+ self.assertIsNone(queries[0].schema)
1089
+ self.assertEqual(queries[0].output, 'bar')
1090
+ self.assertIs(queries[0].lm, lm)
1091
+
1092
+ self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity')))
1093
+ self.assertEqual(queries[1].schema.spec.cls, Activity)
1094
+ self.assertTrue(pg.eq(queries[1].output, Activity(description='hi')))
1095
+ self.assertIs(queries[1].lm, lm)
1096
+ self.assertGreater(queries[0].elapse, 0)
1097
+ self.assertGreater(queries[0].usage_summary.total.total_tokens, 0)
1098
+ self.assertGreater(queries[1].usage_summary.total.total_tokens, 0)
1099
+
1100
+ self.assertEqual(len(child_queries), 1)
1101
+ self.assertIs(child_queries[0], queries[1])
1102
+
1103
+ def test_exclude_child_scopes(self):
1104
+ lm = fake.StaticSequence([
1105
+ 'bar',
1106
+ 'Activity(description="hi")',
1107
+ ])
1108
+ with querying.track_queries(include_child_scopes=False) as queries:
1109
+ querying.query('foo', lm=lm)
1110
+ with querying.track_queries(include_child_scopes=False) as child_queries:
1111
+ querying.query('give me an activity', Activity, lm=lm)
1112
+
1113
+ self.assertEqual(len(queries), 1)
1114
+ self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
1115
+ self.assertIsNone(queries[0].schema)
1116
+ self.assertEqual(queries[0].output, 'bar')
1117
+ self.assertIs(queries[0].lm, lm)
1118
+
1119
+ self.assertEqual(len(child_queries), 1)
1120
+ self.assertTrue(
1121
+ pg.eq(child_queries[0].input, lf.Template('give me an activity'))
1122
+ )
1123
+ self.assertEqual(child_queries[0].schema.spec.cls, Activity)
1124
+ self.assertTrue(pg.eq(child_queries[0].output, Activity(description='hi')))
1125
+ self.assertIs(child_queries[0].lm, lm)
1126
+
1127
+ def test_concurrent_map(self):
1128
+
1129
+ def make_query(prompt):
1130
+ _ = querying.query(prompt, lm=lm)
1131
+
1132
+ lm = fake.StaticSequence([
1133
+ 'foo',
1134
+ 'bar',
1135
+ ])
1136
+ with querying.track_queries() as queries:
1137
+ list(lf.concurrent_map(make_query, ['a', 'b']))
1138
+ self.assertEqual(len(queries), 2)
1139
+
1140
+
656
1141
  if __name__ == '__main__':
657
1142
  unittest.main()