sibi-dst 2025.9.9__py3-none-any.whl → 2025.9.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,7 +37,6 @@ class DataWrapper(ManagedResource):
37
37
  dataclass: Type,
38
38
  date_field: str,
39
39
  data_path: str,
40
- parquet_filename: str,
41
40
  class_params: Optional[Dict] = None,
42
41
  load_params: Optional[Dict] = None,
43
42
  show_progress: bool = False,
@@ -50,7 +49,7 @@ class DataWrapper(ManagedResource):
50
49
  self.dataclass: Type = dataclass
51
50
  self.date_field: str = date_field
52
51
  self.data_path: str = self._ensure_forward_slash(data_path)
53
- self.parquet_filename: str = parquet_filename
52
+ self.partition_on_date: bool = True # Assume Hive-style date partitioning by default
54
53
 
55
54
  if self.fs is None:
56
55
  raise ValueError("DataWrapper requires a File system (fs) to be provided.")
@@ -282,16 +281,23 @@ class DataWrapper(ManagedResource):
282
281
  def _process_single_date(self, date: datetime.date):
283
282
  """Process a single date: load, save to Parquet."""
284
283
  # --- 1. Setup paths and logging ---
285
- path = f"{self.data_path}{date.year}/{date.month:02d}/{date.day:02d}/"
286
- log_extra = self._log_extra(date_context=date.isoformat())
287
- self.logger.debug(f"Processing date {date.isoformat()} for {path}", extra=log_extra)
288
-
284
+ path = self.data_path.rstrip("/")+"/"
285
+ if not self.partition_on_date:
286
+ # not a Hive-style partitioned path
287
+ path = f"{self.data_path}{date.year}/{date.month:02d}/{date.day:02d}/"
288
+ log_extra = self._log_extra(date_context=date.isoformat())
289
+ self.logger.debug(f"Processing date {date.isoformat()} for legacy {path}", extra=log_extra)
290
+ else :
291
+ # Hive-style partitioned path
292
+ log_extra = self._log_extra(date_context=date.isoformat(), partition_on=self.date_field)
293
+ self.logger.debug(f"Processing date {date.isoformat()} for partitioned {self.data_path} with hive-style partitions", extra=log_extra)
289
294
  # --- 2. Check if date/path should be skipped ---
290
295
  if (self.update_planner and path in self.update_planner.skipped and
291
296
  getattr(self.update_planner, 'ignore_missing', False)):
292
297
  self.logger.debug(f"Skipping {date} as it exists in the skipped list", extra=log_extra)
293
298
  return
294
- full_path = f"{path}{self.parquet_filename}"
299
+
300
+ self.logger.debug(f"Processing date {date.isoformat()} for {path}", extra=log_extra)
295
301
 
296
302
  # --- 3. Timing ---
297
303
  overall_start = time.perf_counter()
@@ -326,39 +332,44 @@ class DataWrapper(ManagedResource):
326
332
  self.mmanifest.record(full_path=path)
327
333
  except Exception as e:
328
334
  self.logger.error(f"Failed to record missing path {path}: {e}", extra=log_extra)
329
- self.logger.info(f"No data found for {full_path}. Logged to missing manifest.", extra=log_extra)
335
+ self.logger.info(f"No data found for {path}. Logged to missing manifest.", extra=log_extra)
330
336
  return # Done for this date
331
337
 
332
338
  if total_records < 0:
333
- self.logger.warning(f"Negative record count ({total_records}) for {full_path}. Proceeding.", extra=log_extra)
339
+ self.logger.warning(f"Negative record count ({total_records}) for {path}. Proceeding.", extra=log_extra)
334
340
  # Continue processing even with negative count
335
341
 
336
- # --- 6. Save to Parquet ---
337
- save_start = time.perf_counter()
338
- parquet_params = {
339
- "df_result": df,
340
- "parquet_storage_path": path,
341
- "fs": self.fs,
342
- "logger": self.logger,
343
- "debug": self.debug,
344
- "verbose": self.verbose,
345
- }
346
- self.logger.debug(f"{self.dataclass.__name__} saving to parquet started...", extra=log_extra)
347
- with ParquetSaver(**parquet_params) as ps:
348
- ps.save_to_parquet(self.parquet_filename, overwrite=True)
349
- save_time = time.perf_counter() - save_start
350
- self.logger.debug(f"Parquet saving for {date} completed in {save_time:.2f}s", extra=log_extra)
351
-
352
- # --- 7. Benchmarking ---
353
- total_time = time.perf_counter() - overall_start
354
- self.benchmarks[date] = {
355
- "load_duration": load_time,
356
- "save_duration": save_time,
357
- "total_duration": total_time,
358
- }
359
-
360
- # --- 8. Log Success ---
361
- self._log_success(date, total_time, full_path)
342
+ # --- 6. Save to Parquet ---
343
+ save_start = time.perf_counter()
344
+
345
+
346
+ parquet_params = {
347
+ "df_result": df,
348
+ "parquet_storage_path": path,
349
+ "fs": self.fs,
350
+ "logger": self.logger,
351
+ "debug": self.debug,
352
+ "verbose": self.verbose,
353
+ }
354
+ if self.partition_on_date:
355
+ df["partition_date"] = df[self.date_field].dt.date.astype(str)
356
+ parquet_params["partition_on"] = ["partition_date"]
357
+ self.logger.debug(f"{self.dataclass.__name__} saving to parquet started...", extra=log_extra)
358
+ with ParquetSaver(**parquet_params) as ps:
359
+ ps.save_to_parquet()
360
+ save_time = time.perf_counter() - save_start
361
+ self.logger.debug(f"Parquet saving for {date} completed in {save_time:.2f}s", extra=log_extra)
362
+
363
+ # --- 7. Benchmarking ---
364
+ total_time = time.perf_counter() - overall_start
365
+ self.benchmarks[date] = {
366
+ "load_duration": load_time,
367
+ "save_duration": save_time,
368
+ "total_duration": total_time,
369
+ }
370
+
371
+ # --- 8. Log Success ---
372
+ self._log_success(date, total_time, path)
362
373
 
363
374
  except Exception as e:
364
375
  # --- 9. Handle Errors ---
@@ -397,280 +408,3 @@ class DataWrapper(ManagedResource):
397
408
  except Exception as e:
398
409
  self.logger.error(f"Error generating benchmark summary: {e}", extra=self.logger_extra)
399
410
 
400
- # import datetime
401
- # import random
402
- # import threading
403
- # import time
404
- # from concurrent.futures import ThreadPoolExecutor, as_completed
405
- # from typing import Type, Any, Dict, Optional, Union, List, ClassVar
406
- #
407
- # import pandas as pd
408
- # from tqdm import tqdm
409
- #
410
- # from . import ManagedResource
411
- # from .parquet_saver import ParquetSaver
412
- #
413
- #
414
- # class DataWrapper(ManagedResource):
415
- # DEFAULT_PRIORITY_MAP: ClassVar[Dict[str, int]] = {
416
- # "overwrite": 1,
417
- # "missing_in_history": 2,
418
- # "existing_but_stale": 3,
419
- # "missing_outside_history": 4,
420
- # "file_is_recent": 0,
421
- # }
422
- # DEFAULT_MAX_AGE_MINUTES: int = 1440
423
- # DEFAULT_HISTORY_DAYS_THRESHOLD: int = 30
424
- #
425
- # logger_extra = {"sibi_dst_component": __name__}
426
- #
427
- # def __init__(
428
- # self,
429
- # dataclass: Type,
430
- # date_field: str,
431
- # data_path: str,
432
- # parquet_filename: str,
433
- # class_params: Optional[Dict] = None,
434
- # load_params: Optional[Dict] = None,
435
- # show_progress: bool = False,
436
- # timeout: float = 30,
437
- # max_threads: int = 3,
438
- # **kwargs: Any,
439
- # ):
440
- # super().__init__(**kwargs)
441
- # self.dataclass = dataclass
442
- # self.date_field = date_field
443
- # self.data_path = self._ensure_forward_slash(data_path)
444
- # self.parquet_filename = parquet_filename
445
- # if self.fs is None:
446
- # raise ValueError("DataWrapper requires a File system (fs) to be provided.")
447
- # self.show_progress = show_progress
448
- # self.timeout = timeout
449
- # self.max_threads = max_threads
450
- # self.class_params = class_params or {
451
- # "debug": self.debug,
452
- # "logger": self.logger,
453
- # "fs": self.fs,
454
- # "verbose": self.verbose,
455
- # }
456
- # self.load_params = load_params or {}
457
- #
458
- # self._lock = threading.Lock()
459
- # self.processed_dates: List[datetime.date] = []
460
- # self.benchmarks: Dict[datetime.date, Dict[str, float]] = {}
461
- # self.mmanifest = kwargs.get("mmanifest", None)
462
- # self.update_planner = kwargs.get("update_planner", None)
463
- #
464
- # # --- NEW: stop gate tripped during cleanup/interrupt to block further scheduling/retries
465
- # self._stop_event = threading.Event()
466
- # self.logger_extra.update({"action_module_name": "data_wrapper", "dataclass": self.dataclass.__name__})
467
- #
468
- # # ensure manifest is saved on context exit
469
- # def __exit__(self, exc_type, exc_val, exc_tb):
470
- # if self.mmanifest:
471
- # self.mmanifest.save()
472
- # super().__exit__(exc_type, exc_val, exc_tb)
473
- # return False
474
- #
475
- # # --- NEW: trip stop gate during class-specific cleanup (close/aclose/finalizer path)
476
- # def _cleanup(self) -> None:
477
- # self._stop_event.set()
478
- #
479
- # @staticmethod
480
- # def _convert_to_date(date: Union[datetime.date, str]) -> datetime.date:
481
- # if isinstance(date, datetime.date):
482
- # return date
483
- # try:
484
- # return pd.to_datetime(date).date()
485
- # except ValueError as e:
486
- # raise ValueError(f"Error converting {date} to datetime: {e}")
487
- #
488
- # @staticmethod
489
- # def _ensure_forward_slash(path: str) -> str:
490
- # return path.rstrip("/") + "/"
491
- #
492
- # def process(
493
- # self,
494
- # max_retries: int = 3,
495
- # backoff_base: float = 2.0,
496
- # backoff_jitter: float = 0.1,
497
- # backoff_max: float = 60.0,
498
- # ):
499
- # """
500
- # Execute the update plan with concurrency, retries and exponential backoff.
501
- # Stops scheduling immediately if closed or interrupted (Ctrl-C).
502
- # """
503
- # overall_start = time.perf_counter()
504
- # tasks = list(self.update_planner.get_tasks_by_priority())
505
- # if not tasks:
506
- # self.logger.info("No updates required based on the current plan.")
507
- # return
508
- #
509
- # if self.update_planner.show_progress:
510
- # self.update_planner.show_update_plan()
511
- #
512
- # try:
513
- # for priority, dates in tasks:
514
- # if self._stop_event.is_set():
515
- # break
516
- # self._execute_task_batch(priority, dates, max_retries, backoff_base, backoff_jitter, backoff_max)
517
- # except KeyboardInterrupt:
518
- # self.logger.warning("KeyboardInterrupt received — stopping scheduling and shutting down.", extra=self.logger_extra)
519
- # self._stop_event.set()
520
- # raise
521
- # finally:
522
- # total_time = time.perf_counter() - overall_start
523
- # if self.processed_dates:
524
- # count = len(self.processed_dates)
525
- # self.logger.info(f"Processed {count} dates in {total_time:.1f}s (avg {total_time / count:.1f}s/date)", extra=self.logger_extra)
526
- # if self.update_planner.show_progress:
527
- # self.show_benchmark_summary()
528
- #
529
- # def _execute_task_batch(
530
- # self,
531
- # priority: int,
532
- # dates: List[datetime.date],
533
- # max_retries: int,
534
- # backoff_base: float,
535
- # backoff_jitter: float,
536
- # backoff_max: float,
537
- # ):
538
- # desc = f"Processing {self.dataclass.__name__}, priority: {priority}"
539
- # max_thr = min(len(dates), self.max_threads)
540
- # self.logger.info(f"Executing {len(dates)} tasks with priority {priority} using {max_thr} threads.", extra=self.logger_extra)
541
- #
542
- # # Use explicit try/finally so we can request cancel of queued tasks on teardown
543
- # executor = ThreadPoolExecutor(max_workers=max_thr, thread_name_prefix="datawrapper")
544
- # try:
545
- # futures = {}
546
- # for date in dates:
547
- # if self._stop_event.is_set():
548
- # break
549
- # try:
550
- # fut = executor.submit(
551
- # self._process_date_with_retry, date, max_retries, backoff_base, backoff_jitter, backoff_max
552
- # )
553
- # futures[fut] = date
554
- # except RuntimeError as e:
555
- # # tolerate race: executor shutting down
556
- # if "cannot schedule new futures after shutdown" in str(e).lower():
557
- # self.logger.warning("Executor is shutting down; halting new submissions for this batch.", extra=self.logger_extra)
558
- # break
559
- # raise
560
- #
561
- # iterator = as_completed(futures)
562
- # if self.show_progress:
563
- # iterator = tqdm(iterator, total=len(futures), desc=desc)
564
- #
565
- # for future in iterator:
566
- # try:
567
- # future.result(timeout=self.timeout)
568
- # except Exception as e:
569
- # self.logger.error(f"Permanent failure for {futures[future]}: {e}", extra=self.logger_extra)
570
- # finally:
571
- # # Python 3.9+: cancel_futures prevents queued tasks from starting
572
- # executor.shutdown(wait=True, cancel_futures=True)
573
- #
574
- # def _process_date_with_retry(
575
- # self,
576
- # date: datetime.date,
577
- # max_retries: int,
578
- # backoff_base: float,
579
- # backoff_jitter: float,
580
- # backoff_max: float,
581
- # ):
582
- # for attempt in range(max_retries):
583
- # # --- NEW: bail out quickly if shutdown/interrupt began
584
- # if self._stop_event.is_set():
585
- # raise RuntimeError("shutting_down")
586
- #
587
- # try:
588
- # self._process_single_date(date)
589
- # return
590
- # except Exception as e:
591
- # if attempt < max_retries - 1 and not self._stop_event.is_set():
592
- # base_delay = min(backoff_base ** attempt, backoff_max)
593
- # delay = base_delay * (1 + random.uniform(0.0, max(0.0, backoff_jitter)))
594
- # self.logger.warning(
595
- # f"Retry {attempt + 1}/{max_retries} for {date}: {e} (sleep {delay:.2f}s)",
596
- # extra=self.logger_extra
597
- # )
598
- # time.sleep(delay)
599
- # else:
600
- # self.logger.error(f"Failed processing {date} after {max_retries} attempts.", extra=self.logger_extra)
601
- # raise
602
- #
603
- # def _process_single_date(self, date: datetime.date):
604
- # path = f"{self.data_path}{date.year}/{date.month:02d}/{date.day:02d}/"
605
- # self.logger.debug(f"Processing date {date.isoformat()} for {path}", extra=self.logger_extra)
606
- # if path in self.update_planner.skipped and self.update_planner.ignore_missing:
607
- # self.logger.debug(f"Skipping {date} as it exists in the skipped list", extra=self.logger_extra)
608
- # return
609
- # full_path = f"{path}{self.parquet_filename}"
610
- #
611
- # overall_start = time.perf_counter()
612
- # try:
613
- # load_start = time.perf_counter()
614
- # date_filter = {f"{self.date_field}__date": {date.isoformat()}}
615
- # self.logger.debug(f"{self.dataclass.__name__} is loading data for {date} with filter: {date_filter}", extra=self.logger_extra)
616
- #
617
- # local_load_params = self.load_params.copy()
618
- # local_load_params.update(date_filter)
619
- #
620
- # with self.dataclass(**self.class_params) as local_class_instance:
621
- # df = local_class_instance.load(**local_load_params) # expected to be Dask
622
- # load_time = time.perf_counter() - load_start
623
- #
624
- # if hasattr(local_class_instance, "total_records"):
625
- # total_records = int(local_class_instance.total_records)
626
- # self.logger.debug(f"Total records loaded: {total_records}", extra=self.logger_extra)
627
- #
628
- # if total_records == 0:
629
- # if self.mmanifest:
630
- # self.mmanifest.record(full_path=path)
631
- # self.logger.info(f"No data found for {full_path}. Logged to missing manifest.", extra=self.logger_extra)
632
- # return
633
- #
634
- # if total_records < 0:
635
- # self.logger.warning(f"Negative record count ({total_records}) for {full_path}.", extra=self.logger_extra)
636
- # return
637
- #
638
- # save_start = time.perf_counter()
639
- # parquet_params = {
640
- # "df_result": df,
641
- # "parquet_storage_path": path,
642
- # "fs": self.fs,
643
- # "logger": self.logger,
644
- # "debug": self.debug,
645
- # }
646
- # with ParquetSaver(**parquet_params) as ps:
647
- # ps.save_to_parquet(self.parquet_filename, overwrite=True)
648
- # save_time = time.perf_counter() - save_start
649
- #
650
- # total_time = time.perf_counter() - overall_start
651
- # self.benchmarks[date] = {
652
- # "load_duration": load_time,
653
- # "save_duration": save_time,
654
- # "total_duration": total_time,
655
- # }
656
- # self._log_success(date, total_time, full_path)
657
- #
658
- # except Exception as e:
659
- # self._log_failure(date, e)
660
- # raise
661
- #
662
- # def _log_success(self, date: datetime.date, duration: float, path: str):
663
- # self.logger.info(f"Completed {date} in {duration:.1f}s | Saved to {path}", extra=self.logger_extra)
664
- # self.processed_dates.append(date)
665
- #
666
- # def _log_failure(self, date: datetime.date, error: Exception):
667
- # self.logger.error(f"Failed processing {date}: {error}", extra=self.logger_extra)
668
- #
669
- # def show_benchmark_summary(self):
670
- # if not self.benchmarks:
671
- # self.logger.info("No benchmarking data to show", extra=self.logger_extra)
672
- # return
673
- # df_bench = pd.DataFrame.from_records([{"date": d, **m} for d, m in self.benchmarks.items()])
674
- # df_bench = df_bench.set_index("date").sort_index(ascending=not self.update_planner.reverse_order)
675
- # self.logger.info(f"Benchmark Summary:\n {self.dataclass.__name__}\n" + df_bench.to_string(), extra=self.logger_extra)
676
- #
@@ -71,6 +71,7 @@ class ParquetSaver(ManagedResource):
71
71
  max_delete_workers: int = 8,
72
72
  write_gate_max: int = 2,
73
73
  write_gate_key: Optional[str] = None,
74
+ partition_on: Optional[list[str]] = None,
74
75
  **kwargs: Any,
75
76
  ):
76
77
  super().__init__(**kwargs)
@@ -93,6 +94,7 @@ class ParquetSaver(ManagedResource):
93
94
  self.max_delete_workers = max(1, int(max_delete_workers))
94
95
  self.write_gate_max = max(1, int(write_gate_max))
95
96
  self.write_gate_key = (write_gate_key or self.parquet_storage_path).rstrip("/")
97
+ self.partition_on = partition_on
96
98
 
97
99
  # Fix: Remove deprecated coerce_timestamps parameter
98
100
  self.pyarrow_args.setdefault("compression", "zstd")
@@ -103,7 +105,18 @@ class ParquetSaver(ManagedResource):
103
105
 
104
106
  # ---------- public API ----------
105
107
  def save_to_parquet(self, output_directory_name: str = "default_output", overwrite: bool = True) -> str:
106
- target_path = f"{self.parquet_storage_path}/{output_directory_name}".rstrip("/")
108
+ """
109
+ Save the Dask DataFrame to Parquet. If partition_on is provided, write as a
110
+ partitioned dataset without overwriting earlier partitions.
111
+ """
112
+ # Always treat as a directory target
113
+ if self.partition_on:
114
+ overwrite = False
115
+ # we override the output_directory_name and overwrite setting to avoid confusion since dask will (re) create subdirs
116
+ # Partitioned dataset → write directly to root directory
117
+ target_path = self.parquet_storage_path.rstrip("/")
118
+ else:
119
+ target_path = f"{self.parquet_storage_path}/{output_directory_name}".rstrip("/")
107
120
 
108
121
  sem = get_write_sem(self.write_gate_key, self.write_gate_max)
109
122
  with sem:
@@ -111,7 +124,7 @@ class ParquetSaver(ManagedResource):
111
124
  self._clear_directory_safely(target_path)
112
125
  self.fs.mkdirs(target_path, exist_ok=True)
113
126
 
114
- # Define a pyarrow schema and coerce the Dask frame to match it.
127
+ # Enforce schema before write
115
128
  schema = self._define_schema()
116
129
  ddf = self._coerce_ddf_to_schema(self.df_result, schema)
117
130
 
@@ -128,25 +141,25 @@ class ParquetSaver(ManagedResource):
128
141
  pa.set_cpu_count(self.arrow_cpu)
129
142
 
130
143
  try:
144
+ params = {
145
+ "path": target_path,
146
+ "engine": "pyarrow",
147
+ "filesystem": self.fs,
148
+ "write_index": self.write_index,
149
+ "write_metadata_file": self.write_metadata_file,
150
+ **self.pyarrow_args,
151
+ }
152
+ self.partition_on = self.partition_on if isinstance(self.partition_on, list) else None
153
+ if self.partition_on:
154
+ params["partition_on"] = self.partition_on
155
+
131
156
  with self._local_dask_pool():
132
- ddf.to_parquet(
133
- path=target_path,
134
- engine="pyarrow",
135
- schema=schema,
136
- overwrite=False,
137
- filesystem=self.fs,
138
- write_index=self.write_index,
139
- write_metadata_file=self.write_metadata_file,
140
- **self.pyarrow_args,
141
- )
157
+ ddf.to_parquet(**params)
142
158
  finally:
143
159
  if old_arrow_cpu is not None:
144
160
  pa.set_cpu_count(old_arrow_cpu)
145
161
 
146
- self.logger.info(
147
- f"Parquet dataset written: {target_path}",
148
- extra=self.logger_extra,
149
- )
162
+ self.logger.info(f"Parquet dataset written: {target_path}", extra=self.logger_extra)
150
163
  return target_path
151
164
 
152
165
  @contextmanager
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
  import asyncio, contextlib, inspect, json
3
- from typing import Any, Awaitable, Callable, Dict, Optional, Union
3
+ from typing import Any, Awaitable, Callable, Dict, Optional, Union, Mapping, MutableMapping
4
4
  from fastapi import Request
5
5
  from sse_starlette.sse import EventSourceResponse
6
6
  from sibi_dst.utils import Logger
@@ -13,6 +13,16 @@ TaskFn = Union[Task2, Task3]
13
13
  def _as_sse_msg(event: str, data: Any) -> dict:
14
14
  return {"event": event, "data": json.dumps(data) if not isinstance(data, (str, bytes)) else data}
15
15
 
16
+ def _merge_ctx(*parts: Optional[Mapping[str, Any]]) -> Dict[str, Any]:
17
+ """Right-most precedence; shallow merge is sufficient for our keys."""
18
+ out: Dict[str, Any] = {}
19
+ for p in parts:
20
+ if not p:
21
+ continue
22
+ for k, v in p.items():
23
+ out[k] = v
24
+ return out
25
+
16
26
  class SSERunner:
17
27
  def __init__(self, *, task: TaskFn, logger: Logger, ping: int = 15,
18
28
  headers: Optional[dict] = None, auto_complete: bool = True) -> None:
@@ -31,39 +41,57 @@ class SSERunner:
31
41
  async def _worker(self, queue: asyncio.Queue, task_id: str, ctx: Dict[str, Any]) -> None:
32
42
  self.logger.info(f"SSE {task_id}: start")
33
43
  try:
34
- await queue.put(_as_sse_msg("progress", {"message": "Task started"}))
44
+ await queue.put(_as_sse_msg("progress", {"message": "Task started", "task_id": task_id}))
35
45
  payload = await self._call_task(queue, task_id, ctx)
36
46
  if self.auto_complete:
37
47
  final = payload if payload is not None else {"status": "complete"}
48
+ if isinstance(final, dict) and "task_id" not in final:
49
+ final["task_id"] = task_id
38
50
  await queue.put(_as_sse_msg("complete", final))
39
51
  self.logger.info(f"SSE {task_id}: complete")
40
52
  except asyncio.CancelledError:
41
53
  raise
42
54
  except Exception as e:
43
55
  self.logger.error(f"SSE {task_id} failed: {e}", exc_info=True)
44
- await queue.put(_as_sse_msg("error", {"detail": str(e)}))
56
+ await queue.put(_as_sse_msg("error", {"detail": str(e), "task_id": task_id}))
45
57
  finally:
46
58
  await queue.put(None)
47
59
 
48
- def endpoint(self):
49
- async def handler(request: Request): # <-- only Request
60
+ def endpoint(self, *, ctx: Optional[Dict[str, Any]] = None):
61
+ """
62
+ Create an SSE endpoint.
63
+ - ctx: optional explicit context dict provided by the caller.
64
+ This overrides request-derived context and request.state.ctx.
65
+ Precedence when merging: request-derived < request.state.ctx < ctx (explicit).
66
+ """
67
+ async def handler(request: Request):
50
68
  queue: asyncio.Queue = asyncio.Queue()
51
69
  task_id = str(asyncio.get_running_loop().time()).replace(".", "")
52
70
  self.logger.debug(
53
- f"SSE {task_id}: new request client={request.client} path={request.url.path} q={dict(request.query_params)}")
71
+ f"SSE {task_id}: new request client={request.client} path={request.url.path} q={dict(request.query_params)}"
72
+ )
54
73
 
55
- ctx: Dict[str, Any] = {
56
- "path": dict(request.path_params), # <-- pull path params here
74
+ # Base ctx from the HTTP request
75
+ base_ctx: Dict[str, Any] = {
76
+ "path": dict(request.path_params),
57
77
  "query": dict(request.query_params),
58
78
  "method": request.method,
79
+ "headers": dict(request.headers) if hasattr(request, "headers") else None,
59
80
  }
60
81
  if request.headers.get("content-type", "").startswith("application/json"):
61
82
  try:
62
- ctx["body"] = await request.json()
83
+ base_ctx["body"] = await request.json()
63
84
  except Exception:
64
- ctx["body"] = None
85
+ base_ctx["body"] = None
86
+
87
+ # Pull any pre-attached ctx from request.state
88
+ state_ctx: Optional[Dict[str, Any]] = getattr(request.state, "ctx", None)
89
+
90
+ # Merge with precedence: base_ctx < state_ctx < explicit ctx
91
+ merged_ctx = _merge_ctx(base_ctx, state_ctx, ctx)
65
92
 
66
- worker = asyncio.create_task(self._worker(queue, task_id, ctx))
93
+ # Run worker
94
+ worker = asyncio.create_task(self._worker(queue, task_id, merged_ctx))
67
95
 
68
96
  async def gen():
69
97
  try: