datachain 0.16.4__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (39) hide show
  1. datachain/catalog/catalog.py +25 -92
  2. datachain/cli/__init__.py +11 -9
  3. datachain/cli/commands/datasets.py +1 -1
  4. datachain/cli/commands/query.py +1 -0
  5. datachain/cli/commands/show.py +1 -1
  6. datachain/cli/parser/__init__.py +11 -3
  7. datachain/data_storage/job.py +1 -0
  8. datachain/data_storage/metastore.py +105 -94
  9. datachain/data_storage/sqlite.py +8 -7
  10. datachain/data_storage/warehouse.py +58 -46
  11. datachain/dataset.py +88 -45
  12. datachain/lib/arrow.py +23 -1
  13. datachain/lib/dataset_info.py +2 -1
  14. datachain/lib/dc/csv.py +1 -0
  15. datachain/lib/dc/datachain.py +38 -16
  16. datachain/lib/dc/datasets.py +28 -7
  17. datachain/lib/dc/storage.py +10 -2
  18. datachain/lib/listing.py +2 -0
  19. datachain/lib/pytorch.py +2 -2
  20. datachain/lib/udf.py +17 -5
  21. datachain/listing.py +1 -1
  22. datachain/query/batch.py +40 -39
  23. datachain/query/dataset.py +42 -41
  24. datachain/query/dispatch.py +137 -75
  25. datachain/query/metrics.py +1 -2
  26. datachain/query/queue.py +1 -11
  27. datachain/query/session.py +2 -2
  28. datachain/query/udf.py +1 -1
  29. datachain/query/utils.py +8 -14
  30. datachain/remote/studio.py +4 -4
  31. datachain/semver.py +58 -0
  32. datachain/studio.py +1 -1
  33. datachain/utils.py +3 -0
  34. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/METADATA +1 -1
  35. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/RECORD +39 -38
  36. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/WHEEL +1 -1
  37. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/entry_points.txt +0 -0
  38. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/licenses/LICENSE +0 -0
  39. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/top_level.txt +0 -0
@@ -11,16 +11,15 @@ from urllib.parse import urlparse
11
11
 
12
12
  import attrs
13
13
  import sqlalchemy as sa
14
- from sqlalchemy import Table, case, select
15
- from sqlalchemy.sql import func
16
14
  from sqlalchemy.sql.expression import true
17
- from tqdm.auto import tqdm
18
15
 
19
16
  from datachain.client import Client
20
17
  from datachain.data_storage.schema import convert_rows_custom_column_types
21
18
  from datachain.data_storage.serializer import Serializable
22
19
  from datachain.dataset import DatasetRecord, StorageURI
23
20
  from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
21
+ from datachain.query.batch import RowsOutput
22
+ from datachain.query.utils import get_query_id_column
24
23
  from datachain.sql.functions import path as pathfunc
25
24
  from datachain.sql.types import Int, SQLType
26
25
  from datachain.utils import sql_escape_like
@@ -31,7 +30,6 @@ if TYPE_CHECKING:
31
30
  _FromClauseArgument,
32
31
  _OnClauseArgument,
33
32
  )
34
- from sqlalchemy.sql.selectable import Select
35
33
  from sqlalchemy.types import TypeEngine
36
34
 
37
35
  from datachain.data_storage import schema
@@ -178,7 +176,7 @@ class AbstractWarehouse(ABC, Serializable):
178
176
  def dataset_rows(
179
177
  self,
180
178
  dataset: DatasetRecord,
181
- version: Optional[int] = None,
179
+ version: Optional[str] = None,
182
180
  column: str = "file",
183
181
  ):
184
182
  version = version or dataset.latest_version
@@ -199,13 +197,13 @@ class AbstractWarehouse(ABC, Serializable):
199
197
  # Query Execution
200
198
  #
201
199
 
202
- def query_count(self, query: sa.sql.selectable.Select) -> int:
200
+ def query_count(self, query: sa.Select) -> int:
203
201
  """Count the number of rows in a query."""
204
- count_query = sa.select(func.count(1)).select_from(query.subquery())
202
+ count_query = sa.select(sa.func.count(1)).select_from(query.subquery())
205
203
  return next(self.db.execute(count_query))[0]
206
204
 
207
205
  def table_rows_count(self, table) -> int:
208
- count_query = sa.select(func.count(1)).select_from(table)
206
+ count_query = sa.select(sa.func.count(1)).select_from(table)
209
207
  return next(self.db.execute(count_query))[0]
210
208
 
211
209
  def dataset_select_paginated(
@@ -255,7 +253,7 @@ class AbstractWarehouse(ABC, Serializable):
255
253
  name = parsed.path if parsed.scheme == "file" else parsed.netloc
256
254
  return parsed.scheme, name
257
255
 
258
- def dataset_table_name(self, dataset_name: str, version: int) -> str:
256
+ def dataset_table_name(self, dataset_name: str, version: str) -> str:
259
257
  prefix = self.DATASET_TABLE_PREFIX
260
258
  if Client.is_data_source_uri(dataset_name):
261
259
  # for datasets that are created for bucket listing we use different prefix
@@ -278,18 +276,18 @@ class AbstractWarehouse(ABC, Serializable):
278
276
  name: str,
279
277
  columns: Sequence["sa.Column"] = (),
280
278
  if_not_exists: bool = True,
281
- ) -> Table:
279
+ ) -> sa.Table:
282
280
  """Creates a dataset rows table for the given dataset name and columns"""
283
281
 
284
282
  def drop_dataset_rows_table(
285
283
  self,
286
284
  dataset: DatasetRecord,
287
- version: int,
285
+ version: str,
288
286
  if_exists: bool = True,
289
287
  ) -> None:
290
288
  """Drops a dataset rows table for the given dataset name."""
291
289
  table_name = self.dataset_table_name(dataset.name, version)
292
- table = Table(table_name, self.db.metadata)
290
+ table = sa.Table(table_name, self.db.metadata)
293
291
  self.db.drop_table(table, if_exists=if_exists)
294
292
 
295
293
  @abstractmethod
@@ -297,8 +295,8 @@ class AbstractWarehouse(ABC, Serializable):
297
295
  self,
298
296
  src: "DatasetRecord",
299
297
  dst: "DatasetRecord",
300
- src_version: int,
301
- dst_version: int,
298
+ src_version: str,
299
+ dst_version: str,
302
300
  ) -> None:
303
301
  """
304
302
  Merges source dataset rows and current latest destination dataset rows
@@ -309,7 +307,7 @@ class AbstractWarehouse(ABC, Serializable):
309
307
 
310
308
  def dataset_rows_select(
311
309
  self,
312
- query: sa.sql.selectable.Select,
310
+ query: sa.Select,
313
311
  **kwargs,
314
312
  ) -> Iterator[tuple[Any, ...]]:
315
313
  """
@@ -320,17 +318,35 @@ class AbstractWarehouse(ABC, Serializable):
320
318
  query.selected_columns, rows, self.db.dialect
321
319
  )
322
320
 
321
+ def dataset_rows_select_from_ids(
322
+ self,
323
+ query: sa.Select,
324
+ ids: Iterable[RowsOutput],
325
+ is_batched: bool,
326
+ ) -> Iterator[RowsOutput]:
327
+ """
328
+ Fetch dataset rows from database using a list of IDs.
329
+ """
330
+ if (id_col := get_query_id_column(query)) is None:
331
+ raise RuntimeError("sys__id column not found in query")
332
+
333
+ if is_batched:
334
+ for batch in ids:
335
+ yield list(self.dataset_rows_select(query.where(id_col.in_(batch))))
336
+ else:
337
+ yield from self.dataset_rows_select(query.where(id_col.in_(ids)))
338
+
323
339
  @abstractmethod
324
340
  def get_dataset_sources(
325
- self, dataset: DatasetRecord, version: int
341
+ self, dataset: DatasetRecord, version: str
326
342
  ) -> list[StorageURI]: ...
327
343
 
328
344
  def rename_dataset_table(
329
345
  self,
330
346
  old_name: str,
331
347
  new_name: str,
332
- old_version: int,
333
- new_version: int,
348
+ old_version: str,
349
+ new_version: str,
334
350
  ) -> None:
335
351
  old_ds_table_name = self.dataset_table_name(old_name, old_version)
336
352
  new_ds_table_name = self.dataset_table_name(new_name, new_version)
@@ -341,12 +357,12 @@ class AbstractWarehouse(ABC, Serializable):
341
357
  """Returns total number of rows in a dataset"""
342
358
  dr = self.dataset_rows(dataset, version)
343
359
  table = dr.get_table()
344
- query = select(sa.func.count(table.c.sys__id))
360
+ query = sa.select(sa.func.count(table.c.sys__id))
345
361
  (res,) = self.db.execute(query)
346
362
  return res[0]
347
363
 
348
364
  def dataset_stats(
349
- self, dataset: DatasetRecord, version: int
365
+ self, dataset: DatasetRecord, version: str
350
366
  ) -> tuple[Optional[int], Optional[int]]:
351
367
  """
352
368
  Returns tuple with dataset stats: total number of rows and total dataset size.
@@ -364,7 +380,7 @@ class AbstractWarehouse(ABC, Serializable):
364
380
  ]
365
381
  if size_columns:
366
382
  expressions = (*expressions, sa.func.sum(sum(size_columns)))
367
- query = select(*expressions)
383
+ query = sa.select(*expressions)
368
384
  ((nrows, *rest),) = self.db.execute(query)
369
385
  return nrows, rest[0] if rest else 0
370
386
 
@@ -373,17 +389,17 @@ class AbstractWarehouse(ABC, Serializable):
373
389
  """Convert File entries so they can be passed on to `insert_rows()`"""
374
390
 
375
391
  @abstractmethod
376
- def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
392
+ def insert_rows(self, table: sa.Table, rows: Iterable[dict[str, Any]]) -> None:
377
393
  """Does batch inserts of any kind of rows into table"""
378
394
 
379
- def insert_rows_done(self, table: Table) -> None:
395
+ def insert_rows_done(self, table: sa.Table) -> None:
380
396
  """
381
397
  Only needed for certain implementations
382
398
  to signal when rows inserts are complete.
383
399
  """
384
400
 
385
401
  @abstractmethod
386
- def insert_dataset_rows(self, df, dataset: DatasetRecord, version: int) -> int:
402
+ def insert_dataset_rows(self, df, dataset: DatasetRecord, version: str) -> int:
387
403
  """Inserts dataset rows directly into dataset table"""
388
404
 
389
405
  @abstractmethod
@@ -402,7 +418,7 @@ class AbstractWarehouse(ABC, Serializable):
402
418
 
403
419
  @abstractmethod
404
420
  def dataset_table_export_file_names(
405
- self, dataset: DatasetRecord, version: int
421
+ self, dataset: DatasetRecord, version: str
406
422
  ) -> list[str]:
407
423
  """
408
424
  Returns list of file names that will be created when user runs dataset export
@@ -413,7 +429,7 @@ class AbstractWarehouse(ABC, Serializable):
413
429
  self,
414
430
  bucket_uri: str,
415
431
  dataset: DatasetRecord,
416
- version: int,
432
+ version: str,
417
433
  client_config=None,
418
434
  ) -> list[str]:
419
435
  """
@@ -497,7 +513,7 @@ class AbstractWarehouse(ABC, Serializable):
497
513
  ).subquery()
498
514
  path_glob = "/".join([*path_list, glob_name])
499
515
  dirpath = path_glob[: -len(glob_name)]
500
- relpath = func.substr(de.c(q, "path"), len(dirpath) + 1)
516
+ relpath = sa.func.substr(de.c(q, "path"), len(dirpath) + 1)
501
517
 
502
518
  return self.get_nodes(
503
519
  self.expand_query(de, q, dr)
@@ -584,13 +600,13 @@ class AbstractWarehouse(ABC, Serializable):
584
600
  default = getattr(
585
601
  attrs.fields(Node), dr.without_object(column.name)
586
602
  ).default
587
- return func.coalesce(column, default).label(column.name)
603
+ return sa.func.coalesce(column, default).label(column.name)
588
604
 
589
605
  return sa.select(
590
606
  q.c.sys__id,
591
- case((de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE).label(
592
- dr.col_name("dir_type")
593
- ),
607
+ sa.case(
608
+ (de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE
609
+ ).label(dr.col_name("dir_type")),
594
610
  de.c(q, "path"),
595
611
  with_default(dr.c("etag")),
596
612
  de.c(q, "version"),
@@ -665,7 +681,7 @@ class AbstractWarehouse(ABC, Serializable):
665
681
  return de.c(inner_query, f)
666
682
 
667
683
  return self.db.execute(
668
- select(*(field_to_expr(f) for f in fields)).order_by(
684
+ sa.select(*(field_to_expr(f) for f in fields)).order_by(
669
685
  de.c(inner_query, "source"),
670
686
  de.c(inner_query, "path"),
671
687
  de.c(inner_query, "version"),
@@ -687,7 +703,7 @@ class AbstractWarehouse(ABC, Serializable):
687
703
  return dr.c(f)
688
704
 
689
705
  q = (
690
- select(*(field_to_expr(f) for f in fields))
706
+ sa.select(*(field_to_expr(f) for f in fields))
691
707
  .where(
692
708
  dr.c("path").like(f"{sql_escape_like(dirpath)}%"),
693
709
  ~self.instr(pathfunc.name(dr.c("path")), "/"),
@@ -722,10 +738,10 @@ class AbstractWarehouse(ABC, Serializable):
722
738
  sub_glob = posixpath.join(path, "*")
723
739
  dr = dataset_rows
724
740
  selections: list[sa.ColumnElement] = [
725
- func.sum(dr.c("size")),
741
+ sa.func.sum(dr.c("size")),
726
742
  ]
727
743
  if count_files:
728
- selections.append(func.count())
744
+ selections.append(sa.func.count())
729
745
  results = next(
730
746
  self.db.execute(
731
747
  dr.select(*selections).where(
@@ -842,7 +858,7 @@ class AbstractWarehouse(ABC, Serializable):
842
858
  self,
843
859
  columns: Sequence["sa.Column"] = (),
844
860
  name: Optional[str] = None,
845
- ) -> "sa.Table":
861
+ ) -> sa.Table:
846
862
  """
847
863
  Create a temporary table for storing custom signals generated by a UDF.
848
864
  SQLite TEMPORARY tables cannot be directly used as they are process-specific,
@@ -860,8 +876,8 @@ class AbstractWarehouse(ABC, Serializable):
860
876
  @abstractmethod
861
877
  def copy_table(
862
878
  self,
863
- table: Table,
864
- query: "Select",
879
+ table: sa.Table,
880
+ query: sa.Select,
865
881
  progress_cb: Optional[Callable[[int], None]] = None,
866
882
  ) -> None:
867
883
  """
@@ -875,13 +891,13 @@ class AbstractWarehouse(ABC, Serializable):
875
891
  right: "_FromClauseArgument",
876
892
  onclause: "_OnClauseArgument",
877
893
  inner: bool = True,
878
- ) -> "Select":
894
+ ) -> sa.Select:
879
895
  """
880
896
  Join two tables together.
881
897
  """
882
898
 
883
899
  @abstractmethod
884
- def create_pre_udf_table(self, query: "Select") -> "Table":
900
+ def create_pre_udf_table(self, query: sa.Select) -> sa.Table:
885
901
  """
886
902
  Create a temporary table from a query for use in a UDF.
887
903
  """
@@ -906,12 +922,8 @@ class AbstractWarehouse(ABC, Serializable):
906
922
  are cleaned up as soon as they are no longer needed.
907
923
  """
908
924
  to_drop = set(names)
909
- with tqdm(
910
- desc="Cleanup", unit=" tables", total=len(to_drop), leave=False
911
- ) as pbar:
912
- for name in to_drop:
913
- self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
914
- pbar.update(1)
925
+ for name in to_drop:
926
+ self.db.drop_table(sa.Table(name, self.db.metadata), if_exists=True)
915
927
 
916
928
 
917
929
  def _random_string(length: int) -> str:
datachain/dataset.py CHANGED
@@ -12,6 +12,7 @@ from typing import (
12
12
  )
13
13
  from urllib.parse import urlparse
14
14
 
15
+ from datachain import semver
15
16
  from datachain.error import DatasetVersionNotFoundError
16
17
  from datachain.sql.types import NAME_TYPES_MAPPING, SQLType
17
18
 
@@ -25,6 +26,8 @@ DATASET_PREFIX = "ds://"
25
26
  QUERY_DATASET_PREFIX = "ds_query_"
26
27
  LISTING_PREFIX = "lst__"
27
28
 
29
+ DEFAULT_DATASET_VERSION = "1.0.0"
30
+
28
31
 
29
32
  # StorageURI represents a normalised URI to a valid storage location (full bucket or
30
33
  # absolute local path).
@@ -33,12 +36,12 @@ LISTING_PREFIX = "lst__"
33
36
  StorageURI = NewType("StorageURI", str)
34
37
 
35
38
 
36
- def parse_dataset_uri(uri: str) -> tuple[str, Optional[int]]:
39
+ def parse_dataset_uri(uri: str) -> tuple[str, Optional[str]]:
37
40
  """
38
41
  Parse dataser uri to extract name and version out of it (if version is defined)
39
42
  Example:
40
- Input: ds://zalando@v3
41
- Output: (zalando, 3)
43
+ Input: ds://zalando@v3.0.1
44
+ Output: (zalando, 3.0.1)
42
45
  """
43
46
  p = urlparse(uri)
44
47
  if p.scheme != "ds":
@@ -51,16 +54,15 @@ def parse_dataset_uri(uri: str) -> tuple[str, Optional[int]]:
51
54
  raise Exception(
52
55
  "Wrong dataset uri format, it should be: ds://<name>@v<version>"
53
56
  )
54
- version = int(s[1])
55
- return name, version
57
+ return name, s[1]
56
58
 
57
59
 
58
- def create_dataset_uri(name: str, version: Optional[int] = None) -> str:
60
+ def create_dataset_uri(name: str, version: Optional[str] = None) -> str:
59
61
  """
60
62
  Creates a dataset uri based on dataset name and optionally version
61
63
  Example:
62
- Input: zalando, 3
63
- Output: ds//zalando@v3
64
+ Input: zalando, 3.0.1
65
+ Output: ds//zalando@v3.0.1
64
66
  """
65
67
  uri = f"{DATASET_PREFIX}{name}"
66
68
  if version:
@@ -79,7 +81,7 @@ class DatasetDependency:
79
81
  id: int
80
82
  type: str
81
83
  name: str
82
- version: str # TODO change to int
84
+ version: str
83
85
  created_at: datetime
84
86
  dependencies: list[Optional["DatasetDependency"]]
85
87
 
@@ -102,7 +104,7 @@ class DatasetDependency:
102
104
  dataset_id: Optional[int],
103
105
  dataset_version_id: Optional[int],
104
106
  dataset_name: Optional[str],
105
- dataset_version: Optional[int],
107
+ dataset_version: Optional[str],
106
108
  dataset_version_created_at: Optional[datetime],
107
109
  ) -> Optional["DatasetDependency"]:
108
110
  from datachain.client import Client
@@ -124,7 +126,7 @@ class DatasetDependency:
124
126
  dependency_type,
125
127
  dependency_name,
126
128
  (
127
- str(dataset_version) # type: ignore[arg-type]
129
+ dataset_version # type: ignore[arg-type]
128
130
  if dataset_version
129
131
  else None
130
132
  ),
@@ -163,7 +165,7 @@ class DatasetVersion:
163
165
  id: int
164
166
  uuid: str
165
167
  dataset_id: int
166
- version: int
168
+ version: str
167
169
  status: int
168
170
  feature_schema: dict
169
171
  created_at: datetime
@@ -185,7 +187,7 @@ class DatasetVersion:
185
187
  id: int,
186
188
  uuid: str,
187
189
  dataset_id: int,
188
- version: int,
190
+ version: str,
189
191
  status: int,
190
192
  feature_schema: Optional[str],
191
193
  created_at: datetime,
@@ -222,6 +224,10 @@ class DatasetVersion:
222
224
  job_id,
223
225
  )
224
226
 
227
+ @property
228
+ def version_value(self) -> int:
229
+ return semver.value(self.version)
230
+
225
231
  def __eq__(self, other):
226
232
  if not isinstance(other, DatasetVersion):
227
233
  return False
@@ -230,7 +236,7 @@ class DatasetVersion:
230
236
  def __lt__(self, other):
231
237
  if not isinstance(other, DatasetVersion):
232
238
  return False
233
- return self.version < other.version
239
+ return self.version_value < other.version_value
234
240
 
235
241
  def __hash__(self):
236
242
  return hash(f"{self.dataset_id}_{self.version}")
@@ -275,7 +281,7 @@ class DatasetListVersion:
275
281
  id: int
276
282
  uuid: str
277
283
  dataset_id: int
278
- version: int
284
+ version: str
279
285
  status: int
280
286
  created_at: datetime
281
287
  finished_at: Optional[datetime]
@@ -292,7 +298,7 @@ class DatasetListVersion:
292
298
  id: int,
293
299
  uuid: str,
294
300
  dataset_id: int,
295
- version: int,
301
+ version: str,
296
302
  status: int,
297
303
  created_at: datetime,
298
304
  finished_at: Optional[datetime],
@@ -323,6 +329,10 @@ class DatasetListVersion:
323
329
  def __hash__(self):
324
330
  return hash(f"{self.dataset_id}_{self.version}")
325
331
 
332
+ @property
333
+ def version_value(self) -> int:
334
+ return semver.value(self.version)
335
+
326
336
 
327
337
  @dataclass
328
338
  class DatasetRecord:
@@ -371,7 +381,7 @@ class DatasetRecord:
371
381
  version_id: int,
372
382
  version_uuid: str,
373
383
  version_dataset_id: int,
374
- version: int,
384
+ version: str,
375
385
  version_status: int,
376
386
  version_feature_schema: Optional[str],
377
387
  version_created_at: datetime,
@@ -441,7 +451,7 @@ class DatasetRecord:
441
451
  for c_name, c_type in self.schema.items()
442
452
  }
443
453
 
444
- def get_schema(self, version: int) -> dict[str, Union[SQLType, type[SQLType]]]:
454
+ def get_schema(self, version: str) -> dict[str, Union[SQLType, type[SQLType]]]:
445
455
  return self.get_version(version).schema if version else self.schema
446
456
 
447
457
  def update(self, **kwargs):
@@ -460,20 +470,23 @@ class DatasetRecord:
460
470
  self.versions = []
461
471
 
462
472
  self.versions = list(set(self.versions + other.versions))
463
- self.versions.sort(key=lambda v: v.version)
473
+ self.versions.sort(key=lambda v: v.version_value)
464
474
  return self
465
475
 
466
- def has_version(self, version: int) -> bool:
467
- return version in self.versions_values
476
+ def has_version(self, version: str) -> bool:
477
+ return version in [v.version for v in self.versions]
468
478
 
469
- def is_valid_next_version(self, version: int) -> bool:
479
+ def is_valid_next_version(self, version: str) -> bool:
470
480
  """
471
481
  Checks if a number can be a valid next latest version for dataset.
472
482
  The only rule is that it cannot be lower than current latest version
473
483
  """
474
- return not (self.latest_version and self.latest_version >= version)
484
+ return not (
485
+ self.latest_version
486
+ and semver.value(self.latest_version) >= semver.value(version)
487
+ )
475
488
 
476
- def get_version(self, version: int) -> DatasetVersion:
489
+ def get_version(self, version: str) -> DatasetVersion:
477
490
  if not self.has_version(version):
478
491
  raise DatasetVersionNotFoundError(
479
492
  f"Dataset {self.name} does not have version {version}"
@@ -496,15 +509,15 @@ class DatasetRecord:
496
509
  f"Dataset {self.name} does not have version with uuid {uuid}"
497
510
  ) from None
498
511
 
499
- def remove_version(self, version: int) -> None:
512
+ def remove_version(self, version: str) -> None:
500
513
  if not self.versions or not self.has_version(version):
501
514
  return
502
515
 
503
516
  self.versions = [v for v in self.versions if v.version != version]
504
517
 
505
- def identifier(self, version: int) -> str:
518
+ def identifier(self, version: str) -> str:
506
519
  """
507
- Get identifier in the form my-dataset@v3
520
+ Get identifier in the form my-dataset@v3.0.1
508
521
  """
509
522
  if not self.has_version(version):
510
523
  raise DatasetVersionNotFoundError(
@@ -512,43 +525,73 @@ class DatasetRecord:
512
525
  )
513
526
  return f"{self.name}@v{version}"
514
527
 
515
- def uri(self, version: int) -> str:
528
+ def uri(self, version: str) -> str:
516
529
  """
517
- Dataset uri example: ds://dogs@v3
530
+ Dataset uri example: ds://dogs@v3.0.1
518
531
  """
519
532
  identifier = self.identifier(version)
520
533
  return f"{DATASET_PREFIX}{identifier}"
521
534
 
522
535
  @property
523
- def versions_values(self) -> list[int]:
536
+ def next_version_major(self) -> str:
524
537
  """
525
- Extracts actual versions from list of DatasetVersion objects
526
- in self.versions attribute
538
+ Returns the next auto-incremented version if the major part is being bumped.
527
539
  """
528
540
  if not self.versions:
529
- return []
541
+ return "1.0.0"
530
542
 
531
- return sorted(v.version for v in self.versions)
543
+ major, minor, patch = semver.parse(self.latest_version)
544
+ return semver.create(major + 1, 0, 0)
532
545
 
533
546
  @property
534
- def next_version(self) -> int:
535
- """Returns what should be next autoincrement version of dataset"""
547
+ def next_version_minor(self) -> str:
548
+ """
549
+ Returns the next auto-incremented version if the minor part is being bumped.
550
+ """
536
551
  if not self.versions:
537
- return 1
538
- return max(self.versions_values) + 1
552
+ return "1.0.0"
553
+
554
+ major, minor, patch = semver.parse(self.latest_version)
555
+ return semver.create(major, minor + 1, 0)
539
556
 
540
557
  @property
541
- def latest_version(self) -> int:
558
+ def next_version_patch(self) -> str:
559
+ """
560
+ Returns the next auto-incremented version if the patch part is being bumped.
561
+ """
562
+ if not self.versions:
563
+ return "1.0.0"
564
+
565
+ major, minor, patch = semver.parse(self.latest_version)
566
+ return semver.create(major, minor, patch + 1)
567
+
568
+ @property
569
+ def latest_version(self) -> str:
542
570
  """Returns latest version of a dataset"""
543
- return max(self.versions_values)
571
+ return max(self.versions).version
572
+
573
+ def latest_major_version(self, major: int) -> Optional[str]:
574
+ """
575
+ Returns latest specific major version, e.g if dataset has versions:
576
+ - 1.4.1
577
+ - 2.0.1
578
+ - 2.1.1
579
+ - 2.4.0
580
+ and we call `.latest_major_version(2)` it will return: "2.4.0".
581
+ If no major version is find with input value, None will be returned
582
+ """
583
+ versions = [v for v in self.versions if semver.parse(v.version)[0] == major]
584
+ if not versions:
585
+ return None
586
+ return max(versions).version
544
587
 
545
588
  @property
546
- def prev_version(self) -> Optional[int]:
589
+ def prev_version(self) -> Optional[str]:
547
590
  """Returns previous version of a dataset"""
548
591
  if len(self.versions) == 1:
549
592
  return None
550
593
 
551
- return sorted(self.versions_values)[-2]
594
+ return sorted(self.versions)[-2].version
552
595
 
553
596
  @classmethod
554
597
  def from_dict(cls, d: dict[str, Any]) -> "DatasetRecord":
@@ -577,7 +620,7 @@ class DatasetListRecord:
577
620
  version_id: int,
578
621
  version_uuid: str,
579
622
  version_dataset_id: int,
580
- version: int,
623
+ version: str,
581
624
  version_status: int,
582
625
  version_created_at: datetime,
583
626
  version_finished_at: Optional[datetime],
@@ -626,11 +669,11 @@ class DatasetListRecord:
626
669
  self.versions = []
627
670
 
628
671
  self.versions = list(set(self.versions + other.versions))
629
- self.versions.sort(key=lambda v: v.version)
672
+ self.versions.sort(key=lambda v: v.version_value)
630
673
  return self
631
674
 
632
675
  def latest_version(self) -> DatasetListVersion:
633
- return max(self.versions, key=lambda v: v.version)
676
+ return max(self.versions, key=lambda v: v.version_value)
634
677
 
635
678
  @property
636
679
  def is_bucket_listing(self) -> bool:
datachain/lib/arrow.py CHANGED
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Optional
4
4
 
5
5
  import orjson
6
6
  import pyarrow as pa
7
+ from pyarrow._csv import ParseOptions
7
8
  from pyarrow.dataset import CsvFileFormat, dataset
8
9
  from tqdm.auto import tqdm
9
10
 
@@ -26,6 +27,18 @@ if TYPE_CHECKING:
26
27
  DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
27
28
 
28
29
 
30
+ def fix_pyarrow_format(format, parse_options=None):
31
+ # Re-init invalid row handler: https://issues.apache.org/jira/browse/ARROW-17641
32
+ if (
33
+ format
34
+ and isinstance(format, CsvFileFormat)
35
+ and parse_options
36
+ and isinstance(parse_options, ParseOptions)
37
+ ):
38
+ format.parse_options = parse_options
39
+ return format
40
+
41
+
29
42
  class ArrowGenerator(Generator):
30
43
  DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow._dataset._DEFAULT_BATCH_SIZE`
31
44
 
@@ -53,6 +66,7 @@ class ArrowGenerator(Generator):
53
66
  self.output_schema = output_schema
54
67
  self.source = source
55
68
  self.nrows = nrows
69
+ self.parse_options = kwargs.pop("parse_options", None)
56
70
  self.kwargs = kwargs
57
71
 
58
72
  def process(self, file: File):
@@ -64,7 +78,11 @@ class ArrowGenerator(Generator):
64
78
  else:
65
79
  fs, fs_path = file.get_fs(), file.get_path()
66
80
 
67
- ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **self.kwargs)
81
+ kwargs = self.kwargs
82
+ if format := kwargs.get("format"):
83
+ kwargs["format"] = fix_pyarrow_format(format, self.parse_options)
84
+
85
+ ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **kwargs)
68
86
 
69
87
  hf_schema = _get_hf_schema(ds.schema)
70
88
  use_datachain_schema = (
@@ -137,6 +155,10 @@ class ArrowGenerator(Generator):
137
155
 
138
156
 
139
157
  def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
158
+ parse_options = kwargs.pop("parse_options", None)
159
+ if format := kwargs.get("format"):
160
+ kwargs["format"] = fix_pyarrow_format(format, parse_options)
161
+
140
162
  schemas = []
141
163
  for file in chain.collect("file"):
142
164
  ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]