sibi-dst 2025.8.6__py3-none-any.whl → 2025.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -363,349 +363,3 @@ class Logger:
363
363
  self._otel_initialized_names.add(self.logger_name)
364
364
  self._core_logger.info("OpenTelemetry logging/tracing initialized.")
365
365
 
366
- # from __future__ import annotations
367
- #
368
- # import logging
369
- # import os
370
- # import sys
371
- # import time
372
- # from contextlib import contextmanager
373
- # from logging import LoggerAdapter
374
- # from logging.handlers import RotatingFileHandler
375
- # from typing import Optional, Dict, Any
376
- #
377
- # # OpenTelemetry imports
378
- # from opentelemetry import trace
379
- # from opentelemetry._logs import set_logger_provider
380
- # from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter
381
- # from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
382
- # from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
383
- # from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
384
- # from opentelemetry.sdk.resources import Resource
385
- # from opentelemetry.sdk.trace import TracerProvider
386
- # from opentelemetry.sdk.trace.export import BatchSpanProcessor
387
- # from opentelemetry.trace import Tracer as OTelTracer
388
- #
389
- #
390
- # class Logger:
391
- # """
392
- # Handles the creation and management of logging, with optional OpenTelemetry integration.
393
- # """
394
- #
395
- # DEBUG = logging.DEBUG
396
- # INFO = logging.INFO
397
- # WARNING = logging.WARNING
398
- # ERROR = logging.ERROR
399
- # CRITICAL = logging.CRITICAL
400
- #
401
- # def __init__(
402
- # self,
403
- # log_dir: str,
404
- # logger_name: str,
405
- # log_file: str,
406
- # log_level: int = logging.DEBUG,
407
- # enable_otel: bool = False,
408
- # otel_service_name: Optional[str] = None,
409
- # otel_stream_name: Optional[str] = None,
410
- # otel_endpoint: str = "0.0.0.0:4317",
411
- # otel_insecure: bool = False,
412
- # ):
413
- # self.log_dir = log_dir
414
- # self.logger_name = logger_name
415
- # self.log_file = log_file
416
- # self.log_level = log_level
417
- #
418
- # self.enable_otel = enable_otel
419
- # self.otel_service_name = (otel_service_name or logger_name).strip() or "app"
420
- # self.otel_stream_name = (otel_stream_name or "").strip() or None
421
- # self.otel_endpoint = otel_endpoint
422
- # self.otel_insecure = otel_insecure
423
- #
424
- # self.logger_provider: Optional[LoggerProvider] = None
425
- # self.tracer_provider: Optional[TracerProvider] = None
426
- # self.tracer: Optional[OTelTracer] = None
427
- #
428
- # # Internal logger vs public (adapter) logger
429
- # self._core_logger: logging.Logger = logging.getLogger(self.logger_name)
430
- # self.logger: logging.Logger | LoggerAdapter = self._core_logger
431
- #
432
- # self._setup()
433
- #
434
- # # -------------------------
435
- # # Public API
436
- # # -------------------------
437
- #
438
- # @classmethod
439
- # def default_logger(
440
- # cls,
441
- # log_dir: str = "./logs/",
442
- # logger_name: Optional[str] = None,
443
- # log_file: Optional[str] = None,
444
- # log_level: int = logging.INFO,
445
- # enable_otel: bool = False,
446
- # otel_service_name: Optional[str] = None,
447
- # otel_stream_name: Optional[str] = None,
448
- # otel_endpoint: str = "0.0.0.0:4317",
449
- # otel_insecure: bool = False,
450
- # ) -> "Logger":
451
- # try:
452
- # frame = sys._getframe(1)
453
- # caller_name = frame.f_globals.get("__name__", "default_logger")
454
- # except (AttributeError, ValueError):
455
- # caller_name = "default_logger"
456
- #
457
- # logger_name = logger_name or caller_name
458
- # log_file = log_file or logger_name
459
- #
460
- # return cls(
461
- # log_dir=log_dir,
462
- # logger_name=logger_name,
463
- # log_file=log_file,
464
- # log_level=log_level,
465
- # enable_otel=enable_otel,
466
- # otel_service_name=otel_service_name,
467
- # otel_stream_name=otel_stream_name,
468
- # otel_endpoint=otel_endpoint,
469
- # otel_insecure=otel_insecure,
470
- # )
471
- #
472
- # def shutdown(self):
473
- # """Flush and shut down logging and tracing providers, then Python logging."""
474
- # try:
475
- # if self.enable_otel:
476
- # if self.logger_provider:
477
- # try:
478
- # self._core_logger.info("Flushing OpenTelemetry logs...")
479
- # self.logger_provider.force_flush()
480
- # except Exception:
481
- # pass
482
- # try:
483
- # self._core_logger.info("Shutting down OpenTelemetry logs...")
484
- # self.logger_provider.shutdown()
485
- # except Exception:
486
- # pass
487
- #
488
- # if self.tracer_provider:
489
- # try:
490
- # self._core_logger.info("Flushing OpenTelemetry traces...")
491
- # self.tracer_provider.force_flush()
492
- # except Exception:
493
- # pass
494
- # try:
495
- # self._core_logger.info("Shutting down OpenTelemetry traces...")
496
- # self.tracer_provider.shutdown()
497
- # except Exception:
498
- # pass
499
- # finally:
500
- # logging.shutdown()
501
- #
502
- # def set_level(self, level: int):
503
- # """Set the logging level for the logger."""
504
- # self._core_logger.setLevel(level)
505
- #
506
- # # passthrough convenience methods
507
- # def _log(self, level: int, msg: str, *args, **kwargs):
508
- # extra = kwargs.pop("extra", None)
509
- # if extra is not None:
510
- # # Always emit via an adapter so extras survive to OTel attributes
511
- # if isinstance(self.logger, LoggerAdapter):
512
- # merged = {**self.logger.extra, **extra}
513
- # LoggerAdapter(self.logger.logger, merged).log(level, msg, *args, **kwargs)
514
- # else:
515
- # LoggerAdapter(self.logger, extra).log(level, msg, *args, **kwargs)
516
- # else:
517
- # self.logger.log(level, msg, *args, **kwargs)
518
- #
519
- # def debug(self, msg: str, *args, **kwargs):
520
- # self._log(logging.DEBUG, msg, *args, **kwargs)
521
- #
522
- # def info(self, msg: str, *args, **kwargs):
523
- # self._log(logging.INFO, msg, *args, **kwargs)
524
- #
525
- # def warning(self, msg: str, *args, **kwargs):
526
- # self._log(logging.WARNING, msg, *args, **kwargs)
527
- #
528
- # def error(self, msg: str, *args, **kwargs):
529
- # self._log(logging.ERROR, msg, *args, **kwargs)
530
- #
531
- # def critical(self, msg: str, *args, **kwargs):
532
- # self._log(logging.CRITICAL, msg, *args, **kwargs)
533
- #
534
- #
535
- # def bind(self, **extra: Any) -> LoggerAdapter:
536
- # """
537
- # Return a new LoggerAdapter bound with extra context, merging with existing extras if present.
538
- # Example:
539
- # api_log = logger.bind(component="api", request_id=req.id)
540
- # api_log.info("processing")
541
- # """
542
- # if isinstance(self.logger, LoggerAdapter):
543
- # merged = {**self.logger.extra, **extra}
544
- # return LoggerAdapter(self.logger.logger, merged)
545
- # return LoggerAdapter(self.logger, extra)
546
- #
547
- # @contextmanager
548
- # def bound(self, **extra: Any):
549
- # """
550
- # Context manager that yields a bound adapter for temporary context.
551
- # Example:
552
- # with logger.bound(order_id=oid) as log:
553
- # log.info("starting")
554
- # ...
555
- # """
556
- # adapter = self.bind(**extra)
557
- # try:
558
- # yield adapter
559
- # finally:
560
- # # nothing to clean up; adapter is ephemeral
561
- # pass
562
- #
563
- # def start_span(self, name: str, attributes: Optional[Dict[str, Any]] = None):
564
- # """
565
- # Start a span as a context manager.
566
- #
567
- # Usage:
568
- # with logger.start_span("my-task", {"key": "value"}) as span:
569
- # ...
570
- # """
571
- # if not self.enable_otel or not self.tracer:
572
- # self.warning("Tracing is disabled or not initialized. Cannot start span.")
573
- # from contextlib import nullcontext
574
- # return nullcontext()
575
- #
576
- # cm = self.tracer.start_as_current_span(name)
577
- #
578
- # class _SpanCtx:
579
- # def __enter__(_self):
580
- # span = cm.__enter__()
581
- # if attributes:
582
- # for k, v in attributes.items():
583
- # try:
584
- # span.set_attribute(k, v)
585
- # except Exception:
586
- # pass
587
- # return span
588
- #
589
- # def __exit__(_self, exc_type, exc, tb):
590
- # return cm.__exit__(exc_type, exc, tb)
591
- #
592
- # return _SpanCtx()
593
- #
594
- # def trace_function(self, span_name: Optional[str] = None):
595
- # """Decorator to trace a function with an optional custom span name."""
596
- # def decorator(func):
597
- # def wrapper(*args, **kwargs):
598
- # name = span_name or func.__name__
599
- # with self.start_span(name):
600
- # return func(*args, **kwargs)
601
- # return wrapper
602
- # return decorator
603
- #
604
- # # -------------------------
605
- # # Internal setup
606
- # # -------------------------
607
- #
608
- # def _setup(self):
609
- # """Set up core logger, handlers, and optional OTel."""
610
- # # Configure base logger
611
- # self._core_logger = logging.getLogger(self.logger_name)
612
- # self._core_logger.setLevel(self.log_level)
613
- # self._core_logger.propagate = False
614
- #
615
- # # Standard (file + console) handlers
616
- # self._setup_standard_handlers()
617
- #
618
- # # OTel handlers (logs + traces)
619
- # if self.enable_otel:
620
- # self._setup_otel_handler()
621
- #
622
- # # Public-facing logger (optionally wrapped with adapter extras)
623
- # if self.enable_otel and self.otel_stream_name:
624
- # attributes = {
625
- # "log_stream": self.otel_stream_name,
626
- # "log_service_name": self.otel_service_name,
627
- # "logger_name": self.logger_name,
628
- # }
629
- # self.logger = LoggerAdapter(self._core_logger, extra=attributes)
630
- # else:
631
- # self.logger = self._core_logger
632
- #
633
- # def _setup_standard_handlers(self):
634
- # """Sets up file and console handlers with deduping."""
635
- # os.makedirs(self.log_dir, exist_ok=True)
636
- # calling_script = os.path.splitext(os.path.basename(sys.argv[0]))[0]
637
- # log_file_path = os.path.join(self.log_dir, f"{self.log_file}_{calling_script}.log")
638
- #
639
- # formatter = logging.Formatter(
640
- # "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
641
- # datefmt="%Y-%m-%d %H:%M:%S",
642
- # )
643
- # formatter.converter = time.gmtime # UTC timestamps
644
- #
645
- # # File handler (dedupe by filename)
646
- # if not any(
647
- # isinstance(h, RotatingFileHandler) and getattr(h, "baseFilename", "") == os.path.abspath(log_file_path)
648
- # for h in self._core_logger.handlers
649
- # ):
650
- # file_handler = RotatingFileHandler(
651
- # log_file_path, maxBytes=5 * 1024 * 1024, backupCount=5, delay=True
652
- # )
653
- # file_handler.setFormatter(formatter)
654
- # self._core_logger.addHandler(file_handler)
655
- #
656
- # # Console handler (dedupe by stream)
657
- # if not any(
658
- # isinstance(h, logging.StreamHandler) and getattr(h, "stream", None) is sys.stdout
659
- # for h in self._core_logger.handlers
660
- # ):
661
- # console_handler = logging.StreamHandler(sys.stdout)
662
- # console_handler.setFormatter(formatter)
663
- # self._core_logger.addHandler(console_handler)
664
- #
665
- # def _normalize_otlp_endpoint(self, ep: str) -> str:
666
- # """Ensure OTLP gRPC endpoint has a scheme."""
667
- # if "://" not in ep:
668
- # ep = ("http://" if self.otel_insecure else "https://") + ep
669
- # return ep
670
- #
671
- # def _setup_otel_handler(self):
672
- # """
673
- # Configure OpenTelemetry providers, exporters, and attach a LoggingHandler.
674
- # - service.name: used by most backends (incl. OpenObserve) to group streams/services.
675
- # - log.stream: extra attribute you can filter on in the backend.
676
- # """
677
- # resource_attrs = {
678
- # "service.name": self.otel_service_name,
679
- # "logger.name": self.logger_name,
680
- # }
681
- # if self.otel_stream_name:
682
- # resource_attrs["log.stream"] = self.otel_stream_name
683
- #
684
- # resource = Resource.create(resource_attrs)
685
- #
686
- # # Logs provider
687
- # self.logger_provider = LoggerProvider(resource=resource)
688
- # set_logger_provider(self.logger_provider)
689
- #
690
- # # Traces provider
691
- # self.tracer_provider = TracerProvider(resource=resource)
692
- # trace.set_tracer_provider(self.tracer_provider)
693
- #
694
- # endpoint = self._normalize_otlp_endpoint(self.otel_endpoint)
695
- #
696
- # # Logs exporter + processor
697
- # log_exporter = OTLPLogExporter(endpoint=endpoint, insecure=self.otel_insecure)
698
- # self.logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter))
699
- #
700
- # # Traces exporter + processor
701
- # span_exporter = OTLPSpanExporter(endpoint=endpoint, insecure=self.otel_insecure)
702
- # self.tracer_provider.add_span_processor(BatchSpanProcessor(span_exporter))
703
- # self.tracer = trace.get_tracer(self.logger_name, tracer_provider=self.tracer_provider)
704
- #
705
- # # Attach OTel LoggingHandler once
706
- # if not any(type(h).__name__ == "LoggingHandler" for h in self._core_logger.handlers):
707
- # otel_handler = LoggingHandler(level=logging.NOTSET, logger_provider=self.logger_provider)
708
- # self._core_logger.addHandler(otel_handler)
709
- #
710
- # self._core_logger.info("OpenTelemetry logging and tracing enabled and attached.")
711
- #
@@ -1,5 +1,7 @@
1
1
  import warnings
2
2
 
3
+ from pandas.api.types import is_period_dtype, is_bool_dtype, is_string_dtype
4
+ import pandas as pd
3
5
  import dask.dataframe as dd
4
6
  import pyarrow as pa
5
7
 
@@ -15,6 +17,7 @@ class ParquetSaver(ManagedResource):
15
17
 
16
18
  Assumes `df_result` is a Dask DataFrame.
17
19
  """
20
+ logger_extra = {"sibi_dst_component": __name__}
18
21
 
19
22
  def __init__(
20
23
  self,
@@ -32,6 +35,10 @@ class ParquetSaver(ManagedResource):
32
35
  if "://" in self.parquet_storage_path:
33
36
  self.protocol = self.parquet_storage_path.split(":", 1)[0]
34
37
 
38
+ self.persist = kwargs.get("persist",True)
39
+ self.write_index = kwargs.get("write_index", False)
40
+ self.write_metadata_file = kwargs.get("write_metadata_file", True)
41
+
35
42
  def save_to_parquet(self, output_directory_name: str = "default_output", overwrite: bool = True):
36
43
  """
37
44
  Saves the Dask DataFrame to a Parquet dataset.
@@ -42,17 +49,19 @@ class ParquetSaver(ManagedResource):
42
49
  full_path = f"{self.parquet_storage_path}/{output_directory_name}"
43
50
 
44
51
  if overwrite and self.fs and self.fs.exists(full_path):
45
- self.logger.info(f"Overwrite is True, clearing destination path: {full_path}")
52
+ self.logger.info(f"Overwrite is True, clearing destination path: {full_path}", extra=self.logger_extra)
46
53
  self._clear_directory_safely(full_path)
47
54
 
48
55
  # Ensure the base directory exists after clearing
49
56
  self.fs.mkdirs(full_path, exist_ok=True)
50
57
 
51
58
  schema = self._define_schema()
52
- self.logger.info(f"Saving DataFrame to Parquet dataset at: {full_path}")
59
+ self.logger.info(f"Saving DataFrame to Parquet dataset at: {full_path}", extra=self.logger_extra)
60
+ # 1) Normalize to declared schema (fixes bool→string, Period→string, etc.)
61
+ ddf = self._coerce_ddf_to_schema(self.df_result, schema)
53
62
 
54
- # persist then write (lets the graph be shared if the caller reuses it)
55
- ddf = self.df_result.persist()
63
+ # 2) Persist after coercion so all partitions share the coerced dtypes
64
+ ddf = ddf.persist() if self.persist else ddf
56
65
 
57
66
  try:
58
67
  ddf.to_parquet(
@@ -61,11 +70,12 @@ class ParquetSaver(ManagedResource):
61
70
  schema=schema,
62
71
  overwrite=False, # we've handled deletion already
63
72
  filesystem=self.fs,
64
- write_index=False,
73
+ write_index=self.write_index, # whether to write the index
74
+ write_metadata_file=self.write_metadata_file, # write _metadata for easier reading later
65
75
  )
66
- self.logger.info(f"Successfully saved Parquet dataset to: {full_path}")
76
+ self.logger.info(f"Successfully saved Parquet dataset to: {full_path}", extra=self.logger_extra)
67
77
  except Exception as e:
68
- self.logger.error(f"Failed to save Parquet dataset to {full_path}: {e}")
78
+ self.logger.error(f"Failed to save Parquet dataset to {full_path}: {e}", extra=self.logger_extra)
69
79
  raise
70
80
 
71
81
  def _clear_directory_safely(self, directory: str):
@@ -91,7 +101,7 @@ class ParquetSaver(ManagedResource):
91
101
  else:
92
102
  self.fs.rm(path, recursive=False)
93
103
  except Exception as e:
94
- self.logger.warning(f"Failed to delete '{path}': {e}")
104
+ self.logger.warning(f"Failed to delete '{path}': {e}", extra=self.logger_extra)
95
105
  # remove the (now empty) directory if present
96
106
  try:
97
107
  self.fs.rm(directory, recursive=False)
@@ -120,4 +130,95 @@ class ParquetSaver(ManagedResource):
120
130
  pa.field(c, pandas_dtype_to_pa.get(str(d), pa.string()))
121
131
  for c, d in self.df_result.dtypes.items()
122
132
  ]
123
- return pa.schema(fields)
133
+ return pa.schema(fields)
134
+
135
+
136
+ def _coerce_ddf_to_schema(self, ddf: dd.DataFrame, schema: pa.Schema) -> dd.DataFrame:
137
+ """
138
+ Coerce Dask DataFrame columns to match the provided PyArrow schema.
139
+ - Ensures cross-partition consistency.
140
+ - Converts troublesome dtypes (Period, mixed object/bool) to the declared type.
141
+ """
142
+ # Build a map: name -> target kind
143
+ target = {field.name: field.type for field in schema}
144
+
145
+ def _coerce_partition(pdf: pd.DataFrame) -> pd.DataFrame:
146
+ for col, typ in target.items():
147
+ if col not in pdf.columns:
148
+ continue
149
+
150
+ pa_type = typ
151
+
152
+ # String targets
153
+ if pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
154
+ # Convert Period or any dtype to string with NA-preservation
155
+ s = pdf[col]
156
+ if is_period_dtype(s):
157
+ pdf[col] = s.astype(str)
158
+ elif not is_string_dtype(s):
159
+ # astype("string") keeps NA; str(s) can produce "NaT" strings
160
+ try:
161
+ pdf[col] = s.astype("string")
162
+ except Exception:
163
+ pdf[col] = s.astype(str).astype("string")
164
+ continue
165
+
166
+ # Boolean targets
167
+ if pa.types.is_boolean(pa_type):
168
+ s = pdf[col]
169
+ # Allow object/bool mixtures; coerce via pandas nullable boolean then to bool
170
+ try:
171
+ pdf[col] = s.astype("boolean").astype(bool)
172
+ except Exception:
173
+ pdf[col] = s.astype(bool)
174
+ continue
175
+
176
+ # Integer targets
177
+ if pa.types.is_integer(pa_type):
178
+ s = pdf[col]
179
+ # Go through pandas nullable Int64 to preserve NA, then to int64 if clean
180
+ s2 = pd.to_numeric(s, errors="coerce").astype("Int64")
181
+ # If there are no nulls, downcast to numpy int64 for speed
182
+ if not s2.isna().any():
183
+ s2 = s2.astype("int64")
184
+ pdf[col] = s2
185
+ continue
186
+
187
+ # Floating targets
188
+ if pa.types.is_floating(pa_type):
189
+ pdf[col] = pd.to_numeric(pdf[col], errors="coerce").astype("float64")
190
+ continue
191
+
192
+ # Timestamp[ns] (optionally with tz)
193
+ if pa.types.is_timestamp(pa_type):
194
+ # If tz in Arrow type, you may want to localize; here we just ensure ns
195
+ pdf[col] = pd.to_datetime(pdf[col], errors="coerce")
196
+ continue
197
+
198
+ # Fallback: leave as-is
199
+ return pdf
200
+
201
+ # Provide a meta with target dtypes to avoid meta mismatch warnings
202
+ meta = {}
203
+ for name, typ in target.items():
204
+ # Rough meta mapping; Arrow large_string vs string both → 'string'
205
+ if pa.types.is_string(typ) or pa.types.is_large_string(typ):
206
+ meta[name] = pd.Series([], dtype="string")
207
+ elif pa.types.is_boolean(typ):
208
+ meta[name] = pd.Series([], dtype="bool")
209
+ elif pa.types.is_integer(typ):
210
+ meta[name] = pd.Series([], dtype="Int64") # nullable int
211
+ elif pa.types.is_floating(typ):
212
+ meta[name] = pd.Series([], dtype="float64")
213
+ elif pa.types.is_timestamp(typ):
214
+ meta[name] = pd.Series([], dtype="datetime64[ns]")
215
+ else:
216
+ meta[name] = pd.Series([], dtype="object")
217
+
218
+ # Start from current meta and update known columns
219
+ new_meta = ddf._meta.copy()
220
+ for k, v in meta.items():
221
+ if k in new_meta.columns:
222
+ new_meta[k] = v
223
+
224
+ return ddf.map_partitions(_coerce_partition, meta=new_meta)
@@ -0,0 +1,5 @@
1
+ from .sse_runner import SSERunner, _as_sse_msg
2
+
3
+ __all__ = [
4
+ "SSERunner", "_as_sse_msg"
5
+ ]
@@ -0,0 +1,82 @@
1
+ # jobs.py
2
+ import asyncio, json, uuid
3
+ import contextlib
4
+ import os
5
+
6
+ import redis.asyncio as redis
7
+ from fastapi import APIRouter
8
+ from sse_starlette.sse import EventSourceResponse
9
+ host = os.getenv("REDIS_HOST", "0.0.0.0")
10
+ port = int(os.getenv("REDIS_PORT", 6379))
11
+ db = int(os.getenv("REDIS_DB", 0))
12
+ router = APIRouter(prefix="/jobs", tags=["Jobs"])
13
+ r = redis.Redis(host=host, port=port, db=db, decode_responses=True) # strings for pubsub
14
+
15
+ CHANNEL = lambda job_id: f"job:{job_id}:events"
16
+ KEY_STATUS = lambda job_id: f"job:{job_id}:status" # JSON blob with state/progress
17
+ KEY_RESULT = lambda job_id: f"job:{job_id}:result" # final payload
18
+
19
+ async def publish(job_id: str, event: str, data: dict):
20
+ msg = json.dumps({"event": event, "data": data})
21
+ await r.publish(CHANNEL(job_id), msg)
22
+ # store last status
23
+ await r.set(KEY_STATUS(job_id), json.dumps({"event": event, "data": data}))
24
+
25
+ # ---- Worker entry (can live in a separate process) ----
26
+ async def run_job(job_id: str):
27
+ try:
28
+ await publish(job_id, "progress", {"message": "Initializing..."})
29
+ # ... do actual work, emit more progress
30
+ await asyncio.sleep(0.2)
31
+ # compute result
32
+ result = [{"id": 1, "ok": True}]
33
+ await r.set(KEY_RESULT(job_id), json.dumps(result), ex=3600)
34
+ await publish(job_id, "complete", {"records": len(result)})
35
+ except Exception as e:
36
+ await publish(job_id, "error", {"detail": str(e)})
37
+
38
+ # ---- API ----
39
+ @router.post("/start")
40
+ async def start_job():
41
+ job_id = str(uuid.uuid4())
42
+ # enqueue: prefer Celery/RQ/etc. For demo we detach a task.
43
+ asyncio.create_task(run_job(job_id))
44
+ return {"job_id": job_id}
45
+
46
+ @router.get("/{job_id}/stream")
47
+ async def stream(job_id: str):
48
+ pubsub = r.pubsub()
49
+ await pubsub.subscribe(CHANNEL(job_id))
50
+
51
+ async def gen():
52
+ try:
53
+ # emit latest known status immediately, if any
54
+ if (s := await r.get(KEY_STATUS(job_id))):
55
+ yield {"event": "progress", "data": s}
56
+ while True:
57
+ msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=30.0)
58
+ if msg and msg["type"] == "message":
59
+ payload = msg["data"] # already a JSON string
60
+ yield {"event": "message", "data": payload}
61
+ await asyncio.sleep(0.01)
62
+ finally:
63
+ with contextlib.suppress(Exception):
64
+ await pubsub.unsubscribe(CHANNEL(job_id))
65
+ await pubsub.close()
66
+
67
+ return EventSourceResponse(
68
+ gen(),
69
+ ping=15,
70
+ headers={"Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no"},
71
+ )
72
+
73
+ @router.get("/{job_id}/status")
74
+ async def status(job_id: str):
75
+ s = await r.get(KEY_STATUS(job_id))
76
+ done = await r.exists(KEY_RESULT(job_id))
77
+ return {"job_id": job_id, "status": json.loads(s) if s else None, "done": bool(done)}
78
+
79
+ @router.get("/{job_id}/result")
80
+ async def result(job_id: str):
81
+ data = await r.get(KEY_RESULT(job_id))
82
+ return {"job_id": job_id, "result": json.loads(data) if data else None}
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+ import asyncio, contextlib, inspect, json
3
+ from typing import Any, Awaitable, Callable, Dict, Optional, Union
4
+ from fastapi import Request
5
+ from sse_starlette.sse import EventSourceResponse
6
+ from sibi_dst.utils import Logger
7
+
8
+ Payload = Union[str, bytes, dict, list, None]
9
+ Task2 = Callable[[asyncio.Queue, str], Awaitable[Payload]]
10
+ Task3 = Callable[[asyncio.Queue, str, Dict[str, Any]], Awaitable[Payload]]
11
+ TaskFn = Union[Task2, Task3]
12
+
13
+ def _as_sse_msg(event: str, data: Any) -> dict:
14
+ return {"event": event, "data": json.dumps(data) if not isinstance(data, (str, bytes)) else data}
15
+
16
+ class SSERunner:
17
+ def __init__(self, *, task: TaskFn, logger: Logger, ping: int = 15,
18
+ headers: Optional[dict] = None, auto_complete: bool = True) -> None:
19
+ self.task = task
20
+ self.logger = logger
21
+ self.ping = ping
22
+ self.headers = headers or {"Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no"}
23
+ self.auto_complete = auto_complete
24
+ self._expects_ctx = len(inspect.signature(task).parameters) >= 3
25
+
26
+ async def _call_task(self, queue: asyncio.Queue, task_id: str, ctx: Dict[str, Any]) -> Payload:
27
+ if self._expects_ctx:
28
+ return await self.task(queue, task_id, ctx) # type: ignore[misc]
29
+ return await self.task(queue, task_id) # type: ignore[misc]
30
+
31
+ async def _worker(self, queue: asyncio.Queue, task_id: str, ctx: Dict[str, Any]) -> None:
32
+ self.logger.info(f"SSE {task_id}: start")
33
+ try:
34
+ await queue.put(_as_sse_msg("progress", {"message": "Task started"}))
35
+ payload = await self._call_task(queue, task_id, ctx)
36
+ if self.auto_complete:
37
+ final = payload if payload is not None else {"status": "complete"}
38
+ await queue.put(_as_sse_msg("complete", final))
39
+ self.logger.info(f"SSE {task_id}: complete")
40
+ except asyncio.CancelledError:
41
+ raise
42
+ except Exception as e:
43
+ self.logger.error(f"SSE {task_id} failed: {e}", exc_info=True)
44
+ await queue.put(_as_sse_msg("error", {"detail": str(e)}))
45
+ finally:
46
+ await queue.put(None)
47
+
48
+ def endpoint(self):
49
+ async def handler(request: Request): # <-- only Request
50
+ queue: asyncio.Queue = asyncio.Queue()
51
+ task_id = str(asyncio.get_running_loop().time()).replace(".", "")
52
+
53
+ ctx: Dict[str, Any] = {
54
+ "path": dict(request.path_params), # <-- pull path params here
55
+ "query": dict(request.query_params),
56
+ "method": request.method,
57
+ }
58
+ if request.headers.get("content-type", "").startswith("application/json"):
59
+ try:
60
+ ctx["body"] = await request.json()
61
+ except Exception:
62
+ ctx["body"] = None
63
+
64
+ worker = asyncio.create_task(self._worker(queue, task_id, ctx))
65
+
66
+ async def gen():
67
+ try:
68
+ while True:
69
+ msg = await queue.get()
70
+ if msg is None:
71
+ break
72
+ yield msg
73
+ finally:
74
+ if not worker.done():
75
+ worker.cancel()
76
+ with contextlib.suppress(Exception):
77
+ await worker
78
+
79
+ return EventSourceResponse(gen(), ping=self.ping, headers=self.headers)
80
+ return handler
81
+
82
+ __all__ = ["SSERunner", "_as_sse_msg"]