langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -1,987 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The Langfun Authors
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
"""Schema for structured data."""
|
|
15
|
-
|
|
16
|
-
import abc
|
|
17
|
-
import inspect
|
|
18
|
-
import io
|
|
19
|
-
import re
|
|
20
|
-
import sys
|
|
21
|
-
import textwrap
|
|
22
|
-
import typing
|
|
23
|
-
from typing import Any, Literal, Sequence, Type, Union
|
|
24
|
-
import langfun.core as lf
|
|
25
|
-
from langfun.core.coding.python import correction
|
|
26
|
-
import pyglove as pg
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def include_method_in_prompt(method):
|
|
30
|
-
"""Decorator to include a method in the class definition of the prompt."""
|
|
31
|
-
setattr(method, '__show_in_prompt__', True)
|
|
32
|
-
return method
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def should_include_method_in_prompt(method):
|
|
36
|
-
"""Returns true if the method should be shown in the prompt."""
|
|
37
|
-
return getattr(method, '__show_in_prompt__', False)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def parse_value_spec(value) -> pg.typing.ValueSpec:
|
|
41
|
-
"""Parses a PyGlove ValueSpec equivalence into a ValueSpec."""
|
|
42
|
-
if isinstance(value, pg.typing.ValueSpec):
|
|
43
|
-
return value
|
|
44
|
-
|
|
45
|
-
if isinstance(value, dict) and len(value) == 1 and 'result' in value:
|
|
46
|
-
value = value['result']
|
|
47
|
-
|
|
48
|
-
def _parse_node(v) -> pg.typing.ValueSpec:
|
|
49
|
-
if isinstance(v, dict):
|
|
50
|
-
return pg.typing.Dict([(k, _parse_node(cv)) for k, cv in v.items()])
|
|
51
|
-
elif isinstance(v, list):
|
|
52
|
-
if len(v) != 1:
|
|
53
|
-
raise ValueError(
|
|
54
|
-
'Annotation with list must be a list of a single element. '
|
|
55
|
-
f'Encountered: {v}'
|
|
56
|
-
)
|
|
57
|
-
return pg.typing.List(_parse_node(v[0]))
|
|
58
|
-
else:
|
|
59
|
-
spec = pg.typing.ValueSpec.from_annotation(v, auto_typing=True)
|
|
60
|
-
if isinstance(
|
|
61
|
-
spec,
|
|
62
|
-
(
|
|
63
|
-
pg.typing.Any,
|
|
64
|
-
pg.typing.Callable,
|
|
65
|
-
pg.typing.Tuple,
|
|
66
|
-
pg.typing.Type,
|
|
67
|
-
pg.typing.Union,
|
|
68
|
-
),
|
|
69
|
-
):
|
|
70
|
-
raise ValueError(f'Unsupported schema specification: {v}')
|
|
71
|
-
return spec
|
|
72
|
-
|
|
73
|
-
return _parse_node(value)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
SchemaProtocol = Literal['json', 'python']
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
class SchemaError(Exception): # pylint: disable=g-bad-exception-name
|
|
80
|
-
"""Schema error."""
|
|
81
|
-
|
|
82
|
-
def __init__(self,
|
|
83
|
-
schema: 'Schema',
|
|
84
|
-
value: Any,
|
|
85
|
-
protocol: SchemaProtocol,
|
|
86
|
-
cause: Exception):
|
|
87
|
-
self.schema = schema
|
|
88
|
-
self.value = value
|
|
89
|
-
self.protocol = protocol
|
|
90
|
-
self.cause = cause
|
|
91
|
-
|
|
92
|
-
def __str__(self):
|
|
93
|
-
r = io.StringIO()
|
|
94
|
-
r.write(
|
|
95
|
-
pg.colored(
|
|
96
|
-
f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
|
|
97
|
-
)
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
r.write('\n')
|
|
101
|
-
r.write(pg.colored('Schema:', 'red'))
|
|
102
|
-
r.write('\n\n')
|
|
103
|
-
r.write(textwrap.indent(
|
|
104
|
-
pg.colored(
|
|
105
|
-
schema_repr(self.protocol).repr(self.schema), 'magenta'
|
|
106
|
-
),
|
|
107
|
-
' ' * 2
|
|
108
|
-
))
|
|
109
|
-
r.write('\n\n')
|
|
110
|
-
r.write(pg.colored('Generated value:', 'red'))
|
|
111
|
-
r.write('\n\n')
|
|
112
|
-
r.write(textwrap.indent(
|
|
113
|
-
pg.colored(value_repr(self.protocol).repr(self.value), 'magenta'),
|
|
114
|
-
' ' * 2
|
|
115
|
-
))
|
|
116
|
-
return r.getvalue()
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
class Schema(
|
|
120
|
-
lf.NaturalLanguageFormattable,
|
|
121
|
-
pg.Object,
|
|
122
|
-
pg.views.HtmlTreeView.Extension
|
|
123
|
-
):
|
|
124
|
-
"""Base class for structured data schema."""
|
|
125
|
-
|
|
126
|
-
spec: pg.typing.Annotated[
|
|
127
|
-
pg.typing.Object(pg.typing.ValueSpec, transform=parse_value_spec),
|
|
128
|
-
(
|
|
129
|
-
'A PyGlove ValueSpec object representing the spec for the value '
|
|
130
|
-
'to be parsed.'
|
|
131
|
-
),
|
|
132
|
-
]
|
|
133
|
-
|
|
134
|
-
def schema_str(self, protocol: SchemaProtocol = 'json', **kwargs) -> str:
|
|
135
|
-
"""Returns the representation of the schema."""
|
|
136
|
-
return schema_repr(protocol).repr(self, **kwargs)
|
|
137
|
-
|
|
138
|
-
def value_str(
|
|
139
|
-
self, value: Any, protocol: SchemaProtocol = 'json', **kwargs
|
|
140
|
-
) -> str:
|
|
141
|
-
"""Returns the representation of a structured value."""
|
|
142
|
-
return value_repr(protocol).repr(value, self, **kwargs)
|
|
143
|
-
|
|
144
|
-
def parse(
|
|
145
|
-
self, text: str, protocol: SchemaProtocol = 'json', **kwargs
|
|
146
|
-
) -> Any:
|
|
147
|
-
"""Parse a LM generated text into a structured value."""
|
|
148
|
-
value = value_repr(protocol).parse(text, self, **kwargs)
|
|
149
|
-
|
|
150
|
-
# TODO(daiyip): support autofix for schema error.
|
|
151
|
-
try:
|
|
152
|
-
return self.spec.apply(value)
|
|
153
|
-
except Exception as e:
|
|
154
|
-
raise SchemaError(self, value, protocol, e) # pylint: disable=raise-missing-from
|
|
155
|
-
|
|
156
|
-
def natural_language_format(self) -> str:
|
|
157
|
-
return self.schema_str()
|
|
158
|
-
|
|
159
|
-
def schema_dict(self) -> dict[str, Any]:
|
|
160
|
-
"""Returns the dict representation of the schema."""
|
|
161
|
-
|
|
162
|
-
def _node(vs: pg.typing.ValueSpec) -> Any:
|
|
163
|
-
if isinstance(vs, pg.typing.PrimitiveType):
|
|
164
|
-
return vs
|
|
165
|
-
elif isinstance(vs, pg.typing.Dict):
|
|
166
|
-
assert vs.schema is not None
|
|
167
|
-
return {str(k): _node(f.value) for k, f in vs.schema.fields.items()}
|
|
168
|
-
elif isinstance(vs, pg.typing.List):
|
|
169
|
-
return [_node(vs.element.value)]
|
|
170
|
-
elif isinstance(vs, pg.typing.Object):
|
|
171
|
-
if issubclass(vs.cls, pg.Object):
|
|
172
|
-
d = {pg.JSONConvertible.TYPE_NAME_KEY: vs.cls.__serialization_key__}
|
|
173
|
-
d.update(
|
|
174
|
-
{
|
|
175
|
-
str(k): _node(f.value)
|
|
176
|
-
for k, f in vs.cls.__schema__.fields.items()
|
|
177
|
-
}
|
|
178
|
-
)
|
|
179
|
-
return d
|
|
180
|
-
raise TypeError(
|
|
181
|
-
'Unsupported value spec being used as the schema for '
|
|
182
|
-
f'structured data: {vs}.')
|
|
183
|
-
|
|
184
|
-
return {'result': _node(self.spec)}
|
|
185
|
-
|
|
186
|
-
def class_dependencies(
|
|
187
|
-
self,
|
|
188
|
-
include_base_classes: bool = True,
|
|
189
|
-
include_subclasses: bool = True,
|
|
190
|
-
include_generated_subclasses: bool = False) -> list[Type[Any]]:
|
|
191
|
-
"""Returns a list of class dependencies for current schema."""
|
|
192
|
-
return class_dependencies(
|
|
193
|
-
self.spec,
|
|
194
|
-
include_base_classes,
|
|
195
|
-
include_subclasses,
|
|
196
|
-
include_generated_subclasses
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
@classmethod
|
|
200
|
-
def from_value(cls, value) -> 'Schema':
|
|
201
|
-
"""Creates a schema from an equivalent representation."""
|
|
202
|
-
if isinstance(value, Schema):
|
|
203
|
-
return value
|
|
204
|
-
return cls(parse_value_spec(value))
|
|
205
|
-
|
|
206
|
-
def _html_tree_view_content(
|
|
207
|
-
self,
|
|
208
|
-
*,
|
|
209
|
-
view: pg.views.HtmlTreeView,
|
|
210
|
-
**kwargs,
|
|
211
|
-
):
|
|
212
|
-
return pg.Html.element(
|
|
213
|
-
'div',
|
|
214
|
-
[pg.Html.escape(self.schema_str(protocol='python'))],
|
|
215
|
-
css_classes=['lf-schema-definition']
|
|
216
|
-
).add_style(
|
|
217
|
-
"""
|
|
218
|
-
.lf-schema-definition {
|
|
219
|
-
color: blue;
|
|
220
|
-
margin: 5px;
|
|
221
|
-
white-space: pre-wrap;
|
|
222
|
-
}
|
|
223
|
-
"""
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
|
|
231
|
-
"""Returns a list of top level value specs from a symbolic value."""
|
|
232
|
-
top_level_object_specs = []
|
|
233
|
-
|
|
234
|
-
def _collect_top_level_object_specs(k, v, p):
|
|
235
|
-
del k, p
|
|
236
|
-
if isinstance(v, pg.Object):
|
|
237
|
-
top_level_object_specs.append(pg.typing.Object(v.__class__))
|
|
238
|
-
return pg.TraverseAction.CONTINUE
|
|
239
|
-
return pg.TraverseAction.ENTER
|
|
240
|
-
|
|
241
|
-
pg.traverse(value, _collect_top_level_object_specs)
|
|
242
|
-
return top_level_object_specs
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
def class_dependencies(
|
|
246
|
-
value_or_spec: Union[
|
|
247
|
-
pg.Symbolic,
|
|
248
|
-
Schema,
|
|
249
|
-
pg.typing.ValueSpec,
|
|
250
|
-
Type[pg.Object],
|
|
251
|
-
tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...],
|
|
252
|
-
],
|
|
253
|
-
include_base_classes: bool = True,
|
|
254
|
-
include_subclasses: bool = True,
|
|
255
|
-
include_generated_subclasses: bool = False,
|
|
256
|
-
) -> list[Type[Any]]:
|
|
257
|
-
"""Returns a list of class dependencies from a value or specs."""
|
|
258
|
-
if isinstance(value_or_spec, Schema):
|
|
259
|
-
value_or_spec = value_or_spec.spec
|
|
260
|
-
|
|
261
|
-
if inspect.isclass(value_or_spec) or isinstance(
|
|
262
|
-
value_or_spec, pg.typing.ValueSpec
|
|
263
|
-
):
|
|
264
|
-
value_or_spec = (value_or_spec,)
|
|
265
|
-
|
|
266
|
-
if isinstance(value_or_spec, tuple):
|
|
267
|
-
value_specs = []
|
|
268
|
-
for v in value_or_spec:
|
|
269
|
-
if isinstance(v, pg.typing.ValueSpec):
|
|
270
|
-
value_specs.append(v)
|
|
271
|
-
elif inspect.isclass(v):
|
|
272
|
-
value_specs.append(pg.typing.Object(v))
|
|
273
|
-
else:
|
|
274
|
-
raise TypeError(f'Unsupported spec type: {v!r}')
|
|
275
|
-
else:
|
|
276
|
-
value_specs = _top_level_object_specs_from_value(value_or_spec)
|
|
277
|
-
|
|
278
|
-
seen = set()
|
|
279
|
-
dependencies = []
|
|
280
|
-
|
|
281
|
-
def _add_dependency(cls_or_classes):
|
|
282
|
-
if isinstance(cls_or_classes, type):
|
|
283
|
-
cls_or_classes = [cls_or_classes]
|
|
284
|
-
for cls in cls_or_classes:
|
|
285
|
-
if cls not in dependencies:
|
|
286
|
-
dependencies.append(cls)
|
|
287
|
-
|
|
288
|
-
def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool):
|
|
289
|
-
if isinstance(vs, pg.typing.Object):
|
|
290
|
-
if vs.cls not in seen:
|
|
291
|
-
seen.add(vs.cls)
|
|
292
|
-
|
|
293
|
-
if include_base_classes:
|
|
294
|
-
# Add base classes as dependencies.
|
|
295
|
-
for base_cls in vs.cls.__bases__:
|
|
296
|
-
# We only keep track of user-defined symbolic classes.
|
|
297
|
-
if base_cls is not object and base_cls is not pg.Object:
|
|
298
|
-
_fill_dependencies(
|
|
299
|
-
pg.typing.Object(base_cls), include_subclasses=False
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
# Add members as dependencies.
|
|
303
|
-
for field in pg.schema(vs.cls).values():
|
|
304
|
-
_fill_dependencies(field.value, include_subclasses)
|
|
305
|
-
_add_dependency(vs.cls)
|
|
306
|
-
|
|
307
|
-
# Check subclasses if available.
|
|
308
|
-
if include_subclasses:
|
|
309
|
-
for cls in vs.cls.__subclasses__():
|
|
310
|
-
# NOTE(daiyip): To prevent LLM-generated "hallucinated" classes from
|
|
311
|
-
# polluting the generation space, classes dynamically created by
|
|
312
|
-
# 'eval' (which have __module__ == 'builtins') are excluded from
|
|
313
|
-
# dependencies by default.
|
|
314
|
-
if ((include_generated_subclasses or cls.__module__ != 'builtins')
|
|
315
|
-
and cls not in dependencies):
|
|
316
|
-
_fill_dependencies(pg.typing.Object(cls), include_subclasses=True)
|
|
317
|
-
|
|
318
|
-
if isinstance(vs, pg.typing.List):
|
|
319
|
-
_fill_dependencies(vs.element.value, include_subclasses)
|
|
320
|
-
elif isinstance(vs, pg.typing.Tuple):
|
|
321
|
-
for elem in vs.elements:
|
|
322
|
-
_fill_dependencies(elem.value, include_subclasses)
|
|
323
|
-
elif isinstance(vs, pg.typing.Dict) and vs.schema:
|
|
324
|
-
for v in vs.schema.values():
|
|
325
|
-
_fill_dependencies(v.value, include_subclasses)
|
|
326
|
-
elif isinstance(vs, pg.typing.Union):
|
|
327
|
-
for v in vs.candidates:
|
|
328
|
-
_fill_dependencies(v, include_subclasses)
|
|
329
|
-
|
|
330
|
-
for value_spec in value_specs:
|
|
331
|
-
_fill_dependencies(value_spec, include_subclasses)
|
|
332
|
-
return dependencies
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
def schema_spec(noneable: bool = False) -> pg.typing.ValueSpec: # pylint: disable=unused-argument
|
|
336
|
-
if typing.TYPE_CHECKING:
|
|
337
|
-
return Any
|
|
338
|
-
return pg.typing.Object(
|
|
339
|
-
Schema, transform=Schema.from_value, is_noneable=noneable
|
|
340
|
-
) # pylint: disable=unreachable-code
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
#
|
|
344
|
-
# Schema representations.
|
|
345
|
-
#
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
class SchemaRepr(metaclass=abc.ABCMeta):
|
|
349
|
-
"""Base class for schema representation."""
|
|
350
|
-
|
|
351
|
-
@abc.abstractmethod
|
|
352
|
-
def repr(self, schema: Schema) -> str:
|
|
353
|
-
"""Returns the representation of the schema."""
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
class SchemaPythonRepr(SchemaRepr):
|
|
357
|
-
"""Python-representation for a schema."""
|
|
358
|
-
|
|
359
|
-
def repr(
|
|
360
|
-
self,
|
|
361
|
-
schema: Schema,
|
|
362
|
-
*,
|
|
363
|
-
include_result_definition: bool = True,
|
|
364
|
-
markdown: bool = True,
|
|
365
|
-
**kwargs,
|
|
366
|
-
) -> str:
|
|
367
|
-
ret = ''
|
|
368
|
-
if include_result_definition:
|
|
369
|
-
ret += self.result_definition(schema)
|
|
370
|
-
class_definition_str = self.class_definitions(
|
|
371
|
-
schema, markdown=markdown, **kwargs
|
|
372
|
-
)
|
|
373
|
-
if class_definition_str:
|
|
374
|
-
ret += f'\n\n{class_definition_str}'
|
|
375
|
-
return ret.strip()
|
|
376
|
-
|
|
377
|
-
def class_definitions(
|
|
378
|
-
self,
|
|
379
|
-
schema: Schema,
|
|
380
|
-
additional_dependencies: list[Type[Any]] | None = None,
|
|
381
|
-
**kwargs
|
|
382
|
-
) -> str | None:
|
|
383
|
-
"""Returns a string containing of class definitions from a schema."""
|
|
384
|
-
deps = schema.class_dependencies(
|
|
385
|
-
include_base_classes=False, include_subclasses=True
|
|
386
|
-
)
|
|
387
|
-
allowed_dependencies = set(deps)
|
|
388
|
-
if additional_dependencies:
|
|
389
|
-
allowed_dependencies.update(additional_dependencies)
|
|
390
|
-
return class_definitions(
|
|
391
|
-
deps, allowed_dependencies=allowed_dependencies, **kwargs)
|
|
392
|
-
|
|
393
|
-
def result_definition(self, schema: Schema) -> str:
|
|
394
|
-
return annotation(schema.spec)
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
def source_form(value, compact: bool = True, markdown: bool = False) -> str:
|
|
398
|
-
"""Returns the source code form of an object."""
|
|
399
|
-
return ValuePythonRepr().repr(value, compact=compact, markdown=markdown)
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
def class_definitions(
|
|
403
|
-
classes: Sequence[Type[Any]],
|
|
404
|
-
*,
|
|
405
|
-
allowed_dependencies: set[Type[Any]] | None = None,
|
|
406
|
-
strict: bool = False,
|
|
407
|
-
markdown: bool = False,
|
|
408
|
-
) -> str | None:
|
|
409
|
-
"""Returns a str for class definitions."""
|
|
410
|
-
if not classes:
|
|
411
|
-
return None
|
|
412
|
-
def_str = io.StringIO()
|
|
413
|
-
for i, cls in enumerate(classes):
|
|
414
|
-
if i > 0:
|
|
415
|
-
def_str.write('\n')
|
|
416
|
-
def_str.write(
|
|
417
|
-
class_definition(
|
|
418
|
-
cls,
|
|
419
|
-
strict=strict,
|
|
420
|
-
allowed_dependencies=allowed_dependencies,
|
|
421
|
-
)
|
|
422
|
-
)
|
|
423
|
-
ret = def_str.getvalue()
|
|
424
|
-
if markdown and ret:
|
|
425
|
-
ret = f'```python\n{ret}```'
|
|
426
|
-
return ret
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
def class_definition(
|
|
430
|
-
cls,
|
|
431
|
-
strict: bool = False,
|
|
432
|
-
allowed_dependencies: set[Type[Any]] | None = None,
|
|
433
|
-
) -> str:
|
|
434
|
-
"""Returns the Python class definition."""
|
|
435
|
-
out = io.StringIO()
|
|
436
|
-
schema = pg.schema(cls)
|
|
437
|
-
eligible_bases = []
|
|
438
|
-
for base_cls in cls.__bases__:
|
|
439
|
-
if base_cls is not object:
|
|
440
|
-
if allowed_dependencies is None or base_cls in allowed_dependencies:
|
|
441
|
-
eligible_bases.append(base_cls.__name__)
|
|
442
|
-
|
|
443
|
-
if eligible_bases:
|
|
444
|
-
base_cls_str = ', '.join(eligible_bases)
|
|
445
|
-
out.write(f'class {cls.__name__}({base_cls_str}):\n')
|
|
446
|
-
else:
|
|
447
|
-
out.write(f'class {cls.__name__}:\n')
|
|
448
|
-
|
|
449
|
-
if cls.__doc__:
|
|
450
|
-
doc_lines = cls.__doc__.strip().split('\n')
|
|
451
|
-
if len(doc_lines) == 1:
|
|
452
|
-
out.write(f' """{cls.__doc__}"""\n')
|
|
453
|
-
else:
|
|
454
|
-
out.write(' """')
|
|
455
|
-
|
|
456
|
-
# Since Python 3.13, the indentation of docstring lines is removed.
|
|
457
|
-
# Therefore, we add two spaces to each non-empty line to keep the
|
|
458
|
-
# indentation consistent with the class definition.
|
|
459
|
-
if sys.version_info >= (3, 13):
|
|
460
|
-
for i in range(1, len(doc_lines)):
|
|
461
|
-
if doc_lines[i]:
|
|
462
|
-
doc_lines[i] = ' ' * 2 + doc_lines[i]
|
|
463
|
-
|
|
464
|
-
for line in doc_lines:
|
|
465
|
-
out.write(line)
|
|
466
|
-
out.write('\n')
|
|
467
|
-
out.write(' """\n')
|
|
468
|
-
|
|
469
|
-
empty_class = True
|
|
470
|
-
if schema.fields:
|
|
471
|
-
for key, field in schema.items():
|
|
472
|
-
if not isinstance(key, pg.typing.ConstStrKey):
|
|
473
|
-
pg.logging.warning(
|
|
474
|
-
'Variable-length keyword arguments is not supported in '
|
|
475
|
-
f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
|
|
476
|
-
)
|
|
477
|
-
continue
|
|
478
|
-
|
|
479
|
-
# Skip fields that are marked as excluded from the prompt sent to LLM
|
|
480
|
-
# for OOP.
|
|
481
|
-
if field.metadata.get('exclude_from_prompt', False):
|
|
482
|
-
continue
|
|
483
|
-
|
|
484
|
-
# Write field doc string as comments before the field definition.
|
|
485
|
-
if field.description:
|
|
486
|
-
for line in field.description.split('\n'):
|
|
487
|
-
if line:
|
|
488
|
-
out.write(' # ')
|
|
489
|
-
out.write(line)
|
|
490
|
-
out.write('\n')
|
|
491
|
-
|
|
492
|
-
annotation_str = annotation(
|
|
493
|
-
field.value, strict=strict, allowed_dependencies=allowed_dependencies
|
|
494
|
-
)
|
|
495
|
-
out.write(f' {field.key}: {annotation_str}')
|
|
496
|
-
out.write('\n')
|
|
497
|
-
empty_class = False
|
|
498
|
-
|
|
499
|
-
for method in _iter_newly_defined_methods(cls, allowed_dependencies):
|
|
500
|
-
source = inspect.getsource(method)
|
|
501
|
-
# Remove decorators from the method definition.
|
|
502
|
-
source = re.sub(r'\s*@.*\.include_method_in_prompt.*\n', '', source)
|
|
503
|
-
out.write('\n')
|
|
504
|
-
out.write(
|
|
505
|
-
textwrap.indent(
|
|
506
|
-
inspect.cleandoc('\n' + source), ' ' * 2)
|
|
507
|
-
)
|
|
508
|
-
out.write('\n')
|
|
509
|
-
empty_class = False
|
|
510
|
-
|
|
511
|
-
if empty_class:
|
|
512
|
-
out.write(' pass\n')
|
|
513
|
-
return out.getvalue()
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
def _iter_newly_defined_methods(
|
|
517
|
-
cls, allowed_dependencies: set[Type[Any]] | None):
|
|
518
|
-
names = {attr_name: True for attr_name in dir(cls)}
|
|
519
|
-
for base in cls.__bases__:
|
|
520
|
-
if allowed_dependencies is None or base in allowed_dependencies:
|
|
521
|
-
for name in dir(base):
|
|
522
|
-
names.pop(name, None)
|
|
523
|
-
for name in names.keys():
|
|
524
|
-
attr = getattr(cls, name)
|
|
525
|
-
if callable(attr) and should_include_method_in_prompt(attr):
|
|
526
|
-
yield attr
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
def annotation(
|
|
530
|
-
vs: pg.typing.ValueSpec,
|
|
531
|
-
annotate_optional: bool = True,
|
|
532
|
-
strict: bool = False,
|
|
533
|
-
allowed_dependencies: set[Type[Any]] | None = None,
|
|
534
|
-
) -> str:
|
|
535
|
-
"""Returns the annotation string for a value spec."""
|
|
536
|
-
child_annotation_kwargs = dict(
|
|
537
|
-
strict=strict, allowed_dependencies=allowed_dependencies
|
|
538
|
-
)
|
|
539
|
-
if isinstance(vs, pg.typing.Any):
|
|
540
|
-
return 'Any'
|
|
541
|
-
elif isinstance(vs, pg.typing.Enum):
|
|
542
|
-
candidate_str = ', '.join([repr(v) for v in vs.values])
|
|
543
|
-
return f'Literal[{candidate_str}]'
|
|
544
|
-
elif isinstance(vs, pg.typing.Union):
|
|
545
|
-
candidate_str = ', '.join(
|
|
546
|
-
[
|
|
547
|
-
annotation(c, annotate_optional=False, **child_annotation_kwargs)
|
|
548
|
-
for c in vs.candidates
|
|
549
|
-
]
|
|
550
|
-
)
|
|
551
|
-
if vs.is_noneable:
|
|
552
|
-
candidate_str += ', None'
|
|
553
|
-
return f'Union[{candidate_str}]'
|
|
554
|
-
|
|
555
|
-
if isinstance(vs, pg.typing.Bool):
|
|
556
|
-
x = 'bool'
|
|
557
|
-
elif isinstance(vs, pg.typing.Str):
|
|
558
|
-
if vs.regex is None:
|
|
559
|
-
x = 'str'
|
|
560
|
-
else:
|
|
561
|
-
if strict:
|
|
562
|
-
x = f"pg.typing.Str(regex='{vs.regex.pattern}')"
|
|
563
|
-
else:
|
|
564
|
-
x = f"str(regex='{vs.regex.pattern}')"
|
|
565
|
-
elif isinstance(vs, pg.typing.Number):
|
|
566
|
-
constraints = []
|
|
567
|
-
min_label = 'min_value' if strict else 'min'
|
|
568
|
-
max_label = 'max_value' if strict else 'max'
|
|
569
|
-
if vs.min_value is not None:
|
|
570
|
-
constraints.append(f'{min_label}={vs.min_value}')
|
|
571
|
-
if vs.max_value is not None:
|
|
572
|
-
constraints.append(f'{max_label}={vs.max_value}')
|
|
573
|
-
x = 'int' if isinstance(vs, pg.typing.Int) else 'float'
|
|
574
|
-
if constraints:
|
|
575
|
-
if strict:
|
|
576
|
-
x = (
|
|
577
|
-
'pg.typing.Int'
|
|
578
|
-
if isinstance(vs, pg.typing.Int)
|
|
579
|
-
else 'pg.typing.Float'
|
|
580
|
-
)
|
|
581
|
-
x += '(' + ', '.join(constraints) + ')'
|
|
582
|
-
elif isinstance(vs, pg.typing.Object):
|
|
583
|
-
if allowed_dependencies is None or vs.cls in allowed_dependencies:
|
|
584
|
-
x = vs.cls.__name__
|
|
585
|
-
else:
|
|
586
|
-
x = 'Any'
|
|
587
|
-
elif isinstance(vs, pg.typing.List):
|
|
588
|
-
item_str = annotation(vs.element.value, **child_annotation_kwargs)
|
|
589
|
-
x = f'list[{item_str}]'
|
|
590
|
-
elif isinstance(vs, pg.typing.Tuple):
|
|
591
|
-
elem_str = ', '.join(
|
|
592
|
-
[annotation(el.value, **child_annotation_kwargs) for el in vs.elements]
|
|
593
|
-
)
|
|
594
|
-
x = f'tuple[{elem_str}]'
|
|
595
|
-
elif isinstance(vs, pg.typing.Dict):
|
|
596
|
-
kv_pairs = None
|
|
597
|
-
if vs.schema is not None:
|
|
598
|
-
kv_pairs = [
|
|
599
|
-
(k, annotation(f.value, **child_annotation_kwargs))
|
|
600
|
-
for k, f in vs.schema.items()
|
|
601
|
-
if isinstance(k, pg.typing.ConstStrKey)
|
|
602
|
-
]
|
|
603
|
-
|
|
604
|
-
if kv_pairs:
|
|
605
|
-
kv_str = ', '.join(f"'{k}': {v}" for k, v in kv_pairs)
|
|
606
|
-
x = '{' + kv_str + '}'
|
|
607
|
-
if strict:
|
|
608
|
-
x = f'pg.typing.Dict({x})'
|
|
609
|
-
elif vs.schema and vs.schema.dynamic_field:
|
|
610
|
-
v = annotation(vs.schema.dynamic_field.value, **child_annotation_kwargs)
|
|
611
|
-
x = f'dict[str, {v}]'
|
|
612
|
-
else:
|
|
613
|
-
x = 'dict[str, Any]'
|
|
614
|
-
|
|
615
|
-
else:
|
|
616
|
-
raise TypeError(f'Unsupported value spec being used as schema: {vs}.')
|
|
617
|
-
|
|
618
|
-
if annotate_optional and vs.is_noneable:
|
|
619
|
-
x += ' | None'
|
|
620
|
-
return x
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
class SchemaJsonRepr(SchemaRepr):
|
|
624
|
-
"""JSON-representation for a schema."""
|
|
625
|
-
|
|
626
|
-
def repr(self, schema: Schema, **kwargs) -> str:
|
|
627
|
-
del kwargs
|
|
628
|
-
out = io.StringIO()
|
|
629
|
-
def _visit(node: Any) -> None:
|
|
630
|
-
if isinstance(node, str):
|
|
631
|
-
out.write(f'"{node}"')
|
|
632
|
-
elif isinstance(node, list):
|
|
633
|
-
assert len(node) == 1, node
|
|
634
|
-
out.write('[')
|
|
635
|
-
_visit(node[0])
|
|
636
|
-
out.write(']')
|
|
637
|
-
elif isinstance(node, dict):
|
|
638
|
-
out.write('{')
|
|
639
|
-
for i, (k, v) in enumerate(node.items()):
|
|
640
|
-
if i != 0:
|
|
641
|
-
out.write(', ')
|
|
642
|
-
out.write(f'"{k}": ')
|
|
643
|
-
_visit(v)
|
|
644
|
-
out.write('}')
|
|
645
|
-
elif isinstance(node, pg.typing.Enum):
|
|
646
|
-
out.write(' | '.join(
|
|
647
|
-
f'"{v}"' if isinstance(v, str) else repr(v)
|
|
648
|
-
for v in node.values))
|
|
649
|
-
elif isinstance(node, pg.typing.PrimitiveType):
|
|
650
|
-
x = node.value_type.__name__
|
|
651
|
-
if isinstance(node, pg.typing.Number):
|
|
652
|
-
params = []
|
|
653
|
-
if node.min_value is not None:
|
|
654
|
-
params.append(f'min={node.min_value}')
|
|
655
|
-
if node.max_value is not None:
|
|
656
|
-
params.append(f'max={node.max_value}')
|
|
657
|
-
if params:
|
|
658
|
-
x += f'({", ".join(params)})'
|
|
659
|
-
elif isinstance(node, pg.typing.Str):
|
|
660
|
-
if node.regex is not None:
|
|
661
|
-
x += f'(regex={node.regex.pattern})'
|
|
662
|
-
if node.is_noneable:
|
|
663
|
-
x = x + ' | None'
|
|
664
|
-
out.write(x)
|
|
665
|
-
else:
|
|
666
|
-
raise ValueError(
|
|
667
|
-
f'Unsupported value spec being used as schema: {node}.')
|
|
668
|
-
_visit(schema.schema_dict())
|
|
669
|
-
return out.getvalue()
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
#
|
|
673
|
-
# Value representations.
|
|
674
|
-
#
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
class ValueRepr(metaclass=abc.ABCMeta):
|
|
678
|
-
"""Base class for value representation."""
|
|
679
|
-
|
|
680
|
-
@abc.abstractmethod
|
|
681
|
-
def repr(self, value: Any, schema: Schema | None = None, **kwargs) -> str:
|
|
682
|
-
"""Returns the representation of a structured value."""
|
|
683
|
-
|
|
684
|
-
@abc.abstractmethod
|
|
685
|
-
def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any:
|
|
686
|
-
"""Parse a LM generated text into a structured value."""
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
class ValuePythonRepr(ValueRepr):
|
|
690
|
-
"""Python-representation for value."""
|
|
691
|
-
|
|
692
|
-
def repr(self,
|
|
693
|
-
value: Any,
|
|
694
|
-
schema: Schema | None = None,
|
|
695
|
-
*,
|
|
696
|
-
compact: bool = True,
|
|
697
|
-
verbose: bool = False,
|
|
698
|
-
markdown: bool = True,
|
|
699
|
-
assign_to_var: str | None = None,
|
|
700
|
-
**kwargs) -> str:
|
|
701
|
-
del schema
|
|
702
|
-
if inspect.isclass(value):
|
|
703
|
-
cls_schema = Schema.from_value(value)
|
|
704
|
-
if isinstance(cls_schema.spec, pg.typing.Object):
|
|
705
|
-
object_code = SchemaPythonRepr().class_definitions(
|
|
706
|
-
cls_schema,
|
|
707
|
-
markdown=markdown,
|
|
708
|
-
# We add `pg.Object` as additional dependencies to the class
|
|
709
|
-
# definition so exemplars for class generation could show
|
|
710
|
-
# pg.Object as their bases.
|
|
711
|
-
additional_dependencies=[pg.Object]
|
|
712
|
-
)
|
|
713
|
-
assert object_code is not None
|
|
714
|
-
return object_code
|
|
715
|
-
else:
|
|
716
|
-
object_code = SchemaPythonRepr().result_definition(cls_schema)
|
|
717
|
-
elif isinstance(value, lf.Template):
|
|
718
|
-
return str(value)
|
|
719
|
-
else:
|
|
720
|
-
object_code = pg.format(
|
|
721
|
-
value, compact=compact, verbose=verbose, python_format=True
|
|
722
|
-
)
|
|
723
|
-
if assign_to_var is not None:
|
|
724
|
-
object_code = f'{assign_to_var} = {object_code}'
|
|
725
|
-
if markdown:
|
|
726
|
-
return f'```python\n{ object_code }\n```'
|
|
727
|
-
return object_code
|
|
728
|
-
|
|
729
|
-
def parse(
|
|
730
|
-
self,
|
|
731
|
-
text: str,
|
|
732
|
-
schema: Schema | None = None,
|
|
733
|
-
*,
|
|
734
|
-
additional_context: dict[str, Type[Any]] | None = None,
|
|
735
|
-
permission: pg.coding.CodePermission = (
|
|
736
|
-
pg.coding.CodePermission.ASSIGN | pg.coding.CodePermission.CALL
|
|
737
|
-
),
|
|
738
|
-
autofix=0,
|
|
739
|
-
autofix_lm: lf.LanguageModel = lf.contextual(),
|
|
740
|
-
**kwargs,
|
|
741
|
-
) -> Any:
|
|
742
|
-
"""Parse a Python string into a structured object."""
|
|
743
|
-
del kwargs
|
|
744
|
-
global_vars = additional_context or {}
|
|
745
|
-
if schema is not None:
|
|
746
|
-
dependencies = schema.class_dependencies()
|
|
747
|
-
global_vars.update({d.__name__: d for d in dependencies})
|
|
748
|
-
return structure_from_python(
|
|
749
|
-
text,
|
|
750
|
-
global_vars=global_vars,
|
|
751
|
-
autofix=autofix,
|
|
752
|
-
autofix_lm=autofix_lm,
|
|
753
|
-
permission=permission,
|
|
754
|
-
)
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
def structure_from_python(
|
|
758
|
-
code: str,
|
|
759
|
-
*,
|
|
760
|
-
global_vars: dict[str, Any] | None = None,
|
|
761
|
-
permission: pg.coding.CodePermission = (
|
|
762
|
-
pg.coding.CodePermission.ASSIGN | pg.coding.CodePermission.CALL
|
|
763
|
-
),
|
|
764
|
-
autofix=0,
|
|
765
|
-
autofix_lm: lf.LanguageModel = lf.contextual(),
|
|
766
|
-
) -> Any:
|
|
767
|
-
"""Evaluates structure from Python code with access to symbols."""
|
|
768
|
-
global_vars = global_vars or {}
|
|
769
|
-
global_vars.update({
|
|
770
|
-
'pg': pg,
|
|
771
|
-
'Object': pg.Object,
|
|
772
|
-
'Any': typing.Any,
|
|
773
|
-
'List': typing.List,
|
|
774
|
-
'Tuple': typing.Tuple,
|
|
775
|
-
'Dict': typing.Dict,
|
|
776
|
-
'Sequence': typing.Sequence,
|
|
777
|
-
'Optional': typing.Optional,
|
|
778
|
-
'Union': typing.Union,
|
|
779
|
-
# Special value markers.
|
|
780
|
-
'UNKNOWN': UNKNOWN,
|
|
781
|
-
})
|
|
782
|
-
# We are creating objects here, so we execute the code without a sandbox.
|
|
783
|
-
return correction.run_with_correction(
|
|
784
|
-
code,
|
|
785
|
-
global_vars=global_vars,
|
|
786
|
-
sandbox=False,
|
|
787
|
-
max_attempts=autofix,
|
|
788
|
-
lm=autofix_lm,
|
|
789
|
-
permission=permission,
|
|
790
|
-
)
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
class JsonError(Exception):
|
|
794
|
-
"""Json parsing error."""
|
|
795
|
-
|
|
796
|
-
def __init__(self, json: str, cause: Exception):
|
|
797
|
-
self.json = json
|
|
798
|
-
self.cause = cause
|
|
799
|
-
|
|
800
|
-
def __str__(self) -> str:
|
|
801
|
-
r = io.StringIO()
|
|
802
|
-
r.write(
|
|
803
|
-
pg.colored(
|
|
804
|
-
f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
|
|
805
|
-
)
|
|
806
|
-
)
|
|
807
|
-
|
|
808
|
-
r.write('\n\n')
|
|
809
|
-
r.write(pg.colored('JSON text:', 'red'))
|
|
810
|
-
r.write('\n\n')
|
|
811
|
-
r.write(textwrap.indent(pg.colored(self.json, 'magenta'), ' ' * 2))
|
|
812
|
-
return r.getvalue()
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
class ValueJsonRepr(ValueRepr):
|
|
816
|
-
"""JSON-representation for value."""
|
|
817
|
-
|
|
818
|
-
def repr(self, value: Any, schema: Schema | None = None, **kwargs) -> str:
|
|
819
|
-
del schema
|
|
820
|
-
return pg.to_json_str(dict(result=value))
|
|
821
|
-
|
|
822
|
-
def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any:
|
|
823
|
-
"""Parse a JSON string into a structured object."""
|
|
824
|
-
del schema
|
|
825
|
-
try:
|
|
826
|
-
text = cleanup_json(text)
|
|
827
|
-
v = pg.from_json_str(text, **kwargs)
|
|
828
|
-
except Exception as e:
|
|
829
|
-
raise JsonError(text, e) # pylint: disable=raise-missing-from
|
|
830
|
-
|
|
831
|
-
if not isinstance(v, dict) or 'result' not in v:
|
|
832
|
-
raise JsonError(text, ValueError(
|
|
833
|
-
'The root node of the JSON must be a dict with key `result`. '
|
|
834
|
-
f'Encountered: {v}'
|
|
835
|
-
))
|
|
836
|
-
return v['result']
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
def cleanup_json(json_str: str) -> str:
|
|
840
|
-
"""Clean up the LM responded JSON string."""
|
|
841
|
-
# Treatments:
|
|
842
|
-
# 1. Extract the JSON string with a top-level dict from the response.
|
|
843
|
-
# This prevents the leading and trailing texts in the response to
|
|
844
|
-
# be counted as part of the JSON.
|
|
845
|
-
# 2. Escape new lines in JSON values.
|
|
846
|
-
|
|
847
|
-
curly_brackets = 0
|
|
848
|
-
under_json = False
|
|
849
|
-
under_str = False
|
|
850
|
-
str_begin = -1
|
|
851
|
-
|
|
852
|
-
cleaned = io.StringIO()
|
|
853
|
-
for i, c in enumerate(json_str):
|
|
854
|
-
if c == '{' and not under_str:
|
|
855
|
-
cleaned.write(c)
|
|
856
|
-
curly_brackets += 1
|
|
857
|
-
under_json = True
|
|
858
|
-
continue
|
|
859
|
-
elif not under_json:
|
|
860
|
-
continue
|
|
861
|
-
|
|
862
|
-
if c == '}' and not under_str:
|
|
863
|
-
cleaned.write(c)
|
|
864
|
-
curly_brackets -= 1
|
|
865
|
-
if curly_brackets == 0:
|
|
866
|
-
break
|
|
867
|
-
elif c == '"' and json_str[i - 1] != '\\':
|
|
868
|
-
under_str = not under_str
|
|
869
|
-
if under_str:
|
|
870
|
-
str_begin = i
|
|
871
|
-
else:
|
|
872
|
-
assert str_begin > 0
|
|
873
|
-
str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
|
|
874
|
-
cleaned.write(str_value)
|
|
875
|
-
str_begin = -1
|
|
876
|
-
elif not under_str:
|
|
877
|
-
cleaned.write(c)
|
|
878
|
-
|
|
879
|
-
if not under_json:
|
|
880
|
-
raise ValueError(f'No JSON dict in the output: {json_str}')
|
|
881
|
-
|
|
882
|
-
if curly_brackets > 0:
|
|
883
|
-
raise ValueError(
|
|
884
|
-
f'Malformated JSON: missing {curly_brackets} closing curly braces.'
|
|
885
|
-
)
|
|
886
|
-
|
|
887
|
-
return cleaned.getvalue()
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
def schema_repr(protocol: SchemaProtocol) -> SchemaRepr:
|
|
891
|
-
"""Gets a SchemaRepr object from protocol."""
|
|
892
|
-
if protocol == 'json':
|
|
893
|
-
return SchemaJsonRepr()
|
|
894
|
-
elif protocol == 'python':
|
|
895
|
-
return SchemaPythonRepr()
|
|
896
|
-
raise ValueError(f'Unsupported protocol: {protocol}.')
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
def value_repr(protocol: SchemaProtocol) -> ValueRepr:
|
|
900
|
-
if protocol == 'json':
|
|
901
|
-
return ValueJsonRepr()
|
|
902
|
-
elif protocol == 'python':
|
|
903
|
-
return ValuePythonRepr()
|
|
904
|
-
raise ValueError(f'Unsupported protocol: {protocol}.')
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
#
|
|
908
|
-
# Special value markers.
|
|
909
|
-
#
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
class Missing(pg.Object, pg.typing.CustomTyping):
|
|
913
|
-
"""Value marker for a missing field.
|
|
914
|
-
|
|
915
|
-
This class differs from pg.MISSING_VALUE in two aspects:
|
|
916
|
-
* When a field is assigned with lf.Missing(), it's considered non-partial.
|
|
917
|
-
* lf.Missing() could format the value spec as Python annotations that are
|
|
918
|
-
consistent with `lf.structured.Schema.schema_repr()`.
|
|
919
|
-
"""
|
|
920
|
-
|
|
921
|
-
def _on_bound(self):
|
|
922
|
-
super()._on_bound()
|
|
923
|
-
self._value_spec = None
|
|
924
|
-
|
|
925
|
-
@property
|
|
926
|
-
def value_spec(self) -> pg.ValueSpec | None:
|
|
927
|
-
"""Returns the value spec that applies to the current missing value."""
|
|
928
|
-
return self._value_spec
|
|
929
|
-
|
|
930
|
-
def custom_apply(
|
|
931
|
-
self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
|
|
932
|
-
) -> tuple[bool, Any]:
|
|
933
|
-
self._value_spec = value_spec
|
|
934
|
-
return (False, self)
|
|
935
|
-
|
|
936
|
-
def format(self, *args, **kwargs) -> str:
|
|
937
|
-
if self._value_spec is None:
|
|
938
|
-
return 'MISSING'
|
|
939
|
-
return f'MISSING({annotation(self._value_spec)})'
|
|
940
|
-
|
|
941
|
-
@classmethod
|
|
942
|
-
def find_missing(cls, value: Any) -> dict[str, 'Missing']:
|
|
943
|
-
"""Lists all missing values contained in the value."""
|
|
944
|
-
missing = {}
|
|
945
|
-
|
|
946
|
-
def _visit(k, v, p):
|
|
947
|
-
del p
|
|
948
|
-
if isinstance(v, Missing):
|
|
949
|
-
missing[k] = v
|
|
950
|
-
return pg.TraverseAction.ENTER
|
|
951
|
-
|
|
952
|
-
pg.traverse(value, _visit)
|
|
953
|
-
return missing
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
MISSING = Missing()
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
def mark_missing(value: Any) -> Any:
|
|
960
|
-
"""Replaces pg.MISSING within the value with lf.structured.Missing objects."""
|
|
961
|
-
if isinstance(value, list):
|
|
962
|
-
value = pg.List(value)
|
|
963
|
-
elif isinstance(value, dict):
|
|
964
|
-
value = pg.Dict(value)
|
|
965
|
-
if isinstance(value, pg.Symbolic):
|
|
966
|
-
|
|
967
|
-
def _mark_missing(k, v, p):
|
|
968
|
-
del k, p
|
|
969
|
-
if pg.MISSING_VALUE == v:
|
|
970
|
-
v = Missing()
|
|
971
|
-
return v
|
|
972
|
-
|
|
973
|
-
return value.rebind(_mark_missing, raise_on_no_change=False)
|
|
974
|
-
return value
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
class Unknown(pg.Object, pg.typing.CustomTyping):
|
|
978
|
-
"""Value marker for a field that LMs could not provide."""
|
|
979
|
-
|
|
980
|
-
def custom_apply(self, *args, **kwargs) -> tuple[bool, Any]:
|
|
981
|
-
return (False, self)
|
|
982
|
-
|
|
983
|
-
def format(self, *args, **kwargs) -> str:
|
|
984
|
-
return 'UNKNOWN'
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
UNKNOWN = Unknown()
|