datachain 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +99 -113
- datachain/catalog/loader.py +8 -65
- datachain/cli.py +148 -57
- datachain/data_storage/__init__.py +0 -3
- datachain/data_storage/metastore.py +2 -9
- datachain/data_storage/sqlite.py +7 -145
- datachain/data_storage/warehouse.py +1 -5
- datachain/dataset.py +16 -3
- datachain/lib/pytorch.py +1 -4
- datachain/query/dataset.py +0 -3
- datachain/query/dispatch.py +1 -13
- datachain/query/session.py +0 -1
- datachain/remote/studio.py +33 -1
- datachain/studio.py +80 -0
- {datachain-0.7.4.dist-info → datachain-0.7.6.dist-info}/METADATA +1 -1
- {datachain-0.7.4.dist-info → datachain-0.7.6.dist-info}/RECORD +20 -21
- datachain/data_storage/id_generator.py +0 -136
- {datachain-0.7.4.dist-info → datachain-0.7.6.dist-info}/LICENSE +0 -0
- {datachain-0.7.4.dist-info → datachain-0.7.6.dist-info}/WHEEL +0 -0
- {datachain-0.7.4.dist-info → datachain-0.7.6.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.4.dist-info → datachain-0.7.6.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -42,7 +42,6 @@ from datachain.dataset import (
|
|
|
42
42
|
DatasetRecord,
|
|
43
43
|
DatasetStats,
|
|
44
44
|
DatasetStatus,
|
|
45
|
-
RowDict,
|
|
46
45
|
StorageURI,
|
|
47
46
|
create_dataset_uri,
|
|
48
47
|
parse_dataset_uri,
|
|
@@ -69,13 +68,11 @@ from .datasource import DataSource
|
|
|
69
68
|
|
|
70
69
|
if TYPE_CHECKING:
|
|
71
70
|
from datachain.data_storage import (
|
|
72
|
-
AbstractIDGenerator,
|
|
73
71
|
AbstractMetastore,
|
|
74
72
|
AbstractWarehouse,
|
|
75
73
|
)
|
|
76
74
|
from datachain.dataset import DatasetListVersion
|
|
77
75
|
from datachain.job import Job
|
|
78
|
-
from datachain.lib.file import File
|
|
79
76
|
from datachain.listing import Listing
|
|
80
77
|
|
|
81
78
|
logger = logging.getLogger("datachain")
|
|
@@ -127,8 +124,10 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
127
124
|
self,
|
|
128
125
|
metastore: "AbstractMetastore",
|
|
129
126
|
warehouse: "AbstractWarehouse",
|
|
130
|
-
|
|
131
|
-
|
|
127
|
+
remote_ds_name: str,
|
|
128
|
+
remote_ds_version: int,
|
|
129
|
+
local_ds_name: str,
|
|
130
|
+
local_ds_version: int,
|
|
132
131
|
schema: dict[str, Union[SQLType, type[SQLType]]],
|
|
133
132
|
max_threads: int = PULL_DATASET_MAX_THREADS,
|
|
134
133
|
):
|
|
@@ -136,8 +135,10 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
136
135
|
self._check_dependencies()
|
|
137
136
|
self.metastore = metastore
|
|
138
137
|
self.warehouse = warehouse
|
|
139
|
-
self.
|
|
140
|
-
self.
|
|
138
|
+
self.remote_ds_name = remote_ds_name
|
|
139
|
+
self.remote_ds_version = remote_ds_version
|
|
140
|
+
self.local_ds_name = local_ds_name
|
|
141
|
+
self.local_ds_version = local_ds_version
|
|
141
142
|
self.schema = schema
|
|
142
143
|
self.last_status_check: Optional[float] = None
|
|
143
144
|
self.studio_client = StudioClient()
|
|
@@ -171,7 +172,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
171
172
|
Checks are done every PULL_DATASET_CHECK_STATUS_INTERVAL seconds
|
|
172
173
|
"""
|
|
173
174
|
export_status_response = self.studio_client.dataset_export_status(
|
|
174
|
-
self.
|
|
175
|
+
self.remote_ds_name, self.remote_ds_version
|
|
175
176
|
)
|
|
176
177
|
if not export_status_response.ok:
|
|
177
178
|
raise_remote_error(export_status_response.message)
|
|
@@ -203,7 +204,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
203
204
|
|
|
204
205
|
# metastore and warehouse are not thread safe
|
|
205
206
|
with self.metastore.clone() as metastore, self.warehouse.clone() as warehouse:
|
|
206
|
-
|
|
207
|
+
local_ds = metastore.get_dataset(self.local_ds_name)
|
|
207
208
|
|
|
208
209
|
urls = list(urls)
|
|
209
210
|
while urls:
|
|
@@ -227,7 +228,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
227
228
|
df = df.drop("sys__id", axis=1)
|
|
228
229
|
|
|
229
230
|
inserted = warehouse.insert_dataset_rows(
|
|
230
|
-
df,
|
|
231
|
+
df, local_ds, self.local_ds_version
|
|
231
232
|
)
|
|
232
233
|
self.increase_counter(inserted) # type: ignore [arg-type]
|
|
233
234
|
urls.remove(url)
|
|
@@ -520,7 +521,6 @@ def find_column_to_str( # noqa: PLR0911
|
|
|
520
521
|
class Catalog:
|
|
521
522
|
def __init__(
|
|
522
523
|
self,
|
|
523
|
-
id_generator: "AbstractIDGenerator",
|
|
524
524
|
metastore: "AbstractMetastore",
|
|
525
525
|
warehouse: "AbstractWarehouse",
|
|
526
526
|
cache_dir=None,
|
|
@@ -533,7 +533,6 @@ class Catalog:
|
|
|
533
533
|
):
|
|
534
534
|
datachain_dir = DataChainDir(cache=cache_dir, tmp=tmp_dir)
|
|
535
535
|
datachain_dir.init()
|
|
536
|
-
self.id_generator = id_generator
|
|
537
536
|
self.metastore = metastore
|
|
538
537
|
self._warehouse = warehouse
|
|
539
538
|
self.cache = DataChainCache(datachain_dir.cache, datachain_dir.tmp)
|
|
@@ -567,7 +566,6 @@ class Catalog:
|
|
|
567
566
|
def copy(self, cache=True, db=True):
|
|
568
567
|
result = copy(self)
|
|
569
568
|
if not db:
|
|
570
|
-
result.id_generator = None
|
|
571
569
|
result.metastore = None
|
|
572
570
|
result._warehouse = None
|
|
573
571
|
result.warehouse = None
|
|
@@ -967,7 +965,6 @@ class Catalog:
|
|
|
967
965
|
are cleaned up as soon as they are no longer needed.
|
|
968
966
|
"""
|
|
969
967
|
self.warehouse.cleanup_tables(names)
|
|
970
|
-
self.id_generator.delete_uris(names)
|
|
971
968
|
|
|
972
969
|
def create_dataset_from_sources(
|
|
973
970
|
self,
|
|
@@ -1101,6 +1098,13 @@ class Catalog:
|
|
|
1101
1098
|
def get_dataset(self, name: str) -> DatasetRecord:
|
|
1102
1099
|
return self.metastore.get_dataset(name)
|
|
1103
1100
|
|
|
1101
|
+
def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord:
|
|
1102
|
+
"""Returns dataset that contains version with specific uuid"""
|
|
1103
|
+
for dataset in self.ls_datasets():
|
|
1104
|
+
if dataset.has_version_with_uuid(uuid):
|
|
1105
|
+
return self.get_dataset(dataset.name)
|
|
1106
|
+
raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.")
|
|
1107
|
+
|
|
1104
1108
|
def get_remote_dataset(self, name: str) -> DatasetRecord:
|
|
1105
1109
|
studio_client = StudioClient()
|
|
1106
1110
|
|
|
@@ -1268,35 +1272,6 @@ class Catalog:
|
|
|
1268
1272
|
dataset = self.get_dataset(name)
|
|
1269
1273
|
return self.update_dataset(dataset, **update_data)
|
|
1270
1274
|
|
|
1271
|
-
def get_file_from_row(
|
|
1272
|
-
self, dataset_name: str, dataset_version: int, row: RowDict, signal_name: str
|
|
1273
|
-
) -> "File":
|
|
1274
|
-
"""
|
|
1275
|
-
Function that returns specific file signal from dataset row by name.
|
|
1276
|
-
"""
|
|
1277
|
-
from datachain.lib.file import File
|
|
1278
|
-
from datachain.lib.signal_schema import DEFAULT_DELIMITER, SignalSchema
|
|
1279
|
-
|
|
1280
|
-
version = self.get_dataset(dataset_name).get_version(dataset_version)
|
|
1281
|
-
schema = SignalSchema.deserialize(version.feature_schema)
|
|
1282
|
-
|
|
1283
|
-
if signal_name not in schema.get_signals(File):
|
|
1284
|
-
raise RuntimeError(
|
|
1285
|
-
f"File signal with path {signal_name} not found in ",
|
|
1286
|
-
f"dataset {dataset_name}@v{dataset_version} signals schema",
|
|
1287
|
-
)
|
|
1288
|
-
|
|
1289
|
-
prefix = signal_name.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
|
|
1290
|
-
file_signals = {
|
|
1291
|
-
c_name.removeprefix(prefix): c_value
|
|
1292
|
-
for c_name, c_value in row.items()
|
|
1293
|
-
if c_name.startswith(prefix)
|
|
1294
|
-
and DEFAULT_DELIMITER not in c_name.removeprefix(prefix)
|
|
1295
|
-
and c_name.removeprefix(prefix) in File.model_fields
|
|
1296
|
-
}
|
|
1297
|
-
|
|
1298
|
-
return File(**file_signals)
|
|
1299
|
-
|
|
1300
1275
|
def ls(
|
|
1301
1276
|
self,
|
|
1302
1277
|
sources: list[str],
|
|
@@ -1316,10 +1291,12 @@ class Catalog:
|
|
|
1316
1291
|
for source in data_sources: # type: ignore [union-attr]
|
|
1317
1292
|
yield source, source.ls(fields)
|
|
1318
1293
|
|
|
1319
|
-
def pull_dataset(
|
|
1294
|
+
def pull_dataset( # noqa: PLR0915
|
|
1320
1295
|
self,
|
|
1321
|
-
|
|
1296
|
+
remote_ds_uri: str,
|
|
1322
1297
|
output: Optional[str] = None,
|
|
1298
|
+
local_ds_name: Optional[str] = None,
|
|
1299
|
+
local_ds_version: Optional[int] = None,
|
|
1323
1300
|
no_cp: bool = False,
|
|
1324
1301
|
force: bool = False,
|
|
1325
1302
|
edatachain: bool = False,
|
|
@@ -1327,105 +1304,112 @@ class Catalog:
|
|
|
1327
1304
|
*,
|
|
1328
1305
|
client_config=None,
|
|
1329
1306
|
) -> None:
|
|
1330
|
-
|
|
1331
|
-
# TODO copy correct remote dates https://github.com/iterative/dvcx/issues/new
|
|
1332
|
-
# TODO compare dataset stats on remote vs local pull to assert it's ok
|
|
1333
|
-
def _instantiate_dataset():
|
|
1307
|
+
def _instantiate(ds_uri: str) -> None:
|
|
1334
1308
|
if no_cp:
|
|
1335
1309
|
return
|
|
1310
|
+
assert output
|
|
1336
1311
|
self.cp(
|
|
1337
|
-
[
|
|
1312
|
+
[ds_uri],
|
|
1338
1313
|
output,
|
|
1339
1314
|
force=force,
|
|
1340
1315
|
no_edatachain_file=not edatachain,
|
|
1341
1316
|
edatachain_file=edatachain_file,
|
|
1342
1317
|
client_config=client_config,
|
|
1343
1318
|
)
|
|
1344
|
-
print(f"Dataset {
|
|
1319
|
+
print(f"Dataset {ds_uri} instantiated locally to {output}")
|
|
1345
1320
|
|
|
1346
1321
|
if not output and not no_cp:
|
|
1347
1322
|
raise ValueError("Please provide output directory for instantiation")
|
|
1348
1323
|
|
|
1349
|
-
client_config = client_config or self.client_config
|
|
1350
|
-
|
|
1351
1324
|
studio_client = StudioClient()
|
|
1352
1325
|
|
|
1353
1326
|
try:
|
|
1354
|
-
|
|
1327
|
+
remote_ds_name, version = parse_dataset_uri(remote_ds_uri)
|
|
1355
1328
|
except Exception as e:
|
|
1356
1329
|
raise DataChainError("Error when parsing dataset uri") from e
|
|
1357
1330
|
|
|
1358
|
-
|
|
1359
|
-
try:
|
|
1360
|
-
dataset = self.get_dataset(remote_dataset_name)
|
|
1361
|
-
except DatasetNotFoundError:
|
|
1362
|
-
# we will create new one if it doesn't exist
|
|
1363
|
-
pass
|
|
1364
|
-
|
|
1365
|
-
if dataset and version and dataset.has_version(version):
|
|
1366
|
-
"""No need to communicate with Studio at all"""
|
|
1367
|
-
dataset_uri = create_dataset_uri(remote_dataset_name, version)
|
|
1368
|
-
print(f"Local copy of dataset {dataset_uri} already present")
|
|
1369
|
-
_instantiate_dataset()
|
|
1370
|
-
return
|
|
1331
|
+
remote_ds = self.get_remote_dataset(remote_ds_name)
|
|
1371
1332
|
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1333
|
+
try:
|
|
1334
|
+
# if version is not specified in uri, take the latest one
|
|
1335
|
+
if not version:
|
|
1336
|
+
version = remote_ds.latest_version
|
|
1337
|
+
print(f"Version not specified, pulling the latest one (v{version})")
|
|
1338
|
+
# updating dataset uri with latest version
|
|
1339
|
+
remote_ds_uri = create_dataset_uri(remote_ds_name, version)
|
|
1340
|
+
remote_ds_version = remote_ds.get_version(version)
|
|
1341
|
+
except (DatasetVersionNotFoundError, StopIteration) as exc:
|
|
1342
|
+
raise DataChainError(
|
|
1343
|
+
f"Dataset {remote_ds_name} doesn't have version {version} on server"
|
|
1344
|
+
) from exc
|
|
1379
1345
|
|
|
1380
|
-
|
|
1346
|
+
local_ds_name = local_ds_name or remote_ds.name
|
|
1347
|
+
local_ds_version = local_ds_version or remote_ds_version.version
|
|
1348
|
+
local_ds_uri = create_dataset_uri(local_ds_name, local_ds_version)
|
|
1381
1349
|
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1350
|
+
try:
|
|
1351
|
+
# try to find existing dataset with the same uuid to avoid pulling again
|
|
1352
|
+
existing_ds = self.get_dataset_with_version_uuid(remote_ds_version.uuid)
|
|
1353
|
+
existing_ds_version = existing_ds.get_version_by_uuid(
|
|
1354
|
+
remote_ds_version.uuid
|
|
1355
|
+
)
|
|
1356
|
+
existing_ds_uri = create_dataset_uri(
|
|
1357
|
+
existing_ds.name, existing_ds_version.version
|
|
1358
|
+
)
|
|
1359
|
+
if existing_ds_uri == remote_ds_uri:
|
|
1360
|
+
print(f"Local copy of dataset {remote_ds_uri} already present")
|
|
1361
|
+
else:
|
|
1362
|
+
print(
|
|
1363
|
+
f"Local copy of dataset {remote_ds_uri} already present as"
|
|
1364
|
+
f" dataset {existing_ds_uri}"
|
|
1365
|
+
)
|
|
1366
|
+
_instantiate(existing_ds_uri)
|
|
1385
1367
|
return
|
|
1368
|
+
except DatasetNotFoundError:
|
|
1369
|
+
pass
|
|
1386
1370
|
|
|
1387
1371
|
try:
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1372
|
+
local_dataset = self.get_dataset(local_ds_name)
|
|
1373
|
+
if local_dataset and local_dataset.has_version(local_ds_version):
|
|
1374
|
+
raise DataChainError(
|
|
1375
|
+
f"Local dataset {local_ds_uri} already exists with different uuid,"
|
|
1376
|
+
" please choose different local dataset name or version"
|
|
1377
|
+
)
|
|
1378
|
+
except DatasetNotFoundError:
|
|
1379
|
+
pass
|
|
1394
1380
|
|
|
1395
|
-
stats_response = studio_client.dataset_stats(
|
|
1381
|
+
stats_response = studio_client.dataset_stats(
|
|
1382
|
+
remote_ds_name, remote_ds_version.version
|
|
1383
|
+
)
|
|
1396
1384
|
if not stats_response.ok:
|
|
1397
1385
|
raise_remote_error(stats_response.message)
|
|
1398
|
-
|
|
1386
|
+
ds_stats = stats_response.data
|
|
1399
1387
|
|
|
1400
1388
|
dataset_save_progress_bar = tqdm(
|
|
1401
|
-
desc=f"Saving dataset {
|
|
1389
|
+
desc=f"Saving dataset {remote_ds_uri} locally: ",
|
|
1402
1390
|
unit=" rows",
|
|
1403
1391
|
unit_scale=True,
|
|
1404
1392
|
unit_divisor=1000,
|
|
1405
|
-
total=
|
|
1393
|
+
total=ds_stats.num_objects, # type: ignore [union-attr]
|
|
1406
1394
|
)
|
|
1407
1395
|
|
|
1408
|
-
schema = DatasetRecord.parse_schema(
|
|
1396
|
+
schema = DatasetRecord.parse_schema(remote_ds_version.schema)
|
|
1409
1397
|
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
dataset = self.create_dataset(
|
|
1415
|
-
remote_dataset_name,
|
|
1416
|
-
version,
|
|
1417
|
-
query_script=remote_dataset_version.query_script,
|
|
1398
|
+
local_ds = self.create_dataset(
|
|
1399
|
+
local_ds_name,
|
|
1400
|
+
local_ds_version,
|
|
1401
|
+
query_script=remote_ds_version.query_script,
|
|
1418
1402
|
create_rows=True,
|
|
1419
|
-
columns=
|
|
1420
|
-
feature_schema=
|
|
1403
|
+
columns=tuple(sa.Column(n, t) for n, t in schema.items() if n != "sys__id"),
|
|
1404
|
+
feature_schema=remote_ds_version.feature_schema,
|
|
1421
1405
|
validate_version=False,
|
|
1422
|
-
uuid=
|
|
1406
|
+
uuid=remote_ds_version.uuid,
|
|
1423
1407
|
)
|
|
1424
1408
|
|
|
1425
1409
|
# asking remote to export dataset rows table to s3 and to return signed
|
|
1426
1410
|
# urls of exported parts, which are in parquet format
|
|
1427
1411
|
export_response = studio_client.export_dataset_table(
|
|
1428
|
-
|
|
1412
|
+
remote_ds_name, remote_ds_version.version
|
|
1429
1413
|
)
|
|
1430
1414
|
if not export_response.ok:
|
|
1431
1415
|
raise_remote_error(export_response.message)
|
|
@@ -1442,8 +1426,10 @@ class Catalog:
|
|
|
1442
1426
|
rows_fetcher = DatasetRowsFetcher(
|
|
1443
1427
|
metastore,
|
|
1444
1428
|
warehouse,
|
|
1445
|
-
|
|
1446
|
-
version,
|
|
1429
|
+
remote_ds_name,
|
|
1430
|
+
remote_ds_version.version,
|
|
1431
|
+
local_ds_name,
|
|
1432
|
+
local_ds_version,
|
|
1447
1433
|
schema,
|
|
1448
1434
|
)
|
|
1449
1435
|
try:
|
|
@@ -1455,23 +1441,23 @@ class Catalog:
|
|
|
1455
1441
|
dataset_save_progress_bar,
|
|
1456
1442
|
)
|
|
1457
1443
|
except:
|
|
1458
|
-
self.remove_dataset(
|
|
1444
|
+
self.remove_dataset(local_ds_name, local_ds_version)
|
|
1459
1445
|
raise
|
|
1460
1446
|
|
|
1461
|
-
|
|
1462
|
-
|
|
1447
|
+
local_ds = self.metastore.update_dataset_status(
|
|
1448
|
+
local_ds,
|
|
1463
1449
|
DatasetStatus.COMPLETE,
|
|
1464
|
-
version=
|
|
1465
|
-
error_message=
|
|
1466
|
-
error_stack=
|
|
1467
|
-
script_output=
|
|
1450
|
+
version=local_ds_version,
|
|
1451
|
+
error_message=remote_ds.error_message,
|
|
1452
|
+
error_stack=remote_ds.error_stack,
|
|
1453
|
+
script_output=remote_ds.error_stack,
|
|
1468
1454
|
)
|
|
1469
|
-
self.update_dataset_version_with_warehouse_info(
|
|
1455
|
+
self.update_dataset_version_with_warehouse_info(local_ds, local_ds_version)
|
|
1470
1456
|
|
|
1471
1457
|
dataset_save_progress_bar.close()
|
|
1472
|
-
print(f"Dataset {
|
|
1458
|
+
print(f"Dataset {remote_ds_uri} saved locally")
|
|
1473
1459
|
|
|
1474
|
-
|
|
1460
|
+
_instantiate(local_ds_uri)
|
|
1475
1461
|
|
|
1476
1462
|
def clone(
|
|
1477
1463
|
self,
|
datachain/catalog/loader.py
CHANGED
|
@@ -4,21 +4,16 @@ from typing import Any, Optional
|
|
|
4
4
|
|
|
5
5
|
from datachain.catalog import Catalog
|
|
6
6
|
from datachain.data_storage import (
|
|
7
|
-
AbstractIDGenerator,
|
|
8
7
|
AbstractMetastore,
|
|
9
8
|
AbstractWarehouse,
|
|
10
9
|
)
|
|
11
10
|
from datachain.data_storage.serializer import deserialize
|
|
12
11
|
from datachain.data_storage.sqlite import (
|
|
13
|
-
SQLiteIDGenerator,
|
|
14
12
|
SQLiteMetastore,
|
|
15
13
|
SQLiteWarehouse,
|
|
16
14
|
)
|
|
17
15
|
from datachain.utils import get_envs_by_prefix
|
|
18
16
|
|
|
19
|
-
ID_GENERATOR_SERIALIZED = "DATACHAIN__ID_GENERATOR"
|
|
20
|
-
ID_GENERATOR_IMPORT_PATH = "DATACHAIN_ID_GENERATOR"
|
|
21
|
-
ID_GENERATOR_ARG_PREFIX = "DATACHAIN_ID_GENERATOR_ARG_"
|
|
22
17
|
METASTORE_SERIALIZED = "DATACHAIN__METASTORE"
|
|
23
18
|
METASTORE_IMPORT_PATH = "DATACHAIN_METASTORE"
|
|
24
19
|
METASTORE_ARG_PREFIX = "DATACHAIN_METASTORE_ARG_"
|
|
@@ -31,45 +26,7 @@ DISTRIBUTED_ARG_PREFIX = "DATACHAIN_DISTRIBUTED_ARG_"
|
|
|
31
26
|
IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"
|
|
32
27
|
|
|
33
28
|
|
|
34
|
-
def
|
|
35
|
-
id_generator_serialized = os.environ.get(ID_GENERATOR_SERIALIZED)
|
|
36
|
-
if id_generator_serialized:
|
|
37
|
-
id_generator_obj = deserialize(id_generator_serialized)
|
|
38
|
-
if not isinstance(id_generator_obj, AbstractIDGenerator):
|
|
39
|
-
raise RuntimeError(
|
|
40
|
-
"Deserialized ID generator is not an instance of AbstractIDGenerator: "
|
|
41
|
-
f"{id_generator_obj}"
|
|
42
|
-
)
|
|
43
|
-
return id_generator_obj
|
|
44
|
-
|
|
45
|
-
id_generator_import_path = os.environ.get(ID_GENERATOR_IMPORT_PATH)
|
|
46
|
-
id_generator_arg_envs = get_envs_by_prefix(ID_GENERATOR_ARG_PREFIX)
|
|
47
|
-
# Convert env variable names to keyword argument names by lowercasing them
|
|
48
|
-
id_generator_args: dict[str, Any] = {
|
|
49
|
-
k.lower(): v for k, v in id_generator_arg_envs.items()
|
|
50
|
-
}
|
|
51
|
-
|
|
52
|
-
if not id_generator_import_path:
|
|
53
|
-
id_generator_args["in_memory"] = in_memory
|
|
54
|
-
return SQLiteIDGenerator(**id_generator_args)
|
|
55
|
-
if in_memory:
|
|
56
|
-
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
|
|
57
|
-
# ID generator paths are specified as (for example):
|
|
58
|
-
# datachain.data_storage.SQLiteIDGenerator
|
|
59
|
-
if "." not in id_generator_import_path:
|
|
60
|
-
raise RuntimeError(
|
|
61
|
-
f"Invalid {ID_GENERATOR_IMPORT_PATH} import path:"
|
|
62
|
-
f"{id_generator_import_path}"
|
|
63
|
-
)
|
|
64
|
-
module_name, _, class_name = id_generator_import_path.rpartition(".")
|
|
65
|
-
id_generator = import_module(module_name)
|
|
66
|
-
id_generator_class = getattr(id_generator, class_name)
|
|
67
|
-
return id_generator_class(**id_generator_args)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def get_metastore(
|
|
71
|
-
id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
|
|
72
|
-
) -> "AbstractMetastore":
|
|
29
|
+
def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
|
|
73
30
|
metastore_serialized = os.environ.get(METASTORE_SERIALIZED)
|
|
74
31
|
if metastore_serialized:
|
|
75
32
|
metastore_obj = deserialize(metastore_serialized)
|
|
@@ -80,9 +37,6 @@ def get_metastore(
|
|
|
80
37
|
)
|
|
81
38
|
return metastore_obj
|
|
82
39
|
|
|
83
|
-
if id_generator is None:
|
|
84
|
-
id_generator = get_id_generator()
|
|
85
|
-
|
|
86
40
|
metastore_import_path = os.environ.get(METASTORE_IMPORT_PATH)
|
|
87
41
|
metastore_arg_envs = get_envs_by_prefix(METASTORE_ARG_PREFIX)
|
|
88
42
|
# Convert env variable names to keyword argument names by lowercasing them
|
|
@@ -91,10 +45,8 @@ def get_metastore(
|
|
|
91
45
|
}
|
|
92
46
|
|
|
93
47
|
if not metastore_import_path:
|
|
94
|
-
if not isinstance(id_generator, SQLiteIDGenerator):
|
|
95
|
-
raise ValueError("SQLiteMetastore can only be used with SQLiteIDGenerator")
|
|
96
48
|
metastore_args["in_memory"] = in_memory
|
|
97
|
-
return SQLiteMetastore(
|
|
49
|
+
return SQLiteMetastore(**metastore_args)
|
|
98
50
|
if in_memory:
|
|
99
51
|
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
|
|
100
52
|
# Metastore paths are specified as (for example):
|
|
@@ -106,12 +58,10 @@ def get_metastore(
|
|
|
106
58
|
module_name, _, class_name = metastore_import_path.rpartition(".")
|
|
107
59
|
metastore = import_module(module_name)
|
|
108
60
|
metastore_class = getattr(metastore, class_name)
|
|
109
|
-
return metastore_class(
|
|
61
|
+
return metastore_class(**metastore_args)
|
|
110
62
|
|
|
111
63
|
|
|
112
|
-
def get_warehouse(
|
|
113
|
-
id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
|
|
114
|
-
) -> "AbstractWarehouse":
|
|
64
|
+
def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
|
|
115
65
|
warehouse_serialized = os.environ.get(WAREHOUSE_SERIALIZED)
|
|
116
66
|
if warehouse_serialized:
|
|
117
67
|
warehouse_obj = deserialize(warehouse_serialized)
|
|
@@ -122,9 +72,6 @@ def get_warehouse(
|
|
|
122
72
|
)
|
|
123
73
|
return warehouse_obj
|
|
124
74
|
|
|
125
|
-
if id_generator is None:
|
|
126
|
-
id_generator = get_id_generator()
|
|
127
|
-
|
|
128
75
|
warehouse_import_path = os.environ.get(WAREHOUSE_IMPORT_PATH)
|
|
129
76
|
warehouse_arg_envs = get_envs_by_prefix(WAREHOUSE_ARG_PREFIX)
|
|
130
77
|
# Convert env variable names to keyword argument names by lowercasing them
|
|
@@ -133,10 +80,8 @@ def get_warehouse(
|
|
|
133
80
|
}
|
|
134
81
|
|
|
135
82
|
if not warehouse_import_path:
|
|
136
|
-
if not isinstance(id_generator, SQLiteIDGenerator):
|
|
137
|
-
raise ValueError("SQLiteWarehouse can only be used with SQLiteIDGenerator")
|
|
138
83
|
warehouse_args["in_memory"] = in_memory
|
|
139
|
-
return SQLiteWarehouse(
|
|
84
|
+
return SQLiteWarehouse(**warehouse_args)
|
|
140
85
|
if in_memory:
|
|
141
86
|
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
|
|
142
87
|
# Warehouse paths are specified as (for example):
|
|
@@ -148,7 +93,7 @@ def get_warehouse(
|
|
|
148
93
|
module_name, _, class_name = warehouse_import_path.rpartition(".")
|
|
149
94
|
warehouse = import_module(module_name)
|
|
150
95
|
warehouse_class = getattr(warehouse, class_name)
|
|
151
|
-
return warehouse_class(
|
|
96
|
+
return warehouse_class(**warehouse_args)
|
|
152
97
|
|
|
153
98
|
|
|
154
99
|
def get_distributed_class(**kwargs):
|
|
@@ -188,11 +133,9 @@ def get_catalog(
|
|
|
188
133
|
and name of variable after, e.g. if it accepts team_id as kwargs
|
|
189
134
|
we can provide DATACHAIN_METASTORE_ARG_TEAM_ID=12345 env variable.
|
|
190
135
|
"""
|
|
191
|
-
id_generator = get_id_generator(in_memory=in_memory)
|
|
192
136
|
return Catalog(
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
warehouse=get_warehouse(id_generator, in_memory=in_memory),
|
|
137
|
+
metastore=get_metastore(in_memory=in_memory),
|
|
138
|
+
warehouse=get_warehouse(in_memory=in_memory),
|
|
196
139
|
client_config=client_config,
|
|
197
140
|
in_memory=in_memory,
|
|
198
141
|
)
|