whisper.rn 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +49 -18
- package/cpp/ggml-backend-impl.h +0 -3
- package/cpp/ggml-backend-reg.cpp +8 -0
- package/cpp/ggml-backend.cpp +0 -2
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
- package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
- package/cpp/ggml-cpu/ggml-cpu.c +67 -24
- package/cpp/ggml-cpu/ops.cpp +489 -364
- package/cpp/ggml-cpu/ops.h +4 -4
- package/cpp/ggml-cpu/repack.cpp +143 -29
- package/cpp/ggml-cpu/simd-mappings.h +25 -25
- package/cpp/ggml-cpu/unary-ops.cpp +151 -0
- package/cpp/ggml-cpu/unary-ops.h +7 -0
- package/cpp/ggml-cpu/vec.cpp +83 -0
- package/cpp/ggml-cpu/vec.h +20 -8
- package/cpp/ggml-impl.h +67 -2
- package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
- package/cpp/ggml-metal/ggml-metal-context.m +5 -6
- package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
- package/cpp/ggml-metal/ggml-metal-device.h +26 -1
- package/cpp/ggml-metal/ggml-metal-device.m +243 -28
- package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
- package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
- package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
- package/cpp/ggml-metal/ggml-metal.cpp +8 -3
- package/cpp/ggml-metal/ggml-metal.metal +12436 -0
- package/cpp/ggml.c +317 -4
- package/cpp/ggml.h +139 -0
- package/cpp/jsi/RNWhisperJSI.cpp +7 -2
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +8 -2
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -1
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
package/cpp/ggml-cpu/ops.h
CHANGED
|
@@ -34,6 +34,7 @@ void wsp_ggml_compute_forward_add1(const struct wsp_ggml_compute_params * params
|
|
|
34
34
|
void wsp_ggml_compute_forward_acc(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
35
35
|
void wsp_ggml_compute_forward_sum(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
36
36
|
void wsp_ggml_compute_forward_sum_rows(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
37
|
+
void wsp_ggml_compute_forward_cumsum(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
37
38
|
void wsp_ggml_compute_forward_mean(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
38
39
|
void wsp_ggml_compute_forward_argmax(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
39
40
|
void wsp_ggml_compute_forward_count_equal(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
@@ -51,10 +52,6 @@ void wsp_ggml_compute_forward_scale(const struct wsp_ggml_compute_params * param
|
|
|
51
52
|
void wsp_ggml_compute_forward_set(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
52
53
|
void wsp_ggml_compute_forward_cpy(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
53
54
|
void wsp_ggml_compute_forward_cont(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
54
|
-
void wsp_ggml_compute_forward_reshape(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
55
|
-
void wsp_ggml_compute_forward_view(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
56
|
-
void wsp_ggml_compute_forward_permute(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
57
|
-
void wsp_ggml_compute_forward_transpose(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
58
55
|
void wsp_ggml_compute_forward_get_rows(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
59
56
|
void wsp_ggml_compute_forward_get_rows_back(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
60
57
|
void wsp_ggml_compute_forward_set_rows(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
@@ -85,6 +82,8 @@ void wsp_ggml_compute_forward_arange(const struct wsp_ggml_compute_params * para
|
|
|
85
82
|
void wsp_ggml_compute_forward_timestep_embedding(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
86
83
|
void wsp_ggml_compute_forward_argsort(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
87
84
|
void wsp_ggml_compute_forward_leaky_relu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
85
|
+
void wsp_ggml_compute_forward_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
86
|
+
void wsp_ggml_compute_forward_fill(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
88
87
|
void wsp_ggml_compute_forward_flash_attn_ext(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
89
88
|
void wsp_ggml_compute_forward_flash_attn_back(
|
|
90
89
|
const struct wsp_ggml_compute_params * params,
|
|
@@ -100,6 +99,7 @@ void wsp_ggml_compute_forward_get_rel_pos(const struct wsp_ggml_compute_params *
|
|
|
100
99
|
void wsp_ggml_compute_forward_add_rel_pos(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
101
100
|
void wsp_ggml_compute_forward_rwkv_wkv6(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
102
101
|
void wsp_ggml_compute_forward_rwkv_wkv7(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
102
|
+
void wsp_ggml_compute_forward_solve_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
103
103
|
void wsp_ggml_compute_forward_gla(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
104
104
|
void wsp_ggml_compute_forward_map_custom1(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
105
105
|
void wsp_ggml_compute_forward_map_custom2(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
package/cpp/ggml-cpu/repack.cpp
CHANGED
|
@@ -1600,6 +1600,55 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
|
|
|
1600
1600
|
return false;
|
|
1601
1601
|
}
|
|
1602
1602
|
|
|
1603
|
+
void forward_mul_mat_one_chunk(wsp_ggml_compute_params * params,
|
|
1604
|
+
wsp_ggml_tensor * op,
|
|
1605
|
+
int64_t src0_start,
|
|
1606
|
+
int64_t src0_end,
|
|
1607
|
+
int64_t src1_start,
|
|
1608
|
+
int64_t src1_end) {
|
|
1609
|
+
const wsp_ggml_tensor * src0 = op->src[0];
|
|
1610
|
+
const wsp_ggml_tensor * src1 = op->src[1];
|
|
1611
|
+
wsp_ggml_tensor * dst = op;
|
|
1612
|
+
|
|
1613
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
1614
|
+
|
|
1615
|
+
const size_t src1_col_stride = wsp_ggml_row_size(PARAM_TYPE, ne10);
|
|
1616
|
+
|
|
1617
|
+
WSP_GGML_ASSERT(ne03 == 1 && ne13 == 1);
|
|
1618
|
+
WSP_GGML_ASSERT(ne12 % ne02 == 0);
|
|
1619
|
+
const int64_t r2 = ne12 / ne02;
|
|
1620
|
+
|
|
1621
|
+
const int64_t i12 = src1_start / ne1;
|
|
1622
|
+
const int64_t i11 = src1_start - i12 * ne1;
|
|
1623
|
+
|
|
1624
|
+
// Determine batch index
|
|
1625
|
+
const int64_t i02 = i12 / r2;
|
|
1626
|
+
|
|
1627
|
+
const int64_t i1 = i11;
|
|
1628
|
+
const int64_t i2 = i12;
|
|
1629
|
+
|
|
1630
|
+
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
|
|
1631
|
+
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
|
|
1632
|
+
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
|
|
1633
|
+
|
|
1634
|
+
const int64_t nrows = src1_end - src1_start;
|
|
1635
|
+
const int64_t ncols = src0_end - src0_start;
|
|
1636
|
+
|
|
1637
|
+
WSP_GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
|
|
1638
|
+
|
|
1639
|
+
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
|
1640
|
+
if (nrows > 3) {
|
|
1641
|
+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
|
|
1642
|
+
src0_ptr + src0_start * nb01, src1_ptr,
|
|
1643
|
+
nrows - (nrows % 4), ncols);
|
|
1644
|
+
}
|
|
1645
|
+
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
|
|
1646
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
|
|
1647
|
+
ne01, src0_ptr + src0_start * nb01,
|
|
1648
|
+
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
|
|
1649
|
+
}
|
|
1650
|
+
}
|
|
1651
|
+
|
|
1603
1652
|
void forward_mul_mat(wsp_ggml_compute_params * params, wsp_ggml_tensor * op) {
|
|
1604
1653
|
const wsp_ggml_tensor * src0 = op->src[0];
|
|
1605
1654
|
const wsp_ggml_tensor * src1 = op->src[1];
|
|
@@ -1621,6 +1670,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
|
|
|
1621
1670
|
WSP_GGML_ASSERT(nb1 <= nb2);
|
|
1622
1671
|
WSP_GGML_ASSERT(nb2 <= nb3);
|
|
1623
1672
|
|
|
1673
|
+
// TODO: General batched mul mat for 4D tensors
|
|
1674
|
+
// Currently only supports 3D tensors
|
|
1675
|
+
WSP_GGML_ASSERT(ne03 == 1);
|
|
1676
|
+
WSP_GGML_ASSERT(ne13 == 1);
|
|
1677
|
+
WSP_GGML_ASSERT(ne3 == 1);
|
|
1678
|
+
|
|
1624
1679
|
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
1625
1680
|
|
|
1626
1681
|
WSP_GGML_ASSERT(wsp_ggml_n_dims(op->src[0]) == 2);
|
|
@@ -1628,46 +1683,101 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
|
|
|
1628
1683
|
|
|
1629
1684
|
char * wdata = static_cast<char *>(params->wdata);
|
|
1630
1685
|
const size_t nbw1 = wsp_ggml_row_size(PARAM_TYPE, ne10);
|
|
1686
|
+
const size_t nbw2 = nbw1 * ne11;
|
|
1631
1687
|
|
|
1632
|
-
assert(params->wsize >=
|
|
1688
|
+
assert(params->wsize >= nbw2 * ne12);
|
|
1633
1689
|
|
|
1634
1690
|
const wsp_ggml_from_float_t from_float = wsp_ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
|
1635
1691
|
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
1692
|
+
// INFO: Quantization is done in planes to avoid extra complexity in chunking.
|
|
1693
|
+
// Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
|
|
1694
|
+
// the planes are broadcast.
|
|
1695
|
+
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
|
1696
|
+
char * data_ptr = (char *) src1->data + i12 * nb12;
|
|
1697
|
+
char * wdata_ptr = wdata + i12 * nbw2;
|
|
1698
|
+
|
|
1699
|
+
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
|
1700
|
+
wsp_ggml_wsp_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
|
1701
|
+
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
|
|
1702
|
+
}
|
|
1640
1703
|
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1704
|
+
const int64_t i11_processed = ne11 - ne11 % 4;
|
|
1705
|
+
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
|
1706
|
+
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
|
|
1707
|
+
}
|
|
1644
1708
|
}
|
|
1645
1709
|
|
|
1646
|
-
|
|
1710
|
+
// disable for NUMA
|
|
1711
|
+
const bool disable_chunking = wsp_ggml_is_numa();
|
|
1647
1712
|
|
|
1648
|
-
|
|
1649
|
-
const
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
|
|
1713
|
+
// 4x chunks per thread
|
|
1714
|
+
const int64_t nr0 = wsp_ggml_nrows(op->src[0]);
|
|
1715
|
+
|
|
1716
|
+
int nth_scaled = nth * 4;
|
|
1717
|
+
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
|
|
1718
|
+
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
|
|
1719
|
+
|
|
1720
|
+
// src1 is chunked only by full planes.
|
|
1721
|
+
// When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
|
|
1722
|
+
// to route them thorugh GEMV.
|
|
1723
|
+
// nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
|
|
1724
|
+
// to avoid affecting their performance
|
|
1725
|
+
int64_t nchunk1 = ne12;
|
|
1726
|
+
|
|
1727
|
+
// Ensure minimum chunk size to avoid alignment issues with high thread counts
|
|
1728
|
+
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
|
|
1729
|
+
const int64_t min_chunk_size = NB_COLS;
|
|
1730
|
+
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
|
|
1731
|
+
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
|
1656
1732
|
}
|
|
1657
1733
|
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
|
1661
|
-
(float *) ((char *) dst->data) + src0_start, ne01,
|
|
1662
|
-
(const char *) src0->data + src0_start * nb01,
|
|
1663
|
-
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
|
1734
|
+
if (nth == 1 || nchunk0 < nth || disable_chunking) {
|
|
1735
|
+
nchunk0 = nth;
|
|
1664
1736
|
}
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1737
|
+
|
|
1738
|
+
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
|
1739
|
+
|
|
1740
|
+
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
|
|
1741
|
+
// This prevents creating too many tiny chunks that could overlap after alignment
|
|
1742
|
+
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
|
1743
|
+
nchunk0 = MIN(nchunk0, max_nchunk);
|
|
1744
|
+
|
|
1745
|
+
if (ith == 0) {
|
|
1746
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
1747
|
+
wsp_ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
1748
|
+
}
|
|
1749
|
+
|
|
1750
|
+
wsp_ggml_barrier(params->threadpool);
|
|
1751
|
+
|
|
1752
|
+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
1753
|
+
int current_chunk = ith;
|
|
1754
|
+
|
|
1755
|
+
while (current_chunk < nchunk0 * nchunk1) {
|
|
1756
|
+
const int64_t ith0 = current_chunk % nchunk0;
|
|
1757
|
+
const int64_t ith1 = current_chunk / nchunk0;
|
|
1758
|
+
|
|
1759
|
+
int64_t src0_start = dr0 * ith0;
|
|
1760
|
+
int64_t src0_end = MIN(src0_start + dr0, nr0);
|
|
1761
|
+
|
|
1762
|
+
// full-plane range for src1
|
|
1763
|
+
int64_t src1_start = ith1 * ne11;
|
|
1764
|
+
int64_t src1_end = (ith1 + 1) * ne11;
|
|
1765
|
+
|
|
1766
|
+
// Align boundaries to NB_COLS - round up to ensure all data is included
|
|
1767
|
+
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
|
|
1768
|
+
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
|
1769
|
+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
|
1770
|
+
src0_end = MIN(src0_end, ne01);
|
|
1771
|
+
|
|
1772
|
+
// Make sure current plane is the last one before exiting
|
|
1773
|
+
if (src0_start >= src0_end) {
|
|
1774
|
+
current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
1775
|
+
continue;
|
|
1776
|
+
}
|
|
1777
|
+
|
|
1778
|
+
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
|
|
1779
|
+
|
|
1780
|
+
current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
1671
1781
|
}
|
|
1672
1782
|
}
|
|
1673
1783
|
|
|
@@ -1772,8 +1882,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
|
|
|
1772
1882
|
int64_t src0_cur_start = (ith * ne01) / nth;
|
|
1773
1883
|
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
|
1774
1884
|
|
|
1885
|
+
// Align boundaries to NB_COLS - round up to ensure all data is included
|
|
1775
1886
|
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
|
1776
1887
|
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
|
1888
|
+
if (src0_cur_end > ne01) {
|
|
1889
|
+
src0_cur_end = ne01;
|
|
1890
|
+
}
|
|
1777
1891
|
|
|
1778
1892
|
if (src0_cur_start >= src0_cur_end) {
|
|
1779
1893
|
return;
|
|
@@ -956,7 +956,7 @@ do { \
|
|
|
956
956
|
|
|
957
957
|
#define WSP_GGML_F32Cx8 __m256
|
|
958
958
|
#define WSP_GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
|
959
|
-
#define WSP_GGML_F32Cx8_SET1(x) (__m256)
|
|
959
|
+
#define WSP_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
|
|
960
960
|
|
|
961
961
|
static inline __m256 __lasx_f32cx8_load(const wsp_ggml_fp16_t * x) {
|
|
962
962
|
__m256i a;
|
|
@@ -999,34 +999,34 @@ static inline void __lasx_f32cx8_store(wsp_ggml_fp16_t * x, __m256 y) {
|
|
|
999
999
|
|
|
1000
1000
|
#define WSP_GGML_F32x4 __m128
|
|
1001
1001
|
#define WSP_GGML_F32x4_ZERO (__m128)__lsx_vldi(0)
|
|
1002
|
-
#define WSP_GGML_F32x4_SET1(x) (__m128)
|
|
1002
|
+
#define WSP_GGML_F32x4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
|
|
1003
1003
|
#define WSP_GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0)
|
|
1004
1004
|
#define WSP_GGML_F32x4_STORE(x, y) __lsx_vst(y, x, 0)
|
|
1005
1005
|
#define WSP_GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
|
|
1006
1006
|
#define WSP_GGML_F32x4_ADD __lsx_vfadd_s
|
|
1007
1007
|
#define WSP_GGML_F32x4_MUL __lsx_vfmul_s
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
res
|
|
1008
|
+
|
|
1009
|
+
#define WSP_GGML_F32x4_REDUCE(res, x) \
|
|
1010
|
+
{ \
|
|
1011
|
+
int offset = WSP_GGML_F32_ARR >> 1; \
|
|
1012
|
+
for (int i = 0; i < offset; ++i) { \
|
|
1013
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
|
1014
|
+
} \
|
|
1015
|
+
offset >>= 1; \
|
|
1016
|
+
for (int i = 0; i < offset; ++i) { \
|
|
1017
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
|
1018
|
+
} \
|
|
1019
|
+
offset >>= 1; \
|
|
1020
|
+
for (int i = 0; i < offset; ++i) { \
|
|
1021
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
|
1022
|
+
} \
|
|
1023
|
+
__m128i t0 = __lsx_vpickev_w((__m128i)x[0], (__m128i)x[0]); \
|
|
1024
|
+
__m128i t1 = __lsx_vpickod_w((__m128i)x[0], (__m128i)x[0]); \
|
|
1025
|
+
__m128 t2 = __lsx_vfadd_s((__m128)t0, (__m128)t1); \
|
|
1026
|
+
__m128i t3 = __lsx_vpickev_w((__m128i)t2, (__m128i)t2); \
|
|
1027
|
+
__m128i t4 = __lsx_vpickod_w((__m128i)t2, (__m128i)t2); \
|
|
1028
|
+
__m128 t5 = __lsx_vfadd_s((__m128)t3, (__m128)t4); \
|
|
1029
|
+
res = (wsp_ggml_float) ((v4f32)t5)[0]; \
|
|
1030
1030
|
}
|
|
1031
1031
|
|
|
1032
1032
|
#define WSP_GGML_F32_VEC WSP_GGML_F32x4
|
|
@@ -1068,7 +1068,7 @@ static inline void __lsx_f16x4_store(wsp_ggml_fp16_t * x, __m128 y) {
|
|
|
1068
1068
|
|
|
1069
1069
|
#define WSP_GGML_F32Cx4 __m128
|
|
1070
1070
|
#define WSP_GGML_F32Cx4_ZERO (__m128)__lsx_vldi(0)
|
|
1071
|
-
#define WSP_GGML_F32Cx4_SET1(x) (__m128)
|
|
1071
|
+
#define WSP_GGML_F32Cx4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
|
|
1072
1072
|
#define WSP_GGML_F32Cx4_LOAD(x) (__m128)__lsx_f16x4_load(x)
|
|
1073
1073
|
#define WSP_GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
|
|
1074
1074
|
#define WSP_GGML_F32Cx4_FMA WSP_GGML_F32x4_FMA
|
|
@@ -52,6 +52,15 @@ static inline float op_sqrt(float x) {
|
|
|
52
52
|
return sqrtf(x);
|
|
53
53
|
}
|
|
54
54
|
|
|
55
|
+
static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
|
|
56
|
+
if (x > 0.0f) {
|
|
57
|
+
return alpha_p * x * x + beta * x;
|
|
58
|
+
} else {
|
|
59
|
+
const float min_x_eps = fminf(x, eps);
|
|
60
|
+
return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
55
64
|
static inline float op_sin(float x) {
|
|
56
65
|
return sinf(x);
|
|
57
66
|
}
|
|
@@ -64,6 +73,30 @@ static inline float op_log(float x) {
|
|
|
64
73
|
return logf(x);
|
|
65
74
|
}
|
|
66
75
|
|
|
76
|
+
static inline float op_expm1(float x) {
|
|
77
|
+
return expf(x) - 1.0f;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
static inline float op_softplus(float x) {
|
|
81
|
+
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
static inline float op_floor(float x) {
|
|
85
|
+
return floorf(x);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
static inline float op_ceil(float x) {
|
|
89
|
+
return ceilf(x);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
static inline float op_round(float x) {
|
|
93
|
+
return roundf(x);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
static inline float op_trunc(float x) {
|
|
97
|
+
return truncf(x);
|
|
98
|
+
}
|
|
99
|
+
|
|
67
100
|
template <float (*op)(float), typename src0_t, typename dst_t>
|
|
68
101
|
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
|
|
69
102
|
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
|
@@ -121,6 +154,86 @@ static void unary_op(const wsp_ggml_compute_params * params, wsp_ggml_tensor * d
|
|
|
121
154
|
}
|
|
122
155
|
}
|
|
123
156
|
|
|
157
|
+
template <float (*op)(float, wsp_ggml_tensor *)>
|
|
158
|
+
static void unary_op_params(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
159
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
160
|
+
|
|
161
|
+
/* */ if (src0->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) { // all f32
|
|
162
|
+
apply_unary_op<op, float, float>(params, dst);
|
|
163
|
+
} else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F16) { // all f16
|
|
164
|
+
apply_unary_op<op, wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst);
|
|
165
|
+
} else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_BF16) { // all bf16
|
|
166
|
+
apply_unary_op<op, wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst);
|
|
167
|
+
} else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_F32) {
|
|
168
|
+
apply_unary_op<op, wsp_ggml_bf16_t, float>(params, dst);
|
|
169
|
+
} else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F32) {
|
|
170
|
+
apply_unary_op<op, wsp_ggml_fp16_t, float>(params, dst);
|
|
171
|
+
} else {
|
|
172
|
+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
|
|
173
|
+
wsp_ggml_type_name(dst->type), wsp_ggml_type_name(src0->type));
|
|
174
|
+
WSP_GGML_ABORT("fatal error");
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
// Extend vec_unary_op to support functors
|
|
179
|
+
template <typename Op, typename src0_t, typename dst_t>
|
|
180
|
+
static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
|
|
181
|
+
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
|
182
|
+
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
|
|
183
|
+
|
|
184
|
+
for (int i = 0; i < n; i++) {
|
|
185
|
+
y[i] = f32_to_dst(op(src0_to_f32(x[i])));
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// Extend apply_unary_op to support functors
|
|
190
|
+
template <typename Op, typename src0_t, typename dst_t>
|
|
191
|
+
static void apply_unary_op_functor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst, Op op) {
|
|
192
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
193
|
+
|
|
194
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0) && wsp_ggml_is_contiguous_1(dst) && wsp_ggml_are_same_shape(src0, dst));
|
|
195
|
+
|
|
196
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
197
|
+
|
|
198
|
+
WSP_GGML_ASSERT( nb0 == sizeof(dst_t));
|
|
199
|
+
WSP_GGML_ASSERT(nb00 == sizeof(src0_t));
|
|
200
|
+
|
|
201
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
202
|
+
|
|
203
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
204
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
205
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
206
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
207
|
+
|
|
208
|
+
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
209
|
+
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
210
|
+
|
|
211
|
+
vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
// Generic dispatcher for functors
|
|
216
|
+
template <typename Op>
|
|
217
|
+
static void unary_op_functor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst, Op op) {
|
|
218
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
219
|
+
|
|
220
|
+
/* */ if (src0->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) { // all f32
|
|
221
|
+
apply_unary_op_functor<Op, float, float>(params, dst, op);
|
|
222
|
+
} else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F16) { // all f16
|
|
223
|
+
apply_unary_op_functor<Op, wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst, op);
|
|
224
|
+
} else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_BF16) { // all bf16
|
|
225
|
+
apply_unary_op_functor<Op, wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst, op);
|
|
226
|
+
} else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_F32) {
|
|
227
|
+
apply_unary_op_functor<Op, wsp_ggml_bf16_t, float>(params, dst, op);
|
|
228
|
+
} else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F32) {
|
|
229
|
+
apply_unary_op_functor<Op, wsp_ggml_fp16_t, float>(params, dst, op);
|
|
230
|
+
} else {
|
|
231
|
+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
|
|
232
|
+
wsp_ggml_type_name(dst->type), wsp_ggml_type_name(src0->type));
|
|
233
|
+
WSP_GGML_ABORT("fatal error");
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
124
237
|
void wsp_ggml_compute_forward_abs(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
125
238
|
unary_op<op_abs>(params, dst);
|
|
126
239
|
}
|
|
@@ -184,3 +297,41 @@ void wsp_ggml_compute_forward_cos(const wsp_ggml_compute_params * params, wsp_gg
|
|
|
184
297
|
void wsp_ggml_compute_forward_log(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
185
298
|
unary_op<op_log>(params, dst);
|
|
186
299
|
}
|
|
300
|
+
|
|
301
|
+
void wsp_ggml_compute_forward_expm1(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
302
|
+
unary_op<op_expm1>(params, dst);
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
void wsp_ggml_compute_forward_softplus(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
306
|
+
unary_op<op_softplus>(params, dst);
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
void wsp_ggml_compute_forward_floor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
310
|
+
unary_op<op_floor>(params, dst);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
void wsp_ggml_compute_forward_ceil(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
314
|
+
unary_op<op_ceil>(params, dst);
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
void wsp_ggml_compute_forward_round(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
318
|
+
unary_op<op_round>(params, dst);
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
void wsp_ggml_compute_forward_trunc(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
322
|
+
unary_op<op_trunc>(params, dst);
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
void wsp_ggml_compute_forward_xielu(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
326
|
+
const float alpha_n = wsp_ggml_get_op_params_f32(dst, 1);
|
|
327
|
+
const float alpha_p = wsp_ggml_get_op_params_f32(dst, 2);
|
|
328
|
+
const float beta = wsp_ggml_get_op_params_f32(dst, 3);
|
|
329
|
+
const float eps = wsp_ggml_get_op_params_f32(dst, 4);
|
|
330
|
+
|
|
331
|
+
const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
|
|
332
|
+
return op_xielu(f, alpha_n, alpha_p, beta, eps);
|
|
333
|
+
};
|
|
334
|
+
|
|
335
|
+
unary_op_functor(params, dst, xielu_op_params);
|
|
336
|
+
}
|
|
337
|
+
|
package/cpp/ggml-cpu/unary-ops.h
CHANGED
|
@@ -22,6 +22,13 @@ void wsp_ggml_compute_forward_sqrt(const struct wsp_ggml_compute_params * params
|
|
|
22
22
|
void wsp_ggml_compute_forward_sin(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
23
23
|
void wsp_ggml_compute_forward_cos(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
24
24
|
void wsp_ggml_compute_forward_log(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
25
|
+
void wsp_ggml_compute_forward_expm1(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
26
|
+
void wsp_ggml_compute_forward_softplus(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
27
|
+
void wsp_ggml_compute_forward_floor(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
28
|
+
void wsp_ggml_compute_forward_ceil(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
29
|
+
void wsp_ggml_compute_forward_round(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
30
|
+
void wsp_ggml_compute_forward_trunc(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
31
|
+
void wsp_ggml_compute_forward_xielu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
25
32
|
|
|
26
33
|
#ifdef __cplusplus
|
|
27
34
|
}
|
package/cpp/ggml-cpu/vec.cpp
CHANGED
|
@@ -360,6 +360,13 @@ void wsp_ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|
|
360
360
|
for (; i + 3 < n; i += 4) {
|
|
361
361
|
vst1q_f32(y + i, wsp_ggml_v_silu(vld1q_f32(x + i)));
|
|
362
362
|
}
|
|
363
|
+
#elif defined(__riscv_v_intrinsic)
|
|
364
|
+
for (int vl; i < n; i += vl) {
|
|
365
|
+
vl = __riscv_vsetvl_e32m2(n - i);
|
|
366
|
+
vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
|
|
367
|
+
vfloat32m2_t vy = wsp_ggml_v_silu_m2(vx, vl);
|
|
368
|
+
__riscv_vse32_v_f32m2(&y[i], vy, vl);
|
|
369
|
+
}
|
|
363
370
|
#endif
|
|
364
371
|
for (; i < n; ++i) {
|
|
365
372
|
y[i] = wsp_ggml_silu_f32(x[i]);
|
|
@@ -404,6 +411,82 @@ void wsp_ggml_vec_swiglu_f32(const int n, float * y, const float * x, const floa
|
|
|
404
411
|
}
|
|
405
412
|
}
|
|
406
413
|
|
|
414
|
+
wsp_ggml_float wsp_ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
|
|
415
|
+
int i = 0;
|
|
416
|
+
wsp_ggml_float sum = 0;
|
|
417
|
+
// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
|
|
418
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
|
|
419
|
+
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
|
420
|
+
for (; i + 15 < n; i += 16) {
|
|
421
|
+
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
|
|
422
|
+
_mm512_set1_ps(mean));
|
|
423
|
+
_mm512_storeu_ps(y + i, val);
|
|
424
|
+
sum += (wsp_ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
|
|
425
|
+
}
|
|
426
|
+
#elif defined(__AVX2__) && defined(__FMA__)
|
|
427
|
+
for (; i + 7 < n; i += 8) {
|
|
428
|
+
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
|
|
429
|
+
_mm256_set1_ps(mean));
|
|
430
|
+
_mm256_storeu_ps(y + i, val);
|
|
431
|
+
val = _mm256_mul_ps(val,val);
|
|
432
|
+
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
|
|
433
|
+
_mm256_castps256_ps128(val));
|
|
434
|
+
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
|
|
435
|
+
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
|
|
436
|
+
sum += (wsp_ggml_float)_mm_cvtss_f32(val2);
|
|
437
|
+
}
|
|
438
|
+
#elif defined(__SSE2__)
|
|
439
|
+
for (; i + 3 < n; i += 4) {
|
|
440
|
+
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
|
|
441
|
+
_mm_set1_ps(mean));
|
|
442
|
+
_mm_storeu_ps(y + i, val);
|
|
443
|
+
val = _mm_mul_ps(val, val);
|
|
444
|
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
|
445
|
+
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
|
|
446
|
+
val = _mm_add_ss(val, _mm_movehdup_ps(val));
|
|
447
|
+
#else
|
|
448
|
+
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
|
|
449
|
+
val = _mm_add_ps(val, tmp);
|
|
450
|
+
tmp = _mm_movehl_ps(tmp, val);
|
|
451
|
+
val = _mm_add_ss(val, tmp);
|
|
452
|
+
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
|
453
|
+
sum += (wsp_ggml_float)_mm_cvtss_f32(val);
|
|
454
|
+
}
|
|
455
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
456
|
+
for (; i + 3 < n; i += 4) {
|
|
457
|
+
float32x4_t val = vsubq_f32(vld1q_f32(x + i),
|
|
458
|
+
vdupq_n_f32(mean));
|
|
459
|
+
vst1q_f32(y + i, val);
|
|
460
|
+
val = vmulq_f32(val, val);
|
|
461
|
+
sum += (wsp_ggml_float)vaddvq_f32(val);
|
|
462
|
+
}
|
|
463
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
464
|
+
for (; i + 3 < n; i += 4) {
|
|
465
|
+
float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
|
|
466
|
+
vec_xst(val, 0, y + i);
|
|
467
|
+
val = vec_mul(val, val);
|
|
468
|
+
sum += (wsp_ggml_float)vec_hsum_f32x4(val);
|
|
469
|
+
}
|
|
470
|
+
#elif defined(__riscv_v_intrinsic)
|
|
471
|
+
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
|
|
472
|
+
for (int vl; i < n; i += vl) {
|
|
473
|
+
vl = __riscv_vsetvl_e32m2(n - i);
|
|
474
|
+
vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], vl), mean, vl);
|
|
475
|
+
__riscv_vse32_v_f32m2(&y[i], val, vl);
|
|
476
|
+
val = __riscv_vfmul_vv_f32m2(val, val, vl);
|
|
477
|
+
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, vl);
|
|
478
|
+
}
|
|
479
|
+
sum = (wsp_ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
|
|
480
|
+
#endif
|
|
481
|
+
for (; i < n; ++i) {
|
|
482
|
+
float val = x[i] - mean;
|
|
483
|
+
y[i] = val;
|
|
484
|
+
val *= val;
|
|
485
|
+
sum += (wsp_ggml_float)val;
|
|
486
|
+
}
|
|
487
|
+
return sum/n;
|
|
488
|
+
}
|
|
489
|
+
|
|
407
490
|
wsp_ggml_float wsp_ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
|
408
491
|
int i = 0;
|
|
409
492
|
wsp_ggml_float sum = 0;
|