deriva-ml 1.17.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/.DS_Store +0 -0
- deriva_ml/__init__.py +79 -0
- deriva_ml/bump_version.py +142 -0
- deriva_ml/core/__init__.py +39 -0
- deriva_ml/core/base.py +1527 -0
- deriva_ml/core/config.py +69 -0
- deriva_ml/core/constants.py +36 -0
- deriva_ml/core/definitions.py +74 -0
- deriva_ml/core/enums.py +222 -0
- deriva_ml/core/ermrest.py +288 -0
- deriva_ml/core/exceptions.py +28 -0
- deriva_ml/core/filespec.py +116 -0
- deriva_ml/dataset/__init__.py +12 -0
- deriva_ml/dataset/aux_classes.py +225 -0
- deriva_ml/dataset/dataset.py +1519 -0
- deriva_ml/dataset/dataset_bag.py +450 -0
- deriva_ml/dataset/history.py +109 -0
- deriva_ml/dataset/upload.py +439 -0
- deriva_ml/demo_catalog.py +495 -0
- deriva_ml/execution/__init__.py +26 -0
- deriva_ml/execution/environment.py +290 -0
- deriva_ml/execution/execution.py +1180 -0
- deriva_ml/execution/execution_configuration.py +147 -0
- deriva_ml/execution/workflow.py +413 -0
- deriva_ml/feature.py +228 -0
- deriva_ml/install_kernel.py +71 -0
- deriva_ml/model/__init__.py +0 -0
- deriva_ml/model/catalog.py +485 -0
- deriva_ml/model/database.py +719 -0
- deriva_ml/protocols/dataset.py +19 -0
- deriva_ml/run_notebook.py +228 -0
- deriva_ml/schema/__init__.py +3 -0
- deriva_ml/schema/annotations.py +473 -0
- deriva_ml/schema/check_schema.py +104 -0
- deriva_ml/schema/create_schema.py +393 -0
- deriva_ml/schema/deriva-ml-reference.json +8525 -0
- deriva_ml/schema/policy.json +81 -0
- deriva_ml/schema/table_comments_utils.py +57 -0
- deriva_ml/test.py +94 -0
- deriva_ml-1.17.10.dist-info/METADATA +38 -0
- deriva_ml-1.17.10.dist-info/RECORD +45 -0
- deriva_ml-1.17.10.dist-info/WHEEL +5 -0
- deriva_ml-1.17.10.dist-info/entry_points.txt +9 -0
- deriva_ml-1.17.10.dist-info/licenses/LICENSE +201 -0
- deriva_ml-1.17.10.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model management for Deriva ML catalogs.
|
|
3
|
+
|
|
4
|
+
This module provides the DerivaModel class which augments the standard Deriva model class with
|
|
5
|
+
ML-specific functionality. It handles schema management, feature definitions, and asset tracking.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
# Standard library imports
|
|
11
|
+
from collections import Counter, defaultdict
|
|
12
|
+
from graphlib import CycleError, TopologicalSorter
|
|
13
|
+
from typing import Any, Callable, Final, Iterable, NewType, TypeAlias
|
|
14
|
+
|
|
15
|
+
from deriva.core.ermrest_catalog import ErmrestCatalog
|
|
16
|
+
|
|
17
|
+
# Deriva imports
|
|
18
|
+
from deriva.core.ermrest_model import Column, FindAssociationResult, Model, Schema, Table
|
|
19
|
+
|
|
20
|
+
# Third-party imports
|
|
21
|
+
from pydantic import ConfigDict, validate_call
|
|
22
|
+
|
|
23
|
+
from deriva_ml.core.definitions import (
|
|
24
|
+
ML_SCHEMA,
|
|
25
|
+
RID,
|
|
26
|
+
DerivaAssetColumns,
|
|
27
|
+
TableDefinition,
|
|
28
|
+
)
|
|
29
|
+
from deriva_ml.core.exceptions import DerivaMLException, DerivaMLTableTypeError
|
|
30
|
+
|
|
31
|
+
# Local imports
|
|
32
|
+
from deriva_ml.feature import Feature
|
|
33
|
+
from deriva_ml.protocols.dataset import DatasetLike
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from icecream import ic
|
|
37
|
+
except ImportError: # Graceful fallback if IceCream isn't installed.
|
|
38
|
+
ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Define common types:
|
|
42
|
+
TableInput: TypeAlias = str | Table
|
|
43
|
+
SchemaDict: TypeAlias = dict[str, Schema]
|
|
44
|
+
FeatureList: TypeAlias = Iterable[Feature]
|
|
45
|
+
SchemaName = NewType("SchemaName", str)
|
|
46
|
+
ColumnSet: TypeAlias = set[Column]
|
|
47
|
+
AssociationResult: TypeAlias = FindAssociationResult
|
|
48
|
+
TableSet: TypeAlias = set[Table]
|
|
49
|
+
PathList: TypeAlias = list[list[Table]]
|
|
50
|
+
|
|
51
|
+
# Define constants:
|
|
52
|
+
VOCAB_COLUMNS: Final[set[str]] = {"NAME", "URI", "SYNONYMS", "DESCRIPTION", "ID"}
|
|
53
|
+
ASSET_COLUMNS: Final[set[str]] = {"Filename", "URL", "Length", "MD5", "Description"}
|
|
54
|
+
|
|
55
|
+
FilterPredicate = Callable[[Table], bool]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class DerivaModel:
|
|
59
|
+
"""Augmented interface to deriva model class.
|
|
60
|
+
|
|
61
|
+
This class provides a number of DerivaML specific methods that augment the interface in the deriva model class.
|
|
62
|
+
|
|
63
|
+
Attributes:
|
|
64
|
+
domain_schema: Schema name for domain-specific tables and relationships.
|
|
65
|
+
model: ERMRest model for the catalog.
|
|
66
|
+
catalog: ERMRest catalog for the model
|
|
67
|
+
hostname: ERMRest catalog for the model
|
|
68
|
+
ml_schema: The ML schema for the catalog.
|
|
69
|
+
domain_schema: The domain schema for the catalog.
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
model: Model,
|
|
76
|
+
ml_schema: str = ML_SCHEMA,
|
|
77
|
+
domain_schema: str | None = None,
|
|
78
|
+
):
|
|
79
|
+
"""Create and initialize a DerivaML instance.
|
|
80
|
+
|
|
81
|
+
This method will connect to a catalog, and initialize local configuration for the ML execution.
|
|
82
|
+
This class is intended to be used as a base class on which domain-specific interfaces are built.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
model: The ERMRest model for the catalog.
|
|
86
|
+
ml_schema: The ML schema name.
|
|
87
|
+
domain_schema: The domain schema name.
|
|
88
|
+
"""
|
|
89
|
+
self.model = model
|
|
90
|
+
self.configuration = None
|
|
91
|
+
self.catalog: ErmrestCatalog = self.model.catalog
|
|
92
|
+
self.hostname = self.catalog.deriva_server.server if isinstance(self.catalog, ErmrestCatalog) else "localhost"
|
|
93
|
+
|
|
94
|
+
self.ml_schema = ml_schema
|
|
95
|
+
builtin_schemas = ("public", self.ml_schema, "www", "WWW")
|
|
96
|
+
if domain_schema:
|
|
97
|
+
self.domain_schema = domain_schema
|
|
98
|
+
else:
|
|
99
|
+
if len(user_schemas := {k for k in self.model.schemas.keys()} - set(builtin_schemas)) == 1:
|
|
100
|
+
self.domain_schema = user_schemas.pop()
|
|
101
|
+
else:
|
|
102
|
+
raise DerivaMLException(f"Ambiguous domain schema: {user_schemas}")
|
|
103
|
+
|
|
104
|
+
def refresh_model(self) -> None:
|
|
105
|
+
self.model = self.catalog.getCatalogModel()
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def schemas(self) -> dict[str, Schema]:
|
|
109
|
+
return self.model.schemas
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def chaise_config(self) -> dict[str, Any]:
|
|
113
|
+
"""Return the chaise configuration."""
|
|
114
|
+
return self.model.chaise_config
|
|
115
|
+
|
|
116
|
+
def __getattr__(self, name: str) -> Any:
|
|
117
|
+
# Called only if `name` is not found in Manager. Delegate attributes to model class.
|
|
118
|
+
return getattr(self.model, name)
|
|
119
|
+
|
|
120
|
+
def name_to_table(self, table: TableInput) -> Table:
|
|
121
|
+
"""Return the table object corresponding to the given table name.
|
|
122
|
+
|
|
123
|
+
If the table name appears in more than one schema, return the first one you find.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
table: A ERMRest table object or a string that is the name of the table.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Table object.
|
|
130
|
+
"""
|
|
131
|
+
if isinstance(table, Table):
|
|
132
|
+
return table
|
|
133
|
+
if table in (s := self.model.schemas[self.domain_schema].tables):
|
|
134
|
+
return s[table]
|
|
135
|
+
for s in [self.model.schemas[sname] for sname in [self.domain_schema, self.ml_schema, "WWW"]]:
|
|
136
|
+
if table in s.tables.keys():
|
|
137
|
+
return s.tables[table]
|
|
138
|
+
raise DerivaMLException(f"The table {table} doesn't exist.")
|
|
139
|
+
|
|
140
|
+
def is_vocabulary(self, table_name: TableInput) -> bool:
|
|
141
|
+
"""Check if a given table is a controlled vocabulary table.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
table_name: A ERMRest table object or the name of the table.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Table object if the table is a controlled vocabulary, False otherwise.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
DerivaMLException: if the table doesn't exist.
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
vocab_columns = {"NAME", "URI", "SYNONYMS", "DESCRIPTION", "ID"}
|
|
154
|
+
table = self.name_to_table(table_name)
|
|
155
|
+
return vocab_columns.issubset({c.name.upper() for c in table.columns})
|
|
156
|
+
|
|
157
|
+
def is_association(
|
|
158
|
+
self,
|
|
159
|
+
table_name: str | Table,
|
|
160
|
+
unqualified: bool = True,
|
|
161
|
+
pure: bool = True,
|
|
162
|
+
min_arity: int = 2,
|
|
163
|
+
max_arity: int = 2,
|
|
164
|
+
) -> bool | set[str] | int:
|
|
165
|
+
"""Check the specified table to see if it is an association table.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
table_name: param unqualified:
|
|
169
|
+
pure: return: (Default value = True)
|
|
170
|
+
table_name: str | Table:
|
|
171
|
+
unqualified: (Default value = True)
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
"""
|
|
177
|
+
table = self.name_to_table(table_name)
|
|
178
|
+
return table.is_association(unqualified=unqualified, pure=pure, min_arity=min_arity, max_arity=max_arity)
|
|
179
|
+
|
|
180
|
+
def find_association(self, table1: Table | str, table2: Table | str) -> tuple[Table, Column, Column]:
|
|
181
|
+
"""Given two tables, return an association table that connects the two and the two columns used to link them..
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
DerivaML exception if there is either not an association table or more than one association table.
|
|
185
|
+
"""
|
|
186
|
+
table1 = self.name_to_table(table1)
|
|
187
|
+
table2 = self.name_to_table(table2)
|
|
188
|
+
|
|
189
|
+
tables = [
|
|
190
|
+
(a.table, a.self_fkey.columns[0].name, other_key.columns[0].name)
|
|
191
|
+
for a in table1.find_associations(pure=False)
|
|
192
|
+
if len(a.other_fkeys) == 1 and (other_key := a.other_fkeys.pop()).pk_table == table2
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
if len(tables) == 1:
|
|
196
|
+
return tables[0]
|
|
197
|
+
elif len(tables) == 0:
|
|
198
|
+
raise DerivaMLException(f"No association tables found between {table1.name} and {table2.name}.")
|
|
199
|
+
else:
|
|
200
|
+
raise DerivaMLException(
|
|
201
|
+
f"There are {len(tables)} association tables between {table1.name} and {table2.name}."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def is_asset(self, table_name: TableInput) -> bool:
|
|
205
|
+
"""True if the specified table is an asset table.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
table_name: str | Table:
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
True if the specified table is an asset table, False otherwise.
|
|
212
|
+
|
|
213
|
+
"""
|
|
214
|
+
asset_columns = {"Filename", "URL", "Length", "MD5", "Description"}
|
|
215
|
+
table = self.name_to_table(table_name)
|
|
216
|
+
return asset_columns.issubset({c.name for c in table.columns})
|
|
217
|
+
|
|
218
|
+
def find_assets(self, with_metadata: bool = False) -> list[Table]:
|
|
219
|
+
"""Return the list of asset tables in the current model"""
|
|
220
|
+
return [t for s in self.model.schemas.values() for t in s.tables.values() if self.is_asset(t)]
|
|
221
|
+
|
|
222
|
+
def find_vocabularies(self) -> list[Table]:
|
|
223
|
+
"""Return a list of all the controlled vocabulary tables in the domain schema."""
|
|
224
|
+
return [t for s in self.model.schemas.values() for t in s.tables.values() if self.is_vocabulary(t)]
|
|
225
|
+
|
|
226
|
+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
|
|
227
|
+
def find_features(self, table: TableInput) -> Iterable[Feature]:
|
|
228
|
+
"""List the names of the features in the specified table.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
table: The table to find features for.
|
|
232
|
+
table: Table | str:
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
An iterable of FeatureResult instances that describe the current features in the table.
|
|
236
|
+
"""
|
|
237
|
+
table = self.name_to_table(table)
|
|
238
|
+
|
|
239
|
+
def is_feature(a: FindAssociationResult) -> bool:
|
|
240
|
+
"""Check if association represents a feature.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
a: Association result to check
|
|
244
|
+
Returns:
|
|
245
|
+
bool: True if association represents a feature
|
|
246
|
+
"""
|
|
247
|
+
return {
|
|
248
|
+
"Feature_Name",
|
|
249
|
+
"Execution",
|
|
250
|
+
a.self_fkey.foreign_key_columns[0].name,
|
|
251
|
+
}.issubset({c.name for c in a.table.columns})
|
|
252
|
+
|
|
253
|
+
return [
|
|
254
|
+
Feature(a, self) for a in table.find_associations(min_arity=3, max_arity=3, pure=False) if is_feature(a)
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
def lookup_feature(self, table: TableInput, feature_name: str) -> Feature:
|
|
258
|
+
"""Lookup the named feature associated with the provided table.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
table: param feature_name:
|
|
262
|
+
table: str | Table:
|
|
263
|
+
feature_name: str:
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
A Feature class that represents the requested feature.
|
|
267
|
+
|
|
268
|
+
Raises:
|
|
269
|
+
DerivaMLException: If the feature cannot be found.
|
|
270
|
+
"""
|
|
271
|
+
table = self.name_to_table(table)
|
|
272
|
+
try:
|
|
273
|
+
return [f for f in self.find_features(table) if f.feature_name == feature_name][0]
|
|
274
|
+
except IndexError:
|
|
275
|
+
raise DerivaMLException(f"Feature {table.name}:{feature_name} doesn't exist.")
|
|
276
|
+
|
|
277
|
+
def asset_metadata(self, table: str | Table) -> set[str]:
|
|
278
|
+
"""Return the metadata columns for an asset table."""
|
|
279
|
+
|
|
280
|
+
table = self.name_to_table(table)
|
|
281
|
+
|
|
282
|
+
if not self.is_asset(table):
|
|
283
|
+
raise DerivaMLTableTypeError("asset table", table.name)
|
|
284
|
+
return {c.name for c in table.columns} - DerivaAssetColumns
|
|
285
|
+
|
|
286
|
+
def apply(self) -> None:
|
|
287
|
+
"""Call ERMRestModel.apply"""
|
|
288
|
+
if self.catalog == "file-system":
|
|
289
|
+
raise DerivaMLException("Cannot apply() to non-catalog model.")
|
|
290
|
+
else:
|
|
291
|
+
self.model.apply()
|
|
292
|
+
|
|
293
|
+
def list_dataset_element_types(self) -> list[Table]:
|
|
294
|
+
"""
|
|
295
|
+
Lists the data types of elements contained within a dataset.
|
|
296
|
+
|
|
297
|
+
This method analyzes the dataset and identifies the data types for all
|
|
298
|
+
elements within it. It is useful for understanding the structure and
|
|
299
|
+
content of the dataset and allows for better manipulation and usage of its
|
|
300
|
+
data.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
list[str]: A list of strings where each string represents a data type
|
|
304
|
+
of an element found in the dataset.
|
|
305
|
+
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
dataset_table = self.name_to_table("Dataset")
|
|
309
|
+
|
|
310
|
+
def domain_table(table: Table) -> bool:
|
|
311
|
+
return table.schema.name == self.domain_schema or table.name == dataset_table.name
|
|
312
|
+
|
|
313
|
+
return [t for a in dataset_table.find_associations() if domain_table(t := a.other_fkeys.pop().pk_table)]
|
|
314
|
+
|
|
315
|
+
def _prepare_wide_table(self,
|
|
316
|
+
dataset,
|
|
317
|
+
dataset_rid: RID,
|
|
318
|
+
include_tables: list[str]) -> tuple[dict[str, Any], list[tuple]]:
|
|
319
|
+
"""
|
|
320
|
+
Generates details of a wide table from the model
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
include_tables (list[str] | None): List of table names to include in the denormalized dataset. If None,
|
|
324
|
+
all tables from the dataset will be included.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
str: SQL query string that represents the process of denormalization.
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
# Skip over tables that we don't want to include in the denormalized dataset.
|
|
331
|
+
# Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
|
|
332
|
+
# table.
|
|
333
|
+
include_tables = set(include_tables)
|
|
334
|
+
for t in include_tables:
|
|
335
|
+
# Check to make sure the table is in the catalog.
|
|
336
|
+
_ = self.name_to_table(t)
|
|
337
|
+
|
|
338
|
+
table_paths = [
|
|
339
|
+
path
|
|
340
|
+
for path in self._schema_to_paths()
|
|
341
|
+
if path[-1].name in include_tables and include_tables.intersection({p.name for p in path})
|
|
342
|
+
]
|
|
343
|
+
paths_by_element = defaultdict(list)
|
|
344
|
+
for p in table_paths:
|
|
345
|
+
paths_by_element[p[2].name].append(p)
|
|
346
|
+
|
|
347
|
+
# Get the names of all of the tables that can be dataset elements.
|
|
348
|
+
dataset_element_tables = {
|
|
349
|
+
e.name for e in self.list_dataset_element_types() if e.schema.name == self.domain_schema
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
skip_columns = {"RCT", "RMT", "RCB", "RMB"}
|
|
353
|
+
element_tables = {}
|
|
354
|
+
for element_table, paths in paths_by_element.items():
|
|
355
|
+
graph = {}
|
|
356
|
+
for path in paths:
|
|
357
|
+
for left, right in zip(path[0:], path[1:]):
|
|
358
|
+
graph.setdefault(left.name, set()).add(right.name)
|
|
359
|
+
|
|
360
|
+
# New lets remove any cycles that we may have in the graph.
|
|
361
|
+
# We will use a topological sort to find the order in which we need to join the tables.
|
|
362
|
+
# If we find a cycle, we will remove the table from the graph and splice in an additional ON clause.
|
|
363
|
+
# We will then repeat the process until there are no cycles.
|
|
364
|
+
graph_has_cycles = True
|
|
365
|
+
element_join_tables = []
|
|
366
|
+
element_join_conditions = {}
|
|
367
|
+
while graph_has_cycles:
|
|
368
|
+
try:
|
|
369
|
+
ts = TopologicalSorter(graph)
|
|
370
|
+
element_join_tables = list(reversed(list(ts.static_order())))
|
|
371
|
+
graph_has_cycles = False
|
|
372
|
+
except CycleError as e:
|
|
373
|
+
cycle_nodes = e.args[1]
|
|
374
|
+
if len(cycle_nodes) > 3:
|
|
375
|
+
raise DerivaMLException(f"Unexpected cycle found when normalizing dataset {cycle_nodes}")
|
|
376
|
+
# Remove cycle from graph and splice in additional ON constraint.
|
|
377
|
+
graph[cycle_nodes[1]].remove(cycle_nodes[0])
|
|
378
|
+
|
|
379
|
+
# The Dataset_Version table is a special case as it points to dataset and dataset to version.
|
|
380
|
+
if "Dataset_Version" in element_join_tables:
|
|
381
|
+
element_join_tables.remove("Dataset_Version")
|
|
382
|
+
|
|
383
|
+
for path in paths:
|
|
384
|
+
for left, right in zip(path[0:], path[1:]):
|
|
385
|
+
if right.name == "Dataset_Version":
|
|
386
|
+
# The Dataset_Version table is a special case as it points to dataset and dataset to version.
|
|
387
|
+
continue
|
|
388
|
+
if element_join_tables.index(right.name) < element_join_tables.index(left.name):
|
|
389
|
+
continue
|
|
390
|
+
table_relationship = self._table_relationship(left, right)
|
|
391
|
+
element_join_conditions.setdefault(right.name, set()).add(
|
|
392
|
+
(table_relationship[0], table_relationship[1])
|
|
393
|
+
)
|
|
394
|
+
element_tables[element_table] = (element_join_tables, element_join_conditions)
|
|
395
|
+
# Get the list of columns that will appear in the final denormalized dataset.
|
|
396
|
+
denormalized_columns = [
|
|
397
|
+
(table_name, c.name)
|
|
398
|
+
for table_name in include_tables
|
|
399
|
+
if not self.is_association(table_name) # Don't include association columns in the denormalized view.'
|
|
400
|
+
for c in self.name_to_table(table_name).columns
|
|
401
|
+
if (not include_tables or table_name in include_tables) and (c.name not in skip_columns)
|
|
402
|
+
]
|
|
403
|
+
return element_tables, denormalized_columns
|
|
404
|
+
|
|
405
|
+
def _table_relationship(
|
|
406
|
+
self,
|
|
407
|
+
table1: TableInput,
|
|
408
|
+
table2: TableInput,
|
|
409
|
+
) -> tuple[Column, Column]:
|
|
410
|
+
"""Return columns used to relate two tables."""
|
|
411
|
+
table1 = self.name_to_table(table1)
|
|
412
|
+
table2 = self.name_to_table(table2)
|
|
413
|
+
relationships = [
|
|
414
|
+
(fk.foreign_key_columns[0], fk.referenced_columns[0]) for fk in table1.foreign_keys if fk.pk_table == table2
|
|
415
|
+
]
|
|
416
|
+
relationships.extend(
|
|
417
|
+
[(fk.referenced_columns[0], fk.foreign_key_columns[0]) for fk in table1.referenced_by if fk.table == table2]
|
|
418
|
+
)
|
|
419
|
+
if len(relationships) != 1:
|
|
420
|
+
raise DerivaMLException(
|
|
421
|
+
f"Ambiguous linkage between {table1.name} and {table2.name}: {[(r[0].name, r[1].name) for r in relationships]}"
|
|
422
|
+
)
|
|
423
|
+
return relationships[0]
|
|
424
|
+
|
|
425
|
+
def _schema_to_paths(
|
|
426
|
+
self,
|
|
427
|
+
root: Table | None = None,
|
|
428
|
+
path: list[Table] | None = None,
|
|
429
|
+
) -> list[list[Table]]:
|
|
430
|
+
"""Return a list of paths through the schema graph.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
root: The root table to start from.
|
|
434
|
+
path: The current path being built.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
A list of paths through the schema graph.
|
|
438
|
+
"""
|
|
439
|
+
path = path or []
|
|
440
|
+
|
|
441
|
+
root = root or self.model.schemas[self.ml_schema].tables["Dataset"]
|
|
442
|
+
path = path.copy() if path else []
|
|
443
|
+
parent = path[-1] if path else None # Table that we are coming from.
|
|
444
|
+
path.append(root)
|
|
445
|
+
paths = [path]
|
|
446
|
+
|
|
447
|
+
def find_arcs(table: Table) -> set[Table]:
|
|
448
|
+
"""Given a path through the model, return the FKs that link the tables"""
|
|
449
|
+
arc_list = [fk.pk_table for fk in table.foreign_keys] + [fk.table for fk in table.referenced_by]
|
|
450
|
+
arc_list = [t for t in arc_list if t.schema.name in {self.domain_schema, self.ml_schema}]
|
|
451
|
+
domain_tables = [t for t in arc_list if t.schema.name == self.domain_schema]
|
|
452
|
+
if multiple_columns := [c for c, cnt in Counter(domain_tables).items() if cnt > 1]:
|
|
453
|
+
raise DerivaMLException(f"Ambiguous relationship in {table.name} {multiple_columns}")
|
|
454
|
+
return set(arc_list)
|
|
455
|
+
|
|
456
|
+
def is_nested_dataset_loopback(n1: Table, n2: Table) -> bool:
|
|
457
|
+
"""Test to see if node is an association table used to link elements to datasets."""
|
|
458
|
+
# If we have node_name <- node_name_dataset-> Dataset then we are looping
|
|
459
|
+
# back around to a new dataset element
|
|
460
|
+
dataset_table = self.model.schemas[self.ml_schema].tables["Dataset"]
|
|
461
|
+
assoc_table = [a for a in dataset_table.find_associations() if a.table == n2]
|
|
462
|
+
return len(assoc_table) == 1 and n1 != dataset_table
|
|
463
|
+
|
|
464
|
+
# Don't follow vocabulary terms back to their use.
|
|
465
|
+
if self.is_vocabulary(root):
|
|
466
|
+
return paths
|
|
467
|
+
|
|
468
|
+
for child in find_arcs(root):
|
|
469
|
+
if child.name in {"Dataset_Execution", "Dataset_Dataset", "Execution"}:
|
|
470
|
+
continue
|
|
471
|
+
if child == parent:
|
|
472
|
+
# Don't loop back via referred_by
|
|
473
|
+
continue
|
|
474
|
+
if is_nested_dataset_loopback(root, child):
|
|
475
|
+
continue
|
|
476
|
+
if child in path:
|
|
477
|
+
raise DerivaMLException(f"Cycle in schema path: {child.name} path:{[p.name for p in path]}")
|
|
478
|
+
|
|
479
|
+
paths.extend(self._schema_to_paths(child, path))
|
|
480
|
+
return paths
|
|
481
|
+
|
|
482
|
+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
|
|
483
|
+
def create_table(self, table_def: TableDefinition) -> Table:
|
|
484
|
+
"""Create a new table from TableDefinition."""
|
|
485
|
+
return self.model.schemas[self.domain_schema].create_table(table_def.model_dump())
|