fluidflow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluidFlow/__init__.py +16 -0
- fluidFlow/attention.py +299 -0
- fluidFlow/basic_modules.py +24 -0
- fluidFlow/dit.py +650 -0
- fluidFlow/flow_matching/__init__.py +124 -0
- fluidFlow/flow_matching/integrators.py +123 -0
- fluidFlow/flow_matching/path.py +192 -0
- fluidFlow/flow_matching/transport.py +615 -0
- fluidFlow/moe.py +184 -0
- fluidFlow/trainer.py +383 -0
- fluidFlow/unet.py +521 -0
- fluidflow-0.1.0.dist-info/METADATA +237 -0
- fluidflow-0.1.0.dist-info/RECORD +15 -0
- fluidflow-0.1.0.dist-info/WHEEL +5 -0
- fluidflow-0.1.0.dist-info/top_level.txt +1 -0
fluidFlow/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Top-level package for FluidFlow.
|
|
2
|
+
|
|
3
|
+
Provides a minimal package initializer so the project can be installed.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"attention",
|
|
8
|
+
"basic_modules",
|
|
9
|
+
"dit",
|
|
10
|
+
"moe",
|
|
11
|
+
"trainer",
|
|
12
|
+
"unet",
|
|
13
|
+
"flow_matching",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
__version__ = "0.1.0"
|
fluidFlow/attention.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from einops import rearrange, repeat
|
|
6
|
+
from timm.layers import trunc_normal_
|
|
7
|
+
try:
|
|
8
|
+
from flash_attn.cute import flash_attn_func
|
|
9
|
+
is_flash_attn_available = True
|
|
10
|
+
print("Flash attention 4 enabled ⚡")
|
|
11
|
+
except:
|
|
12
|
+
is_flash_attn_available = False
|
|
13
|
+
print("Flash attention 4 not found, using torch scaled_dot_product")
|
|
14
|
+
|
|
15
|
+
# Attention with rope and rmsnorm. Borrowed from https://github.dev/hustvl/LightningDiT/blob/main/models/lightningdit.py
|
|
16
|
+
# F4A was added too
|
|
17
|
+
class Attention(nn.Module):
|
|
18
|
+
"""
|
|
19
|
+
Attention module of LightningDiT.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
dim: int,
|
|
24
|
+
num_heads: int = 8,
|
|
25
|
+
qkv_bias: bool = False,
|
|
26
|
+
qk_norm: bool = False,
|
|
27
|
+
attn_drop: float = 0.,
|
|
28
|
+
proj_drop: float = 0.,
|
|
29
|
+
proj_bias: bool = True,
|
|
30
|
+
fused_attn: bool = True,
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__()
|
|
33
|
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
34
|
+
|
|
35
|
+
self.num_heads = num_heads
|
|
36
|
+
self.head_dim = dim // num_heads
|
|
37
|
+
self.scale = self.head_dim ** -0.5
|
|
38
|
+
self.fused_attn = fused_attn
|
|
39
|
+
|
|
40
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
41
|
+
self.qk_norm = qk_norm
|
|
42
|
+
self.q_norm = nn.RMSNorm(self.head_dim)
|
|
43
|
+
self.k_norm = nn.RMSNorm(self.head_dim)
|
|
44
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
45
|
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
|
46
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
47
|
+
|
|
48
|
+
def forward(self, x: torch.Tensor, rope=None) -> torch.Tensor:
|
|
49
|
+
B, N, C = x.shape
|
|
50
|
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # 3, B, heads, N, head_dim
|
|
51
|
+
q, k, v = qkv.unbind(0)
|
|
52
|
+
dtype = q.dtype
|
|
53
|
+
# q, k = self.q_norm(q), self.k_norm(k)
|
|
54
|
+
# this is done this way to avoid dtype mismatch when using fp16/bf16
|
|
55
|
+
if self.qk_norm:
|
|
56
|
+
q = self.q_norm(q.to(self.q_norm.weight.dtype)).to(dtype)
|
|
57
|
+
k = self.k_norm(k.to(self.k_norm.weight.dtype)).to(dtype)
|
|
58
|
+
|
|
59
|
+
if rope is not None:
|
|
60
|
+
q = rope(q)
|
|
61
|
+
k = rope(k)
|
|
62
|
+
# if i don't do this, it explodes when i compile the model, but only for some configurations
|
|
63
|
+
q = q.contiguous()
|
|
64
|
+
k = k.contiguous()
|
|
65
|
+
v = v.contiguous()
|
|
66
|
+
if is_flash_attn_available:
|
|
67
|
+
# FA4 expects (B, N, heads, head_dim) → permute from (B, heads, N, head_dim)
|
|
68
|
+
q_fa = q.permute(0, 2, 1, 3) # (B, N, heads, head_dim)
|
|
69
|
+
k_fa = k.permute(0, 2, 1, 3)
|
|
70
|
+
v_fa = v.permute(0, 2, 1, 3)
|
|
71
|
+
x, *_ = flash_attn_func(
|
|
72
|
+
q_fa, k_fa, v_fa,
|
|
73
|
+
# dropout_p=self.attn_drop.p if self.training else 0.,
|
|
74
|
+
causal=False,
|
|
75
|
+
)
|
|
76
|
+
# FA4 returns (B, N, heads, head_dim) → back to (B, heads, N, head_dim)
|
|
77
|
+
x = x.permute(0, 2, 1, 3)
|
|
78
|
+
elif self.fused_attn:
|
|
79
|
+
x = F.scaled_dot_product_attention(
|
|
80
|
+
q, k, v,
|
|
81
|
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
q = q * self.scale
|
|
85
|
+
attn = q @ k.transpose(-2, -1)
|
|
86
|
+
attn = attn.softmax(dim=-1)
|
|
87
|
+
attn = self.attn_drop(attn)
|
|
88
|
+
x = attn @ v
|
|
89
|
+
|
|
90
|
+
x = x.transpose(1, 2).reshape(B, N, C)
|
|
91
|
+
x = self.proj(x)
|
|
92
|
+
x = self.proj_drop(x)
|
|
93
|
+
return x
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class LinearAttention(nn.Module):
|
|
97
|
+
"""
|
|
98
|
+
A possible formulation can be found on https://arxiv.org/pdf/2503.16726
|
|
99
|
+
"""
|
|
100
|
+
def __init__(self, dim, num_heads=4, qkv_bias=False, proj_bias=True, qk_norm=False, **kwargs):
|
|
101
|
+
super().__init__()
|
|
102
|
+
assert dim % num_heads == 0, 'dimension must be divisible by number of heads'
|
|
103
|
+
self.dim_head = dim // num_heads
|
|
104
|
+
self.scale = self.dim_head ** -0.5
|
|
105
|
+
self.heads = num_heads
|
|
106
|
+
# TODO: temporary left qkv_bias unused
|
|
107
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
|
108
|
+
self.qk_norm = qk_norm
|
|
109
|
+
self.q_norm = nn.RMSNorm(dim) if qk_norm else nn.Identity()
|
|
110
|
+
self.k_norm = nn.RMSNorm(dim) if qk_norm else nn.Identity()
|
|
111
|
+
self.proj = nn.Sequential(
|
|
112
|
+
nn.Linear(dim, dim, bias=proj_bias),
|
|
113
|
+
# nn.RMSNorm(dim),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def forward(self, x, rope=None):
|
|
117
|
+
B, N, C = x.shape # batch, sequence, channels
|
|
118
|
+
qkv = self.qkv(x).reshape(B, N, 3, C)
|
|
119
|
+
q, k, v = qkv.unbind(2) # Each is (B, N, C)
|
|
120
|
+
|
|
121
|
+
# Apply normalization BEFORE reshaping into heads
|
|
122
|
+
if self.qk_norm:
|
|
123
|
+
q = self.q_norm(q.to(self.q_norm.weight.dtype))
|
|
124
|
+
k = self.k_norm(k.to(self.k_norm.weight.dtype)) # (B, N, C)
|
|
125
|
+
|
|
126
|
+
# Now reshape into multi-head format
|
|
127
|
+
q = rearrange(q, 'b n (h d) -> b h d n', h=self.heads) # (B, h, d, N)
|
|
128
|
+
k = rearrange(k, 'b n (h d) -> b h d n', h=self.heads) # (B, h, d, N)
|
|
129
|
+
v = rearrange(v, 'b n (h d) -> b h d n', h=self.heads) # (B, h, d, N)
|
|
130
|
+
|
|
131
|
+
# use the relu approach https://export.arxiv.org/pdf/2410.10629
|
|
132
|
+
q = F.relu(q, inplace=False)
|
|
133
|
+
k = F.relu(k, inplace=False)
|
|
134
|
+
|
|
135
|
+
eps = torch.finfo(q.dtype).eps
|
|
136
|
+
# Use matrix multiplication for normalization
|
|
137
|
+
z = 1 / (k.sum(dim=-1, keepdim=True).transpose(-2, -1) @ q + eps)
|
|
138
|
+
# k.sum(dim=-1, keepdim=True): (B, h, d, 1)
|
|
139
|
+
# .transpose(-2, -1): (B, h, 1, d)
|
|
140
|
+
# @ q: (B, h, 1, d) @ (B, h, d, N) = (B, h, 1, N)
|
|
141
|
+
|
|
142
|
+
context = v @ k.transpose(-2, -1) # (B, h, d, N) @ (B, h, N, d) = (B, h, d, d)
|
|
143
|
+
out = context @ q # (B, h, d, d) @ (B, h, d, N) = (B, h, d, N)
|
|
144
|
+
out = (out * z) # (B, h, d, N) * (B, h, 1, N) = (B, h, d, N)
|
|
145
|
+
# out /= self.scale
|
|
146
|
+
out = rearrange(out, 'b h d n -> b n (h d)') # (B, N, C)
|
|
147
|
+
return self.proj(out)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class WindowAttention(nn.Module):
|
|
151
|
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
|
152
|
+
Adapted for 1D sequences.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
dim (int): Number of input channels.
|
|
156
|
+
window_size (int): The length of the window.
|
|
157
|
+
num_heads (int): Number of attention heads.
|
|
158
|
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
159
|
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
|
160
|
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
|
161
|
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., qk_norm=False):
|
|
165
|
+
|
|
166
|
+
super().__init__()
|
|
167
|
+
self.dim = dim
|
|
168
|
+
self.window_size = window_size # W
|
|
169
|
+
self.num_heads = num_heads
|
|
170
|
+
head_dim = dim // num_heads
|
|
171
|
+
self.scale = qk_scale or head_dim ** -0.5
|
|
172
|
+
|
|
173
|
+
# Define a parameter table of relative position bias
|
|
174
|
+
# For 1D, the range of relative positions is [-(W-1), W-1], so we need 2*W - 1 buckets
|
|
175
|
+
self.relative_position_bias_table = nn.Parameter(
|
|
176
|
+
torch.zeros(2 * window_size - 1, num_heads)) # 2*W-1, nH
|
|
177
|
+
self.pos_bias_scale = nn.Parameter(torch.zeros(1))
|
|
178
|
+
|
|
179
|
+
# Get pair-wise relative position index for each token inside the window
|
|
180
|
+
coords = torch.arange(self.window_size) # W
|
|
181
|
+
relative_coords = coords[:, None] - coords[None, :] # W, W
|
|
182
|
+
relative_coords += self.window_size - 1 # shift to start from 0
|
|
183
|
+
self.register_buffer("relative_position_index", relative_coords)
|
|
184
|
+
|
|
185
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
186
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
187
|
+
self.proj = nn.Linear(dim, dim)
|
|
188
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
189
|
+
self.qk_norm = qk_norm
|
|
190
|
+
self.q_norm = nn.RMSNorm(head_dim) if qk_norm else nn.Identity()
|
|
191
|
+
self.k_norm = nn.RMSNorm(head_dim) if qk_norm else nn.Identity()
|
|
192
|
+
|
|
193
|
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
194
|
+
self.softmax = nn.Softmax(dim=-1)
|
|
195
|
+
|
|
196
|
+
def forward(self, x, mask=None):
|
|
197
|
+
"""
|
|
198
|
+
Args:
|
|
199
|
+
x: input features with shape of (num_windows*B, N, C)
|
|
200
|
+
mask: (0/-inf) mask with shape of (num_windows, W, W) or None
|
|
201
|
+
"""
|
|
202
|
+
B_, N, C = x.shape
|
|
203
|
+
# qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
204
|
+
# q, k, v = qkv[0], qkv[1], qkv[2] # (B_, num_heads, N, head_dim)
|
|
205
|
+
qkv = self.qkv(x).reshape(B_, N, 3, C)
|
|
206
|
+
q, k, v = qkv.unbind(2) # Each is (B, N, C)
|
|
207
|
+
|
|
208
|
+
# Now reshape into multi-head format
|
|
209
|
+
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) # (B, h, N, d)
|
|
210
|
+
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads) # (B, h, N, d)
|
|
211
|
+
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads) # (B, h, N, d)
|
|
212
|
+
|
|
213
|
+
if self.qk_norm:
|
|
214
|
+
q = self.q_norm(q.to(self.q_norm.weight.dtype))
|
|
215
|
+
k = self.k_norm(k.to(self.k_norm.weight.dtype)) # (B, N, C)
|
|
216
|
+
|
|
217
|
+
# Prepare relative position bias
|
|
218
|
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
|
219
|
+
self.window_size, self.window_size, -1) # W, W, nH
|
|
220
|
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, W, W
|
|
221
|
+
|
|
222
|
+
# Combine with window mask if present
|
|
223
|
+
# if mask is not None:
|
|
224
|
+
# nW = mask.shape[0]
|
|
225
|
+
# # Expand relative_position_bias for all windows
|
|
226
|
+
# attn_mask = relative_position_bias.unsqueeze(0).expand(B_ // nW, -1, -1, -1) # (B_//nW, nH, W, W)
|
|
227
|
+
# # Add window mask (broadcast across heads)
|
|
228
|
+
# attn_mask = attn_mask + mask.unsqueeze(0).unsqueeze(1) # (B_//nW, nW, nH, W, W) -> (B_//nW, 1, 1, W, W)
|
|
229
|
+
# attn_mask = attn_mask.view(B_, self.num_heads, N, N)
|
|
230
|
+
# else:
|
|
231
|
+
# # Just use relative position bias
|
|
232
|
+
# attn_mask = relative_position_bias.unsqueeze(0) # (1, nH, W, W)
|
|
233
|
+
relative_position_bias = relative_position_bias * torch.sigmoid(self.pos_bias_scale)
|
|
234
|
+
attn_bias = relative_position_bias.expand(B_, -1, -1, -1)
|
|
235
|
+
# Use PyTorch's fused attention
|
|
236
|
+
x = F.scaled_dot_product_attention(
|
|
237
|
+
q, k, v,
|
|
238
|
+
attn_mask=attn_bias,
|
|
239
|
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
x = x.transpose(1, 2).reshape(B_, N, C)
|
|
243
|
+
x = self.proj(x)
|
|
244
|
+
x = self.proj_drop(x)
|
|
245
|
+
return x
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# -- RoPE --
|
|
249
|
+
def rotate_half(x):
|
|
250
|
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
|
251
|
+
x1, x2 = x.unbind(dim = -1)
|
|
252
|
+
x = torch.stack((-x2, x1), dim = -1)
|
|
253
|
+
return rearrange(x, '... d r -> ... (d r)')
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class VisionRotaryEmbeddingFast(nn.Module):
|
|
257
|
+
def __init__(
|
|
258
|
+
self,
|
|
259
|
+
dim,
|
|
260
|
+
max_seq_len=1024, # Set this large enough for your data (e.g. 1024 or 4096)
|
|
261
|
+
theta = 10000,
|
|
262
|
+
):
|
|
263
|
+
super().__init__()
|
|
264
|
+
|
|
265
|
+
# 1. Generate the frequencies (1D only)
|
|
266
|
+
# inv_freq shape: (dim // 2)
|
|
267
|
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
|
268
|
+
|
|
269
|
+
# 2. Generate position indices: [0, 1, ..., max_seq_len-1]
|
|
270
|
+
t = torch.arange(max_seq_len).float()
|
|
271
|
+
|
|
272
|
+
# 3. Compute outer product: (max_seq_len, dim // 2)
|
|
273
|
+
freqs = torch.outer(t, inv_freq)
|
|
274
|
+
|
|
275
|
+
# 4. Repeat frequencies to match the specific "rotate_half" format
|
|
276
|
+
# Your previous code used repeat(..., '... n -> ... (n r)', r=2)
|
|
277
|
+
# This doubles the last dim so it matches the input shape
|
|
278
|
+
freqs = repeat(freqs, 'n d -> n (d r)', r=2)
|
|
279
|
+
|
|
280
|
+
# 5. Compute Sin and Cos
|
|
281
|
+
freqs_cos = freqs.cos() # Shape (max_seq_len, dim)
|
|
282
|
+
freqs_sin = freqs.sin() # Shape (max_seq_len, dim)
|
|
283
|
+
|
|
284
|
+
# Register as buffers (so they are saved with state_dict but not trained)
|
|
285
|
+
self.register_buffer("freqs_cos", freqs_cos)
|
|
286
|
+
self.register_buffer("freqs_sin", freqs_sin)
|
|
287
|
+
|
|
288
|
+
print(f'======== RoPE 1D initialized with shape {self.freqs_cos.shape} ========')
|
|
289
|
+
|
|
290
|
+
def forward(self, t):
|
|
291
|
+
# t shape: (Batch, Heads, Seq_Len, Dim)
|
|
292
|
+
seq_len = t.shape[-2]
|
|
293
|
+
|
|
294
|
+
# Slice the cached frequencies to the current sequence length
|
|
295
|
+
# Reshape to (1, 1, Seq_Len, Dim) for broadcasting
|
|
296
|
+
cos = self.freqs_cos[:seq_len].view(1, 1, seq_len, -1)
|
|
297
|
+
sin = self.freqs_sin[:seq_len].view(1, 1, seq_len, -1)
|
|
298
|
+
|
|
299
|
+
return t * cos + rotate_half(t) * sin
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
# borrowed from LightningDiT https://github.com/hustvl/LightningDiT/blob/main/models/swiglu_ffn.py
|
|
5
|
+
class SwiGLUFFN(nn.Module):
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
in_features: int,
|
|
9
|
+
hidden_features=None,
|
|
10
|
+
out_features=None,
|
|
11
|
+
bias=True,
|
|
12
|
+
) -> None:
|
|
13
|
+
super().__init__()
|
|
14
|
+
out_features = out_features or in_features
|
|
15
|
+
hidden_features = hidden_features or in_features
|
|
16
|
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
|
17
|
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
|
18
|
+
|
|
19
|
+
# @torch.compile
|
|
20
|
+
def forward(self, x):
|
|
21
|
+
x12 = self.w12(x)
|
|
22
|
+
x1, x2 = x12.chunk(2, dim=-1)
|
|
23
|
+
hidden = F.silu(x1) * x2
|
|
24
|
+
return self.w3(hidden)
|