node-llama-cpp 3.0.0-beta.1 → 3.0.0-beta.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. package/README.md +2 -0
  2. package/dist/ChatWrapper.d.ts +49 -0
  3. package/dist/ChatWrapper.js +120 -0
  4. package/dist/ChatWrapper.js.map +1 -0
  5. package/dist/chatWrappers/AlpacaChatWrapper.d.ts +12 -0
  6. package/dist/chatWrappers/AlpacaChatWrapper.js +21 -0
  7. package/dist/chatWrappers/AlpacaChatWrapper.js.map +1 -0
  8. package/dist/chatWrappers/ChatMLChatWrapper.d.ts +13 -0
  9. package/dist/chatWrappers/ChatMLChatWrapper.js +83 -0
  10. package/dist/chatWrappers/ChatMLChatWrapper.js.map +1 -0
  11. package/dist/chatWrappers/EmptyChatWrapper.d.ts +4 -0
  12. package/dist/chatWrappers/EmptyChatWrapper.js +5 -0
  13. package/dist/chatWrappers/EmptyChatWrapper.js.map +1 -0
  14. package/dist/chatWrappers/FalconChatWrapper.d.ts +21 -0
  15. package/dist/chatWrappers/FalconChatWrapper.js +104 -0
  16. package/dist/chatWrappers/FalconChatWrapper.js.map +1 -0
  17. package/dist/chatWrappers/FunctionaryChatWrapper.d.ts +41 -0
  18. package/dist/chatWrappers/FunctionaryChatWrapper.js +200 -0
  19. package/dist/chatWrappers/FunctionaryChatWrapper.js.map +1 -0
  20. package/dist/chatWrappers/GeneralChatWrapper.d.ts +21 -0
  21. package/dist/chatWrappers/GeneralChatWrapper.js +112 -0
  22. package/dist/chatWrappers/GeneralChatWrapper.js.map +1 -0
  23. package/dist/chatWrappers/LlamaChatWrapper.d.ts +13 -0
  24. package/dist/chatWrappers/LlamaChatWrapper.js +78 -0
  25. package/dist/chatWrappers/LlamaChatWrapper.js.map +1 -0
  26. package/dist/chatWrappers/resolveChatWrapperBasedOnModel.d.ts +4 -4
  27. package/dist/chatWrappers/resolveChatWrapperBasedOnModel.js +24 -16
  28. package/dist/chatWrappers/resolveChatWrapperBasedOnModel.js.map +1 -1
  29. package/dist/cli/commands/ChatCommand.d.ts +2 -1
  30. package/dist/cli/commands/ChatCommand.js +71 -33
  31. package/dist/cli/commands/ChatCommand.js.map +1 -1
  32. package/dist/config.js +1 -1
  33. package/dist/config.js.map +1 -1
  34. package/dist/index.d.ts +17 -10
  35. package/dist/index.js +16 -8
  36. package/dist/index.js.map +1 -1
  37. package/dist/llamaEvaluator/LlamaBins.d.ts +0 -1
  38. package/dist/llamaEvaluator/LlamaChat/LlamaChat.d.ts +175 -0
  39. package/dist/llamaEvaluator/LlamaChat/LlamaChat.js +704 -0
  40. package/dist/llamaEvaluator/LlamaChat/LlamaChat.js.map +1 -0
  41. package/dist/llamaEvaluator/LlamaChat/utils/FunctionCallGrammar.d.ts +21 -0
  42. package/dist/llamaEvaluator/LlamaChat/utils/FunctionCallGrammar.js +120 -0
  43. package/dist/llamaEvaluator/LlamaChat/utils/FunctionCallGrammar.js.map +1 -0
  44. package/dist/llamaEvaluator/LlamaChat/utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.d.ts +16 -0
  45. package/dist/llamaEvaluator/LlamaChat/utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js +117 -0
  46. package/dist/llamaEvaluator/LlamaChat/utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js.map +1 -0
  47. package/dist/llamaEvaluator/{LlamaChatSession.d.ts → LlamaChatSession/LlamaChatSession.d.ts} +48 -25
  48. package/dist/llamaEvaluator/LlamaChatSession/LlamaChatSession.js +211 -0
  49. package/dist/llamaEvaluator/LlamaChatSession/LlamaChatSession.js.map +1 -0
  50. package/dist/llamaEvaluator/LlamaChatSession/utils/defineChatSessionFunction.d.ts +7 -0
  51. package/dist/llamaEvaluator/LlamaChatSession/utils/defineChatSessionFunction.js +8 -0
  52. package/dist/llamaEvaluator/LlamaChatSession/utils/defineChatSessionFunction.js.map +1 -0
  53. package/dist/llamaEvaluator/LlamaContext/LlamaContext.d.ts +18 -23
  54. package/dist/llamaEvaluator/LlamaContext/LlamaContext.js +60 -103
  55. package/dist/llamaEvaluator/LlamaContext/LlamaContext.js.map +1 -1
  56. package/dist/llamaEvaluator/LlamaContext/types.d.ts +6 -14
  57. package/dist/llamaEvaluator/LlamaEmbeddingContext.d.ts +35 -0
  58. package/dist/llamaEvaluator/LlamaEmbeddingContext.js +73 -0
  59. package/dist/llamaEvaluator/LlamaEmbeddingContext.js.map +1 -0
  60. package/dist/llamaEvaluator/LlamaGrammar.d.ts +8 -12
  61. package/dist/llamaEvaluator/LlamaGrammar.js +7 -12
  62. package/dist/llamaEvaluator/LlamaGrammar.js.map +1 -1
  63. package/dist/llamaEvaluator/LlamaJsonSchemaGrammar.js +2 -1
  64. package/dist/llamaEvaluator/LlamaJsonSchemaGrammar.js.map +1 -1
  65. package/dist/llamaEvaluator/LlamaModel.d.ts +10 -2
  66. package/dist/llamaEvaluator/LlamaModel.js +14 -3
  67. package/dist/llamaEvaluator/LlamaModel.js.map +1 -1
  68. package/dist/types.d.ts +41 -3
  69. package/dist/types.js +5 -1
  70. package/dist/types.js.map +1 -1
  71. package/dist/utils/LlamaText.d.ts +42 -0
  72. package/dist/utils/LlamaText.js +207 -0
  73. package/dist/utils/LlamaText.js.map +1 -0
  74. package/dist/utils/StopGenerationDetector.d.ts +28 -0
  75. package/dist/utils/StopGenerationDetector.js +205 -0
  76. package/dist/utils/StopGenerationDetector.js.map +1 -0
  77. package/dist/utils/TokenStreamRegulator.d.ts +30 -0
  78. package/dist/utils/TokenStreamRegulator.js +96 -0
  79. package/dist/utils/TokenStreamRegulator.js.map +1 -0
  80. package/dist/utils/appendUserMessageToChatHistory.d.ts +2 -0
  81. package/dist/utils/appendUserMessageToChatHistory.js +18 -0
  82. package/dist/utils/appendUserMessageToChatHistory.js.map +1 -0
  83. package/dist/utils/compareTokens.d.ts +2 -0
  84. package/dist/utils/compareTokens.js +4 -0
  85. package/dist/utils/compareTokens.js.map +1 -0
  86. package/dist/utils/compileLLamaCpp.js +11 -6
  87. package/dist/utils/compileLLamaCpp.js.map +1 -1
  88. package/dist/utils/findCharacterRemovalCountToFitChatHistoryInContext.d.ts +18 -0
  89. package/dist/utils/findCharacterRemovalCountToFitChatHistoryInContext.js +61 -0
  90. package/dist/utils/findCharacterRemovalCountToFitChatHistoryInContext.js.map +1 -0
  91. package/dist/utils/gbnfJson/GbnfGrammarGenerator.d.ts +1 -0
  92. package/dist/utils/gbnfJson/GbnfGrammarGenerator.js +17 -0
  93. package/dist/utils/gbnfJson/GbnfGrammarGenerator.js.map +1 -1
  94. package/dist/utils/gbnfJson/GbnfTerminal.d.ts +1 -1
  95. package/dist/utils/gbnfJson/GbnfTerminal.js.map +1 -1
  96. package/dist/utils/gbnfJson/terminals/GbnfVerbatimText.d.ts +6 -0
  97. package/dist/utils/gbnfJson/terminals/GbnfVerbatimText.js +21 -0
  98. package/dist/utils/gbnfJson/terminals/GbnfVerbatimText.js.map +1 -0
  99. package/dist/utils/gbnfJson/types.d.ts +1 -1
  100. package/dist/utils/gbnfJson/types.js.map +1 -1
  101. package/dist/utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.d.ts +1 -0
  102. package/dist/utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.js.map +1 -1
  103. package/dist/utils/getBin.d.ts +3 -2
  104. package/dist/utils/getGbnfGrammarForGbnfJsonSchema.js +1 -15
  105. package/dist/utils/getGbnfGrammarForGbnfJsonSchema.js.map +1 -1
  106. package/dist/utils/getTypeScriptTypeStringForGbnfJsonSchema.d.ts +2 -0
  107. package/dist/utils/getTypeScriptTypeStringForGbnfJsonSchema.js +49 -0
  108. package/dist/utils/getTypeScriptTypeStringForGbnfJsonSchema.js.map +1 -0
  109. package/dist/utils/resolveChatWrapper.d.ts +4 -0
  110. package/dist/utils/resolveChatWrapper.js +16 -0
  111. package/dist/utils/resolveChatWrapper.js.map +1 -0
  112. package/dist/utils/truncateTextAndRoundToWords.d.ts +8 -0
  113. package/dist/utils/truncateTextAndRoundToWords.js +27 -0
  114. package/dist/utils/truncateTextAndRoundToWords.js.map +1 -0
  115. package/llama/addon.cpp +45 -17
  116. package/llama/binariesGithubRelease.json +1 -1
  117. package/llama/gitRelease.bundle +0 -0
  118. package/llamaBins/linux-arm64/llama-addon.node +0 -0
  119. package/llamaBins/linux-armv7l/llama-addon.node +0 -0
  120. package/llamaBins/linux-x64/llama-addon.node +0 -0
  121. package/llamaBins/mac-arm64/llama-addon.node +0 -0
  122. package/llamaBins/mac-x64/llama-addon.node +0 -0
  123. package/llamaBins/win-x64/llama-addon.node +0 -0
  124. package/package.json +21 -9
  125. package/dist/ChatPromptWrapper.d.ts +0 -11
  126. package/dist/ChatPromptWrapper.js +0 -20
  127. package/dist/ChatPromptWrapper.js.map +0 -1
  128. package/dist/chatWrappers/ChatMLChatPromptWrapper.d.ts +0 -12
  129. package/dist/chatWrappers/ChatMLChatPromptWrapper.js +0 -22
  130. package/dist/chatWrappers/ChatMLChatPromptWrapper.js.map +0 -1
  131. package/dist/chatWrappers/EmptyChatPromptWrapper.d.ts +0 -4
  132. package/dist/chatWrappers/EmptyChatPromptWrapper.js +0 -5
  133. package/dist/chatWrappers/EmptyChatPromptWrapper.js.map +0 -1
  134. package/dist/chatWrappers/FalconChatPromptWrapper.d.ts +0 -19
  135. package/dist/chatWrappers/FalconChatPromptWrapper.js +0 -33
  136. package/dist/chatWrappers/FalconChatPromptWrapper.js.map +0 -1
  137. package/dist/chatWrappers/GeneralChatPromptWrapper.d.ts +0 -19
  138. package/dist/chatWrappers/GeneralChatPromptWrapper.js +0 -38
  139. package/dist/chatWrappers/GeneralChatPromptWrapper.js.map +0 -1
  140. package/dist/chatWrappers/LlamaChatPromptWrapper.d.ts +0 -12
  141. package/dist/chatWrappers/LlamaChatPromptWrapper.js +0 -23
  142. package/dist/chatWrappers/LlamaChatPromptWrapper.js.map +0 -1
  143. package/dist/chatWrappers/generateContextTextFromConversationHistory.d.ts +0 -15
  144. package/dist/chatWrappers/generateContextTextFromConversationHistory.js +0 -39
  145. package/dist/chatWrappers/generateContextTextFromConversationHistory.js.map +0 -1
  146. package/dist/llamaEvaluator/LlamaChatSession.js +0 -290
  147. package/dist/llamaEvaluator/LlamaChatSession.js.map +0 -1
  148. package/dist/utils/getTextCompletion.d.ts +0 -3
  149. package/dist/utils/getTextCompletion.js +0 -12
  150. package/dist/utils/getTextCompletion.js.map +0 -1
  151. package/llamaBins/mac-arm64/ggml-metal.metal +0 -2929
  152. package/llamaBins/mac-x64/ggml-metal.metal +0 -2929
@@ -1,2929 +0,0 @@
1
- #include <metal_stdlib>
2
-
3
- using namespace metal;
4
-
5
- #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
-
7
- #define QK4_0 32
8
- #define QR4_0 2
9
- typedef struct {
10
- half d; // delta
11
- uint8_t qs[QK4_0 / 2]; // nibbles / quants
12
- } block_q4_0;
13
-
14
- #define QK4_1 32
15
- typedef struct {
16
- half d; // delta
17
- half m; // min
18
- uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
- } block_q4_1;
20
-
21
- #define QK5_0 32
22
- typedef struct {
23
- half d; // delta
24
- uint8_t qh[4]; // 5-th bit of quants
25
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
26
- } block_q5_0;
27
-
28
- #define QK5_1 32
29
- typedef struct {
30
- half d; // delta
31
- half m; // min
32
- uint8_t qh[4]; // 5-th bit of quants
33
- uint8_t qs[QK5_1 / 2]; // nibbles / quants
34
- } block_q5_1;
35
-
36
- #define QK8_0 32
37
- typedef struct {
38
- half d; // delta
39
- int8_t qs[QK8_0]; // quants
40
- } block_q8_0;
41
-
42
- // general-purpose kernel for addition of two tensors
43
- // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
44
- // cons: not very efficient
45
- kernel void kernel_add(
46
- device const char * src0,
47
- device const char * src1,
48
- device char * dst,
49
- constant int64_t & ne00,
50
- constant int64_t & ne01,
51
- constant int64_t & ne02,
52
- constant int64_t & ne03,
53
- constant int64_t & nb00,
54
- constant int64_t & nb01,
55
- constant int64_t & nb02,
56
- constant int64_t & nb03,
57
- constant int64_t & ne10,
58
- constant int64_t & ne11,
59
- constant int64_t & ne12,
60
- constant int64_t & ne13,
61
- constant int64_t & nb10,
62
- constant int64_t & nb11,
63
- constant int64_t & nb12,
64
- constant int64_t & nb13,
65
- constant int64_t & ne0,
66
- constant int64_t & ne1,
67
- constant int64_t & ne2,
68
- constant int64_t & ne3,
69
- constant int64_t & nb0,
70
- constant int64_t & nb1,
71
- constant int64_t & nb2,
72
- constant int64_t & nb3,
73
- uint3 tgpig[[threadgroup_position_in_grid]],
74
- uint3 tpitg[[thread_position_in_threadgroup]],
75
- uint3 ntg[[threads_per_threadgroup]]) {
76
- const int64_t i03 = tgpig.z;
77
- const int64_t i02 = tgpig.y;
78
- const int64_t i01 = tgpig.x;
79
-
80
- const int64_t i13 = i03 % ne13;
81
- const int64_t i12 = i02 % ne12;
82
- const int64_t i11 = i01 % ne11;
83
-
84
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
85
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
86
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
87
-
88
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
89
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
90
-
91
- src0_ptr += ntg.x*nb00;
92
- src1_ptr += ntg.x*nb10;
93
- dst_ptr += ntg.x*nb0;
94
- }
95
- }
96
-
97
- // assumption: src1 is a row
98
- // broadcast src1 into src0
99
- kernel void kernel_add_row(
100
- device const float4 * src0,
101
- device const float4 * src1,
102
- device float4 * dst,
103
- constant int64_t & nb [[buffer(27)]],
104
- uint tpig[[thread_position_in_grid]]) {
105
- dst[tpig] = src0[tpig] + src1[tpig % nb];
106
- }
107
-
108
- kernel void kernel_mul(
109
- device const float4 * src0,
110
- device const float4 * src1,
111
- device float4 * dst,
112
- uint tpig[[thread_position_in_grid]]) {
113
- dst[tpig] = src0[tpig] * src1[tpig];
114
- }
115
-
116
- // assumption: src1 is a row
117
- // broadcast src1 into src0
118
- kernel void kernel_mul_row(
119
- device const float4 * src0,
120
- device const float4 * src1,
121
- device float4 * dst,
122
- constant int64_t & nb,
123
- uint tpig[[thread_position_in_grid]]) {
124
- dst[tpig] = src0[tpig] * src1[tpig % nb];
125
- }
126
-
127
- kernel void kernel_scale(
128
- device const float * src0,
129
- device float * dst,
130
- constant float & scale,
131
- uint tpig[[thread_position_in_grid]]) {
132
- dst[tpig] = src0[tpig] * scale;
133
- }
134
-
135
- kernel void kernel_scale_4(
136
- device const float4 * src0,
137
- device float4 * dst,
138
- constant float & scale,
139
- uint tpig[[thread_position_in_grid]]) {
140
- dst[tpig] = src0[tpig] * scale;
141
- }
142
-
143
- kernel void kernel_silu(
144
- device const float4 * src0,
145
- device float4 * dst,
146
- uint tpig[[thread_position_in_grid]]) {
147
- device const float4 & x = src0[tpig];
148
- dst[tpig] = x / (1.0f + exp(-x));
149
- }
150
-
151
- kernel void kernel_relu(
152
- device const float * src0,
153
- device float * dst,
154
- uint tpig[[thread_position_in_grid]]) {
155
- dst[tpig] = max(0.0f, src0[tpig]);
156
- }
157
-
158
- kernel void kernel_sqr(
159
- device const float * src0,
160
- device float * dst,
161
- uint tpig[[thread_position_in_grid]]) {
162
- dst[tpig] = src0[tpig] * src0[tpig];
163
- }
164
-
165
- constant float GELU_COEF_A = 0.044715f;
166
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
167
-
168
- kernel void kernel_gelu(
169
- device const float4 * src0,
170
- device float4 * dst,
171
- uint tpig[[thread_position_in_grid]]) {
172
- device const float4 & x = src0[tpig];
173
-
174
- // BEWARE !!!
175
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
176
- // This was observed with Falcon 7B and 40B models
177
- //
178
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
179
- }
180
-
181
- kernel void kernel_soft_max(
182
- device const float * src0,
183
- device float * dst,
184
- constant int64_t & ne00,
185
- constant int64_t & ne01,
186
- constant int64_t & ne02,
187
- threadgroup float * buf [[threadgroup(0)]],
188
- uint tgpig[[threadgroup_position_in_grid]],
189
- uint tpitg[[thread_position_in_threadgroup]],
190
- uint sgitg[[simdgroup_index_in_threadgroup]],
191
- uint tiisg[[thread_index_in_simdgroup]],
192
- uint ntg[[threads_per_threadgroup]]) {
193
- const int64_t i03 = (tgpig) / (ne02*ne01);
194
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
196
-
197
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
198
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
199
-
200
- // parallel max
201
- float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202
-
203
- for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
204
- lmax = MAX(lmax, psrc0[i00]);
205
- }
206
-
207
- float max = simd_max(lmax);
208
- if (tiisg == 0) {
209
- buf[sgitg] = max;
210
- }
211
-
212
- threadgroup_barrier(mem_flags::mem_threadgroup);
213
-
214
- // broadcast, simd group number is ntg / 32
215
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
- if (tpitg < i) {
217
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
- }
219
- }
220
-
221
- threadgroup_barrier(mem_flags::mem_threadgroup);
222
-
223
- max = buf[0];
224
-
225
- // parallel sum
226
- float lsum = 0.0f;
227
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
228
- const float exp_psrc0 = exp(psrc0[i00] - max);
229
- lsum += exp_psrc0;
230
- // Remember the result of exp here. exp is expensive, so we really do not
231
- // wish to compute it twice.
232
- pdst[i00] = exp_psrc0;
233
- }
234
-
235
- float sum = simd_sum(lsum);
236
- if (tiisg == 0) {
237
- buf[sgitg] = sum;
238
- }
239
-
240
- threadgroup_barrier(mem_flags::mem_threadgroup);
241
-
242
- // broadcast, simd group number is ntg / 32
243
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
- if (tpitg < i) {
245
- buf[tpitg] += buf[tpitg + i];
246
- }
247
- }
248
-
249
- threadgroup_barrier(mem_flags::mem_threadgroup);
250
-
251
- sum = buf[0];
252
-
253
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
254
- pdst[i00] /= sum;
255
- }
256
- }
257
-
258
- kernel void kernel_soft_max_4(
259
- device const float * src0,
260
- device float * dst,
261
- constant int64_t & ne00,
262
- constant int64_t & ne01,
263
- constant int64_t & ne02,
264
- threadgroup float * buf [[threadgroup(0)]],
265
- uint tgpig[[threadgroup_position_in_grid]],
266
- uint tpitg[[thread_position_in_threadgroup]],
267
- uint sgitg[[simdgroup_index_in_threadgroup]],
268
- uint tiisg[[thread_index_in_simdgroup]],
269
- uint ntg[[threads_per_threadgroup]]) {
270
- const int64_t i03 = (tgpig) / (ne02*ne01);
271
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
273
-
274
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
275
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
276
-
277
- // parallel max
278
- float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279
-
280
- for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
281
- lmax4 = fmax(lmax4, psrc4[i00]);
282
- }
283
-
284
- const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
- float max = simd_max(lmax);
286
- if (tiisg == 0) {
287
- buf[sgitg] = max;
288
- }
289
-
290
- threadgroup_barrier(mem_flags::mem_threadgroup);
291
-
292
- // broadcast, simd group number is ntg / 32
293
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
- if (tpitg < i) {
295
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
- }
297
- }
298
-
299
- threadgroup_barrier(mem_flags::mem_threadgroup);
300
-
301
- max = buf[0];
302
-
303
- // parallel sum
304
- float4 lsum4 = 0.0f;
305
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
306
- const float4 exp_psrc4 = exp(psrc4[i00] - max);
307
- lsum4 += exp_psrc4;
308
- pdst4[i00] = exp_psrc4;
309
- }
310
-
311
- const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312
- float sum = simd_sum(lsum);
313
- if (tiisg == 0) {
314
- buf[sgitg] = sum;
315
- }
316
-
317
- threadgroup_barrier(mem_flags::mem_threadgroup);
318
-
319
- // broadcast, simd group number is ntg / 32
320
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
- if (tpitg < i) {
322
- buf[tpitg] += buf[tpitg + i];
323
- }
324
- }
325
-
326
- threadgroup_barrier(mem_flags::mem_threadgroup);
327
-
328
- sum = buf[0];
329
-
330
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
331
- pdst4[i00] /= sum;
332
- }
333
- }
334
-
335
- kernel void kernel_diag_mask_inf(
336
- device const float * src0,
337
- device float * dst,
338
- constant int64_t & ne00,
339
- constant int64_t & ne01,
340
- constant int & n_past,
341
- uint3 tpig[[thread_position_in_grid]]) {
342
- const int64_t i02 = tpig[2];
343
- const int64_t i01 = tpig[1];
344
- const int64_t i00 = tpig[0];
345
-
346
- if (i00 > n_past + i01) {
347
- dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
348
- } else {
349
- dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
350
- }
351
- }
352
-
353
- kernel void kernel_diag_mask_inf_8(
354
- device const float4 * src0,
355
- device float4 * dst,
356
- constant int64_t & ne00,
357
- constant int64_t & ne01,
358
- constant int & n_past,
359
- uint3 tpig[[thread_position_in_grid]]) {
360
-
361
- const int64_t i = 2*tpig[0];
362
-
363
- dst[i+0] = src0[i+0];
364
- dst[i+1] = src0[i+1];
365
- int64_t i4 = 4*i;
366
- const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
367
- const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
368
- const int64_t i00 = i4;
369
- for (int k = 3; k >= 0; --k) {
370
- if (i00 + 4 + k <= n_past + i01) {
371
- break;
372
- }
373
- dst[i+1][k] = -INFINITY;
374
- if (i00 + k > n_past + i01) {
375
- dst[i][k] = -INFINITY;
376
- }
377
- }
378
- }
379
-
380
- kernel void kernel_norm(
381
- device const void * src0,
382
- device float * dst,
383
- constant int64_t & ne00,
384
- constant uint64_t & nb01,
385
- constant float & eps,
386
- threadgroup float * sum [[threadgroup(0)]],
387
- uint tgpig[[threadgroup_position_in_grid]],
388
- uint tpitg[[thread_position_in_threadgroup]],
389
- uint ntg[[threads_per_threadgroup]]) {
390
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
391
- // MEAN
392
- // parallel sum
393
- sum[tpitg] = 0.0f;
394
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
395
- sum[tpitg] += x[i00];
396
- }
397
- // reduce
398
- threadgroup_barrier(mem_flags::mem_threadgroup);
399
- for (uint i = ntg/2; i > 0; i /= 2) {
400
- if (tpitg < i) {
401
- sum[tpitg] += sum[tpitg + i];
402
- }
403
- threadgroup_barrier(mem_flags::mem_threadgroup);
404
- }
405
- const float mean = sum[0] / ne00;
406
-
407
- // recenter and VARIANCE
408
- threadgroup_barrier(mem_flags::mem_threadgroup);
409
- device float * y = dst + tgpig*ne00;
410
- sum[tpitg] = 0.0f;
411
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
412
- y[i00] = x[i00] - mean;
413
- sum[tpitg] += y[i00] * y[i00];
414
- }
415
-
416
- // reduce
417
- threadgroup_barrier(mem_flags::mem_threadgroup);
418
- for (uint i = ntg/2; i > 0; i /= 2) {
419
- if (tpitg < i) {
420
- sum[tpitg] += sum[tpitg + i];
421
- }
422
- threadgroup_barrier(mem_flags::mem_threadgroup);
423
- }
424
- const float variance = sum[0] / ne00;
425
-
426
- const float scale = 1.0f/sqrt(variance + eps);
427
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
428
- y[i00] = y[i00] * scale;
429
- }
430
- }
431
-
432
- kernel void kernel_rms_norm(
433
- device const void * src0,
434
- device float * dst,
435
- constant int64_t & ne00,
436
- constant uint64_t & nb01,
437
- constant float & eps,
438
- threadgroup float * sum [[threadgroup(0)]],
439
- uint tgpig[[threadgroup_position_in_grid]],
440
- uint tpitg[[thread_position_in_threadgroup]],
441
- uint sgitg[[simdgroup_index_in_threadgroup]],
442
- uint tiisg[[thread_index_in_simdgroup]],
443
- uint ntg[[threads_per_threadgroup]]) {
444
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
445
- device const float * x_scalar = (device const float *) x;
446
-
447
- float4 sumf = 0;
448
- float all_sum = 0;
449
-
450
- // parallel sum
451
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
452
- sumf += x[i00] * x[i00];
453
- }
454
- all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
455
- all_sum = simd_sum(all_sum);
456
- if (tiisg == 0) {
457
- sum[sgitg] = all_sum;
458
- }
459
-
460
- threadgroup_barrier(mem_flags::mem_threadgroup);
461
-
462
- // broadcast, simd group number is ntg / 32
463
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
464
- if (tpitg < i) {
465
- sum[tpitg] += sum[tpitg + i];
466
- }
467
- }
468
- if (tpitg == 0) {
469
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {
470
- sum[0] += x_scalar[i];
471
- }
472
- sum[0] /= ne00;
473
- }
474
-
475
- threadgroup_barrier(mem_flags::mem_threadgroup);
476
-
477
- const float mean = sum[0];
478
- const float scale = 1.0f/sqrt(mean + eps);
479
-
480
- device float4 * y = (device float4 *) (dst + tgpig*ne00);
481
- device float * y_scalar = (device float *) y;
482
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
483
- y[i00] = x[i00] * scale;
484
- }
485
- if (tpitg == 0) {
486
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
487
- y_scalar[i00] = x_scalar[i00] * scale;
488
- }
489
- }
490
- }
491
-
492
- // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
493
- // il indicates where the q4 quants begin (0 or QK4_0/4)
494
- // we assume that the yl's have been multiplied with the appropriate scale factor
495
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
496
- inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
497
- float d = qb_curr->d;
498
-
499
- float2 acc = 0.f;
500
-
501
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
502
-
503
- for (int i = 0; i < 8; i+=2) {
504
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
505
- + yl[i + 1] * (qs[i / 2] & 0x0F00);
506
- acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
507
- + yl[i + 9] * (qs[i / 2] & 0xF000);
508
- }
509
- return d * (sumy * -8.f + acc[0] + acc[1]);
510
- }
511
-
512
- // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
513
- // il indicates where the q4 quants begin (0 or QK4_0/4)
514
- // we assume that the yl's have been multiplied with the appropriate scale factor
515
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
516
- inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
517
- float d = qb_curr->d;
518
- float m = qb_curr->m;
519
-
520
- float2 acc = 0.f;
521
-
522
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
523
-
524
- for (int i = 0; i < 8; i+=2) {
525
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
526
- + yl[i + 1] * (qs[i / 2] & 0x0F00);
527
- acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
528
- + yl[i + 9] * (qs[i / 2] & 0xF000);
529
- }
530
- return d * (acc[0] + acc[1]) + sumy * m;
531
- }
532
-
533
- // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
534
- // il indicates where the q5 quants begin (0 or QK5_0/4)
535
- // we assume that the yl's have been multiplied with the appropriate scale factor
536
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
537
- inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
538
- float d = qb_curr->d;
539
-
540
- float2 acc = 0.f;
541
-
542
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
543
- const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
544
-
545
- for (int i = 0; i < 8; i+=2) {
546
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
547
- + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
548
- acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
549
- + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
550
- }
551
- return d * (sumy * -16.f + acc[0] + acc[1]);
552
- }
553
-
554
- // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
555
- // il indicates where the q5 quants begin (0 or QK5_1/4)
556
- // we assume that the yl's have been multiplied with the appropriate scale factor
557
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
558
- inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
559
- float d = qb_curr->d;
560
- float m = qb_curr->m;
561
-
562
- float2 acc = 0.f;
563
-
564
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
565
- const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
566
-
567
- for (int i = 0; i < 8; i+=2) {
568
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
569
- + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
570
- acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
571
- + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
572
- }
573
- return d * (acc[0] + acc[1]) + sumy * m;
574
- }
575
-
576
- // putting them in the kernel cause a significant performance penalty
577
- #define N_DST 4 // each SIMD group works on 4 rows
578
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
579
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
580
- //Note: This is a template, but strictly speaking it only applies to
581
- // quantizations where the block size is 32. It also does not
582
- // giard against the number of rows not being divisible by
583
- // N_DST, so this is another explicit assumption of the implementation.
584
- template<typename block_q_type, int nr, int nsg, int nw>
585
- void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
586
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
587
- uint3 tgpig, uint tiisg, uint sgitg) {
588
- const int nb = ne00/QK4_0;
589
-
590
- const int r0 = tgpig.x;
591
- const int r1 = tgpig.y;
592
- const int im = tgpig.z;
593
-
594
- const int first_row = (r0 * nsg + sgitg) * nr;
595
-
596
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
597
-
598
- device const block_q_type * x = (device const block_q_type *) src0 + offset0;
599
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
600
-
601
- float yl[16]; // src1 vector cache
602
- float sumf[nr] = {0.f};
603
-
604
- const int ix = (tiisg/2);
605
- const int il = (tiisg%2)*8;
606
-
607
- device const float * yb = y + ix * QK4_0 + il;
608
-
609
- // each thread in a SIMD group deals with half a block.
610
- for (int ib = ix; ib < nb; ib += nw/2) {
611
- float sumy = 0;
612
- for (int i = 0; i < 8; i += 2) {
613
- sumy += yb[i] + yb[i+1];
614
- yl[i+0] = yb[i+ 0];
615
- yl[i+1] = yb[i+ 1]/256.f;
616
-
617
- sumy += yb[i+16] + yb[i+17];
618
- yl[i+8] = yb[i+16]/16.f;
619
- yl[i+9] = yb[i+17]/4096.f;
620
- }
621
-
622
- for (int row = 0; row < nr; row++) {
623
- sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
624
- }
625
-
626
- yb += QK4_0 * 16;
627
- }
628
-
629
- for (int row = 0; row < nr; ++row) {
630
- const float tot = simd_sum(sumf[row]);
631
- if (tiisg == 0 && first_row + row < ne01) {
632
- dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
633
- }
634
- }
635
- }
636
-
637
- kernel void kernel_mul_mv_q4_0_f32(
638
- device const void * src0,
639
- device const float * src1,
640
- device float * dst,
641
- constant int64_t & ne00,
642
- constant int64_t & ne01[[buffer(4)]],
643
- constant int64_t & ne02[[buffer(5)]],
644
- constant int64_t & ne10[[buffer(9)]],
645
- constant int64_t & ne12[[buffer(11)]],
646
- constant int64_t & ne0[[buffer(15)]],
647
- constant int64_t & ne1[[buffer(16)]],
648
- constant uint & gqa[[buffer(17)]],
649
- uint3 tgpig[[threadgroup_position_in_grid]],
650
- uint tiisg[[thread_index_in_simdgroup]],
651
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
652
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
653
- }
654
-
655
- kernel void kernel_mul_mv_q4_1_f32(
656
- device const void * src0,
657
- device const float * src1,
658
- device float * dst,
659
- constant int64_t & ne00,
660
- constant int64_t & ne01[[buffer(4)]],
661
- constant int64_t & ne02[[buffer(5)]],
662
- constant int64_t & ne10[[buffer(9)]],
663
- constant int64_t & ne12[[buffer(11)]],
664
- constant int64_t & ne0[[buffer(15)]],
665
- constant int64_t & ne1[[buffer(16)]],
666
- constant uint & gqa[[buffer(17)]],
667
- uint3 tgpig[[threadgroup_position_in_grid]],
668
- uint tiisg[[thread_index_in_simdgroup]],
669
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
670
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
671
- }
672
-
673
- kernel void kernel_mul_mv_q5_0_f32(
674
- device const void * src0,
675
- device const float * src1,
676
- device float * dst,
677
- constant int64_t & ne00,
678
- constant int64_t & ne01[[buffer(4)]],
679
- constant int64_t & ne02[[buffer(5)]],
680
- constant int64_t & ne10[[buffer(9)]],
681
- constant int64_t & ne12[[buffer(11)]],
682
- constant int64_t & ne0[[buffer(15)]],
683
- constant int64_t & ne1[[buffer(16)]],
684
- constant uint & gqa[[buffer(17)]],
685
- uint3 tgpig[[threadgroup_position_in_grid]],
686
- uint tiisg[[thread_index_in_simdgroup]],
687
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
689
- }
690
-
691
- kernel void kernel_mul_mv_q5_1_f32(
692
- device const void * src0,
693
- device const float * src1,
694
- device float * dst,
695
- constant int64_t & ne00,
696
- constant int64_t & ne01[[buffer(4)]],
697
- constant int64_t & ne02[[buffer(5)]],
698
- constant int64_t & ne10[[buffer(9)]],
699
- constant int64_t & ne12[[buffer(11)]],
700
- constant int64_t & ne0[[buffer(15)]],
701
- constant int64_t & ne1[[buffer(16)]],
702
- constant uint & gqa[[buffer(17)]],
703
- uint3 tgpig[[threadgroup_position_in_grid]],
704
- uint tiisg[[thread_index_in_simdgroup]],
705
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
707
- }
708
-
709
-
710
- #define NB_Q8_0 8
711
-
712
- kernel void kernel_mul_mv_q8_0_f32(
713
- device const void * src0,
714
- device const float * src1,
715
- device float * dst,
716
- constant int64_t & ne00,
717
- constant int64_t & ne01[[buffer(4)]],
718
- constant int64_t & ne02[[buffer(5)]],
719
- constant int64_t & ne10[[buffer(9)]],
720
- constant int64_t & ne12[[buffer(11)]],
721
- constant int64_t & ne0[[buffer(15)]],
722
- constant int64_t & ne1[[buffer(16)]],
723
- constant uint & gqa[[buffer(17)]],
724
- uint3 tgpig[[threadgroup_position_in_grid]],
725
- uint tiisg[[thread_index_in_simdgroup]],
726
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
727
- const int nr = N_DST;
728
- const int nsg = N_SIMDGROUP;
729
- const int nw = N_SIMDWIDTH;
730
-
731
- const int nb = ne00/QK8_0;
732
- const int r0 = tgpig.x;
733
- const int r1 = tgpig.y;
734
- const int im = tgpig.z;
735
- const int first_row = (r0 * nsg + sgitg) * nr;
736
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
737
- device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
738
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
739
-
740
- float yl[NB_Q8_0];
741
- float sumf[nr]={0.f};
742
-
743
- const int ix = tiisg/4;
744
- const int il = tiisg%4;
745
-
746
- device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
747
-
748
- // each thread in a SIMD group deals with NB_Q8_0 quants at a time
749
- for (int ib = ix; ib < nb; ib += nw/4) {
750
- for (int i = 0; i < NB_Q8_0; ++i) {
751
- yl[i] = yb[i];
752
- }
753
-
754
- for (int row = 0; row < nr; row++) {
755
- device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
756
- float sumq = 0.f;
757
- for (int iq = 0; iq < NB_Q8_0; ++iq) {
758
- sumq += qs[iq] * yl[iq];
759
- }
760
- sumf[row] += sumq*x[ib+row*nb].d;
761
- }
762
-
763
- yb += NB_Q8_0 * nw;
764
- }
765
-
766
- for (int row = 0; row < nr; ++row) {
767
- const float tot = simd_sum(sumf[row]);
768
- if (tiisg == 0 && first_row + row < ne01) {
769
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
770
- }
771
- }
772
- }
773
-
774
- #define N_F32_F32 4
775
-
776
- kernel void kernel_mul_mv_f32_f32(
777
- device const char * src0,
778
- device const char * src1,
779
- device float * dst,
780
- constant int64_t & ne00,
781
- constant int64_t & ne01,
782
- constant int64_t & ne02,
783
- constant uint64_t & nb00,
784
- constant uint64_t & nb01,
785
- constant uint64_t & nb02,
786
- constant int64_t & ne10,
787
- constant int64_t & ne11,
788
- constant int64_t & ne12,
789
- constant uint64_t & nb10,
790
- constant uint64_t & nb11,
791
- constant uint64_t & nb12,
792
- constant int64_t & ne0,
793
- constant int64_t & ne1,
794
- uint3 tgpig[[threadgroup_position_in_grid]],
795
- uint tiisg[[thread_index_in_simdgroup]]) {
796
-
797
- const int64_t r0 = tgpig.x;
798
- const int64_t rb = tgpig.y*N_F32_F32;
799
- const int64_t im = tgpig.z;
800
-
801
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
802
-
803
- if (ne00 < 128) {
804
- for (int row = 0; row < N_F32_F32; ++row) {
805
- int r1 = rb + row;
806
- if (r1 >= ne11) {
807
- break;
808
- }
809
-
810
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
811
-
812
- float sumf = 0;
813
- for (int i = tiisg; i < ne00; i += 32) {
814
- sumf += (float) x[i] * (float) y[i];
815
- }
816
-
817
- float all_sum = simd_sum(sumf);
818
- if (tiisg == 0) {
819
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
820
- }
821
- }
822
- } else {
823
- device const float4 * x4 = (device const float4 *)x;
824
- for (int row = 0; row < N_F32_F32; ++row) {
825
- int r1 = rb + row;
826
- if (r1 >= ne11) {
827
- break;
828
- }
829
-
830
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
831
- device const float4 * y4 = (device const float4 *) y;
832
-
833
- float sumf = 0;
834
- for (int i = tiisg; i < ne00/4; i += 32) {
835
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
836
- }
837
-
838
- float all_sum = simd_sum(sumf);
839
- if (tiisg == 0) {
840
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
841
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
842
- }
843
- }
844
- }
845
- }
846
-
847
- #define N_F16_F16 4
848
-
849
- kernel void kernel_mul_mv_f16_f16(
850
- device const char * src0,
851
- device const char * src1,
852
- device float * dst,
853
- constant int64_t & ne00,
854
- constant int64_t & ne01,
855
- constant int64_t & ne02,
856
- constant uint64_t & nb00,
857
- constant uint64_t & nb01,
858
- constant uint64_t & nb02,
859
- constant int64_t & ne10,
860
- constant int64_t & ne11,
861
- constant int64_t & ne12,
862
- constant uint64_t & nb10,
863
- constant uint64_t & nb11,
864
- constant uint64_t & nb12,
865
- constant int64_t & ne0,
866
- constant int64_t & ne1,
867
- uint3 tgpig[[threadgroup_position_in_grid]],
868
- uint tiisg[[thread_index_in_simdgroup]]) {
869
-
870
- const int64_t r0 = tgpig.x;
871
- const int64_t rb = tgpig.y*N_F16_F16;
872
- const int64_t im = tgpig.z;
873
-
874
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
875
-
876
- if (ne00 < 128) {
877
- for (int row = 0; row < N_F16_F16; ++row) {
878
- int r1 = rb + row;
879
- if (r1 >= ne11) {
880
- break;
881
- }
882
-
883
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
884
-
885
- float sumf = 0;
886
- for (int i = tiisg; i < ne00; i += 32) {
887
- sumf += (half) x[i] * (half) y[i];
888
- }
889
-
890
- float all_sum = simd_sum(sumf);
891
- if (tiisg == 0) {
892
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
893
- }
894
- }
895
- } else {
896
- device const half4 * x4 = (device const half4 *)x;
897
- for (int row = 0; row < N_F16_F16; ++row) {
898
- int r1 = rb + row;
899
- if (r1 >= ne11) {
900
- break;
901
- }
902
-
903
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
904
- device const half4 * y4 = (device const half4 *) y;
905
-
906
- float sumf = 0;
907
- for (int i = tiisg; i < ne00/4; i += 32) {
908
- for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
909
- }
910
-
911
- float all_sum = simd_sum(sumf);
912
- if (tiisg == 0) {
913
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
914
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
915
- }
916
- }
917
- }
918
- }
919
-
920
- kernel void kernel_mul_mv_f16_f32_1row(
921
- device const char * src0,
922
- device const char * src1,
923
- device float * dst,
924
- constant int64_t & ne00,
925
- constant int64_t & ne01,
926
- constant int64_t & ne02,
927
- constant uint64_t & nb00,
928
- constant uint64_t & nb01,
929
- constant uint64_t & nb02,
930
- constant int64_t & ne10,
931
- constant int64_t & ne11,
932
- constant int64_t & ne12,
933
- constant uint64_t & nb10,
934
- constant uint64_t & nb11,
935
- constant uint64_t & nb12,
936
- constant int64_t & ne0,
937
- constant int64_t & ne1,
938
- uint3 tgpig[[threadgroup_position_in_grid]],
939
- uint tiisg[[thread_index_in_simdgroup]]) {
940
-
941
- const int64_t r0 = tgpig.x;
942
- const int64_t r1 = tgpig.y;
943
- const int64_t im = tgpig.z;
944
-
945
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
946
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
947
-
948
- float sumf = 0;
949
- if (ne00 < 128) {
950
- for (int i = tiisg; i < ne00; i += 32) {
951
- sumf += (float) x[i] * (float) y[i];
952
- }
953
- float all_sum = simd_sum(sumf);
954
- if (tiisg == 0) {
955
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
956
- }
957
- } else {
958
- device const half4 * x4 = (device const half4 *) x;
959
- device const float4 * y4 = (device const float4 *) y;
960
- for (int i = tiisg; i < ne00/4; i += 32) {
961
- for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
962
- }
963
- float all_sum = simd_sum(sumf);
964
- if (tiisg == 0) {
965
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
966
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
967
- }
968
- }
969
-
970
- }
971
-
972
- #define N_F16_F32 4
973
-
974
- kernel void kernel_mul_mv_f16_f32(
975
- device const char * src0,
976
- device const char * src1,
977
- device float * dst,
978
- constant int64_t & ne00,
979
- constant int64_t & ne01,
980
- constant int64_t & ne02,
981
- constant uint64_t & nb00,
982
- constant uint64_t & nb01,
983
- constant uint64_t & nb02,
984
- constant int64_t & ne10,
985
- constant int64_t & ne11,
986
- constant int64_t & ne12,
987
- constant uint64_t & nb10,
988
- constant uint64_t & nb11,
989
- constant uint64_t & nb12,
990
- constant int64_t & ne0,
991
- constant int64_t & ne1,
992
- uint3 tgpig[[threadgroup_position_in_grid]],
993
- uint tiisg[[thread_index_in_simdgroup]]) {
994
-
995
- const int64_t r0 = tgpig.x;
996
- const int64_t rb = tgpig.y*N_F16_F32;
997
- const int64_t im = tgpig.z;
998
-
999
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1000
-
1001
- if (ne00 < 128) {
1002
- for (int row = 0; row < N_F16_F32; ++row) {
1003
- int r1 = rb + row;
1004
- if (r1 >= ne11) {
1005
- break;
1006
- }
1007
-
1008
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1009
-
1010
- float sumf = 0;
1011
- for (int i = tiisg; i < ne00; i += 32) {
1012
- sumf += (float) x[i] * (float) y[i];
1013
- }
1014
-
1015
- float all_sum = simd_sum(sumf);
1016
- if (tiisg == 0) {
1017
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1018
- }
1019
- }
1020
- } else {
1021
- device const half4 * x4 = (device const half4 *)x;
1022
- for (int row = 0; row < N_F16_F32; ++row) {
1023
- int r1 = rb + row;
1024
- if (r1 >= ne11) {
1025
- break;
1026
- }
1027
-
1028
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1029
- device const float4 * y4 = (device const float4 *) y;
1030
-
1031
- float sumf = 0;
1032
- for (int i = tiisg; i < ne00/4; i += 32) {
1033
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1034
- }
1035
-
1036
- float all_sum = simd_sum(sumf);
1037
- if (tiisg == 0) {
1038
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1039
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1040
- }
1041
- }
1042
- }
1043
- }
1044
-
1045
- // Assumes row size (ne00) is a multiple of 4
1046
- kernel void kernel_mul_mv_f16_f32_l4(
1047
- device const char * src0,
1048
- device const char * src1,
1049
- device float * dst,
1050
- constant int64_t & ne00,
1051
- constant int64_t & ne01,
1052
- constant int64_t & ne02,
1053
- constant uint64_t & nb00,
1054
- constant uint64_t & nb01,
1055
- constant uint64_t & nb02,
1056
- constant int64_t & ne10,
1057
- constant int64_t & ne11,
1058
- constant int64_t & ne12,
1059
- constant uint64_t & nb10,
1060
- constant uint64_t & nb11,
1061
- constant uint64_t & nb12,
1062
- constant int64_t & ne0,
1063
- constant int64_t & ne1,
1064
- uint3 tgpig[[threadgroup_position_in_grid]],
1065
- uint tiisg[[thread_index_in_simdgroup]]) {
1066
-
1067
- const int nrows = ne11;
1068
- const int64_t r0 = tgpig.x;
1069
- const int64_t im = tgpig.z;
1070
-
1071
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1072
-
1073
- for (int r1 = 0; r1 < nrows; ++r1) {
1074
- device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
1075
-
1076
- float sumf = 0;
1077
- for (int i = tiisg; i < ne00/4; i += 32) {
1078
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1079
- }
1080
-
1081
- float all_sum = simd_sum(sumf);
1082
- if (tiisg == 0) {
1083
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1084
- }
1085
- }
1086
- }
1087
-
1088
- kernel void kernel_alibi_f32(
1089
- device const float * src0,
1090
- device float * dst,
1091
- constant int64_t & ne00,
1092
- constant int64_t & ne01,
1093
- constant int64_t & ne02,
1094
- constant int64_t & ne03,
1095
- constant uint64_t & nb00,
1096
- constant uint64_t & nb01,
1097
- constant uint64_t & nb02,
1098
- constant uint64_t & nb03,
1099
- constant int64_t & ne0,
1100
- constant int64_t & ne1,
1101
- constant int64_t & ne2,
1102
- constant int64_t & ne3,
1103
- constant uint64_t & nb0,
1104
- constant uint64_t & nb1,
1105
- constant uint64_t & nb2,
1106
- constant uint64_t & nb3,
1107
- constant float & m0,
1108
- constant float & m1,
1109
- constant int & n_heads_log2_floor,
1110
- uint3 tgpig[[threadgroup_position_in_grid]],
1111
- uint3 tpitg[[thread_position_in_threadgroup]],
1112
- uint3 ntg[[threads_per_threadgroup]]) {
1113
- const int64_t i03 = tgpig[2];
1114
- const int64_t i02 = tgpig[1];
1115
- const int64_t i01 = tgpig[0];
1116
-
1117
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1118
-
1119
- const int64_t i3 = n / (ne2*ne1*ne0);
1120
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1121
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1122
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1123
-
1124
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1125
- float m_k;
1126
- if (i2 < n_heads_log2_floor) {
1127
- m_k = pow(m0, i2 + 1);
1128
- } else {
1129
- m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1130
- }
1131
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1132
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1133
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
1134
- }
1135
- }
1136
-
1137
- static float rope_yarn_ramp(const float low, const float high, const int i0) {
1138
- const float y = (i0 / 2 - low) / max(0.001f, high - low);
1139
- return 1.0f - min(1.0f, max(0.0f, y));
1140
- }
1141
-
1142
- // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
1143
- // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1144
- static void rope_yarn(
1145
- float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1146
- thread float * cos_theta, thread float * sin_theta
1147
- ) {
1148
- // Get n-d rotational scaling corrected for extrapolation
1149
- float theta_interp = freq_scale * theta_extrap;
1150
- float theta = theta_interp;
1151
- if (ext_factor != 0.0f) {
1152
- float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1153
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1154
-
1155
- // Get n-d magnitude scaling corrected for interpolation
1156
- mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
1157
- }
1158
- *cos_theta = cos(theta) * mscale;
1159
- *sin_theta = sin(theta) * mscale;
1160
- }
1161
-
1162
- // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1163
- // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1164
- static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1165
- return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1166
- }
1167
-
1168
- static void rope_yarn_corr_dims(
1169
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1170
- ) {
1171
- // start and end correction dims
1172
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1173
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1174
- }
1175
-
1176
- typedef void (rope_t)(
1177
- device const void * src0,
1178
- device const int32_t * src1,
1179
- device float * dst,
1180
- constant int64_t & ne00,
1181
- constant int64_t & ne01,
1182
- constant int64_t & ne02,
1183
- constant int64_t & ne03,
1184
- constant uint64_t & nb00,
1185
- constant uint64_t & nb01,
1186
- constant uint64_t & nb02,
1187
- constant uint64_t & nb03,
1188
- constant int64_t & ne0,
1189
- constant int64_t & ne1,
1190
- constant int64_t & ne2,
1191
- constant int64_t & ne3,
1192
- constant uint64_t & nb0,
1193
- constant uint64_t & nb1,
1194
- constant uint64_t & nb2,
1195
- constant uint64_t & nb3,
1196
- constant int & n_past,
1197
- constant int & n_dims,
1198
- constant int & mode,
1199
- constant int & n_orig_ctx,
1200
- constant float & freq_base,
1201
- constant float & freq_scale,
1202
- constant float & ext_factor,
1203
- constant float & attn_factor,
1204
- constant float & beta_fast,
1205
- constant float & beta_slow,
1206
- uint tiitg[[thread_index_in_threadgroup]],
1207
- uint3 tptg[[threads_per_threadgroup]],
1208
- uint3 tgpig[[threadgroup_position_in_grid]]);
1209
-
1210
- template<typename T>
1211
- kernel void kernel_rope(
1212
- device const void * src0,
1213
- device const int32_t * src1,
1214
- device float * dst,
1215
- constant int64_t & ne00,
1216
- constant int64_t & ne01,
1217
- constant int64_t & ne02,
1218
- constant int64_t & ne03,
1219
- constant uint64_t & nb00,
1220
- constant uint64_t & nb01,
1221
- constant uint64_t & nb02,
1222
- constant uint64_t & nb03,
1223
- constant int64_t & ne0,
1224
- constant int64_t & ne1,
1225
- constant int64_t & ne2,
1226
- constant int64_t & ne3,
1227
- constant uint64_t & nb0,
1228
- constant uint64_t & nb1,
1229
- constant uint64_t & nb2,
1230
- constant uint64_t & nb3,
1231
- constant int & n_past,
1232
- constant int & n_dims,
1233
- constant int & mode,
1234
- constant int & n_orig_ctx,
1235
- constant float & freq_base,
1236
- constant float & freq_scale,
1237
- constant float & ext_factor,
1238
- constant float & attn_factor,
1239
- constant float & beta_fast,
1240
- constant float & beta_slow,
1241
- uint tiitg[[thread_index_in_threadgroup]],
1242
- uint3 tptg[[threads_per_threadgroup]],
1243
- uint3 tgpig[[threadgroup_position_in_grid]]) {
1244
- const int64_t i3 = tgpig[2];
1245
- const int64_t i2 = tgpig[1];
1246
- const int64_t i1 = tgpig[0];
1247
-
1248
- const bool is_neox = mode & 2;
1249
-
1250
- float corr_dims[2];
1251
- rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1252
-
1253
- device const int32_t * pos = src1;
1254
-
1255
- const int64_t p = pos[i2];
1256
-
1257
- const float theta_0 = (float)p;
1258
- const float inv_ndims = -1.f/n_dims;
1259
-
1260
- if (!is_neox) {
1261
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1262
-
1263
- const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1264
- float cos_theta, sin_theta;
1265
- rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1266
-
1267
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1268
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1269
-
1270
- const T x0 = src[0];
1271
- const T x1 = src[1];
1272
-
1273
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1274
- dst_data[1] = x0*sin_theta + x1*cos_theta;
1275
- }
1276
- } else {
1277
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1278
- for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1279
-
1280
- // simplified from `(ib * n_dims + ic) * inv_ndims`
1281
- const float cur_rot = inv_ndims*ic - ib;
1282
-
1283
- const float theta = theta_0 * pow(freq_base, cur_rot);
1284
- float cos_theta, sin_theta;
1285
- rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1286
-
1287
- const int64_t i0 = ib*n_dims + ic/2;
1288
-
1289
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1290
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1291
-
1292
- const float x0 = src[0];
1293
- const float x1 = src[n_dims/2];
1294
-
1295
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1296
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1297
- }
1298
- }
1299
- }
1300
- }
1301
-
1302
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1303
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1304
-
1305
- kernel void kernel_im2col_f16(
1306
- device const float * x,
1307
- device half * dst,
1308
- constant int32_t & ofs0,
1309
- constant int32_t & ofs1,
1310
- constant int32_t & IW,
1311
- constant int32_t & IH,
1312
- constant int32_t & CHW,
1313
- constant int32_t & s0,
1314
- constant int32_t & s1,
1315
- constant int32_t & p0,
1316
- constant int32_t & p1,
1317
- constant int32_t & d0,
1318
- constant int32_t & d1,
1319
- uint3 tgpig[[threadgroup_position_in_grid]],
1320
- uint3 tgpg[[threadgroups_per_grid]],
1321
- uint3 tpitg[[thread_position_in_threadgroup]],
1322
- uint3 ntg[[threads_per_threadgroup]]) {
1323
- const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
1324
- const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
1325
-
1326
- const int32_t offset_dst =
1327
- (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1328
- (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1329
-
1330
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1331
- dst[offset_dst] = 0.0f;
1332
- } else {
1333
- const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1334
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
1335
- }
1336
- }
1337
-
1338
- kernel void kernel_cpy_f16_f16(
1339
- device const half * src0,
1340
- device half * dst,
1341
- constant int64_t & ne00,
1342
- constant int64_t & ne01,
1343
- constant int64_t & ne02,
1344
- constant int64_t & ne03,
1345
- constant uint64_t & nb00,
1346
- constant uint64_t & nb01,
1347
- constant uint64_t & nb02,
1348
- constant uint64_t & nb03,
1349
- constant int64_t & ne0,
1350
- constant int64_t & ne1,
1351
- constant int64_t & ne2,
1352
- constant int64_t & ne3,
1353
- constant uint64_t & nb0,
1354
- constant uint64_t & nb1,
1355
- constant uint64_t & nb2,
1356
- constant uint64_t & nb3,
1357
- uint3 tgpig[[threadgroup_position_in_grid]],
1358
- uint3 tpitg[[thread_position_in_threadgroup]],
1359
- uint3 ntg[[threads_per_threadgroup]]) {
1360
- const int64_t i03 = tgpig[2];
1361
- const int64_t i02 = tgpig[1];
1362
- const int64_t i01 = tgpig[0];
1363
-
1364
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1365
-
1366
- const int64_t i3 = n / (ne2*ne1*ne0);
1367
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1368
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1369
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1370
-
1371
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1372
-
1373
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1374
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1375
- dst_data[i00] = src[0];
1376
- }
1377
- }
1378
-
1379
- kernel void kernel_cpy_f32_f16(
1380
- device const float * src0,
1381
- device half * dst,
1382
- constant int64_t & ne00,
1383
- constant int64_t & ne01,
1384
- constant int64_t & ne02,
1385
- constant int64_t & ne03,
1386
- constant uint64_t & nb00,
1387
- constant uint64_t & nb01,
1388
- constant uint64_t & nb02,
1389
- constant uint64_t & nb03,
1390
- constant int64_t & ne0,
1391
- constant int64_t & ne1,
1392
- constant int64_t & ne2,
1393
- constant int64_t & ne3,
1394
- constant uint64_t & nb0,
1395
- constant uint64_t & nb1,
1396
- constant uint64_t & nb2,
1397
- constant uint64_t & nb3,
1398
- uint3 tgpig[[threadgroup_position_in_grid]],
1399
- uint3 tpitg[[thread_position_in_threadgroup]],
1400
- uint3 ntg[[threads_per_threadgroup]]) {
1401
- const int64_t i03 = tgpig[2];
1402
- const int64_t i02 = tgpig[1];
1403
- const int64_t i01 = tgpig[0];
1404
-
1405
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1406
-
1407
- const int64_t i3 = n / (ne2*ne1*ne0);
1408
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1409
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1410
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1411
-
1412
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1413
-
1414
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1415
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1416
-
1417
- dst_data[i00] = src[0];
1418
- }
1419
- }
1420
-
1421
- kernel void kernel_cpy_f32_f32(
1422
- device const float * src0,
1423
- device float * dst,
1424
- constant int64_t & ne00,
1425
- constant int64_t & ne01,
1426
- constant int64_t & ne02,
1427
- constant int64_t & ne03,
1428
- constant uint64_t & nb00,
1429
- constant uint64_t & nb01,
1430
- constant uint64_t & nb02,
1431
- constant uint64_t & nb03,
1432
- constant int64_t & ne0,
1433
- constant int64_t & ne1,
1434
- constant int64_t & ne2,
1435
- constant int64_t & ne3,
1436
- constant uint64_t & nb0,
1437
- constant uint64_t & nb1,
1438
- constant uint64_t & nb2,
1439
- constant uint64_t & nb3,
1440
- uint3 tgpig[[threadgroup_position_in_grid]],
1441
- uint3 tpitg[[thread_position_in_threadgroup]],
1442
- uint3 ntg[[threads_per_threadgroup]]) {
1443
- const int64_t i03 = tgpig[2];
1444
- const int64_t i02 = tgpig[1];
1445
- const int64_t i01 = tgpig[0];
1446
-
1447
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1448
-
1449
- const int64_t i3 = n / (ne2*ne1*ne0);
1450
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1451
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1452
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1453
-
1454
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1455
-
1456
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1457
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1458
-
1459
- dst_data[i00] = src[0];
1460
- }
1461
- }
1462
-
1463
- kernel void kernel_concat(
1464
- device const char * src0,
1465
- device const char * src1,
1466
- device char * dst,
1467
- constant int64_t & ne00,
1468
- constant int64_t & ne01,
1469
- constant int64_t & ne02,
1470
- constant int64_t & ne03,
1471
- constant uint64_t & nb00,
1472
- constant uint64_t & nb01,
1473
- constant uint64_t & nb02,
1474
- constant uint64_t & nb03,
1475
- constant int64_t & ne10,
1476
- constant int64_t & ne11,
1477
- constant int64_t & ne12,
1478
- constant int64_t & ne13,
1479
- constant uint64_t & nb10,
1480
- constant uint64_t & nb11,
1481
- constant uint64_t & nb12,
1482
- constant uint64_t & nb13,
1483
- constant int64_t & ne0,
1484
- constant int64_t & ne1,
1485
- constant int64_t & ne2,
1486
- constant int64_t & ne3,
1487
- constant uint64_t & nb0,
1488
- constant uint64_t & nb1,
1489
- constant uint64_t & nb2,
1490
- constant uint64_t & nb3,
1491
- uint3 tgpig[[threadgroup_position_in_grid]],
1492
- uint3 tpitg[[thread_position_in_threadgroup]],
1493
- uint3 ntg[[threads_per_threadgroup]]) {
1494
-
1495
- const int64_t i03 = tgpig.z;
1496
- const int64_t i02 = tgpig.y;
1497
- const int64_t i01 = tgpig.x;
1498
-
1499
- const int64_t i13 = i03 % ne13;
1500
- const int64_t i12 = i02 % ne12;
1501
- const int64_t i11 = i01 % ne11;
1502
-
1503
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1504
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1505
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1506
-
1507
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1508
- if (i02 < ne02) {
1509
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1510
- src0_ptr += ntg.x*nb00;
1511
- } else {
1512
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1513
- src1_ptr += ntg.x*nb10;
1514
- }
1515
- dst_ptr += ntg.x*nb0;
1516
- }
1517
- }
1518
-
1519
- //============================================ k-quants ======================================================
1520
-
1521
- #ifndef QK_K
1522
- #define QK_K 256
1523
- #else
1524
- static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
1525
- #endif
1526
-
1527
- #if QK_K == 256
1528
- #define K_SCALE_SIZE 12
1529
- #else
1530
- #define K_SCALE_SIZE 4
1531
- #endif
1532
-
1533
- typedef struct {
1534
- uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
1535
- uint8_t qs[QK_K/4]; // quants
1536
- half d; // super-block scale for quantized scales
1537
- half dmin; // super-block scale for quantized mins
1538
- } block_q2_K;
1539
- // 84 bytes / block
1540
-
1541
- typedef struct {
1542
- uint8_t hmask[QK_K/8]; // quants - high bit
1543
- uint8_t qs[QK_K/4]; // quants - low 2 bits
1544
- #if QK_K == 64
1545
- uint8_t scales[2];
1546
- #else
1547
- uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
1548
- #endif
1549
- half d; // super-block scale
1550
- } block_q3_K;
1551
-
1552
- #if QK_K == 64
1553
- typedef struct {
1554
- half d[2]; // super-block scales/mins
1555
- uint8_t scales[2];
1556
- uint8_t qs[QK_K/2]; // 4-bit quants
1557
- } block_q4_K;
1558
- #else
1559
- typedef struct {
1560
- half d; // super-block scale for quantized scales
1561
- half dmin; // super-block scale for quantized mins
1562
- uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
1563
- uint8_t qs[QK_K/2]; // 4--bit quants
1564
- } block_q4_K;
1565
- #endif
1566
-
1567
- #if QK_K == 64
1568
- typedef struct {
1569
- half d; // super-block scales/mins
1570
- int8_t scales[QK_K/16]; // 8-bit block scales
1571
- uint8_t qh[QK_K/8]; // quants, high bit
1572
- uint8_t qs[QK_K/2]; // quants, low 4 bits
1573
- } block_q5_K;
1574
- #else
1575
- typedef struct {
1576
- half d; // super-block scale for quantized scales
1577
- half dmin; // super-block scale for quantized mins
1578
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
1579
- uint8_t qh[QK_K/8]; // quants, high bit
1580
- uint8_t qs[QK_K/2]; // quants, low 4 bits
1581
- } block_q5_K;
1582
- // 176 bytes / block
1583
- #endif
1584
-
1585
- typedef struct {
1586
- uint8_t ql[QK_K/2]; // quants, lower 4 bits
1587
- uint8_t qh[QK_K/4]; // quants, upper 2 bits
1588
- int8_t scales[QK_K/16]; // scales, quantized with 8 bits
1589
- half d; // super-block scale
1590
- } block_q6_K;
1591
- // 210 bytes / block
1592
-
1593
- static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1594
- uchar4 r;
1595
- if (j < 4) {
1596
- r[0] = q[j+0] & 63;
1597
- r[2] = q[j+1] & 63;
1598
- r[1] = q[j+4] & 63;
1599
- r[3] = q[j+5] & 63;
1600
- } else {
1601
- r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
1602
- r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
1603
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
1604
- r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
1605
- }
1606
- return r;
1607
- }
1608
-
1609
- //====================================== dot products =========================
1610
-
1611
- kernel void kernel_mul_mv_q2_K_f32(
1612
- device const void * src0,
1613
- device const float * src1,
1614
- device float * dst,
1615
- constant int64_t & ne00,
1616
- constant int64_t & ne01[[buffer(4)]],
1617
- constant int64_t & ne02[[buffer(5)]],
1618
- constant int64_t & ne10[[buffer(9)]],
1619
- constant int64_t & ne12[[buffer(11)]],
1620
- constant int64_t & ne0[[buffer(15)]],
1621
- constant int64_t & ne1[[buffer(16)]],
1622
- constant uint & gqa[[buffer(17)]],
1623
- uint3 tgpig[[threadgroup_position_in_grid]],
1624
- uint tiisg[[thread_index_in_simdgroup]],
1625
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1626
-
1627
- const int nb = ne00/QK_K;
1628
- const int r0 = tgpig.x;
1629
- const int r1 = tgpig.y;
1630
- const int r2 = tgpig.z;
1631
-
1632
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1633
- const int ib_row = first_row * nb;
1634
- const uint offset0 = r2/gqa*(nb*ne0);
1635
- device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
1636
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1637
- float yl[32];
1638
- float sumf[N_DST]={0.f}, all_sum;
1639
-
1640
- const int step = sizeof(block_q2_K) * nb;
1641
-
1642
- #if QK_K == 256
1643
- const int ix = tiisg/8; // 0...3
1644
- const int it = tiisg%8; // 0...7
1645
- const int im = it/4; // 0 or 1
1646
- const int ir = it%4; // 0...3
1647
- const int is = (8*ir)/16;// 0 or 1
1648
-
1649
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
1650
-
1651
- for (int ib = ix; ib < nb; ib += 4) {
1652
-
1653
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
1654
- for (int i = 0; i < 8; ++i) {
1655
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1656
- yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
1657
- yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
1658
- yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1659
- }
1660
-
1661
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1662
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1663
- device const half * dh = &x[ib].d;
1664
-
1665
- for (int row = 0; row < N_DST; row++) {
1666
-
1667
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1668
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1669
- for (int i = 0; i < 8; i += 2) {
1670
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1671
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1672
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1673
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1674
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1675
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1676
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1677
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1678
- }
1679
- float dall = dh[0];
1680
- float dmin = dh[1] * 1.f/16.f;
1681
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1682
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
1683
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
1684
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
1685
- dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
1686
-
1687
- qs += step/2;
1688
- sc += step;
1689
- dh += step/2;
1690
- }
1691
-
1692
- y4 += 4 * QK_K;
1693
- }
1694
- #else
1695
- const int ix = tiisg/2; // 0...15
1696
- const int it = tiisg%2; // 0...1
1697
-
1698
- device const float * y4 = y + ix * QK_K + 8 * it;
1699
-
1700
- for (int ib = ix; ib < nb; ib += 16) {
1701
-
1702
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
1703
- for (int i = 0; i < 8; ++i) {
1704
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1705
- yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
1706
- yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
1707
- yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
1708
- }
1709
-
1710
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
1711
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1712
- device const half * dh = &x[ib].d;
1713
-
1714
- for (int row = 0; row < N_DST; row++) {
1715
-
1716
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1717
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1718
- for (int i = 0; i < 8; i += 2) {
1719
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1720
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1721
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1722
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1723
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1724
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1725
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1726
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1727
- }
1728
-
1729
- float dall = dh[0];
1730
- float dmin = dh[1];
1731
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1732
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1733
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1734
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1735
- dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1736
-
1737
- qs += step/2;
1738
- sc += step;
1739
- dh += step/2;
1740
- }
1741
-
1742
- y4 += 16 * QK_K;
1743
- }
1744
- #endif
1745
-
1746
- for (int row = 0; row < N_DST; ++row) {
1747
- all_sum = simd_sum(sumf[row]);
1748
- if (tiisg == 0) {
1749
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1750
- }
1751
- }
1752
- }
1753
-
1754
- #if QK_K == 256
1755
- kernel void kernel_mul_mv_q3_K_f32(
1756
- device const void * src0,
1757
- device const float * src1,
1758
- device float * dst,
1759
- constant int64_t & ne00,
1760
- constant int64_t & ne01[[buffer(4)]],
1761
- constant int64_t & ne02[[buffer(5)]],
1762
- constant int64_t & ne10[[buffer(9)]],
1763
- constant int64_t & ne12[[buffer(11)]],
1764
- constant int64_t & ne0[[buffer(15)]],
1765
- constant int64_t & ne1[[buffer(16)]],
1766
- constant uint & gqa[[buffer(17)]],
1767
- uint3 tgpig[[threadgroup_position_in_grid]],
1768
- uint tiisg[[thread_index_in_simdgroup]],
1769
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1770
-
1771
- const int nb = ne00/QK_K;
1772
-
1773
- const int64_t r0 = tgpig.x;
1774
- const int64_t r1 = tgpig.y;
1775
- const int64_t r2 = tgpig.z;
1776
-
1777
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1778
- const uint offset0 = r2/gqa*(nb*ne0);
1779
- device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1780
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1781
-
1782
- float yl[32];
1783
-
1784
- //const uint16_t kmask1 = 0x3030;
1785
- //const uint16_t kmask2 = 0x0f0f;
1786
-
1787
- const int tid = tiisg/4;
1788
- const int ix = tiisg%4;
1789
- const int ip = tid/4; // 0 or 1
1790
- const int il = 2*((tid%4)/2); // 0 or 2
1791
- const int ir = tid%2;
1792
- const int n = 8;
1793
- const int l0 = n*ir;
1794
-
1795
- // One would think that the Metal compiler would figure out that ip and il can only have
1796
- // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1797
- // with these two tales.
1798
- //
1799
- // Possible masks for the high bit
1800
- const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
1801
- {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
1802
- {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
1803
- {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
1804
-
1805
- // Possible masks for the low 2 bits
1806
- const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
1807
-
1808
- const ushort4 hm = mm[2*ip + il/2];
1809
-
1810
- const int shift = 2*il;
1811
- const float v1 = il == 0 ? 4.f : 64.f;
1812
- const float v2 = 4.f * v1;
1813
-
1814
- const uint16_t s_shift1 = 4*ip;
1815
- const uint16_t s_shift2 = s_shift1 + il;
1816
-
1817
- const int q_offset = 32*ip + l0;
1818
- const int y_offset = 128*ip + 32*il + l0;
1819
-
1820
- const int step = sizeof(block_q3_K) * nb / 2;
1821
-
1822
- device const float * y1 = yy + ix*QK_K + y_offset;
1823
-
1824
- uint32_t scales32, aux32;
1825
- thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1826
- thread const int8_t * scales = (thread const int8_t *)&scales32;
1827
-
1828
- float sumf1[2] = {0.f};
1829
- float sumf2[2] = {0.f};
1830
- for (int i = ix; i < nb; i += 4) {
1831
-
1832
- for (int l = 0; l < 8; ++l) {
1833
- yl[l+ 0] = y1[l+ 0];
1834
- yl[l+ 8] = y1[l+16];
1835
- yl[l+16] = y1[l+32];
1836
- yl[l+24] = y1[l+48];
1837
- }
1838
-
1839
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1840
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1841
- device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1842
- device const half * dh = &x[i].d;
1843
-
1844
- for (int row = 0; row < 2; ++row) {
1845
-
1846
- const float d_all = (float)dh[0];
1847
-
1848
- scales16[0] = a[4];
1849
- scales16[1] = a[5];
1850
- aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
1851
- scales16[0] = a[il+0];
1852
- scales16[1] = a[il+1];
1853
- scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
1854
-
1855
- float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
1856
- for (int l = 0; l < n; l += 2) {
1857
- const int32_t qs = q[l/2];
1858
- s1 += yl[l+0] * (qs & qm[il/2][0]);
1859
- s2 += yl[l+1] * (qs & qm[il/2][1]);
1860
- s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
1861
- s4 += yl[l+16] * (qs & qm[il/2][2]);
1862
- s5 += yl[l+17] * (qs & qm[il/2][3]);
1863
- s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
1864
- }
1865
- float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1866
- float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1867
- sumf1[row] += d1 * (scales[0] - 32);
1868
- sumf2[row] += d2 * (scales[2] - 32);
1869
-
1870
- s1 = s2 = s3 = s4 = s5 = s6 = 0;
1871
- for (int l = 0; l < n; l += 2) {
1872
- const int32_t qs = q[l/2+8];
1873
- s1 += yl[l+8] * (qs & qm[il/2][0]);
1874
- s2 += yl[l+9] * (qs & qm[il/2][1]);
1875
- s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
1876
- s4 += yl[l+24] * (qs & qm[il/2][2]);
1877
- s5 += yl[l+25] * (qs & qm[il/2][3]);
1878
- s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
1879
- }
1880
- d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1881
- d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1882
- sumf1[row] += d1 * (scales[1] - 32);
1883
- sumf2[row] += d2 * (scales[3] - 32);
1884
-
1885
- q += step;
1886
- h += step;
1887
- a += step;
1888
- dh += step;
1889
-
1890
- }
1891
-
1892
- y1 += 4 * QK_K;
1893
-
1894
- }
1895
-
1896
- for (int row = 0; row < 2; ++row) {
1897
- const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
1898
- sumf1[row] = simd_sum(sumf);
1899
- }
1900
- if (tiisg == 0) {
1901
- for (int row = 0; row < 2; ++row) {
1902
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1903
- }
1904
- }
1905
- }
1906
- #else
1907
- kernel void kernel_mul_mv_q3_K_f32(
1908
- device const void * src0,
1909
- device const float * src1,
1910
- device float * dst,
1911
- constant int64_t & ne00,
1912
- constant int64_t & ne01[[buffer(4)]],
1913
- constant int64_t & ne02[[buffer(5)]],
1914
- constant int64_t & ne10[[buffer(9)]],
1915
- constant int64_t & ne12[[buffer(11)]],
1916
- constant int64_t & ne0[[buffer(15)]],
1917
- constant int64_t & ne1[[buffer(16)]],
1918
- constant uint & gqa[[buffer(17)]],
1919
- uint3 tgpig[[threadgroup_position_in_grid]],
1920
- uint tiisg[[thread_index_in_simdgroup]],
1921
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1922
-
1923
- const int nb = ne00/QK_K;
1924
-
1925
- const int64_t r0 = tgpig.x;
1926
- const int64_t r1 = tgpig.y;
1927
- const int64_t r2 = tgpig.z;
1928
-
1929
- const int row = 2 * r0 + sgitg;
1930
- const uint offset0 = r2/gqa*(nb*ne0);
1931
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1932
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1933
- const int ix = tiisg/4;
1934
- const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1935
- const int im = il/8; // 0, 0, 1, 1
1936
- const int in = il%8; // 0, 4, 0, 4
1937
-
1938
- float2 sum = {0.f, 0.f};
1939
-
1940
- for (int i = ix; i < nb; i += 8) {
1941
-
1942
- const float d_all = (float)(x[i].d);
1943
-
1944
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1945
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1946
- device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1947
- device const float * y = yy + i * QK_K + il;
1948
-
1949
- const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1950
- const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1951
- const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1952
- const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1953
-
1954
- for (int l = 0; l < 4; l += 2) {
1955
- const uint16_t hm = h[l/2] >> im;
1956
- sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1957
- + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1958
- + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1959
- + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1960
- sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1961
- + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1962
- + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1963
- + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
1964
- }
1965
-
1966
- }
1967
- const float sumf = sum[0] + sum[1] * 1.f/256.f;
1968
-
1969
- const float tot = simd_sum(sumf);
1970
- if (tiisg == 0) {
1971
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1972
- }
1973
-
1974
- }
1975
- #endif
1976
-
1977
- #if QK_K == 256
1978
- kernel void kernel_mul_mv_q4_K_f32(
1979
- device const void * src0,
1980
- device const float * src1,
1981
- device float * dst,
1982
- constant int64_t & ne00,
1983
- constant int64_t & ne01 [[buffer(4)]],
1984
- constant int64_t & ne02 [[buffer(5)]],
1985
- constant int64_t & ne10 [[buffer(9)]],
1986
- constant int64_t & ne12 [[buffer(11)]],
1987
- constant int64_t & ne0 [[buffer(15)]],
1988
- constant int64_t & ne1 [[buffer(16)]],
1989
- constant uint & gqa [[buffer(17)]],
1990
- uint3 tgpig[[threadgroup_position_in_grid]],
1991
- uint tiisg[[thread_index_in_simdgroup]],
1992
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1993
-
1994
- const uint16_t kmask1 = 0x3f3f;
1995
- const uint16_t kmask2 = 0x0f0f;
1996
- const uint16_t kmask3 = 0xc0c0;
1997
-
1998
- const int ix = tiisg/8; // 0...3
1999
- const int it = tiisg%8; // 0...7
2000
- const int im = it/4; // 0 or 1
2001
- const int ir = it%4; // 0...3
2002
-
2003
- const int nb = ne00/QK_K;
2004
- const int r0 = tgpig.x;
2005
- const int r1 = tgpig.y;
2006
- const int r2 = tgpig.z;
2007
- //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2008
- const int first_row = r0 * N_DST;
2009
- const int ib_row = first_row * nb;
2010
- const uint offset0 = r2/gqa*(nb*ne0);
2011
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2012
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2013
- float yl[16];
2014
- float yh[16];
2015
- float sumf[N_DST]={0.f}, all_sum;
2016
-
2017
- const int step = sizeof(block_q4_K) * nb / 2;
2018
-
2019
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
2020
-
2021
- uint16_t sc16[4];
2022
- thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
2023
-
2024
- for (int ib = ix; ib < nb; ib += 4) {
2025
-
2026
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
2027
- for (int i = 0; i < 8; ++i) {
2028
- yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
2029
- yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
2030
- yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
2031
- yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
2032
- }
2033
-
2034
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
2035
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2036
- device const half * dh = &x[ib].d;
2037
-
2038
- for (int row = 0; row < N_DST; row++) {
2039
-
2040
- sc16[0] = sc[0] & kmask1;
2041
- sc16[1] = sc[2] & kmask1;
2042
- sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
2043
- sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
2044
-
2045
- device const uint16_t * q2 = q1 + 32;
2046
-
2047
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
2048
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
2049
- for (int i = 0; i < 8; i += 2) {
2050
- acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
2051
- acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
2052
- acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
2053
- acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
2054
- acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
2055
- acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
2056
- acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
2057
- acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
2058
- }
2059
-
2060
- float dall = dh[0];
2061
- float dmin = dh[1];
2062
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
2063
- (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
2064
- (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
2065
- (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
2066
- dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
2067
-
2068
- q1 += step;
2069
- sc += step;
2070
- dh += step;
2071
- }
2072
-
2073
- y4 += 4 * QK_K;
2074
- }
2075
-
2076
- for (int row = 0; row < N_DST; ++row) {
2077
- all_sum = simd_sum(sumf[row]);
2078
- if (tiisg == 0) {
2079
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2080
- }
2081
- }
2082
- }
2083
- #else
2084
- kernel void kernel_mul_mv_q4_K_f32(
2085
- device const void * src0,
2086
- device const float * src1,
2087
- device float * dst,
2088
- constant int64_t & ne00,
2089
- constant int64_t & ne01[[buffer(4)]],
2090
- constant int64_t & ne02[[buffer(5)]],
2091
- constant int64_t & ne10[[buffer(9)]],
2092
- constant int64_t & ne12[[buffer(11)]],
2093
- constant int64_t & ne0[[buffer(15)]],
2094
- constant int64_t & ne1[[buffer(16)]],
2095
- constant uint & gqa[[buffer(17)]],
2096
- uint3 tgpig[[threadgroup_position_in_grid]],
2097
- uint tiisg[[thread_index_in_simdgroup]],
2098
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2099
-
2100
- const int ix = tiisg/4; // 0...7
2101
- const int it = tiisg%4; // 0...3
2102
-
2103
- const int nb = ne00/QK_K;
2104
- const int r0 = tgpig.x;
2105
- const int r1 = tgpig.y;
2106
- const int r2 = tgpig.z;
2107
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2108
- const int ib_row = first_row * nb;
2109
- const uint offset0 = r2/gqa*(nb*ne0);
2110
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2111
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2112
- float yl[8];
2113
- float yh[8];
2114
- float sumf[N_DST]={0.f}, all_sum;
2115
-
2116
- const int step = sizeof(block_q4_K) * nb / 2;
2117
-
2118
- device const float * y4 = y + ix * QK_K + 8 * it;
2119
-
2120
- uint16_t sc16[4];
2121
-
2122
- for (int ib = ix; ib < nb; ib += 8) {
2123
-
2124
- float2 sumy = {0.f, 0.f};
2125
- for (int i = 0; i < 8; ++i) {
2126
- yl[i] = y4[i+ 0]; sumy[0] += yl[i];
2127
- yh[i] = y4[i+32]; sumy[1] += yh[i];
2128
- }
2129
-
2130
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
2131
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
2132
- device const half * dh = x[ib].d;
2133
-
2134
- for (int row = 0; row < N_DST; row++) {
2135
-
2136
- sc16[0] = sc[0] & 0x000f;
2137
- sc16[1] = sc[0] & 0x0f00;
2138
- sc16[2] = sc[0] & 0x00f0;
2139
- sc16[3] = sc[0] & 0xf000;
2140
-
2141
- float2 acc1 = {0.f, 0.f};
2142
- float2 acc2 = {0.f, 0.f};
2143
- for (int i = 0; i < 8; i += 2) {
2144
- acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
2145
- acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
2146
- acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
2147
- acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
2148
- }
2149
-
2150
- float dall = dh[0];
2151
- float dmin = dh[1];
2152
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
2153
- (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
2154
- dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
2155
-
2156
- qs += step;
2157
- sc += step;
2158
- dh += step;
2159
- }
2160
-
2161
- y4 += 8 * QK_K;
2162
- }
2163
-
2164
- for (int row = 0; row < N_DST; ++row) {
2165
- all_sum = simd_sum(sumf[row]);
2166
- if (tiisg == 0) {
2167
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
2168
- }
2169
- }
2170
- }
2171
- #endif
2172
-
2173
- kernel void kernel_mul_mv_q5_K_f32(
2174
- device const void * src0,
2175
- device const float * src1,
2176
- device float * dst,
2177
- constant int64_t & ne00,
2178
- constant int64_t & ne01[[buffer(4)]],
2179
- constant int64_t & ne02[[buffer(5)]],
2180
- constant int64_t & ne10[[buffer(9)]],
2181
- constant int64_t & ne12[[buffer(11)]],
2182
- constant int64_t & ne0[[buffer(15)]],
2183
- constant int64_t & ne1[[buffer(16)]],
2184
- constant uint & gqa[[buffer(17)]],
2185
- uint3 tgpig[[threadgroup_position_in_grid]],
2186
- uint tiisg[[thread_index_in_simdgroup]],
2187
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2188
-
2189
- const int nb = ne00/QK_K;
2190
-
2191
- const int64_t r0 = tgpig.x;
2192
- const int64_t r1 = tgpig.y;
2193
- const int r2 = tgpig.z;
2194
-
2195
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2196
- const uint offset0 = r2/gqa*(nb*ne0);
2197
- device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
2198
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2199
-
2200
- float sumf[2]={0.f};
2201
-
2202
- const int step = sizeof(block_q5_K) * nb;
2203
-
2204
- #if QK_K == 256
2205
- #
2206
- float yl[16], yh[16];
2207
-
2208
- const uint16_t kmask1 = 0x3f3f;
2209
- const uint16_t kmask2 = 0x0f0f;
2210
- const uint16_t kmask3 = 0xc0c0;
2211
-
2212
- const int tid = tiisg/4;
2213
- const int ix = tiisg%4;
2214
- const int im = tid/4;
2215
- const int ir = tid%4;
2216
- const int n = 8;
2217
-
2218
- const int l0 = n*ir;
2219
- const int q_offset = 32*im + l0;
2220
- const int y_offset = 64*im + l0;
2221
-
2222
- const uint8_t hm1 = 1u << (2*im);
2223
- const uint8_t hm2 = hm1 << 1;
2224
- const uint8_t hm3 = hm1 << 4;
2225
- const uint8_t hm4 = hm2 << 4;
2226
-
2227
- uint16_t sc16[4];
2228
- thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
2229
-
2230
- device const float * y1 = yy + ix*QK_K + y_offset;
2231
-
2232
- for (int i = ix; i < nb; i += 4) {
2233
-
2234
- device const uint8_t * q1 = x[i].qs + q_offset;
2235
- device const uint8_t * qh = x[i].qh + l0;
2236
- device const half * dh = &x[i].d;
2237
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
2238
-
2239
- device const float * y2 = y1 + 128;
2240
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
2241
- for (int l = 0; l < 8; ++l) {
2242
- yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
2243
- yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
2244
- yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
2245
- yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
2246
- }
2247
-
2248
- for (int row = 0; row < 2; ++row) {
2249
-
2250
- device const uint8_t * q2 = q1 + 64;
2251
-
2252
- sc16[0] = a[0] & kmask1;
2253
- sc16[1] = a[2] & kmask1;
2254
- sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
2255
- sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
2256
-
2257
- float4 acc1 = {0.f};
2258
- float4 acc2 = {0.f};
2259
- for (int l = 0; l < n; ++l) {
2260
- uint8_t h = qh[l];
2261
- acc1[0] += yl[l+0] * (q1[l] & 0x0F);
2262
- acc1[1] += yl[l+8] * (q1[l] & 0xF0);
2263
- acc1[2] += yh[l+0] * (q2[l] & 0x0F);
2264
- acc1[3] += yh[l+8] * (q2[l] & 0xF0);
2265
- acc2[0] += h & hm1 ? yl[l+0] : 0.f;
2266
- acc2[1] += h & hm2 ? yl[l+8] : 0.f;
2267
- acc2[2] += h & hm3 ? yh[l+0] : 0.f;
2268
- acc2[3] += h & hm4 ? yh[l+8] : 0.f;
2269
- }
2270
- const float dall = dh[0];
2271
- const float dmin = dh[1];
2272
- sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
2273
- sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
2274
- sc8[4] * (acc1[2] + 16.f*acc2[2]) +
2275
- sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
2276
- dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
2277
-
2278
- q1 += step;
2279
- qh += step;
2280
- dh += step/2;
2281
- a += step/2;
2282
-
2283
- }
2284
-
2285
- y1 += 4 * QK_K;
2286
-
2287
- }
2288
- #else
2289
- float yl[8], yh[8];
2290
-
2291
- const int il = 4 * (tiisg/8); // 0, 4, 8, 12
2292
- const int ix = tiisg%8;
2293
- const int im = il/8; // 0, 0, 1, 1
2294
- const int in = il%8; // 0, 4, 0, 4
2295
-
2296
- device const float * y = yy + ix*QK_K + il;
2297
-
2298
- for (int i = ix; i < nb; i += 8) {
2299
-
2300
- for (int l = 0; l < 4; ++l) {
2301
- yl[l+0] = y[l+ 0];
2302
- yl[l+4] = y[l+16];
2303
- yh[l+0] = y[l+32];
2304
- yh[l+4] = y[l+48];
2305
- }
2306
-
2307
- device const half * dh = &x[i].d;
2308
- device const uint8_t * q = x[i].qs + il;
2309
- device const uint8_t * h = x[i].qh + in;
2310
- device const int8_t * s = x[i].scales;
2311
-
2312
- for (int row = 0; row < 2; ++row) {
2313
-
2314
- const float d = dh[0];
2315
-
2316
- float2 acc = {0.f, 0.f};
2317
- for (int l = 0; l < 4; ++l) {
2318
- const uint8_t hl = h[l] >> im;
2319
- acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
2320
- + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
2321
- acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
2322
- + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
2323
- }
2324
- sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
2325
-
2326
- q += step;
2327
- h += step;
2328
- s += step;
2329
- dh += step/2;
2330
-
2331
- }
2332
-
2333
- y += 8 * QK_K;
2334
- }
2335
- #endif
2336
-
2337
- for (int row = 0; row < 2; ++row) {
2338
- const float tot = simd_sum(sumf[row]);
2339
- if (tiisg == 0) {
2340
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
2341
- }
2342
- }
2343
-
2344
- }
2345
-
2346
- kernel void kernel_mul_mv_q6_K_f32(
2347
- device const void * src0,
2348
- device const float * src1,
2349
- device float * dst,
2350
- constant int64_t & ne00,
2351
- constant int64_t & ne01[[buffer(4)]],
2352
- constant int64_t & ne02[[buffer(5)]],
2353
- constant int64_t & ne10[[buffer(9)]],
2354
- constant int64_t & ne12[[buffer(11)]],
2355
- constant int64_t & ne0[[buffer(15)]],
2356
- constant int64_t & ne1[[buffer(16)]],
2357
- constant uint & gqa[[buffer(17)]],
2358
- uint3 tgpig[[threadgroup_position_in_grid]],
2359
- uint tiisg[[thread_index_in_simdgroup]],
2360
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2361
-
2362
- const uint8_t kmask1 = 0x03;
2363
- const uint8_t kmask2 = 0x0C;
2364
- const uint8_t kmask3 = 0x30;
2365
- const uint8_t kmask4 = 0xC0;
2366
-
2367
- const int nb = ne00/QK_K;
2368
-
2369
- const int64_t r0 = tgpig.x;
2370
- const int64_t r1 = tgpig.y;
2371
- const int r2 = tgpig.z;
2372
-
2373
- const int row = 2 * r0 + sgitg;
2374
- const uint offset0 = r2/gqa*(nb*ne0);
2375
- device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
2376
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2377
-
2378
- float sumf = 0;
2379
-
2380
- #if QK_K == 256
2381
- const int tid = tiisg/2;
2382
- const int ix = tiisg%2;
2383
- const int ip = tid/8; // 0 or 1
2384
- const int il = tid%8;
2385
- const int n = 4;
2386
- const int l0 = n*il;
2387
- const int is = 8*ip + l0/16;
2388
-
2389
- const int y_offset = 128*ip + l0;
2390
- const int q_offset_l = 64*ip + l0;
2391
- const int q_offset_h = 32*ip + l0;
2392
-
2393
- for (int i = ix; i < nb; i += 2) {
2394
-
2395
- device const uint8_t * q1 = x[i].ql + q_offset_l;
2396
- device const uint8_t * q2 = q1 + 32;
2397
- device const uint8_t * qh = x[i].qh + q_offset_h;
2398
- device const int8_t * sc = x[i].scales + is;
2399
-
2400
- device const float * y = yy + i * QK_K + y_offset;
2401
-
2402
- const float dall = x[i].d;
2403
-
2404
- float4 sums = {0.f, 0.f, 0.f, 0.f};
2405
- for (int l = 0; l < n; ++l) {
2406
- sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
2407
- sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
2408
- sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
2409
- sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
2410
- }
2411
-
2412
- sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
2413
-
2414
- }
2415
-
2416
- #else
2417
- const int ix = tiisg/4;
2418
- const int il = 4*(tiisg%4);
2419
-
2420
- for (int i = ix; i < nb; i += 8) {
2421
- device const float * y = yy + i * QK_K + il;
2422
- device const uint8_t * ql = x[i].ql + il;
2423
- device const uint8_t * qh = x[i].qh + il;
2424
- device const int8_t * s = x[i].scales;
2425
-
2426
- const float d = x[i].d;
2427
-
2428
- float4 sums = {0.f, 0.f, 0.f, 0.f};
2429
- for (int l = 0; l < 4; ++l) {
2430
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
2431
- sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
2432
- sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
2433
- sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
2434
- }
2435
- sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
2436
- }
2437
-
2438
- #endif
2439
-
2440
- const float tot = simd_sum(sumf);
2441
- if (tiisg == 0) {
2442
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2443
- }
2444
- }
2445
-
2446
- //============================= templates and their specializations =============================
2447
-
2448
- // NOTE: this is not dequantizing - we are simply fitting the template
2449
- template <typename type4x4>
2450
- void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
2451
- float4x4 temp = *(((device float4x4 *)src));
2452
- for (int i = 0; i < 16; i++){
2453
- reg[i/4][i%4] = temp[i/4][i%4];
2454
- }
2455
- }
2456
-
2457
- template <typename type4x4>
2458
- void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
2459
- half4x4 temp = *(((device half4x4 *)src));
2460
- for (int i = 0; i < 16; i++){
2461
- reg[i/4][i%4] = temp[i/4][i%4];
2462
- }
2463
- }
2464
-
2465
- template <typename type4x4>
2466
- void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
2467
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
2468
- const float d1 = il ? (xb->d / 16.h) : xb->d;
2469
- const float d2 = d1 / 256.f;
2470
- const float md = -8.h * xb->d;
2471
- const ushort mask0 = il ? 0x00F0 : 0x000F;
2472
- const ushort mask1 = mask0 << 8;
2473
-
2474
- for (int i=0;i<8;i++) {
2475
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
2476
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
2477
- }
2478
- }
2479
-
2480
- template <typename type4x4>
2481
- void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
2482
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
2483
- const float d1 = il ? (xb->d / 16.h) : xb->d;
2484
- const float d2 = d1 / 256.f;
2485
- const float m = xb->m;
2486
- const ushort mask0 = il ? 0x00F0 : 0x000F;
2487
- const ushort mask1 = mask0 << 8;
2488
-
2489
- for (int i=0;i<8;i++) {
2490
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
2491
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
2492
- }
2493
- }
2494
-
2495
- template <typename type4x4>
2496
- void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2497
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
2498
- const float d = xb->d;
2499
- const float md = -16.h * xb->d;
2500
- const ushort mask = il ? 0x00F0 : 0x000F;
2501
-
2502
- const uint32_t qh = *((device const uint32_t *)xb->qh);
2503
-
2504
- const int x_mv = il ? 4 : 0;
2505
-
2506
- const int gh_mv = il ? 12 : 0;
2507
- const int gh_bk = il ? 0 : 4;
2508
-
2509
- for (int i = 0; i < 8; i++) {
2510
- // extract the 5-th bits for x0 and x1
2511
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2512
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2513
-
2514
- // combine the 4-bits from qs with the 5th bit
2515
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2516
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2517
-
2518
- reg[i/2][2*(i%2)+0] = d * x0 + md;
2519
- reg[i/2][2*(i%2)+1] = d * x1 + md;
2520
- }
2521
- }
2522
-
2523
- template <typename type4x4>
2524
- void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2525
- device const uint16_t * qs = ((device const uint16_t *)xb + 4);
2526
- const float d = xb->d;
2527
- const float m = xb->m;
2528
- const ushort mask = il ? 0x00F0 : 0x000F;
2529
-
2530
- const uint32_t qh = *((device const uint32_t *)xb->qh);
2531
-
2532
- const int x_mv = il ? 4 : 0;
2533
-
2534
- const int gh_mv = il ? 12 : 0;
2535
- const int gh_bk = il ? 0 : 4;
2536
-
2537
- for (int i = 0; i < 8; i++) {
2538
- // extract the 5-th bits for x0 and x1
2539
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2540
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2541
-
2542
- // combine the 4-bits from qs with the 5th bit
2543
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2544
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2545
-
2546
- reg[i/2][2*(i%2)+0] = d * x0 + m;
2547
- reg[i/2][2*(i%2)+1] = d * x1 + m;
2548
- }
2549
- }
2550
-
2551
- template <typename type4x4>
2552
- void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
2553
- device const int8_t * qs = ((device const int8_t *)xb->qs);
2554
- const half d = xb->d;
2555
-
2556
- for (int i=0;i<16;i++) {
2557
- reg[i/4][i%4] = (qs[i + 16*il] * d);
2558
- }
2559
- }
2560
-
2561
- template <typename type4x4>
2562
- void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
2563
- const half d = xb->d;
2564
- const half min = xb->dmin;
2565
- device const uint8_t * q = (device const uint8_t *)xb->qs;
2566
- half dl, ml;
2567
- uint8_t sc = xb->scales[il];
2568
-
2569
- #if QK_K == 256
2570
- q = q + 32*(il/8) + 16*(il&1);
2571
- il = (il/2)%4;
2572
- #endif
2573
- half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
2574
- uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2575
- dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
2576
- for (int i = 0; i < 16; ++i) {
2577
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
2578
- }
2579
- }
2580
-
2581
- template <typename type4x4>
2582
- void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
2583
- const half d_all = xb->d;
2584
- device const uint8_t * q = (device const uint8_t *)xb->qs;
2585
- device const uint8_t * h = (device const uint8_t *)xb->hmask;
2586
- device const int8_t * scales = (device const int8_t *)xb->scales;
2587
-
2588
- #if QK_K == 256
2589
- q = q + 32 * (il/8) + 16 * (il&1);
2590
- h = h + 16 * (il&1);
2591
- uint8_t m = 1 << (il/2);
2592
- uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
2593
- ((il/4)>0 ? 12 : 3);
2594
- uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
2595
- uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
2596
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
2597
- : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
2598
- half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
2599
- const half ml = 4.h * dl;
2600
-
2601
- il = (il/2) & 3;
2602
- const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
2603
- const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2604
- dl *= coef;
2605
-
2606
- for (int i = 0; i < 16; ++i) {
2607
- reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
2608
- }
2609
- #else
2610
- float kcoef = il&1 ? 1.f/16.f : 1.f;
2611
- uint16_t kmask = il&1 ? 0xF0 : 0x0F;
2612
- float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
2613
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
2614
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2615
- uint8_t m = 1<<(il*2);
2616
- for (int i = 0; i < 16; ++i) {
2617
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
2618
- }
2619
- #endif
2620
- }
2621
-
2622
- static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
2623
- return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
2624
- : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
2625
- }
2626
-
2627
- template <typename type4x4>
2628
- void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
2629
- device const uchar * q = xb->qs;
2630
-
2631
- #if QK_K == 256
2632
- short is = (il/4) * 2;
2633
- q = q + (il/4) * 32 + 16 * (il&1);
2634
- il = il & 3;
2635
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2636
- const half d = il < 2 ? xb->d : xb->d / 16.h;
2637
- const half min = xb->dmin;
2638
- const half dl = d * sc[0];
2639
- const half ml = min * sc[1];
2640
- #else
2641
- q = q + 16 * (il&1);
2642
- device const uint8_t * s = xb->scales;
2643
- device const half2 * dh = (device const half2 *)xb->d;
2644
- const float2 d = (float2)dh[0];
2645
- const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
2646
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
2647
- #endif
2648
- const ushort mask = il<2 ? 0x0F : 0xF0;
2649
- for (int i = 0; i < 16; ++i) {
2650
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
2651
- }
2652
- }
2653
-
2654
- template <typename type4x4>
2655
- void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
2656
- device const uint8_t * q = xb->qs;
2657
- device const uint8_t * qh = xb->qh;
2658
-
2659
- #if QK_K == 256
2660
- short is = (il/4) * 2;
2661
- q = q + 32 * (il/4) + 16 * (il&1);
2662
- qh = qh + 16 * (il&1);
2663
- uint8_t ul = 1 << (il/2);
2664
- il = il & 3;
2665
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2666
- const half d = il < 2 ? xb->d : xb->d / 16.h;
2667
- const half min = xb->dmin;
2668
- const half dl = d * sc[0];
2669
- const half ml = min * sc[1];
2670
-
2671
- const ushort mask = il<2 ? 0x0F : 0xF0;
2672
- const half qh_val = il<2 ? 16.h : 256.h;
2673
- for (int i = 0; i < 16; ++i) {
2674
- reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
2675
- }
2676
- #else
2677
- q = q + 16 * (il&1);
2678
- device const int8_t * s = xb->scales;
2679
- const float dl = xb->d * s[il];
2680
- uint8_t m = 1<<(il*2);
2681
- const float coef = il<2 ? 1.f : 1.f/16.f;
2682
- const ushort mask = il<2 ? 0x0F : 0xF0;
2683
- for (int i = 0; i < 16; ++i) {
2684
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
2685
- }
2686
- #endif
2687
- }
2688
-
2689
- template <typename type4x4>
2690
- void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
2691
- const half d_all = xb->d;
2692
- device const uint8_t * ql = (device const uint8_t *)xb->ql;
2693
- device const uint8_t * qh = (device const uint8_t *)xb->qh;
2694
- device const int8_t * scales = (device const int8_t *)xb->scales;
2695
-
2696
- #if QK_K == 256
2697
- ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
2698
- qh = qh + 32*(il/8) + 16*(il&1);
2699
- half sc = scales[(il%2) + 2 * ((il/2))];
2700
- il = (il/2) & 3;
2701
- #else
2702
- ql = ql + 16 * (il&1);
2703
- half sc = scales[il];
2704
- #endif
2705
- const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2706
- const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
2707
- const half coef = il>1 ? 1.f/16.h : 1.h;
2708
- const half ml = d_all * sc * 32.h;
2709
- const half dl = d_all * sc * coef;
2710
- for (int i = 0; i < 16; ++i) {
2711
- const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
2712
- : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
2713
- reg[i/4][i%4] = dl * q - ml;
2714
- }
2715
- }
2716
-
2717
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
2718
- kernel void kernel_get_rows(
2719
- device const void * src0,
2720
- device const int * src1,
2721
- device float * dst,
2722
- constant int64_t & ne00,
2723
- constant uint64_t & nb01,
2724
- constant uint64_t & nb1,
2725
- uint tgpig[[threadgroup_position_in_grid]],
2726
- uint tiitg[[thread_index_in_threadgroup]],
2727
- uint tptg[[threads_per_threadgroup]]) {
2728
- const int i = tgpig;
2729
- const int r = ((device int32_t *) src1)[i];
2730
-
2731
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
2732
- float4x4 temp;
2733
- dequantize_func(
2734
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
2735
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
2736
- }
2737
- }
2738
-
2739
- #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2740
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2741
- #define BLOCK_SIZE_K 32
2742
- #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2743
- #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
2744
- #define THREAD_PER_BLOCK 128
2745
- #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
2746
- #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
2747
- #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
2748
- #define SG_MAT_ROW 8
2749
-
2750
- // each block_q contains 16*nl weights
2751
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2752
- kernel void kernel_mul_mm(device const uchar * src0,
2753
- device const uchar * src1,
2754
- device float * dst,
2755
- constant int64_t & ne00,
2756
- constant int64_t & ne02,
2757
- constant int64_t & nb01,
2758
- constant int64_t & nb02,
2759
- constant int64_t & ne12,
2760
- constant int64_t & nb10,
2761
- constant int64_t & nb11,
2762
- constant int64_t & nb12,
2763
- constant int64_t & ne0,
2764
- constant int64_t & ne1,
2765
- constant uint & gqa,
2766
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2767
- uint3 tgpig[[threadgroup_position_in_grid]],
2768
- uint tiitg[[thread_index_in_threadgroup]],
2769
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2770
-
2771
- threadgroup half * sa = (threadgroup half *)(shared_memory);
2772
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
2773
-
2774
- const uint r0 = tgpig.y;
2775
- const uint r1 = tgpig.x;
2776
- const uint im = tgpig.z;
2777
-
2778
- // if this block is of 64x32 shape or smaller
2779
- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2780
- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2781
-
2782
- // a thread shouldn't load data outside of the matrix
2783
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2784
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
2785
-
2786
- simdgroup_half8x8 ma[4];
2787
- simdgroup_float8x8 mb[2];
2788
- simdgroup_float8x8 c_res[8];
2789
- for (int i = 0; i < 8; i++){
2790
- c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
2791
- }
2792
-
2793
- short il = (tiitg % THREAD_PER_ROW);
2794
-
2795
- uint offset0 = im/gqa*nb02;
2796
- ushort offset1 = il/nl;
2797
-
2798
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2799
- device const float * y = (device const float *)(src1
2800
- + nb12 * im
2801
- + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
2802
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2803
-
2804
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2805
- // load data and store to threadgroup memory
2806
- half4x4 temp_a;
2807
- dequantize_func(x, il, temp_a);
2808
- threadgroup_barrier(mem_flags::mem_threadgroup);
2809
-
2810
- #pragma unroll(16)
2811
- for (int i = 0; i < 16; i++) {
2812
- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2813
- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2814
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2815
- }
2816
-
2817
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2818
-
2819
- il = (il + 2 < nl) ? il + 2 : il % 2;
2820
- x = (il < 2) ? x + (2+nl-1)/nl : x;
2821
- y += BLOCK_SIZE_K;
2822
-
2823
- threadgroup_barrier(mem_flags::mem_threadgroup);
2824
-
2825
- // load matrices from threadgroup memory and conduct outer products
2826
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2827
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2828
-
2829
- #pragma unroll(4)
2830
- for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2831
- #pragma unroll(4)
2832
- for (int i = 0; i < 4; i++) {
2833
- simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
2834
- }
2835
- simdgroup_barrier(mem_flags::mem_none);
2836
- #pragma unroll(2)
2837
- for (int i = 0; i < 2; i++) {
2838
- simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
2839
- }
2840
-
2841
- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2842
- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2843
-
2844
- #pragma unroll(8)
2845
- for (int i = 0; i < 8; i++){
2846
- simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
2847
- }
2848
- }
2849
- }
2850
-
2851
- if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2852
- device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2853
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2854
- for (int i = 0; i < 8; i++) {
2855
- simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2856
- }
2857
- } else {
2858
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
2859
- threadgroup_barrier(mem_flags::mem_threadgroup);
2860
- threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2861
- + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2862
- for (int i = 0; i < 8; i++) {
2863
- simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2864
- }
2865
-
2866
- threadgroup_barrier(mem_flags::mem_threadgroup);
2867
-
2868
- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2869
- if (sgitg == 0) {
2870
- for (int i = 0; i < n_rows; i++) {
2871
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2872
- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2873
- }
2874
- }
2875
- }
2876
- }
2877
- }
2878
-
2879
- #if QK_K == 256
2880
- #define QK_NL 16
2881
- #else
2882
- #define QK_NL 4
2883
- #endif
2884
-
2885
- typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2886
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
2887
-
2888
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2889
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2890
- template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2891
- template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2892
- template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
2893
- template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
2894
- template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2895
- template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2896
- template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
2897
- template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
2898
- template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2899
- template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2900
-
2901
- typedef void (mat_mm_t)(
2902
- device const uchar * src0,
2903
- device const uchar * src1,
2904
- device float * dst,
2905
- constant int64_t & ne00,
2906
- constant int64_t & ne02,
2907
- constant int64_t & nb01,
2908
- constant int64_t & nb02,
2909
- constant int64_t & ne12,
2910
- constant int64_t & nb10,
2911
- constant int64_t & nb11,
2912
- constant int64_t & nb12,
2913
- constant int64_t & ne0,
2914
- constant int64_t & ne1,
2915
- constant uint & gqa,
2916
- threadgroup uchar *, uint3, uint, uint);
2917
-
2918
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2919
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2920
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2921
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2922
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
2923
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
2924
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2925
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2926
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2927
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2928
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2929
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;