datachain 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (51) hide show
  1. datachain/__init__.py +17 -8
  2. datachain/catalog/catalog.py +5 -5
  3. datachain/cli.py +0 -2
  4. datachain/data_storage/schema.py +5 -5
  5. datachain/data_storage/sqlite.py +1 -1
  6. datachain/data_storage/warehouse.py +7 -7
  7. datachain/lib/arrow.py +25 -8
  8. datachain/lib/clip.py +6 -11
  9. datachain/lib/convert/__init__.py +0 -0
  10. datachain/lib/convert/flatten.py +67 -0
  11. datachain/lib/convert/type_converter.py +96 -0
  12. datachain/lib/convert/unflatten.py +69 -0
  13. datachain/lib/convert/values_to_tuples.py +85 -0
  14. datachain/lib/data_model.py +74 -0
  15. datachain/lib/dc.py +225 -168
  16. datachain/lib/file.py +41 -41
  17. datachain/lib/gpt4_vision.py +1 -9
  18. datachain/lib/hf_image_to_text.py +9 -17
  19. datachain/lib/hf_pipeline.py +4 -12
  20. datachain/lib/image.py +2 -18
  21. datachain/lib/image_transform.py +0 -1
  22. datachain/lib/iptc_exif_xmp.py +8 -15
  23. datachain/lib/meta_formats.py +1 -5
  24. datachain/lib/model_store.py +77 -0
  25. datachain/lib/pytorch.py +9 -21
  26. datachain/lib/signal_schema.py +139 -60
  27. datachain/lib/text.py +5 -16
  28. datachain/lib/udf.py +114 -30
  29. datachain/lib/udf_signature.py +5 -5
  30. datachain/lib/webdataset.py +3 -3
  31. datachain/lib/webdataset_laion.py +2 -3
  32. datachain/node.py +4 -4
  33. datachain/query/batch.py +1 -1
  34. datachain/query/dataset.py +51 -178
  35. datachain/query/dispatch.py +43 -30
  36. datachain/query/udf.py +46 -26
  37. datachain/remote/studio.py +1 -9
  38. datachain/torch/__init__.py +21 -0
  39. datachain/utils.py +39 -0
  40. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/METADATA +14 -12
  41. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/RECORD +45 -43
  42. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/WHEEL +1 -1
  43. datachain/image/__init__.py +0 -3
  44. datachain/lib/cached_stream.py +0 -38
  45. datachain/lib/claude.py +0 -69
  46. datachain/lib/feature.py +0 -412
  47. datachain/lib/feature_registry.py +0 -51
  48. datachain/lib/feature_utils.py +0 -154
  49. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/LICENSE +0 -0
  50. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/entry_points.txt +0 -0
  51. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,8 @@ from collections.abc import Iterator
2
2
  from typing import Optional
3
3
 
4
4
  import numpy as np
5
- from pydantic import Field
5
+ from pydantic import BaseModel, Field
6
6
 
7
- from datachain.lib.feature import Feature
8
7
  from datachain.lib.file import File
9
8
  from datachain.lib.webdataset import WDSBasic, WDSReadableSubclass
10
9
 
@@ -34,7 +33,7 @@ class WDSLaion(WDSBasic):
34
33
  json: Laion # type: ignore[assignment]
35
34
 
36
35
 
37
- class LaionMeta(Feature):
36
+ class LaionMeta(BaseModel):
38
37
  file: File
39
38
  index: Optional[int] = Field(default=None)
40
39
  b32_img: list[float] = Field(default=None)
datachain/node.py CHANGED
@@ -46,8 +46,8 @@ class DirTypeGroup:
46
46
 
47
47
  @attrs.define
48
48
  class Node:
49
- id: int = 0
50
- random: int = -1
49
+ sys__id: int = 0
50
+ sys__rand: int = -1
51
51
  vtype: str = ""
52
52
  dir_type: Optional[int] = None
53
53
  parent: str = ""
@@ -127,11 +127,11 @@ class Node:
127
127
 
128
128
  @classmethod
129
129
  def from_dir(cls, parent, name, **kwargs) -> "Node":
130
- return cls(id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs)
130
+ return cls(sys__id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs)
131
131
 
132
132
  @classmethod
133
133
  def root(cls) -> "Node":
134
- return cls(-1, dir_type=DirType.DIR)
134
+ return cls(sys__id=-1, dir_type=DirType.DIR)
135
135
 
136
136
 
137
137
  @attrs.define
datachain/query/batch.py CHANGED
@@ -104,7 +104,7 @@ class Partition(BatchingStrategy):
104
104
  with contextlib.closing(
105
105
  execute(
106
106
  query,
107
- order_by=(PARTITION_COLUMN_ID, "id", *query._order_by_clauses),
107
+ order_by=(PARTITION_COLUMN_ID, "sys__id", *query._order_by_clauses),
108
108
  limit=query._limit,
109
109
  )
110
110
  ) as rows:
@@ -1,4 +1,3 @@
1
- import ast
2
1
  import contextlib
3
2
  import datetime
4
3
  import inspect
@@ -10,7 +9,6 @@ import re
10
9
  import string
11
10
  import subprocess
12
11
  import sys
13
- import types
14
12
  from abc import ABC, abstractmethod
15
13
  from collections.abc import Generator, Iterable, Iterator, Sequence
16
14
  from copy import copy
@@ -26,10 +24,8 @@ from typing import (
26
24
  )
27
25
 
28
26
  import attrs
29
- import pandas as pd
30
27
  import sqlalchemy
31
28
  from attrs import frozen
32
- from dill import dumps, source
33
29
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
34
30
  from sqlalchemy import Column
35
31
  from sqlalchemy.sql import func as f
@@ -52,12 +48,14 @@ from datachain.data_storage.schema import (
52
48
  from datachain.dataset import DatasetStatus, RowDict
53
49
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
54
50
  from datachain.progress import CombinedDownloadCallback
55
- from datachain.query.schema import DEFAULT_DELIMITER
56
51
  from datachain.sql.functions import rand
57
52
  from datachain.storage import Storage, StorageURI
58
- from datachain.utils import batched, determine_processes, inside_notebook
53
+ from datachain.utils import (
54
+ batched,
55
+ determine_processes,
56
+ filtered_cloudpickle_dumps,
57
+ )
59
58
 
60
- from .batch import RowBatch
61
59
  from .metrics import metrics
62
60
  from .schema import C, UDFParamSpec, normalize_param
63
61
  from .session import Session
@@ -257,7 +255,7 @@ class DatasetDiffOperation(Step):
257
255
  """
258
256
 
259
257
  def apply(self, query_generator, temp_tables: list[str]):
260
- source_query = query_generator.exclude(("id",))
258
+ source_query = query_generator.exclude(("sys__id",))
261
259
  target_query = self.dq.apply_steps().select()
262
260
  temp_tables.extend(self.dq.temp_table_names)
263
261
 
@@ -427,22 +425,6 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
427
425
  return DEFAULT_CALLBACK
428
426
 
429
427
 
430
- def run_udf(
431
- udf,
432
- udf_inputs,
433
- catalog,
434
- is_generator,
435
- cache,
436
- download_cb: Callback = DEFAULT_CALLBACK,
437
- processed_cb: Callback = DEFAULT_CALLBACK,
438
- ) -> Iterator[Iterable["UDFResult"]]:
439
- for batch in udf_inputs:
440
- n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
441
- output = udf(catalog, batch, is_generator, cache, cb=download_cb)
442
- processed_cb.relative_update(n_rows)
443
- yield output
444
-
445
-
446
428
  @frozen
447
429
  class UDF(Step, ABC):
448
430
  udf: UDFType
@@ -508,7 +490,7 @@ class UDF(Step, ABC):
508
490
  elif processes:
509
491
  # Parallel processing (faster for more CPU-heavy UDFs)
510
492
  udf_info = {
511
- "udf": self.udf,
493
+ "udf_data": filtered_cloudpickle_dumps(self.udf),
512
494
  "catalog_init": self.catalog.get_init_params(),
513
495
  "id_generator_clone_params": (
514
496
  self.catalog.id_generator.clone_params()
@@ -529,16 +511,15 @@ class UDF(Step, ABC):
529
511
 
530
512
  envs = dict(os.environ)
531
513
  envs.update({"PYTHONPATH": os.getcwd()})
532
- with self.process_feature_module():
533
- process_data = dumps(udf_info, recurse=True)
534
- result = subprocess.run( # noqa: S603
535
- [datachain_exec_path, "--internal-run-udf"],
536
- input=process_data,
537
- check=False,
538
- env=envs,
539
- )
540
- if result.returncode != 0:
541
- raise RuntimeError("UDF Execution Failed!")
514
+ process_data = filtered_cloudpickle_dumps(udf_info)
515
+ result = subprocess.run( # noqa: S603
516
+ [datachain_exec_path, "--internal-run-udf"],
517
+ input=process_data,
518
+ check=False,
519
+ env=envs,
520
+ )
521
+ if result.returncode != 0:
522
+ raise RuntimeError("UDF Execution Failed!")
542
523
 
543
524
  else:
544
525
  # Otherwise process single-threaded (faster for smaller UDFs)
@@ -548,9 +529,6 @@ class UDF(Step, ABC):
548
529
  else:
549
530
  udf = self.udf
550
531
 
551
- if hasattr(udf.func, "setup") and callable(udf.func.setup):
552
- udf.func.setup()
553
-
554
532
  warehouse = self.catalog.warehouse
555
533
 
556
534
  with contextlib.closing(
@@ -560,8 +538,7 @@ class UDF(Step, ABC):
560
538
  processed_cb = get_processed_callback()
561
539
  generated_cb = get_generated_callback(self.is_generator)
562
540
  try:
563
- udf_results = run_udf(
564
- udf,
541
+ udf_results = udf.run(
565
542
  udf_inputs,
566
543
  self.catalog,
567
544
  self.is_generator,
@@ -583,9 +560,6 @@ class UDF(Step, ABC):
583
560
 
584
561
  warehouse.insert_rows_done(udf_table)
585
562
 
586
- if hasattr(udf.func, "teardown") and callable(udf.func.teardown):
587
- udf.func.teardown()
588
-
589
563
  except QueryScriptCancelError:
590
564
  self.catalog.warehouse.close()
591
565
  sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
@@ -594,57 +568,6 @@ class UDF(Step, ABC):
594
568
  self.catalog.warehouse.close()
595
569
  raise
596
570
 
597
- @contextlib.contextmanager
598
- def process_feature_module(self):
599
- # Generate a random name for the feature module
600
- feature_module_name = "tmp" + _random_string(10)
601
- # Create a dynamic module with the generated name
602
- dynamic_module = types.ModuleType(feature_module_name)
603
- # Get the import lines for the necessary objects from the main module
604
- main_module = sys.modules["__main__"]
605
- if getattr(main_module, "__file__", None):
606
- import_lines = list(get_imports(main_module))
607
- else:
608
- import_lines = [
609
- source.getimport(obj, alias=name)
610
- for name, obj in main_module.__dict__.items()
611
- if _imports(obj) and not (name.startswith("__") and name.endswith("__"))
612
- ]
613
-
614
- # Get the feature classes from the main module
615
- feature_classes = {
616
- name: obj
617
- for name, obj in main_module.__dict__.items()
618
- if _feature_predicate(obj)
619
- }
620
- if not feature_classes:
621
- yield None
622
- return
623
-
624
- # Get the source code of the feature classes
625
- feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
626
- # Set the module name for the feature classes to the generated name
627
- for name, cls in feature_classes.items():
628
- cls.__module__ = feature_module_name
629
- setattr(dynamic_module, name, cls)
630
- # Add the dynamic module to the sys.modules dictionary
631
- sys.modules[feature_module_name] = dynamic_module
632
- # Combine the import lines and feature sources
633
- feature_file = "\n".join(import_lines) + "\n" + "\n".join(feature_sources)
634
-
635
- # Write the module content to a .py file
636
- with open(f"{feature_module_name}.py", "w") as module_file:
637
- module_file.write(feature_file)
638
-
639
- try:
640
- yield feature_module_name
641
- finally:
642
- for cls in feature_classes.values():
643
- cls.__module__ = main_module.__name__
644
- os.unlink(f"{feature_module_name}.py")
645
- # Remove the dynamic module from sys.modules
646
- del sys.modules[feature_module_name]
647
-
648
571
  def create_partitions_table(self, query: Select) -> "Table":
649
572
  """
650
573
  Create temporary table with group by partitions.
@@ -663,7 +586,7 @@ class UDF(Step, ABC):
663
586
 
664
587
  # fill table with partitions
665
588
  cols = [
666
- query.selected_columns.id,
589
+ query.selected_columns.sys__id,
667
590
  f.dense_rank().over(order_by=list_partition_by).label(PARTITION_COLUMN_ID),
668
591
  ]
669
592
  self.catalog.warehouse.db.execute(
@@ -697,7 +620,7 @@ class UDF(Step, ABC):
697
620
  subq = query.subquery()
698
621
  query = (
699
622
  sqlalchemy.select(*subq.c)
700
- .outerjoin(partition_tbl, partition_tbl.c.id == subq.c.id)
623
+ .outerjoin(partition_tbl, partition_tbl.c.sys__id == subq.c.sys__id)
701
624
  .add_columns(*partition_columns())
702
625
  )
703
626
 
@@ -729,18 +652,18 @@ class UDFSignal(UDF):
729
652
  columns = [
730
653
  sqlalchemy.Column(c.name, c.type)
731
654
  for c in query.selected_columns
732
- if c.name != "id"
655
+ if c.name != "sys__id"
733
656
  ]
734
657
  table = self.catalog.warehouse.create_udf_table(self.udf_table_name(), columns)
735
658
  select_q = query.with_only_columns(
736
- *[c for c in query.selected_columns if c.name != "id"]
659
+ *[c for c in query.selected_columns if c.name != "sys__id"]
737
660
  )
738
661
 
739
662
  # if there is order by clause we need row_number to preserve order
740
663
  # if there is no order by clause we still need row_number to generate
741
664
  # unique ids as uniqueness is important for this table
742
665
  select_q = select_q.add_columns(
743
- f.row_number().over(order_by=select_q._order_by_clauses).label("id")
666
+ f.row_number().over(order_by=select_q._order_by_clauses).label("sys__id")
744
667
  )
745
668
 
746
669
  self.catalog.warehouse.db.execute(
@@ -756,7 +679,7 @@ class UDFSignal(UDF):
756
679
  if query._order_by_clauses:
757
680
  # we are adding ordering only if it's explicitly added by user in
758
681
  # query part before adding signals
759
- q = q.order_by(table.c.id)
682
+ q = q.order_by(table.c.sys__id)
760
683
  return q, [table]
761
684
 
762
685
  def create_result_query(
@@ -766,7 +689,7 @@ class UDFSignal(UDF):
766
689
  original_cols = [c for c in subq.c if c.name not in partition_col_names]
767
690
 
768
691
  # new signal columns that are added to udf_table
769
- signal_cols = [c for c in udf_table.c if c.name != "id"]
692
+ signal_cols = [c for c in udf_table.c if c.name != "sys__id"]
770
693
  signal_name_cols = {c.name: c for c in signal_cols}
771
694
  cols = signal_cols
772
695
 
@@ -786,7 +709,7 @@ class UDFSignal(UDF):
786
709
  res = (
787
710
  sqlalchemy.select(*cols1)
788
711
  .select_from(subq)
789
- .outerjoin(udf_table, udf_table.c.id == subq.c.id)
712
+ .outerjoin(udf_table, udf_table.c.sys__id == subq.c.sys__id)
790
713
  .add_columns(*cols2)
791
714
  )
792
715
  else:
@@ -795,7 +718,7 @@ class UDFSignal(UDF):
795
718
  if query._order_by_clauses:
796
719
  # if ordering is used in query part before adding signals, we
797
720
  # will have it as order by id from select from pre-created udf table
798
- res = res.order_by(subq.c.id)
721
+ res = res.order_by(subq.c.sys__id)
799
722
 
800
723
  if self.partition_by is not None:
801
724
  subquery = res.subquery()
@@ -833,7 +756,7 @@ class RowGenerator(UDF):
833
756
  # we get the same rows as we got as inputs of UDF since selecting
834
757
  # without ordering can be non deterministic in some databases
835
758
  c = query.selected_columns
836
- query = query.order_by(c.id)
759
+ query = query.order_by(c.sys__id)
837
760
 
838
761
  udf_table_query = udf_table.select().subquery()
839
762
  udf_table_cols: list[sqlalchemy.Label[Any]] = [
@@ -1025,7 +948,7 @@ class SQLJoin(Step):
1025
948
  q1_column_names = {c.name for c in q1_columns}
1026
949
  q2_columns = [
1027
950
  c
1028
- if c.name not in q1_column_names and c.name != "id"
951
+ if c.name not in q1_column_names and c.name != "sys__id"
1029
952
  else c.label(self.rname.format(name=c.name))
1030
953
  for c in q2.c
1031
954
  ]
@@ -1165,8 +1088,8 @@ class DatasetQuery:
1165
1088
  self.version = version or ds.latest_version
1166
1089
  self.feature_schema = ds.get_version(self.version).feature_schema
1167
1090
  self.column_types = copy(ds.schema)
1168
- if "id" in self.column_types:
1169
- self.column_types.pop("id")
1091
+ if "sys__id" in self.column_types:
1092
+ self.column_types.pop("sys__id")
1170
1093
  self.starting_step = QueryStep(self.catalog, name, self.version)
1171
1094
  # attaching to specific dataset
1172
1095
  self.name = name
@@ -1239,7 +1162,7 @@ class DatasetQuery:
1239
1162
  query.steps = self._chunk_limit(query.steps, index, total)
1240
1163
 
1241
1164
  # Prepend the chunk filter to the step chain.
1242
- query = query.filter(C.random % total == index)
1165
+ query = query.filter(C.sys__rand % total == index)
1243
1166
  query.steps = query.steps[-1:] + query.steps[:-1]
1244
1167
 
1245
1168
  result = query.starting_step.apply()
@@ -1366,20 +1289,12 @@ class DatasetQuery:
1366
1289
  finally:
1367
1290
  self.cleanup()
1368
1291
 
1369
- def to_records(self) -> list[dict]:
1370
- with self.as_iterable() as result:
1371
- cols = result.columns
1372
- return [dict(zip(cols, row)) for row in result]
1373
-
1374
- def to_pandas(self) -> "pd.DataFrame":
1375
- records = self.to_records()
1376
- df = pd.DataFrame.from_records(records)
1377
- df.columns = [c.replace(DEFAULT_DELIMITER, ".") for c in df.columns]
1378
- return df
1292
+ def to_records(self) -> list[dict[str, Any]]:
1293
+ return self.results(lambda cols, row: dict(zip(cols, row)))
1379
1294
 
1380
1295
  def shuffle(self) -> "Self":
1381
1296
  # ToDo: implement shaffle based on seed and/or generating random column
1382
- return self.order_by(C.random)
1297
+ return self.order_by(C.sys__rand)
1383
1298
 
1384
1299
  def sample(self, n) -> "Self":
1385
1300
  """
@@ -1395,22 +1310,6 @@ class DatasetQuery:
1395
1310
 
1396
1311
  return sampled.limit(n)
1397
1312
 
1398
- def show(self, limit=20) -> None:
1399
- df = self.limit(limit).to_pandas()
1400
-
1401
- options = ["display.max_colwidth", 50, "display.show_dimensions", False]
1402
- with pd.option_context(*options):
1403
- if inside_notebook():
1404
- from IPython.display import display
1405
-
1406
- display(df)
1407
-
1408
- else:
1409
- print(df.to_string())
1410
-
1411
- if len(df) == limit:
1412
- print(f"[limited by {limit} objects]")
1413
-
1414
1313
  def clone(self, new_table=True) -> "Self":
1415
1314
  obj = copy(self)
1416
1315
  obj.steps = obj.steps.copy()
@@ -1508,30 +1407,35 @@ class DatasetQuery:
1508
1407
  query.steps.append(SQLOffset(offset))
1509
1408
  return query
1510
1409
 
1410
+ def as_scalar(self) -> Any:
1411
+ with self.as_iterable() as rows:
1412
+ row = next(iter(rows))
1413
+ return row[0]
1414
+
1511
1415
  def count(self) -> int:
1512
1416
  query = self.clone()
1513
1417
  query.steps.append(SQLCount())
1514
- return query.results()[0][0]
1418
+ return query.as_scalar()
1515
1419
 
1516
- def sum(self, col: ColumnElement):
1420
+ def sum(self, col: ColumnElement) -> int:
1517
1421
  query = self.clone()
1518
1422
  query.steps.append(SQLSelect((f.sum(col),)))
1519
- return query.results()[0][0]
1423
+ return query.as_scalar()
1520
1424
 
1521
- def avg(self, col: ColumnElement):
1425
+ def avg(self, col: ColumnElement) -> int:
1522
1426
  query = self.clone()
1523
1427
  query.steps.append(SQLSelect((f.avg(col),)))
1524
- return query.results()[0][0]
1428
+ return query.as_scalar()
1525
1429
 
1526
- def min(self, col: ColumnElement):
1430
+ def min(self, col: ColumnElement) -> int:
1527
1431
  query = self.clone()
1528
1432
  query.steps.append(SQLSelect((f.min(col),)))
1529
- return query.results()[0][0]
1433
+ return query.as_scalar()
1530
1434
 
1531
- def max(self, col: ColumnElement):
1435
+ def max(self, col: ColumnElement) -> int:
1532
1436
  query = self.clone()
1533
1437
  query.steps.append(SQLSelect((f.max(col),)))
1534
- return query.results()[0][0]
1438
+ return query.as_scalar()
1535
1439
 
1536
1440
  @detach
1537
1441
  def group_by(self, *cols: ColumnElement) -> "Self":
@@ -1723,7 +1627,7 @@ class DatasetQuery:
1723
1627
  c if isinstance(c, Column) else Column(c.name, c.type)
1724
1628
  for c in query.columns
1725
1629
  ]
1726
- if not [c for c in columns if c.name != "id"]:
1630
+ if not [c for c in columns if c.name != "sys__id"]:
1727
1631
  raise RuntimeError(
1728
1632
  "No columns to save in the query. "
1729
1633
  "Ensure at least one column (other than 'id') is selected."
@@ -1742,11 +1646,11 @@ class DatasetQuery:
1742
1646
 
1743
1647
  # Exclude the id column and let the db create it to avoid unique
1744
1648
  # constraint violations.
1745
- q = query.exclude(("id",))
1649
+ q = query.exclude(("sys__id",))
1746
1650
  if q._order_by_clauses:
1747
1651
  # ensuring we have id sorted by order by clause if it exists in a query
1748
1652
  q = q.add_columns(
1749
- f.row_number().over(order_by=q._order_by_clauses).label("id")
1653
+ f.row_number().over(order_by=q._order_by_clauses).label("sys__id")
1750
1654
  )
1751
1655
 
1752
1656
  cols = tuple(c.name for c in q.columns)
@@ -1873,34 +1777,3 @@ def _random_string(length: int) -> str:
1873
1777
  random.choice(string.ascii_letters + string.digits) # noqa: S311
1874
1778
  for i in range(length)
1875
1779
  )
1876
-
1877
-
1878
- def _feature_predicate(obj):
1879
- from datachain.lib.feature import Feature
1880
-
1881
- return inspect.isclass(obj) and source.isfrommain(obj) and issubclass(obj, Feature)
1882
-
1883
-
1884
- def _imports(obj):
1885
- return not source.isfrommain(obj)
1886
-
1887
-
1888
- def get_imports(m):
1889
- root = ast.parse(inspect.getsource(m))
1890
-
1891
- for node in ast.iter_child_nodes(root):
1892
- if isinstance(node, ast.Import):
1893
- module = None
1894
- elif isinstance(node, ast.ImportFrom):
1895
- module = node.module
1896
- else:
1897
- continue
1898
-
1899
- for n in node.names:
1900
- import_script = ""
1901
- if module:
1902
- import_script += f"from {module} "
1903
- import_script += f"import {n.name}"
1904
- if n.asname:
1905
- import_script += f" as {n.asname}"
1906
- yield import_script
@@ -10,13 +10,12 @@ from typing import Any, Optional
10
10
 
11
11
  import attrs
12
12
  import multiprocess
13
- from dill import load
13
+ from cloudpickle import load, loads
14
14
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
15
15
  from multiprocess import get_context
16
16
 
17
17
  from datachain.catalog import Catalog
18
18
  from datachain.catalog.loader import get_distributed_class
19
- from datachain.query.batch import RowBatch
20
19
  from datachain.query.dataset import (
21
20
  get_download_callback,
22
21
  get_generated_callback,
@@ -85,7 +84,7 @@ def put_into_queue(queue: Queue, item: Any) -> None:
85
84
 
86
85
  def udf_entrypoint() -> int:
87
86
  # Load UDF info from stdin
88
- udf_info = load(stdin.buffer) # noqa: S301
87
+ udf_info = load(stdin.buffer)
89
88
 
90
89
  (
91
90
  warehouse_class,
@@ -96,7 +95,7 @@ def udf_entrypoint() -> int:
96
95
 
97
96
  # Parallel processing (faster for more CPU-heavy UDFs)
98
97
  dispatch = UDFDispatcher(
99
- udf_info["udf"],
98
+ udf_info["udf_data"],
100
99
  udf_info["catalog_init"],
101
100
  udf_info["id_generator_clone_params"],
102
101
  udf_info["metastore_clone_params"],
@@ -109,7 +108,7 @@ def udf_entrypoint() -> int:
109
108
  batching = udf_info["batching"]
110
109
  table = udf_info["table"]
111
110
  n_workers = udf_info["processes"]
112
- udf = udf_info["udf"]
111
+ udf = loads(udf_info["udf_data"])
113
112
  if n_workers is True:
114
113
  # Use default number of CPUs (cores)
115
114
  n_workers = None
@@ -147,7 +146,7 @@ class UDFDispatcher:
147
146
 
148
147
  def __init__(
149
148
  self,
150
- udf,
149
+ udf_data,
151
150
  catalog_init_params,
152
151
  id_generator_clone_params,
153
152
  metastore_clone_params,
@@ -156,14 +155,7 @@ class UDFDispatcher:
156
155
  is_generator=False,
157
156
  buffer_size=DEFAULT_BATCH_SIZE,
158
157
  ):
159
- # isinstance cannot be used here, as dill packages the entire class definition,
160
- # and so these two types are not considered exactly equal,
161
- # even if they have the same import path.
162
- if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
163
- self.udf = udf
164
- else:
165
- self.udf = None
166
- self.udf_factory = udf
158
+ self.udf_data = udf_data
167
159
  self.catalog_init_params = catalog_init_params
168
160
  (
169
161
  self.id_generator_class,
@@ -215,6 +207,15 @@ class UDFDispatcher:
215
207
  self.catalog = Catalog(
216
208
  id_generator, metastore, warehouse, **self.catalog_init_params
217
209
  )
210
+ udf = loads(self.udf_data)
211
+ # isinstance cannot be used here, as cloudpickle packages the entire class
212
+ # definition, and so these two types are not considered exactly equal,
213
+ # even if they have the same import path.
214
+ if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
215
+ self.udf = udf
216
+ else:
217
+ self.udf = None
218
+ self.udf_factory = udf
218
219
  if not self.udf:
219
220
  self.udf = self.udf_factory()
220
221
 
@@ -355,6 +356,15 @@ class WorkerCallback(Callback):
355
356
  put_into_queue(self.queue, {"status": NOTIFY_STATUS, "downloaded": inc})
356
357
 
357
358
 
359
+ class ProcessedCallback(Callback):
360
+ def __init__(self):
361
+ self.processed_rows: Optional[int] = None
362
+ super().__init__()
363
+
364
+ def relative_update(self, inc: int = 1) -> None:
365
+ self.processed_rows = inc
366
+
367
+
358
368
  @attrs.define
359
369
  class UDFWorker:
360
370
  catalog: Catalog
@@ -370,25 +380,28 @@ class UDFWorker:
370
380
  return WorkerCallback(self.done_queue)
371
381
 
372
382
  def run(self) -> None:
373
- if hasattr(self.udf.func, "setup") and callable(self.udf.func.setup):
374
- self.udf.func.setup()
375
- while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
376
- n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
377
- udf_output = self.udf(
378
- self.catalog,
379
- batch,
380
- is_generator=self.is_generator,
381
- cache=self.cache,
382
- cb=self.cb,
383
- )
383
+ processed_cb = ProcessedCallback()
384
+ udf_results = self.udf.run(
385
+ self.get_inputs(),
386
+ self.catalog,
387
+ self.is_generator,
388
+ self.cache,
389
+ download_cb=self.cb,
390
+ processed_cb=processed_cb,
391
+ )
392
+ for udf_output in udf_results:
384
393
  if isinstance(udf_output, GeneratorType):
385
394
  udf_output = list(udf_output) # can not pickle generator
386
395
  put_into_queue(
387
396
  self.done_queue,
388
- {"status": OK_STATUS, "result": udf_output, "processed": n_rows},
397
+ {
398
+ "status": OK_STATUS,
399
+ "result": udf_output,
400
+ "processed": processed_cb.processed_rows,
401
+ },
389
402
  )
390
-
391
- if hasattr(self.udf.func, "teardown") and callable(self.udf.func.teardown):
392
- self.udf.func.teardown()
393
-
394
403
  put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
404
+
405
+ def get_inputs(self):
406
+ while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
407
+ yield batch