lsst-felis 24.1.6rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lsst-felis might be problematic. Click here for more details.

felis/datamodel.py ADDED
@@ -0,0 +1,1116 @@
1
+ """Define Pydantic data models for Felis."""
2
+
3
+ # This file is part of felis.
4
+ #
5
+ # Developed for the LSST Data Management System.
6
+ # This product includes software developed by the LSST Project
7
+ # (https://www.lsst.org).
8
+ # See the COPYRIGHT file at the top-level directory of this distribution
9
+ # for details of code ownership.
10
+ #
11
+ # This program is free software: you can redistribute it and/or modify
12
+ # it under the terms of the GNU General Public License as published by
13
+ # the Free Software Foundation, either version 3 of the License, or
14
+ # (at your option) any later version.
15
+ #
16
+ # This program is distributed in the hope that it will be useful,
17
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
18
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19
+ # GNU General Public License for more details.
20
+ #
21
+ # You should have received a copy of the GNU General Public License
22
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
23
+
24
+ from __future__ import annotations
25
+
26
+ import logging
27
+ from collections.abc import Sequence
28
+ from enum import StrEnum, auto
29
+ from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar, Union
30
+
31
+ import yaml
32
+ from astropy import units as units # type: ignore
33
+ from astropy.io.votable import ucd # type: ignore
34
+ from lsst.resources import ResourcePath, ResourcePathExpression
35
+ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator
36
+
37
+ from .db.dialects import get_supported_dialects
38
+ from .db.sqltypes import get_type_func
39
+ from .db.utils import string_to_typeengine
40
+ from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+ __all__ = (
45
+ "BaseObject",
46
+ "Column",
47
+ "CheckConstraint",
48
+ "Constraint",
49
+ "ForeignKeyConstraint",
50
+ "Index",
51
+ "Schema",
52
+ "SchemaVersion",
53
+ "Table",
54
+ "UniqueConstraint",
55
+ )
56
+
57
+ CONFIG = ConfigDict(
58
+ populate_by_name=True, # Populate attributes by name.
59
+ extra="forbid", # Do not allow extra fields.
60
+ str_strip_whitespace=True, # Strip whitespace from string fields.
61
+ )
62
+ """Pydantic model configuration as described in:
63
+ https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict
64
+ """
65
+
66
+ DESCR_MIN_LENGTH = 3
67
+ """Minimum length for a description field."""
68
+
69
+ DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)]
70
+ """Type for a description, which must be three or more characters long."""
71
+
72
+
73
+ class BaseObject(BaseModel):
74
+ """Base model.
75
+
76
+ All classes representing objects in the Felis data model should inherit
77
+ from this class.
78
+ """
79
+
80
+ model_config = CONFIG
81
+ """Pydantic model configuration."""
82
+
83
+ name: str
84
+ """Name of the database object."""
85
+
86
+ id: str = Field(alias="@id")
87
+ """Unique identifier of the database object."""
88
+
89
+ description: DescriptionStr | None = None
90
+ """Description of the database object."""
91
+
92
+ votable_utype: str | None = Field(None, alias="votable:utype")
93
+ """VOTable utype (usage-specific or unique type) of the object."""
94
+
95
+ @model_validator(mode="after")
96
+ def check_description(self, info: ValidationInfo) -> BaseObject:
97
+ """Check that the description is present if required.
98
+
99
+ Parameters
100
+ ----------
101
+ info
102
+ Validation context used to determine if the check is enabled.
103
+
104
+ Returns
105
+ -------
106
+ `BaseObject`
107
+ The object being validated.
108
+ """
109
+ context = info.context
110
+ if not context or not context.get("check_description", False):
111
+ return self
112
+ if self.description is None or self.description == "":
113
+ raise ValueError("Description is required and must be non-empty")
114
+ if len(self.description) < DESCR_MIN_LENGTH:
115
+ raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long")
116
+ return self
117
+
118
+
119
+ class DataType(StrEnum):
120
+ """`Enum` representing the data types supported by Felis."""
121
+
122
+ boolean = auto()
123
+ byte = auto()
124
+ short = auto()
125
+ int = auto()
126
+ long = auto()
127
+ float = auto()
128
+ double = auto()
129
+ char = auto()
130
+ string = auto()
131
+ unicode = auto()
132
+ text = auto()
133
+ binary = auto()
134
+ timestamp = auto()
135
+
136
+
137
+ class Column(BaseObject):
138
+ """Column model."""
139
+
140
+ datatype: DataType
141
+ """Datatype of the column."""
142
+
143
+ length: int | None = Field(None, gt=0)
144
+ """Length of the column."""
145
+
146
+ precision: int | None = Field(None, ge=0)
147
+ """The numerical precision of the column.
148
+
149
+ For timestamps, this is the number of fractional digits retained in the
150
+ seconds field.
151
+ """
152
+
153
+ nullable: bool = True
154
+ """Whether the column can be ``NULL``."""
155
+
156
+ value: str | int | float | bool | None = None
157
+ """Default value of the column."""
158
+
159
+ autoincrement: bool | None = None
160
+ """Whether the column is autoincremented."""
161
+
162
+ mysql_datatype: str | None = Field(None, alias="mysql:datatype")
163
+ """MySQL datatype override on the column."""
164
+
165
+ postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
166
+ """PostgreSQL datatype override on the column."""
167
+
168
+ ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
169
+ """IVOA UCD of the column."""
170
+
171
+ fits_tunit: str | None = Field(None, alias="fits:tunit")
172
+ """FITS TUNIT of the column."""
173
+
174
+ ivoa_unit: str | None = Field(None, alias="ivoa:unit")
175
+ """IVOA unit of the column."""
176
+
177
+ tap_column_index: int | None = Field(None, alias="tap:column_index")
178
+ """TAP_SCHEMA column index of the column."""
179
+
180
+ tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1)
181
+ """Whether this is a TAP_SCHEMA principal column."""
182
+
183
+ votable_arraysize: int | str | None = Field(None, alias="votable:arraysize")
184
+ """VOTable arraysize of the column."""
185
+
186
+ tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1)
187
+ """TAP_SCHEMA indication that this column is defined by an IVOA standard.
188
+ """
189
+
190
+ votable_xtype: str | None = Field(None, alias="votable:xtype")
191
+ """VOTable xtype (extended type) of the column."""
192
+
193
+ votable_datatype: str | None = Field(None, alias="votable:datatype")
194
+ """VOTable datatype of the column."""
195
+
196
+ @model_validator(mode="after")
197
+ def check_value(self) -> Column:
198
+ """Check that the default value is valid.
199
+
200
+ Returns
201
+ -------
202
+ `Column`
203
+ The column being validated.
204
+ """
205
+ if (value := self.value) is not None:
206
+ if value is not None and self.autoincrement is True:
207
+ raise ValueError("Column cannot have both a default value and be autoincremented")
208
+ felis_type = FelisType.felis_type(self.datatype)
209
+ if felis_type.is_numeric:
210
+ if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int):
211
+ raise ValueError("Default value must be an int for integer type columns")
212
+ elif felis_type in (Float, Double) and not isinstance(value, float):
213
+ raise ValueError("Default value must be a decimal number for float and double columns")
214
+ elif felis_type in (String, Char, Unicode, Text):
215
+ if not isinstance(value, str):
216
+ raise ValueError("Default value must be a string for string columns")
217
+ if not len(value):
218
+ raise ValueError("Default value must be a non-empty string for string columns")
219
+ elif felis_type is Boolean and not isinstance(value, bool):
220
+ raise ValueError("Default value must be a boolean for boolean columns")
221
+ return self
222
+
223
+ @field_validator("ivoa_ucd")
224
+ @classmethod
225
+ def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
226
+ """Check that IVOA UCD values are valid.
227
+
228
+ Parameters
229
+ ----------
230
+ ivoa_ucd
231
+ IVOA UCD value to check.
232
+
233
+ Returns
234
+ -------
235
+ `str`
236
+ The IVOA UCD value if it is valid.
237
+ """
238
+ if ivoa_ucd is not None:
239
+ try:
240
+ ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
241
+ except ValueError as e:
242
+ raise ValueError(f"Invalid IVOA UCD: {e}")
243
+ return ivoa_ucd
244
+
245
+ @model_validator(mode="after")
246
+ def check_units(self) -> Column:
247
+ """Check that the ``fits:tunit`` or ``ivoa:unit`` field has valid
248
+ units according to astropy. Only one may be provided.
249
+
250
+ Returns
251
+ -------
252
+ `Column`
253
+ The column being validated.
254
+
255
+ Raises
256
+ ------
257
+ ValueError
258
+ Raised if both FITS and IVOA units are provided, or if the unit is
259
+ invalid.
260
+ """
261
+ fits_unit = self.fits_tunit
262
+ ivoa_unit = self.ivoa_unit
263
+
264
+ if fits_unit and ivoa_unit:
265
+ raise ValueError("Column cannot have both FITS and IVOA units")
266
+ unit = fits_unit or ivoa_unit
267
+
268
+ if unit is not None:
269
+ try:
270
+ units.Unit(unit)
271
+ except ValueError as e:
272
+ raise ValueError(f"Invalid unit: {e}")
273
+
274
+ return self
275
+
276
+ @model_validator(mode="before")
277
+ @classmethod
278
+ def check_length(cls, values: dict[str, Any]) -> dict[str, Any]:
279
+ """Check that a valid length is provided for sized types.
280
+
281
+ Parameters
282
+ ----------
283
+ values
284
+ Values of the column.
285
+
286
+ Returns
287
+ -------
288
+ `dict` [ `str`, `Any` ]
289
+ The values of the column.
290
+
291
+ Raises
292
+ ------
293
+ ValueError
294
+ Raised if a length is not provided for a sized type.
295
+ """
296
+ datatype = values.get("datatype")
297
+ if datatype is None:
298
+ # Skip this validation if datatype is not provided
299
+ return values
300
+ length = values.get("length")
301
+ felis_type = FelisType.felis_type(datatype)
302
+ if felis_type.is_sized and length is None:
303
+ raise ValueError(
304
+ f"Length must be provided for type '{datatype}'"
305
+ + (f" in column '{values['@id']}'" if "@id" in values else "")
306
+ )
307
+ elif not felis_type.is_sized and length is not None:
308
+ logger.warning(
309
+ f"The datatype '{datatype}' does not support a specified length"
310
+ + (f" in column '{values['@id']}'" if "@id" in values else "")
311
+ )
312
+ return values
313
+
314
+ @model_validator(mode="after")
315
+ def check_redundant_datatypes(self, info: ValidationInfo) -> Column:
316
+ """Check for redundant datatypes on columns.
317
+
318
+ Parameters
319
+ ----------
320
+ info
321
+ Validation context used to determine if the check is enabled.
322
+
323
+ Returns
324
+ -------
325
+ `Column`
326
+ The column being validated.
327
+
328
+ Raises
329
+ ------
330
+ ValueError
331
+ Raised if a datatype override is redundant.
332
+ """
333
+ context = info.context
334
+ if not context or not context.get("check_redundant_datatypes", False):
335
+ return self
336
+ if all(
337
+ getattr(self, f"{dialect}:datatype", None) is not None
338
+ for dialect in get_supported_dialects().keys()
339
+ ):
340
+ return self
341
+
342
+ datatype = self.datatype
343
+ length: int | None = self.length or None
344
+
345
+ datatype_func = get_type_func(datatype)
346
+ felis_type = FelisType.felis_type(datatype)
347
+ if felis_type.is_sized:
348
+ datatype_obj = datatype_func(length)
349
+ else:
350
+ datatype_obj = datatype_func()
351
+
352
+ for dialect_name, dialect in get_supported_dialects().items():
353
+ db_annotation = f"{dialect_name}_datatype"
354
+ if datatype_string := self.model_dump().get(db_annotation):
355
+ db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
356
+ if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect):
357
+ raise ValueError(
358
+ "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format(
359
+ db_annotation,
360
+ datatype_string,
361
+ self.datatype,
362
+ self.id,
363
+ "" if length is None else f" with length {length}",
364
+ )
365
+ )
366
+ else:
367
+ logger.debug(
368
+ f"Type override of 'datatype: {self.datatype}' "
369
+ f"with '{db_annotation}: {datatype_string}' in column '{self.id}' "
370
+ f"compiled to '{datatype_obj.compile(dialect)}' and "
371
+ f"'{db_datatype_obj.compile(dialect)}'"
372
+ )
373
+ return self
374
+
375
+ @model_validator(mode="after")
376
+ def check_precision(self) -> Column:
377
+ """Check that precision is only valid for timestamp columns.
378
+
379
+ Returns
380
+ -------
381
+ `Column`
382
+ The column being validated.
383
+ """
384
+ if self.precision is not None and self.datatype != "timestamp":
385
+ raise ValueError("Precision is only valid for timestamp columns")
386
+ return self
387
+
388
+ @model_validator(mode="before")
389
+ @classmethod
390
+ def check_votable_arraysize(cls, values: dict[str, Any]) -> dict[str, Any]:
391
+ """Set the default value for the ``votable_arraysize`` field, which
392
+ corresponds to ``arraysize`` in the IVOA VOTable standard.
393
+
394
+ Parameters
395
+ ----------
396
+ values
397
+ Values of the column.
398
+
399
+ Returns
400
+ -------
401
+ `dict` [ `str`, `Any` ]
402
+ The values of the column.
403
+
404
+ Notes
405
+ -----
406
+ Following the IVOA VOTable standard, an ``arraysize`` of 1 should not
407
+ be used.
408
+ """
409
+ if values.get("name", None) is None or values.get("datatype", None) is None:
410
+ # Skip bad column data that will not validate
411
+ return values
412
+ arraysize = values.get("votable:arraysize", None)
413
+ if arraysize is None:
414
+ length = values.get("length", None)
415
+ datatype = values.get("datatype")
416
+ if length is not None and length > 1:
417
+ # Following the IVOA standard, arraysize of 1 is disallowed
418
+ if datatype == "char":
419
+ arraysize = str(length)
420
+ elif datatype in ("string", "unicode", "binary"):
421
+ arraysize = f"{length}*"
422
+ elif datatype in ("timestamp", "text"):
423
+ arraysize = "*"
424
+ if arraysize is not None:
425
+ values["votable:arraysize"] = arraysize
426
+ logger.debug(
427
+ f"Set default 'votable:arraysize' to '{arraysize}' on column '{values['name']}'"
428
+ + f" with datatype '{values['datatype']}' and length '{values.get('length', None)}'"
429
+ )
430
+ else:
431
+ logger.debug(f"Using existing 'votable:arraysize' of '{arraysize}' on column '{values['name']}'")
432
+ if isinstance(values["votable:arraysize"], int):
433
+ logger.warning(
434
+ f"Usage of an integer value for 'votable:arraysize' in column '{values['name']}' is "
435
+ + "deprecated"
436
+ )
437
+ values["votable:arraysize"] = str(arraysize)
438
+ return values
439
+
440
+
441
+ class Constraint(BaseObject):
442
+ """Table constraint model."""
443
+
444
+ deferrable: bool = False
445
+ """Whether this constraint will be declared as deferrable."""
446
+
447
+ initially: Literal["IMMEDIATE", "DEFERRED"] | None = None
448
+ """Value for ``INITIALLY`` clause; only used if `deferrable` is
449
+ `True`."""
450
+
451
+ @model_validator(mode="after")
452
+ def check_deferrable(self) -> Constraint:
453
+ """Check that the ``INITIALLY`` clause is only used if `deferrable` is
454
+ `True`.
455
+
456
+ Returns
457
+ -------
458
+ `Constraint`
459
+ The constraint being validated.
460
+ """
461
+ if self.initially is not None and not self.deferrable:
462
+ raise ValueError("INITIALLY clause can only be used if deferrable is True")
463
+ return self
464
+
465
+
466
+ class CheckConstraint(Constraint):
467
+ """Table check constraint model."""
468
+
469
+ type: Literal["Check"] = Field("Check", alias="@type")
470
+ """Type of the constraint."""
471
+
472
+ expression: str
473
+ """Expression for the check constraint."""
474
+
475
+
476
+ class UniqueConstraint(Constraint):
477
+ """Table unique constraint model."""
478
+
479
+ type: Literal["Unique"] = Field("Unique", alias="@type")
480
+ """Type of the constraint."""
481
+
482
+ columns: list[str]
483
+ """Columns in the unique constraint."""
484
+
485
+
486
+ class ForeignKeyConstraint(Constraint):
487
+ """Table foreign key constraint model.
488
+
489
+ This constraint is used to define a foreign key relationship between two
490
+ tables in the schema.
491
+
492
+ Notes
493
+ -----
494
+ These relationships will be reflected in the TAP_SCHEMA ``keys`` and
495
+ ``key_columns`` data.
496
+ """
497
+
498
+ type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type")
499
+ """Type of the constraint."""
500
+
501
+ columns: list[str]
502
+ """The columns comprising the foreign key."""
503
+
504
+ referenced_columns: list[str] = Field(alias="referencedColumns")
505
+ """The columns referenced by the foreign key."""
506
+
507
+
508
+ class Index(BaseObject):
509
+ """Table index model.
510
+
511
+ An index can be defined on either columns or expressions, but not both.
512
+ """
513
+
514
+ columns: list[str] | None = None
515
+ """Columns in the index."""
516
+
517
+ expressions: list[str] | None = None
518
+ """Expressions in the index."""
519
+
520
+ @model_validator(mode="before")
521
+ @classmethod
522
+ def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]:
523
+ """Check that columns or expressions are specified, but not both.
524
+
525
+ Parameters
526
+ ----------
527
+ values
528
+ Values of the index.
529
+
530
+ Returns
531
+ -------
532
+ `dict` [ `str`, `Any` ]
533
+ The values of the index.
534
+
535
+ Raises
536
+ ------
537
+ ValueError
538
+ Raised if both columns and expressions are specified, or if neither
539
+ are specified.
540
+ """
541
+ if "columns" in values and "expressions" in values:
542
+ raise ValueError("Defining columns and expressions is not valid")
543
+ elif "columns" not in values and "expressions" not in values:
544
+ raise ValueError("Must define columns or expressions")
545
+ return values
546
+
547
+
548
+ _ConstraintType = Annotated[
549
+ Union[CheckConstraint, ForeignKeyConstraint, UniqueConstraint], Field(discriminator="type")
550
+ ]
551
+ """Type alias for a constraint type."""
552
+
553
+
554
+ class Table(BaseObject):
555
+ """Table model."""
556
+
557
+ columns: Sequence[Column]
558
+ """Columns in the table."""
559
+
560
+ constraints: list[_ConstraintType] = Field(default_factory=list)
561
+ """Constraints on the table."""
562
+
563
+ indexes: list[Index] = Field(default_factory=list)
564
+ """Indexes on the table."""
565
+
566
+ primary_key: str | list[str] | None = Field(None, alias="primaryKey")
567
+ """Primary key of the table."""
568
+
569
+ tap_table_index: int | None = Field(None, alias="tap:table_index")
570
+ """IVOA TAP_SCHEMA table index of the table."""
571
+
572
+ mysql_engine: str | None = Field("MyISAM", alias="mysql:engine")
573
+ """MySQL engine to use for the table."""
574
+
575
+ mysql_charset: str | None = Field(None, alias="mysql:charset")
576
+ """MySQL charset to use for the table."""
577
+
578
+ @field_validator("columns", mode="after")
579
+ @classmethod
580
+ def check_unique_column_names(cls, columns: list[Column]) -> list[Column]:
581
+ """Check that column names are unique.
582
+
583
+ Parameters
584
+ ----------
585
+ columns
586
+ The columns to check.
587
+
588
+ Returns
589
+ -------
590
+ `list` [ `Column` ]
591
+ The columns if they are unique.
592
+
593
+ Raises
594
+ ------
595
+ ValueError
596
+ Raised if column names are not unique.
597
+ """
598
+ if len(columns) != len(set(column.name for column in columns)):
599
+ raise ValueError("Column names must be unique")
600
+ return columns
601
+
602
+ @model_validator(mode="after")
603
+ def check_tap_table_index(self, info: ValidationInfo) -> Table:
604
+ """Check that the table has a TAP table index.
605
+
606
+ Parameters
607
+ ----------
608
+ info
609
+ Validation context used to determine if the check is enabled.
610
+
611
+ Returns
612
+ -------
613
+ `Table`
614
+ The table being validated.
615
+
616
+ Raises
617
+ ------
618
+ ValueError
619
+ Raised If the table is missing a TAP table index.
620
+ """
621
+ context = info.context
622
+ if not context or not context.get("check_tap_table_indexes", False):
623
+ return self
624
+ if self.tap_table_index is None:
625
+ raise ValueError("Table is missing a TAP table index")
626
+ return self
627
+
628
+ @model_validator(mode="after")
629
+ def check_tap_principal(self, info: ValidationInfo) -> Table:
630
+ """Check that at least one column is flagged as 'principal' for TAP
631
+ purposes.
632
+
633
+ Parameters
634
+ ----------
635
+ info
636
+ Validation context used to determine if the check is enabled.
637
+
638
+ Returns
639
+ -------
640
+ `Table`
641
+ The table being validated.
642
+
643
+ Raises
644
+ ------
645
+ ValueError
646
+ Raised if the table is missing a column flagged as 'principal'.
647
+ """
648
+ context = info.context
649
+ if not context or not context.get("check_tap_principal", False):
650
+ return self
651
+ for col in self.columns:
652
+ if col.tap_principal == 1:
653
+ return self
654
+ raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'")
655
+
656
+
657
+ class SchemaVersion(BaseModel):
658
+ """Schema version model."""
659
+
660
+ current: str
661
+ """The current version of the schema."""
662
+
663
+ compatible: list[str] = Field(default_factory=list)
664
+ """The compatible versions of the schema."""
665
+
666
+ read_compatible: list[str] = Field(default_factory=list)
667
+ """The read compatible versions of the schema."""
668
+
669
+
670
+ class SchemaIdVisitor:
671
+ """Visit a schema and build the map of IDs to objects.
672
+
673
+ Notes
674
+ -----
675
+ Duplicates are added to a set when they are encountered, which can be
676
+ accessed via the ``duplicates`` attribute. The presence of duplicates will
677
+ not throw an error. Only the first object with a given ID will be added to
678
+ the map, but this should not matter, since a ``ValidationError`` will be
679
+ thrown by the ``model_validator`` method if any duplicates are found in the
680
+ schema.
681
+ """
682
+
683
+ def __init__(self) -> None:
684
+ """Create a new SchemaVisitor."""
685
+ self.schema: Schema | None = None
686
+ self.duplicates: set[str] = set()
687
+
688
+ def add(self, obj: BaseObject) -> None:
689
+ """Add an object to the ID map.
690
+
691
+ Parameters
692
+ ----------
693
+ obj
694
+ The object to add to the ID map.
695
+ """
696
+ if hasattr(obj, "id"):
697
+ obj_id = getattr(obj, "id")
698
+ if self.schema is not None:
699
+ if obj_id in self.schema.id_map:
700
+ self.duplicates.add(obj_id)
701
+ else:
702
+ self.schema.id_map[obj_id] = obj
703
+
704
+ def visit_schema(self, schema: Schema) -> None:
705
+ """Visit the objects in a schema and build the ID map.
706
+
707
+ Parameters
708
+ ----------
709
+ schema
710
+ The schema object to visit.
711
+
712
+ Notes
713
+ -----
714
+ This will set an internal variable pointing to the schema object.
715
+ """
716
+ self.schema = schema
717
+ self.duplicates.clear()
718
+ self.add(self.schema)
719
+ for table in self.schema.tables:
720
+ self.visit_table(table)
721
+
722
+ def visit_table(self, table: Table) -> None:
723
+ """Visit a table object.
724
+
725
+ Parameters
726
+ ----------
727
+ table
728
+ The table object to visit.
729
+ """
730
+ self.add(table)
731
+ for column in table.columns:
732
+ self.visit_column(column)
733
+ for constraint in table.constraints:
734
+ self.visit_constraint(constraint)
735
+
736
+ def visit_column(self, column: Column) -> None:
737
+ """Visit a column object.
738
+
739
+ Parameters
740
+ ----------
741
+ column
742
+ The column object to visit.
743
+ """
744
+ self.add(column)
745
+
746
+ def visit_constraint(self, constraint: Constraint) -> None:
747
+ """Visit a constraint object.
748
+
749
+ Parameters
750
+ ----------
751
+ constraint
752
+ The constraint object to visit.
753
+ """
754
+ self.add(constraint)
755
+
756
+
757
+ T = TypeVar("T", bound=BaseObject)
758
+
759
+
760
+ class Schema(BaseObject, Generic[T]):
761
+ """Database schema model.
762
+
763
+ This represents a database schema, which contains one or more tables.
764
+ """
765
+
766
+ version: SchemaVersion | str | None = None
767
+ """The version of the schema."""
768
+
769
+ tables: Sequence[Table]
770
+ """The tables in the schema."""
771
+
772
+ id_map: dict[str, Any] = Field(default_factory=dict, exclude=True)
773
+ """Map of IDs to objects."""
774
+
775
+ @model_validator(mode="before")
776
+ @classmethod
777
+ def generate_ids(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
778
+ """Generate IDs for objects that do not have them.
779
+
780
+ Parameters
781
+ ----------
782
+ values
783
+ The values of the schema.
784
+ info
785
+ Validation context used to determine if ID generation is enabled.
786
+
787
+ Returns
788
+ -------
789
+ `dict` [ `str`, `Any` ]
790
+ The values of the schema with generated IDs.
791
+ """
792
+ context = info.context
793
+ if not context or not context.get("id_generation", False):
794
+ logger.debug("Skipping ID generation")
795
+ return values
796
+ schema_name = values["name"]
797
+ if "@id" not in values:
798
+ values["@id"] = f"#{schema_name}"
799
+ logger.debug(f"Generated ID '{values['@id']}' for schema '{schema_name}'")
800
+ if "tables" in values:
801
+ for table in values["tables"]:
802
+ if "@id" not in table:
803
+ table["@id"] = f"#{table['name']}"
804
+ logger.debug(f"Generated ID '{table['@id']}' for table '{table['name']}'")
805
+ if "columns" in table:
806
+ for column in table["columns"]:
807
+ if "@id" not in column:
808
+ column["@id"] = f"#{table['name']}.{column['name']}"
809
+ logger.debug(f"Generated ID '{column['@id']}' for column '{column['name']}'")
810
+ if "constraints" in table:
811
+ for constraint in table["constraints"]:
812
+ if "@id" not in constraint:
813
+ constraint["@id"] = f"#{constraint['name']}"
814
+ logger.debug(
815
+ f"Generated ID '{constraint['@id']}' for constraint '{constraint['name']}'"
816
+ )
817
+ if "indexes" in table:
818
+ for index in table["indexes"]:
819
+ if "@id" not in index:
820
+ index["@id"] = f"#{index['name']}"
821
+ logger.debug(f"Generated ID '{index['@id']}' for index '{index['name']}'")
822
+ return values
823
+
824
+ @field_validator("tables", mode="after")
825
+ @classmethod
826
+ def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
827
+ """Check that table names are unique.
828
+
829
+ Parameters
830
+ ----------
831
+ tables
832
+ The tables to check.
833
+
834
+ Returns
835
+ -------
836
+ `list` [ `Table` ]
837
+ The tables if they are unique.
838
+
839
+ Raises
840
+ ------
841
+ ValueError
842
+ Raised if table names are not unique.
843
+ """
844
+ if len(tables) != len(set(table.name for table in tables)):
845
+ raise ValueError("Table names must be unique")
846
+ return tables
847
+
848
+ @model_validator(mode="after")
849
+ def check_tap_table_indexes(self, info: ValidationInfo) -> Schema:
850
+ """Check that the TAP table indexes are unique.
851
+
852
+ Parameters
853
+ ----------
854
+ info
855
+ The validation context used to determine if the check is enabled.
856
+
857
+ Returns
858
+ -------
859
+ `Schema`
860
+ The schema being validated.
861
+ """
862
+ context = info.context
863
+ if not context or not context.get("check_tap_table_indexes", False):
864
+ return self
865
+ table_indicies = set()
866
+ for table in self.tables:
867
+ table_index = table.tap_table_index
868
+ if table_index is not None:
869
+ if table_index in table_indicies:
870
+ raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema")
871
+ table_indicies.add(table_index)
872
+ return self
873
+
874
+ @model_validator(mode="after")
875
+ def check_unique_constraint_names(self: Schema) -> Schema:
876
+ """Check for duplicate constraint names in the schema.
877
+
878
+ Returns
879
+ -------
880
+ `Schema`
881
+ The schema being validated.
882
+
883
+ Raises
884
+ ------
885
+ ValueError
886
+ Raised if duplicate constraint names are found in the schema.
887
+ """
888
+ constraint_names = set()
889
+ duplicate_names = []
890
+
891
+ for table in self.tables:
892
+ for constraint in table.constraints:
893
+ constraint_name = constraint.name
894
+ if constraint_name in constraint_names:
895
+ duplicate_names.append(constraint_name)
896
+ else:
897
+ constraint_names.add(constraint_name)
898
+
899
+ if duplicate_names:
900
+ raise ValueError(f"Duplicate constraint names found in schema: {duplicate_names}")
901
+
902
+ return self
903
+
904
+ @model_validator(mode="after")
905
+ def check_unique_index_names(self: Schema) -> Schema:
906
+ """Check for duplicate index names in the schema.
907
+
908
+ Returns
909
+ -------
910
+ `Schema`
911
+ The schema being validated.
912
+
913
+ Raises
914
+ ------
915
+ ValueError
916
+ Raised if duplicate index names are found in the schema.
917
+ """
918
+ index_names = set()
919
+ duplicate_names = []
920
+
921
+ for table in self.tables:
922
+ for index in table.indexes:
923
+ index_name = index.name
924
+ if index_name in index_names:
925
+ duplicate_names.append(index_name)
926
+ else:
927
+ index_names.add(index_name)
928
+
929
+ if duplicate_names:
930
+ raise ValueError(f"Duplicate index names found in schema: {duplicate_names}")
931
+
932
+ return self
933
+
934
+ def _create_id_map(self: Schema) -> Schema:
935
+ """Create a map of IDs to objects.
936
+
937
+ Raises
938
+ ------
939
+ ValueError
940
+ Raised if duplicate identifiers are found in the schema.
941
+
942
+ Notes
943
+ -----
944
+ This is called automatically by the `model_post_init` method. If the
945
+ ID map is already populated, this method will return immediately.
946
+ """
947
+ if len(self.id_map):
948
+ logger.debug("Ignoring call to create_id_map() - ID map was already populated")
949
+ return self
950
+ visitor: SchemaIdVisitor = SchemaIdVisitor()
951
+ visitor.visit_schema(self)
952
+ if len(visitor.duplicates):
953
+ raise ValueError(
954
+ "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
955
+ )
956
+ return self
957
+
958
+ def model_post_init(self, ctx: Any) -> None:
959
+ """Post-initialization hook for the model.
960
+
961
+ Parameters
962
+ ----------
963
+ ctx
964
+ The context object which was passed to the model.
965
+
966
+ Notes
967
+ -----
968
+ This method is called automatically by Pydantic after the model is
969
+ initialized. It is used to create the ID map for the schema.
970
+
971
+ The ``ctx`` argument has the type `Any` because this is the function
972
+ signature in Pydantic itself.
973
+ """
974
+ self._create_id_map()
975
+
976
+ def __getitem__(self, id: str) -> BaseObject:
977
+ """Get an object by its ID.
978
+
979
+ Parameters
980
+ ----------
981
+ id
982
+ The ID of the object to get.
983
+
984
+ Raises
985
+ ------
986
+ KeyError
987
+ Raised if the object with the given ID is not found in the schema.
988
+ """
989
+ if id not in self:
990
+ raise KeyError(f"Object with ID '{id}' not found in schema")
991
+ return self.id_map[id]
992
+
993
+ def __contains__(self, id: str) -> bool:
994
+ """Check if an object with the given ID is in the schema.
995
+
996
+ Parameters
997
+ ----------
998
+ id
999
+ The ID of the object to check.
1000
+ """
1001
+ return id in self.id_map
1002
+
1003
+ def find_object_by_id(self, id: str, obj_type: type[T]) -> T:
1004
+ """Find an object with the given type by its ID.
1005
+
1006
+ Parameters
1007
+ ----------
1008
+ id
1009
+ The ID of the object to find.
1010
+ obj_type
1011
+ The type of the object to find.
1012
+
1013
+ Returns
1014
+ -------
1015
+ BaseObject
1016
+ The object with the given ID and type.
1017
+
1018
+ Raises
1019
+ ------
1020
+ KeyError
1021
+ If the object with the given ID is not found in the schema.
1022
+ TypeError
1023
+ If the object that is found does not have the right type.
1024
+
1025
+ Notes
1026
+ -----
1027
+ The actual return type is the user-specified argument ``T``, which is
1028
+ expected to be a subclass of `BaseObject`.
1029
+ """
1030
+ obj = self[id]
1031
+ if not isinstance(obj, obj_type):
1032
+ raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'")
1033
+ return obj
1034
+
1035
+ def get_table_by_column(self, column: Column) -> Table:
1036
+ """Find the table that contains a column.
1037
+
1038
+ Parameters
1039
+ ----------
1040
+ column
1041
+ The column to find.
1042
+
1043
+ Returns
1044
+ -------
1045
+ `Table`
1046
+ The table that contains the column.
1047
+
1048
+ Raises
1049
+ ------
1050
+ ValueError
1051
+ If the column is not found in any table.
1052
+ """
1053
+ for table in self.tables:
1054
+ if column in table.columns:
1055
+ return table
1056
+ raise ValueError(f"Column '{column.name}' not found in any table")
1057
+
1058
+ @classmethod
1059
+ def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema:
1060
+ """Load a `Schema` from a string representing a ``ResourcePath``.
1061
+
1062
+ Parameters
1063
+ ----------
1064
+ resource_path
1065
+ The ``ResourcePath`` pointing to a YAML file.
1066
+ context
1067
+ Pydantic context to be used in validation.
1068
+
1069
+ Returns
1070
+ -------
1071
+ `str`
1072
+ The ID of the object.
1073
+
1074
+ Raises
1075
+ ------
1076
+ yaml.YAMLError
1077
+ Raised if there is an error loading the YAML data.
1078
+ ValueError
1079
+ Raised if there is an error reading the resource.
1080
+ pydantic.ValidationError
1081
+ Raised if the schema fails validation.
1082
+ """
1083
+ logger.debug(f"Loading schema from: '{resource_path}'")
1084
+ try:
1085
+ rp_stream = ResourcePath(resource_path).read()
1086
+ except Exception as e:
1087
+ raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e
1088
+ yaml_data = yaml.safe_load(rp_stream)
1089
+ return Schema.model_validate(yaml_data, context=context)
1090
+
1091
+ @classmethod
1092
+ def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema:
1093
+ """Load a `Schema` from a file stream which should contain YAML data.
1094
+
1095
+ Parameters
1096
+ ----------
1097
+ source
1098
+ The file stream to read from.
1099
+ context
1100
+ Pydantic context to be used in validation.
1101
+
1102
+ Returns
1103
+ -------
1104
+ `Schema`
1105
+ The Felis schema loaded from the stream.
1106
+
1107
+ Raises
1108
+ ------
1109
+ yaml.YAMLError
1110
+ Raised if there is an error loading the YAML file.
1111
+ pydantic.ValidationError
1112
+ Raised if the schema fails validation.
1113
+ """
1114
+ logger.debug("Loading schema from: '%s'", source)
1115
+ yaml_data = yaml.safe_load(source)
1116
+ return Schema.model_validate(yaml_data, context=context)