whisper.rn 0.4.0-rc.9 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +74 -1
- package/android/build.gradle +12 -3
- package/android/src/main/CMakeLists.txt +43 -13
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +211 -0
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +64 -36
- package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +157 -0
- package/android/src/main/jni.cpp +205 -0
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +26 -0
- package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +26 -0
- package/cpp/coreml/whisper-compat.h +10 -0
- package/cpp/coreml/whisper-compat.m +35 -0
- package/cpp/coreml/whisper-decoder-impl.h +27 -15
- package/cpp/coreml/whisper-decoder-impl.m +36 -10
- package/cpp/coreml/whisper-encoder-impl.h +21 -9
- package/cpp/coreml/whisper-encoder-impl.m +29 -3
- package/cpp/ggml-alloc.c +39 -37
- package/cpp/ggml-alloc.h +1 -1
- package/cpp/ggml-backend-impl.h +55 -27
- package/cpp/ggml-backend-reg.cpp +591 -0
- package/cpp/ggml-backend.cpp +336 -955
- package/cpp/ggml-backend.h +70 -42
- package/cpp/ggml-common.h +57 -49
- package/cpp/ggml-cpp.h +39 -0
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
- package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
- package/cpp/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/ggml-cpu/binary-ops.cpp +158 -0
- package/cpp/ggml-cpu/binary-ops.h +16 -0
- package/cpp/ggml-cpu/common.h +72 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
- package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
- package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
- package/cpp/ggml-cpu/ops.cpp +9085 -0
- package/cpp/ggml-cpu/ops.h +111 -0
- package/cpp/ggml-cpu/quants.c +1157 -0
- package/cpp/ggml-cpu/quants.h +89 -0
- package/cpp/ggml-cpu/repack.cpp +1570 -0
- package/cpp/ggml-cpu/repack.h +98 -0
- package/cpp/ggml-cpu/simd-mappings.h +1006 -0
- package/cpp/ggml-cpu/traits.cpp +36 -0
- package/cpp/ggml-cpu/traits.h +38 -0
- package/cpp/ggml-cpu/unary-ops.cpp +186 -0
- package/cpp/ggml-cpu/unary-ops.h +28 -0
- package/cpp/ggml-cpu/vec.cpp +321 -0
- package/cpp/ggml-cpu/vec.h +973 -0
- package/cpp/ggml-cpu.h +143 -0
- package/cpp/ggml-impl.h +417 -23
- package/cpp/ggml-metal-impl.h +622 -0
- package/cpp/ggml-metal.h +9 -9
- package/cpp/ggml-metal.m +3451 -1344
- package/cpp/ggml-opt.cpp +1037 -0
- package/cpp/ggml-opt.h +237 -0
- package/cpp/ggml-quants.c +296 -10818
- package/cpp/ggml-quants.h +78 -125
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +14 -0
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +4633 -21450
- package/cpp/ggml.h +320 -661
- package/cpp/gguf.cpp +1347 -0
- package/cpp/gguf.h +202 -0
- package/cpp/rn-whisper.cpp +4 -11
- package/cpp/whisper-arch.h +197 -0
- package/cpp/whisper.cpp +2022 -495
- package/cpp/whisper.h +75 -18
- package/ios/CMakeLists.txt +95 -0
- package/ios/RNWhisper.h +5 -0
- package/ios/RNWhisper.mm +147 -0
- package/ios/RNWhisperAudioUtils.m +4 -0
- package/ios/RNWhisperContext.h +5 -0
- package/ios/RNWhisperContext.mm +22 -26
- package/ios/RNWhisperVadContext.h +29 -0
- package/ios/RNWhisperVadContext.mm +152 -0
- package/ios/rnwhisper.xcframework/Info.plist +74 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/jest/mock.js +24 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +111 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +112 -0
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +35 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +39 -3
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +10 -6
- package/src/NativeRNWhisper.ts +48 -0
- package/src/index.ts +132 -1
- package/src/version.json +1 -1
- package/whisper-rn.podspec +11 -18
- package/cpp/README.md +0 -4
- package/cpp/ggml-aarch64.c +0 -3209
- package/cpp/ggml-aarch64.h +0 -39
- package/cpp/ggml-cpu-impl.h +0 -614
|
@@ -0,0 +1,3473 @@
|
|
|
1
|
+
#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
|
|
2
|
+
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
|
3
|
+
|
|
4
|
+
#include "ggml-backend-impl.h"
|
|
5
|
+
#include "ggml-backend.h"
|
|
6
|
+
#include "traits.h"
|
|
7
|
+
#include "ggml-cpu-impl.h"
|
|
8
|
+
#include "ggml-cpu.h"
|
|
9
|
+
#include "ggml-impl.h"
|
|
10
|
+
#include "quants.h"
|
|
11
|
+
#include "ggml-threading.h"
|
|
12
|
+
#include "unary-ops.h"
|
|
13
|
+
#include "binary-ops.h"
|
|
14
|
+
#include "vec.h"
|
|
15
|
+
#include "ops.h"
|
|
16
|
+
#include "ggml.h"
|
|
17
|
+
|
|
18
|
+
#if defined(_MSC_VER) || defined(__MINGW32__)
|
|
19
|
+
#include <malloc.h> // using malloc.h with MSC/MINGW
|
|
20
|
+
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
|
21
|
+
#include <alloca.h>
|
|
22
|
+
#endif
|
|
23
|
+
|
|
24
|
+
#include <assert.h>
|
|
25
|
+
#include <errno.h>
|
|
26
|
+
#include <time.h>
|
|
27
|
+
#include <math.h>
|
|
28
|
+
#include <stdlib.h>
|
|
29
|
+
#include <string.h>
|
|
30
|
+
#include <stdint.h>
|
|
31
|
+
#include <inttypes.h>
|
|
32
|
+
#include <stdio.h>
|
|
33
|
+
#include <float.h>
|
|
34
|
+
#include <limits.h>
|
|
35
|
+
#include <stdarg.h>
|
|
36
|
+
#include <signal.h>
|
|
37
|
+
#if defined(__gnu_linux__)
|
|
38
|
+
#include <syscall.h>
|
|
39
|
+
#endif
|
|
40
|
+
|
|
41
|
+
#ifdef WSP_GGML_USE_OPENMP
|
|
42
|
+
#include <omp.h>
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
|
|
46
|
+
#undef WSP_GGML_USE_LLAMAFILE
|
|
47
|
+
#endif
|
|
48
|
+
|
|
49
|
+
#ifdef WSP_GGML_USE_LLAMAFILE
|
|
50
|
+
#include "llamafile/sgemm.h"
|
|
51
|
+
#endif
|
|
52
|
+
|
|
53
|
+
// Note: once we move threading into a separate C++ file
|
|
54
|
+
// will use std::hardware_destructive_interference_size instead of hardcoding it here
|
|
55
|
+
// and we'll use C++ attribute syntax.
|
|
56
|
+
#define WSP_GGML_CACHE_LINE 64
|
|
57
|
+
|
|
58
|
+
#if defined(__clang__) || defined(__GNUC__)
|
|
59
|
+
#define WSP_GGML_CACHE_ALIGN __attribute__((aligned(WSP_GGML_CACHE_LINE)))
|
|
60
|
+
#endif
|
|
61
|
+
|
|
62
|
+
#if defined(__has_feature)
|
|
63
|
+
#if __has_feature(thread_sanitizer)
|
|
64
|
+
#define WSP_GGML_TSAN_ENABLED 1
|
|
65
|
+
#endif
|
|
66
|
+
#else // __has_feature
|
|
67
|
+
#if defined(__SANITIZE_THREAD__)
|
|
68
|
+
#define WSP_GGML_TSAN_ENABLED 1
|
|
69
|
+
#endif
|
|
70
|
+
#endif // __has_feature
|
|
71
|
+
|
|
72
|
+
#define UNUSED WSP_GGML_UNUSED
|
|
73
|
+
#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
|
|
74
|
+
|
|
75
|
+
#if defined(__ARM_ARCH)
|
|
76
|
+
struct wsp_ggml_arm_arch_features_type {
|
|
77
|
+
int sve_cnt;
|
|
78
|
+
} wsp_ggml_arm_arch_features = { 0 };
|
|
79
|
+
#endif
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
#if defined(_WIN32)
|
|
83
|
+
|
|
84
|
+
#define WIN32_LEAN_AND_MEAN
|
|
85
|
+
#ifndef NOMINMAX
|
|
86
|
+
#define NOMINMAX
|
|
87
|
+
#endif
|
|
88
|
+
#include <windows.h>
|
|
89
|
+
|
|
90
|
+
#if defined(_MSC_VER) && !defined(__clang__)
|
|
91
|
+
#define WSP_GGML_CACHE_ALIGN __declspec(align(WSP_GGML_CACHE_LINE))
|
|
92
|
+
|
|
93
|
+
typedef volatile LONG atomic_int;
|
|
94
|
+
typedef atomic_int atomic_bool;
|
|
95
|
+
typedef atomic_int atomic_flag;
|
|
96
|
+
|
|
97
|
+
#define ATOMIC_FLAG_INIT 0
|
|
98
|
+
|
|
99
|
+
typedef enum {
|
|
100
|
+
memory_order_relaxed,
|
|
101
|
+
memory_order_consume,
|
|
102
|
+
memory_order_acquire,
|
|
103
|
+
memory_order_release,
|
|
104
|
+
memory_order_acq_rel,
|
|
105
|
+
memory_order_seq_cst
|
|
106
|
+
} memory_order;
|
|
107
|
+
|
|
108
|
+
static void atomic_store(atomic_int * ptr, LONG val) {
|
|
109
|
+
InterlockedExchange(ptr, val);
|
|
110
|
+
}
|
|
111
|
+
static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) {
|
|
112
|
+
// TODO: add support for explicit memory order
|
|
113
|
+
InterlockedExchange(ptr, val);
|
|
114
|
+
}
|
|
115
|
+
static LONG atomic_load(atomic_int * ptr) {
|
|
116
|
+
return InterlockedCompareExchange(ptr, 0, 0);
|
|
117
|
+
}
|
|
118
|
+
static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) {
|
|
119
|
+
// TODO: add support for explicit memory order
|
|
120
|
+
return InterlockedCompareExchange(ptr, 0, 0);
|
|
121
|
+
}
|
|
122
|
+
static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
|
|
123
|
+
return InterlockedExchangeAdd(ptr, inc);
|
|
124
|
+
}
|
|
125
|
+
static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) {
|
|
126
|
+
// TODO: add support for explicit memory order
|
|
127
|
+
return InterlockedExchangeAdd(ptr, inc);
|
|
128
|
+
}
|
|
129
|
+
static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
|
|
130
|
+
return InterlockedExchange(ptr, 1);
|
|
131
|
+
}
|
|
132
|
+
static void atomic_flag_clear(atomic_flag * ptr) {
|
|
133
|
+
InterlockedExchange(ptr, 0);
|
|
134
|
+
}
|
|
135
|
+
static void atomic_thread_fence(memory_order mo) {
|
|
136
|
+
MemoryBarrier();
|
|
137
|
+
}
|
|
138
|
+
#else // clang
|
|
139
|
+
#include <stdatomic.h>
|
|
140
|
+
#endif
|
|
141
|
+
|
|
142
|
+
typedef HANDLE pthread_t;
|
|
143
|
+
|
|
144
|
+
typedef DWORD thread_ret_t;
|
|
145
|
+
static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
|
|
146
|
+
(void) unused;
|
|
147
|
+
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
|
|
148
|
+
if (handle == NULL)
|
|
149
|
+
{
|
|
150
|
+
return EAGAIN;
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
*out = handle;
|
|
154
|
+
return 0;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
static int pthread_join(pthread_t thread, void * unused) {
|
|
158
|
+
(void) unused;
|
|
159
|
+
int ret = (int) WaitForSingleObject(thread, INFINITE);
|
|
160
|
+
CloseHandle(thread);
|
|
161
|
+
return ret;
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
static int sched_yield (void) {
|
|
165
|
+
Sleep (0);
|
|
166
|
+
return 0;
|
|
167
|
+
}
|
|
168
|
+
#else
|
|
169
|
+
|
|
170
|
+
#include <pthread.h>
|
|
171
|
+
#include <stdatomic.h>
|
|
172
|
+
#include <sched.h>
|
|
173
|
+
#if defined(__FreeBSD__)
|
|
174
|
+
#include <pthread_np.h>
|
|
175
|
+
#endif
|
|
176
|
+
|
|
177
|
+
typedef void * thread_ret_t;
|
|
178
|
+
|
|
179
|
+
#include <sys/types.h>
|
|
180
|
+
#include <sys/stat.h>
|
|
181
|
+
#include <unistd.h>
|
|
182
|
+
|
|
183
|
+
#endif
|
|
184
|
+
|
|
185
|
+
typedef pthread_t wsp_ggml_thread_t;
|
|
186
|
+
|
|
187
|
+
#if defined(__APPLE__)
|
|
188
|
+
#include <unistd.h>
|
|
189
|
+
#include <mach/mach.h>
|
|
190
|
+
#include <TargetConditionals.h>
|
|
191
|
+
#endif
|
|
192
|
+
|
|
193
|
+
static const struct wsp_ggml_type_traits_cpu type_traits_cpu[WSP_GGML_TYPE_COUNT] = {
|
|
194
|
+
[WSP_GGML_TYPE_F32] = {
|
|
195
|
+
.vec_dot = (wsp_ggml_vec_dot_t) wsp_ggml_vec_dot_f32,
|
|
196
|
+
.vec_dot_type = WSP_GGML_TYPE_F32,
|
|
197
|
+
.nrows = 1,
|
|
198
|
+
},
|
|
199
|
+
[WSP_GGML_TYPE_F16] = {
|
|
200
|
+
.from_float = (wsp_ggml_from_float_t) wsp_ggml_cpu_fp32_to_fp16,
|
|
201
|
+
.vec_dot = (wsp_ggml_vec_dot_t) wsp_ggml_vec_dot_f16,
|
|
202
|
+
.vec_dot_type = WSP_GGML_TYPE_F16,
|
|
203
|
+
.nrows = 1,
|
|
204
|
+
},
|
|
205
|
+
[WSP_GGML_TYPE_Q4_0] = {
|
|
206
|
+
.from_float = wsp_quantize_row_q4_0,
|
|
207
|
+
.vec_dot = wsp_ggml_vec_dot_q4_0_q8_0,
|
|
208
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
209
|
+
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
|
210
|
+
.nrows = 2,
|
|
211
|
+
#else
|
|
212
|
+
.nrows = 1,
|
|
213
|
+
#endif
|
|
214
|
+
},
|
|
215
|
+
[WSP_GGML_TYPE_Q4_1] = {
|
|
216
|
+
.from_float = wsp_quantize_row_q4_1,
|
|
217
|
+
.vec_dot = wsp_ggml_vec_dot_q4_1_q8_1,
|
|
218
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_1,
|
|
219
|
+
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
|
220
|
+
.nrows = 2,
|
|
221
|
+
#else
|
|
222
|
+
.nrows = 1,
|
|
223
|
+
#endif
|
|
224
|
+
},
|
|
225
|
+
[WSP_GGML_TYPE_Q5_0] = {
|
|
226
|
+
.from_float = wsp_quantize_row_q5_0,
|
|
227
|
+
.vec_dot = wsp_ggml_vec_dot_q5_0_q8_0,
|
|
228
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
229
|
+
.nrows = 1,
|
|
230
|
+
},
|
|
231
|
+
[WSP_GGML_TYPE_Q5_1] = {
|
|
232
|
+
.from_float = wsp_quantize_row_q5_1,
|
|
233
|
+
.vec_dot = wsp_ggml_vec_dot_q5_1_q8_1,
|
|
234
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_1,
|
|
235
|
+
.nrows = 1,
|
|
236
|
+
},
|
|
237
|
+
[WSP_GGML_TYPE_Q8_0] = {
|
|
238
|
+
.from_float = wsp_quantize_row_q8_0,
|
|
239
|
+
.vec_dot = wsp_ggml_vec_dot_q8_0_q8_0,
|
|
240
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
241
|
+
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
|
242
|
+
.nrows = 2,
|
|
243
|
+
#else
|
|
244
|
+
.nrows = 1,
|
|
245
|
+
#endif
|
|
246
|
+
},
|
|
247
|
+
[WSP_GGML_TYPE_Q8_1] = {
|
|
248
|
+
.from_float = wsp_quantize_row_q8_1,
|
|
249
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_1,
|
|
250
|
+
.nrows = 1,
|
|
251
|
+
},
|
|
252
|
+
[WSP_GGML_TYPE_Q2_K] = {
|
|
253
|
+
.from_float = wsp_quantize_row_q2_K,
|
|
254
|
+
.vec_dot = wsp_ggml_vec_dot_q2_K_q8_K,
|
|
255
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
256
|
+
.nrows = 1,
|
|
257
|
+
},
|
|
258
|
+
[WSP_GGML_TYPE_Q3_K] = {
|
|
259
|
+
.from_float = wsp_quantize_row_q3_K,
|
|
260
|
+
.vec_dot = wsp_ggml_vec_dot_q3_K_q8_K,
|
|
261
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
262
|
+
.nrows = 1,
|
|
263
|
+
},
|
|
264
|
+
[WSP_GGML_TYPE_Q4_K] = {
|
|
265
|
+
.from_float = wsp_quantize_row_q4_K,
|
|
266
|
+
.vec_dot = wsp_ggml_vec_dot_q4_K_q8_K,
|
|
267
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
268
|
+
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
|
269
|
+
.nrows = 2,
|
|
270
|
+
#else
|
|
271
|
+
.nrows = 1,
|
|
272
|
+
#endif
|
|
273
|
+
},
|
|
274
|
+
[WSP_GGML_TYPE_Q5_K] = {
|
|
275
|
+
.from_float = wsp_quantize_row_q5_K,
|
|
276
|
+
.vec_dot = wsp_ggml_vec_dot_q5_K_q8_K,
|
|
277
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
278
|
+
.nrows = 1,
|
|
279
|
+
},
|
|
280
|
+
[WSP_GGML_TYPE_Q6_K] = {
|
|
281
|
+
.from_float = wsp_quantize_row_q6_K,
|
|
282
|
+
.vec_dot = wsp_ggml_vec_dot_q6_K_q8_K,
|
|
283
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
284
|
+
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
|
285
|
+
.nrows = 2,
|
|
286
|
+
#else
|
|
287
|
+
.nrows = 1,
|
|
288
|
+
#endif
|
|
289
|
+
},
|
|
290
|
+
[WSP_GGML_TYPE_IQ2_XXS] = {
|
|
291
|
+
.from_float = NULL,
|
|
292
|
+
.vec_dot = wsp_ggml_vec_dot_iq2_xxs_q8_K,
|
|
293
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
294
|
+
.nrows = 1,
|
|
295
|
+
},
|
|
296
|
+
[WSP_GGML_TYPE_IQ2_XS] = {
|
|
297
|
+
.from_float = NULL,
|
|
298
|
+
.vec_dot = wsp_ggml_vec_dot_iq2_xs_q8_K,
|
|
299
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
300
|
+
.nrows = 1,
|
|
301
|
+
},
|
|
302
|
+
[WSP_GGML_TYPE_IQ3_XXS] = {
|
|
303
|
+
// NOTE: from_float for iq3 and iq2_s was removed because these quants require initialization in wsp_ggml_wsp_quantize_init
|
|
304
|
+
//.from_float = wsp_quantize_row_iq3_xxs,
|
|
305
|
+
.vec_dot = wsp_ggml_vec_dot_iq3_xxs_q8_K,
|
|
306
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
307
|
+
.nrows = 1,
|
|
308
|
+
},
|
|
309
|
+
[WSP_GGML_TYPE_IQ3_S] = {
|
|
310
|
+
//.from_float = wsp_quantize_row_iq3_s,
|
|
311
|
+
.vec_dot = wsp_ggml_vec_dot_iq3_s_q8_K,
|
|
312
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
313
|
+
.nrows = 1,
|
|
314
|
+
},
|
|
315
|
+
[WSP_GGML_TYPE_IQ2_S] = {
|
|
316
|
+
//.from_float = wsp_quantize_row_iq2_s,
|
|
317
|
+
.vec_dot = wsp_ggml_vec_dot_iq2_s_q8_K,
|
|
318
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
319
|
+
.nrows = 1,
|
|
320
|
+
},
|
|
321
|
+
[WSP_GGML_TYPE_IQ1_S] = {
|
|
322
|
+
.from_float = NULL,
|
|
323
|
+
.vec_dot = wsp_ggml_vec_dot_iq1_s_q8_K,
|
|
324
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
325
|
+
.nrows = 1,
|
|
326
|
+
},
|
|
327
|
+
[WSP_GGML_TYPE_IQ1_M] = {
|
|
328
|
+
.from_float = NULL,
|
|
329
|
+
.vec_dot = wsp_ggml_vec_dot_iq1_m_q8_K,
|
|
330
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
331
|
+
.nrows = 1,
|
|
332
|
+
},
|
|
333
|
+
[WSP_GGML_TYPE_IQ4_NL] = {
|
|
334
|
+
.from_float = wsp_quantize_row_iq4_nl,
|
|
335
|
+
.vec_dot = wsp_ggml_vec_dot_iq4_nl_q8_0,
|
|
336
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
337
|
+
.nrows = 1,
|
|
338
|
+
},
|
|
339
|
+
[WSP_GGML_TYPE_IQ4_XS] = {
|
|
340
|
+
.from_float = wsp_quantize_row_iq4_xs,
|
|
341
|
+
.vec_dot = wsp_ggml_vec_dot_iq4_xs_q8_K,
|
|
342
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
343
|
+
.nrows = 1,
|
|
344
|
+
},
|
|
345
|
+
[WSP_GGML_TYPE_Q8_K] = {
|
|
346
|
+
.from_float = wsp_quantize_row_q8_K,
|
|
347
|
+
},
|
|
348
|
+
[WSP_GGML_TYPE_BF16] = {
|
|
349
|
+
.from_float = (wsp_ggml_from_float_t) wsp_ggml_cpu_fp32_to_bf16,
|
|
350
|
+
.vec_dot = (wsp_ggml_vec_dot_t) wsp_ggml_vec_dot_bf16,
|
|
351
|
+
.vec_dot_type = WSP_GGML_TYPE_BF16,
|
|
352
|
+
.nrows = 1,
|
|
353
|
+
},
|
|
354
|
+
[WSP_GGML_TYPE_TQ1_0] = {
|
|
355
|
+
.from_float = wsp_quantize_row_tq1_0,
|
|
356
|
+
.vec_dot = wsp_ggml_vec_dot_tq1_0_q8_K,
|
|
357
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
358
|
+
.nrows = 1,
|
|
359
|
+
},
|
|
360
|
+
[WSP_GGML_TYPE_TQ2_0] = {
|
|
361
|
+
.from_float = wsp_quantize_row_tq2_0,
|
|
362
|
+
.vec_dot = wsp_ggml_vec_dot_tq2_0_q8_K,
|
|
363
|
+
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
364
|
+
.nrows = 1,
|
|
365
|
+
},
|
|
366
|
+
};
|
|
367
|
+
|
|
368
|
+
const struct wsp_ggml_type_traits_cpu * wsp_ggml_get_type_traits_cpu(enum wsp_ggml_type type) {
|
|
369
|
+
return &type_traits_cpu[type];
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
//
|
|
373
|
+
// Threading defs
|
|
374
|
+
//
|
|
375
|
+
|
|
376
|
+
typedef pthread_t wsp_ggml_thread_t;
|
|
377
|
+
|
|
378
|
+
#if defined(_WIN32)
|
|
379
|
+
|
|
380
|
+
typedef CONDITION_VARIABLE wsp_ggml_cond_t;
|
|
381
|
+
typedef SRWLOCK wsp_ggml_mutex_t;
|
|
382
|
+
|
|
383
|
+
#define wsp_ggml_mutex_init(m) InitializeSRWLock(m)
|
|
384
|
+
#define wsp_ggml_mutex_destroy(m)
|
|
385
|
+
#define wsp_ggml_mutex_lock(m) AcquireSRWLockExclusive(m)
|
|
386
|
+
#define wsp_ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m)
|
|
387
|
+
#define wsp_ggml_mutex_lock_shared(m) AcquireSRWLockShared(m)
|
|
388
|
+
#define wsp_ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m)
|
|
389
|
+
|
|
390
|
+
#define wsp_ggml_cond_init(c) InitializeConditionVariable(c)
|
|
391
|
+
#define wsp_ggml_cond_destroy(c)
|
|
392
|
+
#define wsp_ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED)
|
|
393
|
+
#define wsp_ggml_cond_broadcast(c) WakeAllConditionVariable(c)
|
|
394
|
+
|
|
395
|
+
#define wsp_ggml_thread_create pthread_create
|
|
396
|
+
#define wsp_ggml_thread_join pthread_join
|
|
397
|
+
|
|
398
|
+
#else
|
|
399
|
+
|
|
400
|
+
typedef pthread_cond_t wsp_ggml_cond_t;
|
|
401
|
+
typedef pthread_mutex_t wsp_ggml_mutex_t;
|
|
402
|
+
|
|
403
|
+
#define wsp_ggml_mutex_init(m) pthread_mutex_init(m, NULL)
|
|
404
|
+
#define wsp_ggml_mutex_destroy(m) pthread_mutex_destroy(m)
|
|
405
|
+
#define wsp_ggml_mutex_lock(m) pthread_mutex_lock(m)
|
|
406
|
+
#define wsp_ggml_mutex_unlock(m) pthread_mutex_unlock(m)
|
|
407
|
+
#define wsp_ggml_mutex_lock_shared(m) pthread_mutex_lock(m)
|
|
408
|
+
#define wsp_ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m)
|
|
409
|
+
|
|
410
|
+
#define wsp_ggml_lock_init(x) UNUSED(x)
|
|
411
|
+
#define wsp_ggml_lock_destroy(x) UNUSED(x)
|
|
412
|
+
#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
|
|
413
|
+
#define wsp_ggml_lock_lock(x) _mm_pause()
|
|
414
|
+
#else
|
|
415
|
+
#define wsp_ggml_lock_lock(x) UNUSED(x)
|
|
416
|
+
#endif
|
|
417
|
+
#define wsp_ggml_lock_unlock(x) UNUSED(x)
|
|
418
|
+
|
|
419
|
+
#define WSP_GGML_LOCK_INITIALIZER 0
|
|
420
|
+
#define wsp_ggml_cond_init(c) pthread_cond_init(c, NULL)
|
|
421
|
+
#define wsp_ggml_cond_destroy(c) pthread_cond_destroy(c)
|
|
422
|
+
#define wsp_ggml_cond_wait(c, m) pthread_cond_wait(c, m)
|
|
423
|
+
#define wsp_ggml_cond_broadcast(c) pthread_cond_broadcast(c)
|
|
424
|
+
|
|
425
|
+
#define wsp_ggml_thread_create pthread_create
|
|
426
|
+
#define wsp_ggml_thread_join pthread_join
|
|
427
|
+
|
|
428
|
+
#endif
|
|
429
|
+
|
|
430
|
+
// Threadpool def
|
|
431
|
+
struct wsp_ggml_threadpool {
|
|
432
|
+
wsp_ggml_mutex_t mutex; // mutex for cond.var
|
|
433
|
+
wsp_ggml_cond_t cond; // cond.var for waiting for new work
|
|
434
|
+
|
|
435
|
+
struct wsp_ggml_cgraph * cgraph;
|
|
436
|
+
struct wsp_ggml_cplan * cplan;
|
|
437
|
+
|
|
438
|
+
// synchronization primitives
|
|
439
|
+
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
|
|
440
|
+
atomic_int WSP_GGML_CACHE_ALIGN n_barrier;
|
|
441
|
+
atomic_int WSP_GGML_CACHE_ALIGN n_barrier_passed;
|
|
442
|
+
atomic_int WSP_GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
|
443
|
+
|
|
444
|
+
// these are atomic as an annotation for thread-sanitizer
|
|
445
|
+
atomic_bool stop; // Used for stopping the threadpool altogether
|
|
446
|
+
atomic_bool pause; // Used for pausing the threadpool or individual threads
|
|
447
|
+
atomic_int abort; // Used for aborting processing of a graph
|
|
448
|
+
|
|
449
|
+
struct wsp_ggml_compute_state * workers; // per thread state
|
|
450
|
+
int n_threads_max; // number of threads in the pool
|
|
451
|
+
atomic_int n_threads_cur; // number of threads used in the current graph
|
|
452
|
+
|
|
453
|
+
int32_t prio; // Scheduling priority
|
|
454
|
+
uint32_t poll; // Polling level (0 - no polling)
|
|
455
|
+
|
|
456
|
+
enum wsp_ggml_status ec;
|
|
457
|
+
};
|
|
458
|
+
|
|
459
|
+
// Per-thread state
|
|
460
|
+
struct wsp_ggml_compute_state {
|
|
461
|
+
#ifndef WSP_GGML_USE_OPENMP
|
|
462
|
+
wsp_ggml_thread_t thrd;
|
|
463
|
+
bool cpumask[WSP_GGML_MAX_N_THREADS];
|
|
464
|
+
int last_graph;
|
|
465
|
+
bool pending;
|
|
466
|
+
#endif
|
|
467
|
+
struct wsp_ggml_threadpool * threadpool;
|
|
468
|
+
int ith;
|
|
469
|
+
};
|
|
470
|
+
|
|
471
|
+
// Helpers for polling loops
|
|
472
|
+
#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )
|
|
473
|
+
static inline void wsp_ggml_thread_cpu_relax(void) {
|
|
474
|
+
__asm__ volatile("yield" ::: "memory");
|
|
475
|
+
}
|
|
476
|
+
#elif defined(__x86_64__)
|
|
477
|
+
static inline void wsp_ggml_thread_cpu_relax(void) {
|
|
478
|
+
_mm_pause();
|
|
479
|
+
}
|
|
480
|
+
#else
|
|
481
|
+
static inline void wsp_ggml_thread_cpu_relax(void) {;}
|
|
482
|
+
#endif
|
|
483
|
+
|
|
484
|
+
//
|
|
485
|
+
// NUMA support
|
|
486
|
+
//
|
|
487
|
+
|
|
488
|
+
#define WSP_GGML_NUMA_MAX_NODES 8
|
|
489
|
+
#define WSP_GGML_NUMA_MAX_CPUS 512
|
|
490
|
+
|
|
491
|
+
struct wsp_ggml_numa_node {
|
|
492
|
+
uint32_t cpus[WSP_GGML_NUMA_MAX_CPUS]; // hardware threads on this node
|
|
493
|
+
uint32_t n_cpus;
|
|
494
|
+
};
|
|
495
|
+
|
|
496
|
+
struct wsp_ggml_numa_nodes {
|
|
497
|
+
enum wsp_ggml_numa_strategy numa_strategy;
|
|
498
|
+
struct wsp_ggml_numa_node nodes[WSP_GGML_NUMA_MAX_NODES];
|
|
499
|
+
uint32_t n_nodes;
|
|
500
|
+
uint32_t total_cpus; // hardware threads on system
|
|
501
|
+
uint32_t current_node; // node on which main process is execting
|
|
502
|
+
#if defined(__gnu_linux__)
|
|
503
|
+
cpu_set_t cpuset; // cpuset from numactl
|
|
504
|
+
#else
|
|
505
|
+
uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype
|
|
506
|
+
#endif
|
|
507
|
+
};
|
|
508
|
+
|
|
509
|
+
//
|
|
510
|
+
// ggml state
|
|
511
|
+
//
|
|
512
|
+
|
|
513
|
+
struct wsp_ggml_state {
|
|
514
|
+
struct wsp_ggml_numa_nodes numa;
|
|
515
|
+
};
|
|
516
|
+
|
|
517
|
+
static struct wsp_ggml_state g_state = {0};
|
|
518
|
+
|
|
519
|
+
void wsp_ggml_barrier(struct wsp_ggml_threadpool * tp) {
|
|
520
|
+
int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
|
|
521
|
+
if (n_threads == 1) {
|
|
522
|
+
return;
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
#ifdef WSP_GGML_USE_OPENMP
|
|
526
|
+
#pragma omp barrier
|
|
527
|
+
#else
|
|
528
|
+
int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
|
|
529
|
+
|
|
530
|
+
// enter barrier (full seq-cst fence)
|
|
531
|
+
int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
|
|
532
|
+
|
|
533
|
+
if (n_barrier == (n_threads - 1)) {
|
|
534
|
+
// last thread
|
|
535
|
+
atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
|
|
536
|
+
|
|
537
|
+
// exit barrier (fill seq-cst fence)
|
|
538
|
+
atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
|
|
539
|
+
return;
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
// wait for other threads
|
|
543
|
+
while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
|
|
544
|
+
wsp_ggml_thread_cpu_relax();
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
// exit barrier (full seq-cst fence)
|
|
548
|
+
// TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
|
|
549
|
+
#ifdef WSP_GGML_TSAN_ENABLED
|
|
550
|
+
atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
|
|
551
|
+
#else
|
|
552
|
+
atomic_thread_fence(memory_order_seq_cst);
|
|
553
|
+
#endif
|
|
554
|
+
#endif
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
void wsp_ggml_threadpool_chunk_set(struct wsp_ggml_threadpool * tp, int value) {
|
|
558
|
+
atomic_store_explicit(&tp->current_chunk, value, memory_order_relaxed);
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
int wsp_ggml_threadpool_chunk_add(struct wsp_ggml_threadpool * tp, int value) {
|
|
562
|
+
return atomic_fetch_add_explicit(&tp->current_chunk, value, memory_order_relaxed);
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
#if defined(__gnu_linux__)
|
|
566
|
+
static cpu_set_t wsp_ggml_get_numa_affinity(void) {
|
|
567
|
+
cpu_set_t cpuset;
|
|
568
|
+
pthread_t thread;
|
|
569
|
+
thread = pthread_self();
|
|
570
|
+
CPU_ZERO(&cpuset);
|
|
571
|
+
pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
|
|
572
|
+
return cpuset;
|
|
573
|
+
}
|
|
574
|
+
#else
|
|
575
|
+
static uint32_t wsp_ggml_get_numa_affinity(void) {
|
|
576
|
+
return 0; // no NUMA support
|
|
577
|
+
}
|
|
578
|
+
#endif
|
|
579
|
+
|
|
580
|
+
void wsp_ggml_numa_init(enum wsp_ggml_numa_strategy numa_flag) {
|
|
581
|
+
if (g_state.numa.n_nodes > 0) {
|
|
582
|
+
fprintf(stderr, "wsp_ggml_numa_init: NUMA already initialized\n");
|
|
583
|
+
|
|
584
|
+
return;
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
#if defined(__gnu_linux__)
|
|
588
|
+
struct stat st;
|
|
589
|
+
char path[256];
|
|
590
|
+
int rv;
|
|
591
|
+
|
|
592
|
+
// set numa scheme
|
|
593
|
+
g_state.numa.numa_strategy = numa_flag;
|
|
594
|
+
|
|
595
|
+
WSP_GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy);
|
|
596
|
+
|
|
597
|
+
g_state.numa.cpuset = wsp_ggml_get_numa_affinity();
|
|
598
|
+
|
|
599
|
+
// enumerate nodes
|
|
600
|
+
while (g_state.numa.n_nodes < WSP_GGML_NUMA_MAX_NODES) {
|
|
601
|
+
rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
|
|
602
|
+
WSP_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
|
|
603
|
+
if (stat(path, &st) != 0) { break; }
|
|
604
|
+
++g_state.numa.n_nodes;
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
// enumerate CPUs
|
|
608
|
+
while (g_state.numa.total_cpus < WSP_GGML_NUMA_MAX_CPUS) {
|
|
609
|
+
rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
|
|
610
|
+
WSP_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
|
|
611
|
+
if (stat(path, &st) != 0) { break; }
|
|
612
|
+
++g_state.numa.total_cpus;
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
WSP_GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
|
|
616
|
+
|
|
617
|
+
// figure out which node we're on
|
|
618
|
+
uint current_cpu;
|
|
619
|
+
int getcpu_ret = 0;
|
|
620
|
+
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__)
|
|
621
|
+
getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node);
|
|
622
|
+
#else
|
|
623
|
+
// old glibc doesn't have a wrapper for this call. Fall back on direct syscall
|
|
624
|
+
# if !defined(SYS_getcpu) && defined(SYS_get_cpu)
|
|
625
|
+
# define SYS_getcpu SYS_get_cpu // some older glibc versions use this name
|
|
626
|
+
# endif
|
|
627
|
+
getcpu_ret = syscall(SYS_getcpu, ¤t_cpu, &g_state.numa.current_node);
|
|
628
|
+
#endif
|
|
629
|
+
|
|
630
|
+
if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
|
|
631
|
+
g_state.numa.n_nodes = 0;
|
|
632
|
+
return;
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
WSP_GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu);
|
|
636
|
+
|
|
637
|
+
for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
|
|
638
|
+
struct wsp_ggml_numa_node * node = &g_state.numa.nodes[n];
|
|
639
|
+
WSP_GGML_PRINT_DEBUG("CPUs on node %u:", n);
|
|
640
|
+
node->n_cpus = 0;
|
|
641
|
+
for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
|
|
642
|
+
rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
|
|
643
|
+
WSP_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
|
|
644
|
+
if (stat(path, &st) == 0) {
|
|
645
|
+
node->cpus[node->n_cpus++] = c;
|
|
646
|
+
WSP_GGML_PRINT_DEBUG(" %u", c);
|
|
647
|
+
}
|
|
648
|
+
}
|
|
649
|
+
WSP_GGML_PRINT_DEBUG("\n");
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
if (wsp_ggml_is_numa()) {
|
|
653
|
+
FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
|
|
654
|
+
if (fptr != NULL) {
|
|
655
|
+
char buf[42];
|
|
656
|
+
if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
|
|
657
|
+
WSP_GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
|
|
658
|
+
}
|
|
659
|
+
fclose(fptr);
|
|
660
|
+
}
|
|
661
|
+
}
|
|
662
|
+
#else
|
|
663
|
+
UNUSED(numa_flag);
|
|
664
|
+
// TODO
|
|
665
|
+
#endif
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
bool wsp_ggml_is_numa(void) {
|
|
669
|
+
return g_state.numa.n_nodes > 1;
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
#if defined(__ARM_ARCH)
|
|
673
|
+
|
|
674
|
+
#if defined(__linux__) && defined(__aarch64__)
|
|
675
|
+
#include <sys/auxv.h>
|
|
676
|
+
#endif
|
|
677
|
+
|
|
678
|
+
static void wsp_ggml_init_arm_arch_features(void) {
|
|
679
|
+
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
|
680
|
+
wsp_ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
|
|
681
|
+
#endif
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
#endif // __ARM_ARCH
|
|
685
|
+
|
|
686
|
+
struct wsp_ggml_tensor * wsp_ggml_new_i32(struct wsp_ggml_context * ctx, int32_t value) {
|
|
687
|
+
WSP_GGML_ASSERT(!wsp_ggml_get_no_alloc(ctx));
|
|
688
|
+
|
|
689
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 1);
|
|
690
|
+
|
|
691
|
+
wsp_ggml_set_i32(result, value);
|
|
692
|
+
|
|
693
|
+
return result;
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float value) {
|
|
697
|
+
WSP_GGML_ASSERT(!wsp_ggml_get_no_alloc(ctx));
|
|
698
|
+
|
|
699
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 1);
|
|
700
|
+
|
|
701
|
+
wsp_ggml_set_f32(result, value);
|
|
702
|
+
|
|
703
|
+
return result;
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value) {
|
|
707
|
+
const int n = wsp_ggml_nrows(tensor);
|
|
708
|
+
const int nc = tensor->ne[0];
|
|
709
|
+
const size_t n1 = tensor->nb[1];
|
|
710
|
+
|
|
711
|
+
char * const data = tensor->data;
|
|
712
|
+
|
|
713
|
+
switch (tensor->type) {
|
|
714
|
+
case WSP_GGML_TYPE_I8:
|
|
715
|
+
{
|
|
716
|
+
assert(tensor->nb[0] == sizeof(int8_t));
|
|
717
|
+
for (int i = 0; i < n; i++) {
|
|
718
|
+
wsp_ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
|
|
719
|
+
}
|
|
720
|
+
} break;
|
|
721
|
+
case WSP_GGML_TYPE_I16:
|
|
722
|
+
{
|
|
723
|
+
assert(tensor->nb[0] == sizeof(int16_t));
|
|
724
|
+
for (int i = 0; i < n; i++) {
|
|
725
|
+
wsp_ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
|
|
726
|
+
}
|
|
727
|
+
} break;
|
|
728
|
+
case WSP_GGML_TYPE_I32:
|
|
729
|
+
{
|
|
730
|
+
assert(tensor->nb[0] == sizeof(int32_t));
|
|
731
|
+
for (int i = 0; i < n; i++) {
|
|
732
|
+
wsp_ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
|
|
733
|
+
}
|
|
734
|
+
} break;
|
|
735
|
+
case WSP_GGML_TYPE_F16:
|
|
736
|
+
{
|
|
737
|
+
assert(tensor->nb[0] == sizeof(wsp_ggml_fp16_t));
|
|
738
|
+
for (int i = 0; i < n; i++) {
|
|
739
|
+
wsp_ggml_vec_set_f16(nc, (wsp_ggml_fp16_t *)(data + i*n1), WSP_GGML_FP32_TO_FP16(value));
|
|
740
|
+
}
|
|
741
|
+
} break;
|
|
742
|
+
case WSP_GGML_TYPE_BF16:
|
|
743
|
+
{
|
|
744
|
+
assert(tensor->nb[0] == sizeof(wsp_ggml_fp16_t));
|
|
745
|
+
for (int i = 0; i < n; i++) {
|
|
746
|
+
wsp_ggml_vec_set_bf16(nc, (wsp_ggml_bf16_t *)(data + i*n1), WSP_GGML_FP32_TO_BF16(value));
|
|
747
|
+
}
|
|
748
|
+
} break;
|
|
749
|
+
case WSP_GGML_TYPE_F32:
|
|
750
|
+
{
|
|
751
|
+
assert(tensor->nb[0] == sizeof(float));
|
|
752
|
+
for (int i = 0; i < n; i++) {
|
|
753
|
+
wsp_ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
|
754
|
+
}
|
|
755
|
+
} break;
|
|
756
|
+
default:
|
|
757
|
+
{
|
|
758
|
+
WSP_GGML_ABORT("fatal error");
|
|
759
|
+
}
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
return tensor;
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
struct wsp_ggml_tensor * wsp_ggml_set_f32(struct wsp_ggml_tensor * tensor, float value) {
|
|
766
|
+
const int n = wsp_ggml_nrows(tensor);
|
|
767
|
+
const int nc = tensor->ne[0];
|
|
768
|
+
const size_t n1 = tensor->nb[1];
|
|
769
|
+
|
|
770
|
+
char * const data = tensor->data;
|
|
771
|
+
|
|
772
|
+
switch (tensor->type) {
|
|
773
|
+
case WSP_GGML_TYPE_I8:
|
|
774
|
+
{
|
|
775
|
+
assert(tensor->nb[0] == sizeof(int8_t));
|
|
776
|
+
for (int i = 0; i < n; i++) {
|
|
777
|
+
wsp_ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
|
|
778
|
+
}
|
|
779
|
+
} break;
|
|
780
|
+
case WSP_GGML_TYPE_I16:
|
|
781
|
+
{
|
|
782
|
+
assert(tensor->nb[0] == sizeof(int16_t));
|
|
783
|
+
for (int i = 0; i < n; i++) {
|
|
784
|
+
wsp_ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
|
|
785
|
+
}
|
|
786
|
+
} break;
|
|
787
|
+
case WSP_GGML_TYPE_I32:
|
|
788
|
+
{
|
|
789
|
+
assert(tensor->nb[0] == sizeof(int32_t));
|
|
790
|
+
for (int i = 0; i < n; i++) {
|
|
791
|
+
wsp_ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
|
|
792
|
+
}
|
|
793
|
+
} break;
|
|
794
|
+
case WSP_GGML_TYPE_F16:
|
|
795
|
+
{
|
|
796
|
+
assert(tensor->nb[0] == sizeof(wsp_ggml_fp16_t));
|
|
797
|
+
for (int i = 0; i < n; i++) {
|
|
798
|
+
wsp_ggml_vec_set_f16(nc, (wsp_ggml_fp16_t *)(data + i*n1), WSP_GGML_FP32_TO_FP16(value));
|
|
799
|
+
}
|
|
800
|
+
} break;
|
|
801
|
+
case WSP_GGML_TYPE_BF16:
|
|
802
|
+
{
|
|
803
|
+
assert(tensor->nb[0] == sizeof(wsp_ggml_bf16_t));
|
|
804
|
+
for (int i = 0; i < n; i++) {
|
|
805
|
+
wsp_ggml_vec_set_bf16(nc, (wsp_ggml_bf16_t *)(data + i*n1), WSP_GGML_FP32_TO_BF16(value));
|
|
806
|
+
}
|
|
807
|
+
} break;
|
|
808
|
+
case WSP_GGML_TYPE_F32:
|
|
809
|
+
{
|
|
810
|
+
assert(tensor->nb[0] == sizeof(float));
|
|
811
|
+
for (int i = 0; i < n; i++) {
|
|
812
|
+
wsp_ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
|
813
|
+
}
|
|
814
|
+
} break;
|
|
815
|
+
default:
|
|
816
|
+
{
|
|
817
|
+
WSP_GGML_ABORT("fatal error");
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
return tensor;
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i) {
|
|
825
|
+
if (!wsp_ggml_is_contiguous(tensor)) {
|
|
826
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
|
827
|
+
wsp_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
|
828
|
+
return wsp_ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
|
|
829
|
+
}
|
|
830
|
+
switch (tensor->type) {
|
|
831
|
+
case WSP_GGML_TYPE_I8:
|
|
832
|
+
{
|
|
833
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
|
834
|
+
return ((int8_t *)(tensor->data))[i];
|
|
835
|
+
}
|
|
836
|
+
case WSP_GGML_TYPE_I16:
|
|
837
|
+
{
|
|
838
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
|
839
|
+
return ((int16_t *)(tensor->data))[i];
|
|
840
|
+
}
|
|
841
|
+
case WSP_GGML_TYPE_I32:
|
|
842
|
+
{
|
|
843
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
|
844
|
+
return ((int32_t *)(tensor->data))[i];
|
|
845
|
+
}
|
|
846
|
+
case WSP_GGML_TYPE_F16:
|
|
847
|
+
{
|
|
848
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t));
|
|
849
|
+
return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *)(tensor->data))[i]);
|
|
850
|
+
}
|
|
851
|
+
case WSP_GGML_TYPE_BF16:
|
|
852
|
+
{
|
|
853
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_bf16_t));
|
|
854
|
+
return WSP_GGML_BF16_TO_FP32(((wsp_ggml_bf16_t *)(tensor->data))[i]);
|
|
855
|
+
}
|
|
856
|
+
case WSP_GGML_TYPE_F32:
|
|
857
|
+
{
|
|
858
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
|
859
|
+
return ((float *)(tensor->data))[i];
|
|
860
|
+
}
|
|
861
|
+
default:
|
|
862
|
+
{
|
|
863
|
+
WSP_GGML_ABORT("fatal error");
|
|
864
|
+
}
|
|
865
|
+
}
|
|
866
|
+
}
|
|
867
|
+
|
|
868
|
+
void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value) {
|
|
869
|
+
if (!wsp_ggml_is_contiguous(tensor)) {
|
|
870
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
|
871
|
+
wsp_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
|
872
|
+
wsp_ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
|
|
873
|
+
return;
|
|
874
|
+
}
|
|
875
|
+
switch (tensor->type) {
|
|
876
|
+
case WSP_GGML_TYPE_I8:
|
|
877
|
+
{
|
|
878
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
|
879
|
+
((int8_t *)(tensor->data))[i] = value;
|
|
880
|
+
} break;
|
|
881
|
+
case WSP_GGML_TYPE_I16:
|
|
882
|
+
{
|
|
883
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
|
884
|
+
((int16_t *)(tensor->data))[i] = value;
|
|
885
|
+
} break;
|
|
886
|
+
case WSP_GGML_TYPE_I32:
|
|
887
|
+
{
|
|
888
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
|
889
|
+
((int32_t *)(tensor->data))[i] = value;
|
|
890
|
+
} break;
|
|
891
|
+
case WSP_GGML_TYPE_F16:
|
|
892
|
+
{
|
|
893
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t));
|
|
894
|
+
((wsp_ggml_fp16_t *)(tensor->data))[i] = WSP_GGML_FP32_TO_FP16(value);
|
|
895
|
+
} break;
|
|
896
|
+
case WSP_GGML_TYPE_BF16:
|
|
897
|
+
{
|
|
898
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_bf16_t));
|
|
899
|
+
((wsp_ggml_bf16_t *)(tensor->data))[i] = WSP_GGML_FP32_TO_BF16(value);
|
|
900
|
+
} break;
|
|
901
|
+
case WSP_GGML_TYPE_F32:
|
|
902
|
+
{
|
|
903
|
+
WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
|
904
|
+
((float *)(tensor->data))[i] = value;
|
|
905
|
+
} break;
|
|
906
|
+
default:
|
|
907
|
+
{
|
|
908
|
+
WSP_GGML_ABORT("fatal error");
|
|
909
|
+
}
|
|
910
|
+
}
|
|
911
|
+
}
|
|
912
|
+
|
|
913
|
+
int32_t wsp_ggml_get_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
|
|
914
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
|
915
|
+
switch (tensor->type) {
|
|
916
|
+
case WSP_GGML_TYPE_I8:
|
|
917
|
+
return ((int8_t *) data)[0];
|
|
918
|
+
case WSP_GGML_TYPE_I16:
|
|
919
|
+
return ((int16_t *) data)[0];
|
|
920
|
+
case WSP_GGML_TYPE_I32:
|
|
921
|
+
return ((int32_t *) data)[0];
|
|
922
|
+
case WSP_GGML_TYPE_F16:
|
|
923
|
+
return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *) data)[0]);
|
|
924
|
+
case WSP_GGML_TYPE_BF16:
|
|
925
|
+
return WSP_GGML_BF16_TO_FP32(((wsp_ggml_bf16_t *) data)[0]);
|
|
926
|
+
case WSP_GGML_TYPE_F32:
|
|
927
|
+
return ((float *) data)[0];
|
|
928
|
+
default:
|
|
929
|
+
WSP_GGML_ABORT("fatal error");
|
|
930
|
+
}
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
void wsp_ggml_set_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
|
|
934
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
|
935
|
+
switch (tensor->type) {
|
|
936
|
+
case WSP_GGML_TYPE_I8:
|
|
937
|
+
{
|
|
938
|
+
((int8_t *)(data))[0] = value;
|
|
939
|
+
} break;
|
|
940
|
+
case WSP_GGML_TYPE_I16:
|
|
941
|
+
{
|
|
942
|
+
((int16_t *)(data))[0] = value;
|
|
943
|
+
} break;
|
|
944
|
+
case WSP_GGML_TYPE_I32:
|
|
945
|
+
{
|
|
946
|
+
((int32_t *)(data))[0] = value;
|
|
947
|
+
} break;
|
|
948
|
+
case WSP_GGML_TYPE_F16:
|
|
949
|
+
{
|
|
950
|
+
((wsp_ggml_fp16_t *)(data))[0] = WSP_GGML_FP32_TO_FP16(value);
|
|
951
|
+
} break;
|
|
952
|
+
case WSP_GGML_TYPE_BF16:
|
|
953
|
+
{
|
|
954
|
+
((wsp_ggml_bf16_t *)(data))[0] = WSP_GGML_FP32_TO_BF16(value);
|
|
955
|
+
} break;
|
|
956
|
+
case WSP_GGML_TYPE_F32:
|
|
957
|
+
{
|
|
958
|
+
((float *)(data))[0] = value;
|
|
959
|
+
} break;
|
|
960
|
+
default:
|
|
961
|
+
{
|
|
962
|
+
WSP_GGML_ABORT("fatal error");
|
|
963
|
+
}
|
|
964
|
+
}
|
|
965
|
+
}
|
|
966
|
+
|
|
967
|
+
float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i) {
|
|
968
|
+
if (!wsp_ggml_is_contiguous(tensor)) {
|
|
969
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
|
970
|
+
wsp_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
|
971
|
+
return wsp_ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
|
|
972
|
+
}
|
|
973
|
+
switch (tensor->type) {
|
|
974
|
+
case WSP_GGML_TYPE_I8:
|
|
975
|
+
{
|
|
976
|
+
return ((int8_t *)(tensor->data))[i];
|
|
977
|
+
}
|
|
978
|
+
case WSP_GGML_TYPE_I16:
|
|
979
|
+
{
|
|
980
|
+
return ((int16_t *)(tensor->data))[i];
|
|
981
|
+
}
|
|
982
|
+
case WSP_GGML_TYPE_I32:
|
|
983
|
+
{
|
|
984
|
+
return ((int32_t *)(tensor->data))[i];
|
|
985
|
+
}
|
|
986
|
+
case WSP_GGML_TYPE_F16:
|
|
987
|
+
{
|
|
988
|
+
return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *)(tensor->data))[i]);
|
|
989
|
+
}
|
|
990
|
+
case WSP_GGML_TYPE_BF16:
|
|
991
|
+
{
|
|
992
|
+
return WSP_GGML_BF16_TO_FP32(((wsp_ggml_bf16_t *)(tensor->data))[i]);
|
|
993
|
+
}
|
|
994
|
+
case WSP_GGML_TYPE_F32:
|
|
995
|
+
{
|
|
996
|
+
return ((float *)(tensor->data))[i];
|
|
997
|
+
}
|
|
998
|
+
default:
|
|
999
|
+
{
|
|
1000
|
+
WSP_GGML_ABORT("fatal error");
|
|
1001
|
+
}
|
|
1002
|
+
}
|
|
1003
|
+
}
|
|
1004
|
+
|
|
1005
|
+
void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value) {
|
|
1006
|
+
if (!wsp_ggml_is_contiguous(tensor)) {
|
|
1007
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
|
1008
|
+
wsp_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
|
1009
|
+
wsp_ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
|
|
1010
|
+
return;
|
|
1011
|
+
}
|
|
1012
|
+
switch (tensor->type) {
|
|
1013
|
+
case WSP_GGML_TYPE_I8:
|
|
1014
|
+
{
|
|
1015
|
+
((int8_t *)(tensor->data))[i] = value;
|
|
1016
|
+
} break;
|
|
1017
|
+
case WSP_GGML_TYPE_I16:
|
|
1018
|
+
{
|
|
1019
|
+
((int16_t *)(tensor->data))[i] = value;
|
|
1020
|
+
} break;
|
|
1021
|
+
case WSP_GGML_TYPE_I32:
|
|
1022
|
+
{
|
|
1023
|
+
((int32_t *)(tensor->data))[i] = value;
|
|
1024
|
+
} break;
|
|
1025
|
+
case WSP_GGML_TYPE_F16:
|
|
1026
|
+
{
|
|
1027
|
+
((wsp_ggml_fp16_t *)(tensor->data))[i] = WSP_GGML_FP32_TO_FP16(value);
|
|
1028
|
+
} break;
|
|
1029
|
+
case WSP_GGML_TYPE_BF16:
|
|
1030
|
+
{
|
|
1031
|
+
((wsp_ggml_bf16_t *)(tensor->data))[i] = WSP_GGML_FP32_TO_BF16(value);
|
|
1032
|
+
} break;
|
|
1033
|
+
case WSP_GGML_TYPE_F32:
|
|
1034
|
+
{
|
|
1035
|
+
((float *)(tensor->data))[i] = value;
|
|
1036
|
+
} break;
|
|
1037
|
+
default:
|
|
1038
|
+
{
|
|
1039
|
+
WSP_GGML_ABORT("fatal error");
|
|
1040
|
+
}
|
|
1041
|
+
}
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
float wsp_ggml_get_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
|
|
1045
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
|
1046
|
+
switch (tensor->type) {
|
|
1047
|
+
case WSP_GGML_TYPE_I8:
|
|
1048
|
+
return ((int8_t *) data)[0];
|
|
1049
|
+
case WSP_GGML_TYPE_I16:
|
|
1050
|
+
return ((int16_t *) data)[0];
|
|
1051
|
+
case WSP_GGML_TYPE_I32:
|
|
1052
|
+
return ((int32_t *) data)[0];
|
|
1053
|
+
case WSP_GGML_TYPE_F16:
|
|
1054
|
+
return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *) data)[0]);
|
|
1055
|
+
case WSP_GGML_TYPE_BF16:
|
|
1056
|
+
return WSP_GGML_BF16_TO_FP32(((wsp_ggml_bf16_t *) data)[0]);
|
|
1057
|
+
case WSP_GGML_TYPE_F32:
|
|
1058
|
+
return ((float *) data)[0];
|
|
1059
|
+
default:
|
|
1060
|
+
WSP_GGML_ABORT("fatal error");
|
|
1061
|
+
}
|
|
1062
|
+
}
|
|
1063
|
+
|
|
1064
|
+
void wsp_ggml_set_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
|
|
1065
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
|
1066
|
+
switch (tensor->type) {
|
|
1067
|
+
case WSP_GGML_TYPE_I8:
|
|
1068
|
+
{
|
|
1069
|
+
((int8_t *)(data))[0] = value;
|
|
1070
|
+
} break;
|
|
1071
|
+
case WSP_GGML_TYPE_I16:
|
|
1072
|
+
{
|
|
1073
|
+
((int16_t *)(data))[0] = value;
|
|
1074
|
+
} break;
|
|
1075
|
+
case WSP_GGML_TYPE_I32:
|
|
1076
|
+
{
|
|
1077
|
+
((int32_t *)(data))[0] = value;
|
|
1078
|
+
} break;
|
|
1079
|
+
case WSP_GGML_TYPE_F16:
|
|
1080
|
+
{
|
|
1081
|
+
((wsp_ggml_fp16_t *)(data))[0] = WSP_GGML_FP32_TO_FP16(value);
|
|
1082
|
+
} break;
|
|
1083
|
+
case WSP_GGML_TYPE_BF16:
|
|
1084
|
+
{
|
|
1085
|
+
((wsp_ggml_bf16_t *)(data))[0] = WSP_GGML_FP32_TO_BF16(value);
|
|
1086
|
+
} break;
|
|
1087
|
+
case WSP_GGML_TYPE_F32:
|
|
1088
|
+
{
|
|
1089
|
+
((float *)(data))[0] = value;
|
|
1090
|
+
} break;
|
|
1091
|
+
default:
|
|
1092
|
+
{
|
|
1093
|
+
WSP_GGML_ABORT("fatal error");
|
|
1094
|
+
}
|
|
1095
|
+
}
|
|
1096
|
+
}
|
|
1097
|
+
|
|
1098
|
+
////////////////////////////////////////////////////////////////////////////////
|
|
1099
|
+
|
|
1100
|
+
// wsp_ggml_compute_forward_mul_mat
|
|
1101
|
+
|
|
1102
|
+
static void wsp_ggml_compute_forward_mul_mat_one_chunk(
|
|
1103
|
+
const struct wsp_ggml_compute_params * params,
|
|
1104
|
+
struct wsp_ggml_tensor * dst,
|
|
1105
|
+
const enum wsp_ggml_type type,
|
|
1106
|
+
const int64_t num_rows_per_vec_dot,
|
|
1107
|
+
const int64_t ir0_start,
|
|
1108
|
+
const int64_t ir0_end,
|
|
1109
|
+
const int64_t ir1_start,
|
|
1110
|
+
const int64_t ir1_end) {
|
|
1111
|
+
|
|
1112
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
1113
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
1114
|
+
|
|
1115
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
1116
|
+
|
|
1117
|
+
const bool src1_cont = wsp_ggml_is_contiguous(src1);
|
|
1118
|
+
|
|
1119
|
+
wsp_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
|
1120
|
+
enum wsp_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
|
1121
|
+
|
|
1122
|
+
// broadcast factors
|
|
1123
|
+
const int64_t r2 = ne12 / ne02;
|
|
1124
|
+
const int64_t r3 = ne13 / ne03;
|
|
1125
|
+
|
|
1126
|
+
//printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
|
|
1127
|
+
|
|
1128
|
+
// threads with no work simply yield (not sure if it helps)
|
|
1129
|
+
if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
|
|
1130
|
+
return;
|
|
1131
|
+
}
|
|
1132
|
+
|
|
1133
|
+
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
|
1134
|
+
const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10);
|
|
1135
|
+
|
|
1136
|
+
assert(ne12 % ne02 == 0);
|
|
1137
|
+
assert(ne13 % ne03 == 0);
|
|
1138
|
+
|
|
1139
|
+
// block-tiling attempt
|
|
1140
|
+
const int64_t blck_0 = 16;
|
|
1141
|
+
const int64_t blck_1 = 16;
|
|
1142
|
+
|
|
1143
|
+
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
|
1144
|
+
|
|
1145
|
+
// attempt to reduce false-sharing (does not seem to make a difference)
|
|
1146
|
+
// 16 * 2, accounting for mmla kernels
|
|
1147
|
+
float tmp[32];
|
|
1148
|
+
|
|
1149
|
+
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
|
1150
|
+
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
|
1151
|
+
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
|
|
1152
|
+
const int64_t i13 = (ir1 / (ne12 * ne1));
|
|
1153
|
+
const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
|
|
1154
|
+
const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
|
|
1155
|
+
|
|
1156
|
+
// broadcast src0 into src1
|
|
1157
|
+
const int64_t i03 = i13 / r3;
|
|
1158
|
+
const int64_t i02 = i12 / r2;
|
|
1159
|
+
|
|
1160
|
+
const int64_t i1 = i11;
|
|
1161
|
+
const int64_t i2 = i12;
|
|
1162
|
+
const int64_t i3 = i13;
|
|
1163
|
+
|
|
1164
|
+
const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
|
|
1165
|
+
|
|
1166
|
+
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
|
1167
|
+
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
|
1168
|
+
// the original src1 data pointer, so we should index using the indices directly
|
|
1169
|
+
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
|
1170
|
+
const char * src1_col = (const char*)wdata +
|
|
1171
|
+
(src1_cont || src1->type != vec_dot_type
|
|
1172
|
+
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
|
|
1173
|
+
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
|
|
1174
|
+
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
|
1175
|
+
|
|
1176
|
+
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
|
1177
|
+
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
|
1178
|
+
//}
|
|
1179
|
+
|
|
1180
|
+
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
|
1181
|
+
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
|
1182
|
+
}
|
|
1183
|
+
|
|
1184
|
+
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
|
1185
|
+
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
|
|
1186
|
+
}
|
|
1187
|
+
}
|
|
1188
|
+
}
|
|
1189
|
+
}
|
|
1190
|
+
}
|
|
1191
|
+
|
|
1192
|
+
static void wsp_ggml_compute_forward_mul_mat(
|
|
1193
|
+
const struct wsp_ggml_compute_params * params,
|
|
1194
|
+
struct wsp_ggml_tensor * dst) {
|
|
1195
|
+
|
|
1196
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
1197
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
1198
|
+
|
|
1199
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
1200
|
+
|
|
1201
|
+
const int ith = params->ith;
|
|
1202
|
+
const int nth = params->nth;
|
|
1203
|
+
|
|
1204
|
+
enum wsp_ggml_type const vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
|
|
1205
|
+
wsp_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
|
|
1206
|
+
int64_t const vec_dot_num_rows = type_traits_cpu[src0->type].nrows;
|
|
1207
|
+
|
|
1208
|
+
WSP_GGML_ASSERT(ne0 == ne01);
|
|
1209
|
+
WSP_GGML_ASSERT(ne1 == ne11);
|
|
1210
|
+
WSP_GGML_ASSERT(ne2 == ne12);
|
|
1211
|
+
WSP_GGML_ASSERT(ne3 == ne13);
|
|
1212
|
+
|
|
1213
|
+
// we don't support permuted src0 or src1
|
|
1214
|
+
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(src0->type));
|
|
1215
|
+
WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type));
|
|
1216
|
+
|
|
1217
|
+
// dst cannot be transposed or permuted
|
|
1218
|
+
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
1219
|
+
WSP_GGML_ASSERT(nb0 <= nb1);
|
|
1220
|
+
WSP_GGML_ASSERT(nb1 <= nb2);
|
|
1221
|
+
WSP_GGML_ASSERT(nb2 <= nb3);
|
|
1222
|
+
|
|
1223
|
+
// nb01 >= nb00 - src0 is not transposed
|
|
1224
|
+
// compute by src0 rows
|
|
1225
|
+
|
|
1226
|
+
// TODO: extract to "extra_op"
|
|
1227
|
+
#if WSP_GGML_USE_LLAMAFILE
|
|
1228
|
+
// broadcast factors
|
|
1229
|
+
const int64_t r2 = ne12 / ne02;
|
|
1230
|
+
const int64_t r3 = ne13 / ne03;
|
|
1231
|
+
|
|
1232
|
+
const bool src1_cont = wsp_ggml_is_contiguous(src1);
|
|
1233
|
+
|
|
1234
|
+
if (src1_cont) {
|
|
1235
|
+
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
1236
|
+
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
1237
|
+
if (!llamafile_sgemm(params,
|
|
1238
|
+
ne01, ne11, ne00/wsp_ggml_blck_size(src0->type),
|
|
1239
|
+
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
1240
|
+
nb01/wsp_ggml_type_size(src0->type),
|
|
1241
|
+
(const char *)src1->data + i12*nb12 + i13*nb13,
|
|
1242
|
+
nb11/wsp_ggml_type_size(src1->type),
|
|
1243
|
+
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
1244
|
+
nb1/wsp_ggml_type_size(dst->type),
|
|
1245
|
+
src0->type,
|
|
1246
|
+
src1->type,
|
|
1247
|
+
dst->type))
|
|
1248
|
+
goto UseGgmlGemm1;
|
|
1249
|
+
return;
|
|
1250
|
+
}
|
|
1251
|
+
UseGgmlGemm1:;
|
|
1252
|
+
#endif
|
|
1253
|
+
|
|
1254
|
+
if (src1->type != vec_dot_type) {
|
|
1255
|
+
char * wdata = params->wdata;
|
|
1256
|
+
|
|
1257
|
+
const size_t nbw0 = wsp_ggml_type_size(vec_dot_type);
|
|
1258
|
+
const size_t nbw1 = wsp_ggml_row_size(vec_dot_type, ne10);
|
|
1259
|
+
const size_t nbw2 = nbw1*ne11;
|
|
1260
|
+
const size_t nbw3 = nbw2*ne12;
|
|
1261
|
+
|
|
1262
|
+
assert(params->wsize >= ne13*nbw3);
|
|
1263
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
1264
|
+
|
|
1265
|
+
#if 0
|
|
1266
|
+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
1267
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
1268
|
+
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
|
1269
|
+
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
|
1270
|
+
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
|
1271
|
+
ne10);
|
|
1272
|
+
}
|
|
1273
|
+
}
|
|
1274
|
+
}
|
|
1275
|
+
#else
|
|
1276
|
+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
1277
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
1278
|
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
1279
|
+
size_t bs = wsp_ggml_blck_size(vec_dot_type);
|
|
1280
|
+
int64_t ne10_block_start = (ith * ne10/bs) / nth;
|
|
1281
|
+
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
|
|
1282
|
+
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
|
|
1283
|
+
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
|
|
1284
|
+
(ne10_block_end - ne10_block_start) * bs);
|
|
1285
|
+
}
|
|
1286
|
+
}
|
|
1287
|
+
}
|
|
1288
|
+
#endif
|
|
1289
|
+
}
|
|
1290
|
+
|
|
1291
|
+
if (ith == 0) {
|
|
1292
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
1293
|
+
atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed);
|
|
1294
|
+
}
|
|
1295
|
+
|
|
1296
|
+
wsp_ggml_barrier(params->threadpool);
|
|
1297
|
+
|
|
1298
|
+
#if WSP_GGML_USE_LLAMAFILE
|
|
1299
|
+
if (src1->type != vec_dot_type) {
|
|
1300
|
+
const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
|
1301
|
+
const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10);
|
|
1302
|
+
|
|
1303
|
+
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
1304
|
+
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
1305
|
+
if (!llamafile_sgemm(params,
|
|
1306
|
+
ne01, ne11, ne00/wsp_ggml_blck_size(src0->type),
|
|
1307
|
+
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
1308
|
+
nb01/wsp_ggml_type_size(src0->type),
|
|
1309
|
+
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
|
|
1310
|
+
row_size/wsp_ggml_type_size(vec_dot_type),
|
|
1311
|
+
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
1312
|
+
nb1/wsp_ggml_type_size(dst->type),
|
|
1313
|
+
src0->type,
|
|
1314
|
+
vec_dot_type,
|
|
1315
|
+
dst->type))
|
|
1316
|
+
goto UseGgmlGemm2;
|
|
1317
|
+
return;
|
|
1318
|
+
}
|
|
1319
|
+
UseGgmlGemm2:;
|
|
1320
|
+
#endif
|
|
1321
|
+
|
|
1322
|
+
// This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
|
|
1323
|
+
const int64_t nr0 = ne0;
|
|
1324
|
+
|
|
1325
|
+
// This is the size of the rest of the dimensions of the result
|
|
1326
|
+
const int64_t nr1 = ne1 * ne2 * ne3;
|
|
1327
|
+
|
|
1328
|
+
// Now select a reasonable chunk size.
|
|
1329
|
+
int chunk_size = 16;
|
|
1330
|
+
|
|
1331
|
+
// We need to step up the size if it's small
|
|
1332
|
+
if (nr0 == 1 || nr1 == 1) {
|
|
1333
|
+
chunk_size = 64;
|
|
1334
|
+
}
|
|
1335
|
+
|
|
1336
|
+
// distribute the work across the inner or outer loop based on which one is larger
|
|
1337
|
+
// The number of chunks in the 0/1 dim.
|
|
1338
|
+
// CEIL(nr0/chunk_size)
|
|
1339
|
+
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
|
1340
|
+
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
|
1341
|
+
|
|
1342
|
+
// If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
|
|
1343
|
+
// Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggml-org/llama.cpp/pull/6915
|
|
1344
|
+
// In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
|
|
1345
|
+
if (nchunk0 * nchunk1 < nth * 4 || wsp_ggml_is_numa()) {
|
|
1346
|
+
// distribute the thread work across the inner or outer loop based on which one is larger
|
|
1347
|
+
nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
|
1348
|
+
nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
|
|
1349
|
+
}
|
|
1350
|
+
|
|
1351
|
+
// The number of elements in each chunk
|
|
1352
|
+
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
|
1353
|
+
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
|
1354
|
+
|
|
1355
|
+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
1356
|
+
int current_chunk = ith;
|
|
1357
|
+
|
|
1358
|
+
while (current_chunk < nchunk0 * nchunk1) {
|
|
1359
|
+
const int64_t ith0 = current_chunk % nchunk0;
|
|
1360
|
+
const int64_t ith1 = current_chunk / nchunk0;
|
|
1361
|
+
|
|
1362
|
+
const int64_t ir0_start = dr0 * ith0;
|
|
1363
|
+
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
|
1364
|
+
|
|
1365
|
+
const int64_t ir1_start = dr1 * ith1;
|
|
1366
|
+
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
|
1367
|
+
|
|
1368
|
+
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
|
1369
|
+
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
|
1370
|
+
|
|
1371
|
+
// these checks are needed to avoid crossing dim1 boundaries
|
|
1372
|
+
// can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
|
|
1373
|
+
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
|
|
1374
|
+
num_rows_per_vec_dot = 1;
|
|
1375
|
+
}
|
|
1376
|
+
wsp_ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
|
1377
|
+
|
|
1378
|
+
if (nth >= nchunk0 * nchunk1) {
|
|
1379
|
+
break;
|
|
1380
|
+
}
|
|
1381
|
+
|
|
1382
|
+
current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed);
|
|
1383
|
+
}
|
|
1384
|
+
}
|
|
1385
|
+
|
|
1386
|
+
// wsp_ggml_compute_forward_mul_mat_id
|
|
1387
|
+
|
|
1388
|
+
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
|
|
1389
|
+
|
|
1390
|
+
struct mmid_row_mapping {
|
|
1391
|
+
int32_t i1;
|
|
1392
|
+
int32_t i2;
|
|
1393
|
+
};
|
|
1394
|
+
|
|
1395
|
+
static void wsp_ggml_compute_forward_mul_mat_id_one_chunk(
|
|
1396
|
+
struct wsp_ggml_tensor * dst,
|
|
1397
|
+
const struct wsp_ggml_tensor * src0,
|
|
1398
|
+
const struct wsp_ggml_tensor * src1,
|
|
1399
|
+
const struct wsp_ggml_tensor * ids,
|
|
1400
|
+
const int64_t cur_a,
|
|
1401
|
+
const int64_t ir0_start,
|
|
1402
|
+
const int64_t ir0_end,
|
|
1403
|
+
const int64_t ir1_start,
|
|
1404
|
+
const int64_t ir1_end,
|
|
1405
|
+
const char * src0_cur,
|
|
1406
|
+
const struct mmid_row_mapping * matrix_rows,
|
|
1407
|
+
const size_t row_size,
|
|
1408
|
+
const bool src1_cont,
|
|
1409
|
+
const void * wdata) {
|
|
1410
|
+
|
|
1411
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
1412
|
+
|
|
1413
|
+
const enum wsp_ggml_type type = src0->type;
|
|
1414
|
+
|
|
1415
|
+
wsp_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
|
1416
|
+
enum wsp_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
|
1417
|
+
|
|
1418
|
+
const int64_t blck_0 = 16;
|
|
1419
|
+
const int64_t blck_1 = 16;
|
|
1420
|
+
|
|
1421
|
+
float tmp[16];
|
|
1422
|
+
|
|
1423
|
+
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
|
1424
|
+
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
|
1425
|
+
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
|
|
1426
|
+
const int64_t _i12 = ir1; // logical row index for this expert
|
|
1427
|
+
|
|
1428
|
+
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
|
|
1429
|
+
const int id = row_mapping.i1; // selected expert index
|
|
1430
|
+
|
|
1431
|
+
const int64_t i11 = id % ne11;
|
|
1432
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
1433
|
+
|
|
1434
|
+
const int64_t i1 = id; // selected expert index
|
|
1435
|
+
const int64_t i2 = i12; // row
|
|
1436
|
+
|
|
1437
|
+
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
|
1438
|
+
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
|
1439
|
+
// the original src1 data pointer, so we should index using the indices directly
|
|
1440
|
+
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
|
1441
|
+
const char * src1_col = (const char *) wdata +
|
|
1442
|
+
(src1_cont || src1->type != vec_dot_type
|
|
1443
|
+
? (i11 + i12*ne11)*row_size
|
|
1444
|
+
: (i11*nb11 + i12*nb12));
|
|
1445
|
+
|
|
1446
|
+
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
|
|
1447
|
+
|
|
1448
|
+
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
|
1449
|
+
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
|
|
1450
|
+
}
|
|
1451
|
+
|
|
1452
|
+
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
|
|
1453
|
+
}
|
|
1454
|
+
}
|
|
1455
|
+
}
|
|
1456
|
+
}
|
|
1457
|
+
|
|
1458
|
+
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
|
|
1459
|
+
|
|
1460
|
+
void * ptr = *p;
|
|
1461
|
+
ptr = (void *) WSP_GGML_PAD((uintptr_t) ptr, align);
|
|
1462
|
+
*p = (void *) ((char *) ptr + size);
|
|
1463
|
+
return ptr;
|
|
1464
|
+
}
|
|
1465
|
+
|
|
1466
|
+
static void wsp_ggml_compute_forward_mul_mat_id(
|
|
1467
|
+
const struct wsp_ggml_compute_params * params,
|
|
1468
|
+
struct wsp_ggml_tensor * dst) {
|
|
1469
|
+
|
|
1470
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
1471
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
1472
|
+
const struct wsp_ggml_tensor * ids = dst->src[2];
|
|
1473
|
+
|
|
1474
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
1475
|
+
|
|
1476
|
+
const int ith = params->ith;
|
|
1477
|
+
const int nth = params->nth;
|
|
1478
|
+
|
|
1479
|
+
const enum wsp_ggml_type type = src0->type;
|
|
1480
|
+
|
|
1481
|
+
const bool src1_cont = wsp_ggml_is_contiguous(src1);
|
|
1482
|
+
|
|
1483
|
+
enum wsp_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
|
1484
|
+
wsp_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
|
|
1485
|
+
|
|
1486
|
+
// we don't support permuted src0 or src1
|
|
1487
|
+
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
|
|
1488
|
+
WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type));
|
|
1489
|
+
|
|
1490
|
+
// dst cannot be transposed or permuted
|
|
1491
|
+
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
1492
|
+
WSP_GGML_ASSERT(nb0 <= nb1);
|
|
1493
|
+
WSP_GGML_ASSERT(nb1 <= nb2);
|
|
1494
|
+
WSP_GGML_ASSERT(nb2 <= nb3);
|
|
1495
|
+
|
|
1496
|
+
// row groups
|
|
1497
|
+
const int n_ids = ids->ne[0]; // n_expert_used
|
|
1498
|
+
const int n_as = ne02; // n_expert
|
|
1499
|
+
|
|
1500
|
+
void * wdata_cur = params->wdata;
|
|
1501
|
+
|
|
1502
|
+
if (src1->type != vec_dot_type) {
|
|
1503
|
+
incr_ptr_aligned(&wdata_cur, wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)), sizeof(int64_t));
|
|
1504
|
+
}
|
|
1505
|
+
|
|
1506
|
+
int64_t * matrix_row_counts = // [n_as]
|
|
1507
|
+
incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
|
|
1508
|
+
|
|
1509
|
+
struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
|
|
1510
|
+
incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
|
|
1511
|
+
|
|
1512
|
+
char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
|
|
1513
|
+
incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
|
|
1514
|
+
|
|
1515
|
+
WSP_GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
|
|
1516
|
+
|
|
1517
|
+
if (src1->type != vec_dot_type) {
|
|
1518
|
+
char * wdata = params->wdata;
|
|
1519
|
+
|
|
1520
|
+
const size_t nbw0 = wsp_ggml_type_size(vec_dot_type);
|
|
1521
|
+
const size_t nbw1 = wsp_ggml_row_size(vec_dot_type, ne10);
|
|
1522
|
+
const size_t nbw2 = nbw1*ne11;
|
|
1523
|
+
const size_t nbw3 = nbw2*ne12;
|
|
1524
|
+
|
|
1525
|
+
assert(params->wsize >= ne13*nbw3);
|
|
1526
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
1527
|
+
|
|
1528
|
+
#if 0
|
|
1529
|
+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
1530
|
+
for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
|
|
1531
|
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
1532
|
+
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
|
1533
|
+
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
|
1534
|
+
ne10);
|
|
1535
|
+
}
|
|
1536
|
+
}
|
|
1537
|
+
}
|
|
1538
|
+
#else
|
|
1539
|
+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
1540
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
1541
|
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
1542
|
+
size_t bs = wsp_ggml_blck_size(vec_dot_type);
|
|
1543
|
+
int64_t ne10_block_start = (ith * ne10/bs) / nth;
|
|
1544
|
+
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
|
|
1545
|
+
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
|
|
1546
|
+
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
|
|
1547
|
+
(ne10_block_end - ne10_block_start) * bs);
|
|
1548
|
+
}
|
|
1549
|
+
}
|
|
1550
|
+
}
|
|
1551
|
+
#endif
|
|
1552
|
+
}
|
|
1553
|
+
|
|
1554
|
+
if (ith == 0) {
|
|
1555
|
+
// initialize matrix_row_counts
|
|
1556
|
+
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
|
1557
|
+
|
|
1558
|
+
// group rows by src0 matrix
|
|
1559
|
+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
|
1560
|
+
for (int id = 0; id < n_ids; ++id) {
|
|
1561
|
+
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
|
|
1562
|
+
|
|
1563
|
+
assert(i02 >= 0 && i02 < n_as);
|
|
1564
|
+
|
|
1565
|
+
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
|
|
1566
|
+
matrix_row_counts[i02] += 1;
|
|
1567
|
+
}
|
|
1568
|
+
}
|
|
1569
|
+
}
|
|
1570
|
+
|
|
1571
|
+
// reset current_chunk
|
|
1572
|
+
for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
|
|
1573
|
+
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
|
|
1574
|
+
*current_chunk_ctr = nth;
|
|
1575
|
+
}
|
|
1576
|
+
|
|
1577
|
+
wsp_ggml_barrier(params->threadpool);
|
|
1578
|
+
|
|
1579
|
+
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
|
1580
|
+
const int64_t cne1 = matrix_row_counts[cur_a];
|
|
1581
|
+
|
|
1582
|
+
if (cne1 == 0) {
|
|
1583
|
+
continue;
|
|
1584
|
+
}
|
|
1585
|
+
|
|
1586
|
+
const char * src0_cur = (const char *) src0->data + cur_a * nb02;
|
|
1587
|
+
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
|
1588
|
+
const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10);
|
|
1589
|
+
|
|
1590
|
+
const int64_t nr0 = ne01;
|
|
1591
|
+
const int64_t nr1 = cne1;
|
|
1592
|
+
|
|
1593
|
+
int chunk_size = 16;
|
|
1594
|
+
if (nr0 == 1 || nr1 == 1) {
|
|
1595
|
+
chunk_size = 64;
|
|
1596
|
+
}
|
|
1597
|
+
|
|
1598
|
+
#if defined(__aarch64__)
|
|
1599
|
+
// disable for ARM
|
|
1600
|
+
const bool disable_chunking = true;
|
|
1601
|
+
#else
|
|
1602
|
+
// disable for NUMA
|
|
1603
|
+
const bool disable_chunking = wsp_ggml_is_numa();
|
|
1604
|
+
#endif // defined(__aarch64__)
|
|
1605
|
+
|
|
1606
|
+
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
|
1607
|
+
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
|
1608
|
+
|
|
1609
|
+
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
|
|
1610
|
+
nchunk0 = nr0 > nr1 ? nth : 1;
|
|
1611
|
+
nchunk1 = nr0 > nr1 ? 1 : nth;
|
|
1612
|
+
}
|
|
1613
|
+
|
|
1614
|
+
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
|
1615
|
+
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
|
1616
|
+
|
|
1617
|
+
int current_chunk = ith;
|
|
1618
|
+
|
|
1619
|
+
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
|
|
1620
|
+
|
|
1621
|
+
while (current_chunk < nchunk0 * nchunk1) {
|
|
1622
|
+
const int64_t ith0 = current_chunk % nchunk0;
|
|
1623
|
+
const int64_t ith1 = current_chunk / nchunk0;
|
|
1624
|
+
|
|
1625
|
+
const int64_t ir0_start = dr0 * ith0;
|
|
1626
|
+
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
|
1627
|
+
|
|
1628
|
+
const int64_t ir1_start = dr1 * ith1;
|
|
1629
|
+
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
|
1630
|
+
|
|
1631
|
+
wsp_ggml_compute_forward_mul_mat_id_one_chunk(
|
|
1632
|
+
dst, src0, src1, ids, cur_a,
|
|
1633
|
+
ir0_start, ir0_end, ir1_start, ir1_end,
|
|
1634
|
+
src0_cur, matrix_rows, row_size, src1_cont, wdata
|
|
1635
|
+
);
|
|
1636
|
+
|
|
1637
|
+
if (nth >= nchunk0 * nchunk1) {
|
|
1638
|
+
break;
|
|
1639
|
+
}
|
|
1640
|
+
|
|
1641
|
+
current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
|
|
1642
|
+
}
|
|
1643
|
+
}
|
|
1644
|
+
}
|
|
1645
|
+
|
|
1646
|
+
/////////////////////////////////
|
|
1647
|
+
|
|
1648
|
+
static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * tensor) {
|
|
1649
|
+
WSP_GGML_ASSERT(params);
|
|
1650
|
+
|
|
1651
|
+
if (tensor->op == WSP_GGML_OP_NONE || wsp_ggml_is_empty(tensor)) {
|
|
1652
|
+
return;
|
|
1653
|
+
}
|
|
1654
|
+
|
|
1655
|
+
// extra_buffer op?
|
|
1656
|
+
if (wsp_ggml_cpu_extra_compute_forward(params, tensor)) {
|
|
1657
|
+
return;
|
|
1658
|
+
}
|
|
1659
|
+
|
|
1660
|
+
switch (tensor->op) {
|
|
1661
|
+
case WSP_GGML_OP_DUP:
|
|
1662
|
+
{
|
|
1663
|
+
wsp_ggml_compute_forward_dup(params, tensor);
|
|
1664
|
+
} break;
|
|
1665
|
+
case WSP_GGML_OP_ADD:
|
|
1666
|
+
{
|
|
1667
|
+
wsp_ggml_compute_forward_add(params, tensor);
|
|
1668
|
+
} break;
|
|
1669
|
+
case WSP_GGML_OP_ADD1:
|
|
1670
|
+
{
|
|
1671
|
+
wsp_ggml_compute_forward_add1(params, tensor);
|
|
1672
|
+
} break;
|
|
1673
|
+
case WSP_GGML_OP_ACC:
|
|
1674
|
+
{
|
|
1675
|
+
wsp_ggml_compute_forward_acc(params, tensor);
|
|
1676
|
+
} break;
|
|
1677
|
+
case WSP_GGML_OP_SUB:
|
|
1678
|
+
{
|
|
1679
|
+
wsp_ggml_compute_forward_sub(params, tensor);
|
|
1680
|
+
} break;
|
|
1681
|
+
case WSP_GGML_OP_MUL:
|
|
1682
|
+
{
|
|
1683
|
+
wsp_ggml_compute_forward_mul(params, tensor);
|
|
1684
|
+
} break;
|
|
1685
|
+
case WSP_GGML_OP_DIV:
|
|
1686
|
+
{
|
|
1687
|
+
wsp_ggml_compute_forward_div(params, tensor);
|
|
1688
|
+
} break;
|
|
1689
|
+
case WSP_GGML_OP_SQR:
|
|
1690
|
+
{
|
|
1691
|
+
wsp_ggml_compute_forward_sqr(params, tensor);
|
|
1692
|
+
} break;
|
|
1693
|
+
case WSP_GGML_OP_SQRT:
|
|
1694
|
+
{
|
|
1695
|
+
wsp_ggml_compute_forward_sqrt(params, tensor);
|
|
1696
|
+
} break;
|
|
1697
|
+
case WSP_GGML_OP_LOG:
|
|
1698
|
+
{
|
|
1699
|
+
wsp_ggml_compute_forward_log(params, tensor);
|
|
1700
|
+
} break;
|
|
1701
|
+
case WSP_GGML_OP_SIN:
|
|
1702
|
+
{
|
|
1703
|
+
wsp_ggml_compute_forward_sin(params, tensor);
|
|
1704
|
+
} break;
|
|
1705
|
+
case WSP_GGML_OP_COS:
|
|
1706
|
+
{
|
|
1707
|
+
wsp_ggml_compute_forward_cos(params, tensor);
|
|
1708
|
+
} break;
|
|
1709
|
+
case WSP_GGML_OP_SUM:
|
|
1710
|
+
{
|
|
1711
|
+
wsp_ggml_compute_forward_sum(params, tensor);
|
|
1712
|
+
} break;
|
|
1713
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
1714
|
+
{
|
|
1715
|
+
wsp_ggml_compute_forward_sum_rows(params, tensor);
|
|
1716
|
+
} break;
|
|
1717
|
+
case WSP_GGML_OP_MEAN:
|
|
1718
|
+
{
|
|
1719
|
+
wsp_ggml_compute_forward_mean(params, tensor);
|
|
1720
|
+
} break;
|
|
1721
|
+
case WSP_GGML_OP_ARGMAX:
|
|
1722
|
+
{
|
|
1723
|
+
wsp_ggml_compute_forward_argmax(params, tensor);
|
|
1724
|
+
} break;
|
|
1725
|
+
case WSP_GGML_OP_COUNT_EQUAL:
|
|
1726
|
+
{
|
|
1727
|
+
wsp_ggml_compute_forward_count_equal(params, tensor);
|
|
1728
|
+
} break;
|
|
1729
|
+
case WSP_GGML_OP_REPEAT:
|
|
1730
|
+
{
|
|
1731
|
+
wsp_ggml_compute_forward_repeat(params, tensor);
|
|
1732
|
+
} break;
|
|
1733
|
+
case WSP_GGML_OP_REPEAT_BACK:
|
|
1734
|
+
{
|
|
1735
|
+
wsp_ggml_compute_forward_repeat_back(params, tensor);
|
|
1736
|
+
} break;
|
|
1737
|
+
case WSP_GGML_OP_CONCAT:
|
|
1738
|
+
{
|
|
1739
|
+
wsp_ggml_compute_forward_concat(params, tensor);
|
|
1740
|
+
} break;
|
|
1741
|
+
case WSP_GGML_OP_SILU_BACK:
|
|
1742
|
+
{
|
|
1743
|
+
wsp_ggml_compute_forward_silu_back(params, tensor);
|
|
1744
|
+
} break;
|
|
1745
|
+
case WSP_GGML_OP_NORM:
|
|
1746
|
+
{
|
|
1747
|
+
wsp_ggml_compute_forward_norm(params, tensor);
|
|
1748
|
+
} break;
|
|
1749
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
1750
|
+
{
|
|
1751
|
+
wsp_ggml_compute_forward_rms_norm(params, tensor);
|
|
1752
|
+
} break;
|
|
1753
|
+
case WSP_GGML_OP_RMS_NORM_BACK:
|
|
1754
|
+
{
|
|
1755
|
+
wsp_ggml_compute_forward_rms_norm_back(params, tensor);
|
|
1756
|
+
} break;
|
|
1757
|
+
case WSP_GGML_OP_GROUP_NORM:
|
|
1758
|
+
{
|
|
1759
|
+
wsp_ggml_compute_forward_group_norm(params, tensor);
|
|
1760
|
+
} break;
|
|
1761
|
+
case WSP_GGML_OP_L2_NORM:
|
|
1762
|
+
{
|
|
1763
|
+
wsp_ggml_compute_forward_l2_norm(params, tensor);
|
|
1764
|
+
} break;
|
|
1765
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
1766
|
+
{
|
|
1767
|
+
wsp_ggml_compute_forward_mul_mat(params, tensor);
|
|
1768
|
+
} break;
|
|
1769
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
1770
|
+
{
|
|
1771
|
+
wsp_ggml_compute_forward_mul_mat_id(params, tensor);
|
|
1772
|
+
} break;
|
|
1773
|
+
case WSP_GGML_OP_OUT_PROD:
|
|
1774
|
+
{
|
|
1775
|
+
wsp_ggml_compute_forward_out_prod(params, tensor);
|
|
1776
|
+
} break;
|
|
1777
|
+
case WSP_GGML_OP_SCALE:
|
|
1778
|
+
{
|
|
1779
|
+
wsp_ggml_compute_forward_scale(params, tensor);
|
|
1780
|
+
} break;
|
|
1781
|
+
case WSP_GGML_OP_SET:
|
|
1782
|
+
{
|
|
1783
|
+
wsp_ggml_compute_forward_set(params, tensor);
|
|
1784
|
+
} break;
|
|
1785
|
+
case WSP_GGML_OP_CPY:
|
|
1786
|
+
{
|
|
1787
|
+
wsp_ggml_compute_forward_cpy(params, tensor);
|
|
1788
|
+
} break;
|
|
1789
|
+
case WSP_GGML_OP_CONT:
|
|
1790
|
+
{
|
|
1791
|
+
wsp_ggml_compute_forward_cont(params, tensor);
|
|
1792
|
+
} break;
|
|
1793
|
+
case WSP_GGML_OP_RESHAPE:
|
|
1794
|
+
{
|
|
1795
|
+
wsp_ggml_compute_forward_reshape(params, tensor);
|
|
1796
|
+
} break;
|
|
1797
|
+
case WSP_GGML_OP_VIEW:
|
|
1798
|
+
{
|
|
1799
|
+
wsp_ggml_compute_forward_view(params, tensor);
|
|
1800
|
+
} break;
|
|
1801
|
+
case WSP_GGML_OP_PERMUTE:
|
|
1802
|
+
{
|
|
1803
|
+
wsp_ggml_compute_forward_permute(params, tensor);
|
|
1804
|
+
} break;
|
|
1805
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
1806
|
+
{
|
|
1807
|
+
wsp_ggml_compute_forward_transpose(params, tensor);
|
|
1808
|
+
} break;
|
|
1809
|
+
case WSP_GGML_OP_GET_ROWS:
|
|
1810
|
+
{
|
|
1811
|
+
wsp_ggml_compute_forward_get_rows(params, tensor);
|
|
1812
|
+
} break;
|
|
1813
|
+
case WSP_GGML_OP_GET_ROWS_BACK:
|
|
1814
|
+
{
|
|
1815
|
+
wsp_ggml_compute_forward_get_rows_back(params, tensor);
|
|
1816
|
+
} break;
|
|
1817
|
+
case WSP_GGML_OP_DIAG:
|
|
1818
|
+
{
|
|
1819
|
+
wsp_ggml_compute_forward_diag(params, tensor);
|
|
1820
|
+
} break;
|
|
1821
|
+
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
1822
|
+
{
|
|
1823
|
+
wsp_ggml_compute_forward_diag_mask_inf(params, tensor);
|
|
1824
|
+
} break;
|
|
1825
|
+
case WSP_GGML_OP_DIAG_MASK_ZERO:
|
|
1826
|
+
{
|
|
1827
|
+
wsp_ggml_compute_forward_diag_mask_zero(params, tensor);
|
|
1828
|
+
} break;
|
|
1829
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
1830
|
+
{
|
|
1831
|
+
wsp_ggml_compute_forward_soft_max(params, tensor);
|
|
1832
|
+
} break;
|
|
1833
|
+
case WSP_GGML_OP_SOFT_MAX_BACK:
|
|
1834
|
+
{
|
|
1835
|
+
wsp_ggml_compute_forward_soft_max_ext_back(params, tensor);
|
|
1836
|
+
} break;
|
|
1837
|
+
case WSP_GGML_OP_ROPE:
|
|
1838
|
+
{
|
|
1839
|
+
wsp_ggml_compute_forward_rope(params, tensor);
|
|
1840
|
+
} break;
|
|
1841
|
+
case WSP_GGML_OP_ROPE_BACK:
|
|
1842
|
+
{
|
|
1843
|
+
wsp_ggml_compute_forward_rope_back(params, tensor);
|
|
1844
|
+
} break;
|
|
1845
|
+
case WSP_GGML_OP_CLAMP:
|
|
1846
|
+
{
|
|
1847
|
+
wsp_ggml_compute_forward_clamp(params, tensor);
|
|
1848
|
+
} break;
|
|
1849
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
1850
|
+
{
|
|
1851
|
+
wsp_ggml_compute_forward_conv_transpose_1d(params, tensor);
|
|
1852
|
+
} break;
|
|
1853
|
+
case WSP_GGML_OP_IM2COL:
|
|
1854
|
+
{
|
|
1855
|
+
wsp_ggml_compute_forward_im2col(params, tensor);
|
|
1856
|
+
} break;
|
|
1857
|
+
case WSP_GGML_OP_IM2COL_BACK:
|
|
1858
|
+
{
|
|
1859
|
+
wsp_ggml_compute_forward_im2col_back_f32(params, tensor);
|
|
1860
|
+
} break;
|
|
1861
|
+
case WSP_GGML_OP_CONV_2D_DW:
|
|
1862
|
+
{
|
|
1863
|
+
wsp_ggml_compute_forward_conv_2d_dw(params, tensor);
|
|
1864
|
+
} break;
|
|
1865
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
1866
|
+
{
|
|
1867
|
+
wsp_ggml_compute_forward_conv_transpose_2d(params, tensor);
|
|
1868
|
+
} break;
|
|
1869
|
+
case WSP_GGML_OP_POOL_1D:
|
|
1870
|
+
{
|
|
1871
|
+
wsp_ggml_compute_forward_pool_1d(params, tensor);
|
|
1872
|
+
} break;
|
|
1873
|
+
case WSP_GGML_OP_POOL_2D:
|
|
1874
|
+
{
|
|
1875
|
+
wsp_ggml_compute_forward_pool_2d(params, tensor);
|
|
1876
|
+
} break;
|
|
1877
|
+
case WSP_GGML_OP_POOL_2D_BACK:
|
|
1878
|
+
{
|
|
1879
|
+
wsp_ggml_compute_forward_pool_2d_back(params, tensor);
|
|
1880
|
+
} break;
|
|
1881
|
+
case WSP_GGML_OP_UPSCALE:
|
|
1882
|
+
{
|
|
1883
|
+
wsp_ggml_compute_forward_upscale(params, tensor);
|
|
1884
|
+
} break;
|
|
1885
|
+
case WSP_GGML_OP_PAD:
|
|
1886
|
+
{
|
|
1887
|
+
wsp_ggml_compute_forward_pad(params, tensor);
|
|
1888
|
+
} break;
|
|
1889
|
+
case WSP_GGML_OP_PAD_REFLECT_1D:
|
|
1890
|
+
{
|
|
1891
|
+
wsp_ggml_compute_forward_pad_reflect_1d(params, tensor);
|
|
1892
|
+
} break;
|
|
1893
|
+
case WSP_GGML_OP_ROLL:
|
|
1894
|
+
{
|
|
1895
|
+
wsp_ggml_compute_forward_roll(params, tensor);
|
|
1896
|
+
} break;
|
|
1897
|
+
case WSP_GGML_OP_ARANGE:
|
|
1898
|
+
{
|
|
1899
|
+
wsp_ggml_compute_forward_arange(params, tensor);
|
|
1900
|
+
} break;
|
|
1901
|
+
case WSP_GGML_OP_TIMESTEP_EMBEDDING:
|
|
1902
|
+
{
|
|
1903
|
+
wsp_ggml_compute_forward_timestep_embedding(params, tensor);
|
|
1904
|
+
} break;
|
|
1905
|
+
case WSP_GGML_OP_ARGSORT:
|
|
1906
|
+
{
|
|
1907
|
+
wsp_ggml_compute_forward_argsort(params, tensor);
|
|
1908
|
+
} break;
|
|
1909
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
1910
|
+
{
|
|
1911
|
+
wsp_ggml_compute_forward_leaky_relu(params, tensor);
|
|
1912
|
+
} break;
|
|
1913
|
+
case WSP_GGML_OP_FLASH_ATTN_EXT:
|
|
1914
|
+
{
|
|
1915
|
+
wsp_ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
|
1916
|
+
} break;
|
|
1917
|
+
case WSP_GGML_OP_FLASH_ATTN_BACK:
|
|
1918
|
+
{
|
|
1919
|
+
int32_t t = wsp_ggml_get_op_params_i32(tensor, 0);
|
|
1920
|
+
WSP_GGML_ASSERT(t == 0 || t == 1);
|
|
1921
|
+
bool masked = t != 0;
|
|
1922
|
+
wsp_ggml_compute_forward_flash_attn_back(params, masked, tensor);
|
|
1923
|
+
} break;
|
|
1924
|
+
case WSP_GGML_OP_SSM_CONV:
|
|
1925
|
+
{
|
|
1926
|
+
wsp_ggml_compute_forward_ssm_conv(params, tensor);
|
|
1927
|
+
} break;
|
|
1928
|
+
case WSP_GGML_OP_SSM_SCAN:
|
|
1929
|
+
{
|
|
1930
|
+
wsp_ggml_compute_forward_ssm_scan(params, tensor);
|
|
1931
|
+
} break;
|
|
1932
|
+
case WSP_GGML_OP_WIN_PART:
|
|
1933
|
+
{
|
|
1934
|
+
wsp_ggml_compute_forward_win_part(params, tensor);
|
|
1935
|
+
} break;
|
|
1936
|
+
case WSP_GGML_OP_WIN_UNPART:
|
|
1937
|
+
{
|
|
1938
|
+
wsp_ggml_compute_forward_win_unpart(params, tensor);
|
|
1939
|
+
} break;
|
|
1940
|
+
case WSP_GGML_OP_UNARY:
|
|
1941
|
+
{
|
|
1942
|
+
wsp_ggml_compute_forward_unary(params, tensor);
|
|
1943
|
+
} break;
|
|
1944
|
+
case WSP_GGML_OP_GET_REL_POS:
|
|
1945
|
+
{
|
|
1946
|
+
wsp_ggml_compute_forward_get_rel_pos(params, tensor);
|
|
1947
|
+
} break;
|
|
1948
|
+
case WSP_GGML_OP_ADD_REL_POS:
|
|
1949
|
+
{
|
|
1950
|
+
wsp_ggml_compute_forward_add_rel_pos(params, tensor);
|
|
1951
|
+
} break;
|
|
1952
|
+
case WSP_GGML_OP_RWKV_WKV6:
|
|
1953
|
+
{
|
|
1954
|
+
wsp_ggml_compute_forward_rwkv_wkv6(params, tensor);
|
|
1955
|
+
} break;
|
|
1956
|
+
case WSP_GGML_OP_GATED_LINEAR_ATTN:
|
|
1957
|
+
{
|
|
1958
|
+
wsp_ggml_compute_forward_gla(params, tensor);
|
|
1959
|
+
} break;
|
|
1960
|
+
case WSP_GGML_OP_RWKV_WKV7:
|
|
1961
|
+
{
|
|
1962
|
+
wsp_ggml_compute_forward_rwkv_wkv7(params, tensor);
|
|
1963
|
+
} break;
|
|
1964
|
+
case WSP_GGML_OP_MAP_CUSTOM1:
|
|
1965
|
+
{
|
|
1966
|
+
wsp_ggml_compute_forward_map_custom1(params, tensor);
|
|
1967
|
+
}
|
|
1968
|
+
break;
|
|
1969
|
+
case WSP_GGML_OP_MAP_CUSTOM2:
|
|
1970
|
+
{
|
|
1971
|
+
wsp_ggml_compute_forward_map_custom2(params, tensor);
|
|
1972
|
+
}
|
|
1973
|
+
break;
|
|
1974
|
+
case WSP_GGML_OP_MAP_CUSTOM3:
|
|
1975
|
+
{
|
|
1976
|
+
wsp_ggml_compute_forward_map_custom3(params, tensor);
|
|
1977
|
+
}
|
|
1978
|
+
break;
|
|
1979
|
+
case WSP_GGML_OP_CUSTOM:
|
|
1980
|
+
{
|
|
1981
|
+
wsp_ggml_compute_forward_custom(params, tensor);
|
|
1982
|
+
}
|
|
1983
|
+
break;
|
|
1984
|
+
case WSP_GGML_OP_CROSS_ENTROPY_LOSS:
|
|
1985
|
+
{
|
|
1986
|
+
wsp_ggml_compute_forward_cross_entropy_loss(params, tensor);
|
|
1987
|
+
}
|
|
1988
|
+
break;
|
|
1989
|
+
case WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
|
1990
|
+
{
|
|
1991
|
+
wsp_ggml_compute_forward_cross_entropy_loss_back(params, tensor);
|
|
1992
|
+
}
|
|
1993
|
+
break;
|
|
1994
|
+
case WSP_GGML_OP_OPT_STEP_ADAMW:
|
|
1995
|
+
{
|
|
1996
|
+
wsp_ggml_compute_forward_opt_step_adamw(params, tensor);
|
|
1997
|
+
}
|
|
1998
|
+
break;
|
|
1999
|
+
case WSP_GGML_OP_NONE:
|
|
2000
|
+
{
|
|
2001
|
+
// nop
|
|
2002
|
+
} break;
|
|
2003
|
+
case WSP_GGML_OP_COUNT:
|
|
2004
|
+
{
|
|
2005
|
+
WSP_GGML_ABORT("fatal error");
|
|
2006
|
+
}
|
|
2007
|
+
}
|
|
2008
|
+
}
|
|
2009
|
+
|
|
2010
|
+
// Android's libc implementation "bionic" does not support setting affinity
|
|
2011
|
+
#if defined(__gnu_linux__)
|
|
2012
|
+
static void set_numa_thread_affinity(int thread_n) {
|
|
2013
|
+
if (!wsp_ggml_is_numa()) {
|
|
2014
|
+
return;
|
|
2015
|
+
}
|
|
2016
|
+
|
|
2017
|
+
int node_num;
|
|
2018
|
+
int rv;
|
|
2019
|
+
size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
|
|
2020
|
+
|
|
2021
|
+
switch(g_state.numa.numa_strategy) {
|
|
2022
|
+
case WSP_GGML_NUMA_STRATEGY_DISTRIBUTE:
|
|
2023
|
+
// run thread on node_num thread_n / (threads per node)
|
|
2024
|
+
node_num = thread_n % g_state.numa.n_nodes;
|
|
2025
|
+
break;
|
|
2026
|
+
case WSP_GGML_NUMA_STRATEGY_ISOLATE:
|
|
2027
|
+
// run thread on current_node
|
|
2028
|
+
node_num = g_state.numa.current_node;
|
|
2029
|
+
break;
|
|
2030
|
+
case WSP_GGML_NUMA_STRATEGY_NUMACTL:
|
|
2031
|
+
// use the cpuset that numactl gave us
|
|
2032
|
+
rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset);
|
|
2033
|
+
if (rv) {
|
|
2034
|
+
fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv));
|
|
2035
|
+
}
|
|
2036
|
+
return;
|
|
2037
|
+
default:
|
|
2038
|
+
return;
|
|
2039
|
+
}
|
|
2040
|
+
|
|
2041
|
+
struct wsp_ggml_numa_node * node = &g_state.numa.nodes[node_num];
|
|
2042
|
+
|
|
2043
|
+
cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
|
|
2044
|
+
CPU_ZERO_S(setsize, cpus);
|
|
2045
|
+
for (size_t i = 0; i < node->n_cpus; ++i) {
|
|
2046
|
+
CPU_SET_S(node->cpus[i], setsize, cpus);
|
|
2047
|
+
}
|
|
2048
|
+
|
|
2049
|
+
rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
|
|
2050
|
+
if (rv) {
|
|
2051
|
+
fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
|
|
2052
|
+
}
|
|
2053
|
+
|
|
2054
|
+
CPU_FREE(cpus);
|
|
2055
|
+
}
|
|
2056
|
+
|
|
2057
|
+
static void clear_numa_thread_affinity(void) {
|
|
2058
|
+
if (!wsp_ggml_is_numa()) {
|
|
2059
|
+
return;
|
|
2060
|
+
}
|
|
2061
|
+
|
|
2062
|
+
size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
|
|
2063
|
+
|
|
2064
|
+
cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
|
|
2065
|
+
CPU_ZERO_S(setsize, cpus);
|
|
2066
|
+
for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
|
|
2067
|
+
CPU_SET_S(i, setsize, cpus);
|
|
2068
|
+
}
|
|
2069
|
+
|
|
2070
|
+
int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
|
|
2071
|
+
if (rv) {
|
|
2072
|
+
fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
|
|
2073
|
+
}
|
|
2074
|
+
|
|
2075
|
+
CPU_FREE(cpus);
|
|
2076
|
+
}
|
|
2077
|
+
#else
|
|
2078
|
+
// TODO: Windows etc.
|
|
2079
|
+
// (the linux implementation may also work on BSD, someone should test)
|
|
2080
|
+
static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
|
|
2081
|
+
static void clear_numa_thread_affinity(void) {}
|
|
2082
|
+
#endif
|
|
2083
|
+
|
|
2084
|
+
static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
2085
|
+
int n_tasks = 0;
|
|
2086
|
+
|
|
2087
|
+
if (wsp_ggml_is_empty(node)) {
|
|
2088
|
+
// no need to multi-thread a no-op
|
|
2089
|
+
n_tasks = 1;
|
|
2090
|
+
return n_tasks;
|
|
2091
|
+
}
|
|
2092
|
+
|
|
2093
|
+
switch (node->op) {
|
|
2094
|
+
case WSP_GGML_OP_CPY:
|
|
2095
|
+
case WSP_GGML_OP_DUP:
|
|
2096
|
+
case WSP_GGML_OP_CONT:
|
|
2097
|
+
case WSP_GGML_OP_ADD:
|
|
2098
|
+
case WSP_GGML_OP_ADD1:
|
|
2099
|
+
case WSP_GGML_OP_ACC:
|
|
2100
|
+
{
|
|
2101
|
+
n_tasks = n_threads;
|
|
2102
|
+
} break;
|
|
2103
|
+
case WSP_GGML_OP_SUB:
|
|
2104
|
+
case WSP_GGML_OP_SQR:
|
|
2105
|
+
case WSP_GGML_OP_SQRT:
|
|
2106
|
+
case WSP_GGML_OP_LOG:
|
|
2107
|
+
case WSP_GGML_OP_SIN:
|
|
2108
|
+
case WSP_GGML_OP_COS:
|
|
2109
|
+
case WSP_GGML_OP_SUM:
|
|
2110
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
2111
|
+
case WSP_GGML_OP_MEAN:
|
|
2112
|
+
case WSP_GGML_OP_ARGMAX:
|
|
2113
|
+
{
|
|
2114
|
+
n_tasks = 1;
|
|
2115
|
+
} break;
|
|
2116
|
+
case WSP_GGML_OP_COUNT_EQUAL:
|
|
2117
|
+
{
|
|
2118
|
+
n_tasks = n_threads;
|
|
2119
|
+
} break;
|
|
2120
|
+
case WSP_GGML_OP_REPEAT:
|
|
2121
|
+
case WSP_GGML_OP_REPEAT_BACK:
|
|
2122
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
2123
|
+
{
|
|
2124
|
+
n_tasks = 1;
|
|
2125
|
+
} break;
|
|
2126
|
+
case WSP_GGML_OP_UNARY:
|
|
2127
|
+
switch (wsp_ggml_get_unary_op(node)) {
|
|
2128
|
+
case WSP_GGML_UNARY_OP_ABS:
|
|
2129
|
+
case WSP_GGML_UNARY_OP_SGN:
|
|
2130
|
+
case WSP_GGML_UNARY_OP_NEG:
|
|
2131
|
+
case WSP_GGML_UNARY_OP_STEP:
|
|
2132
|
+
case WSP_GGML_UNARY_OP_TANH:
|
|
2133
|
+
case WSP_GGML_UNARY_OP_ELU:
|
|
2134
|
+
case WSP_GGML_UNARY_OP_RELU:
|
|
2135
|
+
case WSP_GGML_UNARY_OP_SIGMOID:
|
|
2136
|
+
case WSP_GGML_UNARY_OP_HARDSWISH:
|
|
2137
|
+
case WSP_GGML_UNARY_OP_HARDSIGMOID:
|
|
2138
|
+
case WSP_GGML_UNARY_OP_EXP:
|
|
2139
|
+
{
|
|
2140
|
+
n_tasks = 1;
|
|
2141
|
+
} break;
|
|
2142
|
+
|
|
2143
|
+
case WSP_GGML_UNARY_OP_GELU:
|
|
2144
|
+
case WSP_GGML_UNARY_OP_GELU_ERF:
|
|
2145
|
+
case WSP_GGML_UNARY_OP_GELU_QUICK:
|
|
2146
|
+
case WSP_GGML_UNARY_OP_SILU:
|
|
2147
|
+
{
|
|
2148
|
+
n_tasks = n_threads;
|
|
2149
|
+
} break;
|
|
2150
|
+
default:
|
|
2151
|
+
WSP_GGML_ABORT("fatal error");
|
|
2152
|
+
}
|
|
2153
|
+
break;
|
|
2154
|
+
case WSP_GGML_OP_SILU_BACK:
|
|
2155
|
+
case WSP_GGML_OP_MUL:
|
|
2156
|
+
case WSP_GGML_OP_DIV:
|
|
2157
|
+
case WSP_GGML_OP_NORM:
|
|
2158
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
2159
|
+
case WSP_GGML_OP_RMS_NORM_BACK:
|
|
2160
|
+
case WSP_GGML_OP_L2_NORM:
|
|
2161
|
+
case WSP_GGML_OP_GROUP_NORM:
|
|
2162
|
+
case WSP_GGML_OP_CONCAT:
|
|
2163
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
2164
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
2165
|
+
case WSP_GGML_OP_OUT_PROD:
|
|
2166
|
+
{
|
|
2167
|
+
n_tasks = n_threads;
|
|
2168
|
+
} break;
|
|
2169
|
+
case WSP_GGML_OP_GET_ROWS:
|
|
2170
|
+
{
|
|
2171
|
+
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
|
|
2172
|
+
// decreases performance with GPU offloading
|
|
2173
|
+
//n_tasks = n_threads;
|
|
2174
|
+
n_tasks = 1;
|
|
2175
|
+
} break;
|
|
2176
|
+
case WSP_GGML_OP_SCALE:
|
|
2177
|
+
case WSP_GGML_OP_SET:
|
|
2178
|
+
case WSP_GGML_OP_RESHAPE:
|
|
2179
|
+
case WSP_GGML_OP_VIEW:
|
|
2180
|
+
case WSP_GGML_OP_PERMUTE:
|
|
2181
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
2182
|
+
case WSP_GGML_OP_GET_ROWS_BACK:
|
|
2183
|
+
case WSP_GGML_OP_DIAG:
|
|
2184
|
+
{
|
|
2185
|
+
n_tasks = 1;
|
|
2186
|
+
} break;
|
|
2187
|
+
case WSP_GGML_OP_DIAG_MASK_ZERO:
|
|
2188
|
+
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
2189
|
+
case WSP_GGML_OP_SOFT_MAX_BACK:
|
|
2190
|
+
case WSP_GGML_OP_ROPE:
|
|
2191
|
+
case WSP_GGML_OP_ROPE_BACK:
|
|
2192
|
+
case WSP_GGML_OP_ADD_REL_POS:
|
|
2193
|
+
{
|
|
2194
|
+
n_tasks = n_threads;
|
|
2195
|
+
} break;
|
|
2196
|
+
case WSP_GGML_OP_CLAMP:
|
|
2197
|
+
{
|
|
2198
|
+
n_tasks = 1; //TODO
|
|
2199
|
+
} break;
|
|
2200
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
2201
|
+
{
|
|
2202
|
+
n_tasks = MIN(n_threads, wsp_ggml_nrows(node->src[0]));
|
|
2203
|
+
} break;
|
|
2204
|
+
case WSP_GGML_OP_IM2COL:
|
|
2205
|
+
case WSP_GGML_OP_IM2COL_BACK:
|
|
2206
|
+
case WSP_GGML_OP_CONV_2D_DW:
|
|
2207
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
2208
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
2209
|
+
{
|
|
2210
|
+
n_tasks = n_threads;
|
|
2211
|
+
} break;
|
|
2212
|
+
case WSP_GGML_OP_POOL_1D:
|
|
2213
|
+
case WSP_GGML_OP_POOL_2D:
|
|
2214
|
+
case WSP_GGML_OP_POOL_2D_BACK:
|
|
2215
|
+
{
|
|
2216
|
+
n_tasks = 1;
|
|
2217
|
+
} break;
|
|
2218
|
+
case WSP_GGML_OP_UPSCALE:
|
|
2219
|
+
case WSP_GGML_OP_PAD:
|
|
2220
|
+
case WSP_GGML_OP_PAD_REFLECT_1D:
|
|
2221
|
+
case WSP_GGML_OP_ROLL:
|
|
2222
|
+
case WSP_GGML_OP_ARANGE:
|
|
2223
|
+
case WSP_GGML_OP_TIMESTEP_EMBEDDING:
|
|
2224
|
+
case WSP_GGML_OP_ARGSORT:
|
|
2225
|
+
case WSP_GGML_OP_FLASH_ATTN_EXT:
|
|
2226
|
+
case WSP_GGML_OP_FLASH_ATTN_BACK:
|
|
2227
|
+
case WSP_GGML_OP_SSM_CONV:
|
|
2228
|
+
case WSP_GGML_OP_SSM_SCAN:
|
|
2229
|
+
case WSP_GGML_OP_RWKV_WKV6:
|
|
2230
|
+
case WSP_GGML_OP_GATED_LINEAR_ATTN:
|
|
2231
|
+
case WSP_GGML_OP_RWKV_WKV7:
|
|
2232
|
+
{
|
|
2233
|
+
n_tasks = n_threads;
|
|
2234
|
+
} break;
|
|
2235
|
+
case WSP_GGML_OP_WIN_PART:
|
|
2236
|
+
case WSP_GGML_OP_WIN_UNPART:
|
|
2237
|
+
case WSP_GGML_OP_GET_REL_POS:
|
|
2238
|
+
{
|
|
2239
|
+
n_tasks = 1;
|
|
2240
|
+
} break;
|
|
2241
|
+
case WSP_GGML_OP_MAP_CUSTOM1:
|
|
2242
|
+
{
|
|
2243
|
+
struct wsp_ggml_map_custom1_op_params p;
|
|
2244
|
+
memcpy(&p, node->op_params, sizeof(p));
|
|
2245
|
+
if (p.n_tasks == WSP_GGML_N_TASKS_MAX) {
|
|
2246
|
+
n_tasks = n_threads;
|
|
2247
|
+
} else {
|
|
2248
|
+
n_tasks = MIN(p.n_tasks, n_threads);
|
|
2249
|
+
}
|
|
2250
|
+
} break;
|
|
2251
|
+
case WSP_GGML_OP_MAP_CUSTOM2:
|
|
2252
|
+
{
|
|
2253
|
+
struct wsp_ggml_map_custom2_op_params p;
|
|
2254
|
+
memcpy(&p, node->op_params, sizeof(p));
|
|
2255
|
+
if (p.n_tasks == WSP_GGML_N_TASKS_MAX) {
|
|
2256
|
+
n_tasks = n_threads;
|
|
2257
|
+
} else {
|
|
2258
|
+
n_tasks = MIN(p.n_tasks, n_threads);
|
|
2259
|
+
}
|
|
2260
|
+
} break;
|
|
2261
|
+
case WSP_GGML_OP_MAP_CUSTOM3:
|
|
2262
|
+
{
|
|
2263
|
+
struct wsp_ggml_map_custom3_op_params p;
|
|
2264
|
+
memcpy(&p, node->op_params, sizeof(p));
|
|
2265
|
+
if (p.n_tasks == WSP_GGML_N_TASKS_MAX) {
|
|
2266
|
+
n_tasks = n_threads;
|
|
2267
|
+
} else {
|
|
2268
|
+
n_tasks = MIN(p.n_tasks, n_threads);
|
|
2269
|
+
}
|
|
2270
|
+
} break;
|
|
2271
|
+
case WSP_GGML_OP_CUSTOM:
|
|
2272
|
+
{
|
|
2273
|
+
struct wsp_ggml_custom_op_params p;
|
|
2274
|
+
memcpy(&p, node->op_params, sizeof(p));
|
|
2275
|
+
if (p.n_tasks == WSP_GGML_N_TASKS_MAX) {
|
|
2276
|
+
n_tasks = n_threads;
|
|
2277
|
+
} else {
|
|
2278
|
+
n_tasks = MIN(p.n_tasks, n_threads);
|
|
2279
|
+
}
|
|
2280
|
+
} break;
|
|
2281
|
+
case WSP_GGML_OP_CROSS_ENTROPY_LOSS:
|
|
2282
|
+
case WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
|
2283
|
+
case WSP_GGML_OP_OPT_STEP_ADAMW:
|
|
2284
|
+
{
|
|
2285
|
+
n_tasks = n_threads;
|
|
2286
|
+
} break;
|
|
2287
|
+
case WSP_GGML_OP_NONE:
|
|
2288
|
+
{
|
|
2289
|
+
n_tasks = 1;
|
|
2290
|
+
} break;
|
|
2291
|
+
case WSP_GGML_OP_COUNT:
|
|
2292
|
+
{
|
|
2293
|
+
WSP_GGML_ABORT("fatal error");
|
|
2294
|
+
}
|
|
2295
|
+
default:
|
|
2296
|
+
{
|
|
2297
|
+
fprintf(stderr, "%s: op not implemented: ", __func__);
|
|
2298
|
+
if (node->op < WSP_GGML_OP_COUNT) {
|
|
2299
|
+
fprintf(stderr, "%s\n", wsp_ggml_op_name(node->op));
|
|
2300
|
+
} else {
|
|
2301
|
+
fprintf(stderr, "%d\n", node->op);
|
|
2302
|
+
}
|
|
2303
|
+
WSP_GGML_ABORT("fatal error");
|
|
2304
|
+
}
|
|
2305
|
+
}
|
|
2306
|
+
|
|
2307
|
+
assert(n_tasks > 0);
|
|
2308
|
+
|
|
2309
|
+
return n_tasks;
|
|
2310
|
+
}
|
|
2311
|
+
|
|
2312
|
+
static thread_ret_t wsp_ggml_graph_compute_secondary_thread(void* data);
|
|
2313
|
+
|
|
2314
|
+
#if defined(_WIN32)
|
|
2315
|
+
#include "windows.h"
|
|
2316
|
+
|
|
2317
|
+
// TODO: support > 64 CPUs
|
|
2318
|
+
static bool wsp_ggml_thread_apply_affinity(bool * mask) {
|
|
2319
|
+
HANDLE h = GetCurrentThread();
|
|
2320
|
+
uint64_t bitmask = 0ULL;
|
|
2321
|
+
|
|
2322
|
+
assert(WSP_GGML_MAX_N_THREADS >= 64);
|
|
2323
|
+
|
|
2324
|
+
for (int32_t i = 0; i < 8; i++) {
|
|
2325
|
+
int32_t idx = i * 8;
|
|
2326
|
+
uint8_t val = 0;
|
|
2327
|
+
val |= mask[idx + 0] << 0;
|
|
2328
|
+
val |= mask[idx + 1] << 1;
|
|
2329
|
+
val |= mask[idx + 2] << 2;
|
|
2330
|
+
val |= mask[idx + 3] << 3;
|
|
2331
|
+
val |= mask[idx + 4] << 4;
|
|
2332
|
+
val |= mask[idx + 5] << 5;
|
|
2333
|
+
val |= mask[idx + 6] << 6;
|
|
2334
|
+
val |= mask[idx + 7] << 7;
|
|
2335
|
+
bitmask |= (uint64_t)val << idx;
|
|
2336
|
+
}
|
|
2337
|
+
|
|
2338
|
+
for (int32_t i = 64; i < WSP_GGML_MAX_N_THREADS; i++) {
|
|
2339
|
+
if (mask[i]) {
|
|
2340
|
+
fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n");
|
|
2341
|
+
break;
|
|
2342
|
+
}
|
|
2343
|
+
}
|
|
2344
|
+
|
|
2345
|
+
DWORD_PTR m = (DWORD_PTR)bitmask;
|
|
2346
|
+
|
|
2347
|
+
m = SetThreadAffinityMask(h, m);
|
|
2348
|
+
|
|
2349
|
+
return m != 0;
|
|
2350
|
+
}
|
|
2351
|
+
|
|
2352
|
+
static bool wsp_ggml_thread_apply_priority(int32_t prio) {
|
|
2353
|
+
// Note that on Windows the Process Priority Class must be updated in order to set Thread priority.
|
|
2354
|
+
// This is up to the applications.
|
|
2355
|
+
DWORD p = THREAD_PRIORITY_NORMAL;
|
|
2356
|
+
switch (prio) {
|
|
2357
|
+
case WSP_GGML_SCHED_PRIO_LOW: p = THREAD_PRIORITY_BELOW_NORMAL; break;
|
|
2358
|
+
case WSP_GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break;
|
|
2359
|
+
case WSP_GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break;
|
|
2360
|
+
case WSP_GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break;
|
|
2361
|
+
case WSP_GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
|
|
2362
|
+
}
|
|
2363
|
+
|
|
2364
|
+
if (prio != WSP_GGML_SCHED_PRIO_LOW) {
|
|
2365
|
+
// Tell Windows that this thread should not be throttled (needs its own CPU core).
|
|
2366
|
+
// Newer Windows 11 versions aggresively park (offline) CPU cores and often place
|
|
2367
|
+
// all our threads onto the first 4 cores which results in terrible performance with
|
|
2368
|
+
// n_threads > 4
|
|
2369
|
+
#if _WIN32_WINNT >= 0x0602
|
|
2370
|
+
THREAD_POWER_THROTTLING_STATE t;
|
|
2371
|
+
ZeroMemory(&t, sizeof(t));
|
|
2372
|
+
t.Version = THREAD_POWER_THROTTLING_CURRENT_VERSION;
|
|
2373
|
+
t.ControlMask = THREAD_POWER_THROTTLING_EXECUTION_SPEED;
|
|
2374
|
+
t.StateMask = 0;
|
|
2375
|
+
|
|
2376
|
+
if (!SetThreadInformation(GetCurrentThread(), ThreadPowerThrottling, &t, sizeof(t))) {
|
|
2377
|
+
WSP_GGML_LOG_DEBUG("failed to disable thread power throttling %d : (%d)\n", prio, (int) GetLastError());
|
|
2378
|
+
return false;
|
|
2379
|
+
}
|
|
2380
|
+
#endif
|
|
2381
|
+
}
|
|
2382
|
+
|
|
2383
|
+
if (prio == WSP_GGML_SCHED_PRIO_NORMAL) {
|
|
2384
|
+
// Keep inherited policy/priority
|
|
2385
|
+
return true;
|
|
2386
|
+
}
|
|
2387
|
+
|
|
2388
|
+
if (!SetThreadPriority(GetCurrentThread(), p)) {
|
|
2389
|
+
fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError());
|
|
2390
|
+
return false;
|
|
2391
|
+
}
|
|
2392
|
+
|
|
2393
|
+
return true;
|
|
2394
|
+
}
|
|
2395
|
+
|
|
2396
|
+
#elif defined(__APPLE__)
|
|
2397
|
+
#include <sys/types.h>
|
|
2398
|
+
#include <sys/resource.h>
|
|
2399
|
+
|
|
2400
|
+
static bool wsp_ggml_thread_apply_affinity(const bool * mask) {
|
|
2401
|
+
// Not supported on Apple platforms
|
|
2402
|
+
UNUSED(mask);
|
|
2403
|
+
return true;
|
|
2404
|
+
}
|
|
2405
|
+
|
|
2406
|
+
static bool wsp_ggml_thread_apply_priority(int32_t prio) {
|
|
2407
|
+
struct sched_param p;
|
|
2408
|
+
int32_t policy = SCHED_OTHER;
|
|
2409
|
+
switch (prio) {
|
|
2410
|
+
// TODO: there seems to be no way to set lower prio on Apple platforms
|
|
2411
|
+
case WSP_GGML_SCHED_PRIO_LOW: policy = SCHED_OTHER; p.sched_priority = 0; break;
|
|
2412
|
+
case WSP_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
|
|
2413
|
+
case WSP_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
|
|
2414
|
+
case WSP_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
|
|
2415
|
+
case WSP_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
|
|
2416
|
+
}
|
|
2417
|
+
|
|
2418
|
+
if (prio == WSP_GGML_SCHED_PRIO_NORMAL) {
|
|
2419
|
+
// Keep inherited policy/priority
|
|
2420
|
+
return true;
|
|
2421
|
+
}
|
|
2422
|
+
|
|
2423
|
+
int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
|
|
2424
|
+
if (err != 0) {
|
|
2425
|
+
fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
|
|
2426
|
+
return false;
|
|
2427
|
+
}
|
|
2428
|
+
|
|
2429
|
+
return true;
|
|
2430
|
+
}
|
|
2431
|
+
|
|
2432
|
+
#elif defined(__gnu_linux__)
|
|
2433
|
+
// TODO: this may not work on BSD, to be verified
|
|
2434
|
+
|
|
2435
|
+
static bool wsp_ggml_thread_apply_affinity(const bool * mask) {
|
|
2436
|
+
cpu_set_t cpuset;
|
|
2437
|
+
int err;
|
|
2438
|
+
|
|
2439
|
+
CPU_ZERO(&cpuset);
|
|
2440
|
+
|
|
2441
|
+
for (uint32_t i = 0; i < WSP_GGML_MAX_N_THREADS; i++) {
|
|
2442
|
+
if (mask[i]) {
|
|
2443
|
+
WSP_GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i);
|
|
2444
|
+
CPU_SET(i, &cpuset);
|
|
2445
|
+
}
|
|
2446
|
+
}
|
|
2447
|
+
|
|
2448
|
+
#ifdef __ANDROID__
|
|
2449
|
+
err = sched_setaffinity(0, sizeof(cpuset), &cpuset);
|
|
2450
|
+
if (err < 0) {
|
|
2451
|
+
err = errno;
|
|
2452
|
+
}
|
|
2453
|
+
#else
|
|
2454
|
+
err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset);
|
|
2455
|
+
#endif
|
|
2456
|
+
if (err != 0) {
|
|
2457
|
+
fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err);
|
|
2458
|
+
return false;
|
|
2459
|
+
}
|
|
2460
|
+
|
|
2461
|
+
return true;
|
|
2462
|
+
}
|
|
2463
|
+
|
|
2464
|
+
static bool wsp_ggml_thread_apply_priority(int32_t prio) {
|
|
2465
|
+
struct sched_param p;
|
|
2466
|
+
int32_t policy = SCHED_OTHER;
|
|
2467
|
+
switch (prio) {
|
|
2468
|
+
case WSP_GGML_SCHED_PRIO_LOW: policy = SCHED_BATCH; p.sched_priority = 0; break;
|
|
2469
|
+
case WSP_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
|
|
2470
|
+
case WSP_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
|
|
2471
|
+
case WSP_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
|
|
2472
|
+
case WSP_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
|
|
2473
|
+
}
|
|
2474
|
+
|
|
2475
|
+
if (prio == WSP_GGML_SCHED_PRIO_NORMAL) {
|
|
2476
|
+
// Keep inherited policy/priority
|
|
2477
|
+
return true;
|
|
2478
|
+
}
|
|
2479
|
+
|
|
2480
|
+
int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
|
|
2481
|
+
if (err != 0) {
|
|
2482
|
+
fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
|
|
2483
|
+
return false;
|
|
2484
|
+
}
|
|
2485
|
+
|
|
2486
|
+
return true;
|
|
2487
|
+
}
|
|
2488
|
+
|
|
2489
|
+
#else // unsupported platforms
|
|
2490
|
+
|
|
2491
|
+
static bool wsp_ggml_thread_apply_affinity(const bool * mask) {
|
|
2492
|
+
UNUSED(mask);
|
|
2493
|
+
return true;
|
|
2494
|
+
}
|
|
2495
|
+
|
|
2496
|
+
static bool wsp_ggml_thread_apply_priority(int32_t prio) {
|
|
2497
|
+
UNUSED(prio);
|
|
2498
|
+
return true;
|
|
2499
|
+
}
|
|
2500
|
+
|
|
2501
|
+
#endif
|
|
2502
|
+
|
|
2503
|
+
static bool wsp_ggml_thread_cpumask_is_valid(const bool * mask) {
|
|
2504
|
+
for (int i = 0; i < WSP_GGML_MAX_N_THREADS; i++) {
|
|
2505
|
+
if (mask[i]) { return true; }
|
|
2506
|
+
}
|
|
2507
|
+
return false;
|
|
2508
|
+
}
|
|
2509
|
+
|
|
2510
|
+
static void wsp_ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) {
|
|
2511
|
+
if (!strict) {
|
|
2512
|
+
memcpy(local_mask, global_mask, WSP_GGML_MAX_N_THREADS);
|
|
2513
|
+
return;
|
|
2514
|
+
} else {
|
|
2515
|
+
memset(local_mask, 0, WSP_GGML_MAX_N_THREADS);
|
|
2516
|
+
int32_t base_idx = *iter;
|
|
2517
|
+
for (int32_t i = 0; i < WSP_GGML_MAX_N_THREADS; i++) {
|
|
2518
|
+
int32_t idx = base_idx + i;
|
|
2519
|
+
if (idx >= WSP_GGML_MAX_N_THREADS) {
|
|
2520
|
+
// Just a cheaper modulo
|
|
2521
|
+
idx -= WSP_GGML_MAX_N_THREADS;
|
|
2522
|
+
}
|
|
2523
|
+
if (global_mask[idx]) {
|
|
2524
|
+
local_mask[idx] = 1;
|
|
2525
|
+
*iter = idx + 1;
|
|
2526
|
+
return;
|
|
2527
|
+
}
|
|
2528
|
+
}
|
|
2529
|
+
}
|
|
2530
|
+
}
|
|
2531
|
+
|
|
2532
|
+
void wsp_ggml_threadpool_free(struct wsp_ggml_threadpool* threadpool) {
|
|
2533
|
+
if (!threadpool) return;
|
|
2534
|
+
|
|
2535
|
+
const int n_threads = threadpool->n_threads_max;
|
|
2536
|
+
|
|
2537
|
+
#ifndef WSP_GGML_USE_OPENMP
|
|
2538
|
+
struct wsp_ggml_compute_state* workers = threadpool->workers;
|
|
2539
|
+
|
|
2540
|
+
wsp_ggml_mutex_lock(&threadpool->mutex);
|
|
2541
|
+
|
|
2542
|
+
threadpool->stop = true;
|
|
2543
|
+
threadpool->pause = false;
|
|
2544
|
+
|
|
2545
|
+
wsp_ggml_cond_broadcast(&threadpool->cond);
|
|
2546
|
+
wsp_ggml_mutex_unlock(&threadpool->mutex);
|
|
2547
|
+
|
|
2548
|
+
for (int j = 1; j < n_threads; j++) {
|
|
2549
|
+
int32_t rc = wsp_ggml_thread_join(workers[j].thrd, NULL);
|
|
2550
|
+
WSP_GGML_ASSERT(rc == WSP_GGML_EXIT_SUCCESS || rc == WSP_GGML_EXIT_ABORTED);
|
|
2551
|
+
UNUSED(rc);
|
|
2552
|
+
}
|
|
2553
|
+
|
|
2554
|
+
wsp_ggml_mutex_destroy(&threadpool->mutex);
|
|
2555
|
+
wsp_ggml_cond_destroy(&threadpool->cond);
|
|
2556
|
+
#endif // WSP_GGML_USE_OPENMP
|
|
2557
|
+
|
|
2558
|
+
const size_t workers_size = sizeof(struct wsp_ggml_compute_state) * n_threads;
|
|
2559
|
+
wsp_ggml_aligned_free(threadpool->workers, workers_size);
|
|
2560
|
+
wsp_ggml_aligned_free(threadpool, sizeof(struct wsp_ggml_threadpool));
|
|
2561
|
+
}
|
|
2562
|
+
|
|
2563
|
+
#ifndef WSP_GGML_USE_OPENMP
|
|
2564
|
+
// pause/resume must be called under mutex
|
|
2565
|
+
static void wsp_ggml_threadpool_pause_locked(struct wsp_ggml_threadpool * threadpool) {
|
|
2566
|
+
WSP_GGML_PRINT_DEBUG("Pausing threadpool\n");
|
|
2567
|
+
threadpool->pause = true;
|
|
2568
|
+
wsp_ggml_cond_broadcast(&threadpool->cond);
|
|
2569
|
+
}
|
|
2570
|
+
|
|
2571
|
+
static void wsp_ggml_threadpool_resume_locked(struct wsp_ggml_threadpool * threadpool) {
|
|
2572
|
+
WSP_GGML_PRINT_DEBUG("Resuming threadpool\n");
|
|
2573
|
+
threadpool->pause = false;
|
|
2574
|
+
wsp_ggml_cond_broadcast(&threadpool->cond);
|
|
2575
|
+
}
|
|
2576
|
+
#endif
|
|
2577
|
+
|
|
2578
|
+
void wsp_ggml_threadpool_pause(struct wsp_ggml_threadpool * threadpool) {
|
|
2579
|
+
#ifndef WSP_GGML_USE_OPENMP
|
|
2580
|
+
wsp_ggml_mutex_lock(&threadpool->mutex);
|
|
2581
|
+
if (!threadpool->pause) {
|
|
2582
|
+
wsp_ggml_threadpool_pause_locked(threadpool);
|
|
2583
|
+
}
|
|
2584
|
+
wsp_ggml_mutex_unlock(&threadpool->mutex);
|
|
2585
|
+
#else
|
|
2586
|
+
UNUSED(threadpool);
|
|
2587
|
+
#endif
|
|
2588
|
+
}
|
|
2589
|
+
|
|
2590
|
+
void wsp_ggml_threadpool_resume(struct wsp_ggml_threadpool * threadpool) {
|
|
2591
|
+
#ifndef WSP_GGML_USE_OPENMP
|
|
2592
|
+
wsp_ggml_mutex_lock(&threadpool->mutex);
|
|
2593
|
+
if (threadpool->pause) {
|
|
2594
|
+
wsp_ggml_threadpool_resume_locked(threadpool);
|
|
2595
|
+
}
|
|
2596
|
+
wsp_ggml_mutex_unlock(&threadpool->mutex);
|
|
2597
|
+
#else
|
|
2598
|
+
UNUSED(threadpool);
|
|
2599
|
+
#endif
|
|
2600
|
+
}
|
|
2601
|
+
|
|
2602
|
+
struct wsp_ggml_cplan wsp_ggml_graph_plan(
|
|
2603
|
+
const struct wsp_ggml_cgraph * cgraph,
|
|
2604
|
+
int n_threads,
|
|
2605
|
+
struct wsp_ggml_threadpool * threadpool) {
|
|
2606
|
+
|
|
2607
|
+
if (threadpool == NULL) {
|
|
2608
|
+
//WSP_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
|
|
2609
|
+
}
|
|
2610
|
+
if (n_threads <= 0) {
|
|
2611
|
+
n_threads = threadpool ? threadpool->n_threads_max : WSP_GGML_DEFAULT_N_THREADS;
|
|
2612
|
+
}
|
|
2613
|
+
|
|
2614
|
+
size_t work_size = 0;
|
|
2615
|
+
|
|
2616
|
+
struct wsp_ggml_cplan cplan;
|
|
2617
|
+
memset(&cplan, 0, sizeof(struct wsp_ggml_cplan));
|
|
2618
|
+
|
|
2619
|
+
int max_tasks = 1;
|
|
2620
|
+
|
|
2621
|
+
// thread scheduling for the different operations + work buffer size estimation
|
|
2622
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
2623
|
+
struct wsp_ggml_tensor * node = cgraph->nodes[i];
|
|
2624
|
+
|
|
2625
|
+
const int n_tasks = wsp_ggml_get_n_tasks(node, n_threads);
|
|
2626
|
+
|
|
2627
|
+
max_tasks = MAX(max_tasks, n_tasks);
|
|
2628
|
+
|
|
2629
|
+
size_t cur = 0;
|
|
2630
|
+
|
|
2631
|
+
if (!wsp_ggml_cpu_extra_work_size(n_threads, node, &cur)) {
|
|
2632
|
+
switch (node->op) {
|
|
2633
|
+
case WSP_GGML_OP_CPY:
|
|
2634
|
+
case WSP_GGML_OP_DUP:
|
|
2635
|
+
{
|
|
2636
|
+
if (wsp_ggml_is_quantized(node->type) ||
|
|
2637
|
+
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
|
|
2638
|
+
(node->src[0]->type == WSP_GGML_TYPE_F16 && node->src[1] && node->src[1]->type == WSP_GGML_TYPE_BF16) ||
|
|
2639
|
+
(node->src[0]->type == WSP_GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == WSP_GGML_TYPE_F16)) {
|
|
2640
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
|
2641
|
+
}
|
|
2642
|
+
} break;
|
|
2643
|
+
case WSP_GGML_OP_ADD:
|
|
2644
|
+
case WSP_GGML_OP_ADD1:
|
|
2645
|
+
{
|
|
2646
|
+
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
2647
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
|
2648
|
+
}
|
|
2649
|
+
} break;
|
|
2650
|
+
case WSP_GGML_OP_ACC:
|
|
2651
|
+
{
|
|
2652
|
+
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
2653
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
|
|
2654
|
+
}
|
|
2655
|
+
} break;
|
|
2656
|
+
case WSP_GGML_OP_COUNT_EQUAL:
|
|
2657
|
+
{
|
|
2658
|
+
cur = wsp_ggml_type_size(node->type)*n_tasks;
|
|
2659
|
+
} break;
|
|
2660
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
2661
|
+
{
|
|
2662
|
+
const enum wsp_ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
|
|
2663
|
+
|
|
2664
|
+
if (node->src[1]->type != vec_dot_type) {
|
|
2665
|
+
cur = wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(node->src[1]));
|
|
2666
|
+
}
|
|
2667
|
+
} break;
|
|
2668
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
2669
|
+
{
|
|
2670
|
+
cur = 0;
|
|
2671
|
+
const struct wsp_ggml_tensor * src0 = node->src[0];
|
|
2672
|
+
const struct wsp_ggml_tensor * src1 = node->src[1];
|
|
2673
|
+
const struct wsp_ggml_tensor * ids = node->src[2];
|
|
2674
|
+
const enum wsp_ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
|
|
2675
|
+
const int n_as = src0->ne[2];
|
|
2676
|
+
// src1
|
|
2677
|
+
if (src1->type != vec_dot_type) {
|
|
2678
|
+
cur += wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)) + sizeof(int64_t);
|
|
2679
|
+
}
|
|
2680
|
+
// matrix_row_counts
|
|
2681
|
+
cur += n_as * sizeof(int64_t) + sizeof(int64_t);
|
|
2682
|
+
// matrix_rows
|
|
2683
|
+
cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
|
|
2684
|
+
// atomic_current_chunk
|
|
2685
|
+
cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
|
|
2686
|
+
} break;
|
|
2687
|
+
case WSP_GGML_OP_OUT_PROD:
|
|
2688
|
+
{
|
|
2689
|
+
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
2690
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
|
2691
|
+
}
|
|
2692
|
+
} break;
|
|
2693
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
2694
|
+
case WSP_GGML_OP_ROPE:
|
|
2695
|
+
case WSP_GGML_OP_ROPE_BACK:
|
|
2696
|
+
{
|
|
2697
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
|
2698
|
+
} break;
|
|
2699
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
2700
|
+
{
|
|
2701
|
+
WSP_GGML_ASSERT(node->src[0]->ne[3] == 1);
|
|
2702
|
+
WSP_GGML_ASSERT(node->src[1]->ne[2] == 1);
|
|
2703
|
+
WSP_GGML_ASSERT(node->src[1]->ne[3] == 1);
|
|
2704
|
+
|
|
2705
|
+
const int64_t ne00 = node->src[0]->ne[0]; // K
|
|
2706
|
+
const int64_t ne01 = node->src[0]->ne[1]; // Cout
|
|
2707
|
+
const int64_t ne02 = node->src[0]->ne[2]; // Cin
|
|
2708
|
+
const int64_t ne10 = node->src[1]->ne[0]; // L
|
|
2709
|
+
const int64_t ne11 = node->src[1]->ne[1]; // Cin
|
|
2710
|
+
|
|
2711
|
+
if ((node->src[0]->type == WSP_GGML_TYPE_F16 ||
|
|
2712
|
+
node->src[0]->type == WSP_GGML_TYPE_BF16) &&
|
|
2713
|
+
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
2714
|
+
cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02;
|
|
2715
|
+
cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11;
|
|
2716
|
+
} else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
|
|
2717
|
+
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
2718
|
+
cur += sizeof(float)*ne00*ne01*ne02;
|
|
2719
|
+
cur += sizeof(float)*ne10*ne11;
|
|
2720
|
+
} else {
|
|
2721
|
+
WSP_GGML_ABORT("fatal error");
|
|
2722
|
+
}
|
|
2723
|
+
} break;
|
|
2724
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
2725
|
+
{
|
|
2726
|
+
const int64_t ne00 = node->src[0]->ne[0]; // W
|
|
2727
|
+
const int64_t ne01 = node->src[0]->ne[1]; // H
|
|
2728
|
+
const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
|
|
2729
|
+
const int64_t ne03 = node->src[0]->ne[3]; // Channels In
|
|
2730
|
+
|
|
2731
|
+
const int64_t ne10 = node->src[1]->ne[0]; // W
|
|
2732
|
+
const int64_t ne11 = node->src[1]->ne[1]; // H
|
|
2733
|
+
const int64_t ne12 = node->src[1]->ne[2]; // Channels In
|
|
2734
|
+
|
|
2735
|
+
cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
|
2736
|
+
cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11*ne12;
|
|
2737
|
+
} break;
|
|
2738
|
+
case WSP_GGML_OP_FLASH_ATTN_EXT:
|
|
2739
|
+
{
|
|
2740
|
+
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
|
2741
|
+
const int64_t ne20 = node->src[2]->ne[0]; // DV
|
|
2742
|
+
|
|
2743
|
+
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
|
|
2744
|
+
} break;
|
|
2745
|
+
case WSP_GGML_OP_FLASH_ATTN_BACK:
|
|
2746
|
+
{
|
|
2747
|
+
const int64_t D = node->src[0]->ne[0];
|
|
2748
|
+
const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
|
|
2749
|
+
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back
|
|
2750
|
+
if (node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
2751
|
+
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
|
2752
|
+
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
|
2753
|
+
} else if (node->src[1]->type == WSP_GGML_TYPE_F16) {
|
|
2754
|
+
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
|
2755
|
+
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
|
2756
|
+
} else if (node->src[1]->type == WSP_GGML_TYPE_BF16) {
|
|
2757
|
+
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
|
2758
|
+
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
|
2759
|
+
}
|
|
2760
|
+
} break;
|
|
2761
|
+
|
|
2762
|
+
case WSP_GGML_OP_CROSS_ENTROPY_LOSS:
|
|
2763
|
+
{
|
|
2764
|
+
cur = wsp_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
|
2765
|
+
} break;
|
|
2766
|
+
case WSP_GGML_OP_COUNT:
|
|
2767
|
+
{
|
|
2768
|
+
WSP_GGML_ABORT("fatal error");
|
|
2769
|
+
}
|
|
2770
|
+
default:
|
|
2771
|
+
break;
|
|
2772
|
+
}
|
|
2773
|
+
}
|
|
2774
|
+
|
|
2775
|
+
work_size = MAX(work_size, cur);
|
|
2776
|
+
}
|
|
2777
|
+
|
|
2778
|
+
if (work_size > 0) {
|
|
2779
|
+
work_size += CACHE_LINE_SIZE*(n_threads);
|
|
2780
|
+
}
|
|
2781
|
+
|
|
2782
|
+
cplan.threadpool = threadpool;
|
|
2783
|
+
cplan.n_threads = MIN(max_tasks, n_threads);
|
|
2784
|
+
cplan.work_size = work_size;
|
|
2785
|
+
cplan.work_data = NULL;
|
|
2786
|
+
|
|
2787
|
+
return cplan;
|
|
2788
|
+
}
|
|
2789
|
+
|
|
2790
|
+
static thread_ret_t wsp_ggml_graph_compute_thread(void * data) {
|
|
2791
|
+
struct wsp_ggml_compute_state * state = (struct wsp_ggml_compute_state *) data;
|
|
2792
|
+
struct wsp_ggml_threadpool * tp = state->threadpool;
|
|
2793
|
+
|
|
2794
|
+
const struct wsp_ggml_cgraph * cgraph = tp->cgraph;
|
|
2795
|
+
const struct wsp_ggml_cplan * cplan = tp->cplan;
|
|
2796
|
+
|
|
2797
|
+
set_numa_thread_affinity(state->ith);
|
|
2798
|
+
|
|
2799
|
+
struct wsp_ggml_compute_params params = {
|
|
2800
|
+
/*.ith =*/ state->ith,
|
|
2801
|
+
/*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
|
|
2802
|
+
/*.wsize =*/ cplan->work_size,
|
|
2803
|
+
/*.wdata =*/ cplan->work_data,
|
|
2804
|
+
/*.threadpool=*/ tp,
|
|
2805
|
+
};
|
|
2806
|
+
|
|
2807
|
+
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
|
2808
|
+
struct wsp_ggml_tensor * node = cgraph->nodes[node_n];
|
|
2809
|
+
|
|
2810
|
+
wsp_ggml_compute_forward(¶ms, node);
|
|
2811
|
+
|
|
2812
|
+
if (state->ith == 0 && cplan->abort_callback &&
|
|
2813
|
+
cplan->abort_callback(cplan->abort_callback_data)) {
|
|
2814
|
+
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
|
|
2815
|
+
tp->ec = WSP_GGML_STATUS_ABORTED;
|
|
2816
|
+
}
|
|
2817
|
+
|
|
2818
|
+
if (node_n + 1 < cgraph->n_nodes) {
|
|
2819
|
+
wsp_ggml_barrier(state->threadpool);
|
|
2820
|
+
}
|
|
2821
|
+
}
|
|
2822
|
+
|
|
2823
|
+
wsp_ggml_barrier(state->threadpool);
|
|
2824
|
+
|
|
2825
|
+
return 0;
|
|
2826
|
+
}
|
|
2827
|
+
|
|
2828
|
+
#ifndef WSP_GGML_USE_OPENMP
|
|
2829
|
+
|
|
2830
|
+
// check if thread is active
|
|
2831
|
+
static inline bool wsp_ggml_graph_compute_thread_active(struct wsp_ggml_compute_state * state) {
|
|
2832
|
+
struct wsp_ggml_threadpool * threadpool = state->threadpool;
|
|
2833
|
+
int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
|
|
2834
|
+
return (state->ith < n_threads);
|
|
2835
|
+
}
|
|
2836
|
+
|
|
2837
|
+
// check if thread is ready to proceed (exit from polling or sleeping)
|
|
2838
|
+
static inline bool wsp_ggml_graph_compute_thread_ready(struct wsp_ggml_compute_state * state) {
|
|
2839
|
+
struct wsp_ggml_threadpool * threadpool = state->threadpool;
|
|
2840
|
+
|
|
2841
|
+
if (state->pending || threadpool->stop || threadpool->pause) { return true; }
|
|
2842
|
+
|
|
2843
|
+
// check for new graph/work
|
|
2844
|
+
int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
|
|
2845
|
+
if (new_graph != state->last_graph) {
|
|
2846
|
+
state->pending = wsp_ggml_graph_compute_thread_active(state);
|
|
2847
|
+
state->last_graph = new_graph;
|
|
2848
|
+
}
|
|
2849
|
+
|
|
2850
|
+
return state->pending;
|
|
2851
|
+
}
|
|
2852
|
+
|
|
2853
|
+
// sync thread state after polling
|
|
2854
|
+
static inline void wsp_ggml_graph_compute_thread_sync(struct wsp_ggml_compute_state * state) {
|
|
2855
|
+
// TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
|
|
2856
|
+
#ifdef WSP_GGML_TSAN_ENABLED
|
|
2857
|
+
atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
|
|
2858
|
+
#else
|
|
2859
|
+
atomic_thread_fence(memory_order_seq_cst);
|
|
2860
|
+
#endif
|
|
2861
|
+
UNUSED(state);
|
|
2862
|
+
}
|
|
2863
|
+
|
|
2864
|
+
static inline bool wsp_ggml_graph_compute_poll_for_work(struct wsp_ggml_compute_state * state) {
|
|
2865
|
+
struct wsp_ggml_threadpool * threadpool = state->threadpool;
|
|
2866
|
+
|
|
2867
|
+
// Skip polling for unused threads
|
|
2868
|
+
if (!wsp_ggml_graph_compute_thread_active(state)) {
|
|
2869
|
+
return state->pending;
|
|
2870
|
+
}
|
|
2871
|
+
|
|
2872
|
+
// This seems to make 0 ... 100 a decent range for polling level across modern processors.
|
|
2873
|
+
// Perhaps, we can adjust it dynamically based on load and things.
|
|
2874
|
+
const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
|
|
2875
|
+
|
|
2876
|
+
for (uint64_t i=0; !wsp_ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) {
|
|
2877
|
+
// No new work. Keep polling.
|
|
2878
|
+
wsp_ggml_thread_cpu_relax();
|
|
2879
|
+
}
|
|
2880
|
+
|
|
2881
|
+
return state->pending;
|
|
2882
|
+
}
|
|
2883
|
+
|
|
2884
|
+
static inline bool wsp_ggml_graph_compute_check_for_work(struct wsp_ggml_compute_state * state) {
|
|
2885
|
+
struct wsp_ggml_threadpool * threadpool = state->threadpool;
|
|
2886
|
+
|
|
2887
|
+
if (wsp_ggml_graph_compute_poll_for_work(state)) {
|
|
2888
|
+
wsp_ggml_graph_compute_thread_sync(state);
|
|
2889
|
+
return state->pending;
|
|
2890
|
+
}
|
|
2891
|
+
|
|
2892
|
+
wsp_ggml_mutex_lock_shared(&threadpool->mutex);
|
|
2893
|
+
while (!wsp_ggml_graph_compute_thread_ready(state)) {
|
|
2894
|
+
// No new work. Wait for the signal.
|
|
2895
|
+
WSP_GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith);
|
|
2896
|
+
wsp_ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
|
|
2897
|
+
}
|
|
2898
|
+
wsp_ggml_mutex_unlock_shared(&threadpool->mutex);
|
|
2899
|
+
|
|
2900
|
+
return state->pending;
|
|
2901
|
+
}
|
|
2902
|
+
|
|
2903
|
+
static thread_ret_t wsp_ggml_graph_compute_secondary_thread(void* data) {
|
|
2904
|
+
struct wsp_ggml_compute_state * state = (struct wsp_ggml_compute_state *) data;
|
|
2905
|
+
struct wsp_ggml_threadpool * threadpool = state->threadpool;
|
|
2906
|
+
|
|
2907
|
+
wsp_ggml_thread_apply_priority(threadpool->prio);
|
|
2908
|
+
if (wsp_ggml_thread_cpumask_is_valid(state->cpumask)) {
|
|
2909
|
+
wsp_ggml_thread_apply_affinity(state->cpumask);
|
|
2910
|
+
}
|
|
2911
|
+
|
|
2912
|
+
while (true) {
|
|
2913
|
+
// Check if we need to sleep
|
|
2914
|
+
while (threadpool->pause) {
|
|
2915
|
+
WSP_GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith);
|
|
2916
|
+
wsp_ggml_mutex_lock_shared(&threadpool->mutex);
|
|
2917
|
+
if (threadpool->pause) {
|
|
2918
|
+
wsp_ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
|
|
2919
|
+
}
|
|
2920
|
+
WSP_GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith);
|
|
2921
|
+
wsp_ggml_mutex_unlock_shared(&threadpool->mutex);
|
|
2922
|
+
}
|
|
2923
|
+
|
|
2924
|
+
// This needs to be checked for after the cond_wait
|
|
2925
|
+
if (threadpool->stop) break;
|
|
2926
|
+
|
|
2927
|
+
// Check if there is new work
|
|
2928
|
+
// The main thread is the only one that can dispatch new work
|
|
2929
|
+
|
|
2930
|
+
wsp_ggml_graph_compute_check_for_work(state);
|
|
2931
|
+
if (state->pending) {
|
|
2932
|
+
state->pending = false;
|
|
2933
|
+
|
|
2934
|
+
wsp_ggml_graph_compute_thread(state);
|
|
2935
|
+
}
|
|
2936
|
+
}
|
|
2937
|
+
|
|
2938
|
+
return (thread_ret_t) 0;
|
|
2939
|
+
}
|
|
2940
|
+
|
|
2941
|
+
// Start processing new graph
|
|
2942
|
+
static void wsp_ggml_graph_compute_kickoff(struct wsp_ggml_threadpool * threadpool, int n_threads)
|
|
2943
|
+
{
|
|
2944
|
+
// Always take the mutex here because the worker threads are doing hybrid poll/wait
|
|
2945
|
+
|
|
2946
|
+
wsp_ggml_mutex_lock(&threadpool->mutex);
|
|
2947
|
+
|
|
2948
|
+
WSP_GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
|
|
2949
|
+
|
|
2950
|
+
// Update the number of active threads
|
|
2951
|
+
atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
|
|
2952
|
+
|
|
2953
|
+
// Indicate the graph is ready to be processed
|
|
2954
|
+
// We need the full seq-cst fence here because of the polling threads (used in thread_sync)
|
|
2955
|
+
atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
|
|
2956
|
+
|
|
2957
|
+
if (threadpool->pause) {
|
|
2958
|
+
// Update main thread prio and affinity to match the threadpool settings
|
|
2959
|
+
wsp_ggml_thread_apply_priority(threadpool->prio);
|
|
2960
|
+
if (wsp_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
|
|
2961
|
+
wsp_ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
|
|
2962
|
+
}
|
|
2963
|
+
|
|
2964
|
+
// resume does cond broadcast
|
|
2965
|
+
wsp_ggml_threadpool_resume_locked(threadpool);
|
|
2966
|
+
} else {
|
|
2967
|
+
wsp_ggml_cond_broadcast(&threadpool->cond);
|
|
2968
|
+
}
|
|
2969
|
+
|
|
2970
|
+
wsp_ggml_mutex_unlock(&threadpool->mutex);
|
|
2971
|
+
}
|
|
2972
|
+
|
|
2973
|
+
#endif // WSP_GGML_USE_OPENMP
|
|
2974
|
+
|
|
2975
|
+
static struct wsp_ggml_threadpool * wsp_ggml_threadpool_new_impl(
|
|
2976
|
+
struct wsp_ggml_threadpool_params * tpp,
|
|
2977
|
+
struct wsp_ggml_cgraph * cgraph,
|
|
2978
|
+
struct wsp_ggml_cplan * cplan) {
|
|
2979
|
+
|
|
2980
|
+
struct wsp_ggml_threadpool * threadpool =
|
|
2981
|
+
wsp_ggml_aligned_malloc(sizeof(struct wsp_ggml_threadpool));
|
|
2982
|
+
{
|
|
2983
|
+
threadpool->cgraph = cgraph;
|
|
2984
|
+
threadpool->cplan = cplan;
|
|
2985
|
+
threadpool->n_graph = 0;
|
|
2986
|
+
threadpool->n_barrier = 0;
|
|
2987
|
+
threadpool->n_barrier_passed = 0;
|
|
2988
|
+
threadpool->current_chunk = 0;
|
|
2989
|
+
threadpool->stop = false;
|
|
2990
|
+
threadpool->pause = tpp->paused;
|
|
2991
|
+
threadpool->abort = -1;
|
|
2992
|
+
threadpool->workers = NULL;
|
|
2993
|
+
threadpool->n_threads_max = tpp->n_threads;
|
|
2994
|
+
threadpool->n_threads_cur = tpp->n_threads;
|
|
2995
|
+
threadpool->poll = tpp->poll;
|
|
2996
|
+
threadpool->prio = tpp->prio;
|
|
2997
|
+
threadpool->ec = WSP_GGML_STATUS_SUCCESS;
|
|
2998
|
+
}
|
|
2999
|
+
|
|
3000
|
+
// Allocate and init workers state
|
|
3001
|
+
const size_t workers_size = sizeof(struct wsp_ggml_compute_state) * tpp->n_threads;
|
|
3002
|
+
struct wsp_ggml_compute_state * workers = wsp_ggml_aligned_malloc(workers_size);
|
|
3003
|
+
|
|
3004
|
+
memset(workers, 0, workers_size);
|
|
3005
|
+
for (int j = 0; j < tpp->n_threads; j++) {
|
|
3006
|
+
workers[j].threadpool = threadpool;
|
|
3007
|
+
workers[j].ith = j;
|
|
3008
|
+
}
|
|
3009
|
+
|
|
3010
|
+
threadpool->workers = workers;
|
|
3011
|
+
|
|
3012
|
+
#ifndef WSP_GGML_USE_OPENMP
|
|
3013
|
+
wsp_ggml_mutex_init(&threadpool->mutex);
|
|
3014
|
+
wsp_ggml_cond_init(&threadpool->cond);
|
|
3015
|
+
|
|
3016
|
+
// Spin the threads for all workers, and update CPU placements.
|
|
3017
|
+
// Place the main thread last (towards the higher numbered CPU cores).
|
|
3018
|
+
|
|
3019
|
+
int32_t cpumask_iter = 0;
|
|
3020
|
+
|
|
3021
|
+
for (int j = 1; j < tpp->n_threads; j++) {
|
|
3022
|
+
wsp_ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);
|
|
3023
|
+
|
|
3024
|
+
int32_t rc = wsp_ggml_thread_create(&workers[j].thrd, NULL, wsp_ggml_graph_compute_secondary_thread, &workers[j]);
|
|
3025
|
+
WSP_GGML_ASSERT(rc == 0);
|
|
3026
|
+
}
|
|
3027
|
+
|
|
3028
|
+
wsp_ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter);
|
|
3029
|
+
|
|
3030
|
+
if (!threadpool->pause) {
|
|
3031
|
+
// Update main thread prio and affinity at the start, otherwise we'll do it in resume
|
|
3032
|
+
wsp_ggml_thread_apply_priority(threadpool->prio);
|
|
3033
|
+
if (wsp_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
|
|
3034
|
+
wsp_ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
|
|
3035
|
+
}
|
|
3036
|
+
}
|
|
3037
|
+
#endif // WSP_GGML_USE_OPENMP
|
|
3038
|
+
|
|
3039
|
+
return threadpool;
|
|
3040
|
+
}
|
|
3041
|
+
|
|
3042
|
+
struct wsp_ggml_threadpool * wsp_ggml_threadpool_new(struct wsp_ggml_threadpool_params * tpp) {
|
|
3043
|
+
return wsp_ggml_threadpool_new_impl(tpp, NULL, NULL);
|
|
3044
|
+
}
|
|
3045
|
+
|
|
3046
|
+
enum wsp_ggml_status wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan) {
|
|
3047
|
+
wsp_ggml_cpu_init();
|
|
3048
|
+
|
|
3049
|
+
WSP_GGML_ASSERT(cplan);
|
|
3050
|
+
WSP_GGML_ASSERT(cplan->n_threads > 0);
|
|
3051
|
+
WSP_GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
|
|
3052
|
+
|
|
3053
|
+
int n_threads = cplan->n_threads;
|
|
3054
|
+
struct wsp_ggml_threadpool * threadpool = cplan->threadpool;
|
|
3055
|
+
|
|
3056
|
+
bool disposable_threadpool = false;
|
|
3057
|
+
|
|
3058
|
+
if (threadpool == NULL) {
|
|
3059
|
+
//WSP_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
|
|
3060
|
+
disposable_threadpool = true;
|
|
3061
|
+
|
|
3062
|
+
struct wsp_ggml_threadpool_params ttp = wsp_ggml_threadpool_params_default(n_threads);
|
|
3063
|
+
threadpool = wsp_ggml_threadpool_new_impl(&ttp, cgraph, cplan);
|
|
3064
|
+
} else {
|
|
3065
|
+
// Reset some of the parameters that need resetting
|
|
3066
|
+
// No worker threads should be accessing the parameters below at this stage
|
|
3067
|
+
threadpool->cgraph = cgraph;
|
|
3068
|
+
threadpool->cplan = cplan;
|
|
3069
|
+
threadpool->current_chunk = 0;
|
|
3070
|
+
threadpool->abort = -1;
|
|
3071
|
+
threadpool->ec = WSP_GGML_STATUS_SUCCESS;
|
|
3072
|
+
}
|
|
3073
|
+
|
|
3074
|
+
#ifdef WSP_GGML_USE_OPENMP
|
|
3075
|
+
if (n_threads > 1) {
|
|
3076
|
+
#pragma omp parallel num_threads(n_threads)
|
|
3077
|
+
{
|
|
3078
|
+
#pragma omp single
|
|
3079
|
+
{
|
|
3080
|
+
// update the number of threads from the actual number of threads that we got from OpenMP
|
|
3081
|
+
n_threads = omp_get_num_threads();
|
|
3082
|
+
atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
|
|
3083
|
+
}
|
|
3084
|
+
|
|
3085
|
+
wsp_ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
|
|
3086
|
+
}
|
|
3087
|
+
} else {
|
|
3088
|
+
atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed);
|
|
3089
|
+
wsp_ggml_graph_compute_thread(&threadpool->workers[0]);
|
|
3090
|
+
}
|
|
3091
|
+
#else
|
|
3092
|
+
if (n_threads > threadpool->n_threads_max) {
|
|
3093
|
+
WSP_GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
|
|
3094
|
+
n_threads = threadpool->n_threads_max;
|
|
3095
|
+
}
|
|
3096
|
+
|
|
3097
|
+
// Kick all threads to start the new graph
|
|
3098
|
+
wsp_ggml_graph_compute_kickoff(threadpool, n_threads);
|
|
3099
|
+
|
|
3100
|
+
// This is a work thread too
|
|
3101
|
+
wsp_ggml_graph_compute_thread(&threadpool->workers[0]);
|
|
3102
|
+
#endif
|
|
3103
|
+
|
|
3104
|
+
// don't leave affinity set on the main thread
|
|
3105
|
+
clear_numa_thread_affinity();
|
|
3106
|
+
|
|
3107
|
+
enum wsp_ggml_status ret = threadpool->ec;
|
|
3108
|
+
|
|
3109
|
+
if (disposable_threadpool) {
|
|
3110
|
+
wsp_ggml_threadpool_free(threadpool);
|
|
3111
|
+
}
|
|
3112
|
+
|
|
3113
|
+
return ret;
|
|
3114
|
+
}
|
|
3115
|
+
|
|
3116
|
+
enum wsp_ggml_status wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads) {
|
|
3117
|
+
struct wsp_ggml_cplan cplan = wsp_ggml_graph_plan(cgraph, n_threads, NULL);
|
|
3118
|
+
|
|
3119
|
+
cplan.work_data = (uint8_t *)wsp_ggml_new_buffer(ctx, cplan.work_size);
|
|
3120
|
+
|
|
3121
|
+
return wsp_ggml_graph_compute(cgraph, &cplan);
|
|
3122
|
+
}
|
|
3123
|
+
|
|
3124
|
+
void wsp_ggml_cpu_fp32_to_fp16(const float * x, wsp_ggml_fp16_t * y, int64_t n) {
|
|
3125
|
+
int64_t i = 0;
|
|
3126
|
+
#if defined(__F16C__)
|
|
3127
|
+
#if defined(__AVX512F__)
|
|
3128
|
+
for (; i + 15 < n; i += 16) {
|
|
3129
|
+
__m512 x_vec = _mm512_loadu_ps(x + i);
|
|
3130
|
+
__m256i y_vec = _mm512_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
|
3131
|
+
_mm256_storeu_si256((__m256i *)(y + i), y_vec);
|
|
3132
|
+
}
|
|
3133
|
+
#endif
|
|
3134
|
+
for (; i + 7 < n; i += 8) {
|
|
3135
|
+
__m256 x_vec = _mm256_loadu_ps(x + i);
|
|
3136
|
+
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
|
3137
|
+
_mm_storeu_si128((__m128i *)(y + i), y_vec);
|
|
3138
|
+
}
|
|
3139
|
+
for (; i + 3 < n; i += 4) {
|
|
3140
|
+
__m128 x_vec = _mm_loadu_ps(x + i);
|
|
3141
|
+
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
|
3142
|
+
_mm_storel_epi64((__m128i *)(y + i), y_vec);
|
|
3143
|
+
}
|
|
3144
|
+
#endif
|
|
3145
|
+
for (; i < n; ++i) {
|
|
3146
|
+
y[i] = WSP_GGML_FP32_TO_FP16(x[i]);
|
|
3147
|
+
}
|
|
3148
|
+
}
|
|
3149
|
+
|
|
3150
|
+
void wsp_ggml_cpu_fp16_to_fp32(const wsp_ggml_fp16_t * x, float * y, int64_t n) {
|
|
3151
|
+
int64_t i = 0;
|
|
3152
|
+
#if defined(__F16C__)
|
|
3153
|
+
#if defined(__AVX512F__)
|
|
3154
|
+
for (; i + 15 < n; i += 16) {
|
|
3155
|
+
__m256i x_vec = _mm256_loadu_si256((const __m256i *)(x + i));
|
|
3156
|
+
__m512 y_vec = _mm512_cvtph_ps(x_vec);
|
|
3157
|
+
_mm512_storeu_ps(y + i, y_vec);
|
|
3158
|
+
}
|
|
3159
|
+
#endif
|
|
3160
|
+
for (; i + 7 < n; i += 8) {
|
|
3161
|
+
__m128i x_vec = _mm_loadu_si128((const __m128i *)(x + i));
|
|
3162
|
+
__m256 y_vec = _mm256_cvtph_ps(x_vec);
|
|
3163
|
+
_mm256_storeu_ps(y + i, y_vec);
|
|
3164
|
+
}
|
|
3165
|
+
for (; i + 3 < n; i += 4) {
|
|
3166
|
+
__m128i x_vec = _mm_loadl_epi64((const __m128i *)(x + i));
|
|
3167
|
+
__m128 y_vec = _mm_cvtph_ps(x_vec);
|
|
3168
|
+
_mm_storeu_ps(y + i, y_vec);
|
|
3169
|
+
}
|
|
3170
|
+
#endif
|
|
3171
|
+
for (; i < n; ++i) {
|
|
3172
|
+
y[i] = WSP_GGML_FP16_TO_FP32(x[i]);
|
|
3173
|
+
}
|
|
3174
|
+
}
|
|
3175
|
+
|
|
3176
|
+
void wsp_ggml_cpu_fp32_to_bf16(const float * x, wsp_ggml_bf16_t * y, int64_t n) {
|
|
3177
|
+
int64_t i = 0;
|
|
3178
|
+
for (; i < n; ++i) {
|
|
3179
|
+
y[i] = WSP_GGML_FP32_TO_BF16(x[i]);
|
|
3180
|
+
}
|
|
3181
|
+
}
|
|
3182
|
+
|
|
3183
|
+
void wsp_ggml_cpu_bf16_to_fp32(const wsp_ggml_bf16_t * x, float * y, int64_t n) {
|
|
3184
|
+
int64_t i = 0;
|
|
3185
|
+
#if defined(__AVX2__)
|
|
3186
|
+
#if defined(__AVX512F__)
|
|
3187
|
+
for (; i + 15 < n; i += 16) {
|
|
3188
|
+
_mm512_storeu_ps(y + i,
|
|
3189
|
+
_mm512_castsi512_ps(
|
|
3190
|
+
_mm512_slli_epi32(
|
|
3191
|
+
_mm512_cvtepu16_epi32(
|
|
3192
|
+
_mm256_loadu_si256(
|
|
3193
|
+
(const __m256i *)(x + i))),
|
|
3194
|
+
16)));
|
|
3195
|
+
}
|
|
3196
|
+
#endif
|
|
3197
|
+
for (; i + 7 < n; i += 8) {
|
|
3198
|
+
_mm256_storeu_ps(y + i,
|
|
3199
|
+
_mm256_castsi256_ps(
|
|
3200
|
+
_mm256_slli_epi32(
|
|
3201
|
+
_mm256_cvtepu16_epi32(
|
|
3202
|
+
_mm_loadu_si128(
|
|
3203
|
+
(const __m128i *)(x + i))),
|
|
3204
|
+
16)));
|
|
3205
|
+
}
|
|
3206
|
+
#endif
|
|
3207
|
+
for (; i < n; i++) {
|
|
3208
|
+
y[i] = WSP_GGML_BF16_TO_FP32(x[i]);
|
|
3209
|
+
}
|
|
3210
|
+
}
|
|
3211
|
+
|
|
3212
|
+
int wsp_ggml_cpu_has_avx(void) {
|
|
3213
|
+
#if defined(__AVX__)
|
|
3214
|
+
return 1;
|
|
3215
|
+
#else
|
|
3216
|
+
return 0;
|
|
3217
|
+
#endif
|
|
3218
|
+
}
|
|
3219
|
+
|
|
3220
|
+
int wsp_ggml_cpu_has_avx_vnni(void) {
|
|
3221
|
+
#if defined(__AVXVNNI__)
|
|
3222
|
+
return 1;
|
|
3223
|
+
#else
|
|
3224
|
+
return 0;
|
|
3225
|
+
#endif
|
|
3226
|
+
}
|
|
3227
|
+
|
|
3228
|
+
int wsp_ggml_cpu_has_avx2(void) {
|
|
3229
|
+
#if defined(__AVX2__)
|
|
3230
|
+
return 1;
|
|
3231
|
+
#else
|
|
3232
|
+
return 0;
|
|
3233
|
+
#endif
|
|
3234
|
+
}
|
|
3235
|
+
|
|
3236
|
+
int wsp_ggml_cpu_has_avx512(void) {
|
|
3237
|
+
#if defined(__AVX512F__)
|
|
3238
|
+
return 1;
|
|
3239
|
+
#else
|
|
3240
|
+
return 0;
|
|
3241
|
+
#endif
|
|
3242
|
+
}
|
|
3243
|
+
|
|
3244
|
+
int wsp_ggml_cpu_has_avx512_vbmi(void) {
|
|
3245
|
+
#if defined(__AVX512VBMI__)
|
|
3246
|
+
return 1;
|
|
3247
|
+
#else
|
|
3248
|
+
return 0;
|
|
3249
|
+
#endif
|
|
3250
|
+
}
|
|
3251
|
+
|
|
3252
|
+
int wsp_ggml_cpu_has_avx512_vnni(void) {
|
|
3253
|
+
#if defined(__AVX512VNNI__)
|
|
3254
|
+
return 1;
|
|
3255
|
+
#else
|
|
3256
|
+
return 0;
|
|
3257
|
+
#endif
|
|
3258
|
+
}
|
|
3259
|
+
|
|
3260
|
+
int wsp_ggml_cpu_has_avx512_bf16(void) {
|
|
3261
|
+
#if defined(__AVX512BF16__)
|
|
3262
|
+
return 1;
|
|
3263
|
+
#else
|
|
3264
|
+
return 0;
|
|
3265
|
+
#endif
|
|
3266
|
+
}
|
|
3267
|
+
|
|
3268
|
+
int wsp_ggml_cpu_has_amx_int8(void) {
|
|
3269
|
+
#if defined(__AMX_INT8__)
|
|
3270
|
+
return 1;
|
|
3271
|
+
#else
|
|
3272
|
+
return 0;
|
|
3273
|
+
#endif
|
|
3274
|
+
}
|
|
3275
|
+
|
|
3276
|
+
int wsp_ggml_cpu_has_bmi2(void) {
|
|
3277
|
+
#if defined(__BMI2__)
|
|
3278
|
+
return 1;
|
|
3279
|
+
#else
|
|
3280
|
+
return 0;
|
|
3281
|
+
#endif
|
|
3282
|
+
}
|
|
3283
|
+
|
|
3284
|
+
int wsp_ggml_cpu_has_fma(void) {
|
|
3285
|
+
#if defined(__FMA__)
|
|
3286
|
+
return 1;
|
|
3287
|
+
#else
|
|
3288
|
+
return 0;
|
|
3289
|
+
#endif
|
|
3290
|
+
}
|
|
3291
|
+
|
|
3292
|
+
int wsp_ggml_cpu_has_arm_fma(void) {
|
|
3293
|
+
#if defined(__ARM_FEATURE_FMA)
|
|
3294
|
+
return 1;
|
|
3295
|
+
#else
|
|
3296
|
+
return 0;
|
|
3297
|
+
#endif
|
|
3298
|
+
}
|
|
3299
|
+
|
|
3300
|
+
int wsp_ggml_cpu_has_riscv_v(void) {
|
|
3301
|
+
#if defined(__riscv_v_intrinsic)
|
|
3302
|
+
return 1;
|
|
3303
|
+
#else
|
|
3304
|
+
return 0;
|
|
3305
|
+
#endif
|
|
3306
|
+
}
|
|
3307
|
+
|
|
3308
|
+
int wsp_ggml_cpu_has_f16c(void) {
|
|
3309
|
+
#if defined(__F16C__)
|
|
3310
|
+
return 1;
|
|
3311
|
+
#else
|
|
3312
|
+
return 0;
|
|
3313
|
+
#endif
|
|
3314
|
+
}
|
|
3315
|
+
|
|
3316
|
+
int wsp_ggml_cpu_has_fp16_va(void) {
|
|
3317
|
+
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
|
|
3318
|
+
return 1;
|
|
3319
|
+
#else
|
|
3320
|
+
return 0;
|
|
3321
|
+
#endif
|
|
3322
|
+
}
|
|
3323
|
+
|
|
3324
|
+
int wsp_ggml_cpu_has_wasm_simd(void) {
|
|
3325
|
+
#if defined(__wasm_simd128__)
|
|
3326
|
+
return 1;
|
|
3327
|
+
#else
|
|
3328
|
+
return 0;
|
|
3329
|
+
#endif
|
|
3330
|
+
}
|
|
3331
|
+
|
|
3332
|
+
int wsp_ggml_cpu_has_llamafile(void) {
|
|
3333
|
+
#if defined(WSP_GGML_USE_LLAMAFILE)
|
|
3334
|
+
return 1;
|
|
3335
|
+
#else
|
|
3336
|
+
return 0;
|
|
3337
|
+
#endif
|
|
3338
|
+
}
|
|
3339
|
+
|
|
3340
|
+
int wsp_ggml_cpu_has_sse3(void) {
|
|
3341
|
+
#if defined(__SSE3__)
|
|
3342
|
+
return 1;
|
|
3343
|
+
#else
|
|
3344
|
+
return 0;
|
|
3345
|
+
#endif
|
|
3346
|
+
}
|
|
3347
|
+
|
|
3348
|
+
int wsp_ggml_cpu_has_ssse3(void) {
|
|
3349
|
+
#if defined(__SSSE3__)
|
|
3350
|
+
return 1;
|
|
3351
|
+
#else
|
|
3352
|
+
return 0;
|
|
3353
|
+
#endif
|
|
3354
|
+
}
|
|
3355
|
+
|
|
3356
|
+
int wsp_ggml_cpu_has_vsx(void) {
|
|
3357
|
+
#if defined(__POWER9_VECTOR__)
|
|
3358
|
+
return 1;
|
|
3359
|
+
#else
|
|
3360
|
+
return 0;
|
|
3361
|
+
#endif
|
|
3362
|
+
}
|
|
3363
|
+
|
|
3364
|
+
int wsp_ggml_cpu_has_vxe(void) {
|
|
3365
|
+
#if defined(__VXE__) || defined(__VXE2__)
|
|
3366
|
+
return 1;
|
|
3367
|
+
#else
|
|
3368
|
+
return 0;
|
|
3369
|
+
#endif
|
|
3370
|
+
}
|
|
3371
|
+
|
|
3372
|
+
int wsp_ggml_cpu_has_neon(void) {
|
|
3373
|
+
#if defined(__ARM_ARCH) && defined(__ARM_NEON)
|
|
3374
|
+
return 1;
|
|
3375
|
+
#else
|
|
3376
|
+
return 0;
|
|
3377
|
+
#endif
|
|
3378
|
+
}
|
|
3379
|
+
|
|
3380
|
+
int wsp_ggml_cpu_has_dotprod(void) {
|
|
3381
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
|
|
3382
|
+
return 1;
|
|
3383
|
+
#else
|
|
3384
|
+
return 0;
|
|
3385
|
+
#endif
|
|
3386
|
+
}
|
|
3387
|
+
|
|
3388
|
+
int wsp_ggml_cpu_has_sve(void) {
|
|
3389
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
|
|
3390
|
+
return 1;
|
|
3391
|
+
#else
|
|
3392
|
+
return 0;
|
|
3393
|
+
#endif
|
|
3394
|
+
}
|
|
3395
|
+
|
|
3396
|
+
int wsp_ggml_cpu_has_matmul_int8(void) {
|
|
3397
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
3398
|
+
return 1;
|
|
3399
|
+
#else
|
|
3400
|
+
return 0;
|
|
3401
|
+
#endif
|
|
3402
|
+
}
|
|
3403
|
+
|
|
3404
|
+
int wsp_ggml_cpu_get_sve_cnt(void) {
|
|
3405
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
|
|
3406
|
+
return wsp_ggml_arm_arch_features.sve_cnt;
|
|
3407
|
+
#else
|
|
3408
|
+
return 0;
|
|
3409
|
+
#endif
|
|
3410
|
+
}
|
|
3411
|
+
|
|
3412
|
+
int wsp_ggml_cpu_has_sme(void) {
|
|
3413
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
|
|
3414
|
+
return 1;
|
|
3415
|
+
#else
|
|
3416
|
+
return 0;
|
|
3417
|
+
#endif
|
|
3418
|
+
}
|
|
3419
|
+
|
|
3420
|
+
void wsp_ggml_cpu_init(void) {
|
|
3421
|
+
// needed to initialize f16 tables
|
|
3422
|
+
{
|
|
3423
|
+
struct wsp_ggml_init_params params = { 0, NULL, false };
|
|
3424
|
+
struct wsp_ggml_context * ctx = wsp_ggml_init(params);
|
|
3425
|
+
wsp_ggml_free(ctx);
|
|
3426
|
+
}
|
|
3427
|
+
|
|
3428
|
+
wsp_ggml_critical_section_start();
|
|
3429
|
+
|
|
3430
|
+
static bool is_first_call = true;
|
|
3431
|
+
|
|
3432
|
+
if (is_first_call) {
|
|
3433
|
+
// initialize GELU, Quick GELU, SILU and EXP F32 tables
|
|
3434
|
+
{
|
|
3435
|
+
const uint64_t t_start = wsp_ggml_time_us(); UNUSED(t_start);
|
|
3436
|
+
|
|
3437
|
+
for (int i = 0; i < (1 << 16); ++i) {
|
|
3438
|
+
union {
|
|
3439
|
+
uint16_t u16;
|
|
3440
|
+
wsp_ggml_fp16_t fp16;
|
|
3441
|
+
} u = {i};
|
|
3442
|
+
float f = WSP_GGML_FP16_TO_FP32(u.fp16);
|
|
3443
|
+
wsp_ggml_table_gelu_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_gelu_f32(f));
|
|
3444
|
+
wsp_ggml_table_gelu_quick_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_gelu_quick_f32(f));
|
|
3445
|
+
}
|
|
3446
|
+
|
|
3447
|
+
const uint64_t t_end = wsp_ggml_time_us(); UNUSED(t_end);
|
|
3448
|
+
|
|
3449
|
+
WSP_GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
|
|
3450
|
+
|
|
3451
|
+
#ifdef WSP_GGML_USE_OPENMP
|
|
3452
|
+
//if (!getenv("OMP_WAIT_POLICY")) {
|
|
3453
|
+
// // set the wait policy to active, so that OpenMP threads don't sleep
|
|
3454
|
+
// putenv("OMP_WAIT_POLICY=active");
|
|
3455
|
+
//}
|
|
3456
|
+
|
|
3457
|
+
if (!getenv("KMP_BLOCKTIME")) {
|
|
3458
|
+
// set the time to wait before sleeping a thread
|
|
3459
|
+
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
|
|
3460
|
+
putenv("KMP_BLOCKTIME=200"); // 200ms
|
|
3461
|
+
}
|
|
3462
|
+
#endif
|
|
3463
|
+
}
|
|
3464
|
+
|
|
3465
|
+
#if defined(__ARM_ARCH)
|
|
3466
|
+
wsp_ggml_init_arm_arch_features();
|
|
3467
|
+
#endif
|
|
3468
|
+
|
|
3469
|
+
is_first_call = false;
|
|
3470
|
+
}
|
|
3471
|
+
|
|
3472
|
+
wsp_ggml_critical_section_end();
|
|
3473
|
+
}
|