sibi-flux 2025.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sibi_dst/__init__.py +44 -0
  2. sibi_flux/__init__.py +49 -0
  3. sibi_flux/artifacts/__init__.py +7 -0
  4. sibi_flux/artifacts/base.py +166 -0
  5. sibi_flux/artifacts/parquet.py +360 -0
  6. sibi_flux/artifacts/parquet_engine/__init__.py +5 -0
  7. sibi_flux/artifacts/parquet_engine/executor.py +204 -0
  8. sibi_flux/artifacts/parquet_engine/manifest.py +101 -0
  9. sibi_flux/artifacts/parquet_engine/planner.py +544 -0
  10. sibi_flux/conf/settings.py +131 -0
  11. sibi_flux/core/__init__.py +5 -0
  12. sibi_flux/core/managed_resource/__init__.py +3 -0
  13. sibi_flux/core/managed_resource/_managed_resource.py +733 -0
  14. sibi_flux/core/type_maps/__init__.py +100 -0
  15. sibi_flux/dask_cluster/__init__.py +47 -0
  16. sibi_flux/dask_cluster/async_core.py +27 -0
  17. sibi_flux/dask_cluster/client_manager.py +549 -0
  18. sibi_flux/dask_cluster/core.py +322 -0
  19. sibi_flux/dask_cluster/exceptions.py +34 -0
  20. sibi_flux/dask_cluster/utils.py +49 -0
  21. sibi_flux/datacube/__init__.py +3 -0
  22. sibi_flux/datacube/_data_cube.py +332 -0
  23. sibi_flux/datacube/config_engine.py +152 -0
  24. sibi_flux/datacube/field_factory.py +48 -0
  25. sibi_flux/datacube/field_registry.py +122 -0
  26. sibi_flux/datacube/generator.py +677 -0
  27. sibi_flux/datacube/orchestrator.py +171 -0
  28. sibi_flux/dataset/__init__.py +3 -0
  29. sibi_flux/dataset/_dataset.py +162 -0
  30. sibi_flux/df_enricher/__init__.py +56 -0
  31. sibi_flux/df_enricher/async_enricher.py +201 -0
  32. sibi_flux/df_enricher/merger.py +253 -0
  33. sibi_flux/df_enricher/specs.py +45 -0
  34. sibi_flux/df_enricher/types.py +12 -0
  35. sibi_flux/df_helper/__init__.py +5 -0
  36. sibi_flux/df_helper/_df_helper.py +450 -0
  37. sibi_flux/df_helper/backends/__init__.py +34 -0
  38. sibi_flux/df_helper/backends/_params.py +173 -0
  39. sibi_flux/df_helper/backends/_strategies.py +295 -0
  40. sibi_flux/df_helper/backends/http/__init__.py +5 -0
  41. sibi_flux/df_helper/backends/http/_http_config.py +122 -0
  42. sibi_flux/df_helper/backends/parquet/__init__.py +7 -0
  43. sibi_flux/df_helper/backends/parquet/_parquet_options.py +268 -0
  44. sibi_flux/df_helper/backends/sqlalchemy/__init__.py +9 -0
  45. sibi_flux/df_helper/backends/sqlalchemy/_db_connection.py +256 -0
  46. sibi_flux/df_helper/backends/sqlalchemy/_db_gatekeeper.py +15 -0
  47. sibi_flux/df_helper/backends/sqlalchemy/_io_dask.py +386 -0
  48. sibi_flux/df_helper/backends/sqlalchemy/_load_from_db.py +134 -0
  49. sibi_flux/df_helper/backends/sqlalchemy/_model_registry.py +239 -0
  50. sibi_flux/df_helper/backends/sqlalchemy/_sql_model_builder.py +42 -0
  51. sibi_flux/df_helper/backends/utils.py +32 -0
  52. sibi_flux/df_helper/core/__init__.py +15 -0
  53. sibi_flux/df_helper/core/_defaults.py +104 -0
  54. sibi_flux/df_helper/core/_filter_handler.py +617 -0
  55. sibi_flux/df_helper/core/_params_config.py +185 -0
  56. sibi_flux/df_helper/core/_query_config.py +17 -0
  57. sibi_flux/df_validator/__init__.py +3 -0
  58. sibi_flux/df_validator/_df_validator.py +222 -0
  59. sibi_flux/logger/__init__.py +1 -0
  60. sibi_flux/logger/_logger.py +480 -0
  61. sibi_flux/mcp/__init__.py +26 -0
  62. sibi_flux/mcp/client.py +150 -0
  63. sibi_flux/mcp/router.py +126 -0
  64. sibi_flux/orchestration/__init__.py +9 -0
  65. sibi_flux/orchestration/_artifact_orchestrator.py +346 -0
  66. sibi_flux/orchestration/_pipeline_executor.py +212 -0
  67. sibi_flux/osmnx_helper/__init__.py +22 -0
  68. sibi_flux/osmnx_helper/_pbf_handler.py +384 -0
  69. sibi_flux/osmnx_helper/graph_loader.py +225 -0
  70. sibi_flux/osmnx_helper/utils.py +100 -0
  71. sibi_flux/pipelines/__init__.py +3 -0
  72. sibi_flux/pipelines/base.py +218 -0
  73. sibi_flux/py.typed +0 -0
  74. sibi_flux/readers/__init__.py +3 -0
  75. sibi_flux/readers/base.py +82 -0
  76. sibi_flux/readers/parquet.py +106 -0
  77. sibi_flux/utils/__init__.py +53 -0
  78. sibi_flux/utils/boilerplate/__init__.py +19 -0
  79. sibi_flux/utils/boilerplate/base_attacher.py +45 -0
  80. sibi_flux/utils/boilerplate/base_cube_router.py +283 -0
  81. sibi_flux/utils/boilerplate/base_data_cube.py +132 -0
  82. sibi_flux/utils/boilerplate/base_pipeline_template.py +54 -0
  83. sibi_flux/utils/boilerplate/hybrid_data_loader.py +193 -0
  84. sibi_flux/utils/clickhouse_writer/__init__.py +6 -0
  85. sibi_flux/utils/clickhouse_writer/_clickhouse_writer.py +225 -0
  86. sibi_flux/utils/common.py +7 -0
  87. sibi_flux/utils/credentials/__init__.py +3 -0
  88. sibi_flux/utils/credentials/_config_manager.py +155 -0
  89. sibi_flux/utils/dask_utils.py +14 -0
  90. sibi_flux/utils/data_utils/__init__.py +3 -0
  91. sibi_flux/utils/data_utils/_data_utils.py +389 -0
  92. sibi_flux/utils/dataframe_utils.py +52 -0
  93. sibi_flux/utils/date_utils/__init__.py +10 -0
  94. sibi_flux/utils/date_utils/_business_days.py +220 -0
  95. sibi_flux/utils/date_utils/_date_utils.py +311 -0
  96. sibi_flux/utils/date_utils/_file_age_checker.py +319 -0
  97. sibi_flux/utils/file_utils.py +48 -0
  98. sibi_flux/utils/filepath_generator/__init__.py +5 -0
  99. sibi_flux/utils/filepath_generator/_filepath_generator.py +185 -0
  100. sibi_flux/utils/parquet_saver/__init__.py +6 -0
  101. sibi_flux/utils/parquet_saver/_parquet_saver.py +436 -0
  102. sibi_flux/utils/parquet_saver/_write_gatekeeper.py +33 -0
  103. sibi_flux/utils/retry.py +46 -0
  104. sibi_flux/utils/storage/__init__.py +7 -0
  105. sibi_flux/utils/storage/_fs_registry.py +112 -0
  106. sibi_flux/utils/storage/_storage_manager.py +257 -0
  107. sibi_flux/utils/storage/factory.py +33 -0
  108. sibi_flux-2025.12.0.dist-info/METADATA +283 -0
  109. sibi_flux-2025.12.0.dist-info/RECORD +110 -0
  110. sibi_flux-2025.12.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,185 @@
1
+ import datetime as dt
2
+ import re
3
+ from typing import List, Optional, Iterable
4
+
5
+ import fsspec
6
+ import pyarrow.dataset as ds
7
+ from fsspec.utils import infer_storage_options
8
+ from sibi_flux.logger import Logger
9
+
10
+
11
+ class FilePathGenerator:
12
+ """
13
+ Scans Hive-partitioned directories (key=value) and returns paths for a specific date range.
14
+ Supports single-key date partitions (e.g. partition_date=YYYY-MM-DD) or
15
+ composite hierarchies (year=YYYY/month=MM/day=DD).
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ base_path: str = "",
21
+ *,
22
+ fs=None,
23
+ date_partition_key: str = "partition_date",
24
+ logger: Optional[Logger] = None,
25
+ debug: bool = False,
26
+ storage_options: Optional[dict] = None,
27
+ exclude_patterns: Optional[Iterable[str]] = None,
28
+ file_extension: str = "parquet",
29
+ ):
30
+ """
31
+ Args:
32
+ base_path: S3/Local path to the root of the dataset.
33
+ fs: Optional fsspec filesystem instance.
34
+ date_partition_key: The hive column name used for filtering.
35
+ Defaults to "partition_date".
36
+ If the dataset uses standard year/month/day hierarchy,
37
+ this is ignored automatically.
38
+ """
39
+ self.logger = logger or Logger.default_logger(
40
+ logger_name=self.__class__.__name__
41
+ )
42
+ self.debug = debug
43
+ self.storage_options = storage_options or {}
44
+ self.file_extension = file_extension.lstrip(".")
45
+ self._compiled_exclusions = [re.compile(p) for p in (exclude_patterns or [])]
46
+ self.date_partition_key = date_partition_key
47
+
48
+ # 1. Setup Filesystem (Protocol Stripping)
49
+ if fs:
50
+ self.fs = fs
51
+ self.base_path = infer_storage_options(base_path)["path"]
52
+ else:
53
+ self.fs, self.base_path = fsspec.core.url_to_fs(
54
+ base_path, **self.storage_options
55
+ )
56
+
57
+ # 2. Strict Hive Partitioning
58
+ # This automatically parses "key=value" directory names.
59
+ self.partitioning = ds.partitioning(flavor="hive")
60
+
61
+ if self.debug:
62
+ self.logger.debug(
63
+ f"Init: path={self.base_path}, date_key={self.date_partition_key}"
64
+ )
65
+
66
+ def generate_file_paths(
67
+ self, start_date, end_date, engine: str = "dask"
68
+ ) -> List[str]:
69
+ sd = self._to_date(start_date)
70
+ ed = self._to_date(end_date)
71
+ if sd > ed:
72
+ sd, ed = ed, sd
73
+
74
+ # Create Dataset
75
+ try:
76
+ dataset = ds.dataset(
77
+ source=self.base_path,
78
+ filesystem=self.fs,
79
+ format=self.file_extension,
80
+ partitioning=self.partitioning,
81
+ )
82
+ except Exception as e:
83
+ if self.debug:
84
+ self.logger.error(f"Failed to load dataset: {e}")
85
+ return []
86
+
87
+ # 3. Dynamic Filter Logic
88
+ schema_names = set(dataset.schema.names)
89
+ filter_expr = None
90
+ filter_mode = "unknown"
91
+
92
+ # Case A: User-specified single column (e.g., partition_date=2024-01-01)
93
+ if self.date_partition_key in schema_names:
94
+ filter_mode = "single_key"
95
+ # We filter treating the key as a string comparable to YYYY-MM-DD
96
+ filter_expr = (ds.field(self.date_partition_key) >= str(sd)) & (
97
+ ds.field(self.date_partition_key) <= str(ed)
98
+ )
99
+
100
+ # Case B: Standard Composite Hierarchy (year=..., month=..., day=...)
101
+ # We auto-fallback to this if the single key isn't found
102
+ elif {"year", "month", "day"}.issubset(schema_names):
103
+ filter_mode = "composite_ymd"
104
+ # Broad filter on year first for performance
105
+ filter_expr = (ds.field("year") >= sd.year) & (ds.field("year") <= ed.year)
106
+
107
+ else:
108
+ self.logger.warning(
109
+ f"Could not find partition key '{self.date_partition_key}' or standard year/month/day keys "
110
+ f"in dataset schema: {schema_names}. Returning all files (unfiltered)."
111
+ )
112
+
113
+ # 4. Materialize & Refine
114
+ # We assume PyArrow's filter is "coarse" (partition level), so we double-check checks in Python
115
+ valid_paths = []
116
+ fragments = dataset.get_fragments(filter=filter_expr)
117
+
118
+ for frag in fragments:
119
+ if self._is_file_in_range(frag, sd, ed, filter_mode):
120
+ path = frag.path
121
+ # Ensure full protocol paths for Pandas/Dask
122
+ if not path.startswith(self.fs.protocol):
123
+ path = self.fs.unstrip_protocol(path)
124
+
125
+ if not self._is_excluded(path):
126
+ valid_paths.append(path)
127
+
128
+ if self.debug:
129
+ self.logger.debug(
130
+ f"Found {len(valid_paths)} files via '{filter_mode}' filter."
131
+ )
132
+
133
+ return sorted(valid_paths)
134
+
135
+ def _is_file_in_range(self, frag, sd, ed, mode: str) -> bool:
136
+ """
137
+ Parses partition keys from fragment metadata and validates exact date range.
138
+ """
139
+ try:
140
+ keys = ds._get_partition_keys(frag.partition_expression)
141
+
142
+ if mode == "single_key":
143
+ val = keys.get(self.date_partition_key)
144
+ if not val:
145
+ return False
146
+
147
+ # Handle String vs Date type inference
148
+ f_date = (
149
+ val
150
+ if isinstance(val, (dt.date, dt.datetime))
151
+ else self._to_date(val)
152
+ )
153
+ # Ensure comparison between dates (not datetime vs date)
154
+ if isinstance(f_date, dt.datetime):
155
+ f_date = f_date.date()
156
+ return sd <= f_date <= ed
157
+
158
+ elif mode == "composite_ymd":
159
+ f_date = dt.date(
160
+ int(keys["year"]), int(keys["month"]), int(keys["day"])
161
+ )
162
+ return sd <= f_date <= ed
163
+
164
+ return True # If mode unknown, we accept the file (passed by the coarse filter or absence thereof)
165
+
166
+ except Exception:
167
+ # If parsing fails (e.g. malformed folder name), exclude safely
168
+ return False
169
+
170
+ # ------------------------- Helpers -------------------------
171
+
172
+ def _is_excluded(self, path: str) -> bool:
173
+ return any(pat.search(path) for pat in self._compiled_exclusions)
174
+
175
+ @staticmethod
176
+ def _to_date(x) -> dt.date:
177
+ if isinstance(x, dt.datetime):
178
+ return x.date()
179
+ if isinstance(x, dt.date):
180
+ return x
181
+ try:
182
+ return dt.datetime.strptime(str(x), "%Y-%m-%d").date()
183
+ except ValueError:
184
+ # Fallback for weird formats if necessary
185
+ return dt.date.min
@@ -0,0 +1,6 @@
1
+ from ._parquet_saver import ParquetSaver, _coerce_partition
2
+
3
+ __all__ = [
4
+ "ParquetSaver",
5
+ "_coerce_partition",
6
+ ]
@@ -0,0 +1,436 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ import os
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from functools import partial
7
+ from multiprocessing.pool import ThreadPool
8
+ from typing import Any, Dict, Optional, List
9
+
10
+ import dask.dataframe as dd
11
+ import pandas as pd
12
+ import pyarrow as pa
13
+
14
+ from sibi_flux.core import ManagedResource
15
+ from sibi_flux.dask_cluster.core import safe_persist
16
+ from ._write_gatekeeper import get_write_sem
17
+
18
+ warnings.filterwarnings(
19
+ "ignore", message="Passing 'overwrite=True' to to_parquet is deprecated"
20
+ )
21
+
22
+
23
+ def _coerce_partition(
24
+ pdf: pd.DataFrame,
25
+ target: Dict[str, pa.DataType],
26
+ partition_cols: Optional[List[str]] = None,
27
+ ) -> pd.DataFrame:
28
+ """
29
+ Applies type conversions to a single pandas partition.
30
+ """
31
+ partition_cols = partition_cols or []
32
+
33
+ for col, pa_type in target.items():
34
+ if col not in pdf.columns:
35
+ # Fix: Ensure column exists if it's in the target schema
36
+ # Fix: Ensure column exists if it's in the target schema
37
+ if pa.types.is_string(pa_type):
38
+ pdf[col] = None
39
+ # We will cast to string[pyarrow] below
40
+ elif pa.types.is_integer(pa_type):
41
+ pdf[col] = pd.NA
42
+ elif pa.types.is_floating(pa_type):
43
+ pdf[col] = float("nan")
44
+ elif pa.types.is_boolean(pa_type):
45
+ pdf[col] = pd.NA
46
+ elif pa.types.is_timestamp(pa_type):
47
+ pdf[col] = pd.NaT
48
+ else:
49
+ pdf[col] = None
50
+
51
+ try:
52
+ # --- 0. Partition Column Logic ---
53
+ if col in partition_cols:
54
+ try:
55
+ if not pd.api.types.is_numeric_dtype(pdf[col]):
56
+ # SAFE PARSING: Use format='mixed' here too
57
+ temp_series = pd.to_datetime(
58
+ pdf[col], errors="raise", utc=True, format="mixed"
59
+ )
60
+ pdf[col] = temp_series.dt.strftime("%Y-%m-%d").astype(
61
+ "string[pyarrow]"
62
+ )
63
+ continue
64
+ except (ValueError, TypeError):
65
+ pass
66
+
67
+ pdf[col] = pdf[col].astype("string[pyarrow]")
68
+ continue
69
+
70
+ # --- 1. Timestamp Coercion ---
71
+ if pa.types.is_timestamp(pa_type):
72
+ if not pd.api.types.is_datetime64_any_dtype(pdf[col]):
73
+ # CRITICAL FIX: format='mixed' handles naive/aware mix in same column
74
+ pdf[col] = pd.to_datetime(
75
+ pdf[col], errors="coerce", utc=True, format="mixed"
76
+ )
77
+
78
+ if pdf[col].dt.tz is None:
79
+ pdf[col] = pdf[col].dt.tz_localize("UTC")
80
+ else:
81
+ pdf[col] = pdf[col].dt.tz_convert("UTC")
82
+
83
+ pdf[col] = pdf[col].astype("timestamp[ns, tz=UTC][pyarrow]")
84
+ continue
85
+
86
+ # --- 2. Other Types ---
87
+ current_dtype_str = str(pdf[col].dtype)
88
+
89
+ if pa.types.is_string(pa_type) and "string" not in current_dtype_str:
90
+ pdf[col] = pdf[col].astype("string[pyarrow]")
91
+
92
+ elif pa.types.is_boolean(pa_type) and "bool" not in current_dtype_str:
93
+ if pd.api.types.is_object_dtype(
94
+ pdf[col]
95
+ ) or pd.api.types.is_string_dtype(pdf[col]):
96
+ pdf[col] = (
97
+ pdf[col].astype(str).str.lower().isin(["true", "1", "yes"])
98
+ )
99
+ pdf[col] = pdf[col].astype("boolean[pyarrow]")
100
+
101
+ elif pa.types.is_integer(pa_type) and "int" not in current_dtype_str:
102
+ pdf[col] = pd.to_numeric(pdf[col], errors="coerce").astype(
103
+ "int64[pyarrow]"
104
+ )
105
+
106
+ elif pa.types.is_floating(pa_type) and "float" not in current_dtype_str:
107
+ pdf[col] = pd.to_numeric(pdf[col], errors="coerce").astype(
108
+ "float64[pyarrow]"
109
+ )
110
+
111
+ except Exception:
112
+ pass
113
+
114
+ return pdf
115
+
116
+
117
+ class ParquetSaver(ManagedResource):
118
+ """
119
+ Production-grade Dask -> Parquet writer.
120
+ """
121
+
122
+ logger_extra = {"sibi_flux_component": __name__}
123
+
124
+ def __init__(
125
+ self,
126
+ df_result: dd.DataFrame,
127
+ parquet_storage_path: str,
128
+ *,
129
+ repartition_size: Optional[str] = "128MB",
130
+ persist: bool = False,
131
+ write_index: bool = False,
132
+ write_metadata_file: bool = True,
133
+ pyarrow_args: Optional[Dict[str, Any]] = None,
134
+ writer_threads: int | str = "auto",
135
+ arrow_cpu: Optional[int] = None,
136
+ partitions_per_round: int = 24,
137
+ max_delete_workers: int = 8,
138
+ write_gate_max: int = 2,
139
+ write_gate_key: Optional[str] = None,
140
+ partition_on: Optional[list[str]] = None,
141
+ dask_client: Optional[Any] = None,
142
+ schema: Optional[pa.Schema] = None,
143
+ **kwargs: Any,
144
+ ):
145
+ super().__init__(**kwargs)
146
+
147
+ if not isinstance(df_result, dd.DataFrame):
148
+ raise TypeError("df_result must be a Dask DataFrame")
149
+ if not self.fs:
150
+ raise ValueError("File system (fs) must be provided to ParquetSaver.")
151
+
152
+ self.df_result = df_result
153
+ self.dask_client = dask_client
154
+ self.parquet_storage_path = parquet_storage_path.rstrip("/")
155
+ self.repartition_size = repartition_size
156
+ self.persist = persist
157
+ self.write_index = write_index
158
+ self.write_metadata_file = write_metadata_file
159
+ self.pyarrow_args = dict(pyarrow_args or {})
160
+
161
+ if writer_threads == "auto" or writer_threads is None:
162
+ self.writer_threads = self._calculate_optimal_threads()
163
+ self.logger.debug(f"Auto-tuned writer_threads to: {self.writer_threads}")
164
+ else:
165
+ self.writer_threads = max(1, int(writer_threads))
166
+
167
+ self.arrow_cpu = None if arrow_cpu is None else max(1, int(arrow_cpu))
168
+ self.partitions_per_round = max(1, int(partitions_per_round))
169
+ self.max_delete_workers = max(1, int(max_delete_workers))
170
+ self.write_gate_max = max(1, int(write_gate_max))
171
+ self.write_gate_key = (write_gate_key or self.parquet_storage_path).rstrip("/")
172
+ self.partition_on = partition_on or []
173
+ self.scheduler = kwargs.get("scheduler", None)
174
+ self.force_local = kwargs.get("force_local", False)
175
+ self.schema = schema
176
+
177
+ self.pyarrow_args.setdefault("compression", "zstd")
178
+
179
+ self.protocol = "file"
180
+ if "://" in self.parquet_storage_path:
181
+ self.protocol = self.parquet_storage_path.split(":", 1)[0]
182
+
183
+ def save_to_parquet(
184
+ self, output_directory_name: str = "default_output", overwrite: bool = True
185
+ ) -> str:
186
+ import dask.config
187
+ from contextlib import nullcontext
188
+
189
+ # 1. Determine Execution Mode
190
+ assert self.fs is not None, "FileSystem (fs) must be available"
191
+ using_distributed = False
192
+ if self.force_local:
193
+ using_distributed = False
194
+ elif self.dask_client is not None:
195
+ using_distributed = True
196
+ else:
197
+ try:
198
+ from dask.distributed import get_client
199
+
200
+ get_client()
201
+ using_distributed = True
202
+ except (ImportError, ValueError):
203
+ pass
204
+ except Exception:
205
+ # Fallback if get_client() fails weirdly
206
+ pass
207
+
208
+ # 2. Configure Context Factories
209
+ # We use functions/lambdas to create FRESH context managers every time 'with' is called.
210
+ if using_distributed:
211
+ # Distributed: No-op contexts
212
+ def get_config_context():
213
+ return nullcontext()
214
+
215
+ def get_pool_context():
216
+ return nullcontext()
217
+
218
+ else:
219
+ # Local: Use context manager for 'tasks' shuffle, DO NOT SET GLOBAL POOL
220
+ def get_config_context():
221
+ return dask.config.set({"dataframe.shuffle.method": "tasks"})
222
+
223
+ # For local pool, we pass the pool directly to compute() or allow Dask to handle it.
224
+ # However, since we are doing ddf.to_parquet(), we can't easily inject the pool argument
225
+ # without a global config context if not using compute().
226
+ # Safer approach: Use a context manager for the pool too, but ensure it restores state.
227
+ def get_pool_context():
228
+ if self.scheduler == "synchronous":
229
+ return dask.config.set(scheduler="synchronous")
230
+ elif self.scheduler == "threads":
231
+ return dask.config.set(
232
+ pool=ThreadPool(self.writer_threads), scheduler="threads"
233
+ )
234
+ elif not using_distributed and self.scheduler is None:
235
+ # Default to single-threaded for local execution to override any global client
236
+ # and avoid graph dependencies mismatch (Missing Dependency error)
237
+ return dask.config.set(scheduler="single-threaded")
238
+ # If scheduler is None and using_distributed is True, rely on global config
239
+ return nullcontext()
240
+
241
+ # 3. Path Resolution
242
+ if self.partition_on:
243
+ overwrite = False
244
+ target_path = self.parquet_storage_path.rstrip("/")
245
+ else:
246
+ target_path = f"{self.parquet_storage_path}/{output_directory_name}".rstrip(
247
+ "/"
248
+ )
249
+
250
+ sem = get_write_sem(self.write_gate_key, self.write_gate_max)
251
+
252
+ # 4. Execution
253
+ # We use the factory to get a fresh context for the config block
254
+ with get_config_context():
255
+ with sem:
256
+ if overwrite and self.fs.exists(target_path):
257
+ self._clear_directory_safely(target_path)
258
+ self.fs.mkdirs(target_path, exist_ok=True)
259
+
260
+ schema = self._define_schema()
261
+ ddf = self._coerce_ddf_to_schema(self.df_result, schema)
262
+
263
+ if self.repartition_size:
264
+ ddf = ddf.repartition(partition_size=self.repartition_size)
265
+
266
+ # Persist: Request a FRESH pool context
267
+ if self.persist:
268
+ with get_pool_context():
269
+ if using_distributed:
270
+ ddf = safe_persist(ddf)
271
+ else:
272
+ # Local Persistence: Avoid safe_persist to prevent Client resurrection
273
+ ddf = ddf.persist()
274
+
275
+ old_arrow_cpu = None
276
+ if self.arrow_cpu:
277
+ old_arrow_cpu = pa.get_cpu_count()
278
+ pa.set_cpu_count(self.arrow_cpu)
279
+
280
+ try:
281
+ params = {
282
+ "path": target_path,
283
+ "engine": "pyarrow",
284
+ "filesystem": self.fs,
285
+ "write_index": self.write_index,
286
+ "write_metadata_file": self.write_metadata_file,
287
+ "schema": schema,
288
+ **self.pyarrow_args,
289
+ }
290
+
291
+ if self.partition_on:
292
+ params["partition_on"] = self.partition_on
293
+
294
+ # Write: Compute with explicit scheduler if determined
295
+ # Logic: If we determined a local scheduler (single-threaded, threads, synchronous),
296
+ # we should pass it explicitly to compute() to override any global/default client.
297
+ # CRITICAL FIX: Use compute=False and explicit dask.compute() to enforce scheduler
298
+ write_task = ddf.to_parquet(**params, compute=False)
299
+
300
+ compute_kwargs = {}
301
+ if not using_distributed:
302
+ if self.scheduler == "synchronous":
303
+ compute_kwargs["scheduler"] = "synchronous"
304
+ elif self.scheduler == "threads":
305
+ compute_kwargs["scheduler"] = "threads"
306
+ # Note: Pool logic for threads is handled by get_pool_context if needed,
307
+ # but explicit scheduler arg helps.
308
+ elif self.scheduler == "single-threaded":
309
+ compute_kwargs["scheduler"] = "single-threaded"
310
+ elif self.scheduler is None:
311
+ # Default fallback for force_local
312
+ compute_kwargs["scheduler"] = "single-threaded"
313
+
314
+ with get_pool_context():
315
+ dask.compute(write_task, **compute_kwargs)
316
+ finally:
317
+ if old_arrow_cpu is not None:
318
+ pa.set_cpu_count(old_arrow_cpu)
319
+
320
+ self.logger.info(
321
+ f"Parquet dataset written: {target_path}", extra=self.logger_extra
322
+ )
323
+ return target_path
324
+
325
+ # ---------- Internals ----------
326
+
327
+ def _calculate_optimal_threads(self) -> int:
328
+ try:
329
+ if "OMP_NUM_THREADS" in os.environ:
330
+ available_cores = int(os.environ["OMP_NUM_THREADS"])
331
+ elif hasattr(os, "sched_getaffinity"):
332
+ available_cores = len(os.sched_getaffinity(0))
333
+ else:
334
+ available_cores = os.cpu_count() or 4
335
+ recommended = available_cores * 3
336
+ return max(4, min(recommended, 32))
337
+ except Exception:
338
+ return 8
339
+
340
+ def _clear_directory_safely(self, directory: str) -> None:
341
+ if self.protocol.startswith("s3"):
342
+ entries = [p for p in self.fs.glob(f"{directory}/**") if p != directory]
343
+ if not entries:
344
+ return
345
+
346
+ def _rm_one(p: str) -> None:
347
+ try:
348
+ self.fs.rm_file(p)
349
+ except Exception as e:
350
+ self.logger.warning(
351
+ f"Delete failed '{p}': {e}", extra=self.logger_extra
352
+ )
353
+
354
+ with ThreadPoolExecutor(max_workers=self.max_delete_workers) as ex:
355
+ list(ex.map(_rm_one, entries))
356
+ try:
357
+ self.fs.rm(directory, recursive=False)
358
+ except Exception:
359
+ pass
360
+ else:
361
+ self.fs.rm(directory, recursive=True)
362
+
363
+ def _define_schema(self) -> pa.Schema:
364
+ fields = []
365
+ # Pre-calculate overrides from self.schema if provided
366
+ overrides = {f.name: f.type for f in self.schema} if self.schema else {}
367
+
368
+ for name, dtype in self.df_result.dtypes.items():
369
+ if name in overrides:
370
+ fields.append(pa.field(name, overrides[name]))
371
+ continue
372
+
373
+ if name in self.partition_on:
374
+ fields.append(pa.field(name, pa.string()))
375
+ continue
376
+
377
+ dtype_str = str(dtype).lower()
378
+ lower_name = name.lower()
379
+
380
+ # --- PRIORITY 1: Heuristics ---
381
+ if any(x in lower_name for x in ["_dt", "_date", "_at", "time"]):
382
+ pa_type = pa.timestamp("ns", tz="UTC")
383
+ elif lower_name.startswith(("is_", "has_", "flag_")):
384
+ pa_type = pa.bool_()
385
+ elif (
386
+ any(x in lower_name for x in ["_id", "id"]) and "guid" not in lower_name
387
+ ):
388
+ pa_type = pa.int64()
389
+ elif any(
390
+ x in lower_name for x in ["amount", "price", "score", "percent", "rate"]
391
+ ):
392
+ pa_type = pa.float64()
393
+
394
+ # --- PRIORITY 2: Dtypes ---
395
+ elif "int" in dtype_str:
396
+ pa_type = pa.int64()
397
+ elif "float" in dtype_str or "double" in dtype_str:
398
+ pa_type = pa.float64()
399
+ elif "bool" in dtype_str:
400
+ pa_type = pa.bool_()
401
+ elif "datetime" in dtype_str or "timestamp" in dtype_str:
402
+ pa_type = pa.timestamp("ns", tz="UTC")
403
+ else:
404
+ pa_type = pa.string()
405
+
406
+ fields.append(pa.field(name, pa_type))
407
+
408
+ return pa.schema(fields)
409
+
410
+ def _coerce_ddf_to_schema(
411
+ self, ddf: dd.DataFrame, schema: pa.Schema
412
+ ) -> dd.DataFrame:
413
+ target = {f.name: f.type for f in schema}
414
+
415
+ meta_cols: Dict[str, pd.Series] = {}
416
+ for name, typ in target.items():
417
+ if name in self.partition_on:
418
+ meta_cols[name] = pd.Series([], dtype="string[pyarrow]")
419
+ elif pa.types.is_timestamp(typ):
420
+ meta_cols[name] = pd.Series([], dtype="timestamp[ns, tz=UTC][pyarrow]")
421
+ elif pa.types.is_integer(typ):
422
+ meta_cols[name] = pd.Series([], dtype="int64[pyarrow]")
423
+ elif pa.types.is_floating(typ):
424
+ meta_cols[name] = pd.Series([], dtype="float64[pyarrow]")
425
+ elif pa.types.is_boolean(typ):
426
+ meta_cols[name] = pd.Series([], dtype="boolean[pyarrow]")
427
+ else:
428
+ meta_cols[name] = pd.Series([], dtype="string[pyarrow]")
429
+
430
+ new_meta = pd.DataFrame(meta_cols, index=ddf._meta.index)
431
+ new_meta = new_meta[list(target.keys())]
432
+
433
+ coerce_fn = partial(
434
+ _coerce_partition, target=target, partition_cols=self.partition_on
435
+ )
436
+ return ddf.map_partitions(coerce_fn, meta=new_meta)
@@ -0,0 +1,33 @@
1
+ # write_gatekeeper.py
2
+ from __future__ import annotations
3
+ import threading
4
+
5
+ _LOCK = threading.Lock()
6
+ _SEMS: dict[str, threading.Semaphore] = {}
7
+
8
+
9
+ def get_write_sem(key: str, max_concurrency: int) -> threading.Semaphore:
10
+ """
11
+ Acquires or creates a semaphore for a given key, ensuring thread-safe access
12
+ and maximum concurrency.
13
+
14
+ This function retrieves an existing semaphore associated with the provided
15
+ key or creates a new one if it does not exist. The semaphore limits
16
+ concurrent access based on the provided maximum concurrency value. This is
17
+ used to manage thread-level synchronization for a specific resource or
18
+ operation identified by the given key.
19
+
20
+ :param key: The unique identifier for which the semaphore is associated.
21
+ :param max_concurrency: The maximum number of concurrent access allowed. A
22
+ value less than 1 will default to a concurrency limit
23
+ of 1.
24
+ :return: A `threading.Semaphore` object for the specified key, initialized
25
+ with the maximum concurrency.
26
+ :rtype: threading.Semaphore
27
+ """
28
+ with _LOCK:
29
+ sem = _SEMS.get(key)
30
+ if sem is None:
31
+ sem = threading.Semaphore(max(1, int(max_concurrency)))
32
+ _SEMS[key] = sem
33
+ return sem