lite-mamba 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,95 @@
1
+ Metadata-Version: 2.4
2
+ Name: lite-mamba
3
+ Version: 0.1.4
4
+ Summary: Pure-PyTorch lightweight Mamba with multi-dilated causal conv front-end
5
+ Home-page: https://github.com/Mrrobi/lite_mamba
6
+ Author: Md Robiuddin
7
+ Author-email: mrrobi040@gmail.com
8
+ License: Apache-2.0
9
+ Project-URL: Homepage, https://github.com/Mrrobi/lite_mamba
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3 :: Only
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Intended Audience :: Science/Research
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Requires-Python: >=3.9
20
+ Description-Content-Type: text/markdown
21
+ Requires-Dist: torch>=2.0
22
+ Requires-Dist: einops>=0.6
23
+ Dynamic: author
24
+ Dynamic: author-email
25
+ Dynamic: classifier
26
+ Dynamic: description
27
+ Dynamic: description-content-type
28
+ Dynamic: home-page
29
+ Dynamic: license
30
+ Dynamic: project-url
31
+ Dynamic: requires-dist
32
+ Dynamic: requires-python
33
+ Dynamic: summary
34
+
35
+ # lite-mamba
36
+
37
+ A minimal, pure-PyTorch version of Mamba with a multi-dilated causal depthwise conv front-end. No CUDA/Triton build needed; works on CPU or GPU with standard PyTorch ops.
38
+
39
+ ## Install
40
+ ```bash
41
+ pip install torch einops
42
+ pip install lite-mamba
43
+ ```
44
+
45
+ ## Usage
46
+ ```python
47
+ from lite_mamba import Mamba
48
+ import torch
49
+
50
+ x = torch.randn(2, 128, 512) # (batch, seq, d_model)
51
+ m = Mamba(d_model=512, d_conv=3, conv_dilations=(1,2,4,8))
52
+ y = m(x)
53
+ print(y.shape) # (2, 128, 512)
54
+ ```
55
+
56
+ ## API quick reference
57
+ `Mamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=False, layer_idx=None, device=None, dtype=None)`
58
+
59
+ - `d_model` (int, required): input/output embedding size.
60
+ - `d_state` (int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.
61
+ - `d_conv` (int, default 4): depthwise conv kernel size for each branch.
62
+ - `conv_dilations` (tuple[int], default `(1,)`): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is `(d_conv-1)*dilation`.
63
+ - `expand` (float, default 2): inner width multiplier; sets `d_inner = expand * d_model`.
64
+ - `dt_rank` (int or "auto", default "auto"): rank of delta projection. "auto" sets `ceil(d_model/16)`.
65
+ - `dt_min`, `dt_max` (float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.
66
+ - `dt_init` ("random" | "constant", default "random") and `dt_scale`, `dt_init_floor`: control delta init magnitude/stability.
67
+ - `conv_bias` (bool, default True): include bias in depthwise convs.
68
+ - `bias` (bool, default False): include bias in input/output linear projections.
69
+ - `use_fast_path` (bool): ignored in this pure-PyTorch build; kept for API compatibility.
70
+ - `layer_idx` (int | None): identifier for streaming cache registration; required when using `allocate_inference_cache` + `inference_params`.
71
+ - `device`, `dtype`: standard module factory kwargs.
72
+
73
+ ### Inference / streaming helpers
74
+ - `allocate_inference_cache(batch_size, max_seqlen, dtype=None)`: preallocates conv and SSM state buffers for step-wise decoding.
75
+ - `step(hidden_states, conv_state, ssm_state)`: single-token forward (expects `hidden_states` with shape `(B, 1, d_model)`).
76
+ - `forward(..., inference_params)`: if `inference_params` has cached states (with `key_value_memory_dict` and `seqlen_offset`), uses them for streaming.
77
+
78
+ ## Highlights
79
+ - Multi-branch causal dilated convs (weighted sum via learned gates).
80
+ - Pure Python: no custom C++/CUDA or Triton kernels.
81
+ - Streaming support via per-branch conv states and SSM state caching.
82
+
83
+ ## Practical setups
84
+ - **Local modeling / small context**: `d_conv=3`, `conv_dilations=(1,2,4)`, `d_state=8–16`, `expand=2`.
85
+ - **Longer context**: widen `conv_dilations` (e.g., `(1,2,4,8,16)`) or increase `d_state` to 32; expect higher memory/compute.
86
+ - **Streaming/AR decoding**: call `allocate_inference_cache` once per layer, pass `inference_params` during forward; use `step` inside your generation loop.
87
+ - **Stability first**: keep `dt_min` >= 1e-4 and `dt_init_floor` small; leave defaults unless you observe drift or exploding activations.
88
+
89
+ ## Notes
90
+ - Set different `conv_dilations` to adjust receptive field; keep kernels small (e.g., 3–5) to avoid excessive padding.
91
+ - `use_fast_path` flag is ignored here (kept for API compatibility).
92
+ - Reference selective scan is implemented in PyTorch for portability; faster fused kernels are omitted intentionally.
93
+
94
+ ## License
95
+ Apache-2.0
@@ -0,0 +1,61 @@
1
+ # lite-mamba
2
+
3
+ A minimal, pure-PyTorch version of Mamba with a multi-dilated causal depthwise conv front-end. No CUDA/Triton build needed; works on CPU or GPU with standard PyTorch ops.
4
+
5
+ ## Install
6
+ ```bash
7
+ pip install torch einops
8
+ pip install lite-mamba
9
+ ```
10
+
11
+ ## Usage
12
+ ```python
13
+ from lite_mamba import Mamba
14
+ import torch
15
+
16
+ x = torch.randn(2, 128, 512) # (batch, seq, d_model)
17
+ m = Mamba(d_model=512, d_conv=3, conv_dilations=(1,2,4,8))
18
+ y = m(x)
19
+ print(y.shape) # (2, 128, 512)
20
+ ```
21
+
22
+ ## API quick reference
23
+ `Mamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=False, layer_idx=None, device=None, dtype=None)`
24
+
25
+ - `d_model` (int, required): input/output embedding size.
26
+ - `d_state` (int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.
27
+ - `d_conv` (int, default 4): depthwise conv kernel size for each branch.
28
+ - `conv_dilations` (tuple[int], default `(1,)`): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is `(d_conv-1)*dilation`.
29
+ - `expand` (float, default 2): inner width multiplier; sets `d_inner = expand * d_model`.
30
+ - `dt_rank` (int or "auto", default "auto"): rank of delta projection. "auto" sets `ceil(d_model/16)`.
31
+ - `dt_min`, `dt_max` (float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.
32
+ - `dt_init` ("random" | "constant", default "random") and `dt_scale`, `dt_init_floor`: control delta init magnitude/stability.
33
+ - `conv_bias` (bool, default True): include bias in depthwise convs.
34
+ - `bias` (bool, default False): include bias in input/output linear projections.
35
+ - `use_fast_path` (bool): ignored in this pure-PyTorch build; kept for API compatibility.
36
+ - `layer_idx` (int | None): identifier for streaming cache registration; required when using `allocate_inference_cache` + `inference_params`.
37
+ - `device`, `dtype`: standard module factory kwargs.
38
+
39
+ ### Inference / streaming helpers
40
+ - `allocate_inference_cache(batch_size, max_seqlen, dtype=None)`: preallocates conv and SSM state buffers for step-wise decoding.
41
+ - `step(hidden_states, conv_state, ssm_state)`: single-token forward (expects `hidden_states` with shape `(B, 1, d_model)`).
42
+ - `forward(..., inference_params)`: if `inference_params` has cached states (with `key_value_memory_dict` and `seqlen_offset`), uses them for streaming.
43
+
44
+ ## Highlights
45
+ - Multi-branch causal dilated convs (weighted sum via learned gates).
46
+ - Pure Python: no custom C++/CUDA or Triton kernels.
47
+ - Streaming support via per-branch conv states and SSM state caching.
48
+
49
+ ## Practical setups
50
+ - **Local modeling / small context**: `d_conv=3`, `conv_dilations=(1,2,4)`, `d_state=8–16`, `expand=2`.
51
+ - **Longer context**: widen `conv_dilations` (e.g., `(1,2,4,8,16)`) or increase `d_state` to 32; expect higher memory/compute.
52
+ - **Streaming/AR decoding**: call `allocate_inference_cache` once per layer, pass `inference_params` during forward; use `step` inside your generation loop.
53
+ - **Stability first**: keep `dt_min` >= 1e-4 and `dt_init_floor` small; leave defaults unless you observe drift or exploding activations.
54
+
55
+ ## Notes
56
+ - Set different `conv_dilations` to adjust receptive field; keep kernels small (e.g., 3–5) to avoid excessive padding.
57
+ - `use_fast_path` flag is ignored here (kept for API compatibility).
58
+ - Reference selective scan is implemented in PyTorch for portability; faster fused kernels are omitted intentionally.
59
+
60
+ ## License
61
+ Apache-2.0
@@ -0,0 +1,3 @@
1
+ __version__ = "0.1.0"
2
+
3
+ from .mamba_simple import Mamba
@@ -0,0 +1,57 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def _apply_activation(x, activation):
6
+ if activation is None or activation == "identity":
7
+ return x
8
+ if activation in ("silu", "swish"):
9
+ return F.silu(x)
10
+ if activation == "relu":
11
+ return F.relu(x)
12
+ raise ValueError(f"Unsupported activation: {activation}")
13
+
14
+
15
+ def causal_dilated_conv1d_fn(x, weight, bias=None, activation=None, dilation=1):
16
+ """
17
+ Depthwise causal 1D convolution with dilation.
18
+ x: (B, D, L)
19
+ weight: (D, W)
20
+ bias: (D,) or None
21
+ """
22
+ if dilation < 1:
23
+ raise ValueError(f"dilation must be >= 1, got {dilation}")
24
+ pad = dilation * (weight.shape[-1] - 1)
25
+ x = F.pad(x, (pad, 0))
26
+ y = F.conv1d(
27
+ x,
28
+ weight.unsqueeze(1),
29
+ bias=bias,
30
+ stride=1,
31
+ padding=0,
32
+ dilation=dilation,
33
+ groups=weight.shape[0],
34
+ )
35
+ return _apply_activation(y, activation)
36
+
37
+
38
+ def causal_dilated_conv1d_update(x, conv_state, weight, bias=None, activation=None, dilation=1):
39
+ """
40
+ Single-step causal dilated conv update.
41
+ x: (B, D)
42
+ conv_state: (B, D, S) where S = dilation * (W - 1) + 1
43
+ weight: (D, W)
44
+ bias: (D,) or None
45
+ """
46
+ if dilation < 1:
47
+ raise ValueError(f"dilation must be >= 1, got {dilation}")
48
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
49
+ conv_state[:, :, -1] = x
50
+
51
+ idx = torch.arange(0, weight.shape[-1], device=conv_state.device) * dilation
52
+ pos = conv_state.shape[-1] - 1 - idx
53
+ values = conv_state.index_select(-1, pos)
54
+ y = torch.sum(values * weight.unsqueeze(0), dim=-1)
55
+ if bias is not None:
56
+ y = y + bias
57
+ return _apply_activation(y, activation)
@@ -0,0 +1,282 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ from .selective_scan import selective_scan_fn
12
+ from .causal_dilated_conv1d import causal_dilated_conv1d_fn, causal_dilated_conv1d_update
13
+
14
+
15
+ class Mamba(nn.Module):
16
+ def __init__(
17
+ self,
18
+ d_model,
19
+ d_state=16,
20
+ d_conv=4,
21
+ conv_dilations=(1,),
22
+ expand=2,
23
+ dt_rank="auto",
24
+ dt_min=0.001,
25
+ dt_max=0.1,
26
+ dt_init="random",
27
+ dt_scale=1.0,
28
+ dt_init_floor=1e-4,
29
+ conv_bias=True,
30
+ bias=False,
31
+ use_fast_path=False, # kept for API; ignored in this pure-Python build
32
+ layer_idx=None,
33
+ device=None,
34
+ dtype=None,
35
+ ):
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super().__init__()
38
+ self.d_model = d_model
39
+ self.d_state = d_state
40
+ self.d_conv = d_conv
41
+ self.conv_dilations = tuple(conv_dilations)
42
+ self.num_conv_branches = len(self.conv_dilations)
43
+ self.conv_state_lens = [(self.d_conv - 1) * d + 1 for d in self.conv_dilations]
44
+ self.expand = expand
45
+ self.d_inner = int(self.expand * self.d_model)
46
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
47
+ self.use_fast_path = False
48
+ self.layer_idx = layer_idx
49
+
50
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
51
+
52
+ self.conv1d_layers = nn.ModuleList(
53
+ [
54
+ nn.Conv1d(
55
+ in_channels=self.d_inner,
56
+ out_channels=self.d_inner,
57
+ bias=conv_bias,
58
+ kernel_size=d_conv,
59
+ groups=self.d_inner,
60
+ padding=d * (d_conv - 1),
61
+ dilation=d,
62
+ **factory_kwargs,
63
+ )
64
+ for d in self.conv_dilations
65
+ ]
66
+ )
67
+ self.conv1d = self.conv1d_layers[0]
68
+ self.conv_gates = nn.Parameter(torch.ones(self.num_conv_branches))
69
+
70
+ self.activation = "silu"
71
+ self.act = nn.SiLU()
72
+
73
+ self.x_proj = nn.Linear(
74
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
75
+ )
76
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
77
+
78
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
79
+ if dt_init == "constant":
80
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
81
+ elif dt_init == "random":
82
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
83
+ else:
84
+ raise NotImplementedError
85
+
86
+ dt = torch.exp(
87
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
88
+ + math.log(dt_min)
89
+ ).clamp(min=dt_init_floor)
90
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
91
+ with torch.no_grad():
92
+ self.dt_proj.bias.copy_(inv_dt)
93
+ self.dt_proj.bias._no_reinit = True
94
+
95
+ A = repeat(
96
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
97
+ "n -> d n",
98
+ d=self.d_inner,
99
+ ).contiguous()
100
+ A_log = torch.log(A)
101
+ self.A_log = nn.Parameter(A_log)
102
+ self.A_log._no_weight_decay = True
103
+
104
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device))
105
+ self.D._no_weight_decay = True
106
+
107
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
108
+
109
+ def forward(self, hidden_states, inference_params=None):
110
+ batch, seqlen, dim = hidden_states.shape
111
+
112
+ conv_state, ssm_state = None, None
113
+ if inference_params is not None:
114
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
115
+ if inference_params.seqlen_offset > 0:
116
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
117
+ return out
118
+
119
+ xz = rearrange(
120
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
121
+ "d (b l) -> b d l",
122
+ l=seqlen,
123
+ )
124
+ if self.in_proj.bias is not None:
125
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
126
+
127
+ A = -torch.exp(self.A_log.float())
128
+ use_multi_branch = self.num_conv_branches > 1
129
+
130
+ x, z = xz.chunk(2, dim=1)
131
+ # Convs
132
+ conv_outputs = []
133
+ if conv_state is not None:
134
+ for state, conv_layer, dilation in zip(conv_state, self.conv1d_layers, self.conv_dilations):
135
+ pad_len = state.shape[-1] - x.shape[-1]
136
+ if pad_len >= 0:
137
+ state.copy_(F.pad(x, (pad_len, 0)))
138
+ else:
139
+ state.copy_(x[..., -state.shape[-1] :])
140
+ if dilation > 1:
141
+ xi = causal_dilated_conv1d_fn(
142
+ x=x,
143
+ weight=rearrange(conv_layer.weight, "d 1 w -> d w"),
144
+ bias=conv_layer.bias,
145
+ activation=self.activation,
146
+ dilation=dilation,
147
+ )
148
+ else:
149
+ xi = self.act(conv_layer(x)[..., :seqlen])
150
+ conv_outputs.append(xi)
151
+ else:
152
+ for conv_layer, dilation in zip(self.conv1d_layers, self.conv_dilations):
153
+ if dilation > 1:
154
+ xi = causal_dilated_conv1d_fn(
155
+ x=x,
156
+ weight=rearrange(conv_layer.weight, "d 1 w -> d w"),
157
+ bias=conv_layer.bias,
158
+ activation=self.activation,
159
+ dilation=dilation,
160
+ )
161
+ else:
162
+ xi = self.act(conv_layer(x)[..., :seqlen])
163
+ conv_outputs.append(xi)
164
+ gate = torch.softmax(self.conv_gates, dim=0)
165
+ x = sum(g * xi for g, xi in zip(gate, conv_outputs))
166
+
167
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
168
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
169
+ dt = self.dt_proj.weight @ dt.t()
170
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
171
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
172
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
173
+ y = selective_scan_fn(
174
+ x,
175
+ dt,
176
+ A,
177
+ B,
178
+ C,
179
+ self.D.float(),
180
+ z=z,
181
+ delta_bias=self.dt_proj.bias.float(),
182
+ delta_softplus=True,
183
+ return_last_state=ssm_state is not None,
184
+ )
185
+ if ssm_state is not None:
186
+ y, last_state = y
187
+ ssm_state.copy_(last_state)
188
+ y = rearrange(y, "b d l -> b l d")
189
+ out = self.out_proj(y)
190
+ return out
191
+
192
+ def step(self, hidden_states, conv_state, ssm_state):
193
+ dtype = hidden_states.dtype
194
+ assert hidden_states.shape[1] == 1
195
+ xz = self.in_proj(hidden_states.squeeze(1))
196
+ x, z = xz.chunk(2, dim=-1)
197
+
198
+ conv_outputs = []
199
+ for state, conv_layer, dilation in zip(conv_state, self.conv1d_layers, self.conv_dilations):
200
+ if dilation > 1:
201
+ xi = causal_dilated_conv1d_update(
202
+ x,
203
+ state,
204
+ rearrange(conv_layer.weight, "d 1 w -> d w"),
205
+ conv_layer.bias,
206
+ self.activation,
207
+ dilation=dilation,
208
+ )
209
+ else:
210
+ state.copy_(torch.roll(state, shifts=-1, dims=-1))
211
+ state[:, :, -1] = x
212
+ xi = torch.sum(state * rearrange(conv_layer.weight, "d 1 w -> d w"), dim=-1)
213
+ if conv_layer.bias is not None:
214
+ xi = xi + conv_layer.bias
215
+ xi = self.act(xi).to(dtype=dtype)
216
+ conv_outputs.append(xi)
217
+ gate = torch.softmax(self.conv_gates, dim=0)
218
+ x = sum(g * xi for g, xi in zip(gate, conv_outputs))
219
+
220
+ x_db = self.x_proj(x)
221
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
222
+ dt = F.linear(dt, self.dt_proj.weight)
223
+ A = -torch.exp(self.A_log.float())
224
+
225
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
226
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
227
+ dB = torch.einsum("bd,bn->bdn", dt, B)
228
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
229
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
230
+ y = y + self.D.to(dtype) * x
231
+ y = y * self.act(z)
232
+
233
+ out = self.out_proj(y)
234
+ return out.unsqueeze(1), conv_state, ssm_state
235
+
236
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
237
+ device = self.out_proj.weight.device
238
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
239
+ conv_state = [
240
+ torch.zeros(
241
+ batch_size,
242
+ self.d_model * self.expand,
243
+ state_len,
244
+ device=device,
245
+ dtype=conv_dtype,
246
+ )
247
+ for state_len in self.conv_state_lens
248
+ ]
249
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
250
+ ssm_state = torch.zeros(
251
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
252
+ )
253
+ return conv_state, ssm_state
254
+
255
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
256
+ assert self.layer_idx is not None
257
+ if self.layer_idx not in inference_params.key_value_memory_dict:
258
+ conv_state = [
259
+ torch.zeros(
260
+ batch_size,
261
+ self.d_model * self.expand,
262
+ state_len,
263
+ device=self.conv1d.weight.device,
264
+ dtype=self.conv1d.weight.dtype,
265
+ )
266
+ for state_len in self.conv_state_lens
267
+ ]
268
+ ssm_state = torch.zeros(
269
+ batch_size,
270
+ self.d_model * self.expand,
271
+ self.d_state,
272
+ device=self.dt_proj.weight.device,
273
+ dtype=self.dt_proj.weight.dtype,
274
+ )
275
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
276
+ else:
277
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
278
+ if initialize_states:
279
+ for s in conv_state:
280
+ s.zero_()
281
+ ssm_state.zero_()
282
+ return conv_state, ssm_state
@@ -0,0 +1,66 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange, repeat
4
+
5
+ # Pure-PyTorch reference selective scan (no CUDA/Triton deps)
6
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
7
+ return_last_state=False):
8
+ return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
9
+
10
+
11
+ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
12
+ return_last_state=False):
13
+ dtype_in = u.dtype
14
+ u = u.float()
15
+ delta = delta.float()
16
+ if delta_bias is not None:
17
+ delta = delta + delta_bias[..., None].float()
18
+ if delta_softplus:
19
+ delta = F.softplus(delta)
20
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
21
+ is_variable_B = B.dim() >= 3
22
+ is_variable_C = C.dim() >= 3
23
+ if A.is_complex():
24
+ if is_variable_B:
25
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
26
+ if is_variable_C:
27
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
28
+ else:
29
+ B = B.float()
30
+ C = C.float()
31
+ x = A.new_zeros((batch, dim, dstate))
32
+ ys = []
33
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
34
+ if not is_variable_B:
35
+ deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
36
+ else:
37
+ if B.dim() == 3:
38
+ deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
39
+ else:
40
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
41
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
42
+ if is_variable_C and C.dim() == 4:
43
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
44
+ last_state = None
45
+ for i in range(u.shape[2]):
46
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
47
+ if not is_variable_C:
48
+ y = torch.einsum('bdn,dn->bd', x, C)
49
+ else:
50
+ if C.dim() == 3:
51
+ y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
52
+ else:
53
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
54
+ if i == u.shape[2] - 1:
55
+ last_state = x
56
+ if y.is_complex():
57
+ y = y.real * 2
58
+ ys.append(y)
59
+ y = torch.stack(ys, dim=2) # (batch dim L)
60
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
61
+ if z is not None:
62
+ out = out * F.silu(z)
63
+ out = out.to(dtype=dtype_in)
64
+ if return_last_state:
65
+ return out, last_state
66
+ return out
@@ -0,0 +1,95 @@
1
+ Metadata-Version: 2.4
2
+ Name: lite-mamba
3
+ Version: 0.1.4
4
+ Summary: Pure-PyTorch lightweight Mamba with multi-dilated causal conv front-end
5
+ Home-page: https://github.com/Mrrobi/lite_mamba
6
+ Author: Md Robiuddin
7
+ Author-email: mrrobi040@gmail.com
8
+ License: Apache-2.0
9
+ Project-URL: Homepage, https://github.com/Mrrobi/lite_mamba
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3 :: Only
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Intended Audience :: Science/Research
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Requires-Python: >=3.9
20
+ Description-Content-Type: text/markdown
21
+ Requires-Dist: torch>=2.0
22
+ Requires-Dist: einops>=0.6
23
+ Dynamic: author
24
+ Dynamic: author-email
25
+ Dynamic: classifier
26
+ Dynamic: description
27
+ Dynamic: description-content-type
28
+ Dynamic: home-page
29
+ Dynamic: license
30
+ Dynamic: project-url
31
+ Dynamic: requires-dist
32
+ Dynamic: requires-python
33
+ Dynamic: summary
34
+
35
+ # lite-mamba
36
+
37
+ A minimal, pure-PyTorch version of Mamba with a multi-dilated causal depthwise conv front-end. No CUDA/Triton build needed; works on CPU or GPU with standard PyTorch ops.
38
+
39
+ ## Install
40
+ ```bash
41
+ pip install torch einops
42
+ pip install lite-mamba
43
+ ```
44
+
45
+ ## Usage
46
+ ```python
47
+ from lite_mamba import Mamba
48
+ import torch
49
+
50
+ x = torch.randn(2, 128, 512) # (batch, seq, d_model)
51
+ m = Mamba(d_model=512, d_conv=3, conv_dilations=(1,2,4,8))
52
+ y = m(x)
53
+ print(y.shape) # (2, 128, 512)
54
+ ```
55
+
56
+ ## API quick reference
57
+ `Mamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=False, layer_idx=None, device=None, dtype=None)`
58
+
59
+ - `d_model` (int, required): input/output embedding size.
60
+ - `d_state` (int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.
61
+ - `d_conv` (int, default 4): depthwise conv kernel size for each branch.
62
+ - `conv_dilations` (tuple[int], default `(1,)`): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is `(d_conv-1)*dilation`.
63
+ - `expand` (float, default 2): inner width multiplier; sets `d_inner = expand * d_model`.
64
+ - `dt_rank` (int or "auto", default "auto"): rank of delta projection. "auto" sets `ceil(d_model/16)`.
65
+ - `dt_min`, `dt_max` (float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.
66
+ - `dt_init` ("random" | "constant", default "random") and `dt_scale`, `dt_init_floor`: control delta init magnitude/stability.
67
+ - `conv_bias` (bool, default True): include bias in depthwise convs.
68
+ - `bias` (bool, default False): include bias in input/output linear projections.
69
+ - `use_fast_path` (bool): ignored in this pure-PyTorch build; kept for API compatibility.
70
+ - `layer_idx` (int | None): identifier for streaming cache registration; required when using `allocate_inference_cache` + `inference_params`.
71
+ - `device`, `dtype`: standard module factory kwargs.
72
+
73
+ ### Inference / streaming helpers
74
+ - `allocate_inference_cache(batch_size, max_seqlen, dtype=None)`: preallocates conv and SSM state buffers for step-wise decoding.
75
+ - `step(hidden_states, conv_state, ssm_state)`: single-token forward (expects `hidden_states` with shape `(B, 1, d_model)`).
76
+ - `forward(..., inference_params)`: if `inference_params` has cached states (with `key_value_memory_dict` and `seqlen_offset`), uses them for streaming.
77
+
78
+ ## Highlights
79
+ - Multi-branch causal dilated convs (weighted sum via learned gates).
80
+ - Pure Python: no custom C++/CUDA or Triton kernels.
81
+ - Streaming support via per-branch conv states and SSM state caching.
82
+
83
+ ## Practical setups
84
+ - **Local modeling / small context**: `d_conv=3`, `conv_dilations=(1,2,4)`, `d_state=8–16`, `expand=2`.
85
+ - **Longer context**: widen `conv_dilations` (e.g., `(1,2,4,8,16)`) or increase `d_state` to 32; expect higher memory/compute.
86
+ - **Streaming/AR decoding**: call `allocate_inference_cache` once per layer, pass `inference_params` during forward; use `step` inside your generation loop.
87
+ - **Stability first**: keep `dt_min` >= 1e-4 and `dt_init_floor` small; leave defaults unless you observe drift or exploding activations.
88
+
89
+ ## Notes
90
+ - Set different `conv_dilations` to adjust receptive field; keep kernels small (e.g., 3–5) to avoid excessive padding.
91
+ - `use_fast_path` flag is ignored here (kept for API compatibility).
92
+ - Reference selective scan is implemented in PyTorch for portability; faster fused kernels are omitted intentionally.
93
+
94
+ ## License
95
+ Apache-2.0
@@ -0,0 +1,12 @@
1
+ README.md
2
+ pyproject.toml
3
+ setup.py
4
+ lite_mamba/__init__.py
5
+ lite_mamba/causal_dilated_conv1d.py
6
+ lite_mamba/mamba_simple.py
7
+ lite_mamba/selective_scan.py
8
+ lite_mamba.egg-info/PKG-INFO
9
+ lite_mamba.egg-info/SOURCES.txt
10
+ lite_mamba.egg-info/dependency_links.txt
11
+ lite_mamba.egg-info/requires.txt
12
+ lite_mamba.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ torch>=2.0
2
+ einops>=0.6
@@ -0,0 +1 @@
1
+ lite_mamba
@@ -0,0 +1,3 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61", "wheel"]
3
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,38 @@
1
+ from pathlib import Path
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+ ROOT = Path(__file__).parent
6
+ README = (ROOT / "README.md").read_text(encoding="utf-8")
7
+
8
+ setup(
9
+ name="lite-mamba",
10
+ version="0.1.4",
11
+ description="Pure-PyTorch lightweight Mamba with multi-dilated causal conv front-end",
12
+ long_description=README,
13
+ long_description_content_type="text/markdown",
14
+ author="Md Robiuddin",
15
+ author_email="mrrobi040@gmail.com",
16
+ url="https://github.com/Mrrobi/lite_mamba",
17
+ project_urls={
18
+ "Homepage": "https://github.com/Mrrobi/lite_mamba",
19
+ },
20
+ packages=find_packages(exclude=("tests", "docs")),
21
+ python_requires=">=3.9",
22
+ install_requires=[
23
+ "torch>=2.0",
24
+ "einops>=0.6",
25
+ ],
26
+ license="Apache-2.0",
27
+ classifiers=[
28
+ "License :: OSI Approved :: Apache Software License",
29
+ "Programming Language :: Python :: 3",
30
+ "Programming Language :: Python :: 3 :: Only",
31
+ "Programming Language :: Python :: 3.9",
32
+ "Programming Language :: Python :: 3.10",
33
+ "Programming Language :: Python :: 3.11",
34
+ "Programming Language :: Python :: 3.12",
35
+ "Intended Audience :: Science/Research",
36
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
37
+ ],
38
+ )