whisper.rn 0.5.4 → 0.5.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
- package/android/src/main/jni.cpp +13 -0
- package/cpp/ggml-alloc.c +78 -26
- package/cpp/ggml-alloc.h +9 -0
- package/cpp/ggml-backend-impl.h +1 -1
- package/cpp/ggml-backend-reg.cpp +19 -3
- package/cpp/ggml-backend.cpp +72 -20
- package/cpp/ggml-backend.h +2 -1
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
- package/cpp/ggml-cpu/arch-fallback.h +50 -2
- package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
- package/cpp/ggml-cpu/ggml-cpu.c +139 -58
- package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/ggml-cpu/ops.cpp +170 -18
- package/cpp/ggml-cpu/ops.h +1 -0
- package/cpp/ggml-cpu/repack.cpp +531 -5
- package/cpp/ggml-cpu/repack.h +14 -0
- package/cpp/ggml-cpu/simd-mappings.h +16 -18
- package/cpp/ggml-cpu/vec.cpp +41 -1
- package/cpp/ggml-cpu/vec.h +241 -138
- package/cpp/ggml-cpu.h +1 -0
- package/cpp/ggml-impl.h +0 -4
- package/cpp/ggml-metal/ggml-metal-context.m +26 -16
- package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
- package/cpp/ggml-metal/ggml-metal-device.h +87 -65
- package/cpp/ggml-metal/ggml-metal-device.m +263 -104
- package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
- package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
- package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
- package/cpp/ggml-metal/ggml-metal.cpp +6 -5
- package/cpp/ggml-metal/ggml-metal.metal +404 -34
- package/cpp/ggml.c +110 -31
- package/cpp/ggml.h +51 -12
- package/cpp/jsi/RNWhisperJSI.cpp +1 -0
- package/cpp/whisper.cpp +16 -3
- package/ios/CMakeLists.txt +21 -1
- package/ios/RNWhisperContext.mm +5 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/jest-mock.js +2 -0
- package/lib/commonjs/jest-mock.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/jest-mock.js +2 -0
- package/lib/module/jest-mock.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +1 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +1 -0
- package/src/jest-mock.ts +2 -0
- package/src/version.json +1 -1
|
@@ -583,6 +583,10 @@ static wsp_ggml_backend_feature * wsp_ggml_backend_cpu_get_features(wsp_ggml_bac
|
|
|
583
583
|
if (wsp_ggml_cpu_has_riscv_v()) {
|
|
584
584
|
features.push_back({ "RISCV_V", "1" });
|
|
585
585
|
}
|
|
586
|
+
if (wsp_ggml_cpu_get_rvv_vlen() > 0) {
|
|
587
|
+
static std::string rvv_vlen = std::to_string(wsp_ggml_cpu_get_rvv_vlen());
|
|
588
|
+
features.push_back({ "RVV_VLEN", rvv_vlen.c_str() });
|
|
589
|
+
}
|
|
586
590
|
if (wsp_ggml_cpu_has_vsx()) {
|
|
587
591
|
features.push_back({ "VSX", "1" });
|
|
588
592
|
}
|
package/cpp/ggml-cpu/ops.cpp
CHANGED
|
@@ -6383,7 +6383,7 @@ static void wsp_ggml_compute_forward_im2col_3d_f16(
|
|
|
6383
6383
|
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
6384
6384
|
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
|
6385
6385
|
|
|
6386
|
-
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW
|
|
6386
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
6387
6387
|
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
|
6388
6388
|
} else {
|
|
6389
6389
|
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
|
@@ -6554,8 +6554,13 @@ static void wsp_ggml_call_mul_mat(wsp_ggml_type type, const wsp_ggml_compute_par
|
|
|
6554
6554
|
wsp_ggml_compute_forward_mul_mat(params, &dst);
|
|
6555
6555
|
}
|
|
6556
6556
|
|
|
6557
|
+
static inline int64_t wsp_ggml_wrap_around(int64_t coord, int64_t size) {
|
|
6558
|
+
return (coord + size) % size; // adding size avoids negative number weirdness
|
|
6559
|
+
}
|
|
6560
|
+
|
|
6557
6561
|
// wsp_ggml_compute_forward_conv_2d
|
|
6558
6562
|
|
|
6563
|
+
|
|
6559
6564
|
static void wsp_ggml_compute_forward_conv_2d_impl(const wsp_ggml_compute_params * params,
|
|
6560
6565
|
const wsp_ggml_tensor * kernel, // [KW, KH, IC, OC]
|
|
6561
6566
|
const wsp_ggml_tensor * src, // [W, H, C, N]
|
|
@@ -7420,6 +7425,65 @@ static void wsp_ggml_compute_forward_upscale_f32(
|
|
|
7420
7425
|
}
|
|
7421
7426
|
}
|
|
7422
7427
|
}
|
|
7428
|
+
} else if (mode == WSP_GGML_SCALE_MODE_BILINEAR && (mode_flags & WSP_GGML_SCALE_FLAG_ANTIALIAS)) {
|
|
7429
|
+
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
|
|
7430
|
+
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
|
|
7431
|
+
auto triangle_filter = [](float x) -> float {
|
|
7432
|
+
return std::max(1.0f - fabsf(x), 0.0f);
|
|
7433
|
+
};
|
|
7434
|
+
|
|
7435
|
+
// support and invscale, minimum 1 pixel for bilinear
|
|
7436
|
+
const float support1 = std::max(1.0f, 1.0f / sf1);
|
|
7437
|
+
const float invscale1 = 1.0f / support1;
|
|
7438
|
+
const float support0 = std::max(1.0f, 1.0f / sf0);
|
|
7439
|
+
const float invscale0 = 1.0f / support0;
|
|
7440
|
+
|
|
7441
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7442
|
+
const int64_t i03 = i3 / sf3;
|
|
7443
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7444
|
+
const int64_t i02 = i2 / sf2;
|
|
7445
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7446
|
+
const float y = ((float) i1 + pixel_offset) / sf1;
|
|
7447
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7448
|
+
const float x = ((float) i0 + pixel_offset) / sf0;
|
|
7449
|
+
|
|
7450
|
+
// the range of source pixels that contribute
|
|
7451
|
+
const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
|
|
7452
|
+
const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
|
|
7453
|
+
const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
|
|
7454
|
+
const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
|
|
7455
|
+
|
|
7456
|
+
// bilinear filter with antialiasing
|
|
7457
|
+
float val = 0.0f;
|
|
7458
|
+
float total_weight = 0.0f;
|
|
7459
|
+
|
|
7460
|
+
for (int64_t sy = y_min; sy < y_max; sy++) {
|
|
7461
|
+
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
|
|
7462
|
+
|
|
7463
|
+
for (int64_t sx = x_min; sx < x_max; sx++) {
|
|
7464
|
+
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
|
|
7465
|
+
const float weight = weight_x * weight_y;
|
|
7466
|
+
|
|
7467
|
+
if (weight <= 0.0f) {
|
|
7468
|
+
continue;
|
|
7469
|
+
}
|
|
7470
|
+
|
|
7471
|
+
const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
|
|
7472
|
+
val += pixel * weight;
|
|
7473
|
+
total_weight += weight;
|
|
7474
|
+
}
|
|
7475
|
+
}
|
|
7476
|
+
|
|
7477
|
+
if (total_weight > 0.0f) {
|
|
7478
|
+
val /= total_weight;
|
|
7479
|
+
}
|
|
7480
|
+
|
|
7481
|
+
float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7482
|
+
*dst_ptr = val;
|
|
7483
|
+
}
|
|
7484
|
+
}
|
|
7485
|
+
}
|
|
7486
|
+
}
|
|
7423
7487
|
} else if (mode == WSP_GGML_SCALE_MODE_BILINEAR) {
|
|
7424
7488
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7425
7489
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7532,6 +7596,7 @@ void wsp_ggml_compute_forward_upscale(
|
|
|
7532
7596
|
|
|
7533
7597
|
// wsp_ggml_compute_forward_pad
|
|
7534
7598
|
|
|
7599
|
+
template<bool circular_t>
|
|
7535
7600
|
static void wsp_ggml_compute_forward_pad_f32(
|
|
7536
7601
|
const wsp_ggml_compute_params * params,
|
|
7537
7602
|
wsp_ggml_tensor * dst) {
|
|
@@ -7556,23 +7621,40 @@ static void wsp_ggml_compute_forward_pad_f32(
|
|
|
7556
7621
|
const int32_t lp3 = wsp_ggml_get_op_params_i32(dst, 6);
|
|
7557
7622
|
const int32_t rp3 = wsp_ggml_get_op_params_i32(dst, 7);
|
|
7558
7623
|
|
|
7559
|
-
|
|
7560
7624
|
// TODO: optimize
|
|
7561
7625
|
|
|
7562
7626
|
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
|
7563
7627
|
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
|
7564
7628
|
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
|
7565
7629
|
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
|
7566
|
-
|
|
7567
|
-
if (
|
|
7568
|
-
|
|
7569
|
-
|
|
7570
|
-
|
|
7571
|
-
const int64_t
|
|
7630
|
+
// circular means wrap around on a torus, so x and y loop around
|
|
7631
|
+
if constexpr (circular_t) {
|
|
7632
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
7633
|
+
const int64_t src_i0 = wsp_ggml_wrap_around(i0 - lp0, ne00);
|
|
7634
|
+
const int64_t src_i1 = wsp_ggml_wrap_around(i1 - lp1, ne01);
|
|
7635
|
+
const int64_t src_i2 = wsp_ggml_wrap_around(i2 - lp2, ne02);
|
|
7636
|
+
const int64_t src_i3 = wsp_ggml_wrap_around(i3 - lp3, ne03);
|
|
7637
|
+
|
|
7638
|
+
const int64_t src_idx =
|
|
7639
|
+
src_i3*nb03 +
|
|
7640
|
+
src_i2*nb02 +
|
|
7641
|
+
src_i1*nb01 +
|
|
7642
|
+
src_i0*nb00;
|
|
7643
|
+
|
|
7572
7644
|
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
7573
7645
|
dst_ptr[dst_idx] = *src_ptr;
|
|
7574
7646
|
} else {
|
|
7575
|
-
|
|
7647
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
7648
|
+
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
|
7649
|
+
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
|
7650
|
+
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
|
7651
|
+
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
|
7652
|
+
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
|
7653
|
+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
7654
|
+
dst_ptr[dst_idx] = *src_ptr;
|
|
7655
|
+
} else {
|
|
7656
|
+
dst_ptr[dst_idx] = 0;
|
|
7657
|
+
}
|
|
7576
7658
|
}
|
|
7577
7659
|
}
|
|
7578
7660
|
}
|
|
@@ -7580,16 +7662,20 @@ static void wsp_ggml_compute_forward_pad_f32(
|
|
|
7580
7662
|
}
|
|
7581
7663
|
}
|
|
7582
7664
|
|
|
7665
|
+
|
|
7583
7666
|
void wsp_ggml_compute_forward_pad(
|
|
7584
7667
|
const wsp_ggml_compute_params * params,
|
|
7585
7668
|
wsp_ggml_tensor * dst) {
|
|
7586
|
-
|
|
7587
7669
|
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
7588
|
-
|
|
7670
|
+
const bool circular = (bool) wsp_ggml_get_op_params_i32(dst, 8);
|
|
7589
7671
|
switch (src0->type) {
|
|
7590
7672
|
case WSP_GGML_TYPE_F32:
|
|
7591
7673
|
{
|
|
7592
|
-
|
|
7674
|
+
if (circular) {
|
|
7675
|
+
wsp_ggml_compute_forward_pad_f32<true>(params, dst);
|
|
7676
|
+
} else {
|
|
7677
|
+
wsp_ggml_compute_forward_pad_f32<false>(params, dst);
|
|
7678
|
+
}
|
|
7593
7679
|
} break;
|
|
7594
7680
|
default:
|
|
7595
7681
|
{
|
|
@@ -7794,7 +7880,7 @@ void wsp_ggml_compute_forward_timestep_embedding(
|
|
|
7794
7880
|
// wsp_ggml_compute_forward_argsort
|
|
7795
7881
|
|
|
7796
7882
|
template<enum wsp_ggml_sort_order order>
|
|
7797
|
-
struct
|
|
7883
|
+
struct cmp_argsort {
|
|
7798
7884
|
const float * data;
|
|
7799
7885
|
bool operator()(int32_t a, int32_t b) const {
|
|
7800
7886
|
if constexpr (order == WSP_GGML_SORT_ORDER_ASC) {
|
|
@@ -7833,11 +7919,11 @@ static void wsp_ggml_compute_forward_argsort_f32(
|
|
|
7833
7919
|
|
|
7834
7920
|
switch (order) {
|
|
7835
7921
|
case WSP_GGML_SORT_ORDER_ASC:
|
|
7836
|
-
std::sort(dst_data, dst_data + ne0,
|
|
7922
|
+
std::sort(dst_data, dst_data + ne0, cmp_argsort<WSP_GGML_SORT_ORDER_ASC>{src_data});
|
|
7837
7923
|
break;
|
|
7838
7924
|
|
|
7839
7925
|
case WSP_GGML_SORT_ORDER_DESC:
|
|
7840
|
-
std::sort(dst_data, dst_data + ne0,
|
|
7926
|
+
std::sort(dst_data, dst_data + ne0, cmp_argsort<WSP_GGML_SORT_ORDER_DESC>{src_data});
|
|
7841
7927
|
break;
|
|
7842
7928
|
|
|
7843
7929
|
default:
|
|
@@ -7864,6 +7950,72 @@ void wsp_ggml_compute_forward_argsort(
|
|
|
7864
7950
|
}
|
|
7865
7951
|
}
|
|
7866
7952
|
|
|
7953
|
+
// wsp_ggml_compute_forward_top_k
|
|
7954
|
+
|
|
7955
|
+
struct cmp_top_k {
|
|
7956
|
+
const float * data;
|
|
7957
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
7958
|
+
return data[a] > data[b];
|
|
7959
|
+
}
|
|
7960
|
+
};
|
|
7961
|
+
|
|
7962
|
+
static void wsp_ggml_compute_forward_top_k_f32(
|
|
7963
|
+
const wsp_ggml_compute_params * params,
|
|
7964
|
+
wsp_ggml_tensor * dst) {
|
|
7965
|
+
|
|
7966
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
7967
|
+
|
|
7968
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
7969
|
+
|
|
7970
|
+
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
7971
|
+
|
|
7972
|
+
const int ith = params->ith;
|
|
7973
|
+
const int nth = params->nth;
|
|
7974
|
+
|
|
7975
|
+
const int64_t nr = wsp_ggml_nrows(src0);
|
|
7976
|
+
|
|
7977
|
+
const int top_k = ne0;
|
|
7978
|
+
|
|
7979
|
+
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
7980
|
+
|
|
7981
|
+
for (int64_t i = ith; i < nr; i += nth) {
|
|
7982
|
+
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
7983
|
+
|
|
7984
|
+
for (int64_t j = 0; j < ne00; j++) {
|
|
7985
|
+
tmp[j] = j;
|
|
7986
|
+
}
|
|
7987
|
+
|
|
7988
|
+
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
|
|
7989
|
+
|
|
7990
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7991
|
+
|
|
7992
|
+
std::copy(tmp, tmp + top_k, dst_data);
|
|
7993
|
+
|
|
7994
|
+
// emphasize that the order is not important
|
|
7995
|
+
if (top_k > 1) {
|
|
7996
|
+
std::swap(dst_data[0], dst_data[1]);
|
|
7997
|
+
}
|
|
7998
|
+
}
|
|
7999
|
+
}
|
|
8000
|
+
|
|
8001
|
+
void wsp_ggml_compute_forward_top_k(
|
|
8002
|
+
const wsp_ggml_compute_params * params,
|
|
8003
|
+
wsp_ggml_tensor * dst) {
|
|
8004
|
+
|
|
8005
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
8006
|
+
|
|
8007
|
+
switch (src0->type) {
|
|
8008
|
+
case WSP_GGML_TYPE_F32:
|
|
8009
|
+
{
|
|
8010
|
+
wsp_ggml_compute_forward_top_k_f32(params, dst);
|
|
8011
|
+
} break;
|
|
8012
|
+
default:
|
|
8013
|
+
{
|
|
8014
|
+
WSP_GGML_ABORT("fatal error");
|
|
8015
|
+
}
|
|
8016
|
+
}
|
|
8017
|
+
}
|
|
8018
|
+
|
|
7867
8019
|
// wsp_ggml_compute_forward_flash_attn_ext
|
|
7868
8020
|
|
|
7869
8021
|
static void wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
@@ -9696,13 +9848,13 @@ static void wsp_ggml_compute_forward_solve_tri_f32(const struct wsp_ggml_compute
|
|
|
9696
9848
|
for (int64_t i00 = 0; i00 < n; ++i00) {
|
|
9697
9849
|
float sum = 0.0f;
|
|
9698
9850
|
for (int64_t t = 0; t < i00; ++t) {
|
|
9699
|
-
sum += A_batch[i00 * n + t] * X_batch[
|
|
9851
|
+
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
|
|
9700
9852
|
}
|
|
9701
9853
|
|
|
9702
9854
|
const float diag = A_batch[i00 * n + i00];
|
|
9703
|
-
|
|
9855
|
+
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
|
|
9704
9856
|
|
|
9705
|
-
X_batch[
|
|
9857
|
+
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
|
9706
9858
|
}
|
|
9707
9859
|
}
|
|
9708
9860
|
}
|
package/cpp/ggml-cpu/ops.h
CHANGED
|
@@ -81,6 +81,7 @@ void wsp_ggml_compute_forward_roll(const struct wsp_ggml_compute_params * params
|
|
|
81
81
|
void wsp_ggml_compute_forward_arange(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
82
82
|
void wsp_ggml_compute_forward_timestep_embedding(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
83
83
|
void wsp_ggml_compute_forward_argsort(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
84
|
+
void wsp_ggml_compute_forward_top_k(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
84
85
|
void wsp_ggml_compute_forward_leaky_relu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
85
86
|
void wsp_ggml_compute_forward_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
86
87
|
void wsp_ggml_compute_forward_fill(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|