langfun 0.0.2.dev20240530__py3-none-any.whl → 0.0.2.dev20240601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +4 -0
- langfun/core/__init__.py +1 -0
- langfun/core/llms/google_genai.py +66 -13
- langfun/core/llms/google_genai_test.py +1 -1
- langfun/core/llms/vertexai.py +67 -14
- langfun/core/llms/vertexai_test.py +1 -1
- langfun/core/modalities/__init__.py +1 -1
- langfun/core/modalities/audio.py +1 -1
- langfun/core/modalities/image.py +1 -1
- langfun/core/modalities/image_test.py +23 -6
- langfun/core/modalities/mime.py +105 -16
- langfun/core/modalities/mime_test.py +18 -3
- langfun/core/modalities/ms_office.py +38 -10
- langfun/core/modalities/ms_office_test.py +93 -16
- langfun/core/modalities/pdf.py +1 -1
- langfun/core/modalities/video.py +1 -1
- langfun/core/modality.py +4 -0
- langfun/core/structured/__init__.py +2 -0
- langfun/core/structured/completion.py +3 -1
- langfun/core/structured/mapping.py +1 -5
- langfun/core/structured/prompting.py +0 -4
- langfun/core/structured/schema.py +88 -42
- langfun/core/structured/schema_test.py +87 -34
- {langfun-0.0.2.dev20240530.dist-info → langfun-0.0.2.dev20240601.dist-info}/METADATA +4 -3
- {langfun-0.0.2.dev20240530.dist-info → langfun-0.0.2.dev20240601.dist-info}/RECORD +28 -28
- {langfun-0.0.2.dev20240530.dist-info → langfun-0.0.2.dev20240601.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240530.dist-info → langfun-0.0.2.dev20240601.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240530.dist-info → langfun-0.0.2.dev20240601.dist-info}/top_level.txt +0 -0
@@ -12,11 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
"""Video tests."""
|
15
|
+
import base64
|
15
16
|
import io
|
16
17
|
import unittest
|
17
18
|
from unittest import mock
|
18
19
|
|
19
20
|
from langfun.core.modalities import ms_office as ms_office_lib
|
21
|
+
from langfun.core.modalities import pdf as pdf_lib
|
20
22
|
import pyglove as pg
|
21
23
|
|
22
24
|
|
@@ -243,23 +245,72 @@ def pptx_mock_request(*args, **kwargs):
|
|
243
245
|
return pg.Dict(content=pptx_bytes)
|
244
246
|
|
245
247
|
|
248
|
+
pdf_bytes = (
|
249
|
+
b'%PDF-1.1\n%\xc2\xa5\xc2\xb1\xc3\xab\n\n1 0 obj\n'
|
250
|
+
b'<< /Type /Catalog\n /Pages 2 0 R\n >>\nendobj\n\n2 0 obj\n '
|
251
|
+
b'<< /Type /Pages\n /Kids [3 0 R]\n '
|
252
|
+
b'/Count 1\n /MediaBox [0 0 300 144]\n '
|
253
|
+
b'>>\nendobj\n\n3 0 obj\n '
|
254
|
+
b'<< /Type /Page\n /Parent 2 0 R\n /Resources\n '
|
255
|
+
b'<< /Font\n'
|
256
|
+
b'<< /F1\n'
|
257
|
+
b'<< /Type /Font\n'
|
258
|
+
b'/Subtype /Type1\n'
|
259
|
+
b'/BaseFont /Times-Roman\n'
|
260
|
+
b'>>\n>>\n>>\n '
|
261
|
+
b'/Contents 4 0 R\n >>\nendobj\n\n4 0 obj\n '
|
262
|
+
b'<< /Length 55 >>\nstream\n BT\n /F1 18 Tf\n 0 0 Td\n '
|
263
|
+
b'(Hello World) Tj\n ET\nendstream\nendobj\n\nxref\n0 5\n0000000000 '
|
264
|
+
b'65535 f \n0000000018 00000 n \n0000000077 00000 n \n0000000178 00000 n '
|
265
|
+
b'\n0000000457 00000 n \ntrailer\n << /Root 1 0 R\n /Size 5\n '
|
266
|
+
b'>>\nstartxref\n565\n%%EOF\n'
|
267
|
+
)
|
268
|
+
|
269
|
+
|
270
|
+
def convert_mock_request(*args, **kwargs):
|
271
|
+
del args, kwargs
|
272
|
+
|
273
|
+
class Result:
|
274
|
+
def json(self):
|
275
|
+
return {
|
276
|
+
'Files': [
|
277
|
+
{
|
278
|
+
'FileData': base64.b64encode(pdf_bytes).decode()
|
279
|
+
}
|
280
|
+
]
|
281
|
+
}
|
282
|
+
return Result()
|
283
|
+
|
284
|
+
|
246
285
|
class DocxTest(unittest.TestCase):
|
247
286
|
|
248
|
-
def
|
287
|
+
def test_from_bytes(self):
|
249
288
|
content = ms_office_lib.Docx.from_bytes(docx_bytes)
|
250
|
-
self.
|
289
|
+
self.assertIn(
|
251
290
|
content.mime_type,
|
252
|
-
|
291
|
+
(
|
292
|
+
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
293
|
+
'application/octet-stream',
|
294
|
+
),
|
253
295
|
)
|
254
296
|
self.assertEqual(content.to_bytes(), docx_bytes)
|
297
|
+
self.assertTrue(content.is_compatible('text/plain'))
|
298
|
+
self.assertFalse(content.is_compatible('application/pdf'))
|
299
|
+
self.assertEqual(
|
300
|
+
content.make_compatible(['image/png', 'text/plain']).mime_type,
|
301
|
+
'text/plain'
|
302
|
+
)
|
255
303
|
|
256
|
-
def
|
304
|
+
def test_from_uri(self):
|
257
305
|
content = ms_office_lib.Docx.from_uri('http://mock/web/a.docx')
|
258
306
|
with mock.patch('requests.get') as mock_requests_get:
|
259
307
|
mock_requests_get.side_effect = docx_mock_request
|
260
|
-
self.
|
308
|
+
self.assertIn(
|
261
309
|
content.mime_type,
|
262
|
-
|
310
|
+
(
|
311
|
+
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
312
|
+
'application/octet-stream',
|
313
|
+
),
|
263
314
|
)
|
264
315
|
self.assertEqual(content.to_bytes(), docx_bytes)
|
265
316
|
self.assertEqual(content.to_xml(), expected_docx_xml)
|
@@ -267,21 +318,33 @@ class DocxTest(unittest.TestCase):
|
|
267
318
|
|
268
319
|
class XlsxTest(unittest.TestCase):
|
269
320
|
|
270
|
-
def
|
321
|
+
def test_from_bytes(self):
|
271
322
|
content = ms_office_lib.Xlsx.from_bytes(xlsx_bytes)
|
272
|
-
self.
|
323
|
+
self.assertIn(
|
273
324
|
content.mime_type,
|
274
|
-
|
325
|
+
(
|
326
|
+
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
327
|
+
'application/octet-stream',
|
328
|
+
),
|
275
329
|
)
|
276
330
|
self.assertEqual(content.to_bytes(), xlsx_bytes)
|
331
|
+
self.assertTrue(content.is_compatible('text/plain'))
|
332
|
+
self.assertFalse(content.is_compatible('application/pdf'))
|
333
|
+
self.assertEqual(
|
334
|
+
content.make_compatible('text/plain').mime_type,
|
335
|
+
'text/html'
|
336
|
+
)
|
277
337
|
|
278
|
-
def
|
338
|
+
def test_from_uri(self):
|
279
339
|
content = ms_office_lib.Xlsx.from_uri('http://mock/web/a.xlsx')
|
280
340
|
with mock.patch('requests.get') as mock_requests_get:
|
281
341
|
mock_requests_get.side_effect = xlsx_mock_request
|
282
|
-
self.
|
342
|
+
self.assertIn(
|
283
343
|
content.mime_type,
|
284
|
-
|
344
|
+
(
|
345
|
+
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
346
|
+
'application/octet-stream',
|
347
|
+
),
|
285
348
|
)
|
286
349
|
self.assertEqual(content.to_bytes(), xlsx_bytes)
|
287
350
|
self.assertEqual(content.to_html(), expected_xlsx_html)
|
@@ -291,22 +354,36 @@ class PptxTest(unittest.TestCase):
|
|
291
354
|
|
292
355
|
def test_content(self):
|
293
356
|
content = ms_office_lib.Pptx.from_bytes(pptx_bytes)
|
294
|
-
self.
|
357
|
+
self.assertIn(
|
295
358
|
content.mime_type,
|
296
|
-
|
359
|
+
(
|
360
|
+
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
361
|
+
'application/octet-stream',
|
362
|
+
),
|
297
363
|
)
|
298
364
|
self.assertEqual(content.to_bytes(), pptx_bytes)
|
299
365
|
|
300
366
|
def test_file(self):
|
301
367
|
content = ms_office_lib.Pptx.from_uri('http://mock/web/a.pptx')
|
368
|
+
self.assertFalse(content.is_compatible('text/plain'))
|
369
|
+
self.assertTrue(content.is_compatible('application/pdf'))
|
302
370
|
with mock.patch('requests.get') as mock_requests_get:
|
303
371
|
mock_requests_get.side_effect = pptx_mock_request
|
304
|
-
self.
|
372
|
+
self.assertIn(
|
305
373
|
content.mime_type,
|
306
|
-
|
374
|
+
(
|
375
|
+
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
376
|
+
'application/octet-stream',
|
377
|
+
),
|
307
378
|
)
|
308
379
|
self.assertEqual(content.to_bytes(), pptx_bytes)
|
309
380
|
|
381
|
+
with mock.patch('requests.post') as mock_requests_post:
|
382
|
+
mock_requests_post.side_effect = convert_mock_request
|
383
|
+
self.assertIsInstance(
|
384
|
+
content.make_compatible('application/pdf'), pdf_lib.PDF
|
385
|
+
)
|
386
|
+
|
310
387
|
|
311
388
|
if __name__ == '__main__':
|
312
389
|
unittest.main()
|
langfun/core/modalities/pdf.py
CHANGED
langfun/core/modalities/video.py
CHANGED
langfun/core/modality.py
CHANGED
@@ -108,3 +108,7 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
108
108
|
return ModalityRef(name=value.sym_path + k)
|
109
109
|
return v
|
110
110
|
return value.clone().rebind(_placehold, raise_on_no_change=False)
|
111
|
+
|
112
|
+
|
113
|
+
class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
|
114
|
+
"""Exception raised when modality is not supported."""
|
@@ -16,6 +16,8 @@
|
|
16
16
|
# pylint: disable=g-bad-import-order
|
17
17
|
# pylint: disable=g-importing-member
|
18
18
|
|
19
|
+
from langfun.core.structured.schema import include_method_in_prompt
|
20
|
+
|
19
21
|
from langfun.core.structured.schema import Missing
|
20
22
|
from langfun.core.structured.schema import MISSING
|
21
23
|
from langfun.core.structured.schema import Unknown
|
@@ -107,7 +107,9 @@ class CompleteStructure(mapping.Mapping):
|
|
107
107
|
|
108
108
|
def class_defs_repr(self, value: Any) -> str | None:
|
109
109
|
return schema_lib.class_definitions(
|
110
|
-
self.missing_type_dependencies(value),
|
110
|
+
self.missing_type_dependencies(value),
|
111
|
+
markdown=True,
|
112
|
+
allowed_dependencies=set()
|
111
113
|
)
|
112
114
|
|
113
115
|
def postprocess_result(self, result: Any) -> Any:
|
@@ -251,7 +251,7 @@ class Mapping(lf.LangFunc):
|
|
251
251
|
|
252
252
|
{%- if example.schema -%}
|
253
253
|
{{ schema_title }}:
|
254
|
-
{{ example.schema_repr(protocol
|
254
|
+
{{ example.schema_repr(protocol) | indent(2, True) }}
|
255
255
|
|
256
256
|
{% endif -%}
|
257
257
|
|
@@ -279,10 +279,6 @@ class Mapping(lf.LangFunc):
|
|
279
279
|
'The protocol for representing the schema and value.',
|
280
280
|
] = 'python'
|
281
281
|
|
282
|
-
include_methods: Annotated[
|
283
|
-
bool, 'If True, include method definitions in the schema.'
|
284
|
-
] = False
|
285
|
-
|
286
282
|
#
|
287
283
|
# Other user-provided flags.
|
288
284
|
#
|
@@ -114,7 +114,6 @@ def query(
|
|
114
114
|
autofix: int = 0,
|
115
115
|
autofix_lm: lf.LanguageModel | None = None,
|
116
116
|
protocol: schema_lib.SchemaProtocol = 'python',
|
117
|
-
include_methods: bool = False,
|
118
117
|
returns_message: bool = False,
|
119
118
|
skip_lm: bool = False,
|
120
119
|
**kwargs,
|
@@ -174,8 +173,6 @@ def query(
|
|
174
173
|
will use `lm`.
|
175
174
|
protocol: The protocol for schema/value representation. Applicable values
|
176
175
|
are 'json' and 'python'. By default `python` will be used.
|
177
|
-
include_methods: If True, include method definitions in the output type
|
178
|
-
during prompting.
|
179
176
|
returns_message: If True, returns `lf.Message` as the output, instead of
|
180
177
|
returning the structured `message.result`.
|
181
178
|
skip_lm: If True, returns the rendered prompt as a UserMessage object.
|
@@ -225,7 +222,6 @@ def query(
|
|
225
222
|
schema=schema,
|
226
223
|
default=default,
|
227
224
|
examples=examples,
|
228
|
-
include_methods=include_methods,
|
229
225
|
response_postprocess=response_postprocess,
|
230
226
|
autofix=autofix if protocol == 'python' else 0,
|
231
227
|
**kwargs,
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import abc
|
17
17
|
import inspect
|
18
18
|
import io
|
19
|
+
import re
|
19
20
|
import textwrap
|
20
21
|
import typing
|
21
22
|
from typing import Any, Literal, Sequence, Type, Union
|
@@ -24,6 +25,17 @@ from langfun.core.coding.python import correction
|
|
24
25
|
import pyglove as pg
|
25
26
|
|
26
27
|
|
28
|
+
def include_method_in_prompt(method):
|
29
|
+
"""Decorator to include a method in the class definition of the prompt."""
|
30
|
+
setattr(method, '__show_in_prompt__', True)
|
31
|
+
return method
|
32
|
+
|
33
|
+
|
34
|
+
def should_include_method_in_prompt(method):
|
35
|
+
"""Returns true if the method should be shown in the prompt."""
|
36
|
+
return getattr(method, '__show_in_prompt__', False)
|
37
|
+
|
38
|
+
|
27
39
|
def parse_value_spec(value) -> pg.typing.ValueSpec:
|
28
40
|
"""Parses a PyGlove ValueSpec equivalence into a ValueSpec."""
|
29
41
|
if isinstance(value, pg.typing.ValueSpec):
|
@@ -163,9 +175,12 @@ class Schema(lf.NaturalLanguageFormattable, pg.Object):
|
|
163
175
|
|
164
176
|
def class_dependencies(
|
165
177
|
self,
|
178
|
+
include_base_classes: bool = True,
|
166
179
|
include_subclasses: bool = True) -> list[Type[Any]]:
|
167
180
|
"""Returns a list of class dependencies for current schema."""
|
168
|
-
return class_dependencies(
|
181
|
+
return class_dependencies(
|
182
|
+
self.spec, include_base_classes, include_subclasses
|
183
|
+
)
|
169
184
|
|
170
185
|
@classmethod
|
171
186
|
def from_value(cls, value) -> 'Schema':
|
@@ -198,11 +213,12 @@ def class_dependencies(
|
|
198
213
|
Type[pg.Object],
|
199
214
|
tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...],
|
200
215
|
],
|
216
|
+
include_base_classes: bool = True,
|
201
217
|
include_subclasses: bool = True,
|
202
218
|
) -> list[Type[Any]]:
|
203
219
|
"""Returns a list of class dependencies from a value or specs."""
|
204
220
|
if isinstance(value_or_spec, Schema):
|
205
|
-
|
221
|
+
value_or_spec = value_or_spec.spec
|
206
222
|
|
207
223
|
if inspect.isclass(value_or_spec) or isinstance(
|
208
224
|
value_or_spec, pg.typing.ValueSpec
|
@@ -236,13 +252,14 @@ def class_dependencies(
|
|
236
252
|
if vs.cls not in seen:
|
237
253
|
seen.add(vs.cls)
|
238
254
|
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
255
|
+
if include_base_classes:
|
256
|
+
# Add base classes as dependencies.
|
257
|
+
for base_cls in vs.cls.__bases__:
|
258
|
+
# We only keep track of user-defined symbolic classes.
|
259
|
+
if base_cls is not object and base_cls is not pg.Object:
|
260
|
+
_fill_dependencies(
|
261
|
+
pg.typing.Object(base_cls), include_subclasses=False
|
262
|
+
)
|
246
263
|
|
247
264
|
# Add members as dependencies.
|
248
265
|
for field in _pg_schema(vs.cls).values():
|
@@ -301,7 +318,6 @@ class SchemaPythonRepr(SchemaRepr):
|
|
301
318
|
schema: Schema,
|
302
319
|
*,
|
303
320
|
include_result_definition: bool = True,
|
304
|
-
include_methods: bool = False,
|
305
321
|
markdown: bool = True,
|
306
322
|
**kwargs,
|
307
323
|
) -> str:
|
@@ -309,15 +325,27 @@ class SchemaPythonRepr(SchemaRepr):
|
|
309
325
|
if include_result_definition:
|
310
326
|
ret += self.result_definition(schema)
|
311
327
|
class_definition_str = self.class_definitions(
|
312
|
-
schema, markdown=markdown,
|
328
|
+
schema, markdown=markdown, **kwargs
|
313
329
|
)
|
314
330
|
if class_definition_str:
|
315
331
|
ret += f'\n\n{class_definition_str}'
|
316
332
|
return ret.strip()
|
317
333
|
|
318
|
-
def class_definitions(
|
319
|
-
|
320
|
-
|
334
|
+
def class_definitions(
|
335
|
+
self,
|
336
|
+
schema: Schema,
|
337
|
+
additional_dependencies: list[Type[Any]] | None = None,
|
338
|
+
**kwargs
|
339
|
+
) -> str | None:
|
340
|
+
"""Returns a string containing of class definitions from a schema."""
|
341
|
+
deps = schema.class_dependencies(
|
342
|
+
include_base_classes=False, include_subclasses=True
|
343
|
+
)
|
344
|
+
allowed_dependencies = set(deps)
|
345
|
+
if additional_dependencies:
|
346
|
+
allowed_dependencies.update(additional_dependencies)
|
347
|
+
return class_definitions(
|
348
|
+
deps, allowed_dependencies=allowed_dependencies, **kwargs)
|
321
349
|
|
322
350
|
def result_definition(self, schema: Schema) -> str:
|
323
351
|
return annotation(schema.spec)
|
@@ -331,8 +359,7 @@ def source_form(value, markdown: bool = False) -> str:
|
|
331
359
|
def class_definitions(
|
332
360
|
classes: Sequence[Type[Any]],
|
333
361
|
*,
|
334
|
-
|
335
|
-
include_methods: bool = False,
|
362
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
336
363
|
strict: bool = False,
|
337
364
|
markdown: bool = False,
|
338
365
|
) -> str | None:
|
@@ -347,8 +374,7 @@ def class_definitions(
|
|
347
374
|
class_definition(
|
348
375
|
cls,
|
349
376
|
strict=strict,
|
350
|
-
|
351
|
-
include_methods=include_methods,
|
377
|
+
allowed_dependencies=allowed_dependencies,
|
352
378
|
)
|
353
379
|
)
|
354
380
|
ret = def_str.getvalue()
|
@@ -360,8 +386,7 @@ def class_definitions(
|
|
360
386
|
def class_definition(
|
361
387
|
cls,
|
362
388
|
strict: bool = False,
|
363
|
-
|
364
|
-
include_methods: bool = False,
|
389
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
365
390
|
) -> str:
|
366
391
|
"""Returns the Python class definition."""
|
367
392
|
out = io.StringIO()
|
@@ -369,7 +394,7 @@ def class_definition(
|
|
369
394
|
eligible_bases = []
|
370
395
|
for base_cls in cls.__bases__:
|
371
396
|
if base_cls is not object:
|
372
|
-
if
|
397
|
+
if allowed_dependencies is None or base_cls in allowed_dependencies:
|
373
398
|
eligible_bases.append(base_cls.__name__)
|
374
399
|
|
375
400
|
if eligible_bases:
|
@@ -406,32 +431,41 @@ def class_definition(
|
|
406
431
|
out.write(' # ')
|
407
432
|
out.write(line)
|
408
433
|
out.write('\n')
|
409
|
-
out.write(f' {field.key}: {annotation(field.value, strict=strict)}')
|
410
|
-
out.write('\n')
|
411
|
-
empty_class = False
|
412
434
|
|
413
|
-
|
414
|
-
|
415
|
-
out.write('\n')
|
416
|
-
out.write(
|
417
|
-
textwrap.indent(
|
418
|
-
inspect.cleandoc('\n' + inspect.getsource(method)), ' ' * 2)
|
435
|
+
annotation_str = annotation(
|
436
|
+
field.value, strict=strict, allowed_dependencies=allowed_dependencies
|
419
437
|
)
|
438
|
+
out.write(f' {field.key}: {annotation_str}')
|
420
439
|
out.write('\n')
|
421
440
|
empty_class = False
|
422
441
|
|
442
|
+
for method in _iter_newly_defined_methods(cls, allowed_dependencies):
|
443
|
+
source = inspect.getsource(method)
|
444
|
+
# Remove decorators from the method definition.
|
445
|
+
source = re.sub(r'\s*@.*\.include_method_in_prompt.*\n', '', source)
|
446
|
+
out.write('\n')
|
447
|
+
out.write(
|
448
|
+
textwrap.indent(
|
449
|
+
inspect.cleandoc('\n' + source), ' ' * 2)
|
450
|
+
)
|
451
|
+
out.write('\n')
|
452
|
+
empty_class = False
|
453
|
+
|
423
454
|
if empty_class:
|
424
455
|
out.write(' pass\n')
|
425
456
|
return out.getvalue()
|
426
457
|
|
427
458
|
|
428
|
-
def _iter_newly_defined_methods(
|
429
|
-
|
459
|
+
def _iter_newly_defined_methods(
|
460
|
+
cls, allowed_dependencies: set[Type[Any]] | None):
|
461
|
+
names = {attr_name: True for attr_name in dir(cls)}
|
430
462
|
for base in cls.__bases__:
|
431
|
-
|
432
|
-
|
463
|
+
if allowed_dependencies is None or base in allowed_dependencies:
|
464
|
+
for name in dir(base):
|
465
|
+
names.pop(name, None)
|
466
|
+
for name in names.keys():
|
433
467
|
attr = getattr(cls, name)
|
434
|
-
if callable(attr):
|
468
|
+
if callable(attr) and should_include_method_in_prompt(attr):
|
435
469
|
yield attr
|
436
470
|
|
437
471
|
|
@@ -439,8 +473,12 @@ def annotation(
|
|
439
473
|
vs: pg.typing.ValueSpec,
|
440
474
|
annotate_optional: bool = True,
|
441
475
|
strict: bool = False,
|
476
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
442
477
|
) -> str:
|
443
478
|
"""Returns the annotation string for a value spec."""
|
479
|
+
child_annotation_kwargs = dict(
|
480
|
+
strict=strict, allowed_dependencies=allowed_dependencies
|
481
|
+
)
|
444
482
|
if isinstance(vs, pg.typing.Any):
|
445
483
|
return 'Any'
|
446
484
|
elif isinstance(vs, pg.typing.Enum):
|
@@ -449,7 +487,7 @@ def annotation(
|
|
449
487
|
elif isinstance(vs, pg.typing.Union):
|
450
488
|
candidate_str = ', '.join(
|
451
489
|
[
|
452
|
-
annotation(c, annotate_optional=False,
|
490
|
+
annotation(c, annotate_optional=False, **child_annotation_kwargs)
|
453
491
|
for c in vs.candidates
|
454
492
|
]
|
455
493
|
)
|
@@ -485,20 +523,23 @@ def annotation(
|
|
485
523
|
)
|
486
524
|
x += '(' + ', '.join(constraints) + ')'
|
487
525
|
elif isinstance(vs, pg.typing.Object):
|
488
|
-
|
526
|
+
if allowed_dependencies is None or vs.cls in allowed_dependencies:
|
527
|
+
x = vs.cls.__name__
|
528
|
+
else:
|
529
|
+
x = 'Any'
|
489
530
|
elif isinstance(vs, pg.typing.List):
|
490
|
-
item_str = annotation(vs.element.value,
|
531
|
+
item_str = annotation(vs.element.value, **child_annotation_kwargs)
|
491
532
|
x = f'list[{item_str}]'
|
492
533
|
elif isinstance(vs, pg.typing.Tuple):
|
493
534
|
elem_str = ', '.join(
|
494
|
-
[annotation(el.value,
|
535
|
+
[annotation(el.value, **child_annotation_kwargs) for el in vs.elements]
|
495
536
|
)
|
496
537
|
x = f'tuple[{elem_str}]'
|
497
538
|
elif isinstance(vs, pg.typing.Dict):
|
498
539
|
kv_pairs = None
|
499
540
|
if vs.schema is not None:
|
500
541
|
kv_pairs = [
|
501
|
-
(k, annotation(f.value,
|
542
|
+
(k, annotation(f.value, **child_annotation_kwargs))
|
502
543
|
for k, f in vs.schema.items()
|
503
544
|
if isinstance(k, pg.typing.ConstStrKey)
|
504
545
|
]
|
@@ -509,7 +550,7 @@ def annotation(
|
|
509
550
|
if strict:
|
510
551
|
x = f'pg.typing.Dict({x})'
|
511
552
|
elif vs.schema and vs.schema.dynamic_field:
|
512
|
-
v = annotation(vs.schema.dynamic_field.value,
|
553
|
+
v = annotation(vs.schema.dynamic_field.value, **child_annotation_kwargs)
|
513
554
|
x = f'dict[str, {v}]'
|
514
555
|
else:
|
515
556
|
x = 'dict[str, Any]'
|
@@ -604,7 +645,12 @@ class ValuePythonRepr(ValueRepr):
|
|
604
645
|
cls_schema = Schema.from_value(value)
|
605
646
|
if isinstance(cls_schema.spec, pg.typing.Object):
|
606
647
|
object_code = SchemaPythonRepr().class_definitions(
|
607
|
-
cls_schema,
|
648
|
+
cls_schema,
|
649
|
+
markdown=markdown,
|
650
|
+
# We add `pg.Object` as additional dependencies to the class
|
651
|
+
# definition so exemplars for class generation could show
|
652
|
+
# pg.Object as their bases.
|
653
|
+
additional_dependencies=[pg.Object]
|
608
654
|
)
|
609
655
|
assert object_code is not None
|
610
656
|
return object_code
|