whisper.rn 0.4.0-rc.9 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +5 -1
- package/android/build.gradle +12 -3
- package/android/src/main/CMakeLists.txt +43 -13
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +33 -35
- package/android/src/main/jni.cpp +9 -0
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
- package/cpp/coreml/whisper-compat.h +10 -0
- package/cpp/coreml/whisper-compat.m +35 -0
- package/cpp/coreml/whisper-decoder-impl.h +27 -15
- package/cpp/coreml/whisper-decoder-impl.m +36 -10
- package/cpp/coreml/whisper-encoder-impl.h +21 -9
- package/cpp/coreml/whisper-encoder-impl.m +29 -3
- package/cpp/ggml-alloc.c +39 -37
- package/cpp/ggml-alloc.h +1 -1
- package/cpp/ggml-backend-impl.h +55 -27
- package/cpp/ggml-backend-reg.cpp +591 -0
- package/cpp/ggml-backend.cpp +336 -955
- package/cpp/ggml-backend.h +70 -42
- package/cpp/ggml-common.h +57 -49
- package/cpp/ggml-cpp.h +39 -0
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
- package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
- package/cpp/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/ggml-cpu/binary-ops.cpp +158 -0
- package/cpp/ggml-cpu/binary-ops.h +16 -0
- package/cpp/ggml-cpu/common.h +72 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
- package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
- package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
- package/cpp/ggml-cpu/ops.cpp +9085 -0
- package/cpp/ggml-cpu/ops.h +111 -0
- package/cpp/ggml-cpu/quants.c +1157 -0
- package/cpp/ggml-cpu/quants.h +89 -0
- package/cpp/ggml-cpu/repack.cpp +1570 -0
- package/cpp/ggml-cpu/repack.h +98 -0
- package/cpp/ggml-cpu/simd-mappings.h +1006 -0
- package/cpp/ggml-cpu/traits.cpp +36 -0
- package/cpp/ggml-cpu/traits.h +38 -0
- package/cpp/ggml-cpu/unary-ops.cpp +186 -0
- package/cpp/ggml-cpu/unary-ops.h +28 -0
- package/cpp/ggml-cpu/vec.cpp +321 -0
- package/cpp/ggml-cpu/vec.h +973 -0
- package/cpp/ggml-cpu.h +143 -0
- package/cpp/ggml-impl.h +417 -23
- package/cpp/ggml-metal-impl.h +622 -0
- package/cpp/ggml-metal.h +9 -9
- package/cpp/ggml-metal.m +3451 -1344
- package/cpp/ggml-opt.cpp +1037 -0
- package/cpp/ggml-opt.h +237 -0
- package/cpp/ggml-quants.c +296 -10818
- package/cpp/ggml-quants.h +78 -125
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +14 -0
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +4633 -21450
- package/cpp/ggml.h +320 -661
- package/cpp/gguf.cpp +1347 -0
- package/cpp/gguf.h +202 -0
- package/cpp/rn-whisper.cpp +4 -11
- package/cpp/whisper-arch.h +197 -0
- package/cpp/whisper.cpp +2022 -495
- package/cpp/whisper.h +75 -18
- package/ios/CMakeLists.txt +95 -0
- package/ios/RNWhisper.h +5 -0
- package/ios/RNWhisperAudioUtils.m +4 -0
- package/ios/RNWhisperContext.h +5 -0
- package/ios/RNWhisperContext.mm +4 -2
- package/ios/rnwhisper.xcframework/Info.plist +74 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/jest/mock.js +5 -0
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +10 -6
- package/src/version.json +1 -1
- package/whisper-rn.podspec +11 -18
- package/cpp/README.md +0 -4
- package/cpp/ggml-aarch64.c +0 -3209
- package/cpp/ggml-aarch64.h +0 -39
- package/cpp/ggml-cpu-impl.h +0 -614
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
#include "traits.h"
|
|
2
|
+
|
|
3
|
+
#include "ggml-backend-impl.h"
|
|
4
|
+
#include "ggml-backend.h"
|
|
5
|
+
|
|
6
|
+
namespace ggml::cpu {
|
|
7
|
+
tensor_traits::~tensor_traits() {}
|
|
8
|
+
|
|
9
|
+
extra_buffer_type::~extra_buffer_type() {}
|
|
10
|
+
} // namespace ggml::cpu
|
|
11
|
+
|
|
12
|
+
bool wsp_ggml_cpu_extra_compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * op) {
|
|
13
|
+
for (auto extra : wsp_ggml_backend_cpu_get_extra_buffers_type()) {
|
|
14
|
+
if (extra && extra->context) {
|
|
15
|
+
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
|
|
16
|
+
auto tensor_traits = buf_extra->get_tensor_traits(op);
|
|
17
|
+
if (tensor_traits && tensor_traits->compute_forward(params, op)) {
|
|
18
|
+
return true;
|
|
19
|
+
}
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
return false;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
bool wsp_ggml_cpu_extra_work_size(int n_threads, const struct wsp_ggml_tensor * op, size_t * size) {
|
|
26
|
+
for (auto extra : wsp_ggml_backend_cpu_get_extra_buffers_type()) {
|
|
27
|
+
if (extra && extra->context) {
|
|
28
|
+
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
|
|
29
|
+
auto tensor_traits = buf_extra->get_tensor_traits(op);
|
|
30
|
+
if (tensor_traits && tensor_traits->work_size(n_threads, op, *size)) {
|
|
31
|
+
return true;
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
return false;
|
|
36
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "ggml-backend-impl.h"
|
|
3
|
+
#include "ggml-cpu-impl.h"
|
|
4
|
+
#include "ggml.h"
|
|
5
|
+
|
|
6
|
+
#ifdef __cplusplus
|
|
7
|
+
# include <vector>
|
|
8
|
+
extern "C" {
|
|
9
|
+
#endif
|
|
10
|
+
|
|
11
|
+
// return true if op part of extra "accelerator"
|
|
12
|
+
bool wsp_ggml_cpu_extra_compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * op);
|
|
13
|
+
bool wsp_ggml_cpu_extra_work_size(int n_threads, const struct wsp_ggml_tensor * op, size_t * size);
|
|
14
|
+
|
|
15
|
+
#ifdef __cplusplus
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
namespace ggml::cpu {
|
|
19
|
+
// register in tensor->extra
|
|
20
|
+
class tensor_traits {
|
|
21
|
+
public:
|
|
22
|
+
virtual ~tensor_traits();
|
|
23
|
+
virtual bool work_size(int n_threads, const struct wsp_ggml_tensor * op, size_t & size) = 0;
|
|
24
|
+
virtual bool compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * op) = 0;
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
class extra_buffer_type {
|
|
28
|
+
public:
|
|
29
|
+
virtual ~extra_buffer_type();
|
|
30
|
+
virtual bool supports_op(wsp_ggml_backend_dev_t dev, const struct wsp_ggml_tensor * op) = 0;
|
|
31
|
+
virtual tensor_traits * get_tensor_traits(const struct wsp_ggml_tensor * op) = 0;
|
|
32
|
+
};
|
|
33
|
+
} // namespace ggml::cpu
|
|
34
|
+
|
|
35
|
+
// implemented in ggml-cpu.cpp.
|
|
36
|
+
std::vector<wsp_ggml_backend_buffer_type_t> & wsp_ggml_backend_cpu_get_extra_buffers_type();
|
|
37
|
+
|
|
38
|
+
#endif
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
#include "unary-ops.h"
|
|
2
|
+
|
|
3
|
+
static inline float op_abs(float x) {
|
|
4
|
+
return fabsf(x);
|
|
5
|
+
}
|
|
6
|
+
|
|
7
|
+
static inline float op_sgn(float x) {
|
|
8
|
+
return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
static inline float op_neg(float x) {
|
|
12
|
+
return -x;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
static inline float op_step(float x) {
|
|
16
|
+
return (x > 0.f) ? 1.f : 0.f;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
static inline float op_tanh(float x) {
|
|
20
|
+
return tanhf(x);
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
static inline float op_elu(float x) {
|
|
24
|
+
return (x > 0.f) ? x : expm1f(x);
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
static inline float op_relu(float x) {
|
|
28
|
+
return (x > 0.f) ? x : 0.f;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
static inline float op_sigmoid(float x) {
|
|
32
|
+
return 1.f / (1.f + expf(-x));
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
static inline float op_hardsigmoid(float x) {
|
|
36
|
+
return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
static inline float op_exp(float x) {
|
|
40
|
+
return expf(x);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
static inline float op_hardswish(float x) {
|
|
44
|
+
return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
static inline float op_sqr(float x) {
|
|
48
|
+
return x * x;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
static inline float op_sqrt(float x) {
|
|
52
|
+
return sqrtf(x);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
static inline float op_sin(float x) {
|
|
56
|
+
return sinf(x);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
static inline float op_cos(float x) {
|
|
60
|
+
return cosf(x);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
static inline float op_log(float x) {
|
|
64
|
+
return logf(x);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
template <float (*op)(float), typename src0_t, typename dst_t>
|
|
68
|
+
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
|
|
69
|
+
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
|
70
|
+
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
|
|
71
|
+
|
|
72
|
+
for (int i = 0; i < n; i++) {
|
|
73
|
+
y[i] = f32_to_dst(op(src0_to_f32(x[i])));
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
template <float (*op)(float), typename src0_t, typename dst_t>
|
|
78
|
+
static void apply_unary_op(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
79
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
80
|
+
|
|
81
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0) && wsp_ggml_is_contiguous_1(dst) && wsp_ggml_are_same_shape(src0, dst));
|
|
82
|
+
|
|
83
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
84
|
+
|
|
85
|
+
WSP_GGML_ASSERT( nb0 == sizeof(dst_t));
|
|
86
|
+
WSP_GGML_ASSERT(nb00 == sizeof(src0_t));
|
|
87
|
+
|
|
88
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
89
|
+
|
|
90
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
91
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
92
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
93
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
94
|
+
|
|
95
|
+
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
96
|
+
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
97
|
+
|
|
98
|
+
vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
|
|
103
|
+
template <float (*op)(float)>
|
|
104
|
+
static void unary_op(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
105
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
106
|
+
|
|
107
|
+
/* */ if (src0->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) { // all f32
|
|
108
|
+
apply_unary_op<op, float, float>(params, dst);
|
|
109
|
+
} else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F16) { // all f16
|
|
110
|
+
apply_unary_op<op, wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst);
|
|
111
|
+
} else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_BF16) { // all bf16
|
|
112
|
+
apply_unary_op<op, wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst);
|
|
113
|
+
} else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_F32) {
|
|
114
|
+
apply_unary_op<op, wsp_ggml_bf16_t, float>(params, dst);
|
|
115
|
+
} else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F32) {
|
|
116
|
+
apply_unary_op<op, wsp_ggml_fp16_t, float>(params, dst);
|
|
117
|
+
} else {
|
|
118
|
+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
|
|
119
|
+
wsp_ggml_type_name(dst->type), wsp_ggml_type_name(src0->type));
|
|
120
|
+
WSP_GGML_ABORT("fatal error");
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
void wsp_ggml_compute_forward_abs(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
125
|
+
unary_op<op_abs>(params, dst);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
void wsp_ggml_compute_forward_sgn(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
129
|
+
unary_op<op_sgn>(params, dst);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
void wsp_ggml_compute_forward_neg(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
133
|
+
unary_op<op_neg>(params, dst);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
void wsp_ggml_compute_forward_step(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
137
|
+
unary_op<op_step>(params, dst);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
void wsp_ggml_compute_forward_tanh(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
141
|
+
unary_op<op_tanh>(params, dst);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
void wsp_ggml_compute_forward_elu(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
145
|
+
unary_op<op_elu>(params, dst);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
void wsp_ggml_compute_forward_relu(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
149
|
+
unary_op<op_relu>(params, dst);
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
void wsp_ggml_compute_forward_sigmoid(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
153
|
+
unary_op<op_sigmoid>(params, dst);
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
void wsp_ggml_compute_forward_hardsigmoid(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
157
|
+
unary_op<op_hardsigmoid>(params, dst);
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
void wsp_ggml_compute_forward_exp(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
161
|
+
unary_op<op_exp>(params, dst);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
void wsp_ggml_compute_forward_hardswish(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
165
|
+
unary_op<op_hardswish>(params, dst);
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
void wsp_ggml_compute_forward_sqr(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
169
|
+
unary_op<op_sqr>(params, dst);
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
void wsp_ggml_compute_forward_sqrt(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
173
|
+
unary_op<op_sqrt>(params, dst);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
void wsp_ggml_compute_forward_sin(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
177
|
+
unary_op<op_sin>(params, dst);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
void wsp_ggml_compute_forward_cos(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
181
|
+
unary_op<op_cos>(params, dst);
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
void wsp_ggml_compute_forward_log(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
|
|
185
|
+
unary_op<op_log>(params, dst);
|
|
186
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "common.h"
|
|
4
|
+
|
|
5
|
+
#ifdef __cplusplus
|
|
6
|
+
extern "C" {
|
|
7
|
+
#endif
|
|
8
|
+
|
|
9
|
+
void wsp_ggml_compute_forward_abs(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
10
|
+
void wsp_ggml_compute_forward_sgn(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
11
|
+
void wsp_ggml_compute_forward_neg(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
12
|
+
void wsp_ggml_compute_forward_step(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
13
|
+
void wsp_ggml_compute_forward_tanh(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
14
|
+
void wsp_ggml_compute_forward_elu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
15
|
+
void wsp_ggml_compute_forward_relu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
16
|
+
void wsp_ggml_compute_forward_sigmoid(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
17
|
+
void wsp_ggml_compute_forward_hardsigmoid(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
18
|
+
void wsp_ggml_compute_forward_exp(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
19
|
+
void wsp_ggml_compute_forward_hardswish(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
20
|
+
void wsp_ggml_compute_forward_sqr(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
21
|
+
void wsp_ggml_compute_forward_sqrt(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
22
|
+
void wsp_ggml_compute_forward_sin(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
23
|
+
void wsp_ggml_compute_forward_cos(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
24
|
+
void wsp_ggml_compute_forward_log(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
|
|
25
|
+
|
|
26
|
+
#ifdef __cplusplus
|
|
27
|
+
}
|
|
28
|
+
#endif
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
#include "vec.h"
|
|
2
|
+
|
|
3
|
+
#include <cassert>
|
|
4
|
+
|
|
5
|
+
// precomputed gelu table for f16 (128 KB)
|
|
6
|
+
wsp_ggml_fp16_t wsp_ggml_table_gelu_f16[1 << 16];
|
|
7
|
+
|
|
8
|
+
// precomputed quick gelu table for f16 (128 KB)
|
|
9
|
+
wsp_ggml_fp16_t wsp_ggml_table_gelu_quick_f16[1 << 16];
|
|
10
|
+
|
|
11
|
+
void wsp_ggml_vec_dot_f32(int n, float * WSP_GGML_RESTRICT s, size_t bs, const float * WSP_GGML_RESTRICT x, size_t bx, const float * WSP_GGML_RESTRICT y, size_t by, int nrc) {
|
|
12
|
+
assert(nrc == 1);
|
|
13
|
+
WSP_GGML_UNUSED(nrc);
|
|
14
|
+
WSP_GGML_UNUSED(bx);
|
|
15
|
+
WSP_GGML_UNUSED(by);
|
|
16
|
+
WSP_GGML_UNUSED(bs);
|
|
17
|
+
|
|
18
|
+
#if defined(WSP_GGML_SIMD)
|
|
19
|
+
float sumf = 0.0f;
|
|
20
|
+
|
|
21
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
22
|
+
const int sve_register_length = wsp_ggml_cpu_get_sve_cnt() * 8;
|
|
23
|
+
const int wsp_ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
|
|
24
|
+
const int wsp_ggml_f32_step = 8 * wsp_ggml_f32_epr; // choose 8 SVE registers
|
|
25
|
+
|
|
26
|
+
const int np = (n & ~(wsp_ggml_f32_step - 1));
|
|
27
|
+
svfloat32_t sum1 = svdup_n_f32(0.0f);
|
|
28
|
+
svfloat32_t sum2 = svdup_n_f32(0.0f);
|
|
29
|
+
svfloat32_t sum3 = svdup_n_f32(0.0f);
|
|
30
|
+
svfloat32_t sum4 = svdup_n_f32(0.0f);
|
|
31
|
+
svfloat32_t sum5 = svdup_n_f32(0.0f);
|
|
32
|
+
svfloat32_t sum6 = svdup_n_f32(0.0f);
|
|
33
|
+
svfloat32_t sum7 = svdup_n_f32(0.0f);
|
|
34
|
+
svfloat32_t sum8 = svdup_n_f32(0.0f);
|
|
35
|
+
svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8;
|
|
36
|
+
svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8;
|
|
37
|
+
for (int i = 0; i < np; i += wsp_ggml_f32_step) {
|
|
38
|
+
ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
|
|
39
|
+
ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
|
|
40
|
+
sum1 = WSP_GGML_F32_VEC_FMA(ax1, ay1, sum1);
|
|
41
|
+
|
|
42
|
+
ax2 = WSP_GGML_F32_VEC_LOAD(x + i + 1*wsp_ggml_f32_epr);
|
|
43
|
+
ay2 = WSP_GGML_F32_VEC_LOAD(y + i + 1*wsp_ggml_f32_epr);
|
|
44
|
+
sum2 = WSP_GGML_F32_VEC_FMA(ax2, ay2, sum2);
|
|
45
|
+
|
|
46
|
+
ax3 = WSP_GGML_F32_VEC_LOAD(x + i + 2*wsp_ggml_f32_epr);
|
|
47
|
+
ay3 = WSP_GGML_F32_VEC_LOAD(y + i + 2*wsp_ggml_f32_epr);
|
|
48
|
+
sum3 = WSP_GGML_F32_VEC_FMA(ax3, ay3, sum3);
|
|
49
|
+
|
|
50
|
+
ax4 = WSP_GGML_F32_VEC_LOAD(x + i + 3*wsp_ggml_f32_epr);
|
|
51
|
+
ay4 = WSP_GGML_F32_VEC_LOAD(y + i + 3*wsp_ggml_f32_epr);
|
|
52
|
+
sum4 = WSP_GGML_F32_VEC_FMA(ax4, ay4, sum4);
|
|
53
|
+
|
|
54
|
+
ax5 = WSP_GGML_F32_VEC_LOAD(x + i + 4*wsp_ggml_f32_epr);
|
|
55
|
+
ay5 = WSP_GGML_F32_VEC_LOAD(y + i + 4*wsp_ggml_f32_epr);
|
|
56
|
+
sum5 = WSP_GGML_F32_VEC_FMA(ax5, ay5, sum5);
|
|
57
|
+
|
|
58
|
+
ax6 = WSP_GGML_F32_VEC_LOAD(x + i + 5*wsp_ggml_f32_epr);
|
|
59
|
+
ay6 = WSP_GGML_F32_VEC_LOAD(y + i + 5*wsp_ggml_f32_epr);
|
|
60
|
+
sum6 = WSP_GGML_F32_VEC_FMA(ax6, ay6, sum6);
|
|
61
|
+
|
|
62
|
+
ax7 = WSP_GGML_F32_VEC_LOAD(x + i + 6*wsp_ggml_f32_epr);
|
|
63
|
+
ay7 = WSP_GGML_F32_VEC_LOAD(y + i + 6*wsp_ggml_f32_epr);
|
|
64
|
+
sum7 = WSP_GGML_F32_VEC_FMA(ax7, ay7, sum7);
|
|
65
|
+
|
|
66
|
+
ax8 = WSP_GGML_F32_VEC_LOAD(x + i + 7*wsp_ggml_f32_epr);
|
|
67
|
+
ay8 = WSP_GGML_F32_VEC_LOAD(y + i + 7*wsp_ggml_f32_epr);
|
|
68
|
+
sum8 = WSP_GGML_F32_VEC_FMA(ax8, ay8, sum8);
|
|
69
|
+
}
|
|
70
|
+
// leftovers
|
|
71
|
+
// Since 8 unrolls are done in above loop, leftovers lie in range [0, wsp_ggml_f32_step] which is handled in below loop
|
|
72
|
+
const int np2 = (n & ~(wsp_ggml_f32_epr - 1));
|
|
73
|
+
for (int i = np; i < np2; i += wsp_ggml_f32_epr) {
|
|
74
|
+
ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
|
|
75
|
+
ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
|
|
76
|
+
sum1 = WSP_GGML_F32_VEC_FMA(ax1, ay1, sum1);
|
|
77
|
+
}
|
|
78
|
+
// maximum number of leftover elements will be less that wsp_ggml_f32_epr. Apply predicated svmad on available elements only
|
|
79
|
+
if (np2 < n) {
|
|
80
|
+
svbool_t pg = svwhilelt_b32(np2, n);
|
|
81
|
+
ax1 = svld1_f32(pg, x + np2);
|
|
82
|
+
ay1 = svld1_f32(pg, y + np2);
|
|
83
|
+
sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
|
|
84
|
+
}
|
|
85
|
+
// reduce sum1,sum2 to sum1
|
|
86
|
+
WSP_GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
|
|
87
|
+
#else
|
|
88
|
+
const int np = (n & ~(WSP_GGML_F32_STEP - 1));
|
|
89
|
+
|
|
90
|
+
WSP_GGML_F32_VEC sum[WSP_GGML_F32_ARR] = { WSP_GGML_F32_VEC_ZERO };
|
|
91
|
+
|
|
92
|
+
WSP_GGML_F32_VEC ax[WSP_GGML_F32_ARR];
|
|
93
|
+
WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR];
|
|
94
|
+
|
|
95
|
+
for (int i = 0; i < np; i += WSP_GGML_F32_STEP) {
|
|
96
|
+
for (int j = 0; j < WSP_GGML_F32_ARR; j++) {
|
|
97
|
+
ax[j] = WSP_GGML_F32_VEC_LOAD(x + i + j*WSP_GGML_F32_EPR);
|
|
98
|
+
ay[j] = WSP_GGML_F32_VEC_LOAD(y + i + j*WSP_GGML_F32_EPR);
|
|
99
|
+
|
|
100
|
+
sum[j] = WSP_GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// reduce sum0..sum3 to sum0
|
|
105
|
+
WSP_GGML_F32_VEC_REDUCE(sumf, sum);
|
|
106
|
+
|
|
107
|
+
// leftovers
|
|
108
|
+
for (int i = np; i < n; ++i) {
|
|
109
|
+
sumf += x[i]*y[i];
|
|
110
|
+
}
|
|
111
|
+
#endif
|
|
112
|
+
#else
|
|
113
|
+
// scalar
|
|
114
|
+
wsp_ggml_float sumf = 0.0;
|
|
115
|
+
for (int i = 0; i < n; ++i) {
|
|
116
|
+
sumf += (wsp_ggml_float)(x[i]*y[i]);
|
|
117
|
+
}
|
|
118
|
+
#endif
|
|
119
|
+
|
|
120
|
+
*s = sumf;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
void wsp_ggml_vec_dot_bf16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_ggml_bf16_t * WSP_GGML_RESTRICT x, size_t bx, wsp_ggml_bf16_t * WSP_GGML_RESTRICT y, size_t by, int nrc) {
|
|
124
|
+
assert(nrc == 1);
|
|
125
|
+
WSP_GGML_UNUSED(nrc);
|
|
126
|
+
WSP_GGML_UNUSED(bx);
|
|
127
|
+
WSP_GGML_UNUSED(by);
|
|
128
|
+
WSP_GGML_UNUSED(bs);
|
|
129
|
+
int i = 0;
|
|
130
|
+
wsp_ggml_float sumf = 0;
|
|
131
|
+
|
|
132
|
+
#if defined(__AVX512BF16__)
|
|
133
|
+
__m512 c1 = _mm512_setzero_ps();
|
|
134
|
+
__m512 c2 = _mm512_setzero_ps();
|
|
135
|
+
for (; i + 64 <= n; i += 64) {
|
|
136
|
+
c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
|
|
137
|
+
m512bh(_mm512_loadu_si512((y + i))));
|
|
138
|
+
c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
|
|
139
|
+
m512bh(_mm512_loadu_si512((y + i + 32))));
|
|
140
|
+
}
|
|
141
|
+
sumf += (wsp_ggml_float)_mm512_reduce_add_ps(c1);
|
|
142
|
+
sumf += (wsp_ggml_float)_mm512_reduce_add_ps(c2);
|
|
143
|
+
|
|
144
|
+
#elif defined(__AVX512F__)
|
|
145
|
+
#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
|
|
146
|
+
__m512 c1 = _mm512_setzero_ps();
|
|
147
|
+
__m512 c2 = _mm512_setzero_ps();
|
|
148
|
+
for (; i + 32 <= n; i += 32) {
|
|
149
|
+
c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
|
|
150
|
+
c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
|
|
151
|
+
}
|
|
152
|
+
sumf += (wsp_ggml_float)_mm512_reduce_add_ps(c1);
|
|
153
|
+
sumf += (wsp_ggml_float)_mm512_reduce_add_ps(c2);
|
|
154
|
+
|
|
155
|
+
#undef LOAD
|
|
156
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
|
157
|
+
#if defined(__AVX2__)
|
|
158
|
+
#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
|
|
159
|
+
#else
|
|
160
|
+
#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
|
|
161
|
+
#endif
|
|
162
|
+
__m256 c1 = _mm256_setzero_ps();
|
|
163
|
+
__m256 c2 = _mm256_setzero_ps();
|
|
164
|
+
__m256 c3 = _mm256_setzero_ps();
|
|
165
|
+
__m256 c4 = _mm256_setzero_ps();
|
|
166
|
+
for (; i + 32 <= n; i += 32) {
|
|
167
|
+
c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
|
|
168
|
+
c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
|
|
169
|
+
c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
|
|
170
|
+
c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
|
|
171
|
+
}
|
|
172
|
+
__m128 g;
|
|
173
|
+
c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
|
|
174
|
+
_mm256_add_ps(c2, c4));
|
|
175
|
+
g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
|
|
176
|
+
_mm256_castps256_ps128(c1));
|
|
177
|
+
g = _mm_add_ps(g, _mm_movehl_ps(g, g));
|
|
178
|
+
g = _mm_add_ss(g, _mm_movehdup_ps(g));
|
|
179
|
+
sumf += (wsp_ggml_float)_mm_cvtss_f32(g);
|
|
180
|
+
|
|
181
|
+
#undef LOAD
|
|
182
|
+
#endif
|
|
183
|
+
|
|
184
|
+
for (; i < n; ++i) {
|
|
185
|
+
sumf += (wsp_ggml_float)(WSP_GGML_BF16_TO_FP32(x[i]) *
|
|
186
|
+
WSP_GGML_BF16_TO_FP32(y[i]));
|
|
187
|
+
}
|
|
188
|
+
*s = sumf;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
void wsp_ggml_vec_dot_f16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_ggml_fp16_t * WSP_GGML_RESTRICT x, size_t bx, wsp_ggml_fp16_t * WSP_GGML_RESTRICT y, size_t by, int nrc) {
|
|
192
|
+
assert(nrc == 1);
|
|
193
|
+
WSP_GGML_UNUSED(nrc);
|
|
194
|
+
WSP_GGML_UNUSED(bx);
|
|
195
|
+
WSP_GGML_UNUSED(by);
|
|
196
|
+
WSP_GGML_UNUSED(bs);
|
|
197
|
+
|
|
198
|
+
wsp_ggml_float sumf = 0.0;
|
|
199
|
+
|
|
200
|
+
#if defined(WSP_GGML_SIMD)
|
|
201
|
+
const int np = (n & ~(WSP_GGML_F16_STEP - 1));
|
|
202
|
+
|
|
203
|
+
WSP_GGML_F16_VEC sum[WSP_GGML_F16_ARR] = { WSP_GGML_F16_VEC_ZERO };
|
|
204
|
+
|
|
205
|
+
WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR];
|
|
206
|
+
WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
|
|
207
|
+
|
|
208
|
+
for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
|
|
209
|
+
for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
|
|
210
|
+
ax[j] = WSP_GGML_F16_VEC_LOAD(x + i + j*WSP_GGML_F16_EPR, j);
|
|
211
|
+
ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
|
|
212
|
+
|
|
213
|
+
sum[j] = WSP_GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
// reduce sum0..sum3 to sum0
|
|
218
|
+
WSP_GGML_F16_VEC_REDUCE(sumf, sum);
|
|
219
|
+
|
|
220
|
+
// leftovers
|
|
221
|
+
for (int i = np; i < n; ++i) {
|
|
222
|
+
sumf += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[i])*WSP_GGML_FP16_TO_FP32(y[i]));
|
|
223
|
+
}
|
|
224
|
+
#else
|
|
225
|
+
for (int i = 0; i < n; ++i) {
|
|
226
|
+
sumf += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[i])*WSP_GGML_FP16_TO_FP32(y[i]));
|
|
227
|
+
}
|
|
228
|
+
#endif
|
|
229
|
+
|
|
230
|
+
*s = sumf;
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
void wsp_ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|
234
|
+
int i = 0;
|
|
235
|
+
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
|
236
|
+
for (; i + 15 < n; i += 16) {
|
|
237
|
+
_mm512_storeu_ps(y + i, wsp_ggml_v_silu(_mm512_loadu_ps(x + i)));
|
|
238
|
+
}
|
|
239
|
+
#elif defined(__AVX2__) && defined(__FMA__)
|
|
240
|
+
for (; i + 7 < n; i += 8) {
|
|
241
|
+
_mm256_storeu_ps(y + i, wsp_ggml_v_silu(_mm256_loadu_ps(x + i)));
|
|
242
|
+
}
|
|
243
|
+
#elif defined(__SSE2__)
|
|
244
|
+
for (; i + 3 < n; i += 4) {
|
|
245
|
+
_mm_storeu_ps(y + i, wsp_ggml_v_silu(_mm_loadu_ps(x + i)));
|
|
246
|
+
}
|
|
247
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
248
|
+
for (; i + 3 < n; i += 4) {
|
|
249
|
+
vst1q_f32(y + i, wsp_ggml_v_silu(vld1q_f32(x + i)));
|
|
250
|
+
}
|
|
251
|
+
#endif
|
|
252
|
+
for (; i < n; ++i) {
|
|
253
|
+
y[i] = wsp_ggml_silu_f32(x[i]);
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
wsp_ggml_float wsp_ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
|
258
|
+
int i = 0;
|
|
259
|
+
wsp_ggml_float sum = 0;
|
|
260
|
+
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
|
261
|
+
for (; i + 15 < n; i += 16) {
|
|
262
|
+
__m512 val = wsp_ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
|
|
263
|
+
_mm512_set1_ps(max)));
|
|
264
|
+
_mm512_storeu_ps(y + i, val);
|
|
265
|
+
sum += (wsp_ggml_float)_mm512_reduce_add_ps(val);
|
|
266
|
+
}
|
|
267
|
+
#elif defined(__AVX2__) && defined(__FMA__)
|
|
268
|
+
for (; i + 7 < n; i += 8) {
|
|
269
|
+
__m256 val = wsp_ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
|
|
270
|
+
_mm256_set1_ps(max)));
|
|
271
|
+
_mm256_storeu_ps(y + i, val);
|
|
272
|
+
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
|
|
273
|
+
_mm256_castps256_ps128(val));
|
|
274
|
+
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
|
|
275
|
+
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
|
|
276
|
+
sum += (wsp_ggml_float)_mm_cvtss_f32(val2);
|
|
277
|
+
}
|
|
278
|
+
#elif defined(__SSE2__)
|
|
279
|
+
for (; i + 3 < n; i += 4) {
|
|
280
|
+
__m128 val = wsp_ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
|
|
281
|
+
_mm_set1_ps(max)));
|
|
282
|
+
_mm_storeu_ps(y + i, val);
|
|
283
|
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
|
284
|
+
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
|
|
285
|
+
val = _mm_add_ss(val, _mm_movehdup_ps(val));
|
|
286
|
+
#else
|
|
287
|
+
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
|
|
288
|
+
val = _mm_add_ps(val, tmp);
|
|
289
|
+
tmp = _mm_movehl_ps(tmp, val);
|
|
290
|
+
val = _mm_add_ss(val, tmp);
|
|
291
|
+
#endif
|
|
292
|
+
sum += (wsp_ggml_float)_mm_cvtss_f32(val);
|
|
293
|
+
}
|
|
294
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
295
|
+
for (; i + 3 < n; i += 4) {
|
|
296
|
+
float32x4_t val = wsp_ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|
|
297
|
+
vdupq_n_f32(max)));
|
|
298
|
+
vst1q_f32(y + i, val);
|
|
299
|
+
sum += (wsp_ggml_float)vaddvq_f32(val);
|
|
300
|
+
}
|
|
301
|
+
#endif
|
|
302
|
+
for (; i < n; ++i) {
|
|
303
|
+
float val = expf(x[i] - max);
|
|
304
|
+
sum += (wsp_ggml_float)val;
|
|
305
|
+
y[i] = val;
|
|
306
|
+
}
|
|
307
|
+
return sum;
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
wsp_ggml_float wsp_ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
|
|
311
|
+
// log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
|
|
312
|
+
|
|
313
|
+
int i = 0;
|
|
314
|
+
wsp_ggml_float sum = 0;
|
|
315
|
+
for (; i < n; ++i) {
|
|
316
|
+
float val = x[i] - max;
|
|
317
|
+
y[i] = val;
|
|
318
|
+
sum += (wsp_ggml_float)expf(val);
|
|
319
|
+
}
|
|
320
|
+
return sum = (wsp_ggml_float)logf(sum);
|
|
321
|
+
}
|