datachain 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +2 -0
- datachain/catalog/catalog.py +62 -228
- datachain/cli.py +136 -22
- datachain/client/fsspec.py +9 -0
- datachain/client/local.py +11 -32
- datachain/config.py +126 -51
- datachain/data_storage/schema.py +66 -33
- datachain/data_storage/sqlite.py +12 -4
- datachain/data_storage/warehouse.py +101 -129
- datachain/lib/convert/sql_to_python.py +8 -12
- datachain/lib/dc.py +275 -80
- datachain/lib/func/__init__.py +32 -0
- datachain/lib/func/aggregate.py +353 -0
- datachain/lib/func/func.py +152 -0
- datachain/lib/listing.py +6 -21
- datachain/lib/listing_info.py +4 -0
- datachain/lib/signal_schema.py +17 -8
- datachain/lib/udf.py +3 -3
- datachain/lib/utils.py +5 -0
- datachain/listing.py +22 -48
- datachain/query/__init__.py +1 -2
- datachain/query/batch.py +0 -1
- datachain/query/dataset.py +33 -46
- datachain/query/schema.py +1 -61
- datachain/query/session.py +33 -25
- datachain/remote/studio.py +63 -14
- datachain/sql/functions/__init__.py +1 -1
- datachain/sql/functions/aggregate.py +47 -0
- datachain/sql/functions/array.py +0 -8
- datachain/sql/sqlite/base.py +20 -2
- datachain/studio.py +129 -0
- datachain/utils.py +58 -0
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/METADATA +7 -6
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/RECORD +38 -33
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/WHEEL +1 -1
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/LICENSE +0 -0
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/top_level.txt +0 -0
datachain/listing.py
CHANGED
|
@@ -4,12 +4,10 @@ from collections.abc import Iterable, Iterator
|
|
|
4
4
|
from itertools import zip_longest
|
|
5
5
|
from typing import TYPE_CHECKING, Optional
|
|
6
6
|
|
|
7
|
-
from fsspec.asyn import get_loop, sync
|
|
8
7
|
from sqlalchemy import Column
|
|
9
8
|
from sqlalchemy.sql import func
|
|
10
9
|
from tqdm import tqdm
|
|
11
10
|
|
|
12
|
-
from datachain.lib.file import File
|
|
13
11
|
from datachain.node import DirType, Node, NodeWithPath
|
|
14
12
|
from datachain.sql.functions import path as pathfunc
|
|
15
13
|
from datachain.utils import suffix_to_number
|
|
@@ -17,33 +15,29 @@ from datachain.utils import suffix_to_number
|
|
|
17
15
|
if TYPE_CHECKING:
|
|
18
16
|
from datachain.catalog.datasource import DataSource
|
|
19
17
|
from datachain.client import Client
|
|
20
|
-
from datachain.data_storage import
|
|
18
|
+
from datachain.data_storage import AbstractWarehouse
|
|
21
19
|
from datachain.dataset import DatasetRecord
|
|
22
|
-
from datachain.storage import Storage
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
class Listing:
|
|
26
23
|
def __init__(
|
|
27
24
|
self,
|
|
28
|
-
storage: Optional["Storage"],
|
|
29
|
-
metastore: "AbstractMetastore",
|
|
30
25
|
warehouse: "AbstractWarehouse",
|
|
31
26
|
client: "Client",
|
|
32
27
|
dataset: Optional["DatasetRecord"],
|
|
28
|
+
object_name: str = "file",
|
|
33
29
|
):
|
|
34
|
-
self.storage = storage
|
|
35
|
-
self.metastore = metastore
|
|
36
30
|
self.warehouse = warehouse
|
|
37
31
|
self.client = client
|
|
38
32
|
self.dataset = dataset # dataset representing bucket listing
|
|
33
|
+
self.object_name = object_name
|
|
39
34
|
|
|
40
35
|
def clone(self) -> "Listing":
|
|
41
36
|
return self.__class__(
|
|
42
|
-
self.storage,
|
|
43
|
-
self.metastore.clone(),
|
|
44
37
|
self.warehouse.clone(),
|
|
45
38
|
self.client,
|
|
46
39
|
self.dataset,
|
|
40
|
+
self.object_name,
|
|
47
41
|
)
|
|
48
42
|
|
|
49
43
|
def __enter__(self) -> "Listing":
|
|
@@ -53,46 +47,20 @@ class Listing:
|
|
|
53
47
|
self.close()
|
|
54
48
|
|
|
55
49
|
def close(self) -> None:
|
|
56
|
-
self.metastore.close()
|
|
57
50
|
self.warehouse.close()
|
|
58
51
|
|
|
59
52
|
@property
|
|
60
|
-
def
|
|
61
|
-
|
|
53
|
+
def uri(self):
|
|
54
|
+
from datachain.lib.listing import listing_uri_from_name
|
|
55
|
+
|
|
56
|
+
return listing_uri_from_name(self.dataset.name)
|
|
62
57
|
|
|
63
58
|
@property
|
|
64
59
|
def dataset_rows(self):
|
|
65
|
-
return self.warehouse.dataset_rows(
|
|
66
|
-
|
|
67
|
-
def fetch(self, start_prefix="", method: str = "default") -> None:
|
|
68
|
-
sync(get_loop(), self._fetch, start_prefix, method)
|
|
69
|
-
|
|
70
|
-
async def _fetch(self, start_prefix: str, method: str) -> None:
|
|
71
|
-
with self.clone() as fetch_listing:
|
|
72
|
-
if start_prefix:
|
|
73
|
-
start_prefix = start_prefix.rstrip("/")
|
|
74
|
-
try:
|
|
75
|
-
async for entries in fetch_listing.client.scandir(
|
|
76
|
-
start_prefix, method=method
|
|
77
|
-
):
|
|
78
|
-
fetch_listing.insert_entries(entries)
|
|
79
|
-
if len(entries) > 1:
|
|
80
|
-
fetch_listing.metastore.update_last_inserted_at()
|
|
81
|
-
finally:
|
|
82
|
-
fetch_listing.insert_entries_done()
|
|
83
|
-
|
|
84
|
-
def insert_entry(self, entry: File) -> None:
|
|
85
|
-
self.insert_entries([entry])
|
|
86
|
-
|
|
87
|
-
def insert_entries(self, entries: Iterable[File]) -> None:
|
|
88
|
-
self.warehouse.insert_rows(
|
|
89
|
-
self.dataset_rows.get_table(),
|
|
90
|
-
self.warehouse.prepare_entries(entries),
|
|
60
|
+
return self.warehouse.dataset_rows(
|
|
61
|
+
self.dataset, self.dataset.latest_version, object_name=self.object_name
|
|
91
62
|
)
|
|
92
63
|
|
|
93
|
-
def insert_entries_done(self) -> None:
|
|
94
|
-
self.warehouse.insert_rows_done(self.dataset_rows.get_table())
|
|
95
|
-
|
|
96
64
|
def expand_path(self, path, use_glob=True) -> list[Node]:
|
|
97
65
|
if use_glob and glob.has_magic(path):
|
|
98
66
|
return self.warehouse.expand_path(self.dataset_rows, path)
|
|
@@ -200,25 +168,31 @@ class Listing:
|
|
|
200
168
|
conds = []
|
|
201
169
|
if names:
|
|
202
170
|
for name in names:
|
|
203
|
-
conds.append(
|
|
171
|
+
conds.append(
|
|
172
|
+
pathfunc.name(Column(dr.col_name("path"))).op("GLOB")(name)
|
|
173
|
+
)
|
|
204
174
|
if inames:
|
|
205
175
|
for iname in inames:
|
|
206
176
|
conds.append(
|
|
207
|
-
func.lower(pathfunc.name(Column("path"))).op("GLOB")(
|
|
177
|
+
func.lower(pathfunc.name(Column(dr.col_name("path")))).op("GLOB")(
|
|
178
|
+
iname.lower()
|
|
179
|
+
)
|
|
208
180
|
)
|
|
209
181
|
if paths:
|
|
210
182
|
for path in paths:
|
|
211
|
-
conds.append(Column("path").op("GLOB")(path))
|
|
183
|
+
conds.append(Column(dr.col_name("path")).op("GLOB")(path))
|
|
212
184
|
if ipaths:
|
|
213
185
|
for ipath in ipaths:
|
|
214
|
-
conds.append(
|
|
186
|
+
conds.append(
|
|
187
|
+
func.lower(Column(dr.col_name("path"))).op("GLOB")(ipath.lower())
|
|
188
|
+
)
|
|
215
189
|
|
|
216
190
|
if size is not None:
|
|
217
191
|
size_limit = suffix_to_number(size)
|
|
218
192
|
if size_limit >= 0:
|
|
219
|
-
conds.append(Column("size") >= size_limit)
|
|
193
|
+
conds.append(Column(dr.col_name("size")) >= size_limit)
|
|
220
194
|
else:
|
|
221
|
-
conds.append(Column("size") <= -size_limit)
|
|
195
|
+
conds.append(Column(dr.col_name("size")) <= -size_limit)
|
|
222
196
|
|
|
223
197
|
return self.warehouse.find(
|
|
224
198
|
dr,
|
datachain/query/__init__.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
from .dataset import DatasetQuery
|
|
2
2
|
from .params import param
|
|
3
|
-
from .schema import C,
|
|
3
|
+
from .schema import C, LocalFilename, Object, Stream
|
|
4
4
|
from .session import Session
|
|
5
5
|
|
|
6
6
|
__all__ = [
|
|
7
7
|
"C",
|
|
8
8
|
"DatasetQuery",
|
|
9
|
-
"DatasetRow",
|
|
10
9
|
"LocalFilename",
|
|
11
10
|
"Object",
|
|
12
11
|
"Session",
|
datachain/query/batch.py
CHANGED
datachain/query/dataset.py
CHANGED
|
@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
|
|
|
10
10
|
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
11
11
|
from copy import copy
|
|
12
12
|
from functools import wraps
|
|
13
|
+
from secrets import token_hex
|
|
13
14
|
from typing import (
|
|
14
15
|
TYPE_CHECKING,
|
|
15
16
|
Any,
|
|
@@ -173,10 +174,10 @@ class QueryStep(StartingStep):
|
|
|
173
174
|
return sqlalchemy.select(*columns)
|
|
174
175
|
|
|
175
176
|
dataset = self.catalog.get_dataset(self.dataset_name)
|
|
176
|
-
|
|
177
|
+
dr = self.catalog.warehouse.dataset_rows(dataset, self.dataset_version)
|
|
177
178
|
|
|
178
179
|
return step_result(
|
|
179
|
-
q,
|
|
180
|
+
q, dr.columns, dependencies=[(self.dataset_name, self.dataset_version)]
|
|
180
181
|
)
|
|
181
182
|
|
|
182
183
|
|
|
@@ -591,10 +592,6 @@ class UDFSignal(UDFStep):
|
|
|
591
592
|
return query, []
|
|
592
593
|
table = self.catalog.warehouse.create_pre_udf_table(query)
|
|
593
594
|
q: Select = sqlalchemy.select(*table.c)
|
|
594
|
-
if query._order_by_clauses:
|
|
595
|
-
# we are adding ordering only if it's explicitly added by user in
|
|
596
|
-
# query part before adding signals
|
|
597
|
-
q = q.order_by(table.c.sys__id)
|
|
598
595
|
return q, [table]
|
|
599
596
|
|
|
600
597
|
def create_result_query(
|
|
@@ -630,11 +627,6 @@ class UDFSignal(UDFStep):
|
|
|
630
627
|
else:
|
|
631
628
|
res = sqlalchemy.select(*cols1).select_from(subq)
|
|
632
629
|
|
|
633
|
-
if query._order_by_clauses:
|
|
634
|
-
# if ordering is used in query part before adding signals, we
|
|
635
|
-
# will have it as order by id from select from pre-created udf table
|
|
636
|
-
res = res.order_by(subq.c.sys__id)
|
|
637
|
-
|
|
638
630
|
if self.partition_by is not None:
|
|
639
631
|
subquery = res.subquery()
|
|
640
632
|
res = sqlalchemy.select(*subquery.c).select_from(subquery)
|
|
@@ -666,13 +658,6 @@ class RowGenerator(UDFStep):
|
|
|
666
658
|
def create_result_query(
|
|
667
659
|
self, udf_table, query: Select
|
|
668
660
|
) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]:
|
|
669
|
-
if not query._order_by_clauses:
|
|
670
|
-
# if we are not selecting all rows in UDF, we need to ensure that
|
|
671
|
-
# we get the same rows as we got as inputs of UDF since selecting
|
|
672
|
-
# without ordering can be non deterministic in some databases
|
|
673
|
-
c = query.selected_columns
|
|
674
|
-
query = query.order_by(c.sys__id)
|
|
675
|
-
|
|
676
661
|
udf_table_query = udf_table.select().subquery()
|
|
677
662
|
udf_table_cols: list[sqlalchemy.Label[Any]] = [
|
|
678
663
|
label(c.name, c) for c in udf_table_query.columns
|
|
@@ -736,10 +721,17 @@ class SQLMutate(SQLClause):
|
|
|
736
721
|
|
|
737
722
|
def apply_sql_clause(self, query: Select) -> Select:
|
|
738
723
|
original_subquery = query.subquery()
|
|
724
|
+
to_mutate = {c.name for c in self.args}
|
|
725
|
+
|
|
726
|
+
prefix = f"mutate{token_hex(8)}_"
|
|
727
|
+
cols = [
|
|
728
|
+
c.label(prefix + c.name) if c.name in to_mutate else c
|
|
729
|
+
for c in original_subquery.c
|
|
730
|
+
]
|
|
739
731
|
# this is needed for new column to be used in clauses
|
|
740
732
|
# like ORDER BY, otherwise new column is not recognized
|
|
741
733
|
subquery = (
|
|
742
|
-
sqlalchemy.select(*
|
|
734
|
+
sqlalchemy.select(*cols, *self.args)
|
|
743
735
|
.select_from(original_subquery)
|
|
744
736
|
.subquery()
|
|
745
737
|
)
|
|
@@ -957,24 +949,24 @@ class SQLJoin(Step):
|
|
|
957
949
|
|
|
958
950
|
|
|
959
951
|
@frozen
|
|
960
|
-
class
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
cols: PartitionByType
|
|
952
|
+
class SQLGroupBy(SQLClause):
|
|
953
|
+
cols: Sequence[Union[str, ColumnElement]]
|
|
954
|
+
group_by: Sequence[Union[str, ColumnElement]]
|
|
964
955
|
|
|
965
|
-
def
|
|
966
|
-
|
|
956
|
+
def apply_sql_clause(self, query) -> Select:
|
|
957
|
+
if not self.cols:
|
|
958
|
+
raise ValueError("No columns to select")
|
|
959
|
+
if not self.group_by:
|
|
960
|
+
raise ValueError("No columns to group by")
|
|
967
961
|
|
|
968
|
-
|
|
969
|
-
self, query_generator: QueryGenerator, temp_tables: list[str]
|
|
970
|
-
) -> StepResult:
|
|
971
|
-
query = query_generator.select()
|
|
972
|
-
grouped_query = query.group_by(*self.cols)
|
|
962
|
+
subquery = query.subquery()
|
|
973
963
|
|
|
974
|
-
|
|
975
|
-
|
|
964
|
+
cols = [
|
|
965
|
+
subquery.c[str(c)] if isinstance(c, (str, C)) else c
|
|
966
|
+
for c in [*self.group_by, *self.cols]
|
|
967
|
+
]
|
|
976
968
|
|
|
977
|
-
return
|
|
969
|
+
return sqlalchemy.select(*cols).select_from(subquery).group_by(*self.group_by)
|
|
978
970
|
|
|
979
971
|
|
|
980
972
|
def _validate_columns(
|
|
@@ -1130,25 +1122,14 @@ class DatasetQuery:
|
|
|
1130
1122
|
query.steps = query.steps[-1:] + query.steps[:-1]
|
|
1131
1123
|
|
|
1132
1124
|
result = query.starting_step.apply()
|
|
1133
|
-
group_by = None
|
|
1134
1125
|
self.dependencies.update(result.dependencies)
|
|
1135
1126
|
|
|
1136
1127
|
for step in query.steps:
|
|
1137
|
-
if isinstance(step, GroupBy):
|
|
1138
|
-
if group_by is not None:
|
|
1139
|
-
raise TypeError("only one group_by allowed")
|
|
1140
|
-
group_by = step
|
|
1141
|
-
continue
|
|
1142
|
-
|
|
1143
1128
|
result = step.apply(
|
|
1144
1129
|
result.query_generator, self.temp_table_names
|
|
1145
1130
|
) # a chain of steps linked by results
|
|
1146
1131
|
self.dependencies.update(result.dependencies)
|
|
1147
1132
|
|
|
1148
|
-
if group_by:
|
|
1149
|
-
result = group_by.apply(result.query_generator, self.temp_table_names)
|
|
1150
|
-
self.dependencies.update(result.dependencies)
|
|
1151
|
-
|
|
1152
1133
|
return result.query_generator
|
|
1153
1134
|
|
|
1154
1135
|
@staticmethod
|
|
@@ -1410,9 +1391,13 @@ class DatasetQuery:
|
|
|
1410
1391
|
return query.as_scalar()
|
|
1411
1392
|
|
|
1412
1393
|
@detach
|
|
1413
|
-
def group_by(
|
|
1394
|
+
def group_by(
|
|
1395
|
+
self,
|
|
1396
|
+
cols: Sequence[ColumnElement],
|
|
1397
|
+
group_by: Sequence[ColumnElement],
|
|
1398
|
+
) -> "Self":
|
|
1414
1399
|
query = self.clone()
|
|
1415
|
-
query.steps.append(
|
|
1400
|
+
query.steps.append(SQLGroupBy(cols, group_by))
|
|
1416
1401
|
return query
|
|
1417
1402
|
|
|
1418
1403
|
@detach
|
|
@@ -1591,6 +1576,8 @@ class DatasetQuery:
|
|
|
1591
1576
|
)
|
|
1592
1577
|
version = version or dataset.latest_version
|
|
1593
1578
|
|
|
1579
|
+
self.session.add_dataset_version(dataset=dataset, version=version)
|
|
1580
|
+
|
|
1594
1581
|
dr = self.catalog.warehouse.dataset_rows(dataset)
|
|
1595
1582
|
|
|
1596
1583
|
self.catalog.warehouse.copy_table(dr.get_table(), query.select())
|
datachain/query/schema.py
CHANGED
|
@@ -1,16 +1,13 @@
|
|
|
1
1
|
import functools
|
|
2
|
-
import json
|
|
3
2
|
from abc import ABC, abstractmethod
|
|
4
|
-
from datetime import datetime, timezone
|
|
5
3
|
from fnmatch import fnmatch
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
7
5
|
|
|
8
6
|
import attrs
|
|
9
7
|
import sqlalchemy as sa
|
|
10
8
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
11
9
|
|
|
12
10
|
from datachain.lib.file import File
|
|
13
|
-
from datachain.sql.types import JSON, Boolean, DateTime, Int64, SQLType, String
|
|
14
11
|
|
|
15
12
|
if TYPE_CHECKING:
|
|
16
13
|
from datachain.catalog import Catalog
|
|
@@ -228,61 +225,4 @@ def normalize_param(param: UDFParamSpec) -> UDFParameter:
|
|
|
228
225
|
raise TypeError(f"Invalid UDF parameter: {param}")
|
|
229
226
|
|
|
230
227
|
|
|
231
|
-
class DatasetRow:
|
|
232
|
-
schema: ClassVar[dict[str, type[SQLType]]] = {
|
|
233
|
-
"source": String,
|
|
234
|
-
"path": String,
|
|
235
|
-
"size": Int64,
|
|
236
|
-
"location": JSON,
|
|
237
|
-
"is_latest": Boolean,
|
|
238
|
-
"last_modified": DateTime,
|
|
239
|
-
"version": String,
|
|
240
|
-
"etag": String,
|
|
241
|
-
}
|
|
242
|
-
|
|
243
|
-
@staticmethod
|
|
244
|
-
def create(
|
|
245
|
-
path: str,
|
|
246
|
-
source: str = "",
|
|
247
|
-
size: int = 0,
|
|
248
|
-
location: Optional[dict[str, Any]] = None,
|
|
249
|
-
is_latest: bool = True,
|
|
250
|
-
last_modified: Optional[datetime] = None,
|
|
251
|
-
version: str = "",
|
|
252
|
-
etag: str = "",
|
|
253
|
-
) -> tuple[
|
|
254
|
-
str,
|
|
255
|
-
str,
|
|
256
|
-
int,
|
|
257
|
-
Optional[str],
|
|
258
|
-
int,
|
|
259
|
-
bool,
|
|
260
|
-
datetime,
|
|
261
|
-
str,
|
|
262
|
-
str,
|
|
263
|
-
int,
|
|
264
|
-
]:
|
|
265
|
-
if location:
|
|
266
|
-
location = json.dumps([location]) # type: ignore [assignment]
|
|
267
|
-
|
|
268
|
-
last_modified = last_modified or datetime.now(timezone.utc)
|
|
269
|
-
|
|
270
|
-
return ( # type: ignore [return-value]
|
|
271
|
-
source,
|
|
272
|
-
path,
|
|
273
|
-
size,
|
|
274
|
-
location,
|
|
275
|
-
is_latest,
|
|
276
|
-
last_modified,
|
|
277
|
-
version,
|
|
278
|
-
etag,
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
@staticmethod
|
|
282
|
-
def extend(**columns):
|
|
283
|
-
cols = {**DatasetRow.schema}
|
|
284
|
-
cols.update(columns)
|
|
285
|
-
return cols
|
|
286
|
-
|
|
287
|
-
|
|
288
228
|
C = Column
|
datachain/query/session.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import atexit
|
|
2
|
+
import gc
|
|
2
3
|
import logging
|
|
3
|
-
import os
|
|
4
4
|
import re
|
|
5
5
|
import sys
|
|
6
|
-
from typing import TYPE_CHECKING, Optional
|
|
6
|
+
from typing import TYPE_CHECKING, ClassVar, Optional
|
|
7
7
|
from uuid import uuid4
|
|
8
8
|
|
|
9
9
|
from datachain.catalog import get_catalog
|
|
@@ -11,6 +11,7 @@ from datachain.error import TableMissingError
|
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from datachain.catalog import Catalog
|
|
14
|
+
from datachain.dataset import DatasetRecord
|
|
14
15
|
|
|
15
16
|
logger = logging.getLogger("datachain")
|
|
16
17
|
|
|
@@ -39,7 +40,7 @@ class Session:
|
|
|
39
40
|
"""
|
|
40
41
|
|
|
41
42
|
GLOBAL_SESSION_CTX: Optional["Session"] = None
|
|
42
|
-
|
|
43
|
+
SESSION_CONTEXTS: ClassVar[list["Session"]] = []
|
|
43
44
|
ORIGINAL_EXCEPT_HOOK = None
|
|
44
45
|
|
|
45
46
|
DATASET_PREFIX = "session_"
|
|
@@ -64,18 +65,21 @@ class Session:
|
|
|
64
65
|
|
|
65
66
|
session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
|
|
66
67
|
self.name = f"{name}_{session_uuid}"
|
|
67
|
-
self.job_id = os.getenv("DATACHAIN_JOB_ID") or str(uuid4())
|
|
68
68
|
self.is_new_catalog = not catalog
|
|
69
69
|
self.catalog = catalog or get_catalog(
|
|
70
70
|
client_config=client_config, in_memory=in_memory
|
|
71
71
|
)
|
|
72
|
+
self.dataset_versions: list[tuple[DatasetRecord, int]] = []
|
|
72
73
|
|
|
73
74
|
def __enter__(self):
|
|
75
|
+
# Push the current context onto the stack
|
|
76
|
+
Session.SESSION_CONTEXTS.append(self)
|
|
77
|
+
|
|
74
78
|
return self
|
|
75
79
|
|
|
76
80
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
77
81
|
if exc_type:
|
|
78
|
-
self._cleanup_created_versions(
|
|
82
|
+
self._cleanup_created_versions()
|
|
79
83
|
|
|
80
84
|
self._cleanup_temp_datasets()
|
|
81
85
|
if self.is_new_catalog:
|
|
@@ -83,6 +87,12 @@ class Session:
|
|
|
83
87
|
self.catalog.warehouse.close_on_exit()
|
|
84
88
|
self.catalog.id_generator.close_on_exit()
|
|
85
89
|
|
|
90
|
+
if Session.SESSION_CONTEXTS:
|
|
91
|
+
Session.SESSION_CONTEXTS.pop()
|
|
92
|
+
|
|
93
|
+
def add_dataset_version(self, dataset: "DatasetRecord", version: int) -> None:
|
|
94
|
+
self.dataset_versions.append((dataset, version))
|
|
95
|
+
|
|
86
96
|
def generate_temp_dataset_name(self) -> str:
|
|
87
97
|
return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
|
|
88
98
|
|
|
@@ -98,21 +108,15 @@ class Session:
|
|
|
98
108
|
except TableMissingError:
|
|
99
109
|
pass
|
|
100
110
|
|
|
101
|
-
def _cleanup_created_versions(self
|
|
102
|
-
|
|
103
|
-
if not versions:
|
|
111
|
+
def _cleanup_created_versions(self) -> None:
|
|
112
|
+
if not self.dataset_versions:
|
|
104
113
|
return
|
|
105
114
|
|
|
106
|
-
|
|
107
|
-
for dataset_name, version in versions:
|
|
108
|
-
if dataset_name not in datasets:
|
|
109
|
-
datasets[dataset_name] = self.catalog.get_dataset(dataset_name)
|
|
110
|
-
dataset = datasets[dataset_name]
|
|
111
|
-
logger.info(
|
|
112
|
-
"Removing dataset version %s@%s due to exception", dataset_name, version
|
|
113
|
-
)
|
|
115
|
+
for dataset, version in self.dataset_versions:
|
|
114
116
|
self.catalog.remove_dataset_version(dataset, version)
|
|
115
117
|
|
|
118
|
+
self.dataset_versions.clear()
|
|
119
|
+
|
|
116
120
|
@classmethod
|
|
117
121
|
def get(
|
|
118
122
|
cls,
|
|
@@ -125,33 +129,34 @@ class Session:
|
|
|
125
129
|
|
|
126
130
|
Parameters:
|
|
127
131
|
session (Session): Optional Session(). If not provided a new session will
|
|
128
|
-
be created. It's needed mostly for
|
|
129
|
-
catalog (Catalog): Optional catalog. By default a new catalog is created.
|
|
132
|
+
be created. It's needed mostly for simple API purposes.
|
|
133
|
+
catalog (Catalog): Optional catalog. By default, a new catalog is created.
|
|
130
134
|
"""
|
|
131
135
|
if session:
|
|
132
136
|
return session
|
|
133
137
|
|
|
134
|
-
|
|
138
|
+
# Access the active (most recent) context from the stack
|
|
139
|
+
if cls.SESSION_CONTEXTS:
|
|
140
|
+
return cls.SESSION_CONTEXTS[-1]
|
|
141
|
+
|
|
142
|
+
if cls.GLOBAL_SESSION_CTX is None:
|
|
135
143
|
cls.GLOBAL_SESSION_CTX = Session(
|
|
136
144
|
cls.GLOBAL_SESSION_NAME,
|
|
137
145
|
catalog,
|
|
138
146
|
client_config=client_config,
|
|
139
147
|
in_memory=in_memory,
|
|
140
148
|
)
|
|
141
|
-
cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
|
|
142
149
|
|
|
143
150
|
atexit.register(cls._global_cleanup)
|
|
144
151
|
cls.ORIGINAL_EXCEPT_HOOK = sys.excepthook
|
|
145
152
|
sys.excepthook = cls.except_hook
|
|
146
153
|
|
|
147
|
-
return cls.
|
|
154
|
+
return cls.GLOBAL_SESSION_CTX
|
|
148
155
|
|
|
149
156
|
@staticmethod
|
|
150
157
|
def except_hook(exc_type, exc_value, exc_traceback):
|
|
158
|
+
Session.GLOBAL_SESSION_CTX.__exit__(exc_type, exc_value, exc_traceback)
|
|
151
159
|
Session._global_cleanup()
|
|
152
|
-
if Session.GLOBAL_SESSION_CTX is not None:
|
|
153
|
-
job_id = Session.GLOBAL_SESSION_CTX.job_id
|
|
154
|
-
Session.GLOBAL_SESSION_CTX._cleanup_created_versions(job_id)
|
|
155
160
|
|
|
156
161
|
if Session.ORIGINAL_EXCEPT_HOOK:
|
|
157
162
|
Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
|
|
@@ -160,7 +165,6 @@ class Session:
|
|
|
160
165
|
def cleanup_for_tests(cls):
|
|
161
166
|
if cls.GLOBAL_SESSION_CTX is not None:
|
|
162
167
|
cls.GLOBAL_SESSION_CTX.__exit__(None, None, None)
|
|
163
|
-
cls.GLOBAL_SESSION = None
|
|
164
168
|
cls.GLOBAL_SESSION_CTX = None
|
|
165
169
|
atexit.unregister(cls._global_cleanup)
|
|
166
170
|
|
|
@@ -171,3 +175,7 @@ class Session:
|
|
|
171
175
|
def _global_cleanup():
|
|
172
176
|
if Session.GLOBAL_SESSION_CTX is not None:
|
|
173
177
|
Session.GLOBAL_SESSION_CTX.__exit__(None, None, None)
|
|
178
|
+
|
|
179
|
+
for obj in gc.get_objects(): # Get all tracked objects
|
|
180
|
+
if isinstance(obj, Session): # Cleanup temp dataset for session variables.
|
|
181
|
+
obj.__exit__(None, None, None)
|