datachain 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -421,10 +421,6 @@ class AbstractMetastore(ABC, Serializable):
421
421
  ) -> None:
422
422
  """Set the status of the given job and dataset."""
423
423
 
424
- @abstractmethod
425
- def get_possibly_stale_jobs(self) -> list[tuple[str, str, int]]:
426
- """Returns the possibly stale jobs."""
427
-
428
424
 
429
425
  class AbstractDBMetastore(AbstractMetastore):
430
426
  """
@@ -19,8 +19,12 @@ from datachain.sql.types import Int, SQLType, UInt64
19
19
  if TYPE_CHECKING:
20
20
  from sqlalchemy import Engine
21
21
  from sqlalchemy.engine.interfaces import Dialect
22
- from sqlalchemy.sql.base import Executable, ReadOnlyColumnCollection
23
- from sqlalchemy.sql.elements import KeyedColumnElement
22
+ from sqlalchemy.sql.base import (
23
+ ColumnCollection,
24
+ Executable,
25
+ ReadOnlyColumnCollection,
26
+ )
27
+ from sqlalchemy.sql.elements import ColumnElement
24
28
 
25
29
 
26
30
  def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
@@ -43,7 +47,7 @@ def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
43
47
 
44
48
 
45
49
  def convert_rows_custom_column_types(
46
- columns: "ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]",
50
+ columns: "ColumnCollection[str, ColumnElement[Any]]",
47
51
  rows: Iterator[tuple[Any, ...]],
48
52
  dialect: "Dialect",
49
53
  ):
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
42
42
  from sqlalchemy.dialects.sqlite import Insert
43
43
  from sqlalchemy.schema import SchemaItem
44
44
  from sqlalchemy.sql.elements import ColumnClause, ColumnElement, TextClause
45
+ from sqlalchemy.sql.selectable import Select
45
46
  from sqlalchemy.types import TypeEngine
46
47
 
47
48
 
@@ -496,9 +497,6 @@ class SQLiteMetastore(AbstractDBMetastore):
496
497
  def _jobs_insert(self) -> "Insert":
497
498
  return sqlite.insert(self._jobs)
498
499
 
499
- def get_possibly_stale_jobs(self) -> list[tuple[str, str, int]]:
500
- raise NotImplementedError("get_possibly_stale_jobs not implemented for SQLite")
501
-
502
500
 
503
501
  class SQLiteWarehouse(AbstractWarehouse):
504
502
  """
@@ -594,7 +592,7 @@ class SQLiteWarehouse(AbstractWarehouse):
594
592
  ):
595
593
  rows = self.db.execute(select_query, **kwargs)
596
594
  yield from convert_rows_custom_column_types(
597
- select_query.columns, rows, sqlite_dialect
595
+ select_query.selected_columns, rows, sqlite_dialect
598
596
  )
599
597
 
600
598
  def get_dataset_sources(
@@ -708,3 +706,23 @@ class SQLiteWarehouse(AbstractWarehouse):
708
706
  client_config=None,
709
707
  ) -> list[str]:
710
708
  raise NotImplementedError("Exporting dataset table not implemented for SQLite")
709
+
710
+ def create_pre_udf_table(self, query: "Select") -> "Table":
711
+ """
712
+ Create a temporary table from a query for use in a UDF.
713
+ """
714
+ columns = [
715
+ sqlalchemy.Column(c.name, c.type)
716
+ for c in query.selected_columns
717
+ if c.name != "sys__id"
718
+ ]
719
+ table = self.create_udf_table(columns)
720
+
721
+ select_q = query.with_only_columns(
722
+ *[c for c in query.selected_columns if c.name != "sys__id"]
723
+ )
724
+ self.db.execute(
725
+ table.insert().from_select(list(select_q.selected_columns), select_q)
726
+ )
727
+
728
+ return table
@@ -2,6 +2,8 @@ import glob
2
2
  import json
3
3
  import logging
4
4
  import posixpath
5
+ import random
6
+ import string
5
7
  from abc import ABC, abstractmethod
6
8
  from collections.abc import Generator, Iterable, Iterator, Sequence
7
9
  from typing import TYPE_CHECKING, Any, Optional, Union
@@ -24,6 +26,7 @@ from datachain.utils import sql_escape_like
24
26
  if TYPE_CHECKING:
25
27
  from sqlalchemy.sql._typing import _ColumnsClauseArgument
26
28
  from sqlalchemy.sql.elements import ColumnElement
29
+ from sqlalchemy.sql.selectable import Select
27
30
  from sqlalchemy.types import TypeEngine
28
31
 
29
32
  from datachain.data_storage import AbstractIDGenerator, schema
@@ -252,6 +255,12 @@ class AbstractWarehouse(ABC, Serializable):
252
255
  prefix = self.DATASET_SOURCE_TABLE_PREFIX
253
256
  return f"{prefix}{dataset_name}_{version}"
254
257
 
258
+ def temp_table_name(self) -> str:
259
+ return self.TMP_TABLE_NAME_PREFIX + _random_string(6)
260
+
261
+ def udf_table_name(self) -> str:
262
+ return self.UDF_TABLE_NAME_PREFIX + _random_string(6)
263
+
255
264
  #
256
265
  # Datasets
257
266
  #
@@ -494,7 +503,7 @@ class AbstractWarehouse(ABC, Serializable):
494
503
  This gets nodes based on the provided query, and should be used sparingly,
495
504
  as it will be slow on any OLAP database systems.
496
505
  """
497
- columns = [c.name for c in query.columns]
506
+ columns = [c.name for c in query.selected_columns]
498
507
  for row in self.db.execute(query):
499
508
  d = dict(zip(columns, row))
500
509
  yield Node(**d)
@@ -869,8 +878,8 @@ class AbstractWarehouse(ABC, Serializable):
869
878
 
870
879
  def create_udf_table(
871
880
  self,
872
- name: str,
873
881
  columns: Sequence["sa.Column"] = (),
882
+ name: Optional[str] = None,
874
883
  ) -> "sa.Table":
875
884
  """
876
885
  Create a temporary table for storing custom signals generated by a UDF.
@@ -878,7 +887,7 @@ class AbstractWarehouse(ABC, Serializable):
878
887
  and UDFs are run in other processes when run in parallel.
879
888
  """
880
889
  tbl = sa.Table(
881
- name,
890
+ name or self.udf_table_name(),
882
891
  sa.MetaData(),
883
892
  sa.Column("sys__id", Int, primary_key=True),
884
893
  *columns,
@@ -886,6 +895,12 @@ class AbstractWarehouse(ABC, Serializable):
886
895
  self.db.create_table(tbl, if_not_exists=True)
887
896
  return tbl
888
897
 
898
+ @abstractmethod
899
+ def create_pre_udf_table(self, query: "Select") -> "Table":
900
+ """
901
+ Create a temporary table from a query for use in a UDF.
902
+ """
903
+
889
904
  def is_temp_table_name(self, name: str) -> bool:
890
905
  """Returns if the given table name refers to a temporary
891
906
  or no longer needed table."""
@@ -912,29 +927,6 @@ class AbstractWarehouse(ABC, Serializable):
912
927
  for name in names:
913
928
  self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
914
929
 
915
- def subtract_query(
916
- self,
917
- source_query: sa.sql.selectable.Select,
918
- target_query: sa.sql.selectable.Select,
919
- ) -> sa.sql.selectable.Select:
920
- sq = source_query.alias("source_query")
921
- tq = target_query.alias("target_query")
922
-
923
- source_target_join = sa.join(
924
- sq,
925
- tq,
926
- (sq.c.source == tq.c.source)
927
- & (sq.c.parent == tq.c.parent)
928
- & (sq.c.name == tq.c.name),
929
- isouter=True,
930
- )
931
-
932
- return (
933
- select(*sq.c)
934
- .select_from(source_target_join)
935
- .where((tq.c.name == None) | (tq.c.name == "")) # noqa: E711
936
- )
937
-
938
930
  def changed_query(
939
931
  self,
940
932
  source_query: sa.sql.selectable.Select,
@@ -960,3 +952,10 @@ class AbstractWarehouse(ABC, Serializable):
960
952
  & (tq.c.is_latest == true())
961
953
  )
962
954
  )
955
+
956
+
957
+ def _random_string(length: int) -> str:
958
+ return "".join(
959
+ random.choice(string.ascii_letters + string.digits) # noqa: S311
960
+ for i in range(length)
961
+ )
datachain/lib/arrow.py CHANGED
@@ -10,13 +10,17 @@ from datachain.lib.file import File, IndexedFile
10
10
  from datachain.lib.udf import Generator
11
11
 
12
12
  if TYPE_CHECKING:
13
+ from pydantic import BaseModel
14
+
13
15
  from datachain.lib.dc import DataChain
14
16
 
15
17
 
16
18
  class ArrowGenerator(Generator):
17
19
  def __init__(
18
20
  self,
19
- schema: Optional["pa.Schema"] = None,
21
+ input_schema: Optional["pa.Schema"] = None,
22
+ output_schema: Optional[type["BaseModel"]] = None,
23
+ source: bool = True,
20
24
  nrows: Optional[int] = None,
21
25
  **kwargs,
22
26
  ):
@@ -25,24 +29,36 @@ class ArrowGenerator(Generator):
25
29
 
26
30
  Parameters:
27
31
 
28
- schema : Optional pyarrow schema for validation.
32
+ input_schema : Optional pyarrow schema for validation.
33
+ output_schema : Optional pydantic model for validation.
34
+ source : Whether to include info about the source file.
29
35
  nrows : Optional row limit.
30
36
  kwargs: Parameters to pass to pyarrow.dataset.dataset.
31
37
  """
32
38
  super().__init__()
33
- self.schema = schema
39
+ self.input_schema = input_schema
40
+ self.output_schema = output_schema
41
+ self.source = source
34
42
  self.nrows = nrows
35
43
  self.kwargs = kwargs
36
44
 
37
45
  def process(self, file: File):
38
46
  path = file.get_path()
39
- ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
47
+ ds = dataset(
48
+ path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
49
+ )
40
50
  index = 0
41
51
  with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
42
- for record_batch in ds.to_batches():
52
+ for record_batch in ds.to_batches(use_threads=False):
43
53
  for record in record_batch.to_pylist():
44
- source = IndexedFile(file=file, index=index)
45
- yield [source, *record.values()]
54
+ vals = list(record.values())
55
+ if self.output_schema:
56
+ fields = self.output_schema.model_fields
57
+ vals = [self.output_schema(**dict(zip(fields, vals)))]
58
+ if self.source:
59
+ yield [IndexedFile(file=file, index=index), *vals]
60
+ else:
61
+ yield vals
46
62
  index += 1
47
63
  if self.nrows and index >= self.nrows:
48
64
  return
@@ -76,7 +92,10 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
76
92
  if not column:
77
93
  column = f"c{default_column}"
78
94
  default_column += 1
79
- output[column] = _arrow_type_mapper(field.type) # type: ignore[assignment]
95
+ dtype = _arrow_type_mapper(field.type) # type: ignore[assignment]
96
+ if field.nullable:
97
+ dtype = Optional[dtype] # type: ignore[assignment]
98
+ output[column] = dtype
80
99
 
81
100
  return output
82
101
 
@@ -41,17 +41,22 @@ def flatten_list(obj_list):
41
41
  )
42
42
 
43
43
 
44
+ def _flatten_list_field(value: list):
45
+ assert isinstance(value, list)
46
+ if value and ModelStore.is_pydantic(type(value[0])):
47
+ return [val.model_dump() for val in value]
48
+ if value and isinstance(value[0], list):
49
+ return [_flatten_list_field(v) for v in value]
50
+ return value
51
+
52
+
44
53
  def _flatten_fields_values(fields, obj: BaseModel):
45
54
  for name, f_info in fields.items():
46
55
  anno = f_info.annotation
47
56
  # Optimization: Access attributes directly to skip the model_dump() call.
48
57
  value = getattr(obj, name)
49
-
50
58
  if isinstance(value, list):
51
- yield [
52
- val.model_dump() if ModelStore.is_pydantic(type(val)) else val
53
- for val in value
54
- ]
59
+ yield _flatten_list_field(value)
55
60
  elif isinstance(value, dict):
56
61
  yield {
57
62
  key: val.model_dump() if ModelStore.is_pydantic(type(val)) else val
@@ -82,7 +82,7 @@ def python_to_sql(typ): # noqa: PLR0911
82
82
  def _is_json_inside_union(orig, args) -> bool:
83
83
  if orig == Union and len(args) >= 2:
84
84
  # List in JSON: Union[dict, list[dict]]
85
- args_no_nones = [arg for arg in args if arg != type(None)]
85
+ args_no_nones = [arg for arg in args if arg != type(None)] # noqa: E721
86
86
  if len(args_no_nones) == 2:
87
87
  args_no_dicts = [arg for arg in args_no_nones if arg is not dict]
88
88
  if len(args_no_dicts) == 1 and get_origin(args_no_dicts[0]) is list:
@@ -71,7 +71,10 @@ def values_to_tuples( # noqa: C901, PLR0912
71
71
  f"signal '{k}' has unsupported type '{typ.__name__}'."
72
72
  f" Please use DataModel types: {DataTypeNames}",
73
73
  )
74
- types_map[k] = typ
74
+ if typ is list:
75
+ types_map[k] = list[type(v[0][0])] # type: ignore[misc]
76
+ else:
77
+ types_map[k] = typ
75
78
 
76
79
  if length < 0:
77
80
  length = len_
@@ -47,7 +47,12 @@ def is_chain_type(t: type) -> bool:
47
47
  if any(t is ft or t is get_args(ft)[0] for ft in get_args(StandardType)):
48
48
  return True
49
49
 
50
- if get_origin(t) is list and len(get_args(t)) == 1:
50
+ orig = get_origin(t)
51
+ args = get_args(t)
52
+ if orig is list and len(args) == 1:
51
53
  return is_chain_type(get_args(t)[0])
52
54
 
55
+ if orig is Union and len(args) == 2 and (type(None) in args):
56
+ return is_chain_type(args[0])
57
+
53
58
  return False