patchfm 1.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- patchfm-1.1.9/LICENSE +21 -0
- patchfm-1.1.9/PKG-INFO +128 -0
- patchfm-1.1.9/README.md +86 -0
- patchfm-1.1.9/pyproject.toml +31 -0
- patchfm-1.1.9/setup.cfg +4 -0
- patchfm-1.1.9/src/patchfm/__init__.py +2 -0
- patchfm-1.1.9/src/patchfm/configs/model_config.py +21 -0
- patchfm-1.1.9/src/patchfm/inference/forecaster.py +146 -0
- patchfm-1.1.9/src/patchfm/inference/modules.py +268 -0
- patchfm-1.1.9/src/patchfm.egg-info/PKG-INFO +128 -0
- patchfm-1.1.9/src/patchfm.egg-info/SOURCES.txt +12 -0
- patchfm-1.1.9/src/patchfm.egg-info/dependency_links.txt +1 -0
- patchfm-1.1.9/src/patchfm.egg-info/requires.txt +6 -0
- patchfm-1.1.9/src/patchfm.egg-info/top_level.txt +1 -0
patchfm-1.1.9/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Samy-Melwan Vilhes
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
patchfm-1.1.9/PKG-INFO
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: patchfm
|
|
3
|
+
Version: 1.1.9
|
|
4
|
+
Summary: a Foundation Model for Univariate Time Series Forecasting
|
|
5
|
+
Author-email: Samy-Melwan Vilhes <samy-melwan.vilhes@insa-rouen.fr>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 Samy-Melwan Vilhes
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
Project-URL: Repository, https://github.com/vilhess/PatchFM
|
|
28
|
+
Project-URL: Issues, https://github.com/vilhess/PatchFM/issues
|
|
29
|
+
Keywords: Transformer,LLM,Time Series,Zero-shot,Deep Learning
|
|
30
|
+
Classifier: Programming Language :: Python :: 3
|
|
31
|
+
Classifier: Operating System :: OS Independent
|
|
32
|
+
Requires-Python: >=3.11
|
|
33
|
+
Description-Content-Type: text/markdown
|
|
34
|
+
License-File: LICENSE
|
|
35
|
+
Requires-Dist: torch>=2.5.0
|
|
36
|
+
Requires-Dist: einops>=0.8.1
|
|
37
|
+
Requires-Dist: huggingface-hub>=0.35.1
|
|
38
|
+
Requires-Dist: rotary-embedding-torch>=0.8.9
|
|
39
|
+
Requires-Dist: numpy>=1.26.0
|
|
40
|
+
Requires-Dist: safetensors==0.5.3
|
|
41
|
+
Dynamic: license-file
|
|
42
|
+
|
|
43
|
+
# A tutorial on how to build a Foundation Model for Univariate Time Series Forecasting
|
|
44
|
+
|
|
45
|
+
[Huggingface Model Card](https://huggingface.co/vilhess/PatchFM)
|
|
46
|
+
|
|
47
|
+
A transformer-based forecasting model for univariate time series. The approach mirrors Large Language Model (LLM) practices (next-token → next-patch) while remaining lightweight compared to a classic LLM and practical.
|
|
48
|
+
|
|
49
|
+
## Highlights
|
|
50
|
+
- Next-patch prediction objective (autoregressive, causal)
|
|
51
|
+
- Patch-based representation of time series (tokens ↔ patches)
|
|
52
|
+
- Causal masking self-attention with RoPE (relative positions)
|
|
53
|
+
- RevIN (Reversible Instance Normalization) with causal statistics
|
|
54
|
+
- SwiGLU feed-forward networks
|
|
55
|
+
- Multi-quantile outputs (median + uncertainty bands)
|
|
56
|
+
- Efficient rollout with KV caching
|
|
57
|
+
|
|
58
|
+
## Installation
|
|
59
|
+
```bash
|
|
60
|
+
pip install patchfm
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Quick Start
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
import torch
|
|
67
|
+
from patchfm import PatchFMConfig, Forecaster
|
|
68
|
+
|
|
69
|
+
# --- Instantiate model ---
|
|
70
|
+
config = PatchFMConfig()
|
|
71
|
+
model = Forecaster(config)
|
|
72
|
+
|
|
73
|
+
# --- Inference ---
|
|
74
|
+
forecast_horizon = 64
|
|
75
|
+
seq = torch.randn(1, 1024) # (batch, time)
|
|
76
|
+
pred_median, pred_quantiles = model(seq, forecast_horizon=forecast_horizon, quantiles=[0.1, 0.5, 0.9]) # (batch, forecast_horizon), (batch, forecast_horizon, quantiles)
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
We provide an extended quick start example in [notebooks/tutorial.ipynb](./notebooks/tutorial.ipynb).
|
|
80
|
+
If you dont have suitable hardware you can run the the extended quick start example example also in Google Colab:
|
|
81
|
+
|
|
82
|
+
<a target="_blank" href="https://colab.research.google.com/drive/17sdf-7luCkv5TaeLj3Z6kIaTDkwkz3VR?usp=share_link">
|
|
83
|
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open Quick Start In Colab"/>
|
|
84
|
+
</a>
|
|
85
|
+
|
|
86
|
+
## Method (TL;DR)
|
|
87
|
+
- Patching: Split a context signal of length $w$ into $P_{num} = w / P_{len}$ patches of length $P_{len}$.
|
|
88
|
+
- RevIN: Normalize patches using causal running mean/variance over past patches, and denormalize outputs to the original scale.
|
|
89
|
+
- Architecture: Input residual MLP → stacked Transformer blocks (MHA + SwiGLU FFN, pre-norm, residual) → $|\mathcal{Q}|$ output heads mapping back to patch space.
|
|
90
|
+
- Positional encoding: Rotary Position Embeddings (RoPE) applied to queries/keys.
|
|
91
|
+
- Training: Multi-quantile (pinball) loss across positions, elements, and quantiles $\mathcal{Q}$.
|
|
92
|
+
- Inference: Predict next patch; roll out autoregressively with KV caching for long horizons.
|
|
93
|
+
|
|
94
|
+
## Problem Formulation
|
|
95
|
+
Given context patches $x_{p_1}, \ldots, x_{p_n}$, predict the next patch $x_{p_{i+1}}$ for each position $i$ using only past patches (causality). The model outputs quantiles $\{\hat{x}_{p_{i+1}}^{(q)}: q \in \mathcal{Q}\}$ with median (q=0.5) as the point forecast.
|
|
96
|
+
|
|
97
|
+
## Loss: Multi-Quantile (Pinball)
|
|
98
|
+
For residual $u = x - \hat{x}^{(q)}$:
|
|
99
|
+
$$\rho_q(u) = \begin{cases} q\,u, & u \ge 0,\\ (q-1)\,u, & u < 0. \end{cases}$$
|
|
100
|
+
Aggregate over positions, patch elements, and quantiles.
|
|
101
|
+
|
|
102
|
+
## Architecture
|
|
103
|
+
- Input MLP: $\mathbb{R}^{P_{len}} \to \mathbb{R}^{dim}$ residual 2-layer MLP (ReLU)
|
|
104
|
+
- Multi-Head Attention: causal mask, RoPE; queries/keys/values per head
|
|
105
|
+
- FFN: SwiGLU (SiLU-gated), pre-norm + residual
|
|
106
|
+
- Output heads: |Q| linear maps $\mathbb{R}^{dim} \to \mathbb{R}^{P_{len}}$ (one per quantile)
|
|
107
|
+
|
|
108
|
+
### Model Details
|
|
109
|
+
- Patch size: 32
|
|
110
|
+
- Max context: 32 patches (1024 steps)
|
|
111
|
+
- Forecast horizon: 32 steps per forward pass
|
|
112
|
+
- Quantiles $\mathcal{Q}$: {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}
|
|
113
|
+
- Layers: 6
|
|
114
|
+
- Attention heads: 64 (head dim 32)
|
|
115
|
+
- Model dim: 2048
|
|
116
|
+
- Parameters: ~300M
|
|
117
|
+
|
|
118
|
+
## Inference
|
|
119
|
+
- Single step: predict next patch ($P_{len}$ values)
|
|
120
|
+
- Long-horizon: append prediction to context and repeat (optionally drop oldest patch to keep window fixed)
|
|
121
|
+
- KV caching: reuse cached keys/values for past patches; compute new Q/K/V only for the appended patch
|
|
122
|
+
|
|
123
|
+
## Acknowledgements
|
|
124
|
+
We thank the authors of the following repositories for inspiration and code snippets:
|
|
125
|
+
- [TiRex](https://github.com/NX-AI/tirex)
|
|
126
|
+
|
|
127
|
+
## Citation
|
|
128
|
+
If you use this work, please cite the paper ...
|
patchfm-1.1.9/README.md
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# A tutorial on how to build a Foundation Model for Univariate Time Series Forecasting
|
|
2
|
+
|
|
3
|
+
[Huggingface Model Card](https://huggingface.co/vilhess/PatchFM)
|
|
4
|
+
|
|
5
|
+
A transformer-based forecasting model for univariate time series. The approach mirrors Large Language Model (LLM) practices (next-token → next-patch) while remaining lightweight compared to a classic LLM and practical.
|
|
6
|
+
|
|
7
|
+
## Highlights
|
|
8
|
+
- Next-patch prediction objective (autoregressive, causal)
|
|
9
|
+
- Patch-based representation of time series (tokens ↔ patches)
|
|
10
|
+
- Causal masking self-attention with RoPE (relative positions)
|
|
11
|
+
- RevIN (Reversible Instance Normalization) with causal statistics
|
|
12
|
+
- SwiGLU feed-forward networks
|
|
13
|
+
- Multi-quantile outputs (median + uncertainty bands)
|
|
14
|
+
- Efficient rollout with KV caching
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
```bash
|
|
18
|
+
pip install patchfm
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
## Quick Start
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
import torch
|
|
25
|
+
from patchfm import PatchFMConfig, Forecaster
|
|
26
|
+
|
|
27
|
+
# --- Instantiate model ---
|
|
28
|
+
config = PatchFMConfig()
|
|
29
|
+
model = Forecaster(config)
|
|
30
|
+
|
|
31
|
+
# --- Inference ---
|
|
32
|
+
forecast_horizon = 64
|
|
33
|
+
seq = torch.randn(1, 1024) # (batch, time)
|
|
34
|
+
pred_median, pred_quantiles = model(seq, forecast_horizon=forecast_horizon, quantiles=[0.1, 0.5, 0.9]) # (batch, forecast_horizon), (batch, forecast_horizon, quantiles)
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
We provide an extended quick start example in [notebooks/tutorial.ipynb](./notebooks/tutorial.ipynb).
|
|
38
|
+
If you dont have suitable hardware you can run the the extended quick start example example also in Google Colab:
|
|
39
|
+
|
|
40
|
+
<a target="_blank" href="https://colab.research.google.com/drive/17sdf-7luCkv5TaeLj3Z6kIaTDkwkz3VR?usp=share_link">
|
|
41
|
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open Quick Start In Colab"/>
|
|
42
|
+
</a>
|
|
43
|
+
|
|
44
|
+
## Method (TL;DR)
|
|
45
|
+
- Patching: Split a context signal of length $w$ into $P_{num} = w / P_{len}$ patches of length $P_{len}$.
|
|
46
|
+
- RevIN: Normalize patches using causal running mean/variance over past patches, and denormalize outputs to the original scale.
|
|
47
|
+
- Architecture: Input residual MLP → stacked Transformer blocks (MHA + SwiGLU FFN, pre-norm, residual) → $|\mathcal{Q}|$ output heads mapping back to patch space.
|
|
48
|
+
- Positional encoding: Rotary Position Embeddings (RoPE) applied to queries/keys.
|
|
49
|
+
- Training: Multi-quantile (pinball) loss across positions, elements, and quantiles $\mathcal{Q}$.
|
|
50
|
+
- Inference: Predict next patch; roll out autoregressively with KV caching for long horizons.
|
|
51
|
+
|
|
52
|
+
## Problem Formulation
|
|
53
|
+
Given context patches $x_{p_1}, \ldots, x_{p_n}$, predict the next patch $x_{p_{i+1}}$ for each position $i$ using only past patches (causality). The model outputs quantiles $\{\hat{x}_{p_{i+1}}^{(q)}: q \in \mathcal{Q}\}$ with median (q=0.5) as the point forecast.
|
|
54
|
+
|
|
55
|
+
## Loss: Multi-Quantile (Pinball)
|
|
56
|
+
For residual $u = x - \hat{x}^{(q)}$:
|
|
57
|
+
$$\rho_q(u) = \begin{cases} q\,u, & u \ge 0,\\ (q-1)\,u, & u < 0. \end{cases}$$
|
|
58
|
+
Aggregate over positions, patch elements, and quantiles.
|
|
59
|
+
|
|
60
|
+
## Architecture
|
|
61
|
+
- Input MLP: $\mathbb{R}^{P_{len}} \to \mathbb{R}^{dim}$ residual 2-layer MLP (ReLU)
|
|
62
|
+
- Multi-Head Attention: causal mask, RoPE; queries/keys/values per head
|
|
63
|
+
- FFN: SwiGLU (SiLU-gated), pre-norm + residual
|
|
64
|
+
- Output heads: |Q| linear maps $\mathbb{R}^{dim} \to \mathbb{R}^{P_{len}}$ (one per quantile)
|
|
65
|
+
|
|
66
|
+
### Model Details
|
|
67
|
+
- Patch size: 32
|
|
68
|
+
- Max context: 32 patches (1024 steps)
|
|
69
|
+
- Forecast horizon: 32 steps per forward pass
|
|
70
|
+
- Quantiles $\mathcal{Q}$: {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}
|
|
71
|
+
- Layers: 6
|
|
72
|
+
- Attention heads: 64 (head dim 32)
|
|
73
|
+
- Model dim: 2048
|
|
74
|
+
- Parameters: ~300M
|
|
75
|
+
|
|
76
|
+
## Inference
|
|
77
|
+
- Single step: predict next patch ($P_{len}$ values)
|
|
78
|
+
- Long-horizon: append prediction to context and repeat (optionally drop oldest patch to keep window fixed)
|
|
79
|
+
- KV caching: reuse cached keys/values for past patches; compute new Q/K/V only for the appended patch
|
|
80
|
+
|
|
81
|
+
## Acknowledgements
|
|
82
|
+
We thank the authors of the following repositories for inspiration and code snippets:
|
|
83
|
+
- [TiRex](https://github.com/NX-AI/tirex)
|
|
84
|
+
|
|
85
|
+
## Citation
|
|
86
|
+
If you use this work, please cite the paper ...
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "patchfm"
|
|
3
|
+
version = "1.1.9"
|
|
4
|
+
authors = [
|
|
5
|
+
{ name="Samy-Melwan Vilhes", email="samy-melwan.vilhes@insa-rouen.fr" },
|
|
6
|
+
]
|
|
7
|
+
description = "a Foundation Model for Univariate Time Series Forecasting"
|
|
8
|
+
readme = "README.md"
|
|
9
|
+
license = {file="LICENSE"}
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
classifiers = [
|
|
12
|
+
"Programming Language :: Python :: 3",
|
|
13
|
+
"Operating System :: OS Independent",
|
|
14
|
+
]
|
|
15
|
+
keywords = ["Transformer", "LLM", "Time Series", "Zero-shot", "Deep Learning"]
|
|
16
|
+
dependencies = [
|
|
17
|
+
"torch>=2.5.0",
|
|
18
|
+
"einops>=0.8.1",
|
|
19
|
+
"huggingface-hub>=0.35.1",
|
|
20
|
+
"rotary-embedding-torch>=0.8.9",
|
|
21
|
+
"numpy>=1.26.0",
|
|
22
|
+
"safetensors==0.5.3"
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
[project.urls]
|
|
26
|
+
Repository = "https://github.com/vilhess/PatchFM"
|
|
27
|
+
Issues = "https://github.com/vilhess/PatchFM/issues"
|
|
28
|
+
|
|
29
|
+
[build-system]
|
|
30
|
+
requires = ["setuptools >= 77.0.3"]
|
|
31
|
+
build-backend = "setuptools.build_meta"
|
patchfm-1.1.9/setup.cfg
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from dataclasses import dataclass, field, asdict
|
|
2
|
+
|
|
3
|
+
@dataclass
|
|
4
|
+
class PatchFMConfig:
|
|
5
|
+
max_seq_len: int = 1024
|
|
6
|
+
patch_len: int = 32
|
|
7
|
+
d_model: int = 2048
|
|
8
|
+
n_heads: int = 64
|
|
9
|
+
n_layers_encoder: int = 6
|
|
10
|
+
quantiles: list[float] = field(default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
|
|
11
|
+
|
|
12
|
+
compile: bool = True
|
|
13
|
+
|
|
14
|
+
def __getitem__(self, key):
|
|
15
|
+
return getattr(self, key)
|
|
16
|
+
|
|
17
|
+
def __setitem__(self, key, value):
|
|
18
|
+
return setattr(self, key, value)
|
|
19
|
+
|
|
20
|
+
def to_dict(self):
|
|
21
|
+
return asdict(self)
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from einops import rearrange
|
|
4
|
+
from patchfm.inference.modules import RevIN, ResidualBlock, TransformerEncoder, PatchFM, SeqTypeConverter
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# --- Forecaster Model ---
|
|
8
|
+
class Forecaster(nn.Module):
|
|
9
|
+
def __init__(self, config):
|
|
10
|
+
super().__init__()
|
|
11
|
+
|
|
12
|
+
# Store config
|
|
13
|
+
self.max_seq_len = config["max_seq_len"]
|
|
14
|
+
self.patch_len = config["patch_len"]
|
|
15
|
+
self.d_model = config["d_model"]
|
|
16
|
+
self.n_heads = config["n_heads"]
|
|
17
|
+
self.n_layers_encoder = config["n_layers_encoder"]
|
|
18
|
+
self.quantiles = config["quantiles"]
|
|
19
|
+
self.n_quantiles = len(self.quantiles)
|
|
20
|
+
self.max_patches = self.max_seq_len // self.patch_len
|
|
21
|
+
|
|
22
|
+
print("Loading base model from HuggingFace Hub...")
|
|
23
|
+
base_model = PatchFM.from_pretrained("vilhess/PatchFM")
|
|
24
|
+
self._init_from_base(base_model)
|
|
25
|
+
|
|
26
|
+
self.eval()
|
|
27
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
28
|
+
self.to(self.device)
|
|
29
|
+
|
|
30
|
+
self.converter = SeqTypeConverter()
|
|
31
|
+
|
|
32
|
+
if config["compile"]:
|
|
33
|
+
self = torch.compile(self)
|
|
34
|
+
|
|
35
|
+
def _init_components(self):
|
|
36
|
+
"""Initialize modules from scratch."""
|
|
37
|
+
self.revin = RevIN()
|
|
38
|
+
self.proj_embedding = ResidualBlock(
|
|
39
|
+
in_dim=self.patch_len,
|
|
40
|
+
hid_dim=2 * self.patch_len,
|
|
41
|
+
out_dim=self.d_model
|
|
42
|
+
)
|
|
43
|
+
self.transformer_encoder = TransformerEncoder(
|
|
44
|
+
d_model=self.d_model,
|
|
45
|
+
n_heads=self.n_heads,
|
|
46
|
+
n_layers=self.n_layers_encoder
|
|
47
|
+
)
|
|
48
|
+
self.proj_output = ResidualBlock(
|
|
49
|
+
in_dim=self.d_model,
|
|
50
|
+
hid_dim=2 * self.d_model,
|
|
51
|
+
out_dim=self.patch_len * self.n_quantiles
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def _init_from_base(self, base_model):
|
|
55
|
+
"""Initialize modules by reusing a pretrained PatchFM model."""
|
|
56
|
+
self.revin = base_model.revin
|
|
57
|
+
self.proj_embedding = base_model.proj_embedding
|
|
58
|
+
self.transformer_encoder = base_model.transformer_encoder
|
|
59
|
+
self.proj_output = base_model.proj_output
|
|
60
|
+
|
|
61
|
+
@torch.inference_mode()
|
|
62
|
+
def forecast(self, x: torch.Tensor, forecast_horizon: int | None = None, quantiles: list[float] | None = None) -> torch.Tensor:
|
|
63
|
+
x = self.converter.convert(x)
|
|
64
|
+
assert x.ndim in (1, 2), f"Input dimension must be 1D (time) or 2D (batch, time), got {x.ndim}D."
|
|
65
|
+
|
|
66
|
+
batch_dim=True
|
|
67
|
+
if x.ndim != 2:
|
|
68
|
+
x = x.unsqueeze(0)
|
|
69
|
+
batch_dim=False
|
|
70
|
+
bs, ws = x.size()
|
|
71
|
+
|
|
72
|
+
x = x.to(self.device)
|
|
73
|
+
|
|
74
|
+
if ws > self.max_seq_len:
|
|
75
|
+
print(f"Warning: Input length {ws} exceeds max_seq_len {self.max_seq_len}. Truncating input.")
|
|
76
|
+
x = x[:, -self.max_seq_len:]
|
|
77
|
+
ws = self.max_seq_len
|
|
78
|
+
|
|
79
|
+
# Pad so length is divisible by patch_len
|
|
80
|
+
pad = (self.patch_len - ws % self.patch_len) % self.patch_len
|
|
81
|
+
if pad > 0:
|
|
82
|
+
x = torch.cat([x[:, :1].repeat(1, pad), x], dim=1)
|
|
83
|
+
|
|
84
|
+
# Default horizon = patch_len
|
|
85
|
+
forecast_horizon = forecast_horizon or self.patch_len
|
|
86
|
+
|
|
87
|
+
# Reshape into patches
|
|
88
|
+
x = rearrange(x, "b (pn pl) -> b pn pl", pl=self.patch_len)
|
|
89
|
+
|
|
90
|
+
rollouts = -(-forecast_horizon // self.patch_len) # ceil division
|
|
91
|
+
predictions = []
|
|
92
|
+
|
|
93
|
+
for _ in range(rollouts):
|
|
94
|
+
|
|
95
|
+
if x.size(1) > self.max_patches:
|
|
96
|
+
x = x[:, -self.max_patches:, :]
|
|
97
|
+
|
|
98
|
+
init_x = x.clone()
|
|
99
|
+
# Forward pass
|
|
100
|
+
x = self.revin(x, mode="norm")
|
|
101
|
+
x = self.proj_embedding(x)
|
|
102
|
+
x = self.transformer_encoder(x)
|
|
103
|
+
x = x[:, -1:, :] # Keep only the last patch for autoregressive forecasting
|
|
104
|
+
forecasting = self.proj_output(x)
|
|
105
|
+
forecasting = self.revin(forecasting, mode="denorm")
|
|
106
|
+
|
|
107
|
+
# Reshape to (bs, patch_num, patch_len, n_quantiles)
|
|
108
|
+
forecasting = rearrange(
|
|
109
|
+
forecasting, "b 1 (pl q) -> b 1 pl q",
|
|
110
|
+
pl=self.patch_len, q=self.n_quantiles
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Take median quantile (index 4)
|
|
114
|
+
patch_median = forecasting[:, -1:, :, 4].detach()
|
|
115
|
+
predictions.append(forecasting[:, -1, :, :])
|
|
116
|
+
|
|
117
|
+
# Append median patch for next rollout
|
|
118
|
+
x = patch_median.clone()
|
|
119
|
+
x = torch.cat([init_x, x], dim=1)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
pred_quantiles = torch.cat(predictions, dim=1)
|
|
123
|
+
pred_quantiles = pred_quantiles[:, :forecast_horizon, :]
|
|
124
|
+
pred_median = pred_quantiles[:, :, 4]
|
|
125
|
+
|
|
126
|
+
pred_quantiles = pred_quantiles[..., [self.quantiles.index(q) for q in quantiles]] if quantiles is not None else pred_quantiles
|
|
127
|
+
|
|
128
|
+
self.clear_cache()
|
|
129
|
+
|
|
130
|
+
if torch.any(torch.isnan(pred_median)) or torch.any(torch.isinf(pred_median)):
|
|
131
|
+
print("Warning: NaN or Inf values detected in predictions. Returning zeros.")
|
|
132
|
+
pred_median = torch.zeros_like(pred_median)
|
|
133
|
+
pred_quantiles = torch.zeros_like(pred_quantiles)
|
|
134
|
+
|
|
135
|
+
if not batch_dim:
|
|
136
|
+
pred_median = pred_median.squeeze(0)
|
|
137
|
+
pred_quantiles = pred_quantiles.squeeze(0)
|
|
138
|
+
|
|
139
|
+
pred_median, pred_quantiles = self.converter.deconvert(pred_median, pred_quantiles)
|
|
140
|
+
return pred_median, pred_quantiles
|
|
141
|
+
|
|
142
|
+
def __call__(self, context: torch.Tensor, forecast_horizon: int | None = None, quantiles: list[float] | None = None) -> torch.Tensor:
|
|
143
|
+
return self.forecast(context, forecast_horizon, quantiles)
|
|
144
|
+
|
|
145
|
+
def clear_cache(self):
|
|
146
|
+
self.revin.clear_cache()
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# Modules efficient for inference with caching
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
from rotary_embedding_torch import RotaryEmbedding
|
|
7
|
+
from huggingface_hub import PyTorchModelHubMixin
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
class SeqTypeConverter:
|
|
11
|
+
def __init__(self):
|
|
12
|
+
self.init_type = None
|
|
13
|
+
|
|
14
|
+
def convert(self, seq):
|
|
15
|
+
if isinstance(seq, torch.Tensor):
|
|
16
|
+
self.init_type = 'torch'
|
|
17
|
+
return seq
|
|
18
|
+
|
|
19
|
+
elif isinstance(seq, np.ndarray):
|
|
20
|
+
self.init_type = 'numpy'
|
|
21
|
+
return torch.from_numpy(seq)
|
|
22
|
+
|
|
23
|
+
elif isinstance(seq, list):
|
|
24
|
+
if all(isinstance(x, torch.Tensor) for x in seq):
|
|
25
|
+
self.init_type = 'list_of_tensors'
|
|
26
|
+
try:
|
|
27
|
+
return torch.stack(seq)
|
|
28
|
+
except Exception:
|
|
29
|
+
raise ValueError("All tensors in the list must have the same shape to stack.")
|
|
30
|
+
else:
|
|
31
|
+
self.init_type = 'list'
|
|
32
|
+
return torch.tensor(seq)
|
|
33
|
+
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"Unsupported type: {type(seq)}")
|
|
36
|
+
|
|
37
|
+
def deconvert(self, seq, quantiles):
|
|
38
|
+
seq = seq.detach().cpu()
|
|
39
|
+
quantiles = quantiles.detach().cpu()
|
|
40
|
+
|
|
41
|
+
if self.init_type == 'torch':
|
|
42
|
+
return self._ensure_torch(seq), self._ensure_torch(quantiles)
|
|
43
|
+
|
|
44
|
+
elif self.init_type == 'numpy':
|
|
45
|
+
return self._ensure_numpy(seq), self._ensure_numpy(quantiles)
|
|
46
|
+
|
|
47
|
+
elif self.init_type == 'list':
|
|
48
|
+
return seq.tolist(), quantiles.tolist()
|
|
49
|
+
|
|
50
|
+
elif self.init_type == 'list_of_tensors':
|
|
51
|
+
seqs = list(seq.unbind(0))
|
|
52
|
+
quants = list(quantiles.unbind(0))
|
|
53
|
+
return seqs, quants
|
|
54
|
+
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"Unsupported type: {self.init_type}")
|
|
57
|
+
|
|
58
|
+
def _ensure_torch(self, x):
|
|
59
|
+
return x if isinstance(x, torch.Tensor) else torch.tensor(x)
|
|
60
|
+
|
|
61
|
+
def _ensure_numpy(self, x):
|
|
62
|
+
return x if isinstance(x, np.ndarray) else np.array(x)
|
|
63
|
+
|
|
64
|
+
def fill_nan_with_last_observed(x):
|
|
65
|
+
bs, pn, pl = x.size()
|
|
66
|
+
x = rearrange(x, "b pn pl -> (b pn) pl")
|
|
67
|
+
valid_mask = ~torch.isnan(x)
|
|
68
|
+
x_temp = torch.where(valid_mask, x, torch.zeros_like(x))
|
|
69
|
+
seq_indices = torch.arange(x.size(-1), device=x.device).unsqueeze(0)
|
|
70
|
+
valid_indices = torch.where(valid_mask, seq_indices, torch.tensor(-1, device=x.device))
|
|
71
|
+
last_valid_idx = torch.cummax(valid_indices, dim=-1)[0]
|
|
72
|
+
x = x_temp.gather(-1, torch.clamp(last_valid_idx, min=0))
|
|
73
|
+
x = rearrange(x, "(b pn) pl -> b pn pl", b=bs)
|
|
74
|
+
return x
|
|
75
|
+
|
|
76
|
+
def nanstd(o, dim, keepdim=False):
|
|
77
|
+
m = torch.nanmean(o, dim=dim, keepdim=True)
|
|
78
|
+
sq = (o - m) ** 2
|
|
79
|
+
n = torch.sum(~torch.isnan(o), dim=dim, keepdim=True).float()
|
|
80
|
+
n_safe = torch.clamp(n - 1, min=1.0)
|
|
81
|
+
var = torch.nansum(sq, dim=dim, keepdim=True) / n_safe
|
|
82
|
+
std = torch.sqrt(var)
|
|
83
|
+
if not keepdim:
|
|
84
|
+
std = std.squeeze(dim)
|
|
85
|
+
return std
|
|
86
|
+
|
|
87
|
+
class RevIN(nn.Module):
|
|
88
|
+
def __init__(self, eps=1e-5):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.eps = eps
|
|
91
|
+
self.cached_mean = None
|
|
92
|
+
self.cached_std = None
|
|
93
|
+
|
|
94
|
+
def forward(self, x, mode: str):
|
|
95
|
+
assert x.dim() == 3, "Input tensor must be (batch, n_patches, patch_len)"
|
|
96
|
+
|
|
97
|
+
if mode == "norm":
|
|
98
|
+
mean, std = self._get_statistics(x)
|
|
99
|
+
self.cached_mean, self.cached_std = mean.detach(), std.detach()
|
|
100
|
+
out = (x - mean) / std
|
|
101
|
+
out = torch.asinh(out)
|
|
102
|
+
|
|
103
|
+
if torch.isnan(out).any():
|
|
104
|
+
out = fill_nan_with_last_observed(out)
|
|
105
|
+
|
|
106
|
+
elif mode == "denorm":
|
|
107
|
+
assert self.cached_mean is not None and self.cached_std is not None, \
|
|
108
|
+
"Call forward(..., 'norm') before 'denorm'"
|
|
109
|
+
out = torch.sinh(x) * self.cached_std + self.cached_mean
|
|
110
|
+
|
|
111
|
+
else:
|
|
112
|
+
raise NotImplementedError(f"Mode '{mode}' not implemented.")
|
|
113
|
+
return out
|
|
114
|
+
|
|
115
|
+
def _get_statistics(self, x):
|
|
116
|
+
|
|
117
|
+
if not x.isnan().any():
|
|
118
|
+
mean = x.mean(dim=(-1, -2), keepdim=True)
|
|
119
|
+
std = x.std(dim=(-1, -2), keepdim=True) + self.eps
|
|
120
|
+
|
|
121
|
+
else:
|
|
122
|
+
mean = x.nanmean(dim=(-1, -2), keepdim=True)
|
|
123
|
+
std = nanstd(x, dim=(-1, -2), keepdim=True) + self.eps
|
|
124
|
+
|
|
125
|
+
return mean, std
|
|
126
|
+
|
|
127
|
+
def clear_cache(self):
|
|
128
|
+
self.cached_mean = None
|
|
129
|
+
self.cached_std = None
|
|
130
|
+
|
|
131
|
+
class ResidualBlock(nn.Module):
|
|
132
|
+
def __init__(self, in_dim, hid_dim, out_dim):
|
|
133
|
+
super().__init__()
|
|
134
|
+
self.hidden_layer = nn.Linear(in_dim, hid_dim)
|
|
135
|
+
self.output_layer = nn.Linear(hid_dim, out_dim)
|
|
136
|
+
self.residual_layer = nn.Linear(in_dim, out_dim)
|
|
137
|
+
self.act = nn.ReLU()
|
|
138
|
+
|
|
139
|
+
def forward(self, x):
|
|
140
|
+
hid = self.act(self.hidden_layer(x))
|
|
141
|
+
out = self.output_layer(hid)
|
|
142
|
+
res = self.residual_layer(x)
|
|
143
|
+
out = out+res
|
|
144
|
+
return out
|
|
145
|
+
|
|
146
|
+
class MultiHeadAttention(nn.Module):
|
|
147
|
+
def __init__(self, d_model, n_heads, last=False):
|
|
148
|
+
super().__init__()
|
|
149
|
+
assert d_model%n_heads==0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
|
|
150
|
+
|
|
151
|
+
self.WQ = nn.Linear(d_model, d_model)
|
|
152
|
+
self.WK = nn.Linear(d_model, d_model)
|
|
153
|
+
self.WV = nn.Linear(d_model, d_model)
|
|
154
|
+
|
|
155
|
+
self.out_proj = nn.Linear(d_model, d_model)
|
|
156
|
+
|
|
157
|
+
self.head_dim = d_model//n_heads
|
|
158
|
+
self.n_heads = n_heads
|
|
159
|
+
|
|
160
|
+
self.rope = RotaryEmbedding(dim=self.head_dim//2)
|
|
161
|
+
|
|
162
|
+
self.last = last
|
|
163
|
+
|
|
164
|
+
def forward(self, q):
|
|
165
|
+
bs, context, dim = q.size()
|
|
166
|
+
offset = 0
|
|
167
|
+
is_causal = True
|
|
168
|
+
|
|
169
|
+
k = q
|
|
170
|
+
v = q
|
|
171
|
+
|
|
172
|
+
if self.last:
|
|
173
|
+
q = q[:, -1:, :]
|
|
174
|
+
is_causal = False
|
|
175
|
+
offset += (context - 1)
|
|
176
|
+
|
|
177
|
+
q = self.WQ(q).reshape(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
|
|
178
|
+
k = self.WK(k).reshape(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
|
|
179
|
+
v = self.WV(v).reshape(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
|
|
180
|
+
|
|
181
|
+
q = self.rope.rotate_queries_or_keys(q, offset=offset)
|
|
182
|
+
k = self.rope.rotate_queries_or_keys(k)
|
|
183
|
+
|
|
184
|
+
values = nn.functional.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
|
|
185
|
+
|
|
186
|
+
values = values.transpose(1, 2).reshape(bs, -1, dim)
|
|
187
|
+
values = self.out_proj(values)
|
|
188
|
+
return values
|
|
189
|
+
|
|
190
|
+
class FeedForward(nn.Module):
|
|
191
|
+
def __init__(self, d_model, multiple_of=256):
|
|
192
|
+
super().__init__()
|
|
193
|
+
|
|
194
|
+
hidden_dim = d_model*4
|
|
195
|
+
hidden_dim = int(2 * hidden_dim / 3)
|
|
196
|
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
|
197
|
+
|
|
198
|
+
self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
|
|
199
|
+
self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
|
|
200
|
+
self.w3 = nn.Linear(d_model, hidden_dim, bias=False)
|
|
201
|
+
|
|
202
|
+
self.act = nn.SiLU()
|
|
203
|
+
|
|
204
|
+
def forward(self, x):
|
|
205
|
+
x = self.w2(self.act(self.w1(x)) * self.w3(x))
|
|
206
|
+
return x
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class TransformerEncoderLayer(nn.Module):
|
|
210
|
+
def __init__(self, d_model, n_heads, last=False):
|
|
211
|
+
super().__init__()
|
|
212
|
+
self.ln1 = nn.LayerNorm(d_model)
|
|
213
|
+
self.attn = MultiHeadAttention(d_model=d_model, n_heads=n_heads, last=last)
|
|
214
|
+
self.ln2 = nn.LayerNorm(d_model)
|
|
215
|
+
self.ff = FeedForward(d_model=d_model)
|
|
216
|
+
|
|
217
|
+
def forward(self, x):
|
|
218
|
+
out_attn = self.attn(self.ln1((x)))
|
|
219
|
+
x = x + out_attn
|
|
220
|
+
out = x + self.ff(self.ln2(x))
|
|
221
|
+
return out
|
|
222
|
+
|
|
223
|
+
class TransformerEncoder(nn.Module):
|
|
224
|
+
def __init__(self, d_model, n_heads, n_layers):
|
|
225
|
+
super().__init__()
|
|
226
|
+
self.layers = nn.ModuleList(
|
|
227
|
+
[
|
|
228
|
+
TransformerEncoderLayer(d_model=d_model, n_heads=n_heads)
|
|
229
|
+
for _ in range(n_layers-1)
|
|
230
|
+
]
|
|
231
|
+
)
|
|
232
|
+
self.layers.append(TransformerEncoderLayer(d_model=d_model, n_heads=n_heads, last=True))
|
|
233
|
+
self.norm = nn.LayerNorm(d_model)
|
|
234
|
+
|
|
235
|
+
def forward(self, x):
|
|
236
|
+
for layer in self.layers:
|
|
237
|
+
x = layer(x)
|
|
238
|
+
return self.norm(x)
|
|
239
|
+
|
|
240
|
+
class PatchFM(nn.Module, PyTorchModelHubMixin):
|
|
241
|
+
def __init__(self, config):
|
|
242
|
+
super().__init__()
|
|
243
|
+
|
|
244
|
+
# Store config
|
|
245
|
+
self.patch_len = config["patch_len"]
|
|
246
|
+
self.d_model = config["d_model"]
|
|
247
|
+
self.n_heads = config["n_heads"]
|
|
248
|
+
self.n_layers_encoder = config["n_layers_encoder"]
|
|
249
|
+
self.quantiles = config["quantiles"]
|
|
250
|
+
self.n_quantiles = len(self.quantiles)
|
|
251
|
+
|
|
252
|
+
# Components
|
|
253
|
+
self.revin = RevIN()
|
|
254
|
+
self.proj_embedding = ResidualBlock(
|
|
255
|
+
in_dim=self.patch_len,
|
|
256
|
+
hid_dim=2 * self.patch_len,
|
|
257
|
+
out_dim=self.d_model
|
|
258
|
+
)
|
|
259
|
+
self.transformer_encoder = TransformerEncoder(
|
|
260
|
+
d_model=self.d_model,
|
|
261
|
+
n_heads=self.n_heads,
|
|
262
|
+
n_layers=self.n_layers_encoder
|
|
263
|
+
)
|
|
264
|
+
self.proj_output = ResidualBlock(
|
|
265
|
+
in_dim=self.d_model,
|
|
266
|
+
hid_dim=2 * self.d_model,
|
|
267
|
+
out_dim=self.patch_len * self.n_quantiles
|
|
268
|
+
)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: patchfm
|
|
3
|
+
Version: 1.1.9
|
|
4
|
+
Summary: a Foundation Model for Univariate Time Series Forecasting
|
|
5
|
+
Author-email: Samy-Melwan Vilhes <samy-melwan.vilhes@insa-rouen.fr>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 Samy-Melwan Vilhes
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
Project-URL: Repository, https://github.com/vilhess/PatchFM
|
|
28
|
+
Project-URL: Issues, https://github.com/vilhess/PatchFM/issues
|
|
29
|
+
Keywords: Transformer,LLM,Time Series,Zero-shot,Deep Learning
|
|
30
|
+
Classifier: Programming Language :: Python :: 3
|
|
31
|
+
Classifier: Operating System :: OS Independent
|
|
32
|
+
Requires-Python: >=3.11
|
|
33
|
+
Description-Content-Type: text/markdown
|
|
34
|
+
License-File: LICENSE
|
|
35
|
+
Requires-Dist: torch>=2.5.0
|
|
36
|
+
Requires-Dist: einops>=0.8.1
|
|
37
|
+
Requires-Dist: huggingface-hub>=0.35.1
|
|
38
|
+
Requires-Dist: rotary-embedding-torch>=0.8.9
|
|
39
|
+
Requires-Dist: numpy>=1.26.0
|
|
40
|
+
Requires-Dist: safetensors==0.5.3
|
|
41
|
+
Dynamic: license-file
|
|
42
|
+
|
|
43
|
+
# A tutorial on how to build a Foundation Model for Univariate Time Series Forecasting
|
|
44
|
+
|
|
45
|
+
[Huggingface Model Card](https://huggingface.co/vilhess/PatchFM)
|
|
46
|
+
|
|
47
|
+
A transformer-based forecasting model for univariate time series. The approach mirrors Large Language Model (LLM) practices (next-token → next-patch) while remaining lightweight compared to a classic LLM and practical.
|
|
48
|
+
|
|
49
|
+
## Highlights
|
|
50
|
+
- Next-patch prediction objective (autoregressive, causal)
|
|
51
|
+
- Patch-based representation of time series (tokens ↔ patches)
|
|
52
|
+
- Causal masking self-attention with RoPE (relative positions)
|
|
53
|
+
- RevIN (Reversible Instance Normalization) with causal statistics
|
|
54
|
+
- SwiGLU feed-forward networks
|
|
55
|
+
- Multi-quantile outputs (median + uncertainty bands)
|
|
56
|
+
- Efficient rollout with KV caching
|
|
57
|
+
|
|
58
|
+
## Installation
|
|
59
|
+
```bash
|
|
60
|
+
pip install patchfm
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Quick Start
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
import torch
|
|
67
|
+
from patchfm import PatchFMConfig, Forecaster
|
|
68
|
+
|
|
69
|
+
# --- Instantiate model ---
|
|
70
|
+
config = PatchFMConfig()
|
|
71
|
+
model = Forecaster(config)
|
|
72
|
+
|
|
73
|
+
# --- Inference ---
|
|
74
|
+
forecast_horizon = 64
|
|
75
|
+
seq = torch.randn(1, 1024) # (batch, time)
|
|
76
|
+
pred_median, pred_quantiles = model(seq, forecast_horizon=forecast_horizon, quantiles=[0.1, 0.5, 0.9]) # (batch, forecast_horizon), (batch, forecast_horizon, quantiles)
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
We provide an extended quick start example in [notebooks/tutorial.ipynb](./notebooks/tutorial.ipynb).
|
|
80
|
+
If you dont have suitable hardware you can run the the extended quick start example example also in Google Colab:
|
|
81
|
+
|
|
82
|
+
<a target="_blank" href="https://colab.research.google.com/drive/17sdf-7luCkv5TaeLj3Z6kIaTDkwkz3VR?usp=share_link">
|
|
83
|
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open Quick Start In Colab"/>
|
|
84
|
+
</a>
|
|
85
|
+
|
|
86
|
+
## Method (TL;DR)
|
|
87
|
+
- Patching: Split a context signal of length $w$ into $P_{num} = w / P_{len}$ patches of length $P_{len}$.
|
|
88
|
+
- RevIN: Normalize patches using causal running mean/variance over past patches, and denormalize outputs to the original scale.
|
|
89
|
+
- Architecture: Input residual MLP → stacked Transformer blocks (MHA + SwiGLU FFN, pre-norm, residual) → $|\mathcal{Q}|$ output heads mapping back to patch space.
|
|
90
|
+
- Positional encoding: Rotary Position Embeddings (RoPE) applied to queries/keys.
|
|
91
|
+
- Training: Multi-quantile (pinball) loss across positions, elements, and quantiles $\mathcal{Q}$.
|
|
92
|
+
- Inference: Predict next patch; roll out autoregressively with KV caching for long horizons.
|
|
93
|
+
|
|
94
|
+
## Problem Formulation
|
|
95
|
+
Given context patches $x_{p_1}, \ldots, x_{p_n}$, predict the next patch $x_{p_{i+1}}$ for each position $i$ using only past patches (causality). The model outputs quantiles $\{\hat{x}_{p_{i+1}}^{(q)}: q \in \mathcal{Q}\}$ with median (q=0.5) as the point forecast.
|
|
96
|
+
|
|
97
|
+
## Loss: Multi-Quantile (Pinball)
|
|
98
|
+
For residual $u = x - \hat{x}^{(q)}$:
|
|
99
|
+
$$\rho_q(u) = \begin{cases} q\,u, & u \ge 0,\\ (q-1)\,u, & u < 0. \end{cases}$$
|
|
100
|
+
Aggregate over positions, patch elements, and quantiles.
|
|
101
|
+
|
|
102
|
+
## Architecture
|
|
103
|
+
- Input MLP: $\mathbb{R}^{P_{len}} \to \mathbb{R}^{dim}$ residual 2-layer MLP (ReLU)
|
|
104
|
+
- Multi-Head Attention: causal mask, RoPE; queries/keys/values per head
|
|
105
|
+
- FFN: SwiGLU (SiLU-gated), pre-norm + residual
|
|
106
|
+
- Output heads: |Q| linear maps $\mathbb{R}^{dim} \to \mathbb{R}^{P_{len}}$ (one per quantile)
|
|
107
|
+
|
|
108
|
+
### Model Details
|
|
109
|
+
- Patch size: 32
|
|
110
|
+
- Max context: 32 patches (1024 steps)
|
|
111
|
+
- Forecast horizon: 32 steps per forward pass
|
|
112
|
+
- Quantiles $\mathcal{Q}$: {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}
|
|
113
|
+
- Layers: 6
|
|
114
|
+
- Attention heads: 64 (head dim 32)
|
|
115
|
+
- Model dim: 2048
|
|
116
|
+
- Parameters: ~300M
|
|
117
|
+
|
|
118
|
+
## Inference
|
|
119
|
+
- Single step: predict next patch ($P_{len}$ values)
|
|
120
|
+
- Long-horizon: append prediction to context and repeat (optionally drop oldest patch to keep window fixed)
|
|
121
|
+
- KV caching: reuse cached keys/values for past patches; compute new Q/K/V only for the appended patch
|
|
122
|
+
|
|
123
|
+
## Acknowledgements
|
|
124
|
+
We thank the authors of the following repositories for inspiration and code snippets:
|
|
125
|
+
- [TiRex](https://github.com/NX-AI/tirex)
|
|
126
|
+
|
|
127
|
+
## Citation
|
|
128
|
+
If you use this work, please cite the paper ...
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/patchfm/__init__.py
|
|
5
|
+
src/patchfm.egg-info/PKG-INFO
|
|
6
|
+
src/patchfm.egg-info/SOURCES.txt
|
|
7
|
+
src/patchfm.egg-info/dependency_links.txt
|
|
8
|
+
src/patchfm.egg-info/requires.txt
|
|
9
|
+
src/patchfm.egg-info/top_level.txt
|
|
10
|
+
src/patchfm/configs/model_config.py
|
|
11
|
+
src/patchfm/inference/forecaster.py
|
|
12
|
+
src/patchfm/inference/modules.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
patchfm
|