deriva-ml 1.17.15__py3-none-any.whl → 1.17.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. deriva_ml/__init__.py +2 -2
  2. deriva_ml/asset/asset.py +0 -4
  3. deriva_ml/catalog/__init__.py +6 -0
  4. deriva_ml/catalog/clone.py +1513 -22
  5. deriva_ml/catalog/localize.py +66 -29
  6. deriva_ml/core/base.py +12 -9
  7. deriva_ml/core/definitions.py +13 -12
  8. deriva_ml/core/ermrest.py +11 -12
  9. deriva_ml/core/mixins/annotation.py +2 -2
  10. deriva_ml/core/mixins/asset.py +3 -3
  11. deriva_ml/core/mixins/dataset.py +3 -3
  12. deriva_ml/core/mixins/execution.py +1 -0
  13. deriva_ml/core/mixins/feature.py +2 -2
  14. deriva_ml/core/mixins/file.py +2 -2
  15. deriva_ml/core/mixins/path_builder.py +2 -2
  16. deriva_ml/core/mixins/rid_resolution.py +2 -2
  17. deriva_ml/core/mixins/vocabulary.py +2 -2
  18. deriva_ml/core/mixins/workflow.py +3 -3
  19. deriva_ml/dataset/catalog_graph.py +3 -4
  20. deriva_ml/dataset/dataset.py +5 -3
  21. deriva_ml/dataset/dataset_bag.py +0 -2
  22. deriva_ml/dataset/upload.py +2 -2
  23. deriva_ml/demo_catalog.py +0 -1
  24. deriva_ml/execution/__init__.py +8 -8
  25. deriva_ml/execution/base_config.py +2 -2
  26. deriva_ml/execution/execution.py +5 -3
  27. deriva_ml/execution/execution_record.py +0 -1
  28. deriva_ml/execution/model_protocol.py +1 -1
  29. deriva_ml/execution/multirun_config.py +0 -1
  30. deriva_ml/execution/runner.py +3 -3
  31. deriva_ml/experiment/experiment.py +3 -3
  32. deriva_ml/feature.py +2 -2
  33. deriva_ml/interfaces.py +2 -2
  34. deriva_ml/model/__init__.py +45 -24
  35. deriva_ml/model/annotations.py +0 -1
  36. deriva_ml/model/catalog.py +3 -2
  37. deriva_ml/model/data_loader.py +330 -0
  38. deriva_ml/model/data_sources.py +439 -0
  39. deriva_ml/model/database.py +216 -32
  40. deriva_ml/model/fk_orderer.py +379 -0
  41. deriva_ml/model/handles.py +1 -1
  42. deriva_ml/model/schema_builder.py +816 -0
  43. deriva_ml/run_model.py +3 -3
  44. deriva_ml/schema/annotations.py +2 -1
  45. deriva_ml/schema/create_schema.py +1 -1
  46. deriva_ml/schema/validation.py +1 -1
  47. {deriva_ml-1.17.15.dist-info → deriva_ml-1.17.16.dist-info}/METADATA +1 -1
  48. deriva_ml-1.17.16.dist-info/RECORD +81 -0
  49. deriva_ml-1.17.15.dist-info/RECORD +0 -77
  50. {deriva_ml-1.17.15.dist-info → deriva_ml-1.17.16.dist-info}/WHEEL +0 -0
  51. {deriva_ml-1.17.15.dist-info → deriva_ml-1.17.16.dist-info}/entry_points.txt +0 -0
  52. {deriva_ml-1.17.15.dist-info → deriva_ml-1.17.16.dist-info}/licenses/LICENSE +0 -0
  53. {deriva_ml-1.17.15.dist-info → deriva_ml-1.17.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,379 @@
1
+ """Foreign key dependency ordering for safe data insertion.
2
+
3
+ This module provides the ForeignKeyOrderer class which computes a
4
+ topologically sorted insertion order for tables based on their
5
+ foreign key dependencies.
6
+
7
+ When loading data into a database, tables must be populated in an order
8
+ that satisfies foreign key constraints - referenced tables must be
9
+ populated before the tables that reference them.
10
+
11
+ Example:
12
+ orderer = ForeignKeyOrderer(model, schemas=['domain', 'deriva-ml'])
13
+
14
+ # Get safe insertion order for a set of tables
15
+ tables = ['Image', 'Subject', 'Diagnosis']
16
+ ordered = orderer.get_insertion_order(tables)
17
+ # Returns: ['Subject', 'Image', 'Diagnosis']
18
+ # (Subject first because Image references it)
19
+
20
+ # Get deletion order (reverse of insertion)
21
+ delete_order = orderer.get_deletion_order(tables)
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import logging
27
+ from graphlib import CycleError, TopologicalSorter
28
+
29
+ from deriva.core.ermrest_model import Model
30
+ from deriva.core.ermrest_model import Table as DerivaTable
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class ForeignKeyOrderer:
36
+ """Computes insertion order for tables based on FK dependencies.
37
+
38
+ Uses topological sort to ensure referenced tables are populated
39
+ before tables that reference them. Handles cycles by either
40
+ raising an error or breaking them.
41
+
42
+ Example:
43
+ orderer = ForeignKeyOrderer(model, schemas=['domain', 'deriva-ml'])
44
+
45
+ # Get insertion order
46
+ tables_to_fill = ['Image', 'Subject', 'Diagnosis']
47
+ ordered = orderer.get_insertion_order(tables_to_fill)
48
+ # Returns: ['Subject', 'Image', 'Diagnosis']
49
+
50
+ # Get all tables in safe order
51
+ all_ordered = orderer.get_insertion_order()
52
+
53
+ # Get FK dependencies for a table
54
+ deps = orderer.get_dependencies('Image')
55
+ # Returns: {'Subject', 'Dataset', ...}
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ model: Model,
61
+ schemas: list[str],
62
+ ):
63
+ """Initialize the orderer.
64
+
65
+ Args:
66
+ model: ERMrest Model object.
67
+ schemas: Schemas to consider for FK relationships.
68
+ """
69
+ self.model = model
70
+ self.schemas = set(schemas)
71
+ self._table_cache: dict[str, DerivaTable] = {}
72
+ self._build_table_cache()
73
+
74
+ def _build_table_cache(self) -> None:
75
+ """Build cache mapping table names to Table objects."""
76
+ for schema_name in self.schemas:
77
+ if schema_name not in self.model.schemas:
78
+ continue
79
+ schema = self.model.schemas[schema_name]
80
+ for table_name, table in schema.tables.items():
81
+ # Store both qualified and unqualified names
82
+ self._table_cache[f"{schema_name}.{table_name}"] = table
83
+ # Only store unqualified if not already present (avoids conflicts)
84
+ if table_name not in self._table_cache:
85
+ self._table_cache[table_name] = table
86
+
87
+ def _to_table(self, t: str | DerivaTable) -> DerivaTable:
88
+ """Convert table name to Table object.
89
+
90
+ Args:
91
+ t: Table name or Table object.
92
+
93
+ Returns:
94
+ DerivaTable object.
95
+
96
+ Raises:
97
+ ValueError: If table not found.
98
+ """
99
+ if isinstance(t, DerivaTable):
100
+ return t
101
+
102
+ if t in self._table_cache:
103
+ return self._table_cache[t]
104
+
105
+ raise ValueError(f"Table {t} not found in schemas {self.schemas}")
106
+
107
+ def _table_key(self, t: DerivaTable) -> str:
108
+ """Get unique key for a table."""
109
+ return f"{t.schema.name}.{t.name}"
110
+
111
+ def get_dependencies(self, table: str | DerivaTable) -> set[DerivaTable]:
112
+ """Get tables that this table depends on (FK targets).
113
+
114
+ Args:
115
+ table: Table name or object.
116
+
117
+ Returns:
118
+ Set of tables that must be populated before this table.
119
+ """
120
+ t = self._to_table(table)
121
+ dependencies = set()
122
+
123
+ for fk in t.foreign_keys:
124
+ pk_table = fk.pk_table
125
+ # Only include dependencies within our schemas
126
+ if pk_table.schema.name in self.schemas:
127
+ # Don't include self-references as dependencies
128
+ if self._table_key(pk_table) != self._table_key(t):
129
+ dependencies.add(pk_table)
130
+
131
+ return dependencies
132
+
133
+ def get_dependents(self, table: str | DerivaTable) -> set[DerivaTable]:
134
+ """Get tables that depend on this table (FK sources).
135
+
136
+ Args:
137
+ table: Table name or object.
138
+
139
+ Returns:
140
+ Set of tables that reference this table.
141
+ """
142
+ t = self._to_table(table)
143
+ dependents = set()
144
+
145
+ for schema_name in self.schemas:
146
+ if schema_name not in self.model.schemas:
147
+ continue
148
+
149
+ for other_table in self.model.schemas[schema_name].tables.values():
150
+ if self._table_key(other_table) == self._table_key(t):
151
+ continue
152
+
153
+ for fk in other_table.foreign_keys:
154
+ if self._table_key(fk.pk_table) == self._table_key(t):
155
+ dependents.add(other_table)
156
+ break
157
+
158
+ return dependents
159
+
160
+ def _build_dependency_graph(
161
+ self,
162
+ tables: list[str | DerivaTable] | None = None,
163
+ ) -> dict[str, set[str]]:
164
+ """Build FK dependency graph.
165
+
166
+ Args:
167
+ tables: Tables to include. If None, includes all tables.
168
+
169
+ Returns:
170
+ Dict mapping table key -> set of table keys it depends on.
171
+ """
172
+ if tables is None:
173
+ # Include all tables in schemas
174
+ table_objs = []
175
+ for schema_name in self.schemas:
176
+ if schema_name in self.model.schemas:
177
+ table_objs.extend(self.model.schemas[schema_name].tables.values())
178
+ else:
179
+ table_objs = [self._to_table(t) for t in tables]
180
+
181
+ table_keys = {self._table_key(t) for t in table_objs}
182
+ graph: dict[str, set[str]] = {}
183
+
184
+ for t in table_objs:
185
+ key = self._table_key(t)
186
+ deps = set()
187
+
188
+ for fk in t.foreign_keys:
189
+ pk_key = self._table_key(fk.pk_table)
190
+ # Only include deps within our table set
191
+ if pk_key in table_keys and pk_key != key:
192
+ deps.add(pk_key)
193
+
194
+ graph[key] = deps
195
+
196
+ return graph
197
+
198
+ def get_insertion_order(
199
+ self,
200
+ tables: list[str | DerivaTable] | None = None,
201
+ handle_cycles: bool = True,
202
+ ) -> list[DerivaTable]:
203
+ """Compute FK-safe insertion order for the given tables.
204
+
205
+ Returns tables ordered so that all FK dependencies are satisfied
206
+ when inserting in order.
207
+
208
+ Args:
209
+ tables: Tables to order. If None, orders all tables in schemas.
210
+ handle_cycles: If True, break cycles by removing edges.
211
+ If False, raise CycleError on cycles.
212
+
213
+ Returns:
214
+ Ordered list of Table objects (insert from first to last).
215
+
216
+ Raises:
217
+ CycleError: If handle_cycles=False and cycles exist.
218
+ """
219
+ graph = self._build_dependency_graph(tables)
220
+
221
+ try:
222
+ ts = TopologicalSorter(graph)
223
+ ordered_keys = list(ts.static_order())
224
+ except CycleError as e:
225
+ if handle_cycles:
226
+ ordered_keys = self._break_cycles_and_sort(graph, e)
227
+ else:
228
+ raise
229
+
230
+ # Convert keys back to Table objects
231
+ return [self._table_cache[key] for key in ordered_keys]
232
+
233
+ def get_deletion_order(
234
+ self,
235
+ tables: list[str | DerivaTable] | None = None,
236
+ handle_cycles: bool = True,
237
+ ) -> list[DerivaTable]:
238
+ """Compute FK-safe deletion order for the given tables.
239
+
240
+ Returns tables in reverse dependency order - tables that are
241
+ referenced should be deleted last.
242
+
243
+ Args:
244
+ tables: Tables to order. If None, orders all tables in schemas.
245
+ handle_cycles: If True, break cycles. If False, raise on cycles.
246
+
247
+ Returns:
248
+ Ordered list of Table objects (delete from first to last).
249
+ """
250
+ insertion_order = self.get_insertion_order(tables, handle_cycles)
251
+ return list(reversed(insertion_order))
252
+
253
+ def _break_cycles_and_sort(
254
+ self,
255
+ graph: dict[str, set[str]],
256
+ error: CycleError,
257
+ ) -> list[str]:
258
+ """Handle cycles by breaking them and re-sorting.
259
+
260
+ Uses a simple strategy of removing edges from cycle members
261
+ until no cycles remain.
262
+
263
+ Args:
264
+ graph: Dependency graph.
265
+ error: CycleError with cycle info.
266
+
267
+ Returns:
268
+ Ordered list of table keys.
269
+ """
270
+ # Get cycle from error message
271
+ cycle = list(error.args[1]) if len(error.args) > 1 else []
272
+
273
+ if cycle:
274
+ logger.warning(f"Breaking cycle in FK dependencies: {' -> '.join(cycle)}")
275
+
276
+ # Remove the last edge in the cycle
277
+ if len(cycle) >= 2:
278
+ from_node = cycle[-1]
279
+ to_node = cycle[0]
280
+ if from_node in graph and to_node in graph[from_node]:
281
+ graph[from_node].remove(to_node)
282
+ logger.debug(f"Removed edge {from_node} -> {to_node}")
283
+
284
+ # Try again
285
+ try:
286
+ ts = TopologicalSorter(graph)
287
+ return list(ts.static_order())
288
+ except CycleError as e:
289
+ # Recursively break more cycles
290
+ return self._break_cycles_and_sort(graph, e)
291
+
292
+ def validate_insertion_order(
293
+ self,
294
+ tables: list[str | DerivaTable],
295
+ ) -> list[tuple[str, str, str]]:
296
+ """Validate that a list of tables can be inserted in order.
297
+
298
+ Checks each table to ensure all its FK dependencies are
299
+ satisfied by tables earlier in the list.
300
+
301
+ Args:
302
+ tables: Ordered list of tables to validate.
303
+
304
+ Returns:
305
+ List of (table, missing_dependency, fk_name) tuples for
306
+ any unsatisfied dependencies. Empty list if valid.
307
+ """
308
+ table_objs = [self._to_table(t) for t in tables]
309
+ seen_keys = set()
310
+ violations = []
311
+
312
+ for t in table_objs:
313
+ key = self._table_key(t)
314
+
315
+ for fk in t.foreign_keys:
316
+ pk_key = self._table_key(fk.pk_table)
317
+ # Skip self-references and tables not in our set
318
+ if pk_key == key:
319
+ continue
320
+ if pk_key not in {self._table_key(x) for x in table_objs}:
321
+ continue
322
+
323
+ if pk_key not in seen_keys:
324
+ violations.append((key, pk_key, fk.name[1]))
325
+
326
+ seen_keys.add(key)
327
+
328
+ return violations
329
+
330
+ def get_all_tables(self) -> list[DerivaTable]:
331
+ """Get all tables in configured schemas.
332
+
333
+ Returns:
334
+ List of all Table objects.
335
+ """
336
+ tables = []
337
+ for schema_name in self.schemas:
338
+ if schema_name in self.model.schemas:
339
+ tables.extend(self.model.schemas[schema_name].tables.values())
340
+ return tables
341
+
342
+ def find_cycles(self) -> list[list[str]]:
343
+ """Find all FK dependency cycles in the schema.
344
+
345
+ Returns:
346
+ List of cycles, each cycle is a list of table keys.
347
+ """
348
+ graph = self._build_dependency_graph()
349
+ cycles = []
350
+
351
+ # Use DFS to find cycles
352
+ visited = set()
353
+ rec_stack = set()
354
+ path = []
355
+
356
+ def dfs(node: str) -> bool:
357
+ visited.add(node)
358
+ rec_stack.add(node)
359
+ path.append(node)
360
+
361
+ for neighbor in graph.get(node, set()):
362
+ if neighbor not in visited:
363
+ if dfs(neighbor):
364
+ return True
365
+ elif neighbor in rec_stack:
366
+ # Found cycle
367
+ idx = path.index(neighbor)
368
+ cycle = path[idx:] + [neighbor]
369
+ cycles.append(cycle)
370
+
371
+ path.pop()
372
+ rec_stack.remove(node)
373
+ return False
374
+
375
+ for node in graph:
376
+ if node not in visited:
377
+ dfs(node)
378
+
379
+ return cycles
@@ -9,6 +9,6 @@ Classes:
9
9
  TableHandle: Wrapper for ERMrest Table with simplified operations.
10
10
  """
11
11
 
12
- from deriva.core.model_handles import TableHandle, ColumnHandle
12
+ from deriva.core.model_handles import ColumnHandle, TableHandle
13
13
 
14
14
  __all__ = ["TableHandle", "ColumnHandle"]