langfun 0.1.2.dev202412150804__py3-none-any.whl → 0.1.2.dev202412180804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +3 -0
- langfun/core/coding/python/correction.py +2 -2
- langfun/core/eval/v2/__init__.py +5 -1
- langfun/core/eval/v2/checkpointing.py +107 -18
- langfun/core/eval/v2/checkpointing_test.py +41 -8
- langfun/core/eval/v2/runners.py +1 -1
- langfun/core/llms/vertexai.py +1 -1
- langfun/core/structured/__init__.py +9 -22
- langfun/core/structured/completion.py +2 -2
- langfun/core/structured/completion_test.py +4 -4
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +3 -3
- langfun/core/structured/parsing.py +9 -9
- langfun/core/structured/parsing_test.py +8 -8
- langfun/core/structured/{prompting.py → querying.py} +256 -77
- langfun/core/structured/{prompting_test.py → querying_test.py} +138 -51
- langfun/core/structured/schema.py +53 -62
- langfun/core/structured/scoring.py +3 -3
- langfun/core/structured/tokenization.py +2 -2
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412180804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412180804.dist-info}/RECORD +25 -25
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412180804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412180804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412180804.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for structured
|
14
|
+
"""Tests for structured query."""
|
15
15
|
|
16
16
|
import inspect
|
17
17
|
import math
|
@@ -23,7 +23,7 @@ from langfun.core import modalities
|
|
23
23
|
from langfun.core.llms import fake
|
24
24
|
from langfun.core.llms.cache import in_memory
|
25
25
|
from langfun.core.structured import mapping
|
26
|
-
from langfun.core.structured import
|
26
|
+
from langfun.core.structured import querying
|
27
27
|
import pyglove as pg
|
28
28
|
|
29
29
|
|
@@ -51,7 +51,7 @@ class QueryTest(unittest.TestCase):
|
|
51
51
|
expected_modalities: int = 0,
|
52
52
|
**kwargs,
|
53
53
|
):
|
54
|
-
m =
|
54
|
+
m = querying.query(
|
55
55
|
prompt, schema=schema, examples=examples,
|
56
56
|
**kwargs, returns_message=True
|
57
57
|
)
|
@@ -67,14 +67,14 @@ class QueryTest(unittest.TestCase):
|
|
67
67
|
|
68
68
|
def test_call(self):
|
69
69
|
lm = fake.StaticSequence(['1'])
|
70
|
-
self.assertEqual(
|
70
|
+
self.assertEqual(querying.query('what is 1 + 0', int, lm=lm), 1)
|
71
71
|
|
72
72
|
# Testing calling the same `lm` without copy.
|
73
73
|
with self.assertRaises(IndexError):
|
74
|
-
|
74
|
+
querying.query('what is 1 + 2', int, lm=lm)
|
75
75
|
|
76
76
|
self.assertEqual(
|
77
|
-
|
77
|
+
querying.query(
|
78
78
|
'what is 1 + 0', int, lm=lm.clone(), returns_message=True
|
79
79
|
),
|
80
80
|
lf.AIMessage(
|
@@ -88,17 +88,17 @@ class QueryTest(unittest.TestCase):
|
|
88
88
|
),
|
89
89
|
)
|
90
90
|
self.assertEqual(
|
91
|
-
|
91
|
+
querying.query(
|
92
92
|
lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
|
93
93
|
),
|
94
94
|
1,
|
95
95
|
)
|
96
96
|
self.assertEqual(
|
97
|
-
|
97
|
+
querying.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
|
98
98
|
1,
|
99
99
|
)
|
100
100
|
self.assertEqual(
|
101
|
-
|
101
|
+
querying.query(
|
102
102
|
'what is {{x}} + {{y}}',
|
103
103
|
x=1,
|
104
104
|
y=0,
|
@@ -107,7 +107,7 @@ class QueryTest(unittest.TestCase):
|
|
107
107
|
'The answer is one.',
|
108
108
|
)
|
109
109
|
self.assertEqual(
|
110
|
-
|
110
|
+
querying.query(
|
111
111
|
Activity.partial(),
|
112
112
|
lm=fake.StaticResponse('Activity(description="hello")'),
|
113
113
|
),
|
@@ -327,13 +327,76 @@ class QueryTest(unittest.TestCase):
|
|
327
327
|
expected_modalities=3,
|
328
328
|
)
|
329
329
|
|
330
|
+
def test_multiple_queries(self):
|
331
|
+
self.assertEqual(
|
332
|
+
querying.query(
|
333
|
+
'Compute 1 + 2',
|
334
|
+
int,
|
335
|
+
lm=[
|
336
|
+
fake.StaticResponse('1'),
|
337
|
+
fake.StaticResponse('2'),
|
338
|
+
],
|
339
|
+
num_samples=[1, 2],
|
340
|
+
),
|
341
|
+
[1, 2, 2]
|
342
|
+
)
|
343
|
+
self.assertEqual(
|
344
|
+
querying.query(
|
345
|
+
'Compute 1 + 2',
|
346
|
+
int,
|
347
|
+
lm=[
|
348
|
+
fake.StaticResponse('1'),
|
349
|
+
fake.StaticResponse('2'),
|
350
|
+
],
|
351
|
+
num_samples=2,
|
352
|
+
),
|
353
|
+
[1, 1, 2, 2]
|
354
|
+
)
|
355
|
+
self.assertEqual(
|
356
|
+
querying.query(
|
357
|
+
'Compute 1 + 2',
|
358
|
+
int,
|
359
|
+
lm=[
|
360
|
+
fake.StaticResponse('1'),
|
361
|
+
fake.StaticResponse('abc'),
|
362
|
+
],
|
363
|
+
num_samples=[1, 2],
|
364
|
+
),
|
365
|
+
[1]
|
366
|
+
)
|
367
|
+
self.assertEqual(
|
368
|
+
querying.query(
|
369
|
+
'Compute 1 + 2',
|
370
|
+
int,
|
371
|
+
default=0,
|
372
|
+
lm=[
|
373
|
+
fake.StaticResponse('1'),
|
374
|
+
fake.StaticResponse('abc'),
|
375
|
+
],
|
376
|
+
num_samples=[1, 2],
|
377
|
+
),
|
378
|
+
[1, 0, 0]
|
379
|
+
)
|
380
|
+
results = querying.query(
|
381
|
+
'Compute 1 + 2',
|
382
|
+
int,
|
383
|
+
default=0,
|
384
|
+
lm=[
|
385
|
+
fake.StaticResponse('1'),
|
386
|
+
fake.StaticResponse('abc'),
|
387
|
+
],
|
388
|
+
returns_message=True,
|
389
|
+
)
|
390
|
+
self.assertEqual([r.text for r in results], ['1', 'abc'])
|
391
|
+
self.assertEqual([r.result for r in results], [1, 0])
|
392
|
+
|
330
393
|
def test_bad_protocol(self):
|
331
394
|
with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
|
332
|
-
|
395
|
+
querying.query('what is 1 + 1', int, protocol='text')
|
333
396
|
|
334
397
|
def test_query_prompt(self):
|
335
398
|
self.assertEqual(
|
336
|
-
|
399
|
+
querying.query_prompt('what is this?', int),
|
337
400
|
inspect.cleandoc("""
|
338
401
|
Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE.
|
339
402
|
|
@@ -368,14 +431,14 @@ class QueryTest(unittest.TestCase):
|
|
368
431
|
def test_query_prompt_with_metadata(self):
|
369
432
|
self.assertIn(
|
370
433
|
'x',
|
371
|
-
|
434
|
+
querying.query_prompt(
|
372
435
|
'what is this?',
|
373
436
|
metadata_x=1
|
374
437
|
).metadata
|
375
438
|
)
|
376
439
|
self.assertIn(
|
377
440
|
'x',
|
378
|
-
|
441
|
+
querying.query_prompt(
|
379
442
|
'what is this?',
|
380
443
|
int,
|
381
444
|
metadata_x=1
|
@@ -383,7 +446,7 @@ class QueryTest(unittest.TestCase):
|
|
383
446
|
)
|
384
447
|
|
385
448
|
def test_query_prompt_with_unrooted_template(self):
|
386
|
-
output =
|
449
|
+
output = querying.query_prompt(
|
387
450
|
pg.Dict(
|
388
451
|
input=lf.Template(
|
389
452
|
'what is {{image}}',
|
@@ -393,9 +456,33 @@ class QueryTest(unittest.TestCase):
|
|
393
456
|
)
|
394
457
|
self.assertIsNotNone(output.get_modality('image'))
|
395
458
|
|
459
|
+
def test_query_and_reduce(self):
|
460
|
+
self.assertEqual(
|
461
|
+
querying.query_and_reduce(
|
462
|
+
'Compute 1 + 1',
|
463
|
+
int,
|
464
|
+
reduce=sum,
|
465
|
+
lm=[
|
466
|
+
fake.StaticResponse('1'),
|
467
|
+
fake.StaticResponse('2'),
|
468
|
+
],
|
469
|
+
num_samples=[1, 2],
|
470
|
+
),
|
471
|
+
5
|
472
|
+
)
|
473
|
+
self.assertEqual(
|
474
|
+
querying.query_and_reduce(
|
475
|
+
'Compute 1 + 1',
|
476
|
+
int,
|
477
|
+
reduce=sum,
|
478
|
+
lm=fake.StaticResponse('2'),
|
479
|
+
),
|
480
|
+
2
|
481
|
+
)
|
482
|
+
|
396
483
|
def test_query_output(self):
|
397
484
|
self.assertEqual(
|
398
|
-
|
485
|
+
querying.query_output(
|
399
486
|
lf.AIMessage('1'),
|
400
487
|
int,
|
401
488
|
),
|
@@ -414,7 +501,7 @@ class QueryTest(unittest.TestCase):
|
|
414
501
|
|
415
502
|
# Case 1: Reward function based on input and output.
|
416
503
|
self.assertEqual(
|
417
|
-
|
504
|
+
querying.query_reward(
|
418
505
|
mapping.MappingExample(
|
419
506
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
420
507
|
schema=Answer,
|
@@ -425,7 +512,7 @@ class QueryTest(unittest.TestCase):
|
|
425
512
|
1.0
|
426
513
|
)
|
427
514
|
self.assertEqual(
|
428
|
-
|
515
|
+
querying.query_reward(
|
429
516
|
mapping.MappingExample(
|
430
517
|
input=lf.Template('{{x}} + {{y}}', x=2, y=3),
|
431
518
|
output=Answer(final_answer=2),
|
@@ -445,7 +532,7 @@ class QueryTest(unittest.TestCase):
|
|
445
532
|
)
|
446
533
|
|
447
534
|
self.assertEqual(
|
448
|
-
|
535
|
+
querying.query_reward(
|
449
536
|
mapping.MappingExample(
|
450
537
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
451
538
|
output=Answer2(final_answer=2),
|
@@ -470,7 +557,7 @@ class QueryTest(unittest.TestCase):
|
|
470
557
|
) * metadata['weight']
|
471
558
|
|
472
559
|
self.assertEqual(
|
473
|
-
|
560
|
+
querying.query_reward(
|
474
561
|
mapping.MappingExample(
|
475
562
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
476
563
|
output=Answer3(final_answer=2),
|
@@ -486,7 +573,7 @@ class QueryTest(unittest.TestCase):
|
|
486
573
|
final_answer: int
|
487
574
|
|
488
575
|
self.assertIsNone(
|
489
|
-
|
576
|
+
querying.query_reward(
|
490
577
|
mapping.MappingExample(
|
491
578
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
492
579
|
output=Answer4(final_answer=2),
|
@@ -497,7 +584,7 @@ class QueryTest(unittest.TestCase):
|
|
497
584
|
|
498
585
|
# Case 5: Not a structured output.
|
499
586
|
self.assertIsNone(
|
500
|
-
|
587
|
+
querying.query_reward(
|
501
588
|
mapping.MappingExample(
|
502
589
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
503
590
|
output='2',
|
@@ -516,7 +603,7 @@ class QueryTest(unittest.TestCase):
|
|
516
603
|
with self.assertRaisesRegex(
|
517
604
|
TypeError, '.*Answer5.__reward__` should have signature'
|
518
605
|
):
|
519
|
-
|
606
|
+
querying.query_reward(
|
520
607
|
mapping.MappingExample(
|
521
608
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
522
609
|
output=Answer5(final_answer=2),
|
@@ -528,7 +615,7 @@ class QueryTest(unittest.TestCase):
|
|
528
615
|
class QueryStructurePythonTest(unittest.TestCase):
|
529
616
|
|
530
617
|
def test_render_no_examples(self):
|
531
|
-
l =
|
618
|
+
l = querying._QueryStructurePython(
|
532
619
|
input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
|
533
620
|
)
|
534
621
|
self.assertEqual(
|
@@ -565,7 +652,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
565
652
|
)
|
566
653
|
|
567
654
|
def test_render(self):
|
568
|
-
l =
|
655
|
+
l = querying._QueryStructurePython(
|
569
656
|
input=lf.AIMessage('Compute 12 / 6 + 2.'),
|
570
657
|
schema=int,
|
571
658
|
examples=[
|
@@ -675,7 +762,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
675
762
|
),
|
676
763
|
override_attrs=True,
|
677
764
|
):
|
678
|
-
l =
|
765
|
+
l = querying._QueryStructurePython(
|
679
766
|
input=lm_input,
|
680
767
|
schema=[Itinerary],
|
681
768
|
examples=[
|
@@ -712,7 +799,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
712
799
|
mapping.MappingError,
|
713
800
|
'name .* is not defined',
|
714
801
|
):
|
715
|
-
|
802
|
+
querying.query('Compute 1 + 2', int)
|
716
803
|
|
717
804
|
def test_autofix(self):
|
718
805
|
lm = fake.StaticSequence([
|
@@ -723,7 +810,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
723
810
|
)
|
724
811
|
"""),
|
725
812
|
])
|
726
|
-
self.assertEqual(
|
813
|
+
self.assertEqual(querying.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
|
727
814
|
|
728
815
|
def test_response_postprocess(self):
|
729
816
|
with lf.context(
|
@@ -731,12 +818,12 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
731
818
|
override_attrs=True,
|
732
819
|
):
|
733
820
|
self.assertEqual(
|
734
|
-
|
821
|
+
querying.query(
|
735
822
|
'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
|
736
823
|
'3'
|
737
824
|
)
|
738
825
|
self.assertEqual(
|
739
|
-
|
826
|
+
querying.query(
|
740
827
|
'Compute 1 + 2', int,
|
741
828
|
response_postprocess=lambda x: x.split('\n')[1]),
|
742
829
|
3
|
@@ -746,7 +833,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
746
833
|
class QueryStructureJsonTest(unittest.TestCase):
|
747
834
|
|
748
835
|
def test_render_no_examples(self):
|
749
|
-
l =
|
836
|
+
l = querying._QueryStructureJson(
|
750
837
|
input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
|
751
838
|
)
|
752
839
|
self.assertEqual(
|
@@ -762,10 +849,10 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
762
849
|
1 + 1 =
|
763
850
|
|
764
851
|
SCHEMA:
|
765
|
-
{"result": {"_type": "langfun.core.structured.
|
852
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
766
853
|
|
767
854
|
JSON:
|
768
|
-
{"result": {"_type": "langfun.core.structured.
|
855
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
769
856
|
|
770
857
|
INPUT_OBJECT:
|
771
858
|
Compute 12 / 6 + 2.
|
@@ -778,7 +865,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
778
865
|
)
|
779
866
|
|
780
867
|
def test_render(self):
|
781
|
-
l =
|
868
|
+
l = querying._QueryStructureJson(
|
782
869
|
input=lf.AIMessage('Compute 12 / 6 + 2.'),
|
783
870
|
schema=int,
|
784
871
|
examples=[
|
@@ -799,10 +886,10 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
799
886
|
1 + 1 =
|
800
887
|
|
801
888
|
SCHEMA:
|
802
|
-
{"result": {"_type": "langfun.core.structured.
|
889
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
803
890
|
|
804
891
|
JSON:
|
805
|
-
{"result": {"_type": "langfun.core.structured.
|
892
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
806
893
|
|
807
894
|
INPUT_OBJECT:
|
808
895
|
What is the answer of 1 plus 1?
|
@@ -913,7 +1000,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
913
1000
|
),
|
914
1001
|
override_attrs=True,
|
915
1002
|
):
|
916
|
-
l =
|
1003
|
+
l = querying._QueryStructureJson(
|
917
1004
|
input=lm_input,
|
918
1005
|
schema=[Itinerary],
|
919
1006
|
examples=[
|
@@ -951,14 +1038,14 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
951
1038
|
mapping.MappingError,
|
952
1039
|
'No JSON dict in the output',
|
953
1040
|
):
|
954
|
-
|
1041
|
+
querying.query('Compute 1 + 2', int, protocol='json', cache_seed=1)
|
955
1042
|
# Make sure bad mapping does not impact cache.
|
956
1043
|
self.assertEqual(len(cache), 0)
|
957
1044
|
|
958
1045
|
def test_query(self):
|
959
1046
|
lm = fake.StaticSequence(['{"result": 1}'])
|
960
1047
|
self.assertEqual(
|
961
|
-
|
1048
|
+
querying.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
|
962
1049
|
)
|
963
1050
|
|
964
1051
|
|
@@ -968,8 +1055,8 @@ class QueryInvocationTest(unittest.TestCase):
|
|
968
1055
|
lm = fake.StaticSequence([
|
969
1056
|
'Activity(description="hi")',
|
970
1057
|
])
|
971
|
-
with
|
972
|
-
|
1058
|
+
with querying.track_queries() as queries:
|
1059
|
+
querying.query('foo', Activity, lm=lm)
|
973
1060
|
|
974
1061
|
self.assertIn('schema', queries[0].to_html_str())
|
975
1062
|
|
@@ -981,10 +1068,10 @@ class TrackQueriesTest(unittest.TestCase):
|
|
981
1068
|
'bar',
|
982
1069
|
'Activity(description="hi")',
|
983
1070
|
])
|
984
|
-
with
|
985
|
-
|
986
|
-
with
|
987
|
-
|
1071
|
+
with querying.track_queries() as queries:
|
1072
|
+
querying.query('foo', lm=lm)
|
1073
|
+
with querying.track_queries() as child_queries:
|
1074
|
+
querying.query('give me an activity', Activity, lm=lm)
|
988
1075
|
|
989
1076
|
self.assertEqual(len(queries), 2)
|
990
1077
|
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
@@ -1008,10 +1095,10 @@ class TrackQueriesTest(unittest.TestCase):
|
|
1008
1095
|
'bar',
|
1009
1096
|
'Activity(description="hi")',
|
1010
1097
|
])
|
1011
|
-
with
|
1012
|
-
|
1013
|
-
with
|
1014
|
-
|
1098
|
+
with querying.track_queries(include_child_scopes=False) as queries:
|
1099
|
+
querying.query('foo', lm=lm)
|
1100
|
+
with querying.track_queries(include_child_scopes=False) as child_queries:
|
1101
|
+
querying.query('give me an activity', Activity, lm=lm)
|
1015
1102
|
|
1016
1103
|
self.assertEqual(len(queries), 1)
|
1017
1104
|
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
@@ -1030,13 +1117,13 @@ class TrackQueriesTest(unittest.TestCase):
|
|
1030
1117
|
def test_concurrent_map(self):
|
1031
1118
|
|
1032
1119
|
def make_query(prompt):
|
1033
|
-
_ =
|
1120
|
+
_ = querying.query(prompt, lm=lm)
|
1034
1121
|
|
1035
1122
|
lm = fake.StaticSequence([
|
1036
1123
|
'foo',
|
1037
1124
|
'bar',
|
1038
1125
|
])
|
1039
|
-
with
|
1126
|
+
with querying.track_queries() as queries:
|
1040
1127
|
list(lf.concurrent_map(make_query, ['a', 'b']))
|
1041
1128
|
self.assertEqual(len(queries), 2)
|
1042
1129
|
|
@@ -213,18 +213,8 @@ class Schema(
|
|
213
213
|
"""
|
214
214
|
)
|
215
215
|
|
216
|
-
|
217
|
-
|
218
|
-
*,
|
219
|
-
view: pg.views.HtmlTreeView,
|
220
|
-
content: pg.Html | str | None = None,
|
221
|
-
**kwargs,
|
222
|
-
):
|
223
|
-
return view.tooltip(
|
224
|
-
self,
|
225
|
-
content=content or pg.Html.escape(self.schema_str(protocol='python')),
|
226
|
-
**kwargs
|
227
|
-
)
|
216
|
+
|
217
|
+
SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
|
228
218
|
|
229
219
|
|
230
220
|
def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
|
@@ -388,9 +378,9 @@ class SchemaPythonRepr(SchemaRepr):
|
|
388
378
|
return annotation(schema.spec)
|
389
379
|
|
390
380
|
|
391
|
-
def source_form(value, markdown: bool = False) -> str:
|
381
|
+
def source_form(value, compact: bool = True, markdown: bool = False) -> str:
|
392
382
|
"""Returns the source code form of an object."""
|
393
|
-
return ValuePythonRepr().repr(value, markdown=markdown)
|
383
|
+
return ValuePythonRepr().repr(value, compact=compact, markdown=markdown)
|
394
384
|
|
395
385
|
|
396
386
|
def class_definitions(
|
@@ -789,7 +779,7 @@ class ValueJsonRepr(ValueRepr):
|
|
789
779
|
"""Parse a JSON string into a structured object."""
|
790
780
|
del schema
|
791
781
|
try:
|
792
|
-
text =
|
782
|
+
text = cleanup_json(text)
|
793
783
|
v = pg.from_json_str(text, **kwargs)
|
794
784
|
except Exception as e:
|
795
785
|
raise JsonError(text, e) # pylint: disable=raise-missing-from
|
@@ -801,55 +791,56 @@ class ValueJsonRepr(ValueRepr):
|
|
801
791
|
))
|
802
792
|
return v['result']
|
803
793
|
|
804
|
-
def cleanup_json(self, json_str: str) -> str:
|
805
|
-
"""Clean up the LM responded JSON string."""
|
806
|
-
# Treatments:
|
807
|
-
# 1. Extract the JSON string with a top-level dict from the response.
|
808
|
-
# This prevents the leading and trailing texts in the response to
|
809
|
-
# be counted as part of the JSON.
|
810
|
-
# 2. Escape new lines in JSON values.
|
811
|
-
|
812
|
-
curly_brackets = 0
|
813
|
-
under_json = False
|
814
|
-
under_str = False
|
815
|
-
str_begin = -1
|
816
|
-
|
817
|
-
cleaned = io.StringIO()
|
818
|
-
for i, c in enumerate(json_str):
|
819
|
-
if c == '{' and not under_str:
|
820
|
-
cleaned.write(c)
|
821
|
-
curly_brackets += 1
|
822
|
-
under_json = True
|
823
|
-
continue
|
824
|
-
elif not under_json:
|
825
|
-
continue
|
826
794
|
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
795
|
+
def cleanup_json(json_str: str) -> str:
|
796
|
+
"""Clean up the LM responded JSON string."""
|
797
|
+
# Treatments:
|
798
|
+
# 1. Extract the JSON string with a top-level dict from the response.
|
799
|
+
# This prevents the leading and trailing texts in the response to
|
800
|
+
# be counted as part of the JSON.
|
801
|
+
# 2. Escape new lines in JSON values.
|
802
|
+
|
803
|
+
curly_brackets = 0
|
804
|
+
under_json = False
|
805
|
+
under_str = False
|
806
|
+
str_begin = -1
|
807
|
+
|
808
|
+
cleaned = io.StringIO()
|
809
|
+
for i, c in enumerate(json_str):
|
810
|
+
if c == '{' and not under_str:
|
811
|
+
cleaned.write(c)
|
812
|
+
curly_brackets += 1
|
813
|
+
under_json = True
|
814
|
+
continue
|
815
|
+
elif not under_json:
|
816
|
+
continue
|
817
|
+
|
818
|
+
if c == '}' and not under_str:
|
819
|
+
cleaned.write(c)
|
820
|
+
curly_brackets -= 1
|
821
|
+
if curly_brackets == 0:
|
822
|
+
break
|
823
|
+
elif c == '"' and json_str[i - 1] != '\\':
|
824
|
+
under_str = not under_str
|
825
|
+
if under_str:
|
826
|
+
str_begin = i
|
827
|
+
else:
|
828
|
+
assert str_begin > 0
|
829
|
+
str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
|
830
|
+
cleaned.write(str_value)
|
831
|
+
str_begin = -1
|
832
|
+
elif not under_str:
|
833
|
+
cleaned.write(c)
|
834
|
+
|
835
|
+
if not under_json:
|
836
|
+
raise ValueError(f'No JSON dict in the output: {json_str}')
|
837
|
+
|
838
|
+
if curly_brackets > 0:
|
839
|
+
raise ValueError(
|
840
|
+
f'Malformated JSON: missing {curly_brackets} closing curly braces.'
|
841
|
+
)
|
851
842
|
|
852
|
-
|
843
|
+
return cleaned.getvalue()
|
853
844
|
|
854
845
|
|
855
846
|
def schema_repr(protocol: SchemaProtocol) -> SchemaRepr:
|
@@ -17,7 +17,7 @@ from typing import Any, Type, Union
|
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.structured import mapping
|
20
|
-
from langfun.core.structured import
|
20
|
+
from langfun.core.structured import querying
|
21
21
|
from langfun.core.structured import schema as schema_lib
|
22
22
|
import pyglove as pg
|
23
23
|
|
@@ -101,7 +101,7 @@ def score(
|
|
101
101
|
prompts = []
|
102
102
|
for p in prompt:
|
103
103
|
prompts.append(
|
104
|
-
|
104
|
+
querying.query_prompt(
|
105
105
|
p,
|
106
106
|
schema,
|
107
107
|
examples=examples,
|
@@ -111,7 +111,7 @@ def score(
|
|
111
111
|
)
|
112
112
|
input_message = prompts
|
113
113
|
else:
|
114
|
-
input_message =
|
114
|
+
input_message = querying.query_prompt(
|
115
115
|
prompt,
|
116
116
|
schema,
|
117
117
|
examples=examples,
|
@@ -17,7 +17,7 @@ from typing import Any, Type, Union
|
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.structured import mapping
|
20
|
-
from langfun.core.structured import
|
20
|
+
from langfun.core.structured import querying
|
21
21
|
from langfun.core.structured import schema as schema_lib
|
22
22
|
import pyglove as pg
|
23
23
|
|
@@ -48,7 +48,7 @@ def tokenize(
|
|
48
48
|
Returns:
|
49
49
|
A list of (text, token_id) tuples.
|
50
50
|
"""
|
51
|
-
input_message =
|
51
|
+
input_message = querying.query_prompt(
|
52
52
|
prompt,
|
53
53
|
schema,
|
54
54
|
examples=examples,
|