datachain 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (47) hide show
  1. datachain/__init__.py +0 -2
  2. datachain/cache.py +2 -2
  3. datachain/catalog/catalog.py +65 -180
  4. datachain/cli/__init__.py +7 -0
  5. datachain/cli/commands/datasets.py +28 -43
  6. datachain/cli/commands/ls.py +2 -2
  7. datachain/cli/parser/__init__.py +35 -1
  8. datachain/client/fsspec.py +3 -5
  9. datachain/client/hf.py +0 -10
  10. datachain/client/local.py +4 -4
  11. datachain/data_storage/metastore.py +37 -403
  12. datachain/data_storage/sqlite.py +7 -139
  13. datachain/data_storage/warehouse.py +7 -26
  14. datachain/dataset.py +12 -126
  15. datachain/delta.py +7 -11
  16. datachain/error.py +0 -36
  17. datachain/func/func.py +1 -1
  18. datachain/lib/arrow.py +3 -3
  19. datachain/lib/dataset_info.py +0 -4
  20. datachain/lib/dc/datachain.py +92 -259
  21. datachain/lib/dc/datasets.py +49 -87
  22. datachain/lib/dc/listings.py +3 -3
  23. datachain/lib/dc/records.py +0 -1
  24. datachain/lib/dc/storage.py +40 -38
  25. datachain/lib/file.py +23 -77
  26. datachain/lib/listing.py +1 -3
  27. datachain/lib/meta_formats.py +1 -1
  28. datachain/lib/pytorch.py +1 -1
  29. datachain/lib/settings.py +0 -10
  30. datachain/lib/tar.py +2 -1
  31. datachain/lib/udf_signature.py +1 -1
  32. datachain/lib/webdataset.py +20 -30
  33. datachain/listing.py +1 -3
  34. datachain/query/dataset.py +46 -71
  35. datachain/query/session.py +1 -1
  36. datachain/remote/studio.py +26 -61
  37. datachain/studio.py +7 -23
  38. {datachain-0.20.3.dist-info → datachain-0.21.0.dist-info}/METADATA +2 -2
  39. {datachain-0.20.3.dist-info → datachain-0.21.0.dist-info}/RECORD +43 -47
  40. datachain/lib/namespaces.py +0 -71
  41. datachain/lib/projects.py +0 -86
  42. datachain/namespace.py +0 -65
  43. datachain/project.py +0 -78
  44. {datachain-0.20.3.dist-info → datachain-0.21.0.dist-info}/WHEEL +0 -0
  45. {datachain-0.20.3.dist-info → datachain-0.21.0.dist-info}/entry_points.txt +0 -0
  46. {datachain-0.20.3.dist-info → datachain-0.21.0.dist-info}/licenses/LICENSE +0 -0
  47. {datachain-0.20.3.dist-info → datachain-0.21.0.dist-info}/top_level.txt +0 -0
@@ -37,13 +37,9 @@ from datachain.dataset import (
37
37
  from datachain.error import (
38
38
  DatasetNotFoundError,
39
39
  DatasetVersionNotFoundError,
40
- NamespaceNotFoundError,
41
- ProjectNotFoundError,
42
40
  TableMissingError,
43
41
  )
44
42
  from datachain.job import Job
45
- from datachain.namespace import Namespace
46
- from datachain.project import Project
47
43
  from datachain.utils import JSONSerialize
48
44
 
49
45
  if TYPE_CHECKING:
@@ -65,8 +61,6 @@ class AbstractMetastore(ABC, Serializable):
65
61
  uri: StorageURI
66
62
 
67
63
  schema: "schema.Schema"
68
- namespace_class: type[Namespace] = Namespace
69
- project_class: type[Project] = Project
70
64
  dataset_class: type[DatasetRecord] = DatasetRecord
71
65
  dataset_list_class: type[DatasetListRecord] = DatasetListRecord
72
66
  dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
@@ -112,113 +106,14 @@ class AbstractMetastore(ABC, Serializable):
112
106
  def cleanup_for_tests(self) -> None:
113
107
  """Cleanup for tests."""
114
108
 
115
- #
116
- # Namespaces
117
- #
118
-
119
- @property
120
- @abstractmethod
121
- def default_namespace_name(self):
122
- """Gets default namespace name"""
123
-
124
- @property
125
- def system_namespace_name(self):
126
- return Namespace.system()
127
-
128
- @abstractmethod
129
- def create_namespace(
130
- self,
131
- name: str,
132
- description: Optional[str] = None,
133
- uuid: Optional[str] = None,
134
- ignore_if_exists: bool = True,
135
- **kwargs,
136
- ) -> Namespace:
137
- """Creates new namespace"""
138
-
139
- @abstractmethod
140
- def get_namespace(self, name: str, conn=None) -> Namespace:
141
- """Gets a single namespace by name"""
142
-
143
- @abstractmethod
144
- def list_namespaces(self, conn=None) -> list[Namespace]:
145
- """Gets a list of all namespaces"""
146
-
147
- @property
148
- @abstractmethod
149
- def is_studio(self) -> bool:
150
- """Returns True if this code is ran in Studio"""
151
-
152
- def is_local_dataset(self, dataset_namespace: str) -> bool:
153
- """
154
- Returns True if this is local dataset i.e. not pulled from Studio but
155
- created locally. This is False if we ran code in CLI mode but using dataset
156
- names that are present in Studio.
157
- """
158
- return self.is_studio or dataset_namespace == Namespace.default()
159
-
160
- @property
161
- def namespace_allowed_to_create(self):
162
- return self.is_studio
163
-
164
- #
165
- # Projects
166
- #
167
-
168
- @property
169
- @abstractmethod
170
- def default_project_name(self):
171
- """Gets default project name"""
172
-
173
- @property
174
- def listing_project_name(self):
175
- return Project.listing()
176
-
177
- @cached_property
178
- def default_project(self) -> Project:
179
- return self.get_project(self.default_project_name, self.default_namespace_name)
180
-
181
- @cached_property
182
- def listing_project(self) -> Project:
183
- return self.get_project(self.listing_project_name, self.system_namespace_name)
184
-
185
- @abstractmethod
186
- def create_project(
187
- self,
188
- namespace_name: str,
189
- name: str,
190
- description: Optional[str] = None,
191
- uuid: Optional[str] = None,
192
- ignore_if_exists: bool = True,
193
- **kwargs,
194
- ) -> Project:
195
- """Creates new project in specific namespace"""
196
-
197
- @abstractmethod
198
- def get_project(
199
- self, name: str, namespace_name: str, create: bool = False, conn=None
200
- ) -> Project:
201
- """
202
- Gets a single project inside some namespace by name.
203
- It also creates project if not found and create flag is set to True.
204
- """
205
-
206
- @abstractmethod
207
- def list_projects(self, namespace_id: Optional[int], conn=None) -> list[Project]:
208
- """Gets list of projects in some namespace or in general (in all namespaces)"""
209
-
210
- @property
211
- def project_allowed_to_create(self):
212
- return self.is_studio
213
-
214
109
  #
215
110
  # Datasets
216
111
  #
112
+
217
113
  @abstractmethod
218
114
  def create_dataset(
219
115
  self,
220
116
  name: str,
221
- project_id: Optional[int] = None,
222
117
  status: int = DatasetStatus.CREATED,
223
118
  sources: Optional[list[str]] = None,
224
119
  feature_schema: Optional[dict] = None,
@@ -278,22 +173,15 @@ class AbstractMetastore(ABC, Serializable):
278
173
  """
279
174
 
280
175
  @abstractmethod
281
- def list_datasets(
282
- self, project_id: Optional[int] = None
283
- ) -> Iterator[DatasetListRecord]:
284
- """Lists all datasets in some project or in all projects."""
176
+ def list_datasets(self) -> Iterator[DatasetListRecord]:
177
+ """Lists all datasets."""
285
178
 
286
179
  @abstractmethod
287
- def list_datasets_by_prefix(
288
- self, prefix: str, project_id: Optional[int] = None
289
- ) -> Iterator["DatasetListRecord"]:
290
- """
291
- Lists all datasets which names start with prefix in some project or in all
292
- projects.
293
- """
180
+ def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetListRecord"]:
181
+ """Lists all datasets which names start with prefix."""
294
182
 
295
183
  @abstractmethod
296
- def get_dataset(self, name: str, project_id: Optional[int] = None) -> DatasetRecord:
184
+ def get_dataset(self, name: str) -> DatasetRecord:
297
185
  """Gets a single dataset by name."""
298
186
 
299
187
  @abstractmethod
@@ -314,10 +202,10 @@ class AbstractMetastore(ABC, Serializable):
314
202
  @abstractmethod
315
203
  def add_dataset_dependency(
316
204
  self,
317
- source_dataset: "DatasetRecord",
205
+ source_dataset_name: str,
318
206
  source_dataset_version: str,
319
- dep_dataset: "DatasetRecord",
320
- dep_dataset_version: str,
207
+ dataset_name: str,
208
+ dataset_version: str,
321
209
  ) -> None:
322
210
  """Adds dataset dependency to dataset."""
323
211
 
@@ -416,8 +304,6 @@ class AbstractDBMetastore(AbstractMetastore):
416
304
  and has shared logic for all database systems currently in use.
417
305
  """
418
306
 
419
- NAMESPACE_TABLE = "namespaces"
420
- PROJECT_TABLE = "projects"
421
307
  DATASET_TABLE = "datasets"
422
308
  DATASET_VERSION_TABLE = "datasets_versions"
423
309
  DATASET_DEPENDENCY_TABLE = "datasets_dependencies"
@@ -436,62 +322,11 @@ class AbstractDBMetastore(AbstractMetastore):
436
322
  def cleanup_tables(self, temp_table_names: list[str]) -> None:
437
323
  """Cleanup temp tables."""
438
324
 
439
- @classmethod
440
- def _namespaces_columns(cls) -> list["SchemaItem"]:
441
- """Namespace table columns."""
442
- return [
443
- Column("id", Integer, primary_key=True),
444
- Column("uuid", Text, nullable=False, default=uuid4()),
445
- Column("name", Text, nullable=False),
446
- Column("description", Text),
447
- Column("created_at", DateTime(timezone=True)),
448
- ]
449
-
450
- @cached_property
451
- def _namespaces_fields(self) -> list[str]:
452
- return [
453
- c.name # type: ignore [attr-defined]
454
- for c in self._namespaces_columns()
455
- if c.name # type: ignore [attr-defined]
456
- ]
457
-
458
- @classmethod
459
- def _projects_columns(cls) -> list["SchemaItem"]:
460
- """Project table columns."""
461
- return [
462
- Column("id", Integer, primary_key=True),
463
- Column("uuid", Text, nullable=False, default=uuid4()),
464
- Column("name", Text, nullable=False),
465
- Column("description", Text),
466
- Column("created_at", DateTime(timezone=True)),
467
- Column(
468
- "namespace_id",
469
- Integer,
470
- ForeignKey(f"{cls.NAMESPACE_TABLE}.id", ondelete="CASCADE"),
471
- nullable=False,
472
- ),
473
- UniqueConstraint("namespace_id", "name"),
474
- ]
475
-
476
- @cached_property
477
- def _projects_fields(self) -> list[str]:
478
- return [
479
- c.name # type: ignore [attr-defined]
480
- for c in self._projects_columns()
481
- if c.name # type: ignore [attr-defined]
482
- ]
483
-
484
325
  @classmethod
485
326
  def _datasets_columns(cls) -> list["SchemaItem"]:
486
327
  """Datasets table columns."""
487
328
  return [
488
329
  Column("id", Integer, primary_key=True),
489
- Column(
490
- "project_id",
491
- Integer,
492
- ForeignKey(f"{cls.PROJECT_TABLE}.id", ondelete="CASCADE"),
493
- nullable=False,
494
- ),
495
330
  Column("name", Text, nullable=False),
496
331
  Column("description", Text),
497
332
  Column("attrs", JSON, nullable=True),
@@ -610,16 +445,6 @@ class AbstractDBMetastore(AbstractMetastore):
610
445
  #
611
446
  # Query Tables
612
447
  #
613
- @cached_property
614
- def _namespaces(self) -> Table:
615
- return Table(
616
- self.NAMESPACE_TABLE, self.db.metadata, *self._namespaces_columns()
617
- )
618
-
619
- @cached_property
620
- def _projects(self) -> Table:
621
- return Table(self.PROJECT_TABLE, self.db.metadata, *self._projects_columns())
622
-
623
448
  @cached_property
624
449
  def _datasets(self) -> Table:
625
450
  return Table(self.DATASET_TABLE, self.db.metadata, *self._datasets_columns())
@@ -643,34 +468,6 @@ class AbstractDBMetastore(AbstractMetastore):
643
468
  #
644
469
  # Query Starters (These can be overridden by subclasses)
645
470
  #
646
- @abstractmethod
647
- def _namespaces_insert(self) -> "Insert": ...
648
-
649
- def _namespaces_select(self, *columns) -> "Select":
650
- if not columns:
651
- return self._namespaces.select()
652
- return select(*columns)
653
-
654
- def _namespaces_update(self) -> "Update":
655
- return self._namespaces.update()
656
-
657
- def _namespaces_delete(self) -> "Delete":
658
- return self._namespaces.delete()
659
-
660
- @abstractmethod
661
- def _projects_insert(self) -> "Insert": ...
662
-
663
- def _projects_select(self, *columns) -> "Select":
664
- if not columns:
665
- return self._projects.select()
666
- return select(*columns)
667
-
668
- def _projects_update(self) -> "Update":
669
- return self._projects.update()
670
-
671
- def _projects_delete(self) -> "Delete":
672
- return self._projects.delete()
673
-
674
471
  @abstractmethod
675
472
  def _datasets_insert(self) -> "Insert": ...
676
473
 
@@ -713,134 +510,6 @@ class AbstractDBMetastore(AbstractMetastore):
713
510
  def _datasets_dependencies_delete(self) -> "Delete":
714
511
  return self._datasets_dependencies.delete()
715
512
 
716
- #
717
- # Namespaces
718
- #
719
-
720
- def create_namespace(
721
- self,
722
- name: str,
723
- description: Optional[str] = None,
724
- uuid: Optional[str] = None,
725
- ignore_if_exists: bool = True,
726
- **kwargs,
727
- ) -> Namespace:
728
- query = self._namespaces_insert().values(
729
- name=name,
730
- uuid=uuid or str(uuid4()),
731
- created_at=datetime.now(timezone.utc),
732
- description=description,
733
- )
734
- if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
735
- # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
736
- # but generic SQL does not
737
- query = query.on_conflict_do_nothing(index_elements=["name"])
738
- self.db.execute(query)
739
-
740
- return self.get_namespace(name)
741
-
742
- def get_namespace(self, name: str, conn=None) -> Namespace:
743
- """Gets a single namespace by name"""
744
- n = self._namespaces
745
-
746
- query = self._namespaces_select(
747
- *(getattr(n.c, f) for f in self._namespaces_fields),
748
- ).where(n.c.name == name)
749
- rows = list(self.db.execute(query, conn=conn))
750
- if not rows:
751
- raise NamespaceNotFoundError(f"Namespace {name} not found.")
752
- return self.namespace_class.parse(*rows[0])
753
-
754
- def list_namespaces(self, conn=None) -> list[Namespace]:
755
- """Gets a list of all namespaces"""
756
- n = self._namespaces
757
-
758
- query = self._namespaces_select(
759
- *(getattr(n.c, f) for f in self._namespaces_fields),
760
- )
761
- rows = list(self.db.execute(query, conn=conn))
762
-
763
- return [self.namespace_class.parse(*r) for r in rows]
764
-
765
- #
766
- # Projects
767
- #
768
-
769
- def create_project(
770
- self,
771
- namespace_name: str,
772
- name: str,
773
- description: Optional[str] = None,
774
- uuid: Optional[str] = None,
775
- ignore_if_exists: bool = True,
776
- **kwargs,
777
- ) -> Project:
778
- try:
779
- namespace = self.get_namespace(namespace_name)
780
- except NamespaceNotFoundError:
781
- namespace = self.create_namespace(namespace_name)
782
-
783
- query = self._projects_insert().values(
784
- namespace_id=namespace.id,
785
- uuid=uuid or str(uuid4()),
786
- name=name,
787
- created_at=datetime.now(timezone.utc),
788
- description=description,
789
- )
790
- if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
791
- # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
792
- # but generic SQL does not
793
- query = query.on_conflict_do_nothing(
794
- index_elements=["namespace_id", "name"]
795
- )
796
- self.db.execute(query)
797
-
798
- return self.get_project(name, namespace.name)
799
-
800
- def get_project(
801
- self, name: str, namespace_name: str, create: bool = False, conn=None
802
- ) -> Project:
803
- """Gets a single project inside some namespace by name"""
804
- n = self._namespaces
805
- p = self._projects
806
-
807
- query = self._projects_select(
808
- *(getattr(n.c, f) for f in self._namespaces_fields),
809
- *(getattr(p.c, f) for f in self._projects_fields),
810
- )
811
- query = query.select_from(n.join(p, n.c.id == p.c.namespace_id)).where(
812
- p.c.name == name, n.c.name == namespace_name
813
- )
814
-
815
- rows = list(self.db.execute(query, conn=conn))
816
- if not rows:
817
- if create:
818
- return self.create_project(namespace_name, name)
819
- raise ProjectNotFoundError(
820
- f"Project {name} in namespace {namespace_name} not found."
821
- )
822
- return self.project_class.parse(*rows[0])
823
-
824
- def list_projects(self, namespace_id: Optional[int], conn=None) -> list[Project]:
825
- """
826
- Gets a list of projects inside some namespace, or in all namespaces
827
- """
828
- n = self._namespaces
829
- p = self._projects
830
-
831
- query = self._projects_select(
832
- *(getattr(n.c, f) for f in self._namespaces_fields),
833
- *(getattr(p.c, f) for f in self._projects_fields),
834
- )
835
- query = query.select_from(n.join(p, n.c.id == p.c.namespace_id))
836
-
837
- if namespace_id:
838
- query = query.where(n.c.id == namespace_id)
839
-
840
- rows = list(self.db.execute(query, conn=conn))
841
-
842
- return [self.project_class.parse(*r) for r in rows]
843
-
844
513
  #
845
514
  # Datasets
846
515
  #
@@ -848,7 +517,6 @@ class AbstractDBMetastore(AbstractMetastore):
848
517
  def create_dataset(
849
518
  self,
850
519
  name: str,
851
- project_id: Optional[int] = None,
852
520
  status: int = DatasetStatus.CREATED,
853
521
  sources: Optional[list[str]] = None,
854
522
  feature_schema: Optional[dict] = None,
@@ -860,11 +528,9 @@ class AbstractDBMetastore(AbstractMetastore):
860
528
  **kwargs, # TODO registered = True / False
861
529
  ) -> DatasetRecord:
862
530
  """Creates new dataset."""
863
- project_id = project_id or self.default_project.id
864
-
531
+ # TODO abstract this method and add registered = True based on kwargs
865
532
  query = self._datasets_insert().values(
866
533
  name=name,
867
- project_id=project_id,
868
534
  status=status,
869
535
  feature_schema=json.dumps(feature_schema or {}),
870
536
  created_at=datetime.now(timezone.utc),
@@ -880,10 +546,10 @@ class AbstractDBMetastore(AbstractMetastore):
880
546
  if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
881
547
  # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
882
548
  # but generic SQL does not
883
- query = query.on_conflict_do_nothing(index_elements=["project_id", "name"])
549
+ query = query.on_conflict_do_nothing(index_elements=["name"])
884
550
  self.db.execute(query)
885
551
 
886
- return self.get_dataset(name, project_id)
552
+ return self.get_dataset(name)
887
553
 
888
554
  def create_dataset_version( # noqa: PLR0913
889
555
  self,
@@ -940,7 +606,7 @@ class AbstractDBMetastore(AbstractMetastore):
940
606
  )
941
607
  self.db.execute(query, conn=conn)
942
608
 
943
- return self.get_dataset(dataset.name, dataset.project.id, conn=conn)
609
+ return self.get_dataset(dataset.name, conn=conn)
944
610
 
945
611
  def remove_dataset(self, dataset: DatasetRecord) -> None:
946
612
  """Removes dataset."""
@@ -1078,15 +744,13 @@ class AbstractDBMetastore(AbstractMetastore):
1078
744
 
1079
745
  def _parse_dataset_list(self, rows) -> Iterator["DatasetListRecord"]:
1080
746
  # grouping rows by dataset id
1081
- for _, g in groupby(rows, lambda r: r[11]):
747
+ for _, g in groupby(rows, lambda r: r[0]):
1082
748
  dataset = self._parse_list_dataset(list(g))
1083
749
  if dataset:
1084
750
  yield dataset
1085
751
 
1086
752
  def _get_dataset_query(
1087
753
  self,
1088
- namespace_fields: list[str],
1089
- project_fields: list[str],
1090
754
  dataset_fields: list[str],
1091
755
  dataset_version_fields: list[str],
1092
756
  isouter: bool = True,
@@ -1097,81 +761,48 @@ class AbstractDBMetastore(AbstractMetastore):
1097
761
  ):
1098
762
  raise TableMissingError
1099
763
 
1100
- n = self._namespaces
1101
- p = self._projects
1102
764
  d = self._datasets
1103
765
  dv = self._datasets_versions
1104
766
 
1105
767
  query = self._datasets_select(
1106
- *(getattr(n.c, f) for f in namespace_fields),
1107
- *(getattr(p.c, f) for f in project_fields),
1108
768
  *(getattr(d.c, f) for f in dataset_fields),
1109
769
  *(getattr(dv.c, f) for f in dataset_version_fields),
1110
770
  )
1111
- j = (
1112
- n.join(p, n.c.id == p.c.namespace_id)
1113
- .join(d, p.c.id == d.c.project_id)
1114
- .join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
1115
- )
771
+ j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
1116
772
  return query.select_from(j)
1117
773
 
1118
774
  def _base_dataset_query(self) -> "Select":
1119
775
  return self._get_dataset_query(
1120
- self._namespaces_fields,
1121
- self._projects_fields,
1122
- self._dataset_fields,
1123
- self._dataset_version_fields,
776
+ self._dataset_fields, self._dataset_version_fields
1124
777
  )
1125
778
 
1126
779
  def _base_list_datasets_query(self) -> "Select":
1127
780
  return self._get_dataset_query(
1128
- self._namespaces_fields,
1129
- self._projects_fields,
1130
- self._dataset_list_fields,
1131
- self._dataset_list_version_fields,
1132
- isouter=False,
781
+ self._dataset_list_fields, self._dataset_list_version_fields, isouter=False
1133
782
  )
1134
783
 
1135
- def list_datasets(
1136
- self, project_id: Optional[int] = None
1137
- ) -> Iterator["DatasetListRecord"]:
784
+ def list_datasets(self) -> Iterator["DatasetListRecord"]:
1138
785
  """Lists all datasets."""
1139
- d = self._datasets
1140
786
  query = self._base_list_datasets_query().order_by(
1141
787
  self._datasets.c.name, self._datasets_versions.c.version
1142
788
  )
1143
- if project_id:
1144
- query = query.where(d.c.project_id == project_id)
1145
789
  yield from self._parse_dataset_list(self.db.execute(query))
1146
790
 
1147
791
  def list_datasets_by_prefix(
1148
- self, prefix: str, project_id: Optional[int] = None, conn=None
792
+ self, prefix: str, conn=None
1149
793
  ) -> Iterator["DatasetListRecord"]:
1150
- d = self._datasets
1151
794
  query = self._base_list_datasets_query()
1152
- if project_id:
1153
- query = query.where(d.c.project_id == project_id)
1154
795
  query = query.where(self._datasets.c.name.startswith(prefix))
1155
796
  yield from self._parse_dataset_list(self.db.execute(query))
1156
797
 
1157
- def get_dataset(
1158
- self,
1159
- name: str, # normal, not full dataset name
1160
- project_id: Optional[int] = None,
1161
- conn=None,
1162
- ) -> DatasetRecord:
1163
- """
1164
- Gets a single dataset in project by dataset name.
1165
- """
1166
- project_id = project_id or self.default_project.id
798
+ def get_dataset(self, name: str, conn=None) -> DatasetRecord:
799
+ """Gets a single dataset by name"""
1167
800
  d = self._datasets
1168
801
  query = self._base_dataset_query()
1169
- query = query.where(d.c.name == name, d.c.project_id == project_id) # type: ignore [attr-defined]
802
+ query = query.where(d.c.name == name) # type: ignore [attr-defined]
1170
803
  ds = self._parse_dataset(self.db.execute(query, conn=conn))
1171
804
  if not ds:
1172
- raise DatasetNotFoundError(
1173
- f"Dataset {name} not found in project {project_id}"
1174
- )
805
+ raise DatasetNotFoundError(f"Dataset {name} not found.")
1175
806
  return ds
1176
807
 
1177
808
  def remove_dataset_version(
@@ -1241,20 +872,23 @@ class AbstractDBMetastore(AbstractMetastore):
1241
872
  #
1242
873
  def add_dataset_dependency(
1243
874
  self,
1244
- source_dataset: "DatasetRecord",
875
+ source_dataset_name: str,
1245
876
  source_dataset_version: str,
1246
- dep_dataset: "DatasetRecord",
1247
- dep_dataset_version: str,
877
+ dataset_name: str,
878
+ dataset_version: str,
1248
879
  ) -> None:
1249
880
  """Adds dataset dependency to dataset."""
881
+ source_dataset = self.get_dataset(source_dataset_name)
882
+ dataset = self.get_dataset(dataset_name)
883
+
1250
884
  self.db.execute(
1251
885
  self._datasets_dependencies_insert().values(
1252
886
  source_dataset_id=source_dataset.id,
1253
887
  source_dataset_version_id=(
1254
888
  source_dataset.get_version(source_dataset_version).id
1255
889
  ),
1256
- dataset_id=dep_dataset.id,
1257
- dataset_version_id=dep_dataset.get_version(dep_dataset_version).id,
890
+ dataset_id=dataset.id,
891
+ dataset_version_id=dataset.get_version(dataset_version).id,
1258
892
  )
1259
893
  )
1260
894
 
@@ -1296,8 +930,6 @@ class AbstractDBMetastore(AbstractMetastore):
1296
930
  def get_direct_dataset_dependencies(
1297
931
  self, dataset: DatasetRecord, version: str
1298
932
  ) -> list[Optional[DatasetDependency]]:
1299
- n = self._namespaces
1300
- p = self._projects
1301
933
  d = self._datasets
1302
934
  dd = self._datasets_dependencies
1303
935
  dv = self._datasets_versions
@@ -1309,16 +941,18 @@ class AbstractDBMetastore(AbstractMetastore):
1309
941
  query = (
1310
942
  self._datasets_dependencies_select(*select_cols)
1311
943
  .select_from(
1312
- dd.join(d, dd.c.dataset_id == d.c.id, isouter=True)
1313
- .join(dv, dd.c.dataset_version_id == dv.c.id, isouter=True)
1314
- .join(p, d.c.project_id == p.c.id, isouter=True)
1315
- .join(n, p.c.namespace_id == n.c.id, isouter=True)
944
+ dd.join(d, dd.c.dataset_id == d.c.id, isouter=True).join(
945
+ dv, dd.c.dataset_version_id == dv.c.id, isouter=True
946
+ )
1316
947
  )
1317
948
  .where(
1318
949
  (dd.c.source_dataset_id == dataset.id)
1319
950
  & (dd.c.source_dataset_version_id == dataset_version.id)
1320
951
  )
1321
952
  )
953
+ if version:
954
+ dataset_version = dataset.get_version(version)
955
+ query = query.where(dd.c.source_dataset_version_id == dataset_version.id)
1322
956
 
1323
957
  return [self.dependency_class.parse(*r) for r in self.db.execute(query)]
1324
958