whisper.rn 0.5.0-rc.8 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/ggml-alloc.c +1 -15
- package/cpp/ggml-backend-reg.cpp +17 -8
- package/cpp/ggml-backend.cpp +15 -22
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +34 -0
- package/cpp/ggml-cpu/ggml-cpu.c +22 -1
- package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/ggml-cpu/ops.cpp +870 -211
- package/cpp/ggml-cpu/ops.h +3 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +12 -9
- package/cpp/ggml-cpu/vec.h +107 -13
- package/cpp/ggml-impl.h +77 -0
- package/cpp/ggml-metal-impl.h +51 -12
- package/cpp/ggml-metal.m +610 -115
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +110 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +314 -88
- package/cpp/ggml.h +137 -11
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +23 -6
- package/cpp/whisper.cpp +15 -6
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +28 -2
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +28 -2
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +1 -0
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +32 -0
- package/src/realtime-transcription/types.ts +6 -0
package/cpp/ggml-cpu/ops.cpp
CHANGED
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
#include "vec.h"
|
|
9
9
|
|
|
10
10
|
#include <float.h>
|
|
11
|
+
#include <algorithm>
|
|
11
12
|
|
|
12
13
|
// wsp_ggml_compute_forward_dup
|
|
13
14
|
|
|
@@ -1283,6 +1284,7 @@ void wsp_ggml_compute_forward_add(
|
|
|
1283
1284
|
case WSP_GGML_TYPE_Q5_0:
|
|
1284
1285
|
case WSP_GGML_TYPE_Q5_1:
|
|
1285
1286
|
case WSP_GGML_TYPE_Q8_0:
|
|
1287
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
1286
1288
|
case WSP_GGML_TYPE_Q2_K:
|
|
1287
1289
|
case WSP_GGML_TYPE_Q3_K:
|
|
1288
1290
|
case WSP_GGML_TYPE_Q4_K:
|
|
@@ -1309,6 +1311,77 @@ void wsp_ggml_compute_forward_add(
|
|
|
1309
1311
|
}
|
|
1310
1312
|
}
|
|
1311
1313
|
|
|
1314
|
+
// wsp_ggml_compute_forward_add_id
|
|
1315
|
+
|
|
1316
|
+
static void wsp_ggml_compute_forward_add_id_f32(
|
|
1317
|
+
const wsp_ggml_compute_params * params,
|
|
1318
|
+
wsp_ggml_tensor * dst) {
|
|
1319
|
+
|
|
1320
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
1321
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
1322
|
+
const wsp_ggml_tensor * src2 = dst->src[2];
|
|
1323
|
+
|
|
1324
|
+
WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F32);
|
|
1325
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
1326
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
1327
|
+
WSP_GGML_ASSERT(src2->type == WSP_GGML_TYPE_I32);
|
|
1328
|
+
|
|
1329
|
+
WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
1330
|
+
WSP_GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
1331
|
+
|
|
1332
|
+
const int ith = params->ith;
|
|
1333
|
+
const int nth = params->nth;
|
|
1334
|
+
|
|
1335
|
+
const int nr = wsp_ggml_nrows(src0);
|
|
1336
|
+
|
|
1337
|
+
WSP_GGML_TENSOR_TERNARY_OP_LOCALS
|
|
1338
|
+
|
|
1339
|
+
WSP_GGML_ASSERT( nb0 == sizeof(float));
|
|
1340
|
+
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
1341
|
+
|
|
1342
|
+
// rows per thread
|
|
1343
|
+
const int dr = (nr + nth - 1)/nth;
|
|
1344
|
+
|
|
1345
|
+
// row range for this thread
|
|
1346
|
+
const int ir0 = dr*ith;
|
|
1347
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
1348
|
+
|
|
1349
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
1350
|
+
// src0 indices
|
|
1351
|
+
const int i3 = ir/(ne2*ne1);
|
|
1352
|
+
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
|
1353
|
+
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
|
1354
|
+
|
|
1355
|
+
// src1 indices
|
|
1356
|
+
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
|
|
1357
|
+
|
|
1358
|
+
WSP_GGML_ASSERT(i11 >= 0 && i11 < ne11);
|
|
1359
|
+
|
|
1360
|
+
wsp_ggml_vec_add_f32(ne0,
|
|
1361
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
|
1362
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
|
1363
|
+
(float *) ((char *) src1->data + i11*nb11));
|
|
1364
|
+
}
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
void wsp_ggml_compute_forward_add_id(
|
|
1368
|
+
const wsp_ggml_compute_params * params,
|
|
1369
|
+
wsp_ggml_tensor * dst) {
|
|
1370
|
+
|
|
1371
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
1372
|
+
|
|
1373
|
+
switch (src0->type) {
|
|
1374
|
+
case WSP_GGML_TYPE_F32:
|
|
1375
|
+
{
|
|
1376
|
+
wsp_ggml_compute_forward_add_id_f32(params, dst);
|
|
1377
|
+
} break;
|
|
1378
|
+
default:
|
|
1379
|
+
{
|
|
1380
|
+
WSP_GGML_ABORT("unsupported type for wsp_ggml_compute_forward_add_id: %s", wsp_ggml_type_name(src0->type));
|
|
1381
|
+
}
|
|
1382
|
+
}
|
|
1383
|
+
}
|
|
1384
|
+
|
|
1312
1385
|
// wsp_ggml_compute_forward_add1
|
|
1313
1386
|
|
|
1314
1387
|
static void wsp_ggml_compute_forward_add1_f32(
|
|
@@ -1660,6 +1733,7 @@ void wsp_ggml_compute_forward_add1(
|
|
|
1660
1733
|
case WSP_GGML_TYPE_Q5_1:
|
|
1661
1734
|
case WSP_GGML_TYPE_Q8_0:
|
|
1662
1735
|
case WSP_GGML_TYPE_Q8_1:
|
|
1736
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
1663
1737
|
case WSP_GGML_TYPE_Q2_K:
|
|
1664
1738
|
case WSP_GGML_TYPE_Q3_K:
|
|
1665
1739
|
case WSP_GGML_TYPE_Q4_K:
|
|
@@ -1787,6 +1861,7 @@ void wsp_ggml_compute_forward_acc(
|
|
|
1787
1861
|
case WSP_GGML_TYPE_Q5_1:
|
|
1788
1862
|
case WSP_GGML_TYPE_Q8_0:
|
|
1789
1863
|
case WSP_GGML_TYPE_Q8_1:
|
|
1864
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
1790
1865
|
case WSP_GGML_TYPE_Q2_K:
|
|
1791
1866
|
case WSP_GGML_TYPE_Q3_K:
|
|
1792
1867
|
case WSP_GGML_TYPE_Q4_K:
|
|
@@ -3185,11 +3260,356 @@ void wsp_ggml_compute_forward_silu_back(
|
|
|
3185
3260
|
}
|
|
3186
3261
|
}
|
|
3187
3262
|
|
|
3188
|
-
// wsp_ggml_compute_forward_reglu
|
|
3189
|
-
|
|
3190
|
-
static void wsp_ggml_compute_forward_reglu_f32(
|
|
3191
|
-
const wsp_ggml_compute_params * params,
|
|
3192
|
-
wsp_ggml_tensor * dst) {
|
|
3263
|
+
// wsp_ggml_compute_forward_reglu
|
|
3264
|
+
|
|
3265
|
+
static void wsp_ggml_compute_forward_reglu_f32(
|
|
3266
|
+
const wsp_ggml_compute_params * params,
|
|
3267
|
+
wsp_ggml_tensor * dst) {
|
|
3268
|
+
|
|
3269
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3270
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
3271
|
+
char * src0_d = (char *) src0->data;
|
|
3272
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3273
|
+
const size_t src0_o = src0->nb[1];
|
|
3274
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3275
|
+
|
|
3276
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
|
|
3277
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
|
|
3278
|
+
|
|
3279
|
+
if (src1) {
|
|
3280
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
|
|
3281
|
+
WSP_GGML_ASSERT(src0->type == src1->type);
|
|
3282
|
+
}
|
|
3283
|
+
|
|
3284
|
+
const int ith = params->ith;
|
|
3285
|
+
const int nth = params->nth;
|
|
3286
|
+
|
|
3287
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3288
|
+
const int nr = wsp_ggml_nrows(src0);
|
|
3289
|
+
|
|
3290
|
+
WSP_GGML_ASSERT(dst->ne[0] == nc);
|
|
3291
|
+
WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
|
|
3292
|
+
|
|
3293
|
+
const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
|
|
3294
|
+
|
|
3295
|
+
// rows per thread
|
|
3296
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3297
|
+
|
|
3298
|
+
// row range for this thread
|
|
3299
|
+
const int ir0 = dr*ith;
|
|
3300
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3301
|
+
|
|
3302
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3303
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3304
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3305
|
+
|
|
3306
|
+
if (!src1) {
|
|
3307
|
+
src0_p += swapped ? nc : 0;
|
|
3308
|
+
src1_p += swapped ? 0 : nc;
|
|
3309
|
+
}
|
|
3310
|
+
|
|
3311
|
+
wsp_ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3312
|
+
|
|
3313
|
+
#ifndef NDEBUG
|
|
3314
|
+
for (int k = 0; k < nc; k++) {
|
|
3315
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3316
|
+
WSP_GGML_UNUSED(x);
|
|
3317
|
+
assert(!isnan(x));
|
|
3318
|
+
assert(!isinf(x));
|
|
3319
|
+
}
|
|
3320
|
+
#endif
|
|
3321
|
+
}
|
|
3322
|
+
}
|
|
3323
|
+
|
|
3324
|
+
static void wsp_ggml_compute_forward_reglu_f16(
|
|
3325
|
+
const wsp_ggml_compute_params * params,
|
|
3326
|
+
wsp_ggml_tensor * dst) {
|
|
3327
|
+
|
|
3328
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3329
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
3330
|
+
char * src0_d = (char *) src0->data;
|
|
3331
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3332
|
+
const size_t src0_o = src0->nb[1];
|
|
3333
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3334
|
+
|
|
3335
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
|
|
3336
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
|
|
3337
|
+
|
|
3338
|
+
if (src1) {
|
|
3339
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
|
|
3340
|
+
WSP_GGML_ASSERT(src0->type == src1->type);
|
|
3341
|
+
}
|
|
3342
|
+
|
|
3343
|
+
const int ith = params->ith;
|
|
3344
|
+
const int nth = params->nth;
|
|
3345
|
+
|
|
3346
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3347
|
+
const int nr = wsp_ggml_nrows(src0);
|
|
3348
|
+
|
|
3349
|
+
WSP_GGML_ASSERT(dst->ne[0] == nc);
|
|
3350
|
+
WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
|
|
3351
|
+
|
|
3352
|
+
const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
|
|
3353
|
+
|
|
3354
|
+
// rows per thread
|
|
3355
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3356
|
+
|
|
3357
|
+
// row range for this thread
|
|
3358
|
+
const int ir0 = dr*ith;
|
|
3359
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3360
|
+
|
|
3361
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3362
|
+
wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3363
|
+
wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3364
|
+
|
|
3365
|
+
if (!src1) {
|
|
3366
|
+
src0_p += swapped ? nc : 0;
|
|
3367
|
+
src1_p += swapped ? 0 : nc;
|
|
3368
|
+
}
|
|
3369
|
+
|
|
3370
|
+
wsp_ggml_vec_reglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3371
|
+
|
|
3372
|
+
#ifndef NDEBUG
|
|
3373
|
+
for (int k = 0; k < nc; k++) {
|
|
3374
|
+
const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3375
|
+
const float v = WSP_GGML_FP16_TO_FP32(x);
|
|
3376
|
+
WSP_GGML_UNUSED(v);
|
|
3377
|
+
assert(!isnan(v));
|
|
3378
|
+
assert(!isinf(v));
|
|
3379
|
+
}
|
|
3380
|
+
#endif
|
|
3381
|
+
}
|
|
3382
|
+
}
|
|
3383
|
+
|
|
3384
|
+
static void wsp_ggml_compute_forward_reglu(
|
|
3385
|
+
const wsp_ggml_compute_params * params,
|
|
3386
|
+
wsp_ggml_tensor * dst) {
|
|
3387
|
+
|
|
3388
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3389
|
+
|
|
3390
|
+
switch (src0->type) {
|
|
3391
|
+
case WSP_GGML_TYPE_F32:
|
|
3392
|
+
{
|
|
3393
|
+
wsp_ggml_compute_forward_reglu_f32(params, dst);
|
|
3394
|
+
} break;
|
|
3395
|
+
case WSP_GGML_TYPE_F16:
|
|
3396
|
+
{
|
|
3397
|
+
wsp_ggml_compute_forward_reglu_f16(params, dst);
|
|
3398
|
+
} break;
|
|
3399
|
+
default:
|
|
3400
|
+
{
|
|
3401
|
+
WSP_GGML_ABORT("fatal error");
|
|
3402
|
+
}
|
|
3403
|
+
}
|
|
3404
|
+
}
|
|
3405
|
+
|
|
3406
|
+
// wsp_ggml_compute_forward_geglu
|
|
3407
|
+
|
|
3408
|
+
static void wsp_ggml_compute_forward_geglu_f32(
|
|
3409
|
+
const wsp_ggml_compute_params * params,
|
|
3410
|
+
wsp_ggml_tensor * dst) {
|
|
3411
|
+
|
|
3412
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3413
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
3414
|
+
char * src0_d = (char *) src0->data;
|
|
3415
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3416
|
+
const size_t src0_o = src0->nb[1];
|
|
3417
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3418
|
+
|
|
3419
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
|
|
3420
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
|
|
3421
|
+
|
|
3422
|
+
if (src1) {
|
|
3423
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
|
|
3424
|
+
WSP_GGML_ASSERT(src0->type == src1->type);
|
|
3425
|
+
}
|
|
3426
|
+
|
|
3427
|
+
const int ith = params->ith;
|
|
3428
|
+
const int nth = params->nth;
|
|
3429
|
+
|
|
3430
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3431
|
+
const int nr = wsp_ggml_nrows(src0);
|
|
3432
|
+
|
|
3433
|
+
WSP_GGML_ASSERT(dst->ne[0] == nc);
|
|
3434
|
+
WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
|
|
3435
|
+
|
|
3436
|
+
const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
|
|
3437
|
+
|
|
3438
|
+
// rows per thread
|
|
3439
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3440
|
+
|
|
3441
|
+
// row range for this thread
|
|
3442
|
+
const int ir0 = dr*ith;
|
|
3443
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3444
|
+
|
|
3445
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3446
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3447
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3448
|
+
|
|
3449
|
+
if (!src1) {
|
|
3450
|
+
src0_p += swapped ? nc : 0;
|
|
3451
|
+
src1_p += swapped ? 0 : nc;
|
|
3452
|
+
}
|
|
3453
|
+
|
|
3454
|
+
wsp_ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3455
|
+
|
|
3456
|
+
#ifndef NDEBUG
|
|
3457
|
+
for (int k = 0; k < nc; k++) {
|
|
3458
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3459
|
+
WSP_GGML_UNUSED(x);
|
|
3460
|
+
assert(!isnan(x));
|
|
3461
|
+
assert(!isinf(x));
|
|
3462
|
+
}
|
|
3463
|
+
#endif
|
|
3464
|
+
}
|
|
3465
|
+
}
|
|
3466
|
+
|
|
3467
|
+
static void wsp_ggml_compute_forward_geglu_f16(
|
|
3468
|
+
const wsp_ggml_compute_params * params,
|
|
3469
|
+
wsp_ggml_tensor * dst) {
|
|
3470
|
+
|
|
3471
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3472
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
3473
|
+
char * src0_d = (char *) src0->data;
|
|
3474
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3475
|
+
const size_t src0_o = src0->nb[1];
|
|
3476
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3477
|
+
|
|
3478
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
|
|
3479
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
|
|
3480
|
+
|
|
3481
|
+
if (src1) {
|
|
3482
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
|
|
3483
|
+
WSP_GGML_ASSERT(src0->type == src1->type);
|
|
3484
|
+
}
|
|
3485
|
+
|
|
3486
|
+
const int ith = params->ith;
|
|
3487
|
+
const int nth = params->nth;
|
|
3488
|
+
|
|
3489
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3490
|
+
const int nr = wsp_ggml_nrows(src0);
|
|
3491
|
+
|
|
3492
|
+
WSP_GGML_ASSERT(dst->ne[0] == nc);
|
|
3493
|
+
WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
|
|
3494
|
+
|
|
3495
|
+
const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
|
|
3496
|
+
|
|
3497
|
+
// rows per thread
|
|
3498
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3499
|
+
|
|
3500
|
+
// row range for this thread
|
|
3501
|
+
const int ir0 = dr*ith;
|
|
3502
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3503
|
+
|
|
3504
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3505
|
+
wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3506
|
+
wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3507
|
+
|
|
3508
|
+
if (!src1) {
|
|
3509
|
+
src0_p += swapped ? nc : 0;
|
|
3510
|
+
src1_p += swapped ? 0 : nc;
|
|
3511
|
+
}
|
|
3512
|
+
|
|
3513
|
+
wsp_ggml_vec_geglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3514
|
+
|
|
3515
|
+
#ifndef NDEBUG
|
|
3516
|
+
for (int k = 0; k < nc; k++) {
|
|
3517
|
+
const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3518
|
+
const float v = WSP_GGML_FP16_TO_FP32(x);
|
|
3519
|
+
WSP_GGML_UNUSED(v);
|
|
3520
|
+
assert(!isnan(v));
|
|
3521
|
+
assert(!isinf(v));
|
|
3522
|
+
}
|
|
3523
|
+
#endif
|
|
3524
|
+
}
|
|
3525
|
+
}
|
|
3526
|
+
|
|
3527
|
+
static void wsp_ggml_compute_forward_geglu(
|
|
3528
|
+
const wsp_ggml_compute_params * params,
|
|
3529
|
+
wsp_ggml_tensor * dst) {
|
|
3530
|
+
|
|
3531
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3532
|
+
|
|
3533
|
+
switch (src0->type) {
|
|
3534
|
+
case WSP_GGML_TYPE_F32:
|
|
3535
|
+
{
|
|
3536
|
+
wsp_ggml_compute_forward_geglu_f32(params, dst);
|
|
3537
|
+
} break;
|
|
3538
|
+
case WSP_GGML_TYPE_F16:
|
|
3539
|
+
{
|
|
3540
|
+
wsp_ggml_compute_forward_geglu_f16(params, dst);
|
|
3541
|
+
} break;
|
|
3542
|
+
default:
|
|
3543
|
+
{
|
|
3544
|
+
WSP_GGML_ABORT("fatal error");
|
|
3545
|
+
}
|
|
3546
|
+
}
|
|
3547
|
+
}
|
|
3548
|
+
|
|
3549
|
+
// wsp_ggml_compute_forward_swiglu
|
|
3550
|
+
|
|
3551
|
+
static void wsp_ggml_compute_forward_swiglu_f32(
|
|
3552
|
+
const wsp_ggml_compute_params * params,
|
|
3553
|
+
wsp_ggml_tensor * dst) {
|
|
3554
|
+
|
|
3555
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3556
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
3557
|
+
char * src0_d = (char *) src0->data;
|
|
3558
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3559
|
+
const size_t src0_o = src0->nb[1];
|
|
3560
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3561
|
+
|
|
3562
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
|
|
3563
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
|
|
3564
|
+
|
|
3565
|
+
if (src1) {
|
|
3566
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
|
|
3567
|
+
WSP_GGML_ASSERT(src0->type == src1->type);
|
|
3568
|
+
}
|
|
3569
|
+
|
|
3570
|
+
const int ith = params->ith;
|
|
3571
|
+
const int nth = params->nth;
|
|
3572
|
+
|
|
3573
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3574
|
+
const int nr = wsp_ggml_nrows(src0);
|
|
3575
|
+
|
|
3576
|
+
WSP_GGML_ASSERT(dst->ne[0] == nc);
|
|
3577
|
+
WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
|
|
3578
|
+
|
|
3579
|
+
const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
|
|
3580
|
+
|
|
3581
|
+
// rows per thread
|
|
3582
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3583
|
+
|
|
3584
|
+
// row range for this thread
|
|
3585
|
+
const int ir0 = dr*ith;
|
|
3586
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3587
|
+
|
|
3588
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3589
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3590
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3591
|
+
|
|
3592
|
+
if (!src1) {
|
|
3593
|
+
src0_p += swapped ? nc : 0;
|
|
3594
|
+
src1_p += swapped ? 0 : nc;
|
|
3595
|
+
}
|
|
3596
|
+
|
|
3597
|
+
wsp_ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3598
|
+
|
|
3599
|
+
#ifndef NDEBUG
|
|
3600
|
+
for (int k = 0; k < nc; k++) {
|
|
3601
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3602
|
+
WSP_GGML_UNUSED(x);
|
|
3603
|
+
assert(!isnan(x));
|
|
3604
|
+
assert(!isinf(x));
|
|
3605
|
+
}
|
|
3606
|
+
#endif
|
|
3607
|
+
}
|
|
3608
|
+
}
|
|
3609
|
+
|
|
3610
|
+
static void wsp_ggml_compute_forward_swiglu_f16(
|
|
3611
|
+
const wsp_ggml_compute_params * params,
|
|
3612
|
+
wsp_ggml_tensor * dst) {
|
|
3193
3613
|
|
|
3194
3614
|
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3195
3615
|
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
@@ -3225,30 +3645,55 @@ static void wsp_ggml_compute_forward_reglu_f32(
|
|
|
3225
3645
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
3226
3646
|
|
|
3227
3647
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3228
|
-
|
|
3229
|
-
|
|
3648
|
+
wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3649
|
+
wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3230
3650
|
|
|
3231
3651
|
if (!src1) {
|
|
3232
3652
|
src0_p += swapped ? nc : 0;
|
|
3233
3653
|
src1_p += swapped ? 0 : nc;
|
|
3234
3654
|
}
|
|
3235
3655
|
|
|
3236
|
-
|
|
3656
|
+
wsp_ggml_vec_swiglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3237
3657
|
|
|
3238
3658
|
#ifndef NDEBUG
|
|
3239
3659
|
for (int k = 0; k < nc; k++) {
|
|
3240
|
-
const
|
|
3241
|
-
|
|
3242
|
-
|
|
3243
|
-
assert(!
|
|
3660
|
+
const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3661
|
+
const float v = WSP_GGML_FP16_TO_FP32(x);
|
|
3662
|
+
WSP_GGML_UNUSED(v);
|
|
3663
|
+
assert(!isnan(v));
|
|
3664
|
+
assert(!isinf(v));
|
|
3244
3665
|
}
|
|
3245
3666
|
#endif
|
|
3246
3667
|
}
|
|
3247
3668
|
}
|
|
3248
3669
|
|
|
3249
|
-
static void
|
|
3250
|
-
|
|
3251
|
-
|
|
3670
|
+
static void wsp_ggml_compute_forward_swiglu(
|
|
3671
|
+
const wsp_ggml_compute_params * params,
|
|
3672
|
+
wsp_ggml_tensor * dst) {
|
|
3673
|
+
|
|
3674
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3675
|
+
|
|
3676
|
+
switch (src0->type) {
|
|
3677
|
+
case WSP_GGML_TYPE_F32:
|
|
3678
|
+
{
|
|
3679
|
+
wsp_ggml_compute_forward_swiglu_f32(params, dst);
|
|
3680
|
+
} break;
|
|
3681
|
+
case WSP_GGML_TYPE_F16:
|
|
3682
|
+
{
|
|
3683
|
+
wsp_ggml_compute_forward_swiglu_f16(params, dst);
|
|
3684
|
+
} break;
|
|
3685
|
+
default:
|
|
3686
|
+
{
|
|
3687
|
+
WSP_GGML_ABORT("fatal error");
|
|
3688
|
+
}
|
|
3689
|
+
}
|
|
3690
|
+
}
|
|
3691
|
+
|
|
3692
|
+
// wsp_ggml_compute_forward_swiglu_oai
|
|
3693
|
+
|
|
3694
|
+
static void wsp_ggml_compute_forward_swiglu_oai_f32(
|
|
3695
|
+
const wsp_ggml_compute_params * params,
|
|
3696
|
+
wsp_ggml_tensor * dst) {
|
|
3252
3697
|
|
|
3253
3698
|
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
3254
3699
|
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
@@ -3275,6 +3720,8 @@ static void wsp_ggml_compute_forward_reglu_f16(
|
|
|
3275
3720
|
WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
|
|
3276
3721
|
|
|
3277
3722
|
const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
|
|
3723
|
+
const float alpha = wsp_ggml_get_op_params_f32(dst, 2);
|
|
3724
|
+
const float limit = wsp_ggml_get_op_params_f32(dst, 3);
|
|
3278
3725
|
|
|
3279
3726
|
// rows per thread
|
|
3280
3727
|
const int dr = (nr + nth - 1)/nth;
|
|
@@ -3284,29 +3731,34 @@ static void wsp_ggml_compute_forward_reglu_f16(
|
|
|
3284
3731
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
3285
3732
|
|
|
3286
3733
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3287
|
-
|
|
3288
|
-
|
|
3734
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3735
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3736
|
+
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
|
|
3289
3737
|
|
|
3290
3738
|
if (!src1) {
|
|
3291
3739
|
src0_p += swapped ? nc : 0;
|
|
3292
3740
|
src1_p += swapped ? 0 : nc;
|
|
3293
3741
|
}
|
|
3294
3742
|
|
|
3295
|
-
|
|
3743
|
+
for (int k = 0; k < nc; k++) {
|
|
3744
|
+
const float x = std::min(src0_p[k], limit);
|
|
3745
|
+
const float y = std::clamp(src1_p[k], -limit, limit);
|
|
3746
|
+
const float out_glu = x / (1.f + expf(alpha * (-x)));
|
|
3747
|
+
dst_p[k] = out_glu * (y + 1.f);
|
|
3748
|
+
}
|
|
3296
3749
|
|
|
3297
3750
|
#ifndef NDEBUG
|
|
3298
3751
|
for (int k = 0; k < nc; k++) {
|
|
3299
|
-
const
|
|
3300
|
-
|
|
3301
|
-
|
|
3302
|
-
assert(!
|
|
3303
|
-
assert(!isinf(v));
|
|
3752
|
+
const float x = dst_p[k];
|
|
3753
|
+
WSP_GGML_UNUSED(x);
|
|
3754
|
+
assert(!isnan(x));
|
|
3755
|
+
assert(!isinf(x));
|
|
3304
3756
|
}
|
|
3305
3757
|
#endif
|
|
3306
3758
|
}
|
|
3307
3759
|
}
|
|
3308
3760
|
|
|
3309
|
-
static void
|
|
3761
|
+
static void wsp_ggml_compute_forward_swiglu_oai(
|
|
3310
3762
|
const wsp_ggml_compute_params * params,
|
|
3311
3763
|
wsp_ggml_tensor * dst) {
|
|
3312
3764
|
|
|
@@ -3315,11 +3767,7 @@ static void wsp_ggml_compute_forward_reglu(
|
|
|
3315
3767
|
switch (src0->type) {
|
|
3316
3768
|
case WSP_GGML_TYPE_F32:
|
|
3317
3769
|
{
|
|
3318
|
-
|
|
3319
|
-
} break;
|
|
3320
|
-
case WSP_GGML_TYPE_F16:
|
|
3321
|
-
{
|
|
3322
|
-
wsp_ggml_compute_forward_reglu_f16(params, dst);
|
|
3770
|
+
wsp_ggml_compute_forward_swiglu_oai_f32(params, dst);
|
|
3323
3771
|
} break;
|
|
3324
3772
|
default:
|
|
3325
3773
|
{
|
|
@@ -3328,9 +3776,9 @@ static void wsp_ggml_compute_forward_reglu(
|
|
|
3328
3776
|
}
|
|
3329
3777
|
}
|
|
3330
3778
|
|
|
3331
|
-
//
|
|
3779
|
+
// wsp_ggml_compute_forward_geglu_erf
|
|
3332
3780
|
|
|
3333
|
-
static void
|
|
3781
|
+
static void wsp_ggml_compute_forward_geglu_erf_f32(
|
|
3334
3782
|
const wsp_ggml_compute_params * params,
|
|
3335
3783
|
wsp_ggml_tensor * dst) {
|
|
3336
3784
|
|
|
@@ -3376,7 +3824,7 @@ static void wsp_ggml_compute_forward_geglu_f32(
|
|
|
3376
3824
|
src1_p += swapped ? 0 : nc;
|
|
3377
3825
|
}
|
|
3378
3826
|
|
|
3379
|
-
|
|
3827
|
+
wsp_ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3380
3828
|
|
|
3381
3829
|
#ifndef NDEBUG
|
|
3382
3830
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3389,7 +3837,7 @@ static void wsp_ggml_compute_forward_geglu_f32(
|
|
|
3389
3837
|
}
|
|
3390
3838
|
}
|
|
3391
3839
|
|
|
3392
|
-
static void
|
|
3840
|
+
static void wsp_ggml_compute_forward_geglu_erf_f16(
|
|
3393
3841
|
const wsp_ggml_compute_params * params,
|
|
3394
3842
|
wsp_ggml_tensor * dst) {
|
|
3395
3843
|
|
|
@@ -3435,7 +3883,7 @@ static void wsp_ggml_compute_forward_geglu_f16(
|
|
|
3435
3883
|
src1_p += swapped ? 0 : nc;
|
|
3436
3884
|
}
|
|
3437
3885
|
|
|
3438
|
-
|
|
3886
|
+
wsp_ggml_vec_geglu_erf_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3439
3887
|
|
|
3440
3888
|
#ifndef NDEBUG
|
|
3441
3889
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3449,7 +3897,7 @@ static void wsp_ggml_compute_forward_geglu_f16(
|
|
|
3449
3897
|
}
|
|
3450
3898
|
}
|
|
3451
3899
|
|
|
3452
|
-
static void
|
|
3900
|
+
static void wsp_ggml_compute_forward_geglu_erf(
|
|
3453
3901
|
const wsp_ggml_compute_params * params,
|
|
3454
3902
|
wsp_ggml_tensor * dst) {
|
|
3455
3903
|
|
|
@@ -3458,11 +3906,11 @@ static void wsp_ggml_compute_forward_geglu(
|
|
|
3458
3906
|
switch (src0->type) {
|
|
3459
3907
|
case WSP_GGML_TYPE_F32:
|
|
3460
3908
|
{
|
|
3461
|
-
|
|
3909
|
+
wsp_ggml_compute_forward_geglu_erf_f32(params, dst);
|
|
3462
3910
|
} break;
|
|
3463
3911
|
case WSP_GGML_TYPE_F16:
|
|
3464
3912
|
{
|
|
3465
|
-
|
|
3913
|
+
wsp_ggml_compute_forward_geglu_erf_f16(params, dst);
|
|
3466
3914
|
} break;
|
|
3467
3915
|
default:
|
|
3468
3916
|
{
|
|
@@ -3471,9 +3919,9 @@ static void wsp_ggml_compute_forward_geglu(
|
|
|
3471
3919
|
}
|
|
3472
3920
|
}
|
|
3473
3921
|
|
|
3474
|
-
//
|
|
3922
|
+
// wsp_ggml_compute_forward_geglu_quick
|
|
3475
3923
|
|
|
3476
|
-
static void
|
|
3924
|
+
static void wsp_ggml_compute_forward_geglu_quick_f32(
|
|
3477
3925
|
const wsp_ggml_compute_params * params,
|
|
3478
3926
|
wsp_ggml_tensor * dst) {
|
|
3479
3927
|
|
|
@@ -3519,7 +3967,7 @@ static void wsp_ggml_compute_forward_swiglu_f32(
|
|
|
3519
3967
|
src1_p += swapped ? 0 : nc;
|
|
3520
3968
|
}
|
|
3521
3969
|
|
|
3522
|
-
|
|
3970
|
+
wsp_ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3523
3971
|
|
|
3524
3972
|
#ifndef NDEBUG
|
|
3525
3973
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3532,7 +3980,7 @@ static void wsp_ggml_compute_forward_swiglu_f32(
|
|
|
3532
3980
|
}
|
|
3533
3981
|
}
|
|
3534
3982
|
|
|
3535
|
-
static void
|
|
3983
|
+
static void wsp_ggml_compute_forward_geglu_quick_f16(
|
|
3536
3984
|
const wsp_ggml_compute_params * params,
|
|
3537
3985
|
wsp_ggml_tensor * dst) {
|
|
3538
3986
|
|
|
@@ -3578,7 +4026,7 @@ static void wsp_ggml_compute_forward_swiglu_f16(
|
|
|
3578
4026
|
src1_p += swapped ? 0 : nc;
|
|
3579
4027
|
}
|
|
3580
4028
|
|
|
3581
|
-
|
|
4029
|
+
wsp_ggml_vec_geglu_quick_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3582
4030
|
|
|
3583
4031
|
#ifndef NDEBUG
|
|
3584
4032
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3592,7 +4040,7 @@ static void wsp_ggml_compute_forward_swiglu_f16(
|
|
|
3592
4040
|
}
|
|
3593
4041
|
}
|
|
3594
4042
|
|
|
3595
|
-
static void
|
|
4043
|
+
static void wsp_ggml_compute_forward_geglu_quick(
|
|
3596
4044
|
const wsp_ggml_compute_params * params,
|
|
3597
4045
|
wsp_ggml_tensor * dst) {
|
|
3598
4046
|
|
|
@@ -3601,11 +4049,11 @@ static void wsp_ggml_compute_forward_swiglu(
|
|
|
3601
4049
|
switch (src0->type) {
|
|
3602
4050
|
case WSP_GGML_TYPE_F32:
|
|
3603
4051
|
{
|
|
3604
|
-
|
|
4052
|
+
wsp_ggml_compute_forward_geglu_quick_f32(params, dst);
|
|
3605
4053
|
} break;
|
|
3606
4054
|
case WSP_GGML_TYPE_F16:
|
|
3607
4055
|
{
|
|
3608
|
-
|
|
4056
|
+
wsp_ggml_compute_forward_geglu_quick_f16(params, dst);
|
|
3609
4057
|
} break;
|
|
3610
4058
|
default:
|
|
3611
4059
|
{
|
|
@@ -3729,6 +4177,9 @@ static void wsp_ggml_compute_forward_rms_norm_f32(
|
|
|
3729
4177
|
|
|
3730
4178
|
const float scale = 1.0f/sqrtf(mean + eps);
|
|
3731
4179
|
|
|
4180
|
+
// if you hit this, likely you got an inf somewhere earlier
|
|
4181
|
+
assert(scale > 0.0f);
|
|
4182
|
+
|
|
3732
4183
|
wsp_ggml_vec_scale_f32(ne00, y, scale);
|
|
3733
4184
|
}
|
|
3734
4185
|
}
|
|
@@ -4310,6 +4761,7 @@ void wsp_ggml_compute_forward_out_prod(
|
|
|
4310
4761
|
case WSP_GGML_TYPE_Q5_0:
|
|
4311
4762
|
case WSP_GGML_TYPE_Q5_1:
|
|
4312
4763
|
case WSP_GGML_TYPE_Q8_0:
|
|
4764
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
4313
4765
|
case WSP_GGML_TYPE_Q2_K:
|
|
4314
4766
|
case WSP_GGML_TYPE_Q3_K:
|
|
4315
4767
|
case WSP_GGML_TYPE_Q4_K:
|
|
@@ -4357,9 +4809,11 @@ static void wsp_ggml_compute_forward_scale_f32(
|
|
|
4357
4809
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst));
|
|
4358
4810
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
|
|
4359
4811
|
|
|
4360
|
-
// scale factor
|
|
4361
|
-
float
|
|
4362
|
-
|
|
4812
|
+
float s; // scale factor
|
|
4813
|
+
float b; // bias
|
|
4814
|
+
|
|
4815
|
+
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
|
|
4816
|
+
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
|
|
4363
4817
|
|
|
4364
4818
|
const int ith = params->ith;
|
|
4365
4819
|
const int nth = params->nth;
|
|
@@ -4378,12 +4832,22 @@ static void wsp_ggml_compute_forward_scale_f32(
|
|
|
4378
4832
|
|
|
4379
4833
|
const size_t nb1 = dst->nb[1];
|
|
4380
4834
|
|
|
4381
|
-
|
|
4382
|
-
|
|
4383
|
-
|
|
4384
|
-
|
|
4835
|
+
if (b == 0.0f) {
|
|
4836
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4837
|
+
if (dst->data != src0->data) {
|
|
4838
|
+
// src0 is same shape as dst => same indices
|
|
4839
|
+
// TODO: add x parameter to wsp_ggml_vec_scale_f32 and remove this memcpy
|
|
4840
|
+
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
|
4841
|
+
}
|
|
4842
|
+
wsp_ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
|
|
4843
|
+
}
|
|
4844
|
+
} else {
|
|
4845
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4846
|
+
wsp_ggml_vec_mad1_f32(nc,
|
|
4847
|
+
(float *) ((char *) dst->data + i1*nb1),
|
|
4848
|
+
(float *) ((char *) src0->data + i1*nb1),
|
|
4849
|
+
s, b);
|
|
4385
4850
|
}
|
|
4386
|
-
wsp_ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
|
|
4387
4851
|
}
|
|
4388
4852
|
}
|
|
4389
4853
|
|
|
@@ -4572,6 +5036,7 @@ void wsp_ggml_compute_forward_set(
|
|
|
4572
5036
|
case WSP_GGML_TYPE_Q5_1:
|
|
4573
5037
|
case WSP_GGML_TYPE_Q8_0:
|
|
4574
5038
|
case WSP_GGML_TYPE_Q8_1:
|
|
5039
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
4575
5040
|
case WSP_GGML_TYPE_Q2_K:
|
|
4576
5041
|
case WSP_GGML_TYPE_Q3_K:
|
|
4577
5042
|
case WSP_GGML_TYPE_Q4_K:
|
|
@@ -4833,6 +5298,7 @@ void wsp_ggml_compute_forward_get_rows(
|
|
|
4833
5298
|
case WSP_GGML_TYPE_Q5_1:
|
|
4834
5299
|
case WSP_GGML_TYPE_Q8_0:
|
|
4835
5300
|
case WSP_GGML_TYPE_Q8_1:
|
|
5301
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
4836
5302
|
case WSP_GGML_TYPE_Q2_K:
|
|
4837
5303
|
case WSP_GGML_TYPE_Q3_K:
|
|
4838
5304
|
case WSP_GGML_TYPE_Q4_K:
|
|
@@ -5222,6 +5688,7 @@ static void wsp_ggml_compute_forward_soft_max_f32(
|
|
|
5222
5688
|
|
|
5223
5689
|
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
5224
5690
|
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
5691
|
+
const wsp_ggml_tensor * src2 = dst->src[2];
|
|
5225
5692
|
|
|
5226
5693
|
assert(wsp_ggml_is_contiguous(dst));
|
|
5227
5694
|
assert(wsp_ggml_are_same_shape(src0, dst));
|
|
@@ -5232,14 +5699,17 @@ static void wsp_ggml_compute_forward_soft_max_f32(
|
|
|
5232
5699
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
5233
5700
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
5234
5701
|
|
|
5235
|
-
// TODO: handle transposed/permuted matrices
|
|
5236
|
-
|
|
5237
5702
|
const int ith = params->ith;
|
|
5238
5703
|
const int nth = params->nth;
|
|
5239
5704
|
|
|
5240
5705
|
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
5241
5706
|
|
|
5242
|
-
|
|
5707
|
+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
|
5708
|
+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
|
5709
|
+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
|
5710
|
+
|
|
5711
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
|
5712
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
|
5243
5713
|
|
|
5244
5714
|
// TODO: is this supposed to be ceil instead of floor?
|
|
5245
5715
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
|
@@ -5249,68 +5719,78 @@ static void wsp_ggml_compute_forward_soft_max_f32(
|
|
|
5249
5719
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
5250
5720
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
5251
5721
|
|
|
5252
|
-
|
|
5253
|
-
const int nr = wsp_ggml_nrows(src0);
|
|
5254
|
-
|
|
5255
|
-
// rows per thread
|
|
5256
|
-
const int dr = (nr + nth - 1)/nth;
|
|
5257
|
-
|
|
5258
|
-
// row range for this thread
|
|
5259
|
-
const int ir0 = dr*ith;
|
|
5260
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
5261
|
-
|
|
5262
|
-
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
|
5722
|
+
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
5263
5723
|
|
|
5264
5724
|
const bool use_f16 = (src1 && src1->type == WSP_GGML_TYPE_F16);
|
|
5265
5725
|
|
|
5266
|
-
|
|
5267
|
-
|
|
5268
|
-
const uint32_t h = (i1/ne01)%ne02; // head
|
|
5269
|
-
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
5270
|
-
|
|
5271
|
-
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
|
5272
|
-
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
5726
|
+
// sinks
|
|
5727
|
+
const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
|
|
5273
5728
|
|
|
5274
|
-
|
|
5275
|
-
|
|
5276
|
-
|
|
5277
|
-
|
|
5278
|
-
|
|
5279
|
-
|
|
5280
|
-
|
|
5281
|
-
|
|
5282
|
-
|
|
5283
|
-
|
|
5284
|
-
|
|
5285
|
-
|
|
5286
|
-
|
|
5287
|
-
|
|
5729
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
5730
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
5731
|
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
5732
|
+
const int64_t i11 = i01;
|
|
5733
|
+
const int64_t i12 = i02%ne12;
|
|
5734
|
+
const int64_t i13 = i03%ne13;
|
|
5735
|
+
|
|
5736
|
+
// ALiBi
|
|
5737
|
+
const uint32_t h = i02; // head
|
|
5738
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
5739
|
+
|
|
5740
|
+
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
5741
|
+
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
5742
|
+
|
|
5743
|
+
// broadcast the mask across rows
|
|
5744
|
+
wsp_ggml_fp16_t * mp_f16 = src1 ? (wsp_ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5745
|
+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5746
|
+
|
|
5747
|
+
wsp_ggml_vec_cpy_f32 (ne00, wp, sp);
|
|
5748
|
+
wsp_ggml_vec_scale_f32(ne00, wp, scale);
|
|
5749
|
+
if (mp_f32) {
|
|
5750
|
+
if (use_f16) {
|
|
5751
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5752
|
+
wp[i] += slope*WSP_GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
|
5753
|
+
}
|
|
5754
|
+
} else {
|
|
5755
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5756
|
+
wp[i] += slope*mp_f32[i];
|
|
5757
|
+
}
|
|
5758
|
+
}
|
|
5288
5759
|
}
|
|
5289
|
-
}
|
|
5290
|
-
}
|
|
5291
5760
|
|
|
5292
5761
|
#ifndef NDEBUG
|
|
5293
|
-
|
|
5294
|
-
|
|
5295
|
-
|
|
5296
|
-
|
|
5762
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5763
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
|
5764
|
+
assert(!isnan(wp[i]));
|
|
5765
|
+
}
|
|
5297
5766
|
#endif
|
|
5298
5767
|
|
|
5299
|
-
|
|
5300
|
-
|
|
5768
|
+
float max = -INFINITY;
|
|
5769
|
+
wsp_ggml_vec_max_f32(ne00, &max, wp);
|
|
5301
5770
|
|
|
5302
|
-
|
|
5303
|
-
|
|
5771
|
+
// if we have sinks, make a correction as if they were included in the softmax
|
|
5772
|
+
if (sk) {
|
|
5773
|
+
max = MAX(max, sk[i02]);
|
|
5774
|
+
}
|
|
5775
|
+
|
|
5776
|
+
wsp_ggml_float sum = wsp_ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
|
5777
|
+
assert(sum > 0.0);
|
|
5778
|
+
|
|
5779
|
+
if (sk) {
|
|
5780
|
+
sum += (wsp_ggml_float) expf(sk[i02] - max);
|
|
5781
|
+
}
|
|
5304
5782
|
|
|
5305
|
-
|
|
5306
|
-
|
|
5783
|
+
sum = 1.0/sum;
|
|
5784
|
+
wsp_ggml_vec_scale_f32(ne00, dp, sum);
|
|
5307
5785
|
|
|
5308
5786
|
#ifndef NDEBUG
|
|
5309
|
-
|
|
5310
|
-
|
|
5311
|
-
|
|
5312
|
-
|
|
5787
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5788
|
+
assert(!isnan(dp[i]));
|
|
5789
|
+
assert(!isinf(dp[i]));
|
|
5790
|
+
}
|
|
5313
5791
|
#endif
|
|
5792
|
+
}
|
|
5793
|
+
}
|
|
5314
5794
|
}
|
|
5315
5795
|
}
|
|
5316
5796
|
|
|
@@ -5534,6 +6014,7 @@ void wsp_ggml_compute_forward_clamp(
|
|
|
5534
6014
|
case WSP_GGML_TYPE_Q5_1:
|
|
5535
6015
|
case WSP_GGML_TYPE_Q8_0:
|
|
5536
6016
|
case WSP_GGML_TYPE_Q8_1:
|
|
6017
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
5537
6018
|
case WSP_GGML_TYPE_Q2_K:
|
|
5538
6019
|
case WSP_GGML_TYPE_Q3_K:
|
|
5539
6020
|
case WSP_GGML_TYPE_Q4_K:
|
|
@@ -7687,12 +8168,14 @@ void wsp_ggml_compute_forward_argsort(
|
|
|
7687
8168
|
|
|
7688
8169
|
static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
7689
8170
|
const wsp_ggml_compute_params * params,
|
|
7690
|
-
const wsp_ggml_tensor * q,
|
|
7691
|
-
const wsp_ggml_tensor * k,
|
|
7692
|
-
const wsp_ggml_tensor * v,
|
|
7693
|
-
const wsp_ggml_tensor * mask,
|
|
7694
8171
|
wsp_ggml_tensor * dst) {
|
|
7695
8172
|
|
|
8173
|
+
const wsp_ggml_tensor * q = dst->src[0];
|
|
8174
|
+
const wsp_ggml_tensor * k = dst->src[1];
|
|
8175
|
+
const wsp_ggml_tensor * v = dst->src[2];
|
|
8176
|
+
const wsp_ggml_tensor * mask = dst->src[3];
|
|
8177
|
+
const wsp_ggml_tensor * sinks = dst->src[4];
|
|
8178
|
+
|
|
7696
8179
|
WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
7697
8180
|
WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
7698
8181
|
WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
@@ -7766,7 +8249,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7766
8249
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
7767
8250
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
7768
8251
|
|
|
7769
|
-
wsp_ggml_type
|
|
8252
|
+
wsp_ggml_type const k_vec_dot_type = wsp_ggml_get_type_traits_cpu(k->type)->vec_dot_type;
|
|
7770
8253
|
wsp_ggml_from_float_t const q_to_vec_dot = wsp_ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
|
|
7771
8254
|
wsp_ggml_vec_dot_t const kq_vec_dot = wsp_ggml_get_type_traits_cpu(k->type)->vec_dot;
|
|
7772
8255
|
wsp_ggml_to_float_t const v_to_float = wsp_ggml_get_type_traits(v->type)->to_float;
|
|
@@ -7798,7 +8281,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7798
8281
|
memset(VKQ32, 0, DV*sizeof(float));
|
|
7799
8282
|
}
|
|
7800
8283
|
|
|
7801
|
-
const wsp_ggml_fp16_t * mp = mask ? (wsp_ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
|
8284
|
+
const wsp_ggml_fp16_t * mp = mask ? (wsp_ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
|
7802
8285
|
|
|
7803
8286
|
// k indices
|
|
7804
8287
|
const int ik3 = iq3 / rk3;
|
|
@@ -7887,6 +8370,23 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7887
8370
|
}
|
|
7888
8371
|
}
|
|
7889
8372
|
|
|
8373
|
+
// sinks
|
|
8374
|
+
if (sinks) {
|
|
8375
|
+
const float s = ((float *)((char *) sinks->data))[h];
|
|
8376
|
+
|
|
8377
|
+
float ms = 1.0f;
|
|
8378
|
+
float vs = 1.0f;
|
|
8379
|
+
|
|
8380
|
+
if (s > M) {
|
|
8381
|
+
ms = expf(M - s);
|
|
8382
|
+
wsp_ggml_vec_scale_f32(DV, VKQ32, ms);
|
|
8383
|
+
} else {
|
|
8384
|
+
vs = expf(s - M);
|
|
8385
|
+
}
|
|
8386
|
+
|
|
8387
|
+
S = S*ms + vs;
|
|
8388
|
+
}
|
|
8389
|
+
|
|
7890
8390
|
// V /= S
|
|
7891
8391
|
const float S_inv = 1.0f/S;
|
|
7892
8392
|
wsp_ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
@@ -7906,17 +8406,13 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7906
8406
|
|
|
7907
8407
|
void wsp_ggml_compute_forward_flash_attn_ext(
|
|
7908
8408
|
const wsp_ggml_compute_params * params,
|
|
7909
|
-
const wsp_ggml_tensor * q,
|
|
7910
|
-
const wsp_ggml_tensor * k,
|
|
7911
|
-
const wsp_ggml_tensor * v,
|
|
7912
|
-
const wsp_ggml_tensor * mask,
|
|
7913
8409
|
wsp_ggml_tensor * dst) {
|
|
7914
8410
|
switch (dst->op_params[3]) {
|
|
7915
8411
|
case WSP_GGML_PREC_DEFAULT:
|
|
7916
8412
|
case WSP_GGML_PREC_F32:
|
|
7917
8413
|
{
|
|
7918
8414
|
// uses F32 accumulators
|
|
7919
|
-
wsp_ggml_compute_forward_flash_attn_ext_f16(params,
|
|
8415
|
+
wsp_ggml_compute_forward_flash_attn_ext_f16(params, dst);
|
|
7920
8416
|
} break;
|
|
7921
8417
|
default:
|
|
7922
8418
|
{
|
|
@@ -8336,120 +8832,210 @@ void wsp_ggml_compute_forward_ssm_conv(
|
|
|
8336
8832
|
static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
8337
8833
|
const wsp_ggml_compute_params * params,
|
|
8338
8834
|
wsp_ggml_tensor * dst) {
|
|
8339
|
-
const wsp_ggml_tensor * src0 = dst->src[0]; // s
|
|
8340
|
-
const wsp_ggml_tensor * src1 = dst->src[1]; // x
|
|
8341
|
-
const wsp_ggml_tensor * src2 = dst->src[2]; // dt
|
|
8342
|
-
const wsp_ggml_tensor * src3 = dst->src[3]; // A
|
|
8343
|
-
const wsp_ggml_tensor * src4 = dst->src[4]; // B
|
|
8344
|
-
const wsp_ggml_tensor * src5 = dst->src[5]; // C
|
|
8835
|
+
const wsp_ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
|
8836
|
+
const wsp_ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
|
8837
|
+
const wsp_ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
|
8838
|
+
const wsp_ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
|
8839
|
+
const wsp_ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8840
|
+
const wsp_ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8841
|
+
const wsp_ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
|
8345
8842
|
|
|
8346
8843
|
const int ith = params->ith;
|
|
8347
8844
|
const int nth = params->nth;
|
|
8348
8845
|
|
|
8349
|
-
const int64_t nc
|
|
8350
|
-
const int64_t nr
|
|
8351
|
-
const int64_t
|
|
8352
|
-
const int64_t
|
|
8846
|
+
const int64_t nc = src0->ne[0]; // d_state
|
|
8847
|
+
const int64_t nr = src0->ne[1]; // dim
|
|
8848
|
+
const int64_t nh = src1->ne[1]; // n_head
|
|
8849
|
+
const int64_t ng = src4->ne[1];
|
|
8850
|
+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
|
8851
|
+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
|
8353
8852
|
|
|
8354
|
-
|
|
8853
|
+
// can't use wsp_ggml_nbytes because src1 is not necessarily contiguous
|
|
8854
|
+
const int64_t s_off = wsp_ggml_nelements(src1) * wsp_ggml_element_size(src1);
|
|
8855
|
+
|
|
8856
|
+
WSP_GGML_ASSERT(wsp_ggml_nelements(src1) + nc*nr*nh*ns == wsp_ggml_nelements(dst));
|
|
8355
8857
|
WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
8356
8858
|
WSP_GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
8357
8859
|
WSP_GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
8358
8860
|
WSP_GGML_ASSERT(src3->nb[0] == sizeof(float));
|
|
8359
8861
|
WSP_GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
8360
8862
|
WSP_GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
8361
|
-
|
|
8362
|
-
|
|
8363
|
-
|
|
8364
|
-
WSP_GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
|
8365
|
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
|
8366
|
-
WSP_GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
|
8863
|
+
WSP_GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
8864
|
+
// allows optimizing the modulo since n_group should be a power of 2
|
|
8865
|
+
WSP_GGML_ASSERT((ng & -ng) == ng);
|
|
8367
8866
|
|
|
8368
|
-
//
|
|
8369
|
-
const int
|
|
8867
|
+
// heads per thread
|
|
8868
|
+
const int dh = (nh + nth - 1)/nth;
|
|
8370
8869
|
|
|
8371
|
-
//
|
|
8372
|
-
const int
|
|
8373
|
-
const int
|
|
8374
|
-
|
|
8870
|
+
// head range for this thread
|
|
8871
|
+
const int ih0 = dh*ith;
|
|
8872
|
+
const int ih1 = MIN(ih0 + dh, nh);
|
|
8873
|
+
|
|
8874
|
+
const int32_t * ids = (const int32_t *) src6->data;
|
|
8875
|
+
|
|
8876
|
+
for (int i3 = 0; i3 < ns; ++i3) {
|
|
8877
|
+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
|
8878
|
+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
|
8879
|
+
|
|
8880
|
+
for (int i2 = 0; i2 < nt; ++i2) {
|
|
8881
|
+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
|
8882
|
+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
|
8883
|
+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
|
8884
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
|
8885
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
|
8886
|
+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
|
8887
|
+
|
|
8888
|
+
if (src3->ne[0] == 1) {
|
|
8889
|
+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
|
8890
|
+
|
|
8891
|
+
// n_head
|
|
8892
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8893
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8894
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8895
|
+
const float dA = expf(dt_soft_plus * A[h]);
|
|
8896
|
+
|
|
8897
|
+
// dim
|
|
8898
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8899
|
+
const int ii = i1 + h*nr;
|
|
8900
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8901
|
+
float sumf = 0.0f;
|
|
8902
|
+
#if defined(WSP_GGML_SIMD)
|
|
8903
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8904
|
+
const int wsp_ggml_f32_epr = svcntw();
|
|
8905
|
+
const int wsp_ggml_f32_step = 1 * wsp_ggml_f32_epr;
|
|
8906
|
+
|
|
8907
|
+
const int np = (nc & ~(wsp_ggml_f32_step - 1));
|
|
8908
|
+
|
|
8909
|
+
WSP_GGML_F32_VEC sum = WSP_GGML_F32_VEC_ZERO;
|
|
8910
|
+
|
|
8911
|
+
WSP_GGML_F32_VEC adA = WSP_GGML_F32_VEC_SET1(dA);
|
|
8912
|
+
WSP_GGML_F32_VEC axdt = WSP_GGML_F32_VEC_SET1(x_dt);
|
|
8913
|
+
|
|
8914
|
+
for (int i = 0; i < np; i += wsp_ggml_f32_step) {
|
|
8915
|
+
// TODO: maybe unroll more?
|
|
8916
|
+
for (int j = 0; j < 1; j++) {
|
|
8917
|
+
WSP_GGML_F32_VEC t0 = WSP_GGML_F32_VEC_LOAD(s0 + i + j*wsp_ggml_f32_epr + ii*nc);
|
|
8918
|
+
WSP_GGML_F32_VEC t1 = WSP_GGML_F32_VEC_LOAD(B + i + j*wsp_ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8919
|
+
WSP_GGML_F32_VEC t2 = WSP_GGML_F32_VEC_LOAD(C + i + j*wsp_ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8920
|
+
|
|
8921
|
+
t0 = WSP_GGML_F32_VEC_MUL(t0, adA);
|
|
8922
|
+
t1 = WSP_GGML_F32_VEC_MUL(t1, axdt);
|
|
8923
|
+
|
|
8924
|
+
t0 = WSP_GGML_F32_VEC_ADD(t0, t1);
|
|
8925
|
+
|
|
8926
|
+
sum = WSP_GGML_F32_VEC_FMA(sum, t0, t2);
|
|
8927
|
+
|
|
8928
|
+
WSP_GGML_F32_VEC_STORE(s + i + j*wsp_ggml_f32_epr + ii*nc, t0);
|
|
8929
|
+
}
|
|
8930
|
+
}
|
|
8931
|
+
|
|
8932
|
+
sumf = WSP_GGML_F32xt_REDUCE_ONE(sum);
|
|
8933
|
+
#else
|
|
8934
|
+
const int np = (nc & ~(WSP_GGML_F32_STEP - 1));
|
|
8935
|
+
|
|
8936
|
+
WSP_GGML_F32_VEC sum[WSP_GGML_F32_ARR] = { WSP_GGML_F32_VEC_ZERO };
|
|
8937
|
+
|
|
8938
|
+
WSP_GGML_F32_VEC adA = WSP_GGML_F32_VEC_SET1(dA);
|
|
8939
|
+
WSP_GGML_F32_VEC axdt = WSP_GGML_F32_VEC_SET1(x_dt);
|
|
8940
|
+
|
|
8941
|
+
WSP_GGML_F32_VEC ax[WSP_GGML_F32_ARR];
|
|
8942
|
+
WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR];
|
|
8943
|
+
WSP_GGML_F32_VEC az[WSP_GGML_F32_ARR];
|
|
8944
|
+
|
|
8945
|
+
for (int i = 0; i < np; i += WSP_GGML_F32_STEP) {
|
|
8946
|
+
for (int j = 0; j < WSP_GGML_F32_ARR; j++) {
|
|
8947
|
+
ax[j] = WSP_GGML_F32_VEC_LOAD(s0 + i + j*WSP_GGML_F32_EPR + ii*nc);
|
|
8948
|
+
ay[j] = WSP_GGML_F32_VEC_LOAD(B + i + j*WSP_GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8949
|
+
az[j] = WSP_GGML_F32_VEC_LOAD(C + i + j*WSP_GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8950
|
+
|
|
8951
|
+
ax[j] = WSP_GGML_F32_VEC_MUL(ax[j], adA);
|
|
8952
|
+
ay[j] = WSP_GGML_F32_VEC_MUL(ay[j], axdt);
|
|
8953
|
+
|
|
8954
|
+
ax[j] = WSP_GGML_F32_VEC_ADD(ax[j], ay[j]);
|
|
8955
|
+
|
|
8956
|
+
sum[j] = WSP_GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
|
8957
|
+
|
|
8958
|
+
WSP_GGML_F32_VEC_STORE(s + i + j*WSP_GGML_F32_EPR + ii*nc, ax[j]);
|
|
8959
|
+
}
|
|
8960
|
+
}
|
|
8375
8961
|
|
|
8376
|
-
|
|
8377
|
-
|
|
8378
|
-
|
|
8379
|
-
|
|
8380
|
-
|
|
8381
|
-
|
|
8382
|
-
|
|
8383
|
-
|
|
8384
|
-
|
|
8385
|
-
|
|
8386
|
-
|
|
8387
|
-
|
|
8388
|
-
|
|
8389
|
-
|
|
8390
|
-
|
|
8391
|
-
|
|
8392
|
-
|
|
8393
|
-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
|
8394
|
-
float x_dt = x[i1] * dt_soft_plus;
|
|
8395
|
-
svfloat32_t vx_dt = WSP_GGML_F32_VEC_SET1(x_dt);
|
|
8396
|
-
svfloat32_t vdt_soft_plus = WSP_GGML_F32_VEC_SET1(dt_soft_plus);
|
|
8397
|
-
svfloat32_t r1_vector = WSP_GGML_F32_VEC_ZERO;
|
|
8398
|
-
|
|
8399
|
-
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
8400
|
-
svfloat32_t vA = WSP_GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
|
8401
|
-
svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k]);
|
|
8402
|
-
svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k]);
|
|
8403
|
-
svfloat32_t vs0 = WSP_GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
|
8404
|
-
|
|
8405
|
-
svfloat32_t t1 = WSP_GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
8406
|
-
t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
8407
|
-
svfloat32_t t2 = WSP_GGML_F32_VEC_MUL(vx_dt, vB);
|
|
8408
|
-
|
|
8409
|
-
vs0 = WSP_GGML_F32_VEC_FMA(vs0, t1, t2);
|
|
8410
|
-
r1_vector = WSP_GGML_F32_VEC_ADD(WSP_GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
8411
|
-
|
|
8412
|
-
WSP_GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
|
8962
|
+
// reduce sum0..sum3 to sum0
|
|
8963
|
+
WSP_GGML_F32_VEC_REDUCE(sumf, sum);
|
|
8964
|
+
#endif
|
|
8965
|
+
#else
|
|
8966
|
+
const int np = 0;
|
|
8967
|
+
#endif
|
|
8968
|
+
// d_state
|
|
8969
|
+
for (int i0 = np; i0 < nc; ++i0) {
|
|
8970
|
+
const int i = i0 + ii*nc;
|
|
8971
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
8972
|
+
// state = prev_state * dA + dB * x
|
|
8973
|
+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
|
8974
|
+
// y = rowwise_dotprod(state, C)
|
|
8975
|
+
sumf += state * C[ig];
|
|
8976
|
+
s[i] = state;
|
|
8977
|
+
}
|
|
8978
|
+
y[ii] = sumf;
|
|
8413
8979
|
}
|
|
8414
|
-
y[i1] = WSP_GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
8415
8980
|
}
|
|
8416
|
-
}
|
|
8417
|
-
|
|
8418
|
-
|
|
8419
|
-
|
|
8420
|
-
|
|
8421
|
-
|
|
8422
|
-
|
|
8423
|
-
|
|
8424
|
-
|
|
8425
|
-
|
|
8426
|
-
|
|
8427
|
-
|
|
8428
|
-
|
|
8429
|
-
|
|
8430
|
-
|
|
8431
|
-
|
|
8432
|
-
|
|
8433
|
-
|
|
8434
|
-
|
|
8435
|
-
|
|
8436
|
-
|
|
8437
|
-
|
|
8438
|
-
|
|
8439
|
-
|
|
8440
|
-
|
|
8441
|
-
|
|
8442
|
-
|
|
8443
|
-
|
|
8444
|
-
|
|
8445
|
-
|
|
8446
|
-
|
|
8981
|
+
} else {
|
|
8982
|
+
// Mamba-1 has an element-wise decay factor for the states
|
|
8983
|
+
|
|
8984
|
+
// n_head
|
|
8985
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8986
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8987
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8988
|
+
|
|
8989
|
+
// dim
|
|
8990
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8991
|
+
const int ii = i1 + h*nr;
|
|
8992
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8993
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8994
|
+
svfloat32_t vx_dt = WSP_GGML_F32_VEC_SET1(x_dt);
|
|
8995
|
+
svfloat32_t vdt_soft_plus = WSP_GGML_F32_VEC_SET1(dt_soft_plus);
|
|
8996
|
+
svfloat32_t r1_vector = WSP_GGML_F32_VEC_ZERO;
|
|
8997
|
+
|
|
8998
|
+
// d_state
|
|
8999
|
+
// TODO: what happens when (d_state % svcntw()) != 0?
|
|
9000
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
9001
|
+
svfloat32_t vA = WSP_GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
|
9002
|
+
svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
|
|
9003
|
+
svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
|
|
9004
|
+
svfloat32_t vs0 = WSP_GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
|
9005
|
+
|
|
9006
|
+
svfloat32_t t1 = WSP_GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
9007
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
9008
|
+
svfloat32_t t2 = WSP_GGML_F32_VEC_MUL(vx_dt, vB);
|
|
9009
|
+
|
|
9010
|
+
vs0 = WSP_GGML_F32_VEC_FMA(t2, vs0, t1);
|
|
9011
|
+
r1_vector = WSP_GGML_F32_VEC_ADD(WSP_GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
9012
|
+
|
|
9013
|
+
WSP_GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
|
9014
|
+
}
|
|
9015
|
+
y[ii] = WSP_GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
9016
|
+
#else
|
|
9017
|
+
float sumf = 0.0f;
|
|
9018
|
+
// NOTE: can't really use WSP_GGML_SIMD here because d_state is usually 16
|
|
9019
|
+
// and also because expf is used within the loop.
|
|
9020
|
+
// d_state
|
|
9021
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
|
9022
|
+
const int i = i0 + ii*nc;
|
|
9023
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
9024
|
+
// state = prev_state * dA + dB * x
|
|
9025
|
+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
|
9026
|
+
// y = rowwise_dotprod(state, C)
|
|
9027
|
+
sumf += state * C[ig];
|
|
9028
|
+
s[i] = state;
|
|
9029
|
+
}
|
|
9030
|
+
y[ii] = sumf;
|
|
9031
|
+
#endif
|
|
8447
9032
|
}
|
|
8448
|
-
y[i1] = sumf;
|
|
8449
9033
|
}
|
|
8450
9034
|
}
|
|
9035
|
+
// use the output as the source when it's not the first token-wise iteration
|
|
9036
|
+
s0 = s;
|
|
8451
9037
|
}
|
|
8452
|
-
|
|
9038
|
+
}
|
|
8453
9039
|
}
|
|
8454
9040
|
|
|
8455
9041
|
void wsp_ggml_compute_forward_ssm_scan(
|
|
@@ -8688,6 +9274,18 @@ void wsp_ggml_compute_forward_glu(
|
|
|
8688
9274
|
{
|
|
8689
9275
|
wsp_ggml_compute_forward_swiglu(params, dst);
|
|
8690
9276
|
} break;
|
|
9277
|
+
case WSP_GGML_GLU_OP_SWIGLU_OAI:
|
|
9278
|
+
{
|
|
9279
|
+
wsp_ggml_compute_forward_swiglu_oai(params, dst);
|
|
9280
|
+
} break;
|
|
9281
|
+
case WSP_GGML_GLU_OP_GEGLU_ERF:
|
|
9282
|
+
{
|
|
9283
|
+
wsp_ggml_compute_forward_geglu_erf(params, dst);
|
|
9284
|
+
} break;
|
|
9285
|
+
case WSP_GGML_GLU_OP_GEGLU_QUICK:
|
|
9286
|
+
{
|
|
9287
|
+
wsp_ggml_compute_forward_geglu_quick(params, dst);
|
|
9288
|
+
} break;
|
|
8691
9289
|
default:
|
|
8692
9290
|
{
|
|
8693
9291
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -9732,6 +10330,7 @@ static void wsp_ggml_compute_forward_opt_step_adamw_f32(
|
|
|
9732
10330
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
9733
10331
|
|
|
9734
10332
|
const float * adamw_params_ptr = wsp_ggml_get_data_f32(adamw_params);
|
|
10333
|
+
|
|
9735
10334
|
const float alpha = adamw_params_ptr[0];
|
|
9736
10335
|
const float beta1 = adamw_params_ptr[1];
|
|
9737
10336
|
const float beta2 = adamw_params_ptr[2];
|
|
@@ -9739,7 +10338,7 @@ static void wsp_ggml_compute_forward_opt_step_adamw_f32(
|
|
|
9739
10338
|
const float wd = adamw_params_ptr[4];
|
|
9740
10339
|
const float beta1h = adamw_params_ptr[5];
|
|
9741
10340
|
const float beta2h = adamw_params_ptr[6];
|
|
9742
|
-
|
|
10341
|
+
const float keep = 1.f - alpha * wd;
|
|
9743
10342
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
9744
10343
|
const int64_t i03 = ir/(ne02*ne01);
|
|
9745
10344
|
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
@@ -9762,7 +10361,7 @@ static void wsp_ggml_compute_forward_opt_step_adamw_f32(
|
|
|
9762
10361
|
// The weight decay is applied independently of the Adam momenta m and v.
|
|
9763
10362
|
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
|
9764
10363
|
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
|
9765
|
-
w[i00] = w[i00]*
|
|
10364
|
+
w[i00] = w[i00] * keep - alpha * mh / vh;
|
|
9766
10365
|
}
|
|
9767
10366
|
}
|
|
9768
10367
|
}
|
|
@@ -9784,3 +10383,63 @@ void wsp_ggml_compute_forward_opt_step_adamw(
|
|
|
9784
10383
|
}
|
|
9785
10384
|
}
|
|
9786
10385
|
}
|
|
10386
|
+
|
|
10387
|
+
static void wsp_ggml_compute_forward_opt_step_sgd_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
10388
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
10389
|
+
const wsp_ggml_tensor * src0_grad = dst->src[1];
|
|
10390
|
+
const wsp_ggml_tensor * sgd_params = dst->src[2];
|
|
10391
|
+
|
|
10392
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src0_grad));
|
|
10393
|
+
WSP_GGML_ASSERT(wsp_ggml_nelements(sgd_params) == 2);
|
|
10394
|
+
|
|
10395
|
+
const int ith = params->ith;
|
|
10396
|
+
const int nth = params->nth;
|
|
10397
|
+
|
|
10398
|
+
const int nr = wsp_ggml_nrows(src0);
|
|
10399
|
+
|
|
10400
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
10401
|
+
WSP_GGML_ASSERT(nb00 == sizeof(float));
|
|
10402
|
+
|
|
10403
|
+
// rows per thread
|
|
10404
|
+
const int dr = (nr + nth - 1) / nth;
|
|
10405
|
+
|
|
10406
|
+
// row range for this thread
|
|
10407
|
+
const int ir0 = dr * ith;
|
|
10408
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
10409
|
+
|
|
10410
|
+
// using adamw param subset we care about - alpha, wd - could have a separate struct
|
|
10411
|
+
const float * sgd_params_ptr = wsp_ggml_get_data_f32(sgd_params);
|
|
10412
|
+
const float alpha = sgd_params_ptr[0];
|
|
10413
|
+
const float keep = 1.f - alpha * sgd_params_ptr[1];
|
|
10414
|
+
|
|
10415
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
10416
|
+
const int64_t i03 = ir / (ne02 * ne01);
|
|
10417
|
+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
|
|
10418
|
+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
|
|
10419
|
+
|
|
10420
|
+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
|
|
10421
|
+
|
|
10422
|
+
float * w = (float *) ((char *) src0->data + offset); // weight
|
|
10423
|
+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
|
10424
|
+
|
|
10425
|
+
for (int i00 = 0; i00 < ne00; ++i00) {
|
|
10426
|
+
w[i00] = w[i00] * keep - alpha * g[i00];
|
|
10427
|
+
}
|
|
10428
|
+
}
|
|
10429
|
+
}
|
|
10430
|
+
|
|
10431
|
+
void wsp_ggml_compute_forward_opt_step_sgd(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
10432
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
10433
|
+
|
|
10434
|
+
switch (src0->type) {
|
|
10435
|
+
case WSP_GGML_TYPE_F32:
|
|
10436
|
+
{
|
|
10437
|
+
wsp_ggml_compute_forward_opt_step_sgd_f32(params, dst);
|
|
10438
|
+
}
|
|
10439
|
+
break;
|
|
10440
|
+
default:
|
|
10441
|
+
{
|
|
10442
|
+
WSP_GGML_ABORT("fatal error - sgd is F32 only");
|
|
10443
|
+
}
|
|
10444
|
+
}
|
|
10445
|
+
}
|