fluidflow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fluidFlow/__init__.py ADDED
@@ -0,0 +1,16 @@
1
+ """Top-level package for FluidFlow.
2
+
3
+ Provides a minimal package initializer so the project can be installed.
4
+ """
5
+
6
+ __all__ = [
7
+ "attention",
8
+ "basic_modules",
9
+ "dit",
10
+ "moe",
11
+ "trainer",
12
+ "unet",
13
+ "flow_matching",
14
+ ]
15
+
16
+ __version__ = "0.1.0"
fluidFlow/attention.py ADDED
@@ -0,0 +1,299 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange, repeat
6
+ from timm.layers import trunc_normal_
7
+ try:
8
+ from flash_attn.cute import flash_attn_func
9
+ is_flash_attn_available = True
10
+ print("Flash attention 4 enabled ⚡")
11
+ except:
12
+ is_flash_attn_available = False
13
+ print("Flash attention 4 not found, using torch scaled_dot_product")
14
+
15
+ # Attention with rope and rmsnorm. Borrowed from https://github.dev/hustvl/LightningDiT/blob/main/models/lightningdit.py
16
+ # F4A was added too
17
+ class Attention(nn.Module):
18
+ """
19
+ Attention module of LightningDiT.
20
+ """
21
+ def __init__(
22
+ self,
23
+ dim: int,
24
+ num_heads: int = 8,
25
+ qkv_bias: bool = False,
26
+ qk_norm: bool = False,
27
+ attn_drop: float = 0.,
28
+ proj_drop: float = 0.,
29
+ proj_bias: bool = True,
30
+ fused_attn: bool = True,
31
+ ) -> None:
32
+ super().__init__()
33
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
34
+
35
+ self.num_heads = num_heads
36
+ self.head_dim = dim // num_heads
37
+ self.scale = self.head_dim ** -0.5
38
+ self.fused_attn = fused_attn
39
+
40
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
41
+ self.qk_norm = qk_norm
42
+ self.q_norm = nn.RMSNorm(self.head_dim)
43
+ self.k_norm = nn.RMSNorm(self.head_dim)
44
+ self.attn_drop = nn.Dropout(attn_drop)
45
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
46
+ self.proj_drop = nn.Dropout(proj_drop)
47
+
48
+ def forward(self, x: torch.Tensor, rope=None) -> torch.Tensor:
49
+ B, N, C = x.shape
50
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # 3, B, heads, N, head_dim
51
+ q, k, v = qkv.unbind(0)
52
+ dtype = q.dtype
53
+ # q, k = self.q_norm(q), self.k_norm(k)
54
+ # this is done this way to avoid dtype mismatch when using fp16/bf16
55
+ if self.qk_norm:
56
+ q = self.q_norm(q.to(self.q_norm.weight.dtype)).to(dtype)
57
+ k = self.k_norm(k.to(self.k_norm.weight.dtype)).to(dtype)
58
+
59
+ if rope is not None:
60
+ q = rope(q)
61
+ k = rope(k)
62
+ # if i don't do this, it explodes when i compile the model, but only for some configurations
63
+ q = q.contiguous()
64
+ k = k.contiguous()
65
+ v = v.contiguous()
66
+ if is_flash_attn_available:
67
+ # FA4 expects (B, N, heads, head_dim) → permute from (B, heads, N, head_dim)
68
+ q_fa = q.permute(0, 2, 1, 3) # (B, N, heads, head_dim)
69
+ k_fa = k.permute(0, 2, 1, 3)
70
+ v_fa = v.permute(0, 2, 1, 3)
71
+ x, *_ = flash_attn_func(
72
+ q_fa, k_fa, v_fa,
73
+ # dropout_p=self.attn_drop.p if self.training else 0.,
74
+ causal=False,
75
+ )
76
+ # FA4 returns (B, N, heads, head_dim) → back to (B, heads, N, head_dim)
77
+ x = x.permute(0, 2, 1, 3)
78
+ elif self.fused_attn:
79
+ x = F.scaled_dot_product_attention(
80
+ q, k, v,
81
+ dropout_p=self.attn_drop.p if self.training else 0.,
82
+ )
83
+ else:
84
+ q = q * self.scale
85
+ attn = q @ k.transpose(-2, -1)
86
+ attn = attn.softmax(dim=-1)
87
+ attn = self.attn_drop(attn)
88
+ x = attn @ v
89
+
90
+ x = x.transpose(1, 2).reshape(B, N, C)
91
+ x = self.proj(x)
92
+ x = self.proj_drop(x)
93
+ return x
94
+
95
+
96
+ class LinearAttention(nn.Module):
97
+ """
98
+ A possible formulation can be found on https://arxiv.org/pdf/2503.16726
99
+ """
100
+ def __init__(self, dim, num_heads=4, qkv_bias=False, proj_bias=True, qk_norm=False, **kwargs):
101
+ super().__init__()
102
+ assert dim % num_heads == 0, 'dimension must be divisible by number of heads'
103
+ self.dim_head = dim // num_heads
104
+ self.scale = self.dim_head ** -0.5
105
+ self.heads = num_heads
106
+ # TODO: temporary left qkv_bias unused
107
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
108
+ self.qk_norm = qk_norm
109
+ self.q_norm = nn.RMSNorm(dim) if qk_norm else nn.Identity()
110
+ self.k_norm = nn.RMSNorm(dim) if qk_norm else nn.Identity()
111
+ self.proj = nn.Sequential(
112
+ nn.Linear(dim, dim, bias=proj_bias),
113
+ # nn.RMSNorm(dim),
114
+ )
115
+
116
+ def forward(self, x, rope=None):
117
+ B, N, C = x.shape # batch, sequence, channels
118
+ qkv = self.qkv(x).reshape(B, N, 3, C)
119
+ q, k, v = qkv.unbind(2) # Each is (B, N, C)
120
+
121
+ # Apply normalization BEFORE reshaping into heads
122
+ if self.qk_norm:
123
+ q = self.q_norm(q.to(self.q_norm.weight.dtype))
124
+ k = self.k_norm(k.to(self.k_norm.weight.dtype)) # (B, N, C)
125
+
126
+ # Now reshape into multi-head format
127
+ q = rearrange(q, 'b n (h d) -> b h d n', h=self.heads) # (B, h, d, N)
128
+ k = rearrange(k, 'b n (h d) -> b h d n', h=self.heads) # (B, h, d, N)
129
+ v = rearrange(v, 'b n (h d) -> b h d n', h=self.heads) # (B, h, d, N)
130
+
131
+ # use the relu approach https://export.arxiv.org/pdf/2410.10629
132
+ q = F.relu(q, inplace=False)
133
+ k = F.relu(k, inplace=False)
134
+
135
+ eps = torch.finfo(q.dtype).eps
136
+ # Use matrix multiplication for normalization
137
+ z = 1 / (k.sum(dim=-1, keepdim=True).transpose(-2, -1) @ q + eps)
138
+ # k.sum(dim=-1, keepdim=True): (B, h, d, 1)
139
+ # .transpose(-2, -1): (B, h, 1, d)
140
+ # @ q: (B, h, 1, d) @ (B, h, d, N) = (B, h, 1, N)
141
+
142
+ context = v @ k.transpose(-2, -1) # (B, h, d, N) @ (B, h, N, d) = (B, h, d, d)
143
+ out = context @ q # (B, h, d, d) @ (B, h, d, N) = (B, h, d, N)
144
+ out = (out * z) # (B, h, d, N) * (B, h, 1, N) = (B, h, d, N)
145
+ # out /= self.scale
146
+ out = rearrange(out, 'b h d n -> b n (h d)') # (B, N, C)
147
+ return self.proj(out)
148
+
149
+
150
+ class WindowAttention(nn.Module):
151
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
152
+ Adapted for 1D sequences.
153
+
154
+ Args:
155
+ dim (int): Number of input channels.
156
+ window_size (int): The length of the window.
157
+ num_heads (int): Number of attention heads.
158
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
159
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
160
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
161
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
162
+ """
163
+
164
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., qk_norm=False):
165
+
166
+ super().__init__()
167
+ self.dim = dim
168
+ self.window_size = window_size # W
169
+ self.num_heads = num_heads
170
+ head_dim = dim // num_heads
171
+ self.scale = qk_scale or head_dim ** -0.5
172
+
173
+ # Define a parameter table of relative position bias
174
+ # For 1D, the range of relative positions is [-(W-1), W-1], so we need 2*W - 1 buckets
175
+ self.relative_position_bias_table = nn.Parameter(
176
+ torch.zeros(2 * window_size - 1, num_heads)) # 2*W-1, nH
177
+ self.pos_bias_scale = nn.Parameter(torch.zeros(1))
178
+
179
+ # Get pair-wise relative position index for each token inside the window
180
+ coords = torch.arange(self.window_size) # W
181
+ relative_coords = coords[:, None] - coords[None, :] # W, W
182
+ relative_coords += self.window_size - 1 # shift to start from 0
183
+ self.register_buffer("relative_position_index", relative_coords)
184
+
185
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
186
+ self.attn_drop = nn.Dropout(attn_drop)
187
+ self.proj = nn.Linear(dim, dim)
188
+ self.proj_drop = nn.Dropout(proj_drop)
189
+ self.qk_norm = qk_norm
190
+ self.q_norm = nn.RMSNorm(head_dim) if qk_norm else nn.Identity()
191
+ self.k_norm = nn.RMSNorm(head_dim) if qk_norm else nn.Identity()
192
+
193
+ trunc_normal_(self.relative_position_bias_table, std=.02)
194
+ self.softmax = nn.Softmax(dim=-1)
195
+
196
+ def forward(self, x, mask=None):
197
+ """
198
+ Args:
199
+ x: input features with shape of (num_windows*B, N, C)
200
+ mask: (0/-inf) mask with shape of (num_windows, W, W) or None
201
+ """
202
+ B_, N, C = x.shape
203
+ # qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
204
+ # q, k, v = qkv[0], qkv[1], qkv[2] # (B_, num_heads, N, head_dim)
205
+ qkv = self.qkv(x).reshape(B_, N, 3, C)
206
+ q, k, v = qkv.unbind(2) # Each is (B, N, C)
207
+
208
+ # Now reshape into multi-head format
209
+ q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) # (B, h, N, d)
210
+ k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads) # (B, h, N, d)
211
+ v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads) # (B, h, N, d)
212
+
213
+ if self.qk_norm:
214
+ q = self.q_norm(q.to(self.q_norm.weight.dtype))
215
+ k = self.k_norm(k.to(self.k_norm.weight.dtype)) # (B, N, C)
216
+
217
+ # Prepare relative position bias
218
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
219
+ self.window_size, self.window_size, -1) # W, W, nH
220
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, W, W
221
+
222
+ # Combine with window mask if present
223
+ # if mask is not None:
224
+ # nW = mask.shape[0]
225
+ # # Expand relative_position_bias for all windows
226
+ # attn_mask = relative_position_bias.unsqueeze(0).expand(B_ // nW, -1, -1, -1) # (B_//nW, nH, W, W)
227
+ # # Add window mask (broadcast across heads)
228
+ # attn_mask = attn_mask + mask.unsqueeze(0).unsqueeze(1) # (B_//nW, nW, nH, W, W) -> (B_//nW, 1, 1, W, W)
229
+ # attn_mask = attn_mask.view(B_, self.num_heads, N, N)
230
+ # else:
231
+ # # Just use relative position bias
232
+ # attn_mask = relative_position_bias.unsqueeze(0) # (1, nH, W, W)
233
+ relative_position_bias = relative_position_bias * torch.sigmoid(self.pos_bias_scale)
234
+ attn_bias = relative_position_bias.expand(B_, -1, -1, -1)
235
+ # Use PyTorch's fused attention
236
+ x = F.scaled_dot_product_attention(
237
+ q, k, v,
238
+ attn_mask=attn_bias,
239
+ dropout_p=self.attn_drop.p if self.training else 0.0,
240
+ )
241
+
242
+ x = x.transpose(1, 2).reshape(B_, N, C)
243
+ x = self.proj(x)
244
+ x = self.proj_drop(x)
245
+ return x
246
+
247
+
248
+ # -- RoPE --
249
+ def rotate_half(x):
250
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
251
+ x1, x2 = x.unbind(dim = -1)
252
+ x = torch.stack((-x2, x1), dim = -1)
253
+ return rearrange(x, '... d r -> ... (d r)')
254
+
255
+
256
+ class VisionRotaryEmbeddingFast(nn.Module):
257
+ def __init__(
258
+ self,
259
+ dim,
260
+ max_seq_len=1024, # Set this large enough for your data (e.g. 1024 or 4096)
261
+ theta = 10000,
262
+ ):
263
+ super().__init__()
264
+
265
+ # 1. Generate the frequencies (1D only)
266
+ # inv_freq shape: (dim // 2)
267
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
268
+
269
+ # 2. Generate position indices: [0, 1, ..., max_seq_len-1]
270
+ t = torch.arange(max_seq_len).float()
271
+
272
+ # 3. Compute outer product: (max_seq_len, dim // 2)
273
+ freqs = torch.outer(t, inv_freq)
274
+
275
+ # 4. Repeat frequencies to match the specific "rotate_half" format
276
+ # Your previous code used repeat(..., '... n -> ... (n r)', r=2)
277
+ # This doubles the last dim so it matches the input shape
278
+ freqs = repeat(freqs, 'n d -> n (d r)', r=2)
279
+
280
+ # 5. Compute Sin and Cos
281
+ freqs_cos = freqs.cos() # Shape (max_seq_len, dim)
282
+ freqs_sin = freqs.sin() # Shape (max_seq_len, dim)
283
+
284
+ # Register as buffers (so they are saved with state_dict but not trained)
285
+ self.register_buffer("freqs_cos", freqs_cos)
286
+ self.register_buffer("freqs_sin", freqs_sin)
287
+
288
+ print(f'======== RoPE 1D initialized with shape {self.freqs_cos.shape} ========')
289
+
290
+ def forward(self, t):
291
+ # t shape: (Batch, Heads, Seq_Len, Dim)
292
+ seq_len = t.shape[-2]
293
+
294
+ # Slice the cached frequencies to the current sequence length
295
+ # Reshape to (1, 1, Seq_Len, Dim) for broadcasting
296
+ cos = self.freqs_cos[:seq_len].view(1, 1, seq_len, -1)
297
+ sin = self.freqs_sin[:seq_len].view(1, 1, seq_len, -1)
298
+
299
+ return t * cos + rotate_half(t) * sin
@@ -0,0 +1,24 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ # borrowed from LightningDiT https://github.com/hustvl/LightningDiT/blob/main/models/swiglu_ffn.py
5
+ class SwiGLUFFN(nn.Module):
6
+ def __init__(
7
+ self,
8
+ in_features: int,
9
+ hidden_features=None,
10
+ out_features=None,
11
+ bias=True,
12
+ ) -> None:
13
+ super().__init__()
14
+ out_features = out_features or in_features
15
+ hidden_features = hidden_features or in_features
16
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
17
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
18
+
19
+ # @torch.compile
20
+ def forward(self, x):
21
+ x12 = self.w12(x)
22
+ x1, x2 = x12.chunk(2, dim=-1)
23
+ hidden = F.silu(x1) * x2
24
+ return self.w3(hidden)