datachain 0.20.4__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (47) hide show
  1. datachain/__init__.py +0 -2
  2. datachain/cache.py +2 -2
  3. datachain/catalog/catalog.py +65 -180
  4. datachain/cli/__init__.py +7 -0
  5. datachain/cli/commands/datasets.py +28 -43
  6. datachain/cli/commands/ls.py +2 -2
  7. datachain/cli/parser/__init__.py +35 -1
  8. datachain/client/fsspec.py +3 -5
  9. datachain/client/hf.py +0 -10
  10. datachain/client/local.py +4 -4
  11. datachain/data_storage/metastore.py +37 -405
  12. datachain/data_storage/sqlite.py +7 -136
  13. datachain/data_storage/warehouse.py +7 -26
  14. datachain/dataset.py +12 -126
  15. datachain/delta.py +7 -11
  16. datachain/error.py +0 -36
  17. datachain/func/func.py +1 -1
  18. datachain/lib/arrow.py +3 -3
  19. datachain/lib/dataset_info.py +0 -4
  20. datachain/lib/dc/datachain.py +92 -260
  21. datachain/lib/dc/datasets.py +50 -104
  22. datachain/lib/dc/listings.py +3 -3
  23. datachain/lib/dc/records.py +0 -1
  24. datachain/lib/dc/storage.py +40 -38
  25. datachain/lib/file.py +23 -77
  26. datachain/lib/listing.py +1 -3
  27. datachain/lib/meta_formats.py +1 -1
  28. datachain/lib/pytorch.py +1 -1
  29. datachain/lib/settings.py +0 -10
  30. datachain/lib/tar.py +2 -1
  31. datachain/lib/udf_signature.py +1 -1
  32. datachain/lib/webdataset.py +20 -30
  33. datachain/listing.py +1 -3
  34. datachain/query/dataset.py +46 -71
  35. datachain/query/session.py +1 -1
  36. datachain/remote/studio.py +26 -61
  37. datachain/studio.py +7 -23
  38. {datachain-0.20.4.dist-info → datachain-0.21.0.dist-info}/METADATA +2 -2
  39. {datachain-0.20.4.dist-info → datachain-0.21.0.dist-info}/RECORD +43 -47
  40. datachain/lib/namespaces.py +0 -71
  41. datachain/lib/projects.py +0 -86
  42. datachain/namespace.py +0 -65
  43. datachain/project.py +0 -78
  44. {datachain-0.20.4.dist-info → datachain-0.21.0.dist-info}/WHEEL +0 -0
  45. {datachain-0.20.4.dist-info → datachain-0.21.0.dist-info}/entry_points.txt +0 -0
  46. {datachain-0.20.4.dist-info → datachain-0.21.0.dist-info}/licenses/LICENSE +0 -0
  47. {datachain-0.20.4.dist-info → datachain-0.21.0.dist-info}/top_level.txt +0 -0
@@ -37,13 +37,9 @@ from datachain.dataset import (
37
37
  from datachain.error import (
38
38
  DatasetNotFoundError,
39
39
  DatasetVersionNotFoundError,
40
- NamespaceNotFoundError,
41
- ProjectNotFoundError,
42
40
  TableMissingError,
43
41
  )
44
42
  from datachain.job import Job
45
- from datachain.namespace import Namespace
46
- from datachain.project import Project
47
43
  from datachain.utils import JSONSerialize
48
44
 
49
45
  if TYPE_CHECKING:
@@ -65,8 +61,6 @@ class AbstractMetastore(ABC, Serializable):
65
61
  uri: StorageURI
66
62
 
67
63
  schema: "schema.Schema"
68
- namespace_class: type[Namespace] = Namespace
69
- project_class: type[Project] = Project
70
64
  dataset_class: type[DatasetRecord] = DatasetRecord
71
65
  dataset_list_class: type[DatasetListRecord] = DatasetListRecord
72
66
  dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
@@ -112,115 +106,14 @@ class AbstractMetastore(ABC, Serializable):
112
106
  def cleanup_for_tests(self) -> None:
113
107
  """Cleanup for tests."""
114
108
 
115
- #
116
- # Namespaces
117
- #
118
-
119
- @property
120
- @abstractmethod
121
- def default_namespace_name(self):
122
- """Gets default namespace name"""
123
-
124
- @property
125
- def system_namespace_name(self):
126
- return Namespace.system()
127
-
128
- @abstractmethod
129
- def create_namespace(
130
- self,
131
- name: str,
132
- description: Optional[str] = None,
133
- uuid: Optional[str] = None,
134
- ignore_if_exists: bool = True,
135
- **kwargs,
136
- ) -> Namespace:
137
- """Creates new namespace"""
138
-
139
- @abstractmethod
140
- def get_namespace(self, name: str, conn=None) -> Namespace:
141
- """Gets a single namespace by name"""
142
-
143
- @abstractmethod
144
- def list_namespaces(self, conn=None) -> list[Namespace]:
145
- """Gets a list of all namespaces"""
146
-
147
- @property
148
- @abstractmethod
149
- def is_studio(self) -> bool:
150
- """Returns True if this code is ran in Studio"""
151
-
152
- def is_local_dataset(self, dataset_namespace: str) -> bool:
153
- """
154
- Returns True if this is local dataset i.e. not pulled from Studio but
155
- created locally. This is False if we ran code in CLI mode but using dataset
156
- names that are present in Studio.
157
- """
158
- return self.is_studio or dataset_namespace == Namespace.default()
159
-
160
- @property
161
- def namespace_allowed_to_create(self):
162
- return self.is_studio
163
-
164
- #
165
- # Projects
166
- #
167
-
168
- @property
169
- @abstractmethod
170
- def default_project_name(self):
171
- """Gets default project name"""
172
-
173
- @property
174
- def listing_project_name(self):
175
- return Project.listing()
176
-
177
- @cached_property
178
- def default_project(self) -> Project:
179
- return self.get_project(
180
- self.default_project_name, self.default_namespace_name, create=True
181
- )
182
-
183
- @cached_property
184
- def listing_project(self) -> Project:
185
- return self.get_project(self.listing_project_name, self.system_namespace_name)
186
-
187
- @abstractmethod
188
- def create_project(
189
- self,
190
- namespace_name: str,
191
- name: str,
192
- description: Optional[str] = None,
193
- uuid: Optional[str] = None,
194
- ignore_if_exists: bool = True,
195
- **kwargs,
196
- ) -> Project:
197
- """Creates new project in specific namespace"""
198
-
199
- @abstractmethod
200
- def get_project(
201
- self, name: str, namespace_name: str, create: bool = False, conn=None
202
- ) -> Project:
203
- """
204
- Gets a single project inside some namespace by name.
205
- It also creates project if not found and create flag is set to True.
206
- """
207
-
208
- @abstractmethod
209
- def list_projects(self, namespace_id: Optional[int], conn=None) -> list[Project]:
210
- """Gets list of projects in some namespace or in general (in all namespaces)"""
211
-
212
- @property
213
- def project_allowed_to_create(self):
214
- return self.is_studio
215
-
216
109
  #
217
110
  # Datasets
218
111
  #
112
+
219
113
  @abstractmethod
220
114
  def create_dataset(
221
115
  self,
222
116
  name: str,
223
- project_id: Optional[int] = None,
224
117
  status: int = DatasetStatus.CREATED,
225
118
  sources: Optional[list[str]] = None,
226
119
  feature_schema: Optional[dict] = None,
@@ -280,22 +173,15 @@ class AbstractMetastore(ABC, Serializable):
280
173
  """
281
174
 
282
175
  @abstractmethod
283
- def list_datasets(
284
- self, project_id: Optional[int] = None
285
- ) -> Iterator[DatasetListRecord]:
286
- """Lists all datasets in some project or in all projects."""
176
+ def list_datasets(self) -> Iterator[DatasetListRecord]:
177
+ """Lists all datasets."""
287
178
 
288
179
  @abstractmethod
289
- def list_datasets_by_prefix(
290
- self, prefix: str, project_id: Optional[int] = None
291
- ) -> Iterator["DatasetListRecord"]:
292
- """
293
- Lists all datasets which names start with prefix in some project or in all
294
- projects.
295
- """
180
+ def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetListRecord"]:
181
+ """Lists all datasets which names start with prefix."""
296
182
 
297
183
  @abstractmethod
298
- def get_dataset(self, name: str, project_id: Optional[int] = None) -> DatasetRecord:
184
+ def get_dataset(self, name: str) -> DatasetRecord:
299
185
  """Gets a single dataset by name."""
300
186
 
301
187
  @abstractmethod
@@ -316,10 +202,10 @@ class AbstractMetastore(ABC, Serializable):
316
202
  @abstractmethod
317
203
  def add_dataset_dependency(
318
204
  self,
319
- source_dataset: "DatasetRecord",
205
+ source_dataset_name: str,
320
206
  source_dataset_version: str,
321
- dep_dataset: "DatasetRecord",
322
- dep_dataset_version: str,
207
+ dataset_name: str,
208
+ dataset_version: str,
323
209
  ) -> None:
324
210
  """Adds dataset dependency to dataset."""
325
211
 
@@ -418,8 +304,6 @@ class AbstractDBMetastore(AbstractMetastore):
418
304
  and has shared logic for all database systems currently in use.
419
305
  """
420
306
 
421
- NAMESPACE_TABLE = "namespaces"
422
- PROJECT_TABLE = "projects"
423
307
  DATASET_TABLE = "datasets"
424
308
  DATASET_VERSION_TABLE = "datasets_versions"
425
309
  DATASET_DEPENDENCY_TABLE = "datasets_dependencies"
@@ -438,62 +322,11 @@ class AbstractDBMetastore(AbstractMetastore):
438
322
  def cleanup_tables(self, temp_table_names: list[str]) -> None:
439
323
  """Cleanup temp tables."""
440
324
 
441
- @classmethod
442
- def _namespaces_columns(cls) -> list["SchemaItem"]:
443
- """Namespace table columns."""
444
- return [
445
- Column("id", Integer, primary_key=True),
446
- Column("uuid", Text, nullable=False, default=uuid4()),
447
- Column("name", Text, nullable=False),
448
- Column("description", Text),
449
- Column("created_at", DateTime(timezone=True)),
450
- ]
451
-
452
- @cached_property
453
- def _namespaces_fields(self) -> list[str]:
454
- return [
455
- c.name # type: ignore [attr-defined]
456
- for c in self._namespaces_columns()
457
- if c.name # type: ignore [attr-defined]
458
- ]
459
-
460
- @classmethod
461
- def _projects_columns(cls) -> list["SchemaItem"]:
462
- """Project table columns."""
463
- return [
464
- Column("id", Integer, primary_key=True),
465
- Column("uuid", Text, nullable=False, default=uuid4()),
466
- Column("name", Text, nullable=False),
467
- Column("description", Text),
468
- Column("created_at", DateTime(timezone=True)),
469
- Column(
470
- "namespace_id",
471
- Integer,
472
- ForeignKey(f"{cls.NAMESPACE_TABLE}.id", ondelete="CASCADE"),
473
- nullable=False,
474
- ),
475
- UniqueConstraint("namespace_id", "name"),
476
- ]
477
-
478
- @cached_property
479
- def _projects_fields(self) -> list[str]:
480
- return [
481
- c.name # type: ignore [attr-defined]
482
- for c in self._projects_columns()
483
- if c.name # type: ignore [attr-defined]
484
- ]
485
-
486
325
  @classmethod
487
326
  def _datasets_columns(cls) -> list["SchemaItem"]:
488
327
  """Datasets table columns."""
489
328
  return [
490
329
  Column("id", Integer, primary_key=True),
491
- Column(
492
- "project_id",
493
- Integer,
494
- ForeignKey(f"{cls.PROJECT_TABLE}.id", ondelete="CASCADE"),
495
- nullable=False,
496
- ),
497
330
  Column("name", Text, nullable=False),
498
331
  Column("description", Text),
499
332
  Column("attrs", JSON, nullable=True),
@@ -612,16 +445,6 @@ class AbstractDBMetastore(AbstractMetastore):
612
445
  #
613
446
  # Query Tables
614
447
  #
615
- @cached_property
616
- def _namespaces(self) -> Table:
617
- return Table(
618
- self.NAMESPACE_TABLE, self.db.metadata, *self._namespaces_columns()
619
- )
620
-
621
- @cached_property
622
- def _projects(self) -> Table:
623
- return Table(self.PROJECT_TABLE, self.db.metadata, *self._projects_columns())
624
-
625
448
  @cached_property
626
449
  def _datasets(self) -> Table:
627
450
  return Table(self.DATASET_TABLE, self.db.metadata, *self._datasets_columns())
@@ -645,34 +468,6 @@ class AbstractDBMetastore(AbstractMetastore):
645
468
  #
646
469
  # Query Starters (These can be overridden by subclasses)
647
470
  #
648
- @abstractmethod
649
- def _namespaces_insert(self) -> "Insert": ...
650
-
651
- def _namespaces_select(self, *columns) -> "Select":
652
- if not columns:
653
- return self._namespaces.select()
654
- return select(*columns)
655
-
656
- def _namespaces_update(self) -> "Update":
657
- return self._namespaces.update()
658
-
659
- def _namespaces_delete(self) -> "Delete":
660
- return self._namespaces.delete()
661
-
662
- @abstractmethod
663
- def _projects_insert(self) -> "Insert": ...
664
-
665
- def _projects_select(self, *columns) -> "Select":
666
- if not columns:
667
- return self._projects.select()
668
- return select(*columns)
669
-
670
- def _projects_update(self) -> "Update":
671
- return self._projects.update()
672
-
673
- def _projects_delete(self) -> "Delete":
674
- return self._projects.delete()
675
-
676
471
  @abstractmethod
677
472
  def _datasets_insert(self) -> "Insert": ...
678
473
 
@@ -715,134 +510,6 @@ class AbstractDBMetastore(AbstractMetastore):
715
510
  def _datasets_dependencies_delete(self) -> "Delete":
716
511
  return self._datasets_dependencies.delete()
717
512
 
718
- #
719
- # Namespaces
720
- #
721
-
722
- def create_namespace(
723
- self,
724
- name: str,
725
- description: Optional[str] = None,
726
- uuid: Optional[str] = None,
727
- ignore_if_exists: bool = True,
728
- **kwargs,
729
- ) -> Namespace:
730
- query = self._namespaces_insert().values(
731
- name=name,
732
- uuid=uuid or str(uuid4()),
733
- created_at=datetime.now(timezone.utc),
734
- description=description,
735
- )
736
- if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
737
- # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
738
- # but generic SQL does not
739
- query = query.on_conflict_do_nothing(index_elements=["name"])
740
- self.db.execute(query)
741
-
742
- return self.get_namespace(name)
743
-
744
- def get_namespace(self, name: str, conn=None) -> Namespace:
745
- """Gets a single namespace by name"""
746
- n = self._namespaces
747
-
748
- query = self._namespaces_select(
749
- *(getattr(n.c, f) for f in self._namespaces_fields),
750
- ).where(n.c.name == name)
751
- rows = list(self.db.execute(query, conn=conn))
752
- if not rows:
753
- raise NamespaceNotFoundError(f"Namespace {name} not found.")
754
- return self.namespace_class.parse(*rows[0])
755
-
756
- def list_namespaces(self, conn=None) -> list[Namespace]:
757
- """Gets a list of all namespaces"""
758
- n = self._namespaces
759
-
760
- query = self._namespaces_select(
761
- *(getattr(n.c, f) for f in self._namespaces_fields),
762
- )
763
- rows = list(self.db.execute(query, conn=conn))
764
-
765
- return [self.namespace_class.parse(*r) for r in rows]
766
-
767
- #
768
- # Projects
769
- #
770
-
771
- def create_project(
772
- self,
773
- namespace_name: str,
774
- name: str,
775
- description: Optional[str] = None,
776
- uuid: Optional[str] = None,
777
- ignore_if_exists: bool = True,
778
- **kwargs,
779
- ) -> Project:
780
- try:
781
- namespace = self.get_namespace(namespace_name)
782
- except NamespaceNotFoundError:
783
- namespace = self.create_namespace(namespace_name)
784
-
785
- query = self._projects_insert().values(
786
- namespace_id=namespace.id,
787
- uuid=uuid or str(uuid4()),
788
- name=name,
789
- created_at=datetime.now(timezone.utc),
790
- description=description,
791
- )
792
- if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
793
- # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
794
- # but generic SQL does not
795
- query = query.on_conflict_do_nothing(
796
- index_elements=["namespace_id", "name"]
797
- )
798
- self.db.execute(query)
799
-
800
- return self.get_project(name, namespace.name)
801
-
802
- def get_project(
803
- self, name: str, namespace_name: str, create: bool = False, conn=None
804
- ) -> Project:
805
- """Gets a single project inside some namespace by name"""
806
- n = self._namespaces
807
- p = self._projects
808
-
809
- query = self._projects_select(
810
- *(getattr(n.c, f) for f in self._namespaces_fields),
811
- *(getattr(p.c, f) for f in self._projects_fields),
812
- )
813
- query = query.select_from(n.join(p, n.c.id == p.c.namespace_id)).where(
814
- p.c.name == name, n.c.name == namespace_name
815
- )
816
-
817
- rows = list(self.db.execute(query, conn=conn))
818
- if not rows:
819
- if create:
820
- return self.create_project(namespace_name, name)
821
- raise ProjectNotFoundError(
822
- f"Project {name} in namespace {namespace_name} not found."
823
- )
824
- return self.project_class.parse(*rows[0])
825
-
826
- def list_projects(self, namespace_id: Optional[int], conn=None) -> list[Project]:
827
- """
828
- Gets a list of projects inside some namespace, or in all namespaces
829
- """
830
- n = self._namespaces
831
- p = self._projects
832
-
833
- query = self._projects_select(
834
- *(getattr(n.c, f) for f in self._namespaces_fields),
835
- *(getattr(p.c, f) for f in self._projects_fields),
836
- )
837
- query = query.select_from(n.join(p, n.c.id == p.c.namespace_id))
838
-
839
- if namespace_id:
840
- query = query.where(n.c.id == namespace_id)
841
-
842
- rows = list(self.db.execute(query, conn=conn))
843
-
844
- return [self.project_class.parse(*r) for r in rows]
845
-
846
513
  #
847
514
  # Datasets
848
515
  #
@@ -850,7 +517,6 @@ class AbstractDBMetastore(AbstractMetastore):
850
517
  def create_dataset(
851
518
  self,
852
519
  name: str,
853
- project_id: Optional[int] = None,
854
520
  status: int = DatasetStatus.CREATED,
855
521
  sources: Optional[list[str]] = None,
856
522
  feature_schema: Optional[dict] = None,
@@ -862,11 +528,9 @@ class AbstractDBMetastore(AbstractMetastore):
862
528
  **kwargs, # TODO registered = True / False
863
529
  ) -> DatasetRecord:
864
530
  """Creates new dataset."""
865
- project_id = project_id or self.default_project.id
866
-
531
+ # TODO abstract this method and add registered = True based on kwargs
867
532
  query = self._datasets_insert().values(
868
533
  name=name,
869
- project_id=project_id,
870
534
  status=status,
871
535
  feature_schema=json.dumps(feature_schema or {}),
872
536
  created_at=datetime.now(timezone.utc),
@@ -882,10 +546,10 @@ class AbstractDBMetastore(AbstractMetastore):
882
546
  if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
883
547
  # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
884
548
  # but generic SQL does not
885
- query = query.on_conflict_do_nothing(index_elements=["project_id", "name"])
549
+ query = query.on_conflict_do_nothing(index_elements=["name"])
886
550
  self.db.execute(query)
887
551
 
888
- return self.get_dataset(name, project_id)
552
+ return self.get_dataset(name)
889
553
 
890
554
  def create_dataset_version( # noqa: PLR0913
891
555
  self,
@@ -942,7 +606,7 @@ class AbstractDBMetastore(AbstractMetastore):
942
606
  )
943
607
  self.db.execute(query, conn=conn)
944
608
 
945
- return self.get_dataset(dataset.name, dataset.project.id, conn=conn)
609
+ return self.get_dataset(dataset.name, conn=conn)
946
610
 
947
611
  def remove_dataset(self, dataset: DatasetRecord) -> None:
948
612
  """Removes dataset."""
@@ -1080,15 +744,13 @@ class AbstractDBMetastore(AbstractMetastore):
1080
744
 
1081
745
  def _parse_dataset_list(self, rows) -> Iterator["DatasetListRecord"]:
1082
746
  # grouping rows by dataset id
1083
- for _, g in groupby(rows, lambda r: r[11]):
747
+ for _, g in groupby(rows, lambda r: r[0]):
1084
748
  dataset = self._parse_list_dataset(list(g))
1085
749
  if dataset:
1086
750
  yield dataset
1087
751
 
1088
752
  def _get_dataset_query(
1089
753
  self,
1090
- namespace_fields: list[str],
1091
- project_fields: list[str],
1092
754
  dataset_fields: list[str],
1093
755
  dataset_version_fields: list[str],
1094
756
  isouter: bool = True,
@@ -1099,81 +761,48 @@ class AbstractDBMetastore(AbstractMetastore):
1099
761
  ):
1100
762
  raise TableMissingError
1101
763
 
1102
- n = self._namespaces
1103
- p = self._projects
1104
764
  d = self._datasets
1105
765
  dv = self._datasets_versions
1106
766
 
1107
767
  query = self._datasets_select(
1108
- *(getattr(n.c, f) for f in namespace_fields),
1109
- *(getattr(p.c, f) for f in project_fields),
1110
768
  *(getattr(d.c, f) for f in dataset_fields),
1111
769
  *(getattr(dv.c, f) for f in dataset_version_fields),
1112
770
  )
1113
- j = (
1114
- n.join(p, n.c.id == p.c.namespace_id)
1115
- .join(d, p.c.id == d.c.project_id)
1116
- .join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
1117
- )
771
+ j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
1118
772
  return query.select_from(j)
1119
773
 
1120
774
  def _base_dataset_query(self) -> "Select":
1121
775
  return self._get_dataset_query(
1122
- self._namespaces_fields,
1123
- self._projects_fields,
1124
- self._dataset_fields,
1125
- self._dataset_version_fields,
776
+ self._dataset_fields, self._dataset_version_fields
1126
777
  )
1127
778
 
1128
779
  def _base_list_datasets_query(self) -> "Select":
1129
780
  return self._get_dataset_query(
1130
- self._namespaces_fields,
1131
- self._projects_fields,
1132
- self._dataset_list_fields,
1133
- self._dataset_list_version_fields,
1134
- isouter=False,
781
+ self._dataset_list_fields, self._dataset_list_version_fields, isouter=False
1135
782
  )
1136
783
 
1137
- def list_datasets(
1138
- self, project_id: Optional[int] = None
1139
- ) -> Iterator["DatasetListRecord"]:
784
+ def list_datasets(self) -> Iterator["DatasetListRecord"]:
1140
785
  """Lists all datasets."""
1141
- d = self._datasets
1142
786
  query = self._base_list_datasets_query().order_by(
1143
787
  self._datasets.c.name, self._datasets_versions.c.version
1144
788
  )
1145
- if project_id:
1146
- query = query.where(d.c.project_id == project_id)
1147
789
  yield from self._parse_dataset_list(self.db.execute(query))
1148
790
 
1149
791
  def list_datasets_by_prefix(
1150
- self, prefix: str, project_id: Optional[int] = None, conn=None
792
+ self, prefix: str, conn=None
1151
793
  ) -> Iterator["DatasetListRecord"]:
1152
- d = self._datasets
1153
794
  query = self._base_list_datasets_query()
1154
- if project_id:
1155
- query = query.where(d.c.project_id == project_id)
1156
795
  query = query.where(self._datasets.c.name.startswith(prefix))
1157
796
  yield from self._parse_dataset_list(self.db.execute(query))
1158
797
 
1159
- def get_dataset(
1160
- self,
1161
- name: str, # normal, not full dataset name
1162
- project_id: Optional[int] = None,
1163
- conn=None,
1164
- ) -> DatasetRecord:
1165
- """
1166
- Gets a single dataset in project by dataset name.
1167
- """
1168
- project_id = project_id or self.default_project.id
798
+ def get_dataset(self, name: str, conn=None) -> DatasetRecord:
799
+ """Gets a single dataset by name"""
1169
800
  d = self._datasets
1170
801
  query = self._base_dataset_query()
1171
- query = query.where(d.c.name == name, d.c.project_id == project_id) # type: ignore [attr-defined]
802
+ query = query.where(d.c.name == name) # type: ignore [attr-defined]
1172
803
  ds = self._parse_dataset(self.db.execute(query, conn=conn))
1173
804
  if not ds:
1174
- raise DatasetNotFoundError(
1175
- f"Dataset {name} not found in project {project_id}"
1176
- )
805
+ raise DatasetNotFoundError(f"Dataset {name} not found.")
1177
806
  return ds
1178
807
 
1179
808
  def remove_dataset_version(
@@ -1243,20 +872,23 @@ class AbstractDBMetastore(AbstractMetastore):
1243
872
  #
1244
873
  def add_dataset_dependency(
1245
874
  self,
1246
- source_dataset: "DatasetRecord",
875
+ source_dataset_name: str,
1247
876
  source_dataset_version: str,
1248
- dep_dataset: "DatasetRecord",
1249
- dep_dataset_version: str,
877
+ dataset_name: str,
878
+ dataset_version: str,
1250
879
  ) -> None:
1251
880
  """Adds dataset dependency to dataset."""
881
+ source_dataset = self.get_dataset(source_dataset_name)
882
+ dataset = self.get_dataset(dataset_name)
883
+
1252
884
  self.db.execute(
1253
885
  self._datasets_dependencies_insert().values(
1254
886
  source_dataset_id=source_dataset.id,
1255
887
  source_dataset_version_id=(
1256
888
  source_dataset.get_version(source_dataset_version).id
1257
889
  ),
1258
- dataset_id=dep_dataset.id,
1259
- dataset_version_id=dep_dataset.get_version(dep_dataset_version).id,
890
+ dataset_id=dataset.id,
891
+ dataset_version_id=dataset.get_version(dataset_version).id,
1260
892
  )
1261
893
  )
1262
894
 
@@ -1298,8 +930,6 @@ class AbstractDBMetastore(AbstractMetastore):
1298
930
  def get_direct_dataset_dependencies(
1299
931
  self, dataset: DatasetRecord, version: str
1300
932
  ) -> list[Optional[DatasetDependency]]:
1301
- n = self._namespaces
1302
- p = self._projects
1303
933
  d = self._datasets
1304
934
  dd = self._datasets_dependencies
1305
935
  dv = self._datasets_versions
@@ -1311,16 +941,18 @@ class AbstractDBMetastore(AbstractMetastore):
1311
941
  query = (
1312
942
  self._datasets_dependencies_select(*select_cols)
1313
943
  .select_from(
1314
- dd.join(d, dd.c.dataset_id == d.c.id, isouter=True)
1315
- .join(dv, dd.c.dataset_version_id == dv.c.id, isouter=True)
1316
- .join(p, d.c.project_id == p.c.id, isouter=True)
1317
- .join(n, p.c.namespace_id == n.c.id, isouter=True)
944
+ dd.join(d, dd.c.dataset_id == d.c.id, isouter=True).join(
945
+ dv, dd.c.dataset_version_id == dv.c.id, isouter=True
946
+ )
1318
947
  )
1319
948
  .where(
1320
949
  (dd.c.source_dataset_id == dataset.id)
1321
950
  & (dd.c.source_dataset_version_id == dataset_version.id)
1322
951
  )
1323
952
  )
953
+ if version:
954
+ dataset_version = dataset.get_version(version)
955
+ query = query.where(dd.c.source_dataset_version_id == dataset_version.id)
1324
956
 
1325
957
  return [self.dependency_class.parse(*r) for r in self.db.execute(query)]
1326
958