sibi-dst 0.3.56__py3-none-any.whl → 0.3.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,238 +25,237 @@ class ArtifactUpdaterMultiWrapper:
25
25
  def __init__(self, wrapped_classes=None, debug=False, **kwargs):
26
26
  self.wrapped_classes = wrapped_classes or {}
27
27
  self.debug = debug
28
- self.logger = kwargs.setdefault('logger',Logger.default_logger(logger_name=self.__class__.__name__))
28
+ self.logger = kwargs.setdefault(
29
+ 'logger', Logger.default_logger(logger_name=self.__class__.__name__)
30
+ )
29
31
  self.logger.set_level(logging.DEBUG if debug else logging.INFO)
30
32
 
31
33
  today = datetime.datetime.today()
32
- self.today_str = today.strftime('%Y-%m-%d')
33
- self.current_year_starts_on_str = datetime.date(today.year, 1, 1).strftime('%Y-%m-%d')
34
- self.parquet_start_date = kwargs.get('parquet_start_date', self.current_year_starts_on_str)
35
- self.parquet_end_date = kwargs.get('parquet_end_date', self.today_str)
36
-
37
- # track concurrency and locks
34
+ self.parquet_start_date = kwargs.get(
35
+ 'parquet_start_date',
36
+ datetime.date(today.year, 1, 1).strftime('%Y-%m-%d')
37
+ )
38
+ self.parquet_end_date = kwargs.get(
39
+ 'parquet_end_date',
40
+ today.strftime('%Y-%m-%d')
41
+ )
42
+
43
+ # track pending/completed/failed artifacts
44
+ self.pending = set()
45
+ self.completed = set()
46
+ self.failed = set()
47
+
48
+ # concurrency primitives
38
49
  self.locks = {}
50
+ self.locks_lock = asyncio.Lock()
39
51
  self.worker_heartbeat = defaultdict(float)
40
-
41
- # graceful shutdown handling
42
- loop = asyncio.get_event_loop()
43
- self.register_signal_handlers(loop)
52
+ self.workers_lock = asyncio.Lock()
44
53
 
45
54
  # dynamic scaling config
46
55
  self.min_workers = kwargs.get('min_workers', 1)
47
- self.max_workers = kwargs.get('max_workers', 8)
48
- self.memory_per_worker_gb = kwargs.get('memory_per_worker_gb', 1) # default 1GB per worker
49
- self.monitor_interval = kwargs.get('monitor_interval', 10) # default monitor interval in seconds
56
+ self.max_workers = kwargs.get('max_workers', 3)
57
+ self.memory_per_worker_gb = kwargs.get('memory_per_worker_gb', 1)
58
+ self.monitor_interval = kwargs.get('monitor_interval', 10)
50
59
  self.retry_attempts = kwargs.get('retry_attempts', 3)
51
60
  self.update_timeout_seconds = kwargs.get('update_timeout_seconds', 600)
52
61
  self.lock_acquire_timeout_seconds = kwargs.get('lock_acquire_timeout_seconds', 10)
53
62
 
54
- def register_signal_handlers(self, loop):
55
- for sig in (signal.SIGINT, signal.SIGTERM):
56
- loop.add_signal_handler(sig, lambda: asyncio.create_task(self.shutdown()))
57
-
58
- async def shutdown(self):
59
- self.logger.info("Shutdown signal received. Cleaning up...")
60
- tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
61
- [task.cancel() for task in tasks]
62
- await asyncio.gather(*tasks, return_exceptions=True)
63
- self.logger.info("Shutdown complete.")
64
-
65
- def get_lock_for_artifact(self, artifact):
66
- artifact_key = artifact.__class__.__name__
67
- if artifact_key not in self.locks:
68
- self.locks[artifact_key] = asyncio.Lock()
69
- return self.locks[artifact_key]
63
+ async def get_lock_for_artifact(self, artifact):
64
+ key = artifact.__class__.__name__
65
+ async with self.locks_lock:
66
+ if key not in self.locks:
67
+ self.locks[key] = asyncio.Lock()
68
+ return self.locks[key]
70
69
 
71
70
  def get_artifacts(self, data_type):
72
71
  if data_type not in self.wrapped_classes:
73
72
  raise ValueError(f"Unsupported data type: {data_type}")
74
-
75
- return [
76
- artifact_class(
77
- parquet_start_date=self.parquet_start_date,
78
- parquet_end_date=self.parquet_end_date,
79
- logger=self.logger,
80
- debug=self.debug
81
- )
82
- for artifact_class in self.wrapped_classes[data_type]
83
- ]
73
+ artifacts = [cls(
74
+ parquet_start_date=self.parquet_start_date,
75
+ parquet_end_date=self.parquet_end_date,
76
+ logger=self.logger,
77
+ debug=self.debug
78
+ ) for cls in self.wrapped_classes[data_type]]
79
+ # seed pending set and clear others
80
+ self.pending = set(artifacts)
81
+ self.completed.clear()
82
+ self.failed.clear()
83
+ return artifacts
84
84
 
85
85
  def estimate_complexity(self, artifact):
86
86
  try:
87
- if hasattr(artifact, 'get_size_estimate'):
88
- return artifact.get_size_estimate()
89
- except Exception as e:
90
- self.logger.warning(f"Failed to estimate complexity for {artifact}: {e}")
91
- return 1 # default
87
+ return artifact.get_size_estimate()
88
+ except Exception:
89
+ return 1
92
90
 
93
91
  def prioritize_tasks(self, artifacts):
94
92
  queue = asyncio.PriorityQueue()
95
- for artifact in artifacts:
96
- complexity = self.estimate_complexity(artifact)
97
- # we invert the complexity to ensure higher complexity -> higher priority
98
- # if you want high complexity first, store negative complexity in the priority queue
99
- # or if the smaller number means earlier processing, just keep as is
100
- queue.put_nowait(PrioritizedItem(complexity, artifact))
93
+ for art in artifacts:
94
+ queue.put_nowait(PrioritizedItem(self.estimate_complexity(art), art))
101
95
  return queue
102
96
 
103
97
  async def resource_monitor(self, queue, workers):
104
- """Monitor system resources and adjust worker count while queue is not empty."""
105
- while True:
106
- # break if queue done
107
- if queue.empty():
108
- await asyncio.sleep(0.5)
109
- if queue.empty():
110
- break
111
-
98
+ while not queue.empty():
112
99
  try:
113
- available_memory = psutil.virtual_memory().available
114
- worker_memory_bytes = self.memory_per_worker_gb * (1024 ** 3)
115
- max_workers_by_memory = available_memory // worker_memory_bytes
116
-
117
- # figure out how many workers we can sustain
118
- # note: we also cap by self.max_workers
119
- optimal_workers = min(psutil.cpu_count(), max_workers_by_memory, self.max_workers)
120
-
121
- # ensure at least self.min_workers is used
122
- optimal_workers = max(self.min_workers, optimal_workers)
123
-
124
- current_worker_count = len(workers)
125
-
126
- if optimal_workers > current_worker_count:
127
- # we can add more workers if queue is not empty
128
- diff = optimal_workers - current_worker_count
129
- for _ in range(diff):
130
- worker_id = len(workers)
131
- # create a new worker
132
- w = asyncio.create_task(self.worker(queue, worker_id))
133
- workers.append(w)
134
- self.logger.info(f"Added worker {worker_id}. Total workers: {len(workers)}")
135
- elif optimal_workers < current_worker_count:
136
- # remove some workers
137
- diff = current_worker_count - optimal_workers
138
- for _ in range(diff):
139
- w = workers.pop()
140
- w.cancel()
141
- self.logger.info(f"Removed a worker. Total workers: {len(workers)}")
142
-
100
+ avail = psutil.virtual_memory().available
101
+ max_by_mem = avail // (self.memory_per_worker_gb * 2**30)
102
+ optimal = max(self.min_workers,
103
+ min(psutil.cpu_count(), max_by_mem, self.max_workers))
104
+ async with self.workers_lock:
105
+ current = len(workers)
106
+ if optimal > current:
107
+ for _ in range(optimal - current):
108
+ wid = len(workers)
109
+ workers.append(asyncio.create_task(self.worker(queue, wid)))
110
+ self.logger.info(f"Added worker {wid}")
111
+ elif optimal < current:
112
+ for _ in range(current - optimal):
113
+ w = workers.pop()
114
+ w.cancel()
115
+ self.logger.info("Removed a worker")
143
116
  await asyncio.sleep(self.monitor_interval)
144
-
145
117
  except asyncio.CancelledError:
146
- # monitor is being shut down
147
118
  break
148
119
  except Exception as e:
149
- self.logger.error(f"Error in resource_monitor: {e}")
120
+ self.logger.error(f"Monitor error: {e}")
150
121
  await asyncio.sleep(self.monitor_interval)
151
122
 
152
123
  @asynccontextmanager
153
124
  async def artifact_lock(self, artifact):
154
- lock = self.get_lock_for_artifact(artifact)
125
+ lock = await self.get_lock_for_artifact(artifact)
155
126
  try:
156
127
  await asyncio.wait_for(lock.acquire(), timeout=self.lock_acquire_timeout_seconds)
157
128
  yield
158
- except asyncio.TimeoutError:
159
- self.logger.error(f"Timeout acquiring lock for artifact: {artifact.__class__.__name__}")
160
- yield # continue but no actual lock was acquired
161
129
  finally:
162
130
  if lock.locked():
163
131
  lock.release()
164
132
 
165
133
  async def async_update_artifact(self, artifact, **kwargs):
166
- for attempt in range(self.retry_attempts):
134
+ for attempt in range(1, self.retry_attempts + 1):
135
+ lock = await self.get_lock_for_artifact(artifact)
167
136
  try:
168
- async with self.artifact_lock(artifact):
169
- self.logger.info(
170
- f"Updating artifact: {artifact.__class__.__name__}, Attempt: {attempt + 1} of {self.retry_attempts}" )
171
- start_time = time.time()
137
+ await asyncio.wait_for(lock.acquire(), timeout=self.lock_acquire_timeout_seconds)
138
+ try:
139
+ self.logger.info(f"Updating {artifact.__class__.__name__} (attempt {attempt})")
172
140
  await asyncio.wait_for(
173
141
  asyncio.to_thread(artifact.update_parquet, **kwargs),
174
142
  timeout=self.update_timeout_seconds
175
143
  )
176
- elapsed_time = time.time() - start_time
144
+ # mark success
145
+ async with self.workers_lock:
146
+ self.pending.discard(artifact)
147
+ self.completed.add(artifact)
177
148
  self.logger.info(
178
- f"Successfully updated artifact: {artifact.__class__.__name__} in {elapsed_time:.2f}s." )
149
+ f" {artifact.__class__.__name__} done "
150
+ f"{len(self.completed)}/{len(self.completed) + len(self.pending) + len(self.failed)} completed, "
151
+ f"{len(self.failed)} failed"
152
+ )
179
153
  return
180
-
154
+ finally:
155
+ if lock.locked():
156
+ lock.release()
181
157
  except asyncio.TimeoutError:
182
- self.logger.error(f"Timeout updating artifact {artifact.__class__.__name__}, Attempt: {attempt + 1}")
158
+ self.logger.warning(f"Timeout on {artifact.__class__.__name__}, attempt {attempt}")
183
159
  except Exception as e:
184
- self.logger.error(
185
- f"Error updating artifact {artifact.__class__.__name__}, Attempt: {attempt + 1}: {e}" )
186
-
187
- # exponential backoff
188
- await asyncio.sleep(2 ** attempt)
160
+ self.logger.error(f"Error on {artifact}: {e}")
161
+ finally:
162
+ if lock.locked():
163
+ lock.release()
164
+ await asyncio.sleep(2 ** (attempt - 1))
189
165
 
190
- self.logger.error(f"All retry attempts failed for artifact: {artifact.__class__.__name__}")
166
+ # all retries exhausted -> mark failure
167
+ async with self.workers_lock:
168
+ self.pending.discard(artifact)
169
+ self.failed.add(artifact)
170
+ self.logger.error(f"✖️ Permanently failed {artifact.__class__.__name__}")
191
171
 
192
172
  async def worker(self, queue, worker_id, **kwargs):
193
- """A worker that dynamically pulls tasks from the queue."""
194
173
  while True:
195
174
  try:
196
- prioritized_item = await queue.get()
197
- if prioritized_item is None:
198
- break
199
- artifact = prioritized_item.artifact
200
- # heartbeat
175
+ item = await queue.get()
176
+ art = item.artifact
201
177
  self.worker_heartbeat[worker_id] = time.time()
202
-
203
- await self.async_update_artifact(artifact, **kwargs)
204
-
178
+ await self.async_update_artifact(art, **kwargs)
205
179
  except asyncio.CancelledError:
206
- self.logger.info(f"Worker {worker_id} shutting down gracefully.")
180
+ self.logger.info(f"Worker {worker_id} stopped")
207
181
  break
208
- except Exception as e:
209
- self.logger.error(f"Error in worker {worker_id}: {e}")
210
182
  finally:
211
183
  queue.task_done()
212
184
 
213
- async def process_tasks(self, queue, initial_workers, **kwargs):
214
- """Start a set of workers and a resource monitor to dynamically adjust them."""
215
- # create initial workers
216
- workers = []
217
- for worker_id in range(initial_workers):
218
- w = asyncio.create_task(self.worker(queue, worker_id, **kwargs))
219
- workers.append(w)
220
-
221
- # start resource monitor
222
- monitor_task = asyncio.create_task(self.resource_monitor(queue, workers))
223
-
224
- # wait until queue is done
225
- try:
226
- await queue.join()
227
- finally:
228
- # cancel resource monitor
229
- monitor_task.cancel()
230
- # all workers done
231
- for w in workers:
232
- w.cancel()
233
- await asyncio.gather(*workers, return_exceptions=True)
185
+ def calculate_initial_workers(self, count: int) -> int:
186
+ avail = psutil.virtual_memory().available
187
+ max_by_mem = avail // (self.memory_per_worker_gb * 2**30)
188
+ return max(self.min_workers,
189
+ min(psutil.cpu_count(), max_by_mem, count, self.max_workers))
234
190
 
235
191
  async def update_data(self, data_type, **kwargs):
236
- self.logger.info(f"Processing wrapper group: {data_type} with {kwargs}")
192
+ self.logger.info(f"Starting update for {data_type}")
237
193
  artifacts = self.get_artifacts(data_type)
238
194
  queue = self.prioritize_tasks(artifacts)
195
+ init = self.calculate_initial_workers(len(artifacts))
196
+ tasks = [asyncio.create_task(self.worker(queue, i, **kwargs)) for i in range(init)]
197
+ monitor = asyncio.create_task(self.resource_monitor(queue, tasks))
198
+ await queue.join()
199
+ monitor.cancel()
200
+ for t in tasks:
201
+ t.cancel()
202
+ await asyncio.gather(*tasks, return_exceptions=True)
203
+ self.logger.info(self.format_results_table())
204
+ self.logger.info("All artifacts processed.")
205
+
206
+ def format_results_table(self):
207
+ results = self.get_update_status()
208
+ headers = ["Metric", "Value"]
209
+ rows = [
210
+ ["Total", results['total']],
211
+ ["Completed", results['completed']],
212
+ ["Pending", results['pending']],
213
+ ["Failed", results['failed']],
214
+ ["Pending Items", len(results['pending_items'])],
215
+ ["Failed Items", len(results['failed_items'])]
216
+ ]
217
+
218
+ # Find max lengths for alignment
219
+ max_metric = max(len(str(row[0])) for row in rows)
220
+ max_value = max(len(str(row[1])) for row in rows)
239
221
 
240
- # compute initial worker count (this can be low if memory is low initially)
241
- initial_workers = self.calculate_initial_workers(len(artifacts))
242
- self.logger.info(f"Initial worker count: {initial_workers} for {len(artifacts)} artifacts")
222
+ format_str = "{:<%d} {:>%d}" % (max_metric, max_value)
243
223
 
244
- total_start_time = time.time()
245
- await self.process_tasks(queue, initial_workers, **kwargs)
246
- total_time = time.time() - total_start_time
247
- self.logger.info(f"Total processing time: {total_time:.2f} seconds.")
224
+ table = [
225
+ "\n",
226
+ format_str.format(*headers),
227
+ "-" * (max_metric + max_value + 2)
228
+ ]
248
229
 
249
- def calculate_initial_workers(self, artifact_count: int) -> int:
250
- """Compute the initial number of workers before resource_monitor can adjust."""
251
- self.logger.info("Calculating initial worker count...")
252
- available_memory = psutil.virtual_memory().available
253
- self.logger.info(f"Available memory: {available_memory / (1024 ** 3):.2f} GB")
254
- worker_memory_bytes = self.memory_per_worker_gb * (1024 ** 3)
255
- self.logger.info(f"Memory per worker: {worker_memory_bytes / (1024 ** 3):.2f} GB")
256
- max_workers_by_memory = available_memory // worker_memory_bytes
257
- self.logger.info(f"Max workers by memory: {max_workers_by_memory}")
258
- # also consider CPU count and artifact_count
259
- initial = min(psutil.cpu_count(), max_workers_by_memory, artifact_count, self.max_workers)
260
- self.logger.info(f"Optimal workers: {initial} CPU: {psutil.cpu_count()} Max Workers: {self.max_workers}")
261
- return max(self.min_workers, initial)
230
+ for row in rows:
231
+ table.append(format_str.format(row[0], row[1]))
232
+
233
+ return "\n".join(table)
234
+
235
+ def get_update_status(self):
236
+ total = len(self.pending) + len(self.completed) + len(self.failed)
237
+ return {
238
+ "total": total,
239
+ "completed": len(self.completed),
240
+ "pending": len(self.pending),
241
+ "failed": len(self.failed),
242
+ "pending_items": [a.__class__.__name__ for a in self.pending],
243
+ "failed_items": [a.__class__.__name__ for a in self.failed]
244
+ }
245
+
246
+ # Top‑level driver
247
+ # environment = None # fill this in with your wrapped_classes dict
248
+ #
249
+ # async def main():
250
+ # wrapper = ArtifactUpdaterMultiWrapper(
251
+ # wrapped_classes=environment,
252
+ # debug=True
253
+ # )
254
+ # loop = asyncio.get_running_loop()
255
+ # for sig in (signal.SIGINT, signal.SIGTERM):
256
+ # loop.add_signal_handler(sig, lambda: asyncio.create_task(wrapper.shutdown()))
257
+ # await wrapper.update_data("your_data_type")
258
+ #
259
+ # if __name__ == "__main__":
260
+ # asyncio.run(main())
262
261
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  import datetime
3
5
  import logging
@@ -6,10 +8,10 @@ from typing import Any, Dict, TypeVar
6
8
  from typing import Union, Optional
7
9
 
8
10
  import dask.dataframe as dd
9
- from dask import delayed, compute
11
+ import fsspec
10
12
  import pandas as pd
13
+ from dask import delayed, compute
11
14
  from pydantic import BaseModel
12
- import fsspec
13
15
 
14
16
  from sibi_dst.df_helper.core import QueryConfig, ParamsConfig, FilterHandler
15
17
  from sibi_dst.utils import Logger
@@ -45,7 +47,7 @@ class DfHelper:
45
47
  :ivar df: The DataFrame currently being processed or loaded.
46
48
  :type df: Union[dd.DataFrame, pd.DataFrame]
47
49
  :ivar backend_django: Configuration for interacting with Django database backends.
48
- :type backend_django: Optional[DjangoConnectionConfig]
50
+ :type backend_connection: Optional[DjangoConnectionConfig]
49
51
  :ivar _backend_query: Internal configuration for query handling.
50
52
  :type _backend_query: Optional[QueryConfig]
51
53
  :ivar _backend_params: Internal parameters configuration for DataFrame handling.
@@ -54,8 +56,6 @@ class DfHelper:
54
56
  :type backend_parquet: Optional[ParquetConfig]
55
57
  :ivar backend_http: Configuration for interacting with HTTP-based backends.
56
58
  :type backend_http: Optional[HttpConfig]
57
- :ivar backend_sqlalchemy: Configuration for interacting with SQLAlchemy-based databases.
58
- :type backend_sqlalchemy: Optional[SqlAlchemyConnectionConfig]
59
59
  :ivar parquet_filename: The filename for a Parquet file, if applicable.
60
60
  :type parquet_filename: str
61
61
  :ivar logger: Logger instance used for debugging and information logging.
@@ -64,12 +64,11 @@ class DfHelper:
64
64
  :type default_config: Dict
65
65
  """
66
66
  df: Union[dd.DataFrame, pd.DataFrame] = None
67
- backend_django: Optional[DjangoConnectionConfig] = None
67
+ backend_db_connection: Optional[Union[DjangoConnectionConfig | SqlAlchemyConnectionConfig]] = None
68
68
  _backend_query: Optional[QueryConfig] = None
69
69
  _backend_params: Optional[ParamsConfig] = None
70
70
  backend_parquet: Optional[ParquetConfig] = None
71
71
  backend_http: Optional[HttpConfig] = None
72
- backend_sqlalchemy: Optional[SqlAlchemyConnectionConfig] = None
73
72
  parquet_filename: str = None
74
73
  logger: Logger
75
74
  default_config: Dict = None
@@ -91,7 +90,7 @@ class DfHelper:
91
90
  self.filesystem_options = kwargs.pop('filesystem_options', {})
92
91
  kwargs.setdefault("live", True)
93
92
  kwargs.setdefault("logger", self.logger)
94
- self.fs =kwargs.setdefault("fs", fsspec.filesystem('file'))
93
+ self.fs = kwargs.setdefault("fs", fsspec.filesystem('file'))
95
94
  self.__post_init(**kwargs)
96
95
 
97
96
  def __str__(self):
@@ -100,6 +99,34 @@ class DfHelper:
100
99
  def __call__(self, **options):
101
100
  return self.load(**options)
102
101
 
102
+ def __enter__(self):
103
+ return self
104
+
105
+ def __exit__(self, exc_type, exc_value, traceback):
106
+ self.__cleanup()
107
+ return False
108
+
109
+ def __cleanup(self):
110
+ """
111
+ Clean up resources when exiting the context manager.
112
+ This method is called when the context manager exits.
113
+ """
114
+
115
+ if self.backend_db_connection:
116
+ if getattr(self.backend_db_connection, "dispose_idle_connections", None):
117
+ self.backend_db_connection.dispose_idle_connections()
118
+ if getattr(self.backend_db_connection, "close", None):
119
+ self.backend_db_connection.close()
120
+
121
+ self.backend_db_connection = None
122
+
123
+ if self.backend_parquet:
124
+ self.backend_parquet = None
125
+ if self.backend_http:
126
+ self.backend_http = None
127
+ self._backend_query = None
128
+ self._backend_params = None
129
+
103
130
  def __post_init(self, **kwargs):
104
131
  """
105
132
  Initializes backend-specific configurations based on the provided backend type and other
@@ -111,20 +138,19 @@ class DfHelper:
111
138
  Additional parameters for specific backend types are extracted here.
112
139
  :return: None
113
140
  """
114
- self.logger.debug(f"backend used: {self.backend}")
115
- self.logger.debug(f"kwargs passed to backend plugins: {kwargs}")
141
+ # self.logger.debug(f"backend used: {self.backend}")
142
+ # self.logger.debug(f"kwargs passed to backend plugins: {kwargs}")
116
143
  self._backend_query = self.__get_config(QueryConfig, kwargs)
117
144
  self._backend_params = self.__get_config(ParamsConfig, kwargs)
118
145
  if self.backend == 'django_db':
119
- self.backend_django = self.__get_config(DjangoConnectionConfig, kwargs)
146
+ self.backend_db_connection = self.__get_config(DjangoConnectionConfig, kwargs)
120
147
  elif self.backend == 'parquet':
121
148
  self.parquet_filename = kwargs.setdefault("parquet_filename", None)
122
149
  self.backend_parquet = ParquetConfig(**kwargs)
123
150
  elif self.backend == 'http':
124
151
  self.backend_http = HttpConfig(**kwargs)
125
152
  elif self.backend == 'sqlalchemy':
126
- self.backend_sqlalchemy = self.__get_config(SqlAlchemyConnectionConfig, kwargs)
127
-
153
+ self.backend_db_connection = self.__get_config(SqlAlchemyConnectionConfig, kwargs)
128
154
 
129
155
  def __get_config(self, model: [T], kwargs: Dict[str, Any]) -> Union[T]:
130
156
  """
@@ -134,11 +160,13 @@ class DfHelper:
134
160
  :param kwargs: The dictionary of keyword arguments.
135
161
  :return: The initialized Pydantic model instance.
136
162
  """
163
+ kwargs.setdefault("debug", self.debug)
164
+ kwargs.setdefault("logger", self.logger)
137
165
  # Extract keys that the model can accept
138
166
  recognized_keys = set(model.__annotations__.keys())
139
167
  self.logger.debug(f"recognized keys: {recognized_keys}")
140
168
  model_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in recognized_keys}
141
- self.logger.debug(f"model_kwargs: {model_kwargs}")
169
+ # self.logger.debug(f"model_kwargs: {model_kwargs}")
142
170
  return model(**model_kwargs)
143
171
 
144
172
  def load_parallel(self, **options):
@@ -171,10 +199,10 @@ class DfHelper:
171
199
  `as_pandas` is set to True, or kept in its native backend format otherwise.
172
200
  """
173
201
  # this will be the universal method to load data from a df irrespective of the backend
174
- df = self.__load(**options)
202
+ self.df = self.__load(**options)
175
203
  if self.as_pandas:
176
- return df.compute()
177
- return df
204
+ return self.df.compute()
205
+ return self.df
178
206
 
179
207
  def __load(self, **options):
180
208
  """
@@ -196,7 +224,7 @@ class DfHelper:
196
224
  """
197
225
  if self.backend == 'django_db':
198
226
  self._backend_params.parse_params(options)
199
- return self.__load_from_db(**options)
227
+ return self.__load_from_django_db(**options)
200
228
  elif self.backend == 'sqlalchemy':
201
229
  self._backend_params.parse_params(options)
202
230
  return self.__load_from_sqlalchemy(**options)
@@ -227,7 +255,7 @@ class DfHelper:
227
255
  try:
228
256
  options.setdefault("debug", self.debug)
229
257
  db_loader = SqlAlchemyLoadFromDb(
230
- self.backend_sqlalchemy,
258
+ self.backend_db_connection,
231
259
  self._backend_query,
232
260
  self._backend_params,
233
261
  self.logger,
@@ -236,6 +264,7 @@ class DfHelper:
236
264
  self.df = db_loader.build_and_load()
237
265
  self.__process_loaded_data()
238
266
  self.__post_process_df()
267
+ self.backend_db_connection.close()
239
268
  self.logger.debug("Data successfully loaded from sqlalchemy database.")
240
269
  except Exception as e:
241
270
  self.logger.debug(f"Failed to load data from sqlalchemy database: {e}: options: {options}")
@@ -243,7 +272,7 @@ class DfHelper:
243
272
 
244
273
  return self.df
245
274
 
246
- def __load_from_db(self, **options) -> Union[pd.DataFrame, dd.DataFrame]:
275
+ def __load_from_django_db(self, **options) -> Union[pd.DataFrame, dd.DataFrame]:
247
276
  """
248
277
  Loads data from a Django database using a specific backend query mechanism. Processes the loaded data
249
278
  and applies further post-processing before returning the dataframe. If the operation fails, an
@@ -258,7 +287,7 @@ class DfHelper:
258
287
  try:
259
288
  options.setdefault("debug", self.debug)
260
289
  db_loader = DjangoLoadFromDb(
261
- self.backend_django,
290
+ self.backend_db_connection,
262
291
  self._backend_query,
263
292
  self._backend_params,
264
293
  self.logger,
@@ -307,6 +336,7 @@ class DfHelper:
307
336
  :raises ValueError: If the lengths of `fieldnames` and `column_names` do not match,
308
337
  or if the specified `index_col` is not found in the DataFrame.
309
338
  """
339
+ self.logger.debug("Post-processing DataFrame.")
310
340
  df_params = self._backend_params.df_params
311
341
  fieldnames = df_params.get("fieldnames", None)
312
342
  index_col = df_params.get("index_col", None)
@@ -357,16 +387,16 @@ class DfHelper:
357
387
 
358
388
  :return: None
359
389
  """
360
- self.logger.debug(f"Type of self.df: {type(self.df)}")
390
+ self.logger.debug(f"Processing loaded data...")
361
391
  if self.df.map_partitions(len).compute().sum() > 0:
362
392
  field_map = self._backend_params.field_map or {}
363
- if isinstance(field_map, dict):
393
+ if isinstance(field_map, dict) and field_map != {}:
364
394
  rename_mapping = {k: v for k, v in field_map.items() if k in self.df.columns}
365
395
  missing_columns = [k for k in field_map.keys() if k not in self.df.columns]
366
396
 
367
397
  if missing_columns:
368
398
  self.logger.warning(
369
- f"The following columns in field_map are not in the DataFrame: {missing_columns}")
399
+ f"The following columns in field_map are not in the DataFrame: {missing_columns}, field map: {field_map}")
370
400
 
371
401
  def rename_columns(df, mapping):
372
402
  return df.rename(columns=mapping)
@@ -376,6 +406,8 @@ class DfHelper:
376
406
  self.df = self.df.map_partitions(rename_columns, mapping=rename_mapping)
377
407
 
378
408
  self.logger.debug("Processing of loaded data completed.")
409
+ else:
410
+ self.logger.debug("DataFrame is empty, skipping processing.")
379
411
 
380
412
  def save_to_parquet(self, parquet_filename: Optional[str] = None, **kwargs):
381
413
  """
@@ -536,14 +568,14 @@ class DfHelper:
536
568
 
537
569
  # Common logic for Django and SQLAlchemy
538
570
  if self.backend == 'django_db':
539
- model_fields = {field.name: field for field in self.backend_django.model._meta.get_fields()}
571
+ model_fields = {field.name: field for field in self.backend_db_connection.model._meta.get_fields()}
540
572
  if mapped_field not in model_fields:
541
573
  raise ValueError(f"Field '{dt_field}' does not exist in the Django model.")
542
574
  field_type = type(model_fields[mapped_field]).__name__
543
575
  is_date_field = field_type == 'DateField'
544
576
  is_datetime_field = field_type == 'DateTimeField'
545
577
  elif self.backend == 'sqlalchemy':
546
- model = self.backend_sqlalchemy.model
578
+ model = self.backend_db_connection.model
547
579
  fields = [column.name for column in model.__table__.columns]
548
580
  if mapped_field not in fields:
549
581
  raise ValueError(f"Field '{dt_field}' does not exist in the SQLAlchemy model.")