tsagentkit 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tsagentkit/__init__.py +126 -0
  2. tsagentkit/anomaly/__init__.py +130 -0
  3. tsagentkit/backtest/__init__.py +48 -0
  4. tsagentkit/backtest/engine.py +788 -0
  5. tsagentkit/backtest/metrics.py +244 -0
  6. tsagentkit/backtest/report.py +342 -0
  7. tsagentkit/calibration/__init__.py +136 -0
  8. tsagentkit/contracts/__init__.py +133 -0
  9. tsagentkit/contracts/errors.py +275 -0
  10. tsagentkit/contracts/results.py +418 -0
  11. tsagentkit/contracts/schema.py +44 -0
  12. tsagentkit/contracts/task_spec.py +300 -0
  13. tsagentkit/covariates/__init__.py +340 -0
  14. tsagentkit/eval/__init__.py +285 -0
  15. tsagentkit/features/__init__.py +20 -0
  16. tsagentkit/features/covariates.py +328 -0
  17. tsagentkit/features/extra/__init__.py +5 -0
  18. tsagentkit/features/extra/native.py +179 -0
  19. tsagentkit/features/factory.py +187 -0
  20. tsagentkit/features/matrix.py +159 -0
  21. tsagentkit/features/tsfeatures_adapter.py +115 -0
  22. tsagentkit/features/versioning.py +203 -0
  23. tsagentkit/hierarchy/__init__.py +39 -0
  24. tsagentkit/hierarchy/aggregation.py +62 -0
  25. tsagentkit/hierarchy/evaluator.py +400 -0
  26. tsagentkit/hierarchy/reconciliation.py +232 -0
  27. tsagentkit/hierarchy/structure.py +453 -0
  28. tsagentkit/models/__init__.py +182 -0
  29. tsagentkit/models/adapters/__init__.py +83 -0
  30. tsagentkit/models/adapters/base.py +321 -0
  31. tsagentkit/models/adapters/chronos.py +387 -0
  32. tsagentkit/models/adapters/moirai.py +256 -0
  33. tsagentkit/models/adapters/registry.py +171 -0
  34. tsagentkit/models/adapters/timesfm.py +440 -0
  35. tsagentkit/models/baselines.py +207 -0
  36. tsagentkit/models/sktime.py +307 -0
  37. tsagentkit/monitoring/__init__.py +51 -0
  38. tsagentkit/monitoring/alerts.py +302 -0
  39. tsagentkit/monitoring/coverage.py +203 -0
  40. tsagentkit/monitoring/drift.py +330 -0
  41. tsagentkit/monitoring/report.py +214 -0
  42. tsagentkit/monitoring/stability.py +275 -0
  43. tsagentkit/monitoring/triggers.py +423 -0
  44. tsagentkit/qa/__init__.py +347 -0
  45. tsagentkit/router/__init__.py +37 -0
  46. tsagentkit/router/bucketing.py +489 -0
  47. tsagentkit/router/fallback.py +132 -0
  48. tsagentkit/router/plan.py +23 -0
  49. tsagentkit/router/router.py +271 -0
  50. tsagentkit/series/__init__.py +26 -0
  51. tsagentkit/series/alignment.py +206 -0
  52. tsagentkit/series/dataset.py +449 -0
  53. tsagentkit/series/sparsity.py +261 -0
  54. tsagentkit/series/validation.py +393 -0
  55. tsagentkit/serving/__init__.py +39 -0
  56. tsagentkit/serving/orchestration.py +943 -0
  57. tsagentkit/serving/packaging.py +73 -0
  58. tsagentkit/serving/provenance.py +317 -0
  59. tsagentkit/serving/tsfm_cache.py +214 -0
  60. tsagentkit/skill/README.md +135 -0
  61. tsagentkit/skill/__init__.py +8 -0
  62. tsagentkit/skill/recipes.md +429 -0
  63. tsagentkit/skill/tool_map.md +21 -0
  64. tsagentkit/time/__init__.py +134 -0
  65. tsagentkit/utils/__init__.py +20 -0
  66. tsagentkit/utils/quantiles.py +83 -0
  67. tsagentkit/utils/signature.py +47 -0
  68. tsagentkit/utils/temporal.py +41 -0
  69. tsagentkit-1.0.2.dist-info/METADATA +371 -0
  70. tsagentkit-1.0.2.dist-info/RECORD +72 -0
  71. tsagentkit-1.0.2.dist-info/WHEEL +4 -0
  72. tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,232 @@
1
+ """Hierarchical forecast reconciliation (adapter to hierarchicalforecast)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ from .structure import HierarchyStructure
12
+
13
+
14
+ class ReconciliationMethod(Enum):
15
+ """Available reconciliation methods."""
16
+
17
+ BOTTOM_UP = "bottom_up"
18
+ TOP_DOWN = "top_down"
19
+ MIDDLE_OUT = "middle_out"
20
+ OLS = "ols"
21
+ WLS = "wls"
22
+ MIN_TRACE = "min_trace"
23
+
24
+
25
+ def _build_tag_indices(tags: dict[str, np.ndarray], order: list[str]) -> dict[str, np.ndarray]:
26
+ node_to_idx = {node: idx for idx, node in enumerate(order)}
27
+ indexed: dict[str, np.ndarray] = {}
28
+ for key, nodes in tags.items():
29
+ indices = [node_to_idx[n] for n in nodes if n in node_to_idx]
30
+ if indices:
31
+ indexed[key] = np.array(indices, dtype=int)
32
+ return indexed
33
+
34
+
35
+ def _build_s_matrix(structure: HierarchyStructure, order: list[str]) -> np.ndarray:
36
+ s_df = structure.to_s_df()
37
+ s_df = s_df.set_index("unique_id").reindex(order)
38
+ return s_df[structure.bottom_nodes].to_numpy(dtype=float)
39
+
40
+
41
+ def _get_level_key(tags: dict[str, np.ndarray], middle_level: int | str | None) -> str:
42
+ level_keys = [k for k in tags if k.startswith("level_")]
43
+ level_keys = sorted(level_keys, key=lambda k: int(k.split("_")[1]))
44
+ if not level_keys:
45
+ return next(iter(tags), "bottom")
46
+
47
+ if isinstance(middle_level, str) and middle_level in tags:
48
+ return middle_level
49
+ if isinstance(middle_level, int):
50
+ idx = min(max(middle_level, 0), len(level_keys) - 1)
51
+ return level_keys[idx]
52
+
53
+ if len(level_keys) >= 3:
54
+ return level_keys[len(level_keys) // 2]
55
+ if len(level_keys) == 2:
56
+ return level_keys[1]
57
+ return level_keys[0]
58
+
59
+
60
+ def _select_reconciler(
61
+ method: ReconciliationMethod,
62
+ tags: dict[str, np.ndarray],
63
+ middle_level: int | str | None = None,
64
+ has_insample: bool = False,
65
+ ) -> Any:
66
+ from hierarchicalforecast.methods import BottomUp, MiddleOut, MinTrace, TopDown
67
+
68
+ if method == ReconciliationMethod.BOTTOM_UP:
69
+ return BottomUp()
70
+ if method == ReconciliationMethod.TOP_DOWN:
71
+ return TopDown(method="forecast_proportions")
72
+ if method == ReconciliationMethod.MIDDLE_OUT:
73
+ return MiddleOut(
74
+ middle_level=_get_level_key(tags, middle_level),
75
+ top_down_method="forecast_proportions",
76
+ )
77
+ if method == ReconciliationMethod.OLS:
78
+ return MinTrace(method="ols")
79
+ if method == ReconciliationMethod.WLS:
80
+ return MinTrace(method="wls_struct")
81
+ if method == ReconciliationMethod.MIN_TRACE:
82
+ return MinTrace(method="mint_shrink" if has_insample else "ols")
83
+ return MinTrace(method="ols")
84
+
85
+
86
+ def _apply_reconciler(
87
+ reconciler: Any,
88
+ s_matrix: np.ndarray,
89
+ y_hat: np.ndarray,
90
+ tags: dict[str, np.ndarray],
91
+ y_insample: np.ndarray | None = None,
92
+ y_hat_insample: np.ndarray | None = None,
93
+ ) -> np.ndarray:
94
+ result = reconciler.fit_predict(
95
+ S=s_matrix,
96
+ y_hat=y_hat,
97
+ tags=tags,
98
+ y_insample=y_insample,
99
+ y_hat_insample=y_hat_insample,
100
+ )
101
+ if isinstance(result, dict):
102
+ if "mean" in result:
103
+ return np.asarray(result["mean"])
104
+ # Fallback to first array-like entry
105
+ for value in result.values():
106
+ if isinstance(value, (np.ndarray, list)):
107
+ return np.asarray(value)
108
+ raise ValueError("Unsupported reconciliation output format.")
109
+ return np.asarray(result)
110
+
111
+
112
+ class Reconciler:
113
+ """Hierarchical forecast reconciliation engine (adapter)."""
114
+
115
+ def __init__(
116
+ self,
117
+ method: ReconciliationMethod,
118
+ structure: HierarchyStructure,
119
+ ) -> None:
120
+ self.method = method
121
+ self.structure = structure
122
+
123
+ def reconcile(
124
+ self,
125
+ base_forecasts: np.ndarray,
126
+ fitted_values: np.ndarray | None = None,
127
+ residuals: np.ndarray | None = None,
128
+ **kwargs: Any,
129
+ ) -> np.ndarray:
130
+ """Reconcile base forecasts to be hierarchy-consistent."""
131
+ y_hat = np.asarray(base_forecasts, dtype=float)
132
+ was_1d = y_hat.ndim == 1
133
+ if was_1d:
134
+ y_hat = y_hat[:, None]
135
+
136
+ hf_order = self.structure.node_order()
137
+ tags = self.structure.to_tags()
138
+ indexed_tags = _build_tag_indices(tags, hf_order)
139
+ s_matrix = _build_s_matrix(self.structure, hf_order)
140
+ order_to_hf = [self.structure.all_nodes.index(node) for node in hf_order]
141
+ hf_to_order = np.argsort(order_to_hf)
142
+ y_hat = y_hat[order_to_hf]
143
+
144
+ y_insample = None
145
+ y_hat_insample = None
146
+ has_insample = False
147
+ if fitted_values is not None and residuals is not None:
148
+ y_hat_insample = np.asarray(fitted_values, dtype=float)[order_to_hf]
149
+ residuals_arr = np.asarray(residuals, dtype=float)[order_to_hf]
150
+ y_insample = y_hat_insample + residuals_arr
151
+ has_insample = True
152
+
153
+ reconciler = _select_reconciler(
154
+ self.method,
155
+ tags,
156
+ middle_level=kwargs.get("middle_level"),
157
+ has_insample=has_insample,
158
+ )
159
+ reconciled = _apply_reconciler(
160
+ reconciler,
161
+ s_matrix,
162
+ y_hat,
163
+ indexed_tags,
164
+ y_insample=y_insample,
165
+ y_hat_insample=y_hat_insample,
166
+ )
167
+
168
+ reconciled = reconciled[hf_to_order]
169
+ if was_1d:
170
+ return reconciled[:, 0]
171
+ return reconciled
172
+
173
+
174
+ def reconcile_forecasts(
175
+ base_forecasts: pd.DataFrame,
176
+ structure: HierarchyStructure,
177
+ method: ReconciliationMethod | str = ReconciliationMethod.BOTTOM_UP,
178
+ ) -> pd.DataFrame:
179
+ """Reconcile forecast DataFrame to ensure hierarchy coherence."""
180
+ if isinstance(method, str):
181
+ method = ReconciliationMethod(method)
182
+
183
+ df = base_forecasts.copy()
184
+ id_col = "unique_id"
185
+ ds_col = "ds"
186
+ if not pd.api.types.is_datetime64_any_dtype(df[ds_col]):
187
+ df[ds_col] = pd.to_datetime(df[ds_col])
188
+
189
+ value_cols = [
190
+ c
191
+ for c in df.columns
192
+ if c not in {id_col, ds_col} and pd.api.types.is_numeric_dtype(df[c])
193
+ ]
194
+ if not value_cols:
195
+ raise ValueError("No numeric forecast columns to reconcile.")
196
+
197
+ from hierarchicalforecast.core import HierarchicalReconciliation
198
+
199
+ tags = structure.to_tags()
200
+ s_df = structure.to_s_df()
201
+ reconciler = _select_reconciler(method, tags)
202
+ engine = HierarchicalReconciliation([reconciler])
203
+
204
+ y_hat_df = df[[id_col, ds_col] + value_cols].copy()
205
+ reconciled = engine.reconcile(
206
+ Y_hat_df=y_hat_df,
207
+ S_df=s_df,
208
+ tags=tags,
209
+ Y_df=None,
210
+ id_col=id_col,
211
+ time_col=ds_col,
212
+ target_col="y",
213
+ )
214
+
215
+ method_label = reconciler.__class__.__name__
216
+ reconciled = reconciled.copy()
217
+ for col in value_cols:
218
+ candidate = f"{col}/{method_label}"
219
+ if candidate in reconciled.columns:
220
+ reconciled[col] = reconciled[candidate]
221
+ else:
222
+ alternatives = [c for c in reconciled.columns if c.startswith(f"{col}/")]
223
+ if len(alternatives) == 1:
224
+ reconciled[col] = reconciled[alternatives[0]]
225
+ drop_cols = [c for c in reconciled.columns if "/" in c and c.split("/")[0] in value_cols]
226
+ if drop_cols:
227
+ reconciled = reconciled.drop(columns=drop_cols)
228
+
229
+ return reconciled.sort_values([id_col, ds_col]).reset_index(drop=True)
230
+
231
+
232
+ __all__ = ["ReconciliationMethod", "Reconciler", "reconcile_forecasts"]
@@ -0,0 +1,453 @@
1
+ """Hierarchy structure definition for hierarchical time series.
2
+
3
+ Defines the aggregation relationships between time series in a hierarchy
4
+ and provides utilities for validation and navigation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import TYPE_CHECKING
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ if TYPE_CHECKING:
16
+ pass
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class HierarchyStructure:
21
+ """Defines hierarchical relationships between time series.
22
+
23
+ Represents the aggregation structure where bottom-level series
24
+ sum up to higher-level series. The structure is encoded using:
25
+ - aggregation_graph: Parent -> children mapping
26
+ - bottom_nodes: Leaf nodes (no children)
27
+ - s_matrix: Summation matrix (n_total x n_bottom)
28
+
29
+ Example structure for retail:
30
+ Total
31
+ ├── Region_North
32
+ │ ├── Store_A
33
+ │ └── Store_B
34
+ └── Region_South
35
+ ├── Store_C
36
+ └── Store_D
37
+
38
+ Attributes:
39
+ aggregation_graph: Mapping from parent to list of children
40
+ bottom_nodes: List of bottom-level (leaf) node names
41
+ s_matrix: Summation matrix where S[i,j] = 1 if bottom node j
42
+ contributes to node i
43
+ all_nodes: List of all nodes (computed automatically)
44
+
45
+ Example:
46
+ >>> structure = HierarchyStructure(
47
+ ... aggregation_graph={
48
+ ... "Total": ["Region_North", "Region_South"],
49
+ ... "Region_North": ["Store_A", "Store_B"],
50
+ ... "Region_South": ["Store_C", "Store_D"],
51
+ ... },
52
+ ... bottom_nodes=["Store_A", "Store_B", "Store_C", "Store_D"],
53
+ ... s_matrix=s_matrix, # 7x4 matrix
54
+ ... )
55
+ """
56
+
57
+ # Mapping from parent to children
58
+ aggregation_graph: dict[str, list[str]]
59
+
60
+ # Bottom-level nodes (leaf nodes)
61
+ bottom_nodes: list[str]
62
+
63
+ # All nodes in the hierarchy (computed)
64
+ all_nodes: list[str] = field(init=False)
65
+
66
+ # Aggregation matrix S (numpy array)
67
+ # Shape: (n_total, n_bottom)
68
+ # where S[i, j] = 1 if bottom node j contributes to node i
69
+ s_matrix: np.ndarray = field(repr=False)
70
+
71
+ def __post_init__(self) -> None:
72
+ """Validate the hierarchy structure after creation."""
73
+ # Check for empty hierarchy first
74
+ if not self.bottom_nodes:
75
+ raise ValueError("bottom_nodes cannot be empty")
76
+
77
+ # Compute all nodes first (needed for validation)
78
+ nodes = set(self.bottom_nodes)
79
+ for parent, children in self.aggregation_graph.items():
80
+ nodes.add(parent)
81
+ nodes.update(children)
82
+
83
+ # Use object.__setattr__ since dataclass is frozen
84
+ object.__setattr__(self, "all_nodes", sorted(nodes))
85
+
86
+ # Validate structure
87
+ self._validate_structure()
88
+
89
+ def _validate_structure(self) -> None:
90
+ """Validate hierarchy structure is consistent.
91
+
92
+ Raises:
93
+ ValueError: If structure is invalid
94
+ """
95
+ # Check all children exist (before S matrix validation)
96
+ for parent, children in self.aggregation_graph.items():
97
+ for child in children:
98
+ if child not in self.all_nodes:
99
+ raise ValueError(
100
+ f"Child '{child}' of parent '{parent}' not found in hierarchy"
101
+ )
102
+
103
+ # Check S matrix dimensions
104
+ n_total = len(self.all_nodes)
105
+ n_bottom = len(self.bottom_nodes)
106
+ if self.s_matrix.shape != (n_total, n_bottom):
107
+ raise ValueError(
108
+ f"S matrix shape {self.s_matrix.shape} doesn't match "
109
+ f"expected ({n_total}, {n_bottom})"
110
+ )
111
+
112
+ # Check bottom nodes have no children
113
+ for node in self.bottom_nodes:
114
+ if node in self.aggregation_graph:
115
+ raise ValueError(
116
+ f"Bottom node '{node}' cannot have children in aggregation_graph"
117
+ )
118
+
119
+ # Check S matrix values are valid (0 or 1)
120
+ if not np.all(np.isin(self.s_matrix, [0, 1])):
121
+ raise ValueError("S matrix must contain only 0s and 1s")
122
+
123
+ # Check each bottom node contributes to exactly one bottom position
124
+ for j, bottom_node in enumerate(self.bottom_nodes):
125
+ bottom_idx = self.all_nodes.index(bottom_node)
126
+ if self.s_matrix[bottom_idx, j] != 1:
127
+ raise ValueError(
128
+ f"Bottom node '{bottom_node}' must have S[{bottom_idx}, {j}] = 1"
129
+ )
130
+
131
+ @classmethod
132
+ def from_dataframe(
133
+ cls,
134
+ df: pd.DataFrame,
135
+ hierarchy_columns: list[str],
136
+ value_column: str = "y",
137
+ ) -> HierarchyStructure:
138
+ """Build hierarchy structure from DataFrame with hierarchical columns.
139
+
140
+ Automatically constructs the aggregation graph and summation matrix
141
+ from hierarchical identifiers in the DataFrame.
142
+
143
+ Args:
144
+ df: DataFrame with hierarchical identifiers
145
+ hierarchy_columns: Columns defining hierarchy (top to bottom)
146
+ value_column: Column containing values (for validation)
147
+
148
+ Returns:
149
+ HierarchyStructure built from the data
150
+
151
+ Example:
152
+ >>> df = pd.DataFrame({
153
+ ... "country": ["US", "US", "US", "US"],
154
+ ... "state": ["CA", "CA", "NY", "NY"],
155
+ ... "city": ["SF", "LA", "NYC", "BUF"],
156
+ ... "y": [100, 200, 300, 50]
157
+ ... })
158
+ >>> structure = HierarchyStructure.from_dataframe(
159
+ ... df, ["country", "state", "city"]
160
+ ... )
161
+ """
162
+ if not hierarchy_columns:
163
+ raise ValueError("hierarchy_columns cannot be empty")
164
+
165
+ # Get unique combinations
166
+ unique_combos = df[hierarchy_columns].drop_duplicates()
167
+
168
+ # Build aggregation graph
169
+ aggregation_graph: dict[str, list[str]] = {}
170
+
171
+ for level_idx in range(len(hierarchy_columns) - 1):
172
+ parent_col = hierarchy_columns[level_idx]
173
+ child_col = hierarchy_columns[level_idx + 1]
174
+
175
+ # Group by immediate parent column to find children
176
+ for parent_value, group in unique_combos.groupby(parent_col):
177
+ parent_key = str(parent_value)
178
+ children = group[child_col].unique().tolist()
179
+
180
+ if parent_key not in aggregation_graph:
181
+ aggregation_graph[parent_key] = []
182
+ aggregation_graph[parent_key].extend(children)
183
+ aggregation_graph[parent_key] = list(
184
+ dict.fromkeys(aggregation_graph[parent_key])
185
+ ) # Remove duplicates, preserve order
186
+
187
+ # Bottom nodes are the unique values at the lowest level
188
+ bottom_nodes = unique_combos[hierarchy_columns[-1]].unique().tolist()
189
+
190
+ # Build summation matrix
191
+ all_nodes = _get_all_nodes_from_graph(aggregation_graph, bottom_nodes)
192
+ s_matrix = _build_summation_matrix(
193
+ all_nodes, bottom_nodes, aggregation_graph
194
+ )
195
+
196
+ return cls(
197
+ aggregation_graph=aggregation_graph,
198
+ bottom_nodes=bottom_nodes,
199
+ s_matrix=s_matrix,
200
+ )
201
+
202
+ @classmethod
203
+ def from_summation_matrix(
204
+ cls,
205
+ s_matrix: np.ndarray,
206
+ node_names: list[str],
207
+ bottom_node_names: list[str],
208
+ ) -> HierarchyStructure:
209
+ """Build hierarchy from explicit summation matrix.
210
+
211
+ Args:
212
+ s_matrix: Summation matrix (n_nodes x n_bottom)
213
+ node_names: Names for all nodes (length n_nodes)
214
+ bottom_node_names: Names for bottom nodes (length n_bottom)
215
+
216
+ Returns:
217
+ HierarchyStructure
218
+ """
219
+ if len(node_names) != s_matrix.shape[0]:
220
+ raise ValueError(
221
+ f"node_names length {len(node_names)} doesn't match "
222
+ f"S matrix rows {s_matrix.shape[0]}"
223
+ )
224
+ if len(bottom_node_names) != s_matrix.shape[1]:
225
+ raise ValueError(
226
+ f"bottom_node_names length {len(bottom_node_names)} doesn't match "
227
+ f"S matrix columns {s_matrix.shape[1]}"
228
+ )
229
+
230
+ # Infer aggregation graph from S matrix
231
+ aggregation_graph: dict[str, list[str]] = {}
232
+
233
+ for i, node in enumerate(node_names):
234
+ if node in bottom_node_names:
235
+ continue # Skip bottom nodes
236
+
237
+ # Find children: nodes at lower level that sum to this node
238
+ children = []
239
+ for j, bottom_node in enumerate(bottom_node_names):
240
+ if s_matrix[i, j] == 1 and node != bottom_node:
241
+ # Check if this bottom node is a direct child
242
+ # or if there's an intermediate node
243
+ children.append(bottom_node)
244
+
245
+ if children:
246
+ aggregation_graph[node] = children
247
+
248
+ return cls(
249
+ aggregation_graph=aggregation_graph,
250
+ bottom_nodes=bottom_node_names,
251
+ s_matrix=s_matrix,
252
+ )
253
+
254
+ def get_parents(self, node: str) -> list[str]:
255
+ """Get parent nodes of a given node.
256
+
257
+ Args:
258
+ node: Node name
259
+
260
+ Returns:
261
+ List of parent node names
262
+
263
+ Raises:
264
+ ValueError: If node is not in hierarchy
265
+ """
266
+ if node not in self.all_nodes:
267
+ raise ValueError(f"Node '{node}' not in hierarchy")
268
+
269
+ parents = []
270
+ for parent, children in self.aggregation_graph.items():
271
+ if node in children:
272
+ parents.append(parent)
273
+ return parents
274
+
275
+ def get_children(self, node: str) -> list[str]:
276
+ """Get child nodes of a given node.
277
+
278
+ Args:
279
+ node: Node name
280
+
281
+ Returns:
282
+ List of child node names
283
+ """
284
+ return self.aggregation_graph.get(node, [])
285
+
286
+ def get_level(self, node: str) -> int:
287
+ """Get hierarchy level (0 = root, increasing downward).
288
+
289
+ Args:
290
+ node: Node name
291
+
292
+ Returns:
293
+ Hierarchy level (0 for root nodes)
294
+
295
+ Raises:
296
+ ValueError: If node is not in hierarchy
297
+ """
298
+ if node not in self.all_nodes:
299
+ raise ValueError(f"Node '{node}' not in hierarchy")
300
+
301
+ level = 0
302
+ current = node
303
+ parents = self.get_parents(current)
304
+
305
+ # Traverse up to find depth
306
+ while parents:
307
+ level += 1
308
+ current = parents[0] # Use first parent (works for tree structures)
309
+ parents = self.get_parents(current)
310
+
311
+ return level
312
+
313
+ def get_nodes_at_level(self, level: int) -> list[str]:
314
+ """Get all nodes at a specific hierarchy level.
315
+
316
+ Args:
317
+ level: Hierarchy level (0 = root)
318
+
319
+ Returns:
320
+ List of node names at that level
321
+ """
322
+ return [node for node in self.all_nodes if self.get_level(node) == level]
323
+
324
+ def is_leaf(self, node: str) -> bool:
325
+ """Check if node is a leaf (bottom-level) node.
326
+
327
+ Args:
328
+ node: Node name
329
+
330
+ Returns:
331
+ True if node is a bottom node
332
+ """
333
+ return node in self.bottom_nodes
334
+
335
+ def get_num_levels(self) -> int:
336
+ """Get the number of levels in the hierarchy.
337
+
338
+ Returns:
339
+ Maximum level + 1 (since levels start at 0)
340
+ """
341
+ if not self.all_nodes:
342
+ return 0
343
+ return max(self.get_level(node) for node in self.all_nodes) + 1
344
+
345
+ def node_order(self) -> list[str]:
346
+ """Return nodes ordered with aggregates first and bottom nodes last."""
347
+ order: list[str] = []
348
+ for level in range(self.get_num_levels()):
349
+ order.extend(self.get_nodes_at_level(level))
350
+ bottom = [n for n in self.bottom_nodes if n in order]
351
+ order = [n for n in order if n not in bottom] + bottom
352
+ return order
353
+
354
+ def to_s_df(self, id_col: str = "unique_id") -> pd.DataFrame:
355
+ """Return summation matrix as S_df (rows ordered aggregates -> bottom)."""
356
+ s_df = pd.DataFrame(
357
+ self.s_matrix,
358
+ index=self.all_nodes,
359
+ columns=self.bottom_nodes,
360
+ )
361
+ ordered = self.node_order()
362
+ s_df = s_df.reindex(index=ordered)
363
+ s_df = s_df.reset_index().rename(columns={"index": id_col})
364
+ return s_df
365
+
366
+ def to_tags(self, level_prefix: str = "level_") -> dict[str, np.ndarray]:
367
+ """Return hierarchical tags mapping level -> node names."""
368
+ tags: dict[str, np.ndarray] = {}
369
+ order = self.node_order()
370
+ for level in range(self.get_num_levels()):
371
+ level_nodes = [n for n in order if self.get_level(n) == level]
372
+ if level_nodes:
373
+ tags[f"{level_prefix}{level}"] = np.array(level_nodes, dtype=object)
374
+ return tags
375
+
376
+
377
+ def _get_all_nodes_from_graph(
378
+ aggregation_graph: dict[str, list[str]],
379
+ bottom_nodes: list[str],
380
+ ) -> list[str]:
381
+ """Get all nodes from aggregation graph and bottom nodes."""
382
+ nodes = set(bottom_nodes)
383
+ for parent, children in aggregation_graph.items():
384
+ nodes.add(parent)
385
+ nodes.update(children)
386
+ return sorted(nodes)
387
+
388
+
389
+ def _build_summation_matrix(
390
+ all_nodes: list[str],
391
+ bottom_nodes: list[str],
392
+ aggregation_graph: dict[str, list[str]],
393
+ ) -> np.ndarray:
394
+ """Build summation matrix from hierarchy definition.
395
+
396
+ The S matrix encodes which bottom nodes contribute to each node.
397
+ S[i, j] = 1 if bottom node j contributes to node i.
398
+ """
399
+ n_total = len(all_nodes)
400
+ n_bottom = len(bottom_nodes)
401
+
402
+ s_matrix = np.zeros((n_total, n_bottom), dtype=int)
403
+
404
+ # Map node names to indices
405
+ node_to_idx = {node: i for i, node in enumerate(all_nodes)}
406
+ bottom_to_idx = {node: j for j, node in enumerate(bottom_nodes)}
407
+
408
+ # For each bottom node, determine all ancestors
409
+ for bottom_node in bottom_nodes:
410
+ bottom_idx = node_to_idx[bottom_node]
411
+ j = bottom_to_idx[bottom_node]
412
+
413
+ # Bottom node contributes to itself
414
+ s_matrix[bottom_idx, j] = 1
415
+
416
+ # Trace up the hierarchy
417
+ _add_ancestor_contributions(
418
+ bottom_node,
419
+ j,
420
+ s_matrix,
421
+ node_to_idx,
422
+ aggregation_graph,
423
+ )
424
+
425
+ return s_matrix
426
+
427
+
428
+ def _add_ancestor_contributions(
429
+ node: str,
430
+ bottom_idx: int,
431
+ s_matrix: np.ndarray,
432
+ node_to_idx: dict[str, int],
433
+ aggregation_graph: dict[str, list[str]],
434
+ ) -> None:
435
+ """Recursively add contributions for all ancestors of a node."""
436
+ # Find all parents of this node
437
+ parents = [
438
+ parent
439
+ for parent, children in aggregation_graph.items()
440
+ if node in children
441
+ ]
442
+
443
+ for parent in parents:
444
+ parent_idx = node_to_idx[parent]
445
+ s_matrix[parent_idx, bottom_idx] = 1
446
+ # Recurse up the hierarchy
447
+ _add_ancestor_contributions(
448
+ parent,
449
+ bottom_idx,
450
+ s_matrix,
451
+ node_to_idx,
452
+ aggregation_graph,
453
+ )