fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,421 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from typing import Optional, Union
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+
15
+
16
+ # Function APIs
17
+ def combine_shuffling(
18
+ tokens: torch.Tensor,
19
+ token_counts: torch.Tensor,
20
+ expert_start: Optional[int] = None,
21
+ expert_end: Optional[int] = None,
22
+ is_padded: bool = False,
23
+ ) -> tuple[torch.Tensor, torch.Tensor]:
24
+ # pyre-ignore
25
+ return _combine_or_split_shuffling(
26
+ tokens=tokens,
27
+ token_counts=token_counts,
28
+ expert_start=expert_start,
29
+ expert_end=expert_end,
30
+ is_padded=is_padded,
31
+ is_combine=True,
32
+ )
33
+
34
+
35
+ def split_shuffling(
36
+ tokens: torch.Tensor,
37
+ token_counts: torch.Tensor,
38
+ expert_start: Optional[int] = None,
39
+ expert_end: Optional[int] = None,
40
+ is_padded: bool = False,
41
+ init_with_zeros: bool = False,
42
+ ) -> torch.Tensor:
43
+ # pyre-ignore
44
+ return _combine_or_split_shuffling(
45
+ tokens=tokens,
46
+ token_counts=token_counts,
47
+ expert_start=expert_start,
48
+ expert_end=expert_end,
49
+ is_padded=is_padded,
50
+ is_combine=False,
51
+ init_with_zeros=init_with_zeros,
52
+ )
53
+
54
+
55
+ def _combine_or_split_shuffling(
56
+ tokens: torch.Tensor,
57
+ token_counts: torch.Tensor,
58
+ expert_start: Optional[int],
59
+ expert_end: Optional[int],
60
+ is_padded: bool,
61
+ is_combine: bool,
62
+ init_with_zeros: bool = False,
63
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
64
+ # T is intentionally ignored in kernel interface to avoid recompilation
65
+ assert tokens.is_contiguous()
66
+ assert token_counts.is_contiguous()
67
+
68
+ T, D = tokens.shape
69
+ EP, E = token_counts.shape
70
+ B_T = -1
71
+ if is_padded:
72
+ assert T % EP == 0
73
+ B_T = T // EP
74
+
75
+ if expert_start is None:
76
+ expert_start = 0
77
+ if expert_end is None:
78
+ expert_end = E
79
+
80
+ EG: int = expert_end - expert_start
81
+
82
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
83
+ SPLIT_D = max(NUM_SMS // (EP * EG), 1)
84
+ SPLIT_D = triton.next_power_of_2(SPLIT_D + 1)
85
+ if T <= 1024:
86
+ SPLIT_D //= 2
87
+
88
+ if is_combine:
89
+ grid = (EP * EG * SPLIT_D + 1,)
90
+ else:
91
+ grid = (EP * EG * SPLIT_D,)
92
+
93
+ output_tokens = (
94
+ torch.zeros_like(tokens) if init_with_zeros else torch.empty_like(tokens)
95
+ )
96
+ if is_combine:
97
+ output_token_counts = torch.empty(
98
+ EG + 1, dtype=token_counts.dtype, device=token_counts.device
99
+ )
100
+ else:
101
+ output_token_counts = None
102
+
103
+ BLOCK_E = max(triton.next_power_of_2(E), 8)
104
+ BLOCK_EG = max(triton.next_power_of_2(EG), 8)
105
+ BLOCK_EP = max(triton.next_power_of_2(EP), 8)
106
+
107
+ _fbgemm_combine_or_split_shuffling[grid](
108
+ tokens,
109
+ token_counts,
110
+ output_tokens,
111
+ output_token_counts,
112
+ is_combine,
113
+ expert_start,
114
+ is_padded,
115
+ B_T,
116
+ EG,
117
+ EP,
118
+ E,
119
+ D,
120
+ BLOCK_E,
121
+ BLOCK_EG,
122
+ BLOCK_EP,
123
+ SPLIT_D,
124
+ )
125
+
126
+ if is_combine:
127
+ assert output_token_counts is not None
128
+ return output_tokens, output_token_counts
129
+ else:
130
+ return output_tokens
131
+
132
+
133
+ # Torch Custom Op Registrations
134
+ _COMBINE_SHUFFLING_OP_NAME = "fbgemm::combine_shuffling"
135
+
136
+ torch.library.define(
137
+ "fbgemm::combine_shuffling",
138
+ "(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_padded = False) -> (Tensor, Tensor)",
139
+ )
140
+
141
+
142
+ @torch.library.impl(_COMBINE_SHUFFLING_OP_NAME, "Meta")
143
+ def combine_shuffling_meta(
144
+ tokens,
145
+ token_counts,
146
+ expert_start,
147
+ expert_end,
148
+ is_padded,
149
+ ):
150
+ _, E = token_counts.shape
151
+ if expert_start is None:
152
+ expert_start = 0
153
+ if expert_end is None:
154
+ expert_end = E
155
+
156
+ EG: int = expert_end - expert_start
157
+ output_tokens = torch.empty_like(tokens)
158
+ output_token_counts = torch.empty(
159
+ EG + 1, dtype=token_counts.dtype, device=token_counts.device
160
+ )
161
+ return output_tokens, output_token_counts
162
+
163
+
164
+ @torch.library.impl(_COMBINE_SHUFFLING_OP_NAME, "CUDA")
165
+ def combine_shuffling_cuda(
166
+ tokens,
167
+ token_counts,
168
+ expert_start=None,
169
+ expert_end=None,
170
+ is_padded=False,
171
+ ):
172
+ return combine_shuffling(
173
+ tokens,
174
+ token_counts,
175
+ expert_start,
176
+ expert_end,
177
+ is_padded,
178
+ )
179
+
180
+
181
+ _SPLIT_SHUFFLING_OP_NAME = "fbgemm::split_shuffling"
182
+
183
+ torch.library.define(
184
+ "fbgemm::split_shuffling",
185
+ "(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_padded = False, bool? init_with_zeros = False) -> Tensor",
186
+ )
187
+
188
+
189
+ @torch.library.impl(_SPLIT_SHUFFLING_OP_NAME, "Meta")
190
+ def split_shuffling_meta(
191
+ tokens,
192
+ token_counts,
193
+ expert_start,
194
+ expert_end,
195
+ is_padded,
196
+ ):
197
+ output_tokens = torch.empty_like(tokens)
198
+ return output_tokens
199
+
200
+
201
+ @torch.library.impl(_SPLIT_SHUFFLING_OP_NAME, "CUDA")
202
+ def split_shuffling_cuda(
203
+ tokens,
204
+ token_counts,
205
+ expert_start=None,
206
+ expert_end=None,
207
+ is_padded=False,
208
+ ):
209
+ return split_shuffling(
210
+ tokens,
211
+ token_counts,
212
+ expert_start,
213
+ expert_end,
214
+ is_padded,
215
+ )
216
+
217
+
218
+ # Kernel Implementations
219
+ _NV_CONFIGS = [
220
+ triton.Config(
221
+ {
222
+ "BLOCK_T": block_t,
223
+ "BLOCK_D": block_d,
224
+ },
225
+ num_stages=num_stages,
226
+ num_warps=num_warps,
227
+ num_ctas=num_ctas,
228
+ )
229
+ for block_t in [32, 64]
230
+ for block_d in [256, 512, 1024]
231
+ for num_stages in [1, 3]
232
+ for num_warps in [8, 16]
233
+ for num_ctas in [1]
234
+ ]
235
+
236
+ _AMD_CONFIGS = [
237
+ triton.Config(
238
+ {
239
+ "BLOCK_T": block_t,
240
+ "BLOCK_D": block_d,
241
+ "waves_per_eu": waves_per_cu,
242
+ },
243
+ num_stages=num_stages,
244
+ num_warps=num_warps,
245
+ )
246
+ for block_t in [32, 64]
247
+ for block_d in [256, 512, 1024]
248
+ for num_stages in [1, 3]
249
+ for num_warps, waves_per_cu in [(8, 2), (16, 4)]
250
+ ]
251
+
252
+
253
+ @triton.autotune(
254
+ configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
255
+ key=[
256
+ "COMBINE",
257
+ "EG",
258
+ "EP",
259
+ "E",
260
+ "D",
261
+ ],
262
+ )
263
+ @triton.jit
264
+ def _fbgemm_combine_or_split_shuffling(
265
+ input_tokens_ptr,
266
+ input_token_counts_ptr,
267
+ output_tokens_ptr,
268
+ output_token_counts_ptr,
269
+ COMBINE: tl.constexpr,
270
+ EG_START,
271
+ PADDED,
272
+ B_T: tl.constexpr,
273
+ EG: tl.constexpr,
274
+ EP: tl.constexpr,
275
+ E: tl.constexpr,
276
+ D: tl.constexpr,
277
+ BLOCK_E: tl.constexpr,
278
+ BLOCK_EG: tl.constexpr,
279
+ BLOCK_EP: tl.constexpr,
280
+ SPLIT_D: tl.constexpr,
281
+ BLOCK_T: tl.constexpr,
282
+ BLOCK_D: tl.constexpr,
283
+ ) -> None:
284
+ """
285
+ tokens: [T, D]
286
+ input_token_counts: [EP, E]
287
+ output_tokens: [T, D]
288
+ output_token_counts: [E]
289
+ """
290
+ tidx = tl.program_id(0)
291
+
292
+ NUM_D_BLOCKS: tl.constexpr = (D + SPLIT_D * BLOCK_D - 1) // (SPLIT_D * BLOCK_D)
293
+
294
+ rank = tidx // (EG * SPLIT_D)
295
+ local_expert = (tidx % (EG * SPLIT_D)) // SPLIT_D
296
+ didx = tidx % SPLIT_D
297
+ # All experts in communication group
298
+ offs_e = tl.arange(0, BLOCK_E)
299
+ # Local experts
300
+ offs_eg = tl.arange(0, BLOCK_EG)
301
+ # Ranks
302
+ offs_ep = tl.arange(0, BLOCK_EP)
303
+
304
+ global_expert = local_expert + EG_START
305
+
306
+ input_token_counts = tl.load(
307
+ input_token_counts_ptr + offs_ep[:, None] * E + offs_e[None, :],
308
+ eviction_policy="evict_last",
309
+ mask=((offs_ep[:, None] < EP) & (offs_e[None, :] < E)),
310
+ other=0,
311
+ ) # [EP, E]
312
+
313
+ if E == EG:
314
+ input_token_counts_eg = input_token_counts
315
+ else:
316
+ input_token_counts_eg = tl.load(
317
+ input_token_counts_ptr + offs_ep[:, None] * E + EG_START + offs_eg[None, :],
318
+ eviction_policy="evict_last",
319
+ mask=((offs_ep[:, None] < EP) & (offs_eg[None, :] < EG)),
320
+ other=0,
321
+ ) # [EP, EG]
322
+
323
+ if COMBINE:
324
+ LAST_TILE: tl.constexpr = EP * EG * SPLIT_D
325
+
326
+ if tidx == LAST_TILE:
327
+ output_token_counts_eg = tl.sum(input_token_counts_eg, axis=0)
328
+ tl.store(
329
+ output_token_counts_ptr + offs_eg,
330
+ output_token_counts_eg,
331
+ mask=(offs_eg < EG),
332
+ )
333
+ output_token_counts_eg = tl.sum(output_token_counts_eg)
334
+ tl.store(output_token_counts_ptr + EG, output_token_counts_eg)
335
+ return
336
+
337
+ cond0 = offs_ep[:, None] < rank
338
+ cond1 = offs_ep[:, None] == rank
339
+
340
+ cond2 = offs_e[None, :] < global_expert
341
+
342
+ if PADDED:
343
+ tl.device_assert(B_T >= 0)
344
+ # Only need information from previous experts in the same rank.
345
+ ep_first_order = (
346
+ tl.sum(tl.where(cond1 and cond2, input_token_counts, 0)) + B_T * rank
347
+ )
348
+ else:
349
+ # r < rank || (r == rank && e < expert)
350
+ ep_first_order = tl.sum(
351
+ tl.where(cond0 or (cond1 and cond2), input_token_counts, 0)
352
+ )
353
+
354
+ cond4 = offs_eg[None, :] < local_expert
355
+ cond5 = offs_eg[None, :] == local_expert
356
+
357
+ # Expert first only need information from local experts across ranks.
358
+ # e < expert || (e == expert && r < rank)
359
+ expert_first_order = tl.sum(
360
+ tl.where(cond4 or (cond5 and cond0), input_token_counts_eg, 0)
361
+ )
362
+
363
+ if COMBINE:
364
+ input_offset = ep_first_order
365
+ output_offset = expert_first_order
366
+ else:
367
+ input_offset = expert_first_order
368
+ output_offset = ep_first_order
369
+
370
+ input_offset = input_offset.to(tl.int64)
371
+ output_offset = output_offset.to(tl.int64)
372
+
373
+ num_copy_tokens = tl.load(input_token_counts_ptr + rank * E + global_expert)
374
+ if num_copy_tokens == 0:
375
+ return
376
+
377
+ STEP_D: tl.constexpr = SPLIT_D * BLOCK_D
378
+ MASK_D: tl.constexpr = D % STEP_D != 0
379
+
380
+ num_t_blocks = tl.cdiv(num_copy_tokens, BLOCK_T)
381
+
382
+ t_1d_ptr = tl.arange(0, BLOCK_T)[:, None]
383
+ ti_1d_ptr = input_offset + t_1d_ptr
384
+ to_1d_ptr = output_offset + t_1d_ptr
385
+
386
+ d_1d_ptr = didx * NUM_D_BLOCKS * BLOCK_D + tl.arange(0, BLOCK_D)[None, :]
387
+
388
+ i_2d_ptr = input_tokens_ptr + ti_1d_ptr * D + d_1d_ptr
389
+ o_2d_ptr = output_tokens_ptr + to_1d_ptr * D + d_1d_ptr
390
+
391
+ for i in range(num_t_blocks * NUM_D_BLOCKS):
392
+ mask = t_1d_ptr < num_copy_tokens
393
+ if MASK_D:
394
+ mask &= d_1d_ptr < D
395
+
396
+ block = tl.load(
397
+ i_2d_ptr,
398
+ mask=mask,
399
+ )
400
+ tl.store(
401
+ o_2d_ptr,
402
+ value=block,
403
+ mask=mask,
404
+ )
405
+
406
+ if i % NUM_D_BLOCKS == (NUM_D_BLOCKS - 1): # pyre-ignore
407
+ # just to make sure constant folding happens
408
+ D_1D_SHIFT: tl.constexpr = -(NUM_D_BLOCKS - 1) * BLOCK_D
409
+ TD_2D_SHIFT: tl.constexpr = BLOCK_T * D + D_1D_SHIFT
410
+ # increment T, D
411
+ t_1d_ptr += BLOCK_T
412
+ i_2d_ptr += TD_2D_SHIFT
413
+ o_2d_ptr += TD_2D_SHIFT
414
+ if MASK_D:
415
+ d_1d_ptr += D_1D_SHIFT
416
+ else:
417
+ # increment D
418
+ i_2d_ptr += BLOCK_D
419
+ o_2d_ptr += BLOCK_D
420
+ if MASK_D:
421
+ d_1d_ptr += BLOCK_D