featkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- featkit/__init__.py +1 -0
- featkit/builders/.gitkeep +0 -0
- featkit/builders/__init__.py +0 -0
- featkit/builders/distributional_space.py +77 -0
- featkit/builders/pivot_space.py +102 -0
- featkit/builders/temporal_space.py +86 -0
- featkit/config.py +38 -0
- featkit/contracts/__init__.py +1 -0
- featkit/contracts/measurement/.gitkeep +0 -0
- featkit/contracts/measurement/__init__.py +27 -0
- featkit/contracts/measurement/base.py +47 -0
- featkit/contracts/measurement/defaults.py +117 -0
- featkit/contracts/output/.gitkeep +0 -0
- featkit/contracts/output/__init__.py +19 -0
- featkit/contracts/output/base.py +36 -0
- featkit/contracts/output/defaults.py +80 -0
- featkit/dataset/.gitkeep +0 -0
- featkit/dataset/__init__.py +0 -0
- featkit/dataset/base.py +120 -0
- featkit/enums.py +110 -0
- featkit/fields/.gitkeep +0 -0
- featkit/fields/__init__.py +9 -0
- featkit/fields/base.py +48 -0
- featkit/fields/categorical_field.py +55 -0
- featkit/fields/id_field.py +14 -0
- featkit/fields/measurement_field.py +42 -0
- featkit/fields/time_field.py +43 -0
- featkit/generators/__init__.py +0 -0
- featkit/generators/base.py +171 -0
- featkit/generators/output.py +118 -0
- featkit/generators/pyspark/.gitkeep +0 -0
- featkit/generators/pyspark/__init__.py +0 -0
- featkit/generators/pyspark/databricks.py +448 -0
- featkit/generators/sql/.gitkeep +0 -0
- featkit/generators/sql/__init__.py +0 -0
- featkit/generators/sql/base.py +496 -0
- featkit/generators/sql/databricks.py +19 -0
- featkit/generators/sql/snowflake.py +19 -0
- featkit/generators/sql/spark_sql.py +19 -0
- featkit/layer2/.gitkeep +0 -0
- featkit/layer2/__init__.py +0 -0
- featkit/layer2/base.py +86 -0
- featkit/layer2/distributional.py +51 -0
- featkit/layer2/pivoted.py +63 -0
- featkit/layer3/.gitkeep +0 -0
- featkit/layer3/__init__.py +0 -0
- featkit/layer3/temporal_feature.py +87 -0
- featkit/pipeline.py +63 -0
- featkit-0.1.0.dist-info/METADATA +140 -0
- featkit-0.1.0.dist-info/RECORD +52 -0
- featkit-0.1.0.dist-info/WHEEL +4 -0
- featkit-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,496 @@
|
|
|
1
|
+
"""AbstractSQLCodeGenerator — dialect-agnostic SQL generation via SQLGlot."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from abc import abstractmethod
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import sqlglot
|
|
11
|
+
import sqlglot.expressions as exp
|
|
12
|
+
|
|
13
|
+
from featkit.enums import DistributionalMetric, TemporalOperator, TimeWindowDirection
|
|
14
|
+
from featkit.generators.base import AbstractCodeGenerator
|
|
15
|
+
from featkit.generators.output import SQLOutput
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from featkit.layer2.distributional import DistributionalColumn
|
|
19
|
+
from featkit.layer3.temporal_feature import TemporalFeature
|
|
20
|
+
from featkit.pipeline import FeatureStorePipeline
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AbstractSQLCodeGenerator(AbstractCodeGenerator):
|
|
24
|
+
"""Base for all SQL-emitting code generators.
|
|
25
|
+
|
|
26
|
+
Subclasses must declare :attr:`dialect`. All five build steps are
|
|
27
|
+
implemented here by composing SQL strings that are validated and
|
|
28
|
+
normalised through SQLGlot so that dialect-specific rendering is applied
|
|
29
|
+
automatically.
|
|
30
|
+
|
|
31
|
+
Both the schema and table portions of every generated table reference are
|
|
32
|
+
double-quoted, guarding against reserved words and special characters in
|
|
33
|
+
user-supplied names. Categorical values used in CASE WHEN predicates are
|
|
34
|
+
rendered as SQLGlot :class:`~sqlglot.expressions.Literal` objects so that
|
|
35
|
+
single quotes and other special characters are escaped correctly.
|
|
36
|
+
|
|
37
|
+
MOB reference table structure (period self-join):
|
|
38
|
+
|
|
39
|
+
.. code-block:: text
|
|
40
|
+
|
|
41
|
+
periodos_unicos → periodos_ordenados (ROW_NUMBER by period)
|
|
42
|
+
periodos_ordenados A LEFT JOIN periodos_ordenados B ON 1=1
|
|
43
|
+
→ (ts_analysis, ts_relative, mob = B.mob - A.mob)
|
|
44
|
+
|
|
45
|
+
Layer 3 features are then derived by joining with the MOB table and using
|
|
46
|
+
GROUP BY + CASE WHEN aggregation to implement each temporal operator.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def dialect(self) -> str:
|
|
52
|
+
"""SQLGlot dialect identifier (e.g. ``"snowflake"``, ``"databricks"``)."""
|
|
53
|
+
...
|
|
54
|
+
|
|
55
|
+
def render(self, expr: exp.Expr) -> str:
|
|
56
|
+
"""Render a SQLGlot expression tree to the target dialect SQL string."""
|
|
57
|
+
return str(expr.sql(dialect=self.dialect))
|
|
58
|
+
|
|
59
|
+
# ------------------------------------------------------------------
|
|
60
|
+
# Internal helpers
|
|
61
|
+
# ------------------------------------------------------------------
|
|
62
|
+
|
|
63
|
+
def _tbl(self, pipeline: FeatureStorePipeline, suffix: str) -> str:
|
|
64
|
+
"""Return a fully-qualified, double-quoted intermediate table reference.
|
|
65
|
+
|
|
66
|
+
Both the schema and table identifiers are quoted through SQLGlot so
|
|
67
|
+
reserved words and special characters are rendered safely per dialect.
|
|
68
|
+
"""
|
|
69
|
+
cfg = pipeline.config
|
|
70
|
+
return self.render(
|
|
71
|
+
sqlglot.exp.Table(
|
|
72
|
+
this=sqlglot.exp.to_identifier(f"{cfg.output_table_prefix}{suffix}", quoted=True),
|
|
73
|
+
db=sqlglot.exp.to_identifier(cfg.output_schema, quoted=True),
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def _transpile(self, sql: str) -> str:
|
|
78
|
+
"""Parse *sql* in :attr:`dialect` and re-emit it (pretty-printed)."""
|
|
79
|
+
return sqlglot.transpile(sql, read=self.dialect, write=self.dialect, pretty=True)[0]
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def _safe_cte_name(name: str) -> str:
|
|
83
|
+
"""Return a SQL-safe CTE identifier from *name*.
|
|
84
|
+
|
|
85
|
+
Replaces any character outside ``[A-Za-z0-9_]`` with ``_`` so that
|
|
86
|
+
field names containing hyphens, spaces, or other special characters
|
|
87
|
+
produce valid unquoted identifiers.
|
|
88
|
+
"""
|
|
89
|
+
return re.sub(r"[^A-Za-z0-9]", "_", name)
|
|
90
|
+
|
|
91
|
+
def _str_literal(self, value: str) -> str:
|
|
92
|
+
"""Render *value* as a correctly escaped SQL string literal."""
|
|
93
|
+
return exp.Literal.string(value).sql(dialect=self.dialect)
|
|
94
|
+
|
|
95
|
+
def _quoted_id(self, name: str) -> str:
|
|
96
|
+
"""Render *name* as a quoted SQL identifier."""
|
|
97
|
+
return exp.Identifier(this=name, quoted=True).sql(dialect=self.dialect)
|
|
98
|
+
|
|
99
|
+
# ------------------------------------------------------------------
|
|
100
|
+
# build_mob_table
|
|
101
|
+
# ------------------------------------------------------------------
|
|
102
|
+
|
|
103
|
+
def build_mob_table(self, pipeline: FeatureStorePipeline) -> SQLOutput:
|
|
104
|
+
"""Generate the MOB (Month-on-Books) period cross-reference table.
|
|
105
|
+
|
|
106
|
+
The implementation follows the period self-join pattern:
|
|
107
|
+
|
|
108
|
+
1. ``periodos_unicos`` — distinct periods from the source table.
|
|
109
|
+
2. ``periodos_ordenados`` — assigns a sequential ``mob`` number to each
|
|
110
|
+
period via ``ROW_NUMBER() OVER (ORDER BY {time_col})``.
|
|
111
|
+
3. Self-join on ``1 = 1`` to produce all (analysis, relative) period
|
|
112
|
+
combinations.
|
|
113
|
+
4. ``mob = B.mob - A.mob`` gives the signed period offset: 0 = current,
|
|
114
|
+
negative = past, positive = future.
|
|
115
|
+
|
|
116
|
+
The resulting table has columns:
|
|
117
|
+
``{time_col}_analysis``, ``{time_col}_relative``, ``mob``.
|
|
118
|
+
"""
|
|
119
|
+
ds = pipeline.config.dataset
|
|
120
|
+
time_col = ds.time_field.name
|
|
121
|
+
src = ds.source_reference
|
|
122
|
+
tbl = self._tbl(pipeline, "mob_ref")
|
|
123
|
+
|
|
124
|
+
time_analysis = f"{time_col}_analysis"
|
|
125
|
+
time_relative = f"{time_col}_relative"
|
|
126
|
+
|
|
127
|
+
sql = (
|
|
128
|
+
f"CREATE OR REPLACE TABLE {tbl} AS\n"
|
|
129
|
+
f"WITH periodos_unicos AS (\n"
|
|
130
|
+
f" SELECT DISTINCT {time_col}\n"
|
|
131
|
+
f" FROM {src}\n"
|
|
132
|
+
f"),\n"
|
|
133
|
+
f"periodos_ordenados AS (\n"
|
|
134
|
+
f" SELECT {time_col},\n"
|
|
135
|
+
f" ROW_NUMBER() OVER (ORDER BY {time_col}) AS mob\n"
|
|
136
|
+
f" FROM periodos_unicos\n"
|
|
137
|
+
f")\n"
|
|
138
|
+
f"SELECT\n"
|
|
139
|
+
f" a.{time_col} AS {time_analysis},\n"
|
|
140
|
+
f" b.{time_col} AS {time_relative},\n"
|
|
141
|
+
f" b.mob - a.mob AS mob\n"
|
|
142
|
+
f"FROM periodos_ordenados a\n"
|
|
143
|
+
f"LEFT JOIN periodos_ordenados b ON 1 = 1"
|
|
144
|
+
)
|
|
145
|
+
return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
|
|
146
|
+
|
|
147
|
+
# ------------------------------------------------------------------
|
|
148
|
+
# build_layer2a
|
|
149
|
+
# ------------------------------------------------------------------
|
|
150
|
+
|
|
151
|
+
def build_layer2a(self, pipeline: FeatureStorePipeline) -> SQLOutput:
|
|
152
|
+
"""Generate the Layer 2A pivot aggregation table.
|
|
153
|
+
|
|
154
|
+
Categorical values in CASE WHEN predicates are rendered via
|
|
155
|
+
:meth:`_str_literal` to escape special characters and prevent
|
|
156
|
+
SQL injection in the generated script. Column identifiers are
|
|
157
|
+
double-quoted via :meth:`_quoted_id`.
|
|
158
|
+
"""
|
|
159
|
+
ds = pipeline.config.dataset
|
|
160
|
+
id_cols = [f.name for f in ds.id_fields]
|
|
161
|
+
time_col = ds.time_field.name
|
|
162
|
+
src = ds.source_reference
|
|
163
|
+
tbl = self._tbl(pipeline, "layer2a")
|
|
164
|
+
|
|
165
|
+
select_parts: list[str] = list(id_cols) + [time_col]
|
|
166
|
+
|
|
167
|
+
for col in pipeline.layer2a:
|
|
168
|
+
meas = col.source_measurement.name
|
|
169
|
+
agg = col.layer2_aggregator.value
|
|
170
|
+
alias = col.column_name
|
|
171
|
+
|
|
172
|
+
conditions = [
|
|
173
|
+
f"{self._quoted_id(cat_field.name)} = {self._str_literal(cat_val)}"
|
|
174
|
+
for cat_field, cat_val in sorted(
|
|
175
|
+
col.categorical_combination.items(), key=lambda kv: kv[0].name
|
|
176
|
+
)
|
|
177
|
+
if cat_val is not None
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
if conditions:
|
|
181
|
+
predicate = " AND ".join(conditions)
|
|
182
|
+
agg_expr = f"{agg}(CASE WHEN {predicate} THEN {meas} END)"
|
|
183
|
+
else:
|
|
184
|
+
agg_expr = f"{agg}({meas})"
|
|
185
|
+
|
|
186
|
+
select_parts.append(f"{agg_expr} AS {alias}")
|
|
187
|
+
|
|
188
|
+
group_cols = ", ".join(id_cols + [time_col])
|
|
189
|
+
select_list = ",\n ".join(select_parts)
|
|
190
|
+
|
|
191
|
+
sql = (
|
|
192
|
+
f"CREATE OR REPLACE TABLE {tbl} AS\n"
|
|
193
|
+
f"SELECT\n {select_list}\n"
|
|
194
|
+
f"FROM {src}\n"
|
|
195
|
+
f"GROUP BY {group_cols}"
|
|
196
|
+
)
|
|
197
|
+
return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
|
|
198
|
+
|
|
199
|
+
# ------------------------------------------------------------------
|
|
200
|
+
# build_layer2b
|
|
201
|
+
# ------------------------------------------------------------------
|
|
202
|
+
|
|
203
|
+
def build_layer2b(self, pipeline: FeatureStorePipeline) -> SQLOutput:
|
|
204
|
+
"""Generate the Layer 2B distributional CTEs table.
|
|
205
|
+
|
|
206
|
+
For each (categorical, measurement, aggregator) group three CTEs are
|
|
207
|
+
produced:
|
|
208
|
+
|
|
209
|
+
* ``{safe}_raw`` — ``GROUP BY`` that computes per-category aggregates
|
|
210
|
+
(``cat_val``).
|
|
211
|
+
* ``{safe}_shares`` — window function over ``{safe}_raw`` that adds
|
|
212
|
+
``total_val = SUM(cat_val) OVER (PARTITION BY ids, time_col)``.
|
|
213
|
+
Splitting into two CTEs avoids the invalid Snowflake pattern of
|
|
214
|
+
nesting an aggregate inside a window aggregate
|
|
215
|
+
(``SUM(SUM(x)) OVER (...)``).
|
|
216
|
+
* ``{safe}_metrics`` — computes the requested distributional statistics
|
|
217
|
+
from ``{safe}_shares``.
|
|
218
|
+
|
|
219
|
+
All metrics CTEs are joined back to a ``base`` (DISTINCT ids × ts) CTE
|
|
220
|
+
via ``LEFT JOIN``.
|
|
221
|
+
|
|
222
|
+
CTE names are sanitised via :meth:`_safe_cte_name` to handle field
|
|
223
|
+
names that contain special characters.
|
|
224
|
+
"""
|
|
225
|
+
if not pipeline.layer2b:
|
|
226
|
+
return SQLOutput(sql="", dialect=self.dialect)
|
|
227
|
+
|
|
228
|
+
ds = pipeline.config.dataset
|
|
229
|
+
id_cols = [f.name for f in ds.id_fields]
|
|
230
|
+
time_col = ds.time_field.name
|
|
231
|
+
src = ds.source_reference
|
|
232
|
+
tbl = self._tbl(pipeline, "layer2b")
|
|
233
|
+
|
|
234
|
+
id_list = ", ".join(id_cols)
|
|
235
|
+
b_id_sel = ", ".join(f"b.{c}" for c in id_cols)
|
|
236
|
+
|
|
237
|
+
# Group distributional columns by (cat_name, meas_name, agg_name)
|
|
238
|
+
groups: dict[tuple[str, str, str], list[DistributionalColumn]] = defaultdict(list)
|
|
239
|
+
for col in pipeline.layer2b:
|
|
240
|
+
key = (col.categorical.name, col.source_measurement.name, col.layer2_aggregator.value)
|
|
241
|
+
groups[key].append(col)
|
|
242
|
+
|
|
243
|
+
cte_defs: list[str] = []
|
|
244
|
+
metrics_cte_names: list[str] = []
|
|
245
|
+
|
|
246
|
+
for (cat_name, meas_name, agg_name), cols in groups.items():
|
|
247
|
+
safe = "_".join(self._safe_cte_name(p) for p in [cat_name, meas_name, agg_name.lower()])
|
|
248
|
+
raw_cte = safe + "_raw"
|
|
249
|
+
shares_cte = safe + "_shares"
|
|
250
|
+
metrics_cte = safe + "_metrics"
|
|
251
|
+
metrics_cte_names.append(metrics_cte)
|
|
252
|
+
|
|
253
|
+
# Step 1: group-level aggregation (no nested aggregate)
|
|
254
|
+
cte_defs.append(
|
|
255
|
+
f"{raw_cte} AS (\n"
|
|
256
|
+
f" SELECT\n"
|
|
257
|
+
f" {id_list},\n"
|
|
258
|
+
f" {time_col},\n"
|
|
259
|
+
f" {cat_name},\n"
|
|
260
|
+
f" {agg_name}({meas_name}) AS cat_val\n"
|
|
261
|
+
f" FROM {src}\n"
|
|
262
|
+
f" GROUP BY {id_list}, {time_col}, {cat_name}\n"
|
|
263
|
+
f")"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Step 2: window function on already-aggregated cat_val
|
|
267
|
+
cte_defs.append(
|
|
268
|
+
f"{shares_cte} AS (\n"
|
|
269
|
+
f" SELECT\n"
|
|
270
|
+
f" *,\n"
|
|
271
|
+
f" SUM(cat_val) OVER (PARTITION BY {id_list}, {time_col}) AS total_val\n"
|
|
272
|
+
f" FROM {raw_cte}\n"
|
|
273
|
+
f")"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Step 3: distributional metrics
|
|
277
|
+
metric_exprs = [
|
|
278
|
+
" "
|
|
279
|
+
+ self._distributional_expr(col.distributional_metric, cat_name, col.column_name)
|
|
280
|
+
for col in cols
|
|
281
|
+
]
|
|
282
|
+
|
|
283
|
+
cte_defs.append(
|
|
284
|
+
f"{metrics_cte} AS (\n"
|
|
285
|
+
f" SELECT\n"
|
|
286
|
+
f" {id_list},\n"
|
|
287
|
+
f" {time_col},\n" + ",\n".join(metric_exprs) + f"\n FROM {shares_cte}\n"
|
|
288
|
+
f" GROUP BY {id_list}, {time_col}\n"
|
|
289
|
+
f")"
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Prepend the base CTE
|
|
293
|
+
base_cte = f"base AS (\n SELECT DISTINCT {id_list}, {time_col} FROM {src}\n)"
|
|
294
|
+
all_cte_defs = [base_cte] + cte_defs
|
|
295
|
+
|
|
296
|
+
all_metric_cols = ",\n ".join(col.column_name for col in pipeline.layer2b)
|
|
297
|
+
|
|
298
|
+
joins = "\n".join(
|
|
299
|
+
f"LEFT JOIN {mc} USING ({id_list}, {time_col})" for mc in metrics_cte_names
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
sql = (
|
|
303
|
+
f"CREATE OR REPLACE TABLE {tbl} AS\n"
|
|
304
|
+
f"WITH\n" + ",\n".join(all_cte_defs) + f"\nSELECT\n"
|
|
305
|
+
f" {b_id_sel},\n"
|
|
306
|
+
f" b.{time_col},\n"
|
|
307
|
+
f" {all_metric_cols}\n"
|
|
308
|
+
f"FROM base b\n"
|
|
309
|
+
f"{joins}"
|
|
310
|
+
)
|
|
311
|
+
return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
|
|
312
|
+
|
|
313
|
+
def _distributional_expr(self, metric: DistributionalMetric, cat_col: str, alias: str) -> str:
|
|
314
|
+
"""Return a SQL aggregate expression for one distributional metric.
|
|
315
|
+
|
|
316
|
+
Operates on columns from the *shares* CTE:
|
|
317
|
+
``cat_val`` = per-category aggregate, ``total_val`` = entity×period total.
|
|
318
|
+
"""
|
|
319
|
+
if metric == DistributionalMetric.ENTROPY:
|
|
320
|
+
return (
|
|
321
|
+
"-SUM(CASE WHEN cat_val > 0 "
|
|
322
|
+
"THEN (cat_val / NULLIF(total_val, 0)) * LN(cat_val / NULLIF(total_val, 0)) "
|
|
323
|
+
f"ELSE 0 END) AS {alias}"
|
|
324
|
+
)
|
|
325
|
+
if metric == DistributionalMetric.HHI:
|
|
326
|
+
return f"SUM(POWER(cat_val / NULLIF(total_val, 0), 2)) AS {alias}"
|
|
327
|
+
if metric == DistributionalMetric.DOMINANT_PROPORTION:
|
|
328
|
+
return f"MAX(cat_val / NULLIF(total_val, 0)) AS {alias}"
|
|
329
|
+
if metric == DistributionalMetric.MODE:
|
|
330
|
+
return f"MAX_BY({cat_col}, cat_val) AS {alias}"
|
|
331
|
+
if metric == DistributionalMetric.COUNT:
|
|
332
|
+
return f"COUNT(CASE WHEN cat_val > 0 THEN 1 END) AS {alias}"
|
|
333
|
+
raise ValueError(f"Unsupported distributional metric: {metric}")
|
|
334
|
+
|
|
335
|
+
# ------------------------------------------------------------------
|
|
336
|
+
# build_layer3
|
|
337
|
+
# ------------------------------------------------------------------
|
|
338
|
+
|
|
339
|
+
def build_layer3(self, pipeline: FeatureStorePipeline) -> SQLOutput:
|
|
340
|
+
"""Generate the Layer 3 temporal features table.
|
|
341
|
+
|
|
342
|
+
Joins the MOB reference table with Layer 2A (and Layer 2B if present),
|
|
343
|
+
then derives every :class:`~featkit.layer3.temporal_feature.TemporalFeature`
|
|
344
|
+
via a ``GROUP BY (id_cols, ts_analysis)`` aggregation. Each operator is
|
|
345
|
+
expressed as an aggregate over a ``CASE WHEN mob BETWEEN … THEN col END``
|
|
346
|
+
filter, so the approach scales to any window size without window-function
|
|
347
|
+
``ROWS BETWEEN`` clauses.
|
|
348
|
+
|
|
349
|
+
The join between Layer 2A/B and the MOB table is on
|
|
350
|
+
``l2a.{time_col} = mob.{time_col}_relative`` so that the correct
|
|
351
|
+
relative-period values are aggregated for each analysis snapshot.
|
|
352
|
+
"""
|
|
353
|
+
ds = pipeline.config.dataset
|
|
354
|
+
id_cols = [f.name for f in ds.id_fields]
|
|
355
|
+
time_col = ds.time_field.name
|
|
356
|
+
mob_tbl = self._tbl(pipeline, "mob_ref")
|
|
357
|
+
l2a_tbl = self._tbl(pipeline, "layer2a")
|
|
358
|
+
l2b_tbl = self._tbl(pipeline, "layer2b")
|
|
359
|
+
tbl = self._tbl(pipeline, "layer3")
|
|
360
|
+
|
|
361
|
+
time_relative = f"{time_col}_relative"
|
|
362
|
+
time_analysis = f"{time_col}_analysis"
|
|
363
|
+
|
|
364
|
+
l2a_id_sel = ", ".join(f"l2a.{c}" for c in id_cols)
|
|
365
|
+
group_by = ", ".join([f"l2a.{c}" for c in id_cols] + [f"mob.{time_analysis}"])
|
|
366
|
+
|
|
367
|
+
select_parts: list[str] = [l2a_id_sel, f"mob.{time_analysis} AS {time_col}"]
|
|
368
|
+
for feat in pipeline.layer3:
|
|
369
|
+
expr = self._temporal_expr(feat)
|
|
370
|
+
select_parts.append(f"{expr} AS {feat.column_name}")
|
|
371
|
+
|
|
372
|
+
select_list = ",\n ".join(select_parts)
|
|
373
|
+
|
|
374
|
+
# Layer 2B join: use l2a's (id, ts) since l2a.ts = mob.ts_relative already
|
|
375
|
+
if pipeline.layer2b:
|
|
376
|
+
l2b_join_conds = " AND ".join(
|
|
377
|
+
[f"l2b.{c} = l2a.{c}" for c in id_cols] + [f"l2b.{time_col} = l2a.{time_col}"]
|
|
378
|
+
)
|
|
379
|
+
l2b_join = f"\nLEFT JOIN {l2b_tbl} l2b ON {l2b_join_conds}"
|
|
380
|
+
else:
|
|
381
|
+
l2b_join = ""
|
|
382
|
+
|
|
383
|
+
sql = (
|
|
384
|
+
f"CREATE OR REPLACE TABLE {tbl} AS\n"
|
|
385
|
+
f"SELECT\n {select_list}\n"
|
|
386
|
+
f"FROM {mob_tbl} mob\n"
|
|
387
|
+
f"JOIN {l2a_tbl} l2a ON l2a.{time_col} = mob.{time_relative}"
|
|
388
|
+
f"{l2b_join}\n"
|
|
389
|
+
f"GROUP BY {group_by}"
|
|
390
|
+
)
|
|
391
|
+
return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
|
|
392
|
+
|
|
393
|
+
def _temporal_expr(self, feat: TemporalFeature, mob_col: str = "mob.mob") -> str:
|
|
394
|
+
"""Return a GROUP BY-compatible aggregate expression for one temporal feature.
|
|
395
|
+
|
|
396
|
+
Uses ``CASE WHEN {mob_col} BETWEEN … THEN col END`` inside an aggregate
|
|
397
|
+
function so that each operator maps to a standard SQL aggregate over the
|
|
398
|
+
appropriate period range. Backward windows use negative mob offsets
|
|
399
|
+
(current = 0, one period back = -1, etc.); forward windows use positive
|
|
400
|
+
offsets.
|
|
401
|
+
"""
|
|
402
|
+
from featkit.layer2.distributional import DistributionalColumn
|
|
403
|
+
|
|
404
|
+
src_col = feat.source.column_name
|
|
405
|
+
prefix = "l2b" if isinstance(feat.source, DistributionalColumn) else "l2a"
|
|
406
|
+
col = f"{prefix}.{src_col}"
|
|
407
|
+
|
|
408
|
+
op = feat.operator
|
|
409
|
+
w = feat.window_size
|
|
410
|
+
bwd = feat.direction == TimeWindowDirection.BACKWARD
|
|
411
|
+
|
|
412
|
+
if w is not None:
|
|
413
|
+
if bwd:
|
|
414
|
+
lo, hi = -(w - 1), 0
|
|
415
|
+
else:
|
|
416
|
+
lo, hi = 0, w - 1
|
|
417
|
+
in_window = f"{mob_col} BETWEEN {lo} AND {hi}"
|
|
418
|
+
case_col = f"CASE WHEN {in_window} THEN {col} END"
|
|
419
|
+
case_notnull = f"CASE WHEN {in_window} AND {col} IS NOT NULL THEN 1 END"
|
|
420
|
+
|
|
421
|
+
if op == TemporalOperator.PROM_U:
|
|
422
|
+
return f"AVG({case_col})"
|
|
423
|
+
if op == TemporalOperator.PROM_P:
|
|
424
|
+
return f"AVG({case_col})"
|
|
425
|
+
if op == TemporalOperator.SUM_U:
|
|
426
|
+
return f"SUM({case_col})"
|
|
427
|
+
if op == TemporalOperator.SUM_P:
|
|
428
|
+
return f"SUM({case_col})"
|
|
429
|
+
if op == TemporalOperator.MIN_U:
|
|
430
|
+
return f"MIN({case_col})"
|
|
431
|
+
if op == TemporalOperator.MAX_U:
|
|
432
|
+
return f"MAX({case_col})"
|
|
433
|
+
if op == TemporalOperator.ULT_MES:
|
|
434
|
+
return f"MAX(CASE WHEN {mob_col} = 0 THEN {col} END)"
|
|
435
|
+
if op == TemporalOperator.PREV_MES:
|
|
436
|
+
return f"MAX(CASE WHEN {mob_col} = -1 THEN {col} END)"
|
|
437
|
+
if op == TemporalOperator.CREC:
|
|
438
|
+
curr = f"MAX(CASE WHEN {mob_col} = 0 THEN {col} END)"
|
|
439
|
+
prev = f"MAX(CASE WHEN {mob_col} = -1 THEN {col} END)"
|
|
440
|
+
return f"({curr} / NULLIF({prev}, 0)) - 1"
|
|
441
|
+
if op == TemporalOperator.FREQ:
|
|
442
|
+
return f"COUNT({case_notnull})"
|
|
443
|
+
if op == TemporalOperator.XM:
|
|
444
|
+
return f"COUNT({case_notnull})"
|
|
445
|
+
if op == TemporalOperator.REC:
|
|
446
|
+
return f"-MAX(CASE WHEN {col} IS NOT NULL THEN {mob_col} END)"
|
|
447
|
+
if op == TemporalOperator.MEDIA_ABS:
|
|
448
|
+
return f"MEDIAN({case_col})"
|
|
449
|
+
if op == TemporalOperator.RATIO:
|
|
450
|
+
return f"SUM({case_col})"
|
|
451
|
+
return f"MAX(CASE WHEN {mob_col} = 0 THEN {col} END)"
|
|
452
|
+
|
|
453
|
+
# ------------------------------------------------------------------
|
|
454
|
+
# build_final_join
|
|
455
|
+
# ------------------------------------------------------------------
|
|
456
|
+
|
|
457
|
+
def build_final_join(self, pipeline: FeatureStorePipeline) -> SQLOutput:
|
|
458
|
+
"""Generate the final feature table joining Layer 2 and Layer 3."""
|
|
459
|
+
ds = pipeline.config.dataset
|
|
460
|
+
id_cols = [f.name for f in ds.id_fields]
|
|
461
|
+
time_col = ds.time_field.name
|
|
462
|
+
l2a_tbl = self._tbl(pipeline, "layer2a")
|
|
463
|
+
l2b_tbl = self._tbl(pipeline, "layer2b")
|
|
464
|
+
l3_tbl = self._tbl(pipeline, "layer3")
|
|
465
|
+
tbl = self._tbl(pipeline, "features")
|
|
466
|
+
|
|
467
|
+
id_using = ", ".join(id_cols)
|
|
468
|
+
|
|
469
|
+
# Explicit column list — avoids ambiguity from USING join semantics
|
|
470
|
+
select_parts: list[str] = (
|
|
471
|
+
[f"l2a.{c}" for c in id_cols]
|
|
472
|
+
+ [f"l2a.{time_col}"]
|
|
473
|
+
+ [f"l2a.{col.column_name}" for col in pipeline.layer2a]
|
|
474
|
+
)
|
|
475
|
+
if pipeline.layer2b:
|
|
476
|
+
select_parts += [f"l2b.{col.column_name}" for col in pipeline.layer2b]
|
|
477
|
+
if pipeline.layer3:
|
|
478
|
+
select_parts += [f"l3.{feat.column_name}" for feat in pipeline.layer3]
|
|
479
|
+
|
|
480
|
+
select_list = ",\n ".join(select_parts)
|
|
481
|
+
|
|
482
|
+
l2b_join = (
|
|
483
|
+
f"\nLEFT JOIN {l2b_tbl} l2b USING ({id_using}, {time_col})" if pipeline.layer2b else ""
|
|
484
|
+
)
|
|
485
|
+
l3_join = (
|
|
486
|
+
f"\nLEFT JOIN {l3_tbl} l3 USING ({id_using}, {time_col})" if pipeline.layer3 else ""
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
sql = (
|
|
490
|
+
f"CREATE OR REPLACE TABLE {tbl} AS\n"
|
|
491
|
+
f"SELECT\n {select_list}\n"
|
|
492
|
+
f"FROM {l2a_tbl} l2a"
|
|
493
|
+
f"{l2b_join}"
|
|
494
|
+
f"{l3_join}"
|
|
495
|
+
)
|
|
496
|
+
return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""DatabricksSQLCodeGenerator — SQL generation targeting the Databricks dialect."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from featkit.generators.sql.base import AbstractSQLCodeGenerator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DatabricksSQLCodeGenerator(AbstractSQLCodeGenerator):
|
|
9
|
+
"""SQL code generator for the Databricks dialect.
|
|
10
|
+
|
|
11
|
+
Inherits all generation logic from :class:`AbstractSQLCodeGenerator`; the
|
|
12
|
+
sole responsibility of this subclass is to declare the SQLGlot dialect
|
|
13
|
+
identifier so that all transpiled SQL is rendered with Databricks-specific
|
|
14
|
+
syntax (backtick quoting, ``NULLS LAST`` ordering, etc.).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def dialect(self) -> str:
|
|
19
|
+
return "databricks"
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""SnowflakeSQLCodeGenerator — SQL generation targeting the Snowflake dialect."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from featkit.generators.sql.base import AbstractSQLCodeGenerator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SnowflakeSQLCodeGenerator(AbstractSQLCodeGenerator):
|
|
9
|
+
"""SQL code generator for the Snowflake dialect.
|
|
10
|
+
|
|
11
|
+
Inherits all generation logic from :class:`AbstractSQLCodeGenerator`; the
|
|
12
|
+
sole responsibility of this subclass is to declare the SQLGlot dialect
|
|
13
|
+
identifier so that all transpiled SQL is rendered with Snowflake-specific
|
|
14
|
+
syntax (quoting, ``QUALIFY``, ``MAX_BY``, ``MEDIAN``, etc.).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def dialect(self) -> str:
|
|
19
|
+
return "snowflake"
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""SparkSQLCodeGenerator — SQL generation targeting the Apache Spark SQL dialect."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from featkit.generators.sql.base import AbstractSQLCodeGenerator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SparkSQLCodeGenerator(AbstractSQLCodeGenerator):
|
|
9
|
+
"""SQL code generator for the Apache Spark SQL dialect.
|
|
10
|
+
|
|
11
|
+
Inherits all generation logic from :class:`AbstractSQLCodeGenerator`; the
|
|
12
|
+
sole responsibility of this subclass is to declare the SQLGlot dialect
|
|
13
|
+
identifier so that all transpiled SQL is rendered with Spark-specific
|
|
14
|
+
syntax (backtick quoting, ``NULLS LAST`` ordering, etc.).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def dialect(self) -> str:
|
|
19
|
+
return "spark"
|
featkit/layer2/.gitkeep
ADDED
|
File without changes
|
|
File without changes
|
featkit/layer2/base.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Abstract base for Layer 2 output columns."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from featkit.contracts.measurement.defaults import get_default_contract
|
|
9
|
+
from featkit.contracts.output.defaults import get_default_output_contract
|
|
10
|
+
from featkit.enums import Layer2Aggregator, Layer2OutputType
|
|
11
|
+
from featkit.fields.measurement_field import MeasurementField
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from featkit.contracts.output.base import AbstractLayer2OutputContract
|
|
15
|
+
|
|
16
|
+
#: Separator used between parts of every Layer 2 column name.
|
|
17
|
+
#: Field names and categorical values must not contain this string.
|
|
18
|
+
COLUMN_NAME_SEP = "__"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AbstractLayer2Column(ABC):
|
|
22
|
+
"""Common base for every column in the Layer 2 horizontal concept table.
|
|
23
|
+
|
|
24
|
+
Subclasses supply the concrete ``output_type`` and ``column_name``; this
|
|
25
|
+
class derives ``output_contract`` from ``output_type`` automatically.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If ``layer2_aggregator`` is not permitted by the
|
|
29
|
+
measurement's contract, or if ``source_measurement.name``
|
|
30
|
+
contains the column name separator.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def _check_name_part(value: str, description: str) -> None:
|
|
35
|
+
"""Raise ``ValueError`` if *value* contains :data:`COLUMN_NAME_SEP`."""
|
|
36
|
+
if COLUMN_NAME_SEP in value:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"{description} {value!r} must not contain the column name separator "
|
|
39
|
+
f"{COLUMN_NAME_SEP!r}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
source_measurement: MeasurementField,
|
|
45
|
+
layer2_aggregator: Layer2Aggregator,
|
|
46
|
+
) -> None:
|
|
47
|
+
AbstractLayer2Column._check_name_part(source_measurement.name, "source_measurement.name")
|
|
48
|
+
contract = source_measurement.contract or get_default_contract(
|
|
49
|
+
source_measurement.measurement_type
|
|
50
|
+
)
|
|
51
|
+
if not contract.is_valid(layer2_aggregator):
|
|
52
|
+
valid = ", ".join(
|
|
53
|
+
a.name for a in sorted(contract.valid_layer2_aggregators, key=lambda a: a.value)
|
|
54
|
+
)
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"Layer2Aggregator.{layer2_aggregator.name} is not valid for "
|
|
57
|
+
f"MeasurementType.{source_measurement.measurement_type.name}. "
|
|
58
|
+
f"Valid aggregators: {valid}"
|
|
59
|
+
)
|
|
60
|
+
self.source_measurement = source_measurement
|
|
61
|
+
self.layer2_aggregator = layer2_aggregator
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def output_type(self) -> Layer2OutputType:
|
|
66
|
+
"""Layer 2 output type that governs valid Layer 3 temporal operators."""
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def output_contract(self) -> AbstractLayer2OutputContract:
|
|
71
|
+
"""Contract for the Layer 2 → Layer 3 boundary, derived from ``output_type``."""
|
|
72
|
+
return get_default_output_contract(self.output_type)
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def column_name(self) -> str:
|
|
77
|
+
"""Deterministic name for this column in the Layer 2 output table."""
|
|
78
|
+
...
|
|
79
|
+
|
|
80
|
+
def __repr__(self) -> str:
|
|
81
|
+
return (
|
|
82
|
+
f"{type(self).__name__}("
|
|
83
|
+
f"measurement={self.source_measurement.name!r}, "
|
|
84
|
+
f"aggregator={self.layer2_aggregator.name!r}, "
|
|
85
|
+
f"output_type={self.output_type.name!r})"
|
|
86
|
+
)
|