sibi-dst 0.3.56__py3-none-any.whl → 0.3.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sibi_dst/df_helper/_artifact_updater_multi_wrapper.py +165 -166
- sibi_dst/df_helper/_df_helper.py +55 -23
- sibi_dst/df_helper/_parquet_artifact.py +29 -11
- sibi_dst/df_helper/backends/sqlalchemy/_db_connection.py +182 -89
- sibi_dst/df_helper/backends/sqlalchemy/_load_from_db.py +6 -2
- sibi_dst/utils/__init__.py +2 -0
- sibi_dst/utils/data_wrapper.py +34 -93
- sibi_dst/utils/parquet_saver.py +15 -12
- sibi_dst/utils/update_planner.py +237 -0
- {sibi_dst-0.3.56.dist-info → sibi_dst-0.3.57.dist-info}/METADATA +1 -1
- {sibi_dst-0.3.56.dist-info → sibi_dst-0.3.57.dist-info}/RECORD +12 -11
- {sibi_dst-0.3.56.dist-info → sibi_dst-0.3.57.dist-info}/WHEEL +0 -0
@@ -25,238 +25,237 @@ class ArtifactUpdaterMultiWrapper:
|
|
25
25
|
def __init__(self, wrapped_classes=None, debug=False, **kwargs):
|
26
26
|
self.wrapped_classes = wrapped_classes or {}
|
27
27
|
self.debug = debug
|
28
|
-
self.logger = kwargs.setdefault(
|
28
|
+
self.logger = kwargs.setdefault(
|
29
|
+
'logger', Logger.default_logger(logger_name=self.__class__.__name__)
|
30
|
+
)
|
29
31
|
self.logger.set_level(logging.DEBUG if debug else logging.INFO)
|
30
32
|
|
31
33
|
today = datetime.datetime.today()
|
32
|
-
self.
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
34
|
+
self.parquet_start_date = kwargs.get(
|
35
|
+
'parquet_start_date',
|
36
|
+
datetime.date(today.year, 1, 1).strftime('%Y-%m-%d')
|
37
|
+
)
|
38
|
+
self.parquet_end_date = kwargs.get(
|
39
|
+
'parquet_end_date',
|
40
|
+
today.strftime('%Y-%m-%d')
|
41
|
+
)
|
42
|
+
|
43
|
+
# track pending/completed/failed artifacts
|
44
|
+
self.pending = set()
|
45
|
+
self.completed = set()
|
46
|
+
self.failed = set()
|
47
|
+
|
48
|
+
# concurrency primitives
|
38
49
|
self.locks = {}
|
50
|
+
self.locks_lock = asyncio.Lock()
|
39
51
|
self.worker_heartbeat = defaultdict(float)
|
40
|
-
|
41
|
-
# graceful shutdown handling
|
42
|
-
loop = asyncio.get_event_loop()
|
43
|
-
self.register_signal_handlers(loop)
|
52
|
+
self.workers_lock = asyncio.Lock()
|
44
53
|
|
45
54
|
# dynamic scaling config
|
46
55
|
self.min_workers = kwargs.get('min_workers', 1)
|
47
|
-
self.max_workers = kwargs.get('max_workers',
|
48
|
-
self.memory_per_worker_gb = kwargs.get('memory_per_worker_gb', 1)
|
49
|
-
self.monitor_interval = kwargs.get('monitor_interval', 10)
|
56
|
+
self.max_workers = kwargs.get('max_workers', 3)
|
57
|
+
self.memory_per_worker_gb = kwargs.get('memory_per_worker_gb', 1)
|
58
|
+
self.monitor_interval = kwargs.get('monitor_interval', 10)
|
50
59
|
self.retry_attempts = kwargs.get('retry_attempts', 3)
|
51
60
|
self.update_timeout_seconds = kwargs.get('update_timeout_seconds', 600)
|
52
61
|
self.lock_acquire_timeout_seconds = kwargs.get('lock_acquire_timeout_seconds', 10)
|
53
62
|
|
54
|
-
def
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
61
|
-
[task.cancel() for task in tasks]
|
62
|
-
await asyncio.gather(*tasks, return_exceptions=True)
|
63
|
-
self.logger.info("Shutdown complete.")
|
64
|
-
|
65
|
-
def get_lock_for_artifact(self, artifact):
|
66
|
-
artifact_key = artifact.__class__.__name__
|
67
|
-
if artifact_key not in self.locks:
|
68
|
-
self.locks[artifact_key] = asyncio.Lock()
|
69
|
-
return self.locks[artifact_key]
|
63
|
+
async def get_lock_for_artifact(self, artifact):
|
64
|
+
key = artifact.__class__.__name__
|
65
|
+
async with self.locks_lock:
|
66
|
+
if key not in self.locks:
|
67
|
+
self.locks[key] = asyncio.Lock()
|
68
|
+
return self.locks[key]
|
70
69
|
|
71
70
|
def get_artifacts(self, data_type):
|
72
71
|
if data_type not in self.wrapped_classes:
|
73
72
|
raise ValueError(f"Unsupported data type: {data_type}")
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
73
|
+
artifacts = [cls(
|
74
|
+
parquet_start_date=self.parquet_start_date,
|
75
|
+
parquet_end_date=self.parquet_end_date,
|
76
|
+
logger=self.logger,
|
77
|
+
debug=self.debug
|
78
|
+
) for cls in self.wrapped_classes[data_type]]
|
79
|
+
# seed pending set and clear others
|
80
|
+
self.pending = set(artifacts)
|
81
|
+
self.completed.clear()
|
82
|
+
self.failed.clear()
|
83
|
+
return artifacts
|
84
84
|
|
85
85
|
def estimate_complexity(self, artifact):
|
86
86
|
try:
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
self.logger.warning(f"Failed to estimate complexity for {artifact}: {e}")
|
91
|
-
return 1 # default
|
87
|
+
return artifact.get_size_estimate()
|
88
|
+
except Exception:
|
89
|
+
return 1
|
92
90
|
|
93
91
|
def prioritize_tasks(self, artifacts):
|
94
92
|
queue = asyncio.PriorityQueue()
|
95
|
-
for
|
96
|
-
|
97
|
-
# we invert the complexity to ensure higher complexity -> higher priority
|
98
|
-
# if you want high complexity first, store negative complexity in the priority queue
|
99
|
-
# or if the smaller number means earlier processing, just keep as is
|
100
|
-
queue.put_nowait(PrioritizedItem(complexity, artifact))
|
93
|
+
for art in artifacts:
|
94
|
+
queue.put_nowait(PrioritizedItem(self.estimate_complexity(art), art))
|
101
95
|
return queue
|
102
96
|
|
103
97
|
async def resource_monitor(self, queue, workers):
|
104
|
-
|
105
|
-
while True:
|
106
|
-
# break if queue done
|
107
|
-
if queue.empty():
|
108
|
-
await asyncio.sleep(0.5)
|
109
|
-
if queue.empty():
|
110
|
-
break
|
111
|
-
|
98
|
+
while not queue.empty():
|
112
99
|
try:
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
for _ in range(diff):
|
130
|
-
worker_id = len(workers)
|
131
|
-
# create a new worker
|
132
|
-
w = asyncio.create_task(self.worker(queue, worker_id))
|
133
|
-
workers.append(w)
|
134
|
-
self.logger.info(f"Added worker {worker_id}. Total workers: {len(workers)}")
|
135
|
-
elif optimal_workers < current_worker_count:
|
136
|
-
# remove some workers
|
137
|
-
diff = current_worker_count - optimal_workers
|
138
|
-
for _ in range(diff):
|
139
|
-
w = workers.pop()
|
140
|
-
w.cancel()
|
141
|
-
self.logger.info(f"Removed a worker. Total workers: {len(workers)}")
|
142
|
-
|
100
|
+
avail = psutil.virtual_memory().available
|
101
|
+
max_by_mem = avail // (self.memory_per_worker_gb * 2**30)
|
102
|
+
optimal = max(self.min_workers,
|
103
|
+
min(psutil.cpu_count(), max_by_mem, self.max_workers))
|
104
|
+
async with self.workers_lock:
|
105
|
+
current = len(workers)
|
106
|
+
if optimal > current:
|
107
|
+
for _ in range(optimal - current):
|
108
|
+
wid = len(workers)
|
109
|
+
workers.append(asyncio.create_task(self.worker(queue, wid)))
|
110
|
+
self.logger.info(f"Added worker {wid}")
|
111
|
+
elif optimal < current:
|
112
|
+
for _ in range(current - optimal):
|
113
|
+
w = workers.pop()
|
114
|
+
w.cancel()
|
115
|
+
self.logger.info("Removed a worker")
|
143
116
|
await asyncio.sleep(self.monitor_interval)
|
144
|
-
|
145
117
|
except asyncio.CancelledError:
|
146
|
-
# monitor is being shut down
|
147
118
|
break
|
148
119
|
except Exception as e:
|
149
|
-
self.logger.error(f"
|
120
|
+
self.logger.error(f"Monitor error: {e}")
|
150
121
|
await asyncio.sleep(self.monitor_interval)
|
151
122
|
|
152
123
|
@asynccontextmanager
|
153
124
|
async def artifact_lock(self, artifact):
|
154
|
-
lock = self.get_lock_for_artifact(artifact)
|
125
|
+
lock = await self.get_lock_for_artifact(artifact)
|
155
126
|
try:
|
156
127
|
await asyncio.wait_for(lock.acquire(), timeout=self.lock_acquire_timeout_seconds)
|
157
128
|
yield
|
158
|
-
except asyncio.TimeoutError:
|
159
|
-
self.logger.error(f"Timeout acquiring lock for artifact: {artifact.__class__.__name__}")
|
160
|
-
yield # continue but no actual lock was acquired
|
161
129
|
finally:
|
162
130
|
if lock.locked():
|
163
131
|
lock.release()
|
164
132
|
|
165
133
|
async def async_update_artifact(self, artifact, **kwargs):
|
166
|
-
for attempt in range(self.retry_attempts):
|
134
|
+
for attempt in range(1, self.retry_attempts + 1):
|
135
|
+
lock = await self.get_lock_for_artifact(artifact)
|
167
136
|
try:
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
start_time = time.time()
|
137
|
+
await asyncio.wait_for(lock.acquire(), timeout=self.lock_acquire_timeout_seconds)
|
138
|
+
try:
|
139
|
+
self.logger.info(f"Updating {artifact.__class__.__name__} (attempt {attempt})")
|
172
140
|
await asyncio.wait_for(
|
173
141
|
asyncio.to_thread(artifact.update_parquet, **kwargs),
|
174
142
|
timeout=self.update_timeout_seconds
|
175
143
|
)
|
176
|
-
|
144
|
+
# mark success
|
145
|
+
async with self.workers_lock:
|
146
|
+
self.pending.discard(artifact)
|
147
|
+
self.completed.add(artifact)
|
177
148
|
self.logger.info(
|
178
|
-
f"
|
149
|
+
f"✅ {artifact.__class__.__name__} done — "
|
150
|
+
f"{len(self.completed)}/{len(self.completed) + len(self.pending) + len(self.failed)} completed, "
|
151
|
+
f"{len(self.failed)} failed"
|
152
|
+
)
|
179
153
|
return
|
180
|
-
|
154
|
+
finally:
|
155
|
+
if lock.locked():
|
156
|
+
lock.release()
|
181
157
|
except asyncio.TimeoutError:
|
182
|
-
self.logger.
|
158
|
+
self.logger.warning(f"Timeout on {artifact.__class__.__name__}, attempt {attempt}")
|
183
159
|
except Exception as e:
|
184
|
-
self.logger.error(
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
await asyncio.sleep(2 ** attempt)
|
160
|
+
self.logger.error(f"Error on {artifact}: {e}")
|
161
|
+
finally:
|
162
|
+
if lock.locked():
|
163
|
+
lock.release()
|
164
|
+
await asyncio.sleep(2 ** (attempt - 1))
|
189
165
|
|
190
|
-
|
166
|
+
# all retries exhausted -> mark failure
|
167
|
+
async with self.workers_lock:
|
168
|
+
self.pending.discard(artifact)
|
169
|
+
self.failed.add(artifact)
|
170
|
+
self.logger.error(f"✖️ Permanently failed {artifact.__class__.__name__}")
|
191
171
|
|
192
|
-
async def worker(self, queue, worker_id
|
193
|
-
"""A worker that dynamically pulls tasks from the queue."""
|
172
|
+
async def worker(self, queue, worker_id):
|
194
173
|
while True:
|
195
174
|
try:
|
196
|
-
|
197
|
-
|
198
|
-
break
|
199
|
-
artifact = prioritized_item.artifact
|
200
|
-
# heartbeat
|
175
|
+
item = await queue.get()
|
176
|
+
art = item.artifact
|
201
177
|
self.worker_heartbeat[worker_id] = time.time()
|
202
|
-
|
203
|
-
await self.async_update_artifact(artifact, **kwargs)
|
204
|
-
|
178
|
+
await self.async_update_artifact(art)
|
205
179
|
except asyncio.CancelledError:
|
206
|
-
self.logger.info(f"Worker {worker_id}
|
180
|
+
self.logger.info(f"Worker {worker_id} stopped")
|
207
181
|
break
|
208
|
-
except Exception as e:
|
209
|
-
self.logger.error(f"Error in worker {worker_id}: {e}")
|
210
182
|
finally:
|
211
183
|
queue.task_done()
|
212
184
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
w = asyncio.create_task(self.worker(queue, worker_id, **kwargs))
|
219
|
-
workers.append(w)
|
220
|
-
|
221
|
-
# start resource monitor
|
222
|
-
monitor_task = asyncio.create_task(self.resource_monitor(queue, workers))
|
223
|
-
|
224
|
-
# wait until queue is done
|
225
|
-
try:
|
226
|
-
await queue.join()
|
227
|
-
finally:
|
228
|
-
# cancel resource monitor
|
229
|
-
monitor_task.cancel()
|
230
|
-
# all workers done
|
231
|
-
for w in workers:
|
232
|
-
w.cancel()
|
233
|
-
await asyncio.gather(*workers, return_exceptions=True)
|
185
|
+
def calculate_initial_workers(self, count: int) -> int:
|
186
|
+
avail = psutil.virtual_memory().available
|
187
|
+
max_by_mem = avail // (self.memory_per_worker_gb * 2**30)
|
188
|
+
return max(self.min_workers,
|
189
|
+
min(psutil.cpu_count(), max_by_mem, count, self.max_workers))
|
234
190
|
|
235
191
|
async def update_data(self, data_type, **kwargs):
|
236
|
-
self.logger.info(f"
|
192
|
+
self.logger.info(f"Starting update for {data_type}")
|
237
193
|
artifacts = self.get_artifacts(data_type)
|
238
194
|
queue = self.prioritize_tasks(artifacts)
|
195
|
+
init = self.calculate_initial_workers(len(artifacts))
|
196
|
+
tasks = [asyncio.create_task(self.worker(queue, i)) for i in range(init)]
|
197
|
+
monitor = asyncio.create_task(self.resource_monitor(queue, tasks))
|
198
|
+
await queue.join()
|
199
|
+
monitor.cancel()
|
200
|
+
for t in tasks:
|
201
|
+
t.cancel()
|
202
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
203
|
+
self.logger.info(self.format_results_table())
|
204
|
+
self.logger.info("All artifacts processed.")
|
205
|
+
|
206
|
+
def format_results_table(self):
|
207
|
+
results = self.get_update_status()
|
208
|
+
headers = ["Metric", "Value"]
|
209
|
+
rows = [
|
210
|
+
["Total", results['total']],
|
211
|
+
["Completed", results['completed']],
|
212
|
+
["Pending", results['pending']],
|
213
|
+
["Failed", results['failed']],
|
214
|
+
["Pending Items", len(results['pending_items'])],
|
215
|
+
["Failed Items", len(results['failed_items'])]
|
216
|
+
]
|
217
|
+
|
218
|
+
# Find max lengths for alignment
|
219
|
+
max_metric = max(len(str(row[0])) for row in rows)
|
220
|
+
max_value = max(len(str(row[1])) for row in rows)
|
239
221
|
|
240
|
-
|
241
|
-
initial_workers = self.calculate_initial_workers(len(artifacts))
|
242
|
-
self.logger.info(f"Initial worker count: {initial_workers} for {len(artifacts)} artifacts")
|
222
|
+
format_str = "{:<%d} {:>%d}" % (max_metric, max_value)
|
243
223
|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
224
|
+
table = [
|
225
|
+
"\n",
|
226
|
+
format_str.format(*headers),
|
227
|
+
"-" * (max_metric + max_value + 2)
|
228
|
+
]
|
248
229
|
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
self.
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
230
|
+
for row in rows:
|
231
|
+
table.append(format_str.format(row[0], row[1]))
|
232
|
+
|
233
|
+
return "\n".join(table)
|
234
|
+
|
235
|
+
def get_update_status(self):
|
236
|
+
total = len(self.pending) + len(self.completed) + len(self.failed)
|
237
|
+
return {
|
238
|
+
"total": total,
|
239
|
+
"completed": len(self.completed),
|
240
|
+
"pending": len(self.pending),
|
241
|
+
"failed": len(self.failed),
|
242
|
+
"pending_items": [a.__class__.__name__ for a in self.pending],
|
243
|
+
"failed_items": [a.__class__.__name__ for a in self.failed]
|
244
|
+
}
|
245
|
+
|
246
|
+
# Top‑level driver
|
247
|
+
# environment = None # fill this in with your wrapped_classes dict
|
248
|
+
#
|
249
|
+
# async def main():
|
250
|
+
# wrapper = ArtifactUpdaterMultiWrapper(
|
251
|
+
# wrapped_classes=environment,
|
252
|
+
# debug=True
|
253
|
+
# )
|
254
|
+
# loop = asyncio.get_running_loop()
|
255
|
+
# for sig in (signal.SIGINT, signal.SIGTERM):
|
256
|
+
# loop.add_signal_handler(sig, lambda: asyncio.create_task(wrapper.shutdown()))
|
257
|
+
# await wrapper.update_data("your_data_type")
|
258
|
+
#
|
259
|
+
# if __name__ == "__main__":
|
260
|
+
# asyncio.run(main())
|
262
261
|
|
sibi_dst/df_helper/_df_helper.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import datetime
|
3
5
|
import logging
|
@@ -6,10 +8,10 @@ from typing import Any, Dict, TypeVar
|
|
6
8
|
from typing import Union, Optional
|
7
9
|
|
8
10
|
import dask.dataframe as dd
|
9
|
-
|
11
|
+
import fsspec
|
10
12
|
import pandas as pd
|
13
|
+
from dask import delayed, compute
|
11
14
|
from pydantic import BaseModel
|
12
|
-
import fsspec
|
13
15
|
|
14
16
|
from sibi_dst.df_helper.core import QueryConfig, ParamsConfig, FilterHandler
|
15
17
|
from sibi_dst.utils import Logger
|
@@ -45,7 +47,7 @@ class DfHelper:
|
|
45
47
|
:ivar df: The DataFrame currently being processed or loaded.
|
46
48
|
:type df: Union[dd.DataFrame, pd.DataFrame]
|
47
49
|
:ivar backend_django: Configuration for interacting with Django database backends.
|
48
|
-
:type
|
50
|
+
:type backend_connection: Optional[DjangoConnectionConfig]
|
49
51
|
:ivar _backend_query: Internal configuration for query handling.
|
50
52
|
:type _backend_query: Optional[QueryConfig]
|
51
53
|
:ivar _backend_params: Internal parameters configuration for DataFrame handling.
|
@@ -54,8 +56,6 @@ class DfHelper:
|
|
54
56
|
:type backend_parquet: Optional[ParquetConfig]
|
55
57
|
:ivar backend_http: Configuration for interacting with HTTP-based backends.
|
56
58
|
:type backend_http: Optional[HttpConfig]
|
57
|
-
:ivar backend_sqlalchemy: Configuration for interacting with SQLAlchemy-based databases.
|
58
|
-
:type backend_sqlalchemy: Optional[SqlAlchemyConnectionConfig]
|
59
59
|
:ivar parquet_filename: The filename for a Parquet file, if applicable.
|
60
60
|
:type parquet_filename: str
|
61
61
|
:ivar logger: Logger instance used for debugging and information logging.
|
@@ -64,12 +64,11 @@ class DfHelper:
|
|
64
64
|
:type default_config: Dict
|
65
65
|
"""
|
66
66
|
df: Union[dd.DataFrame, pd.DataFrame] = None
|
67
|
-
|
67
|
+
backend_db_connection: Optional[Union[DjangoConnectionConfig | SqlAlchemyConnectionConfig]] = None
|
68
68
|
_backend_query: Optional[QueryConfig] = None
|
69
69
|
_backend_params: Optional[ParamsConfig] = None
|
70
70
|
backend_parquet: Optional[ParquetConfig] = None
|
71
71
|
backend_http: Optional[HttpConfig] = None
|
72
|
-
backend_sqlalchemy: Optional[SqlAlchemyConnectionConfig] = None
|
73
72
|
parquet_filename: str = None
|
74
73
|
logger: Logger
|
75
74
|
default_config: Dict = None
|
@@ -91,7 +90,7 @@ class DfHelper:
|
|
91
90
|
self.filesystem_options = kwargs.pop('filesystem_options', {})
|
92
91
|
kwargs.setdefault("live", True)
|
93
92
|
kwargs.setdefault("logger", self.logger)
|
94
|
-
self.fs =kwargs.setdefault("fs", fsspec.filesystem('file'))
|
93
|
+
self.fs = kwargs.setdefault("fs", fsspec.filesystem('file'))
|
95
94
|
self.__post_init(**kwargs)
|
96
95
|
|
97
96
|
def __str__(self):
|
@@ -100,6 +99,34 @@ class DfHelper:
|
|
100
99
|
def __call__(self, **options):
|
101
100
|
return self.load(**options)
|
102
101
|
|
102
|
+
def __enter__(self):
|
103
|
+
return self
|
104
|
+
|
105
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
106
|
+
self.__cleanup()
|
107
|
+
return False
|
108
|
+
|
109
|
+
def __cleanup(self):
|
110
|
+
"""
|
111
|
+
Clean up resources when exiting the context manager.
|
112
|
+
This method is called when the context manager exits.
|
113
|
+
"""
|
114
|
+
|
115
|
+
if self.backend_db_connection:
|
116
|
+
if getattr(self.backend_db_connection, "dispose_idle_connections", None):
|
117
|
+
self.backend_db_connection.dispose_idle_connections()
|
118
|
+
if getattr(self.backend_db_connection, "close", None):
|
119
|
+
self.backend_db_connection.close()
|
120
|
+
|
121
|
+
self.backend_db_connection = None
|
122
|
+
|
123
|
+
if self.backend_parquet:
|
124
|
+
self.backend_parquet = None
|
125
|
+
if self.backend_http:
|
126
|
+
self.backend_http = None
|
127
|
+
self._backend_query = None
|
128
|
+
self._backend_params = None
|
129
|
+
|
103
130
|
def __post_init(self, **kwargs):
|
104
131
|
"""
|
105
132
|
Initializes backend-specific configurations based on the provided backend type and other
|
@@ -116,15 +143,14 @@ class DfHelper:
|
|
116
143
|
self._backend_query = self.__get_config(QueryConfig, kwargs)
|
117
144
|
self._backend_params = self.__get_config(ParamsConfig, kwargs)
|
118
145
|
if self.backend == 'django_db':
|
119
|
-
self.
|
146
|
+
self.backend_db_connection = self.__get_config(DjangoConnectionConfig, kwargs)
|
120
147
|
elif self.backend == 'parquet':
|
121
148
|
self.parquet_filename = kwargs.setdefault("parquet_filename", None)
|
122
149
|
self.backend_parquet = ParquetConfig(**kwargs)
|
123
150
|
elif self.backend == 'http':
|
124
151
|
self.backend_http = HttpConfig(**kwargs)
|
125
152
|
elif self.backend == 'sqlalchemy':
|
126
|
-
self.
|
127
|
-
|
153
|
+
self.backend_db_connection = self.__get_config(SqlAlchemyConnectionConfig, kwargs)
|
128
154
|
|
129
155
|
def __get_config(self, model: [T], kwargs: Dict[str, Any]) -> Union[T]:
|
130
156
|
"""
|
@@ -134,6 +160,8 @@ class DfHelper:
|
|
134
160
|
:param kwargs: The dictionary of keyword arguments.
|
135
161
|
:return: The initialized Pydantic model instance.
|
136
162
|
"""
|
163
|
+
kwargs.setdefault("debug", self.debug)
|
164
|
+
kwargs.setdefault("logger", self.logger)
|
137
165
|
# Extract keys that the model can accept
|
138
166
|
recognized_keys = set(model.__annotations__.keys())
|
139
167
|
self.logger.debug(f"recognized keys: {recognized_keys}")
|
@@ -171,10 +199,10 @@ class DfHelper:
|
|
171
199
|
`as_pandas` is set to True, or kept in its native backend format otherwise.
|
172
200
|
"""
|
173
201
|
# this will be the universal method to load data from a df irrespective of the backend
|
174
|
-
df = self.__load(**options)
|
202
|
+
self.df = self.__load(**options)
|
175
203
|
if self.as_pandas:
|
176
|
-
return df.compute()
|
177
|
-
return df
|
204
|
+
return self.df.compute()
|
205
|
+
return self.df
|
178
206
|
|
179
207
|
def __load(self, **options):
|
180
208
|
"""
|
@@ -196,7 +224,7 @@ class DfHelper:
|
|
196
224
|
"""
|
197
225
|
if self.backend == 'django_db':
|
198
226
|
self._backend_params.parse_params(options)
|
199
|
-
return self.
|
227
|
+
return self.__load_from_django_db(**options)
|
200
228
|
elif self.backend == 'sqlalchemy':
|
201
229
|
self._backend_params.parse_params(options)
|
202
230
|
return self.__load_from_sqlalchemy(**options)
|
@@ -227,7 +255,7 @@ class DfHelper:
|
|
227
255
|
try:
|
228
256
|
options.setdefault("debug", self.debug)
|
229
257
|
db_loader = SqlAlchemyLoadFromDb(
|
230
|
-
self.
|
258
|
+
self.backend_db_connection,
|
231
259
|
self._backend_query,
|
232
260
|
self._backend_params,
|
233
261
|
self.logger,
|
@@ -236,6 +264,7 @@ class DfHelper:
|
|
236
264
|
self.df = db_loader.build_and_load()
|
237
265
|
self.__process_loaded_data()
|
238
266
|
self.__post_process_df()
|
267
|
+
self.backend_db_connection.close()
|
239
268
|
self.logger.debug("Data successfully loaded from sqlalchemy database.")
|
240
269
|
except Exception as e:
|
241
270
|
self.logger.debug(f"Failed to load data from sqlalchemy database: {e}: options: {options}")
|
@@ -243,7 +272,7 @@ class DfHelper:
|
|
243
272
|
|
244
273
|
return self.df
|
245
274
|
|
246
|
-
def
|
275
|
+
def __load_from_django_db(self, **options) -> Union[pd.DataFrame, dd.DataFrame]:
|
247
276
|
"""
|
248
277
|
Loads data from a Django database using a specific backend query mechanism. Processes the loaded data
|
249
278
|
and applies further post-processing before returning the dataframe. If the operation fails, an
|
@@ -258,7 +287,7 @@ class DfHelper:
|
|
258
287
|
try:
|
259
288
|
options.setdefault("debug", self.debug)
|
260
289
|
db_loader = DjangoLoadFromDb(
|
261
|
-
self.
|
290
|
+
self.backend_db_connection,
|
262
291
|
self._backend_query,
|
263
292
|
self._backend_params,
|
264
293
|
self.logger,
|
@@ -307,6 +336,7 @@ class DfHelper:
|
|
307
336
|
:raises ValueError: If the lengths of `fieldnames` and `column_names` do not match,
|
308
337
|
or if the specified `index_col` is not found in the DataFrame.
|
309
338
|
"""
|
339
|
+
self.logger.debug("Post-processing DataFrame.")
|
310
340
|
df_params = self._backend_params.df_params
|
311
341
|
fieldnames = df_params.get("fieldnames", None)
|
312
342
|
index_col = df_params.get("index_col", None)
|
@@ -357,16 +387,16 @@ class DfHelper:
|
|
357
387
|
|
358
388
|
:return: None
|
359
389
|
"""
|
360
|
-
self.logger.debug(f"
|
390
|
+
self.logger.debug(f"Processing loaded data...")
|
361
391
|
if self.df.map_partitions(len).compute().sum() > 0:
|
362
392
|
field_map = self._backend_params.field_map or {}
|
363
|
-
if isinstance(field_map, dict):
|
393
|
+
if isinstance(field_map, dict) and field_map != {}:
|
364
394
|
rename_mapping = {k: v for k, v in field_map.items() if k in self.df.columns}
|
365
395
|
missing_columns = [k for k in field_map.keys() if k not in self.df.columns]
|
366
396
|
|
367
397
|
if missing_columns:
|
368
398
|
self.logger.warning(
|
369
|
-
f"The following columns in field_map are not in the DataFrame: {missing_columns}")
|
399
|
+
f"The following columns in field_map are not in the DataFrame: {missing_columns}, field map: {field_map}")
|
370
400
|
|
371
401
|
def rename_columns(df, mapping):
|
372
402
|
return df.rename(columns=mapping)
|
@@ -376,6 +406,8 @@ class DfHelper:
|
|
376
406
|
self.df = self.df.map_partitions(rename_columns, mapping=rename_mapping)
|
377
407
|
|
378
408
|
self.logger.debug("Processing of loaded data completed.")
|
409
|
+
else:
|
410
|
+
self.logger.debug("DataFrame is empty, skipping processing.")
|
379
411
|
|
380
412
|
def save_to_parquet(self, parquet_filename: Optional[str] = None, **kwargs):
|
381
413
|
"""
|
@@ -536,14 +568,14 @@ class DfHelper:
|
|
536
568
|
|
537
569
|
# Common logic for Django and SQLAlchemy
|
538
570
|
if self.backend == 'django_db':
|
539
|
-
model_fields = {field.name: field for field in self.
|
571
|
+
model_fields = {field.name: field for field in self.backend_db_connection.model._meta.get_fields()}
|
540
572
|
if mapped_field not in model_fields:
|
541
573
|
raise ValueError(f"Field '{dt_field}' does not exist in the Django model.")
|
542
574
|
field_type = type(model_fields[mapped_field]).__name__
|
543
575
|
is_date_field = field_type == 'DateField'
|
544
576
|
is_datetime_field = field_type == 'DateTimeField'
|
545
577
|
elif self.backend == 'sqlalchemy':
|
546
|
-
model = self.
|
578
|
+
model = self.backend_db_connection.model
|
547
579
|
fields = [column.name for column in model.__table__.columns]
|
548
580
|
if mapped_field not in fields:
|
549
581
|
raise ValueError(f"Field '{dt_field}' does not exist in the SQLAlchemy model.")
|