lsst-felis 28.2024.4500__py3-none-any.whl → 30.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. felis/__init__.py +9 -1
  2. felis/cli.py +308 -209
  3. felis/config/tap_schema/columns.csv +33 -0
  4. felis/config/tap_schema/key_columns.csv +8 -0
  5. felis/config/tap_schema/keys.csv +8 -0
  6. felis/config/tap_schema/schemas.csv +2 -0
  7. felis/config/tap_schema/tables.csv +6 -0
  8. felis/config/tap_schema/tap_schema_extensions.yaml +73 -0
  9. felis/datamodel.py +599 -59
  10. felis/db/{dialects.py → _dialects.py} +69 -4
  11. felis/db/{variants.py → _variants.py} +1 -1
  12. felis/db/database_context.py +917 -0
  13. felis/diff.py +234 -0
  14. felis/metadata.py +89 -19
  15. felis/tap_schema.py +271 -166
  16. felis/tests/postgresql.py +1 -1
  17. felis/tests/run_cli.py +79 -0
  18. felis/types.py +7 -7
  19. {lsst_felis-28.2024.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/METADATA +20 -16
  20. lsst_felis-30.0.0rc3.dist-info/RECORD +31 -0
  21. {lsst_felis-28.2024.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/WHEEL +1 -1
  22. felis/db/utils.py +0 -409
  23. felis/tap.py +0 -597
  24. felis/tests/utils.py +0 -122
  25. felis/version.py +0 -2
  26. lsst_felis-28.2024.4500.dist-info/RECORD +0 -26
  27. felis/{schemas → config/tap_schema}/tap_schema_std.yaml +0 -0
  28. felis/db/{sqltypes.py → _sqltypes.py} +7 -7
  29. {lsst_felis-28.2024.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/entry_points.txt +0 -0
  30. {lsst_felis-28.2024.4500.dist-info → lsst_felis-30.0.0rc3.dist-info/licenses}/COPYRIGHT +0 -0
  31. {lsst_felis-28.2024.4500.dist-info → lsst_felis-30.0.0rc3.dist-info/licenses}/LICENSE +0 -0
  32. {lsst_felis-28.2024.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/top_level.txt +0 -0
  33. {lsst_felis-28.2024.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/zip-safe +0 -0
felis/datamodel.py CHANGED
@@ -23,29 +23,42 @@
23
23
 
24
24
  from __future__ import annotations
25
25
 
26
+ import json
26
27
  import logging
28
+ import sys
27
29
  from collections.abc import Sequence
28
30
  from enum import StrEnum, auto
29
- from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar, Union
31
+ from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar
30
32
 
31
33
  import yaml
32
34
  from astropy import units as units # type: ignore
33
35
  from astropy.io.votable import ucd # type: ignore
34
36
  from lsst.resources import ResourcePath, ResourcePathExpression
35
- from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator
37
+ from pydantic import (
38
+ BaseModel,
39
+ ConfigDict,
40
+ Field,
41
+ PrivateAttr,
42
+ ValidationError,
43
+ ValidationInfo,
44
+ field_serializer,
45
+ field_validator,
46
+ model_validator,
47
+ )
48
+ from pydantic_core import InitErrorDetails
36
49
 
37
- from .db.dialects import get_supported_dialects
38
- from .db.sqltypes import get_type_func
39
- from .db.utils import string_to_typeengine
50
+ from .db._dialects import get_supported_dialects, string_to_typeengine
51
+ from .db._sqltypes import get_type_func
40
52
  from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode
41
53
 
42
54
  logger = logging.getLogger(__name__)
43
55
 
44
56
  __all__ = (
45
57
  "BaseObject",
46
- "Column",
47
58
  "CheckConstraint",
59
+ "Column",
48
60
  "Constraint",
61
+ "DataType",
49
62
  "ForeignKeyConstraint",
50
63
  "Index",
51
64
  "Schema",
@@ -58,6 +71,7 @@ CONFIG = ConfigDict(
58
71
  populate_by_name=True, # Populate attributes by name.
59
72
  extra="forbid", # Do not allow extra fields.
60
73
  str_strip_whitespace=True, # Strip whitespace from string fields.
74
+ use_enum_values=False, # Do not use enum values during serialization.
61
75
  )
62
76
  """Pydantic model configuration as described in:
63
77
  https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict
@@ -117,7 +131,7 @@ class BaseObject(BaseModel):
117
131
 
118
132
 
119
133
  class DataType(StrEnum):
120
- """`Enum` representing the data types supported by Felis."""
134
+ """``Enum`` representing the data types supported by Felis."""
121
135
 
122
136
  boolean = auto()
123
137
  byte = auto()
@@ -134,6 +148,32 @@ class DataType(StrEnum):
134
148
  timestamp = auto()
135
149
 
136
150
 
151
+ def validate_ivoa_ucd(ivoa_ucd: str) -> str:
152
+ """Validate IVOA UCD values.
153
+
154
+ Parameters
155
+ ----------
156
+ ivoa_ucd
157
+ IVOA UCD value to check.
158
+
159
+ Returns
160
+ -------
161
+ `str`
162
+ The IVOA UCD value if it is valid.
163
+
164
+ Raises
165
+ ------
166
+ ValueError
167
+ If the IVOA UCD value is invalid.
168
+ """
169
+ if ivoa_ucd is not None:
170
+ try:
171
+ ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
172
+ except ValueError as e:
173
+ raise ValueError(f"Invalid IVOA UCD: {e}")
174
+ return ivoa_ucd
175
+
176
+
137
177
  class Column(BaseObject):
138
178
  """Column model."""
139
179
 
@@ -159,12 +199,6 @@ class Column(BaseObject):
159
199
  autoincrement: bool | None = None
160
200
  """Whether the column is autoincremented."""
161
201
 
162
- mysql_datatype: str | None = Field(None, alias="mysql:datatype")
163
- """MySQL datatype override on the column."""
164
-
165
- postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
166
- """PostgreSQL datatype override on the column."""
167
-
168
202
  ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
169
203
  """IVOA UCD of the column."""
170
204
 
@@ -193,6 +227,12 @@ class Column(BaseObject):
193
227
  votable_datatype: str | None = Field(None, alias="votable:datatype")
194
228
  """VOTable datatype of the column."""
195
229
 
230
+ mysql_datatype: str | None = Field(None, alias="mysql:datatype")
231
+ """MySQL datatype override on the column."""
232
+
233
+ postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
234
+ """PostgreSQL datatype override on the column."""
235
+
196
236
  @model_validator(mode="after")
197
237
  def check_value(self) -> Column:
198
238
  """Check that the default value is valid.
@@ -235,12 +275,7 @@ class Column(BaseObject):
235
275
  `str`
236
276
  The IVOA UCD value if it is valid.
237
277
  """
238
- if ivoa_ucd is not None:
239
- try:
240
- ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
241
- except ValueError as e:
242
- raise ValueError(f"Invalid IVOA UCD: {e}")
243
- return ivoa_ucd
278
+ return validate_ivoa_ucd(ivoa_ucd)
244
279
 
245
280
  @model_validator(mode="after")
246
281
  def check_units(self) -> Column:
@@ -387,7 +422,7 @@ class Column(BaseObject):
387
422
 
388
423
  @model_validator(mode="before")
389
424
  @classmethod
390
- def check_votable_arraysize(cls, values: dict[str, Any]) -> dict[str, Any]:
425
+ def check_votable_arraysize(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
391
426
  """Set the default value for the ``votable_arraysize`` field, which
392
427
  corresponds to ``arraysize`` in the IVOA VOTable standard.
393
428
 
@@ -395,6 +430,8 @@ class Column(BaseObject):
395
430
  ----------
396
431
  values
397
432
  Values of the column.
433
+ info
434
+ Validation context used to determine if the check is enabled.
398
435
 
399
436
  Returns
400
437
  -------
@@ -409,6 +446,7 @@ class Column(BaseObject):
409
446
  if values.get("name", None) is None or values.get("datatype", None) is None:
410
447
  # Skip bad column data that will not validate
411
448
  return values
449
+ context = info.context if info.context else {}
412
450
  arraysize = values.get("votable:arraysize", None)
413
451
  if arraysize is None:
414
452
  length = values.get("length", None)
@@ -418,7 +456,14 @@ class Column(BaseObject):
418
456
  if datatype == "char":
419
457
  arraysize = str(length)
420
458
  elif datatype in ("string", "unicode", "binary"):
421
- arraysize = f"{length}*"
459
+ if context.get("force_unbounded_arraysize", False):
460
+ arraysize = "*"
461
+ logger.debug(
462
+ f"Forced VOTable's 'arraysize' to '*' on column '{values['name']}' with datatype "
463
+ + f"'{values['datatype']}' and length '{length}'"
464
+ )
465
+ else:
466
+ arraysize = f"{length}*"
422
467
  elif datatype in ("timestamp", "text"):
423
468
  arraysize = "*"
424
469
  if arraysize is not None:
@@ -437,6 +482,59 @@ class Column(BaseObject):
437
482
  values["votable:arraysize"] = str(arraysize)
438
483
  return values
439
484
 
485
+ @field_serializer("datatype")
486
+ def serialize_datatype(self, value: DataType) -> str:
487
+ """Convert `DataType` to string when serializing to JSON/YAML.
488
+
489
+ Parameters
490
+ ----------
491
+ value
492
+ The `DataType` value to serialize.
493
+
494
+ Returns
495
+ -------
496
+ `str`
497
+ The serialized `DataType` value.
498
+ """
499
+ return str(value)
500
+
501
+ @field_validator("datatype", mode="before")
502
+ @classmethod
503
+ def deserialize_datatype(cls, value: str) -> DataType:
504
+ """Convert string back into `DataType` when loading from JSON/YAML.
505
+
506
+ Parameters
507
+ ----------
508
+ value
509
+ The string value to deserialize.
510
+
511
+ Returns
512
+ -------
513
+ `DataType`
514
+ The deserialized `DataType` value.
515
+ """
516
+ return DataType(value)
517
+
518
+ @model_validator(mode="after")
519
+ def check_votable_xtype(self) -> Column:
520
+ """Set the default value for the ``votable_xtype`` field, which
521
+ corresponds to an Extended Datatype or ``xtype`` in the IVOA VOTable
522
+ standard.
523
+
524
+ Returns
525
+ -------
526
+ `Column`
527
+ The column being validated.
528
+
529
+ Notes
530
+ -----
531
+ This is currently only set automatically for the Felis ``timestamp``
532
+ datatype.
533
+ """
534
+ if self.datatype == DataType.timestamp and self.votable_xtype is None:
535
+ self.votable_xtype = "timestamp"
536
+ return self
537
+
440
538
 
441
539
  class Constraint(BaseObject):
442
540
  """Table constraint model."""
@@ -472,6 +570,22 @@ class CheckConstraint(Constraint):
472
570
  expression: str
473
571
  """Expression for the check constraint."""
474
572
 
573
+ @field_serializer("type")
574
+ def serialize_type(self, value: str) -> str:
575
+ """Ensure '@type' is included in serialized output.
576
+
577
+ Parameters
578
+ ----------
579
+ value
580
+ The value to serialize.
581
+
582
+ Returns
583
+ -------
584
+ `str`
585
+ The serialized value.
586
+ """
587
+ return value
588
+
475
589
 
476
590
  class UniqueConstraint(Constraint):
477
591
  """Table unique constraint model."""
@@ -482,12 +596,30 @@ class UniqueConstraint(Constraint):
482
596
  columns: list[str]
483
597
  """Columns in the unique constraint."""
484
598
 
599
+ @field_serializer("type")
600
+ def serialize_type(self, value: str) -> str:
601
+ """Ensure '@type' is included in serialized output.
602
+
603
+ Parameters
604
+ ----------
605
+ value
606
+ The value to serialize.
607
+
608
+ Returns
609
+ -------
610
+ `str`
611
+ The serialized value.
612
+ """
613
+ return value
614
+
485
615
 
486
616
  class ForeignKeyConstraint(Constraint):
487
617
  """Table foreign key constraint model.
488
618
 
489
619
  This constraint is used to define a foreign key relationship between two
490
- tables in the schema.
620
+ tables in the schema. There must be at least one column in the
621
+ `columns` list, and at least one column in the `referenced_columns` list
622
+ or a validation error will be raised.
491
623
 
492
624
  Notes
493
625
  -----
@@ -498,12 +630,62 @@ class ForeignKeyConstraint(Constraint):
498
630
  type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type")
499
631
  """Type of the constraint."""
500
632
 
501
- columns: list[str]
633
+ columns: list[str] = Field(min_length=1)
502
634
  """The columns comprising the foreign key."""
503
635
 
504
- referenced_columns: list[str] = Field(alias="referencedColumns")
636
+ referenced_columns: list[str] = Field(alias="referencedColumns", min_length=1)
505
637
  """The columns referenced by the foreign key."""
506
638
 
639
+ on_delete: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
640
+ """Action to take when the referenced row is deleted."""
641
+
642
+ on_update: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
643
+ """Action to take when the referenced row is updated."""
644
+
645
+ @field_serializer("type")
646
+ def serialize_type(self, value: str) -> str:
647
+ """Ensure '@type' is included in serialized output.
648
+
649
+ Parameters
650
+ ----------
651
+ value
652
+ The value to serialize.
653
+
654
+ Returns
655
+ -------
656
+ `str`
657
+ The serialized value.
658
+ """
659
+ return value
660
+
661
+ @model_validator(mode="after")
662
+ def check_column_lengths(self) -> ForeignKeyConstraint:
663
+ """Check that the `columns` and `referenced_columns` lists have the
664
+ same length.
665
+
666
+ Returns
667
+ -------
668
+ `ForeignKeyConstraint`
669
+ The foreign key constraint being validated.
670
+
671
+ Raises
672
+ ------
673
+ ValueError
674
+ Raised if the `columns` and `referenced_columns` lists do not have
675
+ the same length.
676
+ """
677
+ if len(self.columns) != len(self.referenced_columns):
678
+ raise ValueError(
679
+ "Columns and referencedColumns must have the same length for a ForeignKey constraint"
680
+ )
681
+ return self
682
+
683
+
684
+ _ConstraintType = Annotated[
685
+ CheckConstraint | ForeignKeyConstraint | UniqueConstraint, Field(discriminator="type")
686
+ ]
687
+ """Type alias for a constraint type."""
688
+
507
689
 
508
690
  class Index(BaseObject):
509
691
  """Table index model.
@@ -545,23 +727,91 @@ class Index(BaseObject):
545
727
  return values
546
728
 
547
729
 
548
- _ConstraintType = Annotated[
549
- Union[CheckConstraint, ForeignKeyConstraint, UniqueConstraint], Field(discriminator="type")
550
- ]
551
- """Type alias for a constraint type."""
730
+ ColumnRef: TypeAlias = str
731
+ """Type alias for a column reference."""
552
732
 
553
733
 
554
- class Table(BaseObject):
555
- """Table model."""
734
+ class ColumnGroup(BaseObject):
735
+ """Column group model."""
556
736
 
557
- columns: Sequence[Column]
558
- """Columns in the table."""
737
+ columns: list[ColumnRef | Column] = Field(..., min_length=1)
738
+ """Columns in the group."""
559
739
 
560
- constraints: list[_ConstraintType] = Field(default_factory=list)
561
- """Constraints on the table."""
740
+ ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
741
+ """IVOA UCD of the column."""
562
742
 
563
- indexes: list[Index] = Field(default_factory=list)
564
- """Indexes on the table."""
743
+ table: Table | None = Field(None, exclude=True)
744
+ """Reference to the parent table."""
745
+
746
+ @field_validator("ivoa_ucd")
747
+ @classmethod
748
+ def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
749
+ """Check that IVOA UCD values are valid.
750
+
751
+ Parameters
752
+ ----------
753
+ ivoa_ucd
754
+ IVOA UCD value to check.
755
+
756
+ Returns
757
+ -------
758
+ `str`
759
+ The IVOA UCD value if it is valid.
760
+ """
761
+ return validate_ivoa_ucd(ivoa_ucd)
762
+
763
+ @model_validator(mode="after")
764
+ def check_unique_columns(self) -> ColumnGroup:
765
+ """Check that the columns list contains unique items.
766
+
767
+ Returns
768
+ -------
769
+ `ColumnGroup`
770
+ The column group being validated.
771
+ """
772
+ column_ids = [col if isinstance(col, str) else col.id for col in self.columns]
773
+ if len(column_ids) != len(set(column_ids)):
774
+ raise ValueError("Columns in the group must be unique")
775
+ return self
776
+
777
+ def _dereference_columns(self) -> None:
778
+ """Dereference ColumnRef to Column objects."""
779
+ if self.table is None:
780
+ raise ValueError("ColumnGroup must have a reference to its parent table")
781
+
782
+ dereferenced_columns: list[ColumnRef | Column] = []
783
+ for col in self.columns:
784
+ if isinstance(col, str):
785
+ # Dereference ColumnRef to Column object
786
+ try:
787
+ col_obj = self.table._find_column_by_id(col)
788
+ except KeyError as e:
789
+ raise ValueError(f"Column '{col}' not found in table '{self.table.name}'") from e
790
+ dereferenced_columns.append(col_obj)
791
+ else:
792
+ dereferenced_columns.append(col)
793
+
794
+ self.columns = dereferenced_columns
795
+
796
+ @field_serializer("columns")
797
+ def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]:
798
+ """Serialize columns as their IDs.
799
+
800
+ Parameters
801
+ ----------
802
+ columns
803
+ The columns to serialize.
804
+
805
+ Returns
806
+ -------
807
+ `list` [ `str` ]
808
+ The serialized column IDs.
809
+ """
810
+ return [col if isinstance(col, str) else col.id for col in columns]
811
+
812
+
813
+ class Table(BaseObject):
814
+ """Table model."""
565
815
 
566
816
  primary_key: str | list[str] | None = Field(None, alias="primaryKey")
567
817
  """Primary key of the table."""
@@ -575,6 +825,18 @@ class Table(BaseObject):
575
825
  mysql_charset: str | None = Field(None, alias="mysql:charset")
576
826
  """MySQL charset to use for the table."""
577
827
 
828
+ columns: Sequence[Column]
829
+ """Columns in the table."""
830
+
831
+ column_groups: list[ColumnGroup] = Field(default_factory=list, alias="columnGroups")
832
+ """Column groups in the table."""
833
+
834
+ constraints: list[_ConstraintType] = Field(default_factory=list)
835
+ """Constraints on the table."""
836
+
837
+ indexes: list[Index] = Field(default_factory=list)
838
+ """Indexes on the table."""
839
+
578
840
  @field_validator("columns", mode="after")
579
841
  @classmethod
580
842
  def check_unique_column_names(cls, columns: list[Column]) -> list[Column]:
@@ -653,6 +915,43 @@ class Table(BaseObject):
653
915
  return self
654
916
  raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'")
655
917
 
918
+ def _find_column_by_id(self, id: str) -> Column:
919
+ """Find a column by ID.
920
+
921
+ Parameters
922
+ ----------
923
+ id
924
+ The ID of the column to find.
925
+
926
+ Returns
927
+ -------
928
+ `Column`
929
+ The column with the given ID.
930
+
931
+ Raises
932
+ ------
933
+ ValueError
934
+ Raised if the column is not found.
935
+ """
936
+ for column in self.columns:
937
+ if column.id == id:
938
+ return column
939
+ raise KeyError(f"Column '{id}' not found in table '{self.name}'")
940
+
941
+ @model_validator(mode="after")
942
+ def dereference_column_groups(self: Table) -> Table:
943
+ """Dereference columns in column groups.
944
+
945
+ Returns
946
+ -------
947
+ `Table`
948
+ The table with dereferenced column groups.
949
+ """
950
+ for group in self.column_groups:
951
+ group.table = self
952
+ group._dereference_columns()
953
+ return self
954
+
656
955
 
657
956
  class SchemaVersion(BaseModel):
658
957
  """Schema version model."""
@@ -696,10 +995,10 @@ class SchemaIdVisitor:
696
995
  if hasattr(obj, "id"):
697
996
  obj_id = getattr(obj, "id")
698
997
  if self.schema is not None:
699
- if obj_id in self.schema.id_map:
998
+ if obj_id in self.schema._id_map:
700
999
  self.duplicates.add(obj_id)
701
1000
  else:
702
- self.schema.id_map[obj_id] = obj
1001
+ self.schema._id_map[obj_id] = obj
703
1002
 
704
1003
  def visit_schema(self, schema: Schema) -> None:
705
1004
  """Visit the objects in a schema and build the ID map.
@@ -757,6 +1056,56 @@ class SchemaIdVisitor:
757
1056
  T = TypeVar("T", bound=BaseObject)
758
1057
 
759
1058
 
1059
+ def _strip_ids(data: Any) -> Any:
1060
+ """Recursively strip '@id' fields from a dictionary or list.
1061
+
1062
+ Parameters
1063
+ ----------
1064
+ data
1065
+ The data to strip IDs from, which can be a dictionary, list, or any
1066
+ other type. Other types will be returned unchanged.
1067
+ """
1068
+ if isinstance(data, dict):
1069
+ data.pop("@id", None)
1070
+ for k, v in data.items():
1071
+ data[k] = _strip_ids(v)
1072
+ return data
1073
+ elif isinstance(data, list):
1074
+ return [_strip_ids(item) for item in data]
1075
+ else:
1076
+ return data
1077
+
1078
+
1079
+ def _append_error(
1080
+ errors: list[InitErrorDetails],
1081
+ loc: tuple,
1082
+ input_value: Any,
1083
+ error_message: str,
1084
+ error_type: str = "value_error",
1085
+ ) -> None:
1086
+ """Append an error to the errors list.
1087
+
1088
+ Parameters
1089
+ ----------
1090
+ errors : list[InitErrorDetails]
1091
+ The list of errors to append to.
1092
+ loc : tuple
1093
+ The location of the error in the schema.
1094
+ input_value : Any
1095
+ The input value that caused the error.
1096
+ error_message : str
1097
+ The error message to include in the context.
1098
+ """
1099
+ errors.append(
1100
+ {
1101
+ "type": error_type,
1102
+ "loc": loc,
1103
+ "input": input_value,
1104
+ "ctx": {"error": error_message},
1105
+ }
1106
+ )
1107
+
1108
+
760
1109
  class Schema(BaseObject, Generic[T]):
761
1110
  """Database schema model.
762
1111
 
@@ -769,7 +1118,7 @@ class Schema(BaseObject, Generic[T]):
769
1118
  tables: Sequence[Table]
770
1119
  """The tables in the schema."""
771
1120
 
772
- id_map: dict[str, Any] = Field(default_factory=dict, exclude=True)
1121
+ _id_map: dict[str, Any] = PrivateAttr(default_factory=dict)
773
1122
  """Map of IDs to objects."""
774
1123
 
775
1124
  @model_validator(mode="before")
@@ -807,6 +1156,14 @@ class Schema(BaseObject, Generic[T]):
807
1156
  if "@id" not in column:
808
1157
  column["@id"] = f"#{table['name']}.{column['name']}"
809
1158
  logger.debug(f"Generated ID '{column['@id']}' for column '{column['name']}'")
1159
+ if "columnGroups" in table:
1160
+ for column_group in table["columnGroups"]:
1161
+ if "@id" not in column_group:
1162
+ column_group["@id"] = f"#{table['name']}.{column_group['name']}"
1163
+ logger.debug(
1164
+ f"Generated ID '{column_group['@id']}' for column group "
1165
+ f"'{column_group['name']}'"
1166
+ )
810
1167
  if "constraints" in table:
811
1168
  for constraint in table["constraints"]:
812
1169
  if "@id" not in constraint:
@@ -931,20 +1288,21 @@ class Schema(BaseObject, Generic[T]):
931
1288
 
932
1289
  return self
933
1290
 
934
- def _create_id_map(self: Schema) -> Schema:
1291
+ @model_validator(mode="after")
1292
+ def create_id_map(self: Schema) -> Schema:
935
1293
  """Create a map of IDs to objects.
936
1294
 
1295
+ Returns
1296
+ -------
1297
+ `Schema`
1298
+ The schema with the ID map created.
1299
+
937
1300
  Raises
938
1301
  ------
939
1302
  ValueError
940
1303
  Raised if duplicate identifiers are found in the schema.
941
-
942
- Notes
943
- -----
944
- This is called automatically by the `model_post_init` method. If the
945
- ID map is already populated, this method will return immediately.
946
1304
  """
947
- if len(self.id_map):
1305
+ if self._id_map:
948
1306
  logger.debug("Ignoring call to create_id_map() - ID map was already populated")
949
1307
  return self
950
1308
  visitor: SchemaIdVisitor = SchemaIdVisitor()
@@ -953,25 +1311,152 @@ class Schema(BaseObject, Generic[T]):
953
1311
  raise ValueError(
954
1312
  "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
955
1313
  )
1314
+ logger.debug("Created ID map with %d entries", len(self._id_map))
956
1315
  return self
957
1316
 
958
- def model_post_init(self, ctx: Any) -> None:
959
- """Post-initialization hook for the model.
1317
+ def _validate_column_id(
1318
+ self: Schema,
1319
+ column_id: str,
1320
+ loc: tuple,
1321
+ errors: list[InitErrorDetails],
1322
+ ) -> None:
1323
+ """Validate a column ID from a constraint and append errors if invalid.
960
1324
 
961
1325
  Parameters
962
1326
  ----------
963
- ctx
964
- The context object which was passed to the model.
1327
+ schema : Schema
1328
+ The schema being validated.
1329
+ column_id : str
1330
+ The column ID to validate.
1331
+ loc : tuple
1332
+ The location of the error in the schema.
1333
+ errors : list[InitErrorDetails]
1334
+ The list of errors to append to.
1335
+ """
1336
+ if column_id not in self:
1337
+ _append_error(
1338
+ errors,
1339
+ loc,
1340
+ column_id,
1341
+ f"Column ID '{column_id}' not found in schema",
1342
+ )
1343
+ elif not isinstance(self[column_id], Column):
1344
+ _append_error(
1345
+ errors,
1346
+ loc,
1347
+ column_id,
1348
+ f"ID '{column_id}' does not refer to a Column object",
1349
+ )
965
1350
 
966
- Notes
967
- -----
968
- This method is called automatically by Pydantic after the model is
969
- initialized. It is used to create the ID map for the schema.
1351
+ def _validate_foreign_key_column(
1352
+ self: Schema,
1353
+ column_id: str,
1354
+ table: Table,
1355
+ loc: tuple,
1356
+ errors: list[InitErrorDetails],
1357
+ ) -> None:
1358
+ """Validate a foreign key column ID from a constraint and append errors
1359
+ if invalid.
970
1360
 
971
- The ``ctx`` argument has the type `Any` because this is the function
972
- signature in Pydantic itself.
1361
+ Parameters
1362
+ ----------
1363
+ schema : Schema
1364
+ The schema being validated.
1365
+ column_id : str
1366
+ The foreign key column ID to validate.
1367
+ loc : tuple
1368
+ The location of the error in the schema.
1369
+ errors : list[InitErrorDetails]
1370
+ The list of errors to append to.
973
1371
  """
974
- self._create_id_map()
1372
+ try:
1373
+ table._find_column_by_id(column_id)
1374
+ except KeyError:
1375
+ _append_error(
1376
+ errors,
1377
+ loc,
1378
+ column_id,
1379
+ f"Column '{column_id}' not found in table '{table.name}'",
1380
+ )
1381
+
1382
+ @model_validator(mode="after")
1383
+ def check_constraints(self: Schema) -> Schema:
1384
+ """Check constraint objects for validity. This needs to be deferred
1385
+ until after the schema is fully loaded and the ID map is created.
1386
+
1387
+ Raises
1388
+ ------
1389
+ pydantic.ValidationError
1390
+ Raised if any constraints are invalid.
1391
+
1392
+ Returns
1393
+ -------
1394
+ `Schema`
1395
+ The schema being validated.
1396
+ """
1397
+ errors: list[InitErrorDetails] = []
1398
+
1399
+ for table_index, table in enumerate(self.tables):
1400
+ for constraint_index, constraint in enumerate(table.constraints):
1401
+ column_ids: list[str] = []
1402
+ referenced_column_ids: list[str] = []
1403
+
1404
+ if isinstance(constraint, ForeignKeyConstraint):
1405
+ column_ids += constraint.columns
1406
+ referenced_column_ids += constraint.referenced_columns
1407
+ elif isinstance(constraint, UniqueConstraint):
1408
+ column_ids += constraint.columns
1409
+ # No extra checks are required on CheckConstraint objects.
1410
+
1411
+ # Validate the foreign key columns
1412
+ for column_id in column_ids:
1413
+ self._validate_column_id(
1414
+ column_id,
1415
+ (
1416
+ "tables",
1417
+ table_index,
1418
+ "constraints",
1419
+ constraint_index,
1420
+ "columns",
1421
+ column_id,
1422
+ ),
1423
+ errors,
1424
+ )
1425
+ # Check that the foreign key column is within the source
1426
+ # table.
1427
+ self._validate_foreign_key_column(
1428
+ column_id,
1429
+ table,
1430
+ (
1431
+ "tables",
1432
+ table_index,
1433
+ "constraints",
1434
+ constraint_index,
1435
+ "columns",
1436
+ column_id,
1437
+ ),
1438
+ errors,
1439
+ )
1440
+
1441
+ # Validate the primary key (reference) columns
1442
+ for referenced_column_id in referenced_column_ids:
1443
+ self._validate_column_id(
1444
+ referenced_column_id,
1445
+ (
1446
+ "tables",
1447
+ table_index,
1448
+ "constraints",
1449
+ constraint_index,
1450
+ "referenced_columns",
1451
+ referenced_column_id,
1452
+ ),
1453
+ errors,
1454
+ )
1455
+
1456
+ if errors:
1457
+ raise ValidationError.from_exception_data("Schema validation failed", errors)
1458
+
1459
+ return self
975
1460
 
976
1461
  def __getitem__(self, id: str) -> BaseObject:
977
1462
  """Get an object by its ID.
@@ -988,7 +1473,7 @@ class Schema(BaseObject, Generic[T]):
988
1473
  """
989
1474
  if id not in self:
990
1475
  raise KeyError(f"Object with ID '{id}' not found in schema")
991
- return self.id_map[id]
1476
+ return self._id_map[id]
992
1477
 
993
1478
  def __contains__(self, id: str) -> bool:
994
1479
  """Check if an object with the given ID is in the schema.
@@ -998,7 +1483,7 @@ class Schema(BaseObject, Generic[T]):
998
1483
  id
999
1484
  The ID of the object to check.
1000
1485
  """
1001
- return id in self.id_map
1486
+ return id in self._id_map
1002
1487
 
1003
1488
  def find_object_by_id(self, id: str, obj_type: type[T]) -> T:
1004
1489
  """Find an object with the given type by its ID.
@@ -1114,3 +1599,58 @@ class Schema(BaseObject, Generic[T]):
1114
1599
  logger.debug("Loading schema from: '%s'", source)
1115
1600
  yaml_data = yaml.safe_load(source)
1116
1601
  return Schema.model_validate(yaml_data, context=context)
1602
+
1603
+ def _model_dump(self, strip_ids: bool = False) -> dict[str, Any]:
1604
+ """Dump the schema as a dictionary with some default arguments
1605
+ applied.
1606
+
1607
+ Parameters
1608
+ ----------
1609
+ strip_ids
1610
+ Whether to strip the IDs from the dumped data. Defaults to `False`.
1611
+
1612
+ Returns
1613
+ -------
1614
+ `dict` [ `str`, `Any` ]
1615
+ The dumped schema data as a dictionary.
1616
+ """
1617
+ data = self.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True)
1618
+ if strip_ids:
1619
+ data = _strip_ids(data)
1620
+ return data
1621
+
1622
+ def dump_yaml(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None:
1623
+ """Pretty print the schema as YAML.
1624
+
1625
+ Parameters
1626
+ ----------
1627
+ stream
1628
+ The stream to write the YAML data to.
1629
+ strip_ids
1630
+ Whether to strip the IDs from the dumped data. Defaults to `False`.
1631
+ """
1632
+ data = self._model_dump(strip_ids=strip_ids)
1633
+ yaml.safe_dump(
1634
+ data,
1635
+ stream,
1636
+ default_flow_style=False,
1637
+ sort_keys=False,
1638
+ )
1639
+
1640
+ def dump_json(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None:
1641
+ """Pretty print the schema as JSON.
1642
+
1643
+ Parameters
1644
+ ----------
1645
+ stream
1646
+ The stream to write the JSON data to.
1647
+ strip_ids
1648
+ Whether to strip the IDs from the dumped data. Defaults to `False`.
1649
+ """
1650
+ data = self._model_dump(strip_ids=strip_ids)
1651
+ json.dump(
1652
+ data,
1653
+ stream,
1654
+ indent=4,
1655
+ sort_keys=False,
1656
+ )