langfun 0.1.2.dev202412150804__py3-none-any.whl → 0.1.2.dev202412170805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/coding/python/correction.py +2 -2
- langfun/core/eval/v2/__init__.py +5 -1
- langfun/core/eval/v2/checkpointing.py +96 -16
- langfun/core/eval/v2/checkpointing_test.py +41 -8
- langfun/core/eval/v2/runners.py +1 -1
- langfun/core/structured/__init__.py +7 -22
- langfun/core/structured/completion.py +2 -2
- langfun/core/structured/completion_test.py +4 -4
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +3 -3
- langfun/core/structured/parsing.py +9 -9
- langfun/core/structured/parsing_test.py +8 -8
- langfun/core/structured/{prompting.py → querying.py} +9 -9
- langfun/core/structured/{prompting_test.py → querying_test.py} +51 -51
- langfun/core/structured/schema.py +51 -50
- langfun/core/structured/scoring.py +3 -3
- langfun/core/structured/tokenization.py +2 -2
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/RECORD +23 -23
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412150804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""
|
14
|
+
"""Query LLM for structured output."""
|
15
15
|
|
16
16
|
import contextlib
|
17
17
|
import functools
|
@@ -26,7 +26,7 @@ import pyglove as pg
|
|
26
26
|
|
27
27
|
|
28
28
|
@lf.use_init_args(['schema', 'default', 'examples'])
|
29
|
-
class
|
29
|
+
class _QueryStructure(mapping.Mapping):
|
30
30
|
"""Query an object out from a natural language text."""
|
31
31
|
|
32
32
|
context_title = 'CONTEXT'
|
@@ -38,7 +38,7 @@ class QueryStructure(mapping.Mapping):
|
|
38
38
|
]
|
39
39
|
|
40
40
|
|
41
|
-
class
|
41
|
+
class _QueryStructureJson(_QueryStructure):
|
42
42
|
"""Query a structured value using JSON as the protocol."""
|
43
43
|
|
44
44
|
preamble = """
|
@@ -52,10 +52,10 @@ class QueryStructureJson(QueryStructure):
|
|
52
52
|
1 + 1 =
|
53
53
|
|
54
54
|
{{ schema_title }}:
|
55
|
-
{"result": {"_type": "langfun.core.structured.
|
55
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
56
56
|
|
57
57
|
{{ output_title}}:
|
58
|
-
{"result": {"_type": "langfun.core.structured.
|
58
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
59
59
|
"""
|
60
60
|
|
61
61
|
protocol = 'json'
|
@@ -63,7 +63,7 @@ class QueryStructureJson(QueryStructure):
|
|
63
63
|
output_title = 'JSON'
|
64
64
|
|
65
65
|
|
66
|
-
class
|
66
|
+
class _QueryStructurePython(_QueryStructure):
|
67
67
|
"""Query a structured value using Python as the protocol."""
|
68
68
|
|
69
69
|
preamble = """
|
@@ -94,11 +94,11 @@ class QueryStructurePython(QueryStructure):
|
|
94
94
|
|
95
95
|
def _query_structure_cls(
|
96
96
|
protocol: schema_lib.SchemaProtocol,
|
97
|
-
) -> Type[
|
97
|
+
) -> Type[_QueryStructure]:
|
98
98
|
if protocol == 'json':
|
99
|
-
return
|
99
|
+
return _QueryStructureJson
|
100
100
|
elif protocol == 'python':
|
101
|
-
return
|
101
|
+
return _QueryStructurePython
|
102
102
|
else:
|
103
103
|
raise ValueError(f'Unknown protocol: {protocol!r}.')
|
104
104
|
|
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for structured
|
14
|
+
"""Tests for structured query."""
|
15
15
|
|
16
16
|
import inspect
|
17
17
|
import math
|
@@ -23,7 +23,7 @@ from langfun.core import modalities
|
|
23
23
|
from langfun.core.llms import fake
|
24
24
|
from langfun.core.llms.cache import in_memory
|
25
25
|
from langfun.core.structured import mapping
|
26
|
-
from langfun.core.structured import
|
26
|
+
from langfun.core.structured import querying
|
27
27
|
import pyglove as pg
|
28
28
|
|
29
29
|
|
@@ -51,7 +51,7 @@ class QueryTest(unittest.TestCase):
|
|
51
51
|
expected_modalities: int = 0,
|
52
52
|
**kwargs,
|
53
53
|
):
|
54
|
-
m =
|
54
|
+
m = querying.query(
|
55
55
|
prompt, schema=schema, examples=examples,
|
56
56
|
**kwargs, returns_message=True
|
57
57
|
)
|
@@ -67,14 +67,14 @@ class QueryTest(unittest.TestCase):
|
|
67
67
|
|
68
68
|
def test_call(self):
|
69
69
|
lm = fake.StaticSequence(['1'])
|
70
|
-
self.assertEqual(
|
70
|
+
self.assertEqual(querying.query('what is 1 + 0', int, lm=lm), 1)
|
71
71
|
|
72
72
|
# Testing calling the same `lm` without copy.
|
73
73
|
with self.assertRaises(IndexError):
|
74
|
-
|
74
|
+
querying.query('what is 1 + 2', int, lm=lm)
|
75
75
|
|
76
76
|
self.assertEqual(
|
77
|
-
|
77
|
+
querying.query(
|
78
78
|
'what is 1 + 0', int, lm=lm.clone(), returns_message=True
|
79
79
|
),
|
80
80
|
lf.AIMessage(
|
@@ -88,17 +88,17 @@ class QueryTest(unittest.TestCase):
|
|
88
88
|
),
|
89
89
|
)
|
90
90
|
self.assertEqual(
|
91
|
-
|
91
|
+
querying.query(
|
92
92
|
lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
|
93
93
|
),
|
94
94
|
1,
|
95
95
|
)
|
96
96
|
self.assertEqual(
|
97
|
-
|
97
|
+
querying.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()),
|
98
98
|
1,
|
99
99
|
)
|
100
100
|
self.assertEqual(
|
101
|
-
|
101
|
+
querying.query(
|
102
102
|
'what is {{x}} + {{y}}',
|
103
103
|
x=1,
|
104
104
|
y=0,
|
@@ -107,7 +107,7 @@ class QueryTest(unittest.TestCase):
|
|
107
107
|
'The answer is one.',
|
108
108
|
)
|
109
109
|
self.assertEqual(
|
110
|
-
|
110
|
+
querying.query(
|
111
111
|
Activity.partial(),
|
112
112
|
lm=fake.StaticResponse('Activity(description="hello")'),
|
113
113
|
),
|
@@ -329,11 +329,11 @@ class QueryTest(unittest.TestCase):
|
|
329
329
|
|
330
330
|
def test_bad_protocol(self):
|
331
331
|
with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
|
332
|
-
|
332
|
+
querying.query('what is 1 + 1', int, protocol='text')
|
333
333
|
|
334
334
|
def test_query_prompt(self):
|
335
335
|
self.assertEqual(
|
336
|
-
|
336
|
+
querying.query_prompt('what is this?', int),
|
337
337
|
inspect.cleandoc("""
|
338
338
|
Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE.
|
339
339
|
|
@@ -368,14 +368,14 @@ class QueryTest(unittest.TestCase):
|
|
368
368
|
def test_query_prompt_with_metadata(self):
|
369
369
|
self.assertIn(
|
370
370
|
'x',
|
371
|
-
|
371
|
+
querying.query_prompt(
|
372
372
|
'what is this?',
|
373
373
|
metadata_x=1
|
374
374
|
).metadata
|
375
375
|
)
|
376
376
|
self.assertIn(
|
377
377
|
'x',
|
378
|
-
|
378
|
+
querying.query_prompt(
|
379
379
|
'what is this?',
|
380
380
|
int,
|
381
381
|
metadata_x=1
|
@@ -383,7 +383,7 @@ class QueryTest(unittest.TestCase):
|
|
383
383
|
)
|
384
384
|
|
385
385
|
def test_query_prompt_with_unrooted_template(self):
|
386
|
-
output =
|
386
|
+
output = querying.query_prompt(
|
387
387
|
pg.Dict(
|
388
388
|
input=lf.Template(
|
389
389
|
'what is {{image}}',
|
@@ -395,7 +395,7 @@ class QueryTest(unittest.TestCase):
|
|
395
395
|
|
396
396
|
def test_query_output(self):
|
397
397
|
self.assertEqual(
|
398
|
-
|
398
|
+
querying.query_output(
|
399
399
|
lf.AIMessage('1'),
|
400
400
|
int,
|
401
401
|
),
|
@@ -414,7 +414,7 @@ class QueryTest(unittest.TestCase):
|
|
414
414
|
|
415
415
|
# Case 1: Reward function based on input and output.
|
416
416
|
self.assertEqual(
|
417
|
-
|
417
|
+
querying.query_reward(
|
418
418
|
mapping.MappingExample(
|
419
419
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
420
420
|
schema=Answer,
|
@@ -425,7 +425,7 @@ class QueryTest(unittest.TestCase):
|
|
425
425
|
1.0
|
426
426
|
)
|
427
427
|
self.assertEqual(
|
428
|
-
|
428
|
+
querying.query_reward(
|
429
429
|
mapping.MappingExample(
|
430
430
|
input=lf.Template('{{x}} + {{y}}', x=2, y=3),
|
431
431
|
output=Answer(final_answer=2),
|
@@ -445,7 +445,7 @@ class QueryTest(unittest.TestCase):
|
|
445
445
|
)
|
446
446
|
|
447
447
|
self.assertEqual(
|
448
|
-
|
448
|
+
querying.query_reward(
|
449
449
|
mapping.MappingExample(
|
450
450
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
451
451
|
output=Answer2(final_answer=2),
|
@@ -470,7 +470,7 @@ class QueryTest(unittest.TestCase):
|
|
470
470
|
) * metadata['weight']
|
471
471
|
|
472
472
|
self.assertEqual(
|
473
|
-
|
473
|
+
querying.query_reward(
|
474
474
|
mapping.MappingExample(
|
475
475
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
476
476
|
output=Answer3(final_answer=2),
|
@@ -486,7 +486,7 @@ class QueryTest(unittest.TestCase):
|
|
486
486
|
final_answer: int
|
487
487
|
|
488
488
|
self.assertIsNone(
|
489
|
-
|
489
|
+
querying.query_reward(
|
490
490
|
mapping.MappingExample(
|
491
491
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
492
492
|
output=Answer4(final_answer=2),
|
@@ -497,7 +497,7 @@ class QueryTest(unittest.TestCase):
|
|
497
497
|
|
498
498
|
# Case 5: Not a structured output.
|
499
499
|
self.assertIsNone(
|
500
|
-
|
500
|
+
querying.query_reward(
|
501
501
|
mapping.MappingExample(
|
502
502
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
503
503
|
output='2',
|
@@ -516,7 +516,7 @@ class QueryTest(unittest.TestCase):
|
|
516
516
|
with self.assertRaisesRegex(
|
517
517
|
TypeError, '.*Answer5.__reward__` should have signature'
|
518
518
|
):
|
519
|
-
|
519
|
+
querying.query_reward(
|
520
520
|
mapping.MappingExample(
|
521
521
|
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
522
522
|
output=Answer5(final_answer=2),
|
@@ -528,7 +528,7 @@ class QueryTest(unittest.TestCase):
|
|
528
528
|
class QueryStructurePythonTest(unittest.TestCase):
|
529
529
|
|
530
530
|
def test_render_no_examples(self):
|
531
|
-
l =
|
531
|
+
l = querying._QueryStructurePython(
|
532
532
|
input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
|
533
533
|
)
|
534
534
|
self.assertEqual(
|
@@ -565,7 +565,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
565
565
|
)
|
566
566
|
|
567
567
|
def test_render(self):
|
568
|
-
l =
|
568
|
+
l = querying._QueryStructurePython(
|
569
569
|
input=lf.AIMessage('Compute 12 / 6 + 2.'),
|
570
570
|
schema=int,
|
571
571
|
examples=[
|
@@ -675,7 +675,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
675
675
|
),
|
676
676
|
override_attrs=True,
|
677
677
|
):
|
678
|
-
l =
|
678
|
+
l = querying._QueryStructurePython(
|
679
679
|
input=lm_input,
|
680
680
|
schema=[Itinerary],
|
681
681
|
examples=[
|
@@ -712,7 +712,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
712
712
|
mapping.MappingError,
|
713
713
|
'name .* is not defined',
|
714
714
|
):
|
715
|
-
|
715
|
+
querying.query('Compute 1 + 2', int)
|
716
716
|
|
717
717
|
def test_autofix(self):
|
718
718
|
lm = fake.StaticSequence([
|
@@ -723,7 +723,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
723
723
|
)
|
724
724
|
"""),
|
725
725
|
])
|
726
|
-
self.assertEqual(
|
726
|
+
self.assertEqual(querying.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
|
727
727
|
|
728
728
|
def test_response_postprocess(self):
|
729
729
|
with lf.context(
|
@@ -731,12 +731,12 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
731
731
|
override_attrs=True,
|
732
732
|
):
|
733
733
|
self.assertEqual(
|
734
|
-
|
734
|
+
querying.query(
|
735
735
|
'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
|
736
736
|
'3'
|
737
737
|
)
|
738
738
|
self.assertEqual(
|
739
|
-
|
739
|
+
querying.query(
|
740
740
|
'Compute 1 + 2', int,
|
741
741
|
response_postprocess=lambda x: x.split('\n')[1]),
|
742
742
|
3
|
@@ -746,7 +746,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
746
746
|
class QueryStructureJsonTest(unittest.TestCase):
|
747
747
|
|
748
748
|
def test_render_no_examples(self):
|
749
|
-
l =
|
749
|
+
l = querying._QueryStructureJson(
|
750
750
|
input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int
|
751
751
|
)
|
752
752
|
self.assertEqual(
|
@@ -762,10 +762,10 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
762
762
|
1 + 1 =
|
763
763
|
|
764
764
|
SCHEMA:
|
765
|
-
{"result": {"_type": "langfun.core.structured.
|
765
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
766
766
|
|
767
767
|
JSON:
|
768
|
-
{"result": {"_type": "langfun.core.structured.
|
768
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
769
769
|
|
770
770
|
INPUT_OBJECT:
|
771
771
|
Compute 12 / 6 + 2.
|
@@ -778,7 +778,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
778
778
|
)
|
779
779
|
|
780
780
|
def test_render(self):
|
781
|
-
l =
|
781
|
+
l = querying._QueryStructureJson(
|
782
782
|
input=lf.AIMessage('Compute 12 / 6 + 2.'),
|
783
783
|
schema=int,
|
784
784
|
examples=[
|
@@ -799,10 +799,10 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
799
799
|
1 + 1 =
|
800
800
|
|
801
801
|
SCHEMA:
|
802
|
-
{"result": {"_type": "langfun.core.structured.
|
802
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
803
803
|
|
804
804
|
JSON:
|
805
|
-
{"result": {"_type": "langfun.core.structured.
|
805
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
806
806
|
|
807
807
|
INPUT_OBJECT:
|
808
808
|
What is the answer of 1 plus 1?
|
@@ -913,7 +913,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
913
913
|
),
|
914
914
|
override_attrs=True,
|
915
915
|
):
|
916
|
-
l =
|
916
|
+
l = querying._QueryStructureJson(
|
917
917
|
input=lm_input,
|
918
918
|
schema=[Itinerary],
|
919
919
|
examples=[
|
@@ -951,14 +951,14 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
951
951
|
mapping.MappingError,
|
952
952
|
'No JSON dict in the output',
|
953
953
|
):
|
954
|
-
|
954
|
+
querying.query('Compute 1 + 2', int, protocol='json', cache_seed=1)
|
955
955
|
# Make sure bad mapping does not impact cache.
|
956
956
|
self.assertEqual(len(cache), 0)
|
957
957
|
|
958
958
|
def test_query(self):
|
959
959
|
lm = fake.StaticSequence(['{"result": 1}'])
|
960
960
|
self.assertEqual(
|
961
|
-
|
961
|
+
querying.query('what is 1 + 0', int, lm=lm, protocol='json'), 1
|
962
962
|
)
|
963
963
|
|
964
964
|
|
@@ -968,8 +968,8 @@ class QueryInvocationTest(unittest.TestCase):
|
|
968
968
|
lm = fake.StaticSequence([
|
969
969
|
'Activity(description="hi")',
|
970
970
|
])
|
971
|
-
with
|
972
|
-
|
971
|
+
with querying.track_queries() as queries:
|
972
|
+
querying.query('foo', Activity, lm=lm)
|
973
973
|
|
974
974
|
self.assertIn('schema', queries[0].to_html_str())
|
975
975
|
|
@@ -981,10 +981,10 @@ class TrackQueriesTest(unittest.TestCase):
|
|
981
981
|
'bar',
|
982
982
|
'Activity(description="hi")',
|
983
983
|
])
|
984
|
-
with
|
985
|
-
|
986
|
-
with
|
987
|
-
|
984
|
+
with querying.track_queries() as queries:
|
985
|
+
querying.query('foo', lm=lm)
|
986
|
+
with querying.track_queries() as child_queries:
|
987
|
+
querying.query('give me an activity', Activity, lm=lm)
|
988
988
|
|
989
989
|
self.assertEqual(len(queries), 2)
|
990
990
|
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
@@ -1008,10 +1008,10 @@ class TrackQueriesTest(unittest.TestCase):
|
|
1008
1008
|
'bar',
|
1009
1009
|
'Activity(description="hi")',
|
1010
1010
|
])
|
1011
|
-
with
|
1012
|
-
|
1013
|
-
with
|
1014
|
-
|
1011
|
+
with querying.track_queries(include_child_scopes=False) as queries:
|
1012
|
+
querying.query('foo', lm=lm)
|
1013
|
+
with querying.track_queries(include_child_scopes=False) as child_queries:
|
1014
|
+
querying.query('give me an activity', Activity, lm=lm)
|
1015
1015
|
|
1016
1016
|
self.assertEqual(len(queries), 1)
|
1017
1017
|
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
@@ -1030,13 +1030,13 @@ class TrackQueriesTest(unittest.TestCase):
|
|
1030
1030
|
def test_concurrent_map(self):
|
1031
1031
|
|
1032
1032
|
def make_query(prompt):
|
1033
|
-
_ =
|
1033
|
+
_ = querying.query(prompt, lm=lm)
|
1034
1034
|
|
1035
1035
|
lm = fake.StaticSequence([
|
1036
1036
|
'foo',
|
1037
1037
|
'bar',
|
1038
1038
|
])
|
1039
|
-
with
|
1039
|
+
with querying.track_queries() as queries:
|
1040
1040
|
list(lf.concurrent_map(make_query, ['a', 'b']))
|
1041
1041
|
self.assertEqual(len(queries), 2)
|
1042
1042
|
|
@@ -388,9 +388,9 @@ class SchemaPythonRepr(SchemaRepr):
|
|
388
388
|
return annotation(schema.spec)
|
389
389
|
|
390
390
|
|
391
|
-
def source_form(value, markdown: bool = False) -> str:
|
391
|
+
def source_form(value, compact: bool = True, markdown: bool = False) -> str:
|
392
392
|
"""Returns the source code form of an object."""
|
393
|
-
return ValuePythonRepr().repr(value, markdown=markdown)
|
393
|
+
return ValuePythonRepr().repr(value, compact=compact, markdown=markdown)
|
394
394
|
|
395
395
|
|
396
396
|
def class_definitions(
|
@@ -789,7 +789,7 @@ class ValueJsonRepr(ValueRepr):
|
|
789
789
|
"""Parse a JSON string into a structured object."""
|
790
790
|
del schema
|
791
791
|
try:
|
792
|
-
text =
|
792
|
+
text = cleanup_json(text)
|
793
793
|
v = pg.from_json_str(text, **kwargs)
|
794
794
|
except Exception as e:
|
795
795
|
raise JsonError(text, e) # pylint: disable=raise-missing-from
|
@@ -801,55 +801,56 @@ class ValueJsonRepr(ValueRepr):
|
|
801
801
|
))
|
802
802
|
return v['result']
|
803
803
|
|
804
|
-
def cleanup_json(self, json_str: str) -> str:
|
805
|
-
"""Clean up the LM responded JSON string."""
|
806
|
-
# Treatments:
|
807
|
-
# 1. Extract the JSON string with a top-level dict from the response.
|
808
|
-
# This prevents the leading and trailing texts in the response to
|
809
|
-
# be counted as part of the JSON.
|
810
|
-
# 2. Escape new lines in JSON values.
|
811
|
-
|
812
|
-
curly_brackets = 0
|
813
|
-
under_json = False
|
814
|
-
under_str = False
|
815
|
-
str_begin = -1
|
816
|
-
|
817
|
-
cleaned = io.StringIO()
|
818
|
-
for i, c in enumerate(json_str):
|
819
|
-
if c == '{' and not under_str:
|
820
|
-
cleaned.write(c)
|
821
|
-
curly_brackets += 1
|
822
|
-
under_json = True
|
823
|
-
continue
|
824
|
-
elif not under_json:
|
825
|
-
continue
|
826
804
|
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
805
|
+
def cleanup_json(json_str: str) -> str:
|
806
|
+
"""Clean up the LM responded JSON string."""
|
807
|
+
# Treatments:
|
808
|
+
# 1. Extract the JSON string with a top-level dict from the response.
|
809
|
+
# This prevents the leading and trailing texts in the response to
|
810
|
+
# be counted as part of the JSON.
|
811
|
+
# 2. Escape new lines in JSON values.
|
812
|
+
|
813
|
+
curly_brackets = 0
|
814
|
+
under_json = False
|
815
|
+
under_str = False
|
816
|
+
str_begin = -1
|
817
|
+
|
818
|
+
cleaned = io.StringIO()
|
819
|
+
for i, c in enumerate(json_str):
|
820
|
+
if c == '{' and not under_str:
|
821
|
+
cleaned.write(c)
|
822
|
+
curly_brackets += 1
|
823
|
+
under_json = True
|
824
|
+
continue
|
825
|
+
elif not under_json:
|
826
|
+
continue
|
827
|
+
|
828
|
+
if c == '}' and not under_str:
|
829
|
+
cleaned.write(c)
|
830
|
+
curly_brackets -= 1
|
831
|
+
if curly_brackets == 0:
|
832
|
+
break
|
833
|
+
elif c == '"' and json_str[i - 1] != '\\':
|
834
|
+
under_str = not under_str
|
835
|
+
if under_str:
|
836
|
+
str_begin = i
|
837
|
+
else:
|
838
|
+
assert str_begin > 0
|
839
|
+
str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
|
840
|
+
cleaned.write(str_value)
|
841
|
+
str_begin = -1
|
842
|
+
elif not under_str:
|
843
|
+
cleaned.write(c)
|
844
|
+
|
845
|
+
if not under_json:
|
846
|
+
raise ValueError(f'No JSON dict in the output: {json_str}')
|
847
|
+
|
848
|
+
if curly_brackets > 0:
|
849
|
+
raise ValueError(
|
850
|
+
f'Malformated JSON: missing {curly_brackets} closing curly braces.'
|
851
|
+
)
|
851
852
|
|
852
|
-
|
853
|
+
return cleaned.getvalue()
|
853
854
|
|
854
855
|
|
855
856
|
def schema_repr(protocol: SchemaProtocol) -> SchemaRepr:
|
@@ -17,7 +17,7 @@ from typing import Any, Type, Union
|
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.structured import mapping
|
20
|
-
from langfun.core.structured import
|
20
|
+
from langfun.core.structured import querying
|
21
21
|
from langfun.core.structured import schema as schema_lib
|
22
22
|
import pyglove as pg
|
23
23
|
|
@@ -101,7 +101,7 @@ def score(
|
|
101
101
|
prompts = []
|
102
102
|
for p in prompt:
|
103
103
|
prompts.append(
|
104
|
-
|
104
|
+
querying.query_prompt(
|
105
105
|
p,
|
106
106
|
schema,
|
107
107
|
examples=examples,
|
@@ -111,7 +111,7 @@ def score(
|
|
111
111
|
)
|
112
112
|
input_message = prompts
|
113
113
|
else:
|
114
|
-
input_message =
|
114
|
+
input_message = querying.query_prompt(
|
115
115
|
prompt,
|
116
116
|
schema,
|
117
117
|
examples=examples,
|
@@ -17,7 +17,7 @@ from typing import Any, Type, Union
|
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.structured import mapping
|
20
|
-
from langfun.core.structured import
|
20
|
+
from langfun.core.structured import querying
|
21
21
|
from langfun.core.structured import schema as schema_lib
|
22
22
|
import pyglove as pg
|
23
23
|
|
@@ -48,7 +48,7 @@ def tokenize(
|
|
48
48
|
Returns:
|
49
49
|
A list of (text, token_id) tuples.
|
50
50
|
"""
|
51
|
-
input_message =
|
51
|
+
input_message = querying.query_prompt(
|
52
52
|
prompt,
|
53
53
|
schema,
|
54
54
|
examples=examples,
|