fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,333 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch.library import register_fake
11
+
12
+
13
+ torch.library.define(
14
+ "blackwell_fmha::fmha_fwd",
15
+ "(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv, Tensor? page_table, int seqlen_k=-1, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True) -> (Tensor, Tensor)",
16
+ tags=torch.Tag.pt2_compliant_tag,
17
+ )
18
+
19
+ torch.library.define(
20
+ "blackwell_fmha::fmha_bwd",
21
+ "(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True, bool deterministic=False) -> (Tensor, Tensor, Tensor)",
22
+ tags=torch.Tag.pt2_compliant_tag,
23
+ )
24
+
25
+
26
+ @torch.library.impl("blackwell_fmha::fmha_fwd", "cuda")
27
+ def custom_op_fmha(
28
+ q: torch.Tensor,
29
+ k: torch.Tensor,
30
+ v: torch.Tensor,
31
+ cu_seqlens_q: Optional[torch.Tensor] = None,
32
+ cu_seqlens_k: Optional[torch.Tensor] = None,
33
+ max_seq_len_q: Optional[int] = None,
34
+ max_seq_len_k: Optional[int] = None,
35
+ softmax_scale: Optional[float] = None,
36
+ causal: bool = False,
37
+ seqlen_kv: Optional[torch.Tensor] = None,
38
+ page_table: Optional[torch.Tensor] = None,
39
+ seqlen_k: Optional[int] = None,
40
+ window_size_left: int = -1,
41
+ window_size_right: int = -1,
42
+ bottom_right: bool = True,
43
+ ) -> tuple[torch.Tensor, torch.Tensor]:
44
+ assert q.is_contiguous(), "q is not contiguous"
45
+ assert k.is_contiguous(), "k is not contiguous"
46
+ assert v.is_contiguous(), "v is not contiguous"
47
+ assert q.is_cuda, "q must be on GPU"
48
+ assert k.is_cuda, "k must be on GPU"
49
+ assert v.is_cuda, "v must be on GPU"
50
+
51
+ return torch.ops.fbgemm.fmha_fwd(
52
+ q,
53
+ k,
54
+ v,
55
+ cu_seqlens_q=cu_seqlens_q,
56
+ cu_seqlens_k=cu_seqlens_k,
57
+ max_seq_len_q=max_seq_len_q,
58
+ max_seq_len_k=max_seq_len_k,
59
+ softmax_scale=softmax_scale,
60
+ causal=causal,
61
+ seqlen_kv=seqlen_kv,
62
+ page_table=page_table,
63
+ seqlen_k=seqlen_k,
64
+ window_size_left=window_size_left,
65
+ window_size_right=window_size_right,
66
+ bottom_right=bottom_right,
67
+ )
68
+
69
+
70
+ @register_fake("blackwell_fmha::fmha_fwd")
71
+ def fmha_fwd_meta(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ cu_seqlens_q: Optional[torch.Tensor] = None,
76
+ cu_seqlens_k: Optional[torch.Tensor] = None,
77
+ max_seq_len_q: Optional[int] = None,
78
+ max_seq_len_k: Optional[int] = None,
79
+ softmax_scale: Optional[float] = None,
80
+ causal: bool = False,
81
+ seqlen_kv: Optional[torch.Tensor] = None,
82
+ page_table: Optional[torch.Tensor] = None,
83
+ seqlen_k: Optional[int] = None,
84
+ window_size_left: int = -1,
85
+ window_size_right: int = -1,
86
+ bottom_right: bool = True,
87
+ ):
88
+ if q.dtype == torch.float16:
89
+ out_dtype = torch.float16
90
+ elif q.dtype == torch.bfloat16:
91
+ out_dtype = torch.bfloat16
92
+ elif q.dtype == torch.float8_e4m3fn:
93
+ # Output is BF16 when input is FP8
94
+ out_dtype = torch.bfloat16
95
+ else:
96
+ raise RuntimeError(f"Unsupported dtype for q: {q.dtype}")
97
+
98
+ kIsVarlen = max_seq_len_q is not None
99
+ if kIsVarlen:
100
+ assert cu_seqlens_q is not None
101
+ SQ = q.shape[0]
102
+ H_Q = q.shape[1]
103
+ B = cu_seqlens_q.shape[0] - 1
104
+ else:
105
+ SQ = q.shape[1]
106
+ H_Q = q.shape[2]
107
+ B = q.shape[0]
108
+ device = q.device
109
+ options2 = {"dtype": torch.float32, "device": device}
110
+ if kIsVarlen:
111
+ assert max_seq_len_q is not None
112
+ out = torch.empty_like(q, dtype=out_dtype)
113
+ size = out.size()
114
+ stride = out.stride()
115
+ storage_offset = q.shape[-1] * max_seq_len_q * H_Q # example scalar offset
116
+ out1 = torch.as_strided(
117
+ out, size=size, stride=stride, storage_offset=storage_offset
118
+ )
119
+ else:
120
+ out1 = torch.empty_like(q, dtype=out_dtype)
121
+
122
+ if kIsVarlen:
123
+ out2 = torch.empty((1, H_Q, SQ), **options2) # type: ignore
124
+ else:
125
+ out2 = torch.empty((B, H_Q, SQ), **options2) # type: ignore
126
+ return out1, out2
127
+
128
+
129
+ @torch.library.impl("blackwell_fmha::fmha_bwd", "cuda")
130
+ def custom_op_fmha_bwd(
131
+ dOutput: torch.Tensor,
132
+ query: torch.Tensor,
133
+ key: torch.Tensor,
134
+ value: torch.Tensor,
135
+ output: torch.Tensor,
136
+ softmax_lse: torch.Tensor,
137
+ cu_seqlens_q: Optional[torch.Tensor] = None,
138
+ cu_seqlens_k: Optional[torch.Tensor] = None,
139
+ max_seq_len_q: Optional[int] = None,
140
+ max_seq_len_k: Optional[int] = None,
141
+ softmax_scale: Optional[float] = None,
142
+ causal: bool = False,
143
+ window_size_left: int = -1,
144
+ window_size_right: int = -1,
145
+ bottom_right: bool = True,
146
+ deterministic: bool = False,
147
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
148
+
149
+ return torch.ops.fbgemm.fmha_bwd(
150
+ dOutput,
151
+ query,
152
+ key,
153
+ value,
154
+ output,
155
+ softmax_lse,
156
+ cu_seqlens_q=cu_seqlens_q,
157
+ cu_seqlens_k=cu_seqlens_k,
158
+ max_seq_len_q=max_seq_len_q,
159
+ max_seq_len_k=max_seq_len_k,
160
+ softmax_scale=softmax_scale,
161
+ causal=causal,
162
+ window_size_left=window_size_left,
163
+ window_size_right=window_size_right,
164
+ bottom_right=bottom_right,
165
+ deterministic=deterministic,
166
+ )
167
+
168
+
169
+ @register_fake("blackwell_fmha::fmha_bwd")
170
+ def fmha_bwd_meta(
171
+ dOutput: torch.Tensor,
172
+ query: torch.Tensor,
173
+ key: torch.Tensor,
174
+ value: torch.Tensor,
175
+ output: torch.Tensor,
176
+ softmax_lse: torch.Tensor,
177
+ cu_seqlens_q: Optional[torch.Tensor] = None,
178
+ cu_seqlens_k: Optional[torch.Tensor] = None,
179
+ max_seq_len_q: Optional[int] = None,
180
+ max_seq_len_k: Optional[int] = None,
181
+ softmax_scale: Optional[float] = None,
182
+ causal: bool = False,
183
+ window_size_left: int = -1,
184
+ window_size_right: int = -1,
185
+ bottom_right: bool = True,
186
+ deterministic: bool = False,
187
+ ):
188
+ return (
189
+ torch.empty_like(query),
190
+ torch.empty_like(key),
191
+ torch.empty_like(value),
192
+ )
193
+
194
+
195
+ def _backward(ctx, *grad):
196
+ if ctx.is_gen:
197
+ # For gen case, no backward pass is needed (generation is inference only)
198
+ raise RuntimeError("Backward pass is not supported for generation phase (sq=1)")
199
+ q, k, v, out, softmax_lse = ctx.saved_tensors
200
+ if not grad[0].is_contiguous():
201
+ grad0 = grad[0].contiguous()
202
+ else:
203
+ grad0 = grad[0]
204
+ if not softmax_lse.is_contiguous:
205
+ softmax_lse = softmax_lse.contiguous()
206
+ if not out.is_contiguous:
207
+ out = out.contiguous()
208
+ if not q.is_contiguous:
209
+ q = q.contiguous()
210
+ if not k.is_contiguous:
211
+ k = k.contiguous()
212
+
213
+ if not softmax_lse.is_contiguous:
214
+ softmax_lse = softmax_lse.contiguous()
215
+ if not out.is_contiguous:
216
+ out = out.contiguous()
217
+ if not q.is_contiguous:
218
+ q = q.contiguous()
219
+ if not k.is_contiguous:
220
+ k = k.contiguous()
221
+
222
+ dq, dk, dv = torch.ops.blackwell_fmha.fmha_bwd(
223
+ grad0,
224
+ q,
225
+ k,
226
+ v,
227
+ out,
228
+ softmax_lse,
229
+ ctx.cu_seqlens_q,
230
+ ctx.cu_seqlens_k,
231
+ ctx.max_seq_len_q,
232
+ ctx.max_seq_len_k,
233
+ ctx.softmax_scale,
234
+ ctx.causal,
235
+ ctx.window_size_left,
236
+ ctx.window_size_right,
237
+ ctx.bottom_right,
238
+ ctx.deterministic,
239
+ )
240
+ return (
241
+ dq,
242
+ dk,
243
+ dv,
244
+ None,
245
+ None,
246
+ None,
247
+ None,
248
+ None,
249
+ None,
250
+ None,
251
+ None,
252
+ None,
253
+ None,
254
+ None,
255
+ None,
256
+ )
257
+
258
+
259
+ def _setup_context(ctx, inputs, output):
260
+ (
261
+ q,
262
+ k,
263
+ v,
264
+ cu_seqlens_q,
265
+ cu_seqlens_k,
266
+ max_seq_len_q,
267
+ max_seq_len_k,
268
+ softmax_scale,
269
+ causal,
270
+ seqlen_kv,
271
+ page_table,
272
+ seqlen_k,
273
+ window_size_left,
274
+ window_size_right,
275
+ bottom_right,
276
+ ) = inputs
277
+ (out, softmax_lse) = output
278
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
279
+ ctx.softmax_scale = softmax_scale
280
+ ctx.causal = causal
281
+ ctx.max_seq_len_q = max_seq_len_q
282
+ ctx.max_seq_len_k = max_seq_len_k
283
+ ctx.cu_seqlens_q = cu_seqlens_q
284
+ ctx.cu_seqlens_k = cu_seqlens_k
285
+ ctx.window_size_left = window_size_left
286
+ ctx.window_size_right = window_size_right
287
+ ctx.bottom_right = bottom_right
288
+ ctx.deterministic = False # Set default value
289
+ ctx.is_gen = False
290
+
291
+
292
+ # This code adds training support for the operator. You must provide us
293
+ # the backward formula for the operator and a `setup_context` function
294
+ # to save values to be used in the backward.
295
+ torch.library.register_autograd(
296
+ "blackwell_fmha::fmha_fwd", _backward, setup_context=_setup_context
297
+ )
298
+
299
+
300
+ def cutlass_blackwell_fmha_custom_op(
301
+ q: torch.Tensor,
302
+ k: torch.Tensor,
303
+ v: torch.Tensor,
304
+ softmax_scale: float | None = None,
305
+ causal: bool = False,
306
+ cu_seqlens_q: torch.Tensor | None = None,
307
+ cu_seqlens_k: torch.Tensor | None = None,
308
+ max_seq_len_q: int | None = None,
309
+ max_seq_len_k: int | None = None,
310
+ seqlen_kv: torch.Tensor | None = None,
311
+ page_table: torch.Tensor | None = None,
312
+ seqlen_k: int | None = -1,
313
+ window_size_left: int | None = -1,
314
+ window_size_right: int | None = -1,
315
+ bottom_right: bool | None = True,
316
+ ):
317
+ return torch.ops.blackwell_fmha.fmha_fwd(
318
+ q=q,
319
+ k=k,
320
+ v=v,
321
+ cu_seqlens_q=cu_seqlens_q,
322
+ cu_seqlens_k=cu_seqlens_k,
323
+ max_seq_len_q=max_seq_len_q,
324
+ max_seq_len_k=max_seq_len_k,
325
+ softmax_scale=softmax_scale,
326
+ causal=causal,
327
+ seqlen_kv=seqlen_kv,
328
+ page_table=page_table,
329
+ seqlen_k=seqlen_k,
330
+ window_size_left=window_size_left,
331
+ window_size_right=window_size_right,
332
+ bottom_right=bottom_right,
333
+ )[0]