langfun 0.1.2.dev202412140804__py3-none-any.whl → 0.1.2.dev202412170805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/coding/python/correction.py +2 -2
- langfun/core/eval/v2/__init__.py +5 -1
- langfun/core/eval/v2/checkpointing.py +96 -16
- langfun/core/eval/v2/checkpointing_test.py +41 -8
- langfun/core/eval/v2/runners.py +1 -1
- langfun/core/structured/__init__.py +7 -22
- langfun/core/structured/completion.py +2 -2
- langfun/core/structured/completion_test.py +4 -4
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +3 -3
- langfun/core/structured/parsing.py +9 -9
- langfun/core/structured/parsing_test.py +8 -8
- langfun/core/structured/{prompting.py → querying.py} +9 -9
- langfun/core/structured/{prompting_test.py → querying_test.py} +51 -51
- langfun/core/structured/schema.py +51 -50
- langfun/core/structured/scoring.py +3 -3
- langfun/core/structured/tokenization.py +2 -2
- {langfun-0.1.2.dev202412140804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202412140804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/RECORD +23 -23
- {langfun-0.1.2.dev202412140804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412140804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412140804.dist-info → langfun-0.1.2.dev202412170805.dist-info}/top_level.txt +0 -0
@@ -76,7 +76,7 @@ def run_with_correction(
|
|
76
76
|
# Delay import at runtime to avoid circular depenency.
|
77
77
|
# pylint: disable=g-import-not-at-top
|
78
78
|
# pytype: disable=import-error
|
79
|
-
from langfun.core.structured import
|
79
|
+
from langfun.core.structured import querying
|
80
80
|
# pytype: enable=import-error
|
81
81
|
# pylint: enable=g-import-not-at-top
|
82
82
|
|
@@ -119,7 +119,7 @@ def run_with_correction(
|
|
119
119
|
# structure.
|
120
120
|
try:
|
121
121
|
# Disable autofix for code correction to avoid recursion.
|
122
|
-
correction =
|
122
|
+
correction = querying.query(
|
123
123
|
CodeWithError(code=code, error=error), CorrectedCode, lm=lm, autofix=0
|
124
124
|
)
|
125
125
|
except errors.CodeError:
|
langfun/core/eval/v2/__init__.py
CHANGED
@@ -29,10 +29,14 @@ from langfun.core.eval.v2.metrics import Metric
|
|
29
29
|
from langfun.core.eval.v2 import metrics
|
30
30
|
|
31
31
|
from langfun.core.eval.v2.experiment import Plugin
|
32
|
-
|
33
32
|
from langfun.core.eval.v2.experiment import Runner
|
34
33
|
from langfun.core.eval.v2 import runners
|
35
34
|
|
35
|
+
# Plugins
|
36
|
+
from langfun.core.eval.v2.checkpointing import BulkCheckpointer
|
37
|
+
from langfun.core.eval.v2.checkpointing import PerExampleCheckpointer
|
38
|
+
from langfun.core.eval.v2.reporting import HtmlReporter
|
39
|
+
|
36
40
|
|
37
41
|
# pylint: enable=g-bad-import-order
|
38
42
|
# pylint: enable=g-importing-member
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Checkpointing evaluation runs."""
|
15
15
|
import threading
|
16
16
|
|
17
|
+
import langfun.core as lf
|
17
18
|
from langfun.core.eval.v2 import example as example_lib
|
18
19
|
from langfun.core.eval.v2 import experiment as experiment_lib
|
19
20
|
import pyglove as pg
|
@@ -24,21 +25,100 @@ Runner = experiment_lib.Runner
|
|
24
25
|
|
25
26
|
|
26
27
|
class Checkpointer(experiment_lib.Plugin):
|
27
|
-
"""
|
28
|
+
"""Base class for checkpointing evaluation examples."""
|
29
|
+
|
30
|
+
|
31
|
+
class PerExampleCheckpointer(Checkpointer):
|
32
|
+
"""Checkpointer that saves each example to a separate file."""
|
33
|
+
|
34
|
+
checkpoint_filename: str = 'checkpoint.bagz'
|
35
|
+
|
36
|
+
def _on_bound(self):
|
37
|
+
super()._on_bound()
|
38
|
+
prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename)
|
39
|
+
self._checkpoint_file_prefix = prefix
|
40
|
+
self._checkpoint_file_ext = ext
|
41
|
+
|
42
|
+
def on_experiment_start(
|
43
|
+
self,
|
44
|
+
runner: Runner,
|
45
|
+
experiment: Experiment,
|
46
|
+
) -> None:
|
47
|
+
"""Creates the checkpoint file."""
|
48
|
+
if not experiment.is_leaf:
|
49
|
+
return
|
50
|
+
|
51
|
+
# For refresh runs, we don't want to load the previous state.
|
52
|
+
if not runner.current_run.refresh:
|
53
|
+
def _load_state(ckpt_file):
|
54
|
+
experiment.load_state(ckpt_file)
|
55
|
+
|
56
|
+
experiment_dir = runner.current_run.input_dir(experiment)
|
57
|
+
if pg.io.path_exists(experiment_dir):
|
58
|
+
ckpt_files = [
|
59
|
+
runner.current_run.input_path_for(experiment, filename)
|
60
|
+
for filename in pg.io.listdir(experiment_dir)
|
61
|
+
if filename.startswith(self._checkpoint_file_prefix)
|
62
|
+
and filename.endswith(self._checkpoint_file_ext)
|
63
|
+
]
|
64
|
+
else:
|
65
|
+
ckpt_files = []
|
66
|
+
|
67
|
+
for ckpt_file, _, error in lf.concurrent_map(
|
68
|
+
_load_state, ckpt_files, max_workers=64,
|
69
|
+
):
|
70
|
+
if error is not None:
|
71
|
+
pg.logging.warning(
|
72
|
+
'Failed to load checkpoint file %s: %s. Skipping the file.',
|
73
|
+
ckpt_file, error
|
74
|
+
)
|
75
|
+
|
76
|
+
def on_example_complete(
|
77
|
+
self,
|
78
|
+
runner: Runner,
|
79
|
+
experiment: Experiment,
|
80
|
+
example: Example,
|
81
|
+
) -> None:
|
82
|
+
"""Saves the example to the checkpoint file."""
|
83
|
+
if not example.has_error:
|
84
|
+
def save_state(example: Example):
|
85
|
+
writer = SequenceWriter(
|
86
|
+
runner.current_run.output_path_for(
|
87
|
+
experiment,
|
88
|
+
(
|
89
|
+
f'{self._checkpoint_file_prefix}_{example.id}'
|
90
|
+
f'{self._checkpoint_file_ext}'
|
91
|
+
)
|
92
|
+
)
|
93
|
+
)
|
94
|
+
writer.add(example)
|
95
|
+
del writer
|
96
|
+
runner.background_run(save_state, example)
|
97
|
+
|
98
|
+
def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]:
|
99
|
+
ext_index = filename.rfind('.')
|
100
|
+
if ext_index == -1:
|
101
|
+
return filename, ''
|
102
|
+
else:
|
103
|
+
return filename[:ext_index], filename[ext_index:]
|
104
|
+
|
105
|
+
|
106
|
+
class BulkCheckpointer(Checkpointer):
|
107
|
+
"""Checkpointer that saves all examples to a single file."""
|
28
108
|
|
29
109
|
checkpoint_filename: str = 'checkpoint.bagz'
|
30
110
|
|
31
111
|
def _on_bound(self):
|
32
112
|
super()._on_bound()
|
33
113
|
self._lock = threading.Lock()
|
34
|
-
self.
|
114
|
+
self._sequence_writer = None
|
35
115
|
|
36
116
|
def on_run_start(
|
37
117
|
self,
|
38
118
|
runner: Runner,
|
39
119
|
root: Experiment,
|
40
120
|
) -> None:
|
41
|
-
self.
|
121
|
+
self._sequence_writer = {}
|
42
122
|
|
43
123
|
def on_run_abort(
|
44
124
|
self,
|
@@ -47,8 +127,8 @@ class Checkpointer(experiment_lib.Plugin):
|
|
47
127
|
error: BaseException
|
48
128
|
) -> None:
|
49
129
|
with self._lock:
|
50
|
-
if self.
|
51
|
-
self.
|
130
|
+
if self._sequence_writer is not None:
|
131
|
+
self._sequence_writer.clear()
|
52
132
|
|
53
133
|
def on_run_complete(
|
54
134
|
self,
|
@@ -56,7 +136,7 @@ class Checkpointer(experiment_lib.Plugin):
|
|
56
136
|
root: Experiment,
|
57
137
|
) -> None:
|
58
138
|
with self._lock:
|
59
|
-
assert self.
|
139
|
+
assert self._sequence_writer is not None and not self._sequence_writer
|
60
140
|
|
61
141
|
def on_experiment_start(
|
62
142
|
self,
|
@@ -74,14 +154,14 @@ class Checkpointer(experiment_lib.Plugin):
|
|
74
154
|
),
|
75
155
|
raise_if_not_exist=False
|
76
156
|
)
|
77
|
-
|
157
|
+
sequence_writer = SequenceWriter(
|
78
158
|
runner.current_run.output_path_for(
|
79
159
|
experiment, self.checkpoint_filename
|
80
160
|
)
|
81
161
|
)
|
82
162
|
with self._lock:
|
83
|
-
if self.
|
84
|
-
self.
|
163
|
+
if self._sequence_writer is not None:
|
164
|
+
self._sequence_writer[experiment.id] = sequence_writer
|
85
165
|
|
86
166
|
def on_experiment_complete(
|
87
167
|
self,
|
@@ -91,10 +171,10 @@ class Checkpointer(experiment_lib.Plugin):
|
|
91
171
|
"""Closes the checkpoint file."""
|
92
172
|
if not experiment.is_leaf:
|
93
173
|
return
|
94
|
-
assert experiment.id in self.
|
174
|
+
assert experiment.id in self._sequence_writer
|
95
175
|
with self._lock:
|
96
|
-
if self.
|
97
|
-
del self.
|
176
|
+
if self._sequence_writer is not None:
|
177
|
+
del self._sequence_writer[experiment.id]
|
98
178
|
|
99
179
|
def on_example_complete(
|
100
180
|
self,
|
@@ -103,13 +183,13 @@ class Checkpointer(experiment_lib.Plugin):
|
|
103
183
|
example: Example,
|
104
184
|
) -> None:
|
105
185
|
"""Saves the example to the checkpoint file."""
|
106
|
-
assert experiment.id in self.
|
186
|
+
assert experiment.id in self._sequence_writer
|
107
187
|
if not example.has_error:
|
108
|
-
runner.background_run(self.
|
188
|
+
runner.background_run(self._sequence_writer[experiment.id].add, example)
|
109
189
|
|
110
190
|
|
111
|
-
class
|
112
|
-
"""Thread safe
|
191
|
+
class SequenceWriter:
|
192
|
+
"""Thread safe sequence writer."""
|
113
193
|
|
114
194
|
def __init__(self, path: str):
|
115
195
|
self._lock = threading.Lock()
|
@@ -24,11 +24,11 @@ import pyglove as pg
|
|
24
24
|
Example = example_lib.Example
|
25
25
|
|
26
26
|
|
27
|
-
class
|
27
|
+
class SequenceWriterTest(unittest.TestCase):
|
28
28
|
|
29
29
|
def test_basic(self):
|
30
30
|
file = os.path.join(tempfile.gettempdir(), 'test.jsonl')
|
31
|
-
writer = checkpointing.
|
31
|
+
writer = checkpointing.SequenceWriter(file)
|
32
32
|
example = Example(id=1, input=pg.Dict(x=1), output=2)
|
33
33
|
writer.add(example)
|
34
34
|
del writer
|
@@ -36,7 +36,7 @@ class StateWriterTest(unittest.TestCase):
|
|
36
36
|
|
37
37
|
def test_error_handling(self):
|
38
38
|
file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl')
|
39
|
-
writer = checkpointing.
|
39
|
+
writer = checkpointing.SequenceWriter(file)
|
40
40
|
writer.add(Example(id=1, input=pg.Dict(x=1), output=2))
|
41
41
|
|
42
42
|
def f():
|
@@ -52,17 +52,50 @@ class StateWriterTest(unittest.TestCase):
|
|
52
52
|
self.assertEqual(len(list(iter(f))), 1)
|
53
53
|
|
54
54
|
|
55
|
-
class
|
55
|
+
class PerExampleCheckpointerTest(unittest.TestCase):
|
56
56
|
|
57
57
|
def test_checkpointing(self):
|
58
|
-
root_dir = os.path.join(tempfile.gettempdir(), '
|
58
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
|
59
59
|
experiment = test_helper.test_experiment()
|
60
60
|
checkpoint_filename = 'checkpoint.jsonl'
|
61
|
-
checkpointer = checkpointing.
|
61
|
+
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
|
62
62
|
run = experiment.run(
|
63
63
|
root_dir, 'new', runner='sequential', plugins=[checkpointer]
|
64
64
|
)
|
65
|
-
|
65
|
+
num_processed = {}
|
66
|
+
for leaf in experiment.leaf_nodes:
|
67
|
+
for i in range(leaf.num_examples):
|
68
|
+
example = leaf.state.get(i + 1)
|
69
|
+
ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
|
70
|
+
if example.has_error:
|
71
|
+
self.assertFalse(pg.io.path_exists(ckpt))
|
72
|
+
else:
|
73
|
+
self.assertTrue(pg.io.path_exists(ckpt))
|
74
|
+
with pg.io.open_sequence(ckpt) as f:
|
75
|
+
self.assertEqual(len(list(iter(f))), 1)
|
76
|
+
if leaf.id not in num_processed:
|
77
|
+
self.assertEqual(leaf.progress.num_skipped, 0)
|
78
|
+
num_processed[leaf.id] = leaf.progress.num_processed
|
79
|
+
|
80
|
+
# Run again, should skip existing.
|
81
|
+
_ = experiment.run(
|
82
|
+
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
|
83
|
+
)
|
84
|
+
for leaf in experiment.leaf_nodes:
|
85
|
+
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
86
|
+
|
87
|
+
|
88
|
+
class BulkCheckpointerTest(unittest.TestCase):
|
89
|
+
|
90
|
+
def test_checkpointing(self):
|
91
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
|
92
|
+
experiment = test_helper.test_experiment()
|
93
|
+
checkpoint_filename = 'checkpoint.jsonl'
|
94
|
+
checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename)
|
95
|
+
run = experiment.run(
|
96
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer]
|
97
|
+
)
|
98
|
+
self.assertEqual(len(checkpointer._sequence_writer), 0)
|
66
99
|
num_processed = {}
|
67
100
|
for leaf in experiment.leaf_nodes:
|
68
101
|
ckpt = run.output_path_for(leaf, checkpoint_filename)
|
@@ -80,7 +113,7 @@ class CheckpointingTest(unittest.TestCase):
|
|
80
113
|
_ = experiment.run(
|
81
114
|
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
|
82
115
|
)
|
83
|
-
self.assertEqual(len(checkpointer.
|
116
|
+
self.assertEqual(len(checkpointer._sequence_writer), 0)
|
84
117
|
for leaf in experiment.leaf_nodes:
|
85
118
|
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
86
119
|
|
langfun/core/eval/v2/runners.py
CHANGED
@@ -36,12 +36,6 @@ from langfun.core.structured.schema import class_definitions
|
|
36
36
|
from langfun.core.structured.schema import annotation
|
37
37
|
from langfun.core.structured.schema import structure_from_python
|
38
38
|
|
39
|
-
from langfun.core.structured.schema import SchemaRepr
|
40
|
-
from langfun.core.structured.schema import SchemaJsonRepr
|
41
|
-
from langfun.core.structured.schema import SchemaPythonRepr
|
42
|
-
from langfun.core.structured.schema import ValueRepr
|
43
|
-
from langfun.core.structured.schema import ValueJsonRepr
|
44
|
-
from langfun.core.structured.schema import ValuePythonRepr
|
45
39
|
from langfun.core.structured.schema import schema_repr
|
46
40
|
from langfun.core.structured.schema import source_form
|
47
41
|
from langfun.core.structured.schema import value_repr
|
@@ -56,26 +50,17 @@ from langfun.core.structured.mapping import Mapping
|
|
56
50
|
from langfun.core.structured.mapping import MappingError
|
57
51
|
from langfun.core.structured.mapping import MappingExample
|
58
52
|
|
59
|
-
from langfun.core.structured.parsing import ParseStructure
|
60
|
-
from langfun.core.structured.parsing import ParseStructureJson
|
61
|
-
from langfun.core.structured.parsing import ParseStructurePython
|
62
53
|
from langfun.core.structured.parsing import parse
|
63
54
|
from langfun.core.structured.parsing import call
|
64
55
|
|
65
|
-
from langfun.core.structured.
|
66
|
-
from langfun.core.structured.
|
67
|
-
from langfun.core.structured.
|
68
|
-
from langfun.core.structured.
|
69
|
-
from langfun.core.structured.
|
70
|
-
from langfun.core.structured.
|
71
|
-
from langfun.core.structured.prompting import query_reward
|
72
|
-
from langfun.core.structured.prompting import QueryInvocation
|
73
|
-
from langfun.core.structured.prompting import track_queries
|
74
|
-
|
75
|
-
from langfun.core.structured.description import DescribeStructure
|
76
|
-
from langfun.core.structured.description import describe
|
56
|
+
from langfun.core.structured.querying import track_queries
|
57
|
+
from langfun.core.structured.querying import QueryInvocation
|
58
|
+
from langfun.core.structured.querying import query
|
59
|
+
from langfun.core.structured.querying import query_prompt
|
60
|
+
from langfun.core.structured.querying import query_output
|
61
|
+
from langfun.core.structured.querying import query_reward
|
77
62
|
|
78
|
-
from langfun.core.structured.
|
63
|
+
from langfun.core.structured.description import describe
|
79
64
|
from langfun.core.structured.completion import complete
|
80
65
|
|
81
66
|
from langfun.core.structured.scoring import score
|
@@ -21,7 +21,7 @@ from langfun.core.structured import schema as schema_lib
|
|
21
21
|
import pyglove as pg
|
22
22
|
|
23
23
|
|
24
|
-
class
|
24
|
+
class _CompleteStructure(mapping.Mapping):
|
25
25
|
"""Complete structure by filling the missing fields."""
|
26
26
|
|
27
27
|
input: Annotated[
|
@@ -241,7 +241,7 @@ def complete(
|
|
241
241
|
Returns:
|
242
242
|
The result based on the schema.
|
243
243
|
"""
|
244
|
-
t =
|
244
|
+
t = _CompleteStructure(
|
245
245
|
input=schema_lib.mark_missing(input_value),
|
246
246
|
default=default,
|
247
247
|
examples=examples,
|
@@ -46,7 +46,7 @@ class TripPlan(pg.Object):
|
|
46
46
|
class CompleteStructureTest(unittest.TestCase):
|
47
47
|
|
48
48
|
def test_render_no_examples(self):
|
49
|
-
l = completion.
|
49
|
+
l = completion._CompleteStructure()
|
50
50
|
input_value = schema_lib.mark_missing(
|
51
51
|
TripPlan.partial(
|
52
52
|
place='San Francisco',
|
@@ -120,7 +120,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
120
120
|
)
|
121
121
|
|
122
122
|
def test_render_no_class_definitions(self):
|
123
|
-
l = completion.
|
123
|
+
l = completion._CompleteStructure()
|
124
124
|
input_value = schema_lib.mark_missing(
|
125
125
|
TripPlan.partial(
|
126
126
|
place='San Francisco',
|
@@ -200,7 +200,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
200
200
|
)
|
201
201
|
|
202
202
|
def test_render_with_examples(self):
|
203
|
-
l = completion.
|
203
|
+
l = completion._CompleteStructure()
|
204
204
|
input_value = schema_lib.mark_missing(
|
205
205
|
TripPlan.partial(
|
206
206
|
place='San Francisco',
|
@@ -411,7 +411,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
411
411
|
modalities.Image.from_bytes(b'image_of_elephant'),
|
412
412
|
)
|
413
413
|
)
|
414
|
-
l = completion.
|
414
|
+
l = completion._CompleteStructure(
|
415
415
|
input=input_value,
|
416
416
|
examples=[
|
417
417
|
mapping.MappingExample(
|
@@ -22,7 +22,7 @@ import pyglove as pg
|
|
22
22
|
|
23
23
|
|
24
24
|
@pg.use_init_args(['examples'])
|
25
|
-
class
|
25
|
+
class _DescribeStructure(mapping.Mapping):
|
26
26
|
"""Describe a structured value in natural language."""
|
27
27
|
|
28
28
|
input_title = 'PYTHON_OBJECT'
|
@@ -106,7 +106,7 @@ def describe(
|
|
106
106
|
Returns:
|
107
107
|
The parsed result based on the schema.
|
108
108
|
"""
|
109
|
-
return
|
109
|
+
return _DescribeStructure(
|
110
110
|
input=value,
|
111
111
|
context=context,
|
112
112
|
examples=examples or default_describe_examples(),
|
@@ -36,7 +36,7 @@ class Itinerary(pg.Object):
|
|
36
36
|
class DescribeStructureTest(unittest.TestCase):
|
37
37
|
|
38
38
|
def test_render(self):
|
39
|
-
l = description_lib.
|
39
|
+
l = description_lib._DescribeStructure(
|
40
40
|
input=Itinerary(
|
41
41
|
day=1,
|
42
42
|
type='daytime',
|
@@ -137,7 +137,7 @@ class DescribeStructureTest(unittest.TestCase):
|
|
137
137
|
],
|
138
138
|
hotel=None,
|
139
139
|
)
|
140
|
-
l = description_lib.
|
140
|
+
l = description_lib._DescribeStructure(
|
141
141
|
input=value, context='1 day itinerary to SF'
|
142
142
|
)
|
143
143
|
self.assertEqual(
|
@@ -187,7 +187,7 @@ class DescribeStructureTest(unittest.TestCase):
|
|
187
187
|
],
|
188
188
|
hotel=None,
|
189
189
|
)
|
190
|
-
l = description_lib.
|
190
|
+
l = description_lib._DescribeStructure(input=value)
|
191
191
|
self.assertEqual(
|
192
192
|
l.render().text,
|
193
193
|
inspect.cleandoc("""
|
@@ -21,7 +21,7 @@ from typing import Any, Callable, Literal, Optional, Tuple
|
|
21
21
|
from langfun.core import language_model
|
22
22
|
from langfun.core import template
|
23
23
|
from langfun.core.coding import python
|
24
|
-
from langfun.core.structured import
|
24
|
+
from langfun.core.structured import querying
|
25
25
|
import pyglove as pg
|
26
26
|
|
27
27
|
|
@@ -39,7 +39,7 @@ def unittest_gen(signature, lm, num_retries=1):
|
|
39
39
|
|
40
40
|
unittest_examples = None
|
41
41
|
for _ in range(num_retries):
|
42
|
-
r =
|
42
|
+
r = querying.query(
|
43
43
|
PythonFunctionSignature(signature=signature),
|
44
44
|
list[UnitTest],
|
45
45
|
lm=lm,
|
@@ -145,7 +145,7 @@ def _function_gen(
|
|
145
145
|
last_error = None
|
146
146
|
for _ in range(num_retries):
|
147
147
|
try:
|
148
|
-
source_code =
|
148
|
+
source_code = querying.query(
|
149
149
|
PythonFunctionPrompt(signature=signature), lm=lm
|
150
150
|
)
|
151
151
|
f = python.evaluate(source_code, global_vars=context)
|
@@ -16,13 +16,13 @@ from typing import Any, Callable, Type, Union
|
|
16
16
|
|
17
17
|
import langfun.core as lf
|
18
18
|
from langfun.core.structured import mapping
|
19
|
-
from langfun.core.structured import
|
19
|
+
from langfun.core.structured import querying
|
20
20
|
from langfun.core.structured import schema as schema_lib
|
21
21
|
import pyglove as pg
|
22
22
|
|
23
23
|
|
24
24
|
@lf.use_init_args(['schema', 'default', 'examples'])
|
25
|
-
class
|
25
|
+
class _ParseStructure(mapping.Mapping):
|
26
26
|
"""Parse an object out from a natural language text."""
|
27
27
|
|
28
28
|
context_title = 'USER_REQUEST'
|
@@ -37,7 +37,7 @@ class ParseStructure(mapping.Mapping):
|
|
37
37
|
]
|
38
38
|
|
39
39
|
|
40
|
-
class
|
40
|
+
class _ParseStructureJson(_ParseStructure):
|
41
41
|
"""Parse an object out from a NL text using JSON as the protocol."""
|
42
42
|
|
43
43
|
preamble = """
|
@@ -53,7 +53,7 @@ class ParseStructureJson(ParseStructure):
|
|
53
53
|
output_title = 'JSON'
|
54
54
|
|
55
55
|
|
56
|
-
class
|
56
|
+
class _ParseStructurePython(_ParseStructure):
|
57
57
|
"""Parse an object out from a NL text using Python as the protocol."""
|
58
58
|
|
59
59
|
preamble = """
|
@@ -87,7 +87,7 @@ def parse(
|
|
87
87
|
returns_message: bool = False,
|
88
88
|
**kwargs,
|
89
89
|
) -> Any:
|
90
|
-
"""Parse a natural
|
90
|
+
"""Parse a natural language message based on schema.
|
91
91
|
|
92
92
|
Examples:
|
93
93
|
|
@@ -271,7 +271,7 @@ def call(
|
|
271
271
|
return lm_output if returns_message else lm_output.text
|
272
272
|
|
273
273
|
# Call `parsing_lm` for structured parsing.
|
274
|
-
parsing_message =
|
274
|
+
parsing_message = querying.query(
|
275
275
|
lm_output.text,
|
276
276
|
schema,
|
277
277
|
examples=parsing_examples,
|
@@ -293,11 +293,11 @@ def call(
|
|
293
293
|
|
294
294
|
def _parse_structure_cls(
|
295
295
|
protocol: schema_lib.SchemaProtocol,
|
296
|
-
) -> Type[
|
296
|
+
) -> Type[_ParseStructure]:
|
297
297
|
if protocol == 'json':
|
298
|
-
return
|
298
|
+
return _ParseStructureJson
|
299
299
|
elif protocol == 'python':
|
300
|
-
return
|
300
|
+
return _ParseStructurePython
|
301
301
|
else:
|
302
302
|
raise ValueError(f'Unknown protocol: {protocol!r}.')
|
303
303
|
|
@@ -37,7 +37,7 @@ class Itinerary(pg.Object):
|
|
37
37
|
class ParseStructurePythonTest(unittest.TestCase):
|
38
38
|
|
39
39
|
def test_render_no_examples(self):
|
40
|
-
l = parsing.
|
40
|
+
l = parsing._ParseStructurePython(int)
|
41
41
|
m = lf.AIMessage('12 / 6 + 2 = 4')
|
42
42
|
self.assertEqual(
|
43
43
|
l.render(input=m, context='Compute 12 / 6 + 2.').text,
|
@@ -62,7 +62,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
62
62
|
)
|
63
63
|
|
64
64
|
def test_render_no_context(self):
|
65
|
-
l = parsing.
|
65
|
+
l = parsing._ParseStructurePython(int)
|
66
66
|
m = lf.AIMessage('12 / 6 + 2 = 4')
|
67
67
|
|
68
68
|
self.assertEqual(
|
@@ -85,7 +85,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
85
85
|
)
|
86
86
|
|
87
87
|
def test_render(self):
|
88
|
-
l = parsing.
|
88
|
+
l = parsing._ParseStructurePython(
|
89
89
|
int,
|
90
90
|
examples=[
|
91
91
|
mapping.MappingExample(
|
@@ -212,7 +212,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
212
212
|
),
|
213
213
|
override_attrs=True,
|
214
214
|
):
|
215
|
-
l = parsing.
|
215
|
+
l = parsing._ParseStructurePython(
|
216
216
|
[Itinerary],
|
217
217
|
examples=[
|
218
218
|
mapping.MappingExample(
|
@@ -295,7 +295,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
295
295
|
class ParseStructureJsonTest(unittest.TestCase):
|
296
296
|
|
297
297
|
def test_render_no_examples(self):
|
298
|
-
l = parsing.
|
298
|
+
l = parsing._ParseStructureJson(int)
|
299
299
|
m = lf.AIMessage('12 / 6 + 2 = 4')
|
300
300
|
self.assertEqual(
|
301
301
|
l.render(input=m, context='Compute 12 / 6 + 2.').text,
|
@@ -320,7 +320,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
320
320
|
)
|
321
321
|
|
322
322
|
def test_render_no_context(self):
|
323
|
-
l = parsing.
|
323
|
+
l = parsing._ParseStructureJson(int)
|
324
324
|
m = lf.AIMessage('12 / 6 + 2 = 4')
|
325
325
|
|
326
326
|
self.assertEqual(
|
@@ -343,7 +343,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
343
343
|
)
|
344
344
|
|
345
345
|
def test_render(self):
|
346
|
-
l = parsing.
|
346
|
+
l = parsing._ParseStructureJson(
|
347
347
|
int,
|
348
348
|
examples=[
|
349
349
|
mapping.MappingExample(
|
@@ -504,7 +504,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
504
504
|
override_attrs=True,
|
505
505
|
):
|
506
506
|
message = lf.LangFunc(lm_input)()
|
507
|
-
l = parsing.
|
507
|
+
l = parsing._ParseStructureJson(
|
508
508
|
[Itinerary],
|
509
509
|
examples=[
|
510
510
|
mapping.MappingExample(
|