datachain 0.1.13__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +0 -4
- datachain/asyn.py +3 -3
- datachain/catalog/__init__.py +3 -3
- datachain/catalog/catalog.py +6 -6
- datachain/catalog/loader.py +3 -3
- datachain/cli.py +10 -2
- datachain/client/azure.py +37 -1
- datachain/client/fsspec.py +1 -1
- datachain/client/local.py +1 -1
- datachain/data_storage/__init__.py +1 -1
- datachain/data_storage/metastore.py +11 -3
- datachain/data_storage/schema.py +12 -7
- datachain/data_storage/sqlite.py +3 -0
- datachain/data_storage/warehouse.py +31 -30
- datachain/dataset.py +1 -3
- datachain/lib/arrow.py +85 -0
- datachain/lib/cached_stream.py +3 -85
- datachain/lib/dc.py +382 -179
- datachain/lib/feature.py +46 -91
- datachain/lib/feature_registry.py +4 -1
- datachain/lib/feature_utils.py +2 -2
- datachain/lib/file.py +30 -44
- datachain/lib/image.py +9 -2
- datachain/lib/meta_formats.py +66 -34
- datachain/lib/settings.py +5 -5
- datachain/lib/signal_schema.py +103 -105
- datachain/lib/udf.py +10 -38
- datachain/lib/udf_signature.py +11 -6
- datachain/lib/webdataset_laion.py +5 -22
- datachain/listing.py +8 -8
- datachain/node.py +1 -1
- datachain/progress.py +1 -1
- datachain/query/builtins.py +1 -1
- datachain/query/dataset.py +42 -119
- datachain/query/dispatch.py +1 -1
- datachain/query/metrics.py +19 -0
- datachain/query/schema.py +13 -3
- datachain/sql/__init__.py +1 -1
- datachain/sql/sqlite/base.py +34 -2
- datachain/sql/sqlite/vector.py +13 -5
- datachain/utils.py +1 -122
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/METADATA +11 -4
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/RECORD +47 -47
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/WHEEL +1 -1
- datachain/_version.py +0 -16
- datachain/lib/parquet.py +0 -32
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/LICENSE +0 -0
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/top_level.txt +0 -0
datachain/query/dataset.py
CHANGED
|
@@ -56,13 +56,13 @@ from datachain.storage import Storage, StorageURI
|
|
|
56
56
|
from datachain.utils import batched, determine_processes
|
|
57
57
|
|
|
58
58
|
from .batch import RowBatch
|
|
59
|
+
from .metrics import metrics
|
|
59
60
|
from .schema import C, UDFParamSpec, normalize_param
|
|
60
61
|
from .session import Session
|
|
61
62
|
from .udf import UDFBase, UDFClassWrapper, UDFFactory, UDFType
|
|
62
63
|
|
|
63
64
|
if TYPE_CHECKING:
|
|
64
65
|
import pandas as pd
|
|
65
|
-
from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg
|
|
66
66
|
from sqlalchemy.sql.elements import ClauseElement
|
|
67
67
|
from sqlalchemy.sql.schema import Table
|
|
68
68
|
from sqlalchemy.sql.selectable import GenerativeSelect
|
|
@@ -71,7 +71,6 @@ if TYPE_CHECKING:
|
|
|
71
71
|
from datachain.catalog import Catalog
|
|
72
72
|
from datachain.data_storage import AbstractWarehouse
|
|
73
73
|
from datachain.dataset import DatasetRecord
|
|
74
|
-
from datachain.sql.types import SQLType
|
|
75
74
|
|
|
76
75
|
from .udf import UDFResult
|
|
77
76
|
|
|
@@ -197,7 +196,7 @@ class IndexingStep(StartingStep):
|
|
|
197
196
|
def apply(self):
|
|
198
197
|
self.catalog.index([self.path], **self.kwargs)
|
|
199
198
|
uri, path = self.parse_path()
|
|
200
|
-
|
|
199
|
+
_partial_id, partial_path = self.catalog.metastore.get_valid_partial_id(
|
|
201
200
|
uri, path
|
|
202
201
|
)
|
|
203
202
|
dataset = self.catalog.get_dataset(Storage.dataset_name(uri, partial_path))
|
|
@@ -523,30 +522,23 @@ class UDF(Step, ABC):
|
|
|
523
522
|
"cache": self.cache,
|
|
524
523
|
}
|
|
525
524
|
|
|
526
|
-
feature_module_name, feature_file = self.process_feature_module()
|
|
527
|
-
|
|
528
|
-
# Write the module content to a .py file
|
|
529
|
-
with open(f"{feature_module_name}.py", "w") as module_file:
|
|
530
|
-
module_file.write(feature_file)
|
|
531
|
-
|
|
532
|
-
process_data = dumps(udf_info, recurse=True)
|
|
533
525
|
# Run the UDFDispatcher in another process to avoid needing
|
|
534
526
|
# if __name__ == '__main__': in user scripts
|
|
535
527
|
datachain_exec_path = os.environ.get("DATACHAIN_EXEC_PATH", "datachain")
|
|
536
528
|
|
|
537
529
|
envs = dict(os.environ)
|
|
538
530
|
envs.update({"PYTHONPATH": os.getcwd()})
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
531
|
+
with self.process_feature_module():
|
|
532
|
+
process_data = dumps(udf_info, recurse=True)
|
|
533
|
+
result = subprocess.run( # noqa: S603
|
|
534
|
+
[datachain_exec_path, "--internal-run-udf"],
|
|
542
535
|
input=process_data,
|
|
543
536
|
check=False,
|
|
544
537
|
env=envs,
|
|
545
538
|
)
|
|
546
539
|
if result.returncode != 0:
|
|
547
540
|
raise RuntimeError("UDF Execution Failed!")
|
|
548
|
-
|
|
549
|
-
os.unlink(f"{feature_module_name}.py")
|
|
541
|
+
|
|
550
542
|
else:
|
|
551
543
|
# Otherwise process single-threaded (faster for smaller UDFs)
|
|
552
544
|
# Optionally instantiate the UDF instance if a class is provided.
|
|
@@ -600,6 +592,7 @@ class UDF(Step, ABC):
|
|
|
600
592
|
self.catalog.warehouse.close()
|
|
601
593
|
raise
|
|
602
594
|
|
|
595
|
+
@contextlib.contextmanager
|
|
603
596
|
def process_feature_module(self):
|
|
604
597
|
# Generate a random name for the feature module
|
|
605
598
|
feature_module_name = "tmp" + _random_string(10)
|
|
@@ -611,10 +604,14 @@ class UDF(Step, ABC):
|
|
|
611
604
|
for name, obj in inspect.getmembers(sys.modules["__main__"], _imports)
|
|
612
605
|
if not (name.startswith("__") and name.endswith("__"))
|
|
613
606
|
]
|
|
607
|
+
main_module = sys.modules["__main__"]
|
|
608
|
+
|
|
614
609
|
# Get the feature classes from the main module
|
|
615
|
-
feature_classes =
|
|
616
|
-
|
|
617
|
-
|
|
610
|
+
feature_classes = {
|
|
611
|
+
name: obj
|
|
612
|
+
for name, obj in main_module.__dict__.items()
|
|
613
|
+
if _feature_predicate(obj)
|
|
614
|
+
}
|
|
618
615
|
# Get the source code of the feature classes
|
|
619
616
|
feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
|
|
620
617
|
# Set the module name for the feature classes to the generated name
|
|
@@ -626,7 +623,18 @@ class UDF(Step, ABC):
|
|
|
626
623
|
# Combine the import lines and feature sources
|
|
627
624
|
feature_file = "".join(import_lines) + "\n".join(feature_sources)
|
|
628
625
|
|
|
629
|
-
|
|
626
|
+
# Write the module content to a .py file
|
|
627
|
+
with open(f"{feature_module_name}.py", "w") as module_file:
|
|
628
|
+
module_file.write(feature_file)
|
|
629
|
+
|
|
630
|
+
try:
|
|
631
|
+
yield feature_module_name
|
|
632
|
+
finally:
|
|
633
|
+
for cls in feature_classes.values():
|
|
634
|
+
cls.__module__ = main_module.__name__
|
|
635
|
+
os.unlink(f"{feature_module_name}.py")
|
|
636
|
+
# Remove the dynamic module from sys.modules
|
|
637
|
+
del sys.modules[feature_module_name]
|
|
630
638
|
|
|
631
639
|
def create_partitions_table(self, query: Select) -> "Table":
|
|
632
640
|
"""
|
|
@@ -685,8 +693,7 @@ class UDF(Step, ABC):
|
|
|
685
693
|
)
|
|
686
694
|
|
|
687
695
|
query, tables = self.process_input_query(query)
|
|
688
|
-
for t in tables
|
|
689
|
-
temp_tables.append(t.name)
|
|
696
|
+
temp_tables.extend(t.name for t in tables)
|
|
690
697
|
udf_table = self.create_udf_table(_query)
|
|
691
698
|
temp_tables.append(udf_table.name)
|
|
692
699
|
self.populate_udf_table(udf_table, query)
|
|
@@ -1120,6 +1127,12 @@ class DatasetQuery:
|
|
|
1120
1127
|
indexing_feature_schema: Optional[dict] = None,
|
|
1121
1128
|
indexing_column_types: Optional[dict[str, Any]] = None,
|
|
1122
1129
|
):
|
|
1130
|
+
if client_config is None:
|
|
1131
|
+
client_config = {}
|
|
1132
|
+
|
|
1133
|
+
if anon:
|
|
1134
|
+
client_config["anon"] = True
|
|
1135
|
+
|
|
1123
1136
|
self.steps: list[Step] = []
|
|
1124
1137
|
self.catalog = catalog or get_catalog(client_config=client_config)
|
|
1125
1138
|
self._chunk_index: Optional[int] = None
|
|
@@ -1134,22 +1147,14 @@ class DatasetQuery:
|
|
|
1134
1147
|
self.column_types: Optional[dict[str, Any]] = None
|
|
1135
1148
|
self.session = Session.get(session, catalog=catalog)
|
|
1136
1149
|
|
|
1137
|
-
if client_config is None:
|
|
1138
|
-
client_config = {}
|
|
1139
|
-
|
|
1140
|
-
if anon:
|
|
1141
|
-
client_config["anon"] = True
|
|
1142
|
-
|
|
1143
1150
|
if path:
|
|
1144
|
-
self.starting_step = IndexingStep(
|
|
1145
|
-
path, self.catalog, {"client_config": client_config}, recursive
|
|
1146
|
-
)
|
|
1151
|
+
self.starting_step = IndexingStep(path, self.catalog, {}, recursive)
|
|
1147
1152
|
self.feature_schema = indexing_feature_schema
|
|
1148
1153
|
self.column_types = indexing_column_types
|
|
1149
1154
|
elif name:
|
|
1150
1155
|
ds = self.catalog.get_dataset(name)
|
|
1151
1156
|
self.version = version or ds.latest_version
|
|
1152
|
-
self.feature_schema = ds.feature_schema
|
|
1157
|
+
self.feature_schema = ds.get_version(self.version).feature_schema
|
|
1153
1158
|
self.column_types = copy(ds.schema)
|
|
1154
1159
|
if "id" in self.column_types:
|
|
1155
1160
|
self.column_types.pop("id")
|
|
@@ -1348,8 +1353,7 @@ class DatasetQuery:
|
|
|
1348
1353
|
MapperCls = OrderedMapper if query._order_by_clauses else AsyncMapper # noqa: N806
|
|
1349
1354
|
with contextlib.closing(row_iter()) as rows:
|
|
1350
1355
|
mapper = MapperCls(get_params, rows, workers=workers)
|
|
1351
|
-
|
|
1352
|
-
yield params
|
|
1356
|
+
yield from mapper.iterate()
|
|
1353
1357
|
finally:
|
|
1354
1358
|
self.cleanup()
|
|
1355
1359
|
|
|
@@ -1386,82 +1390,6 @@ class DatasetQuery:
|
|
|
1386
1390
|
records = self.to_records()
|
|
1387
1391
|
return pd.DataFrame.from_records(records)
|
|
1388
1392
|
|
|
1389
|
-
@classmethod
|
|
1390
|
-
def from_dataframe(
|
|
1391
|
-
cls,
|
|
1392
|
-
df: Union["DataFrameXchg", "pd.DataFrame"],
|
|
1393
|
-
name: str = "",
|
|
1394
|
-
version: Optional[int] = None,
|
|
1395
|
-
catalog: Optional["Catalog"] = None,
|
|
1396
|
-
session: Optional[Session] = None,
|
|
1397
|
-
) -> "Self":
|
|
1398
|
-
from datachain.utils import dtype_mapper
|
|
1399
|
-
|
|
1400
|
-
catalog = catalog or get_catalog()
|
|
1401
|
-
assert catalog is not None
|
|
1402
|
-
session = Session.get(session, catalog=catalog)
|
|
1403
|
-
assert session is not None
|
|
1404
|
-
|
|
1405
|
-
try:
|
|
1406
|
-
if name and version and catalog.get_dataset(name).has_version(version):
|
|
1407
|
-
raise RuntimeError(f"Dataset {name} already has version {version}")
|
|
1408
|
-
except DatasetNotFoundError:
|
|
1409
|
-
pass
|
|
1410
|
-
|
|
1411
|
-
if not name and version:
|
|
1412
|
-
raise RuntimeError("Cannot set version for temporary datasets")
|
|
1413
|
-
|
|
1414
|
-
import pandas as pd # noqa: F401
|
|
1415
|
-
from pandas.api.interchange import from_dataframe
|
|
1416
|
-
|
|
1417
|
-
# This is not optimal for dataframes other than pd.DataFrame, as it may copy
|
|
1418
|
-
# all the data to a new dataframe.
|
|
1419
|
-
pd_df = from_dataframe(df)
|
|
1420
|
-
|
|
1421
|
-
dtype: dict[str, type[SQLType]] = {
|
|
1422
|
-
str(pd_df.columns[i]): dtype_mapper(pd_df.iloc[:, i])
|
|
1423
|
-
for i in range(len(pd_df.columns))
|
|
1424
|
-
}
|
|
1425
|
-
|
|
1426
|
-
name = name or session.generate_temp_dataset_name()
|
|
1427
|
-
dataset = catalog.create_dataset(
|
|
1428
|
-
name,
|
|
1429
|
-
version=version,
|
|
1430
|
-
columns=[Column(name, typ) for name, typ in dtype.items()],
|
|
1431
|
-
)
|
|
1432
|
-
version = version or dataset.latest_version
|
|
1433
|
-
|
|
1434
|
-
dr = catalog.warehouse.dataset_rows(dataset)
|
|
1435
|
-
pd_df.to_sql(
|
|
1436
|
-
dr.table.name,
|
|
1437
|
-
catalog.warehouse.db.engine,
|
|
1438
|
-
if_exists="append",
|
|
1439
|
-
index=False,
|
|
1440
|
-
chunksize=10_000,
|
|
1441
|
-
dtype=dtype,
|
|
1442
|
-
)
|
|
1443
|
-
|
|
1444
|
-
catalog.metastore.update_dataset_status(
|
|
1445
|
-
dataset, DatasetStatus.COMPLETE, version=version
|
|
1446
|
-
)
|
|
1447
|
-
catalog.update_dataset_version_with_warehouse_info(dataset, version)
|
|
1448
|
-
return cls(name=name, version=version, catalog=catalog, session=session)
|
|
1449
|
-
|
|
1450
|
-
from_pandas = from_dataframe
|
|
1451
|
-
|
|
1452
|
-
@classmethod
|
|
1453
|
-
def from_parquet(
|
|
1454
|
-
cls,
|
|
1455
|
-
uri: str,
|
|
1456
|
-
*args,
|
|
1457
|
-
**kwargs,
|
|
1458
|
-
) -> "Self":
|
|
1459
|
-
import pandas as pd
|
|
1460
|
-
|
|
1461
|
-
pd_df = pd.read_parquet(uri, dtype_backend="pyarrow")
|
|
1462
|
-
|
|
1463
|
-
return cls.from_dataframe(pd_df, *args, **kwargs)
|
|
1464
|
-
|
|
1465
1393
|
def shuffle(self) -> "Self":
|
|
1466
1394
|
# ToDo: implement shaffle based on seed and/or generating random column
|
|
1467
1395
|
return self.order_by(C.random)
|
|
@@ -1809,22 +1737,16 @@ class DatasetQuery:
|
|
|
1809
1737
|
|
|
1810
1738
|
# Exclude the id column and let the db create it to avoid unique
|
|
1811
1739
|
# constraint violations.
|
|
1812
|
-
cols = [col.name for col in dr.get_table().c if col.name != "id"]
|
|
1813
|
-
assert cols
|
|
1814
1740
|
q = query.exclude(("id",))
|
|
1815
|
-
|
|
1816
1741
|
if q._order_by_clauses:
|
|
1817
1742
|
# ensuring we have id sorted by order by clause if it exists in a query
|
|
1818
1743
|
q = q.add_columns(
|
|
1819
1744
|
f.row_number().over(order_by=q._order_by_clauses).label("id")
|
|
1820
1745
|
)
|
|
1821
|
-
cols.append("id")
|
|
1822
|
-
|
|
1823
|
-
self.catalog.warehouse.db.execute(
|
|
1824
|
-
sqlalchemy.insert(dr.get_table()).from_select(cols, q),
|
|
1825
|
-
**kwargs,
|
|
1826
|
-
)
|
|
1827
1746
|
|
|
1747
|
+
cols = tuple(c.name for c in q.columns)
|
|
1748
|
+
insert_q = sqlalchemy.insert(dr.get_table()).from_select(cols, q)
|
|
1749
|
+
self.catalog.warehouse.db.execute(insert_q, **kwargs)
|
|
1828
1750
|
self.catalog.metastore.update_dataset_status(
|
|
1829
1751
|
dataset, DatasetStatus.COMPLETE, version=version
|
|
1830
1752
|
)
|
|
@@ -1853,6 +1775,7 @@ def _get_output_fd_for_write() -> Union[str, int]:
|
|
|
1853
1775
|
class ExecutionResult:
|
|
1854
1776
|
preview: list[dict] = attrs.field(factory=list)
|
|
1855
1777
|
dataset: Optional[tuple[str, int]] = None
|
|
1778
|
+
metrics: dict[str, Any] = attrs.field(factory=dict)
|
|
1856
1779
|
|
|
1857
1780
|
|
|
1858
1781
|
def _send_result(dataset_query: DatasetQuery) -> None:
|
|
@@ -1886,7 +1809,7 @@ def _send_result(dataset_query: DatasetQuery) -> None:
|
|
|
1886
1809
|
dataset = dataset_query.name, dataset_query.version
|
|
1887
1810
|
|
|
1888
1811
|
preview = preview_query.to_records()
|
|
1889
|
-
result = ExecutionResult(preview, dataset)
|
|
1812
|
+
result = ExecutionResult(preview, dataset, metrics)
|
|
1890
1813
|
data = attrs.asdict(result)
|
|
1891
1814
|
|
|
1892
1815
|
with open(_get_output_fd_for_write(), mode="w") as f:
|
datachain/query/dispatch.py
CHANGED
|
@@ -257,7 +257,7 @@ class UDFDispatcher:
|
|
|
257
257
|
|
|
258
258
|
if self.buffer_size < n_workers:
|
|
259
259
|
raise RuntimeError(
|
|
260
|
-
|
|
260
|
+
"Parallel run error: buffer size is smaller than "
|
|
261
261
|
f"number of workers: {self.buffer_size} < {n_workers}"
|
|
262
262
|
)
|
|
263
263
|
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
metrics: dict[str, Union[str, int, float, bool, None]] = {}
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: PYI041
|
|
7
|
+
"""Set a metric value."""
|
|
8
|
+
if not isinstance(key, str):
|
|
9
|
+
raise TypeError("Key must be a string")
|
|
10
|
+
if not key:
|
|
11
|
+
raise ValueError("Key must not be empty")
|
|
12
|
+
if not isinstance(value, (str, int, float, bool, type(None))):
|
|
13
|
+
raise TypeError("Value must be a string, int, float or bool")
|
|
14
|
+
metrics[key] = value
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get(key: str) -> Optional[Union[str, int, float, bool]]:
|
|
18
|
+
"""Get a metric value."""
|
|
19
|
+
return metrics[key]
|
datachain/query/schema.py
CHANGED
|
@@ -18,20 +18,30 @@ if TYPE_CHECKING:
|
|
|
18
18
|
from datachain.dataset import RowDict
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
DEFAULT_DELIMITER = "__"
|
|
22
|
+
|
|
23
|
+
|
|
21
24
|
class ColumnMeta(type):
|
|
25
|
+
@staticmethod
|
|
26
|
+
def to_db_name(name: str) -> str:
|
|
27
|
+
return name.replace(".", DEFAULT_DELIMITER)
|
|
28
|
+
|
|
22
29
|
def __getattr__(cls, name: str):
|
|
23
|
-
return cls(name)
|
|
30
|
+
return cls(ColumnMeta.to_db_name(name))
|
|
24
31
|
|
|
25
32
|
|
|
26
33
|
class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
27
34
|
inherit_cache: Optional[bool] = True
|
|
28
35
|
|
|
29
36
|
def __init__(self, text, type_=None, is_literal=False, _selectable=None):
|
|
30
|
-
self.name = text
|
|
37
|
+
self.name = ColumnMeta.to_db_name(text)
|
|
31
38
|
super().__init__(
|
|
32
|
-
|
|
39
|
+
self.name, type_=type_, is_literal=is_literal, _selectable=_selectable
|
|
33
40
|
)
|
|
34
41
|
|
|
42
|
+
def __getattr__(self, name: str):
|
|
43
|
+
return Column(self.name + DEFAULT_DELIMITER + name)
|
|
44
|
+
|
|
35
45
|
def glob(self, glob_str):
|
|
36
46
|
return self.op("GLOB")(glob_str)
|
|
37
47
|
|
datachain/sql/__init__.py
CHANGED
datachain/sql/sqlite/base.py
CHANGED
|
@@ -71,8 +71,6 @@ def setup():
|
|
|
71
71
|
compiles(sql_path.name, "sqlite")(compile_path_name)
|
|
72
72
|
compiles(sql_path.file_stem, "sqlite")(compile_path_file_stem)
|
|
73
73
|
compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
|
|
74
|
-
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance)
|
|
75
|
-
compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance)
|
|
76
74
|
compiles(array.length, "sqlite")(compile_array_length)
|
|
77
75
|
compiles(string.length, "sqlite")(compile_string_length)
|
|
78
76
|
compiles(string.split, "sqlite")(compile_string_split)
|
|
@@ -81,6 +79,13 @@ def setup():
|
|
|
81
79
|
compiles(Values, "sqlite")(compile_values)
|
|
82
80
|
compiles(random.rand, "sqlite")(compile_rand)
|
|
83
81
|
|
|
82
|
+
if load_usearch_extension(sqlite3.connect(":memory:")):
|
|
83
|
+
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
|
|
84
|
+
compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance_ext)
|
|
85
|
+
else:
|
|
86
|
+
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance)
|
|
87
|
+
compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance)
|
|
88
|
+
|
|
84
89
|
register_user_defined_sql_functions()
|
|
85
90
|
setup_is_complete = True
|
|
86
91
|
|
|
@@ -246,11 +251,23 @@ def compile_path_file_ext(element, compiler, **kwargs):
|
|
|
246
251
|
return compiler.process(path_file_ext(*element.clauses.clauses), **kwargs)
|
|
247
252
|
|
|
248
253
|
|
|
254
|
+
def compile_cosine_distance_ext(element, compiler, **kwargs):
|
|
255
|
+
run_compiler_hook("cosine_distance")
|
|
256
|
+
return f"distance_cosine_f32({compiler.process(element.clauses, **kwargs)})"
|
|
257
|
+
|
|
258
|
+
|
|
249
259
|
def compile_cosine_distance(element, compiler, **kwargs):
|
|
250
260
|
run_compiler_hook("cosine_distance")
|
|
251
261
|
return f"cosine_distance({compiler.process(element.clauses, **kwargs)})"
|
|
252
262
|
|
|
253
263
|
|
|
264
|
+
def compile_euclidean_distance_ext(element, compiler, **kwargs):
|
|
265
|
+
run_compiler_hook("euclidean_distance")
|
|
266
|
+
return (
|
|
267
|
+
f"sqrt(distance_sqeuclidean_f32({compiler.process(element.clauses, **kwargs)}))"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
254
271
|
def compile_euclidean_distance(element, compiler, **kwargs):
|
|
255
272
|
run_compiler_hook("euclidean_distance")
|
|
256
273
|
return f"euclidean_distance({compiler.process(element.clauses, **kwargs)})"
|
|
@@ -330,3 +347,18 @@ def compile_values(element, compiler, **kwargs):
|
|
|
330
347
|
|
|
331
348
|
def compile_rand(element, compiler, **kwargs):
|
|
332
349
|
return compiler.process(func.random(), **kwargs)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def load_usearch_extension(conn) -> bool:
|
|
353
|
+
try:
|
|
354
|
+
# usearch is part of the vector optional dependencies
|
|
355
|
+
# we use the extension's cosine and euclidean distance functions
|
|
356
|
+
from usearch import sqlite_path
|
|
357
|
+
|
|
358
|
+
conn.enable_load_extension(True)
|
|
359
|
+
conn.load_extension(sqlite_path())
|
|
360
|
+
conn.enable_load_extension(False)
|
|
361
|
+
return True
|
|
362
|
+
|
|
363
|
+
except Exception: # noqa: BLE001
|
|
364
|
+
return False
|
datachain/sql/sqlite/vector.py
CHANGED
|
@@ -1,15 +1,23 @@
|
|
|
1
|
-
import
|
|
1
|
+
import math
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
-
from scipy.spatial import distance
|
|
5
4
|
|
|
6
5
|
|
|
7
6
|
def euclidean_distance(a: str, b: str):
|
|
8
|
-
a_np = np.
|
|
9
|
-
b_np = np.
|
|
7
|
+
a_np = np.fromstring(a[1:-1], sep=",")
|
|
8
|
+
b_np = np.fromstring(b[1:-1], sep=",")
|
|
10
9
|
|
|
11
10
|
return np.linalg.norm(b_np - a_np)
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
def cosine_distance(a: str, b: str):
|
|
15
|
-
|
|
14
|
+
u = np.fromstring(a[1:-1], sep=",")
|
|
15
|
+
v = np.fromstring(b[1:-1], sep=",")
|
|
16
|
+
|
|
17
|
+
uv = np.inner(u, v)
|
|
18
|
+
uu = np.inner(u, u)
|
|
19
|
+
vv = np.inner(v, v)
|
|
20
|
+
|
|
21
|
+
dist = 1.0 - uv / math.sqrt(uu * vv)
|
|
22
|
+
|
|
23
|
+
return max(0, min(dist, 2.0))
|
datachain/utils.py
CHANGED
|
@@ -18,9 +18,6 @@ from dateutil.parser import isoparse
|
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
import pandas as pd
|
|
21
|
-
import pyarrow as pa
|
|
22
|
-
|
|
23
|
-
from datachain.sql.types import SQLType
|
|
24
21
|
|
|
25
22
|
NUL = b"\0"
|
|
26
23
|
TIME_ZERO = datetime.fromtimestamp(0, tz=timezone.utc)
|
|
@@ -78,7 +75,7 @@ class DataChainDir:
|
|
|
78
75
|
if create:
|
|
79
76
|
instance.init()
|
|
80
77
|
else:
|
|
81
|
-
NotADirectoryError(root)
|
|
78
|
+
raise NotADirectoryError(root)
|
|
82
79
|
return instance
|
|
83
80
|
|
|
84
81
|
|
|
@@ -363,121 +360,3 @@ class JSONSerialize(json.JSONEncoder):
|
|
|
363
360
|
return str(obj)
|
|
364
361
|
|
|
365
362
|
return super().default(obj)
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
def dtype_mapper(col: Union["pd.Index", "pd.Series"]) -> type["SQLType"]: # noqa: PLR0911
|
|
369
|
-
from pandas import ArrowDtype
|
|
370
|
-
from pandas.api.types import infer_dtype
|
|
371
|
-
|
|
372
|
-
from datachain.sql.types import (
|
|
373
|
-
Binary,
|
|
374
|
-
Boolean,
|
|
375
|
-
DateTime,
|
|
376
|
-
Float,
|
|
377
|
-
Float32,
|
|
378
|
-
Float64,
|
|
379
|
-
Int,
|
|
380
|
-
Int32,
|
|
381
|
-
Int64,
|
|
382
|
-
String,
|
|
383
|
-
UInt64,
|
|
384
|
-
)
|
|
385
|
-
|
|
386
|
-
if isinstance(col.dtype, ArrowDtype):
|
|
387
|
-
return arrow_type_mapper(col.dtype.pyarrow_dtype)
|
|
388
|
-
|
|
389
|
-
col_type = infer_dtype(col, skipna=True)
|
|
390
|
-
|
|
391
|
-
if col_type in ("datetime", "datetime64"):
|
|
392
|
-
return DateTime
|
|
393
|
-
if col_type == "bytes":
|
|
394
|
-
return Binary
|
|
395
|
-
if col_type == "floating":
|
|
396
|
-
if col.dtype == "float32":
|
|
397
|
-
return Float32
|
|
398
|
-
if col.dtype == "float64":
|
|
399
|
-
return Float64
|
|
400
|
-
return Float
|
|
401
|
-
if col_type == "integer":
|
|
402
|
-
if col.dtype.name.lower() in ("int8", "int16", "int32"):
|
|
403
|
-
return Int32
|
|
404
|
-
if col.dtype.name.lower() == "int64":
|
|
405
|
-
return Int64
|
|
406
|
-
if col.dtype.name.lower().startswith("uint"):
|
|
407
|
-
return UInt64
|
|
408
|
-
return Int
|
|
409
|
-
if col_type == "boolean":
|
|
410
|
-
return Boolean
|
|
411
|
-
if col_type == "date":
|
|
412
|
-
return DateTime
|
|
413
|
-
if col_type in (
|
|
414
|
-
"complex",
|
|
415
|
-
"time",
|
|
416
|
-
"timedelta",
|
|
417
|
-
"timedelta64",
|
|
418
|
-
"period",
|
|
419
|
-
"interval",
|
|
420
|
-
):
|
|
421
|
-
raise ValueError(f"{col_type!r} datatypes not supported")
|
|
422
|
-
return String
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
def arrow_type_mapper(col_type: "pa.DataType") -> type["SQLType"]: # noqa: PLR0911,C901
|
|
426
|
-
try:
|
|
427
|
-
import pyarrow as pa
|
|
428
|
-
except ImportError as exc:
|
|
429
|
-
raise ImportError(
|
|
430
|
-
"Missing required dependency pyarrow for inferring types"
|
|
431
|
-
) from exc
|
|
432
|
-
|
|
433
|
-
from datachain.sql.types import (
|
|
434
|
-
JSON,
|
|
435
|
-
Array,
|
|
436
|
-
Binary,
|
|
437
|
-
Boolean,
|
|
438
|
-
DateTime,
|
|
439
|
-
Float,
|
|
440
|
-
Float32,
|
|
441
|
-
Float64,
|
|
442
|
-
Int,
|
|
443
|
-
Int32,
|
|
444
|
-
Int64,
|
|
445
|
-
String,
|
|
446
|
-
UInt64,
|
|
447
|
-
)
|
|
448
|
-
|
|
449
|
-
if pa.types.is_timestamp(col_type):
|
|
450
|
-
return DateTime
|
|
451
|
-
if pa.types.is_binary(col_type):
|
|
452
|
-
return Binary
|
|
453
|
-
if pa.types.is_floating(col_type):
|
|
454
|
-
if pa.types.is_float32(col_type):
|
|
455
|
-
return Float32
|
|
456
|
-
if pa.types.is_float64(col_type):
|
|
457
|
-
return Float64
|
|
458
|
-
return Float
|
|
459
|
-
if pa.types.is_integer(col_type):
|
|
460
|
-
if (
|
|
461
|
-
pa.types.is_int8(col_type)
|
|
462
|
-
or pa.types.is_int16(col_type)
|
|
463
|
-
or pa.types.is_int32(col_type)
|
|
464
|
-
):
|
|
465
|
-
return Int32
|
|
466
|
-
if pa.types.is_int64(col_type):
|
|
467
|
-
return Int64
|
|
468
|
-
if pa.types.is_unsigned_integer(col_type):
|
|
469
|
-
return UInt64
|
|
470
|
-
return Int
|
|
471
|
-
if pa.types.is_boolean(col_type):
|
|
472
|
-
return Boolean
|
|
473
|
-
if pa.types.is_date(col_type):
|
|
474
|
-
return DateTime
|
|
475
|
-
if pa.types.is_string(col_type):
|
|
476
|
-
return String
|
|
477
|
-
if pa.types.is_list(col_type):
|
|
478
|
-
return Array(arrow_type_mapper(col_type.value_type)) # type: ignore[return-value]
|
|
479
|
-
if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
|
|
480
|
-
return JSON
|
|
481
|
-
if isinstance(col_type, pa.lib.DictionaryType):
|
|
482
|
-
return arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
|
|
483
|
-
raise ValueError(f"{col_type!r} datatypes not supported")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: datachain
|
|
3
|
-
Version: 0.1
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Wrangle unstructured AI data at scale
|
|
5
5
|
Author-email: Dmitry Petrov <support@dvc.org>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -44,12 +44,19 @@ Requires-Dist: torch >=2.1.0 ; extra == 'cv'
|
|
|
44
44
|
Requires-Dist: torchvision ; extra == 'cv'
|
|
45
45
|
Requires-Dist: transformers >=4.36.0 ; extra == 'cv'
|
|
46
46
|
Provides-Extra: dev
|
|
47
|
-
Requires-Dist: datachain[tests] ; extra == 'dev'
|
|
48
|
-
Requires-Dist: mypy ==1.10.
|
|
47
|
+
Requires-Dist: datachain[docs,tests] ; extra == 'dev'
|
|
48
|
+
Requires-Dist: mypy ==1.10.1 ; extra == 'dev'
|
|
49
49
|
Requires-Dist: types-python-dateutil ; extra == 'dev'
|
|
50
50
|
Requires-Dist: types-PyYAML ; extra == 'dev'
|
|
51
51
|
Requires-Dist: types-requests ; extra == 'dev'
|
|
52
52
|
Requires-Dist: types-ujson ; extra == 'dev'
|
|
53
|
+
Provides-Extra: docs
|
|
54
|
+
Requires-Dist: mkdocs >=1.5.2 ; extra == 'docs'
|
|
55
|
+
Requires-Dist: mkdocs-gen-files >=0.5.0 ; extra == 'docs'
|
|
56
|
+
Requires-Dist: mkdocs-material >=9.3.1 ; extra == 'docs'
|
|
57
|
+
Requires-Dist: mkdocs-section-index >=0.3.6 ; extra == 'docs'
|
|
58
|
+
Requires-Dist: mkdocstrings-python >=1.6.3 ; extra == 'docs'
|
|
59
|
+
Requires-Dist: mkdocs-literate-nav >=0.6.1 ; extra == 'docs'
|
|
53
60
|
Provides-Extra: remote
|
|
54
61
|
Requires-Dist: datachain[pandas] ; extra == 'remote'
|
|
55
62
|
Requires-Dist: lz4 ; extra == 'remote'
|
|
@@ -72,7 +79,7 @@ Requires-Dist: open-clip-torch ; extra == 'tests'
|
|
|
72
79
|
Requires-Dist: aiotools >=1.7.0 ; extra == 'tests'
|
|
73
80
|
Requires-Dist: requests-mock ; extra == 'tests'
|
|
74
81
|
Provides-Extra: vector
|
|
75
|
-
Requires-Dist:
|
|
82
|
+
Requires-Dist: usearch ; extra == 'vector'
|
|
76
83
|
|
|
77
84
|
|PyPI| |Python Version| |Codecov| |Tests| |License|
|
|
78
85
|
|