cui-llama.rn 1.7.4 → 1.7.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +217 -17
- package/android/src/main/CMakeLists.txt +34 -15
- package/android/src/main/java/com/rnllama/LlamaContext.java +79 -5
- package/android/src/main/java/com/rnllama/RNLlama.java +237 -0
- package/android/src/main/jni.cpp +213 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
- package/cpp/README.md +1 -1
- package/cpp/chat-parser.cpp +385 -0
- package/cpp/chat-parser.h +120 -0
- package/cpp/chat.cpp +726 -596
- package/cpp/chat.h +71 -6
- package/cpp/common.cpp +56 -38
- package/cpp/common.h +9 -3
- package/cpp/ggml-backend-reg.cpp +5 -0
- package/cpp/ggml-backend.cpp +10 -2
- package/cpp/ggml-common.h +4 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
- package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
- package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- package/cpp/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/ggml-cpu/common.h +4 -3
- package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
- package/cpp/ggml-cpu/ggml-cpu.c +123 -104
- package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
- package/cpp/ggml-cpu/ops.cpp +330 -148
- package/cpp/ggml-cpu/ops.h +1 -0
- package/cpp/ggml-cpu/quants.c +1158 -0
- package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/ggml-cpu/repack.cpp +1571 -0
- package/cpp/ggml-cpu/repack.h +98 -0
- package/cpp/ggml-cpu/simd-mappings.h +330 -38
- package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/ggml-cpu/vec.cpp +87 -18
- package/cpp/ggml-cpu/vec.h +249 -94
- package/cpp/ggml-cpu.h +1 -0
- package/cpp/ggml-impl.h +63 -183
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal.m +152 -45
- package/cpp/ggml-quants.c +0 -2
- package/cpp/ggml.c +61 -21
- package/cpp/ggml.h +22 -3
- package/cpp/gguf.cpp +24 -3
- package/cpp/json-partial.cpp +256 -0
- package/cpp/json-partial.h +38 -0
- package/cpp/json-schema-to-grammar.cpp +5 -47
- package/cpp/json-schema-to-grammar.h +4 -4
- package/cpp/llama-arch.cpp +153 -3
- package/cpp/llama-arch.h +27 -1
- package/cpp/llama-batch.cpp +741 -272
- package/cpp/llama-batch.h +112 -54
- package/cpp/llama-chat.cpp +30 -8
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +524 -339
- package/cpp/llama-context.h +38 -17
- package/cpp/llama-cparams.cpp +4 -0
- package/cpp/llama-cparams.h +2 -0
- package/cpp/llama-grammar.cpp +12 -2
- package/cpp/llama-graph.cpp +431 -356
- package/cpp/llama-graph.h +126 -58
- package/cpp/llama-hparams.cpp +10 -2
- package/cpp/llama-hparams.h +19 -2
- package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
- package/cpp/llama-kv-cache-unified-iswa.h +128 -0
- package/cpp/llama-kv-cache-unified.cpp +1841 -0
- package/cpp/llama-kv-cache-unified.h +303 -0
- package/cpp/llama-kv-cells.h +439 -0
- package/cpp/llama-memory-hybrid.cpp +246 -0
- package/cpp/llama-memory-hybrid.h +138 -0
- package/cpp/llama-memory-recurrent.cpp +1112 -0
- package/cpp/llama-memory-recurrent.h +183 -0
- package/cpp/llama-memory.cpp +41 -0
- package/cpp/llama-memory.h +86 -5
- package/cpp/llama-mmap.cpp +1 -1
- package/cpp/llama-model-loader.cpp +42 -17
- package/cpp/llama-model-saver.cpp +1 -0
- package/cpp/llama-model.cpp +1639 -513
- package/cpp/llama-model.h +26 -0
- package/cpp/llama-sampling.cpp +2 -2
- package/cpp/llama-vocab.cpp +65 -28
- package/cpp/llama-vocab.h +1 -0
- package/cpp/llama.cpp +11 -7
- package/cpp/llama.h +150 -42
- package/cpp/minja/chat-template.hpp +1 -1
- package/cpp/minja/minja.hpp +1 -1
- package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
- package/cpp/nlohmann/json_fwd.hpp +187 -0
- package/cpp/regex-partial.cpp +204 -0
- package/cpp/regex-partial.h +56 -0
- package/cpp/rn-llama.cpp +646 -35
- package/cpp/rn-llama.h +32 -1
- package/cpp/rn-tts.h +39 -0
- package/cpp/sampling.cpp +7 -8
- package/cpp/tools/mtmd/clip-impl.h +5 -0
- package/cpp/tools/mtmd/clip.cpp +572 -436
- package/cpp/tools/mtmd/clip.h +14 -4
- package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
- package/cpp/tools/mtmd/mtmd-audio.h +2 -17
- package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
- package/cpp/tools/mtmd/mtmd-helper.h +91 -0
- package/cpp/tools/mtmd/mtmd.cpp +368 -248
- package/cpp/tools/mtmd/mtmd.h +6 -70
- package/cpp/unicode.cpp +5 -0
- package/ios/CMakeLists.txt +26 -6
- package/ios/RNLlama.h +1 -1
- package/ios/RNLlama.mm +153 -3
- package/ios/RNLlamaContext.h +9 -1
- package/ios/RNLlamaContext.mm +112 -9
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +24 -0
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +46 -2
- package/src/index.ts +105 -1
- package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
- package/cpp/ggml-cpu/sgemm.cpp +0 -3544
- package/cpp/ggml-cpu/sgemm.h +0 -14
- package/cpp/llama-kv-cache.cpp +0 -2827
- package/cpp/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
- /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
package/cpp/ggml-cpu/ops.cpp
CHANGED
@@ -108,7 +108,7 @@ static void lm_ggml_compute_forward_dup_f16(
|
|
108
108
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
109
109
|
const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
110
110
|
for (int i00 = 0; i00 < ne00; i00++) {
|
111
|
-
dst_ptr[id] =
|
111
|
+
dst_ptr[id] = LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
112
112
|
id++;
|
113
113
|
}
|
114
114
|
}
|
@@ -130,7 +130,7 @@ static void lm_ggml_compute_forward_dup_f16(
|
|
130
130
|
const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
131
131
|
|
132
132
|
for (int i00 = 0; i00 < ne00; i00++) {
|
133
|
-
src0_f32[i00] =
|
133
|
+
src0_f32[i00] = LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
134
134
|
}
|
135
135
|
|
136
136
|
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
@@ -156,7 +156,7 @@ static void lm_ggml_compute_forward_dup_f16(
|
|
156
156
|
for (int i00 = 0; i00 < ne00; i00++) {
|
157
157
|
const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
158
158
|
|
159
|
-
dst_ptr[id] =
|
159
|
+
dst_ptr[id] = LM_GGML_CPU_FP16_TO_FP32(*src0_ptr);
|
160
160
|
id++;
|
161
161
|
}
|
162
162
|
}
|
@@ -267,7 +267,7 @@ static void lm_ggml_compute_forward_dup_f16(
|
|
267
267
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
268
268
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
269
269
|
|
270
|
-
*(float *) dst_ptr =
|
270
|
+
*(float *) dst_ptr = LM_GGML_CPU_FP16_TO_FP32(*(const lm_ggml_fp16_t *) src0_ptr);
|
271
271
|
|
272
272
|
if (++i10 == ne0) {
|
273
273
|
i10 = 0;
|
@@ -372,7 +372,7 @@ static void lm_ggml_compute_forward_dup_bf16(
|
|
372
372
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
373
373
|
const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
374
374
|
for (int i00 = 0; i00 < ne00; i00++) {
|
375
|
-
dst_ptr[id] =
|
375
|
+
dst_ptr[id] = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(src0_ptr[i00]));
|
376
376
|
id++;
|
377
377
|
}
|
378
378
|
}
|
@@ -473,7 +473,7 @@ static void lm_ggml_compute_forward_dup_bf16(
|
|
473
473
|
for (int i00 = 0; i00 < ne00; i00++) {
|
474
474
|
const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
475
475
|
|
476
|
-
dst_ptr[id] =
|
476
|
+
dst_ptr[id] = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*src0_ptr));
|
477
477
|
id++;
|
478
478
|
}
|
479
479
|
}
|
@@ -566,7 +566,7 @@ static void lm_ggml_compute_forward_dup_bf16(
|
|
566
566
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
567
567
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
568
568
|
|
569
|
-
*(lm_ggml_fp16_t *) dst_ptr =
|
569
|
+
*(lm_ggml_fp16_t *) dst_ptr = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*(const lm_ggml_bf16_t *) src0_ptr));
|
570
570
|
|
571
571
|
if (++i10 == ne0) {
|
572
572
|
i10 = 0;
|
@@ -765,7 +765,7 @@ static void lm_ggml_compute_forward_dup_f32(
|
|
765
765
|
for (int i00 = 0; i00 < ne00; i00++) {
|
766
766
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
767
767
|
|
768
|
-
dst_ptr[id] =
|
768
|
+
dst_ptr[id] = LM_GGML_CPU_FP32_TO_FP16(*src0_ptr);
|
769
769
|
id++;
|
770
770
|
}
|
771
771
|
}
|
@@ -878,7 +878,7 @@ static void lm_ggml_compute_forward_dup_f32(
|
|
878
878
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
879
879
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
880
880
|
|
881
|
-
*(lm_ggml_fp16_t *) dst_ptr =
|
881
|
+
*(lm_ggml_fp16_t *) dst_ptr = LM_GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
|
882
882
|
|
883
883
|
if (++i10 == ne0) {
|
884
884
|
i10 = 0;
|
@@ -1419,7 +1419,7 @@ static void lm_ggml_compute_forward_add1_f16_f32(
|
|
1419
1419
|
lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
1420
1420
|
lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
1421
1421
|
for (int i = 0; i < ne0; i++) {
|
1422
|
-
dst_ptr[i] =
|
1422
|
+
dst_ptr[i] = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
|
1423
1423
|
}
|
1424
1424
|
}
|
1425
1425
|
}
|
@@ -1435,7 +1435,7 @@ static void lm_ggml_compute_forward_add1_f16_f16(
|
|
1435
1435
|
LM_GGML_ASSERT(lm_ggml_is_scalar(src1));
|
1436
1436
|
|
1437
1437
|
// scalar to add
|
1438
|
-
const float v =
|
1438
|
+
const float v = LM_GGML_CPU_FP16_TO_FP32(*(lm_ggml_fp16_t *) src1->data);
|
1439
1439
|
|
1440
1440
|
const int ith = params->ith;
|
1441
1441
|
const int nth = params->nth;
|
@@ -1467,7 +1467,7 @@ static void lm_ggml_compute_forward_add1_f16_f16(
|
|
1467
1467
|
lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
1468
1468
|
lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
1469
1469
|
for (int i = 0; i < ne0; i++) {
|
1470
|
-
dst_ptr[i] =
|
1470
|
+
dst_ptr[i] = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
|
1471
1471
|
}
|
1472
1472
|
}
|
1473
1473
|
}
|
@@ -1889,7 +1889,7 @@ static void lm_ggml_compute_forward_sum_f16(
|
|
1889
1889
|
}
|
1890
1890
|
}
|
1891
1891
|
}
|
1892
|
-
((lm_ggml_fp16_t *) dst->data)[0] =
|
1892
|
+
((lm_ggml_fp16_t *) dst->data)[0] = LM_GGML_CPU_FP32_TO_FP16(sum);
|
1893
1893
|
}
|
1894
1894
|
|
1895
1895
|
static void lm_ggml_compute_forward_sum_bf16(
|
@@ -2660,7 +2660,7 @@ static void lm_ggml_compute_forward_gelu_f16(
|
|
2660
2660
|
#ifndef NDEBUG
|
2661
2661
|
for (int k = 0; k < nc; k++) {
|
2662
2662
|
const lm_ggml_fp16_t x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2663
|
-
const float v =
|
2663
|
+
const float v = LM_GGML_CPU_FP16_TO_FP32(x);
|
2664
2664
|
LM_GGML_UNUSED(v);
|
2665
2665
|
assert(!isnan(v));
|
2666
2666
|
assert(!isinf(v));
|
@@ -2763,7 +2763,7 @@ static void lm_ggml_compute_forward_gelu_erf_f16(
|
|
2763
2763
|
#ifndef NDEBUG
|
2764
2764
|
for (int k = 0; k < nc; k++) {
|
2765
2765
|
const lm_ggml_fp16_t x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2766
|
-
const float v =
|
2766
|
+
const float v = LM_GGML_CPU_FP16_TO_FP32(x);
|
2767
2767
|
LM_GGML_UNUSED(v);
|
2768
2768
|
assert(!isnan(v));
|
2769
2769
|
assert(!isinf(v));
|
@@ -2866,7 +2866,7 @@ static void lm_ggml_compute_forward_gelu_quick_f16(
|
|
2866
2866
|
#ifndef NDEBUG
|
2867
2867
|
for (int k = 0; k < nc; k++) {
|
2868
2868
|
const lm_ggml_fp16_t x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2869
|
-
const float v =
|
2869
|
+
const float v = LM_GGML_CPU_FP16_TO_FP32(x);
|
2870
2870
|
LM_GGML_UNUSED(v);
|
2871
2871
|
assert(!isnan(v));
|
2872
2872
|
assert(!isinf(v));
|
@@ -2969,7 +2969,7 @@ static void lm_ggml_compute_forward_silu_f16(
|
|
2969
2969
|
#ifndef NDEBUG
|
2970
2970
|
for (int k = 0; k < nc; k++) {
|
2971
2971
|
const lm_ggml_fp16_t x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
2972
|
-
const float v =
|
2972
|
+
const float v = LM_GGML_CPU_FP16_TO_FP32(x);
|
2973
2973
|
LM_GGML_UNUSED(v);
|
2974
2974
|
assert(!isnan(v));
|
2975
2975
|
assert(!isinf(v));
|
@@ -3163,7 +3163,7 @@ static void lm_ggml_compute_forward_silu_back_f16(
|
|
3163
3163
|
#ifndef NDEBUG
|
3164
3164
|
for (int k = 0; k < nc; k++) {
|
3165
3165
|
const float x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
3166
|
-
const float v =
|
3166
|
+
const float v = LM_GGML_CPU_FP16_TO_FP32(x);
|
3167
3167
|
LM_GGML_UNUSED(v);
|
3168
3168
|
assert(!isnan(v));
|
3169
3169
|
assert(!isinf(v));
|
@@ -4500,7 +4500,7 @@ static void lm_ggml_compute_forward_get_rows_back_f32_f16(
|
|
4500
4500
|
|
4501
4501
|
for (int j = 0; j < nc; ++j) {
|
4502
4502
|
lm_ggml_fp16_t v = ((lm_ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
|
4503
|
-
((float *) ((char *) dst->data + r*dst->nb[1]))[j] +=
|
4503
|
+
((float *) ((char *) dst->data + r*dst->nb[1]))[j] += LM_GGML_CPU_FP16_TO_FP32(v);
|
4504
4504
|
}
|
4505
4505
|
}
|
4506
4506
|
}
|
@@ -4792,7 +4792,7 @@ static void lm_ggml_compute_forward_soft_max_f32(
|
|
4792
4792
|
if (mp_f32) {
|
4793
4793
|
if (use_f16) {
|
4794
4794
|
for (int i = 0; i < nc; ++i) {
|
4795
|
-
wp[i] += slope*
|
4795
|
+
wp[i] += slope*LM_GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
4796
4796
|
}
|
4797
4797
|
} else {
|
4798
4798
|
for (int i = 0; i < nc; ++i) {
|
@@ -5018,8 +5018,8 @@ static void lm_ggml_compute_forward_clamp_f16(
|
|
5018
5018
|
lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
5019
5019
|
|
5020
5020
|
for (int i = 0; i < nc; i++) {
|
5021
|
-
float v =
|
5022
|
-
dst_ptr[i] =
|
5021
|
+
float v = LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
|
5022
|
+
dst_ptr[i] = LM_GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
|
5023
5023
|
}
|
5024
5024
|
}
|
5025
5025
|
}
|
@@ -5476,11 +5476,11 @@ static void lm_ggml_compute_forward_rope_f16(
|
|
5476
5476
|
const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
5477
5477
|
lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
5478
5478
|
|
5479
|
-
const float x0 =
|
5480
|
-
const float x1 =
|
5479
|
+
const float x0 = LM_GGML_CPU_FP16_TO_FP32(src[0]);
|
5480
|
+
const float x1 = LM_GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
5481
5481
|
|
5482
|
-
dst_data[0] =
|
5483
|
-
dst_data[n_dims] =
|
5482
|
+
dst_data[0] = LM_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
5483
|
+
dst_data[n_dims] = LM_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
5484
5484
|
}
|
5485
5485
|
} else {
|
5486
5486
|
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
@@ -5492,11 +5492,11 @@ static void lm_ggml_compute_forward_rope_f16(
|
|
5492
5492
|
const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
5493
5493
|
lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
5494
5494
|
|
5495
|
-
const float x0 =
|
5496
|
-
const float x1 =
|
5495
|
+
const float x0 = LM_GGML_CPU_FP16_TO_FP32(src[0]);
|
5496
|
+
const float x1 = LM_GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
5497
5497
|
|
5498
|
-
dst_data[0] =
|
5499
|
-
dst_data[n_dims/2] =
|
5498
|
+
dst_data[0] = LM_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
5499
|
+
dst_data[n_dims/2] = LM_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
5500
5500
|
}
|
5501
5501
|
}
|
5502
5502
|
} else {
|
@@ -5507,11 +5507,11 @@ static void lm_ggml_compute_forward_rope_f16(
|
|
5507
5507
|
const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
5508
5508
|
lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
5509
5509
|
|
5510
|
-
const float x0 =
|
5511
|
-
const float x1 =
|
5510
|
+
const float x0 = LM_GGML_CPU_FP16_TO_FP32(src[0]);
|
5511
|
+
const float x1 = LM_GGML_CPU_FP16_TO_FP32(src[1]);
|
5512
5512
|
|
5513
|
-
dst_data[0] =
|
5514
|
-
dst_data[1] =
|
5513
|
+
dst_data[0] = LM_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
5514
|
+
dst_data[1] = LM_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
5515
5515
|
}
|
5516
5516
|
}
|
5517
5517
|
|
@@ -5525,11 +5525,11 @@ static void lm_ggml_compute_forward_rope_f16(
|
|
5525
5525
|
const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
5526
5526
|
lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
5527
5527
|
|
5528
|
-
const float x0 =
|
5529
|
-
const float x1 =
|
5528
|
+
const float x0 = LM_GGML_CPU_FP16_TO_FP32(src[0]);
|
5529
|
+
const float x1 = LM_GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
5530
5530
|
|
5531
|
-
dst_data[0] =
|
5532
|
-
dst_data[n_dims] =
|
5531
|
+
dst_data[0] = LM_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
5532
|
+
dst_data[n_dims] = LM_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
5533
5533
|
}
|
5534
5534
|
} else {
|
5535
5535
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
@@ -5640,7 +5640,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32(
|
|
5640
5640
|
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
5641
5641
|
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
5642
5642
|
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
5643
|
-
dst_data[i10*ne11 + i11] =
|
5643
|
+
dst_data[i10*ne11 + i11] = LM_GGML_CPU_FP32_TO_FP16(src[i10]);
|
5644
5644
|
}
|
5645
5645
|
}
|
5646
5646
|
}
|
@@ -5933,7 +5933,7 @@ static void lm_ggml_compute_forward_im2col_f16(
|
|
5933
5933
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
5934
5934
|
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
5935
5935
|
} else {
|
5936
|
-
dst_data[iic*(KH*KW) + ikh*KW + ikw] =
|
5936
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = LM_GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
5937
5937
|
}
|
5938
5938
|
}
|
5939
5939
|
}
|
@@ -6109,7 +6109,7 @@ void lm_ggml_compute_forward_conv_transpose_2d(
|
|
6109
6109
|
const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
|
6110
6110
|
lm_ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
|
6111
6111
|
for (int i10 = 0; i10 < ne10; i10++) {
|
6112
|
-
dst_data[i10*ne12 + i12] =
|
6112
|
+
dst_data[i10*ne12 + i12] = LM_GGML_CPU_FP32_TO_FP16(src[i10]);
|
6113
6113
|
}
|
6114
6114
|
}
|
6115
6115
|
}
|
@@ -6358,7 +6358,7 @@ static void lm_ggml_compute_forward_pool_1d_sk_p0(
|
|
6358
6358
|
case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error");
|
6359
6359
|
}
|
6360
6360
|
for (int ki = 0; ki < k; ++ki) {
|
6361
|
-
const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] :
|
6361
|
+
const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_CPU_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]);
|
6362
6362
|
switch (op) {
|
6363
6363
|
case LM_GGML_OP_POOL_AVG: drow[i] += srow_j; break;
|
6364
6364
|
case LM_GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
|
@@ -6450,7 +6450,7 @@ void lm_ggml_compute_forward_pool_2d(
|
|
6450
6450
|
for (int kx = 0; kx < k0; ++kx) {
|
6451
6451
|
int j = ix + kx;
|
6452
6452
|
if (j < 0 || j >= src->ne[0]) continue;
|
6453
|
-
const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] :
|
6453
|
+
const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_CPU_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]);
|
6454
6454
|
switch (op) {
|
6455
6455
|
case LM_GGML_OP_POOL_AVG: *out += srow_j; break;
|
6456
6456
|
case LM_GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
|
@@ -6538,7 +6538,7 @@ void lm_ggml_compute_forward_pool_2d_back(
|
|
6538
6538
|
}
|
6539
6539
|
|
6540
6540
|
const float val = dst->type == LM_GGML_TYPE_F32 ?
|
6541
|
-
((const float *) drowf)[j] :
|
6541
|
+
((const float *) drowf)[j] : LM_GGML_CPU_FP16_TO_FP32(((const lm_ggml_fp16_t *) drowf)[j]);
|
6542
6542
|
if (val <= maxval) {
|
6543
6543
|
continue;
|
6544
6544
|
}
|
@@ -6558,7 +6558,7 @@ void lm_ggml_compute_forward_pool_2d_back(
|
|
6558
6558
|
if (dst->type == LM_GGML_TYPE_F32) {
|
6559
6559
|
((float *) drow)[j] += grad0;
|
6560
6560
|
} else {
|
6561
|
-
((lm_ggml_fp16_t *) drow)[j] =
|
6561
|
+
((lm_ggml_fp16_t *) drow)[j] = LM_GGML_CPU_FP32_TO_FP16(grad0 + LM_GGML_CPU_FP16_TO_FP32(((const lm_ggml_fp16_t *) drow)[j]));
|
6562
6562
|
}
|
6563
6563
|
} else if (op == LM_GGML_OP_POOL_AVG) {
|
6564
6564
|
const float grad = grad0 / ka;
|
@@ -6577,7 +6577,7 @@ void lm_ggml_compute_forward_pool_2d_back(
|
|
6577
6577
|
if (dst->type == LM_GGML_TYPE_F32) {
|
6578
6578
|
((float *) drow)[j] += grad;
|
6579
6579
|
} else {
|
6580
|
-
((lm_ggml_fp16_t *) drow)[j] +=
|
6580
|
+
((lm_ggml_fp16_t *) drow)[j] += LM_GGML_CPU_FP32_TO_FP16(grad);
|
6581
6581
|
}
|
6582
6582
|
}
|
6583
6583
|
}
|
@@ -6793,6 +6793,73 @@ void lm_ggml_compute_forward_pad_reflect_1d(
|
|
6793
6793
|
}
|
6794
6794
|
}
|
6795
6795
|
|
6796
|
+
// lm_ggml_compute_forward_roll
|
6797
|
+
|
6798
|
+
static int64_t lm_ggml_wrap_index(int64_t i, int64_t ne) {
|
6799
|
+
if (i < 0) {
|
6800
|
+
return i + ne;
|
6801
|
+
} else if (i >= ne) {
|
6802
|
+
return i - ne;
|
6803
|
+
}
|
6804
|
+
return i;
|
6805
|
+
}
|
6806
|
+
|
6807
|
+
static void lm_ggml_compute_forward_roll_f32(
|
6808
|
+
const lm_ggml_compute_params * params,
|
6809
|
+
lm_ggml_tensor * dst) {
|
6810
|
+
|
6811
|
+
const lm_ggml_tensor * src0 = dst->src[0];
|
6812
|
+
const float * src_data = (const float *) src0->data;
|
6813
|
+
float * dst_data = (float *) dst->data;
|
6814
|
+
|
6815
|
+
LM_GGML_TENSOR_UNARY_OP_LOCALS
|
6816
|
+
|
6817
|
+
const int s0 = lm_ggml_get_op_params_i32(dst, 0);
|
6818
|
+
const int s1 = lm_ggml_get_op_params_i32(dst, 1);
|
6819
|
+
const int s2 = lm_ggml_get_op_params_i32(dst, 2);
|
6820
|
+
const int s3 = lm_ggml_get_op_params_i32(dst, 3);
|
6821
|
+
|
6822
|
+
const int64_t total = ne1 * ne2 * ne3;
|
6823
|
+
const int64_t per_thread = (total + params->nth) / params->nth;
|
6824
|
+
const int64_t start = params->ith * per_thread;
|
6825
|
+
const int64_t end = std::min(start + per_thread, total);
|
6826
|
+
|
6827
|
+
for (int64_t i = start; i < end; ++i) {
|
6828
|
+
const int64_t i1 = i % ne1;
|
6829
|
+
const int64_t i2 = (i / ne1) % ne2;
|
6830
|
+
const int64_t i3 = i / (ne2 * ne1);
|
6831
|
+
float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
|
6832
|
+
|
6833
|
+
const int64_t i01 = lm_ggml_wrap_index(i1 - s1, ne01);
|
6834
|
+
const int64_t i02 = lm_ggml_wrap_index(i2 - s2, ne02);
|
6835
|
+
const int64_t i03 = lm_ggml_wrap_index(i3 - s3, ne03);
|
6836
|
+
const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
|
6837
|
+
|
6838
|
+
const int64_t s = lm_ggml_wrap_index(-s0, ne00);
|
6839
|
+
const int64_t n = ne00 - s;
|
6840
|
+
lm_ggml_vec_cpy_f32(n, dst_row, src_row + s);
|
6841
|
+
lm_ggml_vec_cpy_f32(s, dst_row + n, src_row);
|
6842
|
+
}
|
6843
|
+
}
|
6844
|
+
|
6845
|
+
void lm_ggml_compute_forward_roll(
|
6846
|
+
const lm_ggml_compute_params * params,
|
6847
|
+
lm_ggml_tensor * dst) {
|
6848
|
+
|
6849
|
+
const lm_ggml_tensor * src0 = dst->src[0];
|
6850
|
+
|
6851
|
+
switch (src0->type) {
|
6852
|
+
case LM_GGML_TYPE_F32:
|
6853
|
+
{
|
6854
|
+
lm_ggml_compute_forward_roll_f32(params, dst);
|
6855
|
+
} break;
|
6856
|
+
default:
|
6857
|
+
{
|
6858
|
+
LM_GGML_ABORT("fatal error");
|
6859
|
+
}
|
6860
|
+
}
|
6861
|
+
}
|
6862
|
+
|
6796
6863
|
// lm_ggml_compute_forward_arange
|
6797
6864
|
|
6798
6865
|
static void lm_ggml_compute_forward_arange_f32(
|
@@ -7075,7 +7142,7 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16(
|
|
7075
7142
|
// loop over n_kv and n_head_kv
|
7076
7143
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
7077
7144
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
7078
|
-
const float mv = mp ? slope*
|
7145
|
+
const float mv = mp ? slope*LM_GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
7079
7146
|
if (mv == -INFINITY) {
|
7080
7147
|
continue;
|
7081
7148
|
}
|
@@ -7143,7 +7210,7 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16(
|
|
7143
7210
|
|
7144
7211
|
if (v->type == LM_GGML_TYPE_F16) {
|
7145
7212
|
for (int64_t d = 0; d < DV; ++d) {
|
7146
|
-
VKQ32[d] =
|
7213
|
+
VKQ32[d] = LM_GGML_CPU_FP16_TO_FP32(VKQ16[d]);
|
7147
7214
|
}
|
7148
7215
|
}
|
7149
7216
|
|
@@ -7633,39 +7700,83 @@ static void lm_ggml_compute_forward_ssm_scan_f32(
|
|
7633
7700
|
const int ir1 = MIN(ir0 + dr, nr);
|
7634
7701
|
const int ir = ir1 - ir0;
|
7635
7702
|
|
7636
|
-
|
7637
|
-
for (int
|
7638
|
-
|
7639
|
-
|
7640
|
-
|
7641
|
-
|
7642
|
-
|
7643
|
-
|
7644
|
-
|
7645
|
-
|
7646
|
-
|
7647
|
-
|
7648
|
-
|
7649
|
-
|
7650
|
-
|
7651
|
-
|
7652
|
-
|
7653
|
-
|
7654
|
-
|
7655
|
-
|
7656
|
-
|
7657
|
-
|
7658
|
-
|
7659
|
-
|
7660
|
-
|
7661
|
-
|
7662
|
-
|
7663
|
-
|
7703
|
+
#ifdef __ARM_FEATURE_SVE
|
7704
|
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
7705
|
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
7706
|
+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
7707
|
+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
7708
|
+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
7709
|
+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
7710
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
7711
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
7712
|
+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
7713
|
+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
7714
|
+
|
7715
|
+
// use the output as the source for the next token-wise iterations
|
7716
|
+
if (i2 > 0) { s0 = s; }
|
7717
|
+
|
7718
|
+
// d_inner
|
7719
|
+
for (int i1 = 0; i1 < ir; ++i1) {
|
7720
|
+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
7721
|
+
float x_dt = x[i1] * dt_soft_plus;
|
7722
|
+
svfloat32_t vx_dt = LM_GGML_F32_VEC_SET1(x_dt);
|
7723
|
+
svfloat32_t vdt_soft_plus = LM_GGML_F32_VEC_SET1(dt_soft_plus);
|
7724
|
+
svfloat32_t r1_vector = LM_GGML_F32_VEC_ZERO;
|
7725
|
+
|
7726
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
7727
|
+
svfloat32_t vA = LM_GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
7728
|
+
svfloat32_t vB = LM_GGML_F32_VEC_LOAD(&B[k]);
|
7729
|
+
svfloat32_t vC = LM_GGML_F32_VEC_LOAD(&C[k]);
|
7730
|
+
svfloat32_t vs0 = LM_GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
7731
|
+
|
7732
|
+
svfloat32_t t1 = LM_GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
7733
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
7734
|
+
svfloat32_t t2 = LM_GGML_F32_VEC_MUL(vx_dt, vB);
|
7735
|
+
|
7736
|
+
vs0 = LM_GGML_F32_VEC_FMA(vs0, t1, t2);
|
7737
|
+
r1_vector = LM_GGML_F32_VEC_ADD(LM_GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
7738
|
+
|
7739
|
+
LM_GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
7740
|
+
}
|
7741
|
+
y[i1] = LM_GGML_F32xt_REDUCE_ONE(r1_vector);
|
7664
7742
|
}
|
7665
|
-
y[i1] = sumf;
|
7666
7743
|
}
|
7667
7744
|
}
|
7668
|
-
|
7745
|
+
#else
|
7746
|
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
7747
|
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
7748
|
+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
7749
|
+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
7750
|
+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
7751
|
+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
7752
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
7753
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
7754
|
+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
7755
|
+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
7756
|
+
|
7757
|
+
// use the output as the source for the next token-wise iterations
|
7758
|
+
if (i2 > 0) { s0 = s; }
|
7759
|
+
|
7760
|
+
// d_inner
|
7761
|
+
for (int i1 = 0; i1 < ir; ++i1) {
|
7762
|
+
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
7763
|
+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
7764
|
+
float x_dt = x[i1] * dt_soft_plus;
|
7765
|
+
float sumf = 0.0f;
|
7766
|
+
// d_state
|
7767
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
7768
|
+
int i = i0 + i1*nc;
|
7769
|
+
// state = prev_state * dA + dB * x
|
7770
|
+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
7771
|
+
// y = rowwise_dotprod(state, C)
|
7772
|
+
sumf += state * C[i0];
|
7773
|
+
s[i] = state;
|
7774
|
+
}
|
7775
|
+
y[i1] = sumf;
|
7776
|
+
}
|
7777
|
+
}
|
7778
|
+
}
|
7779
|
+
#endif
|
7669
7780
|
}
|
7670
7781
|
|
7671
7782
|
void lm_ggml_compute_forward_ssm_scan(
|
@@ -8070,6 +8181,14 @@ static void lm_ggml_compute_forward_rwkv_wkv6_f32(
|
|
8070
8181
|
#define LM_GGML_F32X_MUL LM_GGML_F32x16_MUL
|
8071
8182
|
#define LM_GGML_F32X_FMA LM_GGML_F32x16_FMA
|
8072
8183
|
#define WKV_VECTOR_SIZE 16
|
8184
|
+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
8185
|
+
#define LM_GGML_F32X LM_GGML_F32xt
|
8186
|
+
#define LM_GGML_F32X_SET1 LM_GGML_F32xt_SET1
|
8187
|
+
#define LM_GGML_F32X_LOAD LM_GGML_F32xt_LOAD
|
8188
|
+
#define LM_GGML_F32X_STORE LM_GGML_F32xt_STORE
|
8189
|
+
#define LM_GGML_F32X_MUL LM_GGML_F32xt_MUL
|
8190
|
+
#define LM_GGML_F32X_FMA LM_GGML_F32xt_FMA
|
8191
|
+
#define WKV_VECTOR_SIZE 8
|
8073
8192
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
8074
8193
|
#define LM_GGML_F32X LM_GGML_F32x4
|
8075
8194
|
#define LM_GGML_F32X_SET1 LM_GGML_F32x4_SET1
|
@@ -8081,7 +8200,13 @@ static void lm_ggml_compute_forward_rwkv_wkv6_f32(
|
|
8081
8200
|
#endif
|
8082
8201
|
|
8083
8202
|
#ifdef WKV_VECTOR_SIZE
|
8084
|
-
|
8203
|
+
int wkv_vector_size;
|
8204
|
+
#if defined(__ARM_FEATURE_SVE)
|
8205
|
+
wkv_vector_size = svcntw();
|
8206
|
+
#else
|
8207
|
+
wkv_vector_size = WKV_VECTOR_SIZE;
|
8208
|
+
#endif
|
8209
|
+
const int64_t vec_count = head_size / wkv_vector_size;
|
8085
8210
|
|
8086
8211
|
for (int64_t t = 0; t < T; t++) {
|
8087
8212
|
size_t t_offset = t * t_stride;
|
@@ -8111,7 +8236,7 @@ static void lm_ggml_compute_forward_rwkv_wkv6_f32(
|
|
8111
8236
|
LM_GGML_F32X time_decay_vec = LM_GGML_F32X_SET1(time_decay_val);
|
8112
8237
|
|
8113
8238
|
for (int64_t j = 0; j < vec_count; j++) {
|
8114
|
-
size_t base_j = j *
|
8239
|
+
size_t base_j = j * wkv_vector_size;
|
8115
8240
|
size_t t_h_j_offset = t_h_offset + base_j;
|
8116
8241
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
8117
8242
|
|
@@ -8136,7 +8261,7 @@ static void lm_ggml_compute_forward_rwkv_wkv6_f32(
|
|
8136
8261
|
}
|
8137
8262
|
|
8138
8263
|
// Handle remaining elements, this will not be used.
|
8139
|
-
for (int64_t j = vec_count *
|
8264
|
+
for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
|
8140
8265
|
size_t t_h_j_offset = t_h_offset + j;
|
8141
8266
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
8142
8267
|
float v_val = v[t_h_j_offset];
|
@@ -8272,6 +8397,14 @@ static void lm_ggml_compute_forward_gla_f32(
|
|
8272
8397
|
#define LM_GGML_F32X_MUL LM_GGML_F32x16_MUL
|
8273
8398
|
#define LM_GGML_F32X_FMA LM_GGML_F32x16_FMA
|
8274
8399
|
#define GLA_VECTOR_SIZE 16
|
8400
|
+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
8401
|
+
#define LM_GGML_F32X LM_GGML_F32xt
|
8402
|
+
#define LM_GGML_F32X_SET1 LM_GGML_F32xt_SET1
|
8403
|
+
#define LM_GGML_F32X_LOAD LM_GGML_F32xt_LOAD
|
8404
|
+
#define LM_GGML_F32X_STORE LM_GGML_F32xt_STORE
|
8405
|
+
#define LM_GGML_F32X_MUL LM_GGML_F32xt_MUL
|
8406
|
+
#define LM_GGML_F32X_FMA LM_GGML_F32xt_FMA
|
8407
|
+
#define GLA_VECTOR_SIZE 8
|
8275
8408
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
8276
8409
|
#define LM_GGML_F32X LM_GGML_F32x4
|
8277
8410
|
#define LM_GGML_F32X_SET1 LM_GGML_F32x4_SET1
|
@@ -8283,7 +8416,13 @@ static void lm_ggml_compute_forward_gla_f32(
|
|
8283
8416
|
#endif
|
8284
8417
|
|
8285
8418
|
#ifdef GLA_VECTOR_SIZE
|
8286
|
-
|
8419
|
+
int gla_vector_size;
|
8420
|
+
#if defined(__ARM_FEATURE_SVE)
|
8421
|
+
gla_vector_size = svcntw();
|
8422
|
+
#else
|
8423
|
+
gla_vector_size = GLA_VECTOR_SIZE;
|
8424
|
+
#endif
|
8425
|
+
const int64_t vec_count = head_size / gla_vector_size;
|
8287
8426
|
|
8288
8427
|
for (int64_t t = 0; t < T; t++) {
|
8289
8428
|
size_t t_offset = t * t_stride;
|
@@ -8310,7 +8449,7 @@ static void lm_ggml_compute_forward_gla_f32(
|
|
8310
8449
|
LM_GGML_F32X g_vec = LM_GGML_F32X_SET1(g_val);
|
8311
8450
|
|
8312
8451
|
for (int64_t j = 0; j < vec_count; j++) {
|
8313
|
-
size_t base_j = j *
|
8452
|
+
size_t base_j = j * gla_vector_size;
|
8314
8453
|
size_t t_h_j_offset = t_h_offset + base_j;
|
8315
8454
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
8316
8455
|
|
@@ -8334,7 +8473,7 @@ static void lm_ggml_compute_forward_gla_f32(
|
|
8334
8473
|
}
|
8335
8474
|
|
8336
8475
|
// Handle remaining elements, this will not be used.
|
8337
|
-
for (int64_t j = vec_count *
|
8476
|
+
for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
|
8338
8477
|
size_t t_h_j_offset = t_h_offset + j;
|
8339
8478
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
8340
8479
|
float v_val = v[t_h_j_offset];
|
@@ -8443,83 +8582,126 @@ static void lm_ggml_compute_forward_rwkv_wkv7_f32(
|
|
8443
8582
|
int64_t h_stride_2d = head_size * head_size;
|
8444
8583
|
|
8445
8584
|
#if defined(LM_GGML_SIMD)
|
8446
|
-
|
8447
|
-
|
8448
|
-
int64_t
|
8449
|
-
|
8450
|
-
|
8451
|
-
|
8452
|
-
|
8453
|
-
|
8454
|
-
int64_t
|
8455
|
-
|
8456
|
-
|
8457
|
-
|
8458
|
-
|
8459
|
-
int64_t
|
8460
|
-
|
8461
|
-
|
8585
|
+
#if defined(__ARM_FEATURE_SVE)
|
8586
|
+
// scalar Route to scalar implementation //TODO: Write SVE code
|
8587
|
+
for (int64_t t = 0; t < T; t++) {
|
8588
|
+
int64_t t_offset = t * t_stride;
|
8589
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
8590
|
+
float * state_cur = state + state_offset;
|
8591
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
8592
|
+
|
8593
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
8594
|
+
int64_t h_offset = h * h_stride;
|
8595
|
+
int64_t t_h_offset = t_offset + h_offset;
|
8596
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
8597
|
+
|
8598
|
+
for (int64_t i = 0; i < head_size; i++) {
|
8599
|
+
int64_t t_h_i_offset = t_h_offset + i;
|
8600
|
+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
8601
|
+
|
8602
|
+
float v_val = v[t_h_i_offset];
|
8603
|
+
|
8604
|
+
float sa = 0, result = 0;
|
8605
|
+
for (int64_t j = 0; j < head_size; j++) {
|
8606
|
+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
8607
|
+
}
|
8462
8608
|
|
8463
|
-
|
8464
|
-
|
8465
|
-
|
8466
|
-
|
8467
|
-
|
8468
|
-
|
8469
|
-
|
8470
|
-
|
8471
|
-
|
8472
|
-
|
8473
|
-
|
8609
|
+
for (int64_t j = 0; j < head_size; j++) {
|
8610
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
8611
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
8612
|
+
|
8613
|
+
float r_val = r[t_h_j_offset];
|
8614
|
+
float w_val = w[t_h_j_offset];
|
8615
|
+
float k_val = k[t_h_j_offset];
|
8616
|
+
float b_val = b[t_h_j_offset];
|
8617
|
+
float kv_val = v_val * k_val;
|
8618
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
8619
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
8620
|
+
result += state_cur[h_2d_i_j_offset] * r_val;
|
8474
8621
|
}
|
8475
|
-
|
8622
|
+
dst_data[t_h_i_offset] = result;
|
8476
8623
|
}
|
8624
|
+
}
|
8625
|
+
}
|
8626
|
+
#else
|
8627
|
+
for (int64_t t = 0; t < T; t++) {
|
8628
|
+
int64_t t_offset = t * t_stride;
|
8629
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
8630
|
+
float * state_cur = state + state_offset;
|
8631
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
8632
|
+
|
8633
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
8634
|
+
int64_t h_offset = h * h_stride;
|
8635
|
+
int64_t t_h_offset = t_offset + h_offset;
|
8636
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
8637
|
+
|
8638
|
+
for (int64_t ii = 0; ii < head_size; ii++) {
|
8639
|
+
int64_t t_h_i_offset = t_h_offset + ii;
|
8640
|
+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
8641
|
+
|
8642
|
+
LM_GGML_F32_VEC v_vec = LM_GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
8643
|
+
|
8644
|
+
float sa = 0;
|
8645
|
+
{
|
8646
|
+
LM_GGML_F32_VEC sum[LM_GGML_F32_ARR] = { LM_GGML_F32_VEC_ZERO };
|
8647
|
+
LM_GGML_F32_VEC ax[LM_GGML_F32_ARR];
|
8648
|
+
LM_GGML_F32_VEC ay[LM_GGML_F32_ARR];
|
8649
|
+
for (int64_t j = 0; j < head_size; j += LM_GGML_F32_STEP) {
|
8650
|
+
for (int64_t kk = 0; kk < LM_GGML_F32_ARR; kk++) {
|
8651
|
+
ax[kk] = LM_GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * LM_GGML_F32_EPR]);
|
8652
|
+
ay[kk] = LM_GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * LM_GGML_F32_EPR]);
|
8653
|
+
sum[kk] = LM_GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
8654
|
+
}
|
8655
|
+
}
|
8656
|
+
LM_GGML_F32_VEC_REDUCE(sa, sum);
|
8657
|
+
}
|
8477
8658
|
|
8478
|
-
|
8659
|
+
LM_GGML_F32_VEC sa_vec = LM_GGML_F32_VEC_SET1(sa);
|
8479
8660
|
|
8480
|
-
|
8481
|
-
|
8482
|
-
|
8483
|
-
|
8484
|
-
|
8485
|
-
|
8661
|
+
int64_t j = 0;
|
8662
|
+
LM_GGML_F32_VEC result_vec[LM_GGML_F32_ARR] = { LM_GGML_F32_VEC_ZERO };
|
8663
|
+
for (; j < head_size; j += LM_GGML_F32_STEP) {
|
8664
|
+
for (int64_t kk = 0; kk < LM_GGML_F32_ARR; kk++) {
|
8665
|
+
int64_t t_h_j_offset = t_h_offset + j + kk * LM_GGML_F32_EPR;
|
8666
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * LM_GGML_F32_EPR;
|
8486
8667
|
|
8487
|
-
|
8488
|
-
|
8489
|
-
|
8490
|
-
|
8668
|
+
LM_GGML_F32_VEC r_vec = LM_GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
8669
|
+
LM_GGML_F32_VEC w_vec = LM_GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
8670
|
+
LM_GGML_F32_VEC k_vec = LM_GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
8671
|
+
LM_GGML_F32_VEC b_vec = LM_GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
8491
8672
|
|
8492
|
-
|
8673
|
+
k_vec = LM_GGML_F32_VEC_MUL(v_vec, k_vec);
|
8493
8674
|
|
8494
|
-
|
8495
|
-
|
8496
|
-
|
8497
|
-
|
8498
|
-
|
8675
|
+
LM_GGML_F32_VEC state_vec = LM_GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
8676
|
+
// kv + s * decay + sa * b
|
8677
|
+
state_vec = LM_GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
8678
|
+
state_vec = LM_GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
8679
|
+
LM_GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
8499
8680
|
|
8500
|
-
|
8681
|
+
result_vec[kk] = LM_GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
8682
|
+
}
|
8683
|
+
}
|
8684
|
+
LM_GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
8685
|
+
|
8686
|
+
// There shouldn't be left-overs though.
|
8687
|
+
for (; j < head_size; j++) {
|
8688
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
8689
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
8690
|
+
|
8691
|
+
float r_val = r[t_h_j_offset];
|
8692
|
+
float w_val = w[t_h_j_offset];
|
8693
|
+
float k_val = k[t_h_j_offset];
|
8694
|
+
float b_val = b[t_h_j_offset];
|
8695
|
+
float kv_val = v[t_h_i_offset] * k_val;
|
8696
|
+
|
8697
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
8698
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
8699
|
+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
8501
8700
|
}
|
8502
|
-
}
|
8503
|
-
LM_GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
8504
|
-
|
8505
|
-
// There shouldn't be left-overs though.
|
8506
|
-
for (; j < head_size; j++) {
|
8507
|
-
int64_t t_h_j_offset = t_h_offset + j;
|
8508
|
-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
8509
|
-
|
8510
|
-
float r_val = r[t_h_j_offset];
|
8511
|
-
float w_val = w[t_h_j_offset];
|
8512
|
-
float k_val = k[t_h_j_offset];
|
8513
|
-
float b_val = b[t_h_j_offset];
|
8514
|
-
float kv_val = v[t_h_i_offset] * k_val;
|
8515
|
-
|
8516
|
-
float prev_state_val = state_prev[h_2d_i_j_offset];
|
8517
|
-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
8518
|
-
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
8519
8701
|
}
|
8520
8702
|
}
|
8521
8703
|
}
|
8522
|
-
|
8704
|
+
#endif
|
8523
8705
|
#else
|
8524
8706
|
for (int64_t t = 0; t < T; t++) {
|
8525
8707
|
int64_t t_offset = t * t_stride;
|