erdify 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erdify/__init__.py +27 -0
- erdify/__main__.py +6 -0
- erdify/cli.py +114 -0
- erdify/config.py +41 -0
- erdify/generator.py +322 -0
- erdify/parser.py +519 -0
- erdify/py.typed +0 -0
- erdify-0.3.0.dist-info/METADATA +637 -0
- erdify-0.3.0.dist-info/RECORD +12 -0
- erdify-0.3.0.dist-info/WHEEL +4 -0
- erdify-0.3.0.dist-info/entry_points.txt +3 -0
- erdify-0.3.0.dist-info/licenses/LICENSE +21 -0
erdify/parser.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
1
|
+
"""AST-based parser for SQLModel, SQLAlchemy, Pydantic and dataclass models."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import re
|
|
5
|
+
import sys
|
|
6
|
+
from fnmatch import fnmatchcase
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Dict, List, Tuple
|
|
9
|
+
|
|
10
|
+
from .config import EntityInfo, EnumInfo, FieldInfo
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ASTDatabaseParser:
|
|
14
|
+
"""Parses database models using AST to extract schema information."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
database_path: Path,
|
|
19
|
+
exclude_patterns: List[str] | None = None,
|
|
20
|
+
infer_keys: bool = False,
|
|
21
|
+
):
|
|
22
|
+
self.database_path = database_path
|
|
23
|
+
self.exclude_patterns = exclude_patterns or []
|
|
24
|
+
self.infer_keys = infer_keys
|
|
25
|
+
self.entities: Dict[str, EntityInfo] = {}
|
|
26
|
+
self.enums: Dict[str, EnumInfo] = {}
|
|
27
|
+
self.all_classes: Dict[str, ast.ClassDef] = {} # Store all class definitions
|
|
28
|
+
self.file_trees: Dict[Path, ast.Module] = {}
|
|
29
|
+
|
|
30
|
+
def parse_all_models(self) -> Tuple[Dict[str, EntityInfo], Dict[str, EnumInfo]]:
|
|
31
|
+
"""Parse all model files in the database directory."""
|
|
32
|
+
model_files = list(self.database_path.rglob("models.py"))
|
|
33
|
+
|
|
34
|
+
# First pass: parse all files and collect class definitions
|
|
35
|
+
for model_file in model_files:
|
|
36
|
+
try:
|
|
37
|
+
with open(model_file, "r") as f:
|
|
38
|
+
content = f.read()
|
|
39
|
+
tree = ast.parse(content)
|
|
40
|
+
self.file_trees[model_file] = tree
|
|
41
|
+
|
|
42
|
+
# Collect all class definitions
|
|
43
|
+
for node in ast.walk(tree):
|
|
44
|
+
if isinstance(node, ast.ClassDef):
|
|
45
|
+
self.all_classes[node.name] = node
|
|
46
|
+
except Exception as e:
|
|
47
|
+
print(f"Error parsing {model_file}: {e}", file=sys.stderr)
|
|
48
|
+
|
|
49
|
+
# Second pass: process enum and model classes
|
|
50
|
+
for class_node in self.all_classes.values():
|
|
51
|
+
if self._is_enum_class(class_node):
|
|
52
|
+
self._parse_enum_class(class_node)
|
|
53
|
+
continue
|
|
54
|
+
source = self._classify_source(class_node)
|
|
55
|
+
if source is not None:
|
|
56
|
+
self._parse_table_class(class_node, source)
|
|
57
|
+
|
|
58
|
+
# Third pass: apply exclude patterns
|
|
59
|
+
self._apply_exclude_patterns()
|
|
60
|
+
|
|
61
|
+
return self.entities, self.enums
|
|
62
|
+
|
|
63
|
+
def _apply_exclude_patterns(self) -> None:
|
|
64
|
+
"""Remove excluded entities and strip relationships pointing at them.
|
|
65
|
+
|
|
66
|
+
A pattern matches an entity if its case-sensitive glob matches either
|
|
67
|
+
the class name or the table name. Relationships in surviving entities
|
|
68
|
+
that target an excluded entity are dropped so no dangling lines remain.
|
|
69
|
+
"""
|
|
70
|
+
if not self.exclude_patterns:
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
def is_excluded(entity: EntityInfo) -> bool:
|
|
74
|
+
return any(
|
|
75
|
+
fnmatchcase(entity.name, pattern) or fnmatchcase(entity.table_name, pattern)
|
|
76
|
+
for pattern in self.exclude_patterns
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
excluded_names = {name for name, e in self.entities.items() if is_excluded(e)}
|
|
80
|
+
|
|
81
|
+
self.entities = {name: e for name, e in self.entities.items() if name not in excluded_names}
|
|
82
|
+
|
|
83
|
+
for entity in self.entities.values():
|
|
84
|
+
entity.relationships = [
|
|
85
|
+
rel for rel in entity.relationships if rel[0] not in excluded_names
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
def _is_enum_class(self, class_node: ast.ClassDef) -> bool:
|
|
89
|
+
"""Check if a class is an Enum."""
|
|
90
|
+
for base in class_node.bases:
|
|
91
|
+
if isinstance(base, ast.Attribute):
|
|
92
|
+
if base.attr == "Enum":
|
|
93
|
+
return True
|
|
94
|
+
elif isinstance(base, ast.Name):
|
|
95
|
+
if base.id == "Enum":
|
|
96
|
+
return True
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
def _parse_enum_class(self, class_node: ast.ClassDef) -> None:
|
|
100
|
+
"""Parse an enum class and extract its values."""
|
|
101
|
+
values: List[str] = []
|
|
102
|
+
for node in class_node.body:
|
|
103
|
+
if isinstance(node, ast.Assign):
|
|
104
|
+
for target in node.targets:
|
|
105
|
+
if isinstance(target, ast.Name) and not target.id.startswith("_"):
|
|
106
|
+
values.append(target.id)
|
|
107
|
+
|
|
108
|
+
if values:
|
|
109
|
+
self.enums[class_node.name] = EnumInfo(name=class_node.name, values=values)
|
|
110
|
+
|
|
111
|
+
def _classify_source(self, class_node: ast.ClassDef) -> str | None:
|
|
112
|
+
"""Classify which model framework a class belongs to.
|
|
113
|
+
|
|
114
|
+
Returns one of "sqlmodel", "sqlalchemy", "pydantic", "dataclass", or
|
|
115
|
+
None if the class is not a drawable entity.
|
|
116
|
+
"""
|
|
117
|
+
# SQLModel: declared with table=True
|
|
118
|
+
for keyword in class_node.keywords:
|
|
119
|
+
if keyword.arg == "table" and isinstance(keyword.value, ast.Constant):
|
|
120
|
+
if keyword.value.value:
|
|
121
|
+
return "sqlmodel"
|
|
122
|
+
|
|
123
|
+
# SQLAlchemy 2.0: a concrete mapped class has __tablename__ and Mapped[...] columns.
|
|
124
|
+
# Mixins/abstract bases (Mapped fields but no __tablename__) are excluded here so
|
|
125
|
+
# they are not emitted as entities, but their fields are still inherited.
|
|
126
|
+
if self._has_tablename(class_node) and self._has_mapped_field(class_node):
|
|
127
|
+
return "sqlalchemy"
|
|
128
|
+
|
|
129
|
+
# Plain @dataclass
|
|
130
|
+
if self._is_dataclass(class_node):
|
|
131
|
+
return "dataclass"
|
|
132
|
+
|
|
133
|
+
# Pydantic: inherits from BaseModel (directly or transitively)
|
|
134
|
+
if self._inherits_basemodel(class_node):
|
|
135
|
+
return "pydantic"
|
|
136
|
+
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _is_dataclass(class_node: ast.ClassDef) -> bool:
|
|
141
|
+
"""Check if a class is decorated with @dataclass (bare or called)."""
|
|
142
|
+
for dec in class_node.decorator_list:
|
|
143
|
+
if isinstance(dec, ast.Name) and dec.id == "dataclass":
|
|
144
|
+
return True
|
|
145
|
+
if isinstance(dec, ast.Attribute) and dec.attr == "dataclass":
|
|
146
|
+
return True
|
|
147
|
+
if isinstance(dec, ast.Call):
|
|
148
|
+
func = dec.func
|
|
149
|
+
if isinstance(func, ast.Name) and func.id == "dataclass":
|
|
150
|
+
return True
|
|
151
|
+
if isinstance(func, ast.Attribute) and func.attr == "dataclass":
|
|
152
|
+
return True
|
|
153
|
+
return False
|
|
154
|
+
|
|
155
|
+
def _inherits_basemodel(
|
|
156
|
+
self, class_node: ast.ClassDef, visited: set[str] | None = None
|
|
157
|
+
) -> bool:
|
|
158
|
+
"""Check if a class inherits from Pydantic's BaseModel, directly or via ancestors."""
|
|
159
|
+
visited = visited if visited is not None else set()
|
|
160
|
+
if class_node.name in visited:
|
|
161
|
+
return False
|
|
162
|
+
visited.add(class_node.name)
|
|
163
|
+
|
|
164
|
+
for base in class_node.bases:
|
|
165
|
+
if isinstance(base, ast.Name):
|
|
166
|
+
if base.id == "BaseModel":
|
|
167
|
+
return True
|
|
168
|
+
ancestor = self.all_classes.get(base.id)
|
|
169
|
+
if ancestor is not None and self._inherits_basemodel(ancestor, visited):
|
|
170
|
+
return True
|
|
171
|
+
elif isinstance(base, ast.Attribute) and base.attr == "BaseModel":
|
|
172
|
+
return True
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
def _has_tablename(self, class_node: ast.ClassDef) -> bool:
|
|
176
|
+
"""Check if a class assigns __tablename__ (annotated or plain)."""
|
|
177
|
+
for node in class_node.body:
|
|
178
|
+
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
|
|
179
|
+
if node.target.id == "__tablename__":
|
|
180
|
+
return True
|
|
181
|
+
elif isinstance(node, ast.Assign):
|
|
182
|
+
for target in node.targets:
|
|
183
|
+
if isinstance(target, ast.Name) and target.id == "__tablename__":
|
|
184
|
+
return True
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
def _has_mapped_field(self, class_node: ast.ClassDef) -> bool:
|
|
188
|
+
"""Check if a class has at least one Mapped[...] annotated field."""
|
|
189
|
+
for node in class_node.body:
|
|
190
|
+
if isinstance(node, ast.AnnAssign) and self._is_mapped_annotation(node.annotation):
|
|
191
|
+
return True
|
|
192
|
+
return False
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
def _is_mapped_annotation(annotation: ast.expr) -> bool:
|
|
196
|
+
"""Check if an annotation is a SQLAlchemy Mapped[...] subscript."""
|
|
197
|
+
return (
|
|
198
|
+
isinstance(annotation, ast.Subscript)
|
|
199
|
+
and isinstance(annotation.value, ast.Name)
|
|
200
|
+
and annotation.value.id == "Mapped"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
def _unwrap_mapped(type_str: str) -> str:
|
|
205
|
+
"""Strip a SQLAlchemy Mapped[...] wrapper, returning the inner type string."""
|
|
206
|
+
if type_str.startswith("Mapped[") and type_str.endswith("]"):
|
|
207
|
+
return type_str[len("Mapped[") : -1]
|
|
208
|
+
return type_str
|
|
209
|
+
|
|
210
|
+
def _parse_table_class(self, class_node: ast.ClassDef, source: str) -> None:
|
|
211
|
+
"""Parse a table class and all its inherited fields."""
|
|
212
|
+
# Get table name (annotated: SQLModel, or plain assignment: SQLAlchemy)
|
|
213
|
+
table_name = None
|
|
214
|
+
for node in class_node.body:
|
|
215
|
+
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
|
|
216
|
+
if node.target.id == "__tablename__" and isinstance(node.value, ast.Constant):
|
|
217
|
+
table_name = str(node.value.value)
|
|
218
|
+
elif isinstance(node, ast.Assign) and isinstance(node.value, ast.Constant):
|
|
219
|
+
for target in node.targets:
|
|
220
|
+
if isinstance(target, ast.Name) and target.id == "__tablename__":
|
|
221
|
+
table_name = str(node.value.value)
|
|
222
|
+
|
|
223
|
+
if not table_name:
|
|
224
|
+
table_name = self._to_snake_case(class_node.name)
|
|
225
|
+
|
|
226
|
+
# Check if link table
|
|
227
|
+
is_link_table = "Link" in class_node.name
|
|
228
|
+
|
|
229
|
+
# Get base classes
|
|
230
|
+
base_classes: List[str] = []
|
|
231
|
+
for base in class_node.bases:
|
|
232
|
+
if isinstance(base, ast.Name):
|
|
233
|
+
base_classes.append(base.id)
|
|
234
|
+
|
|
235
|
+
entity = EntityInfo(
|
|
236
|
+
name=class_node.name,
|
|
237
|
+
table_name=table_name,
|
|
238
|
+
is_link_table=is_link_table,
|
|
239
|
+
base_classes=base_classes,
|
|
240
|
+
source=source,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Collect fields from this class and all base classes (recursively)
|
|
244
|
+
fields_dict: Dict[str, FieldInfo] = {}
|
|
245
|
+
relationships_dict: Dict[str, Tuple[str, str, str]] = {}
|
|
246
|
+
|
|
247
|
+
# Recursively collect fields from all ancestors
|
|
248
|
+
base_fields, base_rels = self._collect_fields_recursive(class_node, set(), source)
|
|
249
|
+
fields_dict.update(base_fields)
|
|
250
|
+
relationships_dict.update(base_rels)
|
|
251
|
+
|
|
252
|
+
entity.fields = list(fields_dict.values())
|
|
253
|
+
entity.relationships = list(relationships_dict.values())
|
|
254
|
+
|
|
255
|
+
self.entities[class_node.name] = entity
|
|
256
|
+
|
|
257
|
+
def _collect_fields_recursive(
|
|
258
|
+
self, class_node: ast.ClassDef, visited: set[str], model_kind: str
|
|
259
|
+
) -> Tuple[Dict[str, FieldInfo], Dict[str, Tuple[str, str, str]]]:
|
|
260
|
+
"""Recursively collect fields from a class and all its ancestors."""
|
|
261
|
+
fields_dict: Dict[str, FieldInfo] = {}
|
|
262
|
+
relationships_dict: Dict[str, Tuple[str, str, str]] = {}
|
|
263
|
+
|
|
264
|
+
# Avoid infinite recursion
|
|
265
|
+
if class_node.name in visited:
|
|
266
|
+
return fields_dict, relationships_dict
|
|
267
|
+
visited.add(class_node.name)
|
|
268
|
+
|
|
269
|
+
# First, process all base classes recursively
|
|
270
|
+
for base in class_node.bases:
|
|
271
|
+
if isinstance(base, ast.Name):
|
|
272
|
+
base_name = base.id
|
|
273
|
+
if base_name in self.all_classes:
|
|
274
|
+
base_fields, base_rels = self._collect_fields_recursive(
|
|
275
|
+
self.all_classes[base_name], visited, model_kind
|
|
276
|
+
)
|
|
277
|
+
fields_dict.update(base_fields)
|
|
278
|
+
relationships_dict.update(base_rels)
|
|
279
|
+
|
|
280
|
+
# Then add/override with current class fields
|
|
281
|
+
class_fields, class_rels = self._extract_class_fields(class_node, model_kind)
|
|
282
|
+
fields_dict.update(class_fields)
|
|
283
|
+
relationships_dict.update(class_rels)
|
|
284
|
+
|
|
285
|
+
return fields_dict, relationships_dict
|
|
286
|
+
|
|
287
|
+
def _extract_class_fields(
|
|
288
|
+
self, class_node: ast.ClassDef, model_kind: str
|
|
289
|
+
) -> Tuple[Dict[str, FieldInfo], Dict[str, Tuple[str, str, str]]]:
|
|
290
|
+
"""Extract fields and relationships from a class definition."""
|
|
291
|
+
fields: Dict[str, FieldInfo] = {}
|
|
292
|
+
relationships: Dict[str, Tuple[str, str, str]] = {}
|
|
293
|
+
|
|
294
|
+
for node in class_node.body:
|
|
295
|
+
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
|
|
296
|
+
field_name = node.target.id
|
|
297
|
+
|
|
298
|
+
# Skip special attributes
|
|
299
|
+
if field_name.startswith("_"):
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
# Check for an explicit relationship call (SQLModel/SQLAlchemy)
|
|
303
|
+
is_relationship = False
|
|
304
|
+
if node.value and isinstance(node.value, ast.Call):
|
|
305
|
+
if isinstance(node.value.func, ast.Name) and node.value.func.id in (
|
|
306
|
+
"Relationship", # SQLModel
|
|
307
|
+
"relationship", # SQLAlchemy
|
|
308
|
+
):
|
|
309
|
+
is_relationship = True
|
|
310
|
+
|
|
311
|
+
if is_relationship:
|
|
312
|
+
rel = self._parse_relationship(node)
|
|
313
|
+
if rel:
|
|
314
|
+
relationships[field_name] = rel
|
|
315
|
+
continue
|
|
316
|
+
|
|
317
|
+
# Keyless sources (Pydantic/dataclass): a field typed as another
|
|
318
|
+
# model is treated as a relationship rather than a column.
|
|
319
|
+
if model_kind in ("pydantic", "dataclass"):
|
|
320
|
+
rel = self._parse_model_reference(node)
|
|
321
|
+
if rel:
|
|
322
|
+
relationships[field_name] = rel
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
field = self._parse_field(node, model_kind)
|
|
326
|
+
if field:
|
|
327
|
+
fields[field_name] = field
|
|
328
|
+
|
|
329
|
+
return fields, relationships
|
|
330
|
+
|
|
331
|
+
def _parse_field(self, node: ast.AnnAssign, model_kind: str) -> FieldInfo | None:
|
|
332
|
+
"""Parse a field from an annotated assignment."""
|
|
333
|
+
if not isinstance(node.target, ast.Name):
|
|
334
|
+
return None
|
|
335
|
+
|
|
336
|
+
field_name = node.target.id
|
|
337
|
+
type_str = self._unwrap_mapped(ast.unparse(node.annotation))
|
|
338
|
+
|
|
339
|
+
# Check if nullable
|
|
340
|
+
is_nullable = "None" in type_str or "Optional" in type_str
|
|
341
|
+
|
|
342
|
+
# Parse Field() / mapped_column() parameters
|
|
343
|
+
is_primary_key = False
|
|
344
|
+
is_foreign_key = False
|
|
345
|
+
foreign_table = None
|
|
346
|
+
index = False
|
|
347
|
+
|
|
348
|
+
if node.value and isinstance(node.value, ast.Call):
|
|
349
|
+
if isinstance(node.value.func, ast.Name) and node.value.func.id in (
|
|
350
|
+
"Field", # SQLModel
|
|
351
|
+
"mapped_column", # SQLAlchemy
|
|
352
|
+
):
|
|
353
|
+
for keyword in node.value.keywords:
|
|
354
|
+
if keyword.arg == "primary_key" and isinstance(keyword.value, ast.Constant):
|
|
355
|
+
is_primary_key = bool(keyword.value.value)
|
|
356
|
+
elif keyword.arg == "foreign_key" and isinstance(keyword.value, ast.Constant):
|
|
357
|
+
is_foreign_key = True
|
|
358
|
+
foreign_table = str(keyword.value.value)
|
|
359
|
+
elif keyword.arg == "index" and isinstance(keyword.value, ast.Constant):
|
|
360
|
+
index = bool(keyword.value.value)
|
|
361
|
+
|
|
362
|
+
# SQLAlchemy: ForeignKey("table.col") passed to mapped_column()
|
|
363
|
+
fk_target = self._extract_foreign_key_arg(node.value)
|
|
364
|
+
if fk_target is not None:
|
|
365
|
+
is_foreign_key = True
|
|
366
|
+
foreign_table = fk_target
|
|
367
|
+
|
|
368
|
+
# Optional name-based key inference for keyless sources (Pydantic/dataclass).
|
|
369
|
+
# Never overrides explicitly declared keys above.
|
|
370
|
+
if self.infer_keys and model_kind in ("pydantic", "dataclass"):
|
|
371
|
+
if field_name == "id":
|
|
372
|
+
is_primary_key = True
|
|
373
|
+
elif field_name.endswith("_id"):
|
|
374
|
+
is_foreign_key = True
|
|
375
|
+
foreign_table = f"{field_name[: -len('_id')]}.id"
|
|
376
|
+
|
|
377
|
+
# Extract default value
|
|
378
|
+
default_value = self._extract_default_value(node)
|
|
379
|
+
|
|
380
|
+
# Clean up type string
|
|
381
|
+
type_str = type_str.replace(" | None", "").replace("Optional[", "").replace("]", "")
|
|
382
|
+
|
|
383
|
+
return FieldInfo(
|
|
384
|
+
name=field_name,
|
|
385
|
+
type_str=type_str,
|
|
386
|
+
is_primary_key=is_primary_key,
|
|
387
|
+
is_foreign_key=is_foreign_key,
|
|
388
|
+
is_nullable=is_nullable,
|
|
389
|
+
foreign_table=foreign_table,
|
|
390
|
+
index=index,
|
|
391
|
+
default_value=default_value,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
@staticmethod
|
|
395
|
+
def _extract_foreign_key_arg(call: ast.Call) -> str | None:
|
|
396
|
+
"""Extract the target from a ForeignKey("table.col") call argument.
|
|
397
|
+
|
|
398
|
+
SQLAlchemy expresses foreign keys as `mapped_column(ForeignKey("user.id"))`,
|
|
399
|
+
where ForeignKey may appear as a positional or keyword argument.
|
|
400
|
+
"""
|
|
401
|
+
candidates = list(call.args) + [kw.value for kw in call.keywords]
|
|
402
|
+
for arg in candidates:
|
|
403
|
+
if (
|
|
404
|
+
isinstance(arg, ast.Call)
|
|
405
|
+
and isinstance(arg.func, ast.Name)
|
|
406
|
+
and arg.func.id == "ForeignKey"
|
|
407
|
+
and arg.args
|
|
408
|
+
and isinstance(arg.args[0], ast.Constant)
|
|
409
|
+
):
|
|
410
|
+
return str(arg.args[0].value)
|
|
411
|
+
return None
|
|
412
|
+
|
|
413
|
+
def _extract_default_value(self, node: ast.AnnAssign) -> str | None:
|
|
414
|
+
"""Extract default value from field definition."""
|
|
415
|
+
if not node.value:
|
|
416
|
+
return None
|
|
417
|
+
|
|
418
|
+
# Direct constant value (e.g., field: str = "value")
|
|
419
|
+
if isinstance(node.value, ast.Constant):
|
|
420
|
+
return repr(node.value.value)
|
|
421
|
+
|
|
422
|
+
# Field() / mapped_column() with default parameter
|
|
423
|
+
if isinstance(node.value, ast.Call):
|
|
424
|
+
if isinstance(node.value.func, ast.Name) and node.value.func.id in (
|
|
425
|
+
"Field",
|
|
426
|
+
"mapped_column",
|
|
427
|
+
):
|
|
428
|
+
# Check positional args first (Field(default_value, ...))
|
|
429
|
+
if node.value.args:
|
|
430
|
+
first_arg = node.value.args[0]
|
|
431
|
+
if isinstance(first_arg, ast.Constant):
|
|
432
|
+
return repr(first_arg.value)
|
|
433
|
+
elif isinstance(first_arg, ast.Attribute):
|
|
434
|
+
# Enum value like OrderStatus.PENDING
|
|
435
|
+
return (
|
|
436
|
+
f"{first_arg.value.id}.{first_arg.attr}"
|
|
437
|
+
if isinstance(first_arg.value, ast.Name)
|
|
438
|
+
else None
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Check keyword args (Field(default=value))
|
|
442
|
+
for keyword in node.value.keywords:
|
|
443
|
+
if keyword.arg == "default":
|
|
444
|
+
if isinstance(keyword.value, ast.Constant):
|
|
445
|
+
return repr(keyword.value.value)
|
|
446
|
+
elif isinstance(keyword.value, ast.Attribute):
|
|
447
|
+
# Enum value like OrderStatus.PENDING
|
|
448
|
+
if isinstance(keyword.value.value, ast.Name):
|
|
449
|
+
return f"{keyword.value.value.id}.{keyword.value.attr}"
|
|
450
|
+
|
|
451
|
+
return None
|
|
452
|
+
|
|
453
|
+
def _parse_relationship(self, node: ast.AnnAssign) -> Tuple[str, str, str] | None:
|
|
454
|
+
"""Parse a relationship from an annotated assignment."""
|
|
455
|
+
if not isinstance(node.target, ast.Name):
|
|
456
|
+
return None
|
|
457
|
+
|
|
458
|
+
field_name = node.target.id
|
|
459
|
+
type_str = self._unwrap_mapped(ast.unparse(node.annotation))
|
|
460
|
+
|
|
461
|
+
rel_type = "many" if ("list[" in type_str or "List[" in type_str) else "one"
|
|
462
|
+
target = self._clean_target(type_str)
|
|
463
|
+
|
|
464
|
+
return (target, rel_type, field_name)
|
|
465
|
+
|
|
466
|
+
def _parse_model_reference(self, node: ast.AnnAssign) -> Tuple[str, str, str] | None:
|
|
467
|
+
"""Treat a field typed as another known model as a relationship.
|
|
468
|
+
|
|
469
|
+
Used for Pydantic/dataclass models, which express relationships as plain
|
|
470
|
+
typed attributes (e.g. ``user: User`` or ``items: list[Item]``) rather
|
|
471
|
+
than via a Relationship()/relationship() call.
|
|
472
|
+
"""
|
|
473
|
+
if not isinstance(node.target, ast.Name):
|
|
474
|
+
return None
|
|
475
|
+
|
|
476
|
+
type_str = self._unwrap_mapped(ast.unparse(node.annotation))
|
|
477
|
+
rel_type = "many" if ("list[" in type_str or "List[" in type_str) else "one"
|
|
478
|
+
target = self._clean_target(type_str)
|
|
479
|
+
|
|
480
|
+
target_class = self.all_classes.get(target)
|
|
481
|
+
if target_class is not None and self._classify_source(target_class) is not None:
|
|
482
|
+
return (target, rel_type, node.target.id)
|
|
483
|
+
return None
|
|
484
|
+
|
|
485
|
+
@staticmethod
|
|
486
|
+
def _clean_target(type_str: str) -> str:
|
|
487
|
+
"""Reduce a (possibly wrapped) annotation to a bare target class name."""
|
|
488
|
+
cleaned = type_str
|
|
489
|
+
for token in ("Mapped[", "list[", "List[", "Optional["):
|
|
490
|
+
cleaned = cleaned.replace(token, "")
|
|
491
|
+
cleaned = cleaned.replace("]", "").replace(" | None", "").replace("None", "")
|
|
492
|
+
return cleaned.replace('"', "").replace("'", "").strip()
|
|
493
|
+
|
|
494
|
+
def _to_snake_case(self, name: str) -> str:
|
|
495
|
+
"""Convert CamelCase to snake_case."""
|
|
496
|
+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
497
|
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def parse_models_directory(
|
|
501
|
+
path: Path,
|
|
502
|
+
exclude_patterns: List[str] | None = None,
|
|
503
|
+
infer_keys: bool = False,
|
|
504
|
+
) -> Tuple[Dict[str, EntityInfo], Dict[str, EnumInfo]]:
|
|
505
|
+
"""
|
|
506
|
+
Parse SQLModel, SQLAlchemy, Pydantic and dataclass models in a directory.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
path: Path to directory containing model files
|
|
510
|
+
exclude_patterns: List of case-sensitive glob patterns. An entity is
|
|
511
|
+
excluded if a pattern matches its class name or its table name.
|
|
512
|
+
infer_keys: For keyless sources (Pydantic/dataclass), infer a primary
|
|
513
|
+
key from a field named ``id`` and a foreign key from ``<x>_id``.
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
Tuple of (entities dict, enums dict)
|
|
517
|
+
"""
|
|
518
|
+
parser = ASTDatabaseParser(path, exclude_patterns=exclude_patterns, infer_keys=infer_keys)
|
|
519
|
+
return parser.parse_all_models()
|
erdify/py.typed
ADDED
|
File without changes
|