lite-mamba 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lite_mamba-0.1.4/PKG-INFO +95 -0
- lite_mamba-0.1.4/README.md +61 -0
- lite_mamba-0.1.4/lite_mamba/__init__.py +3 -0
- lite_mamba-0.1.4/lite_mamba/causal_dilated_conv1d.py +57 -0
- lite_mamba-0.1.4/lite_mamba/mamba_simple.py +282 -0
- lite_mamba-0.1.4/lite_mamba/selective_scan.py +66 -0
- lite_mamba-0.1.4/lite_mamba.egg-info/PKG-INFO +95 -0
- lite_mamba-0.1.4/lite_mamba.egg-info/SOURCES.txt +12 -0
- lite_mamba-0.1.4/lite_mamba.egg-info/dependency_links.txt +1 -0
- lite_mamba-0.1.4/lite_mamba.egg-info/requires.txt +2 -0
- lite_mamba-0.1.4/lite_mamba.egg-info/top_level.txt +1 -0
- lite_mamba-0.1.4/pyproject.toml +3 -0
- lite_mamba-0.1.4/setup.cfg +4 -0
- lite_mamba-0.1.4/setup.py +38 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: lite-mamba
|
|
3
|
+
Version: 0.1.4
|
|
4
|
+
Summary: Pure-PyTorch lightweight Mamba with multi-dilated causal conv front-end
|
|
5
|
+
Home-page: https://github.com/Mrrobi/lite_mamba
|
|
6
|
+
Author: Md Robiuddin
|
|
7
|
+
Author-email: mrrobi040@gmail.com
|
|
8
|
+
License: Apache-2.0
|
|
9
|
+
Project-URL: Homepage, https://github.com/Mrrobi/lite_mamba
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Intended Audience :: Science/Research
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.9
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
Requires-Dist: torch>=2.0
|
|
22
|
+
Requires-Dist: einops>=0.6
|
|
23
|
+
Dynamic: author
|
|
24
|
+
Dynamic: author-email
|
|
25
|
+
Dynamic: classifier
|
|
26
|
+
Dynamic: description
|
|
27
|
+
Dynamic: description-content-type
|
|
28
|
+
Dynamic: home-page
|
|
29
|
+
Dynamic: license
|
|
30
|
+
Dynamic: project-url
|
|
31
|
+
Dynamic: requires-dist
|
|
32
|
+
Dynamic: requires-python
|
|
33
|
+
Dynamic: summary
|
|
34
|
+
|
|
35
|
+
# lite-mamba
|
|
36
|
+
|
|
37
|
+
A minimal, pure-PyTorch version of Mamba with a multi-dilated causal depthwise conv front-end. No CUDA/Triton build needed; works on CPU or GPU with standard PyTorch ops.
|
|
38
|
+
|
|
39
|
+
## Install
|
|
40
|
+
```bash
|
|
41
|
+
pip install torch einops
|
|
42
|
+
pip install lite-mamba
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Usage
|
|
46
|
+
```python
|
|
47
|
+
from lite_mamba import Mamba
|
|
48
|
+
import torch
|
|
49
|
+
|
|
50
|
+
x = torch.randn(2, 128, 512) # (batch, seq, d_model)
|
|
51
|
+
m = Mamba(d_model=512, d_conv=3, conv_dilations=(1,2,4,8))
|
|
52
|
+
y = m(x)
|
|
53
|
+
print(y.shape) # (2, 128, 512)
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## API quick reference
|
|
57
|
+
`Mamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=False, layer_idx=None, device=None, dtype=None)`
|
|
58
|
+
|
|
59
|
+
- `d_model` (int, required): input/output embedding size.
|
|
60
|
+
- `d_state` (int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.
|
|
61
|
+
- `d_conv` (int, default 4): depthwise conv kernel size for each branch.
|
|
62
|
+
- `conv_dilations` (tuple[int], default `(1,)`): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is `(d_conv-1)*dilation`.
|
|
63
|
+
- `expand` (float, default 2): inner width multiplier; sets `d_inner = expand * d_model`.
|
|
64
|
+
- `dt_rank` (int or "auto", default "auto"): rank of delta projection. "auto" sets `ceil(d_model/16)`.
|
|
65
|
+
- `dt_min`, `dt_max` (float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.
|
|
66
|
+
- `dt_init` ("random" | "constant", default "random") and `dt_scale`, `dt_init_floor`: control delta init magnitude/stability.
|
|
67
|
+
- `conv_bias` (bool, default True): include bias in depthwise convs.
|
|
68
|
+
- `bias` (bool, default False): include bias in input/output linear projections.
|
|
69
|
+
- `use_fast_path` (bool): ignored in this pure-PyTorch build; kept for API compatibility.
|
|
70
|
+
- `layer_idx` (int | None): identifier for streaming cache registration; required when using `allocate_inference_cache` + `inference_params`.
|
|
71
|
+
- `device`, `dtype`: standard module factory kwargs.
|
|
72
|
+
|
|
73
|
+
### Inference / streaming helpers
|
|
74
|
+
- `allocate_inference_cache(batch_size, max_seqlen, dtype=None)`: preallocates conv and SSM state buffers for step-wise decoding.
|
|
75
|
+
- `step(hidden_states, conv_state, ssm_state)`: single-token forward (expects `hidden_states` with shape `(B, 1, d_model)`).
|
|
76
|
+
- `forward(..., inference_params)`: if `inference_params` has cached states (with `key_value_memory_dict` and `seqlen_offset`), uses them for streaming.
|
|
77
|
+
|
|
78
|
+
## Highlights
|
|
79
|
+
- Multi-branch causal dilated convs (weighted sum via learned gates).
|
|
80
|
+
- Pure Python: no custom C++/CUDA or Triton kernels.
|
|
81
|
+
- Streaming support via per-branch conv states and SSM state caching.
|
|
82
|
+
|
|
83
|
+
## Practical setups
|
|
84
|
+
- **Local modeling / small context**: `d_conv=3`, `conv_dilations=(1,2,4)`, `d_state=8–16`, `expand=2`.
|
|
85
|
+
- **Longer context**: widen `conv_dilations` (e.g., `(1,2,4,8,16)`) or increase `d_state` to 32; expect higher memory/compute.
|
|
86
|
+
- **Streaming/AR decoding**: call `allocate_inference_cache` once per layer, pass `inference_params` during forward; use `step` inside your generation loop.
|
|
87
|
+
- **Stability first**: keep `dt_min` >= 1e-4 and `dt_init_floor` small; leave defaults unless you observe drift or exploding activations.
|
|
88
|
+
|
|
89
|
+
## Notes
|
|
90
|
+
- Set different `conv_dilations` to adjust receptive field; keep kernels small (e.g., 3–5) to avoid excessive padding.
|
|
91
|
+
- `use_fast_path` flag is ignored here (kept for API compatibility).
|
|
92
|
+
- Reference selective scan is implemented in PyTorch for portability; faster fused kernels are omitted intentionally.
|
|
93
|
+
|
|
94
|
+
## License
|
|
95
|
+
Apache-2.0
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# lite-mamba
|
|
2
|
+
|
|
3
|
+
A minimal, pure-PyTorch version of Mamba with a multi-dilated causal depthwise conv front-end. No CUDA/Triton build needed; works on CPU or GPU with standard PyTorch ops.
|
|
4
|
+
|
|
5
|
+
## Install
|
|
6
|
+
```bash
|
|
7
|
+
pip install torch einops
|
|
8
|
+
pip install lite-mamba
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Usage
|
|
12
|
+
```python
|
|
13
|
+
from lite_mamba import Mamba
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
x = torch.randn(2, 128, 512) # (batch, seq, d_model)
|
|
17
|
+
m = Mamba(d_model=512, d_conv=3, conv_dilations=(1,2,4,8))
|
|
18
|
+
y = m(x)
|
|
19
|
+
print(y.shape) # (2, 128, 512)
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## API quick reference
|
|
23
|
+
`Mamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=False, layer_idx=None, device=None, dtype=None)`
|
|
24
|
+
|
|
25
|
+
- `d_model` (int, required): input/output embedding size.
|
|
26
|
+
- `d_state` (int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.
|
|
27
|
+
- `d_conv` (int, default 4): depthwise conv kernel size for each branch.
|
|
28
|
+
- `conv_dilations` (tuple[int], default `(1,)`): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is `(d_conv-1)*dilation`.
|
|
29
|
+
- `expand` (float, default 2): inner width multiplier; sets `d_inner = expand * d_model`.
|
|
30
|
+
- `dt_rank` (int or "auto", default "auto"): rank of delta projection. "auto" sets `ceil(d_model/16)`.
|
|
31
|
+
- `dt_min`, `dt_max` (float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.
|
|
32
|
+
- `dt_init` ("random" | "constant", default "random") and `dt_scale`, `dt_init_floor`: control delta init magnitude/stability.
|
|
33
|
+
- `conv_bias` (bool, default True): include bias in depthwise convs.
|
|
34
|
+
- `bias` (bool, default False): include bias in input/output linear projections.
|
|
35
|
+
- `use_fast_path` (bool): ignored in this pure-PyTorch build; kept for API compatibility.
|
|
36
|
+
- `layer_idx` (int | None): identifier for streaming cache registration; required when using `allocate_inference_cache` + `inference_params`.
|
|
37
|
+
- `device`, `dtype`: standard module factory kwargs.
|
|
38
|
+
|
|
39
|
+
### Inference / streaming helpers
|
|
40
|
+
- `allocate_inference_cache(batch_size, max_seqlen, dtype=None)`: preallocates conv and SSM state buffers for step-wise decoding.
|
|
41
|
+
- `step(hidden_states, conv_state, ssm_state)`: single-token forward (expects `hidden_states` with shape `(B, 1, d_model)`).
|
|
42
|
+
- `forward(..., inference_params)`: if `inference_params` has cached states (with `key_value_memory_dict` and `seqlen_offset`), uses them for streaming.
|
|
43
|
+
|
|
44
|
+
## Highlights
|
|
45
|
+
- Multi-branch causal dilated convs (weighted sum via learned gates).
|
|
46
|
+
- Pure Python: no custom C++/CUDA or Triton kernels.
|
|
47
|
+
- Streaming support via per-branch conv states and SSM state caching.
|
|
48
|
+
|
|
49
|
+
## Practical setups
|
|
50
|
+
- **Local modeling / small context**: `d_conv=3`, `conv_dilations=(1,2,4)`, `d_state=8–16`, `expand=2`.
|
|
51
|
+
- **Longer context**: widen `conv_dilations` (e.g., `(1,2,4,8,16)`) or increase `d_state` to 32; expect higher memory/compute.
|
|
52
|
+
- **Streaming/AR decoding**: call `allocate_inference_cache` once per layer, pass `inference_params` during forward; use `step` inside your generation loop.
|
|
53
|
+
- **Stability first**: keep `dt_min` >= 1e-4 and `dt_init_floor` small; leave defaults unless you observe drift or exploding activations.
|
|
54
|
+
|
|
55
|
+
## Notes
|
|
56
|
+
- Set different `conv_dilations` to adjust receptive field; keep kernels small (e.g., 3–5) to avoid excessive padding.
|
|
57
|
+
- `use_fast_path` flag is ignored here (kept for API compatibility).
|
|
58
|
+
- Reference selective scan is implemented in PyTorch for portability; faster fused kernels are omitted intentionally.
|
|
59
|
+
|
|
60
|
+
## License
|
|
61
|
+
Apache-2.0
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _apply_activation(x, activation):
|
|
6
|
+
if activation is None or activation == "identity":
|
|
7
|
+
return x
|
|
8
|
+
if activation in ("silu", "swish"):
|
|
9
|
+
return F.silu(x)
|
|
10
|
+
if activation == "relu":
|
|
11
|
+
return F.relu(x)
|
|
12
|
+
raise ValueError(f"Unsupported activation: {activation}")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def causal_dilated_conv1d_fn(x, weight, bias=None, activation=None, dilation=1):
|
|
16
|
+
"""
|
|
17
|
+
Depthwise causal 1D convolution with dilation.
|
|
18
|
+
x: (B, D, L)
|
|
19
|
+
weight: (D, W)
|
|
20
|
+
bias: (D,) or None
|
|
21
|
+
"""
|
|
22
|
+
if dilation < 1:
|
|
23
|
+
raise ValueError(f"dilation must be >= 1, got {dilation}")
|
|
24
|
+
pad = dilation * (weight.shape[-1] - 1)
|
|
25
|
+
x = F.pad(x, (pad, 0))
|
|
26
|
+
y = F.conv1d(
|
|
27
|
+
x,
|
|
28
|
+
weight.unsqueeze(1),
|
|
29
|
+
bias=bias,
|
|
30
|
+
stride=1,
|
|
31
|
+
padding=0,
|
|
32
|
+
dilation=dilation,
|
|
33
|
+
groups=weight.shape[0],
|
|
34
|
+
)
|
|
35
|
+
return _apply_activation(y, activation)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def causal_dilated_conv1d_update(x, conv_state, weight, bias=None, activation=None, dilation=1):
|
|
39
|
+
"""
|
|
40
|
+
Single-step causal dilated conv update.
|
|
41
|
+
x: (B, D)
|
|
42
|
+
conv_state: (B, D, S) where S = dilation * (W - 1) + 1
|
|
43
|
+
weight: (D, W)
|
|
44
|
+
bias: (D,) or None
|
|
45
|
+
"""
|
|
46
|
+
if dilation < 1:
|
|
47
|
+
raise ValueError(f"dilation must be >= 1, got {dilation}")
|
|
48
|
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
|
|
49
|
+
conv_state[:, :, -1] = x
|
|
50
|
+
|
|
51
|
+
idx = torch.arange(0, weight.shape[-1], device=conv_state.device) * dilation
|
|
52
|
+
pos = conv_state.shape[-1] - 1 - idx
|
|
53
|
+
values = conv_state.index_select(-1, pos)
|
|
54
|
+
y = torch.sum(values * weight.unsqueeze(0), dim=-1)
|
|
55
|
+
if bias is not None:
|
|
56
|
+
y = y + bias
|
|
57
|
+
return _apply_activation(y, activation)
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from einops import rearrange, repeat
|
|
10
|
+
|
|
11
|
+
from .selective_scan import selective_scan_fn
|
|
12
|
+
from .causal_dilated_conv1d import causal_dilated_conv1d_fn, causal_dilated_conv1d_update
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Mamba(nn.Module):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
d_model,
|
|
19
|
+
d_state=16,
|
|
20
|
+
d_conv=4,
|
|
21
|
+
conv_dilations=(1,),
|
|
22
|
+
expand=2,
|
|
23
|
+
dt_rank="auto",
|
|
24
|
+
dt_min=0.001,
|
|
25
|
+
dt_max=0.1,
|
|
26
|
+
dt_init="random",
|
|
27
|
+
dt_scale=1.0,
|
|
28
|
+
dt_init_floor=1e-4,
|
|
29
|
+
conv_bias=True,
|
|
30
|
+
bias=False,
|
|
31
|
+
use_fast_path=False, # kept for API; ignored in this pure-Python build
|
|
32
|
+
layer_idx=None,
|
|
33
|
+
device=None,
|
|
34
|
+
dtype=None,
|
|
35
|
+
):
|
|
36
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.d_model = d_model
|
|
39
|
+
self.d_state = d_state
|
|
40
|
+
self.d_conv = d_conv
|
|
41
|
+
self.conv_dilations = tuple(conv_dilations)
|
|
42
|
+
self.num_conv_branches = len(self.conv_dilations)
|
|
43
|
+
self.conv_state_lens = [(self.d_conv - 1) * d + 1 for d in self.conv_dilations]
|
|
44
|
+
self.expand = expand
|
|
45
|
+
self.d_inner = int(self.expand * self.d_model)
|
|
46
|
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
|
47
|
+
self.use_fast_path = False
|
|
48
|
+
self.layer_idx = layer_idx
|
|
49
|
+
|
|
50
|
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
|
51
|
+
|
|
52
|
+
self.conv1d_layers = nn.ModuleList(
|
|
53
|
+
[
|
|
54
|
+
nn.Conv1d(
|
|
55
|
+
in_channels=self.d_inner,
|
|
56
|
+
out_channels=self.d_inner,
|
|
57
|
+
bias=conv_bias,
|
|
58
|
+
kernel_size=d_conv,
|
|
59
|
+
groups=self.d_inner,
|
|
60
|
+
padding=d * (d_conv - 1),
|
|
61
|
+
dilation=d,
|
|
62
|
+
**factory_kwargs,
|
|
63
|
+
)
|
|
64
|
+
for d in self.conv_dilations
|
|
65
|
+
]
|
|
66
|
+
)
|
|
67
|
+
self.conv1d = self.conv1d_layers[0]
|
|
68
|
+
self.conv_gates = nn.Parameter(torch.ones(self.num_conv_branches))
|
|
69
|
+
|
|
70
|
+
self.activation = "silu"
|
|
71
|
+
self.act = nn.SiLU()
|
|
72
|
+
|
|
73
|
+
self.x_proj = nn.Linear(
|
|
74
|
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
|
75
|
+
)
|
|
76
|
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
|
77
|
+
|
|
78
|
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
|
79
|
+
if dt_init == "constant":
|
|
80
|
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
|
81
|
+
elif dt_init == "random":
|
|
82
|
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
|
83
|
+
else:
|
|
84
|
+
raise NotImplementedError
|
|
85
|
+
|
|
86
|
+
dt = torch.exp(
|
|
87
|
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
|
88
|
+
+ math.log(dt_min)
|
|
89
|
+
).clamp(min=dt_init_floor)
|
|
90
|
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
|
91
|
+
with torch.no_grad():
|
|
92
|
+
self.dt_proj.bias.copy_(inv_dt)
|
|
93
|
+
self.dt_proj.bias._no_reinit = True
|
|
94
|
+
|
|
95
|
+
A = repeat(
|
|
96
|
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
|
97
|
+
"n -> d n",
|
|
98
|
+
d=self.d_inner,
|
|
99
|
+
).contiguous()
|
|
100
|
+
A_log = torch.log(A)
|
|
101
|
+
self.A_log = nn.Parameter(A_log)
|
|
102
|
+
self.A_log._no_weight_decay = True
|
|
103
|
+
|
|
104
|
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device))
|
|
105
|
+
self.D._no_weight_decay = True
|
|
106
|
+
|
|
107
|
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
|
108
|
+
|
|
109
|
+
def forward(self, hidden_states, inference_params=None):
|
|
110
|
+
batch, seqlen, dim = hidden_states.shape
|
|
111
|
+
|
|
112
|
+
conv_state, ssm_state = None, None
|
|
113
|
+
if inference_params is not None:
|
|
114
|
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
|
115
|
+
if inference_params.seqlen_offset > 0:
|
|
116
|
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
|
117
|
+
return out
|
|
118
|
+
|
|
119
|
+
xz = rearrange(
|
|
120
|
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
|
121
|
+
"d (b l) -> b d l",
|
|
122
|
+
l=seqlen,
|
|
123
|
+
)
|
|
124
|
+
if self.in_proj.bias is not None:
|
|
125
|
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
|
126
|
+
|
|
127
|
+
A = -torch.exp(self.A_log.float())
|
|
128
|
+
use_multi_branch = self.num_conv_branches > 1
|
|
129
|
+
|
|
130
|
+
x, z = xz.chunk(2, dim=1)
|
|
131
|
+
# Convs
|
|
132
|
+
conv_outputs = []
|
|
133
|
+
if conv_state is not None:
|
|
134
|
+
for state, conv_layer, dilation in zip(conv_state, self.conv1d_layers, self.conv_dilations):
|
|
135
|
+
pad_len = state.shape[-1] - x.shape[-1]
|
|
136
|
+
if pad_len >= 0:
|
|
137
|
+
state.copy_(F.pad(x, (pad_len, 0)))
|
|
138
|
+
else:
|
|
139
|
+
state.copy_(x[..., -state.shape[-1] :])
|
|
140
|
+
if dilation > 1:
|
|
141
|
+
xi = causal_dilated_conv1d_fn(
|
|
142
|
+
x=x,
|
|
143
|
+
weight=rearrange(conv_layer.weight, "d 1 w -> d w"),
|
|
144
|
+
bias=conv_layer.bias,
|
|
145
|
+
activation=self.activation,
|
|
146
|
+
dilation=dilation,
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
xi = self.act(conv_layer(x)[..., :seqlen])
|
|
150
|
+
conv_outputs.append(xi)
|
|
151
|
+
else:
|
|
152
|
+
for conv_layer, dilation in zip(self.conv1d_layers, self.conv_dilations):
|
|
153
|
+
if dilation > 1:
|
|
154
|
+
xi = causal_dilated_conv1d_fn(
|
|
155
|
+
x=x,
|
|
156
|
+
weight=rearrange(conv_layer.weight, "d 1 w -> d w"),
|
|
157
|
+
bias=conv_layer.bias,
|
|
158
|
+
activation=self.activation,
|
|
159
|
+
dilation=dilation,
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
xi = self.act(conv_layer(x)[..., :seqlen])
|
|
163
|
+
conv_outputs.append(xi)
|
|
164
|
+
gate = torch.softmax(self.conv_gates, dim=0)
|
|
165
|
+
x = sum(g * xi for g, xi in zip(gate, conv_outputs))
|
|
166
|
+
|
|
167
|
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
|
|
168
|
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
|
169
|
+
dt = self.dt_proj.weight @ dt.t()
|
|
170
|
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
|
171
|
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
|
172
|
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
|
173
|
+
y = selective_scan_fn(
|
|
174
|
+
x,
|
|
175
|
+
dt,
|
|
176
|
+
A,
|
|
177
|
+
B,
|
|
178
|
+
C,
|
|
179
|
+
self.D.float(),
|
|
180
|
+
z=z,
|
|
181
|
+
delta_bias=self.dt_proj.bias.float(),
|
|
182
|
+
delta_softplus=True,
|
|
183
|
+
return_last_state=ssm_state is not None,
|
|
184
|
+
)
|
|
185
|
+
if ssm_state is not None:
|
|
186
|
+
y, last_state = y
|
|
187
|
+
ssm_state.copy_(last_state)
|
|
188
|
+
y = rearrange(y, "b d l -> b l d")
|
|
189
|
+
out = self.out_proj(y)
|
|
190
|
+
return out
|
|
191
|
+
|
|
192
|
+
def step(self, hidden_states, conv_state, ssm_state):
|
|
193
|
+
dtype = hidden_states.dtype
|
|
194
|
+
assert hidden_states.shape[1] == 1
|
|
195
|
+
xz = self.in_proj(hidden_states.squeeze(1))
|
|
196
|
+
x, z = xz.chunk(2, dim=-1)
|
|
197
|
+
|
|
198
|
+
conv_outputs = []
|
|
199
|
+
for state, conv_layer, dilation in zip(conv_state, self.conv1d_layers, self.conv_dilations):
|
|
200
|
+
if dilation > 1:
|
|
201
|
+
xi = causal_dilated_conv1d_update(
|
|
202
|
+
x,
|
|
203
|
+
state,
|
|
204
|
+
rearrange(conv_layer.weight, "d 1 w -> d w"),
|
|
205
|
+
conv_layer.bias,
|
|
206
|
+
self.activation,
|
|
207
|
+
dilation=dilation,
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
state.copy_(torch.roll(state, shifts=-1, dims=-1))
|
|
211
|
+
state[:, :, -1] = x
|
|
212
|
+
xi = torch.sum(state * rearrange(conv_layer.weight, "d 1 w -> d w"), dim=-1)
|
|
213
|
+
if conv_layer.bias is not None:
|
|
214
|
+
xi = xi + conv_layer.bias
|
|
215
|
+
xi = self.act(xi).to(dtype=dtype)
|
|
216
|
+
conv_outputs.append(xi)
|
|
217
|
+
gate = torch.softmax(self.conv_gates, dim=0)
|
|
218
|
+
x = sum(g * xi for g, xi in zip(gate, conv_outputs))
|
|
219
|
+
|
|
220
|
+
x_db = self.x_proj(x)
|
|
221
|
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
|
222
|
+
dt = F.linear(dt, self.dt_proj.weight)
|
|
223
|
+
A = -torch.exp(self.A_log.float())
|
|
224
|
+
|
|
225
|
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
|
226
|
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
|
227
|
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
|
228
|
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
|
229
|
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
|
230
|
+
y = y + self.D.to(dtype) * x
|
|
231
|
+
y = y * self.act(z)
|
|
232
|
+
|
|
233
|
+
out = self.out_proj(y)
|
|
234
|
+
return out.unsqueeze(1), conv_state, ssm_state
|
|
235
|
+
|
|
236
|
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
|
237
|
+
device = self.out_proj.weight.device
|
|
238
|
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
|
239
|
+
conv_state = [
|
|
240
|
+
torch.zeros(
|
|
241
|
+
batch_size,
|
|
242
|
+
self.d_model * self.expand,
|
|
243
|
+
state_len,
|
|
244
|
+
device=device,
|
|
245
|
+
dtype=conv_dtype,
|
|
246
|
+
)
|
|
247
|
+
for state_len in self.conv_state_lens
|
|
248
|
+
]
|
|
249
|
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
|
250
|
+
ssm_state = torch.zeros(
|
|
251
|
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
|
252
|
+
)
|
|
253
|
+
return conv_state, ssm_state
|
|
254
|
+
|
|
255
|
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
|
256
|
+
assert self.layer_idx is not None
|
|
257
|
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
|
258
|
+
conv_state = [
|
|
259
|
+
torch.zeros(
|
|
260
|
+
batch_size,
|
|
261
|
+
self.d_model * self.expand,
|
|
262
|
+
state_len,
|
|
263
|
+
device=self.conv1d.weight.device,
|
|
264
|
+
dtype=self.conv1d.weight.dtype,
|
|
265
|
+
)
|
|
266
|
+
for state_len in self.conv_state_lens
|
|
267
|
+
]
|
|
268
|
+
ssm_state = torch.zeros(
|
|
269
|
+
batch_size,
|
|
270
|
+
self.d_model * self.expand,
|
|
271
|
+
self.d_state,
|
|
272
|
+
device=self.dt_proj.weight.device,
|
|
273
|
+
dtype=self.dt_proj.weight.dtype,
|
|
274
|
+
)
|
|
275
|
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
|
276
|
+
else:
|
|
277
|
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
|
278
|
+
if initialize_states:
|
|
279
|
+
for s in conv_state:
|
|
280
|
+
s.zero_()
|
|
281
|
+
ssm_state.zero_()
|
|
282
|
+
return conv_state, ssm_state
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
from einops import rearrange, repeat
|
|
4
|
+
|
|
5
|
+
# Pure-PyTorch reference selective scan (no CUDA/Triton deps)
|
|
6
|
+
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
|
7
|
+
return_last_state=False):
|
|
8
|
+
return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
|
12
|
+
return_last_state=False):
|
|
13
|
+
dtype_in = u.dtype
|
|
14
|
+
u = u.float()
|
|
15
|
+
delta = delta.float()
|
|
16
|
+
if delta_bias is not None:
|
|
17
|
+
delta = delta + delta_bias[..., None].float()
|
|
18
|
+
if delta_softplus:
|
|
19
|
+
delta = F.softplus(delta)
|
|
20
|
+
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
|
21
|
+
is_variable_B = B.dim() >= 3
|
|
22
|
+
is_variable_C = C.dim() >= 3
|
|
23
|
+
if A.is_complex():
|
|
24
|
+
if is_variable_B:
|
|
25
|
+
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
|
26
|
+
if is_variable_C:
|
|
27
|
+
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
|
28
|
+
else:
|
|
29
|
+
B = B.float()
|
|
30
|
+
C = C.float()
|
|
31
|
+
x = A.new_zeros((batch, dim, dstate))
|
|
32
|
+
ys = []
|
|
33
|
+
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
|
34
|
+
if not is_variable_B:
|
|
35
|
+
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
|
36
|
+
else:
|
|
37
|
+
if B.dim() == 3:
|
|
38
|
+
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
|
39
|
+
else:
|
|
40
|
+
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
|
41
|
+
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
|
42
|
+
if is_variable_C and C.dim() == 4:
|
|
43
|
+
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
|
44
|
+
last_state = None
|
|
45
|
+
for i in range(u.shape[2]):
|
|
46
|
+
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
|
47
|
+
if not is_variable_C:
|
|
48
|
+
y = torch.einsum('bdn,dn->bd', x, C)
|
|
49
|
+
else:
|
|
50
|
+
if C.dim() == 3:
|
|
51
|
+
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
|
52
|
+
else:
|
|
53
|
+
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
|
54
|
+
if i == u.shape[2] - 1:
|
|
55
|
+
last_state = x
|
|
56
|
+
if y.is_complex():
|
|
57
|
+
y = y.real * 2
|
|
58
|
+
ys.append(y)
|
|
59
|
+
y = torch.stack(ys, dim=2) # (batch dim L)
|
|
60
|
+
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
|
61
|
+
if z is not None:
|
|
62
|
+
out = out * F.silu(z)
|
|
63
|
+
out = out.to(dtype=dtype_in)
|
|
64
|
+
if return_last_state:
|
|
65
|
+
return out, last_state
|
|
66
|
+
return out
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: lite-mamba
|
|
3
|
+
Version: 0.1.4
|
|
4
|
+
Summary: Pure-PyTorch lightweight Mamba with multi-dilated causal conv front-end
|
|
5
|
+
Home-page: https://github.com/Mrrobi/lite_mamba
|
|
6
|
+
Author: Md Robiuddin
|
|
7
|
+
Author-email: mrrobi040@gmail.com
|
|
8
|
+
License: Apache-2.0
|
|
9
|
+
Project-URL: Homepage, https://github.com/Mrrobi/lite_mamba
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Intended Audience :: Science/Research
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.9
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
Requires-Dist: torch>=2.0
|
|
22
|
+
Requires-Dist: einops>=0.6
|
|
23
|
+
Dynamic: author
|
|
24
|
+
Dynamic: author-email
|
|
25
|
+
Dynamic: classifier
|
|
26
|
+
Dynamic: description
|
|
27
|
+
Dynamic: description-content-type
|
|
28
|
+
Dynamic: home-page
|
|
29
|
+
Dynamic: license
|
|
30
|
+
Dynamic: project-url
|
|
31
|
+
Dynamic: requires-dist
|
|
32
|
+
Dynamic: requires-python
|
|
33
|
+
Dynamic: summary
|
|
34
|
+
|
|
35
|
+
# lite-mamba
|
|
36
|
+
|
|
37
|
+
A minimal, pure-PyTorch version of Mamba with a multi-dilated causal depthwise conv front-end. No CUDA/Triton build needed; works on CPU or GPU with standard PyTorch ops.
|
|
38
|
+
|
|
39
|
+
## Install
|
|
40
|
+
```bash
|
|
41
|
+
pip install torch einops
|
|
42
|
+
pip install lite-mamba
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Usage
|
|
46
|
+
```python
|
|
47
|
+
from lite_mamba import Mamba
|
|
48
|
+
import torch
|
|
49
|
+
|
|
50
|
+
x = torch.randn(2, 128, 512) # (batch, seq, d_model)
|
|
51
|
+
m = Mamba(d_model=512, d_conv=3, conv_dilations=(1,2,4,8))
|
|
52
|
+
y = m(x)
|
|
53
|
+
print(y.shape) # (2, 128, 512)
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## API quick reference
|
|
57
|
+
`Mamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=False, layer_idx=None, device=None, dtype=None)`
|
|
58
|
+
|
|
59
|
+
- `d_model` (int, required): input/output embedding size.
|
|
60
|
+
- `d_state` (int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.
|
|
61
|
+
- `d_conv` (int, default 4): depthwise conv kernel size for each branch.
|
|
62
|
+
- `conv_dilations` (tuple[int], default `(1,)`): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is `(d_conv-1)*dilation`.
|
|
63
|
+
- `expand` (float, default 2): inner width multiplier; sets `d_inner = expand * d_model`.
|
|
64
|
+
- `dt_rank` (int or "auto", default "auto"): rank of delta projection. "auto" sets `ceil(d_model/16)`.
|
|
65
|
+
- `dt_min`, `dt_max` (float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.
|
|
66
|
+
- `dt_init` ("random" | "constant", default "random") and `dt_scale`, `dt_init_floor`: control delta init magnitude/stability.
|
|
67
|
+
- `conv_bias` (bool, default True): include bias in depthwise convs.
|
|
68
|
+
- `bias` (bool, default False): include bias in input/output linear projections.
|
|
69
|
+
- `use_fast_path` (bool): ignored in this pure-PyTorch build; kept for API compatibility.
|
|
70
|
+
- `layer_idx` (int | None): identifier for streaming cache registration; required when using `allocate_inference_cache` + `inference_params`.
|
|
71
|
+
- `device`, `dtype`: standard module factory kwargs.
|
|
72
|
+
|
|
73
|
+
### Inference / streaming helpers
|
|
74
|
+
- `allocate_inference_cache(batch_size, max_seqlen, dtype=None)`: preallocates conv and SSM state buffers for step-wise decoding.
|
|
75
|
+
- `step(hidden_states, conv_state, ssm_state)`: single-token forward (expects `hidden_states` with shape `(B, 1, d_model)`).
|
|
76
|
+
- `forward(..., inference_params)`: if `inference_params` has cached states (with `key_value_memory_dict` and `seqlen_offset`), uses them for streaming.
|
|
77
|
+
|
|
78
|
+
## Highlights
|
|
79
|
+
- Multi-branch causal dilated convs (weighted sum via learned gates).
|
|
80
|
+
- Pure Python: no custom C++/CUDA or Triton kernels.
|
|
81
|
+
- Streaming support via per-branch conv states and SSM state caching.
|
|
82
|
+
|
|
83
|
+
## Practical setups
|
|
84
|
+
- **Local modeling / small context**: `d_conv=3`, `conv_dilations=(1,2,4)`, `d_state=8–16`, `expand=2`.
|
|
85
|
+
- **Longer context**: widen `conv_dilations` (e.g., `(1,2,4,8,16)`) or increase `d_state` to 32; expect higher memory/compute.
|
|
86
|
+
- **Streaming/AR decoding**: call `allocate_inference_cache` once per layer, pass `inference_params` during forward; use `step` inside your generation loop.
|
|
87
|
+
- **Stability first**: keep `dt_min` >= 1e-4 and `dt_init_floor` small; leave defaults unless you observe drift or exploding activations.
|
|
88
|
+
|
|
89
|
+
## Notes
|
|
90
|
+
- Set different `conv_dilations` to adjust receptive field; keep kernels small (e.g., 3–5) to avoid excessive padding.
|
|
91
|
+
- `use_fast_path` flag is ignored here (kept for API compatibility).
|
|
92
|
+
- Reference selective scan is implemented in PyTorch for portability; faster fused kernels are omitted intentionally.
|
|
93
|
+
|
|
94
|
+
## License
|
|
95
|
+
Apache-2.0
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
setup.py
|
|
4
|
+
lite_mamba/__init__.py
|
|
5
|
+
lite_mamba/causal_dilated_conv1d.py
|
|
6
|
+
lite_mamba/mamba_simple.py
|
|
7
|
+
lite_mamba/selective_scan.py
|
|
8
|
+
lite_mamba.egg-info/PKG-INFO
|
|
9
|
+
lite_mamba.egg-info/SOURCES.txt
|
|
10
|
+
lite_mamba.egg-info/dependency_links.txt
|
|
11
|
+
lite_mamba.egg-info/requires.txt
|
|
12
|
+
lite_mamba.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
lite_mamba
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from setuptools import find_packages, setup
|
|
4
|
+
|
|
5
|
+
ROOT = Path(__file__).parent
|
|
6
|
+
README = (ROOT / "README.md").read_text(encoding="utf-8")
|
|
7
|
+
|
|
8
|
+
setup(
|
|
9
|
+
name="lite-mamba",
|
|
10
|
+
version="0.1.4",
|
|
11
|
+
description="Pure-PyTorch lightweight Mamba with multi-dilated causal conv front-end",
|
|
12
|
+
long_description=README,
|
|
13
|
+
long_description_content_type="text/markdown",
|
|
14
|
+
author="Md Robiuddin",
|
|
15
|
+
author_email="mrrobi040@gmail.com",
|
|
16
|
+
url="https://github.com/Mrrobi/lite_mamba",
|
|
17
|
+
project_urls={
|
|
18
|
+
"Homepage": "https://github.com/Mrrobi/lite_mamba",
|
|
19
|
+
},
|
|
20
|
+
packages=find_packages(exclude=("tests", "docs")),
|
|
21
|
+
python_requires=">=3.9",
|
|
22
|
+
install_requires=[
|
|
23
|
+
"torch>=2.0",
|
|
24
|
+
"einops>=0.6",
|
|
25
|
+
],
|
|
26
|
+
license="Apache-2.0",
|
|
27
|
+
classifiers=[
|
|
28
|
+
"License :: OSI Approved :: Apache Software License",
|
|
29
|
+
"Programming Language :: Python :: 3",
|
|
30
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
31
|
+
"Programming Language :: Python :: 3.9",
|
|
32
|
+
"Programming Language :: Python :: 3.10",
|
|
33
|
+
"Programming Language :: Python :: 3.11",
|
|
34
|
+
"Programming Language :: Python :: 3.12",
|
|
35
|
+
"Intended Audience :: Science/Research",
|
|
36
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
37
|
+
],
|
|
38
|
+
)
|