deriva-ml 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,8 +4,6 @@ The module implements the sqllite interface to a set of directories representing
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- import sqlite3
8
-
9
7
  # Standard library imports
10
8
  from collections import defaultdict
11
9
  from copy import copy
@@ -16,15 +14,18 @@ import deriva.core.datapath as datapath
16
14
  # Third-party imports
17
15
  import pandas as pd
18
16
 
17
+ # Local imports
18
+ from deriva.core.ermrest_model import Table
19
+
19
20
  # Deriva imports
20
- from deriva.core.ermrest_model import Column, Table
21
21
  from pydantic import ConfigDict, validate_call
22
+ from sqlalchemy import CompoundSelect, Engine, RowMapping, Select, and_, inspect, select, union
23
+ from sqlalchemy.orm import RelationshipProperty, Session
24
+ from sqlalchemy.orm.util import AliasedClass
22
25
 
23
- # Local imports
24
26
  from deriva_ml.core.definitions import RID, VocabularyTerm
25
27
  from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm
26
28
  from deriva_ml.feature import Feature
27
- from deriva_ml.model.sql_mapper import SQLMapper
28
29
 
29
30
  if TYPE_CHECKING:
30
31
  from deriva_ml.model.database import DatabaseModel
@@ -64,7 +65,8 @@ class DatasetBag:
64
65
  dataset_rid: Optional RID for the dataset.
65
66
  """
66
67
  self.model = database_model
67
- self.database = cast(sqlite3.Connection, self.model.dbase)
68
+ self.engine = cast(Engine, self.model.engine)
69
+ self.metadata = self.model.metadata
68
70
 
69
71
  self.dataset_rid = dataset_rid or self.model.dataset_rid
70
72
  if not self.dataset_rid:
@@ -86,54 +88,48 @@ class DatasetBag:
86
88
  """
87
89
  return self.model.list_tables()
88
90
 
89
- def _dataset_table_view(self, table: str) -> str:
90
- """Return a SQL command that will return all of the elements in the specified table that are associated with
91
- dataset_rid"""
92
-
93
- table_name = self.model.normalize_table_name(table)
94
-
95
- # Get the names of the columns in the table.
96
- with self.database as dbase:
97
- select_args = ",".join(
98
- [f'"{table_name}"."{c[1]}"' for c in dbase.execute(f'PRAGMA table_info("{table_name}")').fetchall()]
99
- )
91
+ @staticmethod
92
+ def _find_relationship_attr(source, target):
93
+ """
94
+ Return the relationship attribute (InstrumentedAttribute) on `source`
95
+ that points to `target`. Works with classes or AliasedClass.
96
+ Raises LookupError if not found.
97
+ """
98
+ src_mapper = inspect(source).mapper
99
+ tgt_mapper = inspect(target).mapper
100
100
 
101
- # Get the list of datasets in the bag including the dataset itself.
102
- datasets = ",".join(
103
- [f'"{self.dataset_rid}"'] + [f'"{ds.dataset_rid}"' for ds in self.list_dataset_children(recurse=True)]
104
- )
101
+ # collect relationships on the *class* mapper (not on alias)
102
+ candidates: list[RelationshipProperty] = [rel for rel in src_mapper.relationships if rel.mapper is tgt_mapper]
105
103
 
106
- # Find the paths that terminate in the table we are looking for
107
- # Assemble the ON clause by looking at each table pair, and looking up the FK columns that connect them.
108
- paths = [
109
- (
110
- [f'"{self.model.normalize_table_name(t.name)}"' for t in p],
111
- [self.model._table_relationship(t1, t2) for t1, t2 in zip(p, p[1:])],
112
- )
113
- for p in self.model._schema_to_paths()
114
- if p[-1].name == table
115
- ]
104
+ if not candidates:
105
+ raise LookupError(f"No relationship from {src_mapper.class_.__name__} {tgt_mapper.class_.__name__}")
116
106
 
117
- sql = []
118
- dataset_table_name = f'"{self.model.normalize_table_name(self._dataset_table.name)}"'
107
+ # Prefer MANYTOONE when multiple paths exist (often best for joins)
108
+ candidates.sort(key=lambda r: r.direction.name != "MANYTOONE")
109
+ rel = candidates[0]
119
110
 
120
- def column_name(col: Column) -> str:
121
- return f'"{self.model.normalize_table_name(col.table.name)}"."{col.name}"'
111
+ # Bind to the actual source (alias or class)
112
+ return getattr(source, rel.key) if isinstance(source, AliasedClass) else rel.class_attribute
122
113
 
123
- for ts, on in paths:
124
- tables = " JOIN ".join(ts)
125
- on_expression = " and ".join([f"{column_name(left)}={column_name(right)}" for left, right in on])
126
- sql.append(
127
- f"SELECT {select_args} FROM {tables} "
128
- f"{'ON ' + on_expression if on_expression else ''} "
129
- f"WHERE {dataset_table_name}.RID IN ({datasets})"
130
- )
131
- if table_name == self.model.normalize_table_name(self._dataset_table.name):
132
- sql.append(
133
- f"SELECT {select_args} FROM {dataset_table_name} WHERE {dataset_table_name}.RID IN ({datasets})"
134
- )
135
- sql = " UNION ".join(sql) if len(sql) > 1 else sql[0]
136
- return sql
114
+ def _dataset_table_view(self, table: str) -> CompoundSelect[Any]:
115
+ """Return a SQL command that will return all of the elements in the specified table that are associated with
116
+ dataset_rid"""
117
+ table_class = self.model.get_orm_class_by_name(table)
118
+ dataset_table_class = self.model.get_orm_class_by_name(self._dataset_table.name)
119
+ dataset_rids = [self.dataset_rid] + [c.dataset_rid for c in self.list_dataset_children(recurse=True)]
120
+
121
+ paths = [[t.name for t in p] for p in self.model._schema_to_paths() if p[-1].name == table]
122
+ sql_cmds = []
123
+ for path in paths:
124
+ path_sql = select(table_class)
125
+ last_class = self.model.get_orm_class_by_name(path[0])
126
+ for t in path[1:]:
127
+ t_class = self.model.get_orm_class_by_name(t)
128
+ path_sql = path_sql.join(self._find_relationship_attr(last_class, t_class))
129
+ last_class = t_class
130
+ path_sql = path_sql.where(dataset_table_class.RID.in_(dataset_rids))
131
+ sql_cmds.append(path_sql)
132
+ return union(*sql_cmds)
137
133
 
138
134
  def get_table(self, table: str) -> Generator[tuple, None, None]:
139
135
  """Retrieve the contents of the specified table. If schema is not provided as part of the table name,
@@ -146,9 +142,10 @@ class DatasetBag:
146
142
  A generator that yields tuples of column values.
147
143
 
148
144
  """
149
- result = self.database.execute(self._dataset_table_view(table))
150
- while row := result.fetchone():
151
- yield row
145
+ with Session(self.engine) as session:
146
+ result = session.execute(self._dataset_table_view(table))
147
+ for row in result:
148
+ yield row
152
149
 
153
150
  def get_table_as_dataframe(self, table: str) -> pd.DataFrame:
154
151
  """Retrieve the contents of the specified table as a dataframe.
@@ -163,7 +160,7 @@ class DatasetBag:
163
160
  Returns:
164
161
  A dataframe containing the contents of the specified table.
165
162
  """
166
- return pd.read_sql(self._dataset_table_view(table), self.database)
163
+ return pd.read_sql(self._dataset_table_view(table), self.engine)
167
164
 
168
165
  def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
169
166
  """Retrieve the contents of the specified table as a dictionary.
@@ -176,15 +173,12 @@ class DatasetBag:
176
173
  A generator producing dictionaries containing the contents of the specified table as name/value pairs.
177
174
  """
178
175
 
179
- table_name = self.model.normalize_table_name(table)
180
- schema, table = table_name.split(":")
181
- with self.database as _dbase:
182
- mapper = SQLMapper(self.model, table)
183
- result = self.database.execute(self._dataset_table_view(table))
184
- while row := result.fetchone():
185
- yield mapper.transform_tuple(row)
176
+ with Session(self.engine) as session:
177
+ result = session.execute(self._dataset_table_view(table))
178
+ for row in result.mappings():
179
+ yield row
186
180
 
187
- @validate_call
181
+ # @validate_call
188
182
  def list_dataset_members(self, recurse: bool = False) -> dict[str, list[dict[str, Any]]]:
189
183
  """Return a list of entities associated with a specific dataset.
190
184
 
@@ -198,39 +192,31 @@ class DatasetBag:
198
192
  # Look at each of the element types that might be in the _dataset_table and get the list of rid for them from
199
193
  # the appropriate association table.
200
194
  members = defaultdict(list)
201
- for assoc_table in self._dataset_table.find_associations():
202
- member_fkey = assoc_table.other_fkeys.pop()
203
- if member_fkey.pk_table.name == "Dataset" and member_fkey.foreign_key_columns[0].name != "Nested_Dataset":
204
- # Sometimes find_assoc gets confused on Dataset_Dataset.
205
- member_fkey = assoc_table.self_fkey
206
-
207
- target_table = member_fkey.pk_table
208
- member_table = assoc_table.table
209
-
210
- if target_table.schema.name != self.model.domain_schema and not (
211
- target_table == self._dataset_table or target_table.name == "File"
212
- ):
195
+
196
+ dataset_class = self.model.get_orm_class_for_table(self._dataset_table)
197
+ for element_table in self.model.list_dataset_element_types():
198
+ element_class = self.model.get_orm_class_for_table(element_table)
199
+
200
+ assoc_class, dataset_rel, element_rel = self.model.get_orm_association_class(dataset_class, element_class)
201
+
202
+ element_table = inspect(element_class).mapped_table
203
+ if element_table.schema != self.model.domain_schema and element_table.name not in ["Dataset", "File"]:
213
204
  # Look at domain tables and nested datasets.
214
205
  continue
215
- sql_target = self.model.normalize_table_name(target_table.name)
216
- sql_member = self.model.normalize_table_name(member_table.name)
217
-
218
206
  # Get the names of the columns that we are going to need for linking
219
- member_link = tuple(c.name for c in next(iter(member_fkey.column_map.items())))
220
- with self.database as db:
221
- col_names = [c[1] for c in db.execute(f'PRAGMA table_info("{sql_target}")').fetchall()]
222
- select_cols = ",".join([f'"{sql_target}".{c}' for c in col_names])
207
+ with Session(self.engine) as session:
223
208
  sql_cmd = (
224
- f'SELECT {select_cols} FROM "{sql_member}" '
225
- f'JOIN "{sql_target}" ON "{sql_member}".{member_link[0]} = "{sql_target}".{member_link[1]} '
226
- f'WHERE "{self.dataset_rid}" = "{sql_member}".Dataset;'
209
+ select(element_class)
210
+ .join(element_rel)
211
+ .where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
227
212
  )
228
- mapper = SQLMapper(self.model, sql_target)
229
- target_entities = [mapper.transform_tuple(e) for e in db.execute(sql_cmd).fetchall()]
230
- members[target_table.name].extend(target_entities)
231
- if recurse and (target_table.name == self._dataset_table.name):
213
+ # Get back the list of ORM entities and convert them to dictionaries.
214
+ element_entities = session.scalars(sql_cmd).all()
215
+ element_rows = [{c.key: getattr(obj, c.key) for c in obj.__table__.columns} for obj in element_entities]
216
+ members[element_table.name].extend(element_rows)
217
+ if recurse and (element_table.name == self._dataset_table.name):
232
218
  # Get the members for all the nested datasets and add to the member list.
233
- nested_datasets = [d["RID"] for d in target_entities]
219
+ nested_datasets = [d["RID"] for d in element_rows]
234
220
  for ds in nested_datasets:
235
221
  nested_dataset = self.model.get_dataset(ds)
236
222
  for k, v in nested_dataset.list_dataset_members(recurse=recurse).items():
@@ -259,12 +245,10 @@ class DatasetBag:
259
245
  Feature values.
260
246
  """
261
247
  feature = self.model.lookup_feature(table, feature_name)
262
- feature_table = self.model.normalize_table_name(feature.feature_table.name)
263
-
264
- with self.database as db:
265
- col_names = [c[1] for c in db.execute(f'PRAGMA table_info("{feature_table}")').fetchall()]
266
- sql_cmd = f'SELECT * FROM "{feature_table}"'
267
- return cast(datapath._ResultSet, [dict(zip(col_names, r)) for r in db.execute(sql_cmd).fetchall()])
248
+ feature_class = self.model.get_orm_class_for_table(feature.feature_table)
249
+ with Session(self.engine) as session:
250
+ sql_cmd = select(feature_class)
251
+ return cast(datapath._ResultSet, [row for row in session.execute(sql_cmd).mappings()])
268
252
 
269
253
  def list_dataset_element_types(self) -> list[Table]:
270
254
  """
@@ -291,18 +275,18 @@ class DatasetBag:
291
275
  Returns:
292
276
  List of child dataset bags.
293
277
  """
294
- ds_table = self.model.normalize_table_name("Dataset")
295
- nds_table = self.model.normalize_table_name("Dataset_Dataset")
296
- dv_table = self.model.normalize_table_name("Dataset_Version")
297
- with self.database as db:
278
+ ds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset")
279
+ nds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Dataset")
280
+ dv_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Version")
281
+
282
+ with Session(self.engine) as session:
298
283
  sql_cmd = (
299
- f'SELECT "{nds_table}".Nested_Dataset, "{dv_table}".Version '
300
- f'FROM "{nds_table}" JOIN "{dv_table}" JOIN "{ds_table}" on '
301
- f'"{ds_table}".Version == "{dv_table}".RID AND '
302
- f'"{nds_table}".Nested_Dataset == "{ds_table}".RID '
303
- f'where "{nds_table}".Dataset == "{self.dataset_rid}"'
284
+ select(nds_table.Nested_Dataset, dv_table.Version)
285
+ .join_from(ds_table, nds_table, onclause=ds_table.RID == nds_table.Nested_Dataset)
286
+ .join_from(ds_table, dv_table, onclause=ds_table.Version == dv_table.RID)
287
+ .where(nds_table.Dataset == self.dataset_rid)
304
288
  )
305
- nested = [DatasetBag(self.model, r[0]) for r in db.execute(sql_cmd).fetchall()]
289
+ nested = [DatasetBag(self.model, r[0]) for r in session.execute(sql_cmd).all()]
306
290
 
307
291
  result = copy(nested)
308
292
  if recurse:
@@ -336,20 +320,19 @@ class DatasetBag:
336
320
  >>> term = ml.lookup_term("tissue_types", "epithelium")
337
321
  """
338
322
  # Get and validate vocabulary table reference
339
- vocab_table = self.model.normalize_table_name(table)
340
323
  if not self.model.is_vocabulary(table):
341
324
  raise DerivaMLException(f"The table {table} is not a controlled vocabulary")
342
325
 
343
326
  # Search for term by name or synonym
344
- for term in self.get_table_as_dict(vocab_table):
327
+ for term in self.get_table_as_dict(table):
345
328
  if term_name == term["Name"] or (term["Synonyms"] and term_name in term["Synonyms"]):
346
329
  term["Synonyms"] = list(term["Synonyms"])
347
330
  return VocabularyTerm.model_validate(term)
348
331
 
349
332
  # Term not found
350
- raise DerivaMLInvalidTerm(vocab_table, term_name)
333
+ raise DerivaMLInvalidTerm(table, term_name)
351
334
 
352
- def _denormalize(self, include_tables: list[str] | None) -> str:
335
+ def _denormalize(self, include_tables: list[str]) -> Select:
353
336
  """
354
337
  Generates an SQL statement for denormalizing the dataset based on the tables to include. Processes cycles in
355
338
  graph relationships, ensures proper join order, and generates selected columns for denormalization.
@@ -361,48 +344,57 @@ class DatasetBag:
361
344
  Returns:
362
345
  str: SQL query string that represents the process of denormalization.
363
346
  """
364
-
365
- def column_name(col: Column) -> str:
366
- return f'"{self.model.normalize_table_name(col.table.name)}"."{col.name}"'
367
-
368
347
  # Skip over tables that we don't want to include in the denormalized dataset.
369
348
  # Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
370
349
  # table.
371
350
 
372
- join_tables, tables, denormalized_columns, dataset_rids, dataset_element_tables = (
351
+ def find_relationship(table, join_condition):
352
+ side1 = (join_condition[0].table.name, join_condition[0].name)
353
+ side2 = (join_condition[1].table.name, join_condition[1].name)
354
+
355
+ for relationship in inspect(table).relationships:
356
+ local_columns = list(relationship.local_columns)[0].table.name, list(relationship.local_columns)[0].name
357
+ remote_side = list(relationship.remote_side)[0].table.name, list(relationship.remote_side)[0].name
358
+ if local_columns == side1 and remote_side == side2 or local_columns == side2 and remote_side == side1:
359
+ return relationship
360
+ return None
361
+
362
+ join_tables, denormalized_columns = (
373
363
  self.model._prepare_wide_table(self, self.dataset_rid, include_tables)
374
364
  )
375
365
 
376
- select_args = [
377
- # SQLlite will strip out the table name from the column in the select statement, so we need to add
378
- # an explicit alias to the column name.
379
- f'"{self.model.normalize_table_name(table_name)}"."{column_name}" AS "{table_name}.{column_name}"'
366
+ denormalized_columns = [
367
+ self.model.get_orm_class_by_name(table_name)
368
+ .__table__.columns[column_name]
369
+ .label(f"{table_name}.{column_name}")
380
370
  for table_name, column_name in denormalized_columns
381
371
  ]
382
-
383
- # First table in the table list is the table specified in the method call.
384
- normalized_join_tables = [self.model.normalize_table_name(t) for t in join_tables]
385
- sql_statement = f'SELECT {",".join(select_args)} FROM "{normalized_join_tables[0]}"'
386
- for t in normalized_join_tables[1:]:
387
- on = tables[t]
388
- sql_statement += f' LEFT JOIN "{t}" ON '
389
- sql_statement += "OR ".join([f"{column_name(o[0])} = {column_name(o[1])}" for o in on])
390
-
391
- # Select only rows from the datasets you wish to include.
392
- dataset_rid_list = ",".join([f'"{self.dataset_rid}"'] + [f'"{b.dataset_rid}"' for b in dataset_rids])
393
- sql_statement += f'WHERE "{self.model.normalize_table_name("Dataset")}"."RID" IN ({dataset_rid_list})'
394
-
395
- # Only include rows that have actual values in them.
396
- real_row = [f'"{self.model.normalize_table_name(t)}".RID IS NOT NULL ' for t in dataset_element_tables]
397
- sql_statement += f" AND ({' OR '.join(real_row)})"
398
- return sql_statement
399
-
400
- def denormalize_as_dataframe(self, include_tables: list[str] | None = None) -> pd.DataFrame:
372
+ sql_statements = []
373
+ for key, (path, join_conditions) in join_tables.items():
374
+ sql_statement = select(*denormalized_columns).select_from(
375
+ self.model.get_orm_class_for_table(self._dataset_table)
376
+ )
377
+ for table_name in path[1:]: # Skip over dataset table
378
+ table_class = self.model.get_orm_class_by_name(table_name)
379
+ on_clause = [
380
+ getattr(table_class, r.key)
381
+ for on_condition in join_conditions[table_name]
382
+ if (r := find_relationship(table_class, on_condition))
383
+ ]
384
+ sql_statement = sql_statement.join(table_class, onclause=and_(*on_clause))
385
+ dataset_rid_list = [self.dataset_rid] + self.list_dataset_children(recurse=True)
386
+ dataset_class = self.model.get_orm_class_by_name(self._dataset_table.name)
387
+ sql_statement = sql_statement.where(dataset_class.RID.in_(dataset_rid_list))
388
+ sql_statements.append(sql_statement)
389
+ return union(*sql_statements)
390
+
391
+ def denormalize_as_dataframe(self, include_tables: list[str]) -> pd.DataFrame:
401
392
  """
402
393
  Denormalize the dataset and return the result as a dataframe.
403
394
 
404
- This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
405
- the dataset values into a single wide table. The result is returned as a dataframe.
395
+ This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
396
+ the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
397
+ for each row in the denormalized wide table.
406
398
 
407
399
  The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
408
400
  view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
@@ -412,28 +404,27 @@ class DatasetBag:
412
404
  The resulting wide table will include a column for every table needed to complete the denormalization process.
413
405
 
414
406
  Args:
415
- include_tables: List of table names to include in the denormalized dataset. If None, than the entire schema
416
- is used.
407
+ include_tables: List of table names to include in the denormalized dataset.
417
408
 
418
409
  Returns:
419
410
  Dataframe containing the denormalized dataset.
420
411
  """
421
- return pd.read_sql(self._denormalize(include_tables=include_tables), self.database)
412
+ return pd.read_sql(self._denormalize(include_tables=include_tables), self.engine)
422
413
 
423
- def denormalize_as_dict(self, include_tables: list[str] | None = None) -> Generator[dict[str, Any], None, None]:
414
+ def denormalize_as_dict(self, include_tables: list[str]) -> Generator[RowMapping, None, None]:
424
415
  """
425
- Denormalize the dataset and return the result as a set of dictionarys.
416
+ Denormalize the dataset and return the result as a set of dictionary's.
426
417
 
427
418
  This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
428
- the dataset values into a single wide table. The result is returned as a generateor that returns a dictionary
429
- for each row in the denormlized wide table.
419
+ the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
420
+ for each row in the denormalized wide table.
430
421
 
431
422
  The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
432
423
  view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
433
424
  additional tables are required to complete the denormalization process. If include_tables is not specified,
434
425
  all of the tables in the schema will be included.
435
426
 
436
- The resulting wide table will include a column for every table needed to complete the denormalization process.
427
+ The resulting wide table will include a only those column for the tables listed in include_columns.
437
428
 
438
429
  Args:
439
430
  include_tables: List of table names to include in the denormalized dataset. If None, than the entire schema
@@ -442,11 +433,13 @@ class DatasetBag:
442
433
  Returns:
443
434
  A generator that returns a dictionary representation of each row in the denormalized dataset.
444
435
  """
445
- with self.database as dbase:
446
- cursor = dbase.execute(self._denormalize(include_tables=include_tables))
447
- columns = [desc[0] for desc in cursor.description]
448
- for row in cursor:
449
- yield dict(zip(columns, row))
436
+ with Session(self.engine) as session:
437
+ cursor = session.execute(
438
+ self._denormalize(include_tables=include_tables)
439
+ )
440
+ yield from cursor.mappings()
441
+ for row in cursor.mappings():
442
+ yield row
450
443
 
451
444
 
452
445
  # Add annotations after definition to deal with forward reference issues in pydantic
@@ -77,11 +77,11 @@ feature_value_regex = feature_table_dir_regex + f"{SEP}(?P=feature_name)[.](?P<e
77
77
  feature_asset_dir_regex = feature_table_dir_regex + f"{SEP}asset{SEP}(?P<asset_table>[-\\w]+)"
78
78
  feature_asset_regex = feature_asset_dir_regex + f"{SEP}(?P<file>[A-Za-z0-9_-]+)[.](?P<ext>[a-z0-9]*)$"
79
79
 
80
- asset_path_regex = exec_dir_regex + f"{SEP}asset{SEP}(?P<schema>[-\\w]+){SEP}(?P<asset_table>[-\\w]*)"
80
+ asset_path_regex = exec_dir_regex + rf"{SEP}asset{SEP}(?P<schema>[-\w]+){SEP}(?P<asset_table>[-\w]*)"
81
81
 
82
82
  asset_file_regex = r"(?P<file>[-\w]+)[.](?P<ext>[a-z0-9]*)$"
83
83
 
84
- table_regex = exec_dir_regex + f"{SEP}table{SEP}(?P<schema>[-\\w]+){SEP}(?P<table>[-\\w]+){SEP}(?P=table)[.](csv|json)$"
84
+ table_regex = exec_dir_regex + rf"{SEP}table{SEP}(?P<schema>[-\w]+){SEP}(?P<table>[-\w]+){SEP}(?P=table)[.](csv|json)$"
85
85
 
86
86
 
87
87
  def is_feature_dir(path: Path) -> Optional[re.Match]:
@@ -190,7 +190,9 @@ def asset_table_upload_spec(model: DerivaModel, asset_table: str | Table):
190
190
  metadata_columns = model.asset_metadata(asset_table)
191
191
  asset_table = model.name_to_table(asset_table)
192
192
  schema = model.name_to_table(asset_table).schema.name
193
- metadata_path = "/".join([rf"(?P<{c}>[-\w]+)" for c in metadata_columns])
193
+
194
+ # Be careful here as a metadata value might be a string with can contain special characters.
195
+ metadata_path = "/".join([rf"(?P<{c}>[-:._ \w]+)" for c in metadata_columns])
194
196
  asset_path = f"{exec_dir_regex}/asset/{schema}/{asset_table.name}/{metadata_path}/{asset_file_regex}"
195
197
  asset_table = model.name_to_table(asset_table)
196
198
  schema = model.name_to_table(asset_table).schema.name
@@ -417,7 +419,7 @@ def asset_file_path(
417
419
  raise DerivaMLException(f"Metadata {metadata} does not match asset metadata {asset_metadata}")
418
420
 
419
421
  for m in asset_metadata:
420
- path = path / metadata.get(m, "None")
422
+ path = path / str(metadata.get(m, "None"))
421
423
  path.mkdir(parents=True, exist_ok=True)
422
424
  return path / file_name
423
425
 
deriva_ml/demo_catalog.py CHANGED
@@ -5,6 +5,7 @@ import itertools
5
5
  import logging
6
6
  import string
7
7
  from collections.abc import Iterator, Sequence
8
+ from datetime import datetime
8
9
  from numbers import Integral
9
10
  from pathlib import Path
10
11
  from random import choice, randint, random
@@ -54,7 +55,13 @@ def populate_demo_catalog(ml_instance: DerivaML) -> None:
54
55
  )
55
56
  with execution.execute() as e:
56
57
  for s in ss:
57
- image_file = e.asset_file_path("Image", f"test_{s['RID']}.txt", Subject=s["RID"])
58
+ image_file = e.asset_file_path(
59
+ "Image",
60
+ f"test_{s['RID']}.txt",
61
+ Subject=s["RID"],
62
+ Acquisition_Time=datetime.now(),
63
+ Acquisition_Date=datetime.now().date(),
64
+ )
58
65
  with image_file.open("w") as f:
59
66
  f.write(f"Hello there {random()}\n")
60
67
  execution.upload_execution_outputs()
@@ -343,7 +350,14 @@ def create_domain_schema(catalog: ErmrestCatalog, sname: str) -> None:
343
350
  )
344
351
  with TemporaryDirectory() as tmpdir:
345
352
  ml_instance = DerivaML(hostname=catalog.deriva_server.server, catalog_id=catalog.catalog_id, working_dir=tmpdir)
346
- ml_instance.create_asset("Image", referenced_tables=[subject_table])
353
+ ml_instance.create_asset(
354
+ "Image",
355
+ column_defs=[
356
+ Column.define("Acquisition_Time", builtin_types.timestamp),
357
+ Column.define("Acquisition_Date", builtin_types.date),
358
+ ],
359
+ referenced_tables=[subject_table],
360
+ )
347
361
  catalog_annotation(ml_instance.model)
348
362
 
349
363
 
@@ -1,7 +1,7 @@
1
1
  from typing import TYPE_CHECKING
2
2
 
3
3
  # Safe imports - no circular dependencies
4
- from deriva_ml.execution.execution_configuration import ExecutionConfiguration
4
+ from deriva_ml.execution.execution_configuration import ExecutionConfiguration, AssetRIDConfig
5
5
  from deriva_ml.execution.workflow import Workflow
6
6
 
7
7
  if TYPE_CHECKING:
@@ -22,4 +22,5 @@ __all__ = [
22
22
  "Execution", # Lazy-loaded
23
23
  "ExecutionConfiguration",
24
24
  "Workflow",
25
+ "AssetRIDConfig"
25
26
  ]
@@ -583,7 +583,6 @@ class Execution:
583
583
  asset_rid=status.result["RID"],
584
584
  )
585
585
  )
586
-
587
586
  self._update_asset_execution_table(asset_map)
588
587
  self.update_status(Status.running, "Updating features...")
589
588
 
@@ -805,7 +804,7 @@ class Execution:
805
804
  self,
806
805
  uploaded_assets: dict[str, list[AssetFilePath]],
807
806
  asset_role: str = "Output",
808
- ):
807
+ ) -> None:
809
808
  """Add entry to the association table connecting an asset to an execution RID
810
809
 
811
810
  Args:
@@ -814,6 +813,9 @@ class Execution:
814
813
  asset_role: A term or list of terms from the Asset_Role vocabulary.
815
814
  """
816
815
  # Make sure the asset role is in the controlled vocabulary table.
816
+ if self._dry_run:
817
+ # Don't do any updates of we are doing a dry run.
818
+ return
817
819
  self._ml_object.lookup_term(MLVocab.asset_role, asset_role)
818
820
 
819
821
  pb = self._ml_object.pathBuilder
@@ -22,15 +22,17 @@ Typical usage example:
22
22
 
23
23
  from __future__ import annotations
24
24
 
25
+ from dataclasses import dataclass
25
26
  import json
26
27
  import sys
27
28
  from pathlib import Path
28
29
  from typing import Any
29
30
 
31
+ from hydra_zen import builds
30
32
  from pydantic import BaseModel, ConfigDict, Field, field_validator
31
33
 
32
34
  from deriva_ml.core.definitions import RID
33
- from deriva_ml.dataset.aux_classes import DatasetList, DatasetSpec
35
+ from deriva_ml.dataset.aux_classes import DatasetSpec
34
36
  from deriva_ml.execution.workflow import Workflow
35
37
 
36
38
 
@@ -64,7 +66,7 @@ class ExecutionConfiguration(BaseModel):
64
66
  ... )
65
67
  """
66
68
 
67
- datasets: list[DatasetSpec] | DatasetList = []
69
+ datasets: list[DatasetSpec] = []
68
70
  assets: list[RID] = []
69
71
  workflow: RID | Workflow
70
72
  description: str = ""
@@ -72,13 +74,13 @@ class ExecutionConfiguration(BaseModel):
72
74
 
73
75
  model_config = ConfigDict(arbitrary_types_allowed=True)
74
76
 
75
- @field_validator("datasets", mode="before")
76
- @classmethod
77
- def validate_datasets(cls, value: Any) -> Any:
78
- if isinstance(value, DatasetList):
79
- config_list: DatasetList = value
80
- value = config_list.datasets
81
- return value
77
+ # @field_validator("datasets", mode="before")
78
+ # @classmethod
79
+ # def validate_datasets(cls, value: Any) -> Any:
80
+ # if isinstance(value, DatasetList):
81
+ # config_list: DatasetList = value
82
+ # value = config_list.datasets
83
+ # return value
82
84
 
83
85
  @field_validator("workflow", mode="before")
84
86
  @classmethod
@@ -137,3 +139,20 @@ class ExecutionConfiguration(BaseModel):
137
139
  # hs = HatracStore("https", self.host_name, self.credential)
138
140
  # hs.get_obj(path=configuration["URL"], destfilename=dest_file.name)
139
141
  # return ExecutionConfiguration.load_configuration(Path(dest_file.name))
142
+
143
+
144
+ @dataclass
145
+ class AssetRID(str):
146
+ rid: str
147
+ description: str = ""
148
+
149
+ def __new__(cls, rid: str, description: str = ""):
150
+ obj = super().__new__(cls, rid)
151
+ obj.description = description
152
+ return obj
153
+
154
+ AssetRIDConfig = builds(AssetRID, populate_full_signature=True)
155
+
156
+
157
+
158
+