langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -1,982 +0,0 @@
1
- # Copyright 2023 The Langfun Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """Tests for structured parsing."""
15
-
16
- import dataclasses
17
- import inspect
18
- import typing
19
- import unittest
20
-
21
- import langfun.core as lf
22
- from langfun.core.llms import fake
23
- from langfun.core.structured import schema as schema_lib
24
- import pyglove as pg
25
-
26
-
27
- class Activity(pg.Object):
28
- description: str
29
-
30
-
31
- class Itinerary(pg.Object):
32
- """A travel itinerary for a day."""
33
-
34
- day: pg.typing.Int[1, None]
35
- type: pg.typing.Enum['daytime', 'nighttime']
36
- activities: list[Activity]
37
- hotel: pg.typing.Annotated[
38
- pg.typing.Str['.*Hotel'] | None,
39
- 'Hotel to stay if applicable.'
40
- ]
41
-
42
-
43
- class PlaceOfInterest(pg.Object):
44
- """The name of a place of interest.
45
-
46
- A place of interest is a place that people ususally visit during their
47
- travels.
48
- """
49
-
50
- name: str
51
-
52
-
53
- Itinerary.__serialization_key__ = 'Itinerary'
54
-
55
-
56
- class Node(pg.Object):
57
- children: list['Node']
58
-
59
-
60
- class SchemaTest(unittest.TestCase):
61
-
62
- def assert_schema(self, annotation, spec):
63
- self.assertEqual(schema_lib.Schema(annotation).spec, spec)
64
-
65
- def assert_unsupported_annotation(self, annotation):
66
- with self.assertRaises(ValueError):
67
- schema_lib.Schema(annotation)
68
-
69
- def test_init(self):
70
- self.assert_schema(int, pg.typing.Int())
71
- self.assert_schema(float, pg.typing.Float())
72
- self.assert_schema(str, pg.typing.Str())
73
- self.assert_schema(bool, pg.typing.Bool())
74
-
75
- # Top-level dictionary with 'result' as the only key is flattened.
76
- self.assert_schema(dict(result=int), pg.typing.Int())
77
-
78
- self.assert_schema(list[str], pg.typing.List(pg.typing.Str()))
79
- self.assert_schema([str], pg.typing.List(pg.typing.Str()))
80
-
81
- with self.assertRaisesRegex(
82
- ValueError, 'Annotation with list must be a list of a single element.'
83
- ):
84
- schema_lib.Schema([str, int])
85
-
86
- self.assert_schema(
87
- dict[str, int], pg.typing.Dict([(pg.typing.StrKey(), pg.typing.Int())])
88
- )
89
-
90
- self.assert_schema(
91
- {
92
- 'x': int,
93
- 'y': [str],
94
- },
95
- pg.typing.Dict([
96
- ('x', int),
97
- ('y', pg.typing.List(pg.typing.Str())),
98
- ]),
99
- )
100
-
101
- self.assert_schema(Itinerary, pg.typing.Object(Itinerary))
102
-
103
- self.assert_unsupported_annotation(typing.Type[int])
104
- self.assert_unsupported_annotation(typing.Union[int, str, bool])
105
- self.assert_unsupported_annotation(typing.Any)
106
-
107
- def test_schema_dict(self):
108
- schema = schema_lib.Schema([{'x': Itinerary}])
109
- self.assertEqual(
110
- schema.schema_dict(),
111
- {
112
- 'result': [
113
- {
114
- 'x': {
115
- '_type': 'Itinerary',
116
- 'day': pg.typing.Int(min_value=1),
117
- 'activities': [{
118
- '_type': Activity.__type_name__,
119
- 'description': pg.typing.Str(),
120
- }],
121
- 'hotel': pg.typing.Str(regex='.*Hotel').noneable(),
122
- 'type': pg.typing.Enum['daytime', 'nighttime'],
123
- }
124
- }
125
- ]
126
- },
127
- )
128
-
129
- def test_class_dependencies(self):
130
- class Foo(pg.Object):
131
- x: int
132
-
133
- class Bar(pg.Object):
134
- y: str
135
-
136
- class A(pg.Object):
137
- foo: tuple[Foo, int]
138
-
139
- class X(pg.Object):
140
- k: int
141
-
142
- class B(A):
143
- bar: Bar
144
- foo2: Foo | X
145
-
146
- schema = schema_lib.Schema([B])
147
- v = schema_lib.structure_from_python(
148
- """
149
- class C(B):
150
- pass
151
- """,
152
- global_vars=dict(B=B),
153
- permission=pg.coding.CodePermission.ALL,
154
- )
155
- self.assertEqual(v.__module__, 'builtins')
156
- self.assertEqual(schema.class_dependencies(), [Foo, A, Bar, X, B])
157
-
158
- def test_class_dependencies_non_pyglove(self):
159
- class Baz:
160
- def __init__(self, x: int):
161
- pass
162
-
163
- @dataclasses.dataclass(frozen=True)
164
- class AA:
165
- foo: tuple[Baz, int]
166
-
167
- class XX(pg.Object):
168
- pass
169
-
170
- @dataclasses.dataclass(frozen=True)
171
- class BB(AA):
172
- foo2: Baz | XX
173
-
174
- v = schema_lib.structure_from_python(
175
- """
176
- class CC(BB):
177
- pass
178
- """,
179
- global_vars=dict(BB=BB),
180
- permission=pg.coding.CodePermission.ALL,
181
- )
182
- self.assertEqual(v.__module__, 'builtins')
183
- schema = schema_lib.Schema([AA])
184
- self.assertEqual(schema.class_dependencies(), [Baz, AA, XX, BB])
185
-
186
- def test_schema_repr(self):
187
- schema = schema_lib.Schema([{'x': Itinerary}])
188
- self.assertEqual(
189
- schema.schema_str(protocol='json'),
190
- (
191
- '{"result": [{"x": {"_type": "Itinerary", "day":'
192
- ' int(min=1), "type": "daytime" | "nighttime", "activities":'
193
- ' [{"_type": "%s", "description": str}], "hotel":'
194
- ' str(regex=.*Hotel) | None}}]}' % (
195
- Activity.__type_name__,
196
- )
197
- ),
198
- )
199
- with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
200
- schema.schema_str(protocol='text')
201
-
202
- def test_value_repr(self):
203
- schema = schema_lib.Schema(int)
204
- self.assertEqual(schema.value_str(1, protocol='json'), '{"result": 1}')
205
- with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
206
- schema.value_str(1, protocol='text')
207
-
208
- def test_parse(self):
209
- schema = schema_lib.Schema(int)
210
- self.assertEqual(schema.parse('{"result": 1}'), 1)
211
- schema = schema_lib.Schema(dict[str, int])
212
- self.assertEqual(
213
- schema.parse('{"result": {"x": 1}}}'),
214
- dict(x=1)
215
- )
216
- with self.assertRaisesRegex(
217
- schema_lib.SchemaError, 'Expect .* but encountered .*'):
218
- schema.parse('{"result": "def"}')
219
-
220
- with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
221
- schema.parse('1', protocol='text')
222
-
223
-
224
- class ClassDependenciesTest(unittest.TestCase):
225
-
226
- def test_class_dependencies_from_specs(self):
227
- class Foo(pg.Object):
228
- x: int
229
-
230
- class Bar(pg.Object):
231
- y: str
232
-
233
- class A(pg.Object):
234
- foo: tuple[Foo, int]
235
-
236
- class X(pg.Object):
237
- k: int
238
-
239
- class B(A):
240
- bar: Bar
241
- foo2: Foo | X
242
-
243
- self.assertEqual(schema_lib.class_dependencies(Foo), [Foo])
244
-
245
- self.assertEqual(
246
- schema_lib.class_dependencies((A,), include_subclasses=False), [Foo, A]
247
- )
248
-
249
- self.assertEqual(
250
- schema_lib.class_dependencies(A, include_subclasses=True),
251
- [Foo, A, Bar, X, B],
252
- )
253
-
254
- self.assertEqual(
255
- schema_lib.class_dependencies(schema_lib.Schema(A)), [Foo, A, Bar, X, B]
256
- )
257
-
258
- self.assertEqual(
259
- schema_lib.class_dependencies(pg.typing.Object(A)), [Foo, A, Bar, X, B]
260
- )
261
-
262
- with self.assertRaisesRegex(TypeError, 'Unsupported spec type'):
263
- schema_lib.class_dependencies((Foo, 1))
264
-
265
- def test_class_dependencies_recursive(self):
266
- self.assertEqual(
267
- schema_lib.class_dependencies(Node),
268
- [Node]
269
- )
270
-
271
- def test_class_dependencies_from_value(self):
272
- class Foo(pg.Object):
273
- x: int
274
-
275
- class Bar(pg.Object):
276
- y: str
277
-
278
- class A(pg.Object):
279
- foo: tuple[Foo, int]
280
-
281
- class B(pg.Object):
282
- pass
283
-
284
- class X(pg.Object):
285
- k: dict[str, B]
286
-
287
- class C(A):
288
- bar: Bar
289
- foo2: Foo | X
290
-
291
- a = A(foo=(Foo(1), 0))
292
- self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, B, X, C])
293
-
294
- self.assertEqual(schema_lib.class_dependencies(1), [])
295
-
296
-
297
- class SchemaPythonReprTest(unittest.TestCase):
298
-
299
- def assert_annotation(
300
- self,
301
- value_spec: pg.typing.ValueSpec,
302
- expected_annotation: str,
303
- strict: bool = False,
304
- **kwargs,
305
- ) -> None:
306
- self.assertEqual(
307
- schema_lib.annotation(value_spec, strict=strict, **kwargs),
308
- expected_annotation,
309
- )
310
-
311
- def test_annotation(self):
312
- # Bool.
313
- self.assert_annotation(pg.typing.Bool(), 'bool')
314
- self.assert_annotation(pg.typing.Bool().noneable(), 'bool | None')
315
-
316
- # Str.
317
- self.assert_annotation(pg.typing.Str(), 'str')
318
- self.assert_annotation(pg.typing.Str().noneable(), 'str | None')
319
- self.assert_annotation(pg.typing.Str(regex='a.*'), "str(regex='a.*')")
320
- self.assert_annotation(pg.typing.Str(regex='a.*'), "str(regex='a.*')")
321
- self.assert_annotation(
322
- pg.typing.Str(regex='a.*'), "pg.typing.Str(regex='a.*')", strict=True
323
- )
324
-
325
- # Int.
326
- self.assert_annotation(pg.typing.Int(), 'int')
327
- self.assert_annotation(pg.typing.Int().noneable(), 'int | None')
328
- self.assert_annotation(pg.typing.Int(min_value=0), 'int(min=0)')
329
- self.assert_annotation(pg.typing.Int(max_value=1), 'int(max=1)')
330
- self.assert_annotation(
331
- pg.typing.Int(min_value=0, max_value=1), 'int(min=0, max=1)'
332
- )
333
-
334
- self.assert_annotation(pg.typing.Int(), 'int', strict=True)
335
- self.assert_annotation(
336
- pg.typing.Int(min_value=0), 'pg.typing.Int(min_value=0)', strict=True
337
- )
338
- self.assert_annotation(
339
- pg.typing.Int(max_value=1), 'pg.typing.Int(max_value=1)', strict=True
340
- )
341
- self.assert_annotation(
342
- pg.typing.Int(min_value=0, max_value=1),
343
- 'pg.typing.Int(min_value=0, max_value=1)',
344
- strict=True,
345
- )
346
-
347
- # Float.
348
- self.assert_annotation(pg.typing.Float(), 'float')
349
- self.assert_annotation(pg.typing.Float().noneable(), 'float | None')
350
- self.assert_annotation(pg.typing.Float(min_value=0), 'float(min=0)')
351
- self.assert_annotation(pg.typing.Float(max_value=1), 'float(max=1)')
352
- self.assert_annotation(
353
- pg.typing.Float(min_value=0, max_value=1), 'float(min=0, max=1)'
354
- )
355
-
356
- self.assert_annotation(pg.typing.Float(), 'float', strict=True)
357
- self.assert_annotation(
358
- pg.typing.Float(min_value=0),
359
- 'pg.typing.Float(min_value=0)',
360
- strict=True,
361
- )
362
- self.assert_annotation(
363
- pg.typing.Float(max_value=1),
364
- 'pg.typing.Float(max_value=1)',
365
- strict=True,
366
- )
367
- self.assert_annotation(
368
- pg.typing.Float(min_value=0, max_value=1),
369
- 'pg.typing.Float(min_value=0, max_value=1)',
370
- strict=True,
371
- )
372
-
373
- # Enum
374
- self.assert_annotation(
375
- pg.typing.Enum[1, 'foo'].noneable(), "Literal[1, 'foo', None]"
376
- )
377
-
378
- # Object.
379
- self.assert_annotation(pg.typing.Object(Activity), 'Activity')
380
- self.assert_annotation(
381
- pg.typing.Object(Activity).noneable(), 'Activity | None'
382
- )
383
- self.assert_annotation(
384
- pg.typing.Object(Activity).noneable(), 'Activity | None',
385
- allowed_dependencies=set([Activity]),
386
- )
387
- self.assert_annotation(
388
- pg.typing.Object(Activity).noneable(), 'Any | None',
389
- allowed_dependencies=set(),
390
- )
391
-
392
- # List.
393
- self.assert_annotation(
394
- pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]'
395
- )
396
- self.assert_annotation(
397
- pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]',
398
- allowed_dependencies=set([Activity]),
399
- )
400
- self.assert_annotation(
401
- pg.typing.List(pg.typing.Object(Activity)), 'list[Any]',
402
- allowed_dependencies=set(),
403
- )
404
- self.assert_annotation(
405
- pg.typing.List(pg.typing.Object(Activity)).noneable(),
406
- 'list[Activity] | None',
407
- )
408
- self.assert_annotation(
409
- pg.typing.List(pg.typing.Object(Activity).noneable()),
410
- 'list[Activity | None]',
411
- )
412
-
413
- # Tuple.
414
- self.assert_annotation(
415
- pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]'
416
- )
417
- self.assert_annotation(
418
- pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]',
419
- allowed_dependencies=set([Activity]),
420
- )
421
- self.assert_annotation(
422
- pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Any, str]',
423
- allowed_dependencies=set(),
424
- )
425
- self.assert_annotation(
426
- pg.typing.Tuple([Activity, pg.typing.Str()]).noneable(),
427
- 'tuple[Activity, str] | None',
428
- )
429
-
430
- # Dict.
431
- self.assert_annotation(
432
- pg.typing.Dict({'x': Activity, 'y': str}),
433
- '{\'x\': Activity, \'y\': str}'
434
- )
435
- self.assert_annotation(
436
- pg.typing.Dict({'x': Activity, 'y': str}),
437
- '{\'x\': Activity, \'y\': str}',
438
- allowed_dependencies=set([Activity]),
439
- )
440
- self.assert_annotation(
441
- pg.typing.Dict({'x': Activity, 'y': str}),
442
- '{\'x\': Any, \'y\': str}',
443
- allowed_dependencies=set(),
444
- )
445
- self.assert_annotation(
446
- pg.typing.Dict({'x': int, 'y': str}),
447
- 'pg.typing.Dict({\'x\': int, \'y\': str})',
448
- strict=True,
449
- )
450
- self.assert_annotation(
451
- pg.typing.Dict(),
452
- 'dict[str, Any]',
453
- strict=False,
454
- )
455
-
456
- class DictValue(pg.Object):
457
- pass
458
-
459
- self.assert_annotation(
460
- pg.typing.Dict([(pg.typing.StrKey(), DictValue)]),
461
- 'dict[str, DictValue]',
462
- strict=False,
463
- )
464
- self.assert_annotation(
465
- pg.typing.Dict(),
466
- 'dict[str, Any]',
467
- strict=True,
468
- )
469
-
470
- # Union.
471
- self.assert_annotation(
472
- pg.typing.Union(
473
- [pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
474
- ).noneable(),
475
- 'Union[Activity, Itinerary, None]',
476
- )
477
- self.assert_annotation(
478
- pg.typing.Union(
479
- [pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
480
- ).noneable(),
481
- 'Union[Activity, Any, None]',
482
- allowed_dependencies=set([Activity]),
483
- )
484
-
485
- # Any.
486
- self.assert_annotation(pg.typing.Any(), 'Any')
487
- self.assert_annotation(pg.typing.Any().noneable(), 'Any')
488
-
489
- def test_class_definition(self):
490
- self.assertEqual(
491
- schema_lib.class_definition(Activity, allowed_dependencies=set()),
492
- 'class Activity:\n description: str\n',
493
- )
494
- self.assertEqual(
495
- schema_lib.class_definition(Itinerary),
496
- inspect.cleandoc("""
497
- class Itinerary(Object):
498
- \"\"\"A travel itinerary for a day.\"\"\"
499
- day: int(min=1)
500
- type: Literal['daytime', 'nighttime']
501
- activities: list[Activity]
502
- # Hotel to stay if applicable.
503
- hotel: str(regex='.*Hotel') | None
504
- """) + '\n',
505
- )
506
- self.assertEqual(
507
- schema_lib.class_definition(
508
- PlaceOfInterest, allowed_dependencies=set()
509
- ),
510
- inspect.cleandoc("""
511
- class PlaceOfInterest:
512
- \"\"\"The name of a place of interest.
513
-
514
- A place of interest is a place that people ususally visit during their
515
- travels.
516
- \"\"\"
517
- name: str
518
- """) + '\n',
519
- )
520
-
521
- class A(pg.Object):
522
- pass
523
-
524
- self.assertEqual(
525
- schema_lib.class_definition(A, allowed_dependencies=set()),
526
- 'class A:\n pass\n',
527
- )
528
- self.assertEqual(
529
- schema_lib.class_definition(A),
530
- 'class A(Object):\n pass\n',
531
- )
532
-
533
- class C(pg.Object):
534
- x: str
535
- __kwargs__: typing.Any
536
-
537
- self.assertEqual(
538
- schema_lib.class_definition(C), 'class C(Object):\n x: str\n'
539
- )
540
-
541
- class D(pg.Object):
542
- x: str
543
- @schema_lib.include_method_in_prompt
544
- def __call__(self, y: int) -> int:
545
- return len(self.x) + y
546
-
547
- self.assertEqual(
548
- schema_lib.class_definition(D),
549
- inspect.cleandoc(
550
- """
551
- class D(Object):
552
- x: str
553
-
554
- def __call__(self, y: int) -> int:
555
- return len(self.x) + y
556
- """) + '\n'
557
- )
558
-
559
- class E(pg.Object):
560
- x: str
561
- y: typing.Annotated[int, 'y', dict(exclude_from_prompt=True)]
562
-
563
- self.assertEqual(
564
- schema_lib.class_definition(E),
565
- inspect.cleandoc(
566
- """
567
- class E(Object):
568
- x: str
569
- """) + '\n'
570
- )
571
-
572
- def test_repr(self):
573
- class Foo(pg.Object):
574
- x: int
575
-
576
- @dataclasses.dataclass(frozen=True)
577
- class Bar:
578
- """Class Bar."""
579
- y: str
580
-
581
- @dataclasses.dataclass(frozen=True)
582
- class Baz(Bar): # pylint: disable=unused-variable
583
- pass
584
-
585
- class A(pg.Object):
586
- foo: Foo
587
-
588
- @schema_lib.include_method_in_prompt
589
- def foo_value(self) -> int:
590
- return self.foo.x
591
-
592
- def baz_value(self) -> str:
593
- return 'baz'
594
-
595
- class B(A):
596
- bar: Bar
597
- foo2: Foo
598
-
599
- @schema_lib.include_method_in_prompt
600
- def bar_value(self) -> str:
601
- return self.bar.y
602
-
603
- schema = schema_lib.Schema([B])
604
- self.assertEqual(
605
- schema_lib.SchemaPythonRepr().class_definitions(schema),
606
- inspect.cleandoc('''
607
- class Foo:
608
- x: int
609
-
610
- class Bar:
611
- """Class Bar."""
612
- y: str
613
-
614
- class Baz(Bar):
615
- """Baz(y: str)"""
616
- y: str
617
-
618
- class B:
619
- foo: Foo
620
- bar: Bar
621
- foo2: Foo
622
-
623
- def bar_value(self) -> str:
624
- return self.bar.y
625
-
626
- def foo_value(self) -> int:
627
- return self.foo.x
628
- ''') + '\n',
629
- )
630
-
631
- self.assertEqual(
632
- schema_lib.SchemaPythonRepr().result_definition(schema), 'list[B]'
633
- )
634
-
635
- self.assertEqual(
636
- schema_lib.SchemaPythonRepr().repr(schema),
637
- inspect.cleandoc('''
638
- list[B]
639
-
640
- ```python
641
- class Foo:
642
- x: int
643
-
644
- class Bar:
645
- """Class Bar."""
646
- y: str
647
-
648
- class Baz(Bar):
649
- """Baz(y: str)"""
650
- y: str
651
-
652
- class B:
653
- foo: Foo
654
- bar: Bar
655
- foo2: Foo
656
-
657
- def bar_value(self) -> str:
658
- return self.bar.y
659
-
660
- def foo_value(self) -> int:
661
- return self.foo.x
662
- ```
663
- '''),
664
- )
665
- self.assertEqual(
666
- schema_lib.SchemaPythonRepr().repr(
667
- schema,
668
- include_result_definition=False,
669
- markdown=False,
670
- ),
671
- inspect.cleandoc('''
672
- class Foo:
673
- x: int
674
-
675
- class Bar:
676
- """Class Bar."""
677
- y: str
678
-
679
- class Baz(Bar):
680
- """Baz(y: str)"""
681
- y: str
682
-
683
- class B:
684
- foo: Foo
685
- bar: Bar
686
- foo2: Foo
687
-
688
- def bar_value(self) -> str:
689
- return self.bar.y
690
-
691
- def foo_value(self) -> int:
692
- return self.foo.x
693
- '''),
694
- )
695
-
696
-
697
- class SchemaJsonReprTest(unittest.TestCase):
698
-
699
- def test_repr(self):
700
- schema = schema_lib.Schema([{'x': Itinerary}])
701
- self.assertEqual(
702
- schema_lib.SchemaJsonRepr().repr(schema),
703
- (
704
- '{"result": [{"x": {"_type": "Itinerary", "day":'
705
- ' int(min=1), "type": "daytime" | "nighttime", "activities":'
706
- ' [{"_type": "%s", "description": str}], "hotel":'
707
- ' str(regex=.*Hotel) | None}}]}' % (
708
- Activity.__type_name__,
709
- )
710
- ),
711
- )
712
-
713
-
714
- class ValuePythonReprTest(unittest.TestCase):
715
-
716
- def test_repr(self):
717
- class Foo(pg.Object):
718
- x: int
719
-
720
- class A(pg.Object):
721
- foo: list[Foo]
722
- y: str | None
723
-
724
- self.assertEqual(
725
- schema_lib.ValuePythonRepr().repr(1, schema_lib.Schema(int)),
726
- '```python\n1\n```'
727
- )
728
- self.assertEqual(
729
- schema_lib.ValuePythonRepr().repr(lf.Template('hi, {{a}}', a='foo')),
730
- 'hi, foo'
731
- )
732
- self.assertEqual(
733
- schema_lib.ValuePythonRepr().repr(
734
- A([Foo(1), Foo(2)], 'bar'), schema_lib.Schema(A), markdown=False,
735
- ),
736
- "A(foo=[Foo(x=1), Foo(x=2)], y='bar')",
737
- )
738
- self.assertEqual(
739
- schema_lib.ValuePythonRepr().repr(
740
- A([Foo(1), Foo(2)], 'bar'),
741
- schema_lib.Schema(A),
742
- markdown=True,
743
- compact=False,
744
- assign_to_var='output',
745
- ),
746
- inspect.cleandoc("""
747
- ```python
748
- output = A(
749
- foo=[
750
- Foo(
751
- x=1
752
- ),
753
- Foo(
754
- x=2
755
- )
756
- ],
757
- y='bar'
758
- )
759
- ```
760
- """),
761
- )
762
- self.assertEqual(
763
- schema_lib.ValuePythonRepr().repr(A),
764
- inspect.cleandoc("""
765
- ```python
766
- class Foo(Object):
767
- x: int
768
-
769
- class A(Object):
770
- foo: list[Foo]
771
- y: str | None
772
- ```
773
- """),
774
- )
775
- self.assertEqual(schema_lib.source_form(int), 'int')
776
-
777
- def test_parse(self):
778
- class Foo(pg.Object):
779
- x: int
780
-
781
- class A(pg.Object):
782
- foo: list[Foo]
783
- y: str | None
784
-
785
- self.assertEqual(
786
- schema_lib.ValuePythonRepr().parse(
787
- "A(foo=[Foo(x=1), Foo(x=2)], y='bar')", schema_lib.Schema(A)
788
- ),
789
- A([Foo(1), Foo(2)], y='bar'),
790
- )
791
-
792
- def test_parse_with_correction(self):
793
- class Foo(pg.Object):
794
- x: int
795
-
796
- class A(pg.Object):
797
- foo: list[Foo]
798
- y: str | None
799
-
800
- self.assertEqual(
801
- schema_lib.ValuePythonRepr().parse(
802
- "A(foo=[Foo(x=1), Foo(x=2)], y='bar'",
803
- schema_lib.Schema(A),
804
- autofix=1,
805
- autofix_lm=fake.StaticResponse(inspect.cleandoc("""
806
- CorrectedCode(
807
- corrected_code='A(foo=[Foo(x=1), Foo(x=2)], y=\\\'bar\\\')',
808
- )
809
- """)),
810
- ),
811
- A([Foo(1), Foo(2)], y='bar'),
812
- )
813
-
814
- def test_parse_class_def(self):
815
- self.assertTrue(
816
- inspect.isclass(
817
- schema_lib.ValuePythonRepr().parse(
818
- """
819
- class A:
820
- x: Dict[str, Any]
821
- y: Optional[Sequence[str]]
822
- z: Union[int, List[int], Tuple[int]]
823
- """,
824
- permission=pg.coding.CodePermission.ALL,
825
- )
826
- )
827
- )
828
-
829
-
830
- class ValueJsonReprTest(unittest.TestCase):
831
-
832
- def test_repr(self):
833
- self.assertEqual(schema_lib.ValueJsonRepr().repr(1), '{"result": 1}')
834
-
835
- def assert_parse(self, inputs, output) -> None:
836
- self.assertEqual(schema_lib.ValueJsonRepr().parse(inputs), output)
837
-
838
- def test_parse_basics(self):
839
- self.assert_parse('{"result": 1}', 1)
840
- self.assert_parse('{"result": "\\"}ab{"}', '"}ab{')
841
- self.assert_parse(
842
- '{"result": {"x": true, "y": null}}',
843
- {'x': True, 'y': None},
844
- )
845
- self.assert_parse(
846
- (
847
- '{"result": {"_type": "%s", "description": "play"}}'
848
- % Activity.__type_name__
849
- ),
850
- Activity('play'),
851
- )
852
- with self.assertRaisesRegex(
853
- schema_lib.JsonError, 'JSONDecodeError'
854
- ):
855
- schema_lib.ValueJsonRepr().parse('{"abc", 1}')
856
-
857
- with self.assertRaisesRegex(
858
- schema_lib.JsonError,
859
- 'The root node of the JSON must be a dict with key `result`'
860
- ):
861
- schema_lib.ValueJsonRepr().parse('{"abc": 1}')
862
-
863
- def test_parse_with_surrounding_texts(self):
864
- self.assert_parse('The answer is {"result": 1}.', 1)
865
-
866
- def test_parse_with_new_lines(self):
867
- self.assert_parse(
868
- """
869
- {
870
- "result": [
871
- "foo
872
- bar"]
873
- }
874
- """,
875
- ['foo\nbar'])
876
-
877
- def test_parse_with_malformated_json(self):
878
- with self.assertRaisesRegex(
879
- schema_lib.JsonError, 'No JSON dict in the output'
880
- ):
881
- schema_lib.ValueJsonRepr().parse('The answer is 1.')
882
-
883
- with self.assertRaisesRegex(
884
- schema_lib.JsonError,
885
- 'Malformated JSON: missing .* closing curly braces'
886
- ):
887
- schema_lib.ValueJsonRepr().parse('{"result": 1')
888
-
889
-
890
- class ProtocolTest(unittest.TestCase):
891
-
892
- def test_schema_repr(self):
893
- self.assertIsInstance(
894
- schema_lib.schema_repr('json'), schema_lib.SchemaJsonRepr)
895
- self.assertIsInstance(
896
- schema_lib.schema_repr('python'), schema_lib.SchemaPythonRepr)
897
- with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
898
- schema_lib.schema_repr('text')
899
-
900
- def test_value_repr(self):
901
- self.assertIsInstance(
902
- schema_lib.value_repr('json'), schema_lib.ValueJsonRepr)
903
- self.assertIsInstance(
904
- schema_lib.value_repr('python'), schema_lib.ValuePythonRepr)
905
- with self.assertRaisesRegex(ValueError, 'Unsupported protocol'):
906
- schema_lib.value_repr('text')
907
-
908
-
909
- class MissingTest(unittest.TestCase):
910
-
911
- def test_basics(self):
912
- a = Itinerary(
913
- day=1,
914
- type=schema_lib.Missing(),
915
- activities=schema_lib.Missing(),
916
- hotel=schema_lib.Missing(),
917
- )
918
- self.assertFalse(a.is_partial)
919
- self.assertEqual(str(schema_lib.Missing()), 'MISSING')
920
- self.assertEqual(str(a.type), "MISSING(Literal['daytime', 'nighttime'])")
921
- self.assertEqual(str(a.activities), 'MISSING(list[Activity])')
922
- self.assertEqual(str(a.hotel), "MISSING(str(regex='.*Hotel') | None)")
923
-
924
- def assert_missing(self, value, expected_missing):
925
- value = schema_lib.mark_missing(value)
926
- self.assertEqual(schema_lib.Missing.find_missing(value), expected_missing)
927
-
928
- def test_find_missing(self):
929
- self.assert_missing(
930
- Itinerary.partial(),
931
- {
932
- 'day': schema_lib.MISSING,
933
- 'type': schema_lib.MISSING,
934
- 'activities': schema_lib.MISSING,
935
- },
936
- )
937
-
938
- self.assert_missing(
939
- Itinerary.partial(
940
- day=1, type='daytime', activities=[Activity.partial()]
941
- ),
942
- {
943
- 'activities[0].description': schema_lib.MISSING,
944
- },
945
- )
946
-
947
- def test_mark_missing(self):
948
- class A(pg.Object):
949
- x: typing.Any
950
-
951
- self.assertEqual(schema_lib.mark_missing(1), 1)
952
- self.assertEqual(
953
- schema_lib.mark_missing(pg.MISSING_VALUE), pg.MISSING_VALUE
954
- )
955
- self.assertEqual(
956
- schema_lib.mark_missing(A.partial(A.partial(A.partial()))),
957
- A(A(A(schema_lib.MISSING))),
958
- )
959
- self.assertEqual(
960
- schema_lib.mark_missing(dict(a=A.partial())),
961
- dict(a=A(schema_lib.MISSING)),
962
- )
963
- self.assertEqual(
964
- schema_lib.mark_missing([1, dict(a=A.partial())]),
965
- [1, dict(a=A(schema_lib.MISSING))],
966
- )
967
-
968
-
969
- class UnknownTest(unittest.TestCase):
970
-
971
- def test_basics(self):
972
- class A(pg.Object):
973
- x: int
974
-
975
- a = A(x=schema_lib.Unknown())
976
- self.assertFalse(a.is_partial)
977
- self.assertEqual(a.x, schema_lib.UNKNOWN)
978
- self.assertEqual(schema_lib.UNKNOWN, schema_lib.Unknown())
979
-
980
-
981
- if __name__ == '__main__':
982
- unittest.main()