datachain 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (46) hide show
  1. datachain/__init__.py +3 -4
  2. datachain/cache.py +10 -4
  3. datachain/catalog/catalog.py +42 -16
  4. datachain/cli.py +48 -32
  5. datachain/data_storage/metastore.py +24 -0
  6. datachain/data_storage/warehouse.py +3 -1
  7. datachain/job.py +56 -0
  8. datachain/lib/arrow.py +19 -7
  9. datachain/lib/clip.py +89 -66
  10. datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
  11. datachain/lib/convert/sql_to_python.py +23 -0
  12. datachain/lib/convert/values_to_tuples.py +51 -33
  13. datachain/lib/data_model.py +6 -27
  14. datachain/lib/dataset_info.py +70 -0
  15. datachain/lib/dc.py +618 -156
  16. datachain/lib/file.py +130 -22
  17. datachain/lib/image.py +1 -1
  18. datachain/lib/meta_formats.py +14 -2
  19. datachain/lib/model_store.py +3 -2
  20. datachain/lib/pytorch.py +10 -7
  21. datachain/lib/signal_schema.py +19 -11
  22. datachain/lib/text.py +2 -1
  23. datachain/lib/udf.py +56 -5
  24. datachain/lib/udf_signature.py +1 -1
  25. datachain/node.py +11 -8
  26. datachain/query/dataset.py +62 -28
  27. datachain/query/schema.py +2 -0
  28. datachain/query/session.py +4 -4
  29. datachain/sql/functions/array.py +12 -0
  30. datachain/sql/functions/string.py +8 -0
  31. datachain/torch/__init__.py +1 -1
  32. datachain/utils.py +6 -0
  33. datachain-0.2.13.dist-info/METADATA +411 -0
  34. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/RECORD +38 -42
  35. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/WHEEL +1 -1
  36. datachain/lib/gpt4_vision.py +0 -97
  37. datachain/lib/hf_image_to_text.py +0 -97
  38. datachain/lib/hf_pipeline.py +0 -90
  39. datachain/lib/image_transform.py +0 -103
  40. datachain/lib/iptc_exif_xmp.py +0 -76
  41. datachain/lib/unstructured.py +0 -41
  42. datachain/text/__init__.py +0 -3
  43. datachain-0.2.11.dist-info/METADATA +0 -431
  44. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/LICENSE +0 -0
  45. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/entry_points.txt +0 -0
  46. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/top_level.txt +0 -0
datachain/node.py CHANGED
@@ -5,7 +5,7 @@ import attrs
5
5
 
6
6
  from datachain.cache import UniqueId
7
7
  from datachain.storage import StorageURI
8
- from datachain.utils import time_to_str
8
+ from datachain.utils import TIME_ZERO, time_to_str
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from typing_extensions import Self
@@ -111,13 +111,16 @@ class Node:
111
111
  if storage is None:
112
112
  storage = self.source
113
113
  return UniqueId(
114
- storage,
115
- self.parent,
116
- self.name,
117
- self.etag,
118
- self.size,
119
- self.vtype,
120
- self.location,
114
+ storage=storage,
115
+ parent=self.parent,
116
+ name=self.name,
117
+ size=self.size,
118
+ version=self.version or "",
119
+ etag=self.etag,
120
+ is_latest=self.is_latest,
121
+ vtype=self.vtype,
122
+ location=self.location,
123
+ last_modified=self.last_modified or TIME_ZERO,
121
124
  )
122
125
 
123
126
  @classmethod
@@ -54,6 +54,7 @@ from datachain.utils import (
54
54
  batched,
55
55
  determine_processes,
56
56
  filtered_cloudpickle_dumps,
57
+ get_datachain_executable,
57
58
  )
58
59
 
59
60
  from .metrics import metrics
@@ -426,7 +427,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
426
427
 
427
428
 
428
429
  @frozen
429
- class UDF(Step, ABC):
430
+ class UDFStep(Step, ABC):
430
431
  udf: UDFType
431
432
  catalog: "Catalog"
432
433
  partition_by: Optional[PartitionByType] = None
@@ -507,13 +508,12 @@ class UDF(Step, ABC):
507
508
 
508
509
  # Run the UDFDispatcher in another process to avoid needing
509
510
  # if __name__ == '__main__': in user scripts
510
- datachain_exec_path = os.environ.get("DATACHAIN_EXEC_PATH", "datachain")
511
-
511
+ exec_cmd = get_datachain_executable()
512
512
  envs = dict(os.environ)
513
513
  envs.update({"PYTHONPATH": os.getcwd()})
514
514
  process_data = filtered_cloudpickle_dumps(udf_info)
515
515
  result = subprocess.run( # noqa: S603
516
- [datachain_exec_path, "--internal-run-udf"],
516
+ [*exec_cmd, "internal-run-udf"],
517
517
  input=process_data,
518
518
  check=False,
519
519
  env=envs,
@@ -635,7 +635,7 @@ class UDF(Step, ABC):
635
635
 
636
636
 
637
637
  @frozen
638
- class UDFSignal(UDF):
638
+ class UDFSignal(UDFStep):
639
639
  is_generator = False
640
640
 
641
641
  def create_udf_table(self, query: Select) -> "Table":
@@ -730,7 +730,7 @@ class UDFSignal(UDF):
730
730
 
731
731
 
732
732
  @frozen
733
- class RowGenerator(UDF):
733
+ class RowGenerator(UDFStep):
734
734
  """Extend dataset with new rows."""
735
735
 
736
736
  is_generator = True
@@ -820,8 +820,16 @@ class SQLMutate(SQLClause):
820
820
  args: tuple[ColumnElement, ...]
821
821
 
822
822
  def apply_sql_clause(self, query: Select) -> Select:
823
- subquery = query.subquery()
824
- return sqlalchemy.select(*subquery.c, *self.args).select_from(subquery)
823
+ original_subquery = query.subquery()
824
+ # this is needed for new column to be used in clauses
825
+ # like ORDER BY, otherwise new column is not recognized
826
+ subquery = (
827
+ sqlalchemy.select(*original_subquery.c, *self.args)
828
+ .select_from(original_subquery)
829
+ .subquery()
830
+ )
831
+
832
+ return sqlalchemy.select(*subquery.c).select_from(subquery)
825
833
 
826
834
 
827
835
  @frozen
@@ -865,6 +873,18 @@ class SQLCount(SQLClause):
865
873
  return sqlalchemy.select(f.count(1)).select_from(query.subquery())
866
874
 
867
875
 
876
+ @frozen
877
+ class SQLDistinct(SQLClause):
878
+ args: tuple[ColumnElement, ...]
879
+ dialect: str
880
+
881
+ def apply_sql_clause(self, query):
882
+ if self.dialect == "sqlite":
883
+ return query.group_by(*self.args)
884
+
885
+ return query.distinct(*self.args)
886
+
887
+
868
888
  @frozen
869
889
  class SQLUnion(Step):
870
890
  query1: "DatasetQuery"
@@ -946,12 +966,15 @@ class SQLJoin(Step):
946
966
 
947
967
  q1_columns = list(q1.c)
948
968
  q1_column_names = {c.name for c in q1_columns}
949
- q2_columns = [
950
- c
951
- if c.name not in q1_column_names and c.name != "sys__id"
952
- else c.label(self.rname.format(name=c.name))
953
- for c in q2.c
954
- ]
969
+
970
+ q2_columns = []
971
+ for c in q2.c:
972
+ if c.name.startswith("sys__"):
973
+ continue
974
+
975
+ if c.name in q1_column_names:
976
+ c = c.label(self.rname.format(name=c.name))
977
+ q2_columns.append(c)
955
978
 
956
979
  res_columns = q1_columns + q2_columns
957
980
  predicates = (
@@ -1058,6 +1081,7 @@ class DatasetQuery:
1058
1081
  anon: bool = False,
1059
1082
  indexing_feature_schema: Optional[dict] = None,
1060
1083
  indexing_column_types: Optional[dict[str, Any]] = None,
1084
+ update: Optional[bool] = False,
1061
1085
  ):
1062
1086
  if client_config is None:
1063
1087
  client_config = {}
@@ -1080,10 +1104,12 @@ class DatasetQuery:
1080
1104
  self.session = Session.get(session, catalog=catalog)
1081
1105
 
1082
1106
  if path:
1083
- self.starting_step = IndexingStep(path, self.catalog, {}, recursive)
1107
+ kwargs = {"update": True} if update else {}
1108
+ self.starting_step = IndexingStep(path, self.catalog, kwargs, recursive)
1084
1109
  self.feature_schema = indexing_feature_schema
1085
1110
  self.column_types = indexing_column_types
1086
1111
  elif name:
1112
+ self.name = name
1087
1113
  ds = self.catalog.get_dataset(name)
1088
1114
  self.version = version or ds.latest_version
1089
1115
  self.feature_schema = ds.get_version(self.version).feature_schema
@@ -1091,9 +1117,6 @@ class DatasetQuery:
1091
1117
  if "sys__id" in self.column_types:
1092
1118
  self.column_types.pop("sys__id")
1093
1119
  self.starting_step = QueryStep(self.catalog, name, self.version)
1094
- # attaching to specific dataset
1095
- self.name = name
1096
- self.version = version
1097
1120
  else:
1098
1121
  raise ValueError("must provide path or name")
1099
1122
 
@@ -1102,7 +1125,7 @@ class DatasetQuery:
1102
1125
  return bool(re.compile(r"^[a-zA-Z0-9]+://").match(path))
1103
1126
 
1104
1127
  def __iter__(self):
1105
- return iter(self.results())
1128
+ return iter(self.db_results())
1106
1129
 
1107
1130
  def __or__(self, other):
1108
1131
  return self.union(other)
@@ -1223,13 +1246,16 @@ class DatasetQuery:
1223
1246
  warehouse.close()
1224
1247
  self.temp_table_names = []
1225
1248
 
1226
- def results(self, row_factory=None, **kwargs):
1249
+ def db_results(self, row_factory=None, **kwargs):
1227
1250
  with self.as_iterable(**kwargs) as result:
1228
1251
  if row_factory:
1229
1252
  cols = result.columns
1230
1253
  return [row_factory(cols, r) for r in result]
1231
1254
  return list(result)
1232
1255
 
1256
+ def to_db_records(self) -> list[dict[str, Any]]:
1257
+ return self.db_results(lambda cols, row: dict(zip(cols, row)))
1258
+
1233
1259
  @contextlib.contextmanager
1234
1260
  def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
1235
1261
  try:
@@ -1289,9 +1315,6 @@ class DatasetQuery:
1289
1315
  finally:
1290
1316
  self.cleanup()
1291
1317
 
1292
- def to_records(self) -> list[dict[str, Any]]:
1293
- return self.results(lambda cols, row: dict(zip(cols, row)))
1294
-
1295
1318
  def shuffle(self) -> "Self":
1296
1319
  # ToDo: implement shaffle based on seed and/or generating random column
1297
1320
  return self.order_by(C.sys__rand)
@@ -1407,6 +1430,14 @@ class DatasetQuery:
1407
1430
  query.steps.append(SQLOffset(offset))
1408
1431
  return query
1409
1432
 
1433
+ @detach
1434
+ def distinct(self, *args) -> "Self":
1435
+ query = self.clone()
1436
+ query.steps.append(
1437
+ SQLDistinct(args, dialect=self.catalog.warehouse.db.dialect.name)
1438
+ )
1439
+ return query
1440
+
1410
1441
  def as_scalar(self) -> Any:
1411
1442
  with self.as_iterable() as rows:
1412
1443
  row = next(iter(rows))
@@ -1705,10 +1736,13 @@ def _send_result(dataset_query: DatasetQuery) -> None:
1705
1736
 
1706
1737
  columns = preview_args.get("columns") or []
1707
1738
 
1708
- preview_query = (
1709
- dataset_query.select(*columns)
1710
- .limit(preview_args.get("limit", 10))
1711
- .offset(preview_args.get("offset", 0))
1739
+ if type(dataset_query) is DatasetQuery:
1740
+ preview_query = dataset_query.select(*columns)
1741
+ else:
1742
+ preview_query = dataset_query.select(*columns, _sys=False)
1743
+
1744
+ preview_query = preview_query.limit(preview_args.get("limit", 10)).offset(
1745
+ preview_args.get("offset", 0)
1712
1746
  )
1713
1747
 
1714
1748
  dataset: Optional[tuple[str, int]] = None
@@ -1717,7 +1751,7 @@ def _send_result(dataset_query: DatasetQuery) -> None:
1717
1751
  assert dataset_query.version, "Dataset version should be provided"
1718
1752
  dataset = dataset_query.name, dataset_query.version
1719
1753
 
1720
- preview = preview_query.to_records()
1754
+ preview = preview_query.to_db_records()
1721
1755
  result = ExecutionResult(preview, dataset, metrics)
1722
1756
  data = attrs.asdict(result)
1723
1757
 
datachain/query/schema.py CHANGED
@@ -32,6 +32,7 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
32
32
  inherit_cache: Optional[bool] = True
33
33
 
34
34
  def __init__(self, text, type_=None, is_literal=False, _selectable=None):
35
+ """Dataset column."""
35
36
  self.name = ColumnMeta.to_db_name(text)
36
37
  super().__init__(
37
38
  self.name, type_=type_, is_literal=is_literal, _selectable=_selectable
@@ -41,6 +42,7 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
41
42
  return Column(self.name + DEFAULT_DELIMITER + name)
42
43
 
43
44
  def glob(self, glob_str):
45
+ """Search for matches using glob pattern matching."""
44
46
  return self.op("GLOB")(glob_str)
45
47
 
46
48
 
@@ -28,9 +28,9 @@ class Session:
28
28
 
29
29
  Parameters:
30
30
 
31
- `name` (str): The name of the session. Only latters and numbers are supported.
31
+ name (str): The name of the session. Only latters and numbers are supported.
32
32
  It can be empty.
33
- `catalog` (Catalog): Catalog object.
33
+ catalog (Catalog): Catalog object.
34
34
  """
35
35
 
36
36
  GLOBAL_SESSION_CTX: Optional["Session"] = None
@@ -80,9 +80,9 @@ class Session:
80
80
  """Creates a Session() object from a catalog.
81
81
 
82
82
  Parameters:
83
- `session` (Session): Optional Session(). If not provided a new session will
83
+ session (Session): Optional Session(). If not provided a new session will
84
84
  be created. It's needed mostly for simplie API purposes.
85
- `catalog` (Catalog): Optional catalog. By default a new catalog is created.
85
+ catalog (Catalog): Optional catalog. By default a new catalog is created.
86
86
  """
87
87
  if session:
88
88
  return session
@@ -5,6 +5,10 @@ from datachain.sql.utils import compiler_not_implemented
5
5
 
6
6
 
7
7
  class cosine_distance(GenericFunction): # noqa: N801
8
+ """
9
+ Takes a column and array and returns the cosine distance between them.
10
+ """
11
+
8
12
  type = Float()
9
13
  package = "array"
10
14
  name = "cosine_distance"
@@ -12,6 +16,10 @@ class cosine_distance(GenericFunction): # noqa: N801
12
16
 
13
17
 
14
18
  class euclidean_distance(GenericFunction): # noqa: N801
19
+ """
20
+ Takes a column and array and returns the Euclidean distance between them.
21
+ """
22
+
15
23
  type = Float()
16
24
  package = "array"
17
25
  name = "euclidean_distance"
@@ -19,6 +27,10 @@ class euclidean_distance(GenericFunction): # noqa: N801
19
27
 
20
28
 
21
29
  class length(GenericFunction): # noqa: N801
30
+ """
31
+ Returns the length of the array.
32
+ """
33
+
22
34
  type = Int64()
23
35
  package = "array"
24
36
  name = "length"
@@ -5,6 +5,10 @@ from datachain.sql.utils import compiler_not_implemented
5
5
 
6
6
 
7
7
  class length(GenericFunction): # noqa: N801
8
+ """
9
+ Returns the length of the string.
10
+ """
11
+
8
12
  type = Int64()
9
13
  package = "string"
10
14
  name = "length"
@@ -12,6 +16,10 @@ class length(GenericFunction): # noqa: N801
12
16
 
13
17
 
14
18
  class split(GenericFunction): # noqa: N801
19
+ """
20
+ Takes a column and split character and returns an array of the parts.
21
+ """
22
+
15
23
  type = Array(String())
16
24
  package = "string"
17
25
  name = "split"
@@ -1,5 +1,5 @@
1
1
  try:
2
- from datachain.lib.clip import similarity_scores as clip_similarity_scores
2
+ from datachain.lib.clip import clip_similarity_scores
3
3
  from datachain.lib.image import convert_image, convert_images
4
4
  from datachain.lib.pytorch import PytorchDataset, label_to_int
5
5
  from datachain.lib.text import convert_text
datachain/utils.py CHANGED
@@ -427,3 +427,9 @@ def filtered_cloudpickle_dumps(obj: Any) -> bytes:
427
427
  for model_class, namespace in model_namespaces.items():
428
428
  # Restore original __pydantic_parent_namespace__ locally.
429
429
  model_class.__pydantic_parent_namespace__ = namespace
430
+
431
+
432
+ def get_datachain_executable() -> list[str]:
433
+ if datachain_exec_path := os.getenv("DATACHAIN_EXEC_PATH"):
434
+ return [datachain_exec_path]
435
+ return [sys.executable, "-m", "datachain"]