langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,531 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import dataclasses
15
+ import typing
16
+ import unittest
17
+
18
+ from langfun.core.structured.schema import base
19
+ from langfun.core.structured.schema import json # pylint: disable=unused-import
20
+ import pyglove as pg
21
+
22
+
23
+ class Activity(pg.Object):
24
+ description: str
25
+
26
+
27
+ class Itinerary(pg.Object):
28
+ """A travel itinerary for a day."""
29
+
30
+ day: pg.typing.Int[1, None]
31
+ type: pg.typing.Enum['daytime', 'nighttime']
32
+ activities: list[Activity]
33
+ hotel: pg.typing.Annotated[
34
+ pg.typing.Str['.*Hotel'] | None,
35
+ 'Hotel to stay if applicable.'
36
+ ]
37
+
38
+
39
+ Itinerary.__serialization_key__ = 'Itinerary'
40
+
41
+
42
+ class Node(pg.Object):
43
+ children: list['Node']
44
+
45
+
46
+ class SchemaTest(unittest.TestCase):
47
+
48
+ def assert_schema(self, annotation, spec):
49
+ self.assertEqual(base.Schema(annotation).spec, spec)
50
+
51
+ def assert_unsupported_annotation(self, annotation):
52
+ with self.assertRaises(ValueError):
53
+ base.Schema(annotation)
54
+
55
+ def test_init(self):
56
+ self.assert_schema(int, pg.typing.Int())
57
+ self.assert_schema(float, pg.typing.Float())
58
+ self.assert_schema(str, pg.typing.Str())
59
+ self.assert_schema(bool, pg.typing.Bool())
60
+
61
+ # Top-level dictionary with 'result' as the only key is flattened.
62
+ self.assert_schema(dict(result=int), pg.typing.Int())
63
+
64
+ self.assert_schema(list[str], pg.typing.List(pg.typing.Str()))
65
+ self.assert_schema([str], pg.typing.List(pg.typing.Str()))
66
+
67
+ with self.assertRaisesRegex(
68
+ ValueError, 'Annotation with list must be a list of a single element.'
69
+ ):
70
+ base.Schema([str, int])
71
+
72
+ self.assert_schema(
73
+ dict[str, int], pg.typing.Dict([(pg.typing.StrKey(), pg.typing.Int())])
74
+ )
75
+
76
+ self.assert_schema(
77
+ {
78
+ 'x': int,
79
+ 'y': [str],
80
+ },
81
+ pg.typing.Dict([
82
+ ('x', int),
83
+ ('y', pg.typing.List(pg.typing.Str())),
84
+ ]),
85
+ )
86
+
87
+ self.assert_schema(Itinerary, pg.typing.Object(Itinerary))
88
+
89
+ self.assert_unsupported_annotation(typing.Type[int])
90
+ self.assert_unsupported_annotation(typing.Union[int, str, bool])
91
+ self.assert_unsupported_annotation(typing.Any)
92
+
93
+ def test_schema_dict(self):
94
+ schema = base.Schema([{'x': Itinerary}])
95
+ self.assertEqual(
96
+ schema.schema_dict(),
97
+ {
98
+ 'result': [
99
+ {
100
+ 'x': {
101
+ '_type': 'Itinerary',
102
+ 'day': pg.typing.Int(min_value=1),
103
+ 'activities': [{
104
+ '_type': Activity.__type_name__,
105
+ 'description': pg.typing.Str(),
106
+ }],
107
+ 'hotel': pg.typing.Str(regex='.*Hotel').noneable(),
108
+ 'type': pg.typing.Enum['daytime', 'nighttime'],
109
+ }
110
+ }
111
+ ]
112
+ },
113
+ )
114
+
115
+ def test_class_dependencies(self):
116
+ class Foo(pg.Object):
117
+ x: int
118
+
119
+ class Bar(pg.Object):
120
+ y: str
121
+
122
+ class A(pg.Object):
123
+ foo: tuple[Foo, int]
124
+
125
+ class X(pg.Object):
126
+ k: int | bytes
127
+
128
+ class B(A):
129
+ bar: Bar
130
+ foo2: Foo | X
131
+
132
+ schema = base.Schema([B])
133
+ self.assertEqual(schema.class_dependencies(), [Foo, A, Bar, X, B])
134
+
135
+ def test_class_dependencies_non_pyglove(self):
136
+ class Baz:
137
+ def __init__(self, x: int):
138
+ pass
139
+
140
+ @dataclasses.dataclass(frozen=True)
141
+ class AA:
142
+ foo: tuple[Baz, int]
143
+
144
+ class XX(pg.Object):
145
+ pass
146
+
147
+ @dataclasses.dataclass(frozen=True)
148
+ class BB(AA):
149
+ foo2: Baz | XX
150
+
151
+ schema = base.Schema([AA])
152
+ self.assertEqual(schema.class_dependencies(), [Baz, AA, XX, BB])
153
+
154
+ def test_schema_repr(self):
155
+ schema = base.Schema([{'x': Itinerary}])
156
+ self.assertEqual(
157
+ schema.schema_repr(protocol='json'),
158
+ (
159
+ '{"result": [{"x": {"_type": "Itinerary", "day":'
160
+ ' int(min=1), "type": "daytime" | "nighttime", "activities":'
161
+ ' [{"_type": "%s", "description": str}], "hotel":'
162
+ ' str(regex=.*Hotel) | None}}]}' % (
163
+ Activity.__type_name__,
164
+ )
165
+ ),
166
+ )
167
+ with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
168
+ schema.schema_repr(protocol='text')
169
+
170
+ def test_value_repr(self):
171
+ schema = base.Schema(int)
172
+ self.assertEqual(schema.value_repr(1, protocol='json'), '{"result": 1}')
173
+ with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
174
+ schema.value_repr(1, protocol='text')
175
+
176
+ def test_parse(self):
177
+ schema = base.Schema(int)
178
+ self.assertEqual(schema.parse_value('{"result": 1}', protocol='json'), 1)
179
+ schema = base.Schema(dict[str, int])
180
+ self.assertEqual(
181
+ schema.parse_value('{"result": {"x": 1}}}', protocol='json'),
182
+ dict(x=1)
183
+ )
184
+ with self.assertRaisesRegex(
185
+ base.SchemaError, 'Expect .* but encountered .*'):
186
+ schema.parse_value('{"result": "def"}', protocol='json')
187
+
188
+ with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
189
+ schema.parse_value('1', protocol='text')
190
+
191
+
192
+ class ClassDependenciesTest(unittest.TestCase):
193
+
194
+ def test_class_dependencies_from_specs(self):
195
+ class Foo(pg.Object):
196
+ x: int
197
+
198
+ class Bar(pg.Object):
199
+ y: str
200
+
201
+ class A(pg.Object):
202
+ foo: tuple[Foo, int]
203
+
204
+ class X(pg.Object):
205
+ k: int
206
+
207
+ class B(A):
208
+ bar: Bar
209
+ foo2: Foo | X
210
+
211
+ self.assertEqual(base.class_dependencies(Foo), [Foo])
212
+
213
+ self.assertEqual(
214
+ base.class_dependencies((A,), include_subclasses=False), [Foo, A]
215
+ )
216
+
217
+ self.assertEqual(
218
+ base.class_dependencies(A, include_subclasses=True),
219
+ [Foo, A, Bar, X, B],
220
+ )
221
+
222
+ self.assertEqual(
223
+ base.class_dependencies(base.Schema(A)), [Foo, A, Bar, X, B]
224
+ )
225
+
226
+ self.assertEqual(
227
+ base.class_dependencies(pg.typing.Object(A)), [Foo, A, Bar, X, B]
228
+ )
229
+
230
+ with self.assertRaisesRegex(TypeError, 'Unsupported spec type'):
231
+ base.class_dependencies((Foo, 1))
232
+
233
+ def test_class_dependencies_recursive(self):
234
+ self.assertEqual(
235
+ base.class_dependencies(Node),
236
+ [Node]
237
+ )
238
+
239
+ def test_class_dependencies_from_value(self):
240
+ class Foo(pg.Object):
241
+ x: int
242
+
243
+ class Bar(pg.Object):
244
+ y: str
245
+
246
+ class A(pg.Object):
247
+ foo: tuple[Foo, int]
248
+
249
+ class B(pg.Object):
250
+ pass
251
+
252
+ class X(pg.Object):
253
+ k: dict[str, B]
254
+
255
+ class C(A):
256
+ bar: Bar
257
+ foo2: Foo | X
258
+
259
+ a = A(foo=(Foo(1), 0))
260
+ self.assertEqual(base.class_dependencies(a), [Foo, A, Bar, B, X, C])
261
+
262
+ self.assertEqual(base.class_dependencies(1), [])
263
+
264
+
265
+ class AnnotationTest(unittest.TestCase):
266
+
267
+ def assert_annotation(
268
+ self,
269
+ value_spec: pg.typing.ValueSpec,
270
+ expected_annotation: str,
271
+ strict: bool = False,
272
+ **kwargs,
273
+ ) -> None:
274
+ self.assertEqual(
275
+ base.annotation(value_spec, strict=strict, **kwargs),
276
+ expected_annotation,
277
+ )
278
+
279
+ def test_annotation(self):
280
+ # Bool.
281
+ self.assert_annotation(pg.typing.Bool(), 'bool')
282
+ self.assert_annotation(pg.typing.Bool().noneable(), 'bool | None')
283
+
284
+ # Str.
285
+ self.assert_annotation(pg.typing.Str(), 'str')
286
+ self.assert_annotation(pg.typing.Str().noneable(), 'str | None')
287
+ self.assert_annotation(pg.typing.Str(regex='a.*'), "str(regex='a.*')")
288
+ self.assert_annotation(pg.typing.Str(regex='a.*'), "str(regex='a.*')")
289
+ self.assert_annotation(
290
+ pg.typing.Str(regex='a.*'), "pg.typing.Str(regex='a.*')", strict=True
291
+ )
292
+
293
+ # Int.
294
+ self.assert_annotation(pg.typing.Int(), 'int')
295
+ self.assert_annotation(pg.typing.Int().noneable(), 'int | None')
296
+ self.assert_annotation(pg.typing.Int(min_value=0), 'int(min=0)')
297
+ self.assert_annotation(pg.typing.Int(max_value=1), 'int(max=1)')
298
+ self.assert_annotation(
299
+ pg.typing.Int(min_value=0, max_value=1), 'int(min=0, max=1)'
300
+ )
301
+
302
+ self.assert_annotation(pg.typing.Int(), 'int', strict=True)
303
+ self.assert_annotation(
304
+ pg.typing.Int(min_value=0), 'pg.typing.Int(min_value=0)', strict=True
305
+ )
306
+ self.assert_annotation(
307
+ pg.typing.Int(max_value=1), 'pg.typing.Int(max_value=1)', strict=True
308
+ )
309
+ self.assert_annotation(
310
+ pg.typing.Int(min_value=0, max_value=1),
311
+ 'pg.typing.Int(min_value=0, max_value=1)',
312
+ strict=True,
313
+ )
314
+
315
+ # Float.
316
+ self.assert_annotation(pg.typing.Float(), 'float')
317
+ self.assert_annotation(pg.typing.Float().noneable(), 'float | None')
318
+ self.assert_annotation(pg.typing.Float(min_value=0), 'float(min=0)')
319
+ self.assert_annotation(pg.typing.Float(max_value=1), 'float(max=1)')
320
+ self.assert_annotation(
321
+ pg.typing.Float(min_value=0, max_value=1), 'float(min=0, max=1)'
322
+ )
323
+
324
+ self.assert_annotation(pg.typing.Float(), 'float', strict=True)
325
+ self.assert_annotation(
326
+ pg.typing.Float(min_value=0),
327
+ 'pg.typing.Float(min_value=0)',
328
+ strict=True,
329
+ )
330
+ self.assert_annotation(
331
+ pg.typing.Float(max_value=1),
332
+ 'pg.typing.Float(max_value=1)',
333
+ strict=True,
334
+ )
335
+ self.assert_annotation(
336
+ pg.typing.Float(min_value=0, max_value=1),
337
+ 'pg.typing.Float(min_value=0, max_value=1)',
338
+ strict=True,
339
+ )
340
+
341
+ # Enum
342
+ self.assert_annotation(
343
+ pg.typing.Enum[1, 'foo'].noneable(), "Literal[1, 'foo', None]"
344
+ )
345
+
346
+ # Object.
347
+ self.assert_annotation(pg.typing.Object(Activity), 'Activity')
348
+ self.assert_annotation(
349
+ pg.typing.Object(Activity).noneable(), 'Activity | None'
350
+ )
351
+ self.assert_annotation(
352
+ pg.typing.Object(Activity).noneable(), 'Activity | None',
353
+ allowed_dependencies=set([Activity]),
354
+ )
355
+ self.assert_annotation(
356
+ pg.typing.Object(Activity).noneable(), 'Any | None',
357
+ allowed_dependencies=set(),
358
+ )
359
+
360
+ # List.
361
+ self.assert_annotation(
362
+ pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]'
363
+ )
364
+ self.assert_annotation(
365
+ pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]',
366
+ allowed_dependencies=set([Activity]),
367
+ )
368
+ self.assert_annotation(
369
+ pg.typing.List(pg.typing.Object(Activity)), 'list[Any]',
370
+ allowed_dependencies=set(),
371
+ )
372
+ self.assert_annotation(
373
+ pg.typing.List(pg.typing.Object(Activity)).noneable(),
374
+ 'list[Activity] | None',
375
+ )
376
+ self.assert_annotation(
377
+ pg.typing.List(pg.typing.Object(Activity).noneable()),
378
+ 'list[Activity | None]',
379
+ )
380
+
381
+ # Tuple.
382
+ self.assert_annotation(
383
+ pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]'
384
+ )
385
+ self.assert_annotation(
386
+ pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]',
387
+ allowed_dependencies=set([Activity]),
388
+ )
389
+ self.assert_annotation(
390
+ pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Any, str]',
391
+ allowed_dependencies=set(),
392
+ )
393
+ self.assert_annotation(
394
+ pg.typing.Tuple([Activity, pg.typing.Str()]).noneable(),
395
+ 'tuple[Activity, str] | None',
396
+ )
397
+
398
+ # Dict.
399
+ self.assert_annotation(
400
+ pg.typing.Dict({'x': Activity, 'y': str}),
401
+ '{\'x\': Activity, \'y\': str}'
402
+ )
403
+ self.assert_annotation(
404
+ pg.typing.Dict({'x': Activity, 'y': str}),
405
+ '{\'x\': Activity, \'y\': str}',
406
+ allowed_dependencies=set([Activity]),
407
+ )
408
+ self.assert_annotation(
409
+ pg.typing.Dict({'x': Activity, 'y': str}),
410
+ '{\'x\': Any, \'y\': str}',
411
+ allowed_dependencies=set(),
412
+ )
413
+ self.assert_annotation(
414
+ pg.typing.Dict({'x': int, 'y': str}),
415
+ 'pg.typing.Dict({\'x\': int, \'y\': str})',
416
+ strict=True,
417
+ )
418
+ self.assert_annotation(
419
+ pg.typing.Dict(),
420
+ 'dict[str, Any]',
421
+ strict=False,
422
+ )
423
+
424
+ class DictValue(pg.Object):
425
+ pass
426
+
427
+ self.assert_annotation(
428
+ pg.typing.Dict([(pg.typing.StrKey(), DictValue)]),
429
+ 'dict[str, DictValue]',
430
+ strict=False,
431
+ )
432
+ self.assert_annotation(
433
+ pg.typing.Dict(),
434
+ 'dict[str, Any]',
435
+ strict=True,
436
+ )
437
+
438
+ # Union.
439
+ self.assert_annotation(
440
+ pg.typing.Union(
441
+ [pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
442
+ ).noneable(),
443
+ 'Union[Activity, Itinerary, None]',
444
+ )
445
+ self.assert_annotation(
446
+ pg.typing.Union(
447
+ [pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
448
+ ).noneable(),
449
+ 'Union[Activity, Any, None]',
450
+ allowed_dependencies=set([Activity]),
451
+ )
452
+
453
+ # Any.
454
+ self.assert_annotation(pg.typing.Any(), 'Any')
455
+ self.assert_annotation(pg.typing.Any().noneable(), 'Any')
456
+
457
+
458
+ class MissingTest(unittest.TestCase):
459
+
460
+ def test_basics(self):
461
+ a = Itinerary(
462
+ day=1,
463
+ type=base.Missing(),
464
+ activities=base.Missing(),
465
+ hotel=base.Missing(),
466
+ )
467
+ self.assertFalse(a.is_partial)
468
+ self.assertEqual(str(base.Missing()), 'MISSING')
469
+ self.assertEqual(str(a.type), "MISSING(Literal['daytime', 'nighttime'])")
470
+ self.assertEqual(str(a.activities), 'MISSING(list[Activity])')
471
+ self.assertEqual(str(a.hotel), "MISSING(str(regex='.*Hotel') | None)")
472
+
473
+ def assert_missing(self, value, expected_missing):
474
+ value = base.mark_missing(value)
475
+ self.assertEqual(base.Missing.find_missing(value), expected_missing)
476
+
477
+ def test_find_missing(self):
478
+ self.assert_missing(
479
+ Itinerary.partial(),
480
+ {
481
+ 'day': base.MISSING,
482
+ 'type': base.MISSING,
483
+ 'activities': base.MISSING,
484
+ },
485
+ )
486
+
487
+ self.assert_missing(
488
+ Itinerary.partial(
489
+ day=1, type='daytime', activities=[Activity.partial()]
490
+ ),
491
+ {
492
+ 'activities[0].description': base.MISSING,
493
+ },
494
+ )
495
+
496
+ def test_mark_missing(self):
497
+ class A(pg.Object):
498
+ x: typing.Any
499
+
500
+ self.assertEqual(base.mark_missing(1), 1)
501
+ self.assertEqual(
502
+ base.mark_missing(pg.MISSING_VALUE), pg.MISSING_VALUE
503
+ )
504
+ self.assertEqual(
505
+ base.mark_missing(A.partial(A.partial(A.partial()))),
506
+ A(A(A(base.MISSING))),
507
+ )
508
+ self.assertEqual(
509
+ base.mark_missing(dict(a=A.partial())),
510
+ dict(a=A(base.MISSING)),
511
+ )
512
+ self.assertEqual(
513
+ base.mark_missing([1, dict(a=A.partial())]),
514
+ [1, dict(a=A(base.MISSING))],
515
+ )
516
+
517
+
518
+ class UnknownTest(unittest.TestCase):
519
+
520
+ def test_basics(self):
521
+ class A(pg.Object):
522
+ x: int
523
+
524
+ a = A(x=base.Unknown())
525
+ self.assertFalse(a.is_partial)
526
+ self.assertEqual(a.x, base.UNKNOWN)
527
+ self.assertEqual(base.UNKNOWN, base.Unknown())
528
+
529
+
530
+ if __name__ == '__main__':
531
+ unittest.main()
@@ -0,0 +1,174 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """JSON-based prompting protocol."""
15
+
16
+ import io
17
+ import textwrap
18
+ from typing import Any
19
+ from langfun.core.structured.schema import base
20
+ import pyglove as pg
21
+
22
+
23
+ class JsonError(Exception): # pylint: disable=g-bad-exception-name
24
+ """Json parsing error."""
25
+
26
+ def __init__(self, json: str, cause: Exception):
27
+ self.json = json
28
+ self.cause = cause
29
+
30
+ def __str__(self) -> str:
31
+ r = io.StringIO()
32
+ r.write(
33
+ pg.colored(
34
+ f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
35
+ )
36
+ )
37
+
38
+ r.write('\n\n')
39
+ r.write(pg.colored('JSON text:', 'red'))
40
+ r.write('\n\n')
41
+ r.write(textwrap.indent(pg.colored(self.json, 'magenta'), ' ' * 2))
42
+ return r.getvalue()
43
+
44
+
45
+ class JsonPromptingProtocol(base.PromptingProtocol):
46
+ """JSON-based prompting protocol."""
47
+
48
+ NAME = 'json'
49
+
50
+ def schema_repr(self, schema: base.Schema, **kwargs) -> str:
51
+ del kwargs
52
+ out = io.StringIO()
53
+ def _visit(node: Any) -> None:
54
+ if isinstance(node, str):
55
+ out.write(f'"{node}"')
56
+ elif isinstance(node, list):
57
+ assert len(node) == 1, node
58
+ out.write('[')
59
+ _visit(node[0])
60
+ out.write(']')
61
+ elif isinstance(node, dict):
62
+ out.write('{')
63
+ for i, (k, v) in enumerate(node.items()):
64
+ if i != 0:
65
+ out.write(', ')
66
+ out.write(f'"{k}": ')
67
+ _visit(v)
68
+ out.write('}')
69
+ elif isinstance(node, pg.typing.Enum):
70
+ out.write(' | '.join(
71
+ f'"{v}"' if isinstance(v, str) else repr(v)
72
+ for v in node.values))
73
+ elif isinstance(node, pg.typing.PrimitiveType):
74
+ x = node.value_type.__name__
75
+ if isinstance(node, pg.typing.Number):
76
+ params = []
77
+ if node.min_value is not None:
78
+ params.append(f'min={node.min_value}')
79
+ if node.max_value is not None:
80
+ params.append(f'max={node.max_value}')
81
+ if params:
82
+ x += f'({", ".join(params)})'
83
+ elif isinstance(node, pg.typing.Str):
84
+ if node.regex is not None:
85
+ x += f'(regex={node.regex.pattern})'
86
+ if node.is_noneable:
87
+ x = x + ' | None'
88
+ out.write(x)
89
+ else:
90
+ raise ValueError(
91
+ f'Unsupported value spec being used as schema: {node}.')
92
+ _visit(schema.schema_dict())
93
+ return out.getvalue()
94
+
95
+ def value_repr(
96
+ self,
97
+ value: Any,
98
+ schema: base.Schema | None = None,
99
+ **kwargs
100
+ ) -> str:
101
+ del schema, kwargs
102
+ return pg.to_json_str(dict(result=value))
103
+
104
+ def parse_value(
105
+ self,
106
+ text: str,
107
+ schema: base.Schema | None = None,
108
+ **kwargs
109
+ ) -> Any:
110
+ """Parses a JSON string into a structured object."""
111
+ del schema
112
+ try:
113
+ text = cleanup_json(text)
114
+ v = pg.from_json_str(text, **kwargs)
115
+ except Exception as e:
116
+ raise JsonError(text, e) # pylint: disable=raise-missing-from
117
+
118
+ if not isinstance(v, dict) or 'result' not in v:
119
+ raise JsonError(text, ValueError(
120
+ 'The root node of the JSON must be a dict with key `result`. '
121
+ f'Encountered: {v}'
122
+ ))
123
+ return v['result']
124
+
125
+
126
+ def cleanup_json(json_str: str) -> str:
127
+ """Cleans up the LM responded JSON string."""
128
+ # Treatments:
129
+ # 1. Extract the JSON string with a top-level dict from the response.
130
+ # This prevents the leading and trailing texts in the response to
131
+ # be counted as part of the JSON.
132
+ # 2. Escape new lines in JSON values.
133
+
134
+ curly_brackets = 0
135
+ under_json = False
136
+ under_str = False
137
+ str_begin = -1
138
+
139
+ cleaned = io.StringIO()
140
+ for i, c in enumerate(json_str):
141
+ if c == '{' and not under_str:
142
+ cleaned.write(c)
143
+ curly_brackets += 1
144
+ under_json = True
145
+ continue
146
+ elif not under_json:
147
+ continue
148
+
149
+ if c == '}' and not under_str:
150
+ cleaned.write(c)
151
+ curly_brackets -= 1
152
+ if curly_brackets == 0:
153
+ break
154
+ elif c == '"' and json_str[i - 1] != '\\':
155
+ under_str = not under_str
156
+ if under_str:
157
+ str_begin = i
158
+ else:
159
+ assert str_begin > 0
160
+ str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
161
+ cleaned.write(str_value)
162
+ str_begin = -1
163
+ elif not under_str:
164
+ cleaned.write(c)
165
+
166
+ if not under_json:
167
+ raise ValueError(f'No JSON dict in the output: {json_str}')
168
+
169
+ if curly_brackets > 0:
170
+ raise ValueError(
171
+ f'Malformated JSON: missing {curly_brackets} closing curly braces.'
172
+ )
173
+
174
+ return cleaned.getvalue()