deriva-ml 1.17.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. deriva_ml/.DS_Store +0 -0
  2. deriva_ml/__init__.py +79 -0
  3. deriva_ml/bump_version.py +142 -0
  4. deriva_ml/core/__init__.py +39 -0
  5. deriva_ml/core/base.py +1527 -0
  6. deriva_ml/core/config.py +69 -0
  7. deriva_ml/core/constants.py +36 -0
  8. deriva_ml/core/definitions.py +74 -0
  9. deriva_ml/core/enums.py +222 -0
  10. deriva_ml/core/ermrest.py +288 -0
  11. deriva_ml/core/exceptions.py +28 -0
  12. deriva_ml/core/filespec.py +116 -0
  13. deriva_ml/dataset/__init__.py +12 -0
  14. deriva_ml/dataset/aux_classes.py +225 -0
  15. deriva_ml/dataset/dataset.py +1519 -0
  16. deriva_ml/dataset/dataset_bag.py +450 -0
  17. deriva_ml/dataset/history.py +109 -0
  18. deriva_ml/dataset/upload.py +439 -0
  19. deriva_ml/demo_catalog.py +495 -0
  20. deriva_ml/execution/__init__.py +26 -0
  21. deriva_ml/execution/environment.py +290 -0
  22. deriva_ml/execution/execution.py +1180 -0
  23. deriva_ml/execution/execution_configuration.py +147 -0
  24. deriva_ml/execution/workflow.py +413 -0
  25. deriva_ml/feature.py +228 -0
  26. deriva_ml/install_kernel.py +71 -0
  27. deriva_ml/model/__init__.py +0 -0
  28. deriva_ml/model/catalog.py +485 -0
  29. deriva_ml/model/database.py +719 -0
  30. deriva_ml/protocols/dataset.py +19 -0
  31. deriva_ml/run_notebook.py +228 -0
  32. deriva_ml/schema/__init__.py +3 -0
  33. deriva_ml/schema/annotations.py +473 -0
  34. deriva_ml/schema/check_schema.py +104 -0
  35. deriva_ml/schema/create_schema.py +393 -0
  36. deriva_ml/schema/deriva-ml-reference.json +8525 -0
  37. deriva_ml/schema/policy.json +81 -0
  38. deriva_ml/schema/table_comments_utils.py +57 -0
  39. deriva_ml/test.py +94 -0
  40. deriva_ml-1.17.10.dist-info/METADATA +38 -0
  41. deriva_ml-1.17.10.dist-info/RECORD +45 -0
  42. deriva_ml-1.17.10.dist-info/WHEEL +5 -0
  43. deriva_ml-1.17.10.dist-info/entry_points.txt +9 -0
  44. deriva_ml-1.17.10.dist-info/licenses/LICENSE +201 -0
  45. deriva_ml-1.17.10.dist-info/top_level.txt +1 -0
@@ -0,0 +1,485 @@
1
+ """
2
+ Model management for Deriva ML catalogs.
3
+
4
+ This module provides the DerivaModel class which augments the standard Deriva model class with
5
+ ML-specific functionality. It handles schema management, feature definitions, and asset tracking.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ # Standard library imports
11
+ from collections import Counter, defaultdict
12
+ from graphlib import CycleError, TopologicalSorter
13
+ from typing import Any, Callable, Final, Iterable, NewType, TypeAlias
14
+
15
+ from deriva.core.ermrest_catalog import ErmrestCatalog
16
+
17
+ # Deriva imports
18
+ from deriva.core.ermrest_model import Column, FindAssociationResult, Model, Schema, Table
19
+
20
+ # Third-party imports
21
+ from pydantic import ConfigDict, validate_call
22
+
23
+ from deriva_ml.core.definitions import (
24
+ ML_SCHEMA,
25
+ RID,
26
+ DerivaAssetColumns,
27
+ TableDefinition,
28
+ )
29
+ from deriva_ml.core.exceptions import DerivaMLException, DerivaMLTableTypeError
30
+
31
+ # Local imports
32
+ from deriva_ml.feature import Feature
33
+ from deriva_ml.protocols.dataset import DatasetLike
34
+
35
+ try:
36
+ from icecream import ic
37
+ except ImportError: # Graceful fallback if IceCream isn't installed.
38
+ ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
39
+
40
+
41
+ # Define common types:
42
+ TableInput: TypeAlias = str | Table
43
+ SchemaDict: TypeAlias = dict[str, Schema]
44
+ FeatureList: TypeAlias = Iterable[Feature]
45
+ SchemaName = NewType("SchemaName", str)
46
+ ColumnSet: TypeAlias = set[Column]
47
+ AssociationResult: TypeAlias = FindAssociationResult
48
+ TableSet: TypeAlias = set[Table]
49
+ PathList: TypeAlias = list[list[Table]]
50
+
51
+ # Define constants:
52
+ VOCAB_COLUMNS: Final[set[str]] = {"NAME", "URI", "SYNONYMS", "DESCRIPTION", "ID"}
53
+ ASSET_COLUMNS: Final[set[str]] = {"Filename", "URL", "Length", "MD5", "Description"}
54
+
55
+ FilterPredicate = Callable[[Table], bool]
56
+
57
+
58
+ class DerivaModel:
59
+ """Augmented interface to deriva model class.
60
+
61
+ This class provides a number of DerivaML specific methods that augment the interface in the deriva model class.
62
+
63
+ Attributes:
64
+ domain_schema: Schema name for domain-specific tables and relationships.
65
+ model: ERMRest model for the catalog.
66
+ catalog: ERMRest catalog for the model
67
+ hostname: ERMRest catalog for the model
68
+ ml_schema: The ML schema for the catalog.
69
+ domain_schema: The domain schema for the catalog.
70
+
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ model: Model,
76
+ ml_schema: str = ML_SCHEMA,
77
+ domain_schema: str | None = None,
78
+ ):
79
+ """Create and initialize a DerivaML instance.
80
+
81
+ This method will connect to a catalog, and initialize local configuration for the ML execution.
82
+ This class is intended to be used as a base class on which domain-specific interfaces are built.
83
+
84
+ Args:
85
+ model: The ERMRest model for the catalog.
86
+ ml_schema: The ML schema name.
87
+ domain_schema: The domain schema name.
88
+ """
89
+ self.model = model
90
+ self.configuration = None
91
+ self.catalog: ErmrestCatalog = self.model.catalog
92
+ self.hostname = self.catalog.deriva_server.server if isinstance(self.catalog, ErmrestCatalog) else "localhost"
93
+
94
+ self.ml_schema = ml_schema
95
+ builtin_schemas = ("public", self.ml_schema, "www", "WWW")
96
+ if domain_schema:
97
+ self.domain_schema = domain_schema
98
+ else:
99
+ if len(user_schemas := {k for k in self.model.schemas.keys()} - set(builtin_schemas)) == 1:
100
+ self.domain_schema = user_schemas.pop()
101
+ else:
102
+ raise DerivaMLException(f"Ambiguous domain schema: {user_schemas}")
103
+
104
+ def refresh_model(self) -> None:
105
+ self.model = self.catalog.getCatalogModel()
106
+
107
+ @property
108
+ def schemas(self) -> dict[str, Schema]:
109
+ return self.model.schemas
110
+
111
+ @property
112
+ def chaise_config(self) -> dict[str, Any]:
113
+ """Return the chaise configuration."""
114
+ return self.model.chaise_config
115
+
116
+ def __getattr__(self, name: str) -> Any:
117
+ # Called only if `name` is not found in Manager. Delegate attributes to model class.
118
+ return getattr(self.model, name)
119
+
120
+ def name_to_table(self, table: TableInput) -> Table:
121
+ """Return the table object corresponding to the given table name.
122
+
123
+ If the table name appears in more than one schema, return the first one you find.
124
+
125
+ Args:
126
+ table: A ERMRest table object or a string that is the name of the table.
127
+
128
+ Returns:
129
+ Table object.
130
+ """
131
+ if isinstance(table, Table):
132
+ return table
133
+ if table in (s := self.model.schemas[self.domain_schema].tables):
134
+ return s[table]
135
+ for s in [self.model.schemas[sname] for sname in [self.domain_schema, self.ml_schema, "WWW"]]:
136
+ if table in s.tables.keys():
137
+ return s.tables[table]
138
+ raise DerivaMLException(f"The table {table} doesn't exist.")
139
+
140
+ def is_vocabulary(self, table_name: TableInput) -> bool:
141
+ """Check if a given table is a controlled vocabulary table.
142
+
143
+ Args:
144
+ table_name: A ERMRest table object or the name of the table.
145
+
146
+ Returns:
147
+ Table object if the table is a controlled vocabulary, False otherwise.
148
+
149
+ Raises:
150
+ DerivaMLException: if the table doesn't exist.
151
+
152
+ """
153
+ vocab_columns = {"NAME", "URI", "SYNONYMS", "DESCRIPTION", "ID"}
154
+ table = self.name_to_table(table_name)
155
+ return vocab_columns.issubset({c.name.upper() for c in table.columns})
156
+
157
+ def is_association(
158
+ self,
159
+ table_name: str | Table,
160
+ unqualified: bool = True,
161
+ pure: bool = True,
162
+ min_arity: int = 2,
163
+ max_arity: int = 2,
164
+ ) -> bool | set[str] | int:
165
+ """Check the specified table to see if it is an association table.
166
+
167
+ Args:
168
+ table_name: param unqualified:
169
+ pure: return: (Default value = True)
170
+ table_name: str | Table:
171
+ unqualified: (Default value = True)
172
+
173
+ Returns:
174
+
175
+
176
+ """
177
+ table = self.name_to_table(table_name)
178
+ return table.is_association(unqualified=unqualified, pure=pure, min_arity=min_arity, max_arity=max_arity)
179
+
180
+ def find_association(self, table1: Table | str, table2: Table | str) -> tuple[Table, Column, Column]:
181
+ """Given two tables, return an association table that connects the two and the two columns used to link them..
182
+
183
+ Raises:
184
+ DerivaML exception if there is either not an association table or more than one association table.
185
+ """
186
+ table1 = self.name_to_table(table1)
187
+ table2 = self.name_to_table(table2)
188
+
189
+ tables = [
190
+ (a.table, a.self_fkey.columns[0].name, other_key.columns[0].name)
191
+ for a in table1.find_associations(pure=False)
192
+ if len(a.other_fkeys) == 1 and (other_key := a.other_fkeys.pop()).pk_table == table2
193
+ ]
194
+
195
+ if len(tables) == 1:
196
+ return tables[0]
197
+ elif len(tables) == 0:
198
+ raise DerivaMLException(f"No association tables found between {table1.name} and {table2.name}.")
199
+ else:
200
+ raise DerivaMLException(
201
+ f"There are {len(tables)} association tables between {table1.name} and {table2.name}."
202
+ )
203
+
204
+ def is_asset(self, table_name: TableInput) -> bool:
205
+ """True if the specified table is an asset table.
206
+
207
+ Args:
208
+ table_name: str | Table:
209
+
210
+ Returns:
211
+ True if the specified table is an asset table, False otherwise.
212
+
213
+ """
214
+ asset_columns = {"Filename", "URL", "Length", "MD5", "Description"}
215
+ table = self.name_to_table(table_name)
216
+ return asset_columns.issubset({c.name for c in table.columns})
217
+
218
+ def find_assets(self, with_metadata: bool = False) -> list[Table]:
219
+ """Return the list of asset tables in the current model"""
220
+ return [t for s in self.model.schemas.values() for t in s.tables.values() if self.is_asset(t)]
221
+
222
+ def find_vocabularies(self) -> list[Table]:
223
+ """Return a list of all the controlled vocabulary tables in the domain schema."""
224
+ return [t for s in self.model.schemas.values() for t in s.tables.values() if self.is_vocabulary(t)]
225
+
226
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
227
+ def find_features(self, table: TableInput) -> Iterable[Feature]:
228
+ """List the names of the features in the specified table.
229
+
230
+ Args:
231
+ table: The table to find features for.
232
+ table: Table | str:
233
+
234
+ Returns:
235
+ An iterable of FeatureResult instances that describe the current features in the table.
236
+ """
237
+ table = self.name_to_table(table)
238
+
239
+ def is_feature(a: FindAssociationResult) -> bool:
240
+ """Check if association represents a feature.
241
+
242
+ Args:
243
+ a: Association result to check
244
+ Returns:
245
+ bool: True if association represents a feature
246
+ """
247
+ return {
248
+ "Feature_Name",
249
+ "Execution",
250
+ a.self_fkey.foreign_key_columns[0].name,
251
+ }.issubset({c.name for c in a.table.columns})
252
+
253
+ return [
254
+ Feature(a, self) for a in table.find_associations(min_arity=3, max_arity=3, pure=False) if is_feature(a)
255
+ ]
256
+
257
+ def lookup_feature(self, table: TableInput, feature_name: str) -> Feature:
258
+ """Lookup the named feature associated with the provided table.
259
+
260
+ Args:
261
+ table: param feature_name:
262
+ table: str | Table:
263
+ feature_name: str:
264
+
265
+ Returns:
266
+ A Feature class that represents the requested feature.
267
+
268
+ Raises:
269
+ DerivaMLException: If the feature cannot be found.
270
+ """
271
+ table = self.name_to_table(table)
272
+ try:
273
+ return [f for f in self.find_features(table) if f.feature_name == feature_name][0]
274
+ except IndexError:
275
+ raise DerivaMLException(f"Feature {table.name}:{feature_name} doesn't exist.")
276
+
277
+ def asset_metadata(self, table: str | Table) -> set[str]:
278
+ """Return the metadata columns for an asset table."""
279
+
280
+ table = self.name_to_table(table)
281
+
282
+ if not self.is_asset(table):
283
+ raise DerivaMLTableTypeError("asset table", table.name)
284
+ return {c.name for c in table.columns} - DerivaAssetColumns
285
+
286
+ def apply(self) -> None:
287
+ """Call ERMRestModel.apply"""
288
+ if self.catalog == "file-system":
289
+ raise DerivaMLException("Cannot apply() to non-catalog model.")
290
+ else:
291
+ self.model.apply()
292
+
293
+ def list_dataset_element_types(self) -> list[Table]:
294
+ """
295
+ Lists the data types of elements contained within a dataset.
296
+
297
+ This method analyzes the dataset and identifies the data types for all
298
+ elements within it. It is useful for understanding the structure and
299
+ content of the dataset and allows for better manipulation and usage of its
300
+ data.
301
+
302
+ Returns:
303
+ list[str]: A list of strings where each string represents a data type
304
+ of an element found in the dataset.
305
+
306
+ """
307
+
308
+ dataset_table = self.name_to_table("Dataset")
309
+
310
+ def domain_table(table: Table) -> bool:
311
+ return table.schema.name == self.domain_schema or table.name == dataset_table.name
312
+
313
+ return [t for a in dataset_table.find_associations() if domain_table(t := a.other_fkeys.pop().pk_table)]
314
+
315
+ def _prepare_wide_table(self,
316
+ dataset,
317
+ dataset_rid: RID,
318
+ include_tables: list[str]) -> tuple[dict[str, Any], list[tuple]]:
319
+ """
320
+ Generates details of a wide table from the model
321
+
322
+ Args:
323
+ include_tables (list[str] | None): List of table names to include in the denormalized dataset. If None,
324
+ all tables from the dataset will be included.
325
+
326
+ Returns:
327
+ str: SQL query string that represents the process of denormalization.
328
+ """
329
+
330
+ # Skip over tables that we don't want to include in the denormalized dataset.
331
+ # Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
332
+ # table.
333
+ include_tables = set(include_tables)
334
+ for t in include_tables:
335
+ # Check to make sure the table is in the catalog.
336
+ _ = self.name_to_table(t)
337
+
338
+ table_paths = [
339
+ path
340
+ for path in self._schema_to_paths()
341
+ if path[-1].name in include_tables and include_tables.intersection({p.name for p in path})
342
+ ]
343
+ paths_by_element = defaultdict(list)
344
+ for p in table_paths:
345
+ paths_by_element[p[2].name].append(p)
346
+
347
+ # Get the names of all of the tables that can be dataset elements.
348
+ dataset_element_tables = {
349
+ e.name for e in self.list_dataset_element_types() if e.schema.name == self.domain_schema
350
+ }
351
+
352
+ skip_columns = {"RCT", "RMT", "RCB", "RMB"}
353
+ element_tables = {}
354
+ for element_table, paths in paths_by_element.items():
355
+ graph = {}
356
+ for path in paths:
357
+ for left, right in zip(path[0:], path[1:]):
358
+ graph.setdefault(left.name, set()).add(right.name)
359
+
360
+ # New lets remove any cycles that we may have in the graph.
361
+ # We will use a topological sort to find the order in which we need to join the tables.
362
+ # If we find a cycle, we will remove the table from the graph and splice in an additional ON clause.
363
+ # We will then repeat the process until there are no cycles.
364
+ graph_has_cycles = True
365
+ element_join_tables = []
366
+ element_join_conditions = {}
367
+ while graph_has_cycles:
368
+ try:
369
+ ts = TopologicalSorter(graph)
370
+ element_join_tables = list(reversed(list(ts.static_order())))
371
+ graph_has_cycles = False
372
+ except CycleError as e:
373
+ cycle_nodes = e.args[1]
374
+ if len(cycle_nodes) > 3:
375
+ raise DerivaMLException(f"Unexpected cycle found when normalizing dataset {cycle_nodes}")
376
+ # Remove cycle from graph and splice in additional ON constraint.
377
+ graph[cycle_nodes[1]].remove(cycle_nodes[0])
378
+
379
+ # The Dataset_Version table is a special case as it points to dataset and dataset to version.
380
+ if "Dataset_Version" in element_join_tables:
381
+ element_join_tables.remove("Dataset_Version")
382
+
383
+ for path in paths:
384
+ for left, right in zip(path[0:], path[1:]):
385
+ if right.name == "Dataset_Version":
386
+ # The Dataset_Version table is a special case as it points to dataset and dataset to version.
387
+ continue
388
+ if element_join_tables.index(right.name) < element_join_tables.index(left.name):
389
+ continue
390
+ table_relationship = self._table_relationship(left, right)
391
+ element_join_conditions.setdefault(right.name, set()).add(
392
+ (table_relationship[0], table_relationship[1])
393
+ )
394
+ element_tables[element_table] = (element_join_tables, element_join_conditions)
395
+ # Get the list of columns that will appear in the final denormalized dataset.
396
+ denormalized_columns = [
397
+ (table_name, c.name)
398
+ for table_name in include_tables
399
+ if not self.is_association(table_name) # Don't include association columns in the denormalized view.'
400
+ for c in self.name_to_table(table_name).columns
401
+ if (not include_tables or table_name in include_tables) and (c.name not in skip_columns)
402
+ ]
403
+ return element_tables, denormalized_columns
404
+
405
+ def _table_relationship(
406
+ self,
407
+ table1: TableInput,
408
+ table2: TableInput,
409
+ ) -> tuple[Column, Column]:
410
+ """Return columns used to relate two tables."""
411
+ table1 = self.name_to_table(table1)
412
+ table2 = self.name_to_table(table2)
413
+ relationships = [
414
+ (fk.foreign_key_columns[0], fk.referenced_columns[0]) for fk in table1.foreign_keys if fk.pk_table == table2
415
+ ]
416
+ relationships.extend(
417
+ [(fk.referenced_columns[0], fk.foreign_key_columns[0]) for fk in table1.referenced_by if fk.table == table2]
418
+ )
419
+ if len(relationships) != 1:
420
+ raise DerivaMLException(
421
+ f"Ambiguous linkage between {table1.name} and {table2.name}: {[(r[0].name, r[1].name) for r in relationships]}"
422
+ )
423
+ return relationships[0]
424
+
425
+ def _schema_to_paths(
426
+ self,
427
+ root: Table | None = None,
428
+ path: list[Table] | None = None,
429
+ ) -> list[list[Table]]:
430
+ """Return a list of paths through the schema graph.
431
+
432
+ Args:
433
+ root: The root table to start from.
434
+ path: The current path being built.
435
+
436
+ Returns:
437
+ A list of paths through the schema graph.
438
+ """
439
+ path = path or []
440
+
441
+ root = root or self.model.schemas[self.ml_schema].tables["Dataset"]
442
+ path = path.copy() if path else []
443
+ parent = path[-1] if path else None # Table that we are coming from.
444
+ path.append(root)
445
+ paths = [path]
446
+
447
+ def find_arcs(table: Table) -> set[Table]:
448
+ """Given a path through the model, return the FKs that link the tables"""
449
+ arc_list = [fk.pk_table for fk in table.foreign_keys] + [fk.table for fk in table.referenced_by]
450
+ arc_list = [t for t in arc_list if t.schema.name in {self.domain_schema, self.ml_schema}]
451
+ domain_tables = [t for t in arc_list if t.schema.name == self.domain_schema]
452
+ if multiple_columns := [c for c, cnt in Counter(domain_tables).items() if cnt > 1]:
453
+ raise DerivaMLException(f"Ambiguous relationship in {table.name} {multiple_columns}")
454
+ return set(arc_list)
455
+
456
+ def is_nested_dataset_loopback(n1: Table, n2: Table) -> bool:
457
+ """Test to see if node is an association table used to link elements to datasets."""
458
+ # If we have node_name <- node_name_dataset-> Dataset then we are looping
459
+ # back around to a new dataset element
460
+ dataset_table = self.model.schemas[self.ml_schema].tables["Dataset"]
461
+ assoc_table = [a for a in dataset_table.find_associations() if a.table == n2]
462
+ return len(assoc_table) == 1 and n1 != dataset_table
463
+
464
+ # Don't follow vocabulary terms back to their use.
465
+ if self.is_vocabulary(root):
466
+ return paths
467
+
468
+ for child in find_arcs(root):
469
+ if child.name in {"Dataset_Execution", "Dataset_Dataset", "Execution"}:
470
+ continue
471
+ if child == parent:
472
+ # Don't loop back via referred_by
473
+ continue
474
+ if is_nested_dataset_loopback(root, child):
475
+ continue
476
+ if child in path:
477
+ raise DerivaMLException(f"Cycle in schema path: {child.name} path:{[p.name for p in path]}")
478
+
479
+ paths.extend(self._schema_to_paths(child, path))
480
+ return paths
481
+
482
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
483
+ def create_table(self, table_def: TableDefinition) -> Table:
484
+ """Create a new table from TableDefinition."""
485
+ return self.model.schemas[self.domain_schema].create_table(table_def.model_dump())