deriva-ml 1.14.47__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,8 +4,6 @@ The module implements the sqllite interface to a set of directories representing
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- import sqlite3
8
-
9
7
  # Standard library imports
10
8
  from collections import defaultdict
11
9
  from copy import copy
@@ -16,15 +14,18 @@ import deriva.core.datapath as datapath
16
14
  # Third-party imports
17
15
  import pandas as pd
18
16
 
17
+ # Local imports
18
+ from deriva.core.ermrest_model import Table
19
+
19
20
  # Deriva imports
20
- from deriva.core.ermrest_model import Column, Table
21
21
  from pydantic import ConfigDict, validate_call
22
+ from sqlalchemy import CompoundSelect, Engine, RowMapping, Select, and_, inspect, select, union
23
+ from sqlalchemy.orm import RelationshipProperty, Session
24
+ from sqlalchemy.orm.util import AliasedClass
22
25
 
23
- # Local imports
24
26
  from deriva_ml.core.definitions import RID, VocabularyTerm
25
27
  from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm
26
28
  from deriva_ml.feature import Feature
27
- from deriva_ml.model.sql_mapper import SQLMapper
28
29
 
29
30
  if TYPE_CHECKING:
30
31
  from deriva_ml.model.database import DatabaseModel
@@ -64,7 +65,8 @@ class DatasetBag:
64
65
  dataset_rid: Optional RID for the dataset.
65
66
  """
66
67
  self.model = database_model
67
- self.database = cast(sqlite3.Connection, self.model.dbase)
68
+ self.engine = cast(Engine, self.model.engine)
69
+ self.metadata = self.model.metadata
68
70
 
69
71
  self.dataset_rid = dataset_rid or self.model.dataset_rid
70
72
  if not self.dataset_rid:
@@ -86,54 +88,48 @@ class DatasetBag:
86
88
  """
87
89
  return self.model.list_tables()
88
90
 
89
- def _dataset_table_view(self, table: str) -> str:
90
- """Return a SQL command that will return all of the elements in the specified table that are associated with
91
- dataset_rid"""
92
-
93
- table_name = self.model.normalize_table_name(table)
94
-
95
- # Get the names of the columns in the table.
96
- with self.database as dbase:
97
- select_args = ",".join(
98
- [f'"{table_name}"."{c[1]}"' for c in dbase.execute(f'PRAGMA table_info("{table_name}")').fetchall()]
99
- )
91
+ @staticmethod
92
+ def _find_relationship_attr(source, target):
93
+ """
94
+ Return the relationship attribute (InstrumentedAttribute) on `source`
95
+ that points to `target`. Works with classes or AliasedClass.
96
+ Raises LookupError if not found.
97
+ """
98
+ src_mapper = inspect(source).mapper
99
+ tgt_mapper = inspect(target).mapper
100
100
 
101
- # Get the list of datasets in the bag including the dataset itself.
102
- datasets = ",".join(
103
- [f'"{self.dataset_rid}"'] + [f'"{ds.dataset_rid}"' for ds in self.list_dataset_children(recurse=True)]
104
- )
101
+ # collect relationships on the *class* mapper (not on alias)
102
+ candidates: list[RelationshipProperty] = [rel for rel in src_mapper.relationships if rel.mapper is tgt_mapper]
105
103
 
106
- # Find the paths that terminate in the table we are looking for
107
- # Assemble the ON clause by looking at each table pair, and looking up the FK columns that connect them.
108
- paths = [
109
- (
110
- [f'"{self.model.normalize_table_name(t.name)}"' for t in p],
111
- [self.model._table_relationship(t1, t2) for t1, t2 in zip(p, p[1:])],
112
- )
113
- for p in self.model._schema_to_paths()
114
- if p[-1].name == table
115
- ]
104
+ if not candidates:
105
+ raise LookupError(f"No relationship from {src_mapper.class_.__name__} {tgt_mapper.class_.__name__}")
116
106
 
117
- sql = []
118
- dataset_table_name = f'"{self.model.normalize_table_name(self._dataset_table.name)}"'
107
+ # Prefer MANYTOONE when multiple paths exist (often best for joins)
108
+ candidates.sort(key=lambda r: r.direction.name != "MANYTOONE")
109
+ rel = candidates[0]
119
110
 
120
- def column_name(col: Column) -> str:
121
- return f'"{self.model.normalize_table_name(col.table.name)}"."{col.name}"'
111
+ # Bind to the actual source (alias or class)
112
+ return getattr(source, rel.key) if isinstance(source, AliasedClass) else rel.class_attribute
122
113
 
123
- for ts, on in paths:
124
- tables = " JOIN ".join(ts)
125
- on_expression = " and ".join([f"{column_name(left)}={column_name(right)}" for left, right in on])
126
- sql.append(
127
- f"SELECT {select_args} FROM {tables} "
128
- f"{'ON ' + on_expression if on_expression else ''} "
129
- f"WHERE {dataset_table_name}.RID IN ({datasets})"
130
- )
131
- if table_name == self.model.normalize_table_name(self._dataset_table.name):
132
- sql.append(
133
- f"SELECT {select_args} FROM {dataset_table_name} WHERE {dataset_table_name}.RID IN ({datasets})"
134
- )
135
- sql = " UNION ".join(sql) if len(sql) > 1 else sql[0]
136
- return sql
114
+ def _dataset_table_view(self, table: str) -> CompoundSelect[Any]:
115
+ """Return a SQL command that will return all of the elements in the specified table that are associated with
116
+ dataset_rid"""
117
+ table_class = self.model.get_orm_class_by_name(table)
118
+ dataset_table_class = self.model.get_orm_class_by_name(self._dataset_table.name)
119
+ dataset_rids = [self.dataset_rid] + [c.dataset_rid for c in self.list_dataset_children(recurse=True)]
120
+
121
+ paths = [[t.name for t in p] for p in self.model._schema_to_paths() if p[-1].name == table]
122
+ sql_cmds = []
123
+ for path in paths:
124
+ path_sql = select(table_class)
125
+ last_class = self.model.get_orm_class_by_name(path[0])
126
+ for t in path[1:]:
127
+ t_class = self.model.get_orm_class_by_name(t)
128
+ path_sql = path_sql.join(self._find_relationship_attr(last_class, t_class))
129
+ last_class = t_class
130
+ path_sql = path_sql.where(dataset_table_class.RID.in_(dataset_rids))
131
+ sql_cmds.append(path_sql)
132
+ return union(*sql_cmds)
137
133
 
138
134
  def get_table(self, table: str) -> Generator[tuple, None, None]:
139
135
  """Retrieve the contents of the specified table. If schema is not provided as part of the table name,
@@ -146,9 +142,10 @@ class DatasetBag:
146
142
  A generator that yields tuples of column values.
147
143
 
148
144
  """
149
- result = self.database.execute(self._dataset_table_view(table))
150
- while row := result.fetchone():
151
- yield row
145
+ with Session(self.engine) as session:
146
+ result = session.execute(self._dataset_table_view(table))
147
+ for row in result:
148
+ yield row
152
149
 
153
150
  def get_table_as_dataframe(self, table: str) -> pd.DataFrame:
154
151
  """Retrieve the contents of the specified table as a dataframe.
@@ -163,7 +160,7 @@ class DatasetBag:
163
160
  Returns:
164
161
  A dataframe containing the contents of the specified table.
165
162
  """
166
- return pd.read_sql(self._dataset_table_view(table), self.database)
163
+ return pd.read_sql(self._dataset_table_view(table), self.engine)
167
164
 
168
165
  def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
169
166
  """Retrieve the contents of the specified table as a dictionary.
@@ -176,15 +173,12 @@ class DatasetBag:
176
173
  A generator producing dictionaries containing the contents of the specified table as name/value pairs.
177
174
  """
178
175
 
179
- table_name = self.model.normalize_table_name(table)
180
- schema, table = table_name.split(":")
181
- with self.database as _dbase:
182
- mapper = SQLMapper(self.model, table)
183
- result = self.database.execute(self._dataset_table_view(table))
184
- while row := result.fetchone():
185
- yield mapper.transform_tuple(row)
176
+ with Session(self.engine) as session:
177
+ result = session.execute(self._dataset_table_view(table))
178
+ for row in result.mappings():
179
+ yield row
186
180
 
187
- @validate_call
181
+ # @validate_call
188
182
  def list_dataset_members(self, recurse: bool = False) -> dict[str, list[dict[str, Any]]]:
189
183
  """Return a list of entities associated with a specific dataset.
190
184
 
@@ -198,39 +192,31 @@ class DatasetBag:
198
192
  # Look at each of the element types that might be in the _dataset_table and get the list of rid for them from
199
193
  # the appropriate association table.
200
194
  members = defaultdict(list)
201
- for assoc_table in self._dataset_table.find_associations():
202
- member_fkey = assoc_table.other_fkeys.pop()
203
- if member_fkey.pk_table.name == "Dataset" and member_fkey.foreign_key_columns[0].name != "Nested_Dataset":
204
- # Sometimes find_assoc gets confused on Dataset_Dataset.
205
- member_fkey = assoc_table.self_fkey
206
-
207
- target_table = member_fkey.pk_table
208
- member_table = assoc_table.table
209
-
210
- if target_table.schema.name != self.model.domain_schema and not (
211
- target_table == self._dataset_table or target_table.name == "File"
212
- ):
195
+
196
+ dataset_class = self.model.get_orm_class_for_table(self._dataset_table)
197
+ for element_table in self.model.list_dataset_element_types():
198
+ element_class = self.model.get_orm_class_for_table(element_table)
199
+
200
+ assoc_class, dataset_rel, element_rel = self.model.get_orm_association_class(dataset_class, element_class)
201
+
202
+ element_table = inspect(element_class).mapped_table
203
+ if element_table.schema != self.model.domain_schema and element_table.name not in ["Dataset", "File"]:
213
204
  # Look at domain tables and nested datasets.
214
205
  continue
215
- sql_target = self.model.normalize_table_name(target_table.name)
216
- sql_member = self.model.normalize_table_name(member_table.name)
217
-
218
206
  # Get the names of the columns that we are going to need for linking
219
- member_link = tuple(c.name for c in next(iter(member_fkey.column_map.items())))
220
- with self.database as db:
221
- col_names = [c[1] for c in db.execute(f'PRAGMA table_info("{sql_target}")').fetchall()]
222
- select_cols = ",".join([f'"{sql_target}".{c}' for c in col_names])
207
+ with Session(self.engine) as session:
223
208
  sql_cmd = (
224
- f'SELECT {select_cols} FROM "{sql_member}" '
225
- f'JOIN "{sql_target}" ON "{sql_member}".{member_link[0]} = "{sql_target}".{member_link[1]} '
226
- f'WHERE "{self.dataset_rid}" = "{sql_member}".Dataset;'
209
+ select(element_class)
210
+ .join(element_rel)
211
+ .where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
227
212
  )
228
- mapper = SQLMapper(self.model, sql_target)
229
- target_entities = [mapper.transform_tuple(e) for e in db.execute(sql_cmd).fetchall()]
230
- members[target_table.name].extend(target_entities)
231
- if recurse and (target_table.name == self._dataset_table.name):
213
+ # Get back the list of ORM entities and convert them to dictionaries.
214
+ element_entities = session.scalars(sql_cmd).all()
215
+ element_rows = [{c.key: getattr(obj, c.key) for c in obj.__table__.columns} for obj in element_entities]
216
+ members[element_table.name].extend(element_rows)
217
+ if recurse and (element_table.name == self._dataset_table.name):
232
218
  # Get the members for all the nested datasets and add to the member list.
233
- nested_datasets = [d["RID"] for d in target_entities]
219
+ nested_datasets = [d["RID"] for d in element_rows]
234
220
  for ds in nested_datasets:
235
221
  nested_dataset = self.model.get_dataset(ds)
236
222
  for k, v in nested_dataset.list_dataset_members(recurse=recurse).items():
@@ -259,12 +245,26 @@ class DatasetBag:
259
245
  Feature values.
260
246
  """
261
247
  feature = self.model.lookup_feature(table, feature_name)
262
- feature_table = self.model.normalize_table_name(feature.feature_table.name)
248
+ feature_class = self.model.get_orm_class_for_table(feature.feature_table)
249
+ with Session(self.engine) as session:
250
+ sql_cmd = select(feature_class)
251
+ return cast(datapath._ResultSet, [row for row in session.execute(sql_cmd).mappings()])
263
252
 
264
- with self.database as db:
265
- col_names = [c[1] for c in db.execute(f'PRAGMA table_info("{feature_table}")').fetchall()]
266
- sql_cmd = f'SELECT * FROM "{feature_table}"'
267
- return cast(datapath._ResultSet, [dict(zip(col_names, r)) for r in db.execute(sql_cmd).fetchall()])
253
+ def list_dataset_element_types(self) -> list[Table]:
254
+ """
255
+ Lists the data types of elements contained within a dataset.
256
+
257
+ This method analyzes the dataset and identifies the data types for all
258
+ elements within it. It is useful for understanding the structure and
259
+ content of the dataset and allows for better manipulation and usage of its
260
+ data.
261
+
262
+ Returns:
263
+ list[str]: A list of strings where each string represents a data type
264
+ of an element found in the dataset.
265
+
266
+ """
267
+ return self.model.list_dataset_element_types()
268
268
 
269
269
  def list_dataset_children(self, recurse: bool = False) -> list[DatasetBag]:
270
270
  """Get nested datasets.
@@ -275,18 +275,18 @@ class DatasetBag:
275
275
  Returns:
276
276
  List of child dataset bags.
277
277
  """
278
- ds_table = self.model.normalize_table_name("Dataset")
279
- nds_table = self.model.normalize_table_name("Dataset_Dataset")
280
- dv_table = self.model.normalize_table_name("Dataset_Version")
281
- with self.database as db:
278
+ ds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset")
279
+ nds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Dataset")
280
+ dv_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Version")
281
+
282
+ with Session(self.engine) as session:
282
283
  sql_cmd = (
283
- f'SELECT "{nds_table}".Nested_Dataset, "{dv_table}".Version '
284
- f'FROM "{nds_table}" JOIN "{dv_table}" JOIN "{ds_table}" on '
285
- f'"{ds_table}".Version == "{dv_table}".RID AND '
286
- f'"{nds_table}".Nested_Dataset == "{ds_table}".RID '
287
- f'where "{nds_table}".Dataset == "{self.dataset_rid}"'
284
+ select(nds_table.Nested_Dataset, dv_table.Version)
285
+ .join_from(ds_table, nds_table, onclause=ds_table.RID == nds_table.Nested_Dataset)
286
+ .join_from(ds_table, dv_table, onclause=ds_table.Version == dv_table.RID)
287
+ .where(nds_table.Dataset == self.dataset_rid)
288
288
  )
289
- nested = [DatasetBag(self.model, r[0]) for r in db.execute(sql_cmd).fetchall()]
289
+ nested = [DatasetBag(self.model, r[0]) for r in session.execute(sql_cmd).all()]
290
290
 
291
291
  result = copy(nested)
292
292
  if recurse:
@@ -320,18 +320,126 @@ class DatasetBag:
320
320
  >>> term = ml.lookup_term("tissue_types", "epithelium")
321
321
  """
322
322
  # Get and validate vocabulary table reference
323
- vocab_table = self.model.normalize_table_name(table)
324
323
  if not self.model.is_vocabulary(table):
325
324
  raise DerivaMLException(f"The table {table} is not a controlled vocabulary")
326
325
 
327
326
  # Search for term by name or synonym
328
- for term in self.get_table_as_dict(vocab_table):
327
+ for term in self.get_table_as_dict(table):
329
328
  if term_name == term["Name"] or (term["Synonyms"] and term_name in term["Synonyms"]):
330
329
  term["Synonyms"] = list(term["Synonyms"])
331
330
  return VocabularyTerm.model_validate(term)
332
331
 
333
332
  # Term not found
334
- raise DerivaMLInvalidTerm(vocab_table, term_name)
333
+ raise DerivaMLInvalidTerm(table, term_name)
334
+
335
+ def _denormalize(self, include_tables: list[str]) -> Select:
336
+ """
337
+ Generates an SQL statement for denormalizing the dataset based on the tables to include. Processes cycles in
338
+ graph relationships, ensures proper join order, and generates selected columns for denormalization.
339
+
340
+ Args:
341
+ include_tables (list[str] | None): List of table names to include in the denormalized dataset. If None,
342
+ all tables from the dataset will be included.
343
+
344
+ Returns:
345
+ str: SQL query string that represents the process of denormalization.
346
+ """
347
+ # Skip over tables that we don't want to include in the denormalized dataset.
348
+ # Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
349
+ # table.
350
+
351
+ def find_relationship(table, join_condition):
352
+ side1 = (join_condition[0].table.name, join_condition[0].name)
353
+ side2 = (join_condition[1].table.name, join_condition[1].name)
354
+
355
+ for relationship in inspect(table).relationships:
356
+ local_columns = list(relationship.local_columns)[0].table.name, list(relationship.local_columns)[0].name
357
+ remote_side = list(relationship.remote_side)[0].table.name, list(relationship.remote_side)[0].name
358
+ if local_columns == side1 and remote_side == side2 or local_columns == side2 and remote_side == side1:
359
+ return relationship
360
+ return None
361
+
362
+ join_tables, denormalized_columns = (
363
+ self.model._prepare_wide_table(self, self.dataset_rid, include_tables)
364
+ )
365
+
366
+ denormalized_columns = [
367
+ self.model.get_orm_class_by_name(table_name)
368
+ .__table__.columns[column_name]
369
+ .label(f"{table_name}.{column_name}")
370
+ for table_name, column_name in denormalized_columns
371
+ ]
372
+ sql_statements = []
373
+ for key, (path, join_conditions) in join_tables.items():
374
+ sql_statement = select(*denormalized_columns).select_from(
375
+ self.model.get_orm_class_for_table(self._dataset_table)
376
+ )
377
+ for table_name in path[1:]: # Skip over dataset table
378
+ table_class = self.model.get_orm_class_by_name(table_name)
379
+ on_clause = [
380
+ getattr(table_class, r.key)
381
+ for on_condition in join_conditions[table_name]
382
+ if (r := find_relationship(table_class, on_condition))
383
+ ]
384
+ sql_statement = sql_statement.join(table_class, onclause=and_(*on_clause))
385
+ dataset_rid_list = [self.dataset_rid] + self.list_dataset_children(recurse=True)
386
+ dataset_class = self.model.get_orm_class_by_name(self._dataset_table.name)
387
+ sql_statement = sql_statement.where(dataset_class.RID.in_(dataset_rid_list))
388
+ sql_statements.append(sql_statement)
389
+ return union(*sql_statements)
390
+
391
+ def denormalize_as_dataframe(self, include_tables: list[str]) -> pd.DataFrame:
392
+ """
393
+ Denormalize the dataset and return the result as a dataframe.
394
+
395
+ This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
396
+ the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
397
+ for each row in the denormalized wide table.
398
+
399
+ The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
400
+ view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
401
+ additional tables are required to complete the denormalization process. If include_tables is not specified,
402
+ all of the tables in the schema will be included.
403
+
404
+ The resulting wide table will include a column for every table needed to complete the denormalization process.
405
+
406
+ Args:
407
+ include_tables: List of table names to include in the denormalized dataset.
408
+
409
+ Returns:
410
+ Dataframe containing the denormalized dataset.
411
+ """
412
+ return pd.read_sql(self._denormalize(include_tables=include_tables), self.engine)
413
+
414
+ def denormalize_as_dict(self, include_tables: list[str]) -> Generator[RowMapping, None, None]:
415
+ """
416
+ Denormalize the dataset and return the result as a set of dictionary's.
417
+
418
+ This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
419
+ the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
420
+ for each row in the denormalized wide table.
421
+
422
+ The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
423
+ view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
424
+ additional tables are required to complete the denormalization process. If include_tables is not specified,
425
+ all of the tables in the schema will be included.
426
+
427
+ The resulting wide table will include a only those column for the tables listed in include_columns.
428
+
429
+ Args:
430
+ include_tables: List of table names to include in the denormalized dataset. If None, than the entire schema
431
+ is used.
432
+
433
+ Returns:
434
+ A generator that returns a dictionary representation of each row in the denormalized dataset.
435
+ """
436
+ with Session(self.engine) as session:
437
+ cursor = session.execute(
438
+ self._denormalize(include_tables=include_tables)
439
+ )
440
+ yield from cursor.mappings()
441
+ for row in cursor.mappings():
442
+ yield row
335
443
 
336
444
 
337
445
  # Add annotations after definition to deal with forward reference issues in pydantic
@@ -77,11 +77,11 @@ feature_value_regex = feature_table_dir_regex + f"{SEP}(?P=feature_name)[.](?P<e
77
77
  feature_asset_dir_regex = feature_table_dir_regex + f"{SEP}asset{SEP}(?P<asset_table>[-\\w]+)"
78
78
  feature_asset_regex = feature_asset_dir_regex + f"{SEP}(?P<file>[A-Za-z0-9_-]+)[.](?P<ext>[a-z0-9]*)$"
79
79
 
80
- asset_path_regex = exec_dir_regex + f"{SEP}asset{SEP}(?P<schema>[-\\w]+){SEP}(?P<asset_table>[-\\w]*)"
80
+ asset_path_regex = exec_dir_regex + rf"{SEP}asset{SEP}(?P<schema>[-\w]+){SEP}(?P<asset_table>[-\w]*)"
81
81
 
82
82
  asset_file_regex = r"(?P<file>[-\w]+)[.](?P<ext>[a-z0-9]*)$"
83
83
 
84
- table_regex = exec_dir_regex + f"{SEP}table{SEP}(?P<schema>[-\\w]+){SEP}(?P<table>[-\\w]+){SEP}(?P=table)[.](csv|json)$"
84
+ table_regex = exec_dir_regex + rf"{SEP}table{SEP}(?P<schema>[-\w]+){SEP}(?P<table>[-\w]+){SEP}(?P=table)[.](csv|json)$"
85
85
 
86
86
 
87
87
  def is_feature_dir(path: Path) -> Optional[re.Match]:
@@ -190,7 +190,9 @@ def asset_table_upload_spec(model: DerivaModel, asset_table: str | Table):
190
190
  metadata_columns = model.asset_metadata(asset_table)
191
191
  asset_table = model.name_to_table(asset_table)
192
192
  schema = model.name_to_table(asset_table).schema.name
193
- metadata_path = "/".join([rf"(?P<{c}>[-\w]+)" for c in metadata_columns])
193
+
194
+ # Be careful here as a metadata value might be a string with can contain special characters.
195
+ metadata_path = "/".join([rf"(?P<{c}>[-:._ \w]+)" for c in metadata_columns])
194
196
  asset_path = f"{exec_dir_regex}/asset/{schema}/{asset_table.name}/{metadata_path}/{asset_file_regex}"
195
197
  asset_table = model.name_to_table(asset_table)
196
198
  schema = model.name_to_table(asset_table).schema.name
@@ -412,11 +414,12 @@ def asset_file_path(
412
414
  "Description",
413
415
  }.union(set(DerivaSystemColumns))
414
416
  asset_metadata = {c.name for c in asset_table.columns} - asset_columns
417
+
415
418
  if not (asset_metadata >= set(metadata.keys())):
416
419
  raise DerivaMLException(f"Metadata {metadata} does not match asset metadata {asset_metadata}")
417
420
 
418
421
  for m in asset_metadata:
419
- path = path / metadata.get(m, "None")
422
+ path = path / str(metadata.get(m, "None"))
420
423
  path.mkdir(parents=True, exist_ok=True)
421
424
  return path / file_name
422
425
 
deriva_ml/demo_catalog.py CHANGED
@@ -5,6 +5,7 @@ import itertools
5
5
  import logging
6
6
  import string
7
7
  from collections.abc import Iterator, Sequence
8
+ from datetime import datetime
8
9
  from numbers import Integral
9
10
  from pathlib import Path
10
11
  from random import choice, randint, random
@@ -54,7 +55,13 @@ def populate_demo_catalog(ml_instance: DerivaML) -> None:
54
55
  )
55
56
  with execution.execute() as e:
56
57
  for s in ss:
57
- image_file = e.asset_file_path("Image", f"test_{s['RID']}.txt", Subject=s["RID"])
58
+ image_file = e.asset_file_path(
59
+ "Image",
60
+ f"test_{s['RID']}.txt",
61
+ Subject=s["RID"],
62
+ Acquisition_Time=datetime.now(),
63
+ Acquisition_Date=datetime.now().date(),
64
+ )
58
65
  with image_file.open("w") as f:
59
66
  f.write(f"Hello there {random()}\n")
60
67
  execution.upload_execution_outputs()
@@ -343,7 +350,14 @@ def create_domain_schema(catalog: ErmrestCatalog, sname: str) -> None:
343
350
  )
344
351
  with TemporaryDirectory() as tmpdir:
345
352
  ml_instance = DerivaML(hostname=catalog.deriva_server.server, catalog_id=catalog.catalog_id, working_dir=tmpdir)
346
- ml_instance.create_asset("Image", referenced_tables=[subject_table])
353
+ ml_instance.create_asset(
354
+ "Image",
355
+ column_defs=[
356
+ Column.define("Acquisition_Time", builtin_types.timestamp),
357
+ Column.define("Acquisition_Date", builtin_types.date),
358
+ ],
359
+ referenced_tables=[subject_table],
360
+ )
347
361
  catalog_annotation(ml_instance.model)
348
362
 
349
363
 
@@ -367,7 +381,7 @@ def create_demo_catalog(
367
381
  create_features=False,
368
382
  create_datasets=False,
369
383
  on_exit_delete=True,
370
- logging_level=logging.INFO,
384
+ logging_level=logging.WARNING,
371
385
  ) -> ErmrestCatalog:
372
386
  test_catalog = create_ml_catalog(hostname, project_name=project_name)
373
387
  if on_exit_delete:
@@ -0,0 +1,26 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ # Safe imports - no circular dependencies
4
+ from deriva_ml.execution.execution_configuration import ExecutionConfiguration, AssetRIDConfig
5
+ from deriva_ml.execution.workflow import Workflow
6
+
7
+ if TYPE_CHECKING:
8
+ from deriva_ml.execution.execution import Execution
9
+
10
+
11
+ # Lazy import for runtime
12
+ def __getattr__(name):
13
+ """Lazy import to avoid circular dependencies."""
14
+ if name == "Execution":
15
+ from deriva_ml.execution.execution import Execution
16
+
17
+ return Execution
18
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
19
+
20
+
21
+ __all__ = [
22
+ "Execution", # Lazy-loaded
23
+ "ExecutionConfiguration",
24
+ "Workflow",
25
+ "AssetRIDConfig"
26
+ ]