deriva-ml 1.14.47__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -41,7 +41,6 @@ from deriva_ml.core.base import DerivaML
41
41
  from deriva_ml.core.definitions import (
42
42
  DRY_RUN_RID,
43
43
  RID,
44
- ExecAssetType,
45
44
  ExecMetadataType,
46
45
  FileSpec,
47
46
  FileUploadState,
@@ -198,7 +197,6 @@ class Execution:
198
197
  workflow_rid (RID): RID of the associated workflow.
199
198
  status (Status): Current execution status.
200
199
  asset_paths (list[AssetFilePath]): Paths to execution assets.
201
- parameters (dict): Execution parameters.
202
200
  start_time (datetime | None): When execution started.
203
201
  stop_time (datetime | None): When execution completed.
204
202
 
@@ -206,7 +204,6 @@ class Execution:
206
204
  >>> config = ExecutionConfiguration(
207
205
  ... workflow="analysis",
208
206
  ... description="Process samples",
209
- ... parameters={"threshold": 0.5}
210
207
  ... )
211
208
  >>> with ml.create_execution(config) as execution:
212
209
  ... execution.download_dataset_bag(dataset_spec)
@@ -250,7 +247,6 @@ class Execution:
250
247
 
251
248
  self.dataset_rids: List[RID] = []
252
249
  self.datasets: list[DatasetBag] = []
253
- self.parameters = self.configuration.parameters
254
250
 
255
251
  self._working_dir = self._ml_object.working_dir
256
252
  self._cache_dir = self._ml_object.cache_dir
@@ -292,9 +288,18 @@ class Execution:
292
288
  ]
293
289
  )[0]["RID"]
294
290
 
295
- if isinstance(self.configuration.workflow, Workflow) and self.configuration.workflow.is_notebook:
296
- # Put execution_rid into the cell output so we can find it later.
297
- display(Markdown(f"Execution RID: {self._ml_object.cite(self.execution_rid)}"))
291
+ if rid_path := os.environ.get("DERIVA_ML_SAVE_EXECUTION_RID", None):
292
+ # Put execution_rid into the provided file path so we can find it later.
293
+ with Path(rid_path).open("w") as f:
294
+ json.dump(
295
+ {
296
+ "hostname": self._ml_object.host_name,
297
+ "catalog_id": self._ml_object.catalog_id,
298
+ "workflow_rid": self.workflow_rid,
299
+ "execution_rid": self.execution_rid,
300
+ },
301
+ f,
302
+ )
298
303
 
299
304
  # Create a directory for execution rid so we can recover the state in case of a crash.
300
305
  execution_root(prefix=self._ml_object.working_dir, exec_rid=self.execution_rid)
@@ -302,13 +307,28 @@ class Execution:
302
307
 
303
308
  def _save_runtime_environment(self):
304
309
  runtime_env_path = self.asset_file_path(
305
- "Execution_Metadata",
306
- f"environment_snapshot_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
307
- ExecMetadataType.runtime_env.value,
310
+ asset_name="Execution_Metadata",
311
+ file_name=f"environment_snapshot_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
312
+ asset_types=ExecMetadataType.runtime_env.value,
308
313
  )
309
314
  with Path(runtime_env_path).open("w") as fp:
310
315
  json.dump(get_execution_environment(), fp)
311
316
 
317
+ def _upload_hydra_config_assets(self):
318
+ """Upload hydra assets to the catalog."""
319
+ hydra_runtime_output_dir = self._ml_object.hydra_runtime_output_dir
320
+ if hydra_runtime_output_dir:
321
+ timestamp = hydra_runtime_output_dir.parts[-1]
322
+ for hydra_asset in hydra_runtime_output_dir.rglob("*"):
323
+ if hydra_asset.is_dir():
324
+ continue
325
+ asset = self.asset_file_path(
326
+ asset_name=MLAsset.execution_metadata,
327
+ file_name=hydra_runtime_output_dir / hydra_asset,
328
+ rename_file=f"hydra-{timestamp}-{hydra_asset.name}",
329
+ asset_types=ExecMetadataType.execution_config.value,
330
+ )
331
+
312
332
  def _initialize_execution(self, reload: RID | None = None) -> None:
313
333
  """Initialize the execution by a configuration in the Execution_Metadata table.
314
334
  Set up a working directory and download all the assets and data.
@@ -354,9 +374,9 @@ class Execution:
354
374
  # Save configuration details for later upload
355
375
  if not reload:
356
376
  cfile = self.asset_file_path(
357
- MLAsset.execution_metadata,
358
- "configuration.json",
359
- ExecMetadataType.execution_config.value,
377
+ asset_name=MLAsset.execution_metadata,
378
+ file_name="configuration.json",
379
+ asset_types=ExecMetadataType.execution_config.value,
360
380
  )
361
381
  with Path(cfile).open("w", encoding="utf-8") as config_file:
362
382
  json.dump(self.configuration.model_dump(), config_file)
@@ -364,24 +384,18 @@ class Execution:
364
384
  lock_file = Path(self.configuration.workflow.git_root) / "uv.lock"
365
385
  if lock_file.exists():
366
386
  _ = self.asset_file_path(
367
- MLAsset.execution_metadata,
368
- lock_file,
369
- ExecMetadataType.execution_config.value,
387
+ asset_name=MLAsset.execution_metadata,
388
+ file_name=lock_file,
389
+ asset_types=ExecMetadataType.execution_config.value,
370
390
  )
371
391
 
372
- for parameter_file in self.configuration.parameters:
373
- self.asset_file_path(
374
- MLAsset.execution_asset,
375
- parameter_file,
376
- ExecAssetType.input_file.value,
377
- )
392
+ self._upload_hydra_config_assets()
378
393
 
379
394
  # save runtime env
380
395
  self._save_runtime_environment()
381
396
 
382
397
  # Now upload the files so we have the info in case the execution fails.
383
398
  self.uploaded_assets = self._upload_execution_dirs()
384
-
385
399
  self.start_time = datetime.now()
386
400
  self.update_status(Status.pending, "Initialize status finished.")
387
401
 
@@ -569,7 +583,6 @@ class Execution:
569
583
  asset_rid=status.result["RID"],
570
584
  )
571
585
  )
572
-
573
586
  self._update_asset_execution_table(asset_map)
574
587
  self.update_status(Status.running, "Updating features...")
575
588
 
@@ -791,7 +804,7 @@ class Execution:
791
804
  self,
792
805
  uploaded_assets: dict[str, list[AssetFilePath]],
793
806
  asset_role: str = "Output",
794
- ):
807
+ ) -> None:
795
808
  """Add entry to the association table connecting an asset to an execution RID
796
809
 
797
810
  Args:
@@ -800,6 +813,9 @@ class Execution:
800
813
  asset_role: A term or list of terms from the Asset_Role vocabulary.
801
814
  """
802
815
  # Make sure the asset role is in the controlled vocabulary table.
816
+ if self._dry_run:
817
+ # Don't do any updates of we are doing a dry run.
818
+ return
803
819
  self._ml_object.lookup_term(MLVocab.asset_role, asset_role)
804
820
 
805
821
  pb = self._ml_object.pathBuilder
@@ -856,6 +872,7 @@ class Execution:
856
872
  file_name: str | Path,
857
873
  asset_types: list[str] | str | None = None,
858
874
  copy_file=False,
875
+ rename_file: str | None = None,
859
876
  **kwargs,
860
877
  ) -> AssetFilePath:
861
878
  """Return a pathlib Path to the directory in which to place files for the specified execution_asset type.
@@ -875,6 +892,8 @@ class Execution:
875
892
  asset_name: Type of asset to be uploaded. Must be a term in Asset_Type controlled vocabulary.
876
893
  file_name: Name of file to be uploaded.
877
894
  asset_types: Type of asset to be uploaded. Defaults to the name of the asset.
895
+ copy_file: Whether to copy the file rather than creating a symbolic link.
896
+ rename_file: If provided, the file will be renamed to this name if the file already exists..
878
897
  **kwargs: Any additional metadata values that may be part of the asset table.
879
898
 
880
899
  Returns:
@@ -893,12 +912,15 @@ class Execution:
893
912
  for t in asset_types:
894
913
  self._ml_object.lookup_term(MLVocab.asset_type, t)
895
914
 
915
+ # Determine if we will need to rename an existing file as the asset.
896
916
  file_name = Path(file_name)
917
+ target_name = Path(rename_file) if file_name.exists() and rename_file else file_name
918
+
897
919
  asset_path = asset_file_path(
898
920
  prefix=self._working_dir,
899
921
  exec_rid=self.execution_rid,
900
922
  asset_table=self._model.name_to_table(asset_name),
901
- file_name=file_name.name,
923
+ file_name=target_name.name,
902
924
  metadata=kwargs,
903
925
  )
904
926
 
@@ -914,12 +936,12 @@ class Execution:
914
936
 
915
937
  # Persist the asset types into a file
916
938
  with Path(asset_type_path(self._working_dir, self.execution_rid, asset_table)).open("a") as asset_type_file:
917
- asset_type_file.write(json.dumps({file_name.name: asset_types}) + "\n")
939
+ asset_type_file.write(json.dumps({target_name.name: asset_types}) + "\n")
918
940
 
919
941
  return AssetFilePath(
920
942
  asset_path=asset_path,
921
943
  asset_name=asset_name,
922
- file_name=file_name.name,
944
+ file_name=target_name.name,
923
945
  asset_metadata=kwargs,
924
946
  asset_types=asset_types,
925
947
  )
@@ -22,11 +22,13 @@ Typical usage example:
22
22
 
23
23
  from __future__ import annotations
24
24
 
25
+ from dataclasses import dataclass
25
26
  import json
26
27
  import sys
27
28
  from pathlib import Path
28
29
  from typing import Any
29
30
 
31
+ from hydra_zen import builds
30
32
  from pydantic import BaseModel, ConfigDict, Field, field_validator
31
33
 
32
34
  from deriva_ml.core.definitions import RID
@@ -67,42 +69,18 @@ class ExecutionConfiguration(BaseModel):
67
69
  datasets: list[DatasetSpec] = []
68
70
  assets: list[RID] = []
69
71
  workflow: RID | Workflow
70
- parameters: dict[str, Any] | Path = {}
71
72
  description: str = ""
72
73
  argv: list[str] = Field(default_factory=lambda: sys.argv)
73
74
 
74
75
  model_config = ConfigDict(arbitrary_types_allowed=True)
75
76
 
76
- @field_validator("parameters", mode="before")
77
- @classmethod
78
- def validate_parameters(cls, value: Any) -> Any:
79
- """Validates and loads execution parameters.
80
-
81
- If value is a file path, loads and parses it as JSON. Otherwise, returns
82
- the value as is.
83
-
84
- Args:
85
- value: Parameter value to validate, either:
86
- - Dictionary of parameters
87
- - Path to JSON file
88
- - String path to JSON file
89
-
90
- Returns:
91
- dict[str, Any]: Validated parameter dictionary.
92
-
93
- Raises:
94
- ValueError: If JSON file is invalid or cannot be read.
95
- FileNotFoundError: If parameter file doesn't exist.
96
-
97
- Example:
98
- >>> config = ExecutionConfiguration(parameters="params.json")
99
- >>> print(config.parameters) # Contents of params.json as dict
100
- """
101
- if isinstance(value, str) or isinstance(value, Path):
102
- with Path(value).open("r") as f:
103
- return json.load(f)
104
- else:
105
- return value
77
+ # @field_validator("datasets", mode="before")
78
+ # @classmethod
79
+ # def validate_datasets(cls, value: Any) -> Any:
80
+ # if isinstance(value, DatasetList):
81
+ # config_list: DatasetList = value
82
+ # value = config_list.datasets
83
+ # return value
106
84
 
107
85
  @field_validator("workflow", mode="before")
108
86
  @classmethod
@@ -161,3 +139,20 @@ class ExecutionConfiguration(BaseModel):
161
139
  # hs = HatracStore("https", self.host_name, self.credential)
162
140
  # hs.get_obj(path=configuration["URL"], destfilename=dest_file.name)
163
141
  # return ExecutionConfiguration.load_configuration(Path(dest_file.name))
142
+
143
+
144
+ @dataclass
145
+ class AssetRID(str):
146
+ rid: str
147
+ description: str = ""
148
+
149
+ def __new__(cls, rid: str, description: str = ""):
150
+ obj = super().__new__(cls, rid)
151
+ obj.description = description
152
+ return obj
153
+
154
+ AssetRIDConfig = builds(AssetRID, populate_full_signature=True)
155
+
156
+
157
+
158
+
@@ -9,6 +9,7 @@ from typing import Any
9
9
  import requests
10
10
  from pydantic import BaseModel, PrivateAttr, model_validator
11
11
  from requests import RequestException
12
+ from setuptools_scm import get_version
12
13
 
13
14
  from deriva_ml.core.definitions import RID
14
15
  from deriva_ml.core.exceptions import DerivaMLException
@@ -129,6 +130,13 @@ class Workflow(BaseModel):
129
130
  self.url, self.checksum = Workflow.get_url_and_checksum(path)
130
131
  self.git_root = Workflow._get_git_root(path)
131
132
 
133
+ self.version = get_version(
134
+ root=str(self.git_root or Path.cwd()),
135
+ search_parent_directories=True,
136
+ # Optional but recommended: provide a safe fallback when tags are absent
137
+ fallback_version="0.0",
138
+ )
139
+
132
140
  self._logger = logging.getLogger("deriva_ml")
133
141
  return self
134
142
 
@@ -8,7 +8,8 @@ ML-specific functionality. It handles schema management, feature definitions, an
8
8
  from __future__ import annotations
9
9
 
10
10
  # Standard library imports
11
- from collections import Counter
11
+ from collections import Counter, defaultdict
12
+ from graphlib import CycleError, TopologicalSorter
12
13
  from typing import Any, Callable, Final, Iterable, NewType, TypeAlias
13
14
 
14
15
  from deriva.core.ermrest_catalog import ErmrestCatalog
@@ -21,6 +22,7 @@ from pydantic import ConfigDict, validate_call
21
22
 
22
23
  from deriva_ml.core.definitions import (
23
24
  ML_SCHEMA,
25
+ RID,
24
26
  DerivaAssetColumns,
25
27
  TableDefinition,
26
28
  )
@@ -28,6 +30,7 @@ from deriva_ml.core.exceptions import DerivaMLException, DerivaMLTableTypeError
28
30
 
29
31
  # Local imports
30
32
  from deriva_ml.feature import Feature
33
+ from deriva_ml.protocols.dataset import DatasetLike
31
34
 
32
35
  try:
33
36
  from icecream import ic
@@ -287,6 +290,118 @@ class DerivaModel:
287
290
  else:
288
291
  self.model.apply()
289
292
 
293
+ def list_dataset_element_types(self) -> list[Table]:
294
+ """
295
+ Lists the data types of elements contained within a dataset.
296
+
297
+ This method analyzes the dataset and identifies the data types for all
298
+ elements within it. It is useful for understanding the structure and
299
+ content of the dataset and allows for better manipulation and usage of its
300
+ data.
301
+
302
+ Returns:
303
+ list[str]: A list of strings where each string represents a data type
304
+ of an element found in the dataset.
305
+
306
+ """
307
+
308
+ dataset_table = self.name_to_table("Dataset")
309
+
310
+ def domain_table(table: Table) -> bool:
311
+ return table.schema.name == self.domain_schema or table.name == dataset_table.name
312
+
313
+ return [t for a in dataset_table.find_associations() if domain_table(t := a.other_fkeys.pop().pk_table)]
314
+
315
+ def _prepare_wide_table(self,
316
+ dataset,
317
+ dataset_rid: RID,
318
+ include_tables: list[str]) -> tuple[dict[str, Any], list[tuple]]:
319
+ """
320
+ Generates details of a wide table from the model
321
+
322
+ Args:
323
+ include_tables (list[str] | None): List of table names to include in the denormalized dataset. If None,
324
+ all tables from the dataset will be included.
325
+
326
+ Returns:
327
+ str: SQL query string that represents the process of denormalization.
328
+ """
329
+
330
+ # Skip over tables that we don't want to include in the denormalized dataset.
331
+ # Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
332
+ # table.
333
+ include_tables = set(include_tables)
334
+ for t in include_tables:
335
+ # Check to make sure the table is in the catalog.
336
+ _ = self.name_to_table(t)
337
+
338
+ table_paths = [
339
+ path
340
+ for path in self._schema_to_paths()
341
+ if path[-1].name in include_tables and include_tables.intersection({p.name for p in path})
342
+ ]
343
+ paths_by_element = defaultdict(list)
344
+ for p in table_paths:
345
+ paths_by_element[p[2].name].append(p)
346
+
347
+ # Get the names of all of the tables that can be dataset elements.
348
+ dataset_element_tables = {
349
+ e.name for e in self.list_dataset_element_types() if e.schema.name == self.domain_schema
350
+ }
351
+
352
+ skip_columns = {"RCT", "RMT", "RCB", "RMB"}
353
+ element_tables = {}
354
+ for element_table, paths in paths_by_element.items():
355
+ graph = {}
356
+ for path in paths:
357
+ for left, right in zip(path[0:], path[1:]):
358
+ graph.setdefault(left.name, set()).add(right.name)
359
+
360
+ # New lets remove any cycles that we may have in the graph.
361
+ # We will use a topological sort to find the order in which we need to join the tables.
362
+ # If we find a cycle, we will remove the table from the graph and splice in an additional ON clause.
363
+ # We will then repeat the process until there are no cycles.
364
+ graph_has_cycles = True
365
+ element_join_tables = []
366
+ element_join_conditions = {}
367
+ while graph_has_cycles:
368
+ try:
369
+ ts = TopologicalSorter(graph)
370
+ element_join_tables = list(reversed(list(ts.static_order())))
371
+ graph_has_cycles = False
372
+ except CycleError as e:
373
+ cycle_nodes = e.args[1]
374
+ if len(cycle_nodes) > 3:
375
+ raise DerivaMLException(f"Unexpected cycle found when normalizing dataset {cycle_nodes}")
376
+ # Remove cycle from graph and splice in additional ON constraint.
377
+ graph[cycle_nodes[1]].remove(cycle_nodes[0])
378
+
379
+ # The Dataset_Version table is a special case as it points to dataset and dataset to version.
380
+ if "Dataset_Version" in element_join_tables:
381
+ element_join_tables.remove("Dataset_Version")
382
+
383
+ for path in paths:
384
+ for left, right in zip(path[0:], path[1:]):
385
+ if right.name == "Dataset_Version":
386
+ # The Dataset_Version table is a special case as it points to dataset and dataset to version.
387
+ continue
388
+ if element_join_tables.index(right.name) < element_join_tables.index(left.name):
389
+ continue
390
+ table_relationship = self._table_relationship(left, right)
391
+ element_join_conditions.setdefault(right.name, set()).add(
392
+ (table_relationship[0], table_relationship[1])
393
+ )
394
+ element_tables[element_table] = (element_join_tables, element_join_conditions)
395
+ # Get the list of columns that will appear in the final denormalized dataset.
396
+ denormalized_columns = [
397
+ (table_name, c.name)
398
+ for table_name in include_tables
399
+ if not self.is_association(table_name) # Don't include association columns in the denormalized view.'
400
+ for c in self.name_to_table(table_name).columns
401
+ if (not include_tables or table_name in include_tables) and (c.name not in skip_columns)
402
+ ]
403
+ return element_tables, denormalized_columns
404
+
290
405
  def _table_relationship(
291
406
  self,
292
407
  table1: TableInput,
@@ -302,7 +417,9 @@ class DerivaModel:
302
417
  [(fk.referenced_columns[0], fk.foreign_key_columns[0]) for fk in table1.referenced_by if fk.table == table2]
303
418
  )
304
419
  if len(relationships) != 1:
305
- raise DerivaMLException(f"Ambiguous linkage between {table1.name} and {table2.name}")
420
+ raise DerivaMLException(
421
+ f"Ambiguous linkage between {table1.name} and {table2.name}: {[(r[0].name, r[1].name) for r in relationships]}"
422
+ )
306
423
  return relationships[0]
307
424
 
308
425
  def _schema_to_paths(