ezmsg-learn 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__init__.py +2 -0
- ezmsg/learn/__version__.py +34 -0
- ezmsg/learn/dim_reduce/__init__.py +0 -0
- ezmsg/learn/dim_reduce/adaptive_decomp.py +284 -0
- ezmsg/learn/dim_reduce/incremental_decomp.py +181 -0
- ezmsg/learn/linear_model/__init__.py +1 -0
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
- ezmsg/learn/linear_model/cca.py +1 -0
- ezmsg/learn/linear_model/linear_regressor.py +5 -0
- ezmsg/learn/linear_model/sgd.py +5 -0
- ezmsg/learn/linear_model/slda.py +6 -0
- ezmsg/learn/model/__init__.py +0 -0
- ezmsg/learn/model/cca.py +122 -0
- ezmsg/learn/model/mlp.py +133 -0
- ezmsg/learn/model/mlp_old.py +49 -0
- ezmsg/learn/model/refit_kalman.py +401 -0
- ezmsg/learn/model/rnn.py +160 -0
- ezmsg/learn/model/transformer.py +175 -0
- ezmsg/learn/nlin_model/__init__.py +1 -0
- ezmsg/learn/nlin_model/mlp.py +6 -0
- ezmsg/learn/process/__init__.py +0 -0
- ezmsg/learn/process/adaptive_linear_regressor.py +157 -0
- ezmsg/learn/process/base.py +173 -0
- ezmsg/learn/process/linear_regressor.py +99 -0
- ezmsg/learn/process/mlp_old.py +200 -0
- ezmsg/learn/process/refit_kalman.py +407 -0
- ezmsg/learn/process/rnn.py +266 -0
- ezmsg/learn/process/sgd.py +131 -0
- ezmsg/learn/process/sklearn.py +274 -0
- ezmsg/learn/process/slda.py +119 -0
- ezmsg/learn/process/torch.py +378 -0
- ezmsg/learn/process/transformer.py +222 -0
- ezmsg/learn/util.py +66 -0
- ezmsg_learn-1.0.dist-info/METADATA +34 -0
- ezmsg_learn-1.0.dist-info/RECORD +36 -0
- ezmsg_learn-1.0.dist-info/WHEEL +4 -0
ezmsg/learn/model/rnn.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RNNModel(torch.nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Recurrent neural network supporting GRU, LSTM, and vanilla RNN (tanh/relu).
|
|
9
|
+
|
|
10
|
+
Attributes:
|
|
11
|
+
input_size (int): Number of input features per time step.
|
|
12
|
+
hidden_size (int): Number of hidden units in the RNN cell.
|
|
13
|
+
num_layers (int, optional): Number of RNN layers. Default is 1.
|
|
14
|
+
output_size (int | dict[str, int], optional): Number of output features or classes if single head output or a
|
|
15
|
+
dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
|
|
16
|
+
dropout (float, optional): Dropout rate applied after input and RNN output. Default is 0.3.
|
|
17
|
+
rnn_type (str, optional): Type of RNN cell to use: 'GRU', 'LSTM', 'RNN-Tanh', 'RNN-ReLU'. Default is 'GRU'.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
dict[str, torch.Tensor]: Dictionary of decoded predictions mapping head names to tensors of shape
|
|
21
|
+
(batch, seq_len, output_size). If single head output, the key is "output".
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
input_size: int,
|
|
27
|
+
hidden_size: int,
|
|
28
|
+
num_layers: int = 1,
|
|
29
|
+
output_size: int | dict[str, int] = 2,
|
|
30
|
+
dropout: float = 0.3,
|
|
31
|
+
rnn_type: str = "GRU",
|
|
32
|
+
):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.linear_embeddings = torch.nn.Linear(input_size, input_size)
|
|
35
|
+
self.dropout_input = torch.nn.Dropout(dropout)
|
|
36
|
+
|
|
37
|
+
rnn_klass_str = rnn_type.upper().split("-")[0]
|
|
38
|
+
if rnn_klass_str not in ["GRU", "LSTM", "RNN"]:
|
|
39
|
+
raise ValueError(f"Unrecognized rnn_type: {rnn_type}")
|
|
40
|
+
rnn_klass = {"GRU": torch.nn.GRU, "LSTM": torch.nn.LSTM, "RNN": torch.nn.RNN}[
|
|
41
|
+
rnn_klass_str
|
|
42
|
+
]
|
|
43
|
+
rnn_kwargs = {}
|
|
44
|
+
if rnn_klass_str == "RNN":
|
|
45
|
+
rnn_kwargs["nonlinearity"] = rnn_type.lower().split("-")[-1]
|
|
46
|
+
self.rnn = rnn_klass(
|
|
47
|
+
input_size,
|
|
48
|
+
hidden_size,
|
|
49
|
+
num_layers,
|
|
50
|
+
batch_first=True,
|
|
51
|
+
dropout=dropout if num_layers > 1 else 0.0,
|
|
52
|
+
**rnn_kwargs,
|
|
53
|
+
)
|
|
54
|
+
self.rnn_type = rnn_klass_str
|
|
55
|
+
|
|
56
|
+
self.output_dropout = torch.nn.Dropout(dropout)
|
|
57
|
+
if isinstance(output_size, int):
|
|
58
|
+
output_size = {"output": output_size}
|
|
59
|
+
self.heads = torch.nn.ModuleDict(
|
|
60
|
+
{
|
|
61
|
+
name: torch.nn.Linear(hidden_size, size)
|
|
62
|
+
for name, size in output_size.items()
|
|
63
|
+
}
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def infer_config_from_state_dict(
|
|
68
|
+
cls, state_dict: dict, rnn_type: str = "GRU"
|
|
69
|
+
) -> dict[str, int | float]:
|
|
70
|
+
"""
|
|
71
|
+
This method is specific to each processor.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
state_dict: The state dict of the model.
|
|
75
|
+
rnn_type: The type of RNN used in the model (e.g., 'GRU', 'LSTM', 'RNN-Tanh', 'RNN-ReLU').
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
A dictionary of model parameters obtained from the state dict.
|
|
79
|
+
|
|
80
|
+
"""
|
|
81
|
+
output_size = {
|
|
82
|
+
key.split(".")[1]: param.shape[0]
|
|
83
|
+
for key, param in state_dict.items()
|
|
84
|
+
if key.startswith("heads.") and key.endswith(".bias")
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
return {
|
|
88
|
+
# Infer input_size from linear_embeddings.weight (shape: [input_size, input_size])
|
|
89
|
+
"input_size": state_dict["linear_embeddings.weight"].shape[1],
|
|
90
|
+
# Infer hidden_size from rnn.weight_ih_l0 (shape: [hidden_size * 3, input_size])
|
|
91
|
+
"hidden_size": state_dict["rnn.weight_ih_l0"].shape[0]
|
|
92
|
+
// cls._get_gate_count(rnn_type),
|
|
93
|
+
# Infer num_layers by counting rnn layers in state_dict (e.g., weight_ih_l<k>)
|
|
94
|
+
"num_layers": sum(1 for key in state_dict if "rnn.weight_ih_l" in key),
|
|
95
|
+
"output_size": output_size,
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _get_gate_count(rnn_type: str) -> int:
|
|
100
|
+
if rnn_type.upper() == "GRU":
|
|
101
|
+
return 3
|
|
102
|
+
elif rnn_type.upper() == "LSTM":
|
|
103
|
+
return 4
|
|
104
|
+
elif rnn_type.upper().startswith("RNN"):
|
|
105
|
+
return 1
|
|
106
|
+
else:
|
|
107
|
+
raise ValueError(f"Unsupported rnn_type: {rnn_type}")
|
|
108
|
+
|
|
109
|
+
def init_hidden(self, batch_size: int, device: torch.device) -> torch.Tensor:
|
|
110
|
+
"""
|
|
111
|
+
Initialize the hidden state for the RNN.
|
|
112
|
+
Args:
|
|
113
|
+
batch_size (int): Size of the batch.
|
|
114
|
+
device (torch.device): Device to place the hidden state on (e.g., 'cpu' or 'cuda').
|
|
115
|
+
Returns:
|
|
116
|
+
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Initial hidden state for the RNN.
|
|
117
|
+
For LSTM, returns a tuple of (h_n, c_n) where h_n is the hidden state and c_n is the cell state.
|
|
118
|
+
For GRU or vanilla RNN, returns just h_n.
|
|
119
|
+
"""
|
|
120
|
+
shape = (self.rnn.num_layers, batch_size, self.rnn.hidden_size)
|
|
121
|
+
if self.rnn_type == "LSTM":
|
|
122
|
+
return (
|
|
123
|
+
torch.zeros(shape, device=device),
|
|
124
|
+
torch.zeros(shape, device=device),
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
return torch.zeros(shape, device=device)
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self,
|
|
131
|
+
x: torch.Tensor,
|
|
132
|
+
input_lens: Optional[torch.Tensor] = None,
|
|
133
|
+
hx: Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]] = None,
|
|
134
|
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor | tuple]:
|
|
135
|
+
"""
|
|
136
|
+
Forward pass through the RNN model.
|
|
137
|
+
Args:
|
|
138
|
+
x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
|
|
139
|
+
input_lens (Optional[torch.Tensor]): Optional tensor of lengths for each sequence in the batch.
|
|
140
|
+
If provided, sequences will be packed before passing through the RNN.
|
|
141
|
+
hx (Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]): Optional initial hidden state for the RNN.
|
|
142
|
+
Returns:
|
|
143
|
+
tuple[dict[str, torch.Tensor], torch.Tensor | tuple]:
|
|
144
|
+
A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size).
|
|
145
|
+
If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.
|
|
146
|
+
"""
|
|
147
|
+
x = self.linear_embeddings(x)
|
|
148
|
+
x = self.dropout_input(x)
|
|
149
|
+
total_length = x.shape[1]
|
|
150
|
+
if input_lens is not None:
|
|
151
|
+
x = torch.nn.utils.rnn.pack_padded_sequence(
|
|
152
|
+
x, input_lens, batch_first=True, enforce_sorted=False
|
|
153
|
+
)
|
|
154
|
+
x_out, hx_out = self.rnn(x, hx)
|
|
155
|
+
if input_lens is not None:
|
|
156
|
+
x_out, _ = torch.nn.utils.rnn.pad_packed_sequence(
|
|
157
|
+
x_out, batch_first=True, total_length=total_length
|
|
158
|
+
)
|
|
159
|
+
x_out = self.output_dropout(x_out)
|
|
160
|
+
return {name: head(x_out) for name, head in self.heads.items()}, hx_out
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TransformerModel(torch.nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Transformer-based encoder (optional decoder) neural network.
|
|
9
|
+
|
|
10
|
+
If `decoder_layers > 0`, the model includes a Transformer decoder. In this case, the `tgt` argument must be
|
|
11
|
+
provided: during training, it is typically the ground-truth target sequence (i.e. teacher forcing); during
|
|
12
|
+
inference, it can be constructed autoregressively from previous predictions.
|
|
13
|
+
|
|
14
|
+
Attributes:
|
|
15
|
+
input_size (int): Number of input features per time step.
|
|
16
|
+
hidden_size (int): Dimensionality of the transformer model.
|
|
17
|
+
encoder_layers (int, optional): Number of transformer encoder layers. Default is 1.
|
|
18
|
+
decoder_layers (int, optional): Number of transformer decoder layers. Default is 0.
|
|
19
|
+
output_size (int | dict[str, int], optional): Number of output features or classes if single head output, or a
|
|
20
|
+
dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
|
|
21
|
+
dropout (float, optional): Dropout rate applied after input and transformer output. Default is 0.3.
|
|
22
|
+
attention_heads (int, optional): Number of attention heads in the transformer. Default is 4.
|
|
23
|
+
max_seq_len (int, optional): Maximum sequence length for positional embeddings. Default is 512.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
dict[str, torch.Tensor]: Dictionary of decoded predictions mapping head names to tensors of shape
|
|
27
|
+
(batch, seq_len, output_size). If single head output, the key is "output".
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
input_size: int,
|
|
33
|
+
hidden_size: int,
|
|
34
|
+
encoder_layers: int = 1,
|
|
35
|
+
decoder_layers: int = 0,
|
|
36
|
+
output_size: int | dict[str, int] = 2,
|
|
37
|
+
dropout: float = 0.3,
|
|
38
|
+
attention_heads: int = 4,
|
|
39
|
+
max_seq_len: int = 512,
|
|
40
|
+
autoregressive_head: str | None = None,
|
|
41
|
+
):
|
|
42
|
+
super().__init__()
|
|
43
|
+
|
|
44
|
+
self.decoder_layers = decoder_layers
|
|
45
|
+
self.hidden_size = hidden_size
|
|
46
|
+
|
|
47
|
+
if isinstance(output_size, int):
|
|
48
|
+
autoregressive_size = output_size
|
|
49
|
+
else:
|
|
50
|
+
autoregressive_size = list(output_size.values())[0]
|
|
51
|
+
if isinstance(output_size, dict):
|
|
52
|
+
autoregressive_size = output_size.get(
|
|
53
|
+
autoregressive_head, autoregressive_size
|
|
54
|
+
)
|
|
55
|
+
self.start_token = torch.nn.Parameter(torch.zeros(1, 1, autoregressive_size))
|
|
56
|
+
self.output_to_hidden = torch.nn.Linear(autoregressive_size, hidden_size)
|
|
57
|
+
|
|
58
|
+
self.input_proj = torch.nn.Linear(input_size, hidden_size)
|
|
59
|
+
self.pos_embedding = torch.nn.Embedding(max_seq_len, hidden_size)
|
|
60
|
+
self.dropout = torch.nn.Dropout(dropout)
|
|
61
|
+
|
|
62
|
+
self.encoder = torch.nn.TransformerEncoder(
|
|
63
|
+
torch.nn.TransformerEncoderLayer(
|
|
64
|
+
d_model=hidden_size,
|
|
65
|
+
nhead=attention_heads,
|
|
66
|
+
dim_feedforward=hidden_size * 4,
|
|
67
|
+
dropout=dropout,
|
|
68
|
+
batch_first=True,
|
|
69
|
+
),
|
|
70
|
+
num_layers=encoder_layers,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.decoder = None
|
|
74
|
+
if decoder_layers > 0:
|
|
75
|
+
self.decoder = torch.nn.TransformerDecoder(
|
|
76
|
+
torch.nn.TransformerDecoderLayer(
|
|
77
|
+
d_model=hidden_size,
|
|
78
|
+
nhead=attention_heads,
|
|
79
|
+
dim_feedforward=hidden_size * 4,
|
|
80
|
+
dropout=dropout,
|
|
81
|
+
batch_first=True,
|
|
82
|
+
),
|
|
83
|
+
num_layers=decoder_layers,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if isinstance(output_size, int):
|
|
87
|
+
output_size = {"output": output_size}
|
|
88
|
+
self.heads = torch.nn.ModuleDict(
|
|
89
|
+
{
|
|
90
|
+
name: torch.nn.Linear(hidden_size, out_dim)
|
|
91
|
+
for name, out_dim in output_size.items()
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def infer_config_from_state_dict(cls, state_dict: dict) -> dict[str, int | float]:
|
|
97
|
+
# Infer output size from heads.<name>.bias (shape: [output_size])
|
|
98
|
+
output_size = {
|
|
99
|
+
key.split(".")[1]: param.shape[0]
|
|
100
|
+
for key, param in state_dict.items()
|
|
101
|
+
if key.startswith("heads.") and key.endswith(".bias")
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
# Infer input_size from input_proj.weight (shape: [hidden_size, input_size])
|
|
106
|
+
"input_size": state_dict["input_proj.weight"].shape[1],
|
|
107
|
+
# Infer hidden_size from input_proj.weight (shape: [hidden_size, input_size])
|
|
108
|
+
"hidden_size": state_dict["input_proj.weight"].shape[0],
|
|
109
|
+
"output_size": output_size,
|
|
110
|
+
# Infer encoder_layers from transformer layers in state_dict
|
|
111
|
+
"encoder_layers": len(
|
|
112
|
+
[k for k in state_dict if k.startswith("encoder.layers")]
|
|
113
|
+
),
|
|
114
|
+
# Infer decoder_layers from transformer decoder layers in state_dict
|
|
115
|
+
"decoder_layers": len(
|
|
116
|
+
{k.split(".")[2] for k in state_dict if k.startswith("decoder.layers")}
|
|
117
|
+
)
|
|
118
|
+
if any(k.startswith("decoder.layers") for k in state_dict)
|
|
119
|
+
else 0,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
def forward(
|
|
123
|
+
self,
|
|
124
|
+
src: torch.Tensor,
|
|
125
|
+
tgt: Optional[torch.Tensor] = None,
|
|
126
|
+
src_mask: Optional[torch.Tensor] = None,
|
|
127
|
+
tgt_mask: Optional[torch.Tensor] = None,
|
|
128
|
+
start_pos: int = 0,
|
|
129
|
+
) -> dict[str, torch.Tensor]:
|
|
130
|
+
"""
|
|
131
|
+
Forward pass through the transformer model.
|
|
132
|
+
Args:
|
|
133
|
+
src (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
|
|
134
|
+
tgt (Optional[torch.Tensor]): Target tensor for decoder, shape (batch, seq_len, input_size).
|
|
135
|
+
Required if `decoder_layers > 0`. In training, this can be the ground-truth target sequence
|
|
136
|
+
(i.e. teacher forcing). During inference, this is constructed autoregressively.
|
|
137
|
+
src_mask (Optional[torch.Tensor]): Optional attention mask for the encoder input. Should be broadcastable
|
|
138
|
+
to shape (batch, seq_len, seq_len) or (seq_len, seq_len).
|
|
139
|
+
tgt_mask (Optional[torch.Tensor]): Optional attention mask for the decoder input. Used to enforce causal
|
|
140
|
+
decoding (i.e. autoregressive generation) during training or inference.
|
|
141
|
+
start_pos (int): Starting offset for positional embeddings. Used for streaming inference to maintain
|
|
142
|
+
correct positional indices. Default is 0.
|
|
143
|
+
Returns:
|
|
144
|
+
dict[str, torch.Tensor]: Dictionary of output tensors each output head, each with shape (batch, seq_len,
|
|
145
|
+
output_size).
|
|
146
|
+
"""
|
|
147
|
+
B, T, _ = src.shape
|
|
148
|
+
device = src.device
|
|
149
|
+
|
|
150
|
+
x = self.input_proj(src)
|
|
151
|
+
pos_ids = torch.arange(start_pos, start_pos + T, device=device).expand(B, T)
|
|
152
|
+
x = x + self.pos_embedding(pos_ids)
|
|
153
|
+
x = self.dropout(x)
|
|
154
|
+
|
|
155
|
+
memory = self.encoder(x, mask=src_mask)
|
|
156
|
+
|
|
157
|
+
if self.decoder is not None:
|
|
158
|
+
if tgt is None:
|
|
159
|
+
tgt = self.start_token.expand(B, -1, -1).to(device)
|
|
160
|
+
tgt_proj = self.output_to_hidden(tgt)
|
|
161
|
+
tgt_pos_ids = torch.arange(tgt.shape[1], device=device).expand(
|
|
162
|
+
B, tgt.shape[1]
|
|
163
|
+
)
|
|
164
|
+
tgt_proj = tgt_proj + self.pos_embedding(tgt_pos_ids)
|
|
165
|
+
tgt_proj = self.dropout(tgt_proj)
|
|
166
|
+
out = self.decoder(
|
|
167
|
+
tgt_proj,
|
|
168
|
+
memory,
|
|
169
|
+
tgt_mask=tgt_mask,
|
|
170
|
+
memory_mask=src_mask,
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
out = memory
|
|
174
|
+
|
|
175
|
+
return {name: head(out) for name, head in self.heads.items()}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Use of this module is deprecated. Please use `ezmsg.learn.model` or `ezmsg.learn.process` instead.
|
|
File without changes
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from dataclasses import field
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import river.optim
|
|
6
|
+
import river.linear_model
|
|
7
|
+
import sklearn.base
|
|
8
|
+
import ezmsg.core as ez
|
|
9
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
10
|
+
from ezmsg.sigproc.base import (
|
|
11
|
+
processor_state,
|
|
12
|
+
BaseAdaptiveTransformer,
|
|
13
|
+
BaseAdaptiveTransformerUnit,
|
|
14
|
+
)
|
|
15
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
16
|
+
|
|
17
|
+
from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AdaptiveLinearRegressorSettings(ez.Settings):
|
|
21
|
+
model_type: AdaptiveLinearRegressor = AdaptiveLinearRegressor.LINEAR
|
|
22
|
+
settings_path: str | None = None
|
|
23
|
+
model_kwargs: dict = field(default_factory=dict)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@processor_state
|
|
27
|
+
class AdaptiveLinearRegressorState:
|
|
28
|
+
template: AxisArray | None = None
|
|
29
|
+
model: river.linear_model.base.GLM | sklearn.base.RegressorMixin | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AdaptiveLinearRegressorTransformer(
|
|
33
|
+
BaseAdaptiveTransformer[
|
|
34
|
+
AdaptiveLinearRegressorSettings,
|
|
35
|
+
AxisArray,
|
|
36
|
+
AxisArray,
|
|
37
|
+
AdaptiveLinearRegressorState,
|
|
38
|
+
]
|
|
39
|
+
):
|
|
40
|
+
def __init__(self, *args, **kwargs):
|
|
41
|
+
super().__init__(*args, **kwargs)
|
|
42
|
+
self.settings = replace(
|
|
43
|
+
self.settings, model_type=AdaptiveLinearRegressor(self.settings.model_type)
|
|
44
|
+
)
|
|
45
|
+
b_river = self.settings.model_type in [
|
|
46
|
+
AdaptiveLinearRegressor.LINEAR,
|
|
47
|
+
AdaptiveLinearRegressor.LOGISTIC,
|
|
48
|
+
]
|
|
49
|
+
if b_river:
|
|
50
|
+
self.settings.model_kwargs["l2"] = self.settings.model_kwargs.get("l2", 0.0)
|
|
51
|
+
if "learn_rate" in self.settings.model_kwargs:
|
|
52
|
+
self.settings.model_kwargs["optimizer"] = river.optim.SGD(
|
|
53
|
+
self.settings.model_kwargs.pop("learn_rate")
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if self.settings.settings_path is not None:
|
|
57
|
+
# Load model from file
|
|
58
|
+
import pickle
|
|
59
|
+
|
|
60
|
+
with open(self.settings.settings_path, "rb") as f:
|
|
61
|
+
self.state.model = pickle.load(f)
|
|
62
|
+
|
|
63
|
+
if b_river:
|
|
64
|
+
# Override with kwargs?!
|
|
65
|
+
self.state.model.l2 = self.settings.model_kwargs["l2"]
|
|
66
|
+
if "optimizer" in self.settings.model_kwargs:
|
|
67
|
+
self.state.model.optimizer = self.settings.model_kwargs["optimizer"]
|
|
68
|
+
else:
|
|
69
|
+
print("TODO: Override sklearn model with kwargs")
|
|
70
|
+
else:
|
|
71
|
+
# Build model from scratch.
|
|
72
|
+
regressor_klass = get_regressor(
|
|
73
|
+
RegressorType.ADAPTIVE, self.settings.model_type
|
|
74
|
+
)
|
|
75
|
+
self.state.model = regressor_klass(**self.settings.model_kwargs)
|
|
76
|
+
|
|
77
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
78
|
+
# So far, nothing to reset so hash can be constant.
|
|
79
|
+
return -1
|
|
80
|
+
|
|
81
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
82
|
+
# So far, there is nothing to reset.
|
|
83
|
+
# .model is initialized in __init__
|
|
84
|
+
# .template is updated in partial_fit
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
88
|
+
if np.any(np.isnan(message.sample.data)):
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
if self.settings.model_type in [
|
|
92
|
+
AdaptiveLinearRegressor.LINEAR,
|
|
93
|
+
AdaptiveLinearRegressor.LOGISTIC,
|
|
94
|
+
]:
|
|
95
|
+
x = pd.DataFrame.from_dict(
|
|
96
|
+
{
|
|
97
|
+
k: v
|
|
98
|
+
for k, v in zip(
|
|
99
|
+
message.sample.axes["ch"].data, message.sample.data.T
|
|
100
|
+
)
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
y = pd.Series(
|
|
104
|
+
data=message.trigger.value.data[:, 0],
|
|
105
|
+
name=message.trigger.value.axes["ch"].data[0],
|
|
106
|
+
)
|
|
107
|
+
self.state.model.learn_many(x, y)
|
|
108
|
+
else:
|
|
109
|
+
X = message.sample.data
|
|
110
|
+
if message.sample.get_axis_idx("time") != 0:
|
|
111
|
+
X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0)
|
|
112
|
+
self.state.model.partial_fit(X, message.trigger.value.data)
|
|
113
|
+
|
|
114
|
+
self.state.template = replace(
|
|
115
|
+
message.trigger.value,
|
|
116
|
+
data=np.empty_like(message.trigger.value.data),
|
|
117
|
+
key=message.trigger.value.key + "_pred",
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def _process(self, message: AxisArray) -> AxisArray | None:
|
|
121
|
+
if self.state.template is None:
|
|
122
|
+
return AxisArray(np.array([]), dims=[""])
|
|
123
|
+
|
|
124
|
+
if not np.any(np.isnan(message.data)):
|
|
125
|
+
if self.settings.model_type in [
|
|
126
|
+
AdaptiveLinearRegressor.LINEAR,
|
|
127
|
+
AdaptiveLinearRegressor.LOGISTIC,
|
|
128
|
+
]:
|
|
129
|
+
# convert msg_in.data to something appropriate for river
|
|
130
|
+
x = pd.DataFrame.from_dict(
|
|
131
|
+
{k: v for k, v in zip(message.axes["ch"].data, message.data.T)}
|
|
132
|
+
)
|
|
133
|
+
preds = self.state.model.predict_many(x).values
|
|
134
|
+
else:
|
|
135
|
+
preds = self.state.model.predict(message.data)
|
|
136
|
+
return replace(
|
|
137
|
+
self.state.template,
|
|
138
|
+
data=preds.reshape((len(preds), -1)),
|
|
139
|
+
axes={
|
|
140
|
+
**self.state.template.axes,
|
|
141
|
+
"time": replace(
|
|
142
|
+
message.axes["time"],
|
|
143
|
+
offset=message.axes["time"].offset,
|
|
144
|
+
),
|
|
145
|
+
},
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class AdaptiveLinearRegressorUnit(
|
|
150
|
+
BaseAdaptiveTransformerUnit[
|
|
151
|
+
AdaptiveLinearRegressorSettings,
|
|
152
|
+
AxisArray,
|
|
153
|
+
AxisArray,
|
|
154
|
+
AdaptiveLinearRegressorTransformer,
|
|
155
|
+
]
|
|
156
|
+
):
|
|
157
|
+
SETTINGS = AdaptiveLinearRegressorSettings
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import typing
|
|
5
|
+
|
|
6
|
+
import ezmsg.core as ez
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelInitMixin:
|
|
11
|
+
"""
|
|
12
|
+
Mixin class to support model initialization from:
|
|
13
|
+
1. Setting parameters
|
|
14
|
+
2. Config file
|
|
15
|
+
3. Checkpoint file
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def _merge_config(model_kwargs: dict, config) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Mutate the model_kwargs dictionary with the config parameters.
|
|
22
|
+
Args:
|
|
23
|
+
model_kwargs: Original to-be-mutated model kwargs.
|
|
24
|
+
config: Update config parameters.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
None because model_kwargs is mutated in place.
|
|
28
|
+
"""
|
|
29
|
+
if "model_params" in config:
|
|
30
|
+
config = config["model_params"]
|
|
31
|
+
# Update model_kwargs with config parameters
|
|
32
|
+
for key, value in config.items():
|
|
33
|
+
if key in model_kwargs:
|
|
34
|
+
if model_kwargs[key] != value:
|
|
35
|
+
ez.logger.warning(
|
|
36
|
+
f"Config parameter {key} ({value}) differs from settings ({model_kwargs[key]})."
|
|
37
|
+
)
|
|
38
|
+
else:
|
|
39
|
+
ez.logger.warning(f"Config parameter {key} is not in model_kwargs.")
|
|
40
|
+
model_kwargs[key] = value
|
|
41
|
+
|
|
42
|
+
def _filter_model_kwargs(self, model_class, kwargs: dict) -> dict:
|
|
43
|
+
valid_params = inspect.signature(model_class.__init__).parameters
|
|
44
|
+
filtered_out = set(kwargs.keys()) - {k for k in valid_params if k != "self"}
|
|
45
|
+
if filtered_out:
|
|
46
|
+
ez.logger.warning(
|
|
47
|
+
f"Ignoring unexpected model parameters not accepted by {model_class.__name__} constructor: {sorted(filtered_out)}"
|
|
48
|
+
)
|
|
49
|
+
# Keep all valid parameters, including None values, so checkpoint-inferred values can overwrite them
|
|
50
|
+
return {k: v for k, v in kwargs.items() if k in valid_params and k != "self"}
|
|
51
|
+
|
|
52
|
+
def _init_model(
|
|
53
|
+
self,
|
|
54
|
+
model_class,
|
|
55
|
+
params: dict[str, typing.Any] | None = None,
|
|
56
|
+
config_path: str | None = None,
|
|
57
|
+
checkpoint_path: str | None = None,
|
|
58
|
+
device: str = "cpu",
|
|
59
|
+
state_dict_prefix: str | None = None,
|
|
60
|
+
weights_only: bool | None = None,
|
|
61
|
+
) -> torch.nn.Module:
|
|
62
|
+
"""
|
|
63
|
+
Args:
|
|
64
|
+
model_class: The class of the model to be initialized.
|
|
65
|
+
params: A dictionary of setting parameters to be used for model initialization.
|
|
66
|
+
config_path: Path to a JSON config file to update model parameters.
|
|
67
|
+
checkpoint_path: Path to a checkpoint file to load model weights and possibly config.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The initialized model.
|
|
71
|
+
The model will be initialized with the correct config and weights.
|
|
72
|
+
|
|
73
|
+
"""
|
|
74
|
+
# Model parameters are taken from multiple sources, in ascending priority:
|
|
75
|
+
# 1. Setting parameters
|
|
76
|
+
# 2. Config file if provided
|
|
77
|
+
# 3. "config" entry in checkpoint file if checkpoint file provided and config present
|
|
78
|
+
# 4. Sizes of weights in checkpoint file if provided
|
|
79
|
+
|
|
80
|
+
# Get configs from setting params.
|
|
81
|
+
model_kwargs = params or {}
|
|
82
|
+
state_dict = None
|
|
83
|
+
|
|
84
|
+
# Check if a config file is provided and if so use that to update kwargs (with warnings).
|
|
85
|
+
if config_path:
|
|
86
|
+
config_path = Path(config_path)
|
|
87
|
+
if not config_path.exists():
|
|
88
|
+
ez.logger.error(f"Config path {config_path} does not exist.")
|
|
89
|
+
raise FileNotFoundError(f"Config path {config_path} does not exist.")
|
|
90
|
+
try:
|
|
91
|
+
with open(config_path, "r") as f:
|
|
92
|
+
config = json.load(f)
|
|
93
|
+
self._merge_config(model_kwargs, config)
|
|
94
|
+
except Exception as e:
|
|
95
|
+
raise RuntimeError(
|
|
96
|
+
f"Failed to load config from {config_path}: {str(e)}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# If a checkpoint file is provided, load it.
|
|
100
|
+
if checkpoint_path:
|
|
101
|
+
checkpoint_path = Path(checkpoint_path)
|
|
102
|
+
if not checkpoint_path.exists():
|
|
103
|
+
ez.logger.error(f"Checkpoint path {checkpoint_path} does not exist.")
|
|
104
|
+
raise FileNotFoundError(
|
|
105
|
+
f"Checkpoint path {checkpoint_path} does not exist."
|
|
106
|
+
)
|
|
107
|
+
try:
|
|
108
|
+
checkpoint = torch.load(
|
|
109
|
+
checkpoint_path, map_location=device, weights_only=weights_only
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if "config" in checkpoint:
|
|
113
|
+
config = checkpoint["config"]
|
|
114
|
+
self._merge_config(model_kwargs, config)
|
|
115
|
+
|
|
116
|
+
# Load the model weights and infer the config.
|
|
117
|
+
state_dict = checkpoint
|
|
118
|
+
if "model_state_dict" in checkpoint:
|
|
119
|
+
state_dict = checkpoint["model_state_dict"]
|
|
120
|
+
elif "state_dict" in checkpoint:
|
|
121
|
+
# This is for backward compatibility with older checkpoints
|
|
122
|
+
# that used "state_dict" instead of "model_state_dict"
|
|
123
|
+
state_dict = checkpoint["state_dict"]
|
|
124
|
+
infer_config = getattr(
|
|
125
|
+
model_class,
|
|
126
|
+
"infer_config_from_state_dict",
|
|
127
|
+
lambda _state_dict: {}, # Default to empty dict if not defined
|
|
128
|
+
)
|
|
129
|
+
infer_kwargs = (
|
|
130
|
+
{"rnn_type": model_kwargs["rnn_type"]}
|
|
131
|
+
if "rnn_type" in model_kwargs
|
|
132
|
+
else {}
|
|
133
|
+
)
|
|
134
|
+
self._merge_config(
|
|
135
|
+
model_kwargs,
|
|
136
|
+
infer_config(state_dict, **infer_kwargs),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
raise RuntimeError(
|
|
141
|
+
f"Failed to load checkpoint from {checkpoint_path}: {str(e)}"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Filter model_kwargs to only include valid parameters for the model class
|
|
145
|
+
filtered_kwargs = self._filter_model_kwargs(model_class, model_kwargs)
|
|
146
|
+
|
|
147
|
+
# Remove None values from filtered_kwargs to avoid passing them to the model constructor
|
|
148
|
+
# This should only happen for parameters that weren't inferred from the checkpoint
|
|
149
|
+
final_kwargs = {k: v for k, v in filtered_kwargs.items() if v is not None}
|
|
150
|
+
|
|
151
|
+
# Create the model with the final kwargs
|
|
152
|
+
model = model_class(**final_kwargs)
|
|
153
|
+
|
|
154
|
+
# Finally, load the weights.
|
|
155
|
+
if state_dict:
|
|
156
|
+
if state_dict_prefix:
|
|
157
|
+
# If a prefix is provided, filter the state_dict keys
|
|
158
|
+
state_dict = {
|
|
159
|
+
k[len(state_dict_prefix) :]: v
|
|
160
|
+
for k, v in state_dict.items()
|
|
161
|
+
if k.startswith(state_dict_prefix)
|
|
162
|
+
}
|
|
163
|
+
# Load the model weights
|
|
164
|
+
missing, unexpected = model.load_state_dict(
|
|
165
|
+
state_dict, strict=False, assign=True
|
|
166
|
+
)
|
|
167
|
+
if missing or unexpected:
|
|
168
|
+
ez.logger.warning(
|
|
169
|
+
f"Partial load: missing keys: {missing}, unexpected keys: {unexpected}"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
model.to(device)
|
|
173
|
+
return model
|