sqlobjects 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlobjects/metadata.py ADDED
@@ -0,0 +1,1130 @@
1
+ """SQLObjects Metadata Module - Model metaclass and type inference"""
2
+
3
+ import re
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING, Any, Protocol, Union, get_args, get_origin
6
+
7
+ from sqlalchemy import CheckConstraint, Index, UniqueConstraint
8
+ from sqlalchemy import MetaData as SqlAlchemyMetaData
9
+
10
+ from .fields import Auto, Column
11
+ from .relations import M2MTable, RelationshipDescriptor, RelationshipResolver
12
+ from .utils.naming import to_snake_case
13
+ from .utils.pattern import pluralize
14
+
15
+
16
+ if TYPE_CHECKING:
17
+ from .model import ObjectModel
18
+
19
+ __all__ = [
20
+ "ModelProcessor",
21
+ "ModelRegistry",
22
+ "ModelConfig",
23
+ "M2MTable",
24
+ "index",
25
+ "constraint",
26
+ "unique",
27
+ ]
28
+
29
+
30
+ _FIELD_NAME_PATTERN = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\b")
31
+
32
+
33
+ # Internal Protocol types for type safety
34
+ class _FieldProtocol(Protocol):
35
+ """Internal protocol for field objects."""
36
+
37
+ column: Any
38
+
39
+
40
+ @dataclass
41
+ class _RawModelConfig:
42
+ """Raw model configuration with optional fields for parsing phase."""
43
+
44
+ table_name: str | None = None
45
+ verbose_name: str | None = None
46
+ verbose_name_plural: str | None = None
47
+ ordering: list[str] = field(default_factory=list)
48
+ indexes: list[Index] = field(default_factory=list)
49
+ constraints: list[CheckConstraint | UniqueConstraint] = field(default_factory=list)
50
+ description: str | None = None
51
+ db_options: dict[str, dict[str, Any]] = field(default_factory=dict)
52
+ custom: dict[str, Any] = field(default_factory=dict)
53
+
54
+
55
+ @dataclass
56
+ class ModelConfig:
57
+ """Complete model configuration with all required fields filled.
58
+
59
+ This dataclass holds all configuration options that can be applied to a model,
60
+ including basic settings, database constraints, metadata, and database-specific
61
+ optimizations. All required fields are guaranteed to have values.
62
+
63
+ Attributes:
64
+ table_name: Database table name (never None after processing)
65
+ verbose_name: Human-readable singular name for the model (never None)
66
+ verbose_name_plural: Human-readable plural name for the model (never None)
67
+ ordering: Default ordering for queries (e.g., ['-created_at', 'name'])
68
+ indexes: List of database indexes to create for the table
69
+ constraints: List of database constraints (check, unique) for the table
70
+ description: Detailed description of the model's purpose (can be None)
71
+ db_options: Database-specific configuration options by dialect
72
+ custom: Custom configuration values for application-specific use
73
+ field_validators: Field-level validators registry
74
+ field_metadata: Unified field metadata information
75
+ """
76
+
77
+ table_name: str
78
+ verbose_name: str
79
+ verbose_name_plural: str
80
+ ordering: list[str] = field(default_factory=list)
81
+ indexes: list[Index] = field(default_factory=list)
82
+ constraints: list[CheckConstraint | UniqueConstraint] = field(default_factory=list)
83
+ description: str | None = None
84
+ db_options: dict[str, dict[str, Any]] = field(default_factory=dict)
85
+ custom: dict[str, Any] = field(default_factory=dict)
86
+ field_validators: dict[str, list[Any]] = field(default_factory=dict)
87
+ field_metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
88
+
89
+
90
+ def _infer_type_from_annotation(annotation) -> tuple[str, dict[str, Any]]:
91
+ """Infer type name and parameters from Column[T] annotation.
92
+
93
+ Args:
94
+ annotation: Type annotation to analyze
95
+
96
+ Returns:
97
+ Tuple of (type_name, parameters_dict)
98
+ """
99
+ if get_origin(annotation) is Column:
100
+ args = get_args(annotation)
101
+ if args:
102
+ python_type = args[0]
103
+ return _map_python_type_to_sqlalchemy(python_type)
104
+
105
+ return "string", {}
106
+
107
+
108
+ def _map_python_type_to_sqlalchemy(python_type) -> tuple[str, dict[str, Any]]:
109
+ """Map Python types to SQLAlchemy type names and parameters.
110
+
111
+ Args:
112
+ python_type: Python type to map
113
+
114
+ Returns:
115
+ Tuple of (type_name, parameters_dict)
116
+ """
117
+ # Handle Optional[T] -> Union[T, None]
118
+ if get_origin(python_type) is Union:
119
+ union_args = get_args(python_type)
120
+ if len(union_args) == 2 and type(None) in union_args:
121
+ # Optional[T] case
122
+ non_none_type = union_args[0] if union_args[1] is type(None) else union_args[1]
123
+ type_name, params = _map_python_type_to_sqlalchemy(non_none_type)
124
+ params["nullable"] = True
125
+ return type_name, params
126
+
127
+ # Handle list[T] -> ARRAY
128
+ if get_origin(python_type) is list:
129
+ list_args = get_args(python_type)
130
+ if list_args:
131
+ item_type_name, _ = _map_python_type_to_sqlalchemy(list_args[0])
132
+ return "array", {"item_type": item_type_name}
133
+
134
+ # Use Python type name directly as SQLAlchemy type name
135
+ # Registry alias system will handle mapping in create_type_instance()
136
+ type_name = python_type.__name__ if hasattr(python_type, "__name__") else "string"
137
+ return type_name, {}
138
+
139
+
140
+ class ModelRegistry(SqlAlchemyMetaData):
141
+ """Unified registry for models, tables, and relationships.
142
+
143
+ This class extends SQLAlchemy's MetaData to provide comprehensive
144
+ management of model classes, their database tables, relationships,
145
+ and many-to-many association tables.
146
+
147
+ Features:
148
+ - Model class registration and lookup
149
+ - Relationship resolution and management
150
+ - M2M table creation and management
151
+ - Table-to-model mapping
152
+ """
153
+
154
+ def __init__(self, bind=None, schema=None, quote_schema=None, naming_convention=None, info=None):
155
+ """Initialize ModelRegistry with SQLAlchemy MetaData configuration.
156
+
157
+ Args:
158
+ bind: Database engine or connection
159
+ schema: Default schema name
160
+ quote_schema: Whether to quote schema names
161
+ naming_convention: Naming convention for constraints
162
+ info: Additional metadata information
163
+ """
164
+ super().__init__(schema=schema, quote_schema=quote_schema, naming_convention=naming_convention, info=info)
165
+ if bind is not None:
166
+ self.bind = bind
167
+
168
+ # Model management
169
+ self._models: dict[str, type[ObjectModel]] = {}
170
+ self._table_to_model: dict[str, type[ObjectModel]] = {}
171
+
172
+ # Relationship management
173
+ self._relationships: dict[str, dict[str, RelationshipDescriptor]] = {}
174
+ self._resolved: bool = False
175
+
176
+ # M2M table management
177
+ self._m2m_tables: dict[str, M2MTable] = {}
178
+ self._pending_m2m: list[M2MTable] = []
179
+
180
+ # Model registration
181
+ def register_model(self, model_class: type["ObjectModel"]) -> None:
182
+ """Register model class with table and relationships.
183
+
184
+ Args:
185
+ model_class: Model class to register
186
+ """
187
+ self._models[model_class.__name__] = model_class
188
+
189
+ if hasattr(model_class, "__table__"):
190
+ table = getattr(model_class, "__table__") # noqa: B009
191
+ if table is not None:
192
+ self._table_to_model[table.name] = model_class
193
+
194
+ # Register relationships
195
+ if hasattr(model_class, "_relationships"):
196
+ relationships = getattr(model_class, "_relationships") # noqa: B009
197
+ if relationships is not None:
198
+ self._relationships[model_class.__name__] = relationships
199
+ self._resolved = False # Mark for re-resolution
200
+
201
+ def get_model(self, name: str) -> type["ObjectModel"] | None:
202
+ """Get model class by name.
203
+
204
+ Args:
205
+ name: Model class name
206
+
207
+ Returns:
208
+ Model class or None if not found
209
+ """
210
+ return self._models.get(name)
211
+
212
+ def get_model_by_table(self, table_name: str) -> type["ObjectModel"] | None:
213
+ """Get model class by table name.
214
+
215
+ Args:
216
+ table_name: Database table name
217
+
218
+ Returns:
219
+ Model class or None if not found
220
+ """
221
+ return self._table_to_model.get(table_name)
222
+
223
+ def list_models(self) -> list[type["ObjectModel"]]:
224
+ """Get all registered models.
225
+
226
+ Returns:
227
+ List of all registered model classes
228
+ """
229
+ return list(self._models.values())
230
+
231
+ # Relationship resolution
232
+ def resolve_all_relationships(self) -> None:
233
+ """Resolve all model relationships.
234
+
235
+ This method resolves string-based relationship references to actual
236
+ model classes and determines relationship types.
237
+ """
238
+ if self._resolved:
239
+ return
240
+
241
+ for _, relationships in self._relationships.items():
242
+ for _, descriptor in relationships.items():
243
+ self._resolve_relationship(descriptor)
244
+
245
+ self._resolved = True
246
+
247
+ def _resolve_relationship(self, descriptor: "RelationshipDescriptor") -> None:
248
+ """Resolve single relationship.
249
+
250
+ Args:
251
+ descriptor: Relationship descriptor to resolve
252
+ """
253
+ if isinstance(descriptor.property.argument, str):
254
+ related_model = self._models.get(descriptor.property.argument)
255
+ if related_model:
256
+ descriptor.property.resolved_model = related_model
257
+
258
+ descriptor.property.relationship_type = RelationshipResolver.resolve_relationship_type(
259
+ descriptor.property
260
+ )
261
+
262
+ # M2M table management
263
+ def register_m2m_table(self, m2m_def: "M2MTable") -> None:
264
+ """Register M2M table for delayed creation.
265
+
266
+ Args:
267
+ m2m_def: M2M table definition to register
268
+ """
269
+ self._pending_m2m.append(m2m_def)
270
+
271
+ def process_pending_m2m(self) -> None:
272
+ """Process all pending M2M table registrations.
273
+
274
+ Creates actual database tables for all pending M2M definitions
275
+ where both related models are available.
276
+ """
277
+ for m2m_def in self._pending_m2m:
278
+ self._create_m2m_table(m2m_def)
279
+ self._pending_m2m.clear()
280
+
281
+ def _create_m2m_table(self, m2m_def: "M2MTable") -> None:
282
+ """Create M2M table from definition.
283
+
284
+ Args:
285
+ m2m_def: M2M table definition to create
286
+ """
287
+ left_model = self.get_model(m2m_def.left_model)
288
+ right_model = self.get_model(m2m_def.right_model)
289
+
290
+ if not left_model or not right_model:
291
+ return # Keep in pending
292
+
293
+ left_table = getattr(left_model, "__table__", None)
294
+ right_table = getattr(right_model, "__table__", None)
295
+
296
+ if left_table is None or right_table is None:
297
+ return # Keep in pending
298
+
299
+ m2m_def.create_table(self, left_table, right_table)
300
+ self._m2m_tables[m2m_def.table_name] = m2m_def
301
+
302
+ def get_m2m_table(self, table_name: str) -> Any | None:
303
+ """Get M2M table by name.
304
+
305
+ Args:
306
+ table_name: Name of the M2M table
307
+
308
+ Returns:
309
+ SQLAlchemy Table object or None if not found
310
+ """
311
+ return self.tables.get(table_name)
312
+
313
+ def get_m2m_definition(self, table_name: str) -> Union["M2MTable", None]:
314
+ """Get M2M table definition by name.
315
+
316
+ Args:
317
+ table_name: Name of the M2M table
318
+
319
+ Returns:
320
+ M2MTable definition or None if not found
321
+ """
322
+ return self._m2m_tables.get(table_name)
323
+
324
+
325
+ class ModelProcessor(type):
326
+ """Metaclass that processes SQLObjects model definitions with type inference and table construction.
327
+
328
+ This metaclass handles the complete model processing pipeline:
329
+ - Type inference from annotations
330
+ - Configuration parsing and validation
331
+ - Table construction with indexes and constraints
332
+ - Dataclass functionality generation
333
+ - Model registration and relationship setup
334
+ """
335
+
336
+ def __new__(mcs, name, bases, namespace, **kwargs):
337
+ """Create new model class with complete processing pipeline.
338
+
339
+ Args:
340
+ name: Class name
341
+ bases: Base classes
342
+ namespace: Class namespace
343
+ **kwargs: Additional keyword arguments
344
+
345
+ Returns:
346
+ Processed model class
347
+ """
348
+ # Get or create shared ModelRegistry
349
+ registry = None
350
+ for base in bases:
351
+ if hasattr(base, "__registry__"):
352
+ registry = base.__registry__
353
+ break
354
+ if registry is None:
355
+ registry = ModelRegistry()
356
+
357
+ # Handle type inference
358
+ annotations = namespace.get("__annotations__", {})
359
+
360
+ for field_name, annotation in annotations.items():
361
+ field_value = namespace.get(field_name)
362
+
363
+ if field_value is not None and hasattr(field_value, "column"):
364
+ core_column = field_value.column
365
+
366
+ if isinstance(core_column.type, Auto):
367
+ # Infer actual type from annotation
368
+ inferred_type, inferred_params = _infer_type_from_annotation(annotation)
369
+
370
+ # Create new type instance and replace
371
+ from .fields import create_type_instance
372
+
373
+ new_type_instance = create_type_instance(inferred_type, inferred_params)
374
+ core_column.type = new_type_instance
375
+
376
+ # If nullable=True inferred, update column attribute
377
+ if inferred_params.get("nullable"):
378
+ core_column.nullable = True
379
+
380
+ # Create class
381
+ cls = super().__new__(mcs, name, bases, namespace, **kwargs)
382
+
383
+ # Set shared registry
384
+ cls.__registry__ = registry # type: ignore[reportAttributeAccessIssue]
385
+
386
+ # If not abstract class, build table
387
+ if not cls.__dict__.get("__abstract__", False):
388
+ # Parse configuration
389
+ config = _parse_model_config(cls)
390
+
391
+ # Set default ordering
392
+ if config.ordering:
393
+ cls._default_ordering = config.ordering # type: ignore[reportAttributeAccessIssue]
394
+
395
+ # Integrate field-level config into model config
396
+ config = mcs._integrate_field_config(cls, config)
397
+
398
+ # Register validators
399
+ mcs._register_field_validators(cls, config)
400
+
401
+ # Apply dataclass functionality
402
+ cls = mcs._apply_dataclass_functionality(cls)
403
+
404
+ # Build table
405
+ table = mcs._build_table(cls, config, registry)
406
+ cls.__table__ = table
407
+
408
+ # Initialize field cache (after table creation)
409
+ mcs._initialize_field_cache(cls)
410
+
411
+ # Normalize index and constraint names after table construction
412
+ mcs._post_process_table_indexes(table, config.table_name)
413
+ mcs._post_process_table_constraints(table, config.table_name)
414
+
415
+ # Auto-register model to ModelRegistry
416
+ from typing import cast
417
+
418
+ from .model import ObjectModel
419
+
420
+ registry.register_model(cast(type[ObjectModel], cls))
421
+
422
+ # Process pending M2M tables
423
+ registry.process_pending_m2m()
424
+
425
+ return cls
426
+
427
+ @classmethod
428
+ def _integrate_field_config(mcs, cls: Any, config: ModelConfig) -> ModelConfig:
429
+ """Integrate field-level configuration into model configuration - optimized version.
430
+
431
+ Args:
432
+ cls: Model class
433
+ config: Current model configuration
434
+
435
+ Returns:
436
+ Updated model configuration with field-level settings integrated
437
+ """
438
+ # Collect all field configuration in single pass
439
+ field_indexes, field_validators, field_metadata = mcs._collect_all_field_config(cls, config.table_name)
440
+
441
+ # Merge indexes (avoid duplicates)
442
+ config.indexes = mcs._merge_indexes(field_indexes, config.indexes, config.table_name)
443
+
444
+ # Set validators and metadata
445
+ config.field_validators = field_validators
446
+ config.field_metadata = field_metadata
447
+
448
+ return config
449
+
450
+ @classmethod
451
+ def _collect_all_field_config(
452
+ mcs, cls: Any, table_name: str
453
+ ) -> tuple[list[Index], dict[str, list[Any]], dict[str, dict[str, Any]]]:
454
+ """Collect all field configuration in single pass (performance optimization).
455
+
456
+ Args:
457
+ cls: Model class
458
+ table_name: Database table name
459
+
460
+ Returns:
461
+ Tuple of (indexes, validators, metadata)
462
+ """
463
+ indexes = []
464
+ validators = {}
465
+ metadata = {}
466
+
467
+ try:
468
+ fields = mcs._get_fields(cls)
469
+ except Exception as e:
470
+ raise RuntimeError(f"Failed to get fields for {cls.__name__}: {e}") from e
471
+
472
+ for name, field_def in fields.items():
473
+ try:
474
+ if not hasattr(field_def, "column"):
475
+ continue
476
+
477
+ column = field_def.column
478
+
479
+ # Collect indexes
480
+ if getattr(column, "unique", False) and not getattr(column, "primary_key", False):
481
+ index_name = f"idx_{table_name}_{name}"
482
+ indexes.append(Index(index_name, name, unique=True))
483
+ elif getattr(column, "index", False) and not getattr(column, "primary_key", False):
484
+ index_name = f"idx_{table_name}_{name}"
485
+ indexes.append(Index(index_name, name))
486
+
487
+ # Collect validators
488
+ if column.info:
489
+ field_validators = column.info.get("_enhanced", {}).get("validators")
490
+ if field_validators:
491
+ validators[name] = field_validators
492
+
493
+ # Collect metadata
494
+ field_meta = {}
495
+
496
+ # Collect basic metadata
497
+ if hasattr(column, "comment") and column.comment:
498
+ field_meta["comment"] = column.comment
499
+ if hasattr(column, "doc") and column.doc:
500
+ field_meta["doc"] = column.doc
501
+
502
+ # Collect type information
503
+ field_meta["type"] = str(column.type)
504
+ field_meta["nullable"] = getattr(column, "nullable", True)
505
+ field_meta["primary_key"] = getattr(column, "primary_key", False)
506
+ field_meta["unique"] = getattr(column, "unique", False)
507
+
508
+ # Collect extended parameters
509
+ if column.info:
510
+ enhanced_params = column.info.get("_enhanced", {})
511
+ performance_params = column.info.get("_performance", {})
512
+ codegen_params = column.info.get("_codegen", {})
513
+
514
+ if enhanced_params:
515
+ field_meta["enhanced"] = enhanced_params
516
+ if performance_params:
517
+ field_meta["performance"] = performance_params
518
+ if codegen_params:
519
+ field_meta["codegen"] = codegen_params
520
+
521
+ if field_meta:
522
+ metadata[name] = field_meta
523
+
524
+ except AttributeError:
525
+ # Field missing expected attributes, skip silently
526
+ continue
527
+ except Exception as e:
528
+ raise RuntimeError(f"Error processing field {name} in {cls.__name__}: {e}") from e
529
+
530
+ return indexes, validators, metadata
531
+
532
+ @classmethod
533
+ def _merge_indexes(mcs, field_indexes: list[Index], table_indexes: list[Index], table_name: str) -> list[Index]:
534
+ """Merge field-level and table-level indexes, avoiding duplicates.
535
+
536
+ Args:
537
+ field_indexes: Indexes generated from field definitions
538
+ table_indexes: Indexes defined at table level
539
+ table_name: Database table name
540
+
541
+ Returns:
542
+ Merged list of unique indexes
543
+ """
544
+
545
+ def get_index_signature(idx): # noqa
546
+ if hasattr(idx, "_columns") and idx._columns: # noqa
547
+ columns = tuple(sorted(str(col) for col in idx._columns)) # noqa
548
+ return (columns, idx.unique) # noqa
549
+ return None
550
+
551
+ # Collect table-level index signatures
552
+ table_signatures = set()
553
+ for idx in table_indexes:
554
+ sig = get_index_signature(idx)
555
+ if sig:
556
+ table_signatures.add(sig)
557
+
558
+ # Filter duplicate field-level indexes
559
+ merged_indexes = []
560
+ for idx in field_indexes:
561
+ sig = get_index_signature(idx)
562
+ if sig and sig not in table_signatures:
563
+ merged_indexes.append(idx)
564
+
565
+ # Add table-level indexes
566
+ merged_indexes.extend(table_indexes)
567
+
568
+ # Normalize all index naming format
569
+ return mcs._normalize_all_indexes(merged_indexes, table_name)
570
+
571
+ @classmethod
572
+ def _normalize_all_indexes(mcs, indexes: list[Index], table_name: str) -> list[Index]:
573
+ """Force uniform naming format for all indexes.
574
+
575
+ Args:
576
+ indexes: List of indexes to normalize
577
+ table_name: Database table name
578
+
579
+ Returns:
580
+ List of indexes with normalized names
581
+ """
582
+ normalized_indexes = []
583
+
584
+ for idx in indexes:
585
+ # Get field name list
586
+ if hasattr(idx, "columns") and idx.columns:
587
+ field_names = "_".join(col.name for col in idx.columns) # noqa
588
+ elif hasattr(idx, "_columns") and idx._columns: # noqa
589
+ # Handle indexes not yet bound to table
590
+ field_names = "_".join(str(col).split(".")[-1] for col in idx._columns) # noqa
591
+ else:
592
+ normalized_indexes.append(idx)
593
+ continue
594
+
595
+ # Generate standardized name
596
+ new_name = f"idx_{table_name}_{field_names}"
597
+
598
+ # Directly modify index name (instead of rebuilding)
599
+ idx.name = new_name # type: ignore[reportAttributeAccessIssue]
600
+ normalized_indexes.append(idx)
601
+
602
+ return normalized_indexes
603
+
604
+ @classmethod
605
+ def _register_field_validators(mcs, cls: Any, config: ModelConfig) -> None:
606
+ """Register field-level validators to model class.
607
+
608
+ Args:
609
+ cls: Model class
610
+ config: Model configuration containing validators
611
+ """
612
+ if config.field_validators:
613
+ setattr(cls, "_field_validators", config.field_validators) # noqa: B010
614
+
615
+ @classmethod
616
+ def _build_table(mcs, cls: Any, config: ModelConfig, registry):
617
+ """Build SQLAlchemy Core Table and integrate configuration.
618
+
619
+ Args:
620
+ cls: Model class
621
+ config: Model configuration
622
+ registry: Model registry for metadata
623
+
624
+ Returns:
625
+ SQLAlchemy Table instance
626
+ """
627
+ from sqlalchemy import Table
628
+
629
+ # Collect column definitions
630
+ columns = []
631
+ for name, field_def in mcs._get_fields(cls).items():
632
+ if hasattr(field_def, "column"):
633
+ column = field_def.column
634
+ if column.name is None:
635
+ column.name = name
636
+ columns.append(column)
637
+
638
+ # Build table arguments
639
+ table_args = []
640
+ table_kwargs = {}
641
+
642
+ # Add indexes and constraints (already integrated)
643
+ table_args.extend(config.indexes)
644
+ table_args.extend(config.constraints)
645
+
646
+ # Handle database-specific options
647
+ if config.db_options:
648
+ for db_name, options in config.db_options.items():
649
+ if db_name == "generic":
650
+ table_kwargs.update(options)
651
+ else:
652
+ for key, value in options.items():
653
+ table_kwargs[f"{db_name}_{key}"] = value
654
+
655
+ return Table(config.table_name, registry, *columns, *table_args, **table_kwargs)
656
+
657
+ @classmethod
658
+ def _post_process_table_indexes(mcs, table, table_name: str) -> None:
659
+ """Normalize index names after table construction.
660
+
661
+ Args:
662
+ table: SQLAlchemy Table instance
663
+ table_name: Database table name
664
+ """
665
+ for idx in table.indexes:
666
+ if hasattr(idx, "columns") and idx.columns:
667
+ field_names = "_".join(col.name for col in idx.columns)
668
+ new_name = f"idx_{table_name}_{field_names}"
669
+ idx.name = new_name
670
+
671
+ @classmethod
672
+ def _post_process_table_constraints(mcs, table, table_name: str) -> None:
673
+ """Normalize constraint names after table construction.
674
+
675
+ Args:
676
+ table: SQLAlchemy Table instance
677
+ table_name: Database table name
678
+ """
679
+ for cst in table.constraints:
680
+ if cst.name is None:
681
+ if isinstance(cst, CheckConstraint):
682
+ # Extract field names from condition
683
+ field_matches = _FIELD_NAME_PATTERN.findall(str(cst.sqltext))
684
+ if field_matches:
685
+ field_part = "_".join(field_matches[:2])
686
+ cst.name = f"ck_{table_name}_{field_part}"
687
+ else:
688
+ cst.name = f"ck_{table_name}_constraint"
689
+ elif isinstance(cst, UniqueConstraint) and hasattr(cst, "columns"):
690
+ field_names = "_".join(col.name for col in cst.columns)
691
+ cst.name = f"uq_{table_name}_{field_names}"
692
+
693
+ @classmethod
694
+ def _apply_dataclass_functionality(mcs, cls: Any) -> Any:
695
+ """Apply dataclass functionality to model class.
696
+
697
+ Args:
698
+ cls: Model class to enhance
699
+
700
+ Returns:
701
+ Enhanced model class with dataclass methods
702
+ """
703
+ # Collect field information for generating dataclass methods
704
+ field_configs = {}
705
+ for name, field_def in mcs._get_fields(cls).items():
706
+ if hasattr(field_def, "column"):
707
+ column_attr = getattr(cls, name)
708
+ if hasattr(column_attr, "get_codegen_params"):
709
+ codegen_params = column_attr.get_codegen_params()
710
+ field_configs[name] = codegen_params
711
+
712
+ # Generate dataclass methods if field configs exist
713
+ if field_configs:
714
+ mcs._generate_dataclass_methods(cls, field_configs)
715
+
716
+ return cls
717
+
718
+ @classmethod
719
+ def _generate_dataclass_methods(mcs, cls: Any, field_configs: dict) -> None:
720
+ """Generate dataclass-style methods.
721
+
722
+ Args:
723
+ cls: Model class
724
+ field_configs: Field configuration dictionary
725
+ """
726
+ # Generate __init__ method
727
+ mcs._generate_init_method(cls, field_configs)
728
+
729
+ # Generate __repr__ method
730
+ mcs._generate_repr_method(cls, field_configs)
731
+
732
+ # Generate __eq__ method
733
+ mcs._generate_eq_method(cls, field_configs)
734
+
735
+ # Set standard dataclass compatibility markers
736
+ cls.__dataclass_fields__ = dict.fromkeys(field_configs.keys(), True)
737
+ cls.__dataclass_params__ = {
738
+ "init": True,
739
+ "repr": True,
740
+ "eq": True,
741
+ "order": False,
742
+ "unsafe_hash": False,
743
+ "frozen": False,
744
+ }
745
+ cls.__dataclass_transform__ = True
746
+
747
+ @classmethod
748
+ def _generate_init_method(mcs, cls: Any, field_configs: dict) -> None:
749
+ """Generate __init__ method with support for defaults and default_factory.
750
+
751
+ Args:
752
+ cls: Model class
753
+ field_configs: Field configuration dictionary
754
+ """
755
+ init_fields = [name for name, config in field_configs.items() if config.get("init", True)]
756
+
757
+ if not init_fields:
758
+ return
759
+
760
+ # Collect field defaults and factory functions
761
+ field_defaults = {}
762
+ field_factories = {}
763
+
764
+ for name in init_fields:
765
+ field_attr = getattr(cls, name)
766
+ if hasattr(field_attr, "column") and field_attr.column is not None:
767
+ column = field_attr.column
768
+
769
+ # Check default_factory first
770
+ if hasattr(field_attr, "get_default_factory"):
771
+ factory = field_attr.get_default_factory()
772
+ if factory and callable(factory):
773
+ field_factories[name] = factory
774
+ continue
775
+
776
+ # Handle SQLAlchemy default values
777
+ if column.default is not None:
778
+ if hasattr(column.default, "arg"):
779
+ field_defaults[name] = column.default.arg
780
+ elif hasattr(column.default, "is_scalar") and column.default.is_scalar:
781
+ field_defaults[name] = column.default.arg
782
+
783
+ def __init__(self, **kwargs):
784
+ # Call parent __init__
785
+ super(cls, self).__init__()
786
+
787
+ # Only allow init=True fields as parameters
788
+ for key in kwargs:
789
+ if key not in init_fields:
790
+ raise TypeError(f"{cls.__name__}.__init__() got an unexpected keyword argument '{key}'")
791
+
792
+ # Set field values
793
+ for field_name in init_fields:
794
+ if field_name in kwargs:
795
+ setattr(self, field_name, kwargs[field_name])
796
+ elif field_name in field_factories:
797
+ # Call factory function to generate default value
798
+ setattr(self, field_name, field_factories[field_name]())
799
+ elif field_name in field_defaults:
800
+ # Use static default value
801
+ setattr(self, field_name, field_defaults[field_name])
802
+
803
+ cls.__init__ = __init__
804
+
805
+ @classmethod
806
+ def _generate_repr_method(mcs, cls: Any, field_configs: dict) -> None:
807
+ """Generate __repr__ method.
808
+
809
+ Args:
810
+ cls: Model class
811
+ field_configs: Field configuration dictionary
812
+ """
813
+ repr_fields = [name for name, config in field_configs.items() if config.get("repr", True)]
814
+
815
+ if not repr_fields:
816
+ return
817
+
818
+ def __repr__(self):
819
+ field_strs = []
820
+ for field_name in repr_fields:
821
+ try:
822
+ value = getattr(self, field_name, None)
823
+ field_strs.append(f"{field_name}={value!r}")
824
+ except AttributeError:
825
+ continue
826
+ return f"{cls.__name__}({', '.join(field_strs)})"
827
+
828
+ cls.__repr__ = __repr__
829
+
830
+ @classmethod
831
+ def _generate_eq_method(mcs, cls: Any, field_configs: dict) -> None:
832
+ """Generate intelligent __eq__ method.
833
+
834
+ Args:
835
+ cls: Model class
836
+ field_configs: Field configuration dictionary
837
+ """
838
+ compare_fields = [name for name, config in field_configs.items() if config.get("compare", False)]
839
+
840
+ if not compare_fields:
841
+ return
842
+
843
+ # Identify primary key fields
844
+ pk_fields = []
845
+ for name in compare_fields:
846
+ field_attr = getattr(cls, name)
847
+ if hasattr(field_attr, "column") and field_attr.column is not None and field_attr.column.primary_key:
848
+ pk_fields.append(name)
849
+
850
+ def __eq__(self, other):
851
+ if not isinstance(other, cls):
852
+ return NotImplemented
853
+
854
+ # Smart comparison logic: prioritize primary keys
855
+ if pk_fields:
856
+ self_pk_values = [getattr(self, name, None) for name in pk_fields]
857
+ other_pk_values = [getattr(other, name, None) for name in pk_fields]
858
+
859
+ # If all primary keys are not None, compare only primary keys
860
+ if all(v is not None for v in self_pk_values + other_pk_values):
861
+ return self_pk_values == other_pk_values
862
+
863
+ # If some primary keys are None but not all, not equal
864
+ if any(v is not None for v in self_pk_values + other_pk_values):
865
+ return False
866
+
867
+ # Fall back to comparing all compare=True fields
868
+ for field_name in compare_fields:
869
+ try:
870
+ self_value = getattr(self, field_name, None)
871
+ other_value = getattr(other, field_name, None)
872
+ if self_value != other_value:
873
+ return False
874
+ except AttributeError:
875
+ return False
876
+ return True
877
+
878
+ cls.__eq__ = __eq__
879
+
880
+ @classmethod
881
+ def _initialize_field_cache(mcs, cls: Any) -> None:
882
+ """Initialize field cache for performance optimization.
883
+
884
+ Args:
885
+ cls: Model class
886
+ """
887
+ cls._field_cache = {"deferred_fields": set(), "relationship_fields": set(), "regular_fields": set()}
888
+
889
+ # Get field information from table
890
+ if hasattr(cls, "__table__"):
891
+ table = cls.__table__
892
+ for col_name in table.columns.keys():
893
+ try:
894
+ attr = getattr(cls, col_name, None)
895
+ if attr and hasattr(attr, "column") and attr.column is not None:
896
+ # Check if deferred field
897
+ if hasattr(attr.column, "info") and attr.column.info is not None:
898
+ performance_params = attr.column.info.get("_performance", {})
899
+ if performance_params.get("deferred", False):
900
+ cls._field_cache["deferred_fields"].add(col_name)
901
+ else:
902
+ cls._field_cache["regular_fields"].add(col_name)
903
+ else:
904
+ cls._field_cache["regular_fields"].add(col_name)
905
+ except (AttributeError, TypeError):
906
+ cls._field_cache["regular_fields"].add(col_name)
907
+
908
+ # Check relationship fields
909
+ if hasattr(cls, "_relationships"):
910
+ relationships = getattr(cls, "_relationships", {})
911
+ for rel_name in relationships.keys():
912
+ cls._field_cache["relationship_fields"].add(rel_name)
913
+
914
+ @classmethod
915
+ def _get_fields(mcs, cls: Any) -> dict[str, _FieldProtocol]:
916
+ """Get class field definitions with enhanced error handling.
917
+
918
+ Args:
919
+ cls: Model class
920
+
921
+ Returns:
922
+ Dictionary of field name to field definition
923
+ """
924
+ fields = {}
925
+ for name in dir(cls):
926
+ if name.startswith("_"):
927
+ continue
928
+ try:
929
+ attr = getattr(cls, name)
930
+ if hasattr(attr, "column"):
931
+ fields[name] = attr
932
+ except AttributeError:
933
+ # Attribute not accessible, skip silently
934
+ continue
935
+ except Exception as e:
936
+ raise RuntimeError(f"Unexpected error accessing {name} on {cls.__name__}: {e}") from e
937
+ return fields
938
+
939
+
940
+ def _parse_model_config(model_class: Any) -> ModelConfig:
941
+ """Parse complete configuration for a model class.
942
+
943
+ Args:
944
+ model_class: Model class to process configuration for
945
+
946
+ Returns:
947
+ Complete ModelConfig with all defaults filled
948
+ """
949
+ config_class = getattr(model_class, "Config", None)
950
+ if config_class:
951
+ raw_config = _parse_config_class(config_class)
952
+ else:
953
+ raw_config = _RawModelConfig()
954
+
955
+ return _fill_config_defaults(raw_config, model_class)
956
+
957
+
958
+ def _parse_config_class(config_class: type) -> _RawModelConfig:
959
+ """Parse configuration from a Config inner class.
960
+
961
+ Args:
962
+ config_class: The Config inner class to parse
963
+
964
+ Returns:
965
+ _RawModelConfig instance with parsed configuration
966
+ """
967
+ config = _RawModelConfig()
968
+
969
+ # Basic configuration
970
+ config.table_name = getattr(config_class, "table_name", None)
971
+ config.ordering = getattr(config_class, "ordering", [])
972
+
973
+ # Index configuration
974
+ config.indexes = getattr(config_class, "indexes", [])
975
+
976
+ # Constraint configuration
977
+ config.constraints = getattr(config_class, "constraints", [])
978
+
979
+ # Metadata
980
+ config.verbose_name = getattr(config_class, "verbose_name", None)
981
+ config.verbose_name_plural = getattr(config_class, "verbose_name_plural", None)
982
+ config.description = getattr(config_class, "description", None)
983
+
984
+ # Database-specific configuration
985
+ config.db_options = getattr(config_class, "db_options", {})
986
+
987
+ # Custom configuration
988
+ config.custom = getattr(config_class, "custom", {})
989
+
990
+ return config
991
+
992
+
993
+ def _fill_config_defaults(config: _RawModelConfig, model_class: Any) -> ModelConfig:
994
+ """Fill default values for configuration fields that are None.
995
+
996
+ Args:
997
+ config: _RawModelConfig instance to fill defaults for
998
+ model_class: Model class to generate defaults from
999
+
1000
+ Returns:
1001
+ ModelConfig instance with defaults filled
1002
+ """
1003
+ # Fill table_name if not set
1004
+ table_name = config.table_name
1005
+ if table_name is None:
1006
+ snake_case_name = to_snake_case(model_class.__name__)
1007
+ table_name = pluralize(snake_case_name)
1008
+
1009
+ # Fill verbose_name if not set
1010
+ verbose_name = config.verbose_name
1011
+ if verbose_name is None:
1012
+ verbose_name = model_class.__name__
1013
+
1014
+ # Fill verbose_name_plural if not set
1015
+ verbose_name_plural = config.verbose_name_plural
1016
+ if verbose_name_plural is None:
1017
+ verbose_name_plural = pluralize(verbose_name)
1018
+
1019
+ # Create complete config with required fields
1020
+ return ModelConfig(
1021
+ table_name=table_name,
1022
+ verbose_name=verbose_name,
1023
+ verbose_name_plural=verbose_name_plural,
1024
+ ordering=config.ordering,
1025
+ indexes=config.indexes,
1026
+ constraints=config.constraints,
1027
+ description=config.description,
1028
+ db_options=config.db_options,
1029
+ custom=config.custom,
1030
+ field_validators={},
1031
+ field_metadata={},
1032
+ )
1033
+
1034
+
1035
+ # Convenience functions for creating indexes and constraints
1036
+
1037
+
1038
+ def index(
1039
+ name: str | None = None,
1040
+ *fields: str,
1041
+ unique: bool = False, # noqa
1042
+ postgresql_where: str | None = None,
1043
+ postgresql_using: str | None = None,
1044
+ mysql_using: str | None = None,
1045
+ **kwargs: Any,
1046
+ ) -> Index:
1047
+ """Create an Index with convenient field name support.
1048
+
1049
+ Args:
1050
+ name: Index name (will be normalized to idx_tablename_fields format)
1051
+ *fields: Field names as strings
1052
+ unique: Whether index should be unique
1053
+ postgresql_where: PostgreSQL WHERE clause for partial indexes
1054
+ postgresql_using: PostgreSQL index method (btree, hash, gin, gist, etc.)
1055
+ mysql_using: MySQL index method (btree, hash)
1056
+ **kwargs: Additional SQLAlchemy Index arguments
1057
+
1058
+ Returns:
1059
+ SQLAlchemy Index instance
1060
+
1061
+ Examples:
1062
+ >>> index("idx_users_email", "email", unique=True)
1063
+ >>> index("idx_users_name_age", "name", "age")
1064
+ >>> index("idx_users_status", "status", postgresql_where="status = 'active'")
1065
+ >>> index("idx_users_tags", "tags", postgresql_using="gin")
1066
+ """
1067
+ # Note: Don't auto-generate name here because table_name is needed
1068
+ # Actual name normalization is handled in _merge_indexes
1069
+ if name is None:
1070
+ field_part = "_".join(fields)
1071
+ name = f"idx_{field_part}" # Temporary name, will be replaced later
1072
+
1073
+ # Build dialect-specific kwargs
1074
+ dialect_kwargs = {}
1075
+ if postgresql_where is not None:
1076
+ dialect_kwargs["postgresql_where"] = postgresql_where
1077
+ if postgresql_using is not None:
1078
+ dialect_kwargs["postgresql_using"] = postgresql_using
1079
+ if mysql_using is not None:
1080
+ dialect_kwargs["mysql_using"] = mysql_using
1081
+
1082
+ # Merge with additional kwargs
1083
+ dialect_kwargs.update(kwargs)
1084
+
1085
+ return Index(name, *fields, unique=unique, **dialect_kwargs)
1086
+
1087
+
1088
+ def constraint(
1089
+ condition: str,
1090
+ name: str | None = None,
1091
+ **kwargs: Any,
1092
+ ) -> CheckConstraint:
1093
+ """Create a CheckConstraint with convenient syntax.
1094
+
1095
+ Args:
1096
+ condition: SQL condition expression
1097
+ name: Constraint name (optional, will be normalized if needed)
1098
+ **kwargs: Additional SQLAlchemy CheckConstraint arguments
1099
+
1100
+ Returns:
1101
+ SQLAlchemy CheckConstraint instance
1102
+
1103
+ Examples:
1104
+ >>> constraint("age >= 0", "ck_age_positive")
1105
+ >>> constraint("length(name) > 0")
1106
+ >>> constraint("price > 0 AND price < 10000")
1107
+ """
1108
+ return CheckConstraint(condition, name=name, **kwargs)
1109
+
1110
+
1111
+ def unique(
1112
+ *fields: str,
1113
+ name: str | None = None,
1114
+ **kwargs: Any,
1115
+ ) -> UniqueConstraint:
1116
+ """Create a UniqueConstraint with convenient field name support.
1117
+
1118
+ Args:
1119
+ *fields: Field names as strings
1120
+ name: Constraint name (optional, will be normalized if needed)
1121
+ **kwargs: Additional SQLAlchemy UniqueConstraint arguments
1122
+
1123
+ Returns:
1124
+ SQLAlchemy UniqueConstraint instance
1125
+
1126
+ Examples:
1127
+ >>> unique("email")
1128
+ >>> unique("first_name", "last_name", name="uq_full_name")
1129
+ """
1130
+ return UniqueConstraint(*fields, name=name, **kwargs)