linkml 1.8.1__py3-none-any.whl → 1.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. linkml/cli/__init__.py +0 -0
  2. linkml/cli/__main__.py +4 -0
  3. linkml/cli/main.py +126 -0
  4. linkml/generators/common/build.py +105 -0
  5. linkml/generators/common/lifecycle.py +124 -0
  6. linkml/generators/common/template.py +89 -0
  7. linkml/generators/csvgen.py +1 -1
  8. linkml/generators/docgen/slot.md.jinja2 +4 -0
  9. linkml/generators/docgen.py +1 -1
  10. linkml/generators/dotgen.py +1 -1
  11. linkml/generators/erdiagramgen.py +1 -1
  12. linkml/generators/excelgen.py +1 -1
  13. linkml/generators/golanggen.py +1 -1
  14. linkml/generators/golrgen.py +1 -1
  15. linkml/generators/graphqlgen.py +1 -1
  16. linkml/generators/javagen.py +1 -1
  17. linkml/generators/jsonldcontextgen.py +4 -4
  18. linkml/generators/jsonldgen.py +1 -1
  19. linkml/generators/jsonschemagen.py +69 -22
  20. linkml/generators/linkmlgen.py +1 -1
  21. linkml/generators/markdowngen.py +1 -1
  22. linkml/generators/namespacegen.py +1 -1
  23. linkml/generators/oocodegen.py +2 -1
  24. linkml/generators/owlgen.py +1 -1
  25. linkml/generators/plantumlgen.py +1 -1
  26. linkml/generators/prefixmapgen.py +1 -1
  27. linkml/generators/projectgen.py +1 -1
  28. linkml/generators/protogen.py +1 -1
  29. linkml/generators/pydanticgen/__init__.py +8 -3
  30. linkml/generators/pydanticgen/array.py +114 -194
  31. linkml/generators/pydanticgen/build.py +64 -25
  32. linkml/generators/pydanticgen/includes.py +1 -31
  33. linkml/generators/pydanticgen/pydanticgen.py +616 -274
  34. linkml/generators/pydanticgen/template.py +152 -184
  35. linkml/generators/pydanticgen/templates/attribute.py.jinja +9 -7
  36. linkml/generators/pydanticgen/templates/base_model.py.jinja +0 -13
  37. linkml/generators/pydanticgen/templates/class.py.jinja +2 -2
  38. linkml/generators/pydanticgen/templates/footer.py.jinja +2 -10
  39. linkml/generators/pydanticgen/templates/module.py.jinja +2 -2
  40. linkml/generators/pydanticgen/templates/validator.py.jinja +0 -4
  41. linkml/generators/pythongen.py +12 -2
  42. linkml/generators/rdfgen.py +1 -1
  43. linkml/generators/shaclgen.py +6 -2
  44. linkml/generators/shexgen.py +1 -1
  45. linkml/generators/sparqlgen.py +1 -1
  46. linkml/generators/sqlalchemygen.py +1 -1
  47. linkml/generators/sqltablegen.py +1 -1
  48. linkml/generators/sssomgen.py +1 -1
  49. linkml/generators/summarygen.py +1 -1
  50. linkml/generators/terminusdbgen.py +7 -4
  51. linkml/generators/typescriptgen.py +1 -1
  52. linkml/generators/yamlgen.py +1 -1
  53. linkml/generators/yumlgen.py +1 -1
  54. linkml/linter/cli.py +1 -1
  55. linkml/transformers/logical_model_transformer.py +117 -18
  56. linkml/utils/converter.py +1 -1
  57. linkml/utils/execute_tutorial.py +2 -0
  58. linkml/utils/logictools.py +142 -29
  59. linkml/utils/schema_builder.py +7 -6
  60. linkml/utils/schema_fixer.py +1 -1
  61. linkml/utils/sqlutils.py +1 -1
  62. linkml/validator/cli.py +4 -1
  63. linkml/validators/jsonschemavalidator.py +1 -1
  64. linkml/validators/sparqlvalidator.py +1 -1
  65. linkml/workspaces/example_runner.py +1 -1
  66. {linkml-1.8.1.dist-info → linkml-1.8.2.dist-info}/METADATA +2 -2
  67. {linkml-1.8.1.dist-info → linkml-1.8.2.dist-info}/RECORD +70 -64
  68. {linkml-1.8.1.dist-info → linkml-1.8.2.dist-info}/entry_points.txt +1 -1
  69. {linkml-1.8.1.dist-info → linkml-1.8.2.dist-info}/LICENSE +0 -0
  70. {linkml-1.8.1.dist-info → linkml-1.8.2.dist-info}/WHEEL +0 -0
@@ -1,20 +1,20 @@
1
1
  import inspect
2
2
  import logging
3
3
  import os
4
+ import re
4
5
  import textwrap
5
6
  from collections import defaultdict
6
- from copy import copy, deepcopy
7
7
  from dataclasses import dataclass, field
8
8
  from enum import Enum
9
9
  from pathlib import Path
10
10
  from types import ModuleType
11
- from typing import Dict, List, Literal, Optional, Set, Type, TypeVar, Union, overload
11
+ from typing import ClassVar, Dict, List, Literal, Optional, Set, Tuple, Type, TypeVar, Union, overload
12
12
 
13
13
  import click
14
- from jinja2 import ChoiceLoader, Environment, FileSystemLoader
14
+ from jinja2 import ChoiceLoader, Environment, FileSystemLoader, Template
15
15
  from linkml_runtime.linkml_model.meta import (
16
- Annotation,
17
16
  ClassDefinition,
17
+ ElementName,
18
18
  SchemaDefinition,
19
19
  SlotDefinition,
20
20
  TypeDefinition,
@@ -25,13 +25,13 @@ from linkml_runtime.utils.schemaview import SchemaView
25
25
  from pydantic.version import VERSION as PYDANTIC_VERSION
26
26
 
27
27
  from linkml._version import __version__
28
+ from linkml.generators.common.lifecycle import LifecycleMixin
28
29
  from linkml.generators.common.type_designators import get_accepted_type_designator_values, get_type_designator_value
29
30
  from linkml.generators.oocodegen import OOCodeGenerator
30
31
  from linkml.generators.pydanticgen import includes
31
32
  from linkml.generators.pydanticgen.array import ArrayRangeGenerator, ArrayRepresentation
32
- from linkml.generators.pydanticgen.build import SlotResult
33
+ from linkml.generators.pydanticgen.build import ClassResult, SlotResult, SplitResult
33
34
  from linkml.generators.pydanticgen.template import (
34
- ConditionalImport,
35
35
  Import,
36
36
  Imports,
37
37
  ObjectImport,
@@ -39,7 +39,7 @@ from linkml.generators.pydanticgen.template import (
39
39
  PydanticBaseModel,
40
40
  PydanticClass,
41
41
  PydanticModule,
42
- TemplateModel,
42
+ PydanticTemplateModel,
43
43
  )
44
44
  from linkml.utils import deprecation_warning
45
45
  from linkml.utils.generator import shared_arguments
@@ -84,9 +84,7 @@ DEFAULT_IMPORTS = (
84
84
  ObjectImport(name="Union"),
85
85
  ],
86
86
  )
87
- + Import(module="pydantic.version", objects=[ObjectImport(name="VERSION", alias="PYDANTIC_VERSION")])
88
- + ConditionalImport(
89
- condition="int(PYDANTIC_VERSION[0])>=2",
87
+ + Import(
90
88
  module="pydantic",
91
89
  objects=[
92
90
  ObjectImport(name="BaseModel"),
@@ -95,20 +93,16 @@ DEFAULT_IMPORTS = (
95
93
  ObjectImport(name="RootModel"),
96
94
  ObjectImport(name="field_validator"),
97
95
  ],
98
- alternative=Import(
99
- module="pydantic",
100
- objects=[ObjectImport(name="BaseModel"), ObjectImport(name="Field"), ObjectImport(name="validator")],
101
- ),
102
96
  )
103
97
  )
104
98
 
105
- DEFAULT_INJECTS = {1: [includes.LinkMLMeta_v1], 2: [includes.LinkMLMeta_v2]}
99
+ DEFAULT_INJECTS = [includes.LinkMLMeta]
106
100
 
107
101
 
108
102
  class MetadataMode(str, Enum):
109
103
  FULL = "full"
110
104
  """
111
- all metadata from the source schema will be included, even if it is represented by the template classes,
105
+ all metadata from the source schema will be included, even if it is represented by the template classes,
112
106
  and even if it is represented by some child class (eg. "classes" will be included with schema metadata
113
107
  """
114
108
  EXCEPT_CHILDREN = "except_children"
@@ -118,7 +112,7 @@ class MetadataMode(str, Enum):
118
112
  """
119
113
  AUTO = "auto"
120
114
  """
121
- Only the metadata that isn't represented by the template classes or excluded with ``meta_exclude`` will be included
115
+ Only the metadata that isn't represented by the template classes or excluded with ``meta_exclude`` will be included
122
116
  """
123
117
  NONE = None
124
118
  """
@@ -126,16 +120,55 @@ class MetadataMode(str, Enum):
126
120
  """
127
121
 
128
122
 
123
+ class SplitMode(str, Enum):
124
+ FULL = "full"
125
+ """
126
+ Import all classes defined in imported schemas
127
+ """
128
+
129
+ AUTO = "auto"
130
+ """
131
+ Only import those classes that are actually used in the generated schema as
132
+
133
+ * parents (``is_a``)
134
+ * mixins
135
+ * slot ranges
136
+ """
137
+
138
+
129
139
  DefinitionType = TypeVar("DefinitionType", bound=Union[SchemaDefinition, ClassDefinition, SlotDefinition])
130
140
  TemplateType = TypeVar("TemplateType", bound=Union[PydanticModule, PydanticClass, PydanticAttribute])
131
141
 
132
142
 
133
143
  @dataclass
134
- class PydanticGenerator(OOCodeGenerator):
144
+ class PydanticGenerator(OOCodeGenerator, LifecycleMixin):
135
145
  """
136
146
  Generates Pydantic-compliant classes from a schema
137
147
 
138
148
  This is an alternative to the dataclasses-based Pythongen
149
+
150
+ Lifecycle methods (see :class:`.LifecycleMixin` ) supported:
151
+
152
+ * :meth:`~.LifecycleMixin.before_generate_enums`
153
+
154
+ Slot generation is nested within class generation, since the pydantic generator currently doesn't
155
+ create an independent representation of slots aside from their materialization as class fields.
156
+ Accordingly, the ``before_`` and ``after_generate_slots`` are called before and after each class's
157
+ slot generation, rather than all slot generation.
158
+
159
+ * :meth:`~.LifecycleMixin.before_generate_classes`
160
+ * :meth:`~.LifecycleMixin.before_generate_class`
161
+ * :meth:`~.LifecycleMixin.after_generate_class`
162
+ * :meth:`~.LifecycleMixin.after_generate_classes`
163
+
164
+ * :meth:`~.LifecycleMixin.before_generate_slots`
165
+ * :meth:`~.LifecycleMixin.before_generate_slot`
166
+ * :meth:`~.LifecycleMixin.after_generate_slot`
167
+ * :meth:`~.LifecycleMixin.after_generate_slots`
168
+
169
+ * :meth:`~.LifecycleMixin.before_render_template`
170
+ * :meth:`~.LifecycleMixin.after_render_template`
171
+
139
172
  """
140
173
 
141
174
  # ClassVar overrides
@@ -150,12 +183,12 @@ class PydanticGenerator(OOCodeGenerator):
150
183
  """
151
184
  If black is present in the environment, format the serialized code with it
152
185
  """
153
- pydantic_version: int = int(PYDANTIC_VERSION[0])
186
+
154
187
  template_dir: Optional[Union[str, Path]] = None
155
188
  """
156
- Override templates for each TemplateModel.
189
+ Override templates for each PydanticTemplateModel.
157
190
 
158
- Directory with templates that override the default :attr:`.TemplateModel.template`
191
+ Directory with templates that override the default :attr:`.PydanticTemplateModel.template`
159
192
  for each class. If a matching template is not found in the override directory,
160
193
  the default templates will be used.
161
194
  """
@@ -164,62 +197,62 @@ class PydanticGenerator(OOCodeGenerator):
164
197
  injected_classes: Optional[List[Union[Type, str]]] = None
165
198
  """
166
199
  A list/tuple of classes to inject into the generated module.
167
-
200
+
168
201
  Accepts either live classes or strings. Live classes will have their source code
169
202
  extracted with inspect.get - so they need to be standard python classes declared in a
170
- source file (ie. the module they are contained in needs a ``__file__`` attr,
203
+ source file (ie. the module they are contained in needs a ``__file__`` attr,
171
204
  see: :func:`inspect.getsource` )
172
205
  """
173
206
  injected_fields: Optional[List[str]] = None
174
207
  """
175
208
  A list/tuple of field strings to inject into the base class.
176
-
209
+
177
210
  Examples:
178
-
211
+
179
212
  .. code-block:: python
180
213
 
181
214
  injected_fields = (
182
215
  'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
183
216
  )
184
-
217
+
185
218
  """
186
219
  imports: Optional[List[Import]] = None
187
220
  """
188
- Additional imports to inject into generated module.
189
-
221
+ Additional imports to inject into generated module.
222
+
190
223
  Examples:
191
-
224
+
192
225
  .. code-block:: python
193
-
226
+
194
227
  from linkml.generators.pydanticgen.template import (
195
228
  ConditionalImport,
196
229
  ObjectImport,
197
230
  Import,
198
231
  Imports
199
232
  )
200
-
201
- imports = (Imports() +
202
- Import(module='sys') +
203
- Import(module='numpy', alias='np') +
233
+
234
+ imports = (Imports() +
235
+ Import(module='sys') +
236
+ Import(module='numpy', alias='np') +
204
237
  Import(module='pathlib', objects=[
205
238
  ObjectImport(name="Path"),
206
239
  ObjectImport(name="PurePath", alias="RenamedPurePath")
207
- ]) +
240
+ ]) +
208
241
  ConditionalImport(
209
242
  module="typing",
210
243
  objects=[ObjectImport(name="Literal")],
211
244
  condition="sys.version_info >= (3, 8)",
212
245
  alternative=Import(
213
- module="typing_extensions",
246
+ module="typing_extensions",
214
247
  objects=[ObjectImport(name="Literal")]
215
248
  ),
216
249
  ).imports
217
250
  )
218
-
251
+
219
252
  becomes:
220
-
253
+
221
254
  .. code-block:: python
222
-
255
+
223
256
  import sys
224
257
  import numpy as np
225
258
  from pathlib import (
@@ -230,14 +263,72 @@ class PydanticGenerator(OOCodeGenerator):
230
263
  from typing import Literal
231
264
  else:
232
265
  from typing_extensions import Literal
233
-
266
+
234
267
  """
235
268
  metadata_mode: Union[MetadataMode, str, None] = MetadataMode.AUTO
236
269
  """
237
270
  How to include schema metadata in generated pydantic models.
238
-
271
+
239
272
  See :class:`.MetadataMode` for mode documentation
240
273
  """
274
+ split: bool = False
275
+ """
276
+ Generate schema that import other schema as separate python modules
277
+ that import from one another, rather than rolling all into a single
278
+ module (default, ``False``).
279
+ """
280
+ split_pattern: str = ".{{ schema.name }}"
281
+ """
282
+ When splitting generation, imported modules need to be generated separately
283
+ and placed in a python package and import from each other. Since the
284
+ location of those imported modules is variable -- e.g. one might want to
285
+ generate schema in multiple packages depending on their version -- this
286
+ pattern is used to generate the module portion of the import statement.
287
+
288
+ These patterns should generally yield a relative module import,
289
+ since functions like :func:`.generate_split` will generate and write files
290
+ relative to some base file, though this is not a requirement since custom
291
+ split generation logic is also allowed.
292
+
293
+ The pattern is a jinja template string that is given the ``SchemaDefinition``
294
+ of the imported schema in the environment. Additional variables can be passed
295
+ into the jinja environment with the :attr:`.split_context` argument.
296
+
297
+ Further modification is possible by using jinja filters.
298
+
299
+ After templating, the string is passed through a :attr:`SNAKE_CASE` pattern
300
+ to replace whitespace and other characters that can't be used in module names.
301
+
302
+ See also :meth:`.generate_module_import`, which is used to generate the
303
+ module portion of the import statement (and can be overridden in subclasses).
304
+
305
+ Examples:
306
+
307
+ for a schema named ``ExampleSchema`` and version ``1.2.3`` ...
308
+
309
+ ``".{{ schema.name }}"`` (the default) becomes
310
+
311
+ ``from .example_schema import ClassA, ...``
312
+
313
+ ``"...{{ schema.name }}.v{{ schema.version | replace('.', '_') }}"`` becomes
314
+
315
+ ``from ...example_schema.v1_2_3 import ClassA, ...``
316
+
317
+ """
318
+ split_context: Optional[dict] = None
319
+ """
320
+ Additional variables to pass into ``split_pattern`` when
321
+ generating imported module names.
322
+
323
+ Passed in as ``**kwargs`` , so e.g. if ``split_context = {'myval': 1}``
324
+ then one would use it in a template string like ``{{ myval }}``
325
+ """
326
+ split_mode: SplitMode = SplitMode.AUTO
327
+ """
328
+ How to filter imports from imported schema.
329
+
330
+ See :class:`.SplitMode` for description of options
331
+ """
241
332
 
242
333
  # ObjectVars (identical to pythongen)
243
334
  gen_classvars: bool = True
@@ -245,10 +336,16 @@ class PydanticGenerator(OOCodeGenerator):
245
336
  genmeta: bool = False
246
337
  emit_metadata: bool = True
247
338
 
339
+ # ClassVars
340
+ SNAKE_CASE: ClassVar[str] = r"(((?<!^)(?<!\.))(?=[A-Z][a-z]))|([^\w\.]+)"
341
+ """Substitute CamelCase and non-word characters with _"""
342
+
343
+ # Private attributes
344
+ _predefined_slot_values: Optional[Dict[str, Dict[str, str]]] = None
345
+ _class_bases: Optional[Dict[str, List[str]]] = None
346
+
248
347
  def __post_init__(self):
249
348
  super().__post_init__()
250
- if int(self.pydantic_version) == 1:
251
- deprecation_warning("pydanticgen-v1")
252
349
 
253
350
  def compile_module(self, **kwargs) -> ModuleType:
254
351
  """
@@ -263,8 +360,20 @@ class PydanticGenerator(OOCodeGenerator):
263
360
  logging.error(f"Error compiling generated python code: {e}")
264
361
  raise e
265
362
 
363
+ def _get_classes(self, sv: SchemaView) -> Tuple[List[ClassDefinition], Optional[List[ClassDefinition]]]:
364
+ all_classes = sv.all_classes(imports=True).values()
365
+
366
+ if self.split:
367
+ local_classes = sv.all_classes(imports=False).values()
368
+ imported_classes = [c for c in all_classes if c not in local_classes]
369
+ return list(local_classes), imported_classes
370
+ else:
371
+ return list(all_classes), None
372
+
266
373
  @staticmethod
267
- def sort_classes(clist: List[ClassDefinition]) -> List[ClassDefinition]:
374
+ def sort_classes(
375
+ clist: List[ClassDefinition], imported: Optional[List[ClassDefinition]] = None
376
+ ) -> List[ClassDefinition]:
268
377
  """
269
378
  sort classes such that if C is a child of P then C appears after P in the list
270
379
 
@@ -272,6 +381,9 @@ class PydanticGenerator(OOCodeGenerator):
272
381
 
273
382
  TODO: This should move to SchemaView
274
383
  """
384
+ if imported is not None:
385
+ imported = [i.name for i in imported]
386
+
275
387
  clist = list(clist)
276
388
  slist = [] # sorted
277
389
  while len(clist) > 0:
@@ -283,6 +395,11 @@ class PydanticGenerator(OOCodeGenerator):
283
395
  candidates = [candidate.is_a] + candidate.mixins
284
396
  else:
285
397
  candidates = candidate.mixins
398
+
399
+ # remove blocking classes imported from other schemas if in split mode
400
+ if imported:
401
+ candidates = [c for c in candidates if c not in imported]
402
+
286
403
  if not candidates:
287
404
  can_add = True
288
405
  else:
@@ -296,82 +413,165 @@ class PydanticGenerator(OOCodeGenerator):
296
413
  raise ValueError(f"could not find suitable element in {clist} that does not ref {slist}")
297
414
  return slist
298
415
 
299
- def get_predefined_slot_values(self) -> Dict[str, Dict[str, str]]:
416
+ def generate_class(self, cls: ClassDefinition) -> ClassResult:
417
+ pyclass = PydanticClass(
418
+ name=camelcase(cls.name),
419
+ bases=self.class_bases.get(camelcase(cls.name), PydanticBaseModel.default_name),
420
+ description=cls.description.replace('"', '\\"') if cls.description is not None else None,
421
+ )
422
+
423
+ imports = self._get_imports(cls) if self.split else None
424
+
425
+ result = ClassResult(cls=pyclass, source=cls, imports=imports)
426
+
427
+ # Gather slots
428
+ slots = [self.schemaview.induced_slot(sn, cls.name) for sn in self.schemaview.class_slots(cls.name)]
429
+ slots = self.before_generate_slots(slots, self.schemaview)
430
+
431
+ slot_results = []
432
+ for slot in slots:
433
+ slot = self.before_generate_slot(slot, self.schemaview)
434
+ slot = self.generate_slot(slot, cls)
435
+ slot = self.after_generate_slot(slot, self.schemaview)
436
+ slot_results.append(slot)
437
+ result = result.merge(slot)
438
+
439
+ slot_results = self.after_generate_slots(slot_results, self.schemaview)
440
+ attributes = {slot.attribute.name: slot.attribute for slot in slot_results}
441
+
442
+ result.cls.attributes = attributes
443
+ result.cls = self.include_metadata(result.cls, cls)
444
+
445
+ return result
446
+
447
+ def generate_slot(self, slot: SlotDefinition, cls: ClassDefinition) -> SlotResult:
448
+ slot_args = {
449
+ k: slot._as_dict.get(k, None)
450
+ for k in PydanticAttribute.model_fields.keys()
451
+ if slot._as_dict.get(k, None) is not None
452
+ }
453
+ slot_args["name"] = underscore(slot.name)
454
+ slot_args["description"] = slot.description.replace('"', '\\"') if slot.description is not None else None
455
+ predef = self.predefined_slot_values.get(camelcase(cls.name), {}).get(slot.name, None)
456
+ if predef is not None:
457
+ slot_args["predefined"] = str(predef)
458
+
459
+ pyslot = PydanticAttribute(**slot_args)
460
+ pyslot = self.include_metadata(pyslot, slot)
461
+
462
+ slot_ranges = []
463
+ # Confirm that the original slot range (ignoring the default that comes in from
464
+ # induced_slot) isn't in addition to setting any_of
465
+ any_of_ranges = [a.range if a.range else slot.range for a in slot.any_of]
466
+ if any_of_ranges:
467
+ # list comprehension here is pulling ranges from within AnonymousSlotExpression
468
+ slot_ranges.extend(any_of_ranges)
469
+ else:
470
+ slot_ranges.append(slot.range)
471
+
472
+ pyranges = [self.generate_python_range(slot_range, slot, cls) for slot_range in slot_ranges]
473
+
474
+ pyranges = list(set(pyranges)) # remove duplicates
475
+ pyranges.sort()
476
+
477
+ if len(pyranges) == 1:
478
+ pyrange = pyranges[0]
479
+ elif len(pyranges) > 1:
480
+ pyrange = f"Union[{', '.join(pyranges)}]"
481
+ else:
482
+ raise Exception(f"Could not generate python range for {cls.name}.{slot.name}")
483
+
484
+ pyslot.range = pyrange
485
+
486
+ imports = self._get_imports(slot) if self.split else None
487
+
488
+ result = SlotResult(attribute=pyslot, source=slot, imports=imports)
489
+
490
+ if slot.array is not None:
491
+ results = self.get_array_representations_range(slot, result.attribute.range)
492
+ if len(results) == 1:
493
+ result.attribute.range = results[0].range
494
+ else:
495
+ result.attribute.range = f"Union[{', '.join([res.range for res in results])}]"
496
+ for res in results:
497
+ result = result.merge(res)
498
+
499
+ elif slot.multivalued:
500
+ if slot.inlined or slot.inlined_as_list:
501
+ collection_key = self.generate_collection_key(slot_ranges, slot, cls)
502
+ else:
503
+ collection_key = None
504
+ if slot.inlined is False or collection_key is None or slot.inlined_as_list is True:
505
+ result.attribute.range = f"List[{result.attribute.range}]"
506
+ else:
507
+ simple_dict_value = None
508
+ if len(slot_ranges) == 1:
509
+ simple_dict_value = self._inline_as_simple_dict_with_value(slot)
510
+ if simple_dict_value:
511
+ # simple_dict_value might be the range of the identifier of a class when range is a class,
512
+ # so we specify either that identifier or the range itself
513
+ if simple_dict_value != result.attribute.range:
514
+ simple_dict_value = f"Union[{simple_dict_value}, {result.attribute.range}]"
515
+ result.attribute.range = f"Dict[str, {simple_dict_value}]"
516
+ else:
517
+ result.attribute.range = f"Dict[{collection_key}, {result.attribute.range}]"
518
+ if not (slot.required or slot.identifier or slot.key) and not slot.designates_type:
519
+ result.attribute.range = f"Optional[{result.attribute.range}]"
520
+ return result
521
+
522
+ @property
523
+ def predefined_slot_values(self) -> Dict[str, Dict[str, str]]:
300
524
  """
301
525
  :return: Dictionary of dictionaries with predefined slot values for each class
302
526
  """
303
- sv = self.schemaview
304
- slot_values = defaultdict(dict)
305
- for class_def in sv.all_classes().values():
306
- for slot_name in sv.class_slots(class_def.name):
307
- slot = sv.induced_slot(slot_name, class_def.name)
308
- if slot.designates_type:
309
- target_value = get_type_designator_value(sv, slot, class_def)
310
- slot_values[camelcase(class_def.name)][slot.name] = f'"{target_value}"'
311
- if slot.multivalued:
312
- slot_values[camelcase(class_def.name)][slot.name] = (
313
- "[" + slot_values[camelcase(class_def.name)][slot.name] + "]"
314
- )
315
- slot_values[camelcase(class_def.name)][slot.name] = slot_values[camelcase(class_def.name)][
316
- slot.name
317
- ]
318
- elif slot.ifabsent is not None:
319
- value = ifabsent_value_declaration(slot.ifabsent, sv, class_def, slot)
320
- slot_values[camelcase(class_def.name)][slot.name] = value
321
- # Multivalued slots that are either not inlined (just an identifier) or are
322
- # inlined as lists should get default_factory list, if they're inlined but
323
- # not as a list, that means a dictionary
324
- elif "linkml:elements" in slot.implements:
325
- slot_values[camelcase(class_def.name)][slot.name] = None
326
- elif slot.multivalued:
327
- has_identifier_slot = self.range_class_has_identifier_slot(slot)
328
-
329
- if slot.inlined and not slot.inlined_as_list and has_identifier_slot:
330
- slot_values[camelcase(class_def.name)][slot.name] = "default_factory=dict"
331
- else:
332
- slot_values[camelcase(class_def.name)][slot.name] = "default_factory=list"
333
-
334
- return slot_values
335
-
336
- def range_class_has_identifier_slot(self, slot):
337
- """
338
- Check if the range class of a slot has an identifier slot, via both slot.any_of and slot.range
339
- Should return False if the range is not a class, and also if the range is a class but has no
340
- identifier slot
341
-
342
- :param slot: SlotDefinition
343
- :return: bool
344
- """
345
- sv = self.schemaview
346
- has_identifier_slot = False
347
- if slot.any_of:
348
- for slot_range in slot.any_of:
349
- any_of_range = slot_range.range
350
- if any_of_range in sv.all_classes() and sv.get_identifier_slot(any_of_range, use_key=True) is not None:
351
- has_identifier_slot = True
352
- if slot.range in sv.all_classes() and sv.get_identifier_slot(slot.range, use_key=True) is not None:
353
- has_identifier_slot = True
354
- return has_identifier_slot
355
-
356
- def get_class_isa_plus_mixins(self) -> Dict[str, List[str]]:
527
+ if self._predefined_slot_values is None:
528
+ sv = self.schemaview
529
+ slot_values = defaultdict(dict)
530
+ for class_def in sv.all_classes().values():
531
+ for slot_name in sv.class_slots(class_def.name):
532
+ slot = sv.induced_slot(slot_name, class_def.name)
533
+ if slot.designates_type:
534
+ target_value = get_type_designator_value(sv, slot, class_def)
535
+ slot_values[camelcase(class_def.name)][slot.name] = f'"{target_value}"'
536
+ if slot.multivalued:
537
+ slot_values[camelcase(class_def.name)][slot.name] = (
538
+ "[" + slot_values[camelcase(class_def.name)][slot.name] + "]"
539
+ )
540
+ slot_values[camelcase(class_def.name)][slot.name] = slot_values[camelcase(class_def.name)][
541
+ slot.name
542
+ ]
543
+ elif slot.ifabsent is not None:
544
+ value = ifabsent_value_declaration(slot.ifabsent, sv, class_def, slot)
545
+ slot_values[camelcase(class_def.name)][slot.name] = value
546
+
547
+ self._predefined_slot_values = slot_values
548
+
549
+ return self._predefined_slot_values
550
+
551
+ @property
552
+ def class_bases(self) -> Dict[str, List[str]]:
357
553
  """
358
554
  Generate the inheritance list for each class from is_a plus mixins
359
555
  :return:
360
556
  """
361
- sv = self.schemaview
362
- parents = {}
363
- for class_def in sv.all_classes().values():
364
- class_parents = []
365
- if class_def.is_a:
366
- class_parents.append(camelcase(class_def.is_a))
367
- if self.gen_mixin_inheritance and class_def.mixins:
368
- class_parents.extend([camelcase(mixin) for mixin in class_def.mixins])
369
- if len(class_parents) > 0:
370
- # Use the sorted list of classes to order the parent classes, but reversed to match MRO needs
371
- class_parents.sort(key=lambda x: self.sorted_class_names.index(x))
372
- class_parents.reverse()
373
- parents[camelcase(class_def.name)] = class_parents
374
- return parents
557
+ if self._class_bases is None:
558
+ sv = self.schemaview
559
+ parents = {}
560
+ for class_def in sv.all_classes().values():
561
+ class_parents = []
562
+ if class_def.is_a:
563
+ class_parents.append(camelcase(class_def.is_a))
564
+ if self.gen_mixin_inheritance and class_def.mixins:
565
+ class_parents.extend([camelcase(mixin) for mixin in class_def.mixins])
566
+ if len(class_parents) > 0:
567
+ # Use the sorted list of classes to order the parent classes, but reversed to match MRO needs
568
+ class_parents.sort(
569
+ key=lambda x: self.sorted_class_names.index(x) if x in self.sorted_class_names else -1
570
+ )
571
+ class_parents.reverse()
572
+ parents[camelcase(class_def.name)] = class_parents
573
+ self._class_bases = parents
574
+ return self._class_bases
375
575
 
376
576
  def get_mixin_identifier_range(self, mixin) -> str:
377
577
  sv = self.schemaview
@@ -442,6 +642,10 @@ class PydanticGenerator(OOCodeGenerator):
442
642
  + ",".join(['"' + x + '"' for x in get_accepted_type_designator_values(sv, slot_def, class_def)])
443
643
  + "]"
444
644
  )
645
+ elif slot_def.equals_string:
646
+ pyrange = f'Literal["{slot_def.equals_string}"]'
647
+ elif slot_def.equals_string_in:
648
+ pyrange = "Literal[" + ", ".join([f'"{a_string}"' for a_string in slot_def.equals_string_in]) + "]"
445
649
  elif slot_range in sv.all_classes():
446
650
  pyrange = self.get_class_slot_range(
447
651
  slot_range,
@@ -497,8 +701,18 @@ class PydanticGenerator(OOCodeGenerator):
497
701
  return list(collection_keys)[0]
498
702
  return None
499
703
 
500
- @staticmethod
501
- def _inline_as_simple_dict_with_value(slot_def: SlotDefinition, sv: SchemaView) -> Optional[str]:
704
+ def _clean_injected_classes(self, injected_classes: List[Union[str, Type]]) -> Optional[List[str]]:
705
+ """Get source, deduplicate, and dedent injected classes"""
706
+ if len(injected_classes) == 0:
707
+ return None
708
+
709
+ injected_classes = list(
710
+ dict.fromkeys([c if isinstance(c, str) else inspect.getsource(c) for c in injected_classes])
711
+ )
712
+ injected_classes = [textwrap.dedent(c) for c in injected_classes]
713
+ return injected_classes
714
+
715
+ def _inline_as_simple_dict_with_value(self, slot_def: SlotDefinition) -> Optional[str]:
502
716
  """
503
717
  Determine if a slot should be inlined as a simple dict with a value.
504
718
 
@@ -521,21 +735,21 @@ class PydanticGenerator(OOCodeGenerator):
521
735
  :return: str
522
736
  """
523
737
  if slot_def.inlined and not slot_def.inlined_as_list:
524
- if slot_def.range in sv.all_classes():
525
- id_slot = sv.get_identifier_slot(slot_def.range, use_key=True)
738
+ if slot_def.range in self.schemaview.all_classes():
739
+ id_slot = self.schemaview.get_identifier_slot(slot_def.range, use_key=True)
526
740
  if id_slot is not None:
527
- range_cls_slots = sv.class_induced_slots(slot_def.range)
741
+ range_cls_slots = self.schemaview.class_induced_slots(slot_def.range)
528
742
  if len(range_cls_slots) == 2:
529
743
  non_id_slots = [slot for slot in range_cls_slots if slot.name != id_slot.name]
530
744
  if len(non_id_slots) == 1:
531
745
  value_slot = non_id_slots[0]
532
- value_slot_range_type = sv.get_type(value_slot.range)
746
+ value_slot_range_type = self.schemaview.get_type(value_slot.range)
533
747
  if value_slot_range_type is not None:
534
- return _get_pyrange(value_slot_range_type, sv)
748
+ return _get_pyrange(value_slot_range_type, self.schemaview)
535
749
  return None
536
750
 
537
751
  def _template_environment(self) -> Environment:
538
- env = TemplateModel.environment()
752
+ env = PydanticTemplateModel.environment()
539
753
  if self.template_dir is not None:
540
754
  loader = ChoiceLoader([FileSystemLoader(self.template_dir), env.loader])
541
755
  env.loader = loader
@@ -548,7 +762,7 @@ class PydanticGenerator(OOCodeGenerator):
548
762
  array_reps = []
549
763
  for repr in self.array_representations:
550
764
  generator = ArrayRangeGenerator.get_generator(repr)
551
- result = generator(slot.array, range, self.pydantic_version).make()
765
+ result = generator(slot.array, range).make()
552
766
  array_reps.append(result)
553
767
 
554
768
  if len(array_reps) == 0:
@@ -572,7 +786,7 @@ class PydanticGenerator(OOCodeGenerator):
572
786
  Metadata inclusion mode is dependent on :attr:`.metadata_mode` - see:
573
787
 
574
788
  - :class:`.MetadataMode`
575
- - :meth:`.TemplateModel.exclude_from_meta`
789
+ - :meth:`.PydanticTemplateModel.exclude_from_meta`
576
790
 
577
791
  """
578
792
  if self.metadata_mode is None or self.metadata_mode == MetadataMode.NONE:
@@ -591,13 +805,15 @@ class PydanticGenerator(OOCodeGenerator):
591
805
  continue
592
806
 
593
807
  model_attr = getattr(model, k)
594
- if isinstance(model_attr, list) and not any([isinstance(item, TemplateModel) for item in model_attr]):
808
+ if isinstance(model_attr, list) and not any(
809
+ [isinstance(item, PydanticTemplateModel) for item in model_attr]
810
+ ):
595
811
  meta[k] = v
596
812
  elif isinstance(model_attr, dict) and not any(
597
- [isinstance(item, TemplateModel) for item in model_attr.values()]
813
+ [isinstance(item, PydanticTemplateModel) for item in model_attr.values()]
598
814
  ):
599
815
  meta[k] = v
600
- elif not isinstance(model_attr, (list, dict, TemplateModel)):
816
+ elif not isinstance(model_attr, (list, dict, PydanticTemplateModel)):
601
817
  meta[k] = v
602
818
 
603
819
  elif self.metadata_mode in (MetadataMode.FULL, MetadataMode.FULL.value):
@@ -611,155 +827,145 @@ class PydanticGenerator(OOCodeGenerator):
611
827
  model.meta = meta
612
828
  return model
613
829
 
830
+ def _get_imports(self, element: Union[ClassDefinition, SlotDefinition, None] = None) -> Imports:
831
+ """
832
+ Get imports that are implied by their usage in slots or classes
833
+ (and thus need to be imported when generating schemas in :attr:`.split` == ``True`` mode).
834
+
835
+ **Note:**
836
+ Since in pydantic (currently) the only things that are materialized are classes, we don't
837
+ import class slots from imported schemas and abandon slots, directly expressing them
838
+ in the model.
839
+
840
+ This is a parent placeholder method in case that changes, "give me something and return
841
+ a set of imports" that calls subordinate methods. If slots become materialized, keep
842
+ this as the directly called method rather than spaghetti-ing out another
843
+ independent method. This method is also isolated in anticipation of structured imports,
844
+ where we will need to revise our expectations of what is imported when.
845
+
846
+ Args:
847
+ element (:class:`.ClassDefinition` , :class:`.SlotDefinition` , None): The element
848
+ to get import for. If ``None`` , get all needed imports (see :attr:`.split_mode`
849
+ """
850
+ # import from local references, rather than serializing every class in every file
851
+ if not self.split or (self.split_mode == SplitMode.FULL and element is not None):
852
+ # we are either compiling this whole thing in one big file (default)
853
+ # or going to import all classes from the imported schemas,
854
+ # so we don't import anything
855
+ return Imports()
856
+
857
+ # gather a list of class names,
858
+ # remove local classes and transform to Imports later.
859
+ needed_classes = []
860
+
861
+ # fine to call rather than pass bc it's cached
862
+ all_classes = self.schemaview.all_classes(imports=True)
863
+ local_classes = self.schemaview.all_classes(imports=False)
864
+
865
+ if isinstance(element, ClassDefinition):
866
+ if element.is_a:
867
+ needed_classes.append(element.is_a)
868
+ if element.mixins:
869
+ needed_classes.extend(element.mixins)
870
+
871
+ elif isinstance(element, SlotDefinition):
872
+ # collapses `slot.range`, `slot.any_of`, and `slot.one_of` to a list
873
+ slot_ranges = self.schemaview.slot_range_as_union(element)
874
+ needed_classes.extend([a_range for a_range in slot_ranges if a_range in all_classes])
875
+
876
+ elif element is None:
877
+ # get all imports
878
+ needed_classes.extend([cls for cls in all_classes if cls not in local_classes])
879
+
880
+ else:
881
+ raise ValueError(f"Unsupported type of element to get imports from: f{type(element)}")
882
+
883
+ # SPECIAL CASE: classes that are not generated for structural reasons.
884
+ # TODO: Do we want to have a general means of skipping class generation?
885
+ skips = ("AnyType",)
886
+
887
+ class_imports = [
888
+ self._get_element_import(cls) for cls in needed_classes if (cls not in local_classes and cls not in skips)
889
+ ]
890
+ imports = Imports(imports=class_imports)
891
+
892
+ return imports
893
+
894
+ def generate_module_import(self, schema: SchemaDefinition, context: Optional[dict] = None) -> str:
895
+ """
896
+ Generate the module string for importing from python modules generated from imported schemas
897
+ when in :attr:`.split` mode.
898
+
899
+ Use the :attr:`.split_pattern` as a jinja template rendered with the :class:`.SchemaDefinition`
900
+ and any passed ``context``. Apply the :attr:`.SNAKE_CASE` regex to substitute matches with
901
+ ``_`` and ensure lowercase.
902
+ """
903
+ if context is None:
904
+ context = {}
905
+ module = Template(self.split_pattern).render(schema=schema, **context)
906
+ module = re.sub(self.SNAKE_CASE, "_", module) if self.SNAKE_CASE else module
907
+ module = module.lower()
908
+ return module
909
+
910
+ def _get_element_import(self, class_name: ElementName) -> Import:
911
+ """
912
+ Make an import object for an element from another schema, using the
913
+ :attr:`.split_import_pattern` to generate the module import part.
914
+ """
915
+ schema_name = self.schemaview.element_by_schema_map()[class_name]
916
+ schema = [s for s in self.schemaview.schema_map.values() if s.name == schema_name][0]
917
+ module = self.generate_module_import(schema, self.split_context)
918
+ return Import(module=module, objects=[ObjectImport(name=camelcase(class_name))], is_schema=True)
919
+
614
920
  def render(self) -> PydanticModule:
615
921
  sv: SchemaView
616
922
  sv = self.schemaview
617
- schema = sv.schema
618
- pyschema = SchemaDefinition(
619
- id=schema.id,
620
- name=schema.name,
621
- description=schema.description.replace('"', '\\"') if schema.description else None,
622
- )
623
- enums = self.generate_enums(sv.all_enums())
624
- injected_classes = copy(DEFAULT_INJECTS[self.pydantic_version])
625
- if self.injected_classes is not None:
626
- injected_classes += self.injected_classes
627
923
 
924
+ # imports
628
925
  imports = DEFAULT_IMPORTS
629
926
  if self.imports is not None:
630
927
  for i in self.imports:
631
928
  imports += i
929
+ if self.split_mode == SplitMode.FULL:
930
+ imports += self._get_imports()
632
931
 
633
- sorted_classes = self.sort_classes(list(sv.all_classes().values()))
634
- self.sorted_class_names = [camelcase(c.name) for c in sorted_classes]
932
+ # injected classes
933
+ injected_classes = DEFAULT_INJECTS.copy()
934
+ if self.injected_classes is not None:
935
+ injected_classes += self.injected_classes.copy()
936
+
937
+ # enums
938
+ enums = self.before_generate_enums(list(sv.all_enums().values()), sv)
939
+ enums = self.generate_enums({e.name: e for e in enums})
940
+
941
+ base_model = PydanticBaseModel(extra_fields=self.extra_fields, fields=self.injected_fields)
635
942
 
943
+ # schema classes
944
+ class_results = []
945
+ source_classes, imported_classes = self._get_classes(sv)
946
+ source_classes = self.sort_classes(source_classes, imported_classes)
636
947
  # Don't want to generate classes when class_uri is linkml:Any, will
637
948
  # just swap in typing.Any instead down below
638
- sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"]
639
-
640
- for class_original in sorted_classes:
641
- class_def: ClassDefinition
642
- class_def = deepcopy(class_original)
643
- class_name = class_original.name
644
- class_def.name = camelcase(class_original.name)
645
- if class_def.is_a:
646
- class_def.is_a = camelcase(class_def.is_a)
647
- class_def.mixins = [camelcase(p) for p in class_def.mixins]
648
- if class_def.description:
649
- class_def.description = class_def.description.replace('"', '\\"')
650
- pyschema.classes[class_def.name] = class_def
651
- for attribute in list(class_def.attributes.keys()):
652
- del class_def.attributes[attribute]
653
- for sn in sv.class_slots(class_name):
654
- # TODO: fix runtime, copy should not be necessary
655
- s = deepcopy(sv.induced_slot(sn, class_name))
656
- # logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}')
657
- s.name = underscore(s.name)
658
- if s.description:
659
- s.description = s.description.replace('"', '\\"')
660
- class_def.attributes[s.name] = s
661
-
662
- slot_ranges: List[str] = []
663
-
664
- # Confirm that the original slot range (ignoring the default that comes in from
665
- # induced_slot) isn't in addition to setting any_of
666
- any_of_ranges = [a.range if a.range else s.range for a in s.any_of]
667
- if any_of_ranges:
668
- # list comprehension here is pulling ranges from within AnonymousSlotExpression
669
- slot_ranges.extend(any_of_ranges)
670
- else:
671
- slot_ranges.append(s.range)
949
+ source_classes = [c for c in source_classes if c.class_uri != "linkml:Any"]
950
+ source_classes = self.before_generate_classes(source_classes, sv)
951
+ self.sorted_class_names = [camelcase(c.name) for c in source_classes]
952
+ for cls in source_classes:
953
+ cls = self.before_generate_class(cls, sv)
954
+ result = self.generate_class(cls)
955
+ result = self.after_generate_class(result, sv)
956
+ class_results.append(result)
957
+ if result.imports is not None:
958
+ imports += result.imports
959
+ if result.injected_classes is not None:
960
+ injected_classes.extend(result.injected_classes)
672
961
 
673
- pyranges = [self.generate_python_range(slot_range, s, class_def) for slot_range in slot_ranges]
962
+ class_results = self.after_generate_classes(class_results, sv)
674
963
 
675
- pyranges = list(set(pyranges)) # remove duplicates
676
- pyranges.sort()
964
+ classes = {r.cls.name: r.cls for r in class_results}
677
965
 
678
- if len(pyranges) == 1:
679
- pyrange = pyranges[0]
680
- elif len(pyranges) > 1:
681
- pyrange = f"Union[{', '.join(pyranges)}]"
682
- else:
683
- raise Exception(f"Could not generate python range for {class_name}.{s.name}")
684
-
685
- if s.array is not None:
686
- # TODO add support for xarray
687
- results = self.get_array_representations_range(s, pyrange)
688
- # TODO: Move results unpacking to own function that is used after each slot build stage :)
689
- for res in results:
690
- if res.injected_classes:
691
- injected_classes += res.injected_classes
692
- if res.imports:
693
- imports += res.imports
694
- if len(results) == 1:
695
- pyrange = results[0].annotation
696
- else:
697
- pyrange = f"Union[{', '.join([res.annotation for res in results])}]"
698
-
699
- if "linkml:ColumnOrderedArray" in class_def.implements:
700
- raise NotImplementedError("Cannot generate Pydantic code for ColumnOrderedArrays.")
701
- elif s.multivalued:
702
- if s.inlined or s.inlined_as_list:
703
- collection_key = self.generate_collection_key(slot_ranges, s, class_def)
704
- else:
705
- collection_key = None
706
- if s.inlined is False or collection_key is None or s.inlined_as_list is True:
707
- pyrange = f"List[{pyrange}]"
708
- else:
709
- simple_dict_value = None
710
- if len(slot_ranges) == 1:
711
- simple_dict_value = self._inline_as_simple_dict_with_value(s, sv)
712
- if simple_dict_value:
713
- # inlining as simple dict
714
- pyrange = f"Dict[str, {simple_dict_value}]"
715
- else:
716
- pyrange = f"Dict[{collection_key}, {pyrange}]"
717
- if not (s.required or s.identifier or s.key) and not s.designates_type:
718
- pyrange = f"Optional[{pyrange}]"
719
- ann = Annotation("python_range", pyrange)
720
- s.annotations[ann.tag] = ann
721
-
722
- # TODO: Make cleaning injected classes its own method
723
- injected_classes = list(
724
- dict.fromkeys([c if isinstance(c, str) else inspect.getsource(c) for c in injected_classes])
725
- )
726
- injected_classes = [textwrap.dedent(c) for c in injected_classes]
727
-
728
- base_model = PydanticBaseModel(
729
- pydantic_ver=self.pydantic_version, extra_fields=self.extra_fields, fields=self.injected_fields
730
- )
731
-
732
- classes = {}
733
- predefined = self.get_predefined_slot_values()
734
- bases = self.get_class_isa_plus_mixins()
735
- for k, c in pyschema.classes.items():
736
- attrs = {}
737
- for attr_name, src_attr in c.attributes.items():
738
- src_attr = src_attr._as_dict
739
- new_fields = {
740
- k: src_attr.get(k, None)
741
- for k in PydanticAttribute.model_fields.keys()
742
- if src_attr.get(k, None) is not None
743
- }
744
- predef_slot = predefined.get(k, {}).get(attr_name, None)
745
- if predef_slot is not None:
746
- predef_slot = str(predef_slot)
747
- new_fields["predefined"] = predef_slot
748
- new_fields["name"] = attr_name
749
-
750
- attrs[attr_name] = PydanticAttribute(**new_fields, pydantic_ver=self.pydantic_version)
751
- attrs[attr_name] = self.include_metadata(attrs[attr_name], src_attr)
752
-
753
- new_class = PydanticClass(
754
- name=k, attributes=attrs, description=c.description, pydantic_ver=self.pydantic_version
755
- )
756
- new_class = self.include_metadata(new_class, c)
757
- if k in bases:
758
- new_class.bases = bases[k]
759
- classes[k] = new_class
966
+ injected_classes = self._clean_injected_classes(injected_classes)
760
967
 
761
968
  module = PydanticModule(
762
- pydantic_ver=self.pydantic_version,
763
969
  metamodel_version=self.schema.metamodel_version,
764
970
  version=self.schema.version,
765
971
  python_imports=imports.imports,
@@ -768,22 +974,166 @@ class PydanticGenerator(OOCodeGenerator):
768
974
  enums=enums,
769
975
  classes=classes,
770
976
  )
771
- module = self.include_metadata(module, schema)
977
+ module = self.include_metadata(module, self.schemaview.schema)
978
+ module = self.before_render_template(module, self.schemaview)
772
979
  return module
773
980
 
774
- def serialize(self) -> str:
775
- module = self.render()
776
- return module.render(self._template_environment(), self.black)
981
+ def serialize(self, rendered_module: Optional[PydanticModule] = None) -> str:
982
+ """
983
+ Serialize the schema to a pydantic module as a string
984
+
985
+ Args:
986
+ rendered_module ( :class:`.PydanticModule` ): Optional, if schema was previously
987
+ rendered with :meth:`.render` , use that, otherwise :meth:`.render` fresh.
988
+ """
989
+ if rendered_module is not None:
990
+ module = rendered_module
991
+ else:
992
+ module = self.render()
993
+ serialized = module.render(self._template_environment(), self.black)
994
+ serialized = self.after_render_template(serialized, self.schemaview)
995
+ return serialized
777
996
 
778
997
  def default_value_for_type(self, typ: str) -> str:
779
998
  return "None"
780
999
 
1000
+ @classmethod
1001
+ def generate_split(
1002
+ cls,
1003
+ schema: Union[str, Path, SchemaDefinition],
1004
+ output_path: Union[str, Path] = Path("."),
1005
+ split_pattern: Optional[str] = None,
1006
+ split_context: Optional[dict] = None,
1007
+ split_mode: SplitMode = SplitMode.AUTO,
1008
+ **kwargs,
1009
+ ) -> List[SplitResult]:
1010
+ """
1011
+ Generate a schema that imports from other schema as a set of python modules that
1012
+ import from one another, rather than generating all imported classes in a single schema.
1013
+
1014
+ Uses ``output_path`` for the main schema from ``schema`` , and then
1015
+ generates any imported schema (from which classes are actually used)
1016
+ to modules whose locations are determined by the module names generated
1017
+ by the ``split_pattern`` (see :attr:`.PydanticGenerator.split_pattern` ).
1018
+
1019
+ For example, for
1020
+
1021
+ * a ``output_path`` of ``my_dir/v1_2_3/main.py``
1022
+ * a schema ``main`` with a version ``v1.2.3``
1023
+ * that imports from ``s2`` with version ``v4.5.6``,
1024
+ * and a ``split_pattern`` of ``..{{ schema.version | replace('.', '_') }}.{{ schema.name }}``
1025
+
1026
+ One would get:
1027
+ * ``my_dir/v1_2_3/main.py`` , as expected
1028
+ * that imports ``from ..v4_5_6.s2``
1029
+ * a module at ``my_dir/v4_5_6/s2.py``
1030
+
1031
+ ``__init__.py`` files are generated for any directories that are between
1032
+ the generated modules and their highest common directory.
1033
+
1034
+ Args:
1035
+ schema (str, :class:`.Path` , :class:`.SchemaDefinition` ): Main schema to generate
1036
+ output_path (str, :class:`.Path` ): Python ``.py`` module to generate main schema to
1037
+ split_pattern (str): Pattern to use to generate module names, see :attr:`.PydanticGenerator.split_pattern`
1038
+ split_context (dict): Additional variables to pass into jinja context when generating module import names.
1039
+
1040
+ Returns:
1041
+ list[:class:`.SplitResult`]
1042
+ """
1043
+ output_path = Path(output_path)
1044
+ if not output_path.suffix == ".py":
1045
+ raise ValueError(f"output path must be a python file to write the main schema to, got {output_path}")
1046
+
1047
+ results = []
1048
+
1049
+ # --------------------------------------------------
1050
+ # Main schema
1051
+ # --------------------------------------------------
1052
+ gen_kwargs = kwargs
1053
+ gen_kwargs.update(
1054
+ {"split": True, "split_pattern": split_pattern, "split_context": split_context, "split_mode": split_mode}
1055
+ )
1056
+ generator = cls(schema, **gen_kwargs)
1057
+ # Generate the initial schema to figure out which of the imported schema actually need
1058
+ # to be generated
1059
+ rendered = generator.render()
1060
+ # write schema - we use the ``output_path`` for the main schema, and then
1061
+ # interpret all imported schema paths as relative to that
1062
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1063
+ serialized = generator.serialize(rendered_module=rendered)
1064
+ with open(output_path, "w") as ofile:
1065
+ ofile.write(serialized)
1066
+
1067
+ results.append(
1068
+ SplitResult(main=True, source=generator.schemaview.schema, path=output_path, serialized_module=serialized)
1069
+ )
1070
+
1071
+ # --------------------------------------------------
1072
+ # Imported schemas
1073
+ # --------------------------------------------------
1074
+ imported_schema = {
1075
+ generator.generate_module_import(sch): sch for sch in generator.schemaview.schema_map.values()
1076
+ }
1077
+ for generated_import in [i for i in rendered.python_imports if i.is_schema]:
1078
+ import_generator = cls(imported_schema[generated_import.module], **gen_kwargs)
1079
+ serialized = import_generator.serialize()
1080
+ rel_path = _import_to_path(generated_import.module)
1081
+ abs_path = (output_path.parent / rel_path).resolve()
1082
+ abs_path.parent.mkdir(parents=True, exist_ok=True)
1083
+ with open(abs_path, "w") as ofile:
1084
+ ofile.write(serialized)
1085
+
1086
+ results.append(
1087
+ SplitResult(
1088
+ main=False,
1089
+ source=imported_schema[generated_import.module],
1090
+ path=abs_path,
1091
+ serialized_module=serialized,
1092
+ module_import=generated_import.module,
1093
+ )
1094
+ )
1095
+
1096
+ _ensure_inits([r.path for r in results])
1097
+ return results
1098
+
781
1099
 
782
1100
  def _subclasses(cls: Type):
783
1101
  return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in _subclasses(c)])
784
1102
 
785
1103
 
786
- _TEMPLATE_NAMES = sorted(list(set([c.template for c in _subclasses(TemplateModel)])))
1104
+ _TEMPLATE_NAMES = sorted(list(set([c.template for c in _subclasses(PydanticTemplateModel)])))
1105
+
1106
+
1107
+ def _import_to_path(module: str) -> Path:
1108
+ """Make a (relative) ``Path`` object from a python module import string"""
1109
+ # handle leading .'s separately..
1110
+ _, dots, module = re.split(r"(^\.*)(?=\w)", module, maxsplit=1)
1111
+ # treat zero or one dots as a relative import to the current directory
1112
+ dir_pieces = ["../" for _ in range(max(len(dots) - 1, 0))]
1113
+ dir_pieces.extend(module.split("."))
1114
+ dir_pieces[-1] = dir_pieces[-1] + ".py"
1115
+ return Path(*dir_pieces)
1116
+
1117
+
1118
+ def _ensure_inits(paths: List[Path]):
1119
+ """For a set of paths, find the common root and it and all the subdirectories have an __init__.py"""
1120
+ # if there is only one file, there is no relative importing to be done
1121
+ if len(paths) <= 1:
1122
+ return
1123
+ common_path = Path(os.path.commonpath(paths))
1124
+
1125
+ if not (ipath := (common_path / "__init__.py")).exists():
1126
+ with open(ipath, "w") as ifile:
1127
+ ifile.write(" \n")
1128
+
1129
+ for path in paths:
1130
+ # ensure __init__ for each directory from this path up to the common path
1131
+ path = path.parent
1132
+ while path != common_path:
1133
+ if not (ipath := (path / "__init__.py")).exists():
1134
+ with open(ipath, "w") as ifile:
1135
+ ifile.write(" \n")
1136
+ path = path.parent
787
1137
 
788
1138
 
789
1139
  @shared_arguments(PydanticGenerator)
@@ -795,22 +1145,16 @@ _TEMPLATE_NAMES = sorted(list(set([c.template for c in _subclasses(TemplateModel
795
1145
  Optional jinja2 template directory to use for class generation.
796
1146
 
797
1147
  Pass a directory containing templates with the same name as any of the default
798
- :class:`.TemplateModel` templates to override them. The given directory will be
1148
+ :class:`.PydanticTemplateModel` templates to override them. The given directory will be
799
1149
  searched for matching templates, and use the default templates as a fallback
800
1150
  if an override is not found
801
-
1151
+
802
1152
  Available templates to override:
803
1153
 
804
1154
  \b
805
1155
  """
806
1156
  + "\n".join(["- " + name for name in _TEMPLATE_NAMES]),
807
1157
  )
808
- @click.option(
809
- "--pydantic-version",
810
- type=click.IntRange(1, 2),
811
- default=int(PYDANTIC_VERSION[0]),
812
- help="Pydantic version to use (1 or 2)",
813
- )
814
1158
  @click.option(
815
1159
  "--array-representations",
816
1160
  type=click.Choice([k.value for k in ArrayRepresentation]),
@@ -839,7 +1183,7 @@ Available templates to override:
839
1183
  "Default (auto) is to include all metadata that can't be otherwise represented",
840
1184
  )
841
1185
  @click.version_option(__version__, "-V", "--version")
842
- @click.command()
1186
+ @click.command(name="pydantic")
843
1187
  def cli(
844
1188
  yamlfile,
845
1189
  template_file=None,
@@ -849,7 +1193,6 @@ def cli(
849
1193
  classvars=True,
850
1194
  slots=True,
851
1195
  array_representations=list("list"),
852
- pydantic_version=int(PYDANTIC_VERSION[0]),
853
1196
  extra_fields: Literal["allow", "forbid", "ignore"] = "forbid",
854
1197
  black: bool = False,
855
1198
  meta: MetadataMode = "auto",
@@ -870,7 +1213,6 @@ def cli(
870
1213
 
871
1214
  gen = PydanticGenerator(
872
1215
  yamlfile,
873
- pydantic_version=pydantic_version,
874
1216
  array_representations=[ArrayRepresentation(x) for x in array_representations],
875
1217
  extra_fields=extra_fields,
876
1218
  emit_metadata=head,