whisper.rn 0.4.0-rc.8 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. package/README.md +5 -1
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +44 -13
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -38
  7. package/android/src/main/jni.cpp +38 -1
  8. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  15. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  16. package/cpp/coreml/whisper-compat.h +10 -0
  17. package/cpp/coreml/whisper-compat.m +35 -0
  18. package/cpp/coreml/whisper-decoder-impl.h +27 -15
  19. package/cpp/coreml/whisper-decoder-impl.m +36 -10
  20. package/cpp/coreml/whisper-encoder-impl.h +21 -9
  21. package/cpp/coreml/whisper-encoder-impl.m +29 -3
  22. package/cpp/ggml-alloc.c +727 -517
  23. package/cpp/ggml-alloc.h +47 -65
  24. package/cpp/ggml-backend-impl.h +196 -57
  25. package/cpp/ggml-backend-reg.cpp +591 -0
  26. package/cpp/ggml-backend.cpp +2016 -0
  27. package/cpp/ggml-backend.h +234 -89
  28. package/cpp/ggml-common.h +1861 -0
  29. package/cpp/ggml-cpp.h +39 -0
  30. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  31. package/cpp/ggml-cpu/amx/amx.h +8 -0
  32. package/cpp/ggml-cpu/amx/common.h +91 -0
  33. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  34. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  35. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  36. package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
  37. package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
  38. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  39. package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
  40. package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
  41. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  42. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  43. package/cpp/ggml-cpu/binary-ops.h +16 -0
  44. package/cpp/ggml-cpu/common.h +72 -0
  45. package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
  46. package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
  47. package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
  48. package/cpp/ggml-cpu/ops.cpp +9085 -0
  49. package/cpp/ggml-cpu/ops.h +111 -0
  50. package/cpp/ggml-cpu/quants.c +1157 -0
  51. package/cpp/ggml-cpu/quants.h +89 -0
  52. package/cpp/ggml-cpu/repack.cpp +1570 -0
  53. package/cpp/ggml-cpu/repack.h +98 -0
  54. package/cpp/ggml-cpu/simd-mappings.h +1006 -0
  55. package/cpp/ggml-cpu/traits.cpp +36 -0
  56. package/cpp/ggml-cpu/traits.h +38 -0
  57. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  58. package/cpp/ggml-cpu/unary-ops.h +28 -0
  59. package/cpp/ggml-cpu/vec.cpp +321 -0
  60. package/cpp/ggml-cpu/vec.h +973 -0
  61. package/cpp/ggml-cpu.h +143 -0
  62. package/cpp/ggml-impl.h +525 -168
  63. package/cpp/ggml-metal-impl.h +622 -0
  64. package/cpp/ggml-metal.h +16 -14
  65. package/cpp/ggml-metal.m +5289 -1859
  66. package/cpp/ggml-opt.cpp +1037 -0
  67. package/cpp/ggml-opt.h +237 -0
  68. package/cpp/ggml-quants.c +2916 -6877
  69. package/cpp/ggml-quants.h +87 -249
  70. package/cpp/ggml-threading.cpp +12 -0
  71. package/cpp/ggml-threading.h +14 -0
  72. package/cpp/ggml-whisper-sim.metallib +0 -0
  73. package/cpp/ggml-whisper.metallib +0 -0
  74. package/cpp/ggml.c +3293 -16770
  75. package/cpp/ggml.h +778 -835
  76. package/cpp/gguf.cpp +1347 -0
  77. package/cpp/gguf.h +202 -0
  78. package/cpp/rn-whisper.cpp +84 -0
  79. package/cpp/rn-whisper.h +2 -0
  80. package/cpp/whisper-arch.h +197 -0
  81. package/cpp/whisper.cpp +3240 -944
  82. package/cpp/whisper.h +144 -31
  83. package/ios/CMakeLists.txt +95 -0
  84. package/ios/RNWhisper.h +5 -0
  85. package/ios/RNWhisper.mm +124 -37
  86. package/ios/RNWhisperAudioUtils.h +1 -0
  87. package/ios/RNWhisperAudioUtils.m +24 -13
  88. package/ios/RNWhisperContext.h +8 -2
  89. package/ios/RNWhisperContext.mm +42 -8
  90. package/ios/rnwhisper.xcframework/Info.plist +74 -0
  91. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  92. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  93. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  94. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  95. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  96. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  97. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  98. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  99. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  100. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  101. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  102. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  103. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  104. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  105. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  106. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  107. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  108. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  109. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  110. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  111. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  112. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  113. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  114. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  115. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  116. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  117. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  118. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  119. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  120. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  121. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  122. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  123. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  124. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  125. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  126. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  127. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  128. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  129. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  130. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  131. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  132. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  133. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  134. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  135. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  136. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  137. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  138. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  139. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  140. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  141. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  142. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  143. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  144. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  145. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  146. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  147. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  148. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  149. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  150. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  151. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  152. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  153. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  154. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  155. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  156. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  157. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  158. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  159. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  160. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  161. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  162. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  163. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  164. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  165. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  166. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  167. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  168. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  169. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  170. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  171. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  172. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  173. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  174. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  175. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  176. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  177. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  178. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  179. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  180. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  181. package/jest/mock.js +14 -1
  182. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  183. package/lib/commonjs/index.js +48 -19
  184. package/lib/commonjs/index.js.map +1 -1
  185. package/lib/commonjs/version.json +1 -1
  186. package/lib/module/NativeRNWhisper.js.map +1 -1
  187. package/lib/module/index.js +48 -19
  188. package/lib/module/index.js.map +1 -1
  189. package/lib/module/version.json +1 -1
  190. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  191. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  192. package/lib/typescript/index.d.ts +25 -3
  193. package/lib/typescript/index.d.ts.map +1 -1
  194. package/package.json +15 -10
  195. package/src/NativeRNWhisper.ts +12 -3
  196. package/src/index.ts +63 -24
  197. package/src/version.json +1 -1
  198. package/whisper-rn.podspec +18 -18
  199. package/cpp/README.md +0 -4
  200. package/cpp/ggml-backend.c +0 -1718
  201. package/cpp/ggml-metal-whisper.metal +0 -5820
package/cpp/ggml-cpp.h ADDED
@@ -0,0 +1,39 @@
1
+ #pragma once
2
+
3
+ #ifndef __cplusplus
4
+ #error "This header is for C++ only"
5
+ #endif
6
+
7
+ #include "ggml.h"
8
+ #include "ggml-alloc.h"
9
+ #include "ggml-backend.h"
10
+ #include "gguf.h"
11
+ #include <memory>
12
+
13
+ // Smart pointers for ggml types
14
+
15
+ // ggml
16
+
17
+ struct wsp_ggml_context_deleter { void operator()(wsp_ggml_context * ctx) { wsp_ggml_free(ctx); } };
18
+ struct wsp_gguf_context_deleter { void operator()(wsp_gguf_context * ctx) { wsp_gguf_free(ctx); } };
19
+
20
+ typedef std::unique_ptr<wsp_ggml_context, wsp_ggml_context_deleter> wsp_ggml_context_ptr;
21
+ typedef std::unique_ptr<wsp_gguf_context, wsp_gguf_context_deleter> wsp_gguf_context_ptr;
22
+
23
+ // ggml-alloc
24
+
25
+ struct wsp_ggml_gallocr_deleter { void operator()(wsp_ggml_gallocr_t galloc) { wsp_ggml_gallocr_free(galloc); } };
26
+
27
+ typedef std::unique_ptr<wsp_ggml_gallocr, wsp_ggml_gallocr_deleter> wsp_ggml_gallocr_ptr;
28
+
29
+ // ggml-backend
30
+
31
+ struct wsp_ggml_backend_deleter { void operator()(wsp_ggml_backend_t backend) { wsp_ggml_backend_free(backend); } };
32
+ struct wsp_ggml_backend_buffer_deleter { void operator()(wsp_ggml_backend_buffer_t buffer) { wsp_ggml_backend_buffer_free(buffer); } };
33
+ struct wsp_ggml_backend_event_deleter { void operator()(wsp_ggml_backend_event_t event) { wsp_ggml_backend_event_free(event); } };
34
+ struct wsp_ggml_backend_sched_deleter { void operator()(wsp_ggml_backend_sched_t sched) { wsp_ggml_backend_sched_free(sched); } };
35
+
36
+ typedef std::unique_ptr<wsp_ggml_backend, wsp_ggml_backend_deleter> wsp_ggml_backend_ptr;
37
+ typedef std::unique_ptr<wsp_ggml_backend_buffer, wsp_ggml_backend_buffer_deleter> wsp_ggml_backend_buffer_ptr;
38
+ typedef std::unique_ptr<wsp_ggml_backend_event, wsp_ggml_backend_event_deleter> wsp_ggml_backend_event_ptr;
39
+ typedef std::unique_ptr<wsp_ggml_backend_sched, wsp_ggml_backend_sched_deleter> wsp_ggml_backend_sched_ptr;
@@ -0,0 +1,221 @@
1
+ #include "amx.h"
2
+ #include "common.h"
3
+ #include "mmq.h"
4
+ #include "ggml-backend-impl.h"
5
+ #include "ggml-backend.h"
6
+ #include "ggml-impl.h"
7
+ #include "ggml-cpu.h"
8
+ #include "traits.h"
9
+
10
+ #if defined(__gnu_linux__)
11
+ #include <sys/syscall.h>
12
+ #include <unistd.h>
13
+ #endif
14
+
15
+ #include <cstdlib>
16
+ #include <cstring>
17
+ #include <memory>
18
+
19
+ #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
20
+
21
+ // AMX type_trais
22
+ namespace ggml::cpu::amx {
23
+ class tensor_traits : public ggml::cpu::tensor_traits {
24
+ bool work_size(int /* n_threads */, const struct wsp_ggml_tensor * op, size_t & size) override {
25
+ size = wsp_ggml_backend_amx_desired_wsize(op);
26
+ return true;
27
+ }
28
+
29
+ bool compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * op) override {
30
+ if (op->op == WSP_GGML_OP_MUL_MAT) {
31
+ wsp_ggml_backend_amx_mul_mat(params, op);
32
+ return true;
33
+ }
34
+ return false;
35
+ }
36
+ };
37
+
38
+ static ggml::cpu::tensor_traits * get_tensor_traits(wsp_ggml_backend_buffer_t, struct wsp_ggml_tensor *) {
39
+ static tensor_traits traits;
40
+ return &traits;
41
+ }
42
+ } // namespace ggml::cpu::amx
43
+
44
+ // AMX buffer interface
45
+ static void wsp_ggml_backend_amx_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
46
+ free(buffer->context);
47
+ }
48
+
49
+ static void * wsp_ggml_backend_amx_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
50
+ return (void *) (buffer->context);
51
+ }
52
+
53
+ static enum wsp_ggml_status wsp_ggml_backend_amx_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
54
+ tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);
55
+
56
+ WSP_GGML_UNUSED(buffer);
57
+ return WSP_GGML_STATUS_SUCCESS;
58
+ }
59
+
60
+ static void wsp_ggml_backend_amx_buffer_memset_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor,
61
+ uint8_t value, size_t offset, size_t size) {
62
+ memset((char *) tensor->data + offset, value, size);
63
+
64
+ WSP_GGML_UNUSED(buffer);
65
+ }
66
+
67
+ static void wsp_ggml_backend_amx_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor,
68
+ const void * data, size_t offset, size_t size) {
69
+ if (qtype_has_amx_kernels(tensor->type)) {
70
+ WSP_GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, wsp_ggml_type_name(tensor->type));
71
+ wsp_ggml_backend_amx_convert_weight(tensor, data, offset, size);
72
+ } else {
73
+ memcpy((char *) tensor->data + offset, data, size);
74
+ }
75
+
76
+ WSP_GGML_UNUSED(buffer);
77
+ }
78
+
79
+ /*
80
+ // need to figure what we need to do with buffer->extra.
81
+ static void wsp_ggml_backend_amx_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
82
+ WSP_GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
83
+ memcpy(data, (const char *)tensor->data + offset, size);
84
+
85
+ WSP_GGML_UNUSED(buffer);
86
+ }
87
+
88
+ static bool wsp_ggml_backend_amx_buffer_cpy_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
89
+ if (wsp_ggml_backend_buffer_is_host(src->buffer)) {
90
+ if (qtype_has_amx_kernels(src->type)) {
91
+ wsp_ggml_backend_amx_convert_weight(dst, src->data, 0, wsp_ggml_nbytes(dst));
92
+ } else {
93
+ memcpy(dst->data, src->data, wsp_ggml_nbytes(src));
94
+ }
95
+ return true;
96
+ }
97
+ return false;
98
+
99
+ WSP_GGML_UNUSED(buffer);
100
+ }
101
+ */
102
+
103
+ static void wsp_ggml_backend_amx_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) {
104
+ memset(buffer->context, value, buffer->size);
105
+ }
106
+
107
+ static wsp_ggml_backend_buffer_i wsp_ggml_backend_amx_buffer_interface = {
108
+ /* .free_buffer = */ wsp_ggml_backend_amx_buffer_free_buffer,
109
+ /* .get_base = */ wsp_ggml_backend_amx_buffer_get_base,
110
+ /* .init_tensor = */ wsp_ggml_backend_amx_buffer_init_tensor,
111
+ /* .memset_tensor = */ wsp_ggml_backend_amx_buffer_memset_tensor,
112
+ /* .set_tensor = */ wsp_ggml_backend_amx_buffer_set_tensor,
113
+ /* .get_tensor = */ nullptr,
114
+ /* .cpy_tensor = */ nullptr,
115
+ /* .clear = */ wsp_ggml_backend_amx_buffer_clear,
116
+ /* .reset = */ nullptr,
117
+ };
118
+
119
+ static const char * wsp_ggml_backend_amx_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
120
+ return "AMX";
121
+
122
+ WSP_GGML_UNUSED(buft);
123
+ }
124
+
125
+ static wsp_ggml_backend_buffer_t wsp_ggml_backend_amx_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
126
+ void * data = wsp_ggml_aligned_malloc(size);
127
+ if (data == NULL) {
128
+ fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
129
+ return NULL;
130
+ }
131
+
132
+ return wsp_ggml_backend_buffer_init(buft, wsp_ggml_backend_amx_buffer_interface, data, size);
133
+ }
134
+
135
+ static size_t wsp_ggml_backend_amx_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
136
+ return TENSOR_ALIGNMENT;
137
+
138
+ WSP_GGML_UNUSED(buft);
139
+ }
140
+
141
+ namespace ggml::cpu::amx {
142
+ class extra_buffer_type : ggml::cpu::extra_buffer_type {
143
+ bool supports_op(wsp_ggml_backend_dev_t, const struct wsp_ggml_tensor * op) override {
144
+ // handle only 2d gemm for now
145
+ auto is_contiguous_2d = [](const struct wsp_ggml_tensor * t) {
146
+ return wsp_ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
147
+ };
148
+
149
+ if (op->op == WSP_GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
150
+ is_contiguous_2d(op->src[1]) && // src1 must be contiguous
151
+ op->src[0]->buffer && op->src[0]->buffer->buft == wsp_ggml_backend_amx_buffer_type() &&
152
+ op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
153
+ (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == WSP_GGML_TYPE_F16))) {
154
+ // src1 must be host buffer
155
+ if (op->src[1]->buffer && !wsp_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
156
+ return false;
157
+ }
158
+ // src1 must be float32
159
+ if (op->src[1]->type == WSP_GGML_TYPE_F32) {
160
+ return true;
161
+ }
162
+ }
163
+ return false;
164
+ }
165
+
166
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct wsp_ggml_tensor * op) override {
167
+ if (op->op == WSP_GGML_OP_MUL_MAT && op->src[0]->buffer &&
168
+ op->src[0]->buffer->buft == wsp_ggml_backend_amx_buffer_type()) {
169
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
170
+ }
171
+
172
+ return nullptr;
173
+ }
174
+ };
175
+ } // namespace ggml::cpu::amx
176
+
177
+ static size_t wsp_ggml_backend_amx_buffer_type_get_alloc_size(wsp_ggml_backend_buffer_type_t buft, const wsp_ggml_tensor * tensor) {
178
+ return wsp_ggml_backend_amx_get_alloc_size(tensor);
179
+
180
+ WSP_GGML_UNUSED(buft);
181
+ }
182
+
183
+ #define ARCH_GET_XCOMP_PERM 0x1022
184
+ #define ARCH_REQ_XCOMP_PERM 0x1023
185
+ #define XFEATURE_XTILECFG 17
186
+ #define XFEATURE_XTILEDATA 18
187
+
188
+ static bool wsp_ggml_amx_init() {
189
+ #if defined(__gnu_linux__)
190
+ if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
191
+ fprintf(stderr, "AMX is not ready to be used!\n");
192
+ return false;
193
+ }
194
+ return true;
195
+ #elif defined(_WIN32)
196
+ return true;
197
+ #endif
198
+ }
199
+
200
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_amx_buffer_type() {
201
+ static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_amx = {
202
+ /* .iface = */ {
203
+ /* .get_name = */ wsp_ggml_backend_amx_buffer_type_get_name,
204
+ /* .alloc_buffer = */ wsp_ggml_backend_amx_buffer_type_alloc_buffer,
205
+ /* .get_alignment = */ wsp_ggml_backend_amx_buffer_type_get_alignment,
206
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
207
+ /* .get_alloc_size = */ wsp_ggml_backend_amx_buffer_type_get_alloc_size,
208
+ /* .is_host = */ nullptr,
209
+ },
210
+ /* .device = */ wsp_ggml_backend_reg_dev_get(wsp_ggml_backend_cpu_reg(), 0),
211
+ /* .context = */ new ggml::cpu::amx::extra_buffer_type(),
212
+ };
213
+
214
+ if (!wsp_ggml_amx_init()) {
215
+ return nullptr;
216
+ }
217
+
218
+ return &wsp_ggml_backend_buffer_type_amx;
219
+ }
220
+
221
+ #endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)
@@ -0,0 +1,8 @@
1
+ #include "ggml-backend.h"
2
+ #include "ggml-cpu-impl.h"
3
+
4
+ // GGML internal header
5
+
6
+ #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
7
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_amx_buffer_type(void);
8
+ #endif
@@ -0,0 +1,91 @@
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-cpu-impl.h"
5
+
6
+ #include <algorithm>
7
+ #include <memory>
8
+ #include <type_traits>
9
+
10
+ #if defined(WSP_GGML_USE_OPENMP)
11
+ #include <omp.h>
12
+ #endif
13
+
14
+ #define TILE_M 16
15
+ #define TILE_N 16
16
+ #define TILE_K 32
17
+ #define VNNI_BLK 4
18
+
19
+ #define AMX_BLK_SIZE 32
20
+
21
+ #define TMM0 0
22
+ #define TMM1 1
23
+ #define TMM2 2
24
+ #define TMM3 3
25
+ #define TMM4 4
26
+ #define TMM5 5
27
+ #define TMM6 6
28
+ #define TMM7 7
29
+
30
+ // parallel routines
31
+ template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
32
+ inline T div_up(T x, T y) { return (x + y - 1) / y; }
33
+
34
+ template <typename T>
35
+ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
36
+ #if 0
37
+ // onednn partition pattern
38
+ T& n_my = n_end;
39
+ if (nth <= 1 || n == 0) {
40
+ n_start = 0;
41
+ n_my = n;
42
+ } else {
43
+ T n1 = div_up(n, nth);
44
+ T n2 = n1 - 1;
45
+ T T1 = n - n2 * nth;
46
+ n_my = ith < T1 ? n1 : n2;
47
+ n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
48
+ }
49
+ n_end += n_start;
50
+ #else
51
+ // pytorch aten partition pattern
52
+ T n_my = div_up(n, nth);
53
+ n_start = ith * n_my;
54
+ n_end = std::min(n_start + n_my, n);
55
+ #endif
56
+ }
57
+
58
+ template <typename func_t>
59
+ inline void parallel_for(int n, const func_t& f) {
60
+ #if defined(WSP_GGML_USE_OPENMP)
61
+ #pragma omp parallel
62
+ {
63
+ int nth = omp_get_num_threads();
64
+ int ith = omp_get_thread_num();
65
+ int tbegin, tend;
66
+ balance211(n, nth, ith, tbegin, tend);
67
+ f(tbegin, tend);
68
+ }
69
+ #else
70
+ f(0, n);
71
+ #endif
72
+ }
73
+
74
+ template <typename func_t>
75
+ inline void parallel_for_ggml(const wsp_ggml_compute_params * params, int n, const func_t & f) {
76
+ int tbegin, tend;
77
+ balance211(n, params->nth, params->ith, tbegin, tend);
78
+ f(tbegin, tend);
79
+ }
80
+
81
+ // quantized types that have AMX support
82
+ inline bool qtype_has_amx_kernels(const enum wsp_ggml_type type) {
83
+ // TODO: fix padding for vnni format
84
+ return (type == WSP_GGML_TYPE_Q4_0) ||
85
+ (type == WSP_GGML_TYPE_Q4_1) ||
86
+ (type == WSP_GGML_TYPE_Q8_0) ||
87
+ (type == WSP_GGML_TYPE_Q4_K) ||
88
+ (type == WSP_GGML_TYPE_Q5_K) ||
89
+ (type == WSP_GGML_TYPE_Q6_K) ||
90
+ (type == WSP_GGML_TYPE_IQ4_XS);
91
+ }