mps-flash-attn 0.1.0__cp314-cp314-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (33) hide show
  1. mps_flash_attn/_C.cpython-314-darwin.so +0 -0
  2. mps_flash_attn/__init__.py +246 -0
  3. mps_flash_attn/csrc/mps_flash_attn.cpp +441 -0
  4. mps_flash_attn/csrc/mps_flash_attn.mm +441 -0
  5. mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  6. mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  7. mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  8. mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  9. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  10. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  11. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  12. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  13. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  14. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  15. mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  16. mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  17. mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  18. mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  19. mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  20. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  21. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  22. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  23. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  24. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  25. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  26. mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  27. mps_flash_attn/kernels/manifest.json +27 -0
  28. mps_flash_attn/lib/libMFABridge.dylib +0 -0
  29. mps_flash_attn-0.1.0.dist-info/METADATA +264 -0
  30. mps_flash_attn-0.1.0.dist-info/RECORD +33 -0
  31. mps_flash_attn-0.1.0.dist-info/WHEEL +5 -0
  32. mps_flash_attn-0.1.0.dist-info/licenses/LICENSE +27 -0
  33. mps_flash_attn-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,441 @@
1
+ /**
2
+ * MPS Flash Attention - PyTorch C++ Extension
3
+ *
4
+ * Bridges PyTorch MPS tensors to the MFA Swift library.
5
+ */
6
+
7
+ #include <torch/extension.h>
8
+ #include <ATen/mps/MPSStream.h>
9
+ #include <ATen/native/mps/OperationUtils.h>
10
+
11
+ #import <Metal/Metal.h>
12
+ #import <Foundation/Foundation.h>
13
+
14
+ #include <dlfcn.h>
15
+
16
+ // ============================================================================
17
+ // MFA Bridge Function Types
18
+ // ============================================================================
19
+
20
+ typedef bool (*mfa_init_fn)();
21
+ typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool); // Added causal param
22
+ typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
23
+ typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
24
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
25
+ int32_t, int32_t);
26
+ typedef void (*mfa_release_kernel_fn)(void*);
27
+
28
+ // Global function pointers
29
+ static mfa_init_fn g_mfa_init = nullptr;
30
+ static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
31
+ static mfa_forward_fn g_mfa_forward = nullptr;
32
+ static mfa_backward_fn g_mfa_backward = nullptr;
33
+ static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
34
+ static void* g_dylib_handle = nullptr;
35
+ static bool g_initialized = false;
36
+
37
+ // ============================================================================
38
+ // Load MFA Bridge Library
39
+ // ============================================================================
40
+
41
+ static bool load_mfa_bridge() {
42
+ if (g_dylib_handle) return true;
43
+
44
+ // Try to find the dylib relative to this extension
45
+ // First try the standard location
46
+ const char* paths[] = {
47
+ "libMFABridge.dylib",
48
+ "./libMFABridge.dylib",
49
+ "../swift-bridge/.build/release/libMFABridge.dylib",
50
+ nullptr
51
+ };
52
+
53
+ for (int i = 0; paths[i] != nullptr; i++) {
54
+ g_dylib_handle = dlopen(paths[i], RTLD_NOW);
55
+ if (g_dylib_handle) break;
56
+ }
57
+
58
+ if (!g_dylib_handle) {
59
+ // Try with absolute path from environment
60
+ const char* mfa_path = getenv("MFA_BRIDGE_PATH");
61
+ if (mfa_path) {
62
+ g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
63
+ }
64
+ }
65
+
66
+ if (!g_dylib_handle) {
67
+ throw std::runtime_error(
68
+ "Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
69
+ }
70
+
71
+ // Load function pointers
72
+ g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
73
+ g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
74
+ g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
75
+ g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
76
+ g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
77
+
78
+ if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward || !g_mfa_backward || !g_mfa_release_kernel) {
79
+ throw std::runtime_error("Failed to load MFA bridge functions");
80
+ }
81
+
82
+ return true;
83
+ }
84
+
85
+ // ============================================================================
86
+ // Get MTLBuffer from PyTorch MPS Tensor
87
+ // ============================================================================
88
+
89
+ struct BufferInfo {
90
+ id<MTLBuffer> buffer;
91
+ int64_t byte_offset;
92
+ };
93
+
94
+ static BufferInfo getBufferInfo(const at::Tensor& tensor) {
95
+ TORCH_CHECK(tensor.device().is_mps(), "Tensor must be on MPS device");
96
+ TORCH_CHECK(tensor.is_contiguous(), "Tensor must be contiguous");
97
+
98
+ // Get the underlying Metal buffer (covers entire storage)
99
+ id<MTLBuffer> buffer = at::native::mps::getMTLBufferStorage(tensor);
100
+
101
+ // Calculate byte offset: storage_offset() is in elements, multiply by element size
102
+ int64_t element_size = tensor.element_size();
103
+ int64_t byte_offset = tensor.storage_offset() * element_size;
104
+
105
+ return {buffer, byte_offset};
106
+ }
107
+
108
+ // ============================================================================
109
+ // Kernel Cache
110
+ // ============================================================================
111
+
112
+ struct KernelCacheKey {
113
+ int64_t seq_len_q;
114
+ int64_t seq_len_kv;
115
+ int64_t head_dim;
116
+ bool low_precision;
117
+ bool low_precision_outputs;
118
+ bool causal;
119
+
120
+ bool operator==(const KernelCacheKey& other) const {
121
+ return seq_len_q == other.seq_len_q &&
122
+ seq_len_kv == other.seq_len_kv &&
123
+ head_dim == other.head_dim &&
124
+ low_precision == other.low_precision &&
125
+ low_precision_outputs == other.low_precision_outputs &&
126
+ causal == other.causal;
127
+ }
128
+ };
129
+
130
+ struct KernelCacheKeyHash {
131
+ size_t operator()(const KernelCacheKey& k) const {
132
+ return std::hash<int64_t>()(k.seq_len_q) ^
133
+ (std::hash<int64_t>()(k.seq_len_kv) << 1) ^
134
+ (std::hash<int64_t>()(k.head_dim) << 2) ^
135
+ (std::hash<bool>()(k.low_precision) << 3) ^
136
+ (std::hash<bool>()(k.low_precision_outputs) << 4) ^
137
+ (std::hash<bool>()(k.causal) << 5);
138
+ }
139
+ };
140
+
141
+ static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
142
+
143
+ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal) {
144
+ KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal};
145
+
146
+ auto it = g_kernel_cache.find(key);
147
+ if (it != g_kernel_cache.end()) {
148
+ return it->second;
149
+ }
150
+
151
+ void* kernel = g_mfa_create_kernel(
152
+ static_cast<int32_t>(seq_q),
153
+ static_cast<int32_t>(seq_kv),
154
+ static_cast<int32_t>(head_dim),
155
+ low_prec,
156
+ low_prec_outputs,
157
+ causal
158
+ );
159
+
160
+ if (!kernel) {
161
+ throw std::runtime_error("Failed to create MFA kernel");
162
+ }
163
+
164
+ g_kernel_cache[key] = kernel;
165
+ return kernel;
166
+ }
167
+
168
+ // ============================================================================
169
+ // Flash Attention Forward
170
+ // ============================================================================
171
+
172
+ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
173
+ const at::Tensor& query, // (B, H, N, D)
174
+ const at::Tensor& key, // (B, H, N, D)
175
+ const at::Tensor& value, // (B, H, N, D)
176
+ bool is_causal
177
+ ) {
178
+ // Initialize MFA on first call
179
+ if (!g_initialized) {
180
+ load_mfa_bridge();
181
+ if (!g_mfa_init()) {
182
+ throw std::runtime_error("Failed to initialize MFA");
183
+ }
184
+ g_initialized = true;
185
+ }
186
+
187
+ // Validate inputs
188
+ TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
189
+ TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
190
+ TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
191
+ TORCH_CHECK(query.device().is_mps(), "Query must be on MPS device");
192
+ TORCH_CHECK(key.device().is_mps(), "Key must be on MPS device");
193
+ TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
194
+
195
+ const int64_t batch_size = query.size(0);
196
+ const int64_t num_heads = query.size(1);
197
+ const int64_t seq_len_q = query.size(2);
198
+ const int64_t head_dim = query.size(3);
199
+ const int64_t seq_len_kv = key.size(2);
200
+
201
+ TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
202
+ "Batch size mismatch");
203
+ TORCH_CHECK(key.size(1) == num_heads && value.size(1) == num_heads,
204
+ "Number of heads mismatch");
205
+ TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
206
+ "Head dimension mismatch");
207
+
208
+ // Determine precision
209
+ bool low_precision = (query.scalar_type() == at::kHalf ||
210
+ query.scalar_type() == at::kBFloat16);
211
+
212
+ // For fp16 inputs, we can now output directly to fp16 (no extra conversion needed!)
213
+ bool low_precision_outputs = low_precision;
214
+
215
+ // Make inputs contiguous
216
+ auto q = query.contiguous();
217
+ auto k = key.contiguous();
218
+ auto v = value.contiguous();
219
+
220
+ // Allocate output in the appropriate precision
221
+ // With lowPrecisionOutputs=true, MFA writes FP16 directly
222
+ at::Tensor output;
223
+ if (low_precision_outputs) {
224
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
225
+ query.options().dtype(at::kHalf));
226
+ } else {
227
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
228
+ query.options().dtype(at::kFloat));
229
+ }
230
+
231
+ // Allocate logsumexp (for backward pass, always fp32)
232
+ auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
233
+ query.options().dtype(at::kFloat));
234
+
235
+ // Get or create kernel with matching output precision and causal mode
236
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
237
+
238
+ // Get Metal buffers with byte offsets
239
+ auto q_info = getBufferInfo(q);
240
+ auto k_info = getBufferInfo(k);
241
+ auto v_info = getBufferInfo(v);
242
+ auto o_info = getBufferInfo(output);
243
+ auto l_info = getBufferInfo(logsumexp);
244
+
245
+ // Synchronize with PyTorch's MPS stream
246
+ @autoreleasepool {
247
+ // Wait for PyTorch operations to complete
248
+ auto stream = at::mps::getCurrentMPSStream();
249
+ stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
250
+
251
+ // Execute MFA with storage byte offsets
252
+ bool success = g_mfa_forward(
253
+ kernel,
254
+ (__bridge void*)q_info.buffer,
255
+ (__bridge void*)k_info.buffer,
256
+ (__bridge void*)v_info.buffer,
257
+ (__bridge void*)o_info.buffer,
258
+ (__bridge void*)l_info.buffer,
259
+ q_info.byte_offset,
260
+ k_info.byte_offset,
261
+ v_info.byte_offset,
262
+ o_info.byte_offset,
263
+ l_info.byte_offset,
264
+ static_cast<int32_t>(batch_size),
265
+ static_cast<int32_t>(num_heads)
266
+ );
267
+
268
+ if (!success) {
269
+ throw std::runtime_error("MFA forward pass failed");
270
+ }
271
+ }
272
+
273
+ // Output is already in the correct dtype (fp16 or fp32)
274
+ // Return both output and logsumexp (needed for backward pass)
275
+ return std::make_tuple(output, logsumexp);
276
+ }
277
+
278
+ // Simple forward that only returns output (for inference)
279
+ at::Tensor mps_flash_attention_forward(
280
+ const at::Tensor& query,
281
+ const at::Tensor& key,
282
+ const at::Tensor& value,
283
+ bool is_causal
284
+ ) {
285
+ auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal);
286
+ return output;
287
+ }
288
+
289
+ // ============================================================================
290
+ // Flash Attention Backward
291
+ // ============================================================================
292
+
293
+ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
294
+ const at::Tensor& grad_output, // (B, H, N, D)
295
+ const at::Tensor& query, // (B, H, N, D)
296
+ const at::Tensor& key, // (B, H, N, D)
297
+ const at::Tensor& value, // (B, H, N, D)
298
+ const at::Tensor& output, // (B, H, N, D)
299
+ const at::Tensor& logsumexp, // (B, H, N)
300
+ bool is_causal
301
+ ) {
302
+ // Initialize MFA on first call
303
+ if (!g_initialized) {
304
+ load_mfa_bridge();
305
+ if (!g_mfa_init()) {
306
+ throw std::runtime_error("Failed to initialize MFA");
307
+ }
308
+ g_initialized = true;
309
+ }
310
+
311
+ // Validate inputs
312
+ TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
313
+ TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
314
+ TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
315
+ TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
316
+ TORCH_CHECK(output.dim() == 4, "Output must be 4D (B, H, N, D)");
317
+ TORCH_CHECK(logsumexp.dim() == 3, "Logsumexp must be 3D (B, H, N)");
318
+
319
+ const int64_t batch_size = query.size(0);
320
+ const int64_t num_heads = query.size(1);
321
+ const int64_t seq_len_q = query.size(2);
322
+ const int64_t head_dim = query.size(3);
323
+ const int64_t seq_len_kv = key.size(2);
324
+
325
+ // Determine precision
326
+ bool low_precision = (query.scalar_type() == at::kHalf ||
327
+ query.scalar_type() == at::kBFloat16);
328
+ bool low_precision_outputs = low_precision;
329
+
330
+ // Make inputs contiguous
331
+ auto q = query.contiguous();
332
+ auto k = key.contiguous();
333
+ auto v = value.contiguous();
334
+ auto o = output.contiguous();
335
+ auto dO = grad_output.contiguous();
336
+ auto lse = logsumexp.contiguous();
337
+
338
+ // Get or create kernel (with causal mode)
339
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
340
+
341
+ // Allocate D buffer (dO * O reduction, always fp32)
342
+ auto D = at::empty({batch_size, num_heads, seq_len_q},
343
+ query.options().dtype(at::kFloat));
344
+
345
+ // Allocate gradients (always fp32 for numerical stability)
346
+ auto dQ = at::zeros({batch_size, num_heads, seq_len_q, head_dim},
347
+ query.options().dtype(at::kFloat));
348
+ auto dK = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
349
+ query.options().dtype(at::kFloat));
350
+ auto dV = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
351
+ query.options().dtype(at::kFloat));
352
+
353
+ // Get Metal buffers with byte offsets
354
+ auto q_info = getBufferInfo(q);
355
+ auto k_info = getBufferInfo(k);
356
+ auto v_info = getBufferInfo(v);
357
+ auto o_info = getBufferInfo(o);
358
+ auto do_info = getBufferInfo(dO);
359
+ auto l_info = getBufferInfo(lse);
360
+ auto d_info = getBufferInfo(D);
361
+ auto dq_info = getBufferInfo(dQ);
362
+ auto dk_info = getBufferInfo(dK);
363
+ auto dv_info = getBufferInfo(dV);
364
+
365
+ // Execute backward pass
366
+ @autoreleasepool {
367
+ auto stream = at::mps::getCurrentMPSStream();
368
+ stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
369
+
370
+ bool success = g_mfa_backward(
371
+ kernel,
372
+ (__bridge void*)q_info.buffer,
373
+ (__bridge void*)k_info.buffer,
374
+ (__bridge void*)v_info.buffer,
375
+ (__bridge void*)o_info.buffer,
376
+ (__bridge void*)do_info.buffer,
377
+ (__bridge void*)l_info.buffer,
378
+ (__bridge void*)d_info.buffer,
379
+ (__bridge void*)dq_info.buffer,
380
+ (__bridge void*)dk_info.buffer,
381
+ (__bridge void*)dv_info.buffer,
382
+ q_info.byte_offset,
383
+ k_info.byte_offset,
384
+ v_info.byte_offset,
385
+ o_info.byte_offset,
386
+ do_info.byte_offset,
387
+ l_info.byte_offset,
388
+ d_info.byte_offset,
389
+ dq_info.byte_offset,
390
+ dk_info.byte_offset,
391
+ dv_info.byte_offset,
392
+ static_cast<int32_t>(batch_size),
393
+ static_cast<int32_t>(num_heads)
394
+ );
395
+
396
+ if (!success) {
397
+ throw std::runtime_error("MFA backward pass failed");
398
+ }
399
+ }
400
+
401
+ // Convert gradients back to input dtype if needed
402
+ if (low_precision) {
403
+ dQ = dQ.to(query.scalar_type());
404
+ dK = dK.to(query.scalar_type());
405
+ dV = dV.to(query.scalar_type());
406
+ }
407
+
408
+ return std::make_tuple(dQ, dK, dV);
409
+ }
410
+
411
+ // ============================================================================
412
+ // Python Bindings
413
+ // ============================================================================
414
+
415
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
416
+ m.doc() = "MPS Flash Attention - Metal accelerated attention for Apple Silicon";
417
+
418
+ m.def("forward", &mps_flash_attention_forward,
419
+ "Flash Attention forward pass (returns output only)",
420
+ py::arg("query"),
421
+ py::arg("key"),
422
+ py::arg("value"),
423
+ py::arg("is_causal") = false);
424
+
425
+ m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
426
+ "Flash Attention forward pass (returns output and logsumexp for backward)",
427
+ py::arg("query"),
428
+ py::arg("key"),
429
+ py::arg("value"),
430
+ py::arg("is_causal") = false);
431
+
432
+ m.def("backward", &mps_flash_attention_backward,
433
+ "Flash Attention backward pass",
434
+ py::arg("grad_output"),
435
+ py::arg("query"),
436
+ py::arg("key"),
437
+ py::arg("value"),
438
+ py::arg("output"),
439
+ py::arg("logsumexp"),
440
+ py::arg("is_causal") = false);
441
+ }
@@ -0,0 +1,27 @@
1
+ {
2
+ "version": "1.0",
3
+ "files": [
4
+ "06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib",
5
+ "adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib",
6
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin",
7
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin",
8
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin",
9
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin",
10
+ "ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib",
11
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin",
12
+ "a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib",
13
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib",
14
+ "975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib",
15
+ "2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib",
16
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin",
17
+ "09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib",
18
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin",
19
+ "0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib",
20
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin",
21
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin",
22
+ "771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib",
23
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin",
24
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib",
25
+ "f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib"
26
+ ]
27
+ }
Binary file