cui-llama.rn 1.6.0 → 1.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -7
- package/android/src/main/CMakeLists.txt +16 -11
- package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
- package/android/src/main/jni.cpp +20 -4
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/cpp/LICENSE +21 -0
- package/cpp/chat.cpp +1 -1
- package/cpp/common.cpp +17 -2
- package/cpp/common.h +7 -3
- package/cpp/ggml-alloc.c +4 -1
- package/cpp/ggml-cpp.h +1 -1
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
- package/cpp/ggml-cpu/common.h +72 -0
- package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
- package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
- package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
- package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
- package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
- package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
- package/cpp/ggml-cpu.h +5 -0
- package/cpp/ggml-impl.h +16 -9
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal.m +492 -47
- package/cpp/ggml.c +134 -244
- package/cpp/ggml.h +61 -94
- package/cpp/json-schema-to-grammar.cpp +3 -0
- package/cpp/llama-arch.cpp +46 -17
- package/cpp/llama-arch.h +9 -0
- package/cpp/llama-batch.cpp +5 -1
- package/cpp/llama-batch.h +2 -1
- package/cpp/llama-chat.cpp +31 -10
- package/cpp/llama-chat.h +3 -2
- package/cpp/llama-context.cpp +104 -489
- package/cpp/llama-context.h +14 -30
- package/cpp/llama-graph.cpp +69 -62
- package/cpp/llama-graph.h +21 -18
- package/cpp/llama-hparams.h +5 -0
- package/cpp/llama-kv-cache.cpp +1497 -391
- package/cpp/llama-kv-cache.h +272 -80
- package/cpp/llama-memory.h +11 -1
- package/cpp/llama-model.cpp +502 -176
- package/cpp/llama-model.h +13 -3
- package/cpp/llama-sampling.cpp +2 -1
- package/cpp/llama-vocab.cpp +8 -1
- package/cpp/llama.h +14 -11
- package/cpp/rn-llama.cpp +20 -172
- package/cpp/rn-llama.h +1 -5
- package/ios/CMakeLists.txt +13 -10
- package/ios/RNLlama.h +6 -0
- package/ios/RNLlama.mm +5 -0
- package/ios/RNLlamaContext.mm +26 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +4 -0
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +5 -0
- package/cpp/binary-ops.h +0 -16
- package/cpp/ops.h +0 -128
- package/cpp/simd-mappings.h +0 -888
- package/cpp/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
- /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
- /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
- /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
- /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
- /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
- /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
- /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
- /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
- /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
- /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
@@ -65,6 +65,7 @@ void lm_ggml_compute_forward_conv_transpose_1d(const struct lm_ggml_compute_para
|
|
65
65
|
void lm_ggml_compute_forward_im2col(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
66
66
|
void lm_ggml_compute_forward_im2col_back_f32(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
67
67
|
void lm_ggml_compute_forward_conv_transpose_2d(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
68
|
+
void lm_ggml_compute_forward_conv_2d_dw(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
68
69
|
void lm_ggml_compute_forward_pool_1d(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
69
70
|
void lm_ggml_compute_forward_pool_2d(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
70
71
|
void lm_ggml_compute_forward_pool_2d_back(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
@@ -96,29 +97,10 @@ void lm_ggml_compute_forward_add_rel_pos(const struct lm_ggml_compute_params * p
|
|
96
97
|
void lm_ggml_compute_forward_rwkv_wkv6(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
97
98
|
void lm_ggml_compute_forward_rwkv_wkv7(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
98
99
|
void lm_ggml_compute_forward_gla(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
99
|
-
void lm_ggml_compute_forward_map_unary(
|
100
|
-
const struct lm_ggml_compute_params * params,
|
101
|
-
struct lm_ggml_tensor * dst,
|
102
|
-
const lm_ggml_unary_op_f32_t fun);
|
103
|
-
void lm_ggml_compute_forward_map_binary(
|
104
|
-
const struct lm_ggml_compute_params * params,
|
105
|
-
struct lm_ggml_tensor * dst,
|
106
|
-
const lm_ggml_binary_op_f32_t fun);
|
107
|
-
void lm_ggml_compute_forward_map_custom1_f32(
|
108
|
-
const struct lm_ggml_compute_params * params,
|
109
|
-
struct lm_ggml_tensor * dst,
|
110
|
-
const lm_ggml_custom1_op_f32_t fun);
|
111
|
-
void lm_ggml_compute_forward_map_custom2_f32(
|
112
|
-
const struct lm_ggml_compute_params * params,
|
113
|
-
struct lm_ggml_tensor * dst,
|
114
|
-
const lm_ggml_custom2_op_f32_t fun);
|
115
|
-
void lm_ggml_compute_forward_map_custom3_f32(
|
116
|
-
const struct lm_ggml_compute_params * params,
|
117
|
-
struct lm_ggml_tensor * dst,
|
118
|
-
const lm_ggml_custom3_op_f32_t fun);
|
119
100
|
void lm_ggml_compute_forward_map_custom1(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
120
101
|
void lm_ggml_compute_forward_map_custom2(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
121
102
|
void lm_ggml_compute_forward_map_custom3(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
103
|
+
void lm_ggml_compute_forward_custom(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
122
104
|
void lm_ggml_compute_forward_cross_entropy_loss(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
123
105
|
void lm_ggml_compute_forward_cross_entropy_loss_back(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
124
106
|
void lm_ggml_compute_forward_opt_step_adamw(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
@@ -1054,6 +1054,493 @@ class tinyBLAS_Q0_AVX {
|
|
1054
1054
|
} \
|
1055
1055
|
} \
|
1056
1056
|
|
1057
|
+
template <typename TA, typename TB, typename TC>
|
1058
|
+
class tinyBLAS_BF16_PPC {
|
1059
|
+
public:
|
1060
|
+
tinyBLAS_BF16_PPC(int64_t k,
|
1061
|
+
const TA *A, int64_t lda,
|
1062
|
+
const TB *B, int64_t ldb,
|
1063
|
+
TC *C, int64_t ldc,
|
1064
|
+
int ith, int nth)
|
1065
|
+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
1066
|
+
}
|
1067
|
+
|
1068
|
+
void matmul(int64_t m, int64_t n) {
|
1069
|
+
mnpack(0, m, 0, n);
|
1070
|
+
}
|
1071
|
+
|
1072
|
+
private:
|
1073
|
+
void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
|
1074
|
+
vec_t t[8], s[8];
|
1075
|
+
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
1076
|
+
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
1077
|
+
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
1078
|
+
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
1079
|
+
|
1080
|
+
if (numVec == 2) {
|
1081
|
+
t[0] = vec_perm(c[0], c[1], swiz1);
|
1082
|
+
t[1] = vec_perm(c[2], c[3], swiz1);
|
1083
|
+
s[0] = vec_perm(t[0], t[1], swiz3);
|
1084
|
+
s[1] = vec_perm(t[0], t[1], swiz4);
|
1085
|
+
vec_xst(s[0], 0, (vec_t*)vecOffset);
|
1086
|
+
vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
|
1087
|
+
} else if (numVec == 4) {
|
1088
|
+
t[0] = vec_perm(c[0], c[1], swiz1);
|
1089
|
+
t[1] = vec_perm(c[0], c[1], swiz2);
|
1090
|
+
t[2] = vec_perm(c[2], c[3], swiz1);
|
1091
|
+
t[3] = vec_perm(c[2], c[3], swiz2);
|
1092
|
+
s[0] = vec_perm(t[0], t[2], swiz3);
|
1093
|
+
s[1] = vec_perm(t[0], t[2], swiz4);
|
1094
|
+
s[2] = vec_perm(t[1], t[3], swiz3);
|
1095
|
+
s[3] = vec_perm(t[1], t[3], swiz4);
|
1096
|
+
for (int i = 0; i < 4; ++i)
|
1097
|
+
vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
|
1098
|
+
} else if (numVec == 8) {
|
1099
|
+
for (int i = 0; i < 4; i += 2) {
|
1100
|
+
t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
|
1101
|
+
t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
|
1102
|
+
}
|
1103
|
+
for (int i = 4; i < 8; i += 2) {
|
1104
|
+
t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
|
1105
|
+
t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
|
1106
|
+
}
|
1107
|
+
s[0] = vec_perm(t[0], t[2], swiz3);
|
1108
|
+
s[1] = vec_perm(t[0], t[2], swiz4);
|
1109
|
+
s[2] = vec_perm(t[1], t[3], swiz3);
|
1110
|
+
s[3] = vec_perm(t[1], t[3], swiz4);
|
1111
|
+
s[4] = vec_perm(t[4], t[6], swiz3);
|
1112
|
+
s[5] = vec_perm(t[4], t[6], swiz4);
|
1113
|
+
s[6] = vec_perm(t[5], t[7], swiz3);
|
1114
|
+
s[7] = vec_perm(t[5], t[7], swiz4);
|
1115
|
+
for (int i = 0; i < 8; ++i)
|
1116
|
+
vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
|
1117
|
+
}
|
1118
|
+
}
|
1119
|
+
|
1120
|
+
void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
|
1121
|
+
int64_t i, j;
|
1122
|
+
TA *aoffset = NULL;
|
1123
|
+
unsigned char *vecOffset = NULL;
|
1124
|
+
TA * aoffsets[8];
|
1125
|
+
vector unsigned char c_arr[8];
|
1126
|
+
aoffset = const_cast<TA*>(a);
|
1127
|
+
vecOffset = vec;
|
1128
|
+
j = (rows >> 3);
|
1129
|
+
if (j > 0) {
|
1130
|
+
do {
|
1131
|
+
if (cols == 4) {
|
1132
|
+
aoffsets[0] = aoffset;
|
1133
|
+
for (int it = 1; it < 4; ++it)
|
1134
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
1135
|
+
aoffset += 4 * lda;
|
1136
|
+
for (int i = 0; i < 4; ++i)
|
1137
|
+
c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
|
1138
|
+
vector_permute_store(c_arr, 4, vecOffset);
|
1139
|
+
for (int i = 0; i<4; i++)
|
1140
|
+
aoffsets[i] = aoffsets[i]+lda;
|
1141
|
+
vecOffset +=64;
|
1142
|
+
}
|
1143
|
+
i = (cols >> 3);
|
1144
|
+
if (i > 0) {
|
1145
|
+
aoffsets[0] = aoffset;
|
1146
|
+
for (int it = 1; it < 8; ++it) {
|
1147
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
1148
|
+
}
|
1149
|
+
aoffset += 8 * lda;
|
1150
|
+
do {
|
1151
|
+
for (int it = 0; it < 8; ++it)
|
1152
|
+
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
1153
|
+
vector_permute_store(c_arr, 8, vecOffset);
|
1154
|
+
for (int it = 0; it < 8; ++it)
|
1155
|
+
aoffsets[it] = aoffsets[it] + 8*lda;
|
1156
|
+
vecOffset += 128;
|
1157
|
+
i--;
|
1158
|
+
} while(i > 0);
|
1159
|
+
}
|
1160
|
+
j--;
|
1161
|
+
} while(j > 0);
|
1162
|
+
}
|
1163
|
+
if (rows & 4) {
|
1164
|
+
aoffsets[0] = aoffset;
|
1165
|
+
for (int it = 1; it < 4; ++it)
|
1166
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
1167
|
+
aoffset += 4 * lda;
|
1168
|
+
if (cols == 4) {
|
1169
|
+
for (int it = 0; it < 4; ++it)
|
1170
|
+
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
1171
|
+
vector_permute_store(c_arr, 2, vecOffset);
|
1172
|
+
for (int it = 0; it< 4; it++)
|
1173
|
+
aoffsets[it] = aoffsets[it] + lda;
|
1174
|
+
vecOffset += 32;
|
1175
|
+
}
|
1176
|
+
i = (cols >> 3);
|
1177
|
+
if (i > 0) {
|
1178
|
+
do {
|
1179
|
+
for (int it = 0; it < 4; ++it)
|
1180
|
+
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
1181
|
+
vector_permute_store(c_arr, 4, vecOffset);
|
1182
|
+
for (int it = 0; it< 4; it++)
|
1183
|
+
aoffsets[it] = aoffsets[it] + 8*lda;
|
1184
|
+
vecOffset += 64;
|
1185
|
+
i--;
|
1186
|
+
} while(i > 0);
|
1187
|
+
}
|
1188
|
+
}
|
1189
|
+
if (rows & 3) {
|
1190
|
+
aoffsets[0] = aoffset;
|
1191
|
+
for (int it = 1; it < 4; ++it)
|
1192
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
1193
|
+
if (cols == 4) {
|
1194
|
+
switch(rows) {
|
1195
|
+
case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
|
1196
|
+
case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
|
1197
|
+
case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
|
1198
|
+
break;
|
1199
|
+
}
|
1200
|
+
vector_permute_store(c_arr, 2, vecOffset);
|
1201
|
+
for (int it = 0; it< 4; it++)
|
1202
|
+
aoffsets[it] = aoffsets[it] + lda;
|
1203
|
+
vecOffset += 32;
|
1204
|
+
}
|
1205
|
+
i = (cols >> 3);
|
1206
|
+
if (i > 0) {
|
1207
|
+
do {
|
1208
|
+
switch(rows) {
|
1209
|
+
case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
|
1210
|
+
case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
|
1211
|
+
case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
|
1212
|
+
break;
|
1213
|
+
}
|
1214
|
+
vector_permute_store(c_arr, 4, vecOffset);
|
1215
|
+
for (int it = 0; it <4; it++)
|
1216
|
+
aoffsets[it] = aoffsets[it] + 8* lda;
|
1217
|
+
vecOffset += 64;
|
1218
|
+
i--;
|
1219
|
+
} while(i > 0);
|
1220
|
+
}
|
1221
|
+
}
|
1222
|
+
}
|
1223
|
+
|
1224
|
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
1225
|
+
int64_t mc, nc, mp, np;
|
1226
|
+
int m_rem = MIN(m - m0, 8);
|
1227
|
+
int n_rem = MIN(n - n0, 8);
|
1228
|
+
|
1229
|
+
if (m_rem >= 8 && n_rem >= 8) {
|
1230
|
+
mc = 8;
|
1231
|
+
nc = 8;
|
1232
|
+
gemm<8,8>(m0, m, n0, n);
|
1233
|
+
} else if (m_rem >= 4 && n_rem >= 8) {
|
1234
|
+
mc = 4;
|
1235
|
+
nc = 8;
|
1236
|
+
gemm<4,8>(m0, m, n0, n);
|
1237
|
+
} else if (m_rem >=8 && n_rem >=4){
|
1238
|
+
mc = 8;
|
1239
|
+
nc = 4;
|
1240
|
+
gemm<8,4>(m0, m, n0, n);
|
1241
|
+
} else if ((m_rem < 4) && (n_rem >= 8)) {
|
1242
|
+
nc = 8;
|
1243
|
+
switch(m_rem) {
|
1244
|
+
case 1:
|
1245
|
+
mc = 1;
|
1246
|
+
gemm_Mx8<1>(m0, m, n0, n);
|
1247
|
+
break;
|
1248
|
+
case 2:
|
1249
|
+
mc = 2;
|
1250
|
+
gemm_Mx8<2>(m0, m, n0, n);
|
1251
|
+
break;
|
1252
|
+
case 3:
|
1253
|
+
mc = 3;
|
1254
|
+
gemm_Mx8<3>(m0, m, n0, n);
|
1255
|
+
break;
|
1256
|
+
default:
|
1257
|
+
return;
|
1258
|
+
}
|
1259
|
+
} else if (m_rem >= 4 && n_rem >= 4) {
|
1260
|
+
mc = 4;
|
1261
|
+
nc = 4;
|
1262
|
+
gemm_small<4, 4>(m0, m, n0, n);
|
1263
|
+
} else if ((m_rem > 4) && (n_rem < 4)) {
|
1264
|
+
mc = 4;
|
1265
|
+
switch(n_rem) {
|
1266
|
+
case 1:
|
1267
|
+
nc = 1;
|
1268
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
1269
|
+
break;
|
1270
|
+
case 2:
|
1271
|
+
nc = 2;
|
1272
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
1273
|
+
break;
|
1274
|
+
case 3:
|
1275
|
+
nc = 3;
|
1276
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
1277
|
+
break;
|
1278
|
+
|
1279
|
+
default:
|
1280
|
+
return;
|
1281
|
+
}
|
1282
|
+
} else {
|
1283
|
+
switch((m_rem << 4) | n_rem) {
|
1284
|
+
case 0x43:
|
1285
|
+
mc = 4;
|
1286
|
+
nc = 3;
|
1287
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
1288
|
+
break;
|
1289
|
+
case 0x42:
|
1290
|
+
mc = 4;
|
1291
|
+
nc = 2;
|
1292
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
1293
|
+
break;
|
1294
|
+
case 0x41:
|
1295
|
+
mc = 4;
|
1296
|
+
nc = 1;
|
1297
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
1298
|
+
break;
|
1299
|
+
case 0x34:
|
1300
|
+
mc = 3;
|
1301
|
+
nc = 4;
|
1302
|
+
gemm_small<3, 4>(m0, m, n0, n);
|
1303
|
+
break;
|
1304
|
+
case 0x33:
|
1305
|
+
mc = 3;
|
1306
|
+
nc = 3;
|
1307
|
+
gemm_small<3, 3>(m0, m, n0, n);
|
1308
|
+
break;
|
1309
|
+
case 0x32:
|
1310
|
+
mc = 3;
|
1311
|
+
nc = 2;
|
1312
|
+
gemm_small<3, 2>(m0, m, n0, n);
|
1313
|
+
break;
|
1314
|
+
case 0x31:
|
1315
|
+
mc = 3;
|
1316
|
+
nc = 1;
|
1317
|
+
gemm_small<3, 1>(m0, m, n0, n);
|
1318
|
+
break;
|
1319
|
+
case 0x24:
|
1320
|
+
mc = 2;
|
1321
|
+
nc = 4;
|
1322
|
+
gemm_small<2,4>(m0, m, n0, n);
|
1323
|
+
break;
|
1324
|
+
case 0x23:
|
1325
|
+
mc = 2;
|
1326
|
+
nc = 3;
|
1327
|
+
gemm_small<2, 3>(m0, m, n0, n);
|
1328
|
+
break;
|
1329
|
+
case 0x22:
|
1330
|
+
mc = 2;
|
1331
|
+
nc = 2;
|
1332
|
+
gemm_small<2, 2>(m0, m, n0, n);
|
1333
|
+
break;
|
1334
|
+
case 0x21:
|
1335
|
+
mc = 2;
|
1336
|
+
nc = 1;
|
1337
|
+
gemm_small<2, 1>(m0, m, n0, n);
|
1338
|
+
break;
|
1339
|
+
case 0x14:
|
1340
|
+
mc = 1;
|
1341
|
+
nc = 4;
|
1342
|
+
gemm_small<1, 4>(m0, m, n0, n);
|
1343
|
+
break;
|
1344
|
+
case 0x13:
|
1345
|
+
mc = 1;
|
1346
|
+
nc = 3;
|
1347
|
+
gemm_small<1, 3>(m0, m, n0, n);
|
1348
|
+
break;
|
1349
|
+
case 0x12:
|
1350
|
+
mc = 1;
|
1351
|
+
nc = 2;
|
1352
|
+
gemm_small<1, 2>(m0, m, n0, n);
|
1353
|
+
break;
|
1354
|
+
case 0x11:
|
1355
|
+
mc = 1;
|
1356
|
+
nc = 1;
|
1357
|
+
gemm_small<1, 1>(m0, m, n0, n);
|
1358
|
+
break;
|
1359
|
+
default:
|
1360
|
+
return;
|
1361
|
+
}
|
1362
|
+
}
|
1363
|
+
mp = m0 + (m - m0) / mc * mc;
|
1364
|
+
np = n0 + (n - n0) / nc * nc;
|
1365
|
+
mnpack(mp, m, n0, np);
|
1366
|
+
mnpack(m0, m, np, n);
|
1367
|
+
}
|
1368
|
+
|
1369
|
+
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
1370
|
+
vec_t vec_A[4], vec_B[8] , vec_C[4];
|
1371
|
+
acc_t acc_0, acc_1;
|
1372
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
1373
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
1374
|
+
for (int l = 0; l < k; l+=8) {
|
1375
|
+
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
|
1376
|
+
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
|
1377
|
+
for (int x = 0; x < 4; x++) {
|
1378
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
1379
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
1380
|
+
}
|
1381
|
+
}
|
1382
|
+
SAVE_ACC(&acc_0, ii, jj);
|
1383
|
+
SAVE_ACC(&acc_1, ii, jj+4);
|
1384
|
+
}
|
1385
|
+
|
1386
|
+
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
1387
|
+
vec_t vec_A[8], vec_B[4] , vec_C[4];
|
1388
|
+
acc_t acc_0, acc_1;
|
1389
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
1390
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
1391
|
+
for (int l = 0; l < k; l+=8) {
|
1392
|
+
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
|
1393
|
+
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
1394
|
+
for (int x = 0; x < 4; x++) {
|
1395
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
1396
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
|
1397
|
+
}
|
1398
|
+
}
|
1399
|
+
SAVE_ACC(&acc_0, ii, jj);
|
1400
|
+
SAVE_ACC(&acc_1, ii+4, jj);
|
1401
|
+
}
|
1402
|
+
|
1403
|
+
|
1404
|
+
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
1405
|
+
vec_t vec_A[8], vec_B[8], vec_C[4];
|
1406
|
+
acc_t acc_0, acc_1, acc_2, acc_3;
|
1407
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
1408
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
1409
|
+
__builtin_mma_xxsetaccz(&acc_2);
|
1410
|
+
__builtin_mma_xxsetaccz(&acc_3);
|
1411
|
+
for (int l = 0; l < k; l+=8) {
|
1412
|
+
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
|
1413
|
+
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
|
1414
|
+
for (int x = 0; x < 4; x++) {
|
1415
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
1416
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
|
1417
|
+
__builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
|
1418
|
+
__builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
|
1419
|
+
}
|
1420
|
+
}
|
1421
|
+
|
1422
|
+
SAVE_ACC(&acc_0, ii, jj);
|
1423
|
+
SAVE_ACC(&acc_1, ii, jj+4);
|
1424
|
+
SAVE_ACC(&acc_2, ii+4, jj);
|
1425
|
+
SAVE_ACC(&acc_3, ii+4, jj+4);
|
1426
|
+
}
|
1427
|
+
|
1428
|
+
template<int RM, int RN>
|
1429
|
+
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
1430
|
+
int64_t ytiles = (m - m0) / RM;
|
1431
|
+
int64_t xtiles = (n - n0) / RN;
|
1432
|
+
int64_t tiles = xtiles * ytiles;
|
1433
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
1434
|
+
int64_t start = duty * ith;
|
1435
|
+
int64_t end = start + duty;
|
1436
|
+
if (end > tiles)
|
1437
|
+
end = tiles;
|
1438
|
+
for (int64_t job = start; job < end; ++job) {
|
1439
|
+
int64_t ii = m0 + job / xtiles * RM;
|
1440
|
+
int64_t jj = n0 + job % xtiles * RN;
|
1441
|
+
vec_t vec_C[4];
|
1442
|
+
acc_t acc_0;
|
1443
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
1444
|
+
vec_t vec_A[2], vec_B[2];
|
1445
|
+
for (int l=0; l<k; l+=4) {
|
1446
|
+
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
|
1447
|
+
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
|
1448
|
+
for (int x = 0; x<2; x++) {
|
1449
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
1450
|
+
}
|
1451
|
+
}
|
1452
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
1453
|
+
for (int I = 0; I < RM; I++) {
|
1454
|
+
for (int J = 0; J < RN; J++) {
|
1455
|
+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
1456
|
+
}
|
1457
|
+
}
|
1458
|
+
}
|
1459
|
+
}
|
1460
|
+
|
1461
|
+
template<int RM>
|
1462
|
+
void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
1463
|
+
int RN = 8;
|
1464
|
+
int64_t ytiles = (m - m0) / RM;
|
1465
|
+
int64_t xtiles = (n - n0) / RN;
|
1466
|
+
int64_t tiles = xtiles * ytiles;
|
1467
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
1468
|
+
int64_t start = duty * ith;
|
1469
|
+
int64_t end = start + duty;
|
1470
|
+
if (end > tiles)
|
1471
|
+
end = tiles;
|
1472
|
+
for (int64_t job = start; job < end; ++job) {
|
1473
|
+
int64_t ii = m0 + job / xtiles * RM;
|
1474
|
+
int64_t jj = n0 + job % xtiles * RN;
|
1475
|
+
vec_t vec_C[4];
|
1476
|
+
acc_t acc_0, acc_1;
|
1477
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
1478
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
1479
|
+
vec_t vec_A[4], vec_B[8];
|
1480
|
+
for (int l=0; l<k; l+=8) {
|
1481
|
+
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
|
1482
|
+
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
|
1483
|
+
for (int x = 0; x<4; x++) {
|
1484
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
1485
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
1486
|
+
}
|
1487
|
+
}
|
1488
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
1489
|
+
for (int I = 0; I < RM; I++) {
|
1490
|
+
for (int J = 0; J < 4; J++) {
|
1491
|
+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
1492
|
+
}
|
1493
|
+
}
|
1494
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_1);
|
1495
|
+
for (int I = 0; I < RM; I++) {
|
1496
|
+
for (int J = 0; J < 4; J++) {
|
1497
|
+
*((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
1498
|
+
}
|
1499
|
+
}
|
1500
|
+
}
|
1501
|
+
}
|
1502
|
+
|
1503
|
+
template<int RM, int RN>
|
1504
|
+
inline void kernel(int64_t ii, int64_t jj) {
|
1505
|
+
if constexpr(RM == 4 && RN == 8) {
|
1506
|
+
KERNEL_4x8(ii,jj);
|
1507
|
+
} else if constexpr(RM == 8 && RN == 8) {
|
1508
|
+
KERNEL_8x8(ii,jj);
|
1509
|
+
} else if constexpr(RM == 8 && RN == 4) {
|
1510
|
+
KERNEL_8x4(ii,jj);
|
1511
|
+
} else {
|
1512
|
+
static_assert(false, "RN/RM values not supported");
|
1513
|
+
}
|
1514
|
+
}
|
1515
|
+
|
1516
|
+
template <int RM, int RN>
|
1517
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
1518
|
+
int64_t ytiles = (m - m0) / RM;
|
1519
|
+
int64_t xtiles = (n - n0) / RN;
|
1520
|
+
int64_t tiles = xtiles * ytiles;
|
1521
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
1522
|
+
int64_t start = duty * ith;
|
1523
|
+
int64_t end = start + duty;
|
1524
|
+
if (end > tiles)
|
1525
|
+
end = tiles;
|
1526
|
+
for (int64_t job = start; job < end; ++job) {
|
1527
|
+
int64_t ii = m0 + job / xtiles * RM;
|
1528
|
+
int64_t jj = n0 + job % xtiles * RN;
|
1529
|
+
kernel<RM, RN>(ii, jj);
|
1530
|
+
}
|
1531
|
+
}
|
1532
|
+
|
1533
|
+
const TA *const A;
|
1534
|
+
const TB *const B;
|
1535
|
+
TC *C;
|
1536
|
+
const int64_t k;
|
1537
|
+
const int64_t lda;
|
1538
|
+
const int64_t ldb;
|
1539
|
+
const int64_t ldc;
|
1540
|
+
const int ith;
|
1541
|
+
const int nth;
|
1542
|
+
};
|
1543
|
+
|
1057
1544
|
template <typename TA, typename TB, typename TC>
|
1058
1545
|
class tinyBLAS_Q0_PPC {
|
1059
1546
|
public:
|
@@ -2202,6 +2689,7 @@ class tinyBLAS_PPC {
|
|
2202
2689
|
boffset = vec;
|
2203
2690
|
j = (rows >> 3);
|
2204
2691
|
if (j > 0) {
|
2692
|
+
|
2205
2693
|
do {
|
2206
2694
|
aoffset1 = aoffset;
|
2207
2695
|
aoffset2 = aoffset1 + lda;
|
@@ -2875,9 +3363,22 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
|
|
2875
3363
|
(float *)C, ldc};
|
2876
3364
|
return tb.matmul(m, n);
|
2877
3365
|
}
|
3366
|
+
#elif defined(__MMA__)
|
3367
|
+
if ((k % 8))
|
3368
|
+
return false;
|
3369
|
+
if(Btype == LM_GGML_TYPE_BF16) {
|
3370
|
+
tinyBLAS_BF16_PPC<lm_ggml_bf16_t, lm_ggml_bf16_t, float> tb{ k,
|
3371
|
+
(const lm_ggml_bf16_t *)A, lda,
|
3372
|
+
(const lm_ggml_bf16_t *)B, ldb,
|
3373
|
+
(float *)C, ldc,
|
3374
|
+
params->ith, params->nth};
|
3375
|
+
tb.matmul(m, n);
|
3376
|
+
return true;
|
3377
|
+
}
|
2878
3378
|
#endif
|
2879
3379
|
return false;
|
2880
3380
|
}
|
3381
|
+
|
2881
3382
|
case LM_GGML_TYPE_F16: {
|
2882
3383
|
#if defined(__AVX512F__)
|
2883
3384
|
if (Btype == LM_GGML_TYPE_F16) {
|
@@ -341,7 +341,7 @@ static inline void __avx_f32cx8_store(lm_ggml_fp16_t *x, __m256 y) {
|
|
341
341
|
#define LM_GGML_F32_EPR 4
|
342
342
|
|
343
343
|
#define LM_GGML_F32x4 vector float
|
344
|
-
#define LM_GGML_F32x4_ZERO 0.0f
|
344
|
+
#define LM_GGML_F32x4_ZERO {0.0f}
|
345
345
|
#define LM_GGML_F32x4_SET1 vec_splats
|
346
346
|
#define LM_GGML_F32x4_LOAD(p) vec_xl(0, p)
|
347
347
|
#define LM_GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
|
@@ -855,13 +855,17 @@ static inline __vector float __lzs_f16cx4_load(const lm_ggml_fp16_t * x) {
|
|
855
855
|
tmp[i] = LM_GGML_FP16_TO_FP32(x[i]);
|
856
856
|
}
|
857
857
|
|
858
|
-
|
858
|
+
// note: keep type-cast here to prevent compiler bugs
|
859
|
+
// see: https://github.com/ggml-org/llama.cpp/issues/12846
|
860
|
+
return vec_xl(0, (const float *)(tmp));
|
859
861
|
}
|
860
862
|
|
861
863
|
static inline void __lzs_f16cx4_store(lm_ggml_fp16_t * x, __vector float y) {
|
862
864
|
float arr[4];
|
863
865
|
|
864
|
-
|
866
|
+
// note: keep type-cast here to prevent compiler bugs
|
867
|
+
// see: https://github.com/ggml-org/llama.cpp/issues/12846
|
868
|
+
vec_xst(y, 0, (float *)(arr));
|
865
869
|
|
866
870
|
for (int i = 0; i < 4; i++) {
|
867
871
|
x[i] = LM_GGML_FP32_TO_FP16(arr[i]);
|
package/cpp/ggml-cpu.h
CHANGED
@@ -133,6 +133,11 @@ extern "C" {
|
|
133
133
|
|
134
134
|
LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void);
|
135
135
|
|
136
|
+
LM_GGML_BACKEND_API void lm_ggml_cpu_fp32_to_fp16(const float *, lm_ggml_fp16_t *, int64_t);
|
137
|
+
LM_GGML_BACKEND_API void lm_ggml_cpu_fp16_to_fp32(const lm_ggml_fp16_t *, float *, int64_t);
|
138
|
+
LM_GGML_BACKEND_API void lm_ggml_cpu_fp32_to_bf16(const float *, lm_ggml_bf16_t *, int64_t);
|
139
|
+
LM_GGML_BACKEND_API void lm_ggml_cpu_bf16_to_fp32(const lm_ggml_bf16_t *, float *, int64_t);
|
140
|
+
|
136
141
|
#ifdef __cplusplus
|
137
142
|
}
|
138
143
|
#endif
|
package/cpp/ggml-impl.h
CHANGED
@@ -16,6 +16,14 @@
|
|
16
16
|
#include <arm_sve.h>
|
17
17
|
#endif // __ARM_FEATURE_SVE
|
18
18
|
|
19
|
+
#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
|
20
|
+
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
21
|
+
//
|
22
|
+
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
23
|
+
//
|
24
|
+
#include <arm_neon.h>
|
25
|
+
#endif
|
26
|
+
|
19
27
|
#if defined(__F16C__)
|
20
28
|
#include <immintrin.h>
|
21
29
|
#endif
|
@@ -140,8 +148,14 @@ struct lm_ggml_map_custom2_op_params {
|
|
140
148
|
|
141
149
|
struct lm_ggml_map_custom3_op_params {
|
142
150
|
lm_ggml_custom3_op_t fun;
|
143
|
-
int
|
144
|
-
void
|
151
|
+
int n_tasks;
|
152
|
+
void * userdata;
|
153
|
+
};
|
154
|
+
|
155
|
+
struct lm_ggml_custom_op_params {
|
156
|
+
lm_ggml_custom_op_t fun;
|
157
|
+
int n_tasks;
|
158
|
+
void * userdata;
|
145
159
|
};
|
146
160
|
|
147
161
|
// bitset
|
@@ -311,13 +325,6 @@ LM_GGML_API void lm_ggml_aligned_free(void * ptr, size_t size);
|
|
311
325
|
// for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843
|
312
326
|
//
|
313
327
|
#if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
|
314
|
-
|
315
|
-
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
316
|
-
//
|
317
|
-
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
318
|
-
//
|
319
|
-
#include <arm_neon.h>
|
320
|
-
|
321
328
|
#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x)
|
322
329
|
#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x)
|
323
330
|
|
Binary file
|
package/cpp/ggml-llama.metallib
CHANGED
Binary file
|