langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +2 -0
- langfun/core/eval/v2/checkpointing.py +76 -7
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +92 -17
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +84 -15
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +12 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +64 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +58 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +64 -3
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,664 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Schema and Prompting Protocol for Structured Data."""
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
import inspect
|
|
18
|
+
import io
|
|
19
|
+
import textwrap
|
|
20
|
+
import typing
|
|
21
|
+
from typing import Any, ClassVar, Type, Union
|
|
22
|
+
import langfun.core as lf
|
|
23
|
+
import pyglove as pg
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _parse_value_spec(value) -> pg.typing.ValueSpec:
|
|
27
|
+
"""Parses a PyGlove ValueSpec equivalent into a ValueSpec.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
```
|
|
31
|
+
_parse_value_spec(int) -> pg.typing.Int
|
|
32
|
+
_parse_value_spec(list[int]) -> pg.typing.List(pg.typing.Int)
|
|
33
|
+
_parse_value_spec(dict(a=int, b=str)) -> pg.typing.Dict(
|
|
34
|
+
pg.typing.Int, pg.typing.Str
|
|
35
|
+
)
|
|
36
|
+
```
|
|
37
|
+
Args:
|
|
38
|
+
value: The value to parse. It can be a PyGlove ValueSpec, a dict with a
|
|
39
|
+
single 'result' key, or a Python type annotation.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A PyGlove ValueSpec.
|
|
43
|
+
"""
|
|
44
|
+
if isinstance(value, pg.typing.ValueSpec):
|
|
45
|
+
return value
|
|
46
|
+
|
|
47
|
+
if isinstance(value, dict) and len(value) == 1 and 'result' in value:
|
|
48
|
+
value = value['result']
|
|
49
|
+
|
|
50
|
+
def _parse_node(v) -> pg.typing.ValueSpec:
|
|
51
|
+
if isinstance(v, dict):
|
|
52
|
+
return pg.typing.Dict([(k, _parse_node(cv)) for k, cv in v.items()])
|
|
53
|
+
elif isinstance(v, list):
|
|
54
|
+
if len(v) != 1:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
'Annotation with list must be a list of a single element. '
|
|
57
|
+
f'Encountered: {v}'
|
|
58
|
+
)
|
|
59
|
+
return pg.typing.List(_parse_node(v[0]))
|
|
60
|
+
else:
|
|
61
|
+
spec = pg.typing.ValueSpec.from_annotation(v, auto_typing=True)
|
|
62
|
+
if isinstance(
|
|
63
|
+
spec,
|
|
64
|
+
(
|
|
65
|
+
pg.typing.Any,
|
|
66
|
+
pg.typing.Callable,
|
|
67
|
+
pg.typing.Tuple,
|
|
68
|
+
pg.typing.Type,
|
|
69
|
+
pg.typing.Union,
|
|
70
|
+
),
|
|
71
|
+
):
|
|
72
|
+
raise ValueError(f'Unsupported schema specification: {v}')
|
|
73
|
+
return spec
|
|
74
|
+
|
|
75
|
+
return _parse_node(value)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class SchemaError(Exception): # pylint: disable=g-bad-exception-name
|
|
79
|
+
"""Schema error."""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
schema: 'Schema',
|
|
84
|
+
value: Any,
|
|
85
|
+
protocol: str,
|
|
86
|
+
cause: Exception
|
|
87
|
+
):
|
|
88
|
+
self.schema = schema
|
|
89
|
+
self.value = value
|
|
90
|
+
self.protocol = protocol
|
|
91
|
+
self.cause = cause
|
|
92
|
+
|
|
93
|
+
def __str__(self):
|
|
94
|
+
r = io.StringIO()
|
|
95
|
+
r.write(
|
|
96
|
+
pg.colored(
|
|
97
|
+
f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
r.write('\n')
|
|
102
|
+
r.write(pg.colored('Schema:', 'red'))
|
|
103
|
+
r.write('\n\n')
|
|
104
|
+
r.write(textwrap.indent(
|
|
105
|
+
pg.colored(
|
|
106
|
+
schema_repr(self.schema, protocol=self.protocol), 'magenta'
|
|
107
|
+
),
|
|
108
|
+
' ' * 2
|
|
109
|
+
))
|
|
110
|
+
r.write('\n\n')
|
|
111
|
+
r.write(pg.colored('Generated value:', 'red'))
|
|
112
|
+
r.write('\n\n')
|
|
113
|
+
r.write(textwrap.indent(
|
|
114
|
+
pg.colored(value_repr(self.value, protocol=self.protocol), 'magenta'),
|
|
115
|
+
' ' * 2
|
|
116
|
+
))
|
|
117
|
+
return r.getvalue()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class Schema(
|
|
121
|
+
lf.NaturalLanguageFormattable,
|
|
122
|
+
pg.Object,
|
|
123
|
+
pg.views.HtmlTreeView.Extension
|
|
124
|
+
):
|
|
125
|
+
"""Schema for structured inputs and outputs.
|
|
126
|
+
|
|
127
|
+
`lf.Schema` provides a unified representation for defining the output schema
|
|
128
|
+
used in Langfun's structured operations like `lf.query`, `lf.parse`,
|
|
129
|
+
`lf.complete`, and `lf.describe`. It acts as an abstraction layer,
|
|
130
|
+
allowing schemas to be defined using Python type annotations, `pg.Object`
|
|
131
|
+
classes, or dictionaries, and then converting them into a format that
|
|
132
|
+
language models can understand.
|
|
133
|
+
|
|
134
|
+
`lf.Schema` can be created from various types using `lf.Schema.from_value`:
|
|
135
|
+
* Built-in types: `int`, `str`, `bool`, `float`
|
|
136
|
+
* Typing constructs: `list`, `dict`, `typing.Union`, `typing.Literal`,
|
|
137
|
+
`typing.Optional`
|
|
138
|
+
* PyGlove classes: `pg.Object` subclasses
|
|
139
|
+
|
|
140
|
+
**1. Creating a Schema:**
|
|
141
|
+
|
|
142
|
+
```python
|
|
143
|
+
import langfun as lf
|
|
144
|
+
import pyglove as pg
|
|
145
|
+
from typing import Literal, Union
|
|
146
|
+
|
|
147
|
+
# From a basic type
|
|
148
|
+
int_schema = lf.Schema.from_value(int)
|
|
149
|
+
|
|
150
|
+
# From a list type
|
|
151
|
+
list_schema = lf.Schema.from_value(list[int])
|
|
152
|
+
|
|
153
|
+
# From a dictionary
|
|
154
|
+
dict_schema = lf.Schema.from_value(dict(a=int, b=str))
|
|
155
|
+
|
|
156
|
+
# From pg.Object
|
|
157
|
+
class Point(pg.Object):
|
|
158
|
+
x: int
|
|
159
|
+
y: int
|
|
160
|
+
point_schema = lf.Schema.from_value(Point)
|
|
161
|
+
|
|
162
|
+
# From Union or Literal
|
|
163
|
+
union_schema = lf.Schema.from_value(Union[int, str])
|
|
164
|
+
literal_schema = lf.Schema.from_value(Literal['A', 'B'])
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
**2. Schema Representation:**
|
|
168
|
+
Once created, a schema object can represent itself in different formats,
|
|
169
|
+
such as Python-like syntax or JSON, which is used in prompts to LLMs.
|
|
170
|
+
|
|
171
|
+
```python
|
|
172
|
+
print(point_schema.repr('python'))
|
|
173
|
+
# Output:
|
|
174
|
+
# class Point:
|
|
175
|
+
# x: int
|
|
176
|
+
# y: int
|
|
177
|
+
|
|
178
|
+
print(dict_schema.repr('json'))
|
|
179
|
+
# Output:
|
|
180
|
+
# {
|
|
181
|
+
# "a": "int",
|
|
182
|
+
# "b": "str"
|
|
183
|
+
# }
|
|
184
|
+
```
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
spec: pg.typing.Annotated[
|
|
188
|
+
pg.typing.Object(pg.typing.ValueSpec, transform=_parse_value_spec),
|
|
189
|
+
(
|
|
190
|
+
'A PyGlove ValueSpec object representing the spec for the value '
|
|
191
|
+
'to be parsed.'
|
|
192
|
+
),
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
def schema_repr(self, protocol: str = 'python', **kwargs) -> str:
|
|
196
|
+
"""Returns the representation of the schema."""
|
|
197
|
+
return schema_repr(self, protocol=protocol, **kwargs)
|
|
198
|
+
|
|
199
|
+
def value_repr(
|
|
200
|
+
self, value: Any, protocol: str = 'python', **kwargs
|
|
201
|
+
) -> str:
|
|
202
|
+
"""Returns the representation of a structured value."""
|
|
203
|
+
return value_repr(value, schema=self, protocol=protocol, **kwargs)
|
|
204
|
+
|
|
205
|
+
def parse_value(
|
|
206
|
+
self, text: str, protocol: str = 'python', **kwargs
|
|
207
|
+
) -> Any:
|
|
208
|
+
"""Parses a LM generated text into a structured value."""
|
|
209
|
+
value = parse_value(text, schema=self, protocol=protocol, **kwargs)
|
|
210
|
+
|
|
211
|
+
# TODO(daiyip): support autofix for schema error.
|
|
212
|
+
try:
|
|
213
|
+
return self.spec.apply(value)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
raise SchemaError(self, value, protocol, e) # pylint: disable=raise-missing-from
|
|
216
|
+
|
|
217
|
+
def natural_language_format(self) -> str:
|
|
218
|
+
return self.schema_str()
|
|
219
|
+
|
|
220
|
+
def schema_dict(self) -> dict[str, Any]:
|
|
221
|
+
"""Returns the dictionary representation of the schema."""
|
|
222
|
+
|
|
223
|
+
def _node(vs: pg.typing.ValueSpec) -> Any:
|
|
224
|
+
if isinstance(vs, pg.typing.PrimitiveType):
|
|
225
|
+
return vs
|
|
226
|
+
elif isinstance(vs, pg.typing.Dict):
|
|
227
|
+
assert vs.schema is not None
|
|
228
|
+
return {str(k): _node(f.value) for k, f in vs.schema.fields.items()}
|
|
229
|
+
elif isinstance(vs, pg.typing.List):
|
|
230
|
+
return [_node(vs.element.value)]
|
|
231
|
+
elif isinstance(vs, pg.typing.Object):
|
|
232
|
+
if issubclass(vs.cls, pg.Object):
|
|
233
|
+
d = {pg.JSONConvertible.TYPE_NAME_KEY: vs.cls.__serialization_key__}
|
|
234
|
+
d.update(
|
|
235
|
+
{
|
|
236
|
+
str(k): _node(f.value)
|
|
237
|
+
for k, f in vs.cls.__schema__.fields.items()
|
|
238
|
+
}
|
|
239
|
+
)
|
|
240
|
+
return d
|
|
241
|
+
raise TypeError(
|
|
242
|
+
'Unsupported value spec being used as the schema for '
|
|
243
|
+
f'structured data: {vs}.')
|
|
244
|
+
|
|
245
|
+
return {'result': _node(self.spec)}
|
|
246
|
+
|
|
247
|
+
def class_dependencies(
|
|
248
|
+
self,
|
|
249
|
+
include_base_classes: bool = True,
|
|
250
|
+
include_subclasses: bool = True,
|
|
251
|
+
include_generated_subclasses: bool = False) -> list[Type[Any]]:
|
|
252
|
+
"""Returns a list of class dependencies for current schema."""
|
|
253
|
+
return class_dependencies(
|
|
254
|
+
self.spec,
|
|
255
|
+
include_base_classes,
|
|
256
|
+
include_subclasses,
|
|
257
|
+
include_generated_subclasses
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
@classmethod
|
|
261
|
+
def from_value(cls, value) -> 'Schema':
|
|
262
|
+
"""Creates a schema from an equivalent representation."""
|
|
263
|
+
if isinstance(value, Schema):
|
|
264
|
+
return value
|
|
265
|
+
return cls(_parse_value_spec(value))
|
|
266
|
+
|
|
267
|
+
def _html_tree_view_content(
|
|
268
|
+
self,
|
|
269
|
+
*,
|
|
270
|
+
view: pg.views.HtmlTreeView,
|
|
271
|
+
**kwargs,
|
|
272
|
+
):
|
|
273
|
+
return pg.Html.element(
|
|
274
|
+
'div',
|
|
275
|
+
[pg.Html.escape(self.schema_repr(protocol='python'))],
|
|
276
|
+
css_classes=['lf-schema-definition']
|
|
277
|
+
).add_style(
|
|
278
|
+
"""
|
|
279
|
+
.lf-schema-definition {
|
|
280
|
+
color: blue;
|
|
281
|
+
margin: 5px;
|
|
282
|
+
white-space: pre-wrap;
|
|
283
|
+
}
|
|
284
|
+
"""
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
|
|
292
|
+
"""Returns a list of top level value specs from a symbolic value."""
|
|
293
|
+
top_level_object_specs = []
|
|
294
|
+
|
|
295
|
+
def _collect_top_level_object_specs(k, v, p):
|
|
296
|
+
del k, p
|
|
297
|
+
if isinstance(v, pg.Object):
|
|
298
|
+
top_level_object_specs.append(pg.typing.Object(v.__class__))
|
|
299
|
+
return pg.TraverseAction.CONTINUE
|
|
300
|
+
return pg.TraverseAction.ENTER
|
|
301
|
+
|
|
302
|
+
pg.traverse(value, _collect_top_level_object_specs)
|
|
303
|
+
return top_level_object_specs
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def class_dependencies(
|
|
307
|
+
value_or_spec: Union[
|
|
308
|
+
pg.Symbolic,
|
|
309
|
+
Schema,
|
|
310
|
+
pg.typing.ValueSpec,
|
|
311
|
+
Type[pg.Object],
|
|
312
|
+
tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...],
|
|
313
|
+
],
|
|
314
|
+
include_base_classes: bool = True,
|
|
315
|
+
include_subclasses: bool = True,
|
|
316
|
+
include_generated_subclasses: bool = False,
|
|
317
|
+
) -> list[Type[Any]]:
|
|
318
|
+
"""Returns a list of class dependencies from a value or specs."""
|
|
319
|
+
if isinstance(value_or_spec, Schema):
|
|
320
|
+
value_or_spec = value_or_spec.spec
|
|
321
|
+
|
|
322
|
+
if inspect.isclass(value_or_spec) or isinstance(
|
|
323
|
+
value_or_spec, pg.typing.ValueSpec
|
|
324
|
+
):
|
|
325
|
+
value_or_spec = (value_or_spec,)
|
|
326
|
+
|
|
327
|
+
if isinstance(value_or_spec, tuple):
|
|
328
|
+
value_specs = []
|
|
329
|
+
for v in value_or_spec:
|
|
330
|
+
if isinstance(v, pg.typing.ValueSpec):
|
|
331
|
+
value_specs.append(v)
|
|
332
|
+
elif inspect.isclass(v):
|
|
333
|
+
value_specs.append(pg.typing.Object(v))
|
|
334
|
+
else:
|
|
335
|
+
raise TypeError(f'Unsupported spec type: {v!r}')
|
|
336
|
+
else:
|
|
337
|
+
value_specs = _top_level_object_specs_from_value(value_or_spec)
|
|
338
|
+
|
|
339
|
+
seen = set()
|
|
340
|
+
dependencies = []
|
|
341
|
+
|
|
342
|
+
def _add_dependency(cls_or_classes):
|
|
343
|
+
if isinstance(cls_or_classes, type):
|
|
344
|
+
cls_or_classes = [cls_or_classes]
|
|
345
|
+
for cls in cls_or_classes:
|
|
346
|
+
if cls not in dependencies:
|
|
347
|
+
dependencies.append(cls)
|
|
348
|
+
|
|
349
|
+
def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool):
|
|
350
|
+
if isinstance(vs, pg.typing.Object):
|
|
351
|
+
cls = vs.cls
|
|
352
|
+
if cls.__module__ == 'builtins':
|
|
353
|
+
return
|
|
354
|
+
|
|
355
|
+
if cls not in seen:
|
|
356
|
+
seen.add(cls)
|
|
357
|
+
|
|
358
|
+
if include_base_classes:
|
|
359
|
+
# Add base classes as dependencies.
|
|
360
|
+
for base_cls in cls.__bases__:
|
|
361
|
+
# We only keep track of user-defined symbolic classes.
|
|
362
|
+
if base_cls is not object and base_cls is not pg.Object:
|
|
363
|
+
_fill_dependencies(
|
|
364
|
+
pg.typing.Object(base_cls), include_subclasses=False
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# Add members as dependencies.
|
|
368
|
+
for field in pg.schema(cls).values():
|
|
369
|
+
_fill_dependencies(field.value, include_subclasses)
|
|
370
|
+
_add_dependency(cls)
|
|
371
|
+
|
|
372
|
+
# Check subclasses if available.
|
|
373
|
+
if include_subclasses:
|
|
374
|
+
for subcls in cls.__subclasses__():
|
|
375
|
+
# NOTE(daiyip): To prevent LLM-generated "hallucinated" classes from
|
|
376
|
+
# polluting the generation space, classes dynamically created by
|
|
377
|
+
# 'eval' (which have __module__ == 'builtins') are excluded from
|
|
378
|
+
# dependencies by default.
|
|
379
|
+
if ((include_generated_subclasses or subcls.__module__ != 'builtins')
|
|
380
|
+
and subcls not in dependencies):
|
|
381
|
+
_fill_dependencies(
|
|
382
|
+
pg.typing.Object(subcls), include_subclasses=True
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
if isinstance(vs, pg.typing.List):
|
|
386
|
+
_fill_dependencies(vs.element.value, include_subclasses)
|
|
387
|
+
elif isinstance(vs, pg.typing.Tuple):
|
|
388
|
+
for elem in vs.elements:
|
|
389
|
+
_fill_dependencies(elem.value, include_subclasses)
|
|
390
|
+
elif isinstance(vs, pg.typing.Dict) and vs.schema:
|
|
391
|
+
for v in vs.schema.values():
|
|
392
|
+
_fill_dependencies(v.value, include_subclasses)
|
|
393
|
+
elif isinstance(vs, pg.typing.Union):
|
|
394
|
+
for v in vs.candidates:
|
|
395
|
+
_fill_dependencies(v, include_subclasses)
|
|
396
|
+
|
|
397
|
+
for value_spec in value_specs:
|
|
398
|
+
_fill_dependencies(value_spec, include_subclasses)
|
|
399
|
+
return dependencies
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def schema_spec(noneable: bool = False) -> pg.typing.ValueSpec: # pylint: disable=unused-argument
|
|
403
|
+
if typing.TYPE_CHECKING:
|
|
404
|
+
return Any
|
|
405
|
+
return pg.typing.Object(
|
|
406
|
+
Schema, transform=Schema.from_value, is_noneable=noneable
|
|
407
|
+
) # pylint: disable=unreachable-code
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def annotation(
|
|
411
|
+
vs: pg.typing.ValueSpec,
|
|
412
|
+
annotate_optional: bool = True,
|
|
413
|
+
strict: bool = False,
|
|
414
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
|
415
|
+
) -> str:
|
|
416
|
+
"""Returns the annotation string for a value spec."""
|
|
417
|
+
child_annotation_kwargs = dict(
|
|
418
|
+
strict=strict, allowed_dependencies=allowed_dependencies
|
|
419
|
+
)
|
|
420
|
+
if isinstance(vs, pg.typing.Any):
|
|
421
|
+
return 'Any'
|
|
422
|
+
elif isinstance(vs, pg.typing.Enum):
|
|
423
|
+
candidate_str = ', '.join([repr(v) for v in vs.values])
|
|
424
|
+
return f'Literal[{candidate_str}]'
|
|
425
|
+
elif isinstance(vs, pg.typing.Union):
|
|
426
|
+
candidate_str = ', '.join(
|
|
427
|
+
[
|
|
428
|
+
annotation(c, annotate_optional=False, **child_annotation_kwargs)
|
|
429
|
+
for c in vs.candidates
|
|
430
|
+
]
|
|
431
|
+
)
|
|
432
|
+
if vs.is_noneable:
|
|
433
|
+
candidate_str += ', None'
|
|
434
|
+
return f'Union[{candidate_str}]'
|
|
435
|
+
|
|
436
|
+
if isinstance(vs, pg.typing.Bool):
|
|
437
|
+
x = 'bool'
|
|
438
|
+
elif isinstance(vs, pg.typing.Str):
|
|
439
|
+
if vs.regex is None:
|
|
440
|
+
x = 'str'
|
|
441
|
+
else:
|
|
442
|
+
if strict:
|
|
443
|
+
x = f"pg.typing.Str(regex='{vs.regex.pattern}')"
|
|
444
|
+
else:
|
|
445
|
+
x = f"str(regex='{vs.regex.pattern}')"
|
|
446
|
+
elif isinstance(vs, pg.typing.Number):
|
|
447
|
+
constraints = []
|
|
448
|
+
min_label = 'min_value' if strict else 'min'
|
|
449
|
+
max_label = 'max_value' if strict else 'max'
|
|
450
|
+
if vs.min_value is not None:
|
|
451
|
+
constraints.append(f'{min_label}={vs.min_value}')
|
|
452
|
+
if vs.max_value is not None:
|
|
453
|
+
constraints.append(f'{max_label}={vs.max_value}')
|
|
454
|
+
x = 'int' if isinstance(vs, pg.typing.Int) else 'float'
|
|
455
|
+
if constraints:
|
|
456
|
+
if strict:
|
|
457
|
+
x = (
|
|
458
|
+
'pg.typing.Int'
|
|
459
|
+
if isinstance(vs, pg.typing.Int)
|
|
460
|
+
else 'pg.typing.Float'
|
|
461
|
+
)
|
|
462
|
+
x += '(' + ', '.join(constraints) + ')'
|
|
463
|
+
elif isinstance(vs, pg.typing.Object):
|
|
464
|
+
if allowed_dependencies is None or vs.cls in allowed_dependencies:
|
|
465
|
+
x = vs.cls.__name__
|
|
466
|
+
else:
|
|
467
|
+
x = 'Any'
|
|
468
|
+
elif isinstance(vs, pg.typing.List):
|
|
469
|
+
item_str = annotation(vs.element.value, **child_annotation_kwargs)
|
|
470
|
+
x = f'list[{item_str}]'
|
|
471
|
+
elif isinstance(vs, pg.typing.Tuple):
|
|
472
|
+
elem_str = ', '.join(
|
|
473
|
+
[annotation(el.value, **child_annotation_kwargs) for el in vs.elements]
|
|
474
|
+
)
|
|
475
|
+
x = f'tuple[{elem_str}]'
|
|
476
|
+
elif isinstance(vs, pg.typing.Dict):
|
|
477
|
+
kv_pairs = None
|
|
478
|
+
if vs.schema is not None:
|
|
479
|
+
kv_pairs = [
|
|
480
|
+
(k, annotation(f.value, **child_annotation_kwargs))
|
|
481
|
+
for k, f in vs.schema.items()
|
|
482
|
+
if isinstance(k, pg.typing.ConstStrKey)
|
|
483
|
+
]
|
|
484
|
+
|
|
485
|
+
if kv_pairs:
|
|
486
|
+
kv_str = ', '.join(f"'{k}': {v}" for k, v in kv_pairs)
|
|
487
|
+
x = '{' + kv_str + '}'
|
|
488
|
+
if strict:
|
|
489
|
+
x = f'pg.typing.Dict({x})'
|
|
490
|
+
elif vs.schema and vs.schema.dynamic_field:
|
|
491
|
+
v = annotation(vs.schema.dynamic_field.value, **child_annotation_kwargs)
|
|
492
|
+
x = f'dict[str, {v}]'
|
|
493
|
+
else:
|
|
494
|
+
x = 'dict[str, Any]'
|
|
495
|
+
|
|
496
|
+
else:
|
|
497
|
+
raise TypeError(f'Unsupported value spec being used as schema: {vs}.')
|
|
498
|
+
|
|
499
|
+
if annotate_optional and vs.is_noneable:
|
|
500
|
+
x += ' | None'
|
|
501
|
+
return x
|
|
502
|
+
|
|
503
|
+
#
|
|
504
|
+
# Prompting protocols for structured data.
|
|
505
|
+
#
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
class PromptingProtocol(metaclass=abc.ABCMeta):
|
|
509
|
+
"""Base class for prompting protocols for structured data."""
|
|
510
|
+
|
|
511
|
+
NAME: ClassVar[str]
|
|
512
|
+
|
|
513
|
+
_PROTOCOLS: ClassVar[dict[str, Type['PromptingProtocol']]] = {}
|
|
514
|
+
|
|
515
|
+
def __init_subclass__(cls):
|
|
516
|
+
PromptingProtocol._PROTOCOLS[cls.NAME] = cls
|
|
517
|
+
|
|
518
|
+
@classmethod
|
|
519
|
+
def from_name(cls, name: str) -> 'PromptingProtocol':
|
|
520
|
+
"""Returns the prompting protocol from the name."""
|
|
521
|
+
protocol_cls = cls._PROTOCOLS.get(name)
|
|
522
|
+
if protocol_cls is None:
|
|
523
|
+
raise ValueError(f'Unsupported protocol: {name}.')
|
|
524
|
+
return protocol_cls() # pytype: disable=not-instantiable
|
|
525
|
+
|
|
526
|
+
@abc.abstractmethod
|
|
527
|
+
def schema_repr(self, schema: Schema) -> str:
|
|
528
|
+
"""Returns the representation of the schema."""
|
|
529
|
+
|
|
530
|
+
@abc.abstractmethod
|
|
531
|
+
def value_repr(
|
|
532
|
+
self,
|
|
533
|
+
value: Any,
|
|
534
|
+
schema: Schema | None = None,
|
|
535
|
+
**kwargs
|
|
536
|
+
) -> str:
|
|
537
|
+
"""Returns the representation of a structured value."""
|
|
538
|
+
|
|
539
|
+
@abc.abstractmethod
|
|
540
|
+
def parse_value(
|
|
541
|
+
self,
|
|
542
|
+
text: str,
|
|
543
|
+
schema: Schema | None = None,
|
|
544
|
+
**kwargs
|
|
545
|
+
) -> Any:
|
|
546
|
+
"""Parses a LM generated text into a structured value."""
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def schema_repr(
|
|
550
|
+
schema: Schema,
|
|
551
|
+
*,
|
|
552
|
+
protocol: str = 'python',
|
|
553
|
+
**kwargs
|
|
554
|
+
) -> str:
|
|
555
|
+
"""Returns the representation of the schema based on the protocol."""
|
|
556
|
+
return PromptingProtocol.from_name(protocol).schema_repr(schema, **kwargs)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def value_repr(
|
|
560
|
+
value: Any,
|
|
561
|
+
schema: Schema | None = None,
|
|
562
|
+
*,
|
|
563
|
+
protocol: str = 'python',
|
|
564
|
+
**kwargs) -> str:
|
|
565
|
+
"""Returns the representation of a structured value based on the protocol."""
|
|
566
|
+
return PromptingProtocol.from_name(protocol).value_repr(
|
|
567
|
+
value, schema, **kwargs
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def parse_value(
|
|
572
|
+
text: str,
|
|
573
|
+
schema: Schema | None = None,
|
|
574
|
+
*,
|
|
575
|
+
protocol: str = 'python',
|
|
576
|
+
**kwargs
|
|
577
|
+
) -> Any:
|
|
578
|
+
"""Parses a LM generated text into a structured value."""
|
|
579
|
+
return PromptingProtocol.from_name(protocol).parse_value(
|
|
580
|
+
text, schema=schema, **kwargs
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
#
|
|
585
|
+
# Special value markers.
|
|
586
|
+
#
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
class Missing(pg.Object, pg.typing.CustomTyping):
|
|
590
|
+
"""Value marker for a missing field.
|
|
591
|
+
|
|
592
|
+
This class differs from pg.MISSING_VALUE in two aspects:
|
|
593
|
+
* When a field is assigned with lf.Missing(), it's considered non-partial.
|
|
594
|
+
* lf.Missing() could format the value spec as Python annotations that are
|
|
595
|
+
consistent with `lf.structured.Schema.schema_repr()`.
|
|
596
|
+
"""
|
|
597
|
+
|
|
598
|
+
def _on_bound(self):
|
|
599
|
+
super()._on_bound()
|
|
600
|
+
self._value_spec = None
|
|
601
|
+
|
|
602
|
+
@property
|
|
603
|
+
def value_spec(self) -> pg.ValueSpec | None:
|
|
604
|
+
"""Returns the value spec that applies to the current missing value."""
|
|
605
|
+
return self._value_spec
|
|
606
|
+
|
|
607
|
+
def custom_apply(
|
|
608
|
+
self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
|
|
609
|
+
) -> tuple[bool, Any]:
|
|
610
|
+
self._value_spec = value_spec
|
|
611
|
+
return (False, self)
|
|
612
|
+
|
|
613
|
+
def format(self, *args, **kwargs) -> str:
|
|
614
|
+
if self._value_spec is None:
|
|
615
|
+
return 'MISSING'
|
|
616
|
+
return f'MISSING({annotation(self._value_spec)})'
|
|
617
|
+
|
|
618
|
+
@classmethod
|
|
619
|
+
def find_missing(cls, value: Any) -> dict[str, 'Missing']:
|
|
620
|
+
"""Lists all missing values contained in the value."""
|
|
621
|
+
missing = {}
|
|
622
|
+
|
|
623
|
+
def _visit(k, v, p):
|
|
624
|
+
del p
|
|
625
|
+
if isinstance(v, Missing):
|
|
626
|
+
missing[k] = v
|
|
627
|
+
return pg.TraverseAction.ENTER
|
|
628
|
+
|
|
629
|
+
pg.traverse(value, _visit)
|
|
630
|
+
return missing
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
MISSING = Missing()
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
def mark_missing(value: Any) -> Any:
|
|
637
|
+
"""Replaces pg.MISSING within the value with lf.structured.Missing objects."""
|
|
638
|
+
if isinstance(value, list):
|
|
639
|
+
value = pg.List(value)
|
|
640
|
+
elif isinstance(value, dict):
|
|
641
|
+
value = pg.Dict(value)
|
|
642
|
+
if isinstance(value, pg.Symbolic):
|
|
643
|
+
|
|
644
|
+
def _mark_missing(k, v, p):
|
|
645
|
+
del k, p
|
|
646
|
+
if pg.MISSING_VALUE == v:
|
|
647
|
+
v = Missing()
|
|
648
|
+
return v
|
|
649
|
+
|
|
650
|
+
return value.rebind(_mark_missing, raise_on_no_change=False)
|
|
651
|
+
return value
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
class Unknown(pg.Object, pg.typing.CustomTyping):
|
|
655
|
+
"""Value marker for a field that LMs could not provide."""
|
|
656
|
+
|
|
657
|
+
def custom_apply(self, *args, **kwargs) -> tuple[bool, Any]:
|
|
658
|
+
return (False, self)
|
|
659
|
+
|
|
660
|
+
def format(self, *args, **kwargs) -> str:
|
|
661
|
+
return 'UNKNOWN'
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
UNKNOWN = Unknown()
|