langfun 0.1.2.dev202412150804__py3-none-any.whl → 0.1.2.dev202412180804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for structured prompting."""
14
+ """Tests for structured query."""
15
15
 
16
16
  import inspect
17
17
  import math
@@ -23,7 +23,7 @@ from langfun.core import modalities
23
23
  from langfun.core.llms import fake
24
24
  from langfun.core.llms.cache import in_memory
25
25
  from langfun.core.structured import mapping
26
- from langfun.core.structured import prompting
26
+ from langfun.core.structured import querying
27
27
  import pyglove as pg
28
28
 
29
29
 
@@ -51,7 +51,7 @@ class QueryTest(unittest.TestCase):
51
51
  expected_modalities: int = 0,
52
52
  **kwargs,
53
53
  ):
54
- m = prompting.query(
54
+ m = querying.query(
55
55
  prompt, schema=schema, examples=examples,
56
56
  **kwargs, returns_message=True
57
57
  )
@@ -67,14 +67,14 @@ class QueryTest(unittest.TestCase):
67
67
 
68
68
  def test_call(self):
69
69
  lm = fake.StaticSequence(['1'])
70
- self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm), 1)
70
+ self.assertEqual(querying.query('what is 1 + 0', int, lm=lm), 1)
71
71
 
72
72
  # Testing calling the same `lm` without copy.
73
73
  with self.assertRaises(IndexError):
74
- prompting.query('what is 1 + 2', int, lm=lm)
74
+ querying.query('what is 1 + 2', int, lm=lm)
75
75
 
76
76
  self.assertEqual(
77
- prompting.query(
77
+ querying.query(
78
78
  'what is 1 + 0', int, lm=lm.clone(), returns_message=True
79
79
  ),
80
80
  lf.AIMessage(
@@ -88,17 +88,17 @@ class QueryTest(unittest.TestCase):
88
88
  ),
89
89
  )
90
90
  self.assertEqual(
91
- prompting.query(
91
+ querying.query(
92
92
  lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
93
93
  ),
94
94
  1,
95
95
  )
96
96
  self.assertEqual(
97
- prompting.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
97
+ querying.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
98
98
  1,
99
99
  )
100
100
  self.assertEqual(
101
- prompting.query(
101
+ querying.query(
102
102
  'what is {{x}} + {{y}}',
103
103
  x=1,
104
104
  y=0,
@@ -107,7 +107,7 @@ class QueryTest(unittest.TestCase):
107
107
  'The answer is one.',
108
108
  )
109
109
  self.assertEqual(
110
- prompting.query(
110
+ querying.query(
111
111
  Activity.partial(),
112
112
  lm=fake.StaticResponse('Activity(description="hello")'),
113
113
  ),
@@ -327,13 +327,76 @@ class QueryTest(unittest.TestCase):
327
327
  expected_modalities=3,
328
328
  )
329
329
 
330
+ def test_multiple_queries(self):
331
+ self.assertEqual(
332
+ querying.query(
333
+ 'Compute 1 + 2',
334
+ int,
335
+ lm=[
336
+ fake.StaticResponse('1'),
337
+ fake.StaticResponse('2'),
338
+ ],
339
+ num_samples=[1, 2],
340
+ ),
341
+ [1, 2, 2]
342
+ )
343
+ self.assertEqual(
344
+ querying.query(
345
+ 'Compute 1 + 2',
346
+ int,
347
+ lm=[
348
+ fake.StaticResponse('1'),
349
+ fake.StaticResponse('2'),
350
+ ],
351
+ num_samples=2,
352
+ ),
353
+ [1, 1, 2, 2]
354
+ )
355
+ self.assertEqual(
356
+ querying.query(
357
+ 'Compute 1 + 2',
358
+ int,
359
+ lm=[
360
+ fake.StaticResponse('1'),
361
+ fake.StaticResponse('abc'),
362
+ ],
363
+ num_samples=[1, 2],
364
+ ),
365
+ [1]
366
+ )
367
+ self.assertEqual(
368
+ querying.query(
369
+ 'Compute 1 + 2',
370
+ int,
371
+ default=0,
372
+ lm=[
373
+ fake.StaticResponse('1'),
374
+ fake.StaticResponse('abc'),
375
+ ],
376
+ num_samples=[1, 2],
377
+ ),
378
+ [1, 0, 0]
379
+ )
380
+ results = querying.query(
381
+ 'Compute 1 + 2',
382
+ int,
383
+ default=0,
384
+ lm=[
385
+ fake.StaticResponse('1'),
386
+ fake.StaticResponse('abc'),
387
+ ],
388
+ returns_message=True,
389
+ )
390
+ self.assertEqual([r.text for r in results], ['1', 'abc'])
391
+ self.assertEqual([r.result for r in results], [1, 0])
392
+
330
393
  def test_bad_protocol(self):
331
394
  with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
332
- prompting.query('what is 1 + 1', int, protocol='text')
395
+ querying.query('what is 1 + 1', int, protocol='text')
333
396
 
334
397
  def test_query_prompt(self):
335
398
  self.assertEqual(
336
- prompting.query_prompt('what is this?', int),
399
+ querying.query_prompt('what is this?', int),
337
400
  inspect.cleandoc("""
338
401
  Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE.
339
402
 
@@ -368,14 +431,14 @@ class QueryTest(unittest.TestCase):
368
431
  def test_query_prompt_with_metadata(self):
369
432
  self.assertIn(
370
433
  'x',
371
- prompting.query_prompt(
434
+ querying.query_prompt(
372
435
  'what is this?',
373
436
  metadata_x=1
374
437
  ).metadata
375
438
  )
376
439
  self.assertIn(
377
440
  'x',
378
- prompting.query_prompt(
441
+ querying.query_prompt(
379
442
  'what is this?',
380
443
  int,
381
444
  metadata_x=1
@@ -383,7 +446,7 @@ class QueryTest(unittest.TestCase):
383
446
  )
384
447
 
385
448
  def test_query_prompt_with_unrooted_template(self):
386
- output = prompting.query_prompt(
449
+ output = querying.query_prompt(
387
450
  pg.Dict(
388
451
  input=lf.Template(
389
452
  'what is {{image}}',
@@ -393,9 +456,33 @@ class QueryTest(unittest.TestCase):
393
456
  )
394
457
  self.assertIsNotNone(output.get_modality('image'))
395
458
 
459
+ def test_query_and_reduce(self):
460
+ self.assertEqual(
461
+ querying.query_and_reduce(
462
+ 'Compute 1 + 1',
463
+ int,
464
+ reduce=sum,
465
+ lm=[
466
+ fake.StaticResponse('1'),
467
+ fake.StaticResponse('2'),
468
+ ],
469
+ num_samples=[1, 2],
470
+ ),
471
+ 5
472
+ )
473
+ self.assertEqual(
474
+ querying.query_and_reduce(
475
+ 'Compute 1 + 1',
476
+ int,
477
+ reduce=sum,
478
+ lm=fake.StaticResponse('2'),
479
+ ),
480
+ 2
481
+ )
482
+
396
483
  def test_query_output(self):
397
484
  self.assertEqual(
398
- prompting.query_output(
485
+ querying.query_output(
399
486
  lf.AIMessage('1'),
400
487
  int,
401
488
  ),
@@ -414,7 +501,7 @@ class QueryTest(unittest.TestCase):
414
501
 
415
502
  # Case 1: Reward function based on input and output.
416
503
  self.assertEqual(
417
- prompting.query_reward(
504
+ querying.query_reward(
418
505
  mapping.MappingExample(
419
506
  input=lf.Template('{{x}} + {{y}}', x=1, y=1),
420
507
  schema=Answer,
@@ -425,7 +512,7 @@ class QueryTest(unittest.TestCase):
425
512
  1.0
426
513
  )
427
514
  self.assertEqual(
428
- prompting.query_reward(
515
+ querying.query_reward(
429
516
  mapping.MappingExample(
430
517
  input=lf.Template('{{x}} + {{y}}', x=2, y=3),
431
518
  output=Answer(final_answer=2),
@@ -445,7 +532,7 @@ class QueryTest(unittest.TestCase):
445
532
  )
446
533
 
447
534
  self.assertEqual(
448
- prompting.query_reward(
535
+ querying.query_reward(
449
536
  mapping.MappingExample(
450
537
  input=lf.Template('{{x}} + {{y}}', x=1, y=1),
451
538
  output=Answer2(final_answer=2),
@@ -470,7 +557,7 @@ class QueryTest(unittest.TestCase):
470
557
  ) * metadata['weight']
471
558
 
472
559
  self.assertEqual(
473
- prompting.query_reward(
560
+ querying.query_reward(
474
561
  mapping.MappingExample(
475
562
  input=lf.Template('{{x}} + {{y}}', x=1, y=1),
476
563
  output=Answer3(final_answer=2),
@@ -486,7 +573,7 @@ class QueryTest(unittest.TestCase):
486
573
  final_answer: int
487
574
 
488
575
  self.assertIsNone(
489
- prompting.query_reward(
576
+ querying.query_reward(
490
577
  mapping.MappingExample(
491
578
  input=lf.Template('{{x}} + {{y}}', x=1, y=1),
492
579
  output=Answer4(final_answer=2),
@@ -497,7 +584,7 @@ class QueryTest(unittest.TestCase):
497
584
 
498
585
  # Case 5: Not a structured output.
499
586
  self.assertIsNone(
500
- prompting.query_reward(
587
+ querying.query_reward(
501
588
  mapping.MappingExample(
502
589
  input=lf.Template('{{x}} + {{y}}', x=1, y=1),
503
590
  output='2',
@@ -516,7 +603,7 @@ class QueryTest(unittest.TestCase):
516
603
  with self.assertRaisesRegex(
517
604
  TypeError, '.*Answer5.__reward__` should have signature'
518
605
  ):
519
- prompting.query_reward(
606
+ querying.query_reward(
520
607
  mapping.MappingExample(
521
608
  input=lf.Template('{{x}} + {{y}}', x=1, y=1),
522
609
  output=Answer5(final_answer=2),
@@ -528,7 +615,7 @@ class QueryTest(unittest.TestCase):
528
615
  class QueryStructurePythonTest(unittest.TestCase):
529
616
 
530
617
  def test_render_no_examples(self):
531
- l = prompting.QueryStructurePython(
618
+ l = querying._QueryStructurePython(
532
619
  input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
533
620
  )
534
621
  self.assertEqual(
@@ -565,7 +652,7 @@ class QueryStructurePythonTest(unittest.TestCase):
565
652
  )
566
653
 
567
654
  def test_render(self):
568
- l = prompting.QueryStructurePython(
655
+ l = querying._QueryStructurePython(
569
656
  input=lf.AIMessage('Compute 12 / 6 + 2.'),
570
657
  schema=int,
571
658
  examples=[
@@ -675,7 +762,7 @@ class QueryStructurePythonTest(unittest.TestCase):
675
762
  ),
676
763
  override_attrs=True,
677
764
  ):
678
- l = prompting.QueryStructurePython(
765
+ l = querying._QueryStructurePython(
679
766
  input=lm_input,
680
767
  schema=[Itinerary],
681
768
  examples=[
@@ -712,7 +799,7 @@ class QueryStructurePythonTest(unittest.TestCase):
712
799
  mapping.MappingError,
713
800
  'name .* is not defined',
714
801
  ):
715
- prompting.query('Compute 1 + 2', int)
802
+ querying.query('Compute 1 + 2', int)
716
803
 
717
804
  def test_autofix(self):
718
805
  lm = fake.StaticSequence([
@@ -723,7 +810,7 @@ class QueryStructurePythonTest(unittest.TestCase):
723
810
  )
724
811
  """),
725
812
  ])
726
- self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
813
+ self.assertEqual(querying.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
727
814
 
728
815
  def test_response_postprocess(self):
729
816
  with lf.context(
@@ -731,12 +818,12 @@ class QueryStructurePythonTest(unittest.TestCase):
731
818
  override_attrs=True,
732
819
  ):
733
820
  self.assertEqual(
734
- prompting.query(
821
+ querying.query(
735
822
  'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
736
823
  '3'
737
824
  )
738
825
  self.assertEqual(
739
- prompting.query(
826
+ querying.query(
740
827
  'Compute 1 + 2', int,
741
828
  response_postprocess=lambda x: x.split('\n')[1]),
742
829
  3
@@ -746,7 +833,7 @@ class QueryStructurePythonTest(unittest.TestCase):
746
833
  class QueryStructureJsonTest(unittest.TestCase):
747
834
 
748
835
  def test_render_no_examples(self):
749
- l = prompting.QueryStructureJson(
836
+ l = querying._QueryStructureJson(
750
837
  input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
751
838
  )
752
839
  self.assertEqual(
@@ -762,10 +849,10 @@ class QueryStructureJsonTest(unittest.TestCase):
762
849
  1 + 1 =
763
850
 
764
851
  SCHEMA:
765
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}}
852
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
766
853
 
767
854
  JSON:
768
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}}
855
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
769
856
 
770
857
  INPUT_OBJECT:
771
858
  Compute 12 / 6 + 2.
@@ -778,7 +865,7 @@ class QueryStructureJsonTest(unittest.TestCase):
778
865
  )
779
866
 
780
867
  def test_render(self):
781
- l = prompting.QueryStructureJson(
868
+ l = querying._QueryStructureJson(
782
869
  input=lf.AIMessage('Compute 12 / 6 + 2.'),
783
870
  schema=int,
784
871
  examples=[
@@ -799,10 +886,10 @@ class QueryStructureJsonTest(unittest.TestCase):
799
886
  1 + 1 =
800
887
 
801
888
  SCHEMA:
802
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}}
889
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
803
890
 
804
891
  JSON:
805
- {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}}
892
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
806
893
 
807
894
  INPUT_OBJECT:
808
895
  What is the answer of 1 plus 1?
@@ -913,7 +1000,7 @@ class QueryStructureJsonTest(unittest.TestCase):
913
1000
  ),
914
1001
  override_attrs=True,
915
1002
  ):
916
- l = prompting.QueryStructureJson(
1003
+ l = querying._QueryStructureJson(
917
1004
  input=lm_input,
918
1005
  schema=[Itinerary],
919
1006
  examples=[
@@ -951,14 +1038,14 @@ class QueryStructureJsonTest(unittest.TestCase):
951
1038
  mapping.MappingError,
952
1039
  'No JSON dict in the output',
953
1040
  ):
954
- prompting.query('Compute 1 + 2', int, protocol='json', cache_seed=1)
1041
+ querying.query('Compute 1 + 2', int, protocol='json', cache_seed=1)
955
1042
  # Make sure bad mapping does not impact cache.
956
1043
  self.assertEqual(len(cache), 0)
957
1044
 
958
1045
  def test_query(self):
959
1046
  lm = fake.StaticSequence(['{"result": 1}'])
960
1047
  self.assertEqual(
961
- prompting.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
1048
+ querying.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
962
1049
  )
963
1050
 
964
1051
 
@@ -968,8 +1055,8 @@ class QueryInvocationTest(unittest.TestCase):
968
1055
  lm = fake.StaticSequence([
969
1056
  'Activity(description="hi")',
970
1057
  ])
971
- with prompting.track_queries() as queries:
972
- prompting.query('foo', Activity, lm=lm)
1058
+ with querying.track_queries() as queries:
1059
+ querying.query('foo', Activity, lm=lm)
973
1060
 
974
1061
  self.assertIn('schema', queries[0].to_html_str())
975
1062
 
@@ -981,10 +1068,10 @@ class TrackQueriesTest(unittest.TestCase):
981
1068
  'bar',
982
1069
  'Activity(description="hi")',
983
1070
  ])
984
- with prompting.track_queries() as queries:
985
- prompting.query('foo', lm=lm)
986
- with prompting.track_queries() as child_queries:
987
- prompting.query('give me an activity', Activity, lm=lm)
1071
+ with querying.track_queries() as queries:
1072
+ querying.query('foo', lm=lm)
1073
+ with querying.track_queries() as child_queries:
1074
+ querying.query('give me an activity', Activity, lm=lm)
988
1075
 
989
1076
  self.assertEqual(len(queries), 2)
990
1077
  self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
@@ -1008,10 +1095,10 @@ class TrackQueriesTest(unittest.TestCase):
1008
1095
  'bar',
1009
1096
  'Activity(description="hi")',
1010
1097
  ])
1011
- with prompting.track_queries(include_child_scopes=False) as queries:
1012
- prompting.query('foo', lm=lm)
1013
- with prompting.track_queries(include_child_scopes=False) as child_queries:
1014
- prompting.query('give me an activity', Activity, lm=lm)
1098
+ with querying.track_queries(include_child_scopes=False) as queries:
1099
+ querying.query('foo', lm=lm)
1100
+ with querying.track_queries(include_child_scopes=False) as child_queries:
1101
+ querying.query('give me an activity', Activity, lm=lm)
1015
1102
 
1016
1103
  self.assertEqual(len(queries), 1)
1017
1104
  self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
@@ -1030,13 +1117,13 @@ class TrackQueriesTest(unittest.TestCase):
1030
1117
  def test_concurrent_map(self):
1031
1118
 
1032
1119
  def make_query(prompt):
1033
- _ = prompting.query(prompt, lm=lm)
1120
+ _ = querying.query(prompt, lm=lm)
1034
1121
 
1035
1122
  lm = fake.StaticSequence([
1036
1123
  'foo',
1037
1124
  'bar',
1038
1125
  ])
1039
- with prompting.track_queries() as queries:
1126
+ with querying.track_queries() as queries:
1040
1127
  list(lf.concurrent_map(make_query, ['a', 'b']))
1041
1128
  self.assertEqual(len(queries), 2)
1042
1129
 
@@ -213,18 +213,8 @@ class Schema(
213
213
  """
214
214
  )
215
215
 
216
- def _html_tree_view_tooltip(
217
- self,
218
- *,
219
- view: pg.views.HtmlTreeView,
220
- content: pg.Html | str | None = None,
221
- **kwargs,
222
- ):
223
- return view.tooltip(
224
- self,
225
- content=content or pg.Html.escape(self.schema_str(protocol='python')),
226
- **kwargs
227
- )
216
+
217
+ SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
228
218
 
229
219
 
230
220
  def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
@@ -388,9 +378,9 @@ class SchemaPythonRepr(SchemaRepr):
388
378
  return annotation(schema.spec)
389
379
 
390
380
 
391
- def source_form(value, markdown: bool = False) -> str:
381
+ def source_form(value, compact: bool = True, markdown: bool = False) -> str:
392
382
  """Returns the source code form of an object."""
393
- return ValuePythonRepr().repr(value, markdown=markdown)
383
+ return ValuePythonRepr().repr(value, compact=compact, markdown=markdown)
394
384
 
395
385
 
396
386
  def class_definitions(
@@ -789,7 +779,7 @@ class ValueJsonRepr(ValueRepr):
789
779
  """Parse a JSON string into a structured object."""
790
780
  del schema
791
781
  try:
792
- text = self.cleanup_json(text)
782
+ text = cleanup_json(text)
793
783
  v = pg.from_json_str(text, **kwargs)
794
784
  except Exception as e:
795
785
  raise JsonError(text, e) # pylint: disable=raise-missing-from
@@ -801,55 +791,56 @@ class ValueJsonRepr(ValueRepr):
801
791
  ))
802
792
  return v['result']
803
793
 
804
- def cleanup_json(self, json_str: str) -> str:
805
- """Clean up the LM responded JSON string."""
806
- # Treatments:
807
- # 1. Extract the JSON string with a top-level dict from the response.
808
- # This prevents the leading and trailing texts in the response to
809
- # be counted as part of the JSON.
810
- # 2. Escape new lines in JSON values.
811
-
812
- curly_brackets = 0
813
- under_json = False
814
- under_str = False
815
- str_begin = -1
816
-
817
- cleaned = io.StringIO()
818
- for i, c in enumerate(json_str):
819
- if c == '{' and not under_str:
820
- cleaned.write(c)
821
- curly_brackets += 1
822
- under_json = True
823
- continue
824
- elif not under_json:
825
- continue
826
794
 
827
- if c == '}' and not under_str:
828
- cleaned.write(c)
829
- curly_brackets -= 1
830
- if curly_brackets == 0:
831
- break
832
- elif c == '"' and json_str[i - 1] != '\\':
833
- under_str = not under_str
834
- if under_str:
835
- str_begin = i
836
- else:
837
- assert str_begin > 0
838
- str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
839
- cleaned.write(str_value)
840
- str_begin = -1
841
- elif not under_str:
842
- cleaned.write(c)
843
-
844
- if not under_json:
845
- raise ValueError(f'No JSON dict in the output: {json_str}')
846
-
847
- if curly_brackets > 0:
848
- raise ValueError(
849
- f'Malformated JSON: missing {curly_brackets} closing curly braces.'
850
- )
795
+ def cleanup_json(json_str: str) -> str:
796
+ """Clean up the LM responded JSON string."""
797
+ # Treatments:
798
+ # 1. Extract the JSON string with a top-level dict from the response.
799
+ # This prevents the leading and trailing texts in the response to
800
+ # be counted as part of the JSON.
801
+ # 2. Escape new lines in JSON values.
802
+
803
+ curly_brackets = 0
804
+ under_json = False
805
+ under_str = False
806
+ str_begin = -1
807
+
808
+ cleaned = io.StringIO()
809
+ for i, c in enumerate(json_str):
810
+ if c == '{' and not under_str:
811
+ cleaned.write(c)
812
+ curly_brackets += 1
813
+ under_json = True
814
+ continue
815
+ elif not under_json:
816
+ continue
817
+
818
+ if c == '}' and not under_str:
819
+ cleaned.write(c)
820
+ curly_brackets -= 1
821
+ if curly_brackets == 0:
822
+ break
823
+ elif c == '"' and json_str[i - 1] != '\\':
824
+ under_str = not under_str
825
+ if under_str:
826
+ str_begin = i
827
+ else:
828
+ assert str_begin > 0
829
+ str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
830
+ cleaned.write(str_value)
831
+ str_begin = -1
832
+ elif not under_str:
833
+ cleaned.write(c)
834
+
835
+ if not under_json:
836
+ raise ValueError(f'No JSON dict in the output: {json_str}')
837
+
838
+ if curly_brackets > 0:
839
+ raise ValueError(
840
+ f'Malformated JSON: missing {curly_brackets} closing curly braces.'
841
+ )
851
842
 
852
- return cleaned.getvalue()
843
+ return cleaned.getvalue()
853
844
 
854
845
 
855
846
  def schema_repr(protocol: SchemaProtocol) -> SchemaRepr:
@@ -17,7 +17,7 @@ from typing import Any, Type, Union
17
17
 
18
18
  import langfun.core as lf
19
19
  from langfun.core.structured import mapping
20
- from langfun.core.structured import prompting
20
+ from langfun.core.structured import querying
21
21
  from langfun.core.structured import schema as schema_lib
22
22
  import pyglove as pg
23
23
 
@@ -101,7 +101,7 @@ def score(
101
101
  prompts = []
102
102
  for p in prompt:
103
103
  prompts.append(
104
- prompting.query_prompt(
104
+ querying.query_prompt(
105
105
  p,
106
106
  schema,
107
107
  examples=examples,
@@ -111,7 +111,7 @@ def score(
111
111
  )
112
112
  input_message = prompts
113
113
  else:
114
- input_message = prompting.query_prompt(
114
+ input_message = querying.query_prompt(
115
115
  prompt,
116
116
  schema,
117
117
  examples=examples,
@@ -17,7 +17,7 @@ from typing import Any, Type, Union
17
17
 
18
18
  import langfun.core as lf
19
19
  from langfun.core.structured import mapping
20
- from langfun.core.structured import prompting
20
+ from langfun.core.structured import querying
21
21
  from langfun.core.structured import schema as schema_lib
22
22
  import pyglove as pg
23
23
 
@@ -48,7 +48,7 @@ def tokenize(
48
48
  Returns:
49
49
  A list of (text, token_id) tuples.
50
50
  """
51
- input_message = prompting.query_prompt(
51
+ input_message = querying.query_prompt(
52
52
  prompt,
53
53
  schema,
54
54
  examples=examples,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202412150804
3
+ Version: 0.1.2.dev202412180804
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors