async-easy-model 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,619 @@
1
+ """
2
+ Visualization utilities for EasyModel models.
3
+
4
+ This module provides tools to visualize EasyModel database schemas,
5
+ including tables, fields, relationships, and automatically generated virtual fields.
6
+ """
7
+
8
+ import inspect
9
+ import sys
10
+ from typing import Dict, List, Optional, Type, Set, Any, Union
11
+ from sqlmodel import SQLModel, Field
12
+ from sqlalchemy import inspect as sa_inspect
13
+
14
+ class ModelVisualizer:
15
+ """
16
+ Helper class for visualizing EasyModel database schemas.
17
+
18
+ This class provides methods to generate visual representations of the database schema,
19
+ including tables, fields, relationships, and automatically generated virtual fields.
20
+
21
+ Attributes:
22
+ model_registry (Dict[str, Type[SQLModel]]): Dictionary of registered models
23
+ title (str): Title for the Mermaid diagram
24
+ """
25
+
26
+ def __init__(self, title: str = "EasyModel Table Schemas"):
27
+ """
28
+ Initialize the ModelVisualizer.
29
+
30
+ Args:
31
+ title: Optional title for the Mermaid diagram
32
+ """
33
+ self.model_registry = {}
34
+ self.title = title
35
+ self._load_registered_models()
36
+
37
+ def set_title(self, title: str) -> None:
38
+ """
39
+ Set or update the title for the Mermaid diagram.
40
+
41
+ Args:
42
+ title: New title for the Mermaid diagram
43
+ """
44
+ self.title = title
45
+
46
+ def _load_registered_models(self):
47
+ """
48
+ Load all registered EasyModel models.
49
+ This should be called after init_db has been executed.
50
+ """
51
+ # First try to get models from auto_relationships registry
52
+ try:
53
+ from .auto_relationships import _model_registry
54
+ if _model_registry:
55
+ self.model_registry = _model_registry.copy()
56
+ return
57
+ except (ImportError, AttributeError):
58
+ pass
59
+
60
+ # Fallback: discover models from loaded modules
61
+ for module_name, module in sys.modules.items():
62
+ if hasattr(module, "__dict__"):
63
+ for cls_name, cls in module.__dict__.items():
64
+ if (inspect.isclass(cls) and
65
+ issubclass(cls, SQLModel) and
66
+ cls != SQLModel and
67
+ hasattr(cls, "__tablename__")):
68
+ self.model_registry[cls.__name__] = cls
69
+
70
+ def _get_foreign_keys(self, model_class: Type[SQLModel]) -> Dict[str, str]:
71
+ """
72
+ Extract foreign key information from a model class.
73
+
74
+ Args:
75
+ model_class: The SQLModel class to extract foreign keys from.
76
+
77
+ Returns:
78
+ Dictionary mapping field names to their foreign key references.
79
+ """
80
+ foreign_keys = {}
81
+
82
+ try:
83
+ # Check model annotations and __fields__ for foreign keys
84
+ if hasattr(model_class, "__annotations__"):
85
+ for field_name, field_type in model_class.__annotations__.items():
86
+ if hasattr(model_class, field_name):
87
+ field_value = getattr(model_class, field_name)
88
+ if hasattr(field_value, "foreign_key"):
89
+ if field_value.foreign_key:
90
+ foreign_keys[field_name] = field_value.foreign_key
91
+
92
+ # Check model attributes for Field instances with foreign_key
93
+ for attr_name in dir(model_class):
94
+ if attr_name.startswith("__") or attr_name in foreign_keys:
95
+ continue
96
+ try:
97
+ attr_value = getattr(model_class, attr_name)
98
+ if hasattr(attr_value, "foreign_key"):
99
+ if attr_value.foreign_key:
100
+ foreign_keys[attr_name] = attr_value.foreign_key
101
+ except (AttributeError, TypeError):
102
+ pass
103
+
104
+ # Try to infer foreign keys from field names ending with _id
105
+ if hasattr(model_class, "__fields__"):
106
+ for field_name in model_class.__fields__:
107
+ if field_name.endswith("_id") and field_name not in foreign_keys:
108
+ related_name = field_name[:-3] # Remove _id suffix
109
+ # Check if there's a model with this name
110
+ for model_name in self.model_registry:
111
+ if model_name.lower() == related_name.lower():
112
+ foreign_keys[field_name] = f"{model_name.lower()}.id"
113
+ break
114
+ except Exception as e:
115
+ # Log but don't re-raise to ensure visualization continues
116
+ print(f"Warning: Error getting foreign keys for {model_class.__name__}: {str(e)}")
117
+
118
+ return foreign_keys
119
+
120
+ def _get_virtual_relationship_fields(self, model_class: Type[SQLModel]) -> Dict[str, Dict[str, Any]]:
121
+ """
122
+ Get virtual relationship fields that are automatically generated.
123
+
124
+ Args:
125
+ model_class: The SQLModel class to extract virtual relationship fields from.
126
+
127
+ Returns:
128
+ Dictionary of virtual relationship fields info.
129
+ """
130
+ virtual_fields = {}
131
+ table_name = getattr(model_class, "__tablename__", model_class.__name__.lower())
132
+
133
+ try:
134
+ # Check for junction tables (tables with two foreign keys that form many-to-many)
135
+ foreign_keys = self._get_foreign_keys(model_class)
136
+
137
+ # Track models that have relations via this model's foreign keys
138
+ related_models = set()
139
+
140
+ # Determine if this is a junction table (has two foreign keys to different tables)
141
+ is_junction_table = False
142
+ if len(foreign_keys) >= 2:
143
+ referenced_tables = set()
144
+ for fk_target in foreign_keys.values():
145
+ if "." in fk_target:
146
+ referenced_tables.add(fk_target.split(".")[0])
147
+
148
+ is_junction_table = len(referenced_tables) >= 2
149
+
150
+ # For each foreign key, determine the virtual relationship field
151
+ for field_name, fk_target in foreign_keys.items():
152
+ if "." not in fk_target:
153
+ continue
154
+
155
+ target_table = fk_target.split(".")[0]
156
+
157
+ # Find the related model class
158
+ target_model = None
159
+ for name, cls in self.model_registry.items():
160
+ cls_table_name = getattr(cls, "__tablename__", name.lower())
161
+ if cls_table_name.lower() == target_table.lower():
162
+ target_model = name
163
+ related_models.add(name)
164
+ break
165
+
166
+ if not target_model:
167
+ continue
168
+
169
+ # The relationship field is typically named without the _id suffix
170
+ rel_name = field_name[:-3] if field_name.endswith("_id") else field_name
171
+
172
+ # Skip if this looks like a duplicate relationship
173
+ if rel_name in virtual_fields:
174
+ continue
175
+
176
+ # Add the virtual relationship field (singular form - one instance)
177
+ virtual_fields[rel_name] = {
178
+ "name": rel_name,
179
+ "type": target_model,
180
+ "is_list": False, # One-to-one or many-to-one
181
+ "related_model": target_model,
182
+ "is_virtual": True,
183
+ "is_required": False # Virtual fields are usually optional
184
+ }
185
+
186
+ # For junction tables, create the many-to-many virtual fields
187
+ if is_junction_table and len(related_models) >= 2:
188
+ # Don't add virtual fields to the junction table itself
189
+ return virtual_fields
190
+
191
+ # For each model, check if it's referenced in a junction table to create many-to-many virtuals
192
+ for junction_name, junction_class in self.model_registry.items():
193
+ junction_table_name = getattr(junction_class, "__tablename__", junction_name.lower())
194
+
195
+ # Skip if this is not a potential junction table
196
+ if junction_class == model_class: # Skip self
197
+ continue
198
+
199
+ junction_fks = self._get_foreign_keys(junction_class)
200
+ if len(junction_fks) < 2: # Not a potential junction
201
+ continue
202
+
203
+ # Check if this model is referenced by the potential junction table
204
+ this_model_referenced = False
205
+ other_referenced_models = set()
206
+
207
+ for fk_target in junction_fks.values():
208
+ if "." not in fk_target:
209
+ continue
210
+
211
+ target_table = fk_target.split(".")[0]
212
+
213
+ if target_table == table_name:
214
+ this_model_referenced = True
215
+ else:
216
+ # Find the model class for this target
217
+ for other_name, other_cls in self.model_registry.items():
218
+ other_table = getattr(other_cls, "__tablename__", other_name.lower())
219
+ if other_table == target_table:
220
+ other_referenced_models.add(other_name)
221
+ break
222
+
223
+ # If this is a junction table connecting this model to others
224
+ if this_model_referenced and other_referenced_models:
225
+ for other_model in other_referenced_models:
226
+ # Create a plural form of the model name for the relationship
227
+ # This is a very simple pluralization, might need to improve
228
+ other_model_lower = other_model.lower()
229
+ plural_name = f"{other_model_lower}s"
230
+
231
+ # Add the many-to-many virtual field (plural form)
232
+ virtual_fields[plural_name] = {
233
+ "name": plural_name,
234
+ "type": f"List[{other_model}]",
235
+ "is_list": True, # Many-to-many
236
+ "related_model": other_model,
237
+ "is_virtual": True,
238
+ "is_required": False
239
+ }
240
+
241
+ except Exception as e:
242
+ # Log but don't re-raise to ensure visualization continues
243
+ print(f"Warning: Error getting virtual fields for {model_class.__name__}: {str(e)}")
244
+
245
+ return virtual_fields
246
+
247
+ def _get_field_information(self, model_class: Type[SQLModel]) -> Dict[str, Dict[str, Any]]:
248
+ """
249
+ Extract field information from a model class.
250
+
251
+ Args:
252
+ model_class: The SQLModel class to extract field information from.
253
+
254
+ Returns:
255
+ Dictionary of field information with field names as keys.
256
+ """
257
+ fields = {}
258
+
259
+ try:
260
+ # Make sure the 'id' field is always included
261
+ fields["id"] = {
262
+ "name": "id",
263
+ "type": "int",
264
+ "is_primary": True,
265
+ "is_foreign_key": False,
266
+ "foreign_key_reference": None,
267
+ "is_virtual": False,
268
+ "is_required": False # Changed: primary keys are not considered "required" for the diagram
269
+ }
270
+
271
+ # Get standard database fields
272
+ if hasattr(model_class, "__annotations__"):
273
+ for field_name, field_type in model_class.__annotations__.items():
274
+ # Skip private fields
275
+ if field_name.startswith("_"):
276
+ continue
277
+
278
+ # Get field type as string
279
+ type_str = getattr(field_type, "__name__", str(field_type))
280
+
281
+ # Check if field is optional
282
+ is_optional = False
283
+ if hasattr(field_type, "__origin__"):
284
+ if field_type.__origin__ is Union:
285
+ args = getattr(field_type, "__args__", [])
286
+ if type(None) in args:
287
+ is_optional = True
288
+ # Simplify type for optional fields (remove Union and NoneType)
289
+ other_types = [arg for arg in args if arg is not type(None)]
290
+ if len(other_types) == 1:
291
+ # If there's only one other type (common case like Optional[str])
292
+ type_str = getattr(other_types[0], "__name__", str(other_types[0]))
293
+ else:
294
+ # For more complex unions, just indicate it's a Union
295
+ type_str = "Union"
296
+
297
+ # Check if it's a primary key
298
+ is_primary = field_name == "id"
299
+
300
+ # Store field information
301
+ fields[field_name] = {
302
+ "name": field_name,
303
+ "type": type_str,
304
+ "is_primary": is_primary,
305
+ "is_foreign_key": False,
306
+ "foreign_key_reference": None,
307
+ "is_virtual": False,
308
+ "is_required": not is_optional and not is_primary # Primary keys are not "required" for the diagram
309
+ }
310
+
311
+ # Get foreign key information and update fields
312
+ foreign_keys = self._get_foreign_keys(model_class)
313
+ for field_name, fk_target in foreign_keys.items():
314
+ if field_name in fields:
315
+ fields[field_name]["is_foreign_key"] = True
316
+ fields[field_name]["foreign_key_reference"] = fk_target
317
+ else:
318
+ # Foreign key field wasn't in annotations, add it
319
+ fields[field_name] = {
320
+ "name": field_name,
321
+ "type": "int", # Most foreign keys are integers
322
+ "is_primary": False,
323
+ "is_foreign_key": True,
324
+ "foreign_key_reference": fk_target,
325
+ "is_virtual": False,
326
+ "is_required": field_name != "id" # Only mark as required if not the id field
327
+ }
328
+
329
+ # Get virtual relationship fields
330
+ virtual_fields = self._get_virtual_relationship_fields(model_class)
331
+ for field_name, field_info in virtual_fields.items():
332
+ fields[field_name] = field_info
333
+
334
+ except Exception as e:
335
+ # Log but don't re-raise to ensure visualization continue
336
+ print(f"Warning: Error processing fields for {model_class.__name__}: {str(e)}")
337
+
338
+ return fields
339
+
340
+ def _generate_mermaid_content(self) -> str:
341
+ """
342
+ Generate the raw Mermaid ER diagram content without markdown code fences.
343
+ This is used internally by both mermaid() and mermaid_link() methods.
344
+
345
+ Returns:
346
+ String containing raw Mermaid ER diagram markup without markdown fences.
347
+ """
348
+ if not self.model_registry:
349
+ return f"---\ntitle: {self.title}\nconfig:\n layout: elk\n---\nerDiagram\n %% No models found. Run init_db first."
350
+
351
+ # Start with the title section
352
+ lines = [
353
+ "---",
354
+ f"title: {self.title}",
355
+ "config:",
356
+ " layout: elk",
357
+ "---",
358
+ "erDiagram"
359
+ ]
360
+
361
+ # Keep track of rendered relationships to avoid duplicates
362
+ rendered_relationships = set()
363
+ processed_models = set()
364
+
365
+ # Try to process all models, but continue even if some fail
366
+ for model_name, model_class in self.model_registry.items():
367
+ try:
368
+ table_name = getattr(model_class, "__tablename__", model_name.lower())
369
+ processed_models.add(model_name)
370
+
371
+ # Add entity definition
372
+ lines.append(f" {table_name} {{")
373
+
374
+ # Get fields for this model
375
+ fields = self._get_field_information(model_class)
376
+
377
+ # Add fields
378
+ for field_name, field_info in fields.items():
379
+ # Format type
380
+ field_type = self._simplify_type_for_mermaid(str(field_info["type"]))
381
+
382
+ # If it's a relationship, use the proper model name
383
+ if field_info.get("is_virtual", False) and field_info.get("related_model"):
384
+ related_model = field_info["related_model"]
385
+ if field_info.get("is_list", False):
386
+ # For list fields like 'tags', use the proper casing for model name
387
+ field_type = f"{related_model}[]"
388
+ else:
389
+ # For single object reference, use the model name directly
390
+ field_type = related_model
391
+
392
+ # Format attributes with proper Mermaid syntax
393
+ attrs_str = self._format_field_attributes(field_info)
394
+
395
+ # Add field
396
+ lines.append(f" {field_type} {field_name}{attrs_str}")
397
+
398
+ # Close entity definition
399
+ lines.append(" }")
400
+
401
+ except Exception as e:
402
+ lines.append(f" %% Error defining {model_name}: {str(e)}")
403
+
404
+ # Add relationships between models
405
+ for model_name, model_class in self.model_registry.items():
406
+ try:
407
+ if model_name not in processed_models:
408
+ continue
409
+
410
+ table_name = getattr(model_class, "__tablename__", model_name.lower())
411
+
412
+ # Get fields for this model
413
+ fields = self._get_field_information(model_class)
414
+
415
+ # Add relationships based on foreign keys
416
+ for field_name, field_info in fields.items():
417
+ if field_info.get("is_foreign_key", False) and field_info.get("foreign_key_reference"):
418
+ # Parse the foreign key reference to get the target table
419
+ fk_ref = field_info["foreign_key_reference"]
420
+ target_table = fk_ref.split(".")[0] if "." in fk_ref else fk_ref
421
+
422
+ # Create relationship ID to avoid duplicates
423
+ rel_id = f"{table_name}_{target_table}_{field_name}"
424
+ if rel_id in rendered_relationships:
425
+ continue
426
+
427
+ # Add the relationship
428
+ lines.append(f" {table_name} ||--o{{ {target_table} : \"{field_name}\"")
429
+ rendered_relationships.add(rel_id)
430
+
431
+ # Add many-to-many relationships
432
+ # Check if this model might be a junction table
433
+ if len(fields) >= 3: # id + at least 2 foreign keys
434
+ foreign_key_fields = [f for f in fields.values() if f.get("is_foreign_key", False)]
435
+
436
+ if len(foreign_key_fields) >= 2:
437
+ # This might be a junction table, try to render special M:N relationship
438
+ for i, fk1 in enumerate(foreign_key_fields):
439
+ for fk2 in foreign_key_fields[i+1:]:
440
+ # Skip if either field doesn't have a foreign key reference
441
+ if not fk1.get("foreign_key_reference") or not fk2.get("foreign_key_reference"):
442
+ continue
443
+
444
+ # Get the target tables
445
+ target1 = fk1["foreign_key_reference"].split(".")[0]
446
+ target2 = fk2["foreign_key_reference"].split(".")[0]
447
+
448
+ # Skip self-references or duplicates
449
+ if target1 == target2:
450
+ continue
451
+
452
+ # Create relationship IDs
453
+ rel_id1 = f"{target1}_{target2}_m2m"
454
+ rel_id2 = f"{target2}_{target1}_m2m"
455
+
456
+ if rel_id1 in rendered_relationships or rel_id2 in rendered_relationships:
457
+ continue
458
+
459
+ # Add the many-to-many relationship directly between the end entities
460
+ lines.append(f" {target1} }}o--o{{ {target2} : \"many-to-many\"")
461
+ rendered_relationships.add(rel_id1)
462
+
463
+ except Exception as e:
464
+ lines.append(f" %% Error processing relationships for {model_name}: {str(e)}")
465
+
466
+ return "\n".join(lines)
467
+
468
+ def _simplify_type_for_mermaid(self, type_str: str) -> str:
469
+ """
470
+ Simplify a Python type string for Mermaid ER diagram display.
471
+
472
+ Args:
473
+ type_str: Python type string to simplify
474
+
475
+ Returns:
476
+ A simplified type string suitable for Mermaid diagrams
477
+ """
478
+ # Remove common Python type prefixes
479
+ simplified = type_str.replace("typing.", "").replace("__main__.", "")
480
+
481
+ # Handle common container types
482
+ if simplified.startswith("List["):
483
+ inner_type = simplified[5:-1] # Extract the type inside List[]
484
+ # Preserve proper casing for model names
485
+ return f"{inner_type}[]"
486
+
487
+ # Handle other complex types that might confuse Mermaid
488
+ if "[" in simplified or "Union" in simplified or "Optional" in simplified:
489
+ # For complex types, just return a more generic type name
490
+ if "str" in simplified:
491
+ return "string"
492
+ elif "int" in simplified or "float" in simplified:
493
+ return "number"
494
+ elif "bool" in simplified:
495
+ return "boolean"
496
+ elif "dict" in simplified or "Dict" in simplified:
497
+ return "object"
498
+ else:
499
+ return "any"
500
+
501
+ # Simple type mapping to more Mermaid-friendly types
502
+ type_map = {
503
+ "str": "string",
504
+ "int": "number",
505
+ "float": "number",
506
+ "bool": "boolean",
507
+ "dict": "object",
508
+ "Dict": "object",
509
+ "datetime": "date",
510
+ "date": "date",
511
+ "time": "time",
512
+ "bytes": "binary",
513
+ "bytearray": "binary"
514
+ }
515
+
516
+ return type_map.get(simplified, simplified)
517
+
518
+ def _get_model_name_for_table(self, table_name: str) -> str:
519
+ """
520
+ Get the proper cased model name for a table name.
521
+
522
+ Args:
523
+ table_name: The table name to find the model name for
524
+
525
+ Returns:
526
+ The properly cased model name
527
+ """
528
+ for model_name, model_class in self.model_registry.items():
529
+ if getattr(model_class, "__tablename__", model_name.lower()) == table_name:
530
+ return model_name
531
+ return table_name # Fallback to table name if no model found
532
+
533
+ def _format_field_attributes(self, field_info: Dict[str, Any]) -> str:
534
+ """
535
+ Format field attributes according to Mermaid syntax.
536
+ Only the last attribute should have double quotes, and attributes are separated by spaces.
537
+
538
+ Args:
539
+ field_info: Dictionary with field information
540
+
541
+ Returns:
542
+ String with properly formatted attributes for the Mermaid diagram
543
+ """
544
+ attrs = []
545
+
546
+ # Order is important: PK, FK, and then other attributes
547
+ if field_info.get("is_primary", False):
548
+ attrs.append("PK")
549
+ if field_info.get("is_foreign_key", False):
550
+ attrs.append("FK")
551
+
552
+ # The comment attribute (should be last and in quotes)
553
+ comment = None
554
+ if field_info.get("is_virtual", False):
555
+ comment = "virtual"
556
+ elif field_info.get("is_required", False):
557
+ comment = "required"
558
+
559
+ # Format the attribute string
560
+ if not attrs and not comment:
561
+ return ""
562
+
563
+ result = ""
564
+ if attrs:
565
+ result = " " + " ".join(attrs)
566
+
567
+ if comment:
568
+ result += f' "{comment}"'
569
+
570
+ return result
571
+
572
+ def mermaid(self) -> str:
573
+ """
574
+ Generate a Mermaid ER diagram for all registered models.
575
+
576
+ Returns:
577
+ String containing Mermaid ER diagram markup in markdown format.
578
+ """
579
+ # Get the raw Mermaid content
580
+ content = self._generate_mermaid_content()
581
+
582
+ # Wrap in markdown code fences
583
+ return f"```mermaid\n{content}\n```"
584
+
585
+ def mermaid_link(self) -> str:
586
+ """
587
+ Generate a Mermaid Live Editor link for the ER diagram.
588
+
589
+ Returns:
590
+ URL string that opens the diagram in Mermaid Live Editor.
591
+ """
592
+ # Code to handle mermaid.live base64 links
593
+ import base64, json, zlib
594
+
595
+ def js_btoa(data):
596
+ return base64.b64encode(data)
597
+
598
+ def pako_deflate(data):
599
+ compress = zlib.compressobj(9, zlib.DEFLATED, 15, 8, zlib.Z_DEFAULT_STRATEGY)
600
+ compressed_data = compress.compress(data)
601
+ compressed_data += compress.flush()
602
+ return compressed_data
603
+
604
+ def gen_pako_link(graph_markdown: str):
605
+ jGraph = {"code": graph_markdown, "mermaid": {"theme": "default"}}
606
+ byte_str = json.dumps(jGraph).encode('utf-8')
607
+ deflated = pako_deflate(byte_str)
608
+ d_encode = js_btoa(deflated)
609
+ link = 'https://mermaid.live/edit#pako:' + d_encode.decode('ascii')
610
+ return link
611
+
612
+ # Get the raw Mermaid content directly (without markdown fences)
613
+ mermaid_content = self._generate_mermaid_content()
614
+
615
+ # Generate the link
616
+ return gen_pako_link(mermaid_content)
617
+
618
+ # Alias for backward compatibility
619
+ generate_mermaid_er_diagram = mermaid