sibi-dst 2025.8.7__py3-none-any.whl → 2025.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
1
  import warnings
2
2
 
3
+ from pandas.api.types import is_period_dtype, is_bool_dtype, is_string_dtype
4
+ import pandas as pd
3
5
  import dask.dataframe as dd
4
6
  import pyarrow as pa
5
7
 
@@ -15,6 +17,7 @@ class ParquetSaver(ManagedResource):
15
17
 
16
18
  Assumes `df_result` is a Dask DataFrame.
17
19
  """
20
+ logger_extra = {"sibi_dst_component": __name__}
18
21
 
19
22
  def __init__(
20
23
  self,
@@ -32,6 +35,10 @@ class ParquetSaver(ManagedResource):
32
35
  if "://" in self.parquet_storage_path:
33
36
  self.protocol = self.parquet_storage_path.split(":", 1)[0]
34
37
 
38
+ self.persist = kwargs.get("persist",True)
39
+ self.write_index = kwargs.get("write_index", False)
40
+ self.write_metadata_file = kwargs.get("write_metadata_file", True)
41
+
35
42
  def save_to_parquet(self, output_directory_name: str = "default_output", overwrite: bool = True):
36
43
  """
37
44
  Saves the Dask DataFrame to a Parquet dataset.
@@ -42,17 +49,19 @@ class ParquetSaver(ManagedResource):
42
49
  full_path = f"{self.parquet_storage_path}/{output_directory_name}"
43
50
 
44
51
  if overwrite and self.fs and self.fs.exists(full_path):
45
- self.logger.info(f"Overwrite is True, clearing destination path: {full_path}")
52
+ self.logger.info(f"Overwrite is True, clearing destination path: {full_path}", extra=self.logger_extra)
46
53
  self._clear_directory_safely(full_path)
47
54
 
48
55
  # Ensure the base directory exists after clearing
49
56
  self.fs.mkdirs(full_path, exist_ok=True)
50
57
 
51
58
  schema = self._define_schema()
52
- self.logger.info(f"Saving DataFrame to Parquet dataset at: {full_path}")
59
+ self.logger.info(f"Saving DataFrame to Parquet dataset at: {full_path}", extra=self.logger_extra)
60
+ # 1) Normalize to declared schema (fixes bool→string, Period→string, etc.)
61
+ ddf = self._coerce_ddf_to_schema(self.df_result, schema)
53
62
 
54
- # persist then write (lets the graph be shared if the caller reuses it)
55
- ddf = self.df_result.persist()
63
+ # 2) Persist after coercion so all partitions share the coerced dtypes
64
+ ddf = ddf.persist() if self.persist else ddf
56
65
 
57
66
  try:
58
67
  ddf.to_parquet(
@@ -61,11 +70,12 @@ class ParquetSaver(ManagedResource):
61
70
  schema=schema,
62
71
  overwrite=False, # we've handled deletion already
63
72
  filesystem=self.fs,
64
- write_index=False,
73
+ write_index=self.write_index, # whether to write the index
74
+ write_metadata_file=self.write_metadata_file, # write _metadata for easier reading later
65
75
  )
66
- self.logger.info(f"Successfully saved Parquet dataset to: {full_path}")
76
+ self.logger.info(f"Successfully saved Parquet dataset to: {full_path}", extra=self.logger_extra)
67
77
  except Exception as e:
68
- self.logger.error(f"Failed to save Parquet dataset to {full_path}: {e}")
78
+ self.logger.error(f"Failed to save Parquet dataset to {full_path}: {e}", extra=self.logger_extra)
69
79
  raise
70
80
 
71
81
  def _clear_directory_safely(self, directory: str):
@@ -91,7 +101,7 @@ class ParquetSaver(ManagedResource):
91
101
  else:
92
102
  self.fs.rm(path, recursive=False)
93
103
  except Exception as e:
94
- self.logger.warning(f"Failed to delete '{path}': {e}")
104
+ self.logger.warning(f"Failed to delete '{path}': {e}", extra=self.logger_extra)
95
105
  # remove the (now empty) directory if present
96
106
  try:
97
107
  self.fs.rm(directory, recursive=False)
@@ -120,4 +130,95 @@ class ParquetSaver(ManagedResource):
120
130
  pa.field(c, pandas_dtype_to_pa.get(str(d), pa.string()))
121
131
  for c, d in self.df_result.dtypes.items()
122
132
  ]
123
- return pa.schema(fields)
133
+ return pa.schema(fields)
134
+
135
+
136
+ def _coerce_ddf_to_schema(self, ddf: dd.DataFrame, schema: pa.Schema) -> dd.DataFrame:
137
+ """
138
+ Coerce Dask DataFrame columns to match the provided PyArrow schema.
139
+ - Ensures cross-partition consistency.
140
+ - Converts troublesome dtypes (Period, mixed object/bool) to the declared type.
141
+ """
142
+ # Build a map: name -> target kind
143
+ target = {field.name: field.type for field in schema}
144
+
145
+ def _coerce_partition(pdf: pd.DataFrame) -> pd.DataFrame:
146
+ for col, typ in target.items():
147
+ if col not in pdf.columns:
148
+ continue
149
+
150
+ pa_type = typ
151
+
152
+ # String targets
153
+ if pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
154
+ # Convert Period or any dtype to string with NA-preservation
155
+ s = pdf[col]
156
+ if is_period_dtype(s):
157
+ pdf[col] = s.astype(str)
158
+ elif not is_string_dtype(s):
159
+ # astype("string") keeps NA; str(s) can produce "NaT" strings
160
+ try:
161
+ pdf[col] = s.astype("string")
162
+ except Exception:
163
+ pdf[col] = s.astype(str).astype("string")
164
+ continue
165
+
166
+ # Boolean targets
167
+ if pa.types.is_boolean(pa_type):
168
+ s = pdf[col]
169
+ # Allow object/bool mixtures; coerce via pandas nullable boolean then to bool
170
+ try:
171
+ pdf[col] = s.astype("boolean").astype(bool)
172
+ except Exception:
173
+ pdf[col] = s.astype(bool)
174
+ continue
175
+
176
+ # Integer targets
177
+ if pa.types.is_integer(pa_type):
178
+ s = pdf[col]
179
+ # Go through pandas nullable Int64 to preserve NA, then to int64 if clean
180
+ s2 = pd.to_numeric(s, errors="coerce").astype("Int64")
181
+ # If there are no nulls, downcast to numpy int64 for speed
182
+ if not s2.isna().any():
183
+ s2 = s2.astype("int64")
184
+ pdf[col] = s2
185
+ continue
186
+
187
+ # Floating targets
188
+ if pa.types.is_floating(pa_type):
189
+ pdf[col] = pd.to_numeric(pdf[col], errors="coerce").astype("float64")
190
+ continue
191
+
192
+ # Timestamp[ns] (optionally with tz)
193
+ if pa.types.is_timestamp(pa_type):
194
+ # If tz in Arrow type, you may want to localize; here we just ensure ns
195
+ pdf[col] = pd.to_datetime(pdf[col], errors="coerce")
196
+ continue
197
+
198
+ # Fallback: leave as-is
199
+ return pdf
200
+
201
+ # Provide a meta with target dtypes to avoid meta mismatch warnings
202
+ meta = {}
203
+ for name, typ in target.items():
204
+ # Rough meta mapping; Arrow large_string vs string both → 'string'
205
+ if pa.types.is_string(typ) or pa.types.is_large_string(typ):
206
+ meta[name] = pd.Series([], dtype="string")
207
+ elif pa.types.is_boolean(typ):
208
+ meta[name] = pd.Series([], dtype="bool")
209
+ elif pa.types.is_integer(typ):
210
+ meta[name] = pd.Series([], dtype="Int64") # nullable int
211
+ elif pa.types.is_floating(typ):
212
+ meta[name] = pd.Series([], dtype="float64")
213
+ elif pa.types.is_timestamp(typ):
214
+ meta[name] = pd.Series([], dtype="datetime64[ns]")
215
+ else:
216
+ meta[name] = pd.Series([], dtype="object")
217
+
218
+ # Start from current meta and update known columns
219
+ new_meta = ddf._meta.copy()
220
+ for k, v in meta.items():
221
+ if k in new_meta.columns:
222
+ new_meta[k] = v
223
+
224
+ return ddf.map_partitions(_coerce_partition, meta=new_meta)
@@ -0,0 +1,5 @@
1
+ from .sse_runner import SSERunner, _as_sse_msg
2
+
3
+ __all__ = [
4
+ "SSERunner", "_as_sse_msg"
5
+ ]
@@ -0,0 +1,82 @@
1
+ # jobs.py
2
+ import asyncio, json, uuid
3
+ import contextlib
4
+ import os
5
+
6
+ import redis.asyncio as redis
7
+ from fastapi import APIRouter
8
+ from sse_starlette.sse import EventSourceResponse
9
+ host = os.getenv("REDIS_HOST", "0.0.0.0")
10
+ port = int(os.getenv("REDIS_PORT", 6379))
11
+ db = int(os.getenv("REDIS_DB", 0))
12
+ router = APIRouter(prefix="/jobs", tags=["Jobs"])
13
+ r = redis.Redis(host=host, port=port, db=db, decode_responses=True) # strings for pubsub
14
+
15
+ CHANNEL = lambda job_id: f"job:{job_id}:events"
16
+ KEY_STATUS = lambda job_id: f"job:{job_id}:status" # JSON blob with state/progress
17
+ KEY_RESULT = lambda job_id: f"job:{job_id}:result" # final payload
18
+
19
+ async def publish(job_id: str, event: str, data: dict):
20
+ msg = json.dumps({"event": event, "data": data})
21
+ await r.publish(CHANNEL(job_id), msg)
22
+ # store last status
23
+ await r.set(KEY_STATUS(job_id), json.dumps({"event": event, "data": data}))
24
+
25
+ # ---- Worker entry (can live in a separate process) ----
26
+ async def run_job(job_id: str):
27
+ try:
28
+ await publish(job_id, "progress", {"message": "Initializing..."})
29
+ # ... do actual work, emit more progress
30
+ await asyncio.sleep(0.2)
31
+ # compute result
32
+ result = [{"id": 1, "ok": True}]
33
+ await r.set(KEY_RESULT(job_id), json.dumps(result), ex=3600)
34
+ await publish(job_id, "complete", {"records": len(result)})
35
+ except Exception as e:
36
+ await publish(job_id, "error", {"detail": str(e)})
37
+
38
+ # ---- API ----
39
+ @router.post("/start")
40
+ async def start_job():
41
+ job_id = str(uuid.uuid4())
42
+ # enqueue: prefer Celery/RQ/etc. For demo we detach a task.
43
+ asyncio.create_task(run_job(job_id))
44
+ return {"job_id": job_id}
45
+
46
+ @router.get("/{job_id}/stream")
47
+ async def stream(job_id: str):
48
+ pubsub = r.pubsub()
49
+ await pubsub.subscribe(CHANNEL(job_id))
50
+
51
+ async def gen():
52
+ try:
53
+ # emit latest known status immediately, if any
54
+ if (s := await r.get(KEY_STATUS(job_id))):
55
+ yield {"event": "progress", "data": s}
56
+ while True:
57
+ msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=30.0)
58
+ if msg and msg["type"] == "message":
59
+ payload = msg["data"] # already a JSON string
60
+ yield {"event": "message", "data": payload}
61
+ await asyncio.sleep(0.01)
62
+ finally:
63
+ with contextlib.suppress(Exception):
64
+ await pubsub.unsubscribe(CHANNEL(job_id))
65
+ await pubsub.close()
66
+
67
+ return EventSourceResponse(
68
+ gen(),
69
+ ping=15,
70
+ headers={"Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no"},
71
+ )
72
+
73
+ @router.get("/{job_id}/status")
74
+ async def status(job_id: str):
75
+ s = await r.get(KEY_STATUS(job_id))
76
+ done = await r.exists(KEY_RESULT(job_id))
77
+ return {"job_id": job_id, "status": json.loads(s) if s else None, "done": bool(done)}
78
+
79
+ @router.get("/{job_id}/result")
80
+ async def result(job_id: str):
81
+ data = await r.get(KEY_RESULT(job_id))
82
+ return {"job_id": job_id, "result": json.loads(data) if data else None}
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+ import asyncio, contextlib, inspect, json
3
+ from typing import Any, Awaitable, Callable, Dict, Optional, Union
4
+ from fastapi import Request
5
+ from sse_starlette.sse import EventSourceResponse
6
+ from sibi_dst.utils import Logger
7
+
8
+ Payload = Union[str, bytes, dict, list, None]
9
+ Task2 = Callable[[asyncio.Queue, str], Awaitable[Payload]]
10
+ Task3 = Callable[[asyncio.Queue, str, Dict[str, Any]], Awaitable[Payload]]
11
+ TaskFn = Union[Task2, Task3]
12
+
13
+ def _as_sse_msg(event: str, data: Any) -> dict:
14
+ return {"event": event, "data": json.dumps(data) if not isinstance(data, (str, bytes)) else data}
15
+
16
+ class SSERunner:
17
+ def __init__(self, *, task: TaskFn, logger: Logger, ping: int = 15,
18
+ headers: Optional[dict] = None, auto_complete: bool = True) -> None:
19
+ self.task = task
20
+ self.logger = logger
21
+ self.ping = ping
22
+ self.headers = headers or {"Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no"}
23
+ self.auto_complete = auto_complete
24
+ self._expects_ctx = len(inspect.signature(task).parameters) >= 3
25
+
26
+ async def _call_task(self, queue: asyncio.Queue, task_id: str, ctx: Dict[str, Any]) -> Payload:
27
+ if self._expects_ctx:
28
+ return await self.task(queue, task_id, ctx) # type: ignore[misc]
29
+ return await self.task(queue, task_id) # type: ignore[misc]
30
+
31
+ async def _worker(self, queue: asyncio.Queue, task_id: str, ctx: Dict[str, Any]) -> None:
32
+ self.logger.info(f"SSE {task_id}: start")
33
+ try:
34
+ await queue.put(_as_sse_msg("progress", {"message": "Task started"}))
35
+ payload = await self._call_task(queue, task_id, ctx)
36
+ if self.auto_complete:
37
+ final = payload if payload is not None else {"status": "complete"}
38
+ await queue.put(_as_sse_msg("complete", final))
39
+ self.logger.info(f"SSE {task_id}: complete")
40
+ except asyncio.CancelledError:
41
+ raise
42
+ except Exception as e:
43
+ self.logger.error(f"SSE {task_id} failed: {e}", exc_info=True)
44
+ await queue.put(_as_sse_msg("error", {"detail": str(e)}))
45
+ finally:
46
+ await queue.put(None)
47
+
48
+ def endpoint(self):
49
+ async def handler(request: Request): # <-- only Request
50
+ queue: asyncio.Queue = asyncio.Queue()
51
+ task_id = str(asyncio.get_running_loop().time()).replace(".", "")
52
+
53
+ ctx: Dict[str, Any] = {
54
+ "path": dict(request.path_params), # <-- pull path params here
55
+ "query": dict(request.query_params),
56
+ "method": request.method,
57
+ }
58
+ if request.headers.get("content-type", "").startswith("application/json"):
59
+ try:
60
+ ctx["body"] = await request.json()
61
+ except Exception:
62
+ ctx["body"] = None
63
+
64
+ worker = asyncio.create_task(self._worker(queue, task_id, ctx))
65
+
66
+ async def gen():
67
+ try:
68
+ while True:
69
+ msg = await queue.get()
70
+ if msg is None:
71
+ break
72
+ yield msg
73
+ finally:
74
+ if not worker.done():
75
+ worker.cancel()
76
+ with contextlib.suppress(Exception):
77
+ await worker
78
+
79
+ return EventSourceResponse(gen(), ping=self.ping, headers=self.headers)
80
+ return handler
81
+
82
+ __all__ = ["SSERunner", "_as_sse_msg"]
@@ -1,4 +1,7 @@
1
1
  from __future__ import annotations
2
+
3
+ import asyncio
4
+
2
5
  import pandas as pd
3
6
  import dask.dataframe as dd
4
7
  from typing import Iterable, Optional, List, Tuple, Union
@@ -192,4 +195,38 @@ class HiveDatePartitionedStore:
192
195
  clauses.append([("yyyy","==",y)])
193
196
  for m in range(1, eM):
194
197
  clauses.append([("yyyy","==",eY),("mm","==",m)])
195
- return clauses
198
+ return clauses
199
+
200
+ async def write_async(
201
+ self,
202
+ df: dd.DataFrame,
203
+ *,
204
+ repartition: int | None = None,
205
+ overwrite: bool = False,
206
+ timeout: float | None = None,
207
+ ) -> None:
208
+ async def _run():
209
+ return await asyncio.to_thread(self.write, df, repartition=repartition, overwrite=overwrite)
210
+
211
+ return await (asyncio.wait_for(_run(), timeout) if timeout else _run())
212
+
213
+ async def read_range_async(
214
+ self,
215
+ start, end, *, columns: Iterable[str] | None = None, timeout: float | None = None
216
+ ) -> dd.DataFrame:
217
+ async def _run():
218
+ return await asyncio.to_thread(self.read_range, start, end, columns=columns)
219
+
220
+ return await (asyncio.wait_for(_run(), timeout) if timeout else _run())
221
+
222
+ async def read_month_async(self, year: int, month: int, *, columns=None, timeout: float | None = None):
223
+ async def _run():
224
+ return await asyncio.to_thread(self.read_month, year, month, columns=columns)
225
+
226
+ return await (asyncio.wait_for(_run(), timeout) if timeout else _run())
227
+
228
+ async def read_day_async(self, year: int, month: int, day: int, *, columns=None, timeout: float | None = None):
229
+ async def _run():
230
+ return await asyncio.to_thread(self.read_day, year, month, day, columns=columns)
231
+
232
+ return await (asyncio.wait_for(_run(), timeout) if timeout else _run())