linkml 1.9.4rc2__py3-none-any.whl → 1.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- linkml/cli/main.py +5 -1
- linkml/converter/__init__.py +0 -0
- linkml/generators/__init__.py +2 -0
- linkml/generators/common/build.py +5 -20
- linkml/generators/common/template.py +289 -3
- linkml/generators/docgen.py +55 -10
- linkml/generators/erdiagramgen.py +9 -5
- linkml/generators/graphqlgen.py +32 -6
- linkml/generators/jsonldcontextgen.py +78 -12
- linkml/generators/jsonschemagen.py +29 -12
- linkml/generators/mermaidclassdiagramgen.py +21 -3
- linkml/generators/owlgen.py +13 -2
- linkml/generators/panderagen/dataframe_class.py +13 -0
- linkml/generators/panderagen/dataframe_field.py +50 -0
- linkml/generators/panderagen/linkml_pandera_validator.py +186 -0
- linkml/generators/panderagen/panderagen.py +22 -5
- linkml/generators/panderagen/panderagen_class_based/class.jinja2 +70 -13
- linkml/generators/panderagen/panderagen_class_based/custom_checks.jinja2 +27 -0
- linkml/generators/panderagen/panderagen_class_based/enums.jinja2 +3 -3
- linkml/generators/panderagen/panderagen_class_based/pandera.jinja2 +12 -2
- linkml/generators/panderagen/panderagen_class_based/slots.jinja2 +19 -17
- linkml/generators/panderagen/slot_generator_mixin.py +143 -16
- linkml/generators/panderagen/transforms/__init__.py +19 -0
- linkml/generators/panderagen/transforms/collection_dict_model_transform.py +62 -0
- linkml/generators/panderagen/transforms/list_dict_model_transform.py +66 -0
- linkml/generators/panderagen/transforms/model_transform.py +8 -0
- linkml/generators/panderagen/transforms/nested_struct_model_transform.py +27 -0
- linkml/generators/panderagen/transforms/simple_dict_model_transform.py +86 -0
- linkml/generators/plantumlgen.py +17 -11
- linkml/generators/pydanticgen/pydanticgen.py +53 -2
- linkml/generators/pydanticgen/template.py +45 -233
- linkml/generators/pydanticgen/templates/attribute.py.jinja +1 -0
- linkml/generators/pydanticgen/templates/base_model.py.jinja +16 -2
- linkml/generators/pydanticgen/templates/imports.py.jinja +1 -1
- linkml/generators/rdfgen.py +11 -2
- linkml/generators/rustgen/__init__.py +3 -0
- linkml/generators/rustgen/build.py +97 -0
- linkml/generators/rustgen/cli.py +83 -0
- linkml/generators/rustgen/rustgen.py +1186 -0
- linkml/generators/rustgen/template.py +910 -0
- linkml/generators/rustgen/templates/Cargo.toml.jinja +42 -0
- linkml/generators/rustgen/templates/anything.rs.jinja +149 -0
- linkml/generators/rustgen/templates/as_key_value.rs.jinja +86 -0
- linkml/generators/rustgen/templates/class_module.rs.jinja +8 -0
- linkml/generators/rustgen/templates/enum.rs.jinja +70 -0
- linkml/generators/rustgen/templates/file.rs.jinja +75 -0
- linkml/generators/rustgen/templates/import.rs.jinja +4 -0
- linkml/generators/rustgen/templates/imports.rs.jinja +8 -0
- linkml/generators/rustgen/templates/lib_shim.rs.jinja +52 -0
- linkml/generators/rustgen/templates/poly.rs.jinja +9 -0
- linkml/generators/rustgen/templates/poly_containers.rs.jinja +439 -0
- linkml/generators/rustgen/templates/poly_trait.rs.jinja +15 -0
- linkml/generators/rustgen/templates/poly_trait_impl.rs.jinja +5 -0
- linkml/generators/rustgen/templates/poly_trait_impl_orsubtype.rs.jinja +5 -0
- linkml/generators/rustgen/templates/poly_trait_property.rs.jinja +8 -0
- linkml/generators/rustgen/templates/poly_trait_property_impl.rs.jinja +134 -0
- linkml/generators/rustgen/templates/poly_trait_property_match.rs.jinja +10 -0
- linkml/generators/rustgen/templates/property.rs.jinja +28 -0
- linkml/generators/rustgen/templates/pyproject.toml.jinja +10 -0
- linkml/generators/rustgen/templates/serde_utils.rs.jinja +490 -0
- linkml/generators/rustgen/templates/slot_range_as_union.rs.jinja +64 -0
- linkml/generators/rustgen/templates/struct.rs.jinja +81 -0
- linkml/generators/rustgen/templates/struct_or_subtype_enum.rs.jinja +111 -0
- linkml/generators/rustgen/templates/stub_gen.rs.jinja +71 -0
- linkml/generators/rustgen/templates/stub_utils.rs.jinja +76 -0
- linkml/generators/rustgen/templates/typealias.rs.jinja +13 -0
- linkml/generators/sqltablegen.py +18 -16
- linkml/generators/yarrrmlgen.py +173 -0
- linkml/linter/config/datamodel/config.py +160 -293
- linkml/linter/config/datamodel/config.yaml +34 -26
- linkml/linter/config/default.yaml +4 -0
- linkml/linter/config/recommended.yaml +4 -0
- linkml/linter/linter.py +1 -2
- linkml/linter/rules.py +37 -0
- linkml/utils/schema_builder.py +2 -0
- linkml/utils/schemaloader.py +76 -3
- {linkml-1.9.4rc2.dist-info → linkml-1.9.5.dist-info}/METADATA +1 -1
- {linkml-1.9.4rc2.dist-info → linkml-1.9.5.dist-info}/RECORD +82 -40
- {linkml-1.9.4rc2.dist-info → linkml-1.9.5.dist-info}/entry_points.txt +2 -1
- linkml/generators/panderagen/panderagen_class_based/mixins.jinja2 +0 -26
- /linkml/{utils/converter.py → converter/cli.py} +0 -0
- {linkml-1.9.4rc2.dist-info → linkml-1.9.5.dist-info}/WHEEL +0 -0
- {linkml-1.9.4rc2.dist-info → linkml-1.9.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1186 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal, Optional, Union, overload
|
|
5
|
+
|
|
6
|
+
from jinja2 import Environment
|
|
7
|
+
from linkml_runtime.linkml_model.meta import (
|
|
8
|
+
ClassDefinition,
|
|
9
|
+
EnumDefinition,
|
|
10
|
+
PermissibleValue,
|
|
11
|
+
SlotDefinition,
|
|
12
|
+
TypeDefinition,
|
|
13
|
+
)
|
|
14
|
+
from linkml_runtime.utils.formatutils import camelcase, uncamelcase, underscore
|
|
15
|
+
from linkml_runtime.utils.schemaview import OrderedBy, SchemaView
|
|
16
|
+
|
|
17
|
+
from linkml.generators.common.lifecycle import LifecycleMixin
|
|
18
|
+
from linkml.generators.common.template import ObjectImport
|
|
19
|
+
from linkml.generators.common.type_designators import get_accepted_type_designator_values
|
|
20
|
+
from linkml.generators.rustgen.build import (
|
|
21
|
+
AttributeResult,
|
|
22
|
+
ClassResult,
|
|
23
|
+
CrateResult,
|
|
24
|
+
EnumResult,
|
|
25
|
+
FileResult,
|
|
26
|
+
SlotResult,
|
|
27
|
+
TypeResult,
|
|
28
|
+
)
|
|
29
|
+
from linkml.generators.rustgen.template import (
|
|
30
|
+
AsKeyValue,
|
|
31
|
+
ContainerType,
|
|
32
|
+
Import,
|
|
33
|
+
Imports,
|
|
34
|
+
PolyContainersFile,
|
|
35
|
+
PolyFile,
|
|
36
|
+
PolyTrait,
|
|
37
|
+
PolyTraitImpl,
|
|
38
|
+
PolyTraitImplForSubtypeEnum,
|
|
39
|
+
PolyTraitProperty,
|
|
40
|
+
PolyTraitPropertyImpl,
|
|
41
|
+
PolyTraitPropertyMatch,
|
|
42
|
+
RustCargo,
|
|
43
|
+
RustClassModule,
|
|
44
|
+
RustEnum,
|
|
45
|
+
RustEnumItem,
|
|
46
|
+
RustFile,
|
|
47
|
+
RustLibShim,
|
|
48
|
+
RustProperty,
|
|
49
|
+
RustPyProject,
|
|
50
|
+
RustRange,
|
|
51
|
+
RustStruct,
|
|
52
|
+
RustStructOrSubtypeEnum,
|
|
53
|
+
RustTemplateModel,
|
|
54
|
+
RustTypeAlias,
|
|
55
|
+
SerdeUtilsFile,
|
|
56
|
+
SlotRangeAsUnion,
|
|
57
|
+
StubGenBin,
|
|
58
|
+
StubUtilsFile,
|
|
59
|
+
)
|
|
60
|
+
from linkml.utils.generator import Generator
|
|
61
|
+
|
|
62
|
+
RUST_MODES = Literal["crate", "file"]
|
|
63
|
+
|
|
64
|
+
PYTHON_TO_RUST = {
|
|
65
|
+
int: "isize",
|
|
66
|
+
float: "f64",
|
|
67
|
+
str: "String",
|
|
68
|
+
bool: "bool",
|
|
69
|
+
"int": "isize",
|
|
70
|
+
"float": "f64",
|
|
71
|
+
"str": "String",
|
|
72
|
+
"String": "String",
|
|
73
|
+
"bool": "bool",
|
|
74
|
+
"Bool": "bool",
|
|
75
|
+
"XSDDate": "NaiveDate",
|
|
76
|
+
"date": "NaiveDate",
|
|
77
|
+
"XSDDateTime": "NaiveDateTime",
|
|
78
|
+
"datetime": "NaiveDateTime",
|
|
79
|
+
# "Decimal": "dec",
|
|
80
|
+
"Decimal": "f64",
|
|
81
|
+
}
|
|
82
|
+
"""
|
|
83
|
+
Mapping from python types to rust types.
|
|
84
|
+
|
|
85
|
+
.. todo::
|
|
86
|
+
|
|
87
|
+
- Add numpy types
|
|
88
|
+
- make an enum wrapper for naivedatetime and datetime<fixedoffset> that can represent both of them
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
PROTECTED_NAMES = ("type", "typeof", "abstract")
|
|
93
|
+
|
|
94
|
+
RUST_IMPORTS = {
|
|
95
|
+
"dec": Import(module="rust_decimal", version="1.36", objects=[ObjectImport(name="dec")]),
|
|
96
|
+
"NaiveDate": Import(
|
|
97
|
+
module="chrono", features=["serde"], version="0.4.41", objects=[ObjectImport(name="NaiveDate")]
|
|
98
|
+
),
|
|
99
|
+
"NaiveDateTime": Import(
|
|
100
|
+
module="chrono", features=["serde"], version="0.4.41", objects=[ObjectImport(name="NaiveDateTime")]
|
|
101
|
+
),
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
MERGE_ANNOTATION = "rust.linkml.io/generate/merge"
|
|
105
|
+
|
|
106
|
+
MERGE_IMPORTS = Imports(
|
|
107
|
+
imports=[Import(module="merge", version="0.2.0", objects=[ObjectImport(name="Merge")])],
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
DEFAULT_IMPORTS = Imports(
|
|
111
|
+
imports=[
|
|
112
|
+
Import(module="std::collections", objects=[ObjectImport(name="HashMap")]),
|
|
113
|
+
# Import(module="std::fmt", objects=[ObjectImport(name="Display")]),
|
|
114
|
+
]
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
SERDE_IMPORTS = Imports(
|
|
118
|
+
imports=[
|
|
119
|
+
Import(
|
|
120
|
+
module="serde",
|
|
121
|
+
version="1.0",
|
|
122
|
+
features=["derive"],
|
|
123
|
+
objects=[
|
|
124
|
+
ObjectImport(name="Serialize"),
|
|
125
|
+
ObjectImport(name="Deserialize"),
|
|
126
|
+
ObjectImport(name="de::IntoDeserializer"),
|
|
127
|
+
],
|
|
128
|
+
feature_flag="serde",
|
|
129
|
+
),
|
|
130
|
+
Import(module="serde-value", version="0.7.0", objects=[ObjectImport(name="Value")]),
|
|
131
|
+
Import(module="serde_yml", version="0.0.12", feature_flag="serde", alias="_"),
|
|
132
|
+
Import(module="serde_path_to_error", version="0.1.17", objects=[], feature_flag="serde"),
|
|
133
|
+
]
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
PYTHON_IMPORTS = Imports(
|
|
137
|
+
imports=[
|
|
138
|
+
Import(
|
|
139
|
+
module="pyo3",
|
|
140
|
+
version="0.25.0",
|
|
141
|
+
objects=[ObjectImport(name="prelude::*"), ObjectImport(name="FromPyObject")],
|
|
142
|
+
feature_flag="pyo3",
|
|
143
|
+
features=["chrono"],
|
|
144
|
+
),
|
|
145
|
+
# Import(module="serde_pyobject", version="0.6.1", objects=[], feature_flag="pyo3", features=[]),
|
|
146
|
+
]
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
STUBGEN_IMPORTS = Imports(
|
|
150
|
+
imports=[
|
|
151
|
+
Import(
|
|
152
|
+
module="pyo3-stub-gen",
|
|
153
|
+
version="0.13.1",
|
|
154
|
+
objects=[
|
|
155
|
+
ObjectImport(name="define_stub_info_gatherer"),
|
|
156
|
+
ObjectImport(name="derive::gen_stub_pyclass"),
|
|
157
|
+
ObjectImport(name="derive::gen_stub_pymethods"),
|
|
158
|
+
],
|
|
159
|
+
feature_flag="stubgen",
|
|
160
|
+
feature_dependencies=["pyo3"],
|
|
161
|
+
),
|
|
162
|
+
]
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class SlotContainerMode(Enum):
|
|
167
|
+
SINGLE_VALUE = "single_value"
|
|
168
|
+
MAPPING = "mapping"
|
|
169
|
+
LIST = "list"
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class SlotInlineMode(Enum):
|
|
173
|
+
INLINE = "inline"
|
|
174
|
+
PRIMITIVE = "primitive"
|
|
175
|
+
REFERENCE = "reference"
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def get_key_or_identifier_slot(cls: ClassDefinition, sv: SchemaView) -> Optional[SlotDefinition]:
|
|
179
|
+
induced_slots = sv.class_induced_slots(cls.name)
|
|
180
|
+
for slot in induced_slots:
|
|
181
|
+
if slot.identifier or slot.key:
|
|
182
|
+
return slot
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def get_identifier_slot(cls: ClassDefinition, sv: SchemaView) -> Optional[SlotDefinition]:
|
|
187
|
+
induced_slots = sv.class_induced_slots(cls.name)
|
|
188
|
+
for slot in induced_slots:
|
|
189
|
+
if slot.identifier:
|
|
190
|
+
return slot
|
|
191
|
+
return None
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def class_real_descendants(sv: SchemaView, class_name: str) -> list[str]:
|
|
195
|
+
"""Return true descendants of a class, excluding the class itself.
|
|
196
|
+
|
|
197
|
+
Some SchemaView implementations include the class in `class_descendants`.
|
|
198
|
+
We normalize here to avoid off-by-one errors when deciding if a class has
|
|
199
|
+
subtypes (for OrSubtype generation and trait typing decisions).
|
|
200
|
+
"""
|
|
201
|
+
try:
|
|
202
|
+
descs = list(sv.class_descendants(class_name))
|
|
203
|
+
except Exception:
|
|
204
|
+
descs = []
|
|
205
|
+
return [d for d in descs if d != class_name]
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def has_real_subtypes(sv: SchemaView, class_name: str) -> bool:
|
|
209
|
+
"""True when the class has at least one real subtype (excluding itself)."""
|
|
210
|
+
return len(class_real_descendants(sv, class_name)) > 0
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def determine_slot_mode(s: SlotDefinition, sv: SchemaView) -> tuple[SlotContainerMode, SlotInlineMode]:
|
|
214
|
+
"""Return container and inline modes for a slot."""
|
|
215
|
+
|
|
216
|
+
class_range = s.range in sv.all_classes()
|
|
217
|
+
if not class_range:
|
|
218
|
+
return (
|
|
219
|
+
SlotContainerMode.LIST if s.multivalued else SlotContainerMode.SINGLE_VALUE,
|
|
220
|
+
SlotInlineMode.PRIMITIVE,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if s.multivalued and s.inlined_as_list:
|
|
224
|
+
return (SlotContainerMode.LIST, SlotInlineMode.INLINE)
|
|
225
|
+
|
|
226
|
+
key_slot = get_key_or_identifier_slot(sv.get_class(s.range), sv)
|
|
227
|
+
identifier_slot = get_identifier_slot(sv.get_class(s.range), sv)
|
|
228
|
+
inlined = s.inlined
|
|
229
|
+
if identifier_slot is None:
|
|
230
|
+
# can only inline if identifier slot is none
|
|
231
|
+
inlined = True
|
|
232
|
+
|
|
233
|
+
if not s.multivalued:
|
|
234
|
+
return (
|
|
235
|
+
SlotContainerMode.SINGLE_VALUE,
|
|
236
|
+
SlotInlineMode.INLINE if inlined else SlotInlineMode.REFERENCE,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if not inlined:
|
|
240
|
+
return (SlotContainerMode.LIST, SlotInlineMode.REFERENCE)
|
|
241
|
+
|
|
242
|
+
if key_slot is not None:
|
|
243
|
+
return (SlotContainerMode.MAPPING, SlotInlineMode.INLINE)
|
|
244
|
+
else:
|
|
245
|
+
return (SlotContainerMode.LIST, SlotInlineMode.INLINE)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def can_contain_reference_to_class(s: SlotDefinition, cls: ClassDefinition, sv: SchemaView) -> bool:
|
|
249
|
+
ref_name = cls.name
|
|
250
|
+
seen_classes = set()
|
|
251
|
+
classes_to_check = [s.range]
|
|
252
|
+
while len(classes_to_check) > 0:
|
|
253
|
+
a_class = classes_to_check.pop()
|
|
254
|
+
seen_classes.add(a_class)
|
|
255
|
+
if a_class not in sv.all_classes():
|
|
256
|
+
continue
|
|
257
|
+
if a_class == ref_name:
|
|
258
|
+
return True
|
|
259
|
+
induced_class = sv.induced_class(a_class)
|
|
260
|
+
for attr in induced_class.attributes.values():
|
|
261
|
+
if attr.range not in seen_classes:
|
|
262
|
+
classes_to_check.append(attr.range)
|
|
263
|
+
return False
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def get_rust_type(
|
|
267
|
+
t: Union[TypeDefinition, type, str], sv: SchemaView, pyo3: bool = False, crate_ref: Optional[str] = None
|
|
268
|
+
) -> str:
|
|
269
|
+
"""
|
|
270
|
+
Get the rust type from a given linkml type
|
|
271
|
+
"""
|
|
272
|
+
rsrange = None
|
|
273
|
+
no_add_crate = False
|
|
274
|
+
|
|
275
|
+
if isinstance(t, TypeDefinition):
|
|
276
|
+
rsrange = t.base
|
|
277
|
+
if rsrange is not None and rsrange not in PYTHON_TO_RUST:
|
|
278
|
+
# A type like URIorCURIE which is an alias for a rust type
|
|
279
|
+
rsrange = get_name(t)
|
|
280
|
+
|
|
281
|
+
elif rsrange is None and t.typeof is not None:
|
|
282
|
+
# A type with no base type,
|
|
283
|
+
no_add_crate = True
|
|
284
|
+
rsrange = get_rust_type(sv.get_type(t.typeof), sv, pyo3)
|
|
285
|
+
|
|
286
|
+
elif isinstance(t, str):
|
|
287
|
+
if tdef := sv.all_types().get(t, None):
|
|
288
|
+
rsrange = get_rust_type(tdef, sv, pyo3)
|
|
289
|
+
no_add_crate = True
|
|
290
|
+
elif t in sv.all_enums():
|
|
291
|
+
# Map LinkML enums to generated Rust enums rather than collapsing to String
|
|
292
|
+
e = sv.get_enum(t)
|
|
293
|
+
rsrange = get_name(e)
|
|
294
|
+
no_add_crate = True
|
|
295
|
+
elif t in sv.all_classes():
|
|
296
|
+
c = sv.get_class(t)
|
|
297
|
+
rsrange = get_name(c)
|
|
298
|
+
|
|
299
|
+
# FIXME: Raise here once we have implemented all base types
|
|
300
|
+
if rsrange is None:
|
|
301
|
+
rsrange = PYTHON_TO_RUST[str]
|
|
302
|
+
elif rsrange in PYTHON_TO_RUST:
|
|
303
|
+
rsrange = PYTHON_TO_RUST[rsrange]
|
|
304
|
+
elif crate_ref is not None and not no_add_crate:
|
|
305
|
+
rsrange = f"{crate_ref}::{rsrange}"
|
|
306
|
+
return rsrange
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def get_rust_range_info(
|
|
310
|
+
cls: ClassDefinition, s: SlotDefinition, sv: SchemaView, crate_ref: Optional[str] = None
|
|
311
|
+
) -> RustRange:
|
|
312
|
+
(container_mode, inline_mode) = determine_slot_mode(s, sv)
|
|
313
|
+
all_ranges = sv.slot_range_as_union(s)
|
|
314
|
+
sub_ranges = [
|
|
315
|
+
RustRange(
|
|
316
|
+
type_="String" if inline_mode == SlotInlineMode.REFERENCE else get_rust_type(r, sv, True, crate_ref),
|
|
317
|
+
is_class_range=r in sv.all_classes(),
|
|
318
|
+
has_class_subtypes=has_real_subtypes(sv, r) if r in sv.all_classes() else False,
|
|
319
|
+
)
|
|
320
|
+
for r in all_ranges
|
|
321
|
+
]
|
|
322
|
+
|
|
323
|
+
res = RustRange(
|
|
324
|
+
optional=not s.required,
|
|
325
|
+
has_default=not (s.required or False) or (s.multivalued or False),
|
|
326
|
+
containerType=(
|
|
327
|
+
ContainerType.LIST
|
|
328
|
+
if container_mode == SlotContainerMode.LIST
|
|
329
|
+
else ContainerType.MAPPING
|
|
330
|
+
if container_mode == SlotContainerMode.MAPPING
|
|
331
|
+
else None
|
|
332
|
+
),
|
|
333
|
+
child_ranges=sub_ranges if len(sub_ranges) > 1 else None,
|
|
334
|
+
box_needed=inline_mode == SlotInlineMode.INLINE and can_contain_reference_to_class(s, cls, sv),
|
|
335
|
+
is_class_range=all_ranges[0] in sv.all_classes() if len(all_ranges) == 1 else False,
|
|
336
|
+
is_reference=inline_mode == SlotInlineMode.REFERENCE,
|
|
337
|
+
has_class_subtypes=(
|
|
338
|
+
has_real_subtypes(sv, all_ranges[0])
|
|
339
|
+
if (len(all_ranges) == 1 and all_ranges[0] in sv.all_classes())
|
|
340
|
+
else False
|
|
341
|
+
),
|
|
342
|
+
type_=(
|
|
343
|
+
underscore(uncamelcase(cls.name)) + "_utl::" + get_name(s) + "_range"
|
|
344
|
+
if len(sub_ranges) > 1
|
|
345
|
+
else ("String" if inline_mode == SlotInlineMode.REFERENCE else get_rust_type(s.range, sv, True, crate_ref))
|
|
346
|
+
),
|
|
347
|
+
)
|
|
348
|
+
return res
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def protect_name(v: str) -> str:
|
|
352
|
+
"""
|
|
353
|
+
append an underscore to a protected name
|
|
354
|
+
"""
|
|
355
|
+
if v in PROTECTED_NAMES:
|
|
356
|
+
v = f"{v}_"
|
|
357
|
+
return v
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def get_name(e: Union[ClassDefinition, SlotDefinition, EnumDefinition, PermissibleValue, TypeDefinition]) -> str:
|
|
361
|
+
if isinstance(e, (ClassDefinition, EnumDefinition)):
|
|
362
|
+
name = camelcase(e.name)
|
|
363
|
+
elif isinstance(e, PermissibleValue):
|
|
364
|
+
name = camelcase(e.text)
|
|
365
|
+
elif isinstance(e, (SlotDefinition, TypeDefinition)):
|
|
366
|
+
name = underscore(e.name)
|
|
367
|
+
else:
|
|
368
|
+
raise ValueError("Can only get the name from a slot or class!")
|
|
369
|
+
|
|
370
|
+
name = protect_name(name)
|
|
371
|
+
return name
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@dataclass
|
|
375
|
+
class RustGenerator(Generator, LifecycleMixin):
|
|
376
|
+
"""
|
|
377
|
+
Generate rust types from a linkml schema
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
generatorname = "rustgenerator"
|
|
381
|
+
generatorversion = "0.0.2"
|
|
382
|
+
valid_formats = ["rust"]
|
|
383
|
+
file_extension = "rs"
|
|
384
|
+
crate_name: Optional[str] = None
|
|
385
|
+
|
|
386
|
+
pyo3: bool = True
|
|
387
|
+
"""Generate pyO3 bindings for the rust defs"""
|
|
388
|
+
pyo3_version: str = ">=0.21.1"
|
|
389
|
+
serde: bool = True
|
|
390
|
+
"""Generate serde derive serialization/deserialization attributes"""
|
|
391
|
+
stubgen: bool = True
|
|
392
|
+
"""Generate pyo3-stub-gen instrumentation alongside PyO3 bindings"""
|
|
393
|
+
handwritten_lib: bool = False
|
|
394
|
+
"""Place generated sources under src/generated and leave src/lib.rs for user code"""
|
|
395
|
+
mode: RUST_MODES = "crate"
|
|
396
|
+
"""Generate a cargo.toml file"""
|
|
397
|
+
output: Optional[Path] = None
|
|
398
|
+
"""
|
|
399
|
+
* If ``mode == "crate"`` , a directory to contain the generated crate
|
|
400
|
+
* If ``mode == "file"`` , a file with a ``.rs`` extension
|
|
401
|
+
|
|
402
|
+
If output is not provided at object instantiation,
|
|
403
|
+
it must be provided on a call to :meth:`.serialize`
|
|
404
|
+
"""
|
|
405
|
+
|
|
406
|
+
_environment: Optional[Environment] = None
|
|
407
|
+
|
|
408
|
+
def __post_init__(self):
|
|
409
|
+
self.schemaview: SchemaView = SchemaView(self.schema)
|
|
410
|
+
super().__post_init__()
|
|
411
|
+
|
|
412
|
+
def _select_root_class(self, class_defs: list[ClassDefinition]) -> Optional[ClassDefinition]:
|
|
413
|
+
"""Return the schema-local class marked ``tree_root`` if present."""
|
|
414
|
+
|
|
415
|
+
schema_id = getattr(self.schemaview.schema, "id", None)
|
|
416
|
+
|
|
417
|
+
def is_local(cls: ClassDefinition) -> bool:
|
|
418
|
+
if schema_id is None:
|
|
419
|
+
return cls.from_schema is None
|
|
420
|
+
return cls.from_schema == schema_id
|
|
421
|
+
|
|
422
|
+
local_classes = [cls for cls in class_defs if is_local(cls) and not getattr(cls, "mixin", False)]
|
|
423
|
+
|
|
424
|
+
for cls in local_classes:
|
|
425
|
+
if getattr(cls, "tree_root", False):
|
|
426
|
+
return cls
|
|
427
|
+
|
|
428
|
+
return None
|
|
429
|
+
|
|
430
|
+
def generate_type(self, type_: TypeDefinition) -> TypeResult:
|
|
431
|
+
type_ = self.before_generate_type(type_, self.schemaview)
|
|
432
|
+
res = TypeResult(
|
|
433
|
+
source=type_,
|
|
434
|
+
type_=RustTypeAlias(
|
|
435
|
+
name=get_name(type_),
|
|
436
|
+
type_=get_rust_type(type_.base, self.schemaview, self.pyo3),
|
|
437
|
+
pyo3=self.pyo3,
|
|
438
|
+
stubgen=self.stubgen,
|
|
439
|
+
),
|
|
440
|
+
imports=self.get_imports(type_),
|
|
441
|
+
)
|
|
442
|
+
slot = self.after_generate_type(res, self.schemaview)
|
|
443
|
+
return slot
|
|
444
|
+
|
|
445
|
+
def generate_enum(self, enum: EnumDefinition) -> EnumResult:
|
|
446
|
+
enum = self.before_generate_enum(enum, self.schemaview)
|
|
447
|
+
items = [
|
|
448
|
+
RustEnumItem(
|
|
449
|
+
variant=get_name(pv),
|
|
450
|
+
text=pv.text or name,
|
|
451
|
+
)
|
|
452
|
+
for name, pv in enum.permissible_values.items()
|
|
453
|
+
]
|
|
454
|
+
res = EnumResult(
|
|
455
|
+
source=enum,
|
|
456
|
+
enum=RustEnum(
|
|
457
|
+
name=get_name(enum),
|
|
458
|
+
items=items,
|
|
459
|
+
pyo3=self.pyo3,
|
|
460
|
+
serde=self.serde,
|
|
461
|
+
stubgen=self.stubgen,
|
|
462
|
+
),
|
|
463
|
+
)
|
|
464
|
+
res = self.after_generate_enum(res, self.schemaview)
|
|
465
|
+
return res
|
|
466
|
+
|
|
467
|
+
def generate_slot(self, slot: SlotDefinition) -> SlotResult:
|
|
468
|
+
"""
|
|
469
|
+
Generate a slot as a struct field
|
|
470
|
+
"""
|
|
471
|
+
slot = self.before_generate_slot(slot, self.schemaview)
|
|
472
|
+
class_range = slot.range in self.schemaview.all_classes()
|
|
473
|
+
type_ = get_rust_type(slot.range, self.schemaview, self.pyo3)
|
|
474
|
+
|
|
475
|
+
slot = SlotResult(
|
|
476
|
+
source=slot,
|
|
477
|
+
slot=RustTypeAlias(
|
|
478
|
+
name=get_name(slot),
|
|
479
|
+
type_=type_,
|
|
480
|
+
multivalued=slot.multivalued,
|
|
481
|
+
pyo3=self.pyo3,
|
|
482
|
+
class_range=class_range,
|
|
483
|
+
stubgen=self.stubgen,
|
|
484
|
+
),
|
|
485
|
+
imports=self.get_imports(slot),
|
|
486
|
+
)
|
|
487
|
+
slot = self.after_generate_slot(slot, self.schemaview)
|
|
488
|
+
return slot
|
|
489
|
+
|
|
490
|
+
def generate_class(self, cls: ClassDefinition) -> ClassResult:
|
|
491
|
+
"""
|
|
492
|
+
Generate a class as a struct!
|
|
493
|
+
"""
|
|
494
|
+
cls = self.before_generate_class(cls, self.schemaview)
|
|
495
|
+
induced_attrs = [self.schemaview.induced_slot(sn, cls.name) for sn in self.schemaview.class_slots(cls.name)]
|
|
496
|
+
induced_attrs = self.before_generate_slots(induced_attrs, self.schemaview)
|
|
497
|
+
slot_range_unions = []
|
|
498
|
+
for a in induced_attrs:
|
|
499
|
+
# Promote union across descendants for canonical union enum in base module
|
|
500
|
+
ranges = []
|
|
501
|
+
for r in self.schemaview.slot_range_as_union(a):
|
|
502
|
+
ranges.append(r)
|
|
503
|
+
for d in self.schemaview.class_descendants(cls.name):
|
|
504
|
+
sdesc = self.schemaview.induced_slot(a.name, d)
|
|
505
|
+
if sdesc is None:
|
|
506
|
+
continue
|
|
507
|
+
for r in self.schemaview.slot_range_as_union(sdesc):
|
|
508
|
+
if r not in ranges:
|
|
509
|
+
ranges.append(r)
|
|
510
|
+
if len(ranges) > 1:
|
|
511
|
+
slot_range_unions.append(
|
|
512
|
+
SlotRangeAsUnion(
|
|
513
|
+
slot_name=get_name(a),
|
|
514
|
+
ranges=[get_rust_type(r, self.schemaview, True) for r in ranges],
|
|
515
|
+
stubgen=self.stubgen,
|
|
516
|
+
)
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
cls_mod = RustClassModule(
|
|
520
|
+
class_name=get_name(cls),
|
|
521
|
+
class_name_snakecase=underscore(uncamelcase(cls.name)),
|
|
522
|
+
slot_ranges=slot_range_unions,
|
|
523
|
+
stubgen=self.stubgen,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
attributes = [self.generate_attribute(attr, cls) for attr in induced_attrs]
|
|
527
|
+
attributes = self.after_generate_slots(attributes, self.schemaview)
|
|
528
|
+
|
|
529
|
+
unsendable = any([a.range in self.schemaview.all_classes() for a in induced_attrs])
|
|
530
|
+
res = ClassResult(
|
|
531
|
+
source=cls,
|
|
532
|
+
cls=RustStruct(
|
|
533
|
+
name=get_name(cls),
|
|
534
|
+
properties=[a.attribute for a in attributes],
|
|
535
|
+
special_case_enabled=self.schemaview.get_uri(cls, expand=True).startswith("https://w3id.org/linkml"),
|
|
536
|
+
generate_merge=MERGE_ANNOTATION in cls.annotations,
|
|
537
|
+
unsendable=unsendable,
|
|
538
|
+
pyo3=self.pyo3,
|
|
539
|
+
serde=self.serde,
|
|
540
|
+
stubgen=self.stubgen,
|
|
541
|
+
as_key_value=self.generate_class_as_key_value(cls),
|
|
542
|
+
struct_or_subtype_enum=self.gen_struct_or_subtype_enum(cls),
|
|
543
|
+
class_module=cls_mod,
|
|
544
|
+
),
|
|
545
|
+
)
|
|
546
|
+
# merge imports
|
|
547
|
+
for attr in attributes:
|
|
548
|
+
res = res.merge(attr)
|
|
549
|
+
|
|
550
|
+
res = self.after_generate_class(res, self.schemaview)
|
|
551
|
+
return res
|
|
552
|
+
|
|
553
|
+
def gen_struct_or_subtype_enum(self, cls: ClassDefinition) -> Optional[RustStructOrSubtypeEnum]:
|
|
554
|
+
descendants = class_real_descendants(self.schemaview, cls.name)
|
|
555
|
+
td = self.schemaview.get_type_designator_slot(cls.name)
|
|
556
|
+
td_mapping = {}
|
|
557
|
+
if td is not None:
|
|
558
|
+
for d in descendants:
|
|
559
|
+
d_class = self.schemaview.get_class(d)
|
|
560
|
+
values = get_accepted_type_designator_values(self.schemaview, td, d_class)
|
|
561
|
+
td_mapping[d] = values
|
|
562
|
+
if len(descendants) > 0:
|
|
563
|
+
key_type = "String"
|
|
564
|
+
key_slot = get_key_or_identifier_slot(cls, self.schemaview)
|
|
565
|
+
if key_slot is not None:
|
|
566
|
+
key_type = get_rust_type(key_slot.range, self.schemaview, self.pyo3)
|
|
567
|
+
return RustStructOrSubtypeEnum(
|
|
568
|
+
enum_name=get_name(cls) + "OrSubtype",
|
|
569
|
+
struct_names=[get_name(self.schemaview.get_class(d)) for d in descendants],
|
|
570
|
+
type_designator_name=get_name(td) if td else None,
|
|
571
|
+
as_key_value=get_key_or_identifier_slot(cls, self.schemaview) is not None,
|
|
572
|
+
type_designators=td_mapping,
|
|
573
|
+
key_property_type=key_type,
|
|
574
|
+
)
|
|
575
|
+
return None
|
|
576
|
+
|
|
577
|
+
def generate_class_as_key_value(self, cls: ClassDefinition) -> Optional[AsKeyValue]:
|
|
578
|
+
induced_attrs = [self.schemaview.induced_slot(sn, cls.name) for sn in self.schemaview.class_slots(cls.name)]
|
|
579
|
+
key_attr = None
|
|
580
|
+
value_attrs = []
|
|
581
|
+
value_args_no_default = []
|
|
582
|
+
non_key_attrs = []
|
|
583
|
+
|
|
584
|
+
for attr in induced_attrs:
|
|
585
|
+
if attr.identifier:
|
|
586
|
+
if key_attr is not None:
|
|
587
|
+
## multiple identifiers --> don't know what to do!
|
|
588
|
+
return None
|
|
589
|
+
key_attr = attr
|
|
590
|
+
elif attr.key:
|
|
591
|
+
if key_attr is not None:
|
|
592
|
+
## multiple keys --> don't know what to do!
|
|
593
|
+
return None
|
|
594
|
+
key_attr = attr
|
|
595
|
+
else:
|
|
596
|
+
non_key_attrs.append(attr)
|
|
597
|
+
if not attr.multivalued:
|
|
598
|
+
value_attrs.append(attr)
|
|
599
|
+
if attr.required:
|
|
600
|
+
value_args_no_default.append(attr)
|
|
601
|
+
if key_attr is not None:
|
|
602
|
+
# If there is a key/identifier but no single-valued non-multivalued
|
|
603
|
+
# attribute to serve as the value, do not treat this as a key/value class.
|
|
604
|
+
if len(value_attrs) == 0:
|
|
605
|
+
return None
|
|
606
|
+
value_attr = value_attrs[0]
|
|
607
|
+
simple_dict_possible = (
|
|
608
|
+
len(non_key_attrs) == 1
|
|
609
|
+
and not value_attr.multivalued
|
|
610
|
+
and (
|
|
611
|
+
value_attr.range not in self.schemaview.all_classes()
|
|
612
|
+
or not bool(getattr(value_attr, "inlined", False))
|
|
613
|
+
)
|
|
614
|
+
)
|
|
615
|
+
return AsKeyValue(
|
|
616
|
+
name=get_name(cls),
|
|
617
|
+
key_property_name=get_name(key_attr),
|
|
618
|
+
key_property_type=get_rust_type(key_attr.range, self.schemaview, self.pyo3),
|
|
619
|
+
value_property_name=get_name(value_attr),
|
|
620
|
+
value_property_type=get_rust_type(value_attr.range, self.schemaview, self.pyo3),
|
|
621
|
+
can_convert_from_primitive=simple_dict_possible,
|
|
622
|
+
can_convert_from_empty=len(value_args_no_default) == 0,
|
|
623
|
+
value_property_optional=not bool(value_attr.required),
|
|
624
|
+
serde=self.serde,
|
|
625
|
+
pyo3=self.pyo3,
|
|
626
|
+
stubgen=self.stubgen,
|
|
627
|
+
)
|
|
628
|
+
return None
|
|
629
|
+
|
|
630
|
+
def generate_attribute(self, attr: SlotDefinition, cls: ClassDefinition) -> AttributeResult:
|
|
631
|
+
"""
|
|
632
|
+
Generate an attribute as a struct property
|
|
633
|
+
"""
|
|
634
|
+
attr = self.before_generate_slot(attr, self.schemaview)
|
|
635
|
+
is_class_range = attr.range in self.schemaview.all_classes()
|
|
636
|
+
(container_mode, inline_mode) = determine_slot_mode(attr, self.schemaview)
|
|
637
|
+
range = get_rust_range_info(cls, attr, self.schemaview)
|
|
638
|
+
res = AttributeResult(
|
|
639
|
+
source=attr,
|
|
640
|
+
attribute=RustProperty(
|
|
641
|
+
name=get_name(attr),
|
|
642
|
+
inline_mode=inline_mode.value,
|
|
643
|
+
alias=attr.alias if attr.alias is not None and attr.alias != get_name(attr) else None,
|
|
644
|
+
generate_merge=MERGE_ANNOTATION in cls.annotations,
|
|
645
|
+
container_mode=container_mode.value,
|
|
646
|
+
type_=range,
|
|
647
|
+
required=bool(attr.required),
|
|
648
|
+
multivalued=True if attr.multivalued else False,
|
|
649
|
+
is_key_value=is_class_range
|
|
650
|
+
and self.generate_class_as_key_value(self.schemaview.get_class(attr.range)) is not None,
|
|
651
|
+
pyo3=self.pyo3,
|
|
652
|
+
serde=self.serde,
|
|
653
|
+
stubgen=self.stubgen,
|
|
654
|
+
),
|
|
655
|
+
imports=self.get_imports(attr),
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
res = self.after_generate_slot(res, self.schemaview)
|
|
659
|
+
return res
|
|
660
|
+
|
|
661
|
+
def generate_cargo(self, imports: Imports) -> RustCargo:
|
|
662
|
+
"""
|
|
663
|
+
Generate a Cargo.toml file
|
|
664
|
+
"""
|
|
665
|
+
version = self.schemaview.schema.version if self.schemaview.schema.version is not None else "0.0.0"
|
|
666
|
+
return RustCargo(
|
|
667
|
+
name=self.crate_name if self.crate_name is not None else self.schemaview.schema.name,
|
|
668
|
+
version=version,
|
|
669
|
+
imports=imports,
|
|
670
|
+
pyo3_version=self.pyo3_version,
|
|
671
|
+
pyo3=self.pyo3,
|
|
672
|
+
serde=self.serde,
|
|
673
|
+
stubgen=self.stubgen,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
def generate_pyproject(self) -> RustPyProject:
|
|
677
|
+
"""
|
|
678
|
+
Generate a pyproject.toml file for a pyo3 rust crate
|
|
679
|
+
"""
|
|
680
|
+
version = self.schemaview.schema.version if self.schemaview.schema.version is not None else "0.0.0"
|
|
681
|
+
return RustPyProject(name=self.schemaview.schema.name, version=version)
|
|
682
|
+
|
|
683
|
+
def get_imports(self, element: Union[SlotDefinition, TypeDefinition]) -> Imports:
|
|
684
|
+
if isinstance(element, SlotDefinition):
|
|
685
|
+
type_ = get_rust_type(element.range, self.schemaview, self.pyo3)
|
|
686
|
+
elif isinstance(element, TypeDefinition):
|
|
687
|
+
type_ = get_rust_type(element.base, self.schemaview, self.pyo3)
|
|
688
|
+
else:
|
|
689
|
+
raise TypeError("Must be a slot or type definition")
|
|
690
|
+
|
|
691
|
+
if type_ in RUST_IMPORTS:
|
|
692
|
+
return Imports(imports=[RUST_IMPORTS[type_]])
|
|
693
|
+
else:
|
|
694
|
+
return Imports()
|
|
695
|
+
|
|
696
|
+
@overload
|
|
697
|
+
def render(self, mode: Literal["file"] = "file") -> FileResult: ...
|
|
698
|
+
|
|
699
|
+
@overload
|
|
700
|
+
def render(self, mode: Literal["crate"] = "crate") -> CrateResult: ...
|
|
701
|
+
|
|
702
|
+
def render(self, mode: Optional[RUST_MODES] = None) -> Union[FileResult, CrateResult]:
|
|
703
|
+
"""
|
|
704
|
+
Render the template model of a rust file before serializing
|
|
705
|
+
|
|
706
|
+
Args:
|
|
707
|
+
mode (:class:`.RUST_MODES`, optional): Override the instance-level generation mode
|
|
708
|
+
"""
|
|
709
|
+
if mode is None:
|
|
710
|
+
mode = self.mode
|
|
711
|
+
|
|
712
|
+
sv = self.schemaview
|
|
713
|
+
|
|
714
|
+
types = list(sv.all_types(imports=True).values())
|
|
715
|
+
types = self.before_generate_types(types, sv)
|
|
716
|
+
types = [self.generate_type(t) for t in types]
|
|
717
|
+
types = self.after_generate_types(types, sv)
|
|
718
|
+
|
|
719
|
+
enums = list(sv.all_enums(imports=True).values())
|
|
720
|
+
enums = self.before_generate_enums(enums, sv)
|
|
721
|
+
enums = [self.generate_enum(e) for e in enums]
|
|
722
|
+
enums = self.after_generate_enums(enums, sv)
|
|
723
|
+
|
|
724
|
+
slots = list(sv.induced_slot(s) for s in sv.all_slots())
|
|
725
|
+
slots = self.before_generate_slots(slots, sv)
|
|
726
|
+
slots = [self.generate_slot(s) for s in slots]
|
|
727
|
+
slots = self.after_generate_slots(slots, sv)
|
|
728
|
+
|
|
729
|
+
need_merge_crate = False
|
|
730
|
+
class_defs = [sv.induced_class(c) for c in sv.all_classes(ordered_by=OrderedBy.INHERITANCE)]
|
|
731
|
+
root_class_def = self._select_root_class(class_defs)
|
|
732
|
+
root_struct_name = get_name(root_class_def) if root_class_def is not None else None
|
|
733
|
+
classes = class_defs
|
|
734
|
+
for c in classes:
|
|
735
|
+
if MERGE_ANNOTATION in c.annotations:
|
|
736
|
+
need_merge_crate = True
|
|
737
|
+
break
|
|
738
|
+
|
|
739
|
+
classes = self.before_generate_classes(classes, sv)
|
|
740
|
+
classes = [self.generate_class(c) for c in classes]
|
|
741
|
+
classes = self.after_generate_classes(classes, sv)
|
|
742
|
+
|
|
743
|
+
poly_traits = [self.gen_poly_trait(sv.get_class(c)) for c in sv.all_classes(ordered_by=OrderedBy.INHERITANCE)]
|
|
744
|
+
|
|
745
|
+
imports = DEFAULT_IMPORTS.model_copy()
|
|
746
|
+
imports += PYTHON_IMPORTS
|
|
747
|
+
imports += SERDE_IMPORTS
|
|
748
|
+
if self.stubgen:
|
|
749
|
+
imports += STUBGEN_IMPORTS
|
|
750
|
+
if need_merge_crate:
|
|
751
|
+
imports += MERGE_IMPORTS
|
|
752
|
+
for result in [*enums, *slots, *classes]:
|
|
753
|
+
imports += result.imports
|
|
754
|
+
|
|
755
|
+
file = RustFile(
|
|
756
|
+
name=sv.schema.name,
|
|
757
|
+
imports=imports,
|
|
758
|
+
slots=[t.slot for t in slots],
|
|
759
|
+
types=[t.type_ for t in types],
|
|
760
|
+
enums=[e.enum for e in enums],
|
|
761
|
+
structs=[c.cls for c in classes],
|
|
762
|
+
pyo3=self.pyo3,
|
|
763
|
+
serde=self.serde,
|
|
764
|
+
stubgen=self.stubgen,
|
|
765
|
+
handwritten_lib=self.handwritten_lib,
|
|
766
|
+
root_struct_name=root_struct_name,
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
if mode == "crate":
|
|
770
|
+
extra_files = {}
|
|
771
|
+
extra_files["serde_utils"] = SerdeUtilsFile()
|
|
772
|
+
extra_files["poly"] = PolyFile(imports=imports, traits=poly_traits)
|
|
773
|
+
extra_files["poly_containers"] = PolyContainersFile()
|
|
774
|
+
if self.stubgen:
|
|
775
|
+
extra_files["stub_utils"] = StubUtilsFile()
|
|
776
|
+
cargo = self.generate_cargo(imports)
|
|
777
|
+
pyproject = self.generate_pyproject()
|
|
778
|
+
bin_files = {}
|
|
779
|
+
if self.stubgen:
|
|
780
|
+
bin_files["bin/stub_gen"] = StubGenBin(crate_name=cargo.name, stubgen=self.stubgen)
|
|
781
|
+
res = CrateResult(
|
|
782
|
+
cargo=cargo,
|
|
783
|
+
file=file,
|
|
784
|
+
pyproject=pyproject,
|
|
785
|
+
source=sv.schema,
|
|
786
|
+
extra_files=extra_files,
|
|
787
|
+
bin_files=bin_files,
|
|
788
|
+
)
|
|
789
|
+
return res
|
|
790
|
+
else:
|
|
791
|
+
# Single file: inline serde utils, and skip poly modules
|
|
792
|
+
file.inline_serde_utils = True
|
|
793
|
+
file.emit_poly = False
|
|
794
|
+
file.serde_utils = SerdeUtilsFile()
|
|
795
|
+
res = FileResult(file=file, source=sv.schema)
|
|
796
|
+
return res
|
|
797
|
+
|
|
798
|
+
def gen_poly_trait(self, cls: ClassDefinition) -> PolyTrait:
|
|
799
|
+
impls = []
|
|
800
|
+
class_name = get_name(cls)
|
|
801
|
+
attribs = self.schemaview.class_induced_slots(cls.name)
|
|
802
|
+
superclass_names = []
|
|
803
|
+
if cls.is_a is not None:
|
|
804
|
+
superclass_names.append(cls.is_a)
|
|
805
|
+
for m in cls.mixins:
|
|
806
|
+
superclass_names.append(m)
|
|
807
|
+
|
|
808
|
+
superclasses = [self.schemaview.get_class(sn) for sn in superclass_names if sn is not None]
|
|
809
|
+
for superclass in superclasses:
|
|
810
|
+
attribs_sc = self.schemaview.class_induced_slots(superclass.name)
|
|
811
|
+
attribs = [a for a in attribs if a.name not in [sc.name for sc in attribs_sc]]
|
|
812
|
+
|
|
813
|
+
rust_attribs = []
|
|
814
|
+
for a in attribs:
|
|
815
|
+
n = get_name(a)
|
|
816
|
+
base_ri = get_rust_range_info(cls, a, self.schemaview)
|
|
817
|
+
promoted_ri = self.get_rust_range_info_across_descendants(cls, a)
|
|
818
|
+
rust_attribs.append(PolyTraitProperty(name=n, range=base_ri, promoted_range=promoted_ri))
|
|
819
|
+
|
|
820
|
+
subtype_impls = []
|
|
821
|
+
for sc in self.schemaview.class_descendants(cls.name):
|
|
822
|
+
sco = self.schemaview.get_class(sc)
|
|
823
|
+
induced_slots = self.schemaview.class_induced_slots(sco.name)
|
|
824
|
+
|
|
825
|
+
def find_slot(n: str):
|
|
826
|
+
for s in induced_slots:
|
|
827
|
+
if s.name == n:
|
|
828
|
+
return s
|
|
829
|
+
return None
|
|
830
|
+
|
|
831
|
+
ptis = [
|
|
832
|
+
PolyTraitPropertyImpl(
|
|
833
|
+
name=get_name(a),
|
|
834
|
+
range=get_rust_range_info(sco, find_slot(a.name), self.schemaview),
|
|
835
|
+
definition_range=self.get_rust_range_info_across_descendants(cls, a),
|
|
836
|
+
trait_range=self.get_rust_range_info_across_descendants(cls, a),
|
|
837
|
+
struct_name=get_name(sco),
|
|
838
|
+
)
|
|
839
|
+
for a in attribs
|
|
840
|
+
]
|
|
841
|
+
impls.append(PolyTraitImpl(name=class_name, struct_name=get_name(sco), attrs=ptis))
|
|
842
|
+
has_subtypes = has_real_subtypes(self.schemaview, sc)
|
|
843
|
+
if has_subtypes:
|
|
844
|
+
cases = [get_name(self.schemaview.get_class(x)) for x in class_real_descendants(self.schemaview, sc)]
|
|
845
|
+
matches = [
|
|
846
|
+
PolyTraitPropertyMatch(
|
|
847
|
+
name=get_name(a),
|
|
848
|
+
range=self.get_rust_range_info_across_descendants(cls, a),
|
|
849
|
+
cases=cases,
|
|
850
|
+
struct_name=f"{get_name(sco)}OrSubtype",
|
|
851
|
+
)
|
|
852
|
+
for a in attribs
|
|
853
|
+
]
|
|
854
|
+
subtype_impls.append(
|
|
855
|
+
PolyTraitImplForSubtypeEnum(name=class_name, enum_name=f"{get_name(sco)}OrSubtype", attrs=matches)
|
|
856
|
+
)
|
|
857
|
+
return PolyTrait(
|
|
858
|
+
name=class_name,
|
|
859
|
+
impls=impls,
|
|
860
|
+
attrs=rust_attribs,
|
|
861
|
+
superclass_names=[get_name(scla) for scla in superclasses],
|
|
862
|
+
subtypes=subtype_impls,
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
def serialize(self, output: Optional[Path] = None, mode: Optional[RUST_MODES] = None, force: bool = False) -> str:
|
|
866
|
+
"""
|
|
867
|
+
Serialize a schema to a rust crate or file.
|
|
868
|
+
|
|
869
|
+
Args:
|
|
870
|
+
output (Path, optional): A ``.rs`` file if in ``file`` mode,
|
|
871
|
+
directory otherwise.
|
|
872
|
+
force (bool): If the output already exists, overwrite it.
|
|
873
|
+
Otherwise raise a :class:`FileExistsError`
|
|
874
|
+
"""
|
|
875
|
+
if mode is None:
|
|
876
|
+
mode = self.mode
|
|
877
|
+
|
|
878
|
+
output = self._validate_output(output, mode, force)
|
|
879
|
+
rendered = self.render(mode=mode)
|
|
880
|
+
if mode == "crate":
|
|
881
|
+
serialized = self.write_crate(output, rendered, force)
|
|
882
|
+
else:
|
|
883
|
+
serialized = rendered.file.render(self.template_environment)
|
|
884
|
+
serialized = serialized.rstrip("\n") + "\n"
|
|
885
|
+
with open(output, "w") as f:
|
|
886
|
+
f.write(serialized)
|
|
887
|
+
|
|
888
|
+
return serialized
|
|
889
|
+
|
|
890
|
+
def get_rust_range_info_across_descendants(self, cls: ClassDefinition, s: SlotDefinition) -> RustRange:
|
|
891
|
+
"""Compute a RustRange representing the union of a slot's ranges across a class and all its descendants.
|
|
892
|
+
|
|
893
|
+
Container and optionality are taken from the base class slot.
|
|
894
|
+
"""
|
|
895
|
+
sv = self.schemaview
|
|
896
|
+
# Collect rust type names for all ranges across base + descendants, and remember
|
|
897
|
+
# the source class name (if any) responsible for each rust type so we can
|
|
898
|
+
# correctly determine subtype presence against the metamodel (using class names,
|
|
899
|
+
# not rust type identifiers).
|
|
900
|
+
type_names: list[str] = []
|
|
901
|
+
rust_to_class: dict[str, Optional[str]] = {}
|
|
902
|
+
|
|
903
|
+
def add_for_slot(slot_def: SlotDefinition):
|
|
904
|
+
for r in sv.slot_range_as_union(slot_def):
|
|
905
|
+
if r in sv.all_classes():
|
|
906
|
+
# Special-case: treat Anything/AnyValue as inline to ensure
|
|
907
|
+
# promoted unions include the corresponding variant.
|
|
908
|
+
if r in {"Anything", "AnyValue"}:
|
|
909
|
+
tname = get_rust_type(r, sv, True)
|
|
910
|
+
rust_to_class[tname] = r
|
|
911
|
+
if tname not in type_names:
|
|
912
|
+
type_names.append(tname)
|
|
913
|
+
continue
|
|
914
|
+
# Prefer concrete observations: only add String if explicitly non-inlined
|
|
915
|
+
inl = slot_def.inlined
|
|
916
|
+
inl_list = slot_def.inlined_as_list
|
|
917
|
+
if inl is True or inl_list is True:
|
|
918
|
+
tname = get_rust_type(r, sv, True)
|
|
919
|
+
rust_to_class[tname] = r
|
|
920
|
+
elif inl is False and (inl_list is False or inl_list is None):
|
|
921
|
+
tname = "String"
|
|
922
|
+
rust_to_class[tname] = None
|
|
923
|
+
else:
|
|
924
|
+
# Unknown inlining at this definition; skip adding a guess
|
|
925
|
+
continue
|
|
926
|
+
else:
|
|
927
|
+
tname = get_rust_type(r, sv, True)
|
|
928
|
+
rust_to_class[tname] = None
|
|
929
|
+
if tname not in type_names:
|
|
930
|
+
type_names.append(tname)
|
|
931
|
+
|
|
932
|
+
base_slot = sv.induced_slot(s.name, cls.name)
|
|
933
|
+
if base_slot is not None:
|
|
934
|
+
add_for_slot(base_slot)
|
|
935
|
+
|
|
936
|
+
# Include descendants in the class inheritance tree
|
|
937
|
+
for d in sv.class_descendants(cls.name):
|
|
938
|
+
ds = sv.induced_slot(s.name, d)
|
|
939
|
+
if ds is not None:
|
|
940
|
+
add_for_slot(ds)
|
|
941
|
+
|
|
942
|
+
# If this is a mixin, include classes that use the mixin and their descendants
|
|
943
|
+
try:
|
|
944
|
+
all_classes = list(sv.all_classes())
|
|
945
|
+
except Exception:
|
|
946
|
+
all_classes = []
|
|
947
|
+
for cname in all_classes:
|
|
948
|
+
cdef = sv.get_class(cname)
|
|
949
|
+
if cdef is None:
|
|
950
|
+
continue
|
|
951
|
+
if cls.name in (cdef.mixins or []):
|
|
952
|
+
ds = sv.induced_slot(s.name, cname)
|
|
953
|
+
if ds is not None:
|
|
954
|
+
add_for_slot(ds)
|
|
955
|
+
for dd in sv.class_descendants(cname):
|
|
956
|
+
dslot = sv.induced_slot(s.name, dd)
|
|
957
|
+
if dslot is not None:
|
|
958
|
+
add_for_slot(dslot)
|
|
959
|
+
|
|
960
|
+
container_mode, _ = determine_slot_mode(s, sv)
|
|
961
|
+
# Optionality across descendants/mixin users: optional if not all are required
|
|
962
|
+
all_required = True
|
|
963
|
+
|
|
964
|
+
def consider_required(slot_def: SlotDefinition):
|
|
965
|
+
nonlocal all_required
|
|
966
|
+
if not bool(slot_def.required):
|
|
967
|
+
all_required = False
|
|
968
|
+
|
|
969
|
+
if base_slot is not None:
|
|
970
|
+
consider_required(base_slot)
|
|
971
|
+
for d in sv.class_descendants(cls.name):
|
|
972
|
+
ds = sv.induced_slot(s.name, d)
|
|
973
|
+
if ds is not None:
|
|
974
|
+
consider_required(ds)
|
|
975
|
+
try:
|
|
976
|
+
all_classes = list(sv.all_classes())
|
|
977
|
+
except Exception:
|
|
978
|
+
all_classes = []
|
|
979
|
+
for cname in all_classes:
|
|
980
|
+
cdef = sv.get_class(cname)
|
|
981
|
+
if cdef is None:
|
|
982
|
+
continue
|
|
983
|
+
if cls.name in (cdef.mixins or []):
|
|
984
|
+
ds = sv.induced_slot(s.name, cname)
|
|
985
|
+
if ds is not None:
|
|
986
|
+
consider_required(ds)
|
|
987
|
+
for dd in sv.class_descendants(cname):
|
|
988
|
+
dslot = sv.induced_slot(s.name, dd)
|
|
989
|
+
if dslot is not None:
|
|
990
|
+
consider_required(dslot)
|
|
991
|
+
base_optional = not all_required
|
|
992
|
+
|
|
993
|
+
if len(type_names) > 1:
|
|
994
|
+
child_ranges = [
|
|
995
|
+
RustRange(
|
|
996
|
+
type_=t,
|
|
997
|
+
is_class_range=t not in ("String", "bool", "f64", "isize"),
|
|
998
|
+
)
|
|
999
|
+
for t in type_names
|
|
1000
|
+
]
|
|
1001
|
+
return RustRange(
|
|
1002
|
+
optional=base_optional,
|
|
1003
|
+
has_default=base_optional or (s.multivalued or False),
|
|
1004
|
+
containerType=(
|
|
1005
|
+
ContainerType.LIST
|
|
1006
|
+
if container_mode == SlotContainerMode.LIST
|
|
1007
|
+
else ContainerType.MAPPING
|
|
1008
|
+
if container_mode == SlotContainerMode.MAPPING
|
|
1009
|
+
else None
|
|
1010
|
+
),
|
|
1011
|
+
child_ranges=child_ranges,
|
|
1012
|
+
is_class_range=False,
|
|
1013
|
+
is_reference=False,
|
|
1014
|
+
type_=underscore(uncamelcase(cls.name)) + "_utl::" + get_name(s) + "_range",
|
|
1015
|
+
)
|
|
1016
|
+
else:
|
|
1017
|
+
# Fall back to base definition only if nothing was observed concretely
|
|
1018
|
+
if len(type_names) == 0 and base_slot is not None:
|
|
1019
|
+
for r in sv.slot_range_as_union(base_slot):
|
|
1020
|
+
if r in sv.all_classes():
|
|
1021
|
+
inl = base_slot.inlined
|
|
1022
|
+
inl_list = base_slot.inlined_as_list
|
|
1023
|
+
if inl is True or inl_list is True:
|
|
1024
|
+
tname = get_rust_type(r, sv, True)
|
|
1025
|
+
if tname not in type_names:
|
|
1026
|
+
type_names.append(tname)
|
|
1027
|
+
rust_to_class[tname] = r
|
|
1028
|
+
else:
|
|
1029
|
+
tname = get_rust_type(r, sv, True)
|
|
1030
|
+
if tname not in type_names:
|
|
1031
|
+
type_names.append(tname)
|
|
1032
|
+
rust_to_class[tname] = None
|
|
1033
|
+
# If still empty, fall back to original per-class range info
|
|
1034
|
+
if len(type_names) == 0:
|
|
1035
|
+
return get_rust_range_info(cls, s, sv)
|
|
1036
|
+
single = type_names[0]
|
|
1037
|
+
single_src_class = rust_to_class.get(single, None)
|
|
1038
|
+
return RustRange(
|
|
1039
|
+
optional=base_optional,
|
|
1040
|
+
has_default=base_optional or (s.multivalued or False),
|
|
1041
|
+
containerType=(
|
|
1042
|
+
ContainerType.LIST
|
|
1043
|
+
if container_mode == SlotContainerMode.LIST
|
|
1044
|
+
else ContainerType.MAPPING
|
|
1045
|
+
if container_mode == SlotContainerMode.MAPPING
|
|
1046
|
+
else None
|
|
1047
|
+
),
|
|
1048
|
+
child_ranges=None,
|
|
1049
|
+
is_class_range=single not in ("String", "bool", "f64", "isize"),
|
|
1050
|
+
is_reference=False,
|
|
1051
|
+
has_class_subtypes=(
|
|
1052
|
+
has_real_subtypes(self.schemaview, single_src_class) if single_src_class is not None else False
|
|
1053
|
+
),
|
|
1054
|
+
type_=single,
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
def write_crate(
|
|
1058
|
+
self, output: Optional[Path] = None, rendered: Union[FileResult, CrateResult] = None, force: bool = False
|
|
1059
|
+
) -> str:
|
|
1060
|
+
output = self._validate_output(output, mode="crate", force=force)
|
|
1061
|
+
if rendered is None:
|
|
1062
|
+
rendered = self.render(mode="crate")
|
|
1063
|
+
|
|
1064
|
+
cargo = rendered.cargo.render(self.template_environment)
|
|
1065
|
+
cargo_file = output / "Cargo.toml"
|
|
1066
|
+
self._write_text_file(cargo_file, cargo, crate_root=output)
|
|
1067
|
+
|
|
1068
|
+
pyproject = rendered.pyproject.render(self.template_environment)
|
|
1069
|
+
pyproject_file = output / "pyproject.toml"
|
|
1070
|
+
self._write_text_file(pyproject_file, pyproject, crate_root=output)
|
|
1071
|
+
|
|
1072
|
+
rust_file = rendered.file.render(self.template_environment)
|
|
1073
|
+
src_dir = output / "src"
|
|
1074
|
+
src_dir.mkdir(exist_ok=True)
|
|
1075
|
+
if self.handwritten_lib:
|
|
1076
|
+
generated_dir = src_dir / "generated"
|
|
1077
|
+
generated_dir.mkdir(exist_ok=True)
|
|
1078
|
+
lib_file = generated_dir / "mod.rs"
|
|
1079
|
+
else:
|
|
1080
|
+
generated_dir = src_dir
|
|
1081
|
+
lib_file = src_dir / "lib.rs"
|
|
1082
|
+
self._write_text_file(lib_file, rust_file, crate_root=output)
|
|
1083
|
+
|
|
1084
|
+
for k, f in rendered.extra_files.items():
|
|
1085
|
+
extra_file = f.render(self.template_environment)
|
|
1086
|
+
extra_file_name = f"{k}.rs"
|
|
1087
|
+
extra_file_path = self._safe_subpath(generated_dir, extra_file_name)
|
|
1088
|
+
self._write_text_file(extra_file_path, extra_file, crate_root=output)
|
|
1089
|
+
|
|
1090
|
+
if getattr(rendered, "bin_files", None):
|
|
1091
|
+
for rel_path, template in rendered.bin_files.items():
|
|
1092
|
+
rendered_bin = template.render(self.template_environment)
|
|
1093
|
+
safe_bin_base = self._safe_subpath(src_dir, rel_path)
|
|
1094
|
+
bin_path = safe_bin_base.with_suffix(".rs")
|
|
1095
|
+
self._write_text_file(bin_path, rendered_bin, crate_root=output)
|
|
1096
|
+
|
|
1097
|
+
if self.handwritten_lib:
|
|
1098
|
+
shim_path = src_dir / "lib.rs"
|
|
1099
|
+
if not shim_path.exists():
|
|
1100
|
+
root_struct_name = getattr(rendered.file, "root_struct_name", None)
|
|
1101
|
+
root_struct_fn_snake = underscore(uncamelcase(root_struct_name)) if root_struct_name else None
|
|
1102
|
+
shim_template = RustLibShim(
|
|
1103
|
+
module_name=rendered.file.name,
|
|
1104
|
+
pyo3=self.pyo3,
|
|
1105
|
+
serde=self.serde,
|
|
1106
|
+
stubgen=self.stubgen,
|
|
1107
|
+
handwritten_lib=self.handwritten_lib,
|
|
1108
|
+
root_struct_name=root_struct_name,
|
|
1109
|
+
root_struct_fn_snake=root_struct_fn_snake,
|
|
1110
|
+
)
|
|
1111
|
+
shim = shim_template.render(self.template_environment)
|
|
1112
|
+
self._write_text_file(shim_path, shim, crate_root=output)
|
|
1113
|
+
|
|
1114
|
+
return rust_file
|
|
1115
|
+
|
|
1116
|
+
def _validate_output(
|
|
1117
|
+
self, output: Optional[Path] = None, mode: Optional[RUST_MODES] = None, force: bool = False
|
|
1118
|
+
) -> Path:
|
|
1119
|
+
"""Raise a ValueError if given a dir when in file mode or vice versa"""
|
|
1120
|
+
if output is None:
|
|
1121
|
+
if self.output is None:
|
|
1122
|
+
raise ValueError("Must provide an output if generator doesn't already have one")
|
|
1123
|
+
else:
|
|
1124
|
+
output = Path(self.output)
|
|
1125
|
+
else:
|
|
1126
|
+
output = Path(output)
|
|
1127
|
+
|
|
1128
|
+
if mode == "file":
|
|
1129
|
+
assert output.suffix == ".rs", "Output must be a rust file in file mode"
|
|
1130
|
+
if not force and output.exists():
|
|
1131
|
+
raise FileExistsError(f"{output} already exists and force is False! pass force=True to overwrite")
|
|
1132
|
+
output.parent.mkdir(exist_ok=True, parents=True)
|
|
1133
|
+
elif mode == "crate":
|
|
1134
|
+
if not force and len([d for d in output.iterdir()]) != 0:
|
|
1135
|
+
raise FileExistsError(
|
|
1136
|
+
f"{output} already exists, is not empty, and force is False! pass force=True to overwrite"
|
|
1137
|
+
)
|
|
1138
|
+
output.mkdir(exist_ok=True, parents=True)
|
|
1139
|
+
else:
|
|
1140
|
+
raise ValueError(f"Invalid generation mode: {mode}")
|
|
1141
|
+
|
|
1142
|
+
return output
|
|
1143
|
+
|
|
1144
|
+
def _safe_subpath(self, base: Path, relative: Union[str, Path]) -> Path:
|
|
1145
|
+
"""Return a path nested under base, validating it does not escape."""
|
|
1146
|
+
|
|
1147
|
+
rel_path = Path(relative)
|
|
1148
|
+
if rel_path.is_absolute():
|
|
1149
|
+
raise ValueError(f"Relative path expected, got absolute path: {relative}")
|
|
1150
|
+
|
|
1151
|
+
if not rel_path.parts:
|
|
1152
|
+
raise ValueError("Relative path must contain at least one segment")
|
|
1153
|
+
|
|
1154
|
+
for part in rel_path.parts:
|
|
1155
|
+
if part in (".", ".."):
|
|
1156
|
+
raise ValueError(f"Invalid path segment: {part}")
|
|
1157
|
+
if "/" in part or "\\" in part:
|
|
1158
|
+
raise ValueError(f"Path segment must not contain separators: {part}")
|
|
1159
|
+
|
|
1160
|
+
candidate = base / rel_path
|
|
1161
|
+
base_resolved = base.resolve()
|
|
1162
|
+
try:
|
|
1163
|
+
candidate.resolve().relative_to(base_resolved)
|
|
1164
|
+
except ValueError as exc: # pragma: no cover - defensive
|
|
1165
|
+
raise ValueError(f"Path {candidate} escapes base directory {base}") from exc
|
|
1166
|
+
|
|
1167
|
+
return candidate
|
|
1168
|
+
|
|
1169
|
+
def _write_text_file(self, path: Path, content: str, *, crate_root: Path) -> None:
|
|
1170
|
+
"""Normalize trailing newline, ensure parent dirs, and write text."""
|
|
1171
|
+
|
|
1172
|
+
base_resolved = crate_root.resolve()
|
|
1173
|
+
try:
|
|
1174
|
+
path.resolve().relative_to(base_resolved)
|
|
1175
|
+
except ValueError as exc:
|
|
1176
|
+
raise ValueError(f"Path {path} escapes crate root {crate_root}") from exc
|
|
1177
|
+
|
|
1178
|
+
normalized = content.rstrip("\n") + "\n"
|
|
1179
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
1180
|
+
path.write_text(normalized)
|
|
1181
|
+
|
|
1182
|
+
@property
|
|
1183
|
+
def template_environment(self) -> Environment:
|
|
1184
|
+
if self._environment is None:
|
|
1185
|
+
self._environment = RustTemplateModel.environment()
|
|
1186
|
+
return self._environment
|