erdify 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
erdify/parser.py ADDED
@@ -0,0 +1,519 @@
1
+ """AST-based parser for SQLModel, SQLAlchemy, Pydantic and dataclass models."""
2
+
3
+ import ast
4
+ import re
5
+ import sys
6
+ from fnmatch import fnmatchcase
7
+ from pathlib import Path
8
+ from typing import Dict, List, Tuple
9
+
10
+ from .config import EntityInfo, EnumInfo, FieldInfo
11
+
12
+
13
+ class ASTDatabaseParser:
14
+ """Parses database models using AST to extract schema information."""
15
+
16
+ def __init__(
17
+ self,
18
+ database_path: Path,
19
+ exclude_patterns: List[str] | None = None,
20
+ infer_keys: bool = False,
21
+ ):
22
+ self.database_path = database_path
23
+ self.exclude_patterns = exclude_patterns or []
24
+ self.infer_keys = infer_keys
25
+ self.entities: Dict[str, EntityInfo] = {}
26
+ self.enums: Dict[str, EnumInfo] = {}
27
+ self.all_classes: Dict[str, ast.ClassDef] = {} # Store all class definitions
28
+ self.file_trees: Dict[Path, ast.Module] = {}
29
+
30
+ def parse_all_models(self) -> Tuple[Dict[str, EntityInfo], Dict[str, EnumInfo]]:
31
+ """Parse all model files in the database directory."""
32
+ model_files = list(self.database_path.rglob("models.py"))
33
+
34
+ # First pass: parse all files and collect class definitions
35
+ for model_file in model_files:
36
+ try:
37
+ with open(model_file, "r") as f:
38
+ content = f.read()
39
+ tree = ast.parse(content)
40
+ self.file_trees[model_file] = tree
41
+
42
+ # Collect all class definitions
43
+ for node in ast.walk(tree):
44
+ if isinstance(node, ast.ClassDef):
45
+ self.all_classes[node.name] = node
46
+ except Exception as e:
47
+ print(f"Error parsing {model_file}: {e}", file=sys.stderr)
48
+
49
+ # Second pass: process enum and model classes
50
+ for class_node in self.all_classes.values():
51
+ if self._is_enum_class(class_node):
52
+ self._parse_enum_class(class_node)
53
+ continue
54
+ source = self._classify_source(class_node)
55
+ if source is not None:
56
+ self._parse_table_class(class_node, source)
57
+
58
+ # Third pass: apply exclude patterns
59
+ self._apply_exclude_patterns()
60
+
61
+ return self.entities, self.enums
62
+
63
+ def _apply_exclude_patterns(self) -> None:
64
+ """Remove excluded entities and strip relationships pointing at them.
65
+
66
+ A pattern matches an entity if its case-sensitive glob matches either
67
+ the class name or the table name. Relationships in surviving entities
68
+ that target an excluded entity are dropped so no dangling lines remain.
69
+ """
70
+ if not self.exclude_patterns:
71
+ return
72
+
73
+ def is_excluded(entity: EntityInfo) -> bool:
74
+ return any(
75
+ fnmatchcase(entity.name, pattern) or fnmatchcase(entity.table_name, pattern)
76
+ for pattern in self.exclude_patterns
77
+ )
78
+
79
+ excluded_names = {name for name, e in self.entities.items() if is_excluded(e)}
80
+
81
+ self.entities = {name: e for name, e in self.entities.items() if name not in excluded_names}
82
+
83
+ for entity in self.entities.values():
84
+ entity.relationships = [
85
+ rel for rel in entity.relationships if rel[0] not in excluded_names
86
+ ]
87
+
88
+ def _is_enum_class(self, class_node: ast.ClassDef) -> bool:
89
+ """Check if a class is an Enum."""
90
+ for base in class_node.bases:
91
+ if isinstance(base, ast.Attribute):
92
+ if base.attr == "Enum":
93
+ return True
94
+ elif isinstance(base, ast.Name):
95
+ if base.id == "Enum":
96
+ return True
97
+ return False
98
+
99
+ def _parse_enum_class(self, class_node: ast.ClassDef) -> None:
100
+ """Parse an enum class and extract its values."""
101
+ values: List[str] = []
102
+ for node in class_node.body:
103
+ if isinstance(node, ast.Assign):
104
+ for target in node.targets:
105
+ if isinstance(target, ast.Name) and not target.id.startswith("_"):
106
+ values.append(target.id)
107
+
108
+ if values:
109
+ self.enums[class_node.name] = EnumInfo(name=class_node.name, values=values)
110
+
111
+ def _classify_source(self, class_node: ast.ClassDef) -> str | None:
112
+ """Classify which model framework a class belongs to.
113
+
114
+ Returns one of "sqlmodel", "sqlalchemy", "pydantic", "dataclass", or
115
+ None if the class is not a drawable entity.
116
+ """
117
+ # SQLModel: declared with table=True
118
+ for keyword in class_node.keywords:
119
+ if keyword.arg == "table" and isinstance(keyword.value, ast.Constant):
120
+ if keyword.value.value:
121
+ return "sqlmodel"
122
+
123
+ # SQLAlchemy 2.0: a concrete mapped class has __tablename__ and Mapped[...] columns.
124
+ # Mixins/abstract bases (Mapped fields but no __tablename__) are excluded here so
125
+ # they are not emitted as entities, but their fields are still inherited.
126
+ if self._has_tablename(class_node) and self._has_mapped_field(class_node):
127
+ return "sqlalchemy"
128
+
129
+ # Plain @dataclass
130
+ if self._is_dataclass(class_node):
131
+ return "dataclass"
132
+
133
+ # Pydantic: inherits from BaseModel (directly or transitively)
134
+ if self._inherits_basemodel(class_node):
135
+ return "pydantic"
136
+
137
+ return None
138
+
139
+ @staticmethod
140
+ def _is_dataclass(class_node: ast.ClassDef) -> bool:
141
+ """Check if a class is decorated with @dataclass (bare or called)."""
142
+ for dec in class_node.decorator_list:
143
+ if isinstance(dec, ast.Name) and dec.id == "dataclass":
144
+ return True
145
+ if isinstance(dec, ast.Attribute) and dec.attr == "dataclass":
146
+ return True
147
+ if isinstance(dec, ast.Call):
148
+ func = dec.func
149
+ if isinstance(func, ast.Name) and func.id == "dataclass":
150
+ return True
151
+ if isinstance(func, ast.Attribute) and func.attr == "dataclass":
152
+ return True
153
+ return False
154
+
155
+ def _inherits_basemodel(
156
+ self, class_node: ast.ClassDef, visited: set[str] | None = None
157
+ ) -> bool:
158
+ """Check if a class inherits from Pydantic's BaseModel, directly or via ancestors."""
159
+ visited = visited if visited is not None else set()
160
+ if class_node.name in visited:
161
+ return False
162
+ visited.add(class_node.name)
163
+
164
+ for base in class_node.bases:
165
+ if isinstance(base, ast.Name):
166
+ if base.id == "BaseModel":
167
+ return True
168
+ ancestor = self.all_classes.get(base.id)
169
+ if ancestor is not None and self._inherits_basemodel(ancestor, visited):
170
+ return True
171
+ elif isinstance(base, ast.Attribute) and base.attr == "BaseModel":
172
+ return True
173
+ return False
174
+
175
+ def _has_tablename(self, class_node: ast.ClassDef) -> bool:
176
+ """Check if a class assigns __tablename__ (annotated or plain)."""
177
+ for node in class_node.body:
178
+ if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
179
+ if node.target.id == "__tablename__":
180
+ return True
181
+ elif isinstance(node, ast.Assign):
182
+ for target in node.targets:
183
+ if isinstance(target, ast.Name) and target.id == "__tablename__":
184
+ return True
185
+ return False
186
+
187
+ def _has_mapped_field(self, class_node: ast.ClassDef) -> bool:
188
+ """Check if a class has at least one Mapped[...] annotated field."""
189
+ for node in class_node.body:
190
+ if isinstance(node, ast.AnnAssign) and self._is_mapped_annotation(node.annotation):
191
+ return True
192
+ return False
193
+
194
+ @staticmethod
195
+ def _is_mapped_annotation(annotation: ast.expr) -> bool:
196
+ """Check if an annotation is a SQLAlchemy Mapped[...] subscript."""
197
+ return (
198
+ isinstance(annotation, ast.Subscript)
199
+ and isinstance(annotation.value, ast.Name)
200
+ and annotation.value.id == "Mapped"
201
+ )
202
+
203
+ @staticmethod
204
+ def _unwrap_mapped(type_str: str) -> str:
205
+ """Strip a SQLAlchemy Mapped[...] wrapper, returning the inner type string."""
206
+ if type_str.startswith("Mapped[") and type_str.endswith("]"):
207
+ return type_str[len("Mapped[") : -1]
208
+ return type_str
209
+
210
+ def _parse_table_class(self, class_node: ast.ClassDef, source: str) -> None:
211
+ """Parse a table class and all its inherited fields."""
212
+ # Get table name (annotated: SQLModel, or plain assignment: SQLAlchemy)
213
+ table_name = None
214
+ for node in class_node.body:
215
+ if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
216
+ if node.target.id == "__tablename__" and isinstance(node.value, ast.Constant):
217
+ table_name = str(node.value.value)
218
+ elif isinstance(node, ast.Assign) and isinstance(node.value, ast.Constant):
219
+ for target in node.targets:
220
+ if isinstance(target, ast.Name) and target.id == "__tablename__":
221
+ table_name = str(node.value.value)
222
+
223
+ if not table_name:
224
+ table_name = self._to_snake_case(class_node.name)
225
+
226
+ # Check if link table
227
+ is_link_table = "Link" in class_node.name
228
+
229
+ # Get base classes
230
+ base_classes: List[str] = []
231
+ for base in class_node.bases:
232
+ if isinstance(base, ast.Name):
233
+ base_classes.append(base.id)
234
+
235
+ entity = EntityInfo(
236
+ name=class_node.name,
237
+ table_name=table_name,
238
+ is_link_table=is_link_table,
239
+ base_classes=base_classes,
240
+ source=source,
241
+ )
242
+
243
+ # Collect fields from this class and all base classes (recursively)
244
+ fields_dict: Dict[str, FieldInfo] = {}
245
+ relationships_dict: Dict[str, Tuple[str, str, str]] = {}
246
+
247
+ # Recursively collect fields from all ancestors
248
+ base_fields, base_rels = self._collect_fields_recursive(class_node, set(), source)
249
+ fields_dict.update(base_fields)
250
+ relationships_dict.update(base_rels)
251
+
252
+ entity.fields = list(fields_dict.values())
253
+ entity.relationships = list(relationships_dict.values())
254
+
255
+ self.entities[class_node.name] = entity
256
+
257
+ def _collect_fields_recursive(
258
+ self, class_node: ast.ClassDef, visited: set[str], model_kind: str
259
+ ) -> Tuple[Dict[str, FieldInfo], Dict[str, Tuple[str, str, str]]]:
260
+ """Recursively collect fields from a class and all its ancestors."""
261
+ fields_dict: Dict[str, FieldInfo] = {}
262
+ relationships_dict: Dict[str, Tuple[str, str, str]] = {}
263
+
264
+ # Avoid infinite recursion
265
+ if class_node.name in visited:
266
+ return fields_dict, relationships_dict
267
+ visited.add(class_node.name)
268
+
269
+ # First, process all base classes recursively
270
+ for base in class_node.bases:
271
+ if isinstance(base, ast.Name):
272
+ base_name = base.id
273
+ if base_name in self.all_classes:
274
+ base_fields, base_rels = self._collect_fields_recursive(
275
+ self.all_classes[base_name], visited, model_kind
276
+ )
277
+ fields_dict.update(base_fields)
278
+ relationships_dict.update(base_rels)
279
+
280
+ # Then add/override with current class fields
281
+ class_fields, class_rels = self._extract_class_fields(class_node, model_kind)
282
+ fields_dict.update(class_fields)
283
+ relationships_dict.update(class_rels)
284
+
285
+ return fields_dict, relationships_dict
286
+
287
+ def _extract_class_fields(
288
+ self, class_node: ast.ClassDef, model_kind: str
289
+ ) -> Tuple[Dict[str, FieldInfo], Dict[str, Tuple[str, str, str]]]:
290
+ """Extract fields and relationships from a class definition."""
291
+ fields: Dict[str, FieldInfo] = {}
292
+ relationships: Dict[str, Tuple[str, str, str]] = {}
293
+
294
+ for node in class_node.body:
295
+ if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
296
+ field_name = node.target.id
297
+
298
+ # Skip special attributes
299
+ if field_name.startswith("_"):
300
+ continue
301
+
302
+ # Check for an explicit relationship call (SQLModel/SQLAlchemy)
303
+ is_relationship = False
304
+ if node.value and isinstance(node.value, ast.Call):
305
+ if isinstance(node.value.func, ast.Name) and node.value.func.id in (
306
+ "Relationship", # SQLModel
307
+ "relationship", # SQLAlchemy
308
+ ):
309
+ is_relationship = True
310
+
311
+ if is_relationship:
312
+ rel = self._parse_relationship(node)
313
+ if rel:
314
+ relationships[field_name] = rel
315
+ continue
316
+
317
+ # Keyless sources (Pydantic/dataclass): a field typed as another
318
+ # model is treated as a relationship rather than a column.
319
+ if model_kind in ("pydantic", "dataclass"):
320
+ rel = self._parse_model_reference(node)
321
+ if rel:
322
+ relationships[field_name] = rel
323
+ continue
324
+
325
+ field = self._parse_field(node, model_kind)
326
+ if field:
327
+ fields[field_name] = field
328
+
329
+ return fields, relationships
330
+
331
+ def _parse_field(self, node: ast.AnnAssign, model_kind: str) -> FieldInfo | None:
332
+ """Parse a field from an annotated assignment."""
333
+ if not isinstance(node.target, ast.Name):
334
+ return None
335
+
336
+ field_name = node.target.id
337
+ type_str = self._unwrap_mapped(ast.unparse(node.annotation))
338
+
339
+ # Check if nullable
340
+ is_nullable = "None" in type_str or "Optional" in type_str
341
+
342
+ # Parse Field() / mapped_column() parameters
343
+ is_primary_key = False
344
+ is_foreign_key = False
345
+ foreign_table = None
346
+ index = False
347
+
348
+ if node.value and isinstance(node.value, ast.Call):
349
+ if isinstance(node.value.func, ast.Name) and node.value.func.id in (
350
+ "Field", # SQLModel
351
+ "mapped_column", # SQLAlchemy
352
+ ):
353
+ for keyword in node.value.keywords:
354
+ if keyword.arg == "primary_key" and isinstance(keyword.value, ast.Constant):
355
+ is_primary_key = bool(keyword.value.value)
356
+ elif keyword.arg == "foreign_key" and isinstance(keyword.value, ast.Constant):
357
+ is_foreign_key = True
358
+ foreign_table = str(keyword.value.value)
359
+ elif keyword.arg == "index" and isinstance(keyword.value, ast.Constant):
360
+ index = bool(keyword.value.value)
361
+
362
+ # SQLAlchemy: ForeignKey("table.col") passed to mapped_column()
363
+ fk_target = self._extract_foreign_key_arg(node.value)
364
+ if fk_target is not None:
365
+ is_foreign_key = True
366
+ foreign_table = fk_target
367
+
368
+ # Optional name-based key inference for keyless sources (Pydantic/dataclass).
369
+ # Never overrides explicitly declared keys above.
370
+ if self.infer_keys and model_kind in ("pydantic", "dataclass"):
371
+ if field_name == "id":
372
+ is_primary_key = True
373
+ elif field_name.endswith("_id"):
374
+ is_foreign_key = True
375
+ foreign_table = f"{field_name[: -len('_id')]}.id"
376
+
377
+ # Extract default value
378
+ default_value = self._extract_default_value(node)
379
+
380
+ # Clean up type string
381
+ type_str = type_str.replace(" | None", "").replace("Optional[", "").replace("]", "")
382
+
383
+ return FieldInfo(
384
+ name=field_name,
385
+ type_str=type_str,
386
+ is_primary_key=is_primary_key,
387
+ is_foreign_key=is_foreign_key,
388
+ is_nullable=is_nullable,
389
+ foreign_table=foreign_table,
390
+ index=index,
391
+ default_value=default_value,
392
+ )
393
+
394
+ @staticmethod
395
+ def _extract_foreign_key_arg(call: ast.Call) -> str | None:
396
+ """Extract the target from a ForeignKey("table.col") call argument.
397
+
398
+ SQLAlchemy expresses foreign keys as `mapped_column(ForeignKey("user.id"))`,
399
+ where ForeignKey may appear as a positional or keyword argument.
400
+ """
401
+ candidates = list(call.args) + [kw.value for kw in call.keywords]
402
+ for arg in candidates:
403
+ if (
404
+ isinstance(arg, ast.Call)
405
+ and isinstance(arg.func, ast.Name)
406
+ and arg.func.id == "ForeignKey"
407
+ and arg.args
408
+ and isinstance(arg.args[0], ast.Constant)
409
+ ):
410
+ return str(arg.args[0].value)
411
+ return None
412
+
413
+ def _extract_default_value(self, node: ast.AnnAssign) -> str | None:
414
+ """Extract default value from field definition."""
415
+ if not node.value:
416
+ return None
417
+
418
+ # Direct constant value (e.g., field: str = "value")
419
+ if isinstance(node.value, ast.Constant):
420
+ return repr(node.value.value)
421
+
422
+ # Field() / mapped_column() with default parameter
423
+ if isinstance(node.value, ast.Call):
424
+ if isinstance(node.value.func, ast.Name) and node.value.func.id in (
425
+ "Field",
426
+ "mapped_column",
427
+ ):
428
+ # Check positional args first (Field(default_value, ...))
429
+ if node.value.args:
430
+ first_arg = node.value.args[0]
431
+ if isinstance(first_arg, ast.Constant):
432
+ return repr(first_arg.value)
433
+ elif isinstance(first_arg, ast.Attribute):
434
+ # Enum value like OrderStatus.PENDING
435
+ return (
436
+ f"{first_arg.value.id}.{first_arg.attr}"
437
+ if isinstance(first_arg.value, ast.Name)
438
+ else None
439
+ )
440
+
441
+ # Check keyword args (Field(default=value))
442
+ for keyword in node.value.keywords:
443
+ if keyword.arg == "default":
444
+ if isinstance(keyword.value, ast.Constant):
445
+ return repr(keyword.value.value)
446
+ elif isinstance(keyword.value, ast.Attribute):
447
+ # Enum value like OrderStatus.PENDING
448
+ if isinstance(keyword.value.value, ast.Name):
449
+ return f"{keyword.value.value.id}.{keyword.value.attr}"
450
+
451
+ return None
452
+
453
+ def _parse_relationship(self, node: ast.AnnAssign) -> Tuple[str, str, str] | None:
454
+ """Parse a relationship from an annotated assignment."""
455
+ if not isinstance(node.target, ast.Name):
456
+ return None
457
+
458
+ field_name = node.target.id
459
+ type_str = self._unwrap_mapped(ast.unparse(node.annotation))
460
+
461
+ rel_type = "many" if ("list[" in type_str or "List[" in type_str) else "one"
462
+ target = self._clean_target(type_str)
463
+
464
+ return (target, rel_type, field_name)
465
+
466
+ def _parse_model_reference(self, node: ast.AnnAssign) -> Tuple[str, str, str] | None:
467
+ """Treat a field typed as another known model as a relationship.
468
+
469
+ Used for Pydantic/dataclass models, which express relationships as plain
470
+ typed attributes (e.g. ``user: User`` or ``items: list[Item]``) rather
471
+ than via a Relationship()/relationship() call.
472
+ """
473
+ if not isinstance(node.target, ast.Name):
474
+ return None
475
+
476
+ type_str = self._unwrap_mapped(ast.unparse(node.annotation))
477
+ rel_type = "many" if ("list[" in type_str or "List[" in type_str) else "one"
478
+ target = self._clean_target(type_str)
479
+
480
+ target_class = self.all_classes.get(target)
481
+ if target_class is not None and self._classify_source(target_class) is not None:
482
+ return (target, rel_type, node.target.id)
483
+ return None
484
+
485
+ @staticmethod
486
+ def _clean_target(type_str: str) -> str:
487
+ """Reduce a (possibly wrapped) annotation to a bare target class name."""
488
+ cleaned = type_str
489
+ for token in ("Mapped[", "list[", "List[", "Optional["):
490
+ cleaned = cleaned.replace(token, "")
491
+ cleaned = cleaned.replace("]", "").replace(" | None", "").replace("None", "")
492
+ return cleaned.replace('"', "").replace("'", "").strip()
493
+
494
+ def _to_snake_case(self, name: str) -> str:
495
+ """Convert CamelCase to snake_case."""
496
+ s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
497
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
498
+
499
+
500
+ def parse_models_directory(
501
+ path: Path,
502
+ exclude_patterns: List[str] | None = None,
503
+ infer_keys: bool = False,
504
+ ) -> Tuple[Dict[str, EntityInfo], Dict[str, EnumInfo]]:
505
+ """
506
+ Parse SQLModel, SQLAlchemy, Pydantic and dataclass models in a directory.
507
+
508
+ Args:
509
+ path: Path to directory containing model files
510
+ exclude_patterns: List of case-sensitive glob patterns. An entity is
511
+ excluded if a pattern matches its class name or its table name.
512
+ infer_keys: For keyless sources (Pydantic/dataclass), infer a primary
513
+ key from a field named ``id`` and a foreign key from ``<x>_id``.
514
+
515
+ Returns:
516
+ Tuple of (entities dict, enums dict)
517
+ """
518
+ parser = ASTDatabaseParser(path, exclude_patterns=exclude_patterns, infer_keys=infer_keys)
519
+ return parser.parse_all_models()
erdify/py.typed ADDED
File without changes