datachain 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (49) hide show
  1. datachain/__init__.py +3 -4
  2. datachain/cache.py +10 -4
  3. datachain/catalog/catalog.py +35 -15
  4. datachain/cli.py +37 -32
  5. datachain/data_storage/metastore.py +24 -0
  6. datachain/data_storage/warehouse.py +3 -1
  7. datachain/job.py +56 -0
  8. datachain/lib/arrow.py +19 -7
  9. datachain/lib/clip.py +89 -66
  10. datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
  11. datachain/lib/convert/sql_to_python.py +23 -0
  12. datachain/lib/convert/values_to_tuples.py +51 -33
  13. datachain/lib/data_model.py +6 -27
  14. datachain/lib/dataset_info.py +70 -0
  15. datachain/lib/dc.py +646 -152
  16. datachain/lib/file.py +117 -15
  17. datachain/lib/image.py +1 -1
  18. datachain/lib/meta_formats.py +14 -2
  19. datachain/lib/model_store.py +3 -2
  20. datachain/lib/pytorch.py +10 -7
  21. datachain/lib/signal_schema.py +39 -14
  22. datachain/lib/text.py +2 -1
  23. datachain/lib/udf.py +56 -5
  24. datachain/lib/udf_signature.py +1 -1
  25. datachain/lib/webdataset.py +4 -3
  26. datachain/node.py +11 -8
  27. datachain/query/dataset.py +66 -147
  28. datachain/query/dispatch.py +15 -13
  29. datachain/query/schema.py +2 -0
  30. datachain/query/session.py +4 -4
  31. datachain/sql/functions/array.py +12 -0
  32. datachain/sql/functions/string.py +8 -0
  33. datachain/torch/__init__.py +1 -1
  34. datachain/utils.py +45 -0
  35. datachain-0.2.12.dist-info/METADATA +412 -0
  36. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/RECORD +40 -45
  37. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/WHEEL +1 -1
  38. datachain/lib/feature_registry.py +0 -77
  39. datachain/lib/gpt4_vision.py +0 -97
  40. datachain/lib/hf_image_to_text.py +0 -97
  41. datachain/lib/hf_pipeline.py +0 -90
  42. datachain/lib/image_transform.py +0 -103
  43. datachain/lib/iptc_exif_xmp.py +0 -76
  44. datachain/lib/unstructured.py +0 -41
  45. datachain/text/__init__.py +0 -3
  46. datachain-0.2.10.dist-info/METADATA +0 -430
  47. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/LICENSE +0 -0
  48. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/entry_points.txt +0 -0
  49. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,9 @@ from typing import (
13
13
  get_origin,
14
14
  )
15
15
 
16
- from pydantic import BaseModel, Field
16
+ from pydantic import Field
17
17
 
18
+ from datachain.lib.data_model import DataModel
18
19
  from datachain.lib.file import File, TarVFile
19
20
  from datachain.lib.utils import DataChainError
20
21
 
@@ -45,7 +46,7 @@ class UnknownFileExtensionError(WDSError):
45
46
  super().__init__(tar_stream, f"unknown extension '{ext}' for file '{name}'")
46
47
 
47
48
 
48
- class WDSBasic(BaseModel):
49
+ class WDSBasic(DataModel):
49
50
  file: File
50
51
 
51
52
 
@@ -74,7 +75,7 @@ class WDSAllFile(WDSBasic):
74
75
  cbor: Optional[bytes] = Field(default=None)
75
76
 
76
77
 
77
- class WDSReadableSubclass(BaseModel):
78
+ class WDSReadableSubclass(DataModel):
78
79
  @staticmethod
79
80
  def _reader(builder, item: tarfile.TarInfo) -> "WDSReadableSubclass":
80
81
  raise NotImplementedError
datachain/node.py CHANGED
@@ -5,7 +5,7 @@ import attrs
5
5
 
6
6
  from datachain.cache import UniqueId
7
7
  from datachain.storage import StorageURI
8
- from datachain.utils import time_to_str
8
+ from datachain.utils import TIME_ZERO, time_to_str
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from typing_extensions import Self
@@ -111,13 +111,16 @@ class Node:
111
111
  if storage is None:
112
112
  storage = self.source
113
113
  return UniqueId(
114
- storage,
115
- self.parent,
116
- self.name,
117
- self.etag,
118
- self.size,
119
- self.vtype,
120
- self.location,
114
+ storage=storage,
115
+ parent=self.parent,
116
+ name=self.name,
117
+ size=self.size,
118
+ version=self.version or "",
119
+ etag=self.etag,
120
+ is_latest=self.is_latest,
121
+ vtype=self.vtype,
122
+ location=self.location,
123
+ last_modified=self.last_modified or TIME_ZERO,
121
124
  )
122
125
 
123
126
  @classmethod
@@ -1,4 +1,3 @@
1
- import ast
2
1
  import contextlib
3
2
  import datetime
4
3
  import inspect
@@ -10,7 +9,6 @@ import re
10
9
  import string
11
10
  import subprocess
12
11
  import sys
13
- import types
14
12
  from abc import ABC, abstractmethod
15
13
  from collections.abc import Generator, Iterable, Iterator, Sequence
16
14
  from copy import copy
@@ -26,12 +24,9 @@ from typing import (
26
24
  )
27
25
 
28
26
  import attrs
29
- import pandas as pd
30
27
  import sqlalchemy
31
28
  from attrs import frozen
32
- from dill import dumps, source
33
29
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
34
- from pydantic import BaseModel
35
30
  from sqlalchemy import Column
36
31
  from sqlalchemy.sql import func as f
37
32
  from sqlalchemy.sql.elements import ColumnClause, ColumnElement
@@ -53,10 +48,14 @@ from datachain.data_storage.schema import (
53
48
  from datachain.dataset import DatasetStatus, RowDict
54
49
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
55
50
  from datachain.progress import CombinedDownloadCallback
56
- from datachain.query.schema import DEFAULT_DELIMITER
57
51
  from datachain.sql.functions import rand
58
52
  from datachain.storage import Storage, StorageURI
59
- from datachain.utils import batched, determine_processes, inside_notebook
53
+ from datachain.utils import (
54
+ batched,
55
+ determine_processes,
56
+ filtered_cloudpickle_dumps,
57
+ get_datachain_executable,
58
+ )
60
59
 
61
60
  from .metrics import metrics
62
61
  from .schema import C, UDFParamSpec, normalize_param
@@ -428,7 +427,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
428
427
 
429
428
 
430
429
  @frozen
431
- class UDF(Step, ABC):
430
+ class UDFStep(Step, ABC):
432
431
  udf: UDFType
433
432
  catalog: "Catalog"
434
433
  partition_by: Optional[PartitionByType] = None
@@ -492,7 +491,7 @@ class UDF(Step, ABC):
492
491
  elif processes:
493
492
  # Parallel processing (faster for more CPU-heavy UDFs)
494
493
  udf_info = {
495
- "udf": self.udf,
494
+ "udf_data": filtered_cloudpickle_dumps(self.udf),
496
495
  "catalog_init": self.catalog.get_init_params(),
497
496
  "id_generator_clone_params": (
498
497
  self.catalog.id_generator.clone_params()
@@ -509,20 +508,18 @@ class UDF(Step, ABC):
509
508
 
510
509
  # Run the UDFDispatcher in another process to avoid needing
511
510
  # if __name__ == '__main__': in user scripts
512
- datachain_exec_path = os.environ.get("DATACHAIN_EXEC_PATH", "datachain")
513
-
511
+ exec_cmd = get_datachain_executable()
514
512
  envs = dict(os.environ)
515
513
  envs.update({"PYTHONPATH": os.getcwd()})
516
- with self.process_feature_module():
517
- process_data = dumps(udf_info, recurse=True)
518
- result = subprocess.run( # noqa: S603
519
- [datachain_exec_path, "--internal-run-udf"],
520
- input=process_data,
521
- check=False,
522
- env=envs,
523
- )
524
- if result.returncode != 0:
525
- raise RuntimeError("UDF Execution Failed!")
514
+ process_data = filtered_cloudpickle_dumps(udf_info)
515
+ result = subprocess.run( # noqa: S603
516
+ [*exec_cmd, "internal-run-udf"],
517
+ input=process_data,
518
+ check=False,
519
+ env=envs,
520
+ )
521
+ if result.returncode != 0:
522
+ raise RuntimeError("UDF Execution Failed!")
526
523
 
527
524
  else:
528
525
  # Otherwise process single-threaded (faster for smaller UDFs)
@@ -571,57 +568,6 @@ class UDF(Step, ABC):
571
568
  self.catalog.warehouse.close()
572
569
  raise
573
570
 
574
- @contextlib.contextmanager
575
- def process_feature_module(self):
576
- # Generate a random name for the feature module
577
- feature_module_name = "tmp" + _random_string(10)
578
- # Create a dynamic module with the generated name
579
- dynamic_module = types.ModuleType(feature_module_name)
580
- # Get the import lines for the necessary objects from the main module
581
- main_module = sys.modules["__main__"]
582
- if getattr(main_module, "__file__", None):
583
- import_lines = list(get_imports(main_module))
584
- else:
585
- import_lines = [
586
- source.getimport(obj, alias=name)
587
- for name, obj in main_module.__dict__.items()
588
- if _imports(obj) and not (name.startswith("__") and name.endswith("__"))
589
- ]
590
-
591
- # Get the feature classes from the main module
592
- feature_classes = {
593
- name: obj
594
- for name, obj in main_module.__dict__.items()
595
- if _feature_predicate(obj)
596
- }
597
- if not feature_classes:
598
- yield None
599
- return
600
-
601
- # Get the source code of the feature classes
602
- feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
603
- # Set the module name for the feature classes to the generated name
604
- for name, cls in feature_classes.items():
605
- cls.__module__ = feature_module_name
606
- setattr(dynamic_module, name, cls)
607
- # Add the dynamic module to the sys.modules dictionary
608
- sys.modules[feature_module_name] = dynamic_module
609
- # Combine the import lines and feature sources
610
- feature_file = "\n".join(import_lines) + "\n" + "\n".join(feature_sources)
611
-
612
- # Write the module content to a .py file
613
- with open(f"{feature_module_name}.py", "w") as module_file:
614
- module_file.write(feature_file)
615
-
616
- try:
617
- yield feature_module_name
618
- finally:
619
- for cls in feature_classes.values():
620
- cls.__module__ = main_module.__name__
621
- os.unlink(f"{feature_module_name}.py")
622
- # Remove the dynamic module from sys.modules
623
- del sys.modules[feature_module_name]
624
-
625
571
  def create_partitions_table(self, query: Select) -> "Table":
626
572
  """
627
573
  Create temporary table with group by partitions.
@@ -689,7 +635,7 @@ class UDF(Step, ABC):
689
635
 
690
636
 
691
637
  @frozen
692
- class UDFSignal(UDF):
638
+ class UDFSignal(UDFStep):
693
639
  is_generator = False
694
640
 
695
641
  def create_udf_table(self, query: Select) -> "Table":
@@ -784,7 +730,7 @@ class UDFSignal(UDF):
784
730
 
785
731
 
786
732
  @frozen
787
- class RowGenerator(UDF):
733
+ class RowGenerator(UDFStep):
788
734
  """Extend dataset with new rows."""
789
735
 
790
736
  is_generator = True
@@ -919,6 +865,18 @@ class SQLCount(SQLClause):
919
865
  return sqlalchemy.select(f.count(1)).select_from(query.subquery())
920
866
 
921
867
 
868
+ @frozen
869
+ class SQLDistinct(SQLClause):
870
+ args: tuple[ColumnElement, ...]
871
+ dialect: str
872
+
873
+ def apply_sql_clause(self, query):
874
+ if self.dialect == "sqlite":
875
+ return query.group_by(*self.args)
876
+
877
+ return query.distinct(*self.args)
878
+
879
+
922
880
  @frozen
923
881
  class SQLUnion(Step):
924
882
  query1: "DatasetQuery"
@@ -1000,12 +958,15 @@ class SQLJoin(Step):
1000
958
 
1001
959
  q1_columns = list(q1.c)
1002
960
  q1_column_names = {c.name for c in q1_columns}
1003
- q2_columns = [
1004
- c
1005
- if c.name not in q1_column_names and c.name != "sys__id"
1006
- else c.label(self.rname.format(name=c.name))
1007
- for c in q2.c
1008
- ]
961
+
962
+ q2_columns = []
963
+ for c in q2.c:
964
+ if c.name.startswith("sys__"):
965
+ continue
966
+
967
+ if c.name in q1_column_names:
968
+ c = c.label(self.rname.format(name=c.name))
969
+ q2_columns.append(c)
1009
970
 
1010
971
  res_columns = q1_columns + q2_columns
1011
972
  predicates = (
@@ -1112,6 +1073,7 @@ class DatasetQuery:
1112
1073
  anon: bool = False,
1113
1074
  indexing_feature_schema: Optional[dict] = None,
1114
1075
  indexing_column_types: Optional[dict[str, Any]] = None,
1076
+ update: Optional[bool] = False,
1115
1077
  ):
1116
1078
  if client_config is None:
1117
1079
  client_config = {}
@@ -1134,10 +1096,12 @@ class DatasetQuery:
1134
1096
  self.session = Session.get(session, catalog=catalog)
1135
1097
 
1136
1098
  if path:
1137
- self.starting_step = IndexingStep(path, self.catalog, {}, recursive)
1099
+ kwargs = {"update": True} if update else {}
1100
+ self.starting_step = IndexingStep(path, self.catalog, kwargs, recursive)
1138
1101
  self.feature_schema = indexing_feature_schema
1139
1102
  self.column_types = indexing_column_types
1140
1103
  elif name:
1104
+ self.name = name
1141
1105
  ds = self.catalog.get_dataset(name)
1142
1106
  self.version = version or ds.latest_version
1143
1107
  self.feature_schema = ds.get_version(self.version).feature_schema
@@ -1145,9 +1109,6 @@ class DatasetQuery:
1145
1109
  if "sys__id" in self.column_types:
1146
1110
  self.column_types.pop("sys__id")
1147
1111
  self.starting_step = QueryStep(self.catalog, name, self.version)
1148
- # attaching to specific dataset
1149
- self.name = name
1150
- self.version = version
1151
1112
  else:
1152
1113
  raise ValueError("must provide path or name")
1153
1114
 
@@ -1156,7 +1117,7 @@ class DatasetQuery:
1156
1117
  return bool(re.compile(r"^[a-zA-Z0-9]+://").match(path))
1157
1118
 
1158
1119
  def __iter__(self):
1159
- return iter(self.results())
1120
+ return iter(self.db_results())
1160
1121
 
1161
1122
  def __or__(self, other):
1162
1123
  return self.union(other)
@@ -1277,13 +1238,16 @@ class DatasetQuery:
1277
1238
  warehouse.close()
1278
1239
  self.temp_table_names = []
1279
1240
 
1280
- def results(self, row_factory=None, **kwargs):
1241
+ def db_results(self, row_factory=None, **kwargs):
1281
1242
  with self.as_iterable(**kwargs) as result:
1282
1243
  if row_factory:
1283
1244
  cols = result.columns
1284
1245
  return [row_factory(cols, r) for r in result]
1285
1246
  return list(result)
1286
1247
 
1248
+ def to_db_records(self) -> list[dict[str, Any]]:
1249
+ return self.db_results(lambda cols, row: dict(zip(cols, row)))
1250
+
1287
1251
  @contextlib.contextmanager
1288
1252
  def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
1289
1253
  try:
@@ -1343,15 +1307,6 @@ class DatasetQuery:
1343
1307
  finally:
1344
1308
  self.cleanup()
1345
1309
 
1346
- def to_records(self) -> list[dict[str, Any]]:
1347
- return self.results(lambda cols, row: dict(zip(cols, row)))
1348
-
1349
- def to_pandas(self) -> "pd.DataFrame":
1350
- records = self.to_records()
1351
- df = pd.DataFrame.from_records(records)
1352
- df.columns = [c.replace(DEFAULT_DELIMITER, ".") for c in df.columns]
1353
- return df
1354
-
1355
1310
  def shuffle(self) -> "Self":
1356
1311
  # ToDo: implement shaffle based on seed and/or generating random column
1357
1312
  return self.order_by(C.sys__rand)
@@ -1370,22 +1325,6 @@ class DatasetQuery:
1370
1325
 
1371
1326
  return sampled.limit(n)
1372
1327
 
1373
- def show(self, limit=20) -> None:
1374
- df = self.limit(limit).to_pandas()
1375
-
1376
- options = ["display.max_colwidth", 50, "display.show_dimensions", False]
1377
- with pd.option_context(*options):
1378
- if inside_notebook():
1379
- from IPython.display import display
1380
-
1381
- display(df)
1382
-
1383
- else:
1384
- print(df.to_string())
1385
-
1386
- if len(df) == limit:
1387
- print(f"[limited by {limit} objects]")
1388
-
1389
1328
  def clone(self, new_table=True) -> "Self":
1390
1329
  obj = copy(self)
1391
1330
  obj.steps = obj.steps.copy()
@@ -1483,6 +1422,14 @@ class DatasetQuery:
1483
1422
  query.steps.append(SQLOffset(offset))
1484
1423
  return query
1485
1424
 
1425
+ @detach
1426
+ def distinct(self, *args) -> "Self":
1427
+ query = self.clone()
1428
+ query.steps.append(
1429
+ SQLDistinct(args, dialect=self.catalog.warehouse.db.dialect.name)
1430
+ )
1431
+ return query
1432
+
1486
1433
  def as_scalar(self) -> Any:
1487
1434
  with self.as_iterable() as rows:
1488
1435
  row = next(iter(rows))
@@ -1781,10 +1728,13 @@ def _send_result(dataset_query: DatasetQuery) -> None:
1781
1728
 
1782
1729
  columns = preview_args.get("columns") or []
1783
1730
 
1784
- preview_query = (
1785
- dataset_query.select(*columns)
1786
- .limit(preview_args.get("limit", 10))
1787
- .offset(preview_args.get("offset", 0))
1731
+ if type(dataset_query) is DatasetQuery:
1732
+ preview_query = dataset_query.select(*columns)
1733
+ else:
1734
+ preview_query = dataset_query.select(*columns, _sys=False)
1735
+
1736
+ preview_query = preview_query.limit(preview_args.get("limit", 10)).offset(
1737
+ preview_args.get("offset", 0)
1788
1738
  )
1789
1739
 
1790
1740
  dataset: Optional[tuple[str, int]] = None
@@ -1793,7 +1743,7 @@ def _send_result(dataset_query: DatasetQuery) -> None:
1793
1743
  assert dataset_query.version, "Dataset version should be provided"
1794
1744
  dataset = dataset_query.name, dataset_query.version
1795
1745
 
1796
- preview = preview_query.to_records()
1746
+ preview = preview_query.to_db_records()
1797
1747
  result = ExecutionResult(preview, dataset, metrics)
1798
1748
  data = attrs.asdict(result)
1799
1749
 
@@ -1853,34 +1803,3 @@ def _random_string(length: int) -> str:
1853
1803
  random.choice(string.ascii_letters + string.digits) # noqa: S311
1854
1804
  for i in range(length)
1855
1805
  )
1856
-
1857
-
1858
- def _feature_predicate(obj):
1859
- return (
1860
- inspect.isclass(obj) and source.isfrommain(obj) and issubclass(obj, BaseModel)
1861
- )
1862
-
1863
-
1864
- def _imports(obj):
1865
- return not source.isfrommain(obj)
1866
-
1867
-
1868
- def get_imports(m):
1869
- root = ast.parse(inspect.getsource(m))
1870
-
1871
- for node in ast.iter_child_nodes(root):
1872
- if isinstance(node, ast.Import):
1873
- module = None
1874
- elif isinstance(node, ast.ImportFrom):
1875
- module = node.module
1876
- else:
1877
- continue
1878
-
1879
- for n in node.names:
1880
- import_script = ""
1881
- if module:
1882
- import_script += f"from {module} "
1883
- import_script += f"import {n.name}"
1884
- if n.asname:
1885
- import_script += f" as {n.asname}"
1886
- yield import_script
@@ -10,7 +10,7 @@ from typing import Any, Optional
10
10
 
11
11
  import attrs
12
12
  import multiprocess
13
- from dill import load
13
+ from cloudpickle import load, loads
14
14
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
15
15
  from multiprocess import get_context
16
16
 
@@ -84,7 +84,7 @@ def put_into_queue(queue: Queue, item: Any) -> None:
84
84
 
85
85
  def udf_entrypoint() -> int:
86
86
  # Load UDF info from stdin
87
- udf_info = load(stdin.buffer) # noqa: S301
87
+ udf_info = load(stdin.buffer)
88
88
 
89
89
  (
90
90
  warehouse_class,
@@ -95,7 +95,7 @@ def udf_entrypoint() -> int:
95
95
 
96
96
  # Parallel processing (faster for more CPU-heavy UDFs)
97
97
  dispatch = UDFDispatcher(
98
- udf_info["udf"],
98
+ udf_info["udf_data"],
99
99
  udf_info["catalog_init"],
100
100
  udf_info["id_generator_clone_params"],
101
101
  udf_info["metastore_clone_params"],
@@ -108,7 +108,7 @@ def udf_entrypoint() -> int:
108
108
  batching = udf_info["batching"]
109
109
  table = udf_info["table"]
110
110
  n_workers = udf_info["processes"]
111
- udf = udf_info["udf"]
111
+ udf = loads(udf_info["udf_data"])
112
112
  if n_workers is True:
113
113
  # Use default number of CPUs (cores)
114
114
  n_workers = None
@@ -146,7 +146,7 @@ class UDFDispatcher:
146
146
 
147
147
  def __init__(
148
148
  self,
149
- udf,
149
+ udf_data,
150
150
  catalog_init_params,
151
151
  id_generator_clone_params,
152
152
  metastore_clone_params,
@@ -155,14 +155,7 @@ class UDFDispatcher:
155
155
  is_generator=False,
156
156
  buffer_size=DEFAULT_BATCH_SIZE,
157
157
  ):
158
- # isinstance cannot be used here, as dill packages the entire class definition,
159
- # and so these two types are not considered exactly equal,
160
- # even if they have the same import path.
161
- if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
162
- self.udf = udf
163
- else:
164
- self.udf = None
165
- self.udf_factory = udf
158
+ self.udf_data = udf_data
166
159
  self.catalog_init_params = catalog_init_params
167
160
  (
168
161
  self.id_generator_class,
@@ -214,6 +207,15 @@ class UDFDispatcher:
214
207
  self.catalog = Catalog(
215
208
  id_generator, metastore, warehouse, **self.catalog_init_params
216
209
  )
210
+ udf = loads(self.udf_data)
211
+ # isinstance cannot be used here, as cloudpickle packages the entire class
212
+ # definition, and so these two types are not considered exactly equal,
213
+ # even if they have the same import path.
214
+ if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
215
+ self.udf = udf
216
+ else:
217
+ self.udf = None
218
+ self.udf_factory = udf
217
219
  if not self.udf:
218
220
  self.udf = self.udf_factory()
219
221
 
datachain/query/schema.py CHANGED
@@ -32,6 +32,7 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
32
32
  inherit_cache: Optional[bool] = True
33
33
 
34
34
  def __init__(self, text, type_=None, is_literal=False, _selectable=None):
35
+ """Dataset column."""
35
36
  self.name = ColumnMeta.to_db_name(text)
36
37
  super().__init__(
37
38
  self.name, type_=type_, is_literal=is_literal, _selectable=_selectable
@@ -41,6 +42,7 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
41
42
  return Column(self.name + DEFAULT_DELIMITER + name)
42
43
 
43
44
  def glob(self, glob_str):
45
+ """Search for matches using glob pattern matching."""
44
46
  return self.op("GLOB")(glob_str)
45
47
 
46
48
 
@@ -28,9 +28,9 @@ class Session:
28
28
 
29
29
  Parameters:
30
30
 
31
- `name` (str): The name of the session. Only latters and numbers are supported.
31
+ name (str): The name of the session. Only latters and numbers are supported.
32
32
  It can be empty.
33
- `catalog` (Catalog): Catalog object.
33
+ catalog (Catalog): Catalog object.
34
34
  """
35
35
 
36
36
  GLOBAL_SESSION_CTX: Optional["Session"] = None
@@ -80,9 +80,9 @@ class Session:
80
80
  """Creates a Session() object from a catalog.
81
81
 
82
82
  Parameters:
83
- `session` (Session): Optional Session(). If not provided a new session will
83
+ session (Session): Optional Session(). If not provided a new session will
84
84
  be created. It's needed mostly for simplie API purposes.
85
- `catalog` (Catalog): Optional catalog. By default a new catalog is created.
85
+ catalog (Catalog): Optional catalog. By default a new catalog is created.
86
86
  """
87
87
  if session:
88
88
  return session
@@ -5,6 +5,10 @@ from datachain.sql.utils import compiler_not_implemented
5
5
 
6
6
 
7
7
  class cosine_distance(GenericFunction): # noqa: N801
8
+ """
9
+ Takes a column and array and returns the cosine distance between them.
10
+ """
11
+
8
12
  type = Float()
9
13
  package = "array"
10
14
  name = "cosine_distance"
@@ -12,6 +16,10 @@ class cosine_distance(GenericFunction): # noqa: N801
12
16
 
13
17
 
14
18
  class euclidean_distance(GenericFunction): # noqa: N801
19
+ """
20
+ Takes a column and array and returns the Euclidean distance between them.
21
+ """
22
+
15
23
  type = Float()
16
24
  package = "array"
17
25
  name = "euclidean_distance"
@@ -19,6 +27,10 @@ class euclidean_distance(GenericFunction): # noqa: N801
19
27
 
20
28
 
21
29
  class length(GenericFunction): # noqa: N801
30
+ """
31
+ Returns the length of the array.
32
+ """
33
+
22
34
  type = Int64()
23
35
  package = "array"
24
36
  name = "length"
@@ -5,6 +5,10 @@ from datachain.sql.utils import compiler_not_implemented
5
5
 
6
6
 
7
7
  class length(GenericFunction): # noqa: N801
8
+ """
9
+ Returns the length of the string.
10
+ """
11
+
8
12
  type = Int64()
9
13
  package = "string"
10
14
  name = "length"
@@ -12,6 +16,10 @@ class length(GenericFunction): # noqa: N801
12
16
 
13
17
 
14
18
  class split(GenericFunction): # noqa: N801
19
+ """
20
+ Takes a column and split character and returns an array of the parts.
21
+ """
22
+
15
23
  type = Array(String())
16
24
  package = "string"
17
25
  name = "split"
@@ -1,5 +1,5 @@
1
1
  try:
2
- from datachain.lib.clip import similarity_scores as clip_similarity_scores
2
+ from datachain.lib.clip import clip_similarity_scores
3
3
  from datachain.lib.image import convert_image, convert_images
4
4
  from datachain.lib.pytorch import PytorchDataset, label_to_int
5
5
  from datachain.lib.text import convert_text
datachain/utils.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import glob
2
2
  import importlib.util
3
+ import io
3
4
  import json
4
5
  import os
5
6
  import os.path as osp
@@ -13,8 +14,10 @@ from itertools import islice
13
14
  from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
14
15
  from uuid import UUID
15
16
 
17
+ import cloudpickle
16
18
  from dateutil import tz
17
19
  from dateutil.parser import isoparse
20
+ from pydantic import BaseModel
18
21
 
19
22
  if TYPE_CHECKING:
20
23
  import pandas as pd
@@ -388,3 +391,45 @@ def inside_notebook() -> bool:
388
391
  return False
389
392
 
390
393
  return False
394
+
395
+
396
+ def get_all_subclasses(cls):
397
+ """Return all subclasses of a given class.
398
+ Can return duplicates due to multiple inheritance."""
399
+ for subclass in cls.__subclasses__():
400
+ yield from get_all_subclasses(subclass)
401
+ yield subclass
402
+
403
+
404
+ def filtered_cloudpickle_dumps(obj: Any) -> bytes:
405
+ """Equivalent to cloudpickle.dumps, but this supports Pydantic models."""
406
+ model_namespaces = {}
407
+
408
+ with io.BytesIO() as f:
409
+ pickler = cloudpickle.CloudPickler(f)
410
+
411
+ for model_class in get_all_subclasses(BaseModel):
412
+ # This "is not None" check is needed, because due to multiple inheritance,
413
+ # it is theoretically possible to get the same class twice from
414
+ # get_all_subclasses.
415
+ if model_class.__pydantic_parent_namespace__ is not None:
416
+ # __pydantic_parent_namespace__ can contain many unnecessary and
417
+ # unpickleable entities, so should be removed for serialization.
418
+ model_namespaces[model_class] = (
419
+ model_class.__pydantic_parent_namespace__
420
+ )
421
+ model_class.__pydantic_parent_namespace__ = None
422
+
423
+ try:
424
+ pickler.dump(obj)
425
+ return f.getvalue()
426
+ finally:
427
+ for model_class, namespace in model_namespaces.items():
428
+ # Restore original __pydantic_parent_namespace__ locally.
429
+ model_class.__pydantic_parent_namespace__ = namespace
430
+
431
+
432
+ def get_datachain_executable() -> list[str]:
433
+ if datachain_exec_path := os.getenv("DATACHAIN_EXEC_PATH"):
434
+ return [datachain_exec_path]
435
+ return [sys.executable, "-m", "datachain"]