node-llama-cpp 3.0.0-beta.37 → 3.0.0-beta.39

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. package/bins/linux-arm64/_nlcBuildMetadata.json +1 -1
  2. package/bins/linux-arm64/libggml.so +0 -0
  3. package/bins/linux-arm64/libllama.so +0 -0
  4. package/bins/linux-arm64/llama-addon.node +0 -0
  5. package/bins/linux-armv7l/_nlcBuildMetadata.json +1 -1
  6. package/bins/linux-armv7l/libggml.so +0 -0
  7. package/bins/linux-armv7l/libllama.so +0 -0
  8. package/bins/linux-armv7l/llama-addon.node +0 -0
  9. package/bins/linux-x64/_nlcBuildMetadata.json +1 -1
  10. package/bins/linux-x64/libggml.so +0 -0
  11. package/bins/linux-x64/libllama.so +0 -0
  12. package/bins/linux-x64/llama-addon.node +0 -0
  13. package/bins/linux-x64-vulkan/_nlcBuildMetadata.json +1 -1
  14. package/bins/linux-x64-vulkan/libggml.so +0 -0
  15. package/bins/linux-x64-vulkan/libllama.so +0 -0
  16. package/bins/linux-x64-vulkan/llama-addon.node +0 -0
  17. package/bins/linux-x64-vulkan/vulkan-shaders-gen +0 -0
  18. package/bins/mac-arm64-metal/_nlcBuildMetadata.json +1 -1
  19. package/bins/mac-arm64-metal/ggml-common.h +24 -0
  20. package/bins/mac-arm64-metal/ggml-metal.metal +181 -552
  21. package/bins/mac-arm64-metal/libggml.dylib +0 -0
  22. package/bins/mac-arm64-metal/libllama.dylib +0 -0
  23. package/bins/mac-arm64-metal/llama-addon.node +0 -0
  24. package/bins/mac-x64/_nlcBuildMetadata.json +1 -1
  25. package/bins/mac-x64/libggml.dylib +0 -0
  26. package/bins/mac-x64/libllama.dylib +0 -0
  27. package/bins/mac-x64/llama-addon.node +0 -0
  28. package/bins/win-arm64/_nlcBuildMetadata.json +1 -1
  29. package/bins/win-arm64/ggml.dll +0 -0
  30. package/bins/win-arm64/llama-addon.exp +0 -0
  31. package/bins/win-arm64/llama-addon.lib +0 -0
  32. package/bins/win-arm64/llama-addon.node +0 -0
  33. package/bins/win-arm64/llama.dll +0 -0
  34. package/bins/win-x64/_nlcBuildMetadata.json +1 -1
  35. package/bins/win-x64/ggml.dll +0 -0
  36. package/bins/win-x64/llama-addon.node +0 -0
  37. package/bins/win-x64/llama.dll +0 -0
  38. package/bins/win-x64-vulkan/_nlcBuildMetadata.json +1 -1
  39. package/bins/win-x64-vulkan/ggml.dll +0 -0
  40. package/bins/win-x64-vulkan/llama-addon.node +0 -0
  41. package/bins/win-x64-vulkan/llama.dll +0 -0
  42. package/bins/win-x64-vulkan/vulkan-shaders-gen.exe +0 -0
  43. package/dist/ChatWrapper.d.ts +2 -1
  44. package/dist/ChatWrapper.js +19 -5
  45. package/dist/ChatWrapper.js.map +1 -1
  46. package/dist/bindings/AddonTypes.d.ts +13 -2
  47. package/dist/bindings/getLlama.d.ts +3 -2
  48. package/dist/bindings/getLlama.js +1 -1
  49. package/dist/bindings/getLlama.js.map +1 -1
  50. package/dist/chatWrappers/FunctionaryChatWrapper.js +8 -5
  51. package/dist/chatWrappers/FunctionaryChatWrapper.js.map +1 -1
  52. package/dist/chatWrappers/GemmaChatWrapper.js +1 -1
  53. package/dist/chatWrappers/GemmaChatWrapper.js.map +1 -1
  54. package/dist/chatWrappers/Llama3ChatWrapper.js +5 -6
  55. package/dist/chatWrappers/Llama3ChatWrapper.js.map +1 -1
  56. package/dist/chatWrappers/Llama3_1ChatWrapper.d.ts +31 -0
  57. package/dist/chatWrappers/Llama3_1ChatWrapper.js +223 -0
  58. package/dist/chatWrappers/Llama3_1ChatWrapper.js.map +1 -0
  59. package/dist/chatWrappers/generic/JinjaTemplateChatWrapper.d.ts +9 -0
  60. package/dist/chatWrappers/generic/JinjaTemplateChatWrapper.js.map +1 -1
  61. package/dist/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.d.ts +17 -2
  62. package/dist/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.js +39 -2
  63. package/dist/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.js.map +1 -1
  64. package/dist/chatWrappers/utils/jsonDumps.d.ts +7 -0
  65. package/dist/chatWrappers/utils/jsonDumps.js +18 -0
  66. package/dist/chatWrappers/utils/jsonDumps.js.map +1 -0
  67. package/dist/chatWrappers/utils/resolveChatWrapper.d.ts +5 -3
  68. package/dist/chatWrappers/utils/resolveChatWrapper.js +50 -4
  69. package/dist/chatWrappers/utils/resolveChatWrapper.js.map +1 -1
  70. package/dist/cli/commands/ChatCommand.d.ts +1 -1
  71. package/dist/cli/commands/ChatCommand.js +5 -5
  72. package/dist/cli/commands/ChatCommand.js.map +1 -1
  73. package/dist/cli/commands/CompleteCommand.js +5 -3
  74. package/dist/cli/commands/CompleteCommand.js.map +1 -1
  75. package/dist/cli/commands/InfillCommand.js +5 -3
  76. package/dist/cli/commands/InfillCommand.js.map +1 -1
  77. package/dist/cli/recommendedModels.js +43 -24
  78. package/dist/cli/recommendedModels.js.map +1 -1
  79. package/dist/cli/utils/interactivelyAskForModel.d.ts +2 -1
  80. package/dist/cli/utils/interactivelyAskForModel.js +19 -9
  81. package/dist/cli/utils/interactivelyAskForModel.js.map +1 -1
  82. package/dist/cli/utils/resolveCommandGgufPath.d.ts +2 -1
  83. package/dist/cli/utils/resolveCommandGgufPath.js +3 -2
  84. package/dist/cli/utils/resolveCommandGgufPath.js.map +1 -1
  85. package/dist/consts.d.ts +1 -0
  86. package/dist/consts.js +1 -0
  87. package/dist/consts.js.map +1 -1
  88. package/dist/evaluator/LlamaChat/LlamaChat.d.ts +22 -0
  89. package/dist/evaluator/LlamaChat/LlamaChat.js +65 -34
  90. package/dist/evaluator/LlamaChat/LlamaChat.js.map +1 -1
  91. package/dist/evaluator/LlamaChatSession/LlamaChatSession.d.ts +28 -6
  92. package/dist/evaluator/LlamaChatSession/LlamaChatSession.js +22 -16
  93. package/dist/evaluator/LlamaChatSession/LlamaChatSession.js.map +1 -1
  94. package/dist/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.js +4 -5
  95. package/dist/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.js.map +1 -1
  96. package/dist/evaluator/LlamaCompletion.d.ts +13 -2
  97. package/dist/evaluator/LlamaCompletion.js +10 -5
  98. package/dist/evaluator/LlamaCompletion.js.map +1 -1
  99. package/dist/evaluator/LlamaContext/LlamaContext.d.ts +1 -1
  100. package/dist/evaluator/LlamaContext/LlamaContext.js +60 -0
  101. package/dist/evaluator/LlamaContext/LlamaContext.js.map +1 -1
  102. package/dist/evaluator/LlamaContext/types.d.ts +21 -0
  103. package/dist/evaluator/LlamaGrammar.d.ts +6 -3
  104. package/dist/evaluator/LlamaGrammar.js +2 -2
  105. package/dist/evaluator/LlamaGrammar.js.map +1 -1
  106. package/dist/evaluator/LlamaModel/LlamaModel.d.ts +16 -32
  107. package/dist/evaluator/LlamaModel/LlamaModel.js +94 -53
  108. package/dist/evaluator/LlamaModel/LlamaModel.js.map +1 -1
  109. package/dist/gguf/consts.d.ts +1 -0
  110. package/dist/gguf/consts.js +4 -0
  111. package/dist/gguf/consts.js.map +1 -1
  112. package/dist/gguf/insights/GgufInsights.js +4 -0
  113. package/dist/gguf/insights/GgufInsights.js.map +1 -1
  114. package/dist/gguf/parser/GgufV2Parser.js +3 -1
  115. package/dist/gguf/parser/GgufV2Parser.js.map +1 -1
  116. package/dist/gguf/types/GgufMetadataTypes.d.ts +16 -0
  117. package/dist/gguf/types/GgufMetadataTypes.js.map +1 -1
  118. package/dist/gguf/utils/convertMetadataKeyValueRecordToNestedObject.d.ts +3 -2
  119. package/dist/gguf/utils/convertMetadataKeyValueRecordToNestedObject.js +44 -8
  120. package/dist/gguf/utils/convertMetadataKeyValueRecordToNestedObject.js.map +1 -1
  121. package/dist/index.d.ts +4 -2
  122. package/dist/index.js +3 -1
  123. package/dist/index.js.map +1 -1
  124. package/dist/types.d.ts +15 -1
  125. package/dist/types.js.map +1 -1
  126. package/dist/utils/DeepPartialObject.d.ts +3 -0
  127. package/dist/utils/DeepPartialObject.js +2 -0
  128. package/dist/utils/DeepPartialObject.js.map +1 -0
  129. package/dist/utils/StopGenerationDetector.d.ts +6 -3
  130. package/dist/utils/StopGenerationDetector.js +22 -7
  131. package/dist/utils/StopGenerationDetector.js.map +1 -1
  132. package/dist/utils/TokenStreamRegulator.d.ts +1 -0
  133. package/dist/utils/TokenStreamRegulator.js +23 -5
  134. package/dist/utils/TokenStreamRegulator.js.map +1 -1
  135. package/dist/utils/resolveLastTokens.d.ts +2 -0
  136. package/dist/utils/resolveLastTokens.js +12 -0
  137. package/dist/utils/resolveLastTokens.js.map +1 -0
  138. package/llama/CMakeLists.txt +1 -1
  139. package/llama/addon/AddonContext.cpp +772 -0
  140. package/llama/addon/AddonContext.h +53 -0
  141. package/llama/addon/AddonGrammar.cpp +44 -0
  142. package/llama/addon/AddonGrammar.h +18 -0
  143. package/llama/addon/AddonGrammarEvaluationState.cpp +28 -0
  144. package/llama/addon/AddonGrammarEvaluationState.h +15 -0
  145. package/llama/addon/AddonModel.cpp +681 -0
  146. package/llama/addon/AddonModel.h +61 -0
  147. package/llama/addon/AddonModelData.cpp +25 -0
  148. package/llama/addon/AddonModelData.h +15 -0
  149. package/llama/addon/AddonModelLora.cpp +107 -0
  150. package/llama/addon/AddonModelLora.h +28 -0
  151. package/llama/addon/addon.cpp +217 -0
  152. package/llama/addon/addonGlobals.cpp +22 -0
  153. package/llama/addon/addonGlobals.h +12 -0
  154. package/llama/addon/globals/addonLog.cpp +135 -0
  155. package/llama/addon/globals/addonLog.h +21 -0
  156. package/llama/addon/globals/addonProgress.cpp +15 -0
  157. package/llama/addon/globals/addonProgress.h +15 -0
  158. package/llama/addon/globals/getGpuInfo.cpp +108 -0
  159. package/llama/addon/globals/getGpuInfo.h +6 -0
  160. package/llama/binariesGithubRelease.json +1 -1
  161. package/llama/gitRelease.bundle +0 -0
  162. package/llama/grammars/README.md +1 -1
  163. package/llama/llama.cpp.info.json +1 -1
  164. package/package.json +3 -3
  165. package/templates/packed/electron-typescript-react.json +1 -1
  166. package/templates/packed/node-typescript.json +1 -1
  167. package/llama/addon.cpp +0 -2014
@@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
1219
1219
  kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1220
1220
  }
1221
1221
 
1222
- #define N_F32_F32 4
1222
+ #define N_MV_T_T 4
1223
1223
 
1224
- void kernel_mul_mv_f32_f32_impl(
1224
+ template<typename T0, typename T04, typename T1, typename T14>
1225
+ void kernel_mul_mv_impl(
1225
1226
  device const char * src0,
1226
1227
  device const char * src1,
1227
1228
  device float * dst,
@@ -1239,13 +1240,12 @@ void kernel_mul_mv_f32_f32_impl(
1239
1240
  uint64_t nb12,
1240
1241
  int64_t ne0,
1241
1242
  int64_t ne1,
1242
- uint r2,
1243
- uint r3,
1244
- uint3 tgpig,
1245
- uint tiisg) {
1246
-
1243
+ uint r2,
1244
+ uint r3,
1245
+ uint3 tgpig,
1246
+ uint tiisg) {
1247
1247
  const int64_t r0 = tgpig.x;
1248
- const int64_t rb = tgpig.y*N_F32_F32;
1248
+ const int64_t rb = tgpig.y*N_MV_T_T;
1249
1249
  const int64_t im = tgpig.z;
1250
1250
 
1251
1251
  const uint i12 = im%ne12;
@@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
1253
1253
 
1254
1254
  const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1255
1255
 
1256
- device const float * x = (device const float *) (src0 + offset0);
1256
+ device const T0 * x = (device const T0 *) (src0 + offset0);
1257
1257
 
1258
1258
  if (ne00 < 128) {
1259
- for (int row = 0; row < N_F32_F32; ++row) {
1259
+ for (int row = 0; row < N_MV_T_T; ++row) {
1260
1260
  int r1 = rb + row;
1261
1261
  if (r1 >= ne11) {
1262
1262
  break;
1263
1263
  }
1264
1264
 
1265
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1265
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
1266
1266
 
1267
1267
  float sumf = 0;
1268
1268
  for (int i = tiisg; i < ne00; i += 32) {
1269
- sumf += (float) x[i] * (float) y[i];
1269
+ sumf += (T0) x[i] * (T1) y[i];
1270
1270
  }
1271
1271
 
1272
1272
  float all_sum = simd_sum(sumf);
@@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
1275
1275
  }
1276
1276
  }
1277
1277
  } else {
1278
- device const float4 * x4 = (device const float4 *)x;
1279
- for (int row = 0; row < N_F32_F32; ++row) {
1278
+ device const T04 * x4 = (device const T04 *) x;
1279
+ for (int row = 0; row < N_MV_T_T; ++row) {
1280
1280
  int r1 = rb + row;
1281
1281
  if (r1 >= ne11) {
1282
1282
  break;
1283
1283
  }
1284
1284
 
1285
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1286
- device const float4 * y4 = (device const float4 *) y;
1285
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
1286
+ device const T14 * y4 = (device const T14 *) y;
1287
1287
 
1288
1288
  float sumf = 0;
1289
1289
  for (int i = tiisg; i < ne00/4; i += 32) {
1290
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1290
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
1291
1291
  }
1292
1292
 
1293
1293
  float all_sum = simd_sum(sumf);
1294
1294
  if (tiisg == 0) {
1295
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1295
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
1296
1296
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1297
1297
  }
1298
1298
  }
1299
1299
  }
1300
1300
  }
1301
1301
 
1302
- [[host_name("kernel_mul_mv_f32_f32")]]
1303
- kernel void kernel_mul_mv_f32_f32(
1302
+ template<typename T0, typename T04, typename T1, typename T14>
1303
+ kernel void kernel_mul_mv(
1304
1304
  device const char * src0,
1305
1305
  device const char * src1,
1306
1306
  device float * dst,
@@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
1322
1322
  constant uint & r3,
1323
1323
  uint3 tgpig[[threadgroup_position_in_grid]],
1324
1324
  uint tiisg[[thread_index_in_simdgroup]]) {
1325
- kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1325
+ kernel_mul_mv_impl<T0, T04, T1, T14>(
1326
+ src0,
1327
+ src1,
1328
+ dst,
1329
+ ne00,
1330
+ ne01,
1331
+ ne02,
1332
+ nb00,
1333
+ nb01,
1334
+ nb02,
1335
+ ne10,
1336
+ ne11,
1337
+ ne12,
1338
+ nb10,
1339
+ nb11,
1340
+ nb12,
1341
+ ne0,
1342
+ ne1,
1343
+ r2,
1344
+ r3,
1345
+ tgpig,
1346
+ tiisg);
1326
1347
  }
1327
1348
 
1328
- #define N_F16_F16 4
1349
+ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
1329
1350
 
1330
- kernel void kernel_mul_mv_f16_f16(
1331
- device const char * src0,
1332
- device const char * src1,
1333
- device float * dst,
1334
- constant int64_t & ne00,
1335
- constant int64_t & ne01,
1336
- constant int64_t & ne02,
1337
- constant uint64_t & nb00,
1338
- constant uint64_t & nb01,
1339
- constant uint64_t & nb02,
1340
- constant int64_t & ne10,
1341
- constant int64_t & ne11,
1342
- constant int64_t & ne12,
1343
- constant uint64_t & nb10,
1344
- constant uint64_t & nb11,
1345
- constant uint64_t & nb12,
1346
- constant int64_t & ne0,
1347
- constant int64_t & ne1,
1348
- constant uint & r2,
1349
- constant uint & r3,
1350
- uint3 tgpig[[threadgroup_position_in_grid]],
1351
- uint tiisg[[thread_index_in_simdgroup]]) {
1351
+ template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
1352
+ template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
1353
+ template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
1352
1354
 
1353
- const int64_t r0 = tgpig.x;
1354
- const int64_t rb = tgpig.y*N_F16_F16;
1355
- const int64_t im = tgpig.z;
1356
-
1357
- const uint i12 = im%ne12;
1358
- const uint i13 = im/ne12;
1359
-
1360
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1361
-
1362
- device const half * x = (device const half *) (src0 + offset0);
1363
-
1364
- if (ne00 < 128) {
1365
- for (int row = 0; row < N_F16_F16; ++row) {
1366
- int r1 = rb + row;
1367
- if (r1 >= ne11) {
1368
- break;
1369
- }
1370
-
1371
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1372
-
1373
- float sumf = 0;
1374
- for (int i = tiisg; i < ne00; i += 32) {
1375
- sumf += (half) x[i] * (half) y[i];
1376
- }
1377
-
1378
- float all_sum = simd_sum(sumf);
1379
- if (tiisg == 0) {
1380
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1381
- }
1382
- }
1383
- } else {
1384
- device const half4 * x4 = (device const half4 *)x;
1385
- for (int row = 0; row < N_F16_F16; ++row) {
1386
- int r1 = rb + row;
1387
- if (r1 >= ne11) {
1388
- break;
1389
- }
1390
-
1391
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1392
- device const half4 * y4 = (device const half4 *) y;
1393
-
1394
- float sumf = 0;
1395
- for (int i = tiisg; i < ne00/4; i += 32) {
1396
- for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
1397
- }
1398
-
1399
- float all_sum = simd_sum(sumf);
1400
- if (tiisg == 0) {
1401
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
1402
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1403
- }
1404
- }
1405
- }
1406
- }
1407
-
1408
- void kernel_mul_mv_f16_f32_1row_impl(
1355
+ template<typename T, typename T4>
1356
+ kernel void kernel_mul_mv_1row(
1409
1357
  device const char * src0,
1410
1358
  device const char * src1,
1411
1359
  device float * dst,
@@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
1437
1385
 
1438
1386
  const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1439
1387
 
1440
- device const half * x = (device const half *) (src0 + offset0);
1388
+ device const T * x = (device const T *) (src0 + offset0);
1441
1389
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1442
1390
 
1443
1391
  float sumf = 0;
@@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
1450
1398
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1451
1399
  }
1452
1400
  } else {
1453
- device const half4 * x4 = (device const half4 *) x;
1401
+ device const T4 * x4 = (device const T4 *) x;
1454
1402
  device const float4 * y4 = (device const float4 *) y;
1403
+
1455
1404
  for (int i = tiisg; i < ne00/4; i += 32) {
1456
- for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
1405
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
1457
1406
  }
1407
+
1458
1408
  float all_sum = simd_sum(sumf);
1409
+
1459
1410
  if (tiisg == 0) {
1460
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1411
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
1461
1412
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1462
1413
  }
1463
1414
  }
1464
1415
  }
1465
1416
 
1466
- [[host_name("kernel_mul_mv_f16_f32_1row")]]
1467
- kernel void kernel_mul_mv_f16_f32_1row(
1468
- device const char * src0,
1469
- device const char * src1,
1470
- device float * dst,
1471
- constant int64_t & ne00,
1472
- constant int64_t & ne01,
1473
- constant int64_t & ne02,
1474
- constant uint64_t & nb00,
1475
- constant uint64_t & nb01,
1476
- constant uint64_t & nb02,
1477
- constant int64_t & ne10,
1478
- constant int64_t & ne11,
1479
- constant int64_t & ne12,
1480
- constant uint64_t & nb10,
1481
- constant uint64_t & nb11,
1482
- constant uint64_t & nb12,
1483
- constant int64_t & ne0,
1484
- constant int64_t & ne1,
1485
- constant uint & r2,
1486
- constant uint & r3,
1487
- uint3 tgpig[[threadgroup_position_in_grid]],
1488
- uint tiisg[[thread_index_in_simdgroup]]) {
1489
- kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1490
- }
1491
-
1492
- #define N_F16_F32 4
1417
+ typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
1493
1418
 
1494
- void kernel_mul_mv_f16_f32_impl(
1495
- device const char * src0,
1496
- device const char * src1,
1497
- device float * dst,
1498
- int64_t ne00,
1499
- int64_t ne01,
1500
- int64_t ne02,
1501
- uint64_t nb00,
1502
- uint64_t nb01,
1503
- uint64_t nb02,
1504
- int64_t ne10,
1505
- int64_t ne11,
1506
- int64_t ne12,
1507
- uint64_t nb10,
1508
- uint64_t nb11,
1509
- uint64_t nb12,
1510
- int64_t ne0,
1511
- int64_t ne1,
1512
- uint r2,
1513
- uint r3,
1514
- uint3 tgpig,
1515
- uint tiisg) {
1516
-
1517
- const int64_t r0 = tgpig.x;
1518
- const int64_t rb = tgpig.y*N_F16_F32;
1519
- const int64_t im = tgpig.z;
1520
-
1521
- const uint i12 = im%ne12;
1522
- const uint i13 = im/ne12;
1523
-
1524
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1525
-
1526
- device const half * x = (device const half *) (src0 + offset0);
1527
-
1528
- if (ne00 < 128) {
1529
- for (int row = 0; row < N_F16_F32; ++row) {
1530
- int r1 = rb + row;
1531
- if (r1 >= ne11) {
1532
- break;
1533
- }
1534
-
1535
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1536
-
1537
- float sumf = 0;
1538
- for (int i = tiisg; i < ne00; i += 32) {
1539
- sumf += (float) x[i] * (float) y[i];
1540
- }
1541
-
1542
- float all_sum = simd_sum(sumf);
1543
- if (tiisg == 0) {
1544
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1545
- }
1546
- }
1547
- } else {
1548
- device const half4 * x4 = (device const half4 *)x;
1549
- for (int row = 0; row < N_F16_F32; ++row) {
1550
- int r1 = rb + row;
1551
- if (r1 >= ne11) {
1552
- break;
1553
- }
1554
-
1555
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1556
- device const float4 * y4 = (device const float4 *) y;
1557
-
1558
- float sumf = 0;
1559
- for (int i = tiisg; i < ne00/4; i += 32) {
1560
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1561
- }
1562
-
1563
- float all_sum = simd_sum(sumf);
1564
- if (tiisg == 0) {
1565
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1566
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1567
- }
1568
- }
1569
- }
1570
- }
1571
-
1572
- [[host_name("kernel_mul_mv_f16_f32")]]
1573
- kernel void kernel_mul_mv_f16_f32(
1574
- device const char * src0,
1575
- device const char * src1,
1576
- device float * dst,
1577
- constant int64_t & ne00,
1578
- constant int64_t & ne01,
1579
- constant int64_t & ne02,
1580
- constant uint64_t & nb00,
1581
- constant uint64_t & nb01,
1582
- constant uint64_t & nb02,
1583
- constant int64_t & ne10,
1584
- constant int64_t & ne11,
1585
- constant int64_t & ne12,
1586
- constant uint64_t & nb10,
1587
- constant uint64_t & nb11,
1588
- constant uint64_t & nb12,
1589
- constant int64_t & ne0,
1590
- constant int64_t & ne1,
1591
- constant uint & r2,
1592
- constant uint & r3,
1593
- uint3 tgpig[[threadgroup_position_in_grid]],
1594
- uint tiisg[[thread_index_in_simdgroup]]) {
1595
- kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1596
- }
1419
+ template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
1597
1420
 
1598
1421
  // Assumes row size (ne00) is a multiple of 4
1599
- kernel void kernel_mul_mv_f16_f32_l4(
1422
+ template<typename T, typename T4>
1423
+ kernel void kernel_mul_mv_l4(
1600
1424
  device const char * src0,
1601
1425
  device const char * src1,
1602
1426
  device float * dst,
@@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
1628
1452
 
1629
1453
  const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1630
1454
 
1631
- device const half4 * x4 = (device const half4 *) (src0 + offset0);
1455
+ device const T4 * x4 = (device const T4 *) (src0 + offset0);
1632
1456
 
1633
1457
  for (int r1 = 0; r1 < nrows; ++r1) {
1634
1458
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
1635
1459
 
1636
1460
  float sumf = 0;
1637
1461
  for (int i = tiisg; i < ne00/4; i += 32) {
1638
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1462
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
1639
1463
  }
1640
1464
 
1641
1465
  float all_sum = simd_sum(sumf);
@@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
1645
1469
  }
1646
1470
  }
1647
1471
 
1472
+ typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
1473
+
1474
+ template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
1475
+
1648
1476
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
1649
1477
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
1650
1478
  return 1.0f - min(1.0f, max(0.0f, y));
@@ -2765,91 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
2765
2593
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2766
2594
  //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2767
2595
 
2768
- kernel void kernel_cpy_f16_f16(
2769
- device const half * src0,
2770
- device half * dst,
2771
- constant int64_t & ne00,
2772
- constant int64_t & ne01,
2773
- constant int64_t & ne02,
2774
- constant int64_t & ne03,
2775
- constant uint64_t & nb00,
2776
- constant uint64_t & nb01,
2777
- constant uint64_t & nb02,
2778
- constant uint64_t & nb03,
2779
- constant int64_t & ne0,
2780
- constant int64_t & ne1,
2781
- constant int64_t & ne2,
2782
- constant int64_t & ne3,
2783
- constant uint64_t & nb0,
2784
- constant uint64_t & nb1,
2785
- constant uint64_t & nb2,
2786
- constant uint64_t & nb3,
2787
- uint3 tgpig[[threadgroup_position_in_grid]],
2788
- uint3 tpitg[[thread_position_in_threadgroup]],
2789
- uint3 ntg[[threads_per_threadgroup]]) {
2790
- const int64_t i03 = tgpig[2];
2791
- const int64_t i02 = tgpig[1];
2792
- const int64_t i01 = tgpig[0];
2793
-
2794
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2795
-
2796
- const int64_t i3 = n / (ne2*ne1*ne0);
2797
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2798
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2799
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2800
-
2801
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2802
-
2803
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2804
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2805
- dst_data[i00] = src[0];
2806
- }
2807
- }
2808
-
2809
- kernel void kernel_cpy_f16_f32(
2810
- device const half * src0,
2811
- device float * dst,
2812
- constant int64_t & ne00,
2813
- constant int64_t & ne01,
2814
- constant int64_t & ne02,
2815
- constant int64_t & ne03,
2816
- constant uint64_t & nb00,
2817
- constant uint64_t & nb01,
2818
- constant uint64_t & nb02,
2819
- constant uint64_t & nb03,
2820
- constant int64_t & ne0,
2821
- constant int64_t & ne1,
2822
- constant int64_t & ne2,
2823
- constant int64_t & ne3,
2824
- constant uint64_t & nb0,
2825
- constant uint64_t & nb1,
2826
- constant uint64_t & nb2,
2827
- constant uint64_t & nb3,
2828
- uint3 tgpig[[threadgroup_position_in_grid]],
2829
- uint3 tpitg[[thread_position_in_threadgroup]],
2830
- uint3 ntg[[threads_per_threadgroup]]) {
2831
- const int64_t i03 = tgpig[2];
2832
- const int64_t i02 = tgpig[1];
2833
- const int64_t i01 = tgpig[0];
2834
-
2835
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2836
-
2837
- const int64_t i3 = n / (ne2*ne1*ne0);
2838
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2839
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2840
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2841
-
2842
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2843
-
2844
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2845
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2846
- dst_data[i00] = src[0];
2847
- }
2848
- }
2849
-
2850
- kernel void kernel_cpy_f32_f16(
2851
- device const float * src0,
2852
- device half * dst,
2596
+ template<typename T0, typename T1>
2597
+ kernel void kernel_cpy(
2598
+ device const void * src0,
2599
+ device void * dst,
2853
2600
  constant int64_t & ne00,
2854
2601
  constant int64_t & ne01,
2855
2602
  constant int64_t & ne02,
@@ -2880,56 +2627,20 @@ kernel void kernel_cpy_f32_f16(
2880
2627
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2881
2628
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2882
2629
 
2883
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2630
+ device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2884
2631
 
2885
2632
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2886
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2887
-
2888
- dst_data[i00] = src[0];
2633
+ device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2634
+ dst_data[i00] = (T1) src[0];
2889
2635
  }
2890
2636
  }
2891
2637
 
2892
- kernel void kernel_cpy_f32_f32(
2893
- device const float * src0,
2894
- device float * dst,
2895
- constant int64_t & ne00,
2896
- constant int64_t & ne01,
2897
- constant int64_t & ne02,
2898
- constant int64_t & ne03,
2899
- constant uint64_t & nb00,
2900
- constant uint64_t & nb01,
2901
- constant uint64_t & nb02,
2902
- constant uint64_t & nb03,
2903
- constant int64_t & ne0,
2904
- constant int64_t & ne1,
2905
- constant int64_t & ne2,
2906
- constant int64_t & ne3,
2907
- constant uint64_t & nb0,
2908
- constant uint64_t & nb1,
2909
- constant uint64_t & nb2,
2910
- constant uint64_t & nb3,
2911
- uint3 tgpig[[threadgroup_position_in_grid]],
2912
- uint3 tpitg[[thread_position_in_threadgroup]],
2913
- uint3 ntg[[threads_per_threadgroup]]) {
2914
- const int64_t i03 = tgpig[2];
2915
- const int64_t i02 = tgpig[1];
2916
- const int64_t i01 = tgpig[0];
2917
-
2918
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2919
-
2920
- const int64_t i3 = n / (ne2*ne1*ne0);
2921
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2922
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2923
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2638
+ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
2924
2639
 
2925
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2926
-
2927
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2928
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2929
-
2930
- dst_data[i00] = src[0];
2931
- }
2932
- }
2640
+ template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
2641
+ template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
2642
+ template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
2643
+ template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
2933
2644
 
2934
2645
  kernel void kernel_cpy_f32_q8_0(
2935
2646
  device const float * src0,
@@ -5046,7 +4757,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5046
4757
  device const float4 * y4 = (device const float4 *)yb;
5047
4758
  yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
5048
4759
 
5049
- for (int row = 0; row < 2; ++row) {
4760
+ for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
5050
4761
 
5051
4762
  device const block_iq4_nl & xb = x[row*nb + ib];
5052
4763
  device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
@@ -5078,7 +4789,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5078
4789
  yb += 16 * QK4_NL;
5079
4790
  }
5080
4791
 
5081
- for (int row = 0; row < 2; ++row) {
4792
+ for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
5082
4793
  all_sum = simd_sum(sumf[row]);
5083
4794
  if (tiisg == 0) {
5084
4795
  dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
@@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
5730
5441
  }
5731
5442
 
5732
5443
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
5733
- kernel void kernel_get_rows(
5444
+ kernel void kernel_get_rows_q(
5734
5445
  device const void * src0,
5735
- device const char * src1,
5446
+ device const void * src1,
5736
5447
  device float * dst,
5737
5448
  constant int64_t & ne00,
5738
5449
  constant uint64_t & nb01,
@@ -5745,55 +5456,24 @@ kernel void kernel_get_rows(
5745
5456
  uint3 tgpig[[threadgroup_position_in_grid]],
5746
5457
  uint tiitg[[thread_index_in_threadgroup]],
5747
5458
  uint3 tptg [[threads_per_threadgroup]]) {
5748
- //const int64_t i = tgpig;
5749
- //const int64_t r = ((device int32_t *) src1)[i];
5750
-
5751
5459
  const int64_t i10 = tgpig.x;
5752
5460
  const int64_t i11 = tgpig.y;
5753
5461
 
5754
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5462
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
5755
5463
 
5756
5464
  const int64_t i02 = i11;
5757
5465
 
5758
5466
  for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
5759
5467
  float4x4 temp;
5760
- dequantize_func(
5761
- ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
5468
+ dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
5762
5469
  *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
5763
5470
  }
5764
5471
  }
5765
5472
 
5766
- kernel void kernel_get_rows_f32(
5767
- device const void * src0,
5768
- device const char * src1,
5769
- device float * dst,
5770
- constant int64_t & ne00,
5771
- constant uint64_t & nb01,
5772
- constant uint64_t & nb02,
5773
- constant int64_t & ne10,
5774
- constant uint64_t & nb10,
5775
- constant uint64_t & nb11,
5776
- constant uint64_t & nb1,
5777
- constant uint64_t & nb2,
5778
- uint3 tgpig[[threadgroup_position_in_grid]],
5779
- uint tiitg[[thread_index_in_threadgroup]],
5780
- uint3 tptg [[threads_per_threadgroup]]) {
5781
- const int64_t i10 = tgpig.x;
5782
- const int64_t i11 = tgpig.y;
5783
-
5784
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5785
-
5786
- const int64_t i02 = i11;
5787
-
5788
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5789
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5790
- ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5791
- }
5792
- }
5793
-
5794
- kernel void kernel_get_rows_f16(
5473
+ template<typename T>
5474
+ kernel void kernel_get_rows_f(
5795
5475
  device const void * src0,
5796
- device const char * src1,
5476
+ device const void * src1,
5797
5477
  device float * dst,
5798
5478
  constant int64_t & ne00,
5799
5479
  constant uint64_t & nb01,
@@ -5809,19 +5489,19 @@ kernel void kernel_get_rows_f16(
5809
5489
  const int64_t i10 = tgpig.x;
5810
5490
  const int64_t i11 = tgpig.y;
5811
5491
 
5812
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5492
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
5813
5493
 
5814
5494
  const int64_t i02 = i11;
5815
5495
 
5816
5496
  for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5817
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5818
- ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5497
+ (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5498
+ ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
5819
5499
  }
5820
5500
  }
5821
5501
 
5822
5502
  kernel void kernel_get_rows_i32(
5823
5503
  device const void * src0,
5824
- device const char * src1,
5504
+ device const void * src1,
5825
5505
  device int32_t * dst,
5826
5506
  constant int64_t & ne00,
5827
5507
  constant uint64_t & nb01,
@@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
5837
5517
  const int64_t i10 = tgpig.x;
5838
5518
  const int64_t i11 = tgpig.y;
5839
5519
 
5840
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5520
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
5841
5521
 
5842
5522
  const int64_t i02 = i11;
5843
5523
 
5844
5524
  for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5845
- ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5846
- ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5525
+ (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5526
+ ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
5847
5527
  }
5848
5528
  }
5849
5529
 
@@ -5860,28 +5540,28 @@ kernel void kernel_get_rows_i32(
5860
5540
  #define SG_MAT_ROW 8
5861
5541
 
5862
5542
  // each block_q contains 16*nl weights
5863
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5864
- void kernel_mul_mm_impl(device const uchar * src0,
5865
- device const uchar * src1,
5866
- device float * dst,
5867
- constant int64_t & ne00,
5868
- constant int64_t & ne02,
5869
- constant uint64_t & nb01,
5870
- constant uint64_t & nb02,
5871
- constant int64_t & ne12,
5872
- constant uint64_t & nb10,
5873
- constant uint64_t & nb11,
5874
- constant uint64_t & nb12,
5875
- constant int64_t & ne0,
5876
- constant int64_t & ne1,
5877
- constant uint & r2,
5878
- constant uint & r3,
5879
- threadgroup uchar * shared_memory [[threadgroup(0)]],
5880
- uint3 tgpig[[threadgroup_position_in_grid]],
5881
- uint tiitg[[thread_index_in_threadgroup]],
5882
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5543
+ template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
5544
+ kernel void kernel_mul_mm(device const uchar * src0,
5545
+ device const uchar * src1,
5546
+ device float * dst,
5547
+ constant int64_t & ne00,
5548
+ constant int64_t & ne02,
5549
+ constant uint64_t & nb01,
5550
+ constant uint64_t & nb02,
5551
+ constant int64_t & ne12,
5552
+ constant uint64_t & nb10,
5553
+ constant uint64_t & nb11,
5554
+ constant uint64_t & nb12,
5555
+ constant int64_t & ne0,
5556
+ constant int64_t & ne1,
5557
+ constant uint & r2,
5558
+ constant uint & r3,
5559
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
5560
+ uint3 tgpig[[threadgroup_position_in_grid]],
5561
+ uint tiitg[[thread_index_in_threadgroup]],
5562
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5883
5563
 
5884
- threadgroup half * sa = (threadgroup half *)(shared_memory);
5564
+ threadgroup T * sa = (threadgroup T *)(shared_memory);
5885
5565
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
5886
5566
 
5887
5567
  const uint r0 = tgpig.y;
@@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
5896
5576
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
5897
5577
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
5898
5578
 
5899
- simdgroup_half8x8 ma[4];
5579
+ simdgroup_T8x8 ma[4];
5900
5580
  simdgroup_float8x8 mb[2];
5901
5581
  simdgroup_float8x8 c_res[8];
5902
5582
  for (int i = 0; i < 8; i++){
@@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
5919
5599
 
5920
5600
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
5921
5601
  // load data and store to threadgroup memory
5922
- half4x4 temp_a;
5602
+ T4x4 temp_a;
5923
5603
  dequantize_func(x, il, temp_a);
5924
5604
  threadgroup_barrier(mem_flags::mem_threadgroup);
5925
5605
 
@@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
5939
5619
  threadgroup_barrier(mem_flags::mem_threadgroup);
5940
5620
 
5941
5621
  // load matrices from threadgroup memory and conduct outer products
5942
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
5622
+ threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
5943
5623
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
5944
5624
 
5945
5625
  #pragma unroll(4)
@@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
6115
5795
  }
6116
5796
  }
6117
5797
 
6118
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6119
- kernel void kernel_mul_mm(device const uchar * src0,
6120
- device const uchar * src1,
6121
- device float * dst,
6122
- constant int64_t & ne00,
6123
- constant int64_t & ne02,
6124
- constant uint64_t & nb01,
6125
- constant uint64_t & nb02,
6126
- constant int64_t & ne12,
6127
- constant uint64_t & nb10,
6128
- constant uint64_t & nb11,
6129
- constant uint64_t & nb12,
6130
- constant int64_t & ne0,
6131
- constant int64_t & ne1,
6132
- constant uint & r2,
6133
- constant uint & r3,
6134
- threadgroup uchar * shared_memory [[threadgroup(0)]],
6135
- uint3 tgpig[[threadgroup_position_in_grid]],
6136
- uint tiitg[[thread_index_in_threadgroup]],
6137
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6138
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
6139
- src0,
6140
- src1,
6141
- dst,
6142
- ne00,
6143
- ne02,
6144
- nb01,
6145
- nb02,
6146
- ne12,
6147
- nb10,
6148
- nb11,
6149
- nb12,
6150
- ne0,
6151
- ne1,
6152
- r2,
6153
- r3,
6154
- shared_memory,
6155
- tgpig,
6156
- tiitg,
6157
- sgitg);
6158
- }
6159
-
6160
5798
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6161
5799
  kernel void kernel_mul_mm_id(
6162
5800
  device const uchar * src0s,
@@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
6237
5875
  // get rows
6238
5876
  //
6239
5877
 
6240
- typedef void (get_rows_t)(
6241
- device const void * src0,
6242
- device const char * src1,
6243
- device float * dst,
6244
- constant int64_t & ne00,
6245
- constant uint64_t & nb01,
6246
- constant uint64_t & nb02,
6247
- constant int64_t & ne10,
6248
- constant uint64_t & nb10,
6249
- constant uint64_t & nb11,
6250
- constant uint64_t & nb1,
6251
- constant uint64_t & nb2,
6252
- uint3, uint, uint3);
6253
-
6254
- //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
6255
- //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
6256
- template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
6257
- template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
6258
- template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
6259
- template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
6260
- template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
6261
- template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
6262
- template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
6263
- template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
6264
- template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
6265
- template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
6266
- template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6267
- template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6268
- template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6269
- template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
6270
- template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6271
- template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6272
- template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6273
- template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6274
- template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
5878
+ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
5879
+
5880
+ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
5881
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
5882
+
5883
+ typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
5884
+
5885
+ template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
5886
+ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
5887
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
5888
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
5889
+ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
5890
+ template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
5891
+ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
5892
+ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
5893
+ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
5894
+ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
5895
+ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5896
+ template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5897
+ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5898
+ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
5899
+ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
5900
+ template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
5901
+ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
5902
+ template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
5903
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6275
5904
 
6276
5905
  //
6277
5906
  // matrix-matrix multiplication
6278
5907
  //
6279
5908
 
6280
- typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
6281
-
6282
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
6283
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
6284
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
6285
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
6286
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
6287
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
6288
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
6289
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
6290
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
6291
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
6292
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
6293
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
6294
- template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6295
- template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6296
- template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6297
- template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
6298
- template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
6299
- template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6300
- template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
6301
- template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6302
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
5909
+ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
5910
+
5911
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
5912
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
5913
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
5914
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
5915
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
5916
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
5917
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
5918
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
5919
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
5920
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
5921
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
5922
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
5923
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5924
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5925
+ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5926
+ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
5927
+ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
5928
+ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
5929
+ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
5930
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
5931
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6303
5932
 
6304
5933
  //
6305
5934
  // indirect matrix-matrix multiplication
@@ -6436,7 +6065,7 @@ void mmv_fn(
6436
6065
  impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
6437
6066
  }
6438
6067
 
6439
- typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
6068
+ typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
6440
6069
 
6441
6070
  template<mul_mv_impl_fn_t impl_fn>
6442
6071
  kernel void kernel_mul_mv_id(
@@ -6514,20 +6143,20 @@ kernel void kernel_mul_mv_id(
6514
6143
  sgitg);
6515
6144
  }
6516
6145
 
6517
- typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
6518
-
6519
- template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
6520
- template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
6521
- template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6522
- template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6523
- template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6524
- template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6525
- template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6526
- template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
6527
- template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
6528
- template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
6529
- template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
6530
- template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
6146
+ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
6147
+
6148
+ template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
6149
+ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
6150
+ template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6151
+ template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6152
+ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6153
+ template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6154
+ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6155
+ template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
6156
+ template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
6157
+ template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
6158
+ template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
6159
+ template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
6531
6160
  template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
6532
6161
  template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
6533
6162
  template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;