whisper.rn 0.5.2 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/ggml-alloc.c +11 -4
- package/cpp/ggml-backend-reg.cpp +8 -0
- package/cpp/ggml-backend.cpp +0 -2
- package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
- package/cpp/ggml-cpu/ggml-cpu-impl.h +3 -1
- package/cpp/ggml-cpu/ggml-cpu.c +50 -21
- package/cpp/ggml-cpu/ops.cpp +458 -349
- package/cpp/ggml-cpu/ops.h +4 -4
- package/cpp/ggml-cpu/repack.cpp +143 -29
- package/cpp/ggml-cpu/simd-mappings.h +25 -25
- package/cpp/ggml-cpu/unary-ops.cpp +16 -0
- package/cpp/ggml-cpu/unary-ops.h +2 -0
- package/cpp/ggml-cpu/vec.cpp +17 -0
- package/cpp/ggml-cpu/vec.h +10 -0
- package/cpp/ggml-impl.h +17 -1
- package/cpp/ggml-metal/ggml-metal-context.m +5 -6
- package/cpp/ggml-metal/ggml-metal-device.cpp +101 -4
- package/cpp/ggml-metal/ggml-metal-device.h +8 -1
- package/cpp/ggml-metal/ggml-metal-device.m +216 -14
- package/cpp/ggml-metal/ggml-metal-impl.h +90 -2
- package/cpp/ggml-metal/ggml-metal-ops.cpp +346 -85
- package/cpp/ggml-metal/ggml-metal-ops.h +2 -0
- package/cpp/ggml-metal/ggml-metal.cpp +5 -0
- package/cpp/ggml-metal/ggml-metal.metal +12436 -0
- package/cpp/ggml.c +154 -5
- package/cpp/ggml.h +73 -0
- package/cpp/whisper.cpp +5 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/package.json +1 -1
- package/whisper-rn.podspec +1 -1
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
package/cpp/ggml-cpu/ops.cpp
CHANGED
|
@@ -7,8 +7,10 @@
|
|
|
7
7
|
#include "unary-ops.h"
|
|
8
8
|
#include "vec.h"
|
|
9
9
|
|
|
10
|
-
#include <
|
|
10
|
+
#include <cfloat>
|
|
11
11
|
#include <algorithm>
|
|
12
|
+
#include <cmath>
|
|
13
|
+
#include <functional>
|
|
12
14
|
|
|
13
15
|
// wsp_ggml_compute_forward_dup
|
|
14
16
|
|
|
@@ -1394,6 +1396,56 @@ void wsp_ggml_compute_forward_sum(
|
|
|
1394
1396
|
}
|
|
1395
1397
|
}
|
|
1396
1398
|
|
|
1399
|
+
// wsp_ggml_compute_forward_cumsum
|
|
1400
|
+
|
|
1401
|
+
static void wsp_ggml_compute_forward_cumsum_f32(
|
|
1402
|
+
const wsp_ggml_compute_params * params,
|
|
1403
|
+
wsp_ggml_tensor * dst) {
|
|
1404
|
+
|
|
1405
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
1406
|
+
|
|
1407
|
+
WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
1408
|
+
WSP_GGML_ASSERT(dst->nb[0] == sizeof(float));
|
|
1409
|
+
|
|
1410
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
1411
|
+
|
|
1412
|
+
WSP_GGML_ASSERT(ne0 == ne00);
|
|
1413
|
+
WSP_GGML_ASSERT(ne1 == ne01);
|
|
1414
|
+
WSP_GGML_ASSERT(ne2 == ne02);
|
|
1415
|
+
WSP_GGML_ASSERT(ne3 == ne03);
|
|
1416
|
+
|
|
1417
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
1418
|
+
|
|
1419
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
1420
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
1421
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
1422
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
1423
|
+
|
|
1424
|
+
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
1425
|
+
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
1426
|
+
|
|
1427
|
+
wsp_ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
|
1428
|
+
}
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
void wsp_ggml_compute_forward_cumsum(
|
|
1432
|
+
const wsp_ggml_compute_params * params,
|
|
1433
|
+
wsp_ggml_tensor * dst) {
|
|
1434
|
+
|
|
1435
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
1436
|
+
|
|
1437
|
+
switch (src0->type) {
|
|
1438
|
+
case WSP_GGML_TYPE_F32:
|
|
1439
|
+
{
|
|
1440
|
+
wsp_ggml_compute_forward_cumsum_f32(params, dst);
|
|
1441
|
+
} break;
|
|
1442
|
+
default:
|
|
1443
|
+
{
|
|
1444
|
+
WSP_GGML_ABORT("fatal error");
|
|
1445
|
+
}
|
|
1446
|
+
}
|
|
1447
|
+
}
|
|
1448
|
+
|
|
1397
1449
|
// wsp_ggml_compute_forward_sum_rows
|
|
1398
1450
|
|
|
1399
1451
|
static void wsp_ggml_compute_forward_sum_rows_f32(
|
|
@@ -2140,6 +2192,83 @@ static void wsp_ggml_compute_forward_gelu(
|
|
|
2140
2192
|
}
|
|
2141
2193
|
}
|
|
2142
2194
|
|
|
2195
|
+
// wsp_ggml_compute_fill
|
|
2196
|
+
|
|
2197
|
+
static void wsp_ggml_compute_forward_fill_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
2198
|
+
const float c = wsp_ggml_get_op_params_f32(dst, 0);
|
|
2199
|
+
|
|
2200
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
|
2201
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
|
2202
|
+
|
|
2203
|
+
const auto [ir0, ir1] = get_thread_range(params, dst);
|
|
2204
|
+
|
|
2205
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2206
|
+
const int64_t i03 = ir/(ne2*ne1);
|
|
2207
|
+
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
|
2208
|
+
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
|
2209
|
+
|
|
2210
|
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2211
|
+
|
|
2212
|
+
wsp_ggml_vec_set_f32(ne0, dst_ptr, c);
|
|
2213
|
+
}
|
|
2214
|
+
}
|
|
2215
|
+
|
|
2216
|
+
void wsp_ggml_compute_forward_fill(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
2217
|
+
wsp_ggml_compute_forward_fill_f32(params, dst);
|
|
2218
|
+
}
|
|
2219
|
+
|
|
2220
|
+
// wsp_ggml_compute_tri
|
|
2221
|
+
|
|
2222
|
+
static void wsp_ggml_compute_forward_tri_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
2223
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
2224
|
+
|
|
2225
|
+
const wsp_ggml_tri_type ttype = (wsp_ggml_tri_type) wsp_ggml_get_op_params_i32(dst, 0);
|
|
2226
|
+
|
|
2227
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
2228
|
+
|
|
2229
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
2230
|
+
|
|
2231
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
2232
|
+
|
|
2233
|
+
bool (*bipred)(int, int);
|
|
2234
|
+
|
|
2235
|
+
switch (ttype) {
|
|
2236
|
+
case WSP_GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
|
2237
|
+
case WSP_GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
|
2238
|
+
case WSP_GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
|
2239
|
+
case WSP_GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
|
2240
|
+
default: WSP_GGML_ABORT("invalid tri type");
|
|
2241
|
+
}
|
|
2242
|
+
|
|
2243
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2244
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
2245
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
2246
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
2247
|
+
|
|
2248
|
+
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
2249
|
+
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2250
|
+
|
|
2251
|
+
for (int i0 = 0; i0 < ne0; ++i0) {
|
|
2252
|
+
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
|
2253
|
+
}
|
|
2254
|
+
}
|
|
2255
|
+
}
|
|
2256
|
+
|
|
2257
|
+
void wsp_ggml_compute_forward_tri(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
2258
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
2259
|
+
|
|
2260
|
+
switch (src0->type) {
|
|
2261
|
+
case WSP_GGML_TYPE_F32:
|
|
2262
|
+
{
|
|
2263
|
+
wsp_ggml_compute_forward_tri_f32(params, dst);
|
|
2264
|
+
} break;
|
|
2265
|
+
default:
|
|
2266
|
+
{
|
|
2267
|
+
WSP_GGML_ABORT("fatal error");
|
|
2268
|
+
}
|
|
2269
|
+
}
|
|
2270
|
+
}
|
|
2271
|
+
|
|
2143
2272
|
// wsp_ggml_compute_forward_gelu_erf
|
|
2144
2273
|
|
|
2145
2274
|
static void wsp_ggml_compute_forward_gelu_erf_f32(
|
|
@@ -4455,46 +4584,6 @@ void wsp_ggml_compute_forward_cont(
|
|
|
4455
4584
|
wsp_ggml_compute_forward_dup(params, dst);
|
|
4456
4585
|
}
|
|
4457
4586
|
|
|
4458
|
-
// wsp_ggml_compute_forward_reshape
|
|
4459
|
-
|
|
4460
|
-
void wsp_ggml_compute_forward_reshape(
|
|
4461
|
-
const wsp_ggml_compute_params * params,
|
|
4462
|
-
wsp_ggml_tensor * dst) {
|
|
4463
|
-
// NOP
|
|
4464
|
-
WSP_GGML_UNUSED(params);
|
|
4465
|
-
WSP_GGML_UNUSED(dst);
|
|
4466
|
-
}
|
|
4467
|
-
|
|
4468
|
-
// wsp_ggml_compute_forward_view
|
|
4469
|
-
|
|
4470
|
-
void wsp_ggml_compute_forward_view(
|
|
4471
|
-
const wsp_ggml_compute_params * params,
|
|
4472
|
-
wsp_ggml_tensor * dst) {
|
|
4473
|
-
// NOP
|
|
4474
|
-
WSP_GGML_UNUSED(params);
|
|
4475
|
-
WSP_GGML_UNUSED(dst);
|
|
4476
|
-
}
|
|
4477
|
-
|
|
4478
|
-
// wsp_ggml_compute_forward_permute
|
|
4479
|
-
|
|
4480
|
-
void wsp_ggml_compute_forward_permute(
|
|
4481
|
-
const wsp_ggml_compute_params * params,
|
|
4482
|
-
wsp_ggml_tensor * dst) {
|
|
4483
|
-
// NOP
|
|
4484
|
-
WSP_GGML_UNUSED(params);
|
|
4485
|
-
WSP_GGML_UNUSED(dst);
|
|
4486
|
-
}
|
|
4487
|
-
|
|
4488
|
-
// wsp_ggml_compute_forward_transpose
|
|
4489
|
-
|
|
4490
|
-
void wsp_ggml_compute_forward_transpose(
|
|
4491
|
-
const wsp_ggml_compute_params * params,
|
|
4492
|
-
wsp_ggml_tensor * dst) {
|
|
4493
|
-
// NOP
|
|
4494
|
-
WSP_GGML_UNUSED(params);
|
|
4495
|
-
WSP_GGML_UNUSED(dst);
|
|
4496
|
-
}
|
|
4497
|
-
|
|
4498
4587
|
// wsp_ggml_compute_forward_get_rows
|
|
4499
4588
|
|
|
4500
4589
|
static void wsp_ggml_compute_forward_get_rows_q(
|
|
@@ -5474,7 +5563,7 @@ static void wsp_ggml_rope_cache_init(
|
|
|
5474
5563
|
}
|
|
5475
5564
|
|
|
5476
5565
|
static void wsp_ggml_mrope_cache_init(
|
|
5477
|
-
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
|
5566
|
+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
|
|
5478
5567
|
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5479
5568
|
float * cache, float sin_sign, float theta_scale) {
|
|
5480
5569
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
@@ -5509,14 +5598,26 @@ static void wsp_ggml_mrope_cache_init(
|
|
|
5509
5598
|
}
|
|
5510
5599
|
|
|
5511
5600
|
float theta = theta_t;
|
|
5512
|
-
if (
|
|
5513
|
-
|
|
5514
|
-
|
|
5515
|
-
|
|
5516
|
-
|
|
5517
|
-
|
|
5518
|
-
|
|
5519
|
-
|
|
5601
|
+
if (is_imrope) { // qwen3vl apply interleaved mrope
|
|
5602
|
+
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
|
5603
|
+
theta = theta_h;
|
|
5604
|
+
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
|
5605
|
+
theta = theta_w;
|
|
5606
|
+
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
|
5607
|
+
theta = theta_t;
|
|
5608
|
+
} else {
|
|
5609
|
+
theta = theta_e;
|
|
5610
|
+
}
|
|
5611
|
+
} else {
|
|
5612
|
+
if (sector >= sections[0] && sector < sec_w) {
|
|
5613
|
+
theta = theta_h;
|
|
5614
|
+
}
|
|
5615
|
+
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
|
5616
|
+
theta = theta_w;
|
|
5617
|
+
}
|
|
5618
|
+
else if (sector >= sec_w + sections[2]) {
|
|
5619
|
+
theta = theta_e;
|
|
5620
|
+
}
|
|
5520
5621
|
}
|
|
5521
5622
|
|
|
5522
5623
|
rope_yarn(
|
|
@@ -5531,193 +5632,28 @@ static void wsp_ggml_mrope_cache_init(
|
|
|
5531
5632
|
}
|
|
5532
5633
|
}
|
|
5533
5634
|
|
|
5534
|
-
static void wsp_ggml_compute_forward_rope_f32(
|
|
5535
|
-
const wsp_ggml_compute_params * params,
|
|
5536
|
-
wsp_ggml_tensor * dst,
|
|
5537
|
-
const bool forward) {
|
|
5538
|
-
|
|
5539
|
-
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
5540
|
-
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
5541
|
-
const wsp_ggml_tensor * src2 = dst->src[2];
|
|
5542
|
-
|
|
5543
|
-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
5544
|
-
int sections[4];
|
|
5545
|
-
|
|
5546
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
5547
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
5548
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
5549
|
-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
5550
|
-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
5551
|
-
|
|
5552
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
5553
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
5554
|
-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
5555
|
-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
5556
|
-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
5557
|
-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
5558
|
-
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
5559
|
-
|
|
5560
|
-
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
5561
|
-
|
|
5562
|
-
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
5563
|
-
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
5564
|
-
|
|
5565
|
-
WSP_GGML_ASSERT(nb00 == sizeof(float));
|
|
5566
|
-
|
|
5567
|
-
const int ith = params->ith;
|
|
5568
|
-
const int nth = params->nth;
|
|
5569
|
-
|
|
5570
|
-
const int nr = wsp_ggml_nrows(dst);
|
|
5571
|
-
|
|
5572
|
-
WSP_GGML_ASSERT(n_dims <= ne0);
|
|
5573
|
-
WSP_GGML_ASSERT(n_dims % 2 == 0);
|
|
5574
|
-
|
|
5575
|
-
// rows per thread
|
|
5576
|
-
const int dr = (nr + nth - 1)/nth;
|
|
5577
|
-
|
|
5578
|
-
// row range for this thread
|
|
5579
|
-
const int ir0 = dr*ith;
|
|
5580
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
5581
|
-
|
|
5582
|
-
// row index used to determine which thread to use
|
|
5583
|
-
int ir = 0;
|
|
5584
|
-
|
|
5585
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
5586
|
-
|
|
5587
|
-
float corr_dims[2];
|
|
5588
|
-
wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
5589
|
-
|
|
5590
|
-
const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
|
|
5591
|
-
const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, multimodal rotary position embedding
|
|
5592
|
-
const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
|
|
5593
|
-
|
|
5594
|
-
if (is_mrope) {
|
|
5595
|
-
WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5596
|
-
}
|
|
5597
|
-
|
|
5598
|
-
if (is_vision) {
|
|
5599
|
-
WSP_GGML_ASSERT(n_dims == ne0/2);
|
|
5600
|
-
}
|
|
5601
|
-
|
|
5602
|
-
const float * freq_factors = NULL;
|
|
5603
|
-
if (src2 != NULL) {
|
|
5604
|
-
WSP_GGML_ASSERT(src2->type == WSP_GGML_TYPE_F32);
|
|
5605
|
-
WSP_GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
|
5606
|
-
freq_factors = (const float *) src2->data;
|
|
5607
|
-
}
|
|
5608
|
-
|
|
5609
|
-
// backward process uses inverse rotation by cos and sin.
|
|
5610
|
-
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
5611
|
-
// this essentially just switches the sign of sin.
|
|
5612
|
-
const float sin_sign = forward ? 1.0f : -1.0f;
|
|
5613
5635
|
|
|
5614
|
-
|
|
5615
|
-
|
|
5616
|
-
|
|
5617
|
-
|
|
5618
|
-
|
|
5619
|
-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5620
|
-
if (!is_mrope) {
|
|
5621
|
-
const int64_t p = pos[i2];
|
|
5622
|
-
wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5623
|
-
}
|
|
5624
|
-
else {
|
|
5625
|
-
const int64_t p_t = pos[i2];
|
|
5626
|
-
const int64_t p_h = pos[i2 + ne2];
|
|
5627
|
-
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5628
|
-
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5629
|
-
wsp_ggml_mrope_cache_init(
|
|
5630
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5631
|
-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5632
|
-
}
|
|
5633
|
-
|
|
5634
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5635
|
-
if (ir++ < ir0) continue;
|
|
5636
|
-
if (ir > ir1) break;
|
|
5637
|
-
|
|
5638
|
-
if (is_neox || is_mrope) {
|
|
5639
|
-
if (is_vision){
|
|
5640
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5641
|
-
const int64_t ic = i0/2;
|
|
5642
|
-
|
|
5643
|
-
const float cos_theta = cache[i0 + 0];
|
|
5644
|
-
const float sin_theta = cache[i0 + 1];
|
|
5645
|
-
|
|
5646
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5647
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5648
|
-
|
|
5649
|
-
const float x0 = src[0];
|
|
5650
|
-
const float x1 = src[n_dims];
|
|
5651
|
-
|
|
5652
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5653
|
-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
5654
|
-
}
|
|
5655
|
-
} else {
|
|
5656
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5657
|
-
const int64_t ic = i0/2;
|
|
5636
|
+
template<typename T>
|
|
5637
|
+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
|
|
5638
|
+
for (int64_t i0 = 0; i0 < n; i0 += 2) {
|
|
5639
|
+
const int64_t ic = i0/scale; // hack for WSP_GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
|
|
5658
5640
|
|
|
5659
|
-
|
|
5660
|
-
|
|
5641
|
+
const float cos_theta = cache[i0 + 0];
|
|
5642
|
+
const float sin_theta = cache[i0 + 1];
|
|
5661
5643
|
|
|
5662
|
-
|
|
5663
|
-
|
|
5644
|
+
const T * const src = src_data + ic;
|
|
5645
|
+
T * dst = dst_data + ic;
|
|
5664
5646
|
|
|
5665
|
-
|
|
5666
|
-
|
|
5647
|
+
const float x0 = type_conversion_table<T>::to_f32(src[0]);
|
|
5648
|
+
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
|
|
5667
5649
|
|
|
5668
|
-
|
|
5669
|
-
|
|
5670
|
-
|
|
5671
|
-
}
|
|
5672
|
-
} else {
|
|
5673
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5674
|
-
const float cos_theta = cache[i0 + 0];
|
|
5675
|
-
const float sin_theta = cache[i0 + 1];
|
|
5676
|
-
|
|
5677
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5678
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5679
|
-
|
|
5680
|
-
const float x0 = src[0];
|
|
5681
|
-
const float x1 = src[1];
|
|
5682
|
-
|
|
5683
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5684
|
-
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
5685
|
-
}
|
|
5686
|
-
}
|
|
5687
|
-
|
|
5688
|
-
if (is_vision) {
|
|
5689
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5690
|
-
const int64_t ic = i0/2;
|
|
5691
|
-
|
|
5692
|
-
const float cos_theta = cache[i0 + 0];
|
|
5693
|
-
const float sin_theta = cache[i0 + 1];
|
|
5694
|
-
|
|
5695
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5696
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5697
|
-
|
|
5698
|
-
const float x0 = src[0];
|
|
5699
|
-
const float x1 = src[n_dims];
|
|
5700
|
-
|
|
5701
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5702
|
-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
5703
|
-
}
|
|
5704
|
-
} else {
|
|
5705
|
-
// fill the remain channels with data from src tensor
|
|
5706
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5707
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5708
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5709
|
-
|
|
5710
|
-
dst_data[0] = src[0];
|
|
5711
|
-
dst_data[1] = src[1];
|
|
5712
|
-
}
|
|
5713
|
-
}
|
|
5714
|
-
}
|
|
5715
|
-
}
|
|
5716
|
-
}
|
|
5650
|
+
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
|
|
5651
|
+
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
|
|
5652
|
+
}
|
|
5717
5653
|
}
|
|
5718
5654
|
|
|
5719
|
-
|
|
5720
|
-
static void
|
|
5655
|
+
template<typename T> //float or wsp_ggml_fp16_t
|
|
5656
|
+
static void wsp_ggml_compute_forward_rope_flt(
|
|
5721
5657
|
const wsp_ggml_compute_params * params,
|
|
5722
5658
|
wsp_ggml_tensor * dst,
|
|
5723
5659
|
const bool forward) {
|
|
@@ -5726,6 +5662,9 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5726
5662
|
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
5727
5663
|
const wsp_ggml_tensor * src2 = dst->src[2];
|
|
5728
5664
|
|
|
5665
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32 || src0->type == WSP_GGML_TYPE_F16);
|
|
5666
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_I32);
|
|
5667
|
+
|
|
5729
5668
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
5730
5669
|
int sections[4];
|
|
5731
5670
|
|
|
@@ -5734,6 +5673,7 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5734
5673
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
5735
5674
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
5736
5675
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
5676
|
+
|
|
5737
5677
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
5738
5678
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
5739
5679
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
@@ -5742,13 +5682,13 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5742
5682
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
5743
5683
|
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
5744
5684
|
|
|
5745
|
-
|
|
5746
5685
|
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
5747
5686
|
|
|
5748
5687
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
5749
5688
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
5750
5689
|
|
|
5751
|
-
WSP_GGML_ASSERT(nb0 ==
|
|
5690
|
+
WSP_GGML_ASSERT(nb0 == nb00);
|
|
5691
|
+
WSP_GGML_ASSERT(nb0 == sizeof(T));
|
|
5752
5692
|
|
|
5753
5693
|
const int ith = params->ith;
|
|
5754
5694
|
const int nth = params->nth;
|
|
@@ -5773,11 +5713,11 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5773
5713
|
float corr_dims[2];
|
|
5774
5714
|
wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
5775
5715
|
|
|
5776
|
-
const bool
|
|
5777
|
-
const bool
|
|
5716
|
+
const bool is_imrope = mode == WSP_GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
|
5717
|
+
const bool mrope_used = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
|
5778
5718
|
const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
|
|
5779
5719
|
|
|
5780
|
-
if (
|
|
5720
|
+
if (mrope_used) {
|
|
5781
5721
|
WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5782
5722
|
}
|
|
5783
5723
|
|
|
@@ -5799,11 +5739,11 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5799
5739
|
|
|
5800
5740
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
5801
5741
|
|
|
5802
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
5803
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
5742
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
5743
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
5804
5744
|
|
|
5805
5745
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5806
|
-
if (!
|
|
5746
|
+
if (!mrope_used) {
|
|
5807
5747
|
const int64_t p = pos[i2];
|
|
5808
5748
|
wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5809
5749
|
}
|
|
@@ -5813,90 +5753,44 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
5813
5753
|
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5814
5754
|
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5815
5755
|
wsp_ggml_mrope_cache_init(
|
|
5816
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5756
|
+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
5817
5757
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5818
5758
|
}
|
|
5819
5759
|
|
|
5820
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
5760
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5821
5761
|
if (ir++ < ir0) continue;
|
|
5822
5762
|
if (ir > ir1) break;
|
|
5823
5763
|
|
|
5824
|
-
|
|
5825
|
-
|
|
5826
|
-
|
|
5827
|
-
|
|
5828
|
-
|
|
5829
|
-
|
|
5830
|
-
|
|
5831
|
-
|
|
5832
|
-
|
|
5833
|
-
|
|
5834
|
-
|
|
5835
|
-
|
|
5836
|
-
|
|
5837
|
-
|
|
5838
|
-
|
|
5839
|
-
|
|
5840
|
-
|
|
5841
|
-
} else {
|
|
5842
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5843
|
-
const int64_t ic = i0/2;
|
|
5844
|
-
|
|
5845
|
-
const float cos_theta = cache[i0 + 0];
|
|
5846
|
-
const float sin_theta = cache[i0 + 1];
|
|
5847
|
-
|
|
5848
|
-
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5849
|
-
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5850
|
-
|
|
5851
|
-
const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5852
|
-
const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
|
5853
|
-
|
|
5854
|
-
dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5855
|
-
dst_data[n_dims/2] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5856
|
-
}
|
|
5857
|
-
}
|
|
5858
|
-
} else {
|
|
5859
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5860
|
-
const float cos_theta = cache[i0 + 0];
|
|
5861
|
-
const float sin_theta = cache[i0 + 1];
|
|
5862
|
-
|
|
5863
|
-
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5864
|
-
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5865
|
-
|
|
5866
|
-
const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5867
|
-
const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[1]);
|
|
5868
|
-
|
|
5869
|
-
dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5870
|
-
dst_data[1] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5871
|
-
}
|
|
5764
|
+
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
5765
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
5766
|
+
|
|
5767
|
+
switch (mode) {
|
|
5768
|
+
case WSP_GGML_ROPE_TYPE_NORMAL:
|
|
5769
|
+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
|
|
5770
|
+
break;
|
|
5771
|
+
case WSP_GGML_ROPE_TYPE_NEOX:
|
|
5772
|
+
case WSP_GGML_ROPE_TYPE_MROPE:
|
|
5773
|
+
case WSP_GGML_ROPE_TYPE_IMROPE:
|
|
5774
|
+
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
|
|
5775
|
+
break;
|
|
5776
|
+
case WSP_GGML_ROPE_TYPE_VISION:
|
|
5777
|
+
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
|
|
5778
|
+
break;
|
|
5779
|
+
default:
|
|
5780
|
+
WSP_GGML_ABORT("rope type not supported");
|
|
5872
5781
|
}
|
|
5873
5782
|
|
|
5874
|
-
if (is_vision) {
|
|
5875
|
-
|
|
5876
|
-
const int64_t ic = i0/2;
|
|
5877
|
-
|
|
5878
|
-
const float cos_theta = cache[i0 + 0];
|
|
5879
|
-
const float sin_theta = cache[i0 + 1];
|
|
5880
|
-
|
|
5881
|
-
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5882
|
-
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5883
|
-
|
|
5884
|
-
const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5885
|
-
const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
5886
|
-
|
|
5887
|
-
dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5888
|
-
dst_data[n_dims] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5889
|
-
}
|
|
5890
|
-
} else {
|
|
5783
|
+
if (!is_vision) {
|
|
5784
|
+
// fill the remain channels with data from src tensor
|
|
5891
5785
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5892
|
-
const
|
|
5893
|
-
|
|
5786
|
+
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5787
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5894
5788
|
|
|
5895
5789
|
dst_data[0] = src[0];
|
|
5896
5790
|
dst_data[1] = src[1];
|
|
5897
5791
|
}
|
|
5898
5792
|
}
|
|
5899
|
-
}
|
|
5793
|
+
} //attn-heads
|
|
5900
5794
|
}
|
|
5901
5795
|
}
|
|
5902
5796
|
}
|
|
@@ -5910,11 +5804,11 @@ void wsp_ggml_compute_forward_rope(
|
|
|
5910
5804
|
switch (src0->type) {
|
|
5911
5805
|
case WSP_GGML_TYPE_F16:
|
|
5912
5806
|
{
|
|
5913
|
-
|
|
5807
|
+
wsp_ggml_compute_forward_rope_flt<wsp_ggml_fp16_t>(params, dst, true);
|
|
5914
5808
|
} break;
|
|
5915
5809
|
case WSP_GGML_TYPE_F32:
|
|
5916
5810
|
{
|
|
5917
|
-
|
|
5811
|
+
wsp_ggml_compute_forward_rope_flt<float>(params, dst, true);
|
|
5918
5812
|
} break;
|
|
5919
5813
|
default:
|
|
5920
5814
|
{
|
|
@@ -5934,11 +5828,11 @@ void wsp_ggml_compute_forward_rope_back(
|
|
|
5934
5828
|
switch (src0->type) {
|
|
5935
5829
|
case WSP_GGML_TYPE_F16:
|
|
5936
5830
|
{
|
|
5937
|
-
|
|
5831
|
+
wsp_ggml_compute_forward_rope_flt<wsp_ggml_fp16_t>(params, dst, false);
|
|
5938
5832
|
} break;
|
|
5939
5833
|
case WSP_GGML_TYPE_F32:
|
|
5940
5834
|
{
|
|
5941
|
-
|
|
5835
|
+
wsp_ggml_compute_forward_rope_flt<float>(params, dst, false);
|
|
5942
5836
|
} break;
|
|
5943
5837
|
default:
|
|
5944
5838
|
{
|
|
@@ -7070,7 +6964,11 @@ static void wsp_ggml_compute_forward_conv_2d_dw_cwhn(
|
|
|
7070
6964
|
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
|
|
7071
6965
|
|
|
7072
6966
|
#ifdef WSP_GGML_SIMD
|
|
7073
|
-
|
|
6967
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
6968
|
+
const int64_t pkg_size = svcntw();
|
|
6969
|
+
#else
|
|
6970
|
+
const int64_t pkg_size = WSP_GGML_F32_EPR;
|
|
6971
|
+
#endif
|
|
7074
6972
|
const int64_t pkg_count = c / pkg_size;
|
|
7075
6973
|
const int64_t c_pkg_end = pkg_count * pkg_size;
|
|
7076
6974
|
#else
|
|
@@ -7493,10 +7391,17 @@ static void wsp_ggml_compute_forward_upscale_f32(
|
|
|
7493
7391
|
float sf1 = (float)ne1/src0->ne[1];
|
|
7494
7392
|
float sf2 = (float)ne2/src0->ne[2];
|
|
7495
7393
|
float sf3 = (float)ne3/src0->ne[3];
|
|
7394
|
+
float pixel_offset = 0.5f;
|
|
7496
7395
|
|
|
7497
7396
|
const int32_t mode_flags = wsp_ggml_get_op_params_i32(dst, 0);
|
|
7498
7397
|
const wsp_ggml_scale_mode mode = (wsp_ggml_scale_mode) (mode_flags & 0xFF);
|
|
7499
7398
|
|
|
7399
|
+
if (mode_flags & WSP_GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7400
|
+
pixel_offset = 0.0f;
|
|
7401
|
+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
|
7402
|
+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
|
7403
|
+
}
|
|
7404
|
+
|
|
7500
7405
|
if (mode == WSP_GGML_SCALE_MODE_NEAREST) {
|
|
7501
7406
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7502
7407
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7516,13 +7421,6 @@ static void wsp_ggml_compute_forward_upscale_f32(
|
|
|
7516
7421
|
}
|
|
7517
7422
|
}
|
|
7518
7423
|
} else if (mode == WSP_GGML_SCALE_MODE_BILINEAR) {
|
|
7519
|
-
float pixel_offset = 0.5f;
|
|
7520
|
-
if (mode_flags & WSP_GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7521
|
-
pixel_offset = 0.0f;
|
|
7522
|
-
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
|
7523
|
-
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
|
7524
|
-
}
|
|
7525
|
-
|
|
7526
7424
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7527
7425
|
const int64_t i03 = i3 / sf3;
|
|
7528
7426
|
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
@@ -7557,6 +7455,51 @@ static void wsp_ggml_compute_forward_upscale_f32(
|
|
|
7557
7455
|
|
|
7558
7456
|
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
|
|
7559
7457
|
|
|
7458
|
+
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7459
|
+
*y_dst = val;
|
|
7460
|
+
}
|
|
7461
|
+
}
|
|
7462
|
+
}
|
|
7463
|
+
}
|
|
7464
|
+
} else if (mode == WSP_GGML_SCALE_MODE_BICUBIC) {
|
|
7465
|
+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
|
7466
|
+
const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
|
|
7467
|
+
auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
|
|
7468
|
+
auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
|
|
7469
|
+
auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
|
|
7470
|
+
const float w0 = weight2(x + 1);
|
|
7471
|
+
const float w1 = weight1(x + 0);
|
|
7472
|
+
const float w2 = weight1(1 - x);
|
|
7473
|
+
const float w3 = weight2(2 - x);
|
|
7474
|
+
return p0*w0 + p1*w1 + p2*w2 + p3*w3;
|
|
7475
|
+
};
|
|
7476
|
+
|
|
7477
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7478
|
+
const int64_t i03 = i3 / sf3;
|
|
7479
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7480
|
+
const int64_t i02 = i2 / sf2;
|
|
7481
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7482
|
+
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
|
|
7483
|
+
const int64_t y0 = (int64_t)floorf(y);
|
|
7484
|
+
const float dy = y - (float)y0;
|
|
7485
|
+
|
|
7486
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7487
|
+
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
|
|
7488
|
+
const int64_t x0 = (int64_t)floorf(x);
|
|
7489
|
+
const float dx = x - (float)x0;
|
|
7490
|
+
|
|
7491
|
+
auto p = [=](int64_t x_off, int64_t y_off) -> float {
|
|
7492
|
+
int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
|
|
7493
|
+
int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
|
|
7494
|
+
return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
7495
|
+
};
|
|
7496
|
+
|
|
7497
|
+
const float val = bicubic(
|
|
7498
|
+
bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
|
|
7499
|
+
bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
|
|
7500
|
+
bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
|
|
7501
|
+
bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
|
|
7502
|
+
|
|
7560
7503
|
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7561
7504
|
*y_dst = val;
|
|
7562
7505
|
}
|
|
@@ -7850,6 +7793,18 @@ void wsp_ggml_compute_forward_timestep_embedding(
|
|
|
7850
7793
|
|
|
7851
7794
|
// wsp_ggml_compute_forward_argsort
|
|
7852
7795
|
|
|
7796
|
+
template<enum wsp_ggml_sort_order order>
|
|
7797
|
+
struct argsort_cmp {
|
|
7798
|
+
const float * data;
|
|
7799
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
7800
|
+
if constexpr (order == WSP_GGML_SORT_ORDER_ASC) {
|
|
7801
|
+
return data[a] < data[b];
|
|
7802
|
+
} else {
|
|
7803
|
+
return data[a] > data[b];
|
|
7804
|
+
}
|
|
7805
|
+
}
|
|
7806
|
+
};
|
|
7807
|
+
|
|
7853
7808
|
static void wsp_ggml_compute_forward_argsort_f32(
|
|
7854
7809
|
const wsp_ggml_compute_params * params,
|
|
7855
7810
|
wsp_ggml_tensor * dst) {
|
|
@@ -7868,23 +7823,25 @@ static void wsp_ggml_compute_forward_argsort_f32(
|
|
|
7868
7823
|
wsp_ggml_sort_order order = (wsp_ggml_sort_order) wsp_ggml_get_op_params_i32(dst, 0);
|
|
7869
7824
|
|
|
7870
7825
|
for (int64_t i = ith; i < nr; i += nth) {
|
|
7871
|
-
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7872
7826
|
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
7873
7827
|
|
|
7828
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7829
|
+
|
|
7874
7830
|
for (int64_t j = 0; j < ne0; j++) {
|
|
7875
7831
|
dst_data[j] = j;
|
|
7876
7832
|
}
|
|
7877
7833
|
|
|
7878
|
-
|
|
7879
|
-
|
|
7880
|
-
|
|
7881
|
-
|
|
7882
|
-
|
|
7883
|
-
|
|
7884
|
-
|
|
7885
|
-
|
|
7886
|
-
|
|
7887
|
-
|
|
7834
|
+
switch (order) {
|
|
7835
|
+
case WSP_GGML_SORT_ORDER_ASC:
|
|
7836
|
+
std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_ASC>{src_data});
|
|
7837
|
+
break;
|
|
7838
|
+
|
|
7839
|
+
case WSP_GGML_SORT_ORDER_DESC:
|
|
7840
|
+
std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_DESC>{src_data});
|
|
7841
|
+
break;
|
|
7842
|
+
|
|
7843
|
+
default:
|
|
7844
|
+
WSP_GGML_ABORT("invalid sort order");
|
|
7888
7845
|
}
|
|
7889
7846
|
}
|
|
7890
7847
|
}
|
|
@@ -7909,10 +7866,10 @@ void wsp_ggml_compute_forward_argsort(
|
|
|
7909
7866
|
|
|
7910
7867
|
// wsp_ggml_compute_forward_flash_attn_ext
|
|
7911
7868
|
|
|
7912
|
-
static void
|
|
7869
|
+
static void wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
7913
7870
|
const wsp_ggml_compute_params * params,
|
|
7914
|
-
wsp_ggml_tensor * dst
|
|
7915
|
-
|
|
7871
|
+
wsp_ggml_tensor * dst,
|
|
7872
|
+
int ir0, int ir1) {
|
|
7916
7873
|
const wsp_ggml_tensor * q = dst->src[0];
|
|
7917
7874
|
const wsp_ggml_tensor * k = dst->src[1];
|
|
7918
7875
|
const wsp_ggml_tensor * v = dst->src[2];
|
|
@@ -7928,9 +7885,6 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7928
7885
|
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
7929
7886
|
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
7930
7887
|
|
|
7931
|
-
const int ith = params->ith;
|
|
7932
|
-
const int nth = params->nth;
|
|
7933
|
-
|
|
7934
7888
|
const int64_t DK = nek0;
|
|
7935
7889
|
const int64_t DV = nev0;
|
|
7936
7890
|
const int64_t N = neq1;
|
|
@@ -7964,16 +7918,6 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7964
7918
|
|
|
7965
7919
|
// parallelize by q rows using wsp_ggml_vec_dot_f32
|
|
7966
7920
|
|
|
7967
|
-
// total rows in q
|
|
7968
|
-
const int nr = neq1*neq2*neq3;
|
|
7969
|
-
|
|
7970
|
-
// rows per thread
|
|
7971
|
-
const int dr = (nr + nth - 1)/nth;
|
|
7972
|
-
|
|
7973
|
-
// row range for this thread
|
|
7974
|
-
const int ir0 = dr*ith;
|
|
7975
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
7976
|
-
|
|
7977
7921
|
float scale = 1.0f;
|
|
7978
7922
|
float max_bias = 0.0f;
|
|
7979
7923
|
float logit_softcap = 0.0f;
|
|
@@ -8000,6 +7944,8 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8000
7944
|
WSP_GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
|
8001
7945
|
WSP_GGML_ASSERT((v->type == WSP_GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
|
8002
7946
|
|
|
7947
|
+
int ith = params->ith;
|
|
7948
|
+
|
|
8003
7949
|
// loop over n_batch and n_head
|
|
8004
7950
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
8005
7951
|
// q indices
|
|
@@ -8147,6 +8093,91 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8147
8093
|
}
|
|
8148
8094
|
}
|
|
8149
8095
|
|
|
8096
|
+
static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
8097
|
+
const wsp_ggml_compute_params * params,
|
|
8098
|
+
wsp_ggml_tensor * dst) {
|
|
8099
|
+
|
|
8100
|
+
const wsp_ggml_tensor * q = dst->src[0];
|
|
8101
|
+
const wsp_ggml_tensor * k = dst->src[1];
|
|
8102
|
+
const wsp_ggml_tensor * v = dst->src[2];
|
|
8103
|
+
|
|
8104
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
8105
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8106
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
8107
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
8108
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
8109
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
8110
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8111
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8112
|
+
|
|
8113
|
+
const int64_t DK = nek0;
|
|
8114
|
+
const int64_t DV = nev0;
|
|
8115
|
+
const int64_t N = neq1;
|
|
8116
|
+
|
|
8117
|
+
WSP_GGML_ASSERT(ne0 == DV);
|
|
8118
|
+
WSP_GGML_ASSERT(ne2 == N);
|
|
8119
|
+
|
|
8120
|
+
// input tensor rows must be contiguous
|
|
8121
|
+
WSP_GGML_ASSERT(nbq0 == wsp_ggml_type_size(q->type));
|
|
8122
|
+
WSP_GGML_ASSERT(nbk0 == wsp_ggml_type_size(k->type));
|
|
8123
|
+
WSP_GGML_ASSERT(nbv0 == wsp_ggml_type_size(v->type));
|
|
8124
|
+
|
|
8125
|
+
WSP_GGML_ASSERT(neq0 == DK);
|
|
8126
|
+
WSP_GGML_ASSERT(nek0 == DK);
|
|
8127
|
+
WSP_GGML_ASSERT(nev0 == DV);
|
|
8128
|
+
|
|
8129
|
+
WSP_GGML_ASSERT(neq1 == N);
|
|
8130
|
+
|
|
8131
|
+
// dst cannot be transposed or permuted
|
|
8132
|
+
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
8133
|
+
WSP_GGML_ASSERT(nb0 <= nb1);
|
|
8134
|
+
WSP_GGML_ASSERT(nb1 <= nb2);
|
|
8135
|
+
WSP_GGML_ASSERT(nb2 <= nb3);
|
|
8136
|
+
|
|
8137
|
+
// parallelize by q rows using wsp_ggml_vec_dot_f32
|
|
8138
|
+
|
|
8139
|
+
// total rows in q
|
|
8140
|
+
const int64_t nr = neq1*neq2*neq3;
|
|
8141
|
+
|
|
8142
|
+
// rows per thread
|
|
8143
|
+
const int ith = params->ith;
|
|
8144
|
+
const int nth = params->nth;
|
|
8145
|
+
|
|
8146
|
+
// disable for NUMA
|
|
8147
|
+
const bool disable_chunking = wsp_ggml_is_numa();
|
|
8148
|
+
|
|
8149
|
+
// 4x chunks per thread
|
|
8150
|
+
int nth_scaled = nth * 4;
|
|
8151
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
8152
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
8153
|
+
|
|
8154
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
8155
|
+
nchunk = nth;
|
|
8156
|
+
}
|
|
8157
|
+
|
|
8158
|
+
if (ith == 0) {
|
|
8159
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
8160
|
+
wsp_ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
8161
|
+
}
|
|
8162
|
+
|
|
8163
|
+
wsp_ggml_barrier(params->threadpool);
|
|
8164
|
+
|
|
8165
|
+
// The number of elements in each chunk
|
|
8166
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
8167
|
+
|
|
8168
|
+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
8169
|
+
int current_chunk = ith;
|
|
8170
|
+
|
|
8171
|
+
while (current_chunk < nchunk) {
|
|
8172
|
+
const int64_t ir0 = dr * current_chunk;
|
|
8173
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
8174
|
+
|
|
8175
|
+
wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
|
8176
|
+
|
|
8177
|
+
current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
8178
|
+
}
|
|
8179
|
+
}
|
|
8180
|
+
|
|
8150
8181
|
void wsp_ggml_compute_forward_flash_attn_ext(
|
|
8151
8182
|
const wsp_ggml_compute_params * params,
|
|
8152
8183
|
wsp_ggml_tensor * dst) {
|
|
@@ -8633,7 +8664,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8633
8664
|
// n_head
|
|
8634
8665
|
for (int h = ih0; h < ih1; ++h) {
|
|
8635
8666
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8636
|
-
const float dt_soft_plus =
|
|
8667
|
+
const float dt_soft_plus = wsp_ggml_compute_softplus_f32(dt[h]);
|
|
8637
8668
|
const float dA = expf(dt_soft_plus * A[h]);
|
|
8638
8669
|
const int g = h / (nh / ng); // repeat_interleave
|
|
8639
8670
|
|
|
@@ -8730,7 +8761,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8730
8761
|
// n_head
|
|
8731
8762
|
for (int h = ih0; h < ih1; ++h) {
|
|
8732
8763
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8733
|
-
const float dt_soft_plus =
|
|
8764
|
+
const float dt_soft_plus = wsp_ggml_compute_softplus_f32(dt[h]);
|
|
8734
8765
|
const int g = h / (nh / ng); // repeat_interleave
|
|
8735
8766
|
|
|
8736
8767
|
// dim
|
|
@@ -9013,6 +9044,14 @@ void wsp_ggml_compute_forward_unary(
|
|
|
9013
9044
|
{
|
|
9014
9045
|
wsp_ggml_compute_forward_xielu(params, dst);
|
|
9015
9046
|
} break;
|
|
9047
|
+
case WSP_GGML_UNARY_OP_EXPM1:
|
|
9048
|
+
{
|
|
9049
|
+
wsp_ggml_compute_forward_expm1(params, dst);
|
|
9050
|
+
} break;
|
|
9051
|
+
case WSP_GGML_UNARY_OP_SOFTPLUS:
|
|
9052
|
+
{
|
|
9053
|
+
wsp_ggml_compute_forward_softplus(params, dst);
|
|
9054
|
+
} break;
|
|
9016
9055
|
default:
|
|
9017
9056
|
{
|
|
9018
9057
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -9609,6 +9648,76 @@ void wsp_ggml_compute_forward_gla(
|
|
|
9609
9648
|
}
|
|
9610
9649
|
}
|
|
9611
9650
|
|
|
9651
|
+
static void wsp_ggml_compute_forward_solve_tri_f32(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst) {
|
|
9652
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
|
9653
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
|
9654
|
+
|
|
9655
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS;
|
|
9656
|
+
|
|
9657
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
9658
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
9659
|
+
WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F32);
|
|
9660
|
+
|
|
9661
|
+
WSP_GGML_ASSERT(ne00 == ne01); // A must be square
|
|
9662
|
+
WSP_GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
|
9663
|
+
WSP_GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
|
9664
|
+
|
|
9665
|
+
WSP_GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
|
9666
|
+
WSP_GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
|
9667
|
+
|
|
9668
|
+
const int ith = params->ith;
|
|
9669
|
+
const int nth = params->nth;
|
|
9670
|
+
|
|
9671
|
+
const int64_t k = ne10; // number of RHS columns
|
|
9672
|
+
const int64_t n = ne11; // A is n×n
|
|
9673
|
+
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
|
9674
|
+
|
|
9675
|
+
// chunks per thread
|
|
9676
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
|
9677
|
+
|
|
9678
|
+
// chunk range for this thread
|
|
9679
|
+
const int64_t ir0 = dr*ith;
|
|
9680
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
9681
|
+
|
|
9682
|
+
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
|
9683
|
+
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
|
9684
|
+
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
|
9685
|
+
|
|
9686
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
9687
|
+
const int64_t i03 = ir/(ne02*k);
|
|
9688
|
+
const int64_t i02 = (ir - i03*ne02*k)/k;
|
|
9689
|
+
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
|
9690
|
+
|
|
9691
|
+
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
|
9692
|
+
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
|
9693
|
+
|
|
9694
|
+
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
|
9695
|
+
|
|
9696
|
+
for (int64_t i00 = 0; i00 < n; ++i00) {
|
|
9697
|
+
float sum = 0.0f;
|
|
9698
|
+
for (int64_t t = 0; t < i00; ++t) {
|
|
9699
|
+
sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
|
|
9700
|
+
}
|
|
9701
|
+
|
|
9702
|
+
const float diag = A_batch[i00 * n + i00];
|
|
9703
|
+
WSP_GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
|
9704
|
+
|
|
9705
|
+
X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
|
|
9706
|
+
}
|
|
9707
|
+
}
|
|
9708
|
+
}
|
|
9709
|
+
|
|
9710
|
+
void wsp_ggml_compute_forward_solve_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst) {
|
|
9711
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
9712
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
9713
|
+
|
|
9714
|
+
if (src0->type == WSP_GGML_TYPE_F32 && src1->type == WSP_GGML_TYPE_F32) {
|
|
9715
|
+
wsp_ggml_compute_forward_solve_tri_f32(params, dst);
|
|
9716
|
+
} else {
|
|
9717
|
+
WSP_GGML_ABORT("fatal error");
|
|
9718
|
+
}
|
|
9719
|
+
}
|
|
9720
|
+
|
|
9612
9721
|
// wsp_ggml_compute_forward_rwkv_wkv7
|
|
9613
9722
|
|
|
9614
9723
|
static void wsp_ggml_compute_forward_rwkv_wkv7_f32(
|