ezmsg-learn 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,160 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+
6
+ class RNNModel(torch.nn.Module):
7
+ """
8
+ Recurrent neural network supporting GRU, LSTM, and vanilla RNN (tanh/relu).
9
+
10
+ Attributes:
11
+ input_size (int): Number of input features per time step.
12
+ hidden_size (int): Number of hidden units in the RNN cell.
13
+ num_layers (int, optional): Number of RNN layers. Default is 1.
14
+ output_size (int | dict[str, int], optional): Number of output features or classes if single head output or a
15
+ dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
16
+ dropout (float, optional): Dropout rate applied after input and RNN output. Default is 0.3.
17
+ rnn_type (str, optional): Type of RNN cell to use: 'GRU', 'LSTM', 'RNN-Tanh', 'RNN-ReLU'. Default is 'GRU'.
18
+
19
+ Returns:
20
+ dict[str, torch.Tensor]: Dictionary of decoded predictions mapping head names to tensors of shape
21
+ (batch, seq_len, output_size). If single head output, the key is "output".
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ input_size: int,
27
+ hidden_size: int,
28
+ num_layers: int = 1,
29
+ output_size: int | dict[str, int] = 2,
30
+ dropout: float = 0.3,
31
+ rnn_type: str = "GRU",
32
+ ):
33
+ super().__init__()
34
+ self.linear_embeddings = torch.nn.Linear(input_size, input_size)
35
+ self.dropout_input = torch.nn.Dropout(dropout)
36
+
37
+ rnn_klass_str = rnn_type.upper().split("-")[0]
38
+ if rnn_klass_str not in ["GRU", "LSTM", "RNN"]:
39
+ raise ValueError(f"Unrecognized rnn_type: {rnn_type}")
40
+ rnn_klass = {"GRU": torch.nn.GRU, "LSTM": torch.nn.LSTM, "RNN": torch.nn.RNN}[
41
+ rnn_klass_str
42
+ ]
43
+ rnn_kwargs = {}
44
+ if rnn_klass_str == "RNN":
45
+ rnn_kwargs["nonlinearity"] = rnn_type.lower().split("-")[-1]
46
+ self.rnn = rnn_klass(
47
+ input_size,
48
+ hidden_size,
49
+ num_layers,
50
+ batch_first=True,
51
+ dropout=dropout if num_layers > 1 else 0.0,
52
+ **rnn_kwargs,
53
+ )
54
+ self.rnn_type = rnn_klass_str
55
+
56
+ self.output_dropout = torch.nn.Dropout(dropout)
57
+ if isinstance(output_size, int):
58
+ output_size = {"output": output_size}
59
+ self.heads = torch.nn.ModuleDict(
60
+ {
61
+ name: torch.nn.Linear(hidden_size, size)
62
+ for name, size in output_size.items()
63
+ }
64
+ )
65
+
66
+ @classmethod
67
+ def infer_config_from_state_dict(
68
+ cls, state_dict: dict, rnn_type: str = "GRU"
69
+ ) -> dict[str, int | float]:
70
+ """
71
+ This method is specific to each processor.
72
+
73
+ Args:
74
+ state_dict: The state dict of the model.
75
+ rnn_type: The type of RNN used in the model (e.g., 'GRU', 'LSTM', 'RNN-Tanh', 'RNN-ReLU').
76
+
77
+ Returns:
78
+ A dictionary of model parameters obtained from the state dict.
79
+
80
+ """
81
+ output_size = {
82
+ key.split(".")[1]: param.shape[0]
83
+ for key, param in state_dict.items()
84
+ if key.startswith("heads.") and key.endswith(".bias")
85
+ }
86
+
87
+ return {
88
+ # Infer input_size from linear_embeddings.weight (shape: [input_size, input_size])
89
+ "input_size": state_dict["linear_embeddings.weight"].shape[1],
90
+ # Infer hidden_size from rnn.weight_ih_l0 (shape: [hidden_size * 3, input_size])
91
+ "hidden_size": state_dict["rnn.weight_ih_l0"].shape[0]
92
+ // cls._get_gate_count(rnn_type),
93
+ # Infer num_layers by counting rnn layers in state_dict (e.g., weight_ih_l<k>)
94
+ "num_layers": sum(1 for key in state_dict if "rnn.weight_ih_l" in key),
95
+ "output_size": output_size,
96
+ }
97
+
98
+ @staticmethod
99
+ def _get_gate_count(rnn_type: str) -> int:
100
+ if rnn_type.upper() == "GRU":
101
+ return 3
102
+ elif rnn_type.upper() == "LSTM":
103
+ return 4
104
+ elif rnn_type.upper().startswith("RNN"):
105
+ return 1
106
+ else:
107
+ raise ValueError(f"Unsupported rnn_type: {rnn_type}")
108
+
109
+ def init_hidden(self, batch_size: int, device: torch.device) -> torch.Tensor:
110
+ """
111
+ Initialize the hidden state for the RNN.
112
+ Args:
113
+ batch_size (int): Size of the batch.
114
+ device (torch.device): Device to place the hidden state on (e.g., 'cpu' or 'cuda').
115
+ Returns:
116
+ torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Initial hidden state for the RNN.
117
+ For LSTM, returns a tuple of (h_n, c_n) where h_n is the hidden state and c_n is the cell state.
118
+ For GRU or vanilla RNN, returns just h_n.
119
+ """
120
+ shape = (self.rnn.num_layers, batch_size, self.rnn.hidden_size)
121
+ if self.rnn_type == "LSTM":
122
+ return (
123
+ torch.zeros(shape, device=device),
124
+ torch.zeros(shape, device=device),
125
+ )
126
+ else:
127
+ return torch.zeros(shape, device=device)
128
+
129
+ def forward(
130
+ self,
131
+ x: torch.Tensor,
132
+ input_lens: Optional[torch.Tensor] = None,
133
+ hx: Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]] = None,
134
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor | tuple]:
135
+ """
136
+ Forward pass through the RNN model.
137
+ Args:
138
+ x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
139
+ input_lens (Optional[torch.Tensor]): Optional tensor of lengths for each sequence in the batch.
140
+ If provided, sequences will be packed before passing through the RNN.
141
+ hx (Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]): Optional initial hidden state for the RNN.
142
+ Returns:
143
+ tuple[dict[str, torch.Tensor], torch.Tensor | tuple]:
144
+ A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size).
145
+ If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.
146
+ """
147
+ x = self.linear_embeddings(x)
148
+ x = self.dropout_input(x)
149
+ total_length = x.shape[1]
150
+ if input_lens is not None:
151
+ x = torch.nn.utils.rnn.pack_padded_sequence(
152
+ x, input_lens, batch_first=True, enforce_sorted=False
153
+ )
154
+ x_out, hx_out = self.rnn(x, hx)
155
+ if input_lens is not None:
156
+ x_out, _ = torch.nn.utils.rnn.pad_packed_sequence(
157
+ x_out, batch_first=True, total_length=total_length
158
+ )
159
+ x_out = self.output_dropout(x_out)
160
+ return {name: head(x_out) for name, head in self.heads.items()}, hx_out
@@ -0,0 +1,175 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+
6
+ class TransformerModel(torch.nn.Module):
7
+ """
8
+ Transformer-based encoder (optional decoder) neural network.
9
+
10
+ If `decoder_layers > 0`, the model includes a Transformer decoder. In this case, the `tgt` argument must be
11
+ provided: during training, it is typically the ground-truth target sequence (i.e. teacher forcing); during
12
+ inference, it can be constructed autoregressively from previous predictions.
13
+
14
+ Attributes:
15
+ input_size (int): Number of input features per time step.
16
+ hidden_size (int): Dimensionality of the transformer model.
17
+ encoder_layers (int, optional): Number of transformer encoder layers. Default is 1.
18
+ decoder_layers (int, optional): Number of transformer decoder layers. Default is 0.
19
+ output_size (int | dict[str, int], optional): Number of output features or classes if single head output, or a
20
+ dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
21
+ dropout (float, optional): Dropout rate applied after input and transformer output. Default is 0.3.
22
+ attention_heads (int, optional): Number of attention heads in the transformer. Default is 4.
23
+ max_seq_len (int, optional): Maximum sequence length for positional embeddings. Default is 512.
24
+
25
+ Returns:
26
+ dict[str, torch.Tensor]: Dictionary of decoded predictions mapping head names to tensors of shape
27
+ (batch, seq_len, output_size). If single head output, the key is "output".
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ input_size: int,
33
+ hidden_size: int,
34
+ encoder_layers: int = 1,
35
+ decoder_layers: int = 0,
36
+ output_size: int | dict[str, int] = 2,
37
+ dropout: float = 0.3,
38
+ attention_heads: int = 4,
39
+ max_seq_len: int = 512,
40
+ autoregressive_head: str | None = None,
41
+ ):
42
+ super().__init__()
43
+
44
+ self.decoder_layers = decoder_layers
45
+ self.hidden_size = hidden_size
46
+
47
+ if isinstance(output_size, int):
48
+ autoregressive_size = output_size
49
+ else:
50
+ autoregressive_size = list(output_size.values())[0]
51
+ if isinstance(output_size, dict):
52
+ autoregressive_size = output_size.get(
53
+ autoregressive_head, autoregressive_size
54
+ )
55
+ self.start_token = torch.nn.Parameter(torch.zeros(1, 1, autoregressive_size))
56
+ self.output_to_hidden = torch.nn.Linear(autoregressive_size, hidden_size)
57
+
58
+ self.input_proj = torch.nn.Linear(input_size, hidden_size)
59
+ self.pos_embedding = torch.nn.Embedding(max_seq_len, hidden_size)
60
+ self.dropout = torch.nn.Dropout(dropout)
61
+
62
+ self.encoder = torch.nn.TransformerEncoder(
63
+ torch.nn.TransformerEncoderLayer(
64
+ d_model=hidden_size,
65
+ nhead=attention_heads,
66
+ dim_feedforward=hidden_size * 4,
67
+ dropout=dropout,
68
+ batch_first=True,
69
+ ),
70
+ num_layers=encoder_layers,
71
+ )
72
+
73
+ self.decoder = None
74
+ if decoder_layers > 0:
75
+ self.decoder = torch.nn.TransformerDecoder(
76
+ torch.nn.TransformerDecoderLayer(
77
+ d_model=hidden_size,
78
+ nhead=attention_heads,
79
+ dim_feedforward=hidden_size * 4,
80
+ dropout=dropout,
81
+ batch_first=True,
82
+ ),
83
+ num_layers=decoder_layers,
84
+ )
85
+
86
+ if isinstance(output_size, int):
87
+ output_size = {"output": output_size}
88
+ self.heads = torch.nn.ModuleDict(
89
+ {
90
+ name: torch.nn.Linear(hidden_size, out_dim)
91
+ for name, out_dim in output_size.items()
92
+ }
93
+ )
94
+
95
+ @classmethod
96
+ def infer_config_from_state_dict(cls, state_dict: dict) -> dict[str, int | float]:
97
+ # Infer output size from heads.<name>.bias (shape: [output_size])
98
+ output_size = {
99
+ key.split(".")[1]: param.shape[0]
100
+ for key, param in state_dict.items()
101
+ if key.startswith("heads.") and key.endswith(".bias")
102
+ }
103
+
104
+ return {
105
+ # Infer input_size from input_proj.weight (shape: [hidden_size, input_size])
106
+ "input_size": state_dict["input_proj.weight"].shape[1],
107
+ # Infer hidden_size from input_proj.weight (shape: [hidden_size, input_size])
108
+ "hidden_size": state_dict["input_proj.weight"].shape[0],
109
+ "output_size": output_size,
110
+ # Infer encoder_layers from transformer layers in state_dict
111
+ "encoder_layers": len(
112
+ [k for k in state_dict if k.startswith("encoder.layers")]
113
+ ),
114
+ # Infer decoder_layers from transformer decoder layers in state_dict
115
+ "decoder_layers": len(
116
+ {k.split(".")[2] for k in state_dict if k.startswith("decoder.layers")}
117
+ )
118
+ if any(k.startswith("decoder.layers") for k in state_dict)
119
+ else 0,
120
+ }
121
+
122
+ def forward(
123
+ self,
124
+ src: torch.Tensor,
125
+ tgt: Optional[torch.Tensor] = None,
126
+ src_mask: Optional[torch.Tensor] = None,
127
+ tgt_mask: Optional[torch.Tensor] = None,
128
+ start_pos: int = 0,
129
+ ) -> dict[str, torch.Tensor]:
130
+ """
131
+ Forward pass through the transformer model.
132
+ Args:
133
+ src (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
134
+ tgt (Optional[torch.Tensor]): Target tensor for decoder, shape (batch, seq_len, input_size).
135
+ Required if `decoder_layers > 0`. In training, this can be the ground-truth target sequence
136
+ (i.e. teacher forcing). During inference, this is constructed autoregressively.
137
+ src_mask (Optional[torch.Tensor]): Optional attention mask for the encoder input. Should be broadcastable
138
+ to shape (batch, seq_len, seq_len) or (seq_len, seq_len).
139
+ tgt_mask (Optional[torch.Tensor]): Optional attention mask for the decoder input. Used to enforce causal
140
+ decoding (i.e. autoregressive generation) during training or inference.
141
+ start_pos (int): Starting offset for positional embeddings. Used for streaming inference to maintain
142
+ correct positional indices. Default is 0.
143
+ Returns:
144
+ dict[str, torch.Tensor]: Dictionary of output tensors each output head, each with shape (batch, seq_len,
145
+ output_size).
146
+ """
147
+ B, T, _ = src.shape
148
+ device = src.device
149
+
150
+ x = self.input_proj(src)
151
+ pos_ids = torch.arange(start_pos, start_pos + T, device=device).expand(B, T)
152
+ x = x + self.pos_embedding(pos_ids)
153
+ x = self.dropout(x)
154
+
155
+ memory = self.encoder(x, mask=src_mask)
156
+
157
+ if self.decoder is not None:
158
+ if tgt is None:
159
+ tgt = self.start_token.expand(B, -1, -1).to(device)
160
+ tgt_proj = self.output_to_hidden(tgt)
161
+ tgt_pos_ids = torch.arange(tgt.shape[1], device=device).expand(
162
+ B, tgt.shape[1]
163
+ )
164
+ tgt_proj = tgt_proj + self.pos_embedding(tgt_pos_ids)
165
+ tgt_proj = self.dropout(tgt_proj)
166
+ out = self.decoder(
167
+ tgt_proj,
168
+ memory,
169
+ tgt_mask=tgt_mask,
170
+ memory_mask=src_mask,
171
+ )
172
+ else:
173
+ out = memory
174
+
175
+ return {name: head(out) for name, head in self.heads.items()}
@@ -0,0 +1 @@
1
+ # Use of this module is deprecated. Please use `ezmsg.learn.model` or `ezmsg.learn.process` instead.
@@ -0,0 +1,6 @@
1
+ from ..model.mlp_old import MLP as MLP
2
+ from ..process.mlp_old import (
3
+ MLPSettings as MLPSettings,
4
+ MLPState as MLPState,
5
+ MLPProcessor as MLPProcessor,
6
+ )
File without changes
@@ -0,0 +1,157 @@
1
+ from dataclasses import field
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import river.optim
6
+ import river.linear_model
7
+ import sklearn.base
8
+ import ezmsg.core as ez
9
+ from ezmsg.sigproc.sampler import SampleMessage
10
+ from ezmsg.sigproc.base import (
11
+ processor_state,
12
+ BaseAdaptiveTransformer,
13
+ BaseAdaptiveTransformerUnit,
14
+ )
15
+ from ezmsg.util.messages.axisarray import AxisArray, replace
16
+
17
+ from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor
18
+
19
+
20
+ class AdaptiveLinearRegressorSettings(ez.Settings):
21
+ model_type: AdaptiveLinearRegressor = AdaptiveLinearRegressor.LINEAR
22
+ settings_path: str | None = None
23
+ model_kwargs: dict = field(default_factory=dict)
24
+
25
+
26
+ @processor_state
27
+ class AdaptiveLinearRegressorState:
28
+ template: AxisArray | None = None
29
+ model: river.linear_model.base.GLM | sklearn.base.RegressorMixin | None = None
30
+
31
+
32
+ class AdaptiveLinearRegressorTransformer(
33
+ BaseAdaptiveTransformer[
34
+ AdaptiveLinearRegressorSettings,
35
+ AxisArray,
36
+ AxisArray,
37
+ AdaptiveLinearRegressorState,
38
+ ]
39
+ ):
40
+ def __init__(self, *args, **kwargs):
41
+ super().__init__(*args, **kwargs)
42
+ self.settings = replace(
43
+ self.settings, model_type=AdaptiveLinearRegressor(self.settings.model_type)
44
+ )
45
+ b_river = self.settings.model_type in [
46
+ AdaptiveLinearRegressor.LINEAR,
47
+ AdaptiveLinearRegressor.LOGISTIC,
48
+ ]
49
+ if b_river:
50
+ self.settings.model_kwargs["l2"] = self.settings.model_kwargs.get("l2", 0.0)
51
+ if "learn_rate" in self.settings.model_kwargs:
52
+ self.settings.model_kwargs["optimizer"] = river.optim.SGD(
53
+ self.settings.model_kwargs.pop("learn_rate")
54
+ )
55
+
56
+ if self.settings.settings_path is not None:
57
+ # Load model from file
58
+ import pickle
59
+
60
+ with open(self.settings.settings_path, "rb") as f:
61
+ self.state.model = pickle.load(f)
62
+
63
+ if b_river:
64
+ # Override with kwargs?!
65
+ self.state.model.l2 = self.settings.model_kwargs["l2"]
66
+ if "optimizer" in self.settings.model_kwargs:
67
+ self.state.model.optimizer = self.settings.model_kwargs["optimizer"]
68
+ else:
69
+ print("TODO: Override sklearn model with kwargs")
70
+ else:
71
+ # Build model from scratch.
72
+ regressor_klass = get_regressor(
73
+ RegressorType.ADAPTIVE, self.settings.model_type
74
+ )
75
+ self.state.model = regressor_klass(**self.settings.model_kwargs)
76
+
77
+ def _hash_message(self, message: AxisArray) -> int:
78
+ # So far, nothing to reset so hash can be constant.
79
+ return -1
80
+
81
+ def _reset_state(self, message: AxisArray) -> None:
82
+ # So far, there is nothing to reset.
83
+ # .model is initialized in __init__
84
+ # .template is updated in partial_fit
85
+ pass
86
+
87
+ def partial_fit(self, message: SampleMessage) -> None:
88
+ if np.any(np.isnan(message.sample.data)):
89
+ return
90
+
91
+ if self.settings.model_type in [
92
+ AdaptiveLinearRegressor.LINEAR,
93
+ AdaptiveLinearRegressor.LOGISTIC,
94
+ ]:
95
+ x = pd.DataFrame.from_dict(
96
+ {
97
+ k: v
98
+ for k, v in zip(
99
+ message.sample.axes["ch"].data, message.sample.data.T
100
+ )
101
+ }
102
+ )
103
+ y = pd.Series(
104
+ data=message.trigger.value.data[:, 0],
105
+ name=message.trigger.value.axes["ch"].data[0],
106
+ )
107
+ self.state.model.learn_many(x, y)
108
+ else:
109
+ X = message.sample.data
110
+ if message.sample.get_axis_idx("time") != 0:
111
+ X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0)
112
+ self.state.model.partial_fit(X, message.trigger.value.data)
113
+
114
+ self.state.template = replace(
115
+ message.trigger.value,
116
+ data=np.empty_like(message.trigger.value.data),
117
+ key=message.trigger.value.key + "_pred",
118
+ )
119
+
120
+ def _process(self, message: AxisArray) -> AxisArray | None:
121
+ if self.state.template is None:
122
+ return AxisArray(np.array([]), dims=[""])
123
+
124
+ if not np.any(np.isnan(message.data)):
125
+ if self.settings.model_type in [
126
+ AdaptiveLinearRegressor.LINEAR,
127
+ AdaptiveLinearRegressor.LOGISTIC,
128
+ ]:
129
+ # convert msg_in.data to something appropriate for river
130
+ x = pd.DataFrame.from_dict(
131
+ {k: v for k, v in zip(message.axes["ch"].data, message.data.T)}
132
+ )
133
+ preds = self.state.model.predict_many(x).values
134
+ else:
135
+ preds = self.state.model.predict(message.data)
136
+ return replace(
137
+ self.state.template,
138
+ data=preds.reshape((len(preds), -1)),
139
+ axes={
140
+ **self.state.template.axes,
141
+ "time": replace(
142
+ message.axes["time"],
143
+ offset=message.axes["time"].offset,
144
+ ),
145
+ },
146
+ )
147
+
148
+
149
+ class AdaptiveLinearRegressorUnit(
150
+ BaseAdaptiveTransformerUnit[
151
+ AdaptiveLinearRegressorSettings,
152
+ AxisArray,
153
+ AxisArray,
154
+ AdaptiveLinearRegressorTransformer,
155
+ ]
156
+ ):
157
+ SETTINGS = AdaptiveLinearRegressorSettings
@@ -0,0 +1,173 @@
1
+ import inspect
2
+ import json
3
+ from pathlib import Path
4
+ import typing
5
+
6
+ import ezmsg.core as ez
7
+ import torch
8
+
9
+
10
+ class ModelInitMixin:
11
+ """
12
+ Mixin class to support model initialization from:
13
+ 1. Setting parameters
14
+ 2. Config file
15
+ 3. Checkpoint file
16
+ """
17
+
18
+ @staticmethod
19
+ def _merge_config(model_kwargs: dict, config) -> None:
20
+ """
21
+ Mutate the model_kwargs dictionary with the config parameters.
22
+ Args:
23
+ model_kwargs: Original to-be-mutated model kwargs.
24
+ config: Update config parameters.
25
+
26
+ Returns:
27
+ None because model_kwargs is mutated in place.
28
+ """
29
+ if "model_params" in config:
30
+ config = config["model_params"]
31
+ # Update model_kwargs with config parameters
32
+ for key, value in config.items():
33
+ if key in model_kwargs:
34
+ if model_kwargs[key] != value:
35
+ ez.logger.warning(
36
+ f"Config parameter {key} ({value}) differs from settings ({model_kwargs[key]})."
37
+ )
38
+ else:
39
+ ez.logger.warning(f"Config parameter {key} is not in model_kwargs.")
40
+ model_kwargs[key] = value
41
+
42
+ def _filter_model_kwargs(self, model_class, kwargs: dict) -> dict:
43
+ valid_params = inspect.signature(model_class.__init__).parameters
44
+ filtered_out = set(kwargs.keys()) - {k for k in valid_params if k != "self"}
45
+ if filtered_out:
46
+ ez.logger.warning(
47
+ f"Ignoring unexpected model parameters not accepted by {model_class.__name__} constructor: {sorted(filtered_out)}"
48
+ )
49
+ # Keep all valid parameters, including None values, so checkpoint-inferred values can overwrite them
50
+ return {k: v for k, v in kwargs.items() if k in valid_params and k != "self"}
51
+
52
+ def _init_model(
53
+ self,
54
+ model_class,
55
+ params: dict[str, typing.Any] | None = None,
56
+ config_path: str | None = None,
57
+ checkpoint_path: str | None = None,
58
+ device: str = "cpu",
59
+ state_dict_prefix: str | None = None,
60
+ weights_only: bool | None = None,
61
+ ) -> torch.nn.Module:
62
+ """
63
+ Args:
64
+ model_class: The class of the model to be initialized.
65
+ params: A dictionary of setting parameters to be used for model initialization.
66
+ config_path: Path to a JSON config file to update model parameters.
67
+ checkpoint_path: Path to a checkpoint file to load model weights and possibly config.
68
+
69
+ Returns:
70
+ The initialized model.
71
+ The model will be initialized with the correct config and weights.
72
+
73
+ """
74
+ # Model parameters are taken from multiple sources, in ascending priority:
75
+ # 1. Setting parameters
76
+ # 2. Config file if provided
77
+ # 3. "config" entry in checkpoint file if checkpoint file provided and config present
78
+ # 4. Sizes of weights in checkpoint file if provided
79
+
80
+ # Get configs from setting params.
81
+ model_kwargs = params or {}
82
+ state_dict = None
83
+
84
+ # Check if a config file is provided and if so use that to update kwargs (with warnings).
85
+ if config_path:
86
+ config_path = Path(config_path)
87
+ if not config_path.exists():
88
+ ez.logger.error(f"Config path {config_path} does not exist.")
89
+ raise FileNotFoundError(f"Config path {config_path} does not exist.")
90
+ try:
91
+ with open(config_path, "r") as f:
92
+ config = json.load(f)
93
+ self._merge_config(model_kwargs, config)
94
+ except Exception as e:
95
+ raise RuntimeError(
96
+ f"Failed to load config from {config_path}: {str(e)}"
97
+ )
98
+
99
+ # If a checkpoint file is provided, load it.
100
+ if checkpoint_path:
101
+ checkpoint_path = Path(checkpoint_path)
102
+ if not checkpoint_path.exists():
103
+ ez.logger.error(f"Checkpoint path {checkpoint_path} does not exist.")
104
+ raise FileNotFoundError(
105
+ f"Checkpoint path {checkpoint_path} does not exist."
106
+ )
107
+ try:
108
+ checkpoint = torch.load(
109
+ checkpoint_path, map_location=device, weights_only=weights_only
110
+ )
111
+
112
+ if "config" in checkpoint:
113
+ config = checkpoint["config"]
114
+ self._merge_config(model_kwargs, config)
115
+
116
+ # Load the model weights and infer the config.
117
+ state_dict = checkpoint
118
+ if "model_state_dict" in checkpoint:
119
+ state_dict = checkpoint["model_state_dict"]
120
+ elif "state_dict" in checkpoint:
121
+ # This is for backward compatibility with older checkpoints
122
+ # that used "state_dict" instead of "model_state_dict"
123
+ state_dict = checkpoint["state_dict"]
124
+ infer_config = getattr(
125
+ model_class,
126
+ "infer_config_from_state_dict",
127
+ lambda _state_dict: {}, # Default to empty dict if not defined
128
+ )
129
+ infer_kwargs = (
130
+ {"rnn_type": model_kwargs["rnn_type"]}
131
+ if "rnn_type" in model_kwargs
132
+ else {}
133
+ )
134
+ self._merge_config(
135
+ model_kwargs,
136
+ infer_config(state_dict, **infer_kwargs),
137
+ )
138
+
139
+ except Exception as e:
140
+ raise RuntimeError(
141
+ f"Failed to load checkpoint from {checkpoint_path}: {str(e)}"
142
+ )
143
+
144
+ # Filter model_kwargs to only include valid parameters for the model class
145
+ filtered_kwargs = self._filter_model_kwargs(model_class, model_kwargs)
146
+
147
+ # Remove None values from filtered_kwargs to avoid passing them to the model constructor
148
+ # This should only happen for parameters that weren't inferred from the checkpoint
149
+ final_kwargs = {k: v for k, v in filtered_kwargs.items() if v is not None}
150
+
151
+ # Create the model with the final kwargs
152
+ model = model_class(**final_kwargs)
153
+
154
+ # Finally, load the weights.
155
+ if state_dict:
156
+ if state_dict_prefix:
157
+ # If a prefix is provided, filter the state_dict keys
158
+ state_dict = {
159
+ k[len(state_dict_prefix) :]: v
160
+ for k, v in state_dict.items()
161
+ if k.startswith(state_dict_prefix)
162
+ }
163
+ # Load the model weights
164
+ missing, unexpected = model.load_state_dict(
165
+ state_dict, strict=False, assign=True
166
+ )
167
+ if missing or unexpected:
168
+ ez.logger.warning(
169
+ f"Partial load: missing keys: {missing}, unexpected keys: {unexpected}"
170
+ )
171
+
172
+ model.to(device)
173
+ return model