tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""Hierarchical forecast reconciliation (adapter to hierarchicalforecast)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from .structure import HierarchyStructure
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ReconciliationMethod(Enum):
|
|
15
|
+
"""Available reconciliation methods."""
|
|
16
|
+
|
|
17
|
+
BOTTOM_UP = "bottom_up"
|
|
18
|
+
TOP_DOWN = "top_down"
|
|
19
|
+
MIDDLE_OUT = "middle_out"
|
|
20
|
+
OLS = "ols"
|
|
21
|
+
WLS = "wls"
|
|
22
|
+
MIN_TRACE = "min_trace"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _build_tag_indices(tags: dict[str, np.ndarray], order: list[str]) -> dict[str, np.ndarray]:
|
|
26
|
+
node_to_idx = {node: idx for idx, node in enumerate(order)}
|
|
27
|
+
indexed: dict[str, np.ndarray] = {}
|
|
28
|
+
for key, nodes in tags.items():
|
|
29
|
+
indices = [node_to_idx[n] for n in nodes if n in node_to_idx]
|
|
30
|
+
if indices:
|
|
31
|
+
indexed[key] = np.array(indices, dtype=int)
|
|
32
|
+
return indexed
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _build_s_matrix(structure: HierarchyStructure, order: list[str]) -> np.ndarray:
|
|
36
|
+
s_df = structure.to_s_df()
|
|
37
|
+
s_df = s_df.set_index("unique_id").reindex(order)
|
|
38
|
+
return s_df[structure.bottom_nodes].to_numpy(dtype=float)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _get_level_key(tags: dict[str, np.ndarray], middle_level: int | str | None) -> str:
|
|
42
|
+
level_keys = [k for k in tags if k.startswith("level_")]
|
|
43
|
+
level_keys = sorted(level_keys, key=lambda k: int(k.split("_")[1]))
|
|
44
|
+
if not level_keys:
|
|
45
|
+
return next(iter(tags), "bottom")
|
|
46
|
+
|
|
47
|
+
if isinstance(middle_level, str) and middle_level in tags:
|
|
48
|
+
return middle_level
|
|
49
|
+
if isinstance(middle_level, int):
|
|
50
|
+
idx = min(max(middle_level, 0), len(level_keys) - 1)
|
|
51
|
+
return level_keys[idx]
|
|
52
|
+
|
|
53
|
+
if len(level_keys) >= 3:
|
|
54
|
+
return level_keys[len(level_keys) // 2]
|
|
55
|
+
if len(level_keys) == 2:
|
|
56
|
+
return level_keys[1]
|
|
57
|
+
return level_keys[0]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _select_reconciler(
|
|
61
|
+
method: ReconciliationMethod,
|
|
62
|
+
tags: dict[str, np.ndarray],
|
|
63
|
+
middle_level: int | str | None = None,
|
|
64
|
+
has_insample: bool = False,
|
|
65
|
+
) -> Any:
|
|
66
|
+
from hierarchicalforecast.methods import BottomUp, MiddleOut, MinTrace, TopDown
|
|
67
|
+
|
|
68
|
+
if method == ReconciliationMethod.BOTTOM_UP:
|
|
69
|
+
return BottomUp()
|
|
70
|
+
if method == ReconciliationMethod.TOP_DOWN:
|
|
71
|
+
return TopDown(method="forecast_proportions")
|
|
72
|
+
if method == ReconciliationMethod.MIDDLE_OUT:
|
|
73
|
+
return MiddleOut(
|
|
74
|
+
middle_level=_get_level_key(tags, middle_level),
|
|
75
|
+
top_down_method="forecast_proportions",
|
|
76
|
+
)
|
|
77
|
+
if method == ReconciliationMethod.OLS:
|
|
78
|
+
return MinTrace(method="ols")
|
|
79
|
+
if method == ReconciliationMethod.WLS:
|
|
80
|
+
return MinTrace(method="wls_struct")
|
|
81
|
+
if method == ReconciliationMethod.MIN_TRACE:
|
|
82
|
+
return MinTrace(method="mint_shrink" if has_insample else "ols")
|
|
83
|
+
return MinTrace(method="ols")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _apply_reconciler(
|
|
87
|
+
reconciler: Any,
|
|
88
|
+
s_matrix: np.ndarray,
|
|
89
|
+
y_hat: np.ndarray,
|
|
90
|
+
tags: dict[str, np.ndarray],
|
|
91
|
+
y_insample: np.ndarray | None = None,
|
|
92
|
+
y_hat_insample: np.ndarray | None = None,
|
|
93
|
+
) -> np.ndarray:
|
|
94
|
+
result = reconciler.fit_predict(
|
|
95
|
+
S=s_matrix,
|
|
96
|
+
y_hat=y_hat,
|
|
97
|
+
tags=tags,
|
|
98
|
+
y_insample=y_insample,
|
|
99
|
+
y_hat_insample=y_hat_insample,
|
|
100
|
+
)
|
|
101
|
+
if isinstance(result, dict):
|
|
102
|
+
if "mean" in result:
|
|
103
|
+
return np.asarray(result["mean"])
|
|
104
|
+
# Fallback to first array-like entry
|
|
105
|
+
for value in result.values():
|
|
106
|
+
if isinstance(value, (np.ndarray, list)):
|
|
107
|
+
return np.asarray(value)
|
|
108
|
+
raise ValueError("Unsupported reconciliation output format.")
|
|
109
|
+
return np.asarray(result)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class Reconciler:
|
|
113
|
+
"""Hierarchical forecast reconciliation engine (adapter)."""
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
method: ReconciliationMethod,
|
|
118
|
+
structure: HierarchyStructure,
|
|
119
|
+
) -> None:
|
|
120
|
+
self.method = method
|
|
121
|
+
self.structure = structure
|
|
122
|
+
|
|
123
|
+
def reconcile(
|
|
124
|
+
self,
|
|
125
|
+
base_forecasts: np.ndarray,
|
|
126
|
+
fitted_values: np.ndarray | None = None,
|
|
127
|
+
residuals: np.ndarray | None = None,
|
|
128
|
+
**kwargs: Any,
|
|
129
|
+
) -> np.ndarray:
|
|
130
|
+
"""Reconcile base forecasts to be hierarchy-consistent."""
|
|
131
|
+
y_hat = np.asarray(base_forecasts, dtype=float)
|
|
132
|
+
was_1d = y_hat.ndim == 1
|
|
133
|
+
if was_1d:
|
|
134
|
+
y_hat = y_hat[:, None]
|
|
135
|
+
|
|
136
|
+
hf_order = self.structure.node_order()
|
|
137
|
+
tags = self.structure.to_tags()
|
|
138
|
+
indexed_tags = _build_tag_indices(tags, hf_order)
|
|
139
|
+
s_matrix = _build_s_matrix(self.structure, hf_order)
|
|
140
|
+
order_to_hf = [self.structure.all_nodes.index(node) for node in hf_order]
|
|
141
|
+
hf_to_order = np.argsort(order_to_hf)
|
|
142
|
+
y_hat = y_hat[order_to_hf]
|
|
143
|
+
|
|
144
|
+
y_insample = None
|
|
145
|
+
y_hat_insample = None
|
|
146
|
+
has_insample = False
|
|
147
|
+
if fitted_values is not None and residuals is not None:
|
|
148
|
+
y_hat_insample = np.asarray(fitted_values, dtype=float)[order_to_hf]
|
|
149
|
+
residuals_arr = np.asarray(residuals, dtype=float)[order_to_hf]
|
|
150
|
+
y_insample = y_hat_insample + residuals_arr
|
|
151
|
+
has_insample = True
|
|
152
|
+
|
|
153
|
+
reconciler = _select_reconciler(
|
|
154
|
+
self.method,
|
|
155
|
+
tags,
|
|
156
|
+
middle_level=kwargs.get("middle_level"),
|
|
157
|
+
has_insample=has_insample,
|
|
158
|
+
)
|
|
159
|
+
reconciled = _apply_reconciler(
|
|
160
|
+
reconciler,
|
|
161
|
+
s_matrix,
|
|
162
|
+
y_hat,
|
|
163
|
+
indexed_tags,
|
|
164
|
+
y_insample=y_insample,
|
|
165
|
+
y_hat_insample=y_hat_insample,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
reconciled = reconciled[hf_to_order]
|
|
169
|
+
if was_1d:
|
|
170
|
+
return reconciled[:, 0]
|
|
171
|
+
return reconciled
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def reconcile_forecasts(
|
|
175
|
+
base_forecasts: pd.DataFrame,
|
|
176
|
+
structure: HierarchyStructure,
|
|
177
|
+
method: ReconciliationMethod | str = ReconciliationMethod.BOTTOM_UP,
|
|
178
|
+
) -> pd.DataFrame:
|
|
179
|
+
"""Reconcile forecast DataFrame to ensure hierarchy coherence."""
|
|
180
|
+
if isinstance(method, str):
|
|
181
|
+
method = ReconciliationMethod(method)
|
|
182
|
+
|
|
183
|
+
df = base_forecasts.copy()
|
|
184
|
+
id_col = "unique_id"
|
|
185
|
+
ds_col = "ds"
|
|
186
|
+
if not pd.api.types.is_datetime64_any_dtype(df[ds_col]):
|
|
187
|
+
df[ds_col] = pd.to_datetime(df[ds_col])
|
|
188
|
+
|
|
189
|
+
value_cols = [
|
|
190
|
+
c
|
|
191
|
+
for c in df.columns
|
|
192
|
+
if c not in {id_col, ds_col} and pd.api.types.is_numeric_dtype(df[c])
|
|
193
|
+
]
|
|
194
|
+
if not value_cols:
|
|
195
|
+
raise ValueError("No numeric forecast columns to reconcile.")
|
|
196
|
+
|
|
197
|
+
from hierarchicalforecast.core import HierarchicalReconciliation
|
|
198
|
+
|
|
199
|
+
tags = structure.to_tags()
|
|
200
|
+
s_df = structure.to_s_df()
|
|
201
|
+
reconciler = _select_reconciler(method, tags)
|
|
202
|
+
engine = HierarchicalReconciliation([reconciler])
|
|
203
|
+
|
|
204
|
+
y_hat_df = df[[id_col, ds_col] + value_cols].copy()
|
|
205
|
+
reconciled = engine.reconcile(
|
|
206
|
+
Y_hat_df=y_hat_df,
|
|
207
|
+
S_df=s_df,
|
|
208
|
+
tags=tags,
|
|
209
|
+
Y_df=None,
|
|
210
|
+
id_col=id_col,
|
|
211
|
+
time_col=ds_col,
|
|
212
|
+
target_col="y",
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
method_label = reconciler.__class__.__name__
|
|
216
|
+
reconciled = reconciled.copy()
|
|
217
|
+
for col in value_cols:
|
|
218
|
+
candidate = f"{col}/{method_label}"
|
|
219
|
+
if candidate in reconciled.columns:
|
|
220
|
+
reconciled[col] = reconciled[candidate]
|
|
221
|
+
else:
|
|
222
|
+
alternatives = [c for c in reconciled.columns if c.startswith(f"{col}/")]
|
|
223
|
+
if len(alternatives) == 1:
|
|
224
|
+
reconciled[col] = reconciled[alternatives[0]]
|
|
225
|
+
drop_cols = [c for c in reconciled.columns if "/" in c and c.split("/")[0] in value_cols]
|
|
226
|
+
if drop_cols:
|
|
227
|
+
reconciled = reconciled.drop(columns=drop_cols)
|
|
228
|
+
|
|
229
|
+
return reconciled.sort_values([id_col, ds_col]).reset_index(drop=True)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
__all__ = ["ReconciliationMethod", "Reconciler", "reconcile_forecasts"]
|
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
"""Hierarchy structure definition for hierarchical time series.
|
|
2
|
+
|
|
3
|
+
Defines the aggregation relationships between time series in a hierarchy
|
|
4
|
+
and provides utilities for validation and navigation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class HierarchyStructure:
|
|
21
|
+
"""Defines hierarchical relationships between time series.
|
|
22
|
+
|
|
23
|
+
Represents the aggregation structure where bottom-level series
|
|
24
|
+
sum up to higher-level series. The structure is encoded using:
|
|
25
|
+
- aggregation_graph: Parent -> children mapping
|
|
26
|
+
- bottom_nodes: Leaf nodes (no children)
|
|
27
|
+
- s_matrix: Summation matrix (n_total x n_bottom)
|
|
28
|
+
|
|
29
|
+
Example structure for retail:
|
|
30
|
+
Total
|
|
31
|
+
├── Region_North
|
|
32
|
+
│ ├── Store_A
|
|
33
|
+
│ └── Store_B
|
|
34
|
+
└── Region_South
|
|
35
|
+
├── Store_C
|
|
36
|
+
└── Store_D
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
aggregation_graph: Mapping from parent to list of children
|
|
40
|
+
bottom_nodes: List of bottom-level (leaf) node names
|
|
41
|
+
s_matrix: Summation matrix where S[i,j] = 1 if bottom node j
|
|
42
|
+
contributes to node i
|
|
43
|
+
all_nodes: List of all nodes (computed automatically)
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
>>> structure = HierarchyStructure(
|
|
47
|
+
... aggregation_graph={
|
|
48
|
+
... "Total": ["Region_North", "Region_South"],
|
|
49
|
+
... "Region_North": ["Store_A", "Store_B"],
|
|
50
|
+
... "Region_South": ["Store_C", "Store_D"],
|
|
51
|
+
... },
|
|
52
|
+
... bottom_nodes=["Store_A", "Store_B", "Store_C", "Store_D"],
|
|
53
|
+
... s_matrix=s_matrix, # 7x4 matrix
|
|
54
|
+
... )
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# Mapping from parent to children
|
|
58
|
+
aggregation_graph: dict[str, list[str]]
|
|
59
|
+
|
|
60
|
+
# Bottom-level nodes (leaf nodes)
|
|
61
|
+
bottom_nodes: list[str]
|
|
62
|
+
|
|
63
|
+
# All nodes in the hierarchy (computed)
|
|
64
|
+
all_nodes: list[str] = field(init=False)
|
|
65
|
+
|
|
66
|
+
# Aggregation matrix S (numpy array)
|
|
67
|
+
# Shape: (n_total, n_bottom)
|
|
68
|
+
# where S[i, j] = 1 if bottom node j contributes to node i
|
|
69
|
+
s_matrix: np.ndarray = field(repr=False)
|
|
70
|
+
|
|
71
|
+
def __post_init__(self) -> None:
|
|
72
|
+
"""Validate the hierarchy structure after creation."""
|
|
73
|
+
# Check for empty hierarchy first
|
|
74
|
+
if not self.bottom_nodes:
|
|
75
|
+
raise ValueError("bottom_nodes cannot be empty")
|
|
76
|
+
|
|
77
|
+
# Compute all nodes first (needed for validation)
|
|
78
|
+
nodes = set(self.bottom_nodes)
|
|
79
|
+
for parent, children in self.aggregation_graph.items():
|
|
80
|
+
nodes.add(parent)
|
|
81
|
+
nodes.update(children)
|
|
82
|
+
|
|
83
|
+
# Use object.__setattr__ since dataclass is frozen
|
|
84
|
+
object.__setattr__(self, "all_nodes", sorted(nodes))
|
|
85
|
+
|
|
86
|
+
# Validate structure
|
|
87
|
+
self._validate_structure()
|
|
88
|
+
|
|
89
|
+
def _validate_structure(self) -> None:
|
|
90
|
+
"""Validate hierarchy structure is consistent.
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
ValueError: If structure is invalid
|
|
94
|
+
"""
|
|
95
|
+
# Check all children exist (before S matrix validation)
|
|
96
|
+
for parent, children in self.aggregation_graph.items():
|
|
97
|
+
for child in children:
|
|
98
|
+
if child not in self.all_nodes:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Child '{child}' of parent '{parent}' not found in hierarchy"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Check S matrix dimensions
|
|
104
|
+
n_total = len(self.all_nodes)
|
|
105
|
+
n_bottom = len(self.bottom_nodes)
|
|
106
|
+
if self.s_matrix.shape != (n_total, n_bottom):
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"S matrix shape {self.s_matrix.shape} doesn't match "
|
|
109
|
+
f"expected ({n_total}, {n_bottom})"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Check bottom nodes have no children
|
|
113
|
+
for node in self.bottom_nodes:
|
|
114
|
+
if node in self.aggregation_graph:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"Bottom node '{node}' cannot have children in aggregation_graph"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Check S matrix values are valid (0 or 1)
|
|
120
|
+
if not np.all(np.isin(self.s_matrix, [0, 1])):
|
|
121
|
+
raise ValueError("S matrix must contain only 0s and 1s")
|
|
122
|
+
|
|
123
|
+
# Check each bottom node contributes to exactly one bottom position
|
|
124
|
+
for j, bottom_node in enumerate(self.bottom_nodes):
|
|
125
|
+
bottom_idx = self.all_nodes.index(bottom_node)
|
|
126
|
+
if self.s_matrix[bottom_idx, j] != 1:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"Bottom node '{bottom_node}' must have S[{bottom_idx}, {j}] = 1"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def from_dataframe(
|
|
133
|
+
cls,
|
|
134
|
+
df: pd.DataFrame,
|
|
135
|
+
hierarchy_columns: list[str],
|
|
136
|
+
value_column: str = "y",
|
|
137
|
+
) -> HierarchyStructure:
|
|
138
|
+
"""Build hierarchy structure from DataFrame with hierarchical columns.
|
|
139
|
+
|
|
140
|
+
Automatically constructs the aggregation graph and summation matrix
|
|
141
|
+
from hierarchical identifiers in the DataFrame.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
df: DataFrame with hierarchical identifiers
|
|
145
|
+
hierarchy_columns: Columns defining hierarchy (top to bottom)
|
|
146
|
+
value_column: Column containing values (for validation)
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
HierarchyStructure built from the data
|
|
150
|
+
|
|
151
|
+
Example:
|
|
152
|
+
>>> df = pd.DataFrame({
|
|
153
|
+
... "country": ["US", "US", "US", "US"],
|
|
154
|
+
... "state": ["CA", "CA", "NY", "NY"],
|
|
155
|
+
... "city": ["SF", "LA", "NYC", "BUF"],
|
|
156
|
+
... "y": [100, 200, 300, 50]
|
|
157
|
+
... })
|
|
158
|
+
>>> structure = HierarchyStructure.from_dataframe(
|
|
159
|
+
... df, ["country", "state", "city"]
|
|
160
|
+
... )
|
|
161
|
+
"""
|
|
162
|
+
if not hierarchy_columns:
|
|
163
|
+
raise ValueError("hierarchy_columns cannot be empty")
|
|
164
|
+
|
|
165
|
+
# Get unique combinations
|
|
166
|
+
unique_combos = df[hierarchy_columns].drop_duplicates()
|
|
167
|
+
|
|
168
|
+
# Build aggregation graph
|
|
169
|
+
aggregation_graph: dict[str, list[str]] = {}
|
|
170
|
+
|
|
171
|
+
for level_idx in range(len(hierarchy_columns) - 1):
|
|
172
|
+
parent_col = hierarchy_columns[level_idx]
|
|
173
|
+
child_col = hierarchy_columns[level_idx + 1]
|
|
174
|
+
|
|
175
|
+
# Group by immediate parent column to find children
|
|
176
|
+
for parent_value, group in unique_combos.groupby(parent_col):
|
|
177
|
+
parent_key = str(parent_value)
|
|
178
|
+
children = group[child_col].unique().tolist()
|
|
179
|
+
|
|
180
|
+
if parent_key not in aggregation_graph:
|
|
181
|
+
aggregation_graph[parent_key] = []
|
|
182
|
+
aggregation_graph[parent_key].extend(children)
|
|
183
|
+
aggregation_graph[parent_key] = list(
|
|
184
|
+
dict.fromkeys(aggregation_graph[parent_key])
|
|
185
|
+
) # Remove duplicates, preserve order
|
|
186
|
+
|
|
187
|
+
# Bottom nodes are the unique values at the lowest level
|
|
188
|
+
bottom_nodes = unique_combos[hierarchy_columns[-1]].unique().tolist()
|
|
189
|
+
|
|
190
|
+
# Build summation matrix
|
|
191
|
+
all_nodes = _get_all_nodes_from_graph(aggregation_graph, bottom_nodes)
|
|
192
|
+
s_matrix = _build_summation_matrix(
|
|
193
|
+
all_nodes, bottom_nodes, aggregation_graph
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return cls(
|
|
197
|
+
aggregation_graph=aggregation_graph,
|
|
198
|
+
bottom_nodes=bottom_nodes,
|
|
199
|
+
s_matrix=s_matrix,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
@classmethod
|
|
203
|
+
def from_summation_matrix(
|
|
204
|
+
cls,
|
|
205
|
+
s_matrix: np.ndarray,
|
|
206
|
+
node_names: list[str],
|
|
207
|
+
bottom_node_names: list[str],
|
|
208
|
+
) -> HierarchyStructure:
|
|
209
|
+
"""Build hierarchy from explicit summation matrix.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
s_matrix: Summation matrix (n_nodes x n_bottom)
|
|
213
|
+
node_names: Names for all nodes (length n_nodes)
|
|
214
|
+
bottom_node_names: Names for bottom nodes (length n_bottom)
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
HierarchyStructure
|
|
218
|
+
"""
|
|
219
|
+
if len(node_names) != s_matrix.shape[0]:
|
|
220
|
+
raise ValueError(
|
|
221
|
+
f"node_names length {len(node_names)} doesn't match "
|
|
222
|
+
f"S matrix rows {s_matrix.shape[0]}"
|
|
223
|
+
)
|
|
224
|
+
if len(bottom_node_names) != s_matrix.shape[1]:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
f"bottom_node_names length {len(bottom_node_names)} doesn't match "
|
|
227
|
+
f"S matrix columns {s_matrix.shape[1]}"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Infer aggregation graph from S matrix
|
|
231
|
+
aggregation_graph: dict[str, list[str]] = {}
|
|
232
|
+
|
|
233
|
+
for i, node in enumerate(node_names):
|
|
234
|
+
if node in bottom_node_names:
|
|
235
|
+
continue # Skip bottom nodes
|
|
236
|
+
|
|
237
|
+
# Find children: nodes at lower level that sum to this node
|
|
238
|
+
children = []
|
|
239
|
+
for j, bottom_node in enumerate(bottom_node_names):
|
|
240
|
+
if s_matrix[i, j] == 1 and node != bottom_node:
|
|
241
|
+
# Check if this bottom node is a direct child
|
|
242
|
+
# or if there's an intermediate node
|
|
243
|
+
children.append(bottom_node)
|
|
244
|
+
|
|
245
|
+
if children:
|
|
246
|
+
aggregation_graph[node] = children
|
|
247
|
+
|
|
248
|
+
return cls(
|
|
249
|
+
aggregation_graph=aggregation_graph,
|
|
250
|
+
bottom_nodes=bottom_node_names,
|
|
251
|
+
s_matrix=s_matrix,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def get_parents(self, node: str) -> list[str]:
|
|
255
|
+
"""Get parent nodes of a given node.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
node: Node name
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
List of parent node names
|
|
262
|
+
|
|
263
|
+
Raises:
|
|
264
|
+
ValueError: If node is not in hierarchy
|
|
265
|
+
"""
|
|
266
|
+
if node not in self.all_nodes:
|
|
267
|
+
raise ValueError(f"Node '{node}' not in hierarchy")
|
|
268
|
+
|
|
269
|
+
parents = []
|
|
270
|
+
for parent, children in self.aggregation_graph.items():
|
|
271
|
+
if node in children:
|
|
272
|
+
parents.append(parent)
|
|
273
|
+
return parents
|
|
274
|
+
|
|
275
|
+
def get_children(self, node: str) -> list[str]:
|
|
276
|
+
"""Get child nodes of a given node.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
node: Node name
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
List of child node names
|
|
283
|
+
"""
|
|
284
|
+
return self.aggregation_graph.get(node, [])
|
|
285
|
+
|
|
286
|
+
def get_level(self, node: str) -> int:
|
|
287
|
+
"""Get hierarchy level (0 = root, increasing downward).
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
node: Node name
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Hierarchy level (0 for root nodes)
|
|
294
|
+
|
|
295
|
+
Raises:
|
|
296
|
+
ValueError: If node is not in hierarchy
|
|
297
|
+
"""
|
|
298
|
+
if node not in self.all_nodes:
|
|
299
|
+
raise ValueError(f"Node '{node}' not in hierarchy")
|
|
300
|
+
|
|
301
|
+
level = 0
|
|
302
|
+
current = node
|
|
303
|
+
parents = self.get_parents(current)
|
|
304
|
+
|
|
305
|
+
# Traverse up to find depth
|
|
306
|
+
while parents:
|
|
307
|
+
level += 1
|
|
308
|
+
current = parents[0] # Use first parent (works for tree structures)
|
|
309
|
+
parents = self.get_parents(current)
|
|
310
|
+
|
|
311
|
+
return level
|
|
312
|
+
|
|
313
|
+
def get_nodes_at_level(self, level: int) -> list[str]:
|
|
314
|
+
"""Get all nodes at a specific hierarchy level.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
level: Hierarchy level (0 = root)
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
List of node names at that level
|
|
321
|
+
"""
|
|
322
|
+
return [node for node in self.all_nodes if self.get_level(node) == level]
|
|
323
|
+
|
|
324
|
+
def is_leaf(self, node: str) -> bool:
|
|
325
|
+
"""Check if node is a leaf (bottom-level) node.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
node: Node name
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
True if node is a bottom node
|
|
332
|
+
"""
|
|
333
|
+
return node in self.bottom_nodes
|
|
334
|
+
|
|
335
|
+
def get_num_levels(self) -> int:
|
|
336
|
+
"""Get the number of levels in the hierarchy.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
Maximum level + 1 (since levels start at 0)
|
|
340
|
+
"""
|
|
341
|
+
if not self.all_nodes:
|
|
342
|
+
return 0
|
|
343
|
+
return max(self.get_level(node) for node in self.all_nodes) + 1
|
|
344
|
+
|
|
345
|
+
def node_order(self) -> list[str]:
|
|
346
|
+
"""Return nodes ordered with aggregates first and bottom nodes last."""
|
|
347
|
+
order: list[str] = []
|
|
348
|
+
for level in range(self.get_num_levels()):
|
|
349
|
+
order.extend(self.get_nodes_at_level(level))
|
|
350
|
+
bottom = [n for n in self.bottom_nodes if n in order]
|
|
351
|
+
order = [n for n in order if n not in bottom] + bottom
|
|
352
|
+
return order
|
|
353
|
+
|
|
354
|
+
def to_s_df(self, id_col: str = "unique_id") -> pd.DataFrame:
|
|
355
|
+
"""Return summation matrix as S_df (rows ordered aggregates -> bottom)."""
|
|
356
|
+
s_df = pd.DataFrame(
|
|
357
|
+
self.s_matrix,
|
|
358
|
+
index=self.all_nodes,
|
|
359
|
+
columns=self.bottom_nodes,
|
|
360
|
+
)
|
|
361
|
+
ordered = self.node_order()
|
|
362
|
+
s_df = s_df.reindex(index=ordered)
|
|
363
|
+
s_df = s_df.reset_index().rename(columns={"index": id_col})
|
|
364
|
+
return s_df
|
|
365
|
+
|
|
366
|
+
def to_tags(self, level_prefix: str = "level_") -> dict[str, np.ndarray]:
|
|
367
|
+
"""Return hierarchical tags mapping level -> node names."""
|
|
368
|
+
tags: dict[str, np.ndarray] = {}
|
|
369
|
+
order = self.node_order()
|
|
370
|
+
for level in range(self.get_num_levels()):
|
|
371
|
+
level_nodes = [n for n in order if self.get_level(n) == level]
|
|
372
|
+
if level_nodes:
|
|
373
|
+
tags[f"{level_prefix}{level}"] = np.array(level_nodes, dtype=object)
|
|
374
|
+
return tags
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def _get_all_nodes_from_graph(
|
|
378
|
+
aggregation_graph: dict[str, list[str]],
|
|
379
|
+
bottom_nodes: list[str],
|
|
380
|
+
) -> list[str]:
|
|
381
|
+
"""Get all nodes from aggregation graph and bottom nodes."""
|
|
382
|
+
nodes = set(bottom_nodes)
|
|
383
|
+
for parent, children in aggregation_graph.items():
|
|
384
|
+
nodes.add(parent)
|
|
385
|
+
nodes.update(children)
|
|
386
|
+
return sorted(nodes)
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def _build_summation_matrix(
|
|
390
|
+
all_nodes: list[str],
|
|
391
|
+
bottom_nodes: list[str],
|
|
392
|
+
aggregation_graph: dict[str, list[str]],
|
|
393
|
+
) -> np.ndarray:
|
|
394
|
+
"""Build summation matrix from hierarchy definition.
|
|
395
|
+
|
|
396
|
+
The S matrix encodes which bottom nodes contribute to each node.
|
|
397
|
+
S[i, j] = 1 if bottom node j contributes to node i.
|
|
398
|
+
"""
|
|
399
|
+
n_total = len(all_nodes)
|
|
400
|
+
n_bottom = len(bottom_nodes)
|
|
401
|
+
|
|
402
|
+
s_matrix = np.zeros((n_total, n_bottom), dtype=int)
|
|
403
|
+
|
|
404
|
+
# Map node names to indices
|
|
405
|
+
node_to_idx = {node: i for i, node in enumerate(all_nodes)}
|
|
406
|
+
bottom_to_idx = {node: j for j, node in enumerate(bottom_nodes)}
|
|
407
|
+
|
|
408
|
+
# For each bottom node, determine all ancestors
|
|
409
|
+
for bottom_node in bottom_nodes:
|
|
410
|
+
bottom_idx = node_to_idx[bottom_node]
|
|
411
|
+
j = bottom_to_idx[bottom_node]
|
|
412
|
+
|
|
413
|
+
# Bottom node contributes to itself
|
|
414
|
+
s_matrix[bottom_idx, j] = 1
|
|
415
|
+
|
|
416
|
+
# Trace up the hierarchy
|
|
417
|
+
_add_ancestor_contributions(
|
|
418
|
+
bottom_node,
|
|
419
|
+
j,
|
|
420
|
+
s_matrix,
|
|
421
|
+
node_to_idx,
|
|
422
|
+
aggregation_graph,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
return s_matrix
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _add_ancestor_contributions(
|
|
429
|
+
node: str,
|
|
430
|
+
bottom_idx: int,
|
|
431
|
+
s_matrix: np.ndarray,
|
|
432
|
+
node_to_idx: dict[str, int],
|
|
433
|
+
aggregation_graph: dict[str, list[str]],
|
|
434
|
+
) -> None:
|
|
435
|
+
"""Recursively add contributions for all ancestors of a node."""
|
|
436
|
+
# Find all parents of this node
|
|
437
|
+
parents = [
|
|
438
|
+
parent
|
|
439
|
+
for parent, children in aggregation_graph.items()
|
|
440
|
+
if node in children
|
|
441
|
+
]
|
|
442
|
+
|
|
443
|
+
for parent in parents:
|
|
444
|
+
parent_idx = node_to_idx[parent]
|
|
445
|
+
s_matrix[parent_idx, bottom_idx] = 1
|
|
446
|
+
# Recurse up the hierarchy
|
|
447
|
+
_add_ancestor_contributions(
|
|
448
|
+
parent,
|
|
449
|
+
bottom_idx,
|
|
450
|
+
s_matrix,
|
|
451
|
+
node_to_idx,
|
|
452
|
+
aggregation_graph,
|
|
453
|
+
)
|