langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +2 -0
- langfun/core/eval/v2/checkpointing.py +76 -7
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +92 -17
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +84 -15
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +12 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +64 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +58 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +64 -3
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,531 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import dataclasses
|
|
15
|
+
import typing
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
from langfun.core.structured.schema import base
|
|
19
|
+
from langfun.core.structured.schema import json # pylint: disable=unused-import
|
|
20
|
+
import pyglove as pg
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Activity(pg.Object):
|
|
24
|
+
description: str
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Itinerary(pg.Object):
|
|
28
|
+
"""A travel itinerary for a day."""
|
|
29
|
+
|
|
30
|
+
day: pg.typing.Int[1, None]
|
|
31
|
+
type: pg.typing.Enum['daytime', 'nighttime']
|
|
32
|
+
activities: list[Activity]
|
|
33
|
+
hotel: pg.typing.Annotated[
|
|
34
|
+
pg.typing.Str['.*Hotel'] | None,
|
|
35
|
+
'Hotel to stay if applicable.'
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
Itinerary.__serialization_key__ = 'Itinerary'
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Node(pg.Object):
|
|
43
|
+
children: list['Node']
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SchemaTest(unittest.TestCase):
|
|
47
|
+
|
|
48
|
+
def assert_schema(self, annotation, spec):
|
|
49
|
+
self.assertEqual(base.Schema(annotation).spec, spec)
|
|
50
|
+
|
|
51
|
+
def assert_unsupported_annotation(self, annotation):
|
|
52
|
+
with self.assertRaises(ValueError):
|
|
53
|
+
base.Schema(annotation)
|
|
54
|
+
|
|
55
|
+
def test_init(self):
|
|
56
|
+
self.assert_schema(int, pg.typing.Int())
|
|
57
|
+
self.assert_schema(float, pg.typing.Float())
|
|
58
|
+
self.assert_schema(str, pg.typing.Str())
|
|
59
|
+
self.assert_schema(bool, pg.typing.Bool())
|
|
60
|
+
|
|
61
|
+
# Top-level dictionary with 'result' as the only key is flattened.
|
|
62
|
+
self.assert_schema(dict(result=int), pg.typing.Int())
|
|
63
|
+
|
|
64
|
+
self.assert_schema(list[str], pg.typing.List(pg.typing.Str()))
|
|
65
|
+
self.assert_schema([str], pg.typing.List(pg.typing.Str()))
|
|
66
|
+
|
|
67
|
+
with self.assertRaisesRegex(
|
|
68
|
+
ValueError, 'Annotation with list must be a list of a single element.'
|
|
69
|
+
):
|
|
70
|
+
base.Schema([str, int])
|
|
71
|
+
|
|
72
|
+
self.assert_schema(
|
|
73
|
+
dict[str, int], pg.typing.Dict([(pg.typing.StrKey(), pg.typing.Int())])
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.assert_schema(
|
|
77
|
+
{
|
|
78
|
+
'x': int,
|
|
79
|
+
'y': [str],
|
|
80
|
+
},
|
|
81
|
+
pg.typing.Dict([
|
|
82
|
+
('x', int),
|
|
83
|
+
('y', pg.typing.List(pg.typing.Str())),
|
|
84
|
+
]),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self.assert_schema(Itinerary, pg.typing.Object(Itinerary))
|
|
88
|
+
|
|
89
|
+
self.assert_unsupported_annotation(typing.Type[int])
|
|
90
|
+
self.assert_unsupported_annotation(typing.Union[int, str, bool])
|
|
91
|
+
self.assert_unsupported_annotation(typing.Any)
|
|
92
|
+
|
|
93
|
+
def test_schema_dict(self):
|
|
94
|
+
schema = base.Schema([{'x': Itinerary}])
|
|
95
|
+
self.assertEqual(
|
|
96
|
+
schema.schema_dict(),
|
|
97
|
+
{
|
|
98
|
+
'result': [
|
|
99
|
+
{
|
|
100
|
+
'x': {
|
|
101
|
+
'_type': 'Itinerary',
|
|
102
|
+
'day': pg.typing.Int(min_value=1),
|
|
103
|
+
'activities': [{
|
|
104
|
+
'_type': Activity.__type_name__,
|
|
105
|
+
'description': pg.typing.Str(),
|
|
106
|
+
}],
|
|
107
|
+
'hotel': pg.typing.Str(regex='.*Hotel').noneable(),
|
|
108
|
+
'type': pg.typing.Enum['daytime', 'nighttime'],
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
]
|
|
112
|
+
},
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def test_class_dependencies(self):
|
|
116
|
+
class Foo(pg.Object):
|
|
117
|
+
x: int
|
|
118
|
+
|
|
119
|
+
class Bar(pg.Object):
|
|
120
|
+
y: str
|
|
121
|
+
|
|
122
|
+
class A(pg.Object):
|
|
123
|
+
foo: tuple[Foo, int]
|
|
124
|
+
|
|
125
|
+
class X(pg.Object):
|
|
126
|
+
k: int | bytes
|
|
127
|
+
|
|
128
|
+
class B(A):
|
|
129
|
+
bar: Bar
|
|
130
|
+
foo2: Foo | X
|
|
131
|
+
|
|
132
|
+
schema = base.Schema([B])
|
|
133
|
+
self.assertEqual(schema.class_dependencies(), [Foo, A, Bar, X, B])
|
|
134
|
+
|
|
135
|
+
def test_class_dependencies_non_pyglove(self):
|
|
136
|
+
class Baz:
|
|
137
|
+
def __init__(self, x: int):
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
@dataclasses.dataclass(frozen=True)
|
|
141
|
+
class AA:
|
|
142
|
+
foo: tuple[Baz, int]
|
|
143
|
+
|
|
144
|
+
class XX(pg.Object):
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
@dataclasses.dataclass(frozen=True)
|
|
148
|
+
class BB(AA):
|
|
149
|
+
foo2: Baz | XX
|
|
150
|
+
|
|
151
|
+
schema = base.Schema([AA])
|
|
152
|
+
self.assertEqual(schema.class_dependencies(), [Baz, AA, XX, BB])
|
|
153
|
+
|
|
154
|
+
def test_schema_repr(self):
|
|
155
|
+
schema = base.Schema([{'x': Itinerary}])
|
|
156
|
+
self.assertEqual(
|
|
157
|
+
schema.schema_repr(protocol='json'),
|
|
158
|
+
(
|
|
159
|
+
'{"result": [{"x": {"_type": "Itinerary", "day":'
|
|
160
|
+
' int(min=1), "type": "daytime" | "nighttime", "activities":'
|
|
161
|
+
' [{"_type": "%s", "description": str}], "hotel":'
|
|
162
|
+
' str(regex=.*Hotel) | None}}]}' % (
|
|
163
|
+
Activity.__type_name__,
|
|
164
|
+
)
|
|
165
|
+
),
|
|
166
|
+
)
|
|
167
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
|
|
168
|
+
schema.schema_repr(protocol='text')
|
|
169
|
+
|
|
170
|
+
def test_value_repr(self):
|
|
171
|
+
schema = base.Schema(int)
|
|
172
|
+
self.assertEqual(schema.value_repr(1, protocol='json'), '{"result": 1}')
|
|
173
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
|
|
174
|
+
schema.value_repr(1, protocol='text')
|
|
175
|
+
|
|
176
|
+
def test_parse(self):
|
|
177
|
+
schema = base.Schema(int)
|
|
178
|
+
self.assertEqual(schema.parse_value('{"result": 1}', protocol='json'), 1)
|
|
179
|
+
schema = base.Schema(dict[str, int])
|
|
180
|
+
self.assertEqual(
|
|
181
|
+
schema.parse_value('{"result": {"x": 1}}}', protocol='json'),
|
|
182
|
+
dict(x=1)
|
|
183
|
+
)
|
|
184
|
+
with self.assertRaisesRegex(
|
|
185
|
+
base.SchemaError, 'Expect .* but encountered .*'):
|
|
186
|
+
schema.parse_value('{"result": "def"}', protocol='json')
|
|
187
|
+
|
|
188
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
|
|
189
|
+
schema.parse_value('1', protocol='text')
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class ClassDependenciesTest(unittest.TestCase):
|
|
193
|
+
|
|
194
|
+
def test_class_dependencies_from_specs(self):
|
|
195
|
+
class Foo(pg.Object):
|
|
196
|
+
x: int
|
|
197
|
+
|
|
198
|
+
class Bar(pg.Object):
|
|
199
|
+
y: str
|
|
200
|
+
|
|
201
|
+
class A(pg.Object):
|
|
202
|
+
foo: tuple[Foo, int]
|
|
203
|
+
|
|
204
|
+
class X(pg.Object):
|
|
205
|
+
k: int
|
|
206
|
+
|
|
207
|
+
class B(A):
|
|
208
|
+
bar: Bar
|
|
209
|
+
foo2: Foo | X
|
|
210
|
+
|
|
211
|
+
self.assertEqual(base.class_dependencies(Foo), [Foo])
|
|
212
|
+
|
|
213
|
+
self.assertEqual(
|
|
214
|
+
base.class_dependencies((A,), include_subclasses=False), [Foo, A]
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
self.assertEqual(
|
|
218
|
+
base.class_dependencies(A, include_subclasses=True),
|
|
219
|
+
[Foo, A, Bar, X, B],
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
self.assertEqual(
|
|
223
|
+
base.class_dependencies(base.Schema(A)), [Foo, A, Bar, X, B]
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
self.assertEqual(
|
|
227
|
+
base.class_dependencies(pg.typing.Object(A)), [Foo, A, Bar, X, B]
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
with self.assertRaisesRegex(TypeError, 'Unsupported spec type'):
|
|
231
|
+
base.class_dependencies((Foo, 1))
|
|
232
|
+
|
|
233
|
+
def test_class_dependencies_recursive(self):
|
|
234
|
+
self.assertEqual(
|
|
235
|
+
base.class_dependencies(Node),
|
|
236
|
+
[Node]
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def test_class_dependencies_from_value(self):
|
|
240
|
+
class Foo(pg.Object):
|
|
241
|
+
x: int
|
|
242
|
+
|
|
243
|
+
class Bar(pg.Object):
|
|
244
|
+
y: str
|
|
245
|
+
|
|
246
|
+
class A(pg.Object):
|
|
247
|
+
foo: tuple[Foo, int]
|
|
248
|
+
|
|
249
|
+
class B(pg.Object):
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
class X(pg.Object):
|
|
253
|
+
k: dict[str, B]
|
|
254
|
+
|
|
255
|
+
class C(A):
|
|
256
|
+
bar: Bar
|
|
257
|
+
foo2: Foo | X
|
|
258
|
+
|
|
259
|
+
a = A(foo=(Foo(1), 0))
|
|
260
|
+
self.assertEqual(base.class_dependencies(a), [Foo, A, Bar, B, X, C])
|
|
261
|
+
|
|
262
|
+
self.assertEqual(base.class_dependencies(1), [])
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class AnnotationTest(unittest.TestCase):
|
|
266
|
+
|
|
267
|
+
def assert_annotation(
|
|
268
|
+
self,
|
|
269
|
+
value_spec: pg.typing.ValueSpec,
|
|
270
|
+
expected_annotation: str,
|
|
271
|
+
strict: bool = False,
|
|
272
|
+
**kwargs,
|
|
273
|
+
) -> None:
|
|
274
|
+
self.assertEqual(
|
|
275
|
+
base.annotation(value_spec, strict=strict, **kwargs),
|
|
276
|
+
expected_annotation,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def test_annotation(self):
|
|
280
|
+
# Bool.
|
|
281
|
+
self.assert_annotation(pg.typing.Bool(), 'bool')
|
|
282
|
+
self.assert_annotation(pg.typing.Bool().noneable(), 'bool | None')
|
|
283
|
+
|
|
284
|
+
# Str.
|
|
285
|
+
self.assert_annotation(pg.typing.Str(), 'str')
|
|
286
|
+
self.assert_annotation(pg.typing.Str().noneable(), 'str | None')
|
|
287
|
+
self.assert_annotation(pg.typing.Str(regex='a.*'), "str(regex='a.*')")
|
|
288
|
+
self.assert_annotation(pg.typing.Str(regex='a.*'), "str(regex='a.*')")
|
|
289
|
+
self.assert_annotation(
|
|
290
|
+
pg.typing.Str(regex='a.*'), "pg.typing.Str(regex='a.*')", strict=True
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Int.
|
|
294
|
+
self.assert_annotation(pg.typing.Int(), 'int')
|
|
295
|
+
self.assert_annotation(pg.typing.Int().noneable(), 'int | None')
|
|
296
|
+
self.assert_annotation(pg.typing.Int(min_value=0), 'int(min=0)')
|
|
297
|
+
self.assert_annotation(pg.typing.Int(max_value=1), 'int(max=1)')
|
|
298
|
+
self.assert_annotation(
|
|
299
|
+
pg.typing.Int(min_value=0, max_value=1), 'int(min=0, max=1)'
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
self.assert_annotation(pg.typing.Int(), 'int', strict=True)
|
|
303
|
+
self.assert_annotation(
|
|
304
|
+
pg.typing.Int(min_value=0), 'pg.typing.Int(min_value=0)', strict=True
|
|
305
|
+
)
|
|
306
|
+
self.assert_annotation(
|
|
307
|
+
pg.typing.Int(max_value=1), 'pg.typing.Int(max_value=1)', strict=True
|
|
308
|
+
)
|
|
309
|
+
self.assert_annotation(
|
|
310
|
+
pg.typing.Int(min_value=0, max_value=1),
|
|
311
|
+
'pg.typing.Int(min_value=0, max_value=1)',
|
|
312
|
+
strict=True,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Float.
|
|
316
|
+
self.assert_annotation(pg.typing.Float(), 'float')
|
|
317
|
+
self.assert_annotation(pg.typing.Float().noneable(), 'float | None')
|
|
318
|
+
self.assert_annotation(pg.typing.Float(min_value=0), 'float(min=0)')
|
|
319
|
+
self.assert_annotation(pg.typing.Float(max_value=1), 'float(max=1)')
|
|
320
|
+
self.assert_annotation(
|
|
321
|
+
pg.typing.Float(min_value=0, max_value=1), 'float(min=0, max=1)'
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
self.assert_annotation(pg.typing.Float(), 'float', strict=True)
|
|
325
|
+
self.assert_annotation(
|
|
326
|
+
pg.typing.Float(min_value=0),
|
|
327
|
+
'pg.typing.Float(min_value=0)',
|
|
328
|
+
strict=True,
|
|
329
|
+
)
|
|
330
|
+
self.assert_annotation(
|
|
331
|
+
pg.typing.Float(max_value=1),
|
|
332
|
+
'pg.typing.Float(max_value=1)',
|
|
333
|
+
strict=True,
|
|
334
|
+
)
|
|
335
|
+
self.assert_annotation(
|
|
336
|
+
pg.typing.Float(min_value=0, max_value=1),
|
|
337
|
+
'pg.typing.Float(min_value=0, max_value=1)',
|
|
338
|
+
strict=True,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Enum
|
|
342
|
+
self.assert_annotation(
|
|
343
|
+
pg.typing.Enum[1, 'foo'].noneable(), "Literal[1, 'foo', None]"
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Object.
|
|
347
|
+
self.assert_annotation(pg.typing.Object(Activity), 'Activity')
|
|
348
|
+
self.assert_annotation(
|
|
349
|
+
pg.typing.Object(Activity).noneable(), 'Activity | None'
|
|
350
|
+
)
|
|
351
|
+
self.assert_annotation(
|
|
352
|
+
pg.typing.Object(Activity).noneable(), 'Activity | None',
|
|
353
|
+
allowed_dependencies=set([Activity]),
|
|
354
|
+
)
|
|
355
|
+
self.assert_annotation(
|
|
356
|
+
pg.typing.Object(Activity).noneable(), 'Any | None',
|
|
357
|
+
allowed_dependencies=set(),
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# List.
|
|
361
|
+
self.assert_annotation(
|
|
362
|
+
pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]'
|
|
363
|
+
)
|
|
364
|
+
self.assert_annotation(
|
|
365
|
+
pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]',
|
|
366
|
+
allowed_dependencies=set([Activity]),
|
|
367
|
+
)
|
|
368
|
+
self.assert_annotation(
|
|
369
|
+
pg.typing.List(pg.typing.Object(Activity)), 'list[Any]',
|
|
370
|
+
allowed_dependencies=set(),
|
|
371
|
+
)
|
|
372
|
+
self.assert_annotation(
|
|
373
|
+
pg.typing.List(pg.typing.Object(Activity)).noneable(),
|
|
374
|
+
'list[Activity] | None',
|
|
375
|
+
)
|
|
376
|
+
self.assert_annotation(
|
|
377
|
+
pg.typing.List(pg.typing.Object(Activity).noneable()),
|
|
378
|
+
'list[Activity | None]',
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Tuple.
|
|
382
|
+
self.assert_annotation(
|
|
383
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]'
|
|
384
|
+
)
|
|
385
|
+
self.assert_annotation(
|
|
386
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]',
|
|
387
|
+
allowed_dependencies=set([Activity]),
|
|
388
|
+
)
|
|
389
|
+
self.assert_annotation(
|
|
390
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Any, str]',
|
|
391
|
+
allowed_dependencies=set(),
|
|
392
|
+
)
|
|
393
|
+
self.assert_annotation(
|
|
394
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]).noneable(),
|
|
395
|
+
'tuple[Activity, str] | None',
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# Dict.
|
|
399
|
+
self.assert_annotation(
|
|
400
|
+
pg.typing.Dict({'x': Activity, 'y': str}),
|
|
401
|
+
'{\'x\': Activity, \'y\': str}'
|
|
402
|
+
)
|
|
403
|
+
self.assert_annotation(
|
|
404
|
+
pg.typing.Dict({'x': Activity, 'y': str}),
|
|
405
|
+
'{\'x\': Activity, \'y\': str}',
|
|
406
|
+
allowed_dependencies=set([Activity]),
|
|
407
|
+
)
|
|
408
|
+
self.assert_annotation(
|
|
409
|
+
pg.typing.Dict({'x': Activity, 'y': str}),
|
|
410
|
+
'{\'x\': Any, \'y\': str}',
|
|
411
|
+
allowed_dependencies=set(),
|
|
412
|
+
)
|
|
413
|
+
self.assert_annotation(
|
|
414
|
+
pg.typing.Dict({'x': int, 'y': str}),
|
|
415
|
+
'pg.typing.Dict({\'x\': int, \'y\': str})',
|
|
416
|
+
strict=True,
|
|
417
|
+
)
|
|
418
|
+
self.assert_annotation(
|
|
419
|
+
pg.typing.Dict(),
|
|
420
|
+
'dict[str, Any]',
|
|
421
|
+
strict=False,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
class DictValue(pg.Object):
|
|
425
|
+
pass
|
|
426
|
+
|
|
427
|
+
self.assert_annotation(
|
|
428
|
+
pg.typing.Dict([(pg.typing.StrKey(), DictValue)]),
|
|
429
|
+
'dict[str, DictValue]',
|
|
430
|
+
strict=False,
|
|
431
|
+
)
|
|
432
|
+
self.assert_annotation(
|
|
433
|
+
pg.typing.Dict(),
|
|
434
|
+
'dict[str, Any]',
|
|
435
|
+
strict=True,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Union.
|
|
439
|
+
self.assert_annotation(
|
|
440
|
+
pg.typing.Union(
|
|
441
|
+
[pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
|
|
442
|
+
).noneable(),
|
|
443
|
+
'Union[Activity, Itinerary, None]',
|
|
444
|
+
)
|
|
445
|
+
self.assert_annotation(
|
|
446
|
+
pg.typing.Union(
|
|
447
|
+
[pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
|
|
448
|
+
).noneable(),
|
|
449
|
+
'Union[Activity, Any, None]',
|
|
450
|
+
allowed_dependencies=set([Activity]),
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Any.
|
|
454
|
+
self.assert_annotation(pg.typing.Any(), 'Any')
|
|
455
|
+
self.assert_annotation(pg.typing.Any().noneable(), 'Any')
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class MissingTest(unittest.TestCase):
|
|
459
|
+
|
|
460
|
+
def test_basics(self):
|
|
461
|
+
a = Itinerary(
|
|
462
|
+
day=1,
|
|
463
|
+
type=base.Missing(),
|
|
464
|
+
activities=base.Missing(),
|
|
465
|
+
hotel=base.Missing(),
|
|
466
|
+
)
|
|
467
|
+
self.assertFalse(a.is_partial)
|
|
468
|
+
self.assertEqual(str(base.Missing()), 'MISSING')
|
|
469
|
+
self.assertEqual(str(a.type), "MISSING(Literal['daytime', 'nighttime'])")
|
|
470
|
+
self.assertEqual(str(a.activities), 'MISSING(list[Activity])')
|
|
471
|
+
self.assertEqual(str(a.hotel), "MISSING(str(regex='.*Hotel') | None)")
|
|
472
|
+
|
|
473
|
+
def assert_missing(self, value, expected_missing):
|
|
474
|
+
value = base.mark_missing(value)
|
|
475
|
+
self.assertEqual(base.Missing.find_missing(value), expected_missing)
|
|
476
|
+
|
|
477
|
+
def test_find_missing(self):
|
|
478
|
+
self.assert_missing(
|
|
479
|
+
Itinerary.partial(),
|
|
480
|
+
{
|
|
481
|
+
'day': base.MISSING,
|
|
482
|
+
'type': base.MISSING,
|
|
483
|
+
'activities': base.MISSING,
|
|
484
|
+
},
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
self.assert_missing(
|
|
488
|
+
Itinerary.partial(
|
|
489
|
+
day=1, type='daytime', activities=[Activity.partial()]
|
|
490
|
+
),
|
|
491
|
+
{
|
|
492
|
+
'activities[0].description': base.MISSING,
|
|
493
|
+
},
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def test_mark_missing(self):
|
|
497
|
+
class A(pg.Object):
|
|
498
|
+
x: typing.Any
|
|
499
|
+
|
|
500
|
+
self.assertEqual(base.mark_missing(1), 1)
|
|
501
|
+
self.assertEqual(
|
|
502
|
+
base.mark_missing(pg.MISSING_VALUE), pg.MISSING_VALUE
|
|
503
|
+
)
|
|
504
|
+
self.assertEqual(
|
|
505
|
+
base.mark_missing(A.partial(A.partial(A.partial()))),
|
|
506
|
+
A(A(A(base.MISSING))),
|
|
507
|
+
)
|
|
508
|
+
self.assertEqual(
|
|
509
|
+
base.mark_missing(dict(a=A.partial())),
|
|
510
|
+
dict(a=A(base.MISSING)),
|
|
511
|
+
)
|
|
512
|
+
self.assertEqual(
|
|
513
|
+
base.mark_missing([1, dict(a=A.partial())]),
|
|
514
|
+
[1, dict(a=A(base.MISSING))],
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
class UnknownTest(unittest.TestCase):
|
|
519
|
+
|
|
520
|
+
def test_basics(self):
|
|
521
|
+
class A(pg.Object):
|
|
522
|
+
x: int
|
|
523
|
+
|
|
524
|
+
a = A(x=base.Unknown())
|
|
525
|
+
self.assertFalse(a.is_partial)
|
|
526
|
+
self.assertEqual(a.x, base.UNKNOWN)
|
|
527
|
+
self.assertEqual(base.UNKNOWN, base.Unknown())
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
if __name__ == '__main__':
|
|
531
|
+
unittest.main()
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""JSON-based prompting protocol."""
|
|
15
|
+
|
|
16
|
+
import io
|
|
17
|
+
import textwrap
|
|
18
|
+
from typing import Any
|
|
19
|
+
from langfun.core.structured.schema import base
|
|
20
|
+
import pyglove as pg
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class JsonError(Exception): # pylint: disable=g-bad-exception-name
|
|
24
|
+
"""Json parsing error."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, json: str, cause: Exception):
|
|
27
|
+
self.json = json
|
|
28
|
+
self.cause = cause
|
|
29
|
+
|
|
30
|
+
def __str__(self) -> str:
|
|
31
|
+
r = io.StringIO()
|
|
32
|
+
r.write(
|
|
33
|
+
pg.colored(
|
|
34
|
+
f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
r.write('\n\n')
|
|
39
|
+
r.write(pg.colored('JSON text:', 'red'))
|
|
40
|
+
r.write('\n\n')
|
|
41
|
+
r.write(textwrap.indent(pg.colored(self.json, 'magenta'), ' ' * 2))
|
|
42
|
+
return r.getvalue()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class JsonPromptingProtocol(base.PromptingProtocol):
|
|
46
|
+
"""JSON-based prompting protocol."""
|
|
47
|
+
|
|
48
|
+
NAME = 'json'
|
|
49
|
+
|
|
50
|
+
def schema_repr(self, schema: base.Schema, **kwargs) -> str:
|
|
51
|
+
del kwargs
|
|
52
|
+
out = io.StringIO()
|
|
53
|
+
def _visit(node: Any) -> None:
|
|
54
|
+
if isinstance(node, str):
|
|
55
|
+
out.write(f'"{node}"')
|
|
56
|
+
elif isinstance(node, list):
|
|
57
|
+
assert len(node) == 1, node
|
|
58
|
+
out.write('[')
|
|
59
|
+
_visit(node[0])
|
|
60
|
+
out.write(']')
|
|
61
|
+
elif isinstance(node, dict):
|
|
62
|
+
out.write('{')
|
|
63
|
+
for i, (k, v) in enumerate(node.items()):
|
|
64
|
+
if i != 0:
|
|
65
|
+
out.write(', ')
|
|
66
|
+
out.write(f'"{k}": ')
|
|
67
|
+
_visit(v)
|
|
68
|
+
out.write('}')
|
|
69
|
+
elif isinstance(node, pg.typing.Enum):
|
|
70
|
+
out.write(' | '.join(
|
|
71
|
+
f'"{v}"' if isinstance(v, str) else repr(v)
|
|
72
|
+
for v in node.values))
|
|
73
|
+
elif isinstance(node, pg.typing.PrimitiveType):
|
|
74
|
+
x = node.value_type.__name__
|
|
75
|
+
if isinstance(node, pg.typing.Number):
|
|
76
|
+
params = []
|
|
77
|
+
if node.min_value is not None:
|
|
78
|
+
params.append(f'min={node.min_value}')
|
|
79
|
+
if node.max_value is not None:
|
|
80
|
+
params.append(f'max={node.max_value}')
|
|
81
|
+
if params:
|
|
82
|
+
x += f'({", ".join(params)})'
|
|
83
|
+
elif isinstance(node, pg.typing.Str):
|
|
84
|
+
if node.regex is not None:
|
|
85
|
+
x += f'(regex={node.regex.pattern})'
|
|
86
|
+
if node.is_noneable:
|
|
87
|
+
x = x + ' | None'
|
|
88
|
+
out.write(x)
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f'Unsupported value spec being used as schema: {node}.')
|
|
92
|
+
_visit(schema.schema_dict())
|
|
93
|
+
return out.getvalue()
|
|
94
|
+
|
|
95
|
+
def value_repr(
|
|
96
|
+
self,
|
|
97
|
+
value: Any,
|
|
98
|
+
schema: base.Schema | None = None,
|
|
99
|
+
**kwargs
|
|
100
|
+
) -> str:
|
|
101
|
+
del schema, kwargs
|
|
102
|
+
return pg.to_json_str(dict(result=value))
|
|
103
|
+
|
|
104
|
+
def parse_value(
|
|
105
|
+
self,
|
|
106
|
+
text: str,
|
|
107
|
+
schema: base.Schema | None = None,
|
|
108
|
+
**kwargs
|
|
109
|
+
) -> Any:
|
|
110
|
+
"""Parses a JSON string into a structured object."""
|
|
111
|
+
del schema
|
|
112
|
+
try:
|
|
113
|
+
text = cleanup_json(text)
|
|
114
|
+
v = pg.from_json_str(text, **kwargs)
|
|
115
|
+
except Exception as e:
|
|
116
|
+
raise JsonError(text, e) # pylint: disable=raise-missing-from
|
|
117
|
+
|
|
118
|
+
if not isinstance(v, dict) or 'result' not in v:
|
|
119
|
+
raise JsonError(text, ValueError(
|
|
120
|
+
'The root node of the JSON must be a dict with key `result`. '
|
|
121
|
+
f'Encountered: {v}'
|
|
122
|
+
))
|
|
123
|
+
return v['result']
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def cleanup_json(json_str: str) -> str:
|
|
127
|
+
"""Cleans up the LM responded JSON string."""
|
|
128
|
+
# Treatments:
|
|
129
|
+
# 1. Extract the JSON string with a top-level dict from the response.
|
|
130
|
+
# This prevents the leading and trailing texts in the response to
|
|
131
|
+
# be counted as part of the JSON.
|
|
132
|
+
# 2. Escape new lines in JSON values.
|
|
133
|
+
|
|
134
|
+
curly_brackets = 0
|
|
135
|
+
under_json = False
|
|
136
|
+
under_str = False
|
|
137
|
+
str_begin = -1
|
|
138
|
+
|
|
139
|
+
cleaned = io.StringIO()
|
|
140
|
+
for i, c in enumerate(json_str):
|
|
141
|
+
if c == '{' and not under_str:
|
|
142
|
+
cleaned.write(c)
|
|
143
|
+
curly_brackets += 1
|
|
144
|
+
under_json = True
|
|
145
|
+
continue
|
|
146
|
+
elif not under_json:
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
if c == '}' and not under_str:
|
|
150
|
+
cleaned.write(c)
|
|
151
|
+
curly_brackets -= 1
|
|
152
|
+
if curly_brackets == 0:
|
|
153
|
+
break
|
|
154
|
+
elif c == '"' and json_str[i - 1] != '\\':
|
|
155
|
+
under_str = not under_str
|
|
156
|
+
if under_str:
|
|
157
|
+
str_begin = i
|
|
158
|
+
else:
|
|
159
|
+
assert str_begin > 0
|
|
160
|
+
str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
|
|
161
|
+
cleaned.write(str_value)
|
|
162
|
+
str_begin = -1
|
|
163
|
+
elif not under_str:
|
|
164
|
+
cleaned.write(c)
|
|
165
|
+
|
|
166
|
+
if not under_json:
|
|
167
|
+
raise ValueError(f'No JSON dict in the output: {json_str}')
|
|
168
|
+
|
|
169
|
+
if curly_brackets > 0:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f'Malformated JSON: missing {curly_brackets} closing curly braces.'
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return cleaned.getvalue()
|