datachain 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (38) hide show
  1. datachain/__init__.py +2 -0
  2. datachain/catalog/catalog.py +62 -228
  3. datachain/cli.py +136 -22
  4. datachain/client/fsspec.py +9 -0
  5. datachain/client/local.py +11 -32
  6. datachain/config.py +126 -51
  7. datachain/data_storage/schema.py +66 -33
  8. datachain/data_storage/sqlite.py +12 -4
  9. datachain/data_storage/warehouse.py +101 -129
  10. datachain/lib/convert/sql_to_python.py +8 -12
  11. datachain/lib/dc.py +275 -80
  12. datachain/lib/func/__init__.py +32 -0
  13. datachain/lib/func/aggregate.py +353 -0
  14. datachain/lib/func/func.py +152 -0
  15. datachain/lib/listing.py +6 -21
  16. datachain/lib/listing_info.py +4 -0
  17. datachain/lib/signal_schema.py +17 -8
  18. datachain/lib/udf.py +3 -3
  19. datachain/lib/utils.py +5 -0
  20. datachain/listing.py +22 -48
  21. datachain/query/__init__.py +1 -2
  22. datachain/query/batch.py +0 -1
  23. datachain/query/dataset.py +33 -46
  24. datachain/query/schema.py +1 -61
  25. datachain/query/session.py +33 -25
  26. datachain/remote/studio.py +63 -14
  27. datachain/sql/functions/__init__.py +1 -1
  28. datachain/sql/functions/aggregate.py +47 -0
  29. datachain/sql/functions/array.py +0 -8
  30. datachain/sql/sqlite/base.py +20 -2
  31. datachain/studio.py +129 -0
  32. datachain/utils.py +58 -0
  33. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/METADATA +7 -6
  34. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/RECORD +38 -33
  35. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/WHEEL +1 -1
  36. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/LICENSE +0 -0
  37. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/entry_points.txt +0 -0
  38. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/top_level.txt +0 -0
datachain/listing.py CHANGED
@@ -4,12 +4,10 @@ from collections.abc import Iterable, Iterator
4
4
  from itertools import zip_longest
5
5
  from typing import TYPE_CHECKING, Optional
6
6
 
7
- from fsspec.asyn import get_loop, sync
8
7
  from sqlalchemy import Column
9
8
  from sqlalchemy.sql import func
10
9
  from tqdm import tqdm
11
10
 
12
- from datachain.lib.file import File
13
11
  from datachain.node import DirType, Node, NodeWithPath
14
12
  from datachain.sql.functions import path as pathfunc
15
13
  from datachain.utils import suffix_to_number
@@ -17,33 +15,29 @@ from datachain.utils import suffix_to_number
17
15
  if TYPE_CHECKING:
18
16
  from datachain.catalog.datasource import DataSource
19
17
  from datachain.client import Client
20
- from datachain.data_storage import AbstractMetastore, AbstractWarehouse
18
+ from datachain.data_storage import AbstractWarehouse
21
19
  from datachain.dataset import DatasetRecord
22
- from datachain.storage import Storage
23
20
 
24
21
 
25
22
  class Listing:
26
23
  def __init__(
27
24
  self,
28
- storage: Optional["Storage"],
29
- metastore: "AbstractMetastore",
30
25
  warehouse: "AbstractWarehouse",
31
26
  client: "Client",
32
27
  dataset: Optional["DatasetRecord"],
28
+ object_name: str = "file",
33
29
  ):
34
- self.storage = storage
35
- self.metastore = metastore
36
30
  self.warehouse = warehouse
37
31
  self.client = client
38
32
  self.dataset = dataset # dataset representing bucket listing
33
+ self.object_name = object_name
39
34
 
40
35
  def clone(self) -> "Listing":
41
36
  return self.__class__(
42
- self.storage,
43
- self.metastore.clone(),
44
37
  self.warehouse.clone(),
45
38
  self.client,
46
39
  self.dataset,
40
+ self.object_name,
47
41
  )
48
42
 
49
43
  def __enter__(self) -> "Listing":
@@ -53,46 +47,20 @@ class Listing:
53
47
  self.close()
54
48
 
55
49
  def close(self) -> None:
56
- self.metastore.close()
57
50
  self.warehouse.close()
58
51
 
59
52
  @property
60
- def id(self):
61
- return self.storage.id
53
+ def uri(self):
54
+ from datachain.lib.listing import listing_uri_from_name
55
+
56
+ return listing_uri_from_name(self.dataset.name)
62
57
 
63
58
  @property
64
59
  def dataset_rows(self):
65
- return self.warehouse.dataset_rows(self.dataset, self.dataset.latest_version)
66
-
67
- def fetch(self, start_prefix="", method: str = "default") -> None:
68
- sync(get_loop(), self._fetch, start_prefix, method)
69
-
70
- async def _fetch(self, start_prefix: str, method: str) -> None:
71
- with self.clone() as fetch_listing:
72
- if start_prefix:
73
- start_prefix = start_prefix.rstrip("/")
74
- try:
75
- async for entries in fetch_listing.client.scandir(
76
- start_prefix, method=method
77
- ):
78
- fetch_listing.insert_entries(entries)
79
- if len(entries) > 1:
80
- fetch_listing.metastore.update_last_inserted_at()
81
- finally:
82
- fetch_listing.insert_entries_done()
83
-
84
- def insert_entry(self, entry: File) -> None:
85
- self.insert_entries([entry])
86
-
87
- def insert_entries(self, entries: Iterable[File]) -> None:
88
- self.warehouse.insert_rows(
89
- self.dataset_rows.get_table(),
90
- self.warehouse.prepare_entries(entries),
60
+ return self.warehouse.dataset_rows(
61
+ self.dataset, self.dataset.latest_version, object_name=self.object_name
91
62
  )
92
63
 
93
- def insert_entries_done(self) -> None:
94
- self.warehouse.insert_rows_done(self.dataset_rows.get_table())
95
-
96
64
  def expand_path(self, path, use_glob=True) -> list[Node]:
97
65
  if use_glob and glob.has_magic(path):
98
66
  return self.warehouse.expand_path(self.dataset_rows, path)
@@ -200,25 +168,31 @@ class Listing:
200
168
  conds = []
201
169
  if names:
202
170
  for name in names:
203
- conds.append(pathfunc.name(Column("path")).op("GLOB")(name))
171
+ conds.append(
172
+ pathfunc.name(Column(dr.col_name("path"))).op("GLOB")(name)
173
+ )
204
174
  if inames:
205
175
  for iname in inames:
206
176
  conds.append(
207
- func.lower(pathfunc.name(Column("path"))).op("GLOB")(iname.lower())
177
+ func.lower(pathfunc.name(Column(dr.col_name("path")))).op("GLOB")(
178
+ iname.lower()
179
+ )
208
180
  )
209
181
  if paths:
210
182
  for path in paths:
211
- conds.append(Column("path").op("GLOB")(path))
183
+ conds.append(Column(dr.col_name("path")).op("GLOB")(path))
212
184
  if ipaths:
213
185
  for ipath in ipaths:
214
- conds.append(func.lower(Column("path")).op("GLOB")(ipath.lower()))
186
+ conds.append(
187
+ func.lower(Column(dr.col_name("path"))).op("GLOB")(ipath.lower())
188
+ )
215
189
 
216
190
  if size is not None:
217
191
  size_limit = suffix_to_number(size)
218
192
  if size_limit >= 0:
219
- conds.append(Column("size") >= size_limit)
193
+ conds.append(Column(dr.col_name("size")) >= size_limit)
220
194
  else:
221
- conds.append(Column("size") <= -size_limit)
195
+ conds.append(Column(dr.col_name("size")) <= -size_limit)
222
196
 
223
197
  return self.warehouse.find(
224
198
  dr,
@@ -1,12 +1,11 @@
1
1
  from .dataset import DatasetQuery
2
2
  from .params import param
3
- from .schema import C, DatasetRow, LocalFilename, Object, Stream
3
+ from .schema import C, LocalFilename, Object, Stream
4
4
  from .session import Session
5
5
 
6
6
  __all__ = [
7
7
  "C",
8
8
  "DatasetQuery",
9
- "DatasetRow",
10
9
  "LocalFilename",
11
10
  "Object",
12
11
  "Session",
datachain/query/batch.py CHANGED
@@ -97,7 +97,6 @@ class Partition(BatchingStrategy):
97
97
 
98
98
  ordered_query = query.order_by(None).order_by(
99
99
  PARTITION_COLUMN_ID,
100
- "sys__id",
101
100
  *query._order_by_clauses,
102
101
  )
103
102
 
@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
10
10
  from collections.abc import Generator, Iterable, Iterator, Sequence
11
11
  from copy import copy
12
12
  from functools import wraps
13
+ from secrets import token_hex
13
14
  from typing import (
14
15
  TYPE_CHECKING,
15
16
  Any,
@@ -173,10 +174,10 @@ class QueryStep(StartingStep):
173
174
  return sqlalchemy.select(*columns)
174
175
 
175
176
  dataset = self.catalog.get_dataset(self.dataset_name)
176
- table = self.catalog.warehouse.dataset_rows(dataset, self.dataset_version)
177
+ dr = self.catalog.warehouse.dataset_rows(dataset, self.dataset_version)
177
178
 
178
179
  return step_result(
179
- q, table.c, dependencies=[(self.dataset_name, self.dataset_version)]
180
+ q, dr.columns, dependencies=[(self.dataset_name, self.dataset_version)]
180
181
  )
181
182
 
182
183
 
@@ -591,10 +592,6 @@ class UDFSignal(UDFStep):
591
592
  return query, []
592
593
  table = self.catalog.warehouse.create_pre_udf_table(query)
593
594
  q: Select = sqlalchemy.select(*table.c)
594
- if query._order_by_clauses:
595
- # we are adding ordering only if it's explicitly added by user in
596
- # query part before adding signals
597
- q = q.order_by(table.c.sys__id)
598
595
  return q, [table]
599
596
 
600
597
  def create_result_query(
@@ -630,11 +627,6 @@ class UDFSignal(UDFStep):
630
627
  else:
631
628
  res = sqlalchemy.select(*cols1).select_from(subq)
632
629
 
633
- if query._order_by_clauses:
634
- # if ordering is used in query part before adding signals, we
635
- # will have it as order by id from select from pre-created udf table
636
- res = res.order_by(subq.c.sys__id)
637
-
638
630
  if self.partition_by is not None:
639
631
  subquery = res.subquery()
640
632
  res = sqlalchemy.select(*subquery.c).select_from(subquery)
@@ -666,13 +658,6 @@ class RowGenerator(UDFStep):
666
658
  def create_result_query(
667
659
  self, udf_table, query: Select
668
660
  ) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]:
669
- if not query._order_by_clauses:
670
- # if we are not selecting all rows in UDF, we need to ensure that
671
- # we get the same rows as we got as inputs of UDF since selecting
672
- # without ordering can be non deterministic in some databases
673
- c = query.selected_columns
674
- query = query.order_by(c.sys__id)
675
-
676
661
  udf_table_query = udf_table.select().subquery()
677
662
  udf_table_cols: list[sqlalchemy.Label[Any]] = [
678
663
  label(c.name, c) for c in udf_table_query.columns
@@ -736,10 +721,17 @@ class SQLMutate(SQLClause):
736
721
 
737
722
  def apply_sql_clause(self, query: Select) -> Select:
738
723
  original_subquery = query.subquery()
724
+ to_mutate = {c.name for c in self.args}
725
+
726
+ prefix = f"mutate{token_hex(8)}_"
727
+ cols = [
728
+ c.label(prefix + c.name) if c.name in to_mutate else c
729
+ for c in original_subquery.c
730
+ ]
739
731
  # this is needed for new column to be used in clauses
740
732
  # like ORDER BY, otherwise new column is not recognized
741
733
  subquery = (
742
- sqlalchemy.select(*original_subquery.c, *self.args)
734
+ sqlalchemy.select(*cols, *self.args)
743
735
  .select_from(original_subquery)
744
736
  .subquery()
745
737
  )
@@ -957,24 +949,24 @@ class SQLJoin(Step):
957
949
 
958
950
 
959
951
  @frozen
960
- class GroupBy(Step):
961
- """Group rows by a specific column."""
962
-
963
- cols: PartitionByType
952
+ class SQLGroupBy(SQLClause):
953
+ cols: Sequence[Union[str, ColumnElement]]
954
+ group_by: Sequence[Union[str, ColumnElement]]
964
955
 
965
- def clone(self) -> "Self":
966
- return self.__class__(self.cols)
956
+ def apply_sql_clause(self, query) -> Select:
957
+ if not self.cols:
958
+ raise ValueError("No columns to select")
959
+ if not self.group_by:
960
+ raise ValueError("No columns to group by")
967
961
 
968
- def apply(
969
- self, query_generator: QueryGenerator, temp_tables: list[str]
970
- ) -> StepResult:
971
- query = query_generator.select()
972
- grouped_query = query.group_by(*self.cols)
962
+ subquery = query.subquery()
973
963
 
974
- def q(*columns):
975
- return grouped_query.with_only_columns(*columns)
964
+ cols = [
965
+ subquery.c[str(c)] if isinstance(c, (str, C)) else c
966
+ for c in [*self.group_by, *self.cols]
967
+ ]
976
968
 
977
- return step_result(q, grouped_query.selected_columns)
969
+ return sqlalchemy.select(*cols).select_from(subquery).group_by(*self.group_by)
978
970
 
979
971
 
980
972
  def _validate_columns(
@@ -1130,25 +1122,14 @@ class DatasetQuery:
1130
1122
  query.steps = query.steps[-1:] + query.steps[:-1]
1131
1123
 
1132
1124
  result = query.starting_step.apply()
1133
- group_by = None
1134
1125
  self.dependencies.update(result.dependencies)
1135
1126
 
1136
1127
  for step in query.steps:
1137
- if isinstance(step, GroupBy):
1138
- if group_by is not None:
1139
- raise TypeError("only one group_by allowed")
1140
- group_by = step
1141
- continue
1142
-
1143
1128
  result = step.apply(
1144
1129
  result.query_generator, self.temp_table_names
1145
1130
  ) # a chain of steps linked by results
1146
1131
  self.dependencies.update(result.dependencies)
1147
1132
 
1148
- if group_by:
1149
- result = group_by.apply(result.query_generator, self.temp_table_names)
1150
- self.dependencies.update(result.dependencies)
1151
-
1152
1133
  return result.query_generator
1153
1134
 
1154
1135
  @staticmethod
@@ -1410,9 +1391,13 @@ class DatasetQuery:
1410
1391
  return query.as_scalar()
1411
1392
 
1412
1393
  @detach
1413
- def group_by(self, *cols: ColumnElement) -> "Self":
1394
+ def group_by(
1395
+ self,
1396
+ cols: Sequence[ColumnElement],
1397
+ group_by: Sequence[ColumnElement],
1398
+ ) -> "Self":
1414
1399
  query = self.clone()
1415
- query.steps.append(GroupBy(cols))
1400
+ query.steps.append(SQLGroupBy(cols, group_by))
1416
1401
  return query
1417
1402
 
1418
1403
  @detach
@@ -1591,6 +1576,8 @@ class DatasetQuery:
1591
1576
  )
1592
1577
  version = version or dataset.latest_version
1593
1578
 
1579
+ self.session.add_dataset_version(dataset=dataset, version=version)
1580
+
1594
1581
  dr = self.catalog.warehouse.dataset_rows(dataset)
1595
1582
 
1596
1583
  self.catalog.warehouse.copy_table(dr.get_table(), query.select())
datachain/query/schema.py CHANGED
@@ -1,16 +1,13 @@
1
1
  import functools
2
- import json
3
2
  from abc import ABC, abstractmethod
4
- from datetime import datetime, timezone
5
3
  from fnmatch import fnmatch
6
- from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
7
5
 
8
6
  import attrs
9
7
  import sqlalchemy as sa
10
8
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
11
9
 
12
10
  from datachain.lib.file import File
13
- from datachain.sql.types import JSON, Boolean, DateTime, Int64, SQLType, String
14
11
 
15
12
  if TYPE_CHECKING:
16
13
  from datachain.catalog import Catalog
@@ -228,61 +225,4 @@ def normalize_param(param: UDFParamSpec) -> UDFParameter:
228
225
  raise TypeError(f"Invalid UDF parameter: {param}")
229
226
 
230
227
 
231
- class DatasetRow:
232
- schema: ClassVar[dict[str, type[SQLType]]] = {
233
- "source": String,
234
- "path": String,
235
- "size": Int64,
236
- "location": JSON,
237
- "is_latest": Boolean,
238
- "last_modified": DateTime,
239
- "version": String,
240
- "etag": String,
241
- }
242
-
243
- @staticmethod
244
- def create(
245
- path: str,
246
- source: str = "",
247
- size: int = 0,
248
- location: Optional[dict[str, Any]] = None,
249
- is_latest: bool = True,
250
- last_modified: Optional[datetime] = None,
251
- version: str = "",
252
- etag: str = "",
253
- ) -> tuple[
254
- str,
255
- str,
256
- int,
257
- Optional[str],
258
- int,
259
- bool,
260
- datetime,
261
- str,
262
- str,
263
- int,
264
- ]:
265
- if location:
266
- location = json.dumps([location]) # type: ignore [assignment]
267
-
268
- last_modified = last_modified or datetime.now(timezone.utc)
269
-
270
- return ( # type: ignore [return-value]
271
- source,
272
- path,
273
- size,
274
- location,
275
- is_latest,
276
- last_modified,
277
- version,
278
- etag,
279
- )
280
-
281
- @staticmethod
282
- def extend(**columns):
283
- cols = {**DatasetRow.schema}
284
- cols.update(columns)
285
- return cols
286
-
287
-
288
228
  C = Column
@@ -1,9 +1,9 @@
1
1
  import atexit
2
+ import gc
2
3
  import logging
3
- import os
4
4
  import re
5
5
  import sys
6
- from typing import TYPE_CHECKING, Optional
6
+ from typing import TYPE_CHECKING, ClassVar, Optional
7
7
  from uuid import uuid4
8
8
 
9
9
  from datachain.catalog import get_catalog
@@ -11,6 +11,7 @@ from datachain.error import TableMissingError
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from datachain.catalog import Catalog
14
+ from datachain.dataset import DatasetRecord
14
15
 
15
16
  logger = logging.getLogger("datachain")
16
17
 
@@ -39,7 +40,7 @@ class Session:
39
40
  """
40
41
 
41
42
  GLOBAL_SESSION_CTX: Optional["Session"] = None
42
- GLOBAL_SESSION: Optional["Session"] = None
43
+ SESSION_CONTEXTS: ClassVar[list["Session"]] = []
43
44
  ORIGINAL_EXCEPT_HOOK = None
44
45
 
45
46
  DATASET_PREFIX = "session_"
@@ -64,18 +65,21 @@ class Session:
64
65
 
65
66
  session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
66
67
  self.name = f"{name}_{session_uuid}"
67
- self.job_id = os.getenv("DATACHAIN_JOB_ID") or str(uuid4())
68
68
  self.is_new_catalog = not catalog
69
69
  self.catalog = catalog or get_catalog(
70
70
  client_config=client_config, in_memory=in_memory
71
71
  )
72
+ self.dataset_versions: list[tuple[DatasetRecord, int]] = []
72
73
 
73
74
  def __enter__(self):
75
+ # Push the current context onto the stack
76
+ Session.SESSION_CONTEXTS.append(self)
77
+
74
78
  return self
75
79
 
76
80
  def __exit__(self, exc_type, exc_val, exc_tb):
77
81
  if exc_type:
78
- self._cleanup_created_versions(self.name)
82
+ self._cleanup_created_versions()
79
83
 
80
84
  self._cleanup_temp_datasets()
81
85
  if self.is_new_catalog:
@@ -83,6 +87,12 @@ class Session:
83
87
  self.catalog.warehouse.close_on_exit()
84
88
  self.catalog.id_generator.close_on_exit()
85
89
 
90
+ if Session.SESSION_CONTEXTS:
91
+ Session.SESSION_CONTEXTS.pop()
92
+
93
+ def add_dataset_version(self, dataset: "DatasetRecord", version: int) -> None:
94
+ self.dataset_versions.append((dataset, version))
95
+
86
96
  def generate_temp_dataset_name(self) -> str:
87
97
  return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
88
98
 
@@ -98,21 +108,15 @@ class Session:
98
108
  except TableMissingError:
99
109
  pass
100
110
 
101
- def _cleanup_created_versions(self, job_id: str) -> None:
102
- versions = self.catalog.metastore.get_job_dataset_versions(job_id)
103
- if not versions:
111
+ def _cleanup_created_versions(self) -> None:
112
+ if not self.dataset_versions:
104
113
  return
105
114
 
106
- datasets = {}
107
- for dataset_name, version in versions:
108
- if dataset_name not in datasets:
109
- datasets[dataset_name] = self.catalog.get_dataset(dataset_name)
110
- dataset = datasets[dataset_name]
111
- logger.info(
112
- "Removing dataset version %s@%s due to exception", dataset_name, version
113
- )
115
+ for dataset, version in self.dataset_versions:
114
116
  self.catalog.remove_dataset_version(dataset, version)
115
117
 
118
+ self.dataset_versions.clear()
119
+
116
120
  @classmethod
117
121
  def get(
118
122
  cls,
@@ -125,33 +129,34 @@ class Session:
125
129
 
126
130
  Parameters:
127
131
  session (Session): Optional Session(). If not provided a new session will
128
- be created. It's needed mostly for simplie API purposes.
129
- catalog (Catalog): Optional catalog. By default a new catalog is created.
132
+ be created. It's needed mostly for simple API purposes.
133
+ catalog (Catalog): Optional catalog. By default, a new catalog is created.
130
134
  """
131
135
  if session:
132
136
  return session
133
137
 
134
- if cls.GLOBAL_SESSION is None:
138
+ # Access the active (most recent) context from the stack
139
+ if cls.SESSION_CONTEXTS:
140
+ return cls.SESSION_CONTEXTS[-1]
141
+
142
+ if cls.GLOBAL_SESSION_CTX is None:
135
143
  cls.GLOBAL_SESSION_CTX = Session(
136
144
  cls.GLOBAL_SESSION_NAME,
137
145
  catalog,
138
146
  client_config=client_config,
139
147
  in_memory=in_memory,
140
148
  )
141
- cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
142
149
 
143
150
  atexit.register(cls._global_cleanup)
144
151
  cls.ORIGINAL_EXCEPT_HOOK = sys.excepthook
145
152
  sys.excepthook = cls.except_hook
146
153
 
147
- return cls.GLOBAL_SESSION
154
+ return cls.GLOBAL_SESSION_CTX
148
155
 
149
156
  @staticmethod
150
157
  def except_hook(exc_type, exc_value, exc_traceback):
158
+ Session.GLOBAL_SESSION_CTX.__exit__(exc_type, exc_value, exc_traceback)
151
159
  Session._global_cleanup()
152
- if Session.GLOBAL_SESSION_CTX is not None:
153
- job_id = Session.GLOBAL_SESSION_CTX.job_id
154
- Session.GLOBAL_SESSION_CTX._cleanup_created_versions(job_id)
155
160
 
156
161
  if Session.ORIGINAL_EXCEPT_HOOK:
157
162
  Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
@@ -160,7 +165,6 @@ class Session:
160
165
  def cleanup_for_tests(cls):
161
166
  if cls.GLOBAL_SESSION_CTX is not None:
162
167
  cls.GLOBAL_SESSION_CTX.__exit__(None, None, None)
163
- cls.GLOBAL_SESSION = None
164
168
  cls.GLOBAL_SESSION_CTX = None
165
169
  atexit.unregister(cls._global_cleanup)
166
170
 
@@ -171,3 +175,7 @@ class Session:
171
175
  def _global_cleanup():
172
176
  if Session.GLOBAL_SESSION_CTX is not None:
173
177
  Session.GLOBAL_SESSION_CTX.__exit__(None, None, None)
178
+
179
+ for obj in gc.get_objects(): # Get all tracked objects
180
+ if isinstance(obj, Session): # Cleanup temp dataset for session variables.
181
+ obj.__exit__(None, None, None)