featkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. featkit/__init__.py +1 -0
  2. featkit/builders/.gitkeep +0 -0
  3. featkit/builders/__init__.py +0 -0
  4. featkit/builders/distributional_space.py +77 -0
  5. featkit/builders/pivot_space.py +102 -0
  6. featkit/builders/temporal_space.py +86 -0
  7. featkit/config.py +38 -0
  8. featkit/contracts/__init__.py +1 -0
  9. featkit/contracts/measurement/.gitkeep +0 -0
  10. featkit/contracts/measurement/__init__.py +27 -0
  11. featkit/contracts/measurement/base.py +47 -0
  12. featkit/contracts/measurement/defaults.py +117 -0
  13. featkit/contracts/output/.gitkeep +0 -0
  14. featkit/contracts/output/__init__.py +19 -0
  15. featkit/contracts/output/base.py +36 -0
  16. featkit/contracts/output/defaults.py +80 -0
  17. featkit/dataset/.gitkeep +0 -0
  18. featkit/dataset/__init__.py +0 -0
  19. featkit/dataset/base.py +120 -0
  20. featkit/enums.py +110 -0
  21. featkit/fields/.gitkeep +0 -0
  22. featkit/fields/__init__.py +9 -0
  23. featkit/fields/base.py +48 -0
  24. featkit/fields/categorical_field.py +55 -0
  25. featkit/fields/id_field.py +14 -0
  26. featkit/fields/measurement_field.py +42 -0
  27. featkit/fields/time_field.py +43 -0
  28. featkit/generators/__init__.py +0 -0
  29. featkit/generators/base.py +171 -0
  30. featkit/generators/output.py +118 -0
  31. featkit/generators/pyspark/.gitkeep +0 -0
  32. featkit/generators/pyspark/__init__.py +0 -0
  33. featkit/generators/pyspark/databricks.py +448 -0
  34. featkit/generators/sql/.gitkeep +0 -0
  35. featkit/generators/sql/__init__.py +0 -0
  36. featkit/generators/sql/base.py +496 -0
  37. featkit/generators/sql/databricks.py +19 -0
  38. featkit/generators/sql/snowflake.py +19 -0
  39. featkit/generators/sql/spark_sql.py +19 -0
  40. featkit/layer2/.gitkeep +0 -0
  41. featkit/layer2/__init__.py +0 -0
  42. featkit/layer2/base.py +86 -0
  43. featkit/layer2/distributional.py +51 -0
  44. featkit/layer2/pivoted.py +63 -0
  45. featkit/layer3/.gitkeep +0 -0
  46. featkit/layer3/__init__.py +0 -0
  47. featkit/layer3/temporal_feature.py +87 -0
  48. featkit/pipeline.py +63 -0
  49. featkit-0.1.0.dist-info/METADATA +140 -0
  50. featkit-0.1.0.dist-info/RECORD +52 -0
  51. featkit-0.1.0.dist-info/WHEEL +4 -0
  52. featkit-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,118 @@
1
+ """Output containers and DAG model for generated feature-store code."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass
11
+ class DAGNode:
12
+ """A single node in the feature-store execution DAG.
13
+
14
+ Args:
15
+ step_name: Unique identifier for this pipeline step.
16
+ depends_on: Names of steps that must complete before this one.
17
+ """
18
+
19
+ step_name: str
20
+ depends_on: list[str] = field(default_factory=list)
21
+
22
+
23
+ @dataclass
24
+ class DAG:
25
+ """Directed acyclic graph representing pipeline step dependencies.
26
+
27
+ Args:
28
+ nodes: All nodes in topological order (sources first).
29
+ """
30
+
31
+ nodes: list[DAGNode]
32
+
33
+ def to_json(self) -> str:
34
+ """Serialise the DAG to a JSON string."""
35
+ return json.dumps(
36
+ [{"step_name": n.step_name, "depends_on": n.depends_on} for n in self.nodes],
37
+ indent=2,
38
+ )
39
+
40
+
41
+ @dataclass
42
+ class SQLOutput:
43
+ """Container for generated SQL code.
44
+
45
+ Args:
46
+ sql: The complete SQL script.
47
+ dialect: The SQL dialect (e.g. ``"snowflake"``, ``"databricks"``).
48
+ """
49
+
50
+ sql: str
51
+ dialect: str
52
+
53
+ def save(self, path: str) -> None:
54
+ """Write the SQL string to *path*, creating parent directories as needed."""
55
+ p = Path(path)
56
+ p.parent.mkdir(parents=True, exist_ok=True)
57
+ p.write_text(self.sql, encoding="utf-8")
58
+
59
+
60
+ @dataclass
61
+ class PySparkOutput:
62
+ """Placeholder container for generated PySpark code.
63
+
64
+ The ``code`` attribute holds the PySpark script as a plain string.
65
+ Plan 16 will replace the string payload with a lazy DataFrame chain.
66
+
67
+ Args:
68
+ code: PySpark script string (may be empty for stubs).
69
+ """
70
+
71
+ code: str = ""
72
+
73
+ def save(self, path: str) -> None:
74
+ """Write the PySpark code string to *path*, creating parent directories as needed."""
75
+ p = Path(path)
76
+ p.parent.mkdir(parents=True, exist_ok=True)
77
+ p.write_text(self.code, encoding="utf-8")
78
+
79
+
80
+ #: Union type for any code output produced by a generator.
81
+ CodeOutput = SQLOutput | PySparkOutput
82
+
83
+
84
+ @dataclass
85
+ class FeatureStoreOutput:
86
+ """Complete output of a code-generation run.
87
+
88
+ Args:
89
+ code: The generated script (SQL or PySpark).
90
+ dag: Execution DAG for the pipeline steps.
91
+ mermaid: Mermaid flowchart string for the DAG.
92
+ """
93
+
94
+ code: CodeOutput
95
+ dag: DAG
96
+ mermaid: str
97
+
98
+ def save(self, directory: str) -> None:
99
+ """Persist all artefacts under *directory*.
100
+
101
+ Files written:
102
+
103
+ * ``script.sql`` — when ``code`` is :class:`SQLOutput`
104
+ * ``script.py`` — when ``code`` is :class:`PySparkOutput`
105
+ * ``dag.json`` — JSON serialisation of the DAG
106
+ * ``diagram.md`` — Mermaid diagram wrapped in a fenced code block
107
+ """
108
+ d = Path(directory)
109
+ d.mkdir(parents=True, exist_ok=True)
110
+
111
+ if isinstance(self.code, SQLOutput):
112
+ self.code.save(str(d / "script.sql"))
113
+ else:
114
+ self.code.save(str(d / "script.py"))
115
+
116
+ (d / "dag.json").write_text(self.dag.to_json(), encoding="utf-8")
117
+ diagram_content = f"```mermaid\n{self.mermaid}\n```\n"
118
+ (d / "diagram.md").write_text(diagram_content, encoding="utf-8")
File without changes
File without changes
@@ -0,0 +1,448 @@
1
+ """PySparkCodeGenerator — PySpark DataFrame code generator for Databricks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import defaultdict
6
+ from typing import TYPE_CHECKING
7
+
8
+ from featkit.enums import DistributionalMetric, TemporalOperator, TimeWindowDirection
9
+ from featkit.generators.base import AbstractCodeGenerator
10
+ from featkit.generators.output import CodeOutput, FeatureStoreOutput, PySparkOutput
11
+
12
+ if TYPE_CHECKING:
13
+ from featkit.layer2.distributional import DistributionalColumn
14
+ from featkit.layer3.temporal_feature import TemporalFeature
15
+ from featkit.pipeline import FeatureStorePipeline
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Code-generation header — injected at the top of each generated snippet
19
+ # ---------------------------------------------------------------------------
20
+
21
+ _HEADER = """\
22
+ from pyspark.sql import SparkSession
23
+ from pyspark.sql import functions as F
24
+ from pyspark.sql.window import Window
25
+ """
26
+
27
+
28
+ class PySparkCodeGenerator(AbstractCodeGenerator):
29
+ """Code generator that emits a PySpark Python script.
30
+
31
+ Each ``build_*`` method returns a :class:`~featkit.generators.output.PySparkOutput`
32
+ whose ``code`` attribute contains one self-contained Python snippet. The
33
+ snippets are concatenated by the inherited ``generate()`` orchestrator into
34
+ a single script that, when executed against a live ``SparkSession``, builds
35
+ the full feature table using lazy DataFrame transformations.
36
+
37
+ MOB reference table structure (period self-join, mirrors the SQL generator):
38
+
39
+ * ``_periodos_ordenados`` — distinct periods with a sequential ``mob`` number
40
+ assigned via ``Window.orderBy(time_col)``.
41
+ * ``crossJoin`` of the ordered periods with itself produces all
42
+ ``(ts_analysis, ts_relative)`` pairs.
43
+ * ``mob = mob_b - mob_a`` gives the signed offset (0 = current, negative = past).
44
+
45
+ Layer 3 features are derived by joining the MOB table with Layer 2A/B on
46
+ ``ts_relative == ts``, then using ``groupBy(ids + [ts_analysis])`` with
47
+ ``F.when(F.col("mob").between(lo, hi), col)`` aggregates — no window
48
+ functions required.
49
+
50
+ All generated transformations are lazy — no ``.collect()`` or ``.show()``
51
+ calls are emitted.
52
+ """
53
+
54
+ # ------------------------------------------------------------------
55
+ # Internal helpers
56
+ # ------------------------------------------------------------------
57
+
58
+ @staticmethod
59
+ def _tbl(pipeline: FeatureStorePipeline, suffix: str) -> str:
60
+ """Return a fully-qualified intermediate table name string."""
61
+ cfg = pipeline.config
62
+ return f"{cfg.output_schema}.{cfg.output_table_prefix}{suffix}"
63
+
64
+ @staticmethod
65
+ def _spark_read(table: str) -> str:
66
+ """Return a PySpark expression to read *table* as a DataFrame."""
67
+ return f'spark.read.table("{table}")'
68
+
69
+ # ------------------------------------------------------------------
70
+ # build_mob_table
71
+ # ------------------------------------------------------------------
72
+
73
+ def build_mob_table(self, pipeline: FeatureStorePipeline) -> CodeOutput:
74
+ """Generate the MOB reference DataFrame via period self-join.
75
+
76
+ Assigns a sequential ``mob`` number to each distinct period using
77
+ ``Window.orderBy(time_col)``, then cross-joins the ordered periods
78
+ with themselves to produce all ``(ts_analysis, ts_relative)`` pairs.
79
+ ``mob = mob_b - mob_a`` (0 = current period, negative = past).
80
+ """
81
+ ds = pipeline.config.dataset
82
+ time_col = ds.time_field.name
83
+ src = ds.source_reference
84
+ tbl = self._tbl(pipeline, "mob_ref")
85
+
86
+ time_analysis = f"{time_col}_analysis"
87
+ time_relative = f"{time_col}_relative"
88
+
89
+ code = (
90
+ f"# --- MOB reference table ---\n"
91
+ f"_facts = spark.read.table({src!r})\n"
92
+ f'_periods = _facts.select("{time_col}").distinct()\n'
93
+ f'_mob_win = Window.orderBy("{time_col}")\n'
94
+ f'_periodos_ordenados = _periods.withColumn("mob", F.row_number().over(_mob_win))\n'
95
+ f"_a = _periodos_ordenados.select(\n"
96
+ f' F.col("{time_col}").alias("{time_analysis}"),\n'
97
+ f' F.col("mob").alias("mob_a"),\n'
98
+ f")\n"
99
+ f"_b = _periodos_ordenados.select(\n"
100
+ f' F.col("{time_col}").alias("{time_relative}"),\n'
101
+ f' F.col("mob").alias("mob_b"),\n'
102
+ f")\n"
103
+ f"mob_ref = _a.crossJoin(_b).withColumn(\n"
104
+ f' "mob", F.col("mob_b") - F.col("mob_a")\n'
105
+ f').drop("mob_a", "mob_b")\n'
106
+ f"mob_ref.write.mode('overwrite').saveAsTable({tbl!r})\n"
107
+ )
108
+ return PySparkOutput(code=code)
109
+
110
+ # ------------------------------------------------------------------
111
+ # build_layer2a
112
+ # ------------------------------------------------------------------
113
+
114
+ def build_layer2a(self, pipeline: FeatureStorePipeline) -> CodeOutput:
115
+ """Generate the Layer 2A pivot aggregation DataFrame."""
116
+ ds = pipeline.config.dataset
117
+ id_cols = [f.name for f in ds.id_fields]
118
+ time_col = ds.time_field.name
119
+ src = ds.source_reference
120
+ tbl = self._tbl(pipeline, "layer2a")
121
+
122
+ group_cols = ", ".join(f'"{c}"' for c in id_cols + [time_col])
123
+
124
+ agg_exprs: list[str] = []
125
+ for col in pipeline.layer2a:
126
+ meas = col.source_measurement.name
127
+ agg = col.layer2_aggregator.value.lower() # sum, max, min, avg, count
128
+ alias = col.column_name
129
+
130
+ conditions = [
131
+ f'(F.col("{cat_field.name}") == "{cat_val}")'
132
+ for cat_field, cat_val in sorted(
133
+ col.categorical_combination.items(), key=lambda kv: kv[0].name
134
+ )
135
+ if cat_val is not None
136
+ ]
137
+
138
+ if conditions:
139
+ predicate = " & ".join(conditions)
140
+ inner = f'F.when({predicate}, F.col("{meas}"))'
141
+ agg_expr = f'F.{agg}({inner}).alias("{alias}")'
142
+ else:
143
+ agg_expr = f'F.{agg}(F.col("{meas}")).alias("{alias}")'
144
+
145
+ agg_exprs.append(f" {agg_expr}")
146
+
147
+ agg_list = ",\n".join(agg_exprs)
148
+
149
+ if agg_exprs:
150
+ code = (
151
+ f"# --- Layer 2A: pivot aggregations ---\n"
152
+ f"_l2a_facts = spark.read.table({src!r})\n"
153
+ f"layer2a = _l2a_facts.groupBy({group_cols}).agg(\n"
154
+ f"{agg_list},\n"
155
+ f")\n"
156
+ f"layer2a.write.mode('overwrite').saveAsTable({tbl!r})\n"
157
+ )
158
+ else:
159
+ code = (
160
+ f"# --- Layer 2A: no pivot aggregations; preserve grouping grain ---\n"
161
+ f"_l2a_facts = spark.read.table({src!r})\n"
162
+ f"layer2a = _l2a_facts.select({group_cols}).distinct()\n"
163
+ f"layer2a.write.mode('overwrite').saveAsTable({tbl!r})\n"
164
+ )
165
+ return PySparkOutput(code=code)
166
+
167
+ # ------------------------------------------------------------------
168
+ # build_layer2b
169
+ # ------------------------------------------------------------------
170
+
171
+ def build_layer2b(self, pipeline: FeatureStorePipeline) -> CodeOutput:
172
+ """Generate the Layer 2B distributional metrics DataFrame.
173
+
174
+ Builds one sub-DataFrame per (categorical, measurement, aggregator)
175
+ group, computes the requested distributional metrics via PySpark
176
+ aggregate functions, then joins all sub-DataFrames together.
177
+ """
178
+ if not pipeline.layer2b:
179
+ return PySparkOutput(code="# --- Layer 2B: no distributional columns ---\n")
180
+
181
+ ds = pipeline.config.dataset
182
+ id_cols = [f.name for f in ds.id_fields]
183
+ time_col = ds.time_field.name
184
+ src = ds.source_reference
185
+ tbl = self._tbl(pipeline, "layer2b")
186
+
187
+ group_cols = ", ".join(f'"{c}"' for c in id_cols + [time_col])
188
+ id_time_cols = ", ".join(f'"{c}"' for c in id_cols + [time_col])
189
+
190
+ groups: dict[tuple[str, str, str], list[DistributionalColumn]] = defaultdict(list)
191
+ for col in pipeline.layer2b:
192
+ key = (col.categorical.name, col.source_measurement.name, col.layer2_aggregator.value)
193
+ groups[key].append(col)
194
+
195
+ snippets: list[str] = [
196
+ f"# --- Layer 2B: distributional metrics ---\n"
197
+ f"_l2b_facts = spark.read.table({src!r})\n"
198
+ f"_base = _l2b_facts.select({id_time_cols}).distinct()\n"
199
+ ]
200
+
201
+ df_names: list[str] = []
202
+ for i, ((cat_name, meas_name, agg_name), cols) in enumerate(groups.items()):
203
+ agg_fn = agg_name.lower()
204
+ df_var = f"_dist_{i}"
205
+ df_names.append(df_var)
206
+
207
+ cat_group = ", ".join(f'"{c}"' for c in id_cols + [time_col, cat_name])
208
+ agg_col = f'"{cat_name}_{meas_name}_{agg_fn}_cat_val"'
209
+ total_col = f'"{cat_name}_{meas_name}_{agg_fn}_total_val"'
210
+
211
+ metric_agg_exprs: list[str] = []
212
+ for col in cols:
213
+ expr = self._distributional_agg_expr(
214
+ col.distributional_metric, cat_name, agg_col, total_col, col.column_name
215
+ )
216
+ metric_agg_exprs.append(f" {expr}")
217
+
218
+ metric_list = ",\n".join(metric_agg_exprs)
219
+ shares_window = f"Window.partitionBy({group_cols})"
220
+
221
+ snippets.append(
222
+ f"_shares_{i} = (\n"
223
+ f" _l2b_facts\n"
224
+ f" .groupBy({cat_group})\n"
225
+ f" .agg(F.{agg_fn}({meas_name!r}).alias({agg_col}))\n"
226
+ f" .withColumn(\n"
227
+ f" {total_col},\n"
228
+ f" F.sum(F.col({agg_col})).over({shares_window}),\n"
229
+ f" )\n"
230
+ f")\n"
231
+ f"{df_var} = _shares_{i}.groupBy({group_cols}).agg(\n"
232
+ f"{metric_list},\n"
233
+ f")\n"
234
+ )
235
+
236
+ # Join all metric DataFrames onto base
237
+ join_id_cols = ", ".join(f'"{c}"' for c in id_cols + [time_col])
238
+ join_lines = "_layer2b = _base\n"
239
+ for df_var in df_names:
240
+ join_lines += f'_layer2b = _layer2b.join({df_var}, on=[{join_id_cols}], how="left")\n'
241
+
242
+ snippets.append(join_lines)
243
+ snippets.append(
244
+ f"layer2b = _layer2b\nlayer2b.write.mode('overwrite').saveAsTable({tbl!r})\n"
245
+ )
246
+
247
+ return PySparkOutput(code="\n".join(snippets))
248
+
249
+ def _distributional_agg_expr(
250
+ self,
251
+ metric: DistributionalMetric,
252
+ cat_col: str,
253
+ cat_val_col: str,
254
+ total_val_col: str,
255
+ alias: str,
256
+ ) -> str:
257
+ """Return a PySpark agg() expression string for one distributional metric."""
258
+ cv = f"F.col({cat_val_col})"
259
+ tv = f"F.col({total_val_col})"
260
+ share = f"F.when({tv} != 0, {cv} / {tv}).otherwise(F.lit(0.0))"
261
+
262
+ if metric == DistributionalMetric.ENTROPY:
263
+ p = share
264
+ return (
265
+ f'F.sum(F.when({cv} > 0, -{p} * F.log({p})).otherwise(F.lit(0.0))).alias("{alias}")'
266
+ )
267
+ if metric == DistributionalMetric.HHI:
268
+ return f'F.sum(F.pow({share}, F.lit(2))).alias("{alias}")'
269
+ if metric == DistributionalMetric.DOMINANT_PROPORTION:
270
+ return f'F.max({share}).alias("{alias}")'
271
+ if metric == DistributionalMetric.MODE:
272
+ return f'F.max_by(F.col("{cat_col}"), {cv}).alias("{alias}")'
273
+ if metric == DistributionalMetric.COUNT:
274
+ return f'F.count(F.when({cv} > 0, F.lit(1))).alias("{alias}")'
275
+ raise ValueError(f"Unsupported distributional metric: {metric}")
276
+
277
+ # ------------------------------------------------------------------
278
+ # build_layer3
279
+ # ------------------------------------------------------------------
280
+
281
+ def build_layer3(self, pipeline: FeatureStorePipeline) -> CodeOutput:
282
+ """Generate the Layer 3 temporal features DataFrame.
283
+
284
+ Joins the MOB reference table (``ts_analysis``, ``ts_relative``, ``mob``)
285
+ with Layer 2A on ``mob.ts_relative == l2a.ts``, then uses
286
+ ``groupBy(ids + [ts_analysis])`` with ``F.when(mob.between(lo, hi), col)``
287
+ aggregates to implement each temporal operator — no window functions.
288
+ """
289
+ ds = pipeline.config.dataset
290
+ id_cols = [f.name for f in ds.id_fields]
291
+ time_col = ds.time_field.name
292
+ tbl = self._tbl(pipeline, "layer3")
293
+ mob_tbl = self._tbl(pipeline, "mob_ref")
294
+ l2a_tbl = self._tbl(pipeline, "layer2a")
295
+ l2b_tbl = self._tbl(pipeline, "layer2b")
296
+
297
+ time_analysis = f"{time_col}_analysis"
298
+ time_relative = f"{time_col}_relative"
299
+
300
+ id_time_group = ", ".join(f'"{c}"' for c in id_cols + [time_analysis])
301
+ id_time_join = ", ".join(f'"{c}"' for c in id_cols + [time_col])
302
+
303
+ feat_exprs: list[str] = []
304
+ for feat in pipeline.layer3:
305
+ expr = self._temporal_pyspark_expr(feat)
306
+ feat_exprs.append(f" {expr}")
307
+
308
+ feat_list = ",\n".join(feat_exprs)
309
+
310
+ has_l2b = bool(pipeline.layer2b)
311
+ l2b_join = (
312
+ f'\nl3_df = l3_df.join(spark.read.table({l2b_tbl!r}), on=[{id_time_join}], how="left")'
313
+ if has_l2b
314
+ else ""
315
+ )
316
+
317
+ if feat_exprs:
318
+ agg_block = (
319
+ f"layer3 = l3_df.groupBy({id_time_group}).agg(\n"
320
+ f"{feat_list},\n"
321
+ f').withColumnRenamed("{time_analysis}", "{time_col}")\n'
322
+ )
323
+ else:
324
+ agg_block = (
325
+ f"layer3 = l3_df.select({id_time_group}).distinct()"
326
+ f'.withColumnRenamed("{time_analysis}", "{time_col}")\n'
327
+ )
328
+
329
+ code = (
330
+ f"# --- Layer 3: temporal features ---\n"
331
+ f"_mob = spark.read.table({mob_tbl!r})\n"
332
+ f"_l2a = spark.read.table({l2a_tbl!r})\n"
333
+ f'l3_df = _mob.join(_l2a, _mob["{time_relative}"] == _l2a["{time_col}"], "inner")'
334
+ f"{l2b_join}\n" + agg_block + f"layer3.write.mode('overwrite').saveAsTable({tbl!r})\n"
335
+ )
336
+ return PySparkOutput(code=code)
337
+
338
+ def _temporal_pyspark_expr(self, feat: TemporalFeature) -> str:
339
+ """Return a PySpark groupBy-compatible agg expression for one temporal feature.
340
+
341
+ Produces ``F.agg(F.when(F.col("mob").between(lo, hi), col))`` style
342
+ expressions so that they can be passed directly to ``.agg()`` after
343
+ a ``groupBy(ids + [ts_analysis])`` call.
344
+ """
345
+ src_col = feat.source.column_name
346
+ col_ref = f'F.col("{src_col}")'
347
+ alias = feat.column_name
348
+ op = feat.operator
349
+ w = feat.window_size
350
+ bwd = feat.direction == TimeWindowDirection.BACKWARD
351
+ mob = 'F.col("mob")'
352
+
353
+ if w is not None:
354
+ lo, hi = (-(w - 1), 0) if bwd else (0, w - 1)
355
+ in_window = f"{mob}.between({lo}, {hi})"
356
+ case_col = f"F.when({in_window}, {col_ref})"
357
+
358
+ if op == TemporalOperator.PROM_U:
359
+ return f'F.avg({case_col}).alias("{alias}")'
360
+ if op == TemporalOperator.PROM_P:
361
+ return f'F.avg({case_col}).alias("{alias}")'
362
+ if op == TemporalOperator.SUM_U:
363
+ return f'F.sum({case_col}).alias("{alias}")'
364
+ if op == TemporalOperator.SUM_P:
365
+ return f'F.sum({case_col}).alias("{alias}")'
366
+ if op == TemporalOperator.MIN_U:
367
+ return f'F.min({case_col}).alias("{alias}")'
368
+ if op == TemporalOperator.MAX_U:
369
+ return f'F.max({case_col}).alias("{alias}")'
370
+ if op == TemporalOperator.ULT_MES:
371
+ return f'F.max(F.when({mob} == 0, {col_ref})).alias("{alias}")'
372
+ if op == TemporalOperator.PREV_MES:
373
+ return f'F.max(F.when({mob} == -1, {col_ref})).alias("{alias}")'
374
+ if op == TemporalOperator.CREC:
375
+ curr = f"F.max(F.when({mob} == 0, {col_ref}))"
376
+ prev = f"F.max(F.when({mob} == -1, {col_ref}))"
377
+ return (
378
+ f"({curr} / F.when({prev} != 0, {prev}).otherwise(F.lit(None))"
379
+ f' - F.lit(1)).alias("{alias}")'
380
+ )
381
+ if op == TemporalOperator.FREQ:
382
+ return (
383
+ f'F.count(F.when({in_window} & {col_ref}.isNotNull(), F.lit(1))).alias("{alias}")'
384
+ )
385
+ if op == TemporalOperator.XM:
386
+ return (
387
+ f'F.count(F.when({in_window} & {col_ref}.isNotNull(), F.lit(1))).alias("{alias}")'
388
+ )
389
+ if op == TemporalOperator.REC:
390
+ return f'(-F.max(F.when({col_ref}.isNotNull(), {mob}))).alias("{alias}")'
391
+ if op == TemporalOperator.MEDIA_ABS:
392
+ return f'F.percentile_approx({case_col}, F.lit(0.5)).alias("{alias}")'
393
+ if op == TemporalOperator.RATIO:
394
+ return f'F.sum({case_col}).alias("{alias}")'
395
+ return f'F.max(F.when({mob} == 0, {col_ref})).alias("{alias}")'
396
+
397
+ # ------------------------------------------------------------------
398
+ # build_final_join
399
+ # ------------------------------------------------------------------
400
+
401
+ def build_final_join(self, pipeline: FeatureStorePipeline) -> CodeOutput:
402
+ """Generate the final feature table by joining Layer 2 and Layer 3."""
403
+ ds = pipeline.config.dataset
404
+ id_cols = [f.name for f in ds.id_fields]
405
+ tbl = self._tbl(pipeline, "features")
406
+ l2a_tbl = self._tbl(pipeline, "layer2a")
407
+ l2b_tbl = self._tbl(pipeline, "layer2b")
408
+ l3_tbl = self._tbl(pipeline, "layer3")
409
+
410
+ id_time_join = ", ".join(f'"{c}"' for c in id_cols + [ds.time_field.name])
411
+
412
+ l2b_join = (
413
+ f"\nfinal_df = final_df.join("
414
+ f'spark.read.table({l2b_tbl!r}), on=[{id_time_join}], how="left")'
415
+ if pipeline.layer2b
416
+ else ""
417
+ )
418
+ l3_join = (
419
+ f"\nfinal_df = final_df.join("
420
+ f'spark.read.table({l3_tbl!r}), on=[{id_time_join}], how="left")'
421
+ if pipeline.layer3
422
+ else ""
423
+ )
424
+
425
+ code = (
426
+ f"# --- Final join ---\n"
427
+ f"final_df = spark.read.table({l2a_tbl!r})"
428
+ f"{l2b_join}"
429
+ f"{l3_join}\n"
430
+ f"final_df.write.mode('overwrite').saveAsTable({tbl!r})\n"
431
+ )
432
+ return PySparkOutput(code=code)
433
+
434
+ # ------------------------------------------------------------------
435
+ # Override generate() to emit the header once
436
+ # ------------------------------------------------------------------
437
+
438
+ def generate(self, pipeline: FeatureStorePipeline) -> FeatureStoreOutput:
439
+ """Orchestrate all build steps and prepend the PySpark import header."""
440
+ result = super().generate(pipeline)
441
+ assert isinstance(result, FeatureStoreOutput)
442
+ assert isinstance(result.code, PySparkOutput)
443
+ full_code = _HEADER + "\n" + result.code.code
444
+ return FeatureStoreOutput(
445
+ code=PySparkOutput(code=full_code),
446
+ dag=result.dag,
447
+ mermaid=result.mermaid,
448
+ )
File without changes
File without changes