node-llama-cpp 3.0.0-beta.37 → 3.0.0-beta.39
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bins/linux-arm64/_nlcBuildMetadata.json +1 -1
- package/bins/linux-arm64/libggml.so +0 -0
- package/bins/linux-arm64/libllama.so +0 -0
- package/bins/linux-arm64/llama-addon.node +0 -0
- package/bins/linux-armv7l/_nlcBuildMetadata.json +1 -1
- package/bins/linux-armv7l/libggml.so +0 -0
- package/bins/linux-armv7l/libllama.so +0 -0
- package/bins/linux-armv7l/llama-addon.node +0 -0
- package/bins/linux-x64/_nlcBuildMetadata.json +1 -1
- package/bins/linux-x64/libggml.so +0 -0
- package/bins/linux-x64/libllama.so +0 -0
- package/bins/linux-x64/llama-addon.node +0 -0
- package/bins/linux-x64-vulkan/_nlcBuildMetadata.json +1 -1
- package/bins/linux-x64-vulkan/libggml.so +0 -0
- package/bins/linux-x64-vulkan/libllama.so +0 -0
- package/bins/linux-x64-vulkan/llama-addon.node +0 -0
- package/bins/linux-x64-vulkan/vulkan-shaders-gen +0 -0
- package/bins/mac-arm64-metal/_nlcBuildMetadata.json +1 -1
- package/bins/mac-arm64-metal/ggml-common.h +24 -0
- package/bins/mac-arm64-metal/ggml-metal.metal +181 -552
- package/bins/mac-arm64-metal/libggml.dylib +0 -0
- package/bins/mac-arm64-metal/libllama.dylib +0 -0
- package/bins/mac-arm64-metal/llama-addon.node +0 -0
- package/bins/mac-x64/_nlcBuildMetadata.json +1 -1
- package/bins/mac-x64/libggml.dylib +0 -0
- package/bins/mac-x64/libllama.dylib +0 -0
- package/bins/mac-x64/llama-addon.node +0 -0
- package/bins/win-arm64/_nlcBuildMetadata.json +1 -1
- package/bins/win-arm64/ggml.dll +0 -0
- package/bins/win-arm64/llama-addon.exp +0 -0
- package/bins/win-arm64/llama-addon.lib +0 -0
- package/bins/win-arm64/llama-addon.node +0 -0
- package/bins/win-arm64/llama.dll +0 -0
- package/bins/win-x64/_nlcBuildMetadata.json +1 -1
- package/bins/win-x64/ggml.dll +0 -0
- package/bins/win-x64/llama-addon.node +0 -0
- package/bins/win-x64/llama.dll +0 -0
- package/bins/win-x64-vulkan/_nlcBuildMetadata.json +1 -1
- package/bins/win-x64-vulkan/ggml.dll +0 -0
- package/bins/win-x64-vulkan/llama-addon.node +0 -0
- package/bins/win-x64-vulkan/llama.dll +0 -0
- package/bins/win-x64-vulkan/vulkan-shaders-gen.exe +0 -0
- package/dist/ChatWrapper.d.ts +2 -1
- package/dist/ChatWrapper.js +19 -5
- package/dist/ChatWrapper.js.map +1 -1
- package/dist/bindings/AddonTypes.d.ts +13 -2
- package/dist/bindings/getLlama.d.ts +3 -2
- package/dist/bindings/getLlama.js +1 -1
- package/dist/bindings/getLlama.js.map +1 -1
- package/dist/chatWrappers/FunctionaryChatWrapper.js +8 -5
- package/dist/chatWrappers/FunctionaryChatWrapper.js.map +1 -1
- package/dist/chatWrappers/GemmaChatWrapper.js +1 -1
- package/dist/chatWrappers/GemmaChatWrapper.js.map +1 -1
- package/dist/chatWrappers/Llama3ChatWrapper.js +5 -6
- package/dist/chatWrappers/Llama3ChatWrapper.js.map +1 -1
- package/dist/chatWrappers/Llama3_1ChatWrapper.d.ts +31 -0
- package/dist/chatWrappers/Llama3_1ChatWrapper.js +223 -0
- package/dist/chatWrappers/Llama3_1ChatWrapper.js.map +1 -0
- package/dist/chatWrappers/generic/JinjaTemplateChatWrapper.d.ts +9 -0
- package/dist/chatWrappers/generic/JinjaTemplateChatWrapper.js.map +1 -1
- package/dist/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.d.ts +17 -2
- package/dist/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.js +39 -2
- package/dist/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.js.map +1 -1
- package/dist/chatWrappers/utils/jsonDumps.d.ts +7 -0
- package/dist/chatWrappers/utils/jsonDumps.js +18 -0
- package/dist/chatWrappers/utils/jsonDumps.js.map +1 -0
- package/dist/chatWrappers/utils/resolveChatWrapper.d.ts +5 -3
- package/dist/chatWrappers/utils/resolveChatWrapper.js +50 -4
- package/dist/chatWrappers/utils/resolveChatWrapper.js.map +1 -1
- package/dist/cli/commands/ChatCommand.d.ts +1 -1
- package/dist/cli/commands/ChatCommand.js +5 -5
- package/dist/cli/commands/ChatCommand.js.map +1 -1
- package/dist/cli/commands/CompleteCommand.js +5 -3
- package/dist/cli/commands/CompleteCommand.js.map +1 -1
- package/dist/cli/commands/InfillCommand.js +5 -3
- package/dist/cli/commands/InfillCommand.js.map +1 -1
- package/dist/cli/recommendedModels.js +43 -24
- package/dist/cli/recommendedModels.js.map +1 -1
- package/dist/cli/utils/interactivelyAskForModel.d.ts +2 -1
- package/dist/cli/utils/interactivelyAskForModel.js +19 -9
- package/dist/cli/utils/interactivelyAskForModel.js.map +1 -1
- package/dist/cli/utils/resolveCommandGgufPath.d.ts +2 -1
- package/dist/cli/utils/resolveCommandGgufPath.js +3 -2
- package/dist/cli/utils/resolveCommandGgufPath.js.map +1 -1
- package/dist/consts.d.ts +1 -0
- package/dist/consts.js +1 -0
- package/dist/consts.js.map +1 -1
- package/dist/evaluator/LlamaChat/LlamaChat.d.ts +22 -0
- package/dist/evaluator/LlamaChat/LlamaChat.js +65 -34
- package/dist/evaluator/LlamaChat/LlamaChat.js.map +1 -1
- package/dist/evaluator/LlamaChatSession/LlamaChatSession.d.ts +28 -6
- package/dist/evaluator/LlamaChatSession/LlamaChatSession.js +22 -16
- package/dist/evaluator/LlamaChatSession/LlamaChatSession.js.map +1 -1
- package/dist/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.js +4 -5
- package/dist/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.js.map +1 -1
- package/dist/evaluator/LlamaCompletion.d.ts +13 -2
- package/dist/evaluator/LlamaCompletion.js +10 -5
- package/dist/evaluator/LlamaCompletion.js.map +1 -1
- package/dist/evaluator/LlamaContext/LlamaContext.d.ts +1 -1
- package/dist/evaluator/LlamaContext/LlamaContext.js +60 -0
- package/dist/evaluator/LlamaContext/LlamaContext.js.map +1 -1
- package/dist/evaluator/LlamaContext/types.d.ts +21 -0
- package/dist/evaluator/LlamaGrammar.d.ts +6 -3
- package/dist/evaluator/LlamaGrammar.js +2 -2
- package/dist/evaluator/LlamaGrammar.js.map +1 -1
- package/dist/evaluator/LlamaModel/LlamaModel.d.ts +16 -32
- package/dist/evaluator/LlamaModel/LlamaModel.js +94 -53
- package/dist/evaluator/LlamaModel/LlamaModel.js.map +1 -1
- package/dist/gguf/consts.d.ts +1 -0
- package/dist/gguf/consts.js +4 -0
- package/dist/gguf/consts.js.map +1 -1
- package/dist/gguf/insights/GgufInsights.js +4 -0
- package/dist/gguf/insights/GgufInsights.js.map +1 -1
- package/dist/gguf/parser/GgufV2Parser.js +3 -1
- package/dist/gguf/parser/GgufV2Parser.js.map +1 -1
- package/dist/gguf/types/GgufMetadataTypes.d.ts +16 -0
- package/dist/gguf/types/GgufMetadataTypes.js.map +1 -1
- package/dist/gguf/utils/convertMetadataKeyValueRecordToNestedObject.d.ts +3 -2
- package/dist/gguf/utils/convertMetadataKeyValueRecordToNestedObject.js +44 -8
- package/dist/gguf/utils/convertMetadataKeyValueRecordToNestedObject.js.map +1 -1
- package/dist/index.d.ts +4 -2
- package/dist/index.js +3 -1
- package/dist/index.js.map +1 -1
- package/dist/types.d.ts +15 -1
- package/dist/types.js.map +1 -1
- package/dist/utils/DeepPartialObject.d.ts +3 -0
- package/dist/utils/DeepPartialObject.js +2 -0
- package/dist/utils/DeepPartialObject.js.map +1 -0
- package/dist/utils/StopGenerationDetector.d.ts +6 -3
- package/dist/utils/StopGenerationDetector.js +22 -7
- package/dist/utils/StopGenerationDetector.js.map +1 -1
- package/dist/utils/TokenStreamRegulator.d.ts +1 -0
- package/dist/utils/TokenStreamRegulator.js +23 -5
- package/dist/utils/TokenStreamRegulator.js.map +1 -1
- package/dist/utils/resolveLastTokens.d.ts +2 -0
- package/dist/utils/resolveLastTokens.js +12 -0
- package/dist/utils/resolveLastTokens.js.map +1 -0
- package/llama/CMakeLists.txt +1 -1
- package/llama/addon/AddonContext.cpp +772 -0
- package/llama/addon/AddonContext.h +53 -0
- package/llama/addon/AddonGrammar.cpp +44 -0
- package/llama/addon/AddonGrammar.h +18 -0
- package/llama/addon/AddonGrammarEvaluationState.cpp +28 -0
- package/llama/addon/AddonGrammarEvaluationState.h +15 -0
- package/llama/addon/AddonModel.cpp +681 -0
- package/llama/addon/AddonModel.h +61 -0
- package/llama/addon/AddonModelData.cpp +25 -0
- package/llama/addon/AddonModelData.h +15 -0
- package/llama/addon/AddonModelLora.cpp +107 -0
- package/llama/addon/AddonModelLora.h +28 -0
- package/llama/addon/addon.cpp +217 -0
- package/llama/addon/addonGlobals.cpp +22 -0
- package/llama/addon/addonGlobals.h +12 -0
- package/llama/addon/globals/addonLog.cpp +135 -0
- package/llama/addon/globals/addonLog.h +21 -0
- package/llama/addon/globals/addonProgress.cpp +15 -0
- package/llama/addon/globals/addonProgress.h +15 -0
- package/llama/addon/globals/getGpuInfo.cpp +108 -0
- package/llama/addon/globals/getGpuInfo.h +6 -0
- package/llama/binariesGithubRelease.json +1 -1
- package/llama/gitRelease.bundle +0 -0
- package/llama/grammars/README.md +1 -1
- package/llama/llama.cpp.info.json +1 -1
- package/package.json +3 -3
- package/templates/packed/electron-typescript-react.json +1 -1
- package/templates/packed/node-typescript.json +1 -1
- package/llama/addon.cpp +0 -2014
|
@@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
1219
1219
|
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
|
1220
1220
|
}
|
|
1221
1221
|
|
|
1222
|
-
#define
|
|
1222
|
+
#define N_MV_T_T 4
|
|
1223
1223
|
|
|
1224
|
-
|
|
1224
|
+
template<typename T0, typename T04, typename T1, typename T14>
|
|
1225
|
+
void kernel_mul_mv_impl(
|
|
1225
1226
|
device const char * src0,
|
|
1226
1227
|
device const char * src1,
|
|
1227
1228
|
device float * dst,
|
|
@@ -1239,13 +1240,12 @@ void kernel_mul_mv_f32_f32_impl(
|
|
|
1239
1240
|
uint64_t nb12,
|
|
1240
1241
|
int64_t ne0,
|
|
1241
1242
|
int64_t ne1,
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1243
|
+
uint r2,
|
|
1244
|
+
uint r3,
|
|
1245
|
+
uint3 tgpig,
|
|
1246
|
+
uint tiisg) {
|
|
1247
1247
|
const int64_t r0 = tgpig.x;
|
|
1248
|
-
const int64_t rb = tgpig.y*
|
|
1248
|
+
const int64_t rb = tgpig.y*N_MV_T_T;
|
|
1249
1249
|
const int64_t im = tgpig.z;
|
|
1250
1250
|
|
|
1251
1251
|
const uint i12 = im%ne12;
|
|
@@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
|
|
|
1253
1253
|
|
|
1254
1254
|
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1255
1255
|
|
|
1256
|
-
device const
|
|
1256
|
+
device const T0 * x = (device const T0 *) (src0 + offset0);
|
|
1257
1257
|
|
|
1258
1258
|
if (ne00 < 128) {
|
|
1259
|
-
for (int row = 0; row <
|
|
1259
|
+
for (int row = 0; row < N_MV_T_T; ++row) {
|
|
1260
1260
|
int r1 = rb + row;
|
|
1261
1261
|
if (r1 >= ne11) {
|
|
1262
1262
|
break;
|
|
1263
1263
|
}
|
|
1264
1264
|
|
|
1265
|
-
device const
|
|
1265
|
+
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
|
1266
1266
|
|
|
1267
1267
|
float sumf = 0;
|
|
1268
1268
|
for (int i = tiisg; i < ne00; i += 32) {
|
|
1269
|
-
sumf += (
|
|
1269
|
+
sumf += (T0) x[i] * (T1) y[i];
|
|
1270
1270
|
}
|
|
1271
1271
|
|
|
1272
1272
|
float all_sum = simd_sum(sumf);
|
|
@@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
|
|
|
1275
1275
|
}
|
|
1276
1276
|
}
|
|
1277
1277
|
} else {
|
|
1278
|
-
device const
|
|
1279
|
-
for (int row = 0; row <
|
|
1278
|
+
device const T04 * x4 = (device const T04 *) x;
|
|
1279
|
+
for (int row = 0; row < N_MV_T_T; ++row) {
|
|
1280
1280
|
int r1 = rb + row;
|
|
1281
1281
|
if (r1 >= ne11) {
|
|
1282
1282
|
break;
|
|
1283
1283
|
}
|
|
1284
1284
|
|
|
1285
|
-
device const
|
|
1286
|
-
device const
|
|
1285
|
+
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
|
1286
|
+
device const T14 * y4 = (device const T14 *) y;
|
|
1287
1287
|
|
|
1288
1288
|
float sumf = 0;
|
|
1289
1289
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
1290
|
-
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
1290
|
+
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
|
1291
1291
|
}
|
|
1292
1292
|
|
|
1293
1293
|
float all_sum = simd_sum(sumf);
|
|
1294
1294
|
if (tiisg == 0) {
|
|
1295
|
-
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
|
1295
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
|
1296
1296
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1297
1297
|
}
|
|
1298
1298
|
}
|
|
1299
1299
|
}
|
|
1300
1300
|
}
|
|
1301
1301
|
|
|
1302
|
-
|
|
1303
|
-
kernel void
|
|
1302
|
+
template<typename T0, typename T04, typename T1, typename T14>
|
|
1303
|
+
kernel void kernel_mul_mv(
|
|
1304
1304
|
device const char * src0,
|
|
1305
1305
|
device const char * src1,
|
|
1306
1306
|
device float * dst,
|
|
@@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
1322
1322
|
constant uint & r3,
|
|
1323
1323
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1324
1324
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1325
|
-
|
|
1325
|
+
kernel_mul_mv_impl<T0, T04, T1, T14>(
|
|
1326
|
+
src0,
|
|
1327
|
+
src1,
|
|
1328
|
+
dst,
|
|
1329
|
+
ne00,
|
|
1330
|
+
ne01,
|
|
1331
|
+
ne02,
|
|
1332
|
+
nb00,
|
|
1333
|
+
nb01,
|
|
1334
|
+
nb02,
|
|
1335
|
+
ne10,
|
|
1336
|
+
ne11,
|
|
1337
|
+
ne12,
|
|
1338
|
+
nb10,
|
|
1339
|
+
nb11,
|
|
1340
|
+
nb12,
|
|
1341
|
+
ne0,
|
|
1342
|
+
ne1,
|
|
1343
|
+
r2,
|
|
1344
|
+
r3,
|
|
1345
|
+
tgpig,
|
|
1346
|
+
tiisg);
|
|
1326
1347
|
}
|
|
1327
1348
|
|
|
1328
|
-
|
|
1349
|
+
typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
|
1329
1350
|
|
|
1330
|
-
kernel
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
device float * dst,
|
|
1334
|
-
constant int64_t & ne00,
|
|
1335
|
-
constant int64_t & ne01,
|
|
1336
|
-
constant int64_t & ne02,
|
|
1337
|
-
constant uint64_t & nb00,
|
|
1338
|
-
constant uint64_t & nb01,
|
|
1339
|
-
constant uint64_t & nb02,
|
|
1340
|
-
constant int64_t & ne10,
|
|
1341
|
-
constant int64_t & ne11,
|
|
1342
|
-
constant int64_t & ne12,
|
|
1343
|
-
constant uint64_t & nb10,
|
|
1344
|
-
constant uint64_t & nb11,
|
|
1345
|
-
constant uint64_t & nb12,
|
|
1346
|
-
constant int64_t & ne0,
|
|
1347
|
-
constant int64_t & ne1,
|
|
1348
|
-
constant uint & r2,
|
|
1349
|
-
constant uint & r3,
|
|
1350
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1351
|
-
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1351
|
+
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
|
1352
|
+
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
|
1353
|
+
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
|
1352
1354
|
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
const int64_t im = tgpig.z;
|
|
1356
|
-
|
|
1357
|
-
const uint i12 = im%ne12;
|
|
1358
|
-
const uint i13 = im/ne12;
|
|
1359
|
-
|
|
1360
|
-
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1361
|
-
|
|
1362
|
-
device const half * x = (device const half *) (src0 + offset0);
|
|
1363
|
-
|
|
1364
|
-
if (ne00 < 128) {
|
|
1365
|
-
for (int row = 0; row < N_F16_F16; ++row) {
|
|
1366
|
-
int r1 = rb + row;
|
|
1367
|
-
if (r1 >= ne11) {
|
|
1368
|
-
break;
|
|
1369
|
-
}
|
|
1370
|
-
|
|
1371
|
-
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
1372
|
-
|
|
1373
|
-
float sumf = 0;
|
|
1374
|
-
for (int i = tiisg; i < ne00; i += 32) {
|
|
1375
|
-
sumf += (half) x[i] * (half) y[i];
|
|
1376
|
-
}
|
|
1377
|
-
|
|
1378
|
-
float all_sum = simd_sum(sumf);
|
|
1379
|
-
if (tiisg == 0) {
|
|
1380
|
-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1381
|
-
}
|
|
1382
|
-
}
|
|
1383
|
-
} else {
|
|
1384
|
-
device const half4 * x4 = (device const half4 *)x;
|
|
1385
|
-
for (int row = 0; row < N_F16_F16; ++row) {
|
|
1386
|
-
int r1 = rb + row;
|
|
1387
|
-
if (r1 >= ne11) {
|
|
1388
|
-
break;
|
|
1389
|
-
}
|
|
1390
|
-
|
|
1391
|
-
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
1392
|
-
device const half4 * y4 = (device const half4 *) y;
|
|
1393
|
-
|
|
1394
|
-
float sumf = 0;
|
|
1395
|
-
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
1396
|
-
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
|
1397
|
-
}
|
|
1398
|
-
|
|
1399
|
-
float all_sum = simd_sum(sumf);
|
|
1400
|
-
if (tiisg == 0) {
|
|
1401
|
-
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
|
1402
|
-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1403
|
-
}
|
|
1404
|
-
}
|
|
1405
|
-
}
|
|
1406
|
-
}
|
|
1407
|
-
|
|
1408
|
-
void kernel_mul_mv_f16_f32_1row_impl(
|
|
1355
|
+
template<typename T, typename T4>
|
|
1356
|
+
kernel void kernel_mul_mv_1row(
|
|
1409
1357
|
device const char * src0,
|
|
1410
1358
|
device const char * src1,
|
|
1411
1359
|
device float * dst,
|
|
@@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
|
|
1437
1385
|
|
|
1438
1386
|
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1439
1387
|
|
|
1440
|
-
device const
|
|
1388
|
+
device const T * x = (device const T *) (src0 + offset0);
|
|
1441
1389
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
1442
1390
|
|
|
1443
1391
|
float sumf = 0;
|
|
@@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
|
|
1450
1398
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1451
1399
|
}
|
|
1452
1400
|
} else {
|
|
1453
|
-
device const
|
|
1401
|
+
device const T4 * x4 = (device const T4 *) x;
|
|
1454
1402
|
device const float4 * y4 = (device const float4 *) y;
|
|
1403
|
+
|
|
1455
1404
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
1456
|
-
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
|
1405
|
+
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
|
1457
1406
|
}
|
|
1407
|
+
|
|
1458
1408
|
float all_sum = simd_sum(sumf);
|
|
1409
|
+
|
|
1459
1410
|
if (tiisg == 0) {
|
|
1460
|
-
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
|
1411
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
|
1461
1412
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1462
1413
|
}
|
|
1463
1414
|
}
|
|
1464
1415
|
}
|
|
1465
1416
|
|
|
1466
|
-
|
|
1467
|
-
kernel void kernel_mul_mv_f16_f32_1row(
|
|
1468
|
-
device const char * src0,
|
|
1469
|
-
device const char * src1,
|
|
1470
|
-
device float * dst,
|
|
1471
|
-
constant int64_t & ne00,
|
|
1472
|
-
constant int64_t & ne01,
|
|
1473
|
-
constant int64_t & ne02,
|
|
1474
|
-
constant uint64_t & nb00,
|
|
1475
|
-
constant uint64_t & nb01,
|
|
1476
|
-
constant uint64_t & nb02,
|
|
1477
|
-
constant int64_t & ne10,
|
|
1478
|
-
constant int64_t & ne11,
|
|
1479
|
-
constant int64_t & ne12,
|
|
1480
|
-
constant uint64_t & nb10,
|
|
1481
|
-
constant uint64_t & nb11,
|
|
1482
|
-
constant uint64_t & nb12,
|
|
1483
|
-
constant int64_t & ne0,
|
|
1484
|
-
constant int64_t & ne1,
|
|
1485
|
-
constant uint & r2,
|
|
1486
|
-
constant uint & r3,
|
|
1487
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1488
|
-
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1489
|
-
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
1490
|
-
}
|
|
1491
|
-
|
|
1492
|
-
#define N_F16_F32 4
|
|
1417
|
+
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
|
1493
1418
|
|
|
1494
|
-
|
|
1495
|
-
device const char * src0,
|
|
1496
|
-
device const char * src1,
|
|
1497
|
-
device float * dst,
|
|
1498
|
-
int64_t ne00,
|
|
1499
|
-
int64_t ne01,
|
|
1500
|
-
int64_t ne02,
|
|
1501
|
-
uint64_t nb00,
|
|
1502
|
-
uint64_t nb01,
|
|
1503
|
-
uint64_t nb02,
|
|
1504
|
-
int64_t ne10,
|
|
1505
|
-
int64_t ne11,
|
|
1506
|
-
int64_t ne12,
|
|
1507
|
-
uint64_t nb10,
|
|
1508
|
-
uint64_t nb11,
|
|
1509
|
-
uint64_t nb12,
|
|
1510
|
-
int64_t ne0,
|
|
1511
|
-
int64_t ne1,
|
|
1512
|
-
uint r2,
|
|
1513
|
-
uint r3,
|
|
1514
|
-
uint3 tgpig,
|
|
1515
|
-
uint tiisg) {
|
|
1516
|
-
|
|
1517
|
-
const int64_t r0 = tgpig.x;
|
|
1518
|
-
const int64_t rb = tgpig.y*N_F16_F32;
|
|
1519
|
-
const int64_t im = tgpig.z;
|
|
1520
|
-
|
|
1521
|
-
const uint i12 = im%ne12;
|
|
1522
|
-
const uint i13 = im/ne12;
|
|
1523
|
-
|
|
1524
|
-
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1525
|
-
|
|
1526
|
-
device const half * x = (device const half *) (src0 + offset0);
|
|
1527
|
-
|
|
1528
|
-
if (ne00 < 128) {
|
|
1529
|
-
for (int row = 0; row < N_F16_F32; ++row) {
|
|
1530
|
-
int r1 = rb + row;
|
|
1531
|
-
if (r1 >= ne11) {
|
|
1532
|
-
break;
|
|
1533
|
-
}
|
|
1534
|
-
|
|
1535
|
-
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
1536
|
-
|
|
1537
|
-
float sumf = 0;
|
|
1538
|
-
for (int i = tiisg; i < ne00; i += 32) {
|
|
1539
|
-
sumf += (float) x[i] * (float) y[i];
|
|
1540
|
-
}
|
|
1541
|
-
|
|
1542
|
-
float all_sum = simd_sum(sumf);
|
|
1543
|
-
if (tiisg == 0) {
|
|
1544
|
-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1545
|
-
}
|
|
1546
|
-
}
|
|
1547
|
-
} else {
|
|
1548
|
-
device const half4 * x4 = (device const half4 *)x;
|
|
1549
|
-
for (int row = 0; row < N_F16_F32; ++row) {
|
|
1550
|
-
int r1 = rb + row;
|
|
1551
|
-
if (r1 >= ne11) {
|
|
1552
|
-
break;
|
|
1553
|
-
}
|
|
1554
|
-
|
|
1555
|
-
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
1556
|
-
device const float4 * y4 = (device const float4 *) y;
|
|
1557
|
-
|
|
1558
|
-
float sumf = 0;
|
|
1559
|
-
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
1560
|
-
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
1561
|
-
}
|
|
1562
|
-
|
|
1563
|
-
float all_sum = simd_sum(sumf);
|
|
1564
|
-
if (tiisg == 0) {
|
|
1565
|
-
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
|
1566
|
-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1567
|
-
}
|
|
1568
|
-
}
|
|
1569
|
-
}
|
|
1570
|
-
}
|
|
1571
|
-
|
|
1572
|
-
[[host_name("kernel_mul_mv_f16_f32")]]
|
|
1573
|
-
kernel void kernel_mul_mv_f16_f32(
|
|
1574
|
-
device const char * src0,
|
|
1575
|
-
device const char * src1,
|
|
1576
|
-
device float * dst,
|
|
1577
|
-
constant int64_t & ne00,
|
|
1578
|
-
constant int64_t & ne01,
|
|
1579
|
-
constant int64_t & ne02,
|
|
1580
|
-
constant uint64_t & nb00,
|
|
1581
|
-
constant uint64_t & nb01,
|
|
1582
|
-
constant uint64_t & nb02,
|
|
1583
|
-
constant int64_t & ne10,
|
|
1584
|
-
constant int64_t & ne11,
|
|
1585
|
-
constant int64_t & ne12,
|
|
1586
|
-
constant uint64_t & nb10,
|
|
1587
|
-
constant uint64_t & nb11,
|
|
1588
|
-
constant uint64_t & nb12,
|
|
1589
|
-
constant int64_t & ne0,
|
|
1590
|
-
constant int64_t & ne1,
|
|
1591
|
-
constant uint & r2,
|
|
1592
|
-
constant uint & r3,
|
|
1593
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1594
|
-
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1595
|
-
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
1596
|
-
}
|
|
1419
|
+
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
|
1597
1420
|
|
|
1598
1421
|
// Assumes row size (ne00) is a multiple of 4
|
|
1599
|
-
|
|
1422
|
+
template<typename T, typename T4>
|
|
1423
|
+
kernel void kernel_mul_mv_l4(
|
|
1600
1424
|
device const char * src0,
|
|
1601
1425
|
device const char * src1,
|
|
1602
1426
|
device float * dst,
|
|
@@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
1628
1452
|
|
|
1629
1453
|
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1630
1454
|
|
|
1631
|
-
device const
|
|
1455
|
+
device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
|
1632
1456
|
|
|
1633
1457
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
|
1634
1458
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
|
1635
1459
|
|
|
1636
1460
|
float sumf = 0;
|
|
1637
1461
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
1638
|
-
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
1462
|
+
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
|
1639
1463
|
}
|
|
1640
1464
|
|
|
1641
1465
|
float all_sum = simd_sum(sumf);
|
|
@@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
1645
1469
|
}
|
|
1646
1470
|
}
|
|
1647
1471
|
|
|
1472
|
+
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
|
1473
|
+
|
|
1474
|
+
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
|
1475
|
+
|
|
1648
1476
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
1649
1477
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
1650
1478
|
return 1.0f - min(1.0f, max(0.0f, y));
|
|
@@ -2765,91 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
2765
2593
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
|
2766
2594
|
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
|
2767
2595
|
|
|
2768
|
-
|
|
2769
|
-
|
|
2770
|
-
device
|
|
2771
|
-
|
|
2772
|
-
constant int64_t & ne01,
|
|
2773
|
-
constant int64_t & ne02,
|
|
2774
|
-
constant int64_t & ne03,
|
|
2775
|
-
constant uint64_t & nb00,
|
|
2776
|
-
constant uint64_t & nb01,
|
|
2777
|
-
constant uint64_t & nb02,
|
|
2778
|
-
constant uint64_t & nb03,
|
|
2779
|
-
constant int64_t & ne0,
|
|
2780
|
-
constant int64_t & ne1,
|
|
2781
|
-
constant int64_t & ne2,
|
|
2782
|
-
constant int64_t & ne3,
|
|
2783
|
-
constant uint64_t & nb0,
|
|
2784
|
-
constant uint64_t & nb1,
|
|
2785
|
-
constant uint64_t & nb2,
|
|
2786
|
-
constant uint64_t & nb3,
|
|
2787
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2788
|
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2789
|
-
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2790
|
-
const int64_t i03 = tgpig[2];
|
|
2791
|
-
const int64_t i02 = tgpig[1];
|
|
2792
|
-
const int64_t i01 = tgpig[0];
|
|
2793
|
-
|
|
2794
|
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
2795
|
-
|
|
2796
|
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
2797
|
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
2798
|
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
2799
|
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
2800
|
-
|
|
2801
|
-
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
2802
|
-
|
|
2803
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
2804
|
-
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
2805
|
-
dst_data[i00] = src[0];
|
|
2806
|
-
}
|
|
2807
|
-
}
|
|
2808
|
-
|
|
2809
|
-
kernel void kernel_cpy_f16_f32(
|
|
2810
|
-
device const half * src0,
|
|
2811
|
-
device float * dst,
|
|
2812
|
-
constant int64_t & ne00,
|
|
2813
|
-
constant int64_t & ne01,
|
|
2814
|
-
constant int64_t & ne02,
|
|
2815
|
-
constant int64_t & ne03,
|
|
2816
|
-
constant uint64_t & nb00,
|
|
2817
|
-
constant uint64_t & nb01,
|
|
2818
|
-
constant uint64_t & nb02,
|
|
2819
|
-
constant uint64_t & nb03,
|
|
2820
|
-
constant int64_t & ne0,
|
|
2821
|
-
constant int64_t & ne1,
|
|
2822
|
-
constant int64_t & ne2,
|
|
2823
|
-
constant int64_t & ne3,
|
|
2824
|
-
constant uint64_t & nb0,
|
|
2825
|
-
constant uint64_t & nb1,
|
|
2826
|
-
constant uint64_t & nb2,
|
|
2827
|
-
constant uint64_t & nb3,
|
|
2828
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2829
|
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2830
|
-
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2831
|
-
const int64_t i03 = tgpig[2];
|
|
2832
|
-
const int64_t i02 = tgpig[1];
|
|
2833
|
-
const int64_t i01 = tgpig[0];
|
|
2834
|
-
|
|
2835
|
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
2836
|
-
|
|
2837
|
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
2838
|
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
2839
|
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
2840
|
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
2841
|
-
|
|
2842
|
-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
2843
|
-
|
|
2844
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
2845
|
-
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
2846
|
-
dst_data[i00] = src[0];
|
|
2847
|
-
}
|
|
2848
|
-
}
|
|
2849
|
-
|
|
2850
|
-
kernel void kernel_cpy_f32_f16(
|
|
2851
|
-
device const float * src0,
|
|
2852
|
-
device half * dst,
|
|
2596
|
+
template<typename T0, typename T1>
|
|
2597
|
+
kernel void kernel_cpy(
|
|
2598
|
+
device const void * src0,
|
|
2599
|
+
device void * dst,
|
|
2853
2600
|
constant int64_t & ne00,
|
|
2854
2601
|
constant int64_t & ne01,
|
|
2855
2602
|
constant int64_t & ne02,
|
|
@@ -2880,56 +2627,20 @@ kernel void kernel_cpy_f32_f16(
|
|
|
2880
2627
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
2881
2628
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
2882
2629
|
|
|
2883
|
-
device
|
|
2630
|
+
device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
2884
2631
|
|
|
2885
2632
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
2886
|
-
device const
|
|
2887
|
-
|
|
2888
|
-
dst_data[i00] = src[0];
|
|
2633
|
+
device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
2634
|
+
dst_data[i00] = (T1) src[0];
|
|
2889
2635
|
}
|
|
2890
2636
|
}
|
|
2891
2637
|
|
|
2892
|
-
|
|
2893
|
-
device const float * src0,
|
|
2894
|
-
device float * dst,
|
|
2895
|
-
constant int64_t & ne00,
|
|
2896
|
-
constant int64_t & ne01,
|
|
2897
|
-
constant int64_t & ne02,
|
|
2898
|
-
constant int64_t & ne03,
|
|
2899
|
-
constant uint64_t & nb00,
|
|
2900
|
-
constant uint64_t & nb01,
|
|
2901
|
-
constant uint64_t & nb02,
|
|
2902
|
-
constant uint64_t & nb03,
|
|
2903
|
-
constant int64_t & ne0,
|
|
2904
|
-
constant int64_t & ne1,
|
|
2905
|
-
constant int64_t & ne2,
|
|
2906
|
-
constant int64_t & ne3,
|
|
2907
|
-
constant uint64_t & nb0,
|
|
2908
|
-
constant uint64_t & nb1,
|
|
2909
|
-
constant uint64_t & nb2,
|
|
2910
|
-
constant uint64_t & nb3,
|
|
2911
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2912
|
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2913
|
-
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2914
|
-
const int64_t i03 = tgpig[2];
|
|
2915
|
-
const int64_t i02 = tgpig[1];
|
|
2916
|
-
const int64_t i01 = tgpig[0];
|
|
2917
|
-
|
|
2918
|
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
2919
|
-
|
|
2920
|
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
2921
|
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
2922
|
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
2923
|
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
2638
|
+
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
|
2924
2639
|
|
|
2925
|
-
|
|
2926
|
-
|
|
2927
|
-
|
|
2928
|
-
|
|
2929
|
-
|
|
2930
|
-
dst_data[i00] = src[0];
|
|
2931
|
-
}
|
|
2932
|
-
}
|
|
2640
|
+
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
|
2641
|
+
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
|
2642
|
+
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
|
2643
|
+
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
|
2933
2644
|
|
|
2934
2645
|
kernel void kernel_cpy_f32_q8_0(
|
|
2935
2646
|
device const float * src0,
|
|
@@ -5046,7 +4757,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
5046
4757
|
device const float4 * y4 = (device const float4 *)yb;
|
|
5047
4758
|
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
|
5048
4759
|
|
|
5049
|
-
for (int row = 0; row < 2; ++row) {
|
|
4760
|
+
for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
|
|
5050
4761
|
|
|
5051
4762
|
device const block_iq4_nl & xb = x[row*nb + ib];
|
|
5052
4763
|
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
|
@@ -5078,7 +4789,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
5078
4789
|
yb += 16 * QK4_NL;
|
|
5079
4790
|
}
|
|
5080
4791
|
|
|
5081
|
-
for (int row = 0; row < 2; ++row) {
|
|
4792
|
+
for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
|
|
5082
4793
|
all_sum = simd_sum(sumf[row]);
|
|
5083
4794
|
if (tiisg == 0) {
|
|
5084
4795
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
@@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
|
5730
5441
|
}
|
|
5731
5442
|
|
|
5732
5443
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
5733
|
-
kernel void
|
|
5444
|
+
kernel void kernel_get_rows_q(
|
|
5734
5445
|
device const void * src0,
|
|
5735
|
-
device const
|
|
5446
|
+
device const void * src1,
|
|
5736
5447
|
device float * dst,
|
|
5737
5448
|
constant int64_t & ne00,
|
|
5738
5449
|
constant uint64_t & nb01,
|
|
@@ -5745,55 +5456,24 @@ kernel void kernel_get_rows(
|
|
|
5745
5456
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5746
5457
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
5747
5458
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
5748
|
-
//const int64_t i = tgpig;
|
|
5749
|
-
//const int64_t r = ((device int32_t *) src1)[i];
|
|
5750
|
-
|
|
5751
5459
|
const int64_t i10 = tgpig.x;
|
|
5752
5460
|
const int64_t i11 = tgpig.y;
|
|
5753
5461
|
|
|
5754
|
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5462
|
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5755
5463
|
|
|
5756
5464
|
const int64_t i02 = i11;
|
|
5757
5465
|
|
|
5758
5466
|
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
|
5759
5467
|
float4x4 temp;
|
|
5760
|
-
dequantize_func(
|
|
5761
|
-
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
5468
|
+
dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
5762
5469
|
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
|
5763
5470
|
}
|
|
5764
5471
|
}
|
|
5765
5472
|
|
|
5766
|
-
|
|
5767
|
-
|
|
5768
|
-
device const char * src1,
|
|
5769
|
-
device float * dst,
|
|
5770
|
-
constant int64_t & ne00,
|
|
5771
|
-
constant uint64_t & nb01,
|
|
5772
|
-
constant uint64_t & nb02,
|
|
5773
|
-
constant int64_t & ne10,
|
|
5774
|
-
constant uint64_t & nb10,
|
|
5775
|
-
constant uint64_t & nb11,
|
|
5776
|
-
constant uint64_t & nb1,
|
|
5777
|
-
constant uint64_t & nb2,
|
|
5778
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5779
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
|
5780
|
-
uint3 tptg [[threads_per_threadgroup]]) {
|
|
5781
|
-
const int64_t i10 = tgpig.x;
|
|
5782
|
-
const int64_t i11 = tgpig.y;
|
|
5783
|
-
|
|
5784
|
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5785
|
-
|
|
5786
|
-
const int64_t i02 = i11;
|
|
5787
|
-
|
|
5788
|
-
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
5789
|
-
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
5790
|
-
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
5791
|
-
}
|
|
5792
|
-
}
|
|
5793
|
-
|
|
5794
|
-
kernel void kernel_get_rows_f16(
|
|
5473
|
+
template<typename T>
|
|
5474
|
+
kernel void kernel_get_rows_f(
|
|
5795
5475
|
device const void * src0,
|
|
5796
|
-
device const
|
|
5476
|
+
device const void * src1,
|
|
5797
5477
|
device float * dst,
|
|
5798
5478
|
constant int64_t & ne00,
|
|
5799
5479
|
constant uint64_t & nb01,
|
|
@@ -5809,19 +5489,19 @@ kernel void kernel_get_rows_f16(
|
|
|
5809
5489
|
const int64_t i10 = tgpig.x;
|
|
5810
5490
|
const int64_t i11 = tgpig.y;
|
|
5811
5491
|
|
|
5812
|
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5492
|
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5813
5493
|
|
|
5814
5494
|
const int64_t i02 = i11;
|
|
5815
5495
|
|
|
5816
5496
|
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
5817
|
-
((device float *) ((device char *)
|
|
5818
|
-
|
|
5497
|
+
(( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
5498
|
+
((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
|
5819
5499
|
}
|
|
5820
5500
|
}
|
|
5821
5501
|
|
|
5822
5502
|
kernel void kernel_get_rows_i32(
|
|
5823
5503
|
device const void * src0,
|
|
5824
|
-
device const
|
|
5504
|
+
device const void * src1,
|
|
5825
5505
|
device int32_t * dst,
|
|
5826
5506
|
constant int64_t & ne00,
|
|
5827
5507
|
constant uint64_t & nb01,
|
|
@@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
|
|
|
5837
5517
|
const int64_t i10 = tgpig.x;
|
|
5838
5518
|
const int64_t i11 = tgpig.y;
|
|
5839
5519
|
|
|
5840
|
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5520
|
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5841
5521
|
|
|
5842
5522
|
const int64_t i02 = i11;
|
|
5843
5523
|
|
|
5844
5524
|
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
5845
|
-
((device int32_t *) ((device char *) dst
|
|
5846
|
-
|
|
5525
|
+
(( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
5526
|
+
((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
|
5847
5527
|
}
|
|
5848
5528
|
}
|
|
5849
5529
|
|
|
@@ -5860,28 +5540,28 @@ kernel void kernel_get_rows_i32(
|
|
|
5860
5540
|
#define SG_MAT_ROW 8
|
|
5861
5541
|
|
|
5862
5542
|
// each block_q contains 16*nl weights
|
|
5863
|
-
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread
|
|
5864
|
-
void
|
|
5865
|
-
|
|
5866
|
-
|
|
5867
|
-
|
|
5868
|
-
|
|
5869
|
-
|
|
5870
|
-
|
|
5871
|
-
|
|
5872
|
-
|
|
5873
|
-
|
|
5874
|
-
|
|
5875
|
-
|
|
5876
|
-
|
|
5877
|
-
|
|
5878
|
-
|
|
5879
|
-
|
|
5880
|
-
|
|
5881
|
-
|
|
5882
|
-
|
|
5543
|
+
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
|
5544
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
|
5545
|
+
device const uchar * src1,
|
|
5546
|
+
device float * dst,
|
|
5547
|
+
constant int64_t & ne00,
|
|
5548
|
+
constant int64_t & ne02,
|
|
5549
|
+
constant uint64_t & nb01,
|
|
5550
|
+
constant uint64_t & nb02,
|
|
5551
|
+
constant int64_t & ne12,
|
|
5552
|
+
constant uint64_t & nb10,
|
|
5553
|
+
constant uint64_t & nb11,
|
|
5554
|
+
constant uint64_t & nb12,
|
|
5555
|
+
constant int64_t & ne0,
|
|
5556
|
+
constant int64_t & ne1,
|
|
5557
|
+
constant uint & r2,
|
|
5558
|
+
constant uint & r3,
|
|
5559
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
5560
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5561
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5562
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5883
5563
|
|
|
5884
|
-
threadgroup
|
|
5564
|
+
threadgroup T * sa = (threadgroup T *)(shared_memory);
|
|
5885
5565
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
5886
5566
|
|
|
5887
5567
|
const uint r0 = tgpig.y;
|
|
@@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
5896
5576
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
5897
5577
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
5898
5578
|
|
|
5899
|
-
|
|
5579
|
+
simdgroup_T8x8 ma[4];
|
|
5900
5580
|
simdgroup_float8x8 mb[2];
|
|
5901
5581
|
simdgroup_float8x8 c_res[8];
|
|
5902
5582
|
for (int i = 0; i < 8; i++){
|
|
@@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
5919
5599
|
|
|
5920
5600
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
5921
5601
|
// load data and store to threadgroup memory
|
|
5922
|
-
|
|
5602
|
+
T4x4 temp_a;
|
|
5923
5603
|
dequantize_func(x, il, temp_a);
|
|
5924
5604
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
5925
5605
|
|
|
@@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
5939
5619
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
5940
5620
|
|
|
5941
5621
|
// load matrices from threadgroup memory and conduct outer products
|
|
5942
|
-
threadgroup
|
|
5622
|
+
threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
5943
5623
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
5944
5624
|
|
|
5945
5625
|
#pragma unroll(4)
|
|
@@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
|
|
|
6115
5795
|
}
|
|
6116
5796
|
}
|
|
6117
5797
|
|
|
6118
|
-
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
6119
|
-
kernel void kernel_mul_mm(device const uchar * src0,
|
|
6120
|
-
device const uchar * src1,
|
|
6121
|
-
device float * dst,
|
|
6122
|
-
constant int64_t & ne00,
|
|
6123
|
-
constant int64_t & ne02,
|
|
6124
|
-
constant uint64_t & nb01,
|
|
6125
|
-
constant uint64_t & nb02,
|
|
6126
|
-
constant int64_t & ne12,
|
|
6127
|
-
constant uint64_t & nb10,
|
|
6128
|
-
constant uint64_t & nb11,
|
|
6129
|
-
constant uint64_t & nb12,
|
|
6130
|
-
constant int64_t & ne0,
|
|
6131
|
-
constant int64_t & ne1,
|
|
6132
|
-
constant uint & r2,
|
|
6133
|
-
constant uint & r3,
|
|
6134
|
-
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
6135
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6136
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
|
6137
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6138
|
-
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
6139
|
-
src0,
|
|
6140
|
-
src1,
|
|
6141
|
-
dst,
|
|
6142
|
-
ne00,
|
|
6143
|
-
ne02,
|
|
6144
|
-
nb01,
|
|
6145
|
-
nb02,
|
|
6146
|
-
ne12,
|
|
6147
|
-
nb10,
|
|
6148
|
-
nb11,
|
|
6149
|
-
nb12,
|
|
6150
|
-
ne0,
|
|
6151
|
-
ne1,
|
|
6152
|
-
r2,
|
|
6153
|
-
r3,
|
|
6154
|
-
shared_memory,
|
|
6155
|
-
tgpig,
|
|
6156
|
-
tiitg,
|
|
6157
|
-
sgitg);
|
|
6158
|
-
}
|
|
6159
|
-
|
|
6160
5798
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
6161
5799
|
kernel void kernel_mul_mm_id(
|
|
6162
5800
|
device const uchar * src0s,
|
|
@@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
|
|
|
6237
5875
|
// get rows
|
|
6238
5876
|
//
|
|
6239
5877
|
|
|
6240
|
-
typedef
|
|
6241
|
-
|
|
6242
|
-
|
|
6243
|
-
|
|
6244
|
-
|
|
6245
|
-
|
|
6246
|
-
|
|
6247
|
-
|
|
6248
|
-
|
|
6249
|
-
|
|
6250
|
-
|
|
6251
|
-
|
|
6252
|
-
|
|
6253
|
-
|
|
6254
|
-
|
|
6255
|
-
|
|
6256
|
-
template [[host_name("
|
|
6257
|
-
template [[host_name("
|
|
6258
|
-
template [[host_name("
|
|
6259
|
-
template [[host_name("
|
|
6260
|
-
template [[host_name("
|
|
6261
|
-
template [[host_name("
|
|
6262
|
-
template [[host_name("
|
|
6263
|
-
template [[host_name("
|
|
6264
|
-
template [[host_name("
|
|
6265
|
-
template [[host_name("
|
|
6266
|
-
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
6267
|
-
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
6268
|
-
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
6269
|
-
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
6270
|
-
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
6271
|
-
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
6272
|
-
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
6273
|
-
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
6274
|
-
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
5878
|
+
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
|
5879
|
+
|
|
5880
|
+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
|
5881
|
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
|
5882
|
+
|
|
5883
|
+
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
|
5884
|
+
|
|
5885
|
+
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
|
5886
|
+
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
|
5887
|
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
|
5888
|
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
|
5889
|
+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
|
5890
|
+
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
5891
|
+
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
5892
|
+
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
5893
|
+
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
5894
|
+
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
5895
|
+
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5896
|
+
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5897
|
+
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5898
|
+
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
5899
|
+
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
5900
|
+
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5901
|
+
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
5902
|
+
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
5903
|
+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
6275
5904
|
|
|
6276
5905
|
//
|
|
6277
5906
|
// matrix-matrix multiplication
|
|
6278
5907
|
//
|
|
6279
5908
|
|
|
6280
|
-
typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
|
|
6281
|
-
|
|
6282
|
-
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
|
6283
|
-
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
6284
|
-
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
|
6285
|
-
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
|
6286
|
-
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
|
6287
|
-
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
|
6288
|
-
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
|
6289
|
-
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
6290
|
-
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
6291
|
-
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
6292
|
-
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
6293
|
-
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
6294
|
-
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
6295
|
-
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
6296
|
-
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
6297
|
-
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
6298
|
-
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
6299
|
-
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
6300
|
-
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
6301
|
-
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
6302
|
-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
5909
|
+
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
|
|
5910
|
+
|
|
5911
|
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
|
5912
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
|
5913
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
|
5914
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
|
5915
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
|
5916
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
|
5917
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
|
5918
|
+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
|
5919
|
+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
|
5920
|
+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
|
5921
|
+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
|
5922
|
+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
|
5923
|
+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5924
|
+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5925
|
+
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5926
|
+
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
5927
|
+
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
5928
|
+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5929
|
+
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
5930
|
+
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
5931
|
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
6303
5932
|
|
|
6304
5933
|
//
|
|
6305
5934
|
// indirect matrix-matrix multiplication
|
|
@@ -6436,7 +6065,7 @@ void mmv_fn(
|
|
|
6436
6065
|
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
|
6437
6066
|
}
|
|
6438
6067
|
|
|
6439
|
-
typedef decltype(mmv_fn<
|
|
6068
|
+
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
|
|
6440
6069
|
|
|
6441
6070
|
template<mul_mv_impl_fn_t impl_fn>
|
|
6442
6071
|
kernel void kernel_mul_mv_id(
|
|
@@ -6514,20 +6143,20 @@ kernel void kernel_mul_mv_id(
|
|
|
6514
6143
|
sgitg);
|
|
6515
6144
|
}
|
|
6516
6145
|
|
|
6517
|
-
typedef decltype(kernel_mul_mv_id<mmv_fn<
|
|
6518
|
-
|
|
6519
|
-
template [[host_name("kernel_mul_mv_id_f32_f32")]]
|
|
6520
|
-
template [[host_name("kernel_mul_mv_id_f16_f32")]]
|
|
6521
|
-
template [[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
|
6522
|
-
template [[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
|
6523
|
-
template [[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
|
6524
|
-
template [[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
|
6525
|
-
template [[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
|
6526
|
-
template [[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
|
6527
|
-
template [[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
|
6528
|
-
template [[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
|
6529
|
-
template [[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
|
6530
|
-
template [[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
|
6146
|
+
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
|
|
6147
|
+
|
|
6148
|
+
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
|
6149
|
+
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
|
6150
|
+
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
|
6151
|
+
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
6152
|
+
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
6153
|
+
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
6154
|
+
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
6155
|
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
|
6156
|
+
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
|
6157
|
+
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
|
6158
|
+
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
|
6159
|
+
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
|
6531
6160
|
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
|
6532
6161
|
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
|
6533
6162
|
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|