datachain 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +17 -8
- datachain/catalog/catalog.py +5 -5
- datachain/cli.py +0 -2
- datachain/data_storage/schema.py +5 -5
- datachain/data_storage/sqlite.py +1 -1
- datachain/data_storage/warehouse.py +7 -7
- datachain/lib/arrow.py +25 -8
- datachain/lib/clip.py +6 -11
- datachain/lib/convert/__init__.py +0 -0
- datachain/lib/convert/flatten.py +67 -0
- datachain/lib/convert/type_converter.py +96 -0
- datachain/lib/convert/unflatten.py +69 -0
- datachain/lib/convert/values_to_tuples.py +85 -0
- datachain/lib/data_model.py +74 -0
- datachain/lib/dc.py +192 -167
- datachain/lib/feature_registry.py +36 -10
- datachain/lib/file.py +41 -41
- datachain/lib/gpt4_vision.py +1 -9
- datachain/lib/hf_image_to_text.py +9 -17
- datachain/lib/hf_pipeline.py +4 -12
- datachain/lib/image.py +2 -18
- datachain/lib/image_transform.py +0 -1
- datachain/lib/iptc_exif_xmp.py +8 -15
- datachain/lib/meta_formats.py +1 -5
- datachain/lib/model_store.py +77 -0
- datachain/lib/pytorch.py +9 -21
- datachain/lib/signal_schema.py +120 -58
- datachain/lib/text.py +5 -16
- datachain/lib/udf.py +114 -30
- datachain/lib/udf_signature.py +5 -5
- datachain/lib/webdataset.py +3 -4
- datachain/lib/webdataset_laion.py +2 -3
- datachain/node.py +4 -4
- datachain/query/batch.py +1 -1
- datachain/query/dataset.py +40 -60
- datachain/query/dispatch.py +28 -17
- datachain/query/udf.py +46 -26
- datachain/remote/studio.py +1 -9
- datachain/torch/__init__.py +21 -0
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/METADATA +13 -12
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/RECORD +45 -42
- datachain/image/__init__.py +0 -3
- datachain/lib/cached_stream.py +0 -38
- datachain/lib/claude.py +0 -69
- datachain/lib/feature.py +0 -412
- datachain/lib/feature_utils.py +0 -154
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/LICENSE +0 -0
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/WHEEL +0 -0
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/top_level.txt +0 -0
datachain/node.py
CHANGED
|
@@ -46,8 +46,8 @@ class DirTypeGroup:
|
|
|
46
46
|
|
|
47
47
|
@attrs.define
|
|
48
48
|
class Node:
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
sys__id: int = 0
|
|
50
|
+
sys__rand: int = -1
|
|
51
51
|
vtype: str = ""
|
|
52
52
|
dir_type: Optional[int] = None
|
|
53
53
|
parent: str = ""
|
|
@@ -127,11 +127,11 @@ class Node:
|
|
|
127
127
|
|
|
128
128
|
@classmethod
|
|
129
129
|
def from_dir(cls, parent, name, **kwargs) -> "Node":
|
|
130
|
-
return cls(
|
|
130
|
+
return cls(sys__id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs)
|
|
131
131
|
|
|
132
132
|
@classmethod
|
|
133
133
|
def root(cls) -> "Node":
|
|
134
|
-
return cls(
|
|
134
|
+
return cls(sys__id=-1, dir_type=DirType.DIR)
|
|
135
135
|
|
|
136
136
|
|
|
137
137
|
@attrs.define
|
datachain/query/batch.py
CHANGED
|
@@ -104,7 +104,7 @@ class Partition(BatchingStrategy):
|
|
|
104
104
|
with contextlib.closing(
|
|
105
105
|
execute(
|
|
106
106
|
query,
|
|
107
|
-
order_by=(PARTITION_COLUMN_ID, "
|
|
107
|
+
order_by=(PARTITION_COLUMN_ID, "sys__id", *query._order_by_clauses),
|
|
108
108
|
limit=query._limit,
|
|
109
109
|
)
|
|
110
110
|
) as rows:
|
datachain/query/dataset.py
CHANGED
|
@@ -31,6 +31,7 @@ import sqlalchemy
|
|
|
31
31
|
from attrs import frozen
|
|
32
32
|
from dill import dumps, source
|
|
33
33
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
|
|
34
|
+
from pydantic import BaseModel
|
|
34
35
|
from sqlalchemy import Column
|
|
35
36
|
from sqlalchemy.sql import func as f
|
|
36
37
|
from sqlalchemy.sql.elements import ColumnClause, ColumnElement
|
|
@@ -57,7 +58,6 @@ from datachain.sql.functions import rand
|
|
|
57
58
|
from datachain.storage import Storage, StorageURI
|
|
58
59
|
from datachain.utils import batched, determine_processes, inside_notebook
|
|
59
60
|
|
|
60
|
-
from .batch import RowBatch
|
|
61
61
|
from .metrics import metrics
|
|
62
62
|
from .schema import C, UDFParamSpec, normalize_param
|
|
63
63
|
from .session import Session
|
|
@@ -257,7 +257,7 @@ class DatasetDiffOperation(Step):
|
|
|
257
257
|
"""
|
|
258
258
|
|
|
259
259
|
def apply(self, query_generator, temp_tables: list[str]):
|
|
260
|
-
source_query = query_generator.exclude(("
|
|
260
|
+
source_query = query_generator.exclude(("sys__id",))
|
|
261
261
|
target_query = self.dq.apply_steps().select()
|
|
262
262
|
temp_tables.extend(self.dq.temp_table_names)
|
|
263
263
|
|
|
@@ -427,22 +427,6 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
427
427
|
return DEFAULT_CALLBACK
|
|
428
428
|
|
|
429
429
|
|
|
430
|
-
def run_udf(
|
|
431
|
-
udf,
|
|
432
|
-
udf_inputs,
|
|
433
|
-
catalog,
|
|
434
|
-
is_generator,
|
|
435
|
-
cache,
|
|
436
|
-
download_cb: Callback = DEFAULT_CALLBACK,
|
|
437
|
-
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
438
|
-
) -> Iterator[Iterable["UDFResult"]]:
|
|
439
|
-
for batch in udf_inputs:
|
|
440
|
-
n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
|
|
441
|
-
output = udf(catalog, batch, is_generator, cache, cb=download_cb)
|
|
442
|
-
processed_cb.relative_update(n_rows)
|
|
443
|
-
yield output
|
|
444
|
-
|
|
445
|
-
|
|
446
430
|
@frozen
|
|
447
431
|
class UDF(Step, ABC):
|
|
448
432
|
udf: UDFType
|
|
@@ -548,9 +532,6 @@ class UDF(Step, ABC):
|
|
|
548
532
|
else:
|
|
549
533
|
udf = self.udf
|
|
550
534
|
|
|
551
|
-
if hasattr(udf.func, "setup") and callable(udf.func.setup):
|
|
552
|
-
udf.func.setup()
|
|
553
|
-
|
|
554
535
|
warehouse = self.catalog.warehouse
|
|
555
536
|
|
|
556
537
|
with contextlib.closing(
|
|
@@ -560,8 +541,7 @@ class UDF(Step, ABC):
|
|
|
560
541
|
processed_cb = get_processed_callback()
|
|
561
542
|
generated_cb = get_generated_callback(self.is_generator)
|
|
562
543
|
try:
|
|
563
|
-
udf_results =
|
|
564
|
-
udf,
|
|
544
|
+
udf_results = udf.run(
|
|
565
545
|
udf_inputs,
|
|
566
546
|
self.catalog,
|
|
567
547
|
self.is_generator,
|
|
@@ -583,9 +563,6 @@ class UDF(Step, ABC):
|
|
|
583
563
|
|
|
584
564
|
warehouse.insert_rows_done(udf_table)
|
|
585
565
|
|
|
586
|
-
if hasattr(udf.func, "teardown") and callable(udf.func.teardown):
|
|
587
|
-
udf.func.teardown()
|
|
588
|
-
|
|
589
566
|
except QueryScriptCancelError:
|
|
590
567
|
self.catalog.warehouse.close()
|
|
591
568
|
sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
|
|
@@ -663,7 +640,7 @@ class UDF(Step, ABC):
|
|
|
663
640
|
|
|
664
641
|
# fill table with partitions
|
|
665
642
|
cols = [
|
|
666
|
-
query.selected_columns.
|
|
643
|
+
query.selected_columns.sys__id,
|
|
667
644
|
f.dense_rank().over(order_by=list_partition_by).label(PARTITION_COLUMN_ID),
|
|
668
645
|
]
|
|
669
646
|
self.catalog.warehouse.db.execute(
|
|
@@ -697,7 +674,7 @@ class UDF(Step, ABC):
|
|
|
697
674
|
subq = query.subquery()
|
|
698
675
|
query = (
|
|
699
676
|
sqlalchemy.select(*subq.c)
|
|
700
|
-
.outerjoin(partition_tbl, partition_tbl.c.
|
|
677
|
+
.outerjoin(partition_tbl, partition_tbl.c.sys__id == subq.c.sys__id)
|
|
701
678
|
.add_columns(*partition_columns())
|
|
702
679
|
)
|
|
703
680
|
|
|
@@ -729,18 +706,18 @@ class UDFSignal(UDF):
|
|
|
729
706
|
columns = [
|
|
730
707
|
sqlalchemy.Column(c.name, c.type)
|
|
731
708
|
for c in query.selected_columns
|
|
732
|
-
if c.name != "
|
|
709
|
+
if c.name != "sys__id"
|
|
733
710
|
]
|
|
734
711
|
table = self.catalog.warehouse.create_udf_table(self.udf_table_name(), columns)
|
|
735
712
|
select_q = query.with_only_columns(
|
|
736
|
-
*[c for c in query.selected_columns if c.name != "
|
|
713
|
+
*[c for c in query.selected_columns if c.name != "sys__id"]
|
|
737
714
|
)
|
|
738
715
|
|
|
739
716
|
# if there is order by clause we need row_number to preserve order
|
|
740
717
|
# if there is no order by clause we still need row_number to generate
|
|
741
718
|
# unique ids as uniqueness is important for this table
|
|
742
719
|
select_q = select_q.add_columns(
|
|
743
|
-
f.row_number().over(order_by=select_q._order_by_clauses).label("
|
|
720
|
+
f.row_number().over(order_by=select_q._order_by_clauses).label("sys__id")
|
|
744
721
|
)
|
|
745
722
|
|
|
746
723
|
self.catalog.warehouse.db.execute(
|
|
@@ -756,7 +733,7 @@ class UDFSignal(UDF):
|
|
|
756
733
|
if query._order_by_clauses:
|
|
757
734
|
# we are adding ordering only if it's explicitly added by user in
|
|
758
735
|
# query part before adding signals
|
|
759
|
-
q = q.order_by(table.c.
|
|
736
|
+
q = q.order_by(table.c.sys__id)
|
|
760
737
|
return q, [table]
|
|
761
738
|
|
|
762
739
|
def create_result_query(
|
|
@@ -766,7 +743,7 @@ class UDFSignal(UDF):
|
|
|
766
743
|
original_cols = [c for c in subq.c if c.name not in partition_col_names]
|
|
767
744
|
|
|
768
745
|
# new signal columns that are added to udf_table
|
|
769
|
-
signal_cols = [c for c in udf_table.c if c.name != "
|
|
746
|
+
signal_cols = [c for c in udf_table.c if c.name != "sys__id"]
|
|
770
747
|
signal_name_cols = {c.name: c for c in signal_cols}
|
|
771
748
|
cols = signal_cols
|
|
772
749
|
|
|
@@ -786,7 +763,7 @@ class UDFSignal(UDF):
|
|
|
786
763
|
res = (
|
|
787
764
|
sqlalchemy.select(*cols1)
|
|
788
765
|
.select_from(subq)
|
|
789
|
-
.outerjoin(udf_table, udf_table.c.
|
|
766
|
+
.outerjoin(udf_table, udf_table.c.sys__id == subq.c.sys__id)
|
|
790
767
|
.add_columns(*cols2)
|
|
791
768
|
)
|
|
792
769
|
else:
|
|
@@ -795,7 +772,7 @@ class UDFSignal(UDF):
|
|
|
795
772
|
if query._order_by_clauses:
|
|
796
773
|
# if ordering is used in query part before adding signals, we
|
|
797
774
|
# will have it as order by id from select from pre-created udf table
|
|
798
|
-
res = res.order_by(subq.c.
|
|
775
|
+
res = res.order_by(subq.c.sys__id)
|
|
799
776
|
|
|
800
777
|
if self.partition_by is not None:
|
|
801
778
|
subquery = res.subquery()
|
|
@@ -833,7 +810,7 @@ class RowGenerator(UDF):
|
|
|
833
810
|
# we get the same rows as we got as inputs of UDF since selecting
|
|
834
811
|
# without ordering can be non deterministic in some databases
|
|
835
812
|
c = query.selected_columns
|
|
836
|
-
query = query.order_by(c.
|
|
813
|
+
query = query.order_by(c.sys__id)
|
|
837
814
|
|
|
838
815
|
udf_table_query = udf_table.select().subquery()
|
|
839
816
|
udf_table_cols: list[sqlalchemy.Label[Any]] = [
|
|
@@ -1025,7 +1002,7 @@ class SQLJoin(Step):
|
|
|
1025
1002
|
q1_column_names = {c.name for c in q1_columns}
|
|
1026
1003
|
q2_columns = [
|
|
1027
1004
|
c
|
|
1028
|
-
if c.name not in q1_column_names and c.name != "
|
|
1005
|
+
if c.name not in q1_column_names and c.name != "sys__id"
|
|
1029
1006
|
else c.label(self.rname.format(name=c.name))
|
|
1030
1007
|
for c in q2.c
|
|
1031
1008
|
]
|
|
@@ -1165,8 +1142,8 @@ class DatasetQuery:
|
|
|
1165
1142
|
self.version = version or ds.latest_version
|
|
1166
1143
|
self.feature_schema = ds.get_version(self.version).feature_schema
|
|
1167
1144
|
self.column_types = copy(ds.schema)
|
|
1168
|
-
if "
|
|
1169
|
-
self.column_types.pop("
|
|
1145
|
+
if "sys__id" in self.column_types:
|
|
1146
|
+
self.column_types.pop("sys__id")
|
|
1170
1147
|
self.starting_step = QueryStep(self.catalog, name, self.version)
|
|
1171
1148
|
# attaching to specific dataset
|
|
1172
1149
|
self.name = name
|
|
@@ -1239,7 +1216,7 @@ class DatasetQuery:
|
|
|
1239
1216
|
query.steps = self._chunk_limit(query.steps, index, total)
|
|
1240
1217
|
|
|
1241
1218
|
# Prepend the chunk filter to the step chain.
|
|
1242
|
-
query = query.filter(C.
|
|
1219
|
+
query = query.filter(C.sys__rand % total == index)
|
|
1243
1220
|
query.steps = query.steps[-1:] + query.steps[:-1]
|
|
1244
1221
|
|
|
1245
1222
|
result = query.starting_step.apply()
|
|
@@ -1366,10 +1343,8 @@ class DatasetQuery:
|
|
|
1366
1343
|
finally:
|
|
1367
1344
|
self.cleanup()
|
|
1368
1345
|
|
|
1369
|
-
def to_records(self) -> list[dict]:
|
|
1370
|
-
|
|
1371
|
-
cols = result.columns
|
|
1372
|
-
return [dict(zip(cols, row)) for row in result]
|
|
1346
|
+
def to_records(self) -> list[dict[str, Any]]:
|
|
1347
|
+
return self.results(lambda cols, row: dict(zip(cols, row)))
|
|
1373
1348
|
|
|
1374
1349
|
def to_pandas(self) -> "pd.DataFrame":
|
|
1375
1350
|
records = self.to_records()
|
|
@@ -1379,7 +1354,7 @@ class DatasetQuery:
|
|
|
1379
1354
|
|
|
1380
1355
|
def shuffle(self) -> "Self":
|
|
1381
1356
|
# ToDo: implement shaffle based on seed and/or generating random column
|
|
1382
|
-
return self.order_by(C.
|
|
1357
|
+
return self.order_by(C.sys__rand)
|
|
1383
1358
|
|
|
1384
1359
|
def sample(self, n) -> "Self":
|
|
1385
1360
|
"""
|
|
@@ -1508,30 +1483,35 @@ class DatasetQuery:
|
|
|
1508
1483
|
query.steps.append(SQLOffset(offset))
|
|
1509
1484
|
return query
|
|
1510
1485
|
|
|
1486
|
+
def as_scalar(self) -> Any:
|
|
1487
|
+
with self.as_iterable() as rows:
|
|
1488
|
+
row = next(iter(rows))
|
|
1489
|
+
return row[0]
|
|
1490
|
+
|
|
1511
1491
|
def count(self) -> int:
|
|
1512
1492
|
query = self.clone()
|
|
1513
1493
|
query.steps.append(SQLCount())
|
|
1514
|
-
return query.
|
|
1494
|
+
return query.as_scalar()
|
|
1515
1495
|
|
|
1516
|
-
def sum(self, col: ColumnElement):
|
|
1496
|
+
def sum(self, col: ColumnElement) -> int:
|
|
1517
1497
|
query = self.clone()
|
|
1518
1498
|
query.steps.append(SQLSelect((f.sum(col),)))
|
|
1519
|
-
return query.
|
|
1499
|
+
return query.as_scalar()
|
|
1520
1500
|
|
|
1521
|
-
def avg(self, col: ColumnElement):
|
|
1501
|
+
def avg(self, col: ColumnElement) -> int:
|
|
1522
1502
|
query = self.clone()
|
|
1523
1503
|
query.steps.append(SQLSelect((f.avg(col),)))
|
|
1524
|
-
return query.
|
|
1504
|
+
return query.as_scalar()
|
|
1525
1505
|
|
|
1526
|
-
def min(self, col: ColumnElement):
|
|
1506
|
+
def min(self, col: ColumnElement) -> int:
|
|
1527
1507
|
query = self.clone()
|
|
1528
1508
|
query.steps.append(SQLSelect((f.min(col),)))
|
|
1529
|
-
return query.
|
|
1509
|
+
return query.as_scalar()
|
|
1530
1510
|
|
|
1531
|
-
def max(self, col: ColumnElement):
|
|
1511
|
+
def max(self, col: ColumnElement) -> int:
|
|
1532
1512
|
query = self.clone()
|
|
1533
1513
|
query.steps.append(SQLSelect((f.max(col),)))
|
|
1534
|
-
return query.
|
|
1514
|
+
return query.as_scalar()
|
|
1535
1515
|
|
|
1536
1516
|
@detach
|
|
1537
1517
|
def group_by(self, *cols: ColumnElement) -> "Self":
|
|
@@ -1723,7 +1703,7 @@ class DatasetQuery:
|
|
|
1723
1703
|
c if isinstance(c, Column) else Column(c.name, c.type)
|
|
1724
1704
|
for c in query.columns
|
|
1725
1705
|
]
|
|
1726
|
-
if not [c for c in columns if c.name != "
|
|
1706
|
+
if not [c for c in columns if c.name != "sys__id"]:
|
|
1727
1707
|
raise RuntimeError(
|
|
1728
1708
|
"No columns to save in the query. "
|
|
1729
1709
|
"Ensure at least one column (other than 'id') is selected."
|
|
@@ -1742,11 +1722,11 @@ class DatasetQuery:
|
|
|
1742
1722
|
|
|
1743
1723
|
# Exclude the id column and let the db create it to avoid unique
|
|
1744
1724
|
# constraint violations.
|
|
1745
|
-
q = query.exclude(("
|
|
1725
|
+
q = query.exclude(("sys__id",))
|
|
1746
1726
|
if q._order_by_clauses:
|
|
1747
1727
|
# ensuring we have id sorted by order by clause if it exists in a query
|
|
1748
1728
|
q = q.add_columns(
|
|
1749
|
-
f.row_number().over(order_by=q._order_by_clauses).label("
|
|
1729
|
+
f.row_number().over(order_by=q._order_by_clauses).label("sys__id")
|
|
1750
1730
|
)
|
|
1751
1731
|
|
|
1752
1732
|
cols = tuple(c.name for c in q.columns)
|
|
@@ -1876,9 +1856,9 @@ def _random_string(length: int) -> str:
|
|
|
1876
1856
|
|
|
1877
1857
|
|
|
1878
1858
|
def _feature_predicate(obj):
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1859
|
+
return (
|
|
1860
|
+
inspect.isclass(obj) and source.isfrommain(obj) and issubclass(obj, BaseModel)
|
|
1861
|
+
)
|
|
1882
1862
|
|
|
1883
1863
|
|
|
1884
1864
|
def _imports(obj):
|
datachain/query/dispatch.py
CHANGED
|
@@ -16,7 +16,6 @@ from multiprocess import get_context
|
|
|
16
16
|
|
|
17
17
|
from datachain.catalog import Catalog
|
|
18
18
|
from datachain.catalog.loader import get_distributed_class
|
|
19
|
-
from datachain.query.batch import RowBatch
|
|
20
19
|
from datachain.query.dataset import (
|
|
21
20
|
get_download_callback,
|
|
22
21
|
get_generated_callback,
|
|
@@ -355,6 +354,15 @@ class WorkerCallback(Callback):
|
|
|
355
354
|
put_into_queue(self.queue, {"status": NOTIFY_STATUS, "downloaded": inc})
|
|
356
355
|
|
|
357
356
|
|
|
357
|
+
class ProcessedCallback(Callback):
|
|
358
|
+
def __init__(self):
|
|
359
|
+
self.processed_rows: Optional[int] = None
|
|
360
|
+
super().__init__()
|
|
361
|
+
|
|
362
|
+
def relative_update(self, inc: int = 1) -> None:
|
|
363
|
+
self.processed_rows = inc
|
|
364
|
+
|
|
365
|
+
|
|
358
366
|
@attrs.define
|
|
359
367
|
class UDFWorker:
|
|
360
368
|
catalog: Catalog
|
|
@@ -370,25 +378,28 @@ class UDFWorker:
|
|
|
370
378
|
return WorkerCallback(self.done_queue)
|
|
371
379
|
|
|
372
380
|
def run(self) -> None:
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
)
|
|
381
|
+
processed_cb = ProcessedCallback()
|
|
382
|
+
udf_results = self.udf.run(
|
|
383
|
+
self.get_inputs(),
|
|
384
|
+
self.catalog,
|
|
385
|
+
self.is_generator,
|
|
386
|
+
self.cache,
|
|
387
|
+
download_cb=self.cb,
|
|
388
|
+
processed_cb=processed_cb,
|
|
389
|
+
)
|
|
390
|
+
for udf_output in udf_results:
|
|
384
391
|
if isinstance(udf_output, GeneratorType):
|
|
385
392
|
udf_output = list(udf_output) # can not pickle generator
|
|
386
393
|
put_into_queue(
|
|
387
394
|
self.done_queue,
|
|
388
|
-
{
|
|
395
|
+
{
|
|
396
|
+
"status": OK_STATUS,
|
|
397
|
+
"result": udf_output,
|
|
398
|
+
"processed": processed_cb.processed_rows,
|
|
399
|
+
},
|
|
389
400
|
)
|
|
390
|
-
|
|
391
|
-
if hasattr(self.udf.func, "teardown") and callable(self.udf.func.teardown):
|
|
392
|
-
self.udf.func.teardown()
|
|
393
|
-
|
|
394
401
|
put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
|
|
402
|
+
|
|
403
|
+
def get_inputs(self):
|
|
404
|
+
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
405
|
+
yield batch
|
datachain/query/udf.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import typing
|
|
2
|
-
from collections.abc import Iterable, Mapping, Sequence
|
|
2
|
+
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from functools import WRAPPER_ASSIGNMENTS
|
|
5
5
|
from inspect import isclass
|
|
@@ -14,7 +14,6 @@ from typing import (
|
|
|
14
14
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
15
15
|
|
|
16
16
|
from datachain.dataset import RowDict
|
|
17
|
-
from datachain.lib.utils import AbstractUDF
|
|
18
17
|
|
|
19
18
|
from .batch import Batch, BatchingStrategy, NoBatching, Partition, RowBatch
|
|
20
19
|
from .schema import (
|
|
@@ -100,15 +99,28 @@ class UDFBase:
|
|
|
100
99
|
|
|
101
100
|
def __init__(
|
|
102
101
|
self,
|
|
103
|
-
func: Callable,
|
|
104
102
|
properties: UDFProperties,
|
|
105
103
|
):
|
|
106
|
-
self.func = func
|
|
107
104
|
self.properties = properties
|
|
108
105
|
self.signal_names = properties.signal_names()
|
|
109
106
|
self.output = properties.output
|
|
110
107
|
|
|
111
|
-
def
|
|
108
|
+
def run(
|
|
109
|
+
self,
|
|
110
|
+
udf_inputs: "Iterable[BatchingResult]",
|
|
111
|
+
catalog: "Catalog",
|
|
112
|
+
is_generator: bool,
|
|
113
|
+
cache: bool,
|
|
114
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
115
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
116
|
+
) -> Iterator[Iterable["UDFResult"]]:
|
|
117
|
+
for batch in udf_inputs:
|
|
118
|
+
n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
|
|
119
|
+
output = self.run_once(catalog, batch, is_generator, cache, cb=download_cb)
|
|
120
|
+
processed_cb.relative_update(n_rows)
|
|
121
|
+
yield output
|
|
122
|
+
|
|
123
|
+
def run_once(
|
|
112
124
|
self,
|
|
113
125
|
catalog: "Catalog",
|
|
114
126
|
arg: "BatchingResult",
|
|
@@ -116,24 +128,7 @@ class UDFBase:
|
|
|
116
128
|
cache: bool = False,
|
|
117
129
|
cb: Callback = DEFAULT_CALLBACK,
|
|
118
130
|
) -> Iterable[UDFResult]:
|
|
119
|
-
|
|
120
|
-
self.func._catalog = catalog # type: ignore[unreachable]
|
|
121
|
-
|
|
122
|
-
if isinstance(arg, RowBatch):
|
|
123
|
-
udf_inputs = [
|
|
124
|
-
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
125
|
-
for row in arg.rows
|
|
126
|
-
]
|
|
127
|
-
udf_outputs = self.func(udf_inputs)
|
|
128
|
-
return self._process_results(arg.rows, udf_outputs, is_generator)
|
|
129
|
-
if isinstance(arg, RowDict):
|
|
130
|
-
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
|
|
131
|
-
udf_outputs = self.func(*udf_inputs)
|
|
132
|
-
if not is_generator:
|
|
133
|
-
# udf_outputs is generator already if is_generator=True
|
|
134
|
-
udf_outputs = [udf_outputs]
|
|
135
|
-
return self._process_results([arg], udf_outputs, is_generator)
|
|
136
|
-
raise ValueError(f"Unexpected UDF argument: {arg}")
|
|
131
|
+
raise NotImplementedError
|
|
137
132
|
|
|
138
133
|
def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
139
134
|
return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
|
|
@@ -152,9 +147,9 @@ class UDFBase:
|
|
|
152
147
|
return (dict(zip(self.signal_names, row)) for row in results)
|
|
153
148
|
|
|
154
149
|
# outputting signals
|
|
155
|
-
row_ids = [row["
|
|
150
|
+
row_ids = [row["sys__id"] for row in rows]
|
|
156
151
|
return [
|
|
157
|
-
|
|
152
|
+
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
158
153
|
for row_id, signals in zip(row_ids, results)
|
|
159
154
|
if signals is not None # skip rows with no output
|
|
160
155
|
]
|
|
@@ -194,12 +189,37 @@ class UDFWrapper(UDFBase):
|
|
|
194
189
|
func: Callable,
|
|
195
190
|
properties: UDFProperties,
|
|
196
191
|
):
|
|
197
|
-
|
|
192
|
+
self.func = func
|
|
193
|
+
super().__init__(properties)
|
|
198
194
|
# This emulates the behavior of functools.wraps for a class decorator
|
|
199
195
|
for attr in WRAPPER_ASSIGNMENTS:
|
|
200
196
|
if hasattr(func, attr):
|
|
201
197
|
setattr(self, attr, getattr(func, attr))
|
|
202
198
|
|
|
199
|
+
def run_once(
|
|
200
|
+
self,
|
|
201
|
+
catalog: "Catalog",
|
|
202
|
+
arg: "BatchingResult",
|
|
203
|
+
is_generator: bool = False,
|
|
204
|
+
cache: bool = False,
|
|
205
|
+
cb: Callback = DEFAULT_CALLBACK,
|
|
206
|
+
) -> Iterable[UDFResult]:
|
|
207
|
+
if isinstance(arg, RowBatch):
|
|
208
|
+
udf_inputs = [
|
|
209
|
+
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
210
|
+
for row in arg.rows
|
|
211
|
+
]
|
|
212
|
+
udf_outputs = self.func(udf_inputs)
|
|
213
|
+
return self._process_results(arg.rows, udf_outputs, is_generator)
|
|
214
|
+
if isinstance(arg, RowDict):
|
|
215
|
+
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
|
|
216
|
+
udf_outputs = self.func(*udf_inputs)
|
|
217
|
+
if not is_generator:
|
|
218
|
+
# udf_outputs is generator already if is_generator=True
|
|
219
|
+
udf_outputs = [udf_outputs]
|
|
220
|
+
return self._process_results([arg], udf_outputs, is_generator)
|
|
221
|
+
raise ValueError(f"Unexpected UDF argument: {arg}")
|
|
222
|
+
|
|
203
223
|
# This emulates the behavior of functools.wraps for a class decorator
|
|
204
224
|
def __repr__(self):
|
|
205
225
|
return repr(self.func)
|
datachain/remote/studio.py
CHANGED
|
@@ -190,19 +190,11 @@ class StudioClient:
|
|
|
190
190
|
def dataset_rows_chunk(
|
|
191
191
|
self, name: str, version: int, offset: int
|
|
192
192
|
) -> Response[DatasetRowsData]:
|
|
193
|
-
def _parse_row(row):
|
|
194
|
-
row["id"] = int(row["id"])
|
|
195
|
-
return row
|
|
196
|
-
|
|
197
193
|
req_data = {"dataset_name": name, "dataset_version": version}
|
|
198
|
-
|
|
194
|
+
return self._send_request_msgpack(
|
|
199
195
|
"dataset-rows",
|
|
200
196
|
{**req_data, "offset": offset, "limit": DATASET_ROWS_CHUNK_SIZE},
|
|
201
197
|
)
|
|
202
|
-
if response.ok:
|
|
203
|
-
response.data = [_parse_row(r) for r in response.data]
|
|
204
|
-
|
|
205
|
-
return response
|
|
206
198
|
|
|
207
199
|
def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]:
|
|
208
200
|
response = self._send_request(
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
try:
|
|
2
|
+
from datachain.lib.clip import similarity_scores as clip_similarity_scores
|
|
3
|
+
from datachain.lib.image import convert_image, convert_images
|
|
4
|
+
from datachain.lib.pytorch import PytorchDataset, label_to_int
|
|
5
|
+
from datachain.lib.text import convert_text
|
|
6
|
+
|
|
7
|
+
except ImportError as exc:
|
|
8
|
+
raise ImportError(
|
|
9
|
+
"Missing dependencies for torch:\n"
|
|
10
|
+
"To install run:\n\n"
|
|
11
|
+
" pip install 'datachain[torch]'\n"
|
|
12
|
+
) from exc
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"PytorchDataset",
|
|
16
|
+
"clip_similarity_scores",
|
|
17
|
+
"convert_image",
|
|
18
|
+
"convert_images",
|
|
19
|
+
"convert_text",
|
|
20
|
+
"label_to_int",
|
|
21
|
+
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: datachain
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.10
|
|
4
4
|
Summary: Wrangle unstructured AI data at scale
|
|
5
5
|
Author-email: Dmitry Petrov <support@dvc.org>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -38,12 +38,8 @@ Requires-Dist: ujson >=5.9.0
|
|
|
38
38
|
Requires-Dist: pydantic <3,>=2
|
|
39
39
|
Requires-Dist: jmespath >=1.0
|
|
40
40
|
Requires-Dist: datamodel-code-generator >=0.25
|
|
41
|
+
Requires-Dist: Pillow <11,>=10.0.0
|
|
41
42
|
Requires-Dist: numpy <2,>=1 ; sys_platform == "win32"
|
|
42
|
-
Provides-Extra: cv
|
|
43
|
-
Requires-Dist: Pillow <11,>=10.0.0 ; extra == 'cv'
|
|
44
|
-
Requires-Dist: torch >=2.1.0 ; extra == 'cv'
|
|
45
|
-
Requires-Dist: torchvision ; extra == 'cv'
|
|
46
|
-
Requires-Dist: transformers >=4.36.0 ; extra == 'cv'
|
|
47
43
|
Provides-Extra: dev
|
|
48
44
|
Requires-Dist: datachain[docs,tests] ; extra == 'dev'
|
|
49
45
|
Requires-Dist: mypy ==1.10.1 ; extra == 'dev'
|
|
@@ -63,7 +59,7 @@ Requires-Dist: lz4 ; extra == 'remote'
|
|
|
63
59
|
Requires-Dist: msgpack <2,>=1.0.4 ; extra == 'remote'
|
|
64
60
|
Requires-Dist: requests >=2.22.0 ; extra == 'remote'
|
|
65
61
|
Provides-Extra: tests
|
|
66
|
-
Requires-Dist: datachain[
|
|
62
|
+
Requires-Dist: datachain[remote,torch,vector] ; extra == 'tests'
|
|
67
63
|
Requires-Dist: pytest <9,>=8 ; extra == 'tests'
|
|
68
64
|
Requires-Dist: pytest-sugar >=0.9.6 ; extra == 'tests'
|
|
69
65
|
Requires-Dist: pytest-cov >=4.1.0 ; extra == 'tests'
|
|
@@ -78,6 +74,10 @@ Requires-Dist: hypothesis ; extra == 'tests'
|
|
|
78
74
|
Requires-Dist: open-clip-torch ; extra == 'tests'
|
|
79
75
|
Requires-Dist: aiotools >=1.7.0 ; extra == 'tests'
|
|
80
76
|
Requires-Dist: requests-mock ; extra == 'tests'
|
|
77
|
+
Provides-Extra: torch
|
|
78
|
+
Requires-Dist: torch >=2.1.0 ; extra == 'torch'
|
|
79
|
+
Requires-Dist: torchvision ; extra == 'torch'
|
|
80
|
+
Requires-Dist: transformers >=4.36.0 ; extra == 'torch'
|
|
81
81
|
Provides-Extra: vector
|
|
82
82
|
Requires-Dist: usearch ; extra == 'vector'
|
|
83
83
|
|
|
@@ -89,11 +89,11 @@ Requires-Dist: usearch ; extra == 'vector'
|
|
|
89
89
|
.. |Python Version| image:: https://img.shields.io/pypi/pyversions/datachain
|
|
90
90
|
:target: https://pypi.org/project/datachain
|
|
91
91
|
:alt: Python Version
|
|
92
|
-
.. |Codecov| image:: https://codecov.io/gh/iterative/
|
|
93
|
-
:target: https://
|
|
92
|
+
.. |Codecov| image:: https://codecov.io/gh/iterative/datachain/graph/badge.svg?token=byliXGGyGB
|
|
93
|
+
:target: https://codecov.io/gh/iterative/datachain
|
|
94
94
|
:alt: Codecov
|
|
95
|
-
.. |Tests| image:: https://github.com/iterative/
|
|
96
|
-
:target: https://github.com/iterative/
|
|
95
|
+
.. |Tests| image:: https://github.com/iterative/datachain/workflows/Tests/badge.svg
|
|
96
|
+
:target: https://github.com/iterative/datachain/actions?workflow=Tests
|
|
97
97
|
:alt: Tests
|
|
98
98
|
|
|
99
99
|
AI 🔗 DataChain
|
|
@@ -397,7 +397,8 @@ Chain results can be exported or passed directly to Pytorch dataloader. For exam
|
|
|
397
397
|
Tutorials
|
|
398
398
|
------------------
|
|
399
399
|
|
|
400
|
-
* `
|
|
400
|
+
* `Computer Vision <examples/computer_vision/fashion_product_images/1-quick-start.ipynb>`_ (try in `Colab <https://colab.research.google.com/github/iterative/datachain/blob/main/examples/computer_vision/fashion_product_images/1-quick-start.ipynb>`__)
|
|
401
|
+
* `Multimodal <examples/multimodal/clip_fine_tuning.ipynb>`_ (try in `Colab <https://colab.research.google.com/github/iterative/datachain/blob/main/examples/multimodal/clip_fine_tuning.ipynb>`__)
|
|
401
402
|
|
|
402
403
|
Contributions
|
|
403
404
|
--------------------
|