sglang 0.5.0rc1__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. sglang/bench_one_batch.py +0 -1
  2. sglang/srt/configs/model_config.py +1 -0
  3. sglang/srt/disaggregation/decode.py +0 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/entrypoints/http_server.py +64 -0
  6. sglang/srt/entrypoints/openai/protocol.py +2 -0
  7. sglang/srt/entrypoints/openai/serving_chat.py +1 -0
  8. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  9. sglang/srt/layers/attention/flashinfer_backend.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  11. sglang/srt/layers/attention/triton_backend.py +24 -27
  12. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  13. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -3
  14. sglang/srt/layers/communicator.py +7 -7
  15. sglang/srt/layers/dp_attention.py +118 -27
  16. sglang/srt/layers/logits_processor.py +12 -18
  17. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/multimodal.py +156 -40
  29. sglang/srt/layers/quantization/__init__.py +5 -32
  30. sglang/srt/layers/quantization/awq.py +15 -16
  31. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  32. sglang/srt/layers/quantization/gptq.py +12 -17
  33. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  34. sglang/srt/layers/quantization/modelopt_quant.py +52 -30
  35. sglang/srt/layers/quantization/mxfp4.py +16 -2
  36. sglang/srt/layers/quantization/utils.py +52 -2
  37. sglang/srt/layers/sampler.py +5 -2
  38. sglang/srt/lora/layers.py +6 -2
  39. sglang/srt/managers/cache_controller.py +4 -1
  40. sglang/srt/managers/io_struct.py +14 -0
  41. sglang/srt/managers/schedule_batch.py +18 -39
  42. sglang/srt/managers/scheduler.py +3 -4
  43. sglang/srt/managers/tokenizer_manager.py +28 -18
  44. sglang/srt/mem_cache/allocator.py +8 -157
  45. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  46. sglang/srt/mem_cache/chunk_cache.py +1 -1
  47. sglang/srt/model_executor/cuda_graph_runner.py +8 -21
  48. sglang/srt/model_executor/forward_batch_info.py +8 -10
  49. sglang/srt/model_executor/model_runner.py +57 -53
  50. sglang/srt/models/deepseek_nextn.py +2 -1
  51. sglang/srt/models/deepseek_v2.py +5 -3
  52. sglang/srt/models/glm4_moe.py +2 -2
  53. sglang/srt/models/glm4_moe_nextn.py +2 -1
  54. sglang/srt/models/gpt_oss.py +7 -2
  55. sglang/srt/models/llama.py +10 -2
  56. sglang/srt/models/llama4.py +18 -5
  57. sglang/srt/models/qwen2.py +2 -2
  58. sglang/srt/models/qwen2_moe.py +20 -5
  59. sglang/srt/models/qwen3_classification.py +78 -0
  60. sglang/srt/models/qwen3_moe.py +18 -5
  61. sglang/srt/models/step3_vl.py +6 -2
  62. sglang/srt/operations.py +17 -2
  63. sglang/srt/sampling/sampling_batch_info.py +7 -4
  64. sglang/srt/server_args.py +33 -7
  65. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  66. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  67. sglang/srt/two_batch_overlap.py +4 -8
  68. sglang/test/test_marlin_moe.py +1 -1
  69. sglang/test/test_marlin_utils.py +1 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +5 -5
  72. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +75 -63
  73. sglang/srt/layers/quantization/scalar_type.py +0 -352
  74. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  75. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  76. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 64,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 64,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 64,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -17,57 +17,173 @@ import torch
17
17
  import triton
18
18
  import triton.language as tl
19
19
 
20
+ FMIX32_C1 = 0x85EBCA6B
21
+ FMIX32_C2 = 0xC2B2AE35
22
+ POS_C1 = 0x27D4EB2D
23
+ POS_C2 = 0x165667B1
24
+
25
+
26
+ @triton.jit
27
+ def _rotl32(x, r: tl.constexpr):
28
+ return (x << r) | (x >> (32 - r))
29
+
30
+
31
+ @triton.jit
32
+ def _fmix32(x, C1: tl.constexpr, C2: tl.constexpr):
33
+ c1 = tl.full((), C1, tl.uint32)
34
+ c2 = tl.full((), C2, tl.uint32)
35
+ x ^= x >> 16
36
+ x = x * c1
37
+ x ^= x >> 13
38
+ x = x * c2
39
+ x ^= x >> 16
40
+ return x
41
+
20
42
 
21
43
  @triton.jit
22
- def hash_kernel(
23
- input_ptr,
24
- output_ptr,
25
- n_elements,
26
- BLOCK_SIZE: tl.constexpr,
27
- PRIME: tl.constexpr,
28
- XCONST: tl.constexpr,
44
+ def hash_tiles32_kernel_blocked(
45
+ in_ptr,
46
+ out_ptr,
47
+ n_u32,
48
+ seed1,
49
+ seed2,
50
+ FM_C1: tl.constexpr,
51
+ FM_C2: tl.constexpr,
52
+ POS_A: tl.constexpr,
53
+ POS_B: tl.constexpr,
54
+ TILE: tl.constexpr,
55
+ BLOCK: tl.constexpr,
56
+ USE_CG: tl.constexpr,
29
57
  ):
30
58
  pid = tl.program_id(axis=0)
31
- block_start = pid * BLOCK_SIZE
32
- offsets = block_start + tl.arange(0, BLOCK_SIZE)
33
- mask = offsets < n_elements
59
+ base = pid * TILE
60
+
61
+ s1 = tl.full((), seed1, tl.uint32)
62
+ s2 = tl.full((), seed2, tl.uint32)
63
+ posA = tl.full((), POS_A, tl.uint32)
64
+ posB = tl.full((), POS_B, tl.uint32)
65
+
66
+ h1 = tl.zeros((), dtype=tl.uint32)
67
+ h2 = tl.zeros((), dtype=tl.uint32)
68
+
69
+ for off in tl.static_range(0, TILE, BLOCK):
70
+ idx = base + off + tl.arange(0, BLOCK)
71
+ m = idx < n_u32
34
72
 
35
- data = tl.load(input_ptr + offsets, mask=mask, other=0).to(tl.int64)
36
- mixed = data ^ (offsets.to(tl.int64) + XCONST)
37
- hash_val = mixed * PRIME
38
- hash_val = hash_val ^ (hash_val >> 16)
39
- hash_val = hash_val * (PRIME ^ XCONST)
40
- hash_val = hash_val ^ (hash_val >> 13)
73
+ if USE_CG:
74
+ v = tl.load(in_ptr + idx, mask=m, other=0, cache_modifier=".cg")
75
+ else:
76
+ v = tl.load(in_ptr + idx, mask=m, other=0)
77
+ v = v.to(tl.uint32)
78
+
79
+ iu = idx.to(tl.uint32)
80
+ p1 = (iu * posA + s1) ^ _rotl32(iu, 15)
81
+ p2 = (iu * posB + s2) ^ _rotl32(iu, 13)
82
+
83
+ k1 = _fmix32(v ^ p1, C1=FM_C1, C2=FM_C2)
84
+ k2 = _fmix32(v ^ p2, C1=FM_C1, C2=FM_C2)
85
+
86
+ zero32 = tl.zeros_like(k1)
87
+ k1 = tl.where(m, k1, zero32)
88
+ k2 = tl.where(m, k2, zero32)
89
+
90
+ h1 += tl.sum(k1, axis=0).to(tl.uint32)
91
+ h2 += tl.sum(k2, axis=0).to(tl.uint32)
92
+
93
+ nbytes = tl.full((), n_u32 * 4, tl.uint32)
94
+ h1 ^= nbytes
95
+ h2 ^= nbytes
96
+ h1 = _fmix32(h1, C1=FM_C1, C2=FM_C2)
97
+ h2 = (
98
+ _fmix32(h2, C1=FMIX32_C1, C2=FMIX32_C2)
99
+ if False
100
+ else _fmix32(h2, C1=FM_C1, C2=FM_C2)
101
+ )
102
+
103
+ out = (h1.to(tl.uint64) << 32) | h2.to(tl.uint64)
104
+ tl.store(out_ptr + pid, out)
105
+
106
+
107
+ @triton.jit
108
+ def add_tree_reduce_u64_kernel(in_ptr, out_ptr, n_elems, CHUNK: tl.constexpr):
109
+ pid = tl.program_id(axis=0)
110
+ start = pid * CHUNK
111
+ h = tl.zeros((), dtype=tl.uint64)
112
+ for i in tl.static_range(0, CHUNK):
113
+ idx = start + i
114
+ m = idx < n_elems
115
+ v = tl.load(in_ptr + idx, mask=m, other=0).to(tl.uint64)
116
+ h += v
117
+ tl.store(out_ptr + pid, h)
41
118
 
42
- tl.store(output_ptr + offsets, hash_val, mask=mask)
43
119
 
120
+ def _as_uint32_words(t: torch.Tensor) -> torch.Tensor:
121
+ assert t.is_cuda, "Use .cuda() first"
122
+ tb = t.contiguous().view(torch.uint8)
123
+ nbytes = tb.numel()
124
+ pad = (4 - (nbytes & 3)) & 3
125
+ if pad:
126
+ tb_p = torch.empty(nbytes + pad, dtype=torch.uint8, device=tb.device)
127
+ tb_p[:nbytes].copy_(tb)
128
+ tb_p[nbytes:].zero_()
129
+ tb = tb_p
130
+ return tb.view(torch.uint32)
44
131
 
45
- PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
46
- PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
47
132
 
133
+ def _final_splitmix64(x: int) -> int:
134
+ mask = (1 << 64) - 1
135
+ x &= mask
136
+ x ^= x >> 30
137
+ x = (x * 0xBF58476D1CE4E5B9) & mask
138
+ x ^= x >> 27
139
+ x = (x * 0x94D049BB133111EB) & mask
140
+ x ^= x >> 31
141
+ return x
48
142
 
49
- def gpu_tensor_hash(tensor: torch.Tensor) -> int:
50
- assert tensor.is_cuda
51
- tensor = tensor.contiguous().view(torch.int32)
52
- n = tensor.numel()
53
- BLOCK_SIZE = 1024
54
- grid = (triton.cdiv(n, BLOCK_SIZE),)
55
143
 
56
- intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
144
+ @torch.inference_mode()
145
+ def gpu_tensor_hash(
146
+ tensor: torch.Tensor,
147
+ *,
148
+ seed: int = 0x243F6A88,
149
+ tile_words: int = 8192,
150
+ block_words: int = 256,
151
+ reduce_chunk: int = 1024,
152
+ num_warps: int = 4,
153
+ num_stages: int = 4,
154
+ use_cg: bool = True,
155
+ ) -> int:
156
+ assert tensor.is_cuda, "Use .cuda() first"
157
+ u32 = _as_uint32_words(tensor)
158
+ n = u32.numel()
159
+ if n == 0:
160
+ return 0
57
161
 
58
- # Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
59
- # Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
60
- with torch.cuda.device(tensor.device):
61
- hash_kernel[grid](
62
- tensor,
63
- intermediate_hashes,
64
- n,
65
- BLOCK_SIZE=BLOCK_SIZE,
66
- PRIME=PRIME_1,
67
- XCONST=PRIME_2,
68
- )
162
+ grid1 = (triton.cdiv(n, tile_words),)
163
+ partials = torch.empty(grid1[0], dtype=torch.uint64, device=u32.device)
164
+ hash_tiles32_kernel_blocked[grid1](
165
+ u32,
166
+ partials,
167
+ n,
168
+ seed1=seed & 0xFFFFFFFF,
169
+ seed2=((seed * 0x9E3779B1) ^ 0xDEADBEEF) & 0xFFFFFFFF,
170
+ FM_C1=FMIX32_C1,
171
+ FM_C2=FMIX32_C2,
172
+ POS_A=POS_C1,
173
+ POS_B=POS_C2,
174
+ TILE=tile_words,
175
+ BLOCK=block_words,
176
+ USE_CG=use_cg,
177
+ num_warps=num_warps,
178
+ num_stages=num_stages,
179
+ )
69
180
 
70
- # TODO: threads can't be synced on triton kernel
71
- final_hash = intermediate_hashes.sum().item()
181
+ cur = partials
182
+ while cur.numel() > 1:
183
+ n_elems = cur.numel()
184
+ grid2 = (triton.cdiv(n_elems, reduce_chunk),)
185
+ nxt = torch.empty(grid2[0], dtype=torch.uint64, device=cur.device)
186
+ add_tree_reduce_u64_kernel[grid2](cur, nxt, n_elems, CHUNK=reduce_chunk)
187
+ cur = nxt
72
188
 
73
- return final_hash
189
+ return _final_splitmix64(int(cur.item()))