datachain 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +33 -5
- datachain/catalog/loader.py +19 -13
- datachain/cli/__init__.py +3 -1
- datachain/cli/commands/show.py +12 -1
- datachain/cli/parser/studio.py +13 -1
- datachain/cli/parser/utils.py +6 -0
- datachain/client/fsspec.py +12 -16
- datachain/client/hf.py +36 -14
- datachain/client/local.py +1 -4
- datachain/data_storage/warehouse.py +3 -8
- datachain/dataset.py +8 -0
- datachain/error.py +0 -12
- datachain/fs/utils.py +30 -0
- datachain/func/__init__.py +5 -0
- datachain/func/func.py +2 -1
- datachain/lib/data_model.py +6 -0
- datachain/lib/dc.py +114 -28
- datachain/lib/file.py +100 -25
- datachain/lib/image.py +30 -6
- datachain/lib/listing.py +21 -39
- datachain/lib/signal_schema.py +194 -15
- datachain/lib/video.py +7 -5
- datachain/model/bbox.py +209 -58
- datachain/model/pose.py +49 -37
- datachain/model/segment.py +22 -18
- datachain/model/ultralytics/bbox.py +9 -9
- datachain/model/ultralytics/pose.py +7 -7
- datachain/model/ultralytics/segment.py +7 -7
- datachain/model/utils.py +191 -0
- datachain/nodes_thread_pool.py +32 -11
- datachain/query/dataset.py +4 -2
- datachain/studio.py +8 -6
- datachain/utils.py +3 -16
- {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/METADATA +6 -4
- {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/RECORD +39 -37
- {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/WHEEL +1 -1
- {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/LICENSE +0 -0
- {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -22,7 +22,7 @@ import orjson
|
|
|
22
22
|
import sqlalchemy
|
|
23
23
|
from pydantic import BaseModel
|
|
24
24
|
from sqlalchemy.sql.functions import GenericFunction
|
|
25
|
-
from
|
|
25
|
+
from tqdm import tqdm
|
|
26
26
|
|
|
27
27
|
from datachain.dataset import DatasetRecord
|
|
28
28
|
from datachain.func import literal
|
|
@@ -32,7 +32,14 @@ from datachain.lib.convert.python_to_sql import python_to_sql
|
|
|
32
32
|
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
33
33
|
from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
|
|
34
34
|
from datachain.lib.dataset_info import DatasetInfo
|
|
35
|
-
from datachain.lib.file import
|
|
35
|
+
from datachain.lib.file import (
|
|
36
|
+
EXPORT_FILES_MAX_THREADS,
|
|
37
|
+
ArrowRow,
|
|
38
|
+
File,
|
|
39
|
+
FileExporter,
|
|
40
|
+
FileType,
|
|
41
|
+
get_file_type,
|
|
42
|
+
)
|
|
36
43
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
37
44
|
from datachain.lib.listing import get_file_info, get_listing, list_bucket, ls
|
|
38
45
|
from datachain.lib.listing_info import ListingInfo
|
|
@@ -47,7 +54,6 @@ from datachain.query import Session
|
|
|
47
54
|
from datachain.query.dataset import DatasetQuery, PartitionByType
|
|
48
55
|
from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
|
|
49
56
|
from datachain.sql.functions import path as pathfunc
|
|
50
|
-
from datachain.telemetry import telemetry
|
|
51
57
|
from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
|
|
52
58
|
|
|
53
59
|
if TYPE_CHECKING:
|
|
@@ -65,7 +71,6 @@ _T = TypeVar("_T")
|
|
|
65
71
|
D = TypeVar("D", bound="DataChain")
|
|
66
72
|
UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
|
|
67
73
|
|
|
68
|
-
|
|
69
74
|
DEFAULT_PARQUET_CHUNK_SIZE = 100_000
|
|
70
75
|
|
|
71
76
|
|
|
@@ -208,7 +213,7 @@ class DataChain:
|
|
|
208
213
|
from mistralai.client import MistralClient
|
|
209
214
|
from mistralai.models.chat_completion import ChatMessage
|
|
210
215
|
|
|
211
|
-
from datachain
|
|
216
|
+
from datachain import DataChain, Column
|
|
212
217
|
|
|
213
218
|
PROMPT = (
|
|
214
219
|
"Was this bot dialog successful? "
|
|
@@ -401,7 +406,7 @@ class DataChain:
|
|
|
401
406
|
@classmethod
|
|
402
407
|
def from_storage(
|
|
403
408
|
cls,
|
|
404
|
-
uri,
|
|
409
|
+
uri: Union[str, os.PathLike[str]],
|
|
405
410
|
*,
|
|
406
411
|
type: FileType = "binary",
|
|
407
412
|
session: Optional[Session] = None,
|
|
@@ -543,6 +548,8 @@ class DataChain:
|
|
|
543
548
|
)
|
|
544
549
|
```
|
|
545
550
|
"""
|
|
551
|
+
from datachain.telemetry import telemetry
|
|
552
|
+
|
|
546
553
|
query = DatasetQuery(
|
|
547
554
|
name=name,
|
|
548
555
|
version=version,
|
|
@@ -566,7 +573,7 @@ class DataChain:
|
|
|
566
573
|
@classmethod
|
|
567
574
|
def from_json(
|
|
568
575
|
cls,
|
|
569
|
-
path,
|
|
576
|
+
path: Union[str, os.PathLike[str]],
|
|
570
577
|
type: FileType = "text",
|
|
571
578
|
spec: Optional[DataType] = None,
|
|
572
579
|
schema_from: Optional[str] = "auto",
|
|
@@ -603,7 +610,7 @@ class DataChain:
|
|
|
603
610
|
```
|
|
604
611
|
"""
|
|
605
612
|
if schema_from == "auto":
|
|
606
|
-
schema_from = path
|
|
613
|
+
schema_from = str(path)
|
|
607
614
|
|
|
608
615
|
def jmespath_to_name(s: str):
|
|
609
616
|
name_end = re.search(r"\W", s).start() if re.search(r"\W", s) else len(s) # type: ignore[union-attr]
|
|
@@ -694,9 +701,22 @@ class DataChain:
|
|
|
694
701
|
in_memory: bool = False,
|
|
695
702
|
object_name: str = "dataset",
|
|
696
703
|
include_listing: bool = False,
|
|
704
|
+
studio: bool = False,
|
|
697
705
|
) -> "DataChain":
|
|
698
706
|
"""Generate chain with list of registered datasets.
|
|
699
707
|
|
|
708
|
+
Args:
|
|
709
|
+
session: Optional session instance. If not provided, uses default session.
|
|
710
|
+
settings: Optional dictionary of settings to configure the chain.
|
|
711
|
+
in_memory: If True, creates an in-memory session. Defaults to False.
|
|
712
|
+
object_name: Name of the output object in the chain. Defaults to "dataset".
|
|
713
|
+
include_listing: If True, includes listing datasets. Defaults to False.
|
|
714
|
+
studio: If True, returns datasets from Studio only,
|
|
715
|
+
otherwise returns all local datasets. Defaults to False.
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
DataChain: A new DataChain instance containing dataset information.
|
|
719
|
+
|
|
700
720
|
Example:
|
|
701
721
|
```py
|
|
702
722
|
from datachain import DataChain
|
|
@@ -712,7 +732,7 @@ class DataChain:
|
|
|
712
732
|
datasets = [
|
|
713
733
|
DatasetInfo.from_models(d, v, j)
|
|
714
734
|
for d, v, j in catalog.list_datasets_versions(
|
|
715
|
-
include_listing=include_listing
|
|
735
|
+
include_listing=include_listing, studio=studio
|
|
716
736
|
)
|
|
717
737
|
]
|
|
718
738
|
|
|
@@ -1050,7 +1070,7 @@ class DataChain:
|
|
|
1050
1070
|
def select(self, *args: str, _sys: bool = True) -> "Self":
|
|
1051
1071
|
"""Select only a specified set of signals."""
|
|
1052
1072
|
new_schema = self.signals_schema.resolve(*args)
|
|
1053
|
-
if _sys:
|
|
1073
|
+
if self._sys and _sys:
|
|
1054
1074
|
new_schema = SignalSchema({"sys": Sys}) | new_schema
|
|
1055
1075
|
columns = new_schema.db_signals()
|
|
1056
1076
|
return self._evolve(
|
|
@@ -1093,6 +1113,7 @@ class DataChain:
|
|
|
1093
1113
|
partition_by_columns: list[Column] = []
|
|
1094
1114
|
signal_columns: list[Column] = []
|
|
1095
1115
|
schema_fields: dict[str, DataType] = {}
|
|
1116
|
+
keep_columns: list[str] = []
|
|
1096
1117
|
|
|
1097
1118
|
# validate partition_by columns and add them to the schema
|
|
1098
1119
|
for col in partition_by:
|
|
@@ -1100,10 +1121,13 @@ class DataChain:
|
|
|
1100
1121
|
col_db_name = ColumnMeta.to_db_name(col)
|
|
1101
1122
|
col_type = self.signals_schema.get_column_type(col_db_name)
|
|
1102
1123
|
column = Column(col_db_name, python_to_sql(col_type))
|
|
1124
|
+
if col not in keep_columns:
|
|
1125
|
+
keep_columns.append(col)
|
|
1103
1126
|
elif isinstance(col, Function):
|
|
1104
1127
|
column = col.get_column(self.signals_schema)
|
|
1105
1128
|
col_db_name = column.name
|
|
1106
1129
|
col_type = column.type.python_type
|
|
1130
|
+
schema_fields[col_db_name] = col_type
|
|
1107
1131
|
else:
|
|
1108
1132
|
raise DataChainColumnError(
|
|
1109
1133
|
col,
|
|
@@ -1113,7 +1137,6 @@ class DataChain:
|
|
|
1113
1137
|
),
|
|
1114
1138
|
)
|
|
1115
1139
|
partition_by_columns.append(column)
|
|
1116
|
-
schema_fields[col_db_name] = col_type
|
|
1117
1140
|
|
|
1118
1141
|
# validate signal columns and add them to the schema
|
|
1119
1142
|
if not kwargs:
|
|
@@ -1128,9 +1151,13 @@ class DataChain:
|
|
|
1128
1151
|
signal_columns.append(column)
|
|
1129
1152
|
schema_fields[col_name] = func.get_result_type(self.signals_schema)
|
|
1130
1153
|
|
|
1154
|
+
signal_schema = SignalSchema(schema_fields)
|
|
1155
|
+
if keep_columns:
|
|
1156
|
+
signal_schema |= self.signals_schema.to_partial(*keep_columns)
|
|
1157
|
+
|
|
1131
1158
|
return self._evolve(
|
|
1132
1159
|
query=self._query.group_by(signal_columns, partition_by_columns),
|
|
1133
|
-
signal_schema=
|
|
1160
|
+
signal_schema=signal_schema,
|
|
1134
1161
|
)
|
|
1135
1162
|
|
|
1136
1163
|
def mutate(self, **kwargs) -> "Self":
|
|
@@ -1181,6 +1208,8 @@ class DataChain:
|
|
|
1181
1208
|
)
|
|
1182
1209
|
```
|
|
1183
1210
|
"""
|
|
1211
|
+
from sqlalchemy.sql.sqltypes import NullType
|
|
1212
|
+
|
|
1184
1213
|
primitives = (bool, str, int, float)
|
|
1185
1214
|
|
|
1186
1215
|
for col_name, expr in kwargs.items():
|
|
@@ -1225,23 +1254,37 @@ class DataChain:
|
|
|
1225
1254
|
@overload
|
|
1226
1255
|
def collect_flatten(self) -> Iterator[tuple[Any, ...]]: ...
|
|
1227
1256
|
|
|
1257
|
+
@overload
|
|
1258
|
+
def collect_flatten(self, *, include_hidden: bool) -> Iterator[tuple[Any, ...]]: ...
|
|
1259
|
+
|
|
1228
1260
|
@overload
|
|
1229
1261
|
def collect_flatten(
|
|
1230
1262
|
self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
|
|
1231
1263
|
) -> Iterator[_T]: ...
|
|
1232
1264
|
|
|
1233
|
-
|
|
1265
|
+
@overload
|
|
1266
|
+
def collect_flatten(
|
|
1267
|
+
self,
|
|
1268
|
+
*,
|
|
1269
|
+
row_factory: Callable[[list[str], tuple[Any, ...]], _T],
|
|
1270
|
+
include_hidden: bool,
|
|
1271
|
+
) -> Iterator[_T]: ...
|
|
1272
|
+
|
|
1273
|
+
def collect_flatten(self, *, row_factory=None, include_hidden: bool = True):
|
|
1234
1274
|
"""Yields flattened rows of values as a tuple.
|
|
1235
1275
|
|
|
1236
1276
|
Args:
|
|
1237
1277
|
row_factory : A callable to convert row to a custom format.
|
|
1238
1278
|
It should accept two arguments: a list of column names and
|
|
1239
1279
|
a tuple of row values.
|
|
1280
|
+
include_hidden: Whether to include hidden signals from the schema.
|
|
1240
1281
|
"""
|
|
1241
|
-
db_signals = self._effective_signals_schema.db_signals(
|
|
1282
|
+
db_signals = self._effective_signals_schema.db_signals(
|
|
1283
|
+
include_hidden=include_hidden
|
|
1284
|
+
)
|
|
1242
1285
|
with self._query.ordered_select(*db_signals).as_iterable() as rows:
|
|
1243
1286
|
if row_factory:
|
|
1244
|
-
rows = (row_factory(db_signals, r) for r in rows)
|
|
1287
|
+
rows = (row_factory(db_signals, r) for r in rows) # type: ignore[assignment]
|
|
1245
1288
|
yield from rows
|
|
1246
1289
|
|
|
1247
1290
|
def to_columnar_data_with_names(
|
|
@@ -1275,10 +1318,23 @@ class DataChain:
|
|
|
1275
1318
|
self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
|
|
1276
1319
|
) -> list[_T]: ...
|
|
1277
1320
|
|
|
1278
|
-
|
|
1321
|
+
@overload
|
|
1322
|
+
def results(
|
|
1323
|
+
self,
|
|
1324
|
+
*,
|
|
1325
|
+
row_factory: Callable[[list[str], tuple[Any, ...]], _T],
|
|
1326
|
+
include_hidden: bool,
|
|
1327
|
+
) -> list[_T]: ...
|
|
1328
|
+
|
|
1329
|
+
@overload
|
|
1330
|
+
def results(self, *, include_hidden: bool) -> list[tuple[Any, ...]]: ...
|
|
1331
|
+
|
|
1332
|
+
def results(self, *, row_factory=None, include_hidden=True): # noqa: D102
|
|
1279
1333
|
if row_factory is None:
|
|
1280
|
-
return list(self.collect_flatten())
|
|
1281
|
-
return list(
|
|
1334
|
+
return list(self.collect_flatten(include_hidden=include_hidden))
|
|
1335
|
+
return list(
|
|
1336
|
+
self.collect_flatten(row_factory=row_factory, include_hidden=include_hidden)
|
|
1337
|
+
)
|
|
1282
1338
|
|
|
1283
1339
|
def to_records(self) -> list[dict[str, Any]]:
|
|
1284
1340
|
"""Convert every row to a dictionary."""
|
|
@@ -1788,21 +1844,25 @@ class DataChain:
|
|
|
1788
1844
|
**fr_map,
|
|
1789
1845
|
)
|
|
1790
1846
|
|
|
1791
|
-
def to_pandas(self, flatten=False) -> "pd.DataFrame":
|
|
1847
|
+
def to_pandas(self, flatten=False, include_hidden=True) -> "pd.DataFrame":
|
|
1792
1848
|
"""Return a pandas DataFrame from the chain.
|
|
1793
1849
|
|
|
1794
1850
|
Parameters:
|
|
1795
1851
|
flatten : Whether to use a multiindex or flatten column names.
|
|
1852
|
+
include_hidden : Whether to include hidden columns.
|
|
1796
1853
|
"""
|
|
1797
1854
|
import pandas as pd
|
|
1798
1855
|
|
|
1799
|
-
headers, max_length = self._effective_signals_schema.get_headers_with_length(
|
|
1856
|
+
headers, max_length = self._effective_signals_schema.get_headers_with_length(
|
|
1857
|
+
include_hidden=include_hidden
|
|
1858
|
+
)
|
|
1800
1859
|
if flatten or max_length < 2:
|
|
1801
1860
|
columns = [".".join(filter(None, header)) for header in headers]
|
|
1802
1861
|
else:
|
|
1803
1862
|
columns = pd.MultiIndex.from_tuples(map(tuple, headers))
|
|
1804
1863
|
|
|
1805
|
-
|
|
1864
|
+
results = self.results(include_hidden=include_hidden)
|
|
1865
|
+
return pd.DataFrame.from_records(results, columns=columns)
|
|
1806
1866
|
|
|
1807
1867
|
def show(
|
|
1808
1868
|
self,
|
|
@@ -1810,6 +1870,7 @@ class DataChain:
|
|
|
1810
1870
|
flatten=False,
|
|
1811
1871
|
transpose=False,
|
|
1812
1872
|
truncate=True,
|
|
1873
|
+
include_hidden=False,
|
|
1813
1874
|
) -> None:
|
|
1814
1875
|
"""Show a preview of the chain results.
|
|
1815
1876
|
|
|
@@ -1818,11 +1879,12 @@ class DataChain:
|
|
|
1818
1879
|
flatten : Whether to use a multiindex or flatten column names.
|
|
1819
1880
|
transpose : Whether to transpose rows and columns.
|
|
1820
1881
|
truncate : Whether or not to truncate the contents of columns.
|
|
1882
|
+
include_hidden : Whether to include hidden columns.
|
|
1821
1883
|
"""
|
|
1822
1884
|
import pandas as pd
|
|
1823
1885
|
|
|
1824
1886
|
dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
|
|
1825
|
-
df = dc.to_pandas(flatten)
|
|
1887
|
+
df = dc.to_pandas(flatten, include_hidden=include_hidden)
|
|
1826
1888
|
|
|
1827
1889
|
if df.empty:
|
|
1828
1890
|
print("Empty result")
|
|
@@ -2495,22 +2557,28 @@ class DataChain:
|
|
|
2495
2557
|
|
|
2496
2558
|
def to_storage(
|
|
2497
2559
|
self,
|
|
2498
|
-
output: str,
|
|
2560
|
+
output: Union[str, os.PathLike[str]],
|
|
2499
2561
|
signal: str = "file",
|
|
2500
2562
|
placement: FileExportPlacement = "fullpath",
|
|
2501
|
-
use_cache: bool = True,
|
|
2502
2563
|
link_type: Literal["copy", "symlink"] = "copy",
|
|
2564
|
+
num_threads: Optional[int] = EXPORT_FILES_MAX_THREADS,
|
|
2565
|
+
anon: bool = False,
|
|
2566
|
+
client_config: Optional[dict] = None,
|
|
2503
2567
|
) -> None:
|
|
2504
|
-
"""Export files from a specified signal to a directory.
|
|
2568
|
+
"""Export files from a specified signal to a directory. Files can be
|
|
2569
|
+
exported to a local or cloud directory.
|
|
2505
2570
|
|
|
2506
2571
|
Args:
|
|
2507
2572
|
output: Path to the target directory for exporting files.
|
|
2508
2573
|
signal: Name of the signal to export files from.
|
|
2509
2574
|
placement: The method to use for naming exported files.
|
|
2510
2575
|
The possible values are: "filename", "etag", "fullpath", and "checksum".
|
|
2511
|
-
use_cache: If `True`, cache the files before exporting.
|
|
2512
2576
|
link_type: Method to use for exporting files.
|
|
2513
2577
|
Falls back to `'copy'` if symlinking fails.
|
|
2578
|
+
num_threads : number of threads to use for exporting files.
|
|
2579
|
+
By default it uses 5 threads.
|
|
2580
|
+
anon: If true, we will treat cloud bucket as public one
|
|
2581
|
+
client_config: Optional configuration for the destination storage client
|
|
2514
2582
|
|
|
2515
2583
|
Example:
|
|
2516
2584
|
Cross cloud transfer
|
|
@@ -2525,8 +2593,26 @@ class DataChain:
|
|
|
2525
2593
|
):
|
|
2526
2594
|
raise ValueError("Files with the same name found")
|
|
2527
2595
|
|
|
2528
|
-
|
|
2529
|
-
|
|
2596
|
+
if anon:
|
|
2597
|
+
client_config = (client_config or {}) | {"anon": True}
|
|
2598
|
+
|
|
2599
|
+
progress_bar = tqdm(
|
|
2600
|
+
desc=f"Exporting files to {output}: ",
|
|
2601
|
+
unit=" files",
|
|
2602
|
+
unit_scale=True,
|
|
2603
|
+
unit_divisor=10,
|
|
2604
|
+
total=self.count(),
|
|
2605
|
+
leave=False,
|
|
2606
|
+
)
|
|
2607
|
+
file_exporter = FileExporter(
|
|
2608
|
+
output,
|
|
2609
|
+
placement,
|
|
2610
|
+
self._settings.cache if self._settings else False,
|
|
2611
|
+
link_type,
|
|
2612
|
+
max_threads=num_threads or 1,
|
|
2613
|
+
client_config=client_config,
|
|
2614
|
+
)
|
|
2615
|
+
file_exporter.run(self.collect(signal), progress_bar)
|
|
2530
2616
|
|
|
2531
2617
|
def shuffle(self) -> "Self":
|
|
2532
2618
|
"""Shuffle the rows of the chain deterministically."""
|
datachain/lib/file.py
CHANGED
|
@@ -18,12 +18,12 @@ from urllib.request import url2pathname
|
|
|
18
18
|
|
|
19
19
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
20
20
|
from fsspec.utils import stringify_path
|
|
21
|
-
from PIL import Image as PilImage
|
|
22
21
|
from pydantic import Field, field_validator
|
|
23
22
|
|
|
24
23
|
from datachain.client.fileslice import FileSlice
|
|
25
24
|
from datachain.lib.data_model import DataModel
|
|
26
25
|
from datachain.lib.utils import DataChainError
|
|
26
|
+
from datachain.nodes_thread_pool import NodesThreadPool
|
|
27
27
|
from datachain.sql.types import JSON, Boolean, DateTime, Int, String
|
|
28
28
|
from datachain.utils import TIME_ZERO
|
|
29
29
|
|
|
@@ -43,6 +43,41 @@ logger = logging.getLogger("datachain")
|
|
|
43
43
|
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
|
|
44
44
|
|
|
45
45
|
FileType = Literal["binary", "text", "image", "video"]
|
|
46
|
+
EXPORT_FILES_MAX_THREADS = 5
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class FileExporter(NodesThreadPool):
|
|
50
|
+
"""Class that does file exporting concurrently with thread pool"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
output: Union[str, os.PathLike[str]],
|
|
55
|
+
placement: ExportPlacement,
|
|
56
|
+
use_cache: bool,
|
|
57
|
+
link_type: Literal["copy", "symlink"],
|
|
58
|
+
max_threads: int = EXPORT_FILES_MAX_THREADS,
|
|
59
|
+
client_config: Optional[dict] = None,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(max_threads)
|
|
62
|
+
self.output = output
|
|
63
|
+
self.placement = placement
|
|
64
|
+
self.use_cache = use_cache
|
|
65
|
+
self.link_type = link_type
|
|
66
|
+
self.client_config = client_config
|
|
67
|
+
|
|
68
|
+
def done_task(self, done):
|
|
69
|
+
for task in done:
|
|
70
|
+
task.result()
|
|
71
|
+
|
|
72
|
+
def do_task(self, file):
|
|
73
|
+
file.export(
|
|
74
|
+
self.output,
|
|
75
|
+
self.placement,
|
|
76
|
+
self.use_cache,
|
|
77
|
+
link_type=self.link_type,
|
|
78
|
+
client_config=self.client_config,
|
|
79
|
+
)
|
|
80
|
+
self.increase_counter(1)
|
|
46
81
|
|
|
47
82
|
|
|
48
83
|
class VFileError(DataChainError):
|
|
@@ -158,6 +193,7 @@ class File(DataModel):
|
|
|
158
193
|
"last_modified": DateTime,
|
|
159
194
|
"location": JSON,
|
|
160
195
|
}
|
|
196
|
+
_hidden_fields: ClassVar[list[str]] = ["version", "source"]
|
|
161
197
|
|
|
162
198
|
_unique_id_keys: ClassVar[list[str]] = [
|
|
163
199
|
"source",
|
|
@@ -206,6 +242,30 @@ class File(DataModel):
|
|
|
206
242
|
self._catalog = None
|
|
207
243
|
self._caching_enabled: bool = False
|
|
208
244
|
|
|
245
|
+
def as_text_file(self) -> "TextFile":
|
|
246
|
+
"""Convert the file to a `TextFile` object."""
|
|
247
|
+
if isinstance(self, TextFile):
|
|
248
|
+
return self
|
|
249
|
+
file = TextFile(**self.model_dump())
|
|
250
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
251
|
+
return file
|
|
252
|
+
|
|
253
|
+
def as_image_file(self) -> "ImageFile":
|
|
254
|
+
"""Convert the file to a `ImageFile` object."""
|
|
255
|
+
if isinstance(self, ImageFile):
|
|
256
|
+
return self
|
|
257
|
+
file = ImageFile(**self.model_dump())
|
|
258
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
259
|
+
return file
|
|
260
|
+
|
|
261
|
+
def as_video_file(self) -> "VideoFile":
|
|
262
|
+
"""Convert the file to a `VideoFile` object."""
|
|
263
|
+
if isinstance(self, VideoFile):
|
|
264
|
+
return self
|
|
265
|
+
file = VideoFile(**self.model_dump())
|
|
266
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
267
|
+
return file
|
|
268
|
+
|
|
209
269
|
@classmethod
|
|
210
270
|
def upload(
|
|
211
271
|
cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
|
|
@@ -255,24 +315,24 @@ class File(DataModel):
|
|
|
255
315
|
) as f:
|
|
256
316
|
yield io.TextIOWrapper(f) if mode == "r" else f
|
|
257
317
|
|
|
258
|
-
def
|
|
259
|
-
"""Returns file contents."""
|
|
318
|
+
def read_bytes(self, length: int = -1):
|
|
319
|
+
"""Returns file contents as bytes."""
|
|
260
320
|
with self.open() as stream:
|
|
261
321
|
return stream.read(length)
|
|
262
322
|
|
|
263
|
-
def read_bytes(self):
|
|
264
|
-
"""Returns file contents as bytes."""
|
|
265
|
-
return self.read()
|
|
266
|
-
|
|
267
323
|
def read_text(self):
|
|
268
324
|
"""Returns file contents as text."""
|
|
269
325
|
with self.open(mode="r") as stream:
|
|
270
326
|
return stream.read()
|
|
271
327
|
|
|
272
|
-
def
|
|
328
|
+
def read(self, length: int = -1):
|
|
329
|
+
"""Returns file contents."""
|
|
330
|
+
return self.read_bytes(length)
|
|
331
|
+
|
|
332
|
+
def save(self, destination: str, client_config: Optional[dict] = None):
|
|
273
333
|
"""Writes it's content to destination"""
|
|
274
334
|
destination = stringify_path(destination)
|
|
275
|
-
client: Client = self._catalog.get_client(destination)
|
|
335
|
+
client: Client = self._catalog.get_client(destination, **(client_config or {}))
|
|
276
336
|
|
|
277
337
|
if client.PREFIX == "file://" and not destination.startswith(client.PREFIX):
|
|
278
338
|
destination = Path(destination).absolute().as_uri()
|
|
@@ -296,17 +356,17 @@ class File(DataModel):
|
|
|
296
356
|
|
|
297
357
|
def export(
|
|
298
358
|
self,
|
|
299
|
-
output: str,
|
|
359
|
+
output: Union[str, os.PathLike[str]],
|
|
300
360
|
placement: ExportPlacement = "fullpath",
|
|
301
361
|
use_cache: bool = True,
|
|
302
362
|
link_type: Literal["copy", "symlink"] = "copy",
|
|
363
|
+
client_config: Optional[dict] = None,
|
|
303
364
|
) -> None:
|
|
304
365
|
"""Export file to new location."""
|
|
305
|
-
|
|
306
|
-
self._caching_enabled = use_cache
|
|
366
|
+
self._caching_enabled = use_cache
|
|
307
367
|
dst = self.get_destination_path(output, placement)
|
|
308
368
|
dst_dir = os.path.dirname(dst)
|
|
309
|
-
client: Client = self._catalog.get_client(dst_dir)
|
|
369
|
+
client: Client = self._catalog.get_client(dst_dir, **(client_config or {}))
|
|
310
370
|
client.fs.makedirs(dst_dir, exist_ok=True)
|
|
311
371
|
|
|
312
372
|
if link_type == "symlink":
|
|
@@ -316,7 +376,7 @@ class File(DataModel):
|
|
|
316
376
|
if exc.errno not in (errno.ENOTSUP, errno.EXDEV, errno.ENOSYS):
|
|
317
377
|
raise
|
|
318
378
|
|
|
319
|
-
self.save(dst)
|
|
379
|
+
self.save(dst, client_config=client_config)
|
|
320
380
|
|
|
321
381
|
def _set_stream(
|
|
322
382
|
self,
|
|
@@ -337,15 +397,10 @@ class File(DataModel):
|
|
|
337
397
|
client.download(self, callback=self._download_cb)
|
|
338
398
|
|
|
339
399
|
async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
|
|
340
|
-
from datachain.client.hf import HfClient
|
|
341
|
-
|
|
342
400
|
if self._catalog is None:
|
|
343
401
|
raise RuntimeError("cannot prefetch file because catalog is not setup")
|
|
344
402
|
|
|
345
403
|
client = self._catalog.get_client(self.source)
|
|
346
|
-
if client.protocol == HfClient.protocol:
|
|
347
|
-
return False
|
|
348
|
-
|
|
349
404
|
await client._download(self, callback=download_cb or self._download_cb)
|
|
350
405
|
self._set_stream(
|
|
351
406
|
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
|
|
@@ -393,7 +448,9 @@ class File(DataModel):
|
|
|
393
448
|
path = url2pathname(path)
|
|
394
449
|
return path
|
|
395
450
|
|
|
396
|
-
def get_destination_path(
|
|
451
|
+
def get_destination_path(
|
|
452
|
+
self, output: Union[str, os.PathLike[str]], placement: ExportPlacement
|
|
453
|
+
) -> str:
|
|
397
454
|
"""
|
|
398
455
|
Returns full destination path of a file for exporting to some output
|
|
399
456
|
based on export placement
|
|
@@ -502,11 +559,11 @@ class TextFile(File):
|
|
|
502
559
|
with self.open() as stream:
|
|
503
560
|
return stream.read()
|
|
504
561
|
|
|
505
|
-
def save(self, destination: str):
|
|
562
|
+
def save(self, destination: str, client_config: Optional[dict] = None):
|
|
506
563
|
"""Writes it's content to destination"""
|
|
507
564
|
destination = stringify_path(destination)
|
|
508
565
|
|
|
509
|
-
client: Client = self._catalog.get_client(destination)
|
|
566
|
+
client: Client = self._catalog.get_client(destination, **(client_config or {}))
|
|
510
567
|
with client.fs.open(destination, mode="w") as f:
|
|
511
568
|
f.write(self.read_text())
|
|
512
569
|
|
|
@@ -514,18 +571,36 @@ class TextFile(File):
|
|
|
514
571
|
class ImageFile(File):
|
|
515
572
|
"""`DataModel` for reading image files."""
|
|
516
573
|
|
|
574
|
+
def get_info(self) -> "Image":
|
|
575
|
+
"""
|
|
576
|
+
Retrieves metadata and information about the image file.
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
Image: A Model containing image metadata such as width, height and format.
|
|
580
|
+
"""
|
|
581
|
+
from .image import image_info
|
|
582
|
+
|
|
583
|
+
return image_info(self)
|
|
584
|
+
|
|
517
585
|
def read(self):
|
|
518
586
|
"""Returns `PIL.Image.Image` object."""
|
|
587
|
+
from PIL import Image as PilImage
|
|
588
|
+
|
|
519
589
|
fobj = super().read()
|
|
520
590
|
return PilImage.open(BytesIO(fobj))
|
|
521
591
|
|
|
522
|
-
def save(
|
|
592
|
+
def save( # type: ignore[override]
|
|
593
|
+
self,
|
|
594
|
+
destination: str,
|
|
595
|
+
format: Optional[str] = None,
|
|
596
|
+
client_config: Optional[dict] = None,
|
|
597
|
+
):
|
|
523
598
|
"""Writes it's content to destination"""
|
|
524
599
|
destination = stringify_path(destination)
|
|
525
600
|
|
|
526
|
-
client: Client = self._catalog.get_client(destination)
|
|
601
|
+
client: Client = self._catalog.get_client(destination, **(client_config or {}))
|
|
527
602
|
with client.fs.open(destination, mode="wb") as f:
|
|
528
|
-
self.read().save(f)
|
|
603
|
+
self.read().save(f, format=format)
|
|
529
604
|
|
|
530
605
|
|
|
531
606
|
class Image(DataModel):
|
datachain/lib/image.py
CHANGED
|
@@ -1,17 +1,41 @@
|
|
|
1
1
|
from typing import Callable, Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from PIL import Image
|
|
4
|
+
from PIL import Image as PILImage
|
|
5
|
+
|
|
6
|
+
from datachain.lib.file import File, FileError, Image, ImageFile
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def image_info(file: Union[File, ImageFile]) -> Image:
|
|
10
|
+
"""
|
|
11
|
+
Returns image file information.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
file (ImageFile): Image file object.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Image: Image file information.
|
|
18
|
+
"""
|
|
19
|
+
try:
|
|
20
|
+
img = file.as_image_file().read()
|
|
21
|
+
except Exception as exc:
|
|
22
|
+
raise FileError(file, "unable to open image file") from exc
|
|
23
|
+
|
|
24
|
+
return Image(
|
|
25
|
+
width=img.width,
|
|
26
|
+
height=img.height,
|
|
27
|
+
format=img.format or "",
|
|
28
|
+
)
|
|
5
29
|
|
|
6
30
|
|
|
7
31
|
def convert_image(
|
|
8
|
-
img:
|
|
32
|
+
img: PILImage.Image,
|
|
9
33
|
mode: str = "RGB",
|
|
10
34
|
size: Optional[tuple[int, int]] = None,
|
|
11
35
|
transform: Optional[Callable] = None,
|
|
12
36
|
encoder: Optional[Callable] = None,
|
|
13
37
|
device: Optional[Union[str, torch.device]] = None,
|
|
14
|
-
) -> Union[
|
|
38
|
+
) -> Union[PILImage.Image, torch.Tensor]:
|
|
15
39
|
"""
|
|
16
40
|
Resize, transform, and otherwise convert an image.
|
|
17
41
|
|
|
@@ -47,13 +71,13 @@ def convert_image(
|
|
|
47
71
|
|
|
48
72
|
|
|
49
73
|
def convert_images(
|
|
50
|
-
images: Union[
|
|
74
|
+
images: Union[PILImage.Image, list[PILImage.Image]],
|
|
51
75
|
mode: str = "RGB",
|
|
52
76
|
size: Optional[tuple[int, int]] = None,
|
|
53
77
|
transform: Optional[Callable] = None,
|
|
54
78
|
encoder: Optional[Callable] = None,
|
|
55
79
|
device: Optional[Union[str, torch.device]] = None,
|
|
56
|
-
) -> Union[list[
|
|
80
|
+
) -> Union[list[PILImage.Image], torch.Tensor]:
|
|
57
81
|
"""
|
|
58
82
|
Resize, transform, and otherwise convert one or more images.
|
|
59
83
|
|
|
@@ -65,7 +89,7 @@ def convert_images(
|
|
|
65
89
|
encoder (Callable): Encode image using model.
|
|
66
90
|
device (str or torch.device): Device to use.
|
|
67
91
|
"""
|
|
68
|
-
if isinstance(images,
|
|
92
|
+
if isinstance(images, PILImage.Image):
|
|
69
93
|
images = [images]
|
|
70
94
|
|
|
71
95
|
converted = [
|