whisper.rn 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +49 -18
- package/cpp/ggml-backend-impl.h +0 -3
- package/cpp/ggml-backend-reg.cpp +8 -0
- package/cpp/ggml-backend.cpp +0 -2
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
- package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
- package/cpp/ggml-cpu/ggml-cpu.c +67 -24
- package/cpp/ggml-cpu/ops.cpp +489 -364
- package/cpp/ggml-cpu/ops.h +4 -4
- package/cpp/ggml-cpu/repack.cpp +143 -29
- package/cpp/ggml-cpu/simd-mappings.h +25 -25
- package/cpp/ggml-cpu/unary-ops.cpp +151 -0
- package/cpp/ggml-cpu/unary-ops.h +7 -0
- package/cpp/ggml-cpu/vec.cpp +83 -0
- package/cpp/ggml-cpu/vec.h +20 -8
- package/cpp/ggml-impl.h +67 -2
- package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
- package/cpp/ggml-metal/ggml-metal-context.m +5 -6
- package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
- package/cpp/ggml-metal/ggml-metal-device.h +26 -1
- package/cpp/ggml-metal/ggml-metal-device.m +243 -28
- package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
- package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
- package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
- package/cpp/ggml-metal/ggml-metal.cpp +8 -3
- package/cpp/ggml-metal/ggml-metal.metal +12436 -0
- package/cpp/ggml.c +317 -4
- package/cpp/ggml.h +139 -0
- package/cpp/jsi/RNWhisperJSI.cpp +7 -2
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +8 -2
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -1
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
package/cpp/ggml-cpu/ops.cpp
CHANGED
|
@@ -7,8 +7,10 @@
|
|
|
7
7
|
#include "unary-ops.h"
|
|
8
8
|
#include "vec.h"
|
|
9
9
|
|
|
10
|
-
#include <
|
|
10
|
+
#include <cfloat>
|
|
11
11
|
#include <algorithm>
|
|
12
|
+
#include <cmath>
|
|
13
|
+
#include <functional>
|
|
12
14
|
|
|
13
15
|
// wsp_ggml_compute_forward_dup
|
|
14
16
|
|
|
@@ -1394,6 +1396,56 @@ void wsp_ggml_compute_forward_sum(
|
|
|
1394
1396
|
}
|
|
1395
1397
|
}
|
|
1396
1398
|
|
|
1399
|
+
// wsp_ggml_compute_forward_cumsum
|
|
1400
|
+
|
|
1401
|
+
static void wsp_ggml_compute_forward_cumsum_f32(
|
|
1402
|
+
const wsp_ggml_compute_params * params,
|
|
1403
|
+
wsp_ggml_tensor * dst) {
|
|
1404
|
+
|
|
1405
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
1406
|
+
|
|
1407
|
+
WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
1408
|
+
WSP_GGML_ASSERT(dst->nb[0] == sizeof(float));
|
|
1409
|
+
|
|
1410
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
1411
|
+
|
|
1412
|
+
WSP_GGML_ASSERT(ne0 == ne00);
|
|
1413
|
+
WSP_GGML_ASSERT(ne1 == ne01);
|
|
1414
|
+
WSP_GGML_ASSERT(ne2 == ne02);
|
|
1415
|
+
WSP_GGML_ASSERT(ne3 == ne03);
|
|
1416
|
+
|
|
1417
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
1418
|
+
|
|
1419
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
1420
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
1421
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
1422
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
1423
|
+
|
|
1424
|
+
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
1425
|
+
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
1426
|
+
|
|
1427
|
+
wsp_ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
|
1428
|
+
}
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
void wsp_ggml_compute_forward_cumsum(
|
|
1432
|
+
const wsp_ggml_compute_params * params,
|
|
1433
|
+
wsp_ggml_tensor * dst) {
|
|
1434
|
+
|
|
1435
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
1436
|
+
|
|
1437
|
+
switch (src0->type) {
|
|
1438
|
+
case WSP_GGML_TYPE_F32:
|
|
1439
|
+
{
|
|
1440
|
+
wsp_ggml_compute_forward_cumsum_f32(params, dst);
|
|
1441
|
+
} break;
|
|
1442
|
+
default:
|
|
1443
|
+
{
|
|
1444
|
+
WSP_GGML_ABORT("fatal error");
|
|
1445
|
+
}
|
|
1446
|
+
}
|
|
1447
|
+
}
|
|
1448
|
+
|
|
1397
1449
|
// wsp_ggml_compute_forward_sum_rows
|
|
1398
1450
|
|
|
1399
1451
|
static void wsp_ggml_compute_forward_sum_rows_f32(
|
|
@@ -2140,6 +2192,83 @@ static void wsp_ggml_compute_forward_gelu(
|
|
|
2140
2192
|
}
|
|
2141
2193
|
}
|
|
2142
2194
|
|
|
2195
|
+
// wsp_ggml_compute_fill
|
|
2196
|
+
|
|
2197
|
+
static void wsp_ggml_compute_forward_fill_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
2198
|
+
const float c = wsp_ggml_get_op_params_f32(dst, 0);
|
|
2199
|
+
|
|
2200
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
|
2201
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
|
2202
|
+
|
|
2203
|
+
const auto [ir0, ir1] = get_thread_range(params, dst);
|
|
2204
|
+
|
|
2205
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2206
|
+
const int64_t i03 = ir/(ne2*ne1);
|
|
2207
|
+
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
|
2208
|
+
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
|
2209
|
+
|
|
2210
|
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2211
|
+
|
|
2212
|
+
wsp_ggml_vec_set_f32(ne0, dst_ptr, c);
|
|
2213
|
+
}
|
|
2214
|
+
}
|
|
2215
|
+
|
|
2216
|
+
void wsp_ggml_compute_forward_fill(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
2217
|
+
wsp_ggml_compute_forward_fill_f32(params, dst);
|
|
2218
|
+
}
|
|
2219
|
+
|
|
2220
|
+
// wsp_ggml_compute_tri
|
|
2221
|
+
|
|
2222
|
+
static void wsp_ggml_compute_forward_tri_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
2223
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
2224
|
+
|
|
2225
|
+
const wsp_ggml_tri_type ttype = (wsp_ggml_tri_type) wsp_ggml_get_op_params_i32(dst, 0);
|
|
2226
|
+
|
|
2227
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
2228
|
+
|
|
2229
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
2230
|
+
|
|
2231
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
2232
|
+
|
|
2233
|
+
bool (*bipred)(int, int);
|
|
2234
|
+
|
|
2235
|
+
switch (ttype) {
|
|
2236
|
+
case WSP_GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
|
2237
|
+
case WSP_GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
|
2238
|
+
case WSP_GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
|
2239
|
+
case WSP_GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
|
2240
|
+
default: WSP_GGML_ABORT("invalid tri type");
|
|
2241
|
+
}
|
|
2242
|
+
|
|
2243
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2244
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
2245
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
2246
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
2247
|
+
|
|
2248
|
+
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
2249
|
+
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2250
|
+
|
|
2251
|
+
for (int i0 = 0; i0 < ne0; ++i0) {
|
|
2252
|
+
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
|
2253
|
+
}
|
|
2254
|
+
}
|
|
2255
|
+
}
|
|
2256
|
+
|
|
2257
|
+
void wsp_ggml_compute_forward_tri(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
2258
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
2259
|
+
|
|
2260
|
+
switch (src0->type) {
|
|
2261
|
+
case WSP_GGML_TYPE_F32:
|
|
2262
|
+
{
|
|
2263
|
+
wsp_ggml_compute_forward_tri_f32(params, dst);
|
|
2264
|
+
} break;
|
|
2265
|
+
default:
|
|
2266
|
+
{
|
|
2267
|
+
WSP_GGML_ABORT("fatal error");
|
|
2268
|
+
}
|
|
2269
|
+
}
|
|
2270
|
+
}
|
|
2271
|
+
|
|
2143
2272
|
// wsp_ggml_compute_forward_gelu_erf
|
|
2144
2273
|
|
|
2145
2274
|
static void wsp_ggml_compute_forward_gelu_erf_f32(
|
|
@@ -3467,31 +3596,27 @@ static void wsp_ggml_compute_forward_norm_f32(
|
|
|
3467
3596
|
|
|
3468
3597
|
WSP_GGML_ASSERT(eps >= 0.0f);
|
|
3469
3598
|
|
|
3470
|
-
// TODO: optimize
|
|
3471
3599
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
3472
3600
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
3473
3601
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
3474
3602
|
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
3475
3603
|
|
|
3476
|
-
|
|
3477
|
-
|
|
3478
|
-
sum += (wsp_ggml_float)x[i00];
|
|
3479
|
-
}
|
|
3480
|
-
|
|
3604
|
+
float sum = 0.0;
|
|
3605
|
+
wsp_ggml_vec_sum_f32(ne00, &sum, x);
|
|
3481
3606
|
float mean = sum/ne00;
|
|
3482
3607
|
|
|
3483
3608
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
3609
|
+
float variance = 0;
|
|
3484
3610
|
|
|
3485
|
-
|
|
3486
|
-
|
|
3487
|
-
|
|
3488
|
-
|
|
3489
|
-
|
|
3490
|
-
|
|
3611
|
+
#ifdef WSP_GGML_USE_ACCELERATE
|
|
3612
|
+
mean = -mean;
|
|
3613
|
+
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
|
3614
|
+
vDSP_measqv(y, 1, &variance, ne00);
|
|
3615
|
+
#else
|
|
3616
|
+
variance = wsp_ggml_vec_cvar_f32(ne00, y, x, mean);
|
|
3617
|
+
#endif //WSP_GGML_USE_ACCELERATE
|
|
3491
3618
|
|
|
3492
|
-
float variance = sum2/ne00;
|
|
3493
3619
|
const float scale = 1.0f/sqrtf(variance + eps);
|
|
3494
|
-
|
|
3495
3620
|
wsp_ggml_vec_scale_f32(ne00, y, scale);
|
|
3496
3621
|
}
|
|
3497
3622
|
}
|
|
@@ -4459,46 +4584,6 @@ void wsp_ggml_compute_forward_cont(
|
|
|
4459
4584
|
wsp_ggml_compute_forward_dup(params, dst);
|
|
4460
4585
|
}
|
|
4461
4586
|
|
|
4462
|
-
// wsp_ggml_compute_forward_reshape
|
|
4463
|
-
|
|
4464
|
-
void wsp_ggml_compute_forward_reshape(
|
|
4465
|
-
const wsp_ggml_compute_params * params,
|
|
4466
|
-
wsp_ggml_tensor * dst) {
|
|
4467
|
-
// NOP
|
|
4468
|
-
WSP_GGML_UNUSED(params);
|
|
4469
|
-
WSP_GGML_UNUSED(dst);
|
|
4470
|
-
}
|
|
4471
|
-
|
|
4472
|
-
// wsp_ggml_compute_forward_view
|
|
4473
|
-
|
|
4474
|
-
void wsp_ggml_compute_forward_view(
|
|
4475
|
-
const wsp_ggml_compute_params * params,
|
|
4476
|
-
wsp_ggml_tensor * dst) {
|
|
4477
|
-
// NOP
|
|
4478
|
-
WSP_GGML_UNUSED(params);
|
|
4479
|
-
WSP_GGML_UNUSED(dst);
|
|
4480
|
-
}
|
|
4481
|
-
|
|
4482
|
-
// wsp_ggml_compute_forward_permute
|
|
4483
|
-
|
|
4484
|
-
void wsp_ggml_compute_forward_permute(
|
|
4485
|
-
const wsp_ggml_compute_params * params,
|
|
4486
|
-
wsp_ggml_tensor * dst) {
|
|
4487
|
-
// NOP
|
|
4488
|
-
WSP_GGML_UNUSED(params);
|
|
4489
|
-
WSP_GGML_UNUSED(dst);
|
|
4490
|
-
}
|
|
4491
|
-
|
|
4492
|
-
// wsp_ggml_compute_forward_transpose
|
|
4493
|
-
|
|
4494
|
-
void wsp_ggml_compute_forward_transpose(
|
|
4495
|
-
const wsp_ggml_compute_params * params,
|
|
4496
|
-
wsp_ggml_tensor * dst) {
|
|
4497
|
-
// NOP
|
|
4498
|
-
WSP_GGML_UNUSED(params);
|
|
4499
|
-
WSP_GGML_UNUSED(dst);
|
|
4500
|
-
}
|
|
4501
|
-
|
|
4502
4587
|
// wsp_ggml_compute_forward_get_rows
|
|
4503
4588
|
|
|
4504
4589
|
static void wsp_ggml_compute_forward_get_rows_q(
|
|
@@ -5478,7 +5563,7 @@ static void wsp_ggml_rope_cache_init(
|
|
|
5478
5563
|
}
|
|
5479
5564
|
|
|
5480
5565
|
static void wsp_ggml_mrope_cache_init(
|
|
5481
|
-
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
|
5566
|
+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
|
|
5482
5567
|
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5483
5568
|
float * cache, float sin_sign, float theta_scale) {
|
|
5484
5569
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
@@ -5513,14 +5598,26 @@ static void wsp_ggml_mrope_cache_init(
|
|
|
5513
5598
|
}
|
|
5514
5599
|
|
|
5515
5600
|
float theta = theta_t;
|
|
5516
|
-
if (
|
|
5517
|
-
|
|
5518
|
-
|
|
5519
|
-
|
|
5520
|
-
|
|
5521
|
-
|
|
5522
|
-
|
|
5523
|
-
|
|
5601
|
+
if (is_imrope) { // qwen3vl apply interleaved mrope
|
|
5602
|
+
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
|
5603
|
+
theta = theta_h;
|
|
5604
|
+
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
|
5605
|
+
theta = theta_w;
|
|
5606
|
+
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
|
5607
|
+
theta = theta_t;
|
|
5608
|
+
} else {
|
|
5609
|
+
theta = theta_e;
|
|
5610
|
+
}
|
|
5611
|
+
} else {
|
|
5612
|
+
if (sector >= sections[0] && sector < sec_w) {
|
|
5613
|
+
theta = theta_h;
|
|
5614
|
+
}
|
|
5615
|
+
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
|
5616
|
+
theta = theta_w;
|
|
5617
|
+
}
|
|
5618
|
+
else if (sector >= sec_w + sections[2]) {
|
|
5619
|
+
theta = theta_e;
|
|
5620
|
+
}
|
|
5524
5621
|
}
|
|
5525
5622
|
|
|
5526
5623
|
rope_yarn(
|
|
@@ -5535,193 +5632,28 @@ static void wsp_ggml_mrope_cache_init(
|
|
|
5535
5632
|
}
|
|
5536
5633
|
}
|
|
5537
5634
|
|
|
5538
|
-
static void wsp_ggml_compute_forward_rope_f32(
|
|
5539
|
-
const wsp_ggml_compute_params * params,
|
|
5540
|
-
wsp_ggml_tensor * dst,
|
|
5541
|
-
const bool forward) {
|
|
5542
|
-
|
|
5543
|
-
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
5544
|
-
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
5545
|
-
const wsp_ggml_tensor * src2 = dst->src[2];
|
|
5546
|
-
|
|
5547
|
-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
5548
|
-
int sections[4];
|
|
5549
|
-
|
|
5550
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
5551
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
5552
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
5553
|
-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
5554
|
-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
5555
5635
|
|
|
5556
|
-
|
|
5557
|
-
|
|
5558
|
-
|
|
5559
|
-
|
|
5560
|
-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
5561
|
-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
5562
|
-
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
5563
|
-
|
|
5564
|
-
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
5565
|
-
|
|
5566
|
-
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
5567
|
-
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
5568
|
-
|
|
5569
|
-
WSP_GGML_ASSERT(nb00 == sizeof(float));
|
|
5570
|
-
|
|
5571
|
-
const int ith = params->ith;
|
|
5572
|
-
const int nth = params->nth;
|
|
5573
|
-
|
|
5574
|
-
const int nr = wsp_ggml_nrows(dst);
|
|
5575
|
-
|
|
5576
|
-
WSP_GGML_ASSERT(n_dims <= ne0);
|
|
5577
|
-
WSP_GGML_ASSERT(n_dims % 2 == 0);
|
|
5578
|
-
|
|
5579
|
-
// rows per thread
|
|
5580
|
-
const int dr = (nr + nth - 1)/nth;
|
|
5581
|
-
|
|
5582
|
-
// row range for this thread
|
|
5583
|
-
const int ir0 = dr*ith;
|
|
5584
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
5585
|
-
|
|
5586
|
-
// row index used to determine which thread to use
|
|
5587
|
-
int ir = 0;
|
|
5588
|
-
|
|
5589
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
5590
|
-
|
|
5591
|
-
float corr_dims[2];
|
|
5592
|
-
wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
5593
|
-
|
|
5594
|
-
const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
|
|
5595
|
-
const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, multimodal rotary position embedding
|
|
5596
|
-
const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
|
|
5597
|
-
|
|
5598
|
-
if (is_mrope) {
|
|
5599
|
-
WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5600
|
-
}
|
|
5601
|
-
|
|
5602
|
-
if (is_vision) {
|
|
5603
|
-
WSP_GGML_ASSERT(n_dims == ne0/2);
|
|
5604
|
-
}
|
|
5605
|
-
|
|
5606
|
-
const float * freq_factors = NULL;
|
|
5607
|
-
if (src2 != NULL) {
|
|
5608
|
-
WSP_GGML_ASSERT(src2->type == WSP_GGML_TYPE_F32);
|
|
5609
|
-
WSP_GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
|
5610
|
-
freq_factors = (const float *) src2->data;
|
|
5611
|
-
}
|
|
5612
|
-
|
|
5613
|
-
// backward process uses inverse rotation by cos and sin.
|
|
5614
|
-
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
5615
|
-
// this essentially just switches the sign of sin.
|
|
5616
|
-
const float sin_sign = forward ? 1.0f : -1.0f;
|
|
5617
|
-
|
|
5618
|
-
const int32_t * pos = (const int32_t *) src1->data;
|
|
5619
|
-
|
|
5620
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
5621
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
5622
|
-
|
|
5623
|
-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5624
|
-
if (!is_mrope) {
|
|
5625
|
-
const int64_t p = pos[i2];
|
|
5626
|
-
wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5627
|
-
}
|
|
5628
|
-
else {
|
|
5629
|
-
const int64_t p_t = pos[i2];
|
|
5630
|
-
const int64_t p_h = pos[i2 + ne2];
|
|
5631
|
-
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5632
|
-
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5633
|
-
wsp_ggml_mrope_cache_init(
|
|
5634
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5635
|
-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5636
|
-
}
|
|
5637
|
-
|
|
5638
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5639
|
-
if (ir++ < ir0) continue;
|
|
5640
|
-
if (ir > ir1) break;
|
|
5641
|
-
|
|
5642
|
-
if (is_neox || is_mrope) {
|
|
5643
|
-
if (is_vision){
|
|
5644
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5645
|
-
const int64_t ic = i0/2;
|
|
5646
|
-
|
|
5647
|
-
const float cos_theta = cache[i0 + 0];
|
|
5648
|
-
const float sin_theta = cache[i0 + 1];
|
|
5649
|
-
|
|
5650
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5651
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5652
|
-
|
|
5653
|
-
const float x0 = src[0];
|
|
5654
|
-
const float x1 = src[n_dims];
|
|
5655
|
-
|
|
5656
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5657
|
-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
5658
|
-
}
|
|
5659
|
-
} else {
|
|
5660
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5661
|
-
const int64_t ic = i0/2;
|
|
5636
|
+
template<typename T>
|
|
5637
|
+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
|
|
5638
|
+
for (int64_t i0 = 0; i0 < n; i0 += 2) {
|
|
5639
|
+
const int64_t ic = i0/scale; // hack for WSP_GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
|
|
5662
5640
|
|
|
5663
|
-
|
|
5664
|
-
|
|
5641
|
+
const float cos_theta = cache[i0 + 0];
|
|
5642
|
+
const float sin_theta = cache[i0 + 1];
|
|
5665
5643
|
|
|
5666
|
-
|
|
5667
|
-
|
|
5644
|
+
const T * const src = src_data + ic;
|
|
5645
|
+
T * dst = dst_data + ic;
|
|
5668
5646
|
|
|
5669
|
-
|
|
5670
|
-
|
|
5647
|
+
const float x0 = type_conversion_table<T>::to_f32(src[0]);
|
|
5648
|
+
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
|
|
5671
5649
|
|
|
5672
|
-
|
|
5673
|
-
|
|
5674
|
-
|
|
5675
|
-
}
|
|
5676
|
-
} else {
|
|
5677
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5678
|
-
const float cos_theta = cache[i0 + 0];
|
|
5679
|
-
const float sin_theta = cache[i0 + 1];
|
|
5680
|
-
|
|
5681
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5682
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5683
|
-
|
|
5684
|
-
const float x0 = src[0];
|
|
5685
|
-
const float x1 = src[1];
|
|
5686
|
-
|
|
5687
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5688
|
-
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
5689
|
-
}
|
|
5690
|
-
}
|
|
5691
|
-
|
|
5692
|
-
if (is_vision) {
|
|
5693
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5694
|
-
const int64_t ic = i0/2;
|
|
5695
|
-
|
|
5696
|
-
const float cos_theta = cache[i0 + 0];
|
|
5697
|
-
const float sin_theta = cache[i0 + 1];
|
|
5698
|
-
|
|
5699
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5700
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5701
|
-
|
|
5702
|
-
const float x0 = src[0];
|
|
5703
|
-
const float x1 = src[n_dims];
|
|
5704
|
-
|
|
5705
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5706
|
-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
5707
|
-
}
|
|
5708
|
-
} else {
|
|
5709
|
-
// fill the remain channels with data from src tensor
|
|
5710
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5711
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5712
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5713
|
-
|
|
5714
|
-
dst_data[0] = src[0];
|
|
5715
|
-
dst_data[1] = src[1];
|
|
5716
|
-
}
|
|
5717
|
-
}
|
|
5718
|
-
}
|
|
5719
|
-
}
|
|
5720
|
-
}
|
|
5650
|
+
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
|
|
5651
|
+
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
|
|
5652
|
+
}
|
|
5721
5653
|
}
|
|
5722
5654
|
|
|
5723
|
-
|
|
5724
|
-
static void
|
|
5655
|
+
template<typename T> //float or wsp_ggml_fp16_t
|
|
5656
|
+
static void wsp_ggml_compute_forward_rope_flt(
|
|
5725
5657
|
const wsp_ggml_compute_params * params,
|
|
5726
5658
|
wsp_ggml_tensor * dst,
|
|
5727
5659
|
const bool forward) {
|
|
@@ -5730,6 +5662,9 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5730
5662
|
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
5731
5663
|
const wsp_ggml_tensor * src2 = dst->src[2];
|
|
5732
5664
|
|
|
5665
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32 || src0->type == WSP_GGML_TYPE_F16);
|
|
5666
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_I32);
|
|
5667
|
+
|
|
5733
5668
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
5734
5669
|
int sections[4];
|
|
5735
5670
|
|
|
@@ -5738,6 +5673,7 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5738
5673
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
5739
5674
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
5740
5675
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
5676
|
+
|
|
5741
5677
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
5742
5678
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
5743
5679
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
@@ -5746,13 +5682,13 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5746
5682
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
5747
5683
|
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
5748
5684
|
|
|
5749
|
-
|
|
5750
5685
|
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
5751
5686
|
|
|
5752
5687
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
5753
5688
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
5754
5689
|
|
|
5755
|
-
WSP_GGML_ASSERT(nb0 ==
|
|
5690
|
+
WSP_GGML_ASSERT(nb0 == nb00);
|
|
5691
|
+
WSP_GGML_ASSERT(nb0 == sizeof(T));
|
|
5756
5692
|
|
|
5757
5693
|
const int ith = params->ith;
|
|
5758
5694
|
const int nth = params->nth;
|
|
@@ -5777,11 +5713,11 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5777
5713
|
float corr_dims[2];
|
|
5778
5714
|
wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
5779
5715
|
|
|
5780
|
-
const bool
|
|
5781
|
-
const bool
|
|
5716
|
+
const bool is_imrope = mode == WSP_GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
|
5717
|
+
const bool mrope_used = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
|
5782
5718
|
const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
|
|
5783
5719
|
|
|
5784
|
-
if (
|
|
5720
|
+
if (mrope_used) {
|
|
5785
5721
|
WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5786
5722
|
}
|
|
5787
5723
|
|
|
@@ -5803,11 +5739,11 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5803
5739
|
|
|
5804
5740
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
5805
5741
|
|
|
5806
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
5807
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
5742
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
5743
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
5808
5744
|
|
|
5809
5745
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5810
|
-
if (!
|
|
5746
|
+
if (!mrope_used) {
|
|
5811
5747
|
const int64_t p = pos[i2];
|
|
5812
5748
|
wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5813
5749
|
}
|
|
@@ -5817,90 +5753,44 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5817
5753
|
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5818
5754
|
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5819
5755
|
wsp_ggml_mrope_cache_init(
|
|
5820
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5756
|
+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
5821
5757
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5822
5758
|
}
|
|
5823
5759
|
|
|
5824
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
5760
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5825
5761
|
if (ir++ < ir0) continue;
|
|
5826
5762
|
if (ir > ir1) break;
|
|
5827
5763
|
|
|
5828
|
-
|
|
5829
|
-
|
|
5830
|
-
|
|
5831
|
-
|
|
5832
|
-
|
|
5833
|
-
|
|
5834
|
-
|
|
5835
|
-
|
|
5836
|
-
|
|
5837
|
-
|
|
5838
|
-
|
|
5839
|
-
|
|
5840
|
-
|
|
5841
|
-
|
|
5842
|
-
|
|
5843
|
-
|
|
5844
|
-
|
|
5845
|
-
} else {
|
|
5846
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5847
|
-
const int64_t ic = i0/2;
|
|
5848
|
-
|
|
5849
|
-
const float cos_theta = cache[i0 + 0];
|
|
5850
|
-
const float sin_theta = cache[i0 + 1];
|
|
5851
|
-
|
|
5852
|
-
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5853
|
-
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5854
|
-
|
|
5855
|
-
const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5856
|
-
const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
|
5857
|
-
|
|
5858
|
-
dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5859
|
-
dst_data[n_dims/2] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5860
|
-
}
|
|
5861
|
-
}
|
|
5862
|
-
} else {
|
|
5863
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5864
|
-
const float cos_theta = cache[i0 + 0];
|
|
5865
|
-
const float sin_theta = cache[i0 + 1];
|
|
5866
|
-
|
|
5867
|
-
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5868
|
-
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5869
|
-
|
|
5870
|
-
const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5871
|
-
const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[1]);
|
|
5872
|
-
|
|
5873
|
-
dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5874
|
-
dst_data[1] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5875
|
-
}
|
|
5764
|
+
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
5765
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
5766
|
+
|
|
5767
|
+
switch (mode) {
|
|
5768
|
+
case WSP_GGML_ROPE_TYPE_NORMAL:
|
|
5769
|
+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
|
|
5770
|
+
break;
|
|
5771
|
+
case WSP_GGML_ROPE_TYPE_NEOX:
|
|
5772
|
+
case WSP_GGML_ROPE_TYPE_MROPE:
|
|
5773
|
+
case WSP_GGML_ROPE_TYPE_IMROPE:
|
|
5774
|
+
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
|
|
5775
|
+
break;
|
|
5776
|
+
case WSP_GGML_ROPE_TYPE_VISION:
|
|
5777
|
+
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
|
|
5778
|
+
break;
|
|
5779
|
+
default:
|
|
5780
|
+
WSP_GGML_ABORT("rope type not supported");
|
|
5876
5781
|
}
|
|
5877
5782
|
|
|
5878
|
-
if (is_vision) {
|
|
5879
|
-
|
|
5880
|
-
const int64_t ic = i0/2;
|
|
5881
|
-
|
|
5882
|
-
const float cos_theta = cache[i0 + 0];
|
|
5883
|
-
const float sin_theta = cache[i0 + 1];
|
|
5884
|
-
|
|
5885
|
-
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5886
|
-
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5887
|
-
|
|
5888
|
-
const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5889
|
-
const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
5890
|
-
|
|
5891
|
-
dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5892
|
-
dst_data[n_dims] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5893
|
-
}
|
|
5894
|
-
} else {
|
|
5783
|
+
if (!is_vision) {
|
|
5784
|
+
// fill the remain channels with data from src tensor
|
|
5895
5785
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5896
|
-
const
|
|
5897
|
-
|
|
5786
|
+
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5787
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5898
5788
|
|
|
5899
5789
|
dst_data[0] = src[0];
|
|
5900
5790
|
dst_data[1] = src[1];
|
|
5901
5791
|
}
|
|
5902
5792
|
}
|
|
5903
|
-
}
|
|
5793
|
+
} //attn-heads
|
|
5904
5794
|
}
|
|
5905
5795
|
}
|
|
5906
5796
|
}
|
|
@@ -5914,11 +5804,11 @@ void wsp_ggml_compute_forward_rope(
|
|
|
5914
5804
|
switch (src0->type) {
|
|
5915
5805
|
case WSP_GGML_TYPE_F16:
|
|
5916
5806
|
{
|
|
5917
|
-
|
|
5807
|
+
wsp_ggml_compute_forward_rope_flt<wsp_ggml_fp16_t>(params, dst, true);
|
|
5918
5808
|
} break;
|
|
5919
5809
|
case WSP_GGML_TYPE_F32:
|
|
5920
5810
|
{
|
|
5921
|
-
|
|
5811
|
+
wsp_ggml_compute_forward_rope_flt<float>(params, dst, true);
|
|
5922
5812
|
} break;
|
|
5923
5813
|
default:
|
|
5924
5814
|
{
|
|
@@ -5938,11 +5828,11 @@ void wsp_ggml_compute_forward_rope_back(
|
|
|
5938
5828
|
switch (src0->type) {
|
|
5939
5829
|
case WSP_GGML_TYPE_F16:
|
|
5940
5830
|
{
|
|
5941
|
-
|
|
5831
|
+
wsp_ggml_compute_forward_rope_flt<wsp_ggml_fp16_t>(params, dst, false);
|
|
5942
5832
|
} break;
|
|
5943
5833
|
case WSP_GGML_TYPE_F32:
|
|
5944
5834
|
{
|
|
5945
|
-
|
|
5835
|
+
wsp_ggml_compute_forward_rope_flt<float>(params, dst, false);
|
|
5946
5836
|
} break;
|
|
5947
5837
|
default:
|
|
5948
5838
|
{
|
|
@@ -7074,7 +6964,11 @@ static void wsp_ggml_compute_forward_conv_2d_dw_cwhn(
|
|
|
7074
6964
|
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
|
|
7075
6965
|
|
|
7076
6966
|
#ifdef WSP_GGML_SIMD
|
|
7077
|
-
|
|
6967
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
6968
|
+
const int64_t pkg_size = svcntw();
|
|
6969
|
+
#else
|
|
6970
|
+
const int64_t pkg_size = WSP_GGML_F32_EPR;
|
|
6971
|
+
#endif
|
|
7078
6972
|
const int64_t pkg_count = c / pkg_size;
|
|
7079
6973
|
const int64_t c_pkg_end = pkg_count * pkg_size;
|
|
7080
6974
|
#else
|
|
@@ -7497,10 +7391,17 @@ static void wsp_ggml_compute_forward_upscale_f32(
|
|
|
7497
7391
|
float sf1 = (float)ne1/src0->ne[1];
|
|
7498
7392
|
float sf2 = (float)ne2/src0->ne[2];
|
|
7499
7393
|
float sf3 = (float)ne3/src0->ne[3];
|
|
7394
|
+
float pixel_offset = 0.5f;
|
|
7500
7395
|
|
|
7501
7396
|
const int32_t mode_flags = wsp_ggml_get_op_params_i32(dst, 0);
|
|
7502
7397
|
const wsp_ggml_scale_mode mode = (wsp_ggml_scale_mode) (mode_flags & 0xFF);
|
|
7503
7398
|
|
|
7399
|
+
if (mode_flags & WSP_GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7400
|
+
pixel_offset = 0.0f;
|
|
7401
|
+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
|
7402
|
+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
|
7403
|
+
}
|
|
7404
|
+
|
|
7504
7405
|
if (mode == WSP_GGML_SCALE_MODE_NEAREST) {
|
|
7505
7406
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7506
7407
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7520,13 +7421,6 @@ static void wsp_ggml_compute_forward_upscale_f32(
|
|
|
7520
7421
|
}
|
|
7521
7422
|
}
|
|
7522
7423
|
} else if (mode == WSP_GGML_SCALE_MODE_BILINEAR) {
|
|
7523
|
-
float pixel_offset = 0.5f;
|
|
7524
|
-
if (mode_flags & WSP_GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7525
|
-
pixel_offset = 0.0f;
|
|
7526
|
-
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
|
7527
|
-
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
|
7528
|
-
}
|
|
7529
|
-
|
|
7530
7424
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7531
7425
|
const int64_t i03 = i3 / sf3;
|
|
7532
7426
|
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
@@ -7561,6 +7455,51 @@ static void wsp_ggml_compute_forward_upscale_f32(
|
|
|
7561
7455
|
|
|
7562
7456
|
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
|
|
7563
7457
|
|
|
7458
|
+
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7459
|
+
*y_dst = val;
|
|
7460
|
+
}
|
|
7461
|
+
}
|
|
7462
|
+
}
|
|
7463
|
+
}
|
|
7464
|
+
} else if (mode == WSP_GGML_SCALE_MODE_BICUBIC) {
|
|
7465
|
+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
|
7466
|
+
const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
|
|
7467
|
+
auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
|
|
7468
|
+
auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
|
|
7469
|
+
auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
|
|
7470
|
+
const float w0 = weight2(x + 1);
|
|
7471
|
+
const float w1 = weight1(x + 0);
|
|
7472
|
+
const float w2 = weight1(1 - x);
|
|
7473
|
+
const float w3 = weight2(2 - x);
|
|
7474
|
+
return p0*w0 + p1*w1 + p2*w2 + p3*w3;
|
|
7475
|
+
};
|
|
7476
|
+
|
|
7477
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7478
|
+
const int64_t i03 = i3 / sf3;
|
|
7479
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7480
|
+
const int64_t i02 = i2 / sf2;
|
|
7481
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7482
|
+
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
|
|
7483
|
+
const int64_t y0 = (int64_t)floorf(y);
|
|
7484
|
+
const float dy = y - (float)y0;
|
|
7485
|
+
|
|
7486
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7487
|
+
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
|
|
7488
|
+
const int64_t x0 = (int64_t)floorf(x);
|
|
7489
|
+
const float dx = x - (float)x0;
|
|
7490
|
+
|
|
7491
|
+
auto p = [=](int64_t x_off, int64_t y_off) -> float {
|
|
7492
|
+
int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
|
|
7493
|
+
int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
|
|
7494
|
+
return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
7495
|
+
};
|
|
7496
|
+
|
|
7497
|
+
const float val = bicubic(
|
|
7498
|
+
bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
|
|
7499
|
+
bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
|
|
7500
|
+
bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
|
|
7501
|
+
bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
|
|
7502
|
+
|
|
7564
7503
|
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7565
7504
|
*y_dst = val;
|
|
7566
7505
|
}
|
|
@@ -7854,6 +7793,18 @@ void wsp_ggml_compute_forward_timestep_embedding(
|
|
|
7854
7793
|
|
|
7855
7794
|
// wsp_ggml_compute_forward_argsort
|
|
7856
7795
|
|
|
7796
|
+
template<enum wsp_ggml_sort_order order>
|
|
7797
|
+
struct argsort_cmp {
|
|
7798
|
+
const float * data;
|
|
7799
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
7800
|
+
if constexpr (order == WSP_GGML_SORT_ORDER_ASC) {
|
|
7801
|
+
return data[a] < data[b];
|
|
7802
|
+
} else {
|
|
7803
|
+
return data[a] > data[b];
|
|
7804
|
+
}
|
|
7805
|
+
}
|
|
7806
|
+
};
|
|
7807
|
+
|
|
7857
7808
|
static void wsp_ggml_compute_forward_argsort_f32(
|
|
7858
7809
|
const wsp_ggml_compute_params * params,
|
|
7859
7810
|
wsp_ggml_tensor * dst) {
|
|
@@ -7872,23 +7823,25 @@ static void wsp_ggml_compute_forward_argsort_f32(
|
|
|
7872
7823
|
wsp_ggml_sort_order order = (wsp_ggml_sort_order) wsp_ggml_get_op_params_i32(dst, 0);
|
|
7873
7824
|
|
|
7874
7825
|
for (int64_t i = ith; i < nr; i += nth) {
|
|
7875
|
-
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7876
7826
|
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
7877
7827
|
|
|
7828
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7829
|
+
|
|
7878
7830
|
for (int64_t j = 0; j < ne0; j++) {
|
|
7879
7831
|
dst_data[j] = j;
|
|
7880
7832
|
}
|
|
7881
7833
|
|
|
7882
|
-
|
|
7883
|
-
|
|
7884
|
-
|
|
7885
|
-
|
|
7886
|
-
|
|
7887
|
-
|
|
7888
|
-
|
|
7889
|
-
|
|
7890
|
-
|
|
7891
|
-
|
|
7834
|
+
switch (order) {
|
|
7835
|
+
case WSP_GGML_SORT_ORDER_ASC:
|
|
7836
|
+
std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_ASC>{src_data});
|
|
7837
|
+
break;
|
|
7838
|
+
|
|
7839
|
+
case WSP_GGML_SORT_ORDER_DESC:
|
|
7840
|
+
std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_DESC>{src_data});
|
|
7841
|
+
break;
|
|
7842
|
+
|
|
7843
|
+
default:
|
|
7844
|
+
WSP_GGML_ABORT("invalid sort order");
|
|
7892
7845
|
}
|
|
7893
7846
|
}
|
|
7894
7847
|
}
|
|
@@ -7913,10 +7866,10 @@ void wsp_ggml_compute_forward_argsort(
|
|
|
7913
7866
|
|
|
7914
7867
|
// wsp_ggml_compute_forward_flash_attn_ext
|
|
7915
7868
|
|
|
7916
|
-
static void
|
|
7869
|
+
static void wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
7917
7870
|
const wsp_ggml_compute_params * params,
|
|
7918
|
-
wsp_ggml_tensor * dst
|
|
7919
|
-
|
|
7871
|
+
wsp_ggml_tensor * dst,
|
|
7872
|
+
int ir0, int ir1) {
|
|
7920
7873
|
const wsp_ggml_tensor * q = dst->src[0];
|
|
7921
7874
|
const wsp_ggml_tensor * k = dst->src[1];
|
|
7922
7875
|
const wsp_ggml_tensor * v = dst->src[2];
|
|
@@ -7932,9 +7885,6 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7932
7885
|
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
7933
7886
|
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
7934
7887
|
|
|
7935
|
-
const int ith = params->ith;
|
|
7936
|
-
const int nth = params->nth;
|
|
7937
|
-
|
|
7938
7888
|
const int64_t DK = nek0;
|
|
7939
7889
|
const int64_t DV = nev0;
|
|
7940
7890
|
const int64_t N = neq1;
|
|
@@ -7968,16 +7918,6 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7968
7918
|
|
|
7969
7919
|
// parallelize by q rows using wsp_ggml_vec_dot_f32
|
|
7970
7920
|
|
|
7971
|
-
// total rows in q
|
|
7972
|
-
const int nr = neq1*neq2*neq3;
|
|
7973
|
-
|
|
7974
|
-
// rows per thread
|
|
7975
|
-
const int dr = (nr + nth - 1)/nth;
|
|
7976
|
-
|
|
7977
|
-
// row range for this thread
|
|
7978
|
-
const int ir0 = dr*ith;
|
|
7979
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
7980
|
-
|
|
7981
7921
|
float scale = 1.0f;
|
|
7982
7922
|
float max_bias = 0.0f;
|
|
7983
7923
|
float logit_softcap = 0.0f;
|
|
@@ -8004,6 +7944,8 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8004
7944
|
WSP_GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
|
8005
7945
|
WSP_GGML_ASSERT((v->type == WSP_GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
|
8006
7946
|
|
|
7947
|
+
int ith = params->ith;
|
|
7948
|
+
|
|
8007
7949
|
// loop over n_batch and n_head
|
|
8008
7950
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
8009
7951
|
// q indices
|
|
@@ -8135,7 +8077,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8135
8077
|
}
|
|
8136
8078
|
|
|
8137
8079
|
// V /= S
|
|
8138
|
-
const float S_inv = 1.0f/S;
|
|
8080
|
+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
8139
8081
|
wsp_ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
8140
8082
|
|
|
8141
8083
|
// dst indices
|
|
@@ -8151,6 +8093,91 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8151
8093
|
}
|
|
8152
8094
|
}
|
|
8153
8095
|
|
|
8096
|
+
static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
8097
|
+
const wsp_ggml_compute_params * params,
|
|
8098
|
+
wsp_ggml_tensor * dst) {
|
|
8099
|
+
|
|
8100
|
+
const wsp_ggml_tensor * q = dst->src[0];
|
|
8101
|
+
const wsp_ggml_tensor * k = dst->src[1];
|
|
8102
|
+
const wsp_ggml_tensor * v = dst->src[2];
|
|
8103
|
+
|
|
8104
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
8105
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8106
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
8107
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
8108
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
8109
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
8110
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8111
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8112
|
+
|
|
8113
|
+
const int64_t DK = nek0;
|
|
8114
|
+
const int64_t DV = nev0;
|
|
8115
|
+
const int64_t N = neq1;
|
|
8116
|
+
|
|
8117
|
+
WSP_GGML_ASSERT(ne0 == DV);
|
|
8118
|
+
WSP_GGML_ASSERT(ne2 == N);
|
|
8119
|
+
|
|
8120
|
+
// input tensor rows must be contiguous
|
|
8121
|
+
WSP_GGML_ASSERT(nbq0 == wsp_ggml_type_size(q->type));
|
|
8122
|
+
WSP_GGML_ASSERT(nbk0 == wsp_ggml_type_size(k->type));
|
|
8123
|
+
WSP_GGML_ASSERT(nbv0 == wsp_ggml_type_size(v->type));
|
|
8124
|
+
|
|
8125
|
+
WSP_GGML_ASSERT(neq0 == DK);
|
|
8126
|
+
WSP_GGML_ASSERT(nek0 == DK);
|
|
8127
|
+
WSP_GGML_ASSERT(nev0 == DV);
|
|
8128
|
+
|
|
8129
|
+
WSP_GGML_ASSERT(neq1 == N);
|
|
8130
|
+
|
|
8131
|
+
// dst cannot be transposed or permuted
|
|
8132
|
+
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
8133
|
+
WSP_GGML_ASSERT(nb0 <= nb1);
|
|
8134
|
+
WSP_GGML_ASSERT(nb1 <= nb2);
|
|
8135
|
+
WSP_GGML_ASSERT(nb2 <= nb3);
|
|
8136
|
+
|
|
8137
|
+
// parallelize by q rows using wsp_ggml_vec_dot_f32
|
|
8138
|
+
|
|
8139
|
+
// total rows in q
|
|
8140
|
+
const int64_t nr = neq1*neq2*neq3;
|
|
8141
|
+
|
|
8142
|
+
// rows per thread
|
|
8143
|
+
const int ith = params->ith;
|
|
8144
|
+
const int nth = params->nth;
|
|
8145
|
+
|
|
8146
|
+
// disable for NUMA
|
|
8147
|
+
const bool disable_chunking = wsp_ggml_is_numa();
|
|
8148
|
+
|
|
8149
|
+
// 4x chunks per thread
|
|
8150
|
+
int nth_scaled = nth * 4;
|
|
8151
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
8152
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
8153
|
+
|
|
8154
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
8155
|
+
nchunk = nth;
|
|
8156
|
+
}
|
|
8157
|
+
|
|
8158
|
+
if (ith == 0) {
|
|
8159
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
8160
|
+
wsp_ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
8161
|
+
}
|
|
8162
|
+
|
|
8163
|
+
wsp_ggml_barrier(params->threadpool);
|
|
8164
|
+
|
|
8165
|
+
// The number of elements in each chunk
|
|
8166
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
8167
|
+
|
|
8168
|
+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
8169
|
+
int current_chunk = ith;
|
|
8170
|
+
|
|
8171
|
+
while (current_chunk < nchunk) {
|
|
8172
|
+
const int64_t ir0 = dr * current_chunk;
|
|
8173
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
8174
|
+
|
|
8175
|
+
wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
|
8176
|
+
|
|
8177
|
+
current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
8178
|
+
}
|
|
8179
|
+
}
|
|
8180
|
+
|
|
8154
8181
|
void wsp_ggml_compute_forward_flash_attn_ext(
|
|
8155
8182
|
const wsp_ggml_compute_params * params,
|
|
8156
8183
|
wsp_ggml_tensor * dst) {
|
|
@@ -8637,7 +8664,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8637
8664
|
// n_head
|
|
8638
8665
|
for (int h = ih0; h < ih1; ++h) {
|
|
8639
8666
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8640
|
-
const float dt_soft_plus =
|
|
8667
|
+
const float dt_soft_plus = wsp_ggml_compute_softplus_f32(dt[h]);
|
|
8641
8668
|
const float dA = expf(dt_soft_plus * A[h]);
|
|
8642
8669
|
const int g = h / (nh / ng); // repeat_interleave
|
|
8643
8670
|
|
|
@@ -8734,7 +8761,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8734
8761
|
// n_head
|
|
8735
8762
|
for (int h = ih0; h < ih1; ++h) {
|
|
8736
8763
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8737
|
-
const float dt_soft_plus =
|
|
8764
|
+
const float dt_soft_plus = wsp_ggml_compute_softplus_f32(dt[h]);
|
|
8738
8765
|
const int g = h / (nh / ng); // repeat_interleave
|
|
8739
8766
|
|
|
8740
8767
|
// dim
|
|
@@ -8997,6 +9024,34 @@ void wsp_ggml_compute_forward_unary(
|
|
|
8997
9024
|
{
|
|
8998
9025
|
wsp_ggml_compute_forward_exp(params, dst);
|
|
8999
9026
|
} break;
|
|
9027
|
+
case WSP_GGML_UNARY_OP_FLOOR:
|
|
9028
|
+
{
|
|
9029
|
+
wsp_ggml_compute_forward_floor(params, dst);
|
|
9030
|
+
} break;
|
|
9031
|
+
case WSP_GGML_UNARY_OP_CEIL:
|
|
9032
|
+
{
|
|
9033
|
+
wsp_ggml_compute_forward_ceil(params, dst);
|
|
9034
|
+
} break;
|
|
9035
|
+
case WSP_GGML_UNARY_OP_ROUND:
|
|
9036
|
+
{
|
|
9037
|
+
wsp_ggml_compute_forward_round(params, dst);
|
|
9038
|
+
} break;
|
|
9039
|
+
case WSP_GGML_UNARY_OP_TRUNC:
|
|
9040
|
+
{
|
|
9041
|
+
wsp_ggml_compute_forward_trunc(params, dst);
|
|
9042
|
+
} break;
|
|
9043
|
+
case WSP_GGML_UNARY_OP_XIELU:
|
|
9044
|
+
{
|
|
9045
|
+
wsp_ggml_compute_forward_xielu(params, dst);
|
|
9046
|
+
} break;
|
|
9047
|
+
case WSP_GGML_UNARY_OP_EXPM1:
|
|
9048
|
+
{
|
|
9049
|
+
wsp_ggml_compute_forward_expm1(params, dst);
|
|
9050
|
+
} break;
|
|
9051
|
+
case WSP_GGML_UNARY_OP_SOFTPLUS:
|
|
9052
|
+
{
|
|
9053
|
+
wsp_ggml_compute_forward_softplus(params, dst);
|
|
9054
|
+
} break;
|
|
9000
9055
|
default:
|
|
9001
9056
|
{
|
|
9002
9057
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -9593,6 +9648,76 @@ void wsp_ggml_compute_forward_gla(
|
|
|
9593
9648
|
}
|
|
9594
9649
|
}
|
|
9595
9650
|
|
|
9651
|
+
static void wsp_ggml_compute_forward_solve_tri_f32(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst) {
|
|
9652
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
|
9653
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
|
9654
|
+
|
|
9655
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS;
|
|
9656
|
+
|
|
9657
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
9658
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
9659
|
+
WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F32);
|
|
9660
|
+
|
|
9661
|
+
WSP_GGML_ASSERT(ne00 == ne01); // A must be square
|
|
9662
|
+
WSP_GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
|
9663
|
+
WSP_GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
|
9664
|
+
|
|
9665
|
+
WSP_GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
|
9666
|
+
WSP_GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
|
9667
|
+
|
|
9668
|
+
const int ith = params->ith;
|
|
9669
|
+
const int nth = params->nth;
|
|
9670
|
+
|
|
9671
|
+
const int64_t k = ne10; // number of RHS columns
|
|
9672
|
+
const int64_t n = ne11; // A is n×n
|
|
9673
|
+
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
|
9674
|
+
|
|
9675
|
+
// chunks per thread
|
|
9676
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
|
9677
|
+
|
|
9678
|
+
// chunk range for this thread
|
|
9679
|
+
const int64_t ir0 = dr*ith;
|
|
9680
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
9681
|
+
|
|
9682
|
+
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
|
9683
|
+
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
|
9684
|
+
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
|
9685
|
+
|
|
9686
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
9687
|
+
const int64_t i03 = ir/(ne02*k);
|
|
9688
|
+
const int64_t i02 = (ir - i03*ne02*k)/k;
|
|
9689
|
+
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
|
9690
|
+
|
|
9691
|
+
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
|
9692
|
+
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
|
9693
|
+
|
|
9694
|
+
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
|
9695
|
+
|
|
9696
|
+
for (int64_t i00 = 0; i00 < n; ++i00) {
|
|
9697
|
+
float sum = 0.0f;
|
|
9698
|
+
for (int64_t t = 0; t < i00; ++t) {
|
|
9699
|
+
sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
|
|
9700
|
+
}
|
|
9701
|
+
|
|
9702
|
+
const float diag = A_batch[i00 * n + i00];
|
|
9703
|
+
WSP_GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
|
9704
|
+
|
|
9705
|
+
X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
|
|
9706
|
+
}
|
|
9707
|
+
}
|
|
9708
|
+
}
|
|
9709
|
+
|
|
9710
|
+
void wsp_ggml_compute_forward_solve_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst) {
|
|
9711
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
9712
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
9713
|
+
|
|
9714
|
+
if (src0->type == WSP_GGML_TYPE_F32 && src1->type == WSP_GGML_TYPE_F32) {
|
|
9715
|
+
wsp_ggml_compute_forward_solve_tri_f32(params, dst);
|
|
9716
|
+
} else {
|
|
9717
|
+
WSP_GGML_ABORT("fatal error");
|
|
9718
|
+
}
|
|
9719
|
+
}
|
|
9720
|
+
|
|
9596
9721
|
// wsp_ggml_compute_forward_rwkv_wkv7
|
|
9597
9722
|
|
|
9598
9723
|
static void wsp_ggml_compute_forward_rwkv_wkv7_f32(
|