featkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. featkit/__init__.py +1 -0
  2. featkit/builders/.gitkeep +0 -0
  3. featkit/builders/__init__.py +0 -0
  4. featkit/builders/distributional_space.py +77 -0
  5. featkit/builders/pivot_space.py +102 -0
  6. featkit/builders/temporal_space.py +86 -0
  7. featkit/config.py +38 -0
  8. featkit/contracts/__init__.py +1 -0
  9. featkit/contracts/measurement/.gitkeep +0 -0
  10. featkit/contracts/measurement/__init__.py +27 -0
  11. featkit/contracts/measurement/base.py +47 -0
  12. featkit/contracts/measurement/defaults.py +117 -0
  13. featkit/contracts/output/.gitkeep +0 -0
  14. featkit/contracts/output/__init__.py +19 -0
  15. featkit/contracts/output/base.py +36 -0
  16. featkit/contracts/output/defaults.py +80 -0
  17. featkit/dataset/.gitkeep +0 -0
  18. featkit/dataset/__init__.py +0 -0
  19. featkit/dataset/base.py +120 -0
  20. featkit/enums.py +110 -0
  21. featkit/fields/.gitkeep +0 -0
  22. featkit/fields/__init__.py +9 -0
  23. featkit/fields/base.py +48 -0
  24. featkit/fields/categorical_field.py +55 -0
  25. featkit/fields/id_field.py +14 -0
  26. featkit/fields/measurement_field.py +42 -0
  27. featkit/fields/time_field.py +43 -0
  28. featkit/generators/__init__.py +0 -0
  29. featkit/generators/base.py +171 -0
  30. featkit/generators/output.py +118 -0
  31. featkit/generators/pyspark/.gitkeep +0 -0
  32. featkit/generators/pyspark/__init__.py +0 -0
  33. featkit/generators/pyspark/databricks.py +448 -0
  34. featkit/generators/sql/.gitkeep +0 -0
  35. featkit/generators/sql/__init__.py +0 -0
  36. featkit/generators/sql/base.py +496 -0
  37. featkit/generators/sql/databricks.py +19 -0
  38. featkit/generators/sql/snowflake.py +19 -0
  39. featkit/generators/sql/spark_sql.py +19 -0
  40. featkit/layer2/.gitkeep +0 -0
  41. featkit/layer2/__init__.py +0 -0
  42. featkit/layer2/base.py +86 -0
  43. featkit/layer2/distributional.py +51 -0
  44. featkit/layer2/pivoted.py +63 -0
  45. featkit/layer3/.gitkeep +0 -0
  46. featkit/layer3/__init__.py +0 -0
  47. featkit/layer3/temporal_feature.py +87 -0
  48. featkit/pipeline.py +63 -0
  49. featkit-0.1.0.dist-info/METADATA +140 -0
  50. featkit-0.1.0.dist-info/RECORD +52 -0
  51. featkit-0.1.0.dist-info/WHEEL +4 -0
  52. featkit-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,496 @@
1
+ """AbstractSQLCodeGenerator — dialect-agnostic SQL generation via SQLGlot."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from abc import abstractmethod
7
+ from collections import defaultdict
8
+ from typing import TYPE_CHECKING
9
+
10
+ import sqlglot
11
+ import sqlglot.expressions as exp
12
+
13
+ from featkit.enums import DistributionalMetric, TemporalOperator, TimeWindowDirection
14
+ from featkit.generators.base import AbstractCodeGenerator
15
+ from featkit.generators.output import SQLOutput
16
+
17
+ if TYPE_CHECKING:
18
+ from featkit.layer2.distributional import DistributionalColumn
19
+ from featkit.layer3.temporal_feature import TemporalFeature
20
+ from featkit.pipeline import FeatureStorePipeline
21
+
22
+
23
+ class AbstractSQLCodeGenerator(AbstractCodeGenerator):
24
+ """Base for all SQL-emitting code generators.
25
+
26
+ Subclasses must declare :attr:`dialect`. All five build steps are
27
+ implemented here by composing SQL strings that are validated and
28
+ normalised through SQLGlot so that dialect-specific rendering is applied
29
+ automatically.
30
+
31
+ Both the schema and table portions of every generated table reference are
32
+ double-quoted, guarding against reserved words and special characters in
33
+ user-supplied names. Categorical values used in CASE WHEN predicates are
34
+ rendered as SQLGlot :class:`~sqlglot.expressions.Literal` objects so that
35
+ single quotes and other special characters are escaped correctly.
36
+
37
+ MOB reference table structure (period self-join):
38
+
39
+ .. code-block:: text
40
+
41
+ periodos_unicos → periodos_ordenados (ROW_NUMBER by period)
42
+ periodos_ordenados A LEFT JOIN periodos_ordenados B ON 1=1
43
+ → (ts_analysis, ts_relative, mob = B.mob - A.mob)
44
+
45
+ Layer 3 features are then derived by joining with the MOB table and using
46
+ GROUP BY + CASE WHEN aggregation to implement each temporal operator.
47
+ """
48
+
49
+ @property
50
+ @abstractmethod
51
+ def dialect(self) -> str:
52
+ """SQLGlot dialect identifier (e.g. ``"snowflake"``, ``"databricks"``)."""
53
+ ...
54
+
55
+ def render(self, expr: exp.Expr) -> str:
56
+ """Render a SQLGlot expression tree to the target dialect SQL string."""
57
+ return str(expr.sql(dialect=self.dialect))
58
+
59
+ # ------------------------------------------------------------------
60
+ # Internal helpers
61
+ # ------------------------------------------------------------------
62
+
63
+ def _tbl(self, pipeline: FeatureStorePipeline, suffix: str) -> str:
64
+ """Return a fully-qualified, double-quoted intermediate table reference.
65
+
66
+ Both the schema and table identifiers are quoted through SQLGlot so
67
+ reserved words and special characters are rendered safely per dialect.
68
+ """
69
+ cfg = pipeline.config
70
+ return self.render(
71
+ sqlglot.exp.Table(
72
+ this=sqlglot.exp.to_identifier(f"{cfg.output_table_prefix}{suffix}", quoted=True),
73
+ db=sqlglot.exp.to_identifier(cfg.output_schema, quoted=True),
74
+ )
75
+ )
76
+
77
+ def _transpile(self, sql: str) -> str:
78
+ """Parse *sql* in :attr:`dialect` and re-emit it (pretty-printed)."""
79
+ return sqlglot.transpile(sql, read=self.dialect, write=self.dialect, pretty=True)[0]
80
+
81
+ @staticmethod
82
+ def _safe_cte_name(name: str) -> str:
83
+ """Return a SQL-safe CTE identifier from *name*.
84
+
85
+ Replaces any character outside ``[A-Za-z0-9_]`` with ``_`` so that
86
+ field names containing hyphens, spaces, or other special characters
87
+ produce valid unquoted identifiers.
88
+ """
89
+ return re.sub(r"[^A-Za-z0-9]", "_", name)
90
+
91
+ def _str_literal(self, value: str) -> str:
92
+ """Render *value* as a correctly escaped SQL string literal."""
93
+ return exp.Literal.string(value).sql(dialect=self.dialect)
94
+
95
+ def _quoted_id(self, name: str) -> str:
96
+ """Render *name* as a quoted SQL identifier."""
97
+ return exp.Identifier(this=name, quoted=True).sql(dialect=self.dialect)
98
+
99
+ # ------------------------------------------------------------------
100
+ # build_mob_table
101
+ # ------------------------------------------------------------------
102
+
103
+ def build_mob_table(self, pipeline: FeatureStorePipeline) -> SQLOutput:
104
+ """Generate the MOB (Month-on-Books) period cross-reference table.
105
+
106
+ The implementation follows the period self-join pattern:
107
+
108
+ 1. ``periodos_unicos`` — distinct periods from the source table.
109
+ 2. ``periodos_ordenados`` — assigns a sequential ``mob`` number to each
110
+ period via ``ROW_NUMBER() OVER (ORDER BY {time_col})``.
111
+ 3. Self-join on ``1 = 1`` to produce all (analysis, relative) period
112
+ combinations.
113
+ 4. ``mob = B.mob - A.mob`` gives the signed period offset: 0 = current,
114
+ negative = past, positive = future.
115
+
116
+ The resulting table has columns:
117
+ ``{time_col}_analysis``, ``{time_col}_relative``, ``mob``.
118
+ """
119
+ ds = pipeline.config.dataset
120
+ time_col = ds.time_field.name
121
+ src = ds.source_reference
122
+ tbl = self._tbl(pipeline, "mob_ref")
123
+
124
+ time_analysis = f"{time_col}_analysis"
125
+ time_relative = f"{time_col}_relative"
126
+
127
+ sql = (
128
+ f"CREATE OR REPLACE TABLE {tbl} AS\n"
129
+ f"WITH periodos_unicos AS (\n"
130
+ f" SELECT DISTINCT {time_col}\n"
131
+ f" FROM {src}\n"
132
+ f"),\n"
133
+ f"periodos_ordenados AS (\n"
134
+ f" SELECT {time_col},\n"
135
+ f" ROW_NUMBER() OVER (ORDER BY {time_col}) AS mob\n"
136
+ f" FROM periodos_unicos\n"
137
+ f")\n"
138
+ f"SELECT\n"
139
+ f" a.{time_col} AS {time_analysis},\n"
140
+ f" b.{time_col} AS {time_relative},\n"
141
+ f" b.mob - a.mob AS mob\n"
142
+ f"FROM periodos_ordenados a\n"
143
+ f"LEFT JOIN periodos_ordenados b ON 1 = 1"
144
+ )
145
+ return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
146
+
147
+ # ------------------------------------------------------------------
148
+ # build_layer2a
149
+ # ------------------------------------------------------------------
150
+
151
+ def build_layer2a(self, pipeline: FeatureStorePipeline) -> SQLOutput:
152
+ """Generate the Layer 2A pivot aggregation table.
153
+
154
+ Categorical values in CASE WHEN predicates are rendered via
155
+ :meth:`_str_literal` to escape special characters and prevent
156
+ SQL injection in the generated script. Column identifiers are
157
+ double-quoted via :meth:`_quoted_id`.
158
+ """
159
+ ds = pipeline.config.dataset
160
+ id_cols = [f.name for f in ds.id_fields]
161
+ time_col = ds.time_field.name
162
+ src = ds.source_reference
163
+ tbl = self._tbl(pipeline, "layer2a")
164
+
165
+ select_parts: list[str] = list(id_cols) + [time_col]
166
+
167
+ for col in pipeline.layer2a:
168
+ meas = col.source_measurement.name
169
+ agg = col.layer2_aggregator.value
170
+ alias = col.column_name
171
+
172
+ conditions = [
173
+ f"{self._quoted_id(cat_field.name)} = {self._str_literal(cat_val)}"
174
+ for cat_field, cat_val in sorted(
175
+ col.categorical_combination.items(), key=lambda kv: kv[0].name
176
+ )
177
+ if cat_val is not None
178
+ ]
179
+
180
+ if conditions:
181
+ predicate = " AND ".join(conditions)
182
+ agg_expr = f"{agg}(CASE WHEN {predicate} THEN {meas} END)"
183
+ else:
184
+ agg_expr = f"{agg}({meas})"
185
+
186
+ select_parts.append(f"{agg_expr} AS {alias}")
187
+
188
+ group_cols = ", ".join(id_cols + [time_col])
189
+ select_list = ",\n ".join(select_parts)
190
+
191
+ sql = (
192
+ f"CREATE OR REPLACE TABLE {tbl} AS\n"
193
+ f"SELECT\n {select_list}\n"
194
+ f"FROM {src}\n"
195
+ f"GROUP BY {group_cols}"
196
+ )
197
+ return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
198
+
199
+ # ------------------------------------------------------------------
200
+ # build_layer2b
201
+ # ------------------------------------------------------------------
202
+
203
+ def build_layer2b(self, pipeline: FeatureStorePipeline) -> SQLOutput:
204
+ """Generate the Layer 2B distributional CTEs table.
205
+
206
+ For each (categorical, measurement, aggregator) group three CTEs are
207
+ produced:
208
+
209
+ * ``{safe}_raw`` — ``GROUP BY`` that computes per-category aggregates
210
+ (``cat_val``).
211
+ * ``{safe}_shares`` — window function over ``{safe}_raw`` that adds
212
+ ``total_val = SUM(cat_val) OVER (PARTITION BY ids, time_col)``.
213
+ Splitting into two CTEs avoids the invalid Snowflake pattern of
214
+ nesting an aggregate inside a window aggregate
215
+ (``SUM(SUM(x)) OVER (...)``).
216
+ * ``{safe}_metrics`` — computes the requested distributional statistics
217
+ from ``{safe}_shares``.
218
+
219
+ All metrics CTEs are joined back to a ``base`` (DISTINCT ids × ts) CTE
220
+ via ``LEFT JOIN``.
221
+
222
+ CTE names are sanitised via :meth:`_safe_cte_name` to handle field
223
+ names that contain special characters.
224
+ """
225
+ if not pipeline.layer2b:
226
+ return SQLOutput(sql="", dialect=self.dialect)
227
+
228
+ ds = pipeline.config.dataset
229
+ id_cols = [f.name for f in ds.id_fields]
230
+ time_col = ds.time_field.name
231
+ src = ds.source_reference
232
+ tbl = self._tbl(pipeline, "layer2b")
233
+
234
+ id_list = ", ".join(id_cols)
235
+ b_id_sel = ", ".join(f"b.{c}" for c in id_cols)
236
+
237
+ # Group distributional columns by (cat_name, meas_name, agg_name)
238
+ groups: dict[tuple[str, str, str], list[DistributionalColumn]] = defaultdict(list)
239
+ for col in pipeline.layer2b:
240
+ key = (col.categorical.name, col.source_measurement.name, col.layer2_aggregator.value)
241
+ groups[key].append(col)
242
+
243
+ cte_defs: list[str] = []
244
+ metrics_cte_names: list[str] = []
245
+
246
+ for (cat_name, meas_name, agg_name), cols in groups.items():
247
+ safe = "_".join(self._safe_cte_name(p) for p in [cat_name, meas_name, agg_name.lower()])
248
+ raw_cte = safe + "_raw"
249
+ shares_cte = safe + "_shares"
250
+ metrics_cte = safe + "_metrics"
251
+ metrics_cte_names.append(metrics_cte)
252
+
253
+ # Step 1: group-level aggregation (no nested aggregate)
254
+ cte_defs.append(
255
+ f"{raw_cte} AS (\n"
256
+ f" SELECT\n"
257
+ f" {id_list},\n"
258
+ f" {time_col},\n"
259
+ f" {cat_name},\n"
260
+ f" {agg_name}({meas_name}) AS cat_val\n"
261
+ f" FROM {src}\n"
262
+ f" GROUP BY {id_list}, {time_col}, {cat_name}\n"
263
+ f")"
264
+ )
265
+
266
+ # Step 2: window function on already-aggregated cat_val
267
+ cte_defs.append(
268
+ f"{shares_cte} AS (\n"
269
+ f" SELECT\n"
270
+ f" *,\n"
271
+ f" SUM(cat_val) OVER (PARTITION BY {id_list}, {time_col}) AS total_val\n"
272
+ f" FROM {raw_cte}\n"
273
+ f")"
274
+ )
275
+
276
+ # Step 3: distributional metrics
277
+ metric_exprs = [
278
+ " "
279
+ + self._distributional_expr(col.distributional_metric, cat_name, col.column_name)
280
+ for col in cols
281
+ ]
282
+
283
+ cte_defs.append(
284
+ f"{metrics_cte} AS (\n"
285
+ f" SELECT\n"
286
+ f" {id_list},\n"
287
+ f" {time_col},\n" + ",\n".join(metric_exprs) + f"\n FROM {shares_cte}\n"
288
+ f" GROUP BY {id_list}, {time_col}\n"
289
+ f")"
290
+ )
291
+
292
+ # Prepend the base CTE
293
+ base_cte = f"base AS (\n SELECT DISTINCT {id_list}, {time_col} FROM {src}\n)"
294
+ all_cte_defs = [base_cte] + cte_defs
295
+
296
+ all_metric_cols = ",\n ".join(col.column_name for col in pipeline.layer2b)
297
+
298
+ joins = "\n".join(
299
+ f"LEFT JOIN {mc} USING ({id_list}, {time_col})" for mc in metrics_cte_names
300
+ )
301
+
302
+ sql = (
303
+ f"CREATE OR REPLACE TABLE {tbl} AS\n"
304
+ f"WITH\n" + ",\n".join(all_cte_defs) + f"\nSELECT\n"
305
+ f" {b_id_sel},\n"
306
+ f" b.{time_col},\n"
307
+ f" {all_metric_cols}\n"
308
+ f"FROM base b\n"
309
+ f"{joins}"
310
+ )
311
+ return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
312
+
313
+ def _distributional_expr(self, metric: DistributionalMetric, cat_col: str, alias: str) -> str:
314
+ """Return a SQL aggregate expression for one distributional metric.
315
+
316
+ Operates on columns from the *shares* CTE:
317
+ ``cat_val`` = per-category aggregate, ``total_val`` = entity×period total.
318
+ """
319
+ if metric == DistributionalMetric.ENTROPY:
320
+ return (
321
+ "-SUM(CASE WHEN cat_val > 0 "
322
+ "THEN (cat_val / NULLIF(total_val, 0)) * LN(cat_val / NULLIF(total_val, 0)) "
323
+ f"ELSE 0 END) AS {alias}"
324
+ )
325
+ if metric == DistributionalMetric.HHI:
326
+ return f"SUM(POWER(cat_val / NULLIF(total_val, 0), 2)) AS {alias}"
327
+ if metric == DistributionalMetric.DOMINANT_PROPORTION:
328
+ return f"MAX(cat_val / NULLIF(total_val, 0)) AS {alias}"
329
+ if metric == DistributionalMetric.MODE:
330
+ return f"MAX_BY({cat_col}, cat_val) AS {alias}"
331
+ if metric == DistributionalMetric.COUNT:
332
+ return f"COUNT(CASE WHEN cat_val > 0 THEN 1 END) AS {alias}"
333
+ raise ValueError(f"Unsupported distributional metric: {metric}")
334
+
335
+ # ------------------------------------------------------------------
336
+ # build_layer3
337
+ # ------------------------------------------------------------------
338
+
339
+ def build_layer3(self, pipeline: FeatureStorePipeline) -> SQLOutput:
340
+ """Generate the Layer 3 temporal features table.
341
+
342
+ Joins the MOB reference table with Layer 2A (and Layer 2B if present),
343
+ then derives every :class:`~featkit.layer3.temporal_feature.TemporalFeature`
344
+ via a ``GROUP BY (id_cols, ts_analysis)`` aggregation. Each operator is
345
+ expressed as an aggregate over a ``CASE WHEN mob BETWEEN … THEN col END``
346
+ filter, so the approach scales to any window size without window-function
347
+ ``ROWS BETWEEN`` clauses.
348
+
349
+ The join between Layer 2A/B and the MOB table is on
350
+ ``l2a.{time_col} = mob.{time_col}_relative`` so that the correct
351
+ relative-period values are aggregated for each analysis snapshot.
352
+ """
353
+ ds = pipeline.config.dataset
354
+ id_cols = [f.name for f in ds.id_fields]
355
+ time_col = ds.time_field.name
356
+ mob_tbl = self._tbl(pipeline, "mob_ref")
357
+ l2a_tbl = self._tbl(pipeline, "layer2a")
358
+ l2b_tbl = self._tbl(pipeline, "layer2b")
359
+ tbl = self._tbl(pipeline, "layer3")
360
+
361
+ time_relative = f"{time_col}_relative"
362
+ time_analysis = f"{time_col}_analysis"
363
+
364
+ l2a_id_sel = ", ".join(f"l2a.{c}" for c in id_cols)
365
+ group_by = ", ".join([f"l2a.{c}" for c in id_cols] + [f"mob.{time_analysis}"])
366
+
367
+ select_parts: list[str] = [l2a_id_sel, f"mob.{time_analysis} AS {time_col}"]
368
+ for feat in pipeline.layer3:
369
+ expr = self._temporal_expr(feat)
370
+ select_parts.append(f"{expr} AS {feat.column_name}")
371
+
372
+ select_list = ",\n ".join(select_parts)
373
+
374
+ # Layer 2B join: use l2a's (id, ts) since l2a.ts = mob.ts_relative already
375
+ if pipeline.layer2b:
376
+ l2b_join_conds = " AND ".join(
377
+ [f"l2b.{c} = l2a.{c}" for c in id_cols] + [f"l2b.{time_col} = l2a.{time_col}"]
378
+ )
379
+ l2b_join = f"\nLEFT JOIN {l2b_tbl} l2b ON {l2b_join_conds}"
380
+ else:
381
+ l2b_join = ""
382
+
383
+ sql = (
384
+ f"CREATE OR REPLACE TABLE {tbl} AS\n"
385
+ f"SELECT\n {select_list}\n"
386
+ f"FROM {mob_tbl} mob\n"
387
+ f"JOIN {l2a_tbl} l2a ON l2a.{time_col} = mob.{time_relative}"
388
+ f"{l2b_join}\n"
389
+ f"GROUP BY {group_by}"
390
+ )
391
+ return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
392
+
393
+ def _temporal_expr(self, feat: TemporalFeature, mob_col: str = "mob.mob") -> str:
394
+ """Return a GROUP BY-compatible aggregate expression for one temporal feature.
395
+
396
+ Uses ``CASE WHEN {mob_col} BETWEEN … THEN col END`` inside an aggregate
397
+ function so that each operator maps to a standard SQL aggregate over the
398
+ appropriate period range. Backward windows use negative mob offsets
399
+ (current = 0, one period back = -1, etc.); forward windows use positive
400
+ offsets.
401
+ """
402
+ from featkit.layer2.distributional import DistributionalColumn
403
+
404
+ src_col = feat.source.column_name
405
+ prefix = "l2b" if isinstance(feat.source, DistributionalColumn) else "l2a"
406
+ col = f"{prefix}.{src_col}"
407
+
408
+ op = feat.operator
409
+ w = feat.window_size
410
+ bwd = feat.direction == TimeWindowDirection.BACKWARD
411
+
412
+ if w is not None:
413
+ if bwd:
414
+ lo, hi = -(w - 1), 0
415
+ else:
416
+ lo, hi = 0, w - 1
417
+ in_window = f"{mob_col} BETWEEN {lo} AND {hi}"
418
+ case_col = f"CASE WHEN {in_window} THEN {col} END"
419
+ case_notnull = f"CASE WHEN {in_window} AND {col} IS NOT NULL THEN 1 END"
420
+
421
+ if op == TemporalOperator.PROM_U:
422
+ return f"AVG({case_col})"
423
+ if op == TemporalOperator.PROM_P:
424
+ return f"AVG({case_col})"
425
+ if op == TemporalOperator.SUM_U:
426
+ return f"SUM({case_col})"
427
+ if op == TemporalOperator.SUM_P:
428
+ return f"SUM({case_col})"
429
+ if op == TemporalOperator.MIN_U:
430
+ return f"MIN({case_col})"
431
+ if op == TemporalOperator.MAX_U:
432
+ return f"MAX({case_col})"
433
+ if op == TemporalOperator.ULT_MES:
434
+ return f"MAX(CASE WHEN {mob_col} = 0 THEN {col} END)"
435
+ if op == TemporalOperator.PREV_MES:
436
+ return f"MAX(CASE WHEN {mob_col} = -1 THEN {col} END)"
437
+ if op == TemporalOperator.CREC:
438
+ curr = f"MAX(CASE WHEN {mob_col} = 0 THEN {col} END)"
439
+ prev = f"MAX(CASE WHEN {mob_col} = -1 THEN {col} END)"
440
+ return f"({curr} / NULLIF({prev}, 0)) - 1"
441
+ if op == TemporalOperator.FREQ:
442
+ return f"COUNT({case_notnull})"
443
+ if op == TemporalOperator.XM:
444
+ return f"COUNT({case_notnull})"
445
+ if op == TemporalOperator.REC:
446
+ return f"-MAX(CASE WHEN {col} IS NOT NULL THEN {mob_col} END)"
447
+ if op == TemporalOperator.MEDIA_ABS:
448
+ return f"MEDIAN({case_col})"
449
+ if op == TemporalOperator.RATIO:
450
+ return f"SUM({case_col})"
451
+ return f"MAX(CASE WHEN {mob_col} = 0 THEN {col} END)"
452
+
453
+ # ------------------------------------------------------------------
454
+ # build_final_join
455
+ # ------------------------------------------------------------------
456
+
457
+ def build_final_join(self, pipeline: FeatureStorePipeline) -> SQLOutput:
458
+ """Generate the final feature table joining Layer 2 and Layer 3."""
459
+ ds = pipeline.config.dataset
460
+ id_cols = [f.name for f in ds.id_fields]
461
+ time_col = ds.time_field.name
462
+ l2a_tbl = self._tbl(pipeline, "layer2a")
463
+ l2b_tbl = self._tbl(pipeline, "layer2b")
464
+ l3_tbl = self._tbl(pipeline, "layer3")
465
+ tbl = self._tbl(pipeline, "features")
466
+
467
+ id_using = ", ".join(id_cols)
468
+
469
+ # Explicit column list — avoids ambiguity from USING join semantics
470
+ select_parts: list[str] = (
471
+ [f"l2a.{c}" for c in id_cols]
472
+ + [f"l2a.{time_col}"]
473
+ + [f"l2a.{col.column_name}" for col in pipeline.layer2a]
474
+ )
475
+ if pipeline.layer2b:
476
+ select_parts += [f"l2b.{col.column_name}" for col in pipeline.layer2b]
477
+ if pipeline.layer3:
478
+ select_parts += [f"l3.{feat.column_name}" for feat in pipeline.layer3]
479
+
480
+ select_list = ",\n ".join(select_parts)
481
+
482
+ l2b_join = (
483
+ f"\nLEFT JOIN {l2b_tbl} l2b USING ({id_using}, {time_col})" if pipeline.layer2b else ""
484
+ )
485
+ l3_join = (
486
+ f"\nLEFT JOIN {l3_tbl} l3 USING ({id_using}, {time_col})" if pipeline.layer3 else ""
487
+ )
488
+
489
+ sql = (
490
+ f"CREATE OR REPLACE TABLE {tbl} AS\n"
491
+ f"SELECT\n {select_list}\n"
492
+ f"FROM {l2a_tbl} l2a"
493
+ f"{l2b_join}"
494
+ f"{l3_join}"
495
+ )
496
+ return SQLOutput(sql=self._transpile(sql), dialect=self.dialect)
@@ -0,0 +1,19 @@
1
+ """DatabricksSQLCodeGenerator — SQL generation targeting the Databricks dialect."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from featkit.generators.sql.base import AbstractSQLCodeGenerator
6
+
7
+
8
+ class DatabricksSQLCodeGenerator(AbstractSQLCodeGenerator):
9
+ """SQL code generator for the Databricks dialect.
10
+
11
+ Inherits all generation logic from :class:`AbstractSQLCodeGenerator`; the
12
+ sole responsibility of this subclass is to declare the SQLGlot dialect
13
+ identifier so that all transpiled SQL is rendered with Databricks-specific
14
+ syntax (backtick quoting, ``NULLS LAST`` ordering, etc.).
15
+ """
16
+
17
+ @property
18
+ def dialect(self) -> str:
19
+ return "databricks"
@@ -0,0 +1,19 @@
1
+ """SnowflakeSQLCodeGenerator — SQL generation targeting the Snowflake dialect."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from featkit.generators.sql.base import AbstractSQLCodeGenerator
6
+
7
+
8
+ class SnowflakeSQLCodeGenerator(AbstractSQLCodeGenerator):
9
+ """SQL code generator for the Snowflake dialect.
10
+
11
+ Inherits all generation logic from :class:`AbstractSQLCodeGenerator`; the
12
+ sole responsibility of this subclass is to declare the SQLGlot dialect
13
+ identifier so that all transpiled SQL is rendered with Snowflake-specific
14
+ syntax (quoting, ``QUALIFY``, ``MAX_BY``, ``MEDIAN``, etc.).
15
+ """
16
+
17
+ @property
18
+ def dialect(self) -> str:
19
+ return "snowflake"
@@ -0,0 +1,19 @@
1
+ """SparkSQLCodeGenerator — SQL generation targeting the Apache Spark SQL dialect."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from featkit.generators.sql.base import AbstractSQLCodeGenerator
6
+
7
+
8
+ class SparkSQLCodeGenerator(AbstractSQLCodeGenerator):
9
+ """SQL code generator for the Apache Spark SQL dialect.
10
+
11
+ Inherits all generation logic from :class:`AbstractSQLCodeGenerator`; the
12
+ sole responsibility of this subclass is to declare the SQLGlot dialect
13
+ identifier so that all transpiled SQL is rendered with Spark-specific
14
+ syntax (backtick quoting, ``NULLS LAST`` ordering, etc.).
15
+ """
16
+
17
+ @property
18
+ def dialect(self) -> str:
19
+ return "spark"
File without changes
File without changes
featkit/layer2/base.py ADDED
@@ -0,0 +1,86 @@
1
+ """Abstract base for Layer 2 output columns."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import TYPE_CHECKING
7
+
8
+ from featkit.contracts.measurement.defaults import get_default_contract
9
+ from featkit.contracts.output.defaults import get_default_output_contract
10
+ from featkit.enums import Layer2Aggregator, Layer2OutputType
11
+ from featkit.fields.measurement_field import MeasurementField
12
+
13
+ if TYPE_CHECKING:
14
+ from featkit.contracts.output.base import AbstractLayer2OutputContract
15
+
16
+ #: Separator used between parts of every Layer 2 column name.
17
+ #: Field names and categorical values must not contain this string.
18
+ COLUMN_NAME_SEP = "__"
19
+
20
+
21
+ class AbstractLayer2Column(ABC):
22
+ """Common base for every column in the Layer 2 horizontal concept table.
23
+
24
+ Subclasses supply the concrete ``output_type`` and ``column_name``; this
25
+ class derives ``output_contract`` from ``output_type`` automatically.
26
+
27
+ Raises:
28
+ ValueError: If ``layer2_aggregator`` is not permitted by the
29
+ measurement's contract, or if ``source_measurement.name``
30
+ contains the column name separator.
31
+ """
32
+
33
+ @staticmethod
34
+ def _check_name_part(value: str, description: str) -> None:
35
+ """Raise ``ValueError`` if *value* contains :data:`COLUMN_NAME_SEP`."""
36
+ if COLUMN_NAME_SEP in value:
37
+ raise ValueError(
38
+ f"{description} {value!r} must not contain the column name separator "
39
+ f"{COLUMN_NAME_SEP!r}"
40
+ )
41
+
42
+ def __init__(
43
+ self,
44
+ source_measurement: MeasurementField,
45
+ layer2_aggregator: Layer2Aggregator,
46
+ ) -> None:
47
+ AbstractLayer2Column._check_name_part(source_measurement.name, "source_measurement.name")
48
+ contract = source_measurement.contract or get_default_contract(
49
+ source_measurement.measurement_type
50
+ )
51
+ if not contract.is_valid(layer2_aggregator):
52
+ valid = ", ".join(
53
+ a.name for a in sorted(contract.valid_layer2_aggregators, key=lambda a: a.value)
54
+ )
55
+ raise ValueError(
56
+ f"Layer2Aggregator.{layer2_aggregator.name} is not valid for "
57
+ f"MeasurementType.{source_measurement.measurement_type.name}. "
58
+ f"Valid aggregators: {valid}"
59
+ )
60
+ self.source_measurement = source_measurement
61
+ self.layer2_aggregator = layer2_aggregator
62
+
63
+ @property
64
+ @abstractmethod
65
+ def output_type(self) -> Layer2OutputType:
66
+ """Layer 2 output type that governs valid Layer 3 temporal operators."""
67
+ ...
68
+
69
+ @property
70
+ def output_contract(self) -> AbstractLayer2OutputContract:
71
+ """Contract for the Layer 2 → Layer 3 boundary, derived from ``output_type``."""
72
+ return get_default_output_contract(self.output_type)
73
+
74
+ @property
75
+ @abstractmethod
76
+ def column_name(self) -> str:
77
+ """Deterministic name for this column in the Layer 2 output table."""
78
+ ...
79
+
80
+ def __repr__(self) -> str:
81
+ return (
82
+ f"{type(self).__name__}("
83
+ f"measurement={self.source_measurement.name!r}, "
84
+ f"aggregator={self.layer2_aggregator.name!r}, "
85
+ f"output_type={self.output_type.name!r})"
86
+ )