lsst-felis 28.2024.4500__py3-none-any.whl → 29.2025.4500__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
felis/datamodel.py CHANGED
@@ -23,16 +23,29 @@
23
23
 
24
24
  from __future__ import annotations
25
25
 
26
+ import json
26
27
  import logging
28
+ import sys
27
29
  from collections.abc import Sequence
28
30
  from enum import StrEnum, auto
29
- from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar, Union
31
+ from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar
30
32
 
31
33
  import yaml
32
34
  from astropy import units as units # type: ignore
33
35
  from astropy.io.votable import ucd # type: ignore
34
36
  from lsst.resources import ResourcePath, ResourcePathExpression
35
- from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator
37
+ from pydantic import (
38
+ BaseModel,
39
+ ConfigDict,
40
+ Field,
41
+ PrivateAttr,
42
+ ValidationError,
43
+ ValidationInfo,
44
+ field_serializer,
45
+ field_validator,
46
+ model_validator,
47
+ )
48
+ from pydantic_core import InitErrorDetails
36
49
 
37
50
  from .db.dialects import get_supported_dialects
38
51
  from .db.sqltypes import get_type_func
@@ -43,9 +56,10 @@ logger = logging.getLogger(__name__)
43
56
 
44
57
  __all__ = (
45
58
  "BaseObject",
46
- "Column",
47
59
  "CheckConstraint",
60
+ "Column",
48
61
  "Constraint",
62
+ "DataType",
49
63
  "ForeignKeyConstraint",
50
64
  "Index",
51
65
  "Schema",
@@ -58,6 +72,7 @@ CONFIG = ConfigDict(
58
72
  populate_by_name=True, # Populate attributes by name.
59
73
  extra="forbid", # Do not allow extra fields.
60
74
  str_strip_whitespace=True, # Strip whitespace from string fields.
75
+ use_enum_values=False, # Do not use enum values during serialization.
61
76
  )
62
77
  """Pydantic model configuration as described in:
63
78
  https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict
@@ -117,7 +132,7 @@ class BaseObject(BaseModel):
117
132
 
118
133
 
119
134
  class DataType(StrEnum):
120
- """`Enum` representing the data types supported by Felis."""
135
+ """``Enum`` representing the data types supported by Felis."""
121
136
 
122
137
  boolean = auto()
123
138
  byte = auto()
@@ -134,6 +149,32 @@ class DataType(StrEnum):
134
149
  timestamp = auto()
135
150
 
136
151
 
152
+ def validate_ivoa_ucd(ivoa_ucd: str) -> str:
153
+ """Validate IVOA UCD values.
154
+
155
+ Parameters
156
+ ----------
157
+ ivoa_ucd
158
+ IVOA UCD value to check.
159
+
160
+ Returns
161
+ -------
162
+ `str`
163
+ The IVOA UCD value if it is valid.
164
+
165
+ Raises
166
+ ------
167
+ ValueError
168
+ If the IVOA UCD value is invalid.
169
+ """
170
+ if ivoa_ucd is not None:
171
+ try:
172
+ ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
173
+ except ValueError as e:
174
+ raise ValueError(f"Invalid IVOA UCD: {e}")
175
+ return ivoa_ucd
176
+
177
+
137
178
  class Column(BaseObject):
138
179
  """Column model."""
139
180
 
@@ -159,12 +200,6 @@ class Column(BaseObject):
159
200
  autoincrement: bool | None = None
160
201
  """Whether the column is autoincremented."""
161
202
 
162
- mysql_datatype: str | None = Field(None, alias="mysql:datatype")
163
- """MySQL datatype override on the column."""
164
-
165
- postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
166
- """PostgreSQL datatype override on the column."""
167
-
168
203
  ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
169
204
  """IVOA UCD of the column."""
170
205
 
@@ -193,6 +228,12 @@ class Column(BaseObject):
193
228
  votable_datatype: str | None = Field(None, alias="votable:datatype")
194
229
  """VOTable datatype of the column."""
195
230
 
231
+ mysql_datatype: str | None = Field(None, alias="mysql:datatype")
232
+ """MySQL datatype override on the column."""
233
+
234
+ postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
235
+ """PostgreSQL datatype override on the column."""
236
+
196
237
  @model_validator(mode="after")
197
238
  def check_value(self) -> Column:
198
239
  """Check that the default value is valid.
@@ -235,12 +276,7 @@ class Column(BaseObject):
235
276
  `str`
236
277
  The IVOA UCD value if it is valid.
237
278
  """
238
- if ivoa_ucd is not None:
239
- try:
240
- ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
241
- except ValueError as e:
242
- raise ValueError(f"Invalid IVOA UCD: {e}")
243
- return ivoa_ucd
279
+ return validate_ivoa_ucd(ivoa_ucd)
244
280
 
245
281
  @model_validator(mode="after")
246
282
  def check_units(self) -> Column:
@@ -387,7 +423,7 @@ class Column(BaseObject):
387
423
 
388
424
  @model_validator(mode="before")
389
425
  @classmethod
390
- def check_votable_arraysize(cls, values: dict[str, Any]) -> dict[str, Any]:
426
+ def check_votable_arraysize(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
391
427
  """Set the default value for the ``votable_arraysize`` field, which
392
428
  corresponds to ``arraysize`` in the IVOA VOTable standard.
393
429
 
@@ -395,6 +431,8 @@ class Column(BaseObject):
395
431
  ----------
396
432
  values
397
433
  Values of the column.
434
+ info
435
+ Validation context used to determine if the check is enabled.
398
436
 
399
437
  Returns
400
438
  -------
@@ -409,6 +447,7 @@ class Column(BaseObject):
409
447
  if values.get("name", None) is None or values.get("datatype", None) is None:
410
448
  # Skip bad column data that will not validate
411
449
  return values
450
+ context = info.context if info.context else {}
412
451
  arraysize = values.get("votable:arraysize", None)
413
452
  if arraysize is None:
414
453
  length = values.get("length", None)
@@ -418,7 +457,14 @@ class Column(BaseObject):
418
457
  if datatype == "char":
419
458
  arraysize = str(length)
420
459
  elif datatype in ("string", "unicode", "binary"):
421
- arraysize = f"{length}*"
460
+ if context.get("force_unbounded_arraysize", False):
461
+ arraysize = "*"
462
+ logger.debug(
463
+ f"Forced VOTable's 'arraysize' to '*' on column '{values['name']}' with datatype "
464
+ + f"'{values['datatype']}' and length '{length}'"
465
+ )
466
+ else:
467
+ arraysize = f"{length}*"
422
468
  elif datatype in ("timestamp", "text"):
423
469
  arraysize = "*"
424
470
  if arraysize is not None:
@@ -437,6 +483,59 @@ class Column(BaseObject):
437
483
  values["votable:arraysize"] = str(arraysize)
438
484
  return values
439
485
 
486
+ @field_serializer("datatype")
487
+ def serialize_datatype(self, value: DataType) -> str:
488
+ """Convert `DataType` to string when serializing to JSON/YAML.
489
+
490
+ Parameters
491
+ ----------
492
+ value
493
+ The `DataType` value to serialize.
494
+
495
+ Returns
496
+ -------
497
+ `str`
498
+ The serialized `DataType` value.
499
+ """
500
+ return str(value)
501
+
502
+ @field_validator("datatype", mode="before")
503
+ @classmethod
504
+ def deserialize_datatype(cls, value: str) -> DataType:
505
+ """Convert string back into `DataType` when loading from JSON/YAML.
506
+
507
+ Parameters
508
+ ----------
509
+ value
510
+ The string value to deserialize.
511
+
512
+ Returns
513
+ -------
514
+ `DataType`
515
+ The deserialized `DataType` value.
516
+ """
517
+ return DataType(value)
518
+
519
+ @model_validator(mode="after")
520
+ def check_votable_xtype(self) -> Column:
521
+ """Set the default value for the ``votable_xtype`` field, which
522
+ corresponds to an Extended Datatype or ``xtype`` in the IVOA VOTable
523
+ standard.
524
+
525
+ Returns
526
+ -------
527
+ `Column`
528
+ The column being validated.
529
+
530
+ Notes
531
+ -----
532
+ This is currently only set automatically for the Felis ``timestamp``
533
+ datatype.
534
+ """
535
+ if self.datatype == DataType.timestamp and self.votable_xtype is None:
536
+ self.votable_xtype = "timestamp"
537
+ return self
538
+
440
539
 
441
540
  class Constraint(BaseObject):
442
541
  """Table constraint model."""
@@ -472,6 +571,22 @@ class CheckConstraint(Constraint):
472
571
  expression: str
473
572
  """Expression for the check constraint."""
474
573
 
574
+ @field_serializer("type")
575
+ def serialize_type(self, value: str) -> str:
576
+ """Ensure '@type' is included in serialized output.
577
+
578
+ Parameters
579
+ ----------
580
+ value
581
+ The value to serialize.
582
+
583
+ Returns
584
+ -------
585
+ `str`
586
+ The serialized value.
587
+ """
588
+ return value
589
+
475
590
 
476
591
  class UniqueConstraint(Constraint):
477
592
  """Table unique constraint model."""
@@ -482,12 +597,30 @@ class UniqueConstraint(Constraint):
482
597
  columns: list[str]
483
598
  """Columns in the unique constraint."""
484
599
 
600
+ @field_serializer("type")
601
+ def serialize_type(self, value: str) -> str:
602
+ """Ensure '@type' is included in serialized output.
603
+
604
+ Parameters
605
+ ----------
606
+ value
607
+ The value to serialize.
608
+
609
+ Returns
610
+ -------
611
+ `str`
612
+ The serialized value.
613
+ """
614
+ return value
615
+
485
616
 
486
617
  class ForeignKeyConstraint(Constraint):
487
618
  """Table foreign key constraint model.
488
619
 
489
620
  This constraint is used to define a foreign key relationship between two
490
- tables in the schema.
621
+ tables in the schema. There must be at least one column in the
622
+ `columns` list, and at least one column in the `referenced_columns` list
623
+ or a validation error will be raised.
491
624
 
492
625
  Notes
493
626
  -----
@@ -498,12 +631,62 @@ class ForeignKeyConstraint(Constraint):
498
631
  type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type")
499
632
  """Type of the constraint."""
500
633
 
501
- columns: list[str]
634
+ columns: list[str] = Field(min_length=1)
502
635
  """The columns comprising the foreign key."""
503
636
 
504
- referenced_columns: list[str] = Field(alias="referencedColumns")
637
+ referenced_columns: list[str] = Field(alias="referencedColumns", min_length=1)
505
638
  """The columns referenced by the foreign key."""
506
639
 
640
+ on_delete: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
641
+ """Action to take when the referenced row is deleted."""
642
+
643
+ on_update: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
644
+ """Action to take when the referenced row is updated."""
645
+
646
+ @field_serializer("type")
647
+ def serialize_type(self, value: str) -> str:
648
+ """Ensure '@type' is included in serialized output.
649
+
650
+ Parameters
651
+ ----------
652
+ value
653
+ The value to serialize.
654
+
655
+ Returns
656
+ -------
657
+ `str`
658
+ The serialized value.
659
+ """
660
+ return value
661
+
662
+ @model_validator(mode="after")
663
+ def check_column_lengths(self) -> ForeignKeyConstraint:
664
+ """Check that the `columns` and `referenced_columns` lists have the
665
+ same length.
666
+
667
+ Returns
668
+ -------
669
+ `ForeignKeyConstraint`
670
+ The foreign key constraint being validated.
671
+
672
+ Raises
673
+ ------
674
+ ValueError
675
+ Raised if the `columns` and `referenced_columns` lists do not have
676
+ the same length.
677
+ """
678
+ if len(self.columns) != len(self.referenced_columns):
679
+ raise ValueError(
680
+ "Columns and referencedColumns must have the same length for a ForeignKey constraint"
681
+ )
682
+ return self
683
+
684
+
685
+ _ConstraintType = Annotated[
686
+ CheckConstraint | ForeignKeyConstraint | UniqueConstraint, Field(discriminator="type")
687
+ ]
688
+ """Type alias for a constraint type."""
689
+
507
690
 
508
691
  class Index(BaseObject):
509
692
  """Table index model.
@@ -545,23 +728,91 @@ class Index(BaseObject):
545
728
  return values
546
729
 
547
730
 
548
- _ConstraintType = Annotated[
549
- Union[CheckConstraint, ForeignKeyConstraint, UniqueConstraint], Field(discriminator="type")
550
- ]
551
- """Type alias for a constraint type."""
731
+ ColumnRef: TypeAlias = str
732
+ """Type alias for a column reference."""
552
733
 
553
734
 
554
- class Table(BaseObject):
555
- """Table model."""
735
+ class ColumnGroup(BaseObject):
736
+ """Column group model."""
556
737
 
557
- columns: Sequence[Column]
558
- """Columns in the table."""
738
+ columns: list[ColumnRef | Column] = Field(..., min_length=1)
739
+ """Columns in the group."""
559
740
 
560
- constraints: list[_ConstraintType] = Field(default_factory=list)
561
- """Constraints on the table."""
741
+ ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
742
+ """IVOA UCD of the column."""
562
743
 
563
- indexes: list[Index] = Field(default_factory=list)
564
- """Indexes on the table."""
744
+ table: Table | None = Field(None, exclude=True)
745
+ """Reference to the parent table."""
746
+
747
+ @field_validator("ivoa_ucd")
748
+ @classmethod
749
+ def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
750
+ """Check that IVOA UCD values are valid.
751
+
752
+ Parameters
753
+ ----------
754
+ ivoa_ucd
755
+ IVOA UCD value to check.
756
+
757
+ Returns
758
+ -------
759
+ `str`
760
+ The IVOA UCD value if it is valid.
761
+ """
762
+ return validate_ivoa_ucd(ivoa_ucd)
763
+
764
+ @model_validator(mode="after")
765
+ def check_unique_columns(self) -> ColumnGroup:
766
+ """Check that the columns list contains unique items.
767
+
768
+ Returns
769
+ -------
770
+ `ColumnGroup`
771
+ The column group being validated.
772
+ """
773
+ column_ids = [col if isinstance(col, str) else col.id for col in self.columns]
774
+ if len(column_ids) != len(set(column_ids)):
775
+ raise ValueError("Columns in the group must be unique")
776
+ return self
777
+
778
+ def _dereference_columns(self) -> None:
779
+ """Dereference ColumnRef to Column objects."""
780
+ if self.table is None:
781
+ raise ValueError("ColumnGroup must have a reference to its parent table")
782
+
783
+ dereferenced_columns: list[ColumnRef | Column] = []
784
+ for col in self.columns:
785
+ if isinstance(col, str):
786
+ # Dereference ColumnRef to Column object
787
+ try:
788
+ col_obj = self.table._find_column_by_id(col)
789
+ except KeyError as e:
790
+ raise ValueError(f"Column '{col}' not found in table '{self.table.name}'") from e
791
+ dereferenced_columns.append(col_obj)
792
+ else:
793
+ dereferenced_columns.append(col)
794
+
795
+ self.columns = dereferenced_columns
796
+
797
+ @field_serializer("columns")
798
+ def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]:
799
+ """Serialize columns as their IDs.
800
+
801
+ Parameters
802
+ ----------
803
+ columns
804
+ The columns to serialize.
805
+
806
+ Returns
807
+ -------
808
+ `list` [ `str` ]
809
+ The serialized column IDs.
810
+ """
811
+ return [col if isinstance(col, str) else col.id for col in columns]
812
+
813
+
814
+ class Table(BaseObject):
815
+ """Table model."""
565
816
 
566
817
  primary_key: str | list[str] | None = Field(None, alias="primaryKey")
567
818
  """Primary key of the table."""
@@ -575,6 +826,18 @@ class Table(BaseObject):
575
826
  mysql_charset: str | None = Field(None, alias="mysql:charset")
576
827
  """MySQL charset to use for the table."""
577
828
 
829
+ columns: Sequence[Column]
830
+ """Columns in the table."""
831
+
832
+ column_groups: list[ColumnGroup] = Field(default_factory=list, alias="columnGroups")
833
+ """Column groups in the table."""
834
+
835
+ constraints: list[_ConstraintType] = Field(default_factory=list)
836
+ """Constraints on the table."""
837
+
838
+ indexes: list[Index] = Field(default_factory=list)
839
+ """Indexes on the table."""
840
+
578
841
  @field_validator("columns", mode="after")
579
842
  @classmethod
580
843
  def check_unique_column_names(cls, columns: list[Column]) -> list[Column]:
@@ -653,6 +916,43 @@ class Table(BaseObject):
653
916
  return self
654
917
  raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'")
655
918
 
919
+ def _find_column_by_id(self, id: str) -> Column:
920
+ """Find a column by ID.
921
+
922
+ Parameters
923
+ ----------
924
+ id
925
+ The ID of the column to find.
926
+
927
+ Returns
928
+ -------
929
+ `Column`
930
+ The column with the given ID.
931
+
932
+ Raises
933
+ ------
934
+ ValueError
935
+ Raised if the column is not found.
936
+ """
937
+ for column in self.columns:
938
+ if column.id == id:
939
+ return column
940
+ raise KeyError(f"Column '{id}' not found in table '{self.name}'")
941
+
942
+ @model_validator(mode="after")
943
+ def dereference_column_groups(self: Table) -> Table:
944
+ """Dereference columns in column groups.
945
+
946
+ Returns
947
+ -------
948
+ `Table`
949
+ The table with dereferenced column groups.
950
+ """
951
+ for group in self.column_groups:
952
+ group.table = self
953
+ group._dereference_columns()
954
+ return self
955
+
656
956
 
657
957
  class SchemaVersion(BaseModel):
658
958
  """Schema version model."""
@@ -696,10 +996,10 @@ class SchemaIdVisitor:
696
996
  if hasattr(obj, "id"):
697
997
  obj_id = getattr(obj, "id")
698
998
  if self.schema is not None:
699
- if obj_id in self.schema.id_map:
999
+ if obj_id in self.schema._id_map:
700
1000
  self.duplicates.add(obj_id)
701
1001
  else:
702
- self.schema.id_map[obj_id] = obj
1002
+ self.schema._id_map[obj_id] = obj
703
1003
 
704
1004
  def visit_schema(self, schema: Schema) -> None:
705
1005
  """Visit the objects in a schema and build the ID map.
@@ -757,6 +1057,56 @@ class SchemaIdVisitor:
757
1057
  T = TypeVar("T", bound=BaseObject)
758
1058
 
759
1059
 
1060
+ def _strip_ids(data: Any) -> Any:
1061
+ """Recursively strip '@id' fields from a dictionary or list.
1062
+
1063
+ Parameters
1064
+ ----------
1065
+ data
1066
+ The data to strip IDs from, which can be a dictionary, list, or any
1067
+ other type. Other types will be returned unchanged.
1068
+ """
1069
+ if isinstance(data, dict):
1070
+ data.pop("@id", None)
1071
+ for k, v in data.items():
1072
+ data[k] = _strip_ids(v)
1073
+ return data
1074
+ elif isinstance(data, list):
1075
+ return [_strip_ids(item) for item in data]
1076
+ else:
1077
+ return data
1078
+
1079
+
1080
+ def _append_error(
1081
+ errors: list[InitErrorDetails],
1082
+ loc: tuple,
1083
+ input_value: Any,
1084
+ error_message: str,
1085
+ error_type: str = "value_error",
1086
+ ) -> None:
1087
+ """Append an error to the errors list.
1088
+
1089
+ Parameters
1090
+ ----------
1091
+ errors : list[InitErrorDetails]
1092
+ The list of errors to append to.
1093
+ loc : tuple
1094
+ The location of the error in the schema.
1095
+ input_value : Any
1096
+ The input value that caused the error.
1097
+ error_message : str
1098
+ The error message to include in the context.
1099
+ """
1100
+ errors.append(
1101
+ {
1102
+ "type": error_type,
1103
+ "loc": loc,
1104
+ "input": input_value,
1105
+ "ctx": {"error": error_message},
1106
+ }
1107
+ )
1108
+
1109
+
760
1110
  class Schema(BaseObject, Generic[T]):
761
1111
  """Database schema model.
762
1112
 
@@ -769,7 +1119,7 @@ class Schema(BaseObject, Generic[T]):
769
1119
  tables: Sequence[Table]
770
1120
  """The tables in the schema."""
771
1121
 
772
- id_map: dict[str, Any] = Field(default_factory=dict, exclude=True)
1122
+ _id_map: dict[str, Any] = PrivateAttr(default_factory=dict)
773
1123
  """Map of IDs to objects."""
774
1124
 
775
1125
  @model_validator(mode="before")
@@ -807,6 +1157,14 @@ class Schema(BaseObject, Generic[T]):
807
1157
  if "@id" not in column:
808
1158
  column["@id"] = f"#{table['name']}.{column['name']}"
809
1159
  logger.debug(f"Generated ID '{column['@id']}' for column '{column['name']}'")
1160
+ if "columnGroups" in table:
1161
+ for column_group in table["columnGroups"]:
1162
+ if "@id" not in column_group:
1163
+ column_group["@id"] = f"#{table['name']}.{column_group['name']}"
1164
+ logger.debug(
1165
+ f"Generated ID '{column_group['@id']}' for column group "
1166
+ f"'{column_group['name']}'"
1167
+ )
810
1168
  if "constraints" in table:
811
1169
  for constraint in table["constraints"]:
812
1170
  if "@id" not in constraint:
@@ -931,20 +1289,21 @@ class Schema(BaseObject, Generic[T]):
931
1289
 
932
1290
  return self
933
1291
 
934
- def _create_id_map(self: Schema) -> Schema:
1292
+ @model_validator(mode="after")
1293
+ def create_id_map(self: Schema) -> Schema:
935
1294
  """Create a map of IDs to objects.
936
1295
 
1296
+ Returns
1297
+ -------
1298
+ `Schema`
1299
+ The schema with the ID map created.
1300
+
937
1301
  Raises
938
1302
  ------
939
1303
  ValueError
940
1304
  Raised if duplicate identifiers are found in the schema.
941
-
942
- Notes
943
- -----
944
- This is called automatically by the `model_post_init` method. If the
945
- ID map is already populated, this method will return immediately.
946
1305
  """
947
- if len(self.id_map):
1306
+ if self._id_map:
948
1307
  logger.debug("Ignoring call to create_id_map() - ID map was already populated")
949
1308
  return self
950
1309
  visitor: SchemaIdVisitor = SchemaIdVisitor()
@@ -953,25 +1312,152 @@ class Schema(BaseObject, Generic[T]):
953
1312
  raise ValueError(
954
1313
  "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
955
1314
  )
1315
+ logger.debug("Created ID map with %d entries", len(self._id_map))
956
1316
  return self
957
1317
 
958
- def model_post_init(self, ctx: Any) -> None:
959
- """Post-initialization hook for the model.
1318
+ def _validate_column_id(
1319
+ self: Schema,
1320
+ column_id: str,
1321
+ loc: tuple,
1322
+ errors: list[InitErrorDetails],
1323
+ ) -> None:
1324
+ """Validate a column ID from a constraint and append errors if invalid.
960
1325
 
961
1326
  Parameters
962
1327
  ----------
963
- ctx
964
- The context object which was passed to the model.
1328
+ schema : Schema
1329
+ The schema being validated.
1330
+ column_id : str
1331
+ The column ID to validate.
1332
+ loc : tuple
1333
+ The location of the error in the schema.
1334
+ errors : list[InitErrorDetails]
1335
+ The list of errors to append to.
1336
+ """
1337
+ if column_id not in self:
1338
+ _append_error(
1339
+ errors,
1340
+ loc,
1341
+ column_id,
1342
+ f"Column ID '{column_id}' not found in schema",
1343
+ )
1344
+ elif not isinstance(self[column_id], Column):
1345
+ _append_error(
1346
+ errors,
1347
+ loc,
1348
+ column_id,
1349
+ f"ID '{column_id}' does not refer to a Column object",
1350
+ )
965
1351
 
966
- Notes
967
- -----
968
- This method is called automatically by Pydantic after the model is
969
- initialized. It is used to create the ID map for the schema.
1352
+ def _validate_foreign_key_column(
1353
+ self: Schema,
1354
+ column_id: str,
1355
+ table: Table,
1356
+ loc: tuple,
1357
+ errors: list[InitErrorDetails],
1358
+ ) -> None:
1359
+ """Validate a foreign key column ID from a constraint and append errors
1360
+ if invalid.
970
1361
 
971
- The ``ctx`` argument has the type `Any` because this is the function
972
- signature in Pydantic itself.
1362
+ Parameters
1363
+ ----------
1364
+ schema : Schema
1365
+ The schema being validated.
1366
+ column_id : str
1367
+ The foreign key column ID to validate.
1368
+ loc : tuple
1369
+ The location of the error in the schema.
1370
+ errors : list[InitErrorDetails]
1371
+ The list of errors to append to.
973
1372
  """
974
- self._create_id_map()
1373
+ try:
1374
+ table._find_column_by_id(column_id)
1375
+ except KeyError:
1376
+ _append_error(
1377
+ errors,
1378
+ loc,
1379
+ column_id,
1380
+ f"Column '{column_id}' not found in table '{table.name}'",
1381
+ )
1382
+
1383
+ @model_validator(mode="after")
1384
+ def check_constraints(self: Schema) -> Schema:
1385
+ """Check constraint objects for validity. This needs to be deferred
1386
+ until after the schema is fully loaded and the ID map is created.
1387
+
1388
+ Raises
1389
+ ------
1390
+ pydantic.ValidationError
1391
+ Raised if any constraints are invalid.
1392
+
1393
+ Returns
1394
+ -------
1395
+ `Schema`
1396
+ The schema being validated.
1397
+ """
1398
+ errors: list[InitErrorDetails] = []
1399
+
1400
+ for table_index, table in enumerate(self.tables):
1401
+ for constraint_index, constraint in enumerate(table.constraints):
1402
+ column_ids: list[str] = []
1403
+ referenced_column_ids: list[str] = []
1404
+
1405
+ if isinstance(constraint, ForeignKeyConstraint):
1406
+ column_ids += constraint.columns
1407
+ referenced_column_ids += constraint.referenced_columns
1408
+ elif isinstance(constraint, UniqueConstraint):
1409
+ column_ids += constraint.columns
1410
+ # No extra checks are required on CheckConstraint objects.
1411
+
1412
+ # Validate the foreign key columns
1413
+ for column_id in column_ids:
1414
+ self._validate_column_id(
1415
+ column_id,
1416
+ (
1417
+ "tables",
1418
+ table_index,
1419
+ "constraints",
1420
+ constraint_index,
1421
+ "columns",
1422
+ column_id,
1423
+ ),
1424
+ errors,
1425
+ )
1426
+ # Check that the foreign key column is within the source
1427
+ # table.
1428
+ self._validate_foreign_key_column(
1429
+ column_id,
1430
+ table,
1431
+ (
1432
+ "tables",
1433
+ table_index,
1434
+ "constraints",
1435
+ constraint_index,
1436
+ "columns",
1437
+ column_id,
1438
+ ),
1439
+ errors,
1440
+ )
1441
+
1442
+ # Validate the primary key (reference) columns
1443
+ for referenced_column_id in referenced_column_ids:
1444
+ self._validate_column_id(
1445
+ referenced_column_id,
1446
+ (
1447
+ "tables",
1448
+ table_index,
1449
+ "constraints",
1450
+ constraint_index,
1451
+ "referenced_columns",
1452
+ referenced_column_id,
1453
+ ),
1454
+ errors,
1455
+ )
1456
+
1457
+ if errors:
1458
+ raise ValidationError.from_exception_data("Schema validation failed", errors)
1459
+
1460
+ return self
975
1461
 
976
1462
  def __getitem__(self, id: str) -> BaseObject:
977
1463
  """Get an object by its ID.
@@ -988,7 +1474,7 @@ class Schema(BaseObject, Generic[T]):
988
1474
  """
989
1475
  if id not in self:
990
1476
  raise KeyError(f"Object with ID '{id}' not found in schema")
991
- return self.id_map[id]
1477
+ return self._id_map[id]
992
1478
 
993
1479
  def __contains__(self, id: str) -> bool:
994
1480
  """Check if an object with the given ID is in the schema.
@@ -998,7 +1484,7 @@ class Schema(BaseObject, Generic[T]):
998
1484
  id
999
1485
  The ID of the object to check.
1000
1486
  """
1001
- return id in self.id_map
1487
+ return id in self._id_map
1002
1488
 
1003
1489
  def find_object_by_id(self, id: str, obj_type: type[T]) -> T:
1004
1490
  """Find an object with the given type by its ID.
@@ -1114,3 +1600,58 @@ class Schema(BaseObject, Generic[T]):
1114
1600
  logger.debug("Loading schema from: '%s'", source)
1115
1601
  yaml_data = yaml.safe_load(source)
1116
1602
  return Schema.model_validate(yaml_data, context=context)
1603
+
1604
+ def _model_dump(self, strip_ids: bool = False) -> dict[str, Any]:
1605
+ """Dump the schema as a dictionary with some default arguments
1606
+ applied.
1607
+
1608
+ Parameters
1609
+ ----------
1610
+ strip_ids
1611
+ Whether to strip the IDs from the dumped data. Defaults to `False`.
1612
+
1613
+ Returns
1614
+ -------
1615
+ `dict` [ `str`, `Any` ]
1616
+ The dumped schema data as a dictionary.
1617
+ """
1618
+ data = self.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True)
1619
+ if strip_ids:
1620
+ data = _strip_ids(data)
1621
+ return data
1622
+
1623
+ def dump_yaml(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None:
1624
+ """Pretty print the schema as YAML.
1625
+
1626
+ Parameters
1627
+ ----------
1628
+ stream
1629
+ The stream to write the YAML data to.
1630
+ strip_ids
1631
+ Whether to strip the IDs from the dumped data. Defaults to `False`.
1632
+ """
1633
+ data = self._model_dump(strip_ids=strip_ids)
1634
+ yaml.safe_dump(
1635
+ data,
1636
+ stream,
1637
+ default_flow_style=False,
1638
+ sort_keys=False,
1639
+ )
1640
+
1641
+ def dump_json(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None:
1642
+ """Pretty print the schema as JSON.
1643
+
1644
+ Parameters
1645
+ ----------
1646
+ stream
1647
+ The stream to write the JSON data to.
1648
+ strip_ids
1649
+ Whether to strip the IDs from the dumped data. Defaults to `False`.
1650
+ """
1651
+ data = self._model_dump(strip_ids=strip_ids)
1652
+ json.dump(
1653
+ data,
1654
+ stream,
1655
+ indent=4,
1656
+ sort_keys=False,
1657
+ )