datachain 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (50) hide show
  1. datachain/__init__.py +17 -8
  2. datachain/catalog/catalog.py +5 -5
  3. datachain/cli.py +0 -2
  4. datachain/data_storage/schema.py +5 -5
  5. datachain/data_storage/sqlite.py +1 -1
  6. datachain/data_storage/warehouse.py +7 -7
  7. datachain/lib/arrow.py +25 -8
  8. datachain/lib/clip.py +6 -11
  9. datachain/lib/convert/__init__.py +0 -0
  10. datachain/lib/convert/flatten.py +67 -0
  11. datachain/lib/convert/type_converter.py +96 -0
  12. datachain/lib/convert/unflatten.py +69 -0
  13. datachain/lib/convert/values_to_tuples.py +85 -0
  14. datachain/lib/data_model.py +74 -0
  15. datachain/lib/dc.py +192 -167
  16. datachain/lib/feature_registry.py +36 -10
  17. datachain/lib/file.py +41 -41
  18. datachain/lib/gpt4_vision.py +1 -9
  19. datachain/lib/hf_image_to_text.py +9 -17
  20. datachain/lib/hf_pipeline.py +4 -12
  21. datachain/lib/image.py +2 -18
  22. datachain/lib/image_transform.py +0 -1
  23. datachain/lib/iptc_exif_xmp.py +8 -15
  24. datachain/lib/meta_formats.py +1 -5
  25. datachain/lib/model_store.py +77 -0
  26. datachain/lib/pytorch.py +9 -21
  27. datachain/lib/signal_schema.py +120 -58
  28. datachain/lib/text.py +5 -16
  29. datachain/lib/udf.py +114 -30
  30. datachain/lib/udf_signature.py +5 -5
  31. datachain/lib/webdataset.py +3 -4
  32. datachain/lib/webdataset_laion.py +2 -3
  33. datachain/node.py +4 -4
  34. datachain/query/batch.py +1 -1
  35. datachain/query/dataset.py +40 -60
  36. datachain/query/dispatch.py +28 -17
  37. datachain/query/udf.py +46 -26
  38. datachain/remote/studio.py +1 -9
  39. datachain/torch/__init__.py +21 -0
  40. {datachain-0.2.9.dist-info → datachain-0.2.10.dist-info}/METADATA +13 -12
  41. {datachain-0.2.9.dist-info → datachain-0.2.10.dist-info}/RECORD +45 -42
  42. datachain/image/__init__.py +0 -3
  43. datachain/lib/cached_stream.py +0 -38
  44. datachain/lib/claude.py +0 -69
  45. datachain/lib/feature.py +0 -412
  46. datachain/lib/feature_utils.py +0 -154
  47. {datachain-0.2.9.dist-info → datachain-0.2.10.dist-info}/LICENSE +0 -0
  48. {datachain-0.2.9.dist-info → datachain-0.2.10.dist-info}/WHEEL +0 -0
  49. {datachain-0.2.9.dist-info → datachain-0.2.10.dist-info}/entry_points.txt +0 -0
  50. {datachain-0.2.9.dist-info → datachain-0.2.10.dist-info}/top_level.txt +0 -0
datachain/node.py CHANGED
@@ -46,8 +46,8 @@ class DirTypeGroup:
46
46
 
47
47
  @attrs.define
48
48
  class Node:
49
- id: int = 0
50
- random: int = -1
49
+ sys__id: int = 0
50
+ sys__rand: int = -1
51
51
  vtype: str = ""
52
52
  dir_type: Optional[int] = None
53
53
  parent: str = ""
@@ -127,11 +127,11 @@ class Node:
127
127
 
128
128
  @classmethod
129
129
  def from_dir(cls, parent, name, **kwargs) -> "Node":
130
- return cls(id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs)
130
+ return cls(sys__id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs)
131
131
 
132
132
  @classmethod
133
133
  def root(cls) -> "Node":
134
- return cls(-1, dir_type=DirType.DIR)
134
+ return cls(sys__id=-1, dir_type=DirType.DIR)
135
135
 
136
136
 
137
137
  @attrs.define
datachain/query/batch.py CHANGED
@@ -104,7 +104,7 @@ class Partition(BatchingStrategy):
104
104
  with contextlib.closing(
105
105
  execute(
106
106
  query,
107
- order_by=(PARTITION_COLUMN_ID, "id", *query._order_by_clauses),
107
+ order_by=(PARTITION_COLUMN_ID, "sys__id", *query._order_by_clauses),
108
108
  limit=query._limit,
109
109
  )
110
110
  ) as rows:
@@ -31,6 +31,7 @@ import sqlalchemy
31
31
  from attrs import frozen
32
32
  from dill import dumps, source
33
33
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
34
+ from pydantic import BaseModel
34
35
  from sqlalchemy import Column
35
36
  from sqlalchemy.sql import func as f
36
37
  from sqlalchemy.sql.elements import ColumnClause, ColumnElement
@@ -57,7 +58,6 @@ from datachain.sql.functions import rand
57
58
  from datachain.storage import Storage, StorageURI
58
59
  from datachain.utils import batched, determine_processes, inside_notebook
59
60
 
60
- from .batch import RowBatch
61
61
  from .metrics import metrics
62
62
  from .schema import C, UDFParamSpec, normalize_param
63
63
  from .session import Session
@@ -257,7 +257,7 @@ class DatasetDiffOperation(Step):
257
257
  """
258
258
 
259
259
  def apply(self, query_generator, temp_tables: list[str]):
260
- source_query = query_generator.exclude(("id",))
260
+ source_query = query_generator.exclude(("sys__id",))
261
261
  target_query = self.dq.apply_steps().select()
262
262
  temp_tables.extend(self.dq.temp_table_names)
263
263
 
@@ -427,22 +427,6 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
427
427
  return DEFAULT_CALLBACK
428
428
 
429
429
 
430
- def run_udf(
431
- udf,
432
- udf_inputs,
433
- catalog,
434
- is_generator,
435
- cache,
436
- download_cb: Callback = DEFAULT_CALLBACK,
437
- processed_cb: Callback = DEFAULT_CALLBACK,
438
- ) -> Iterator[Iterable["UDFResult"]]:
439
- for batch in udf_inputs:
440
- n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
441
- output = udf(catalog, batch, is_generator, cache, cb=download_cb)
442
- processed_cb.relative_update(n_rows)
443
- yield output
444
-
445
-
446
430
  @frozen
447
431
  class UDF(Step, ABC):
448
432
  udf: UDFType
@@ -548,9 +532,6 @@ class UDF(Step, ABC):
548
532
  else:
549
533
  udf = self.udf
550
534
 
551
- if hasattr(udf.func, "setup") and callable(udf.func.setup):
552
- udf.func.setup()
553
-
554
535
  warehouse = self.catalog.warehouse
555
536
 
556
537
  with contextlib.closing(
@@ -560,8 +541,7 @@ class UDF(Step, ABC):
560
541
  processed_cb = get_processed_callback()
561
542
  generated_cb = get_generated_callback(self.is_generator)
562
543
  try:
563
- udf_results = run_udf(
564
- udf,
544
+ udf_results = udf.run(
565
545
  udf_inputs,
566
546
  self.catalog,
567
547
  self.is_generator,
@@ -583,9 +563,6 @@ class UDF(Step, ABC):
583
563
 
584
564
  warehouse.insert_rows_done(udf_table)
585
565
 
586
- if hasattr(udf.func, "teardown") and callable(udf.func.teardown):
587
- udf.func.teardown()
588
-
589
566
  except QueryScriptCancelError:
590
567
  self.catalog.warehouse.close()
591
568
  sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
@@ -663,7 +640,7 @@ class UDF(Step, ABC):
663
640
 
664
641
  # fill table with partitions
665
642
  cols = [
666
- query.selected_columns.id,
643
+ query.selected_columns.sys__id,
667
644
  f.dense_rank().over(order_by=list_partition_by).label(PARTITION_COLUMN_ID),
668
645
  ]
669
646
  self.catalog.warehouse.db.execute(
@@ -697,7 +674,7 @@ class UDF(Step, ABC):
697
674
  subq = query.subquery()
698
675
  query = (
699
676
  sqlalchemy.select(*subq.c)
700
- .outerjoin(partition_tbl, partition_tbl.c.id == subq.c.id)
677
+ .outerjoin(partition_tbl, partition_tbl.c.sys__id == subq.c.sys__id)
701
678
  .add_columns(*partition_columns())
702
679
  )
703
680
 
@@ -729,18 +706,18 @@ class UDFSignal(UDF):
729
706
  columns = [
730
707
  sqlalchemy.Column(c.name, c.type)
731
708
  for c in query.selected_columns
732
- if c.name != "id"
709
+ if c.name != "sys__id"
733
710
  ]
734
711
  table = self.catalog.warehouse.create_udf_table(self.udf_table_name(), columns)
735
712
  select_q = query.with_only_columns(
736
- *[c for c in query.selected_columns if c.name != "id"]
713
+ *[c for c in query.selected_columns if c.name != "sys__id"]
737
714
  )
738
715
 
739
716
  # if there is order by clause we need row_number to preserve order
740
717
  # if there is no order by clause we still need row_number to generate
741
718
  # unique ids as uniqueness is important for this table
742
719
  select_q = select_q.add_columns(
743
- f.row_number().over(order_by=select_q._order_by_clauses).label("id")
720
+ f.row_number().over(order_by=select_q._order_by_clauses).label("sys__id")
744
721
  )
745
722
 
746
723
  self.catalog.warehouse.db.execute(
@@ -756,7 +733,7 @@ class UDFSignal(UDF):
756
733
  if query._order_by_clauses:
757
734
  # we are adding ordering only if it's explicitly added by user in
758
735
  # query part before adding signals
759
- q = q.order_by(table.c.id)
736
+ q = q.order_by(table.c.sys__id)
760
737
  return q, [table]
761
738
 
762
739
  def create_result_query(
@@ -766,7 +743,7 @@ class UDFSignal(UDF):
766
743
  original_cols = [c for c in subq.c if c.name not in partition_col_names]
767
744
 
768
745
  # new signal columns that are added to udf_table
769
- signal_cols = [c for c in udf_table.c if c.name != "id"]
746
+ signal_cols = [c for c in udf_table.c if c.name != "sys__id"]
770
747
  signal_name_cols = {c.name: c for c in signal_cols}
771
748
  cols = signal_cols
772
749
 
@@ -786,7 +763,7 @@ class UDFSignal(UDF):
786
763
  res = (
787
764
  sqlalchemy.select(*cols1)
788
765
  .select_from(subq)
789
- .outerjoin(udf_table, udf_table.c.id == subq.c.id)
766
+ .outerjoin(udf_table, udf_table.c.sys__id == subq.c.sys__id)
790
767
  .add_columns(*cols2)
791
768
  )
792
769
  else:
@@ -795,7 +772,7 @@ class UDFSignal(UDF):
795
772
  if query._order_by_clauses:
796
773
  # if ordering is used in query part before adding signals, we
797
774
  # will have it as order by id from select from pre-created udf table
798
- res = res.order_by(subq.c.id)
775
+ res = res.order_by(subq.c.sys__id)
799
776
 
800
777
  if self.partition_by is not None:
801
778
  subquery = res.subquery()
@@ -833,7 +810,7 @@ class RowGenerator(UDF):
833
810
  # we get the same rows as we got as inputs of UDF since selecting
834
811
  # without ordering can be non deterministic in some databases
835
812
  c = query.selected_columns
836
- query = query.order_by(c.id)
813
+ query = query.order_by(c.sys__id)
837
814
 
838
815
  udf_table_query = udf_table.select().subquery()
839
816
  udf_table_cols: list[sqlalchemy.Label[Any]] = [
@@ -1025,7 +1002,7 @@ class SQLJoin(Step):
1025
1002
  q1_column_names = {c.name for c in q1_columns}
1026
1003
  q2_columns = [
1027
1004
  c
1028
- if c.name not in q1_column_names and c.name != "id"
1005
+ if c.name not in q1_column_names and c.name != "sys__id"
1029
1006
  else c.label(self.rname.format(name=c.name))
1030
1007
  for c in q2.c
1031
1008
  ]
@@ -1165,8 +1142,8 @@ class DatasetQuery:
1165
1142
  self.version = version or ds.latest_version
1166
1143
  self.feature_schema = ds.get_version(self.version).feature_schema
1167
1144
  self.column_types = copy(ds.schema)
1168
- if "id" in self.column_types:
1169
- self.column_types.pop("id")
1145
+ if "sys__id" in self.column_types:
1146
+ self.column_types.pop("sys__id")
1170
1147
  self.starting_step = QueryStep(self.catalog, name, self.version)
1171
1148
  # attaching to specific dataset
1172
1149
  self.name = name
@@ -1239,7 +1216,7 @@ class DatasetQuery:
1239
1216
  query.steps = self._chunk_limit(query.steps, index, total)
1240
1217
 
1241
1218
  # Prepend the chunk filter to the step chain.
1242
- query = query.filter(C.random % total == index)
1219
+ query = query.filter(C.sys__rand % total == index)
1243
1220
  query.steps = query.steps[-1:] + query.steps[:-1]
1244
1221
 
1245
1222
  result = query.starting_step.apply()
@@ -1366,10 +1343,8 @@ class DatasetQuery:
1366
1343
  finally:
1367
1344
  self.cleanup()
1368
1345
 
1369
- def to_records(self) -> list[dict]:
1370
- with self.as_iterable() as result:
1371
- cols = result.columns
1372
- return [dict(zip(cols, row)) for row in result]
1346
+ def to_records(self) -> list[dict[str, Any]]:
1347
+ return self.results(lambda cols, row: dict(zip(cols, row)))
1373
1348
 
1374
1349
  def to_pandas(self) -> "pd.DataFrame":
1375
1350
  records = self.to_records()
@@ -1379,7 +1354,7 @@ class DatasetQuery:
1379
1354
 
1380
1355
  def shuffle(self) -> "Self":
1381
1356
  # ToDo: implement shaffle based on seed and/or generating random column
1382
- return self.order_by(C.random)
1357
+ return self.order_by(C.sys__rand)
1383
1358
 
1384
1359
  def sample(self, n) -> "Self":
1385
1360
  """
@@ -1508,30 +1483,35 @@ class DatasetQuery:
1508
1483
  query.steps.append(SQLOffset(offset))
1509
1484
  return query
1510
1485
 
1486
+ def as_scalar(self) -> Any:
1487
+ with self.as_iterable() as rows:
1488
+ row = next(iter(rows))
1489
+ return row[0]
1490
+
1511
1491
  def count(self) -> int:
1512
1492
  query = self.clone()
1513
1493
  query.steps.append(SQLCount())
1514
- return query.results()[0][0]
1494
+ return query.as_scalar()
1515
1495
 
1516
- def sum(self, col: ColumnElement):
1496
+ def sum(self, col: ColumnElement) -> int:
1517
1497
  query = self.clone()
1518
1498
  query.steps.append(SQLSelect((f.sum(col),)))
1519
- return query.results()[0][0]
1499
+ return query.as_scalar()
1520
1500
 
1521
- def avg(self, col: ColumnElement):
1501
+ def avg(self, col: ColumnElement) -> int:
1522
1502
  query = self.clone()
1523
1503
  query.steps.append(SQLSelect((f.avg(col),)))
1524
- return query.results()[0][0]
1504
+ return query.as_scalar()
1525
1505
 
1526
- def min(self, col: ColumnElement):
1506
+ def min(self, col: ColumnElement) -> int:
1527
1507
  query = self.clone()
1528
1508
  query.steps.append(SQLSelect((f.min(col),)))
1529
- return query.results()[0][0]
1509
+ return query.as_scalar()
1530
1510
 
1531
- def max(self, col: ColumnElement):
1511
+ def max(self, col: ColumnElement) -> int:
1532
1512
  query = self.clone()
1533
1513
  query.steps.append(SQLSelect((f.max(col),)))
1534
- return query.results()[0][0]
1514
+ return query.as_scalar()
1535
1515
 
1536
1516
  @detach
1537
1517
  def group_by(self, *cols: ColumnElement) -> "Self":
@@ -1723,7 +1703,7 @@ class DatasetQuery:
1723
1703
  c if isinstance(c, Column) else Column(c.name, c.type)
1724
1704
  for c in query.columns
1725
1705
  ]
1726
- if not [c for c in columns if c.name != "id"]:
1706
+ if not [c for c in columns if c.name != "sys__id"]:
1727
1707
  raise RuntimeError(
1728
1708
  "No columns to save in the query. "
1729
1709
  "Ensure at least one column (other than 'id') is selected."
@@ -1742,11 +1722,11 @@ class DatasetQuery:
1742
1722
 
1743
1723
  # Exclude the id column and let the db create it to avoid unique
1744
1724
  # constraint violations.
1745
- q = query.exclude(("id",))
1725
+ q = query.exclude(("sys__id",))
1746
1726
  if q._order_by_clauses:
1747
1727
  # ensuring we have id sorted by order by clause if it exists in a query
1748
1728
  q = q.add_columns(
1749
- f.row_number().over(order_by=q._order_by_clauses).label("id")
1729
+ f.row_number().over(order_by=q._order_by_clauses).label("sys__id")
1750
1730
  )
1751
1731
 
1752
1732
  cols = tuple(c.name for c in q.columns)
@@ -1876,9 +1856,9 @@ def _random_string(length: int) -> str:
1876
1856
 
1877
1857
 
1878
1858
  def _feature_predicate(obj):
1879
- from datachain.lib.feature import Feature
1880
-
1881
- return inspect.isclass(obj) and source.isfrommain(obj) and issubclass(obj, Feature)
1859
+ return (
1860
+ inspect.isclass(obj) and source.isfrommain(obj) and issubclass(obj, BaseModel)
1861
+ )
1882
1862
 
1883
1863
 
1884
1864
  def _imports(obj):
@@ -16,7 +16,6 @@ from multiprocess import get_context
16
16
 
17
17
  from datachain.catalog import Catalog
18
18
  from datachain.catalog.loader import get_distributed_class
19
- from datachain.query.batch import RowBatch
20
19
  from datachain.query.dataset import (
21
20
  get_download_callback,
22
21
  get_generated_callback,
@@ -355,6 +354,15 @@ class WorkerCallback(Callback):
355
354
  put_into_queue(self.queue, {"status": NOTIFY_STATUS, "downloaded": inc})
356
355
 
357
356
 
357
+ class ProcessedCallback(Callback):
358
+ def __init__(self):
359
+ self.processed_rows: Optional[int] = None
360
+ super().__init__()
361
+
362
+ def relative_update(self, inc: int = 1) -> None:
363
+ self.processed_rows = inc
364
+
365
+
358
366
  @attrs.define
359
367
  class UDFWorker:
360
368
  catalog: Catalog
@@ -370,25 +378,28 @@ class UDFWorker:
370
378
  return WorkerCallback(self.done_queue)
371
379
 
372
380
  def run(self) -> None:
373
- if hasattr(self.udf.func, "setup") and callable(self.udf.func.setup):
374
- self.udf.func.setup()
375
- while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
376
- n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
377
- udf_output = self.udf(
378
- self.catalog,
379
- batch,
380
- is_generator=self.is_generator,
381
- cache=self.cache,
382
- cb=self.cb,
383
- )
381
+ processed_cb = ProcessedCallback()
382
+ udf_results = self.udf.run(
383
+ self.get_inputs(),
384
+ self.catalog,
385
+ self.is_generator,
386
+ self.cache,
387
+ download_cb=self.cb,
388
+ processed_cb=processed_cb,
389
+ )
390
+ for udf_output in udf_results:
384
391
  if isinstance(udf_output, GeneratorType):
385
392
  udf_output = list(udf_output) # can not pickle generator
386
393
  put_into_queue(
387
394
  self.done_queue,
388
- {"status": OK_STATUS, "result": udf_output, "processed": n_rows},
395
+ {
396
+ "status": OK_STATUS,
397
+ "result": udf_output,
398
+ "processed": processed_cb.processed_rows,
399
+ },
389
400
  )
390
-
391
- if hasattr(self.udf.func, "teardown") and callable(self.udf.func.teardown):
392
- self.udf.func.teardown()
393
-
394
401
  put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
402
+
403
+ def get_inputs(self):
404
+ while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
405
+ yield batch
datachain/query/udf.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import typing
2
- from collections.abc import Iterable, Mapping, Sequence
2
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
3
3
  from dataclasses import dataclass
4
4
  from functools import WRAPPER_ASSIGNMENTS
5
5
  from inspect import isclass
@@ -14,7 +14,6 @@ from typing import (
14
14
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
15
15
 
16
16
  from datachain.dataset import RowDict
17
- from datachain.lib.utils import AbstractUDF
18
17
 
19
18
  from .batch import Batch, BatchingStrategy, NoBatching, Partition, RowBatch
20
19
  from .schema import (
@@ -100,15 +99,28 @@ class UDFBase:
100
99
 
101
100
  def __init__(
102
101
  self,
103
- func: Callable,
104
102
  properties: UDFProperties,
105
103
  ):
106
- self.func = func
107
104
  self.properties = properties
108
105
  self.signal_names = properties.signal_names()
109
106
  self.output = properties.output
110
107
 
111
- def __call__(
108
+ def run(
109
+ self,
110
+ udf_inputs: "Iterable[BatchingResult]",
111
+ catalog: "Catalog",
112
+ is_generator: bool,
113
+ cache: bool,
114
+ download_cb: Callback = DEFAULT_CALLBACK,
115
+ processed_cb: Callback = DEFAULT_CALLBACK,
116
+ ) -> Iterator[Iterable["UDFResult"]]:
117
+ for batch in udf_inputs:
118
+ n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
119
+ output = self.run_once(catalog, batch, is_generator, cache, cb=download_cb)
120
+ processed_cb.relative_update(n_rows)
121
+ yield output
122
+
123
+ def run_once(
112
124
  self,
113
125
  catalog: "Catalog",
114
126
  arg: "BatchingResult",
@@ -116,24 +128,7 @@ class UDFBase:
116
128
  cache: bool = False,
117
129
  cb: Callback = DEFAULT_CALLBACK,
118
130
  ) -> Iterable[UDFResult]:
119
- if isinstance(self.func, AbstractUDF):
120
- self.func._catalog = catalog # type: ignore[unreachable]
121
-
122
- if isinstance(arg, RowBatch):
123
- udf_inputs = [
124
- self.bind_parameters(catalog, row, cache=cache, cb=cb)
125
- for row in arg.rows
126
- ]
127
- udf_outputs = self.func(udf_inputs)
128
- return self._process_results(arg.rows, udf_outputs, is_generator)
129
- if isinstance(arg, RowDict):
130
- udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
131
- udf_outputs = self.func(*udf_inputs)
132
- if not is_generator:
133
- # udf_outputs is generator already if is_generator=True
134
- udf_outputs = [udf_outputs]
135
- return self._process_results([arg], udf_outputs, is_generator)
136
- raise ValueError(f"Unexpected UDF argument: {arg}")
131
+ raise NotImplementedError
137
132
 
138
133
  def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
139
134
  return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
@@ -152,9 +147,9 @@ class UDFBase:
152
147
  return (dict(zip(self.signal_names, row)) for row in results)
153
148
 
154
149
  # outputting signals
155
- row_ids = [row["id"] for row in rows]
150
+ row_ids = [row["sys__id"] for row in rows]
156
151
  return [
157
- dict(id=row_id, **dict(zip(self.signal_names, signals)))
152
+ {"sys__id": row_id} | dict(zip(self.signal_names, signals))
158
153
  for row_id, signals in zip(row_ids, results)
159
154
  if signals is not None # skip rows with no output
160
155
  ]
@@ -194,12 +189,37 @@ class UDFWrapper(UDFBase):
194
189
  func: Callable,
195
190
  properties: UDFProperties,
196
191
  ):
197
- super().__init__(func, properties)
192
+ self.func = func
193
+ super().__init__(properties)
198
194
  # This emulates the behavior of functools.wraps for a class decorator
199
195
  for attr in WRAPPER_ASSIGNMENTS:
200
196
  if hasattr(func, attr):
201
197
  setattr(self, attr, getattr(func, attr))
202
198
 
199
+ def run_once(
200
+ self,
201
+ catalog: "Catalog",
202
+ arg: "BatchingResult",
203
+ is_generator: bool = False,
204
+ cache: bool = False,
205
+ cb: Callback = DEFAULT_CALLBACK,
206
+ ) -> Iterable[UDFResult]:
207
+ if isinstance(arg, RowBatch):
208
+ udf_inputs = [
209
+ self.bind_parameters(catalog, row, cache=cache, cb=cb)
210
+ for row in arg.rows
211
+ ]
212
+ udf_outputs = self.func(udf_inputs)
213
+ return self._process_results(arg.rows, udf_outputs, is_generator)
214
+ if isinstance(arg, RowDict):
215
+ udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
216
+ udf_outputs = self.func(*udf_inputs)
217
+ if not is_generator:
218
+ # udf_outputs is generator already if is_generator=True
219
+ udf_outputs = [udf_outputs]
220
+ return self._process_results([arg], udf_outputs, is_generator)
221
+ raise ValueError(f"Unexpected UDF argument: {arg}")
222
+
203
223
  # This emulates the behavior of functools.wraps for a class decorator
204
224
  def __repr__(self):
205
225
  return repr(self.func)
@@ -190,19 +190,11 @@ class StudioClient:
190
190
  def dataset_rows_chunk(
191
191
  self, name: str, version: int, offset: int
192
192
  ) -> Response[DatasetRowsData]:
193
- def _parse_row(row):
194
- row["id"] = int(row["id"])
195
- return row
196
-
197
193
  req_data = {"dataset_name": name, "dataset_version": version}
198
- response = self._send_request_msgpack(
194
+ return self._send_request_msgpack(
199
195
  "dataset-rows",
200
196
  {**req_data, "offset": offset, "limit": DATASET_ROWS_CHUNK_SIZE},
201
197
  )
202
- if response.ok:
203
- response.data = [_parse_row(r) for r in response.data]
204
-
205
- return response
206
198
 
207
199
  def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]:
208
200
  response = self._send_request(
@@ -0,0 +1,21 @@
1
+ try:
2
+ from datachain.lib.clip import similarity_scores as clip_similarity_scores
3
+ from datachain.lib.image import convert_image, convert_images
4
+ from datachain.lib.pytorch import PytorchDataset, label_to_int
5
+ from datachain.lib.text import convert_text
6
+
7
+ except ImportError as exc:
8
+ raise ImportError(
9
+ "Missing dependencies for torch:\n"
10
+ "To install run:\n\n"
11
+ " pip install 'datachain[torch]'\n"
12
+ ) from exc
13
+
14
+ __all__ = [
15
+ "PytorchDataset",
16
+ "clip_similarity_scores",
17
+ "convert_image",
18
+ "convert_images",
19
+ "convert_text",
20
+ "label_to_int",
21
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: datachain
3
- Version: 0.2.9
3
+ Version: 0.2.10
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License: Apache-2.0
@@ -38,12 +38,8 @@ Requires-Dist: ujson >=5.9.0
38
38
  Requires-Dist: pydantic <3,>=2
39
39
  Requires-Dist: jmespath >=1.0
40
40
  Requires-Dist: datamodel-code-generator >=0.25
41
+ Requires-Dist: Pillow <11,>=10.0.0
41
42
  Requires-Dist: numpy <2,>=1 ; sys_platform == "win32"
42
- Provides-Extra: cv
43
- Requires-Dist: Pillow <11,>=10.0.0 ; extra == 'cv'
44
- Requires-Dist: torch >=2.1.0 ; extra == 'cv'
45
- Requires-Dist: torchvision ; extra == 'cv'
46
- Requires-Dist: transformers >=4.36.0 ; extra == 'cv'
47
43
  Provides-Extra: dev
48
44
  Requires-Dist: datachain[docs,tests] ; extra == 'dev'
49
45
  Requires-Dist: mypy ==1.10.1 ; extra == 'dev'
@@ -63,7 +59,7 @@ Requires-Dist: lz4 ; extra == 'remote'
63
59
  Requires-Dist: msgpack <2,>=1.0.4 ; extra == 'remote'
64
60
  Requires-Dist: requests >=2.22.0 ; extra == 'remote'
65
61
  Provides-Extra: tests
66
- Requires-Dist: datachain[cv,remote,vector] ; extra == 'tests'
62
+ Requires-Dist: datachain[remote,torch,vector] ; extra == 'tests'
67
63
  Requires-Dist: pytest <9,>=8 ; extra == 'tests'
68
64
  Requires-Dist: pytest-sugar >=0.9.6 ; extra == 'tests'
69
65
  Requires-Dist: pytest-cov >=4.1.0 ; extra == 'tests'
@@ -78,6 +74,10 @@ Requires-Dist: hypothesis ; extra == 'tests'
78
74
  Requires-Dist: open-clip-torch ; extra == 'tests'
79
75
  Requires-Dist: aiotools >=1.7.0 ; extra == 'tests'
80
76
  Requires-Dist: requests-mock ; extra == 'tests'
77
+ Provides-Extra: torch
78
+ Requires-Dist: torch >=2.1.0 ; extra == 'torch'
79
+ Requires-Dist: torchvision ; extra == 'torch'
80
+ Requires-Dist: transformers >=4.36.0 ; extra == 'torch'
81
81
  Provides-Extra: vector
82
82
  Requires-Dist: usearch ; extra == 'vector'
83
83
 
@@ -89,11 +89,11 @@ Requires-Dist: usearch ; extra == 'vector'
89
89
  .. |Python Version| image:: https://img.shields.io/pypi/pyversions/datachain
90
90
  :target: https://pypi.org/project/datachain
91
91
  :alt: Python Version
92
- .. |Codecov| image:: https://codecov.io/gh/iterative/dvcx/branch/main/graph/badge.svg?token=VSCP2T9R5X
93
- :target: https://app.codecov.io/gh/iterative/dvcx
92
+ .. |Codecov| image:: https://codecov.io/gh/iterative/datachain/graph/badge.svg?token=byliXGGyGB
93
+ :target: https://codecov.io/gh/iterative/datachain
94
94
  :alt: Codecov
95
- .. |Tests| image:: https://github.com/iterative/dvcx/workflows/Tests/badge.svg
96
- :target: https://github.com/iterative/dvcx/actions?workflow=Tests
95
+ .. |Tests| image:: https://github.com/iterative/datachain/workflows/Tests/badge.svg
96
+ :target: https://github.com/iterative/datachain/actions?workflow=Tests
97
97
  :alt: Tests
98
98
 
99
99
  AI 🔗 DataChain
@@ -397,7 +397,8 @@ Chain results can be exported or passed directly to Pytorch dataloader. For exam
397
397
  Tutorials
398
398
  ------------------
399
399
 
400
- * `Multimodal <examples/multimodal/clip_fine_tuning.ipynb>`_ (try in `Colab <https://colab.research.google.com/github/iterative/dvclive/blob/main/examples/multimodal/clip_fine_tuning.ipynb>`__)
400
+ * `Computer Vision <examples/computer_vision/fashion_product_images/1-quick-start.ipynb>`_ (try in `Colab <https://colab.research.google.com/github/iterative/datachain/blob/main/examples/computer_vision/fashion_product_images/1-quick-start.ipynb>`__)
401
+ * `Multimodal <examples/multimodal/clip_fine_tuning.ipynb>`_ (try in `Colab <https://colab.research.google.com/github/iterative/datachain/blob/main/examples/multimodal/clip_fine_tuning.ipynb>`__)
401
402
 
402
403
  Contributions
403
404
  --------------------