@novastera-oss/llamarn 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +134 -36
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -2
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +30 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +50 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +134 -36
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
- package/cpp/llama.cpp/src/llama-batch.h +36 -11
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +313 -213
- package/cpp/llama.cpp/src/llama-context.h +16 -12
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
- package/cpp/llama.cpp/src/llama-graph.h +90 -34
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
- package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +64 -23
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +726 -141
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/llama.h +134 -36
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
|
@@ -993,31 +993,61 @@ kernel void kernel_neg(
|
|
|
993
993
|
dst[tpig] = -src0[tpig];
|
|
994
994
|
}
|
|
995
995
|
|
|
996
|
+
template <bool norm>
|
|
996
997
|
kernel void kernel_sum_rows(
|
|
998
|
+
constant ggml_metal_kargs_sum_rows & args,
|
|
997
999
|
device const float * src0,
|
|
998
1000
|
device float * dst,
|
|
999
|
-
|
|
1000
|
-
uint3
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1001
|
+
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
|
1002
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1003
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1004
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
1005
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1006
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1007
|
+
int64_t i3 = tgpig.z;
|
|
1008
|
+
int64_t i2 = tgpig.y;
|
|
1009
|
+
int64_t i1 = tgpig.x;
|
|
1004
1010
|
|
|
1005
1011
|
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
|
1006
1012
|
return;
|
|
1007
1013
|
}
|
|
1008
1014
|
|
|
1015
|
+
if (sgitg == 0) {
|
|
1016
|
+
shmem_f32[tiisg] = 0.0f;
|
|
1017
|
+
}
|
|
1018
|
+
|
|
1009
1019
|
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
|
1010
1020
|
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
|
1011
1021
|
|
|
1012
|
-
float
|
|
1022
|
+
float sumf = 0;
|
|
1023
|
+
|
|
1024
|
+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
1025
|
+
sumf += src_row[i0];
|
|
1026
|
+
}
|
|
1027
|
+
|
|
1028
|
+
sumf = simd_sum(sumf);
|
|
1029
|
+
|
|
1030
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1013
1031
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1032
|
+
if (tiisg == 0) {
|
|
1033
|
+
shmem_f32[sgitg] = sumf;
|
|
1016
1034
|
}
|
|
1017
1035
|
|
|
1018
|
-
|
|
1036
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1037
|
+
|
|
1038
|
+
sumf = shmem_f32[tiisg];
|
|
1039
|
+
sumf = simd_sum(sumf);
|
|
1040
|
+
|
|
1041
|
+
if (tpitg.x == 0) {
|
|
1042
|
+
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
|
1043
|
+
}
|
|
1019
1044
|
}
|
|
1020
1045
|
|
|
1046
|
+
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
|
1047
|
+
|
|
1048
|
+
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
|
1049
|
+
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
|
1050
|
+
|
|
1021
1051
|
template<typename T>
|
|
1022
1052
|
kernel void kernel_soft_max(
|
|
1023
1053
|
device const char * src0,
|
|
@@ -3328,14 +3358,12 @@ kernel void kernel_flash_attn_ext(
|
|
|
3328
3358
|
constexpr short NW = N_SIMDWIDTH;
|
|
3329
3359
|
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
|
3330
3360
|
|
|
3331
|
-
const short TS = nsg*SH;
|
|
3332
|
-
const short T = DK + 2*TS; // shared memory size per query in (half)
|
|
3361
|
+
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
|
3362
|
+
const short T = 2*DK + 2*TS; // shared memory size per query in (half)
|
|
3333
3363
|
|
|
3334
|
-
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
|
3335
|
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
|
3336
|
-
threadgroup
|
|
3337
|
-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
|
3338
|
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
|
3364
|
+
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
|
3365
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
|
3366
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
|
3339
3367
|
|
|
3340
3368
|
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
|
3341
3369
|
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
|
@@ -3354,7 +3382,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
3354
3382
|
if (iq1 + j < args.ne01) {
|
|
3355
3383
|
sq4[j*DK4 + i] = (q4_t) q4[i];
|
|
3356
3384
|
} else {
|
|
3357
|
-
sq4[j*DK4 + i] =
|
|
3385
|
+
sq4[j*DK4 + i] = 0;
|
|
3358
3386
|
}
|
|
3359
3387
|
}
|
|
3360
3388
|
}
|
|
@@ -3548,20 +3576,20 @@ kernel void kernel_flash_attn_ext(
|
|
|
3548
3576
|
|
|
3549
3577
|
// O = diag(ms)*O
|
|
3550
3578
|
{
|
|
3551
|
-
s8x8_t
|
|
3552
|
-
simdgroup_load(
|
|
3579
|
+
s8x8_t ms;
|
|
3580
|
+
simdgroup_load(ms, ss + 2*C, TS, 0, false);
|
|
3553
3581
|
|
|
3554
3582
|
#pragma unroll(DV8)
|
|
3555
3583
|
for (short i = 0; i < DV8; ++i) {
|
|
3556
|
-
simdgroup_multiply(lo[i],
|
|
3584
|
+
simdgroup_multiply(lo[i], ms, lo[i]);
|
|
3557
3585
|
}
|
|
3558
3586
|
}
|
|
3559
3587
|
|
|
3560
3588
|
// O = O + (Q*K^T)*V
|
|
3561
3589
|
{
|
|
3562
3590
|
for (short cc = 0; cc < C/8; ++cc) {
|
|
3563
|
-
s8x8_t
|
|
3564
|
-
simdgroup_load(
|
|
3591
|
+
s8x8_t vs;
|
|
3592
|
+
simdgroup_load(vs, ss + 8*cc, TS, 0, false);
|
|
3565
3593
|
|
|
3566
3594
|
if (is_same<vd4x4_t, v4x4_t>::value) {
|
|
3567
3595
|
// we can read directly from global memory
|
|
@@ -3572,7 +3600,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
3572
3600
|
v8x8_t mv;
|
|
3573
3601
|
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
|
|
3574
3602
|
|
|
3575
|
-
simdgroup_multiply_accumulate(lo[i],
|
|
3603
|
+
simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
|
|
3576
3604
|
}
|
|
3577
3605
|
} else {
|
|
3578
3606
|
for (short ii = 0; ii < DV16; ii += 4) {
|
|
@@ -3593,10 +3621,10 @@ kernel void kernel_flash_attn_ext(
|
|
|
3593
3621
|
v8x8_t mv;
|
|
3594
3622
|
|
|
3595
3623
|
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
|
3596
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0],
|
|
3624
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
|
|
3597
3625
|
|
|
3598
3626
|
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
|
3599
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1],
|
|
3627
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
|
|
3600
3628
|
}
|
|
3601
3629
|
} else {
|
|
3602
3630
|
if (ii + tx < DV16) {
|
|
@@ -3611,10 +3639,10 @@ kernel void kernel_flash_attn_ext(
|
|
|
3611
3639
|
v8x8_t mv;
|
|
3612
3640
|
|
|
3613
3641
|
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
|
3614
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0],
|
|
3642
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
|
|
3615
3643
|
|
|
3616
3644
|
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
|
3617
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1],
|
|
3645
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
|
|
3618
3646
|
}
|
|
3619
3647
|
}
|
|
3620
3648
|
}
|
|
@@ -3624,93 +3652,89 @@ kernel void kernel_flash_attn_ext(
|
|
|
3624
3652
|
}
|
|
3625
3653
|
|
|
3626
3654
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
|
3627
|
-
for (short j =
|
|
3628
|
-
|
|
3629
|
-
|
|
3630
|
-
ss[j*TS + 1] = M[j];
|
|
3631
|
-
}
|
|
3655
|
+
for (short j = tiisg; j < Q; j += NW) {
|
|
3656
|
+
ss[j*TS + 0] = S[j];
|
|
3657
|
+
ss[j*TS + 1] = M[j];
|
|
3632
3658
|
}
|
|
3633
3659
|
}
|
|
3634
3660
|
|
|
3635
|
-
|
|
3636
|
-
for (ushort sg = 1; sg < nsg; ++sg) {
|
|
3637
|
-
float S = { 0.0f };
|
|
3638
|
-
float M = { -__FLT_MAX__/2 };
|
|
3661
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3639
3662
|
|
|
3640
|
-
|
|
3663
|
+
threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
|
3664
|
+
threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
|
|
3641
3665
|
|
|
3642
|
-
|
|
3643
|
-
|
|
3644
|
-
|
|
3645
|
-
|
|
3646
|
-
|
|
3666
|
+
// store result to shared memory in F32
|
|
3667
|
+
if (sgitg == 0) {
|
|
3668
|
+
for (short i = 0; i < DV8; ++i) {
|
|
3669
|
+
//simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
|
3670
|
+
simdgroup_float8x8 t(1.0f);
|
|
3671
|
+
simdgroup_multiply(t, lo[i], t);
|
|
3672
|
+
simdgroup_store(t, so + i*8, DV, 0, false);
|
|
3647
3673
|
}
|
|
3674
|
+
}
|
|
3648
3675
|
|
|
3649
|
-
|
|
3676
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3650
3677
|
|
|
3651
|
-
|
|
3652
|
-
|
|
3653
|
-
|
|
3654
|
-
|
|
3655
|
-
const float
|
|
3678
|
+
// reduce the warps sequentially
|
|
3679
|
+
for (ushort sg = 1; sg < nsg; ++sg) {
|
|
3680
|
+
if (sgitg == sg) {
|
|
3681
|
+
for (short j = tiisg; j < Q; j += NW) {
|
|
3682
|
+
const float S0 = ss[j*TS - 1*SH + 0];
|
|
3683
|
+
const float S1 = ss[j*TS + 0];
|
|
3656
3684
|
|
|
3657
|
-
const float M0 = ss[j*TS +
|
|
3658
|
-
const float M1 = ss[j*TS
|
|
3685
|
+
const float M0 = ss[j*TS - 1*SH + 1];
|
|
3686
|
+
const float M1 = ss[j*TS + 1];
|
|
3659
3687
|
|
|
3660
|
-
M = max(M0, M1);
|
|
3688
|
+
const float M = max(M0, M1);
|
|
3661
3689
|
|
|
3662
|
-
|
|
3663
|
-
|
|
3690
|
+
float ms0 = exp(M0 - M);
|
|
3691
|
+
float ms1 = exp(M1 - M);
|
|
3664
3692
|
|
|
3665
|
-
S = S0*ms0 + S1*ms1;
|
|
3693
|
+
const float S = S0*ms0 + S1*ms1;
|
|
3666
3694
|
|
|
3667
|
-
|
|
3668
|
-
|
|
3669
|
-
ss[j*TS + 1] = M;
|
|
3695
|
+
ss[j*TS + 0] = S;
|
|
3696
|
+
ss[j*TS + 1] = M;
|
|
3670
3697
|
|
|
3671
|
-
|
|
3672
|
-
|
|
3673
|
-
}
|
|
3698
|
+
ss[j*TS + 2*C + j - 1*SH] = ms0;
|
|
3699
|
+
ss[j*TS + 2*C + j ] = ms1;
|
|
3674
3700
|
}
|
|
3675
3701
|
|
|
3702
|
+
//simdgroup_barrier(mem_flags::mem_threadgroup);
|
|
3703
|
+
|
|
3676
3704
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
|
3677
3705
|
{
|
|
3678
3706
|
s8x8_t ms0;
|
|
3679
3707
|
s8x8_t ms1;
|
|
3680
3708
|
|
|
3681
|
-
simdgroup_load(ms0, ss + 2*C,
|
|
3682
|
-
simdgroup_load(ms1, ss + 2*C
|
|
3709
|
+
simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
|
|
3710
|
+
simdgroup_load(ms1, ss + 2*C, TS, 0, false);
|
|
3683
3711
|
|
|
3684
3712
|
#pragma unroll(DV8)
|
|
3685
3713
|
for (short i = 0; i < DV8; ++i) {
|
|
3686
|
-
|
|
3714
|
+
simdgroup_float8x8 t;
|
|
3687
3715
|
|
|
3688
3716
|
simdgroup_load (t, so + i*8, DV, 0, false);
|
|
3689
|
-
simdgroup_multiply(t,
|
|
3717
|
+
simdgroup_multiply(t, ms0, t);
|
|
3690
3718
|
|
|
3691
|
-
simdgroup_multiply_accumulate(
|
|
3719
|
+
simdgroup_multiply_accumulate(t, ms1, lo[i], t);
|
|
3720
|
+
simdgroup_store(t, so + i*8, DV, 0, false);
|
|
3692
3721
|
}
|
|
3693
3722
|
}
|
|
3694
3723
|
}
|
|
3695
|
-
}
|
|
3696
3724
|
|
|
3697
|
-
|
|
3698
|
-
if (sgitg == 0) {
|
|
3699
|
-
for (short i = 0; i < DV8; ++i) {
|
|
3700
|
-
simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
|
3701
|
-
}
|
|
3725
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3702
3726
|
}
|
|
3703
3727
|
|
|
3704
|
-
|
|
3728
|
+
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
|
|
3705
3729
|
|
|
3706
3730
|
// final rescale with 1/S and store to global memory
|
|
3707
|
-
|
|
3708
|
-
|
|
3709
|
-
const float S = ss[j*TS + 0];
|
|
3731
|
+
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
|
|
3732
|
+
const float S = 1.0f/sf[j*TS + 0];
|
|
3710
3733
|
|
|
3711
|
-
|
|
3712
|
-
|
|
3713
|
-
|
|
3734
|
+
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
|
|
3735
|
+
|
|
3736
|
+
for (short i = tiisg; i < DV4; i += NW) {
|
|
3737
|
+
dst4[i] = (float4) so4[j*DV4 + i]*S;
|
|
3714
3738
|
}
|
|
3715
3739
|
}
|
|
3716
3740
|
}
|
|
@@ -3719,12 +3743,22 @@ kernel void kernel_flash_attn_ext(
|
|
|
3719
3743
|
// template to be able to explore different combinations
|
|
3720
3744
|
//
|
|
3721
3745
|
#define FA_TYPES \
|
|
3722
|
-
|
|
3723
|
-
half,
|
|
3724
|
-
half,
|
|
3725
|
-
float,
|
|
3726
|
-
float,
|
|
3727
|
-
half,
|
|
3746
|
+
float, float4, simdgroup_float8x8, \
|
|
3747
|
+
half, half4x4, simdgroup_half8x8, \
|
|
3748
|
+
half, half4x4, simdgroup_half8x8, \
|
|
3749
|
+
float, simdgroup_float8x8, \
|
|
3750
|
+
float, simdgroup_float8x8, \
|
|
3751
|
+
half, half4, simdgroup_half8x8
|
|
3752
|
+
//float, float4, simdgroup_float8x8
|
|
3753
|
+
|
|
3754
|
+
#define FA_TYPES_BF \
|
|
3755
|
+
bfloat, bfloat4, simdgroup_bfloat8x8, \
|
|
3756
|
+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
|
3757
|
+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
|
3758
|
+
float, simdgroup_float8x8, \
|
|
3759
|
+
float, simdgroup_float8x8, \
|
|
3760
|
+
half, half4, simdgroup_half8x8
|
|
3761
|
+
//float, float4, simdgroup_float8x8
|
|
3728
3762
|
|
|
3729
3763
|
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
|
3730
3764
|
|
|
@@ -3739,15 +3773,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
|
|
|
3739
3773
|
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
|
3740
3774
|
|
|
3741
3775
|
#if defined(GGML_METAL_USE_BF16)
|
|
3742
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3743
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3744
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3745
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3746
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3747
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3748
|
-
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3749
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3750
|
-
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3776
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
|
3777
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
|
3778
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
|
3779
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
|
3780
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
|
3781
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
|
3782
|
+
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
|
3783
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
|
3784
|
+
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
|
3751
3785
|
#endif
|
|
3752
3786
|
|
|
3753
3787
|
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
|
@@ -3801,6 +3835,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
|
|
|
3801
3835
|
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
|
3802
3836
|
|
|
3803
3837
|
#undef FA_TYPES
|
|
3838
|
+
#undef FA_TYPES_BF
|
|
3804
3839
|
|
|
3805
3840
|
template<
|
|
3806
3841
|
typename q4_t, // query types in shared memory
|
|
@@ -3847,12 +3882,12 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
3847
3882
|
|
|
3848
3883
|
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
|
3849
3884
|
|
|
3850
|
-
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
|
3851
|
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
|
3852
|
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 +
|
|
3853
|
-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 +
|
|
3854
|
-
threadgroup float * sm = (threadgroup float *) (shmem_f16 +
|
|
3855
|
-
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
|
3885
|
+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
|
3886
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
|
3887
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
|
3888
|
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
|
3889
|
+
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
|
3890
|
+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
|
|
3856
3891
|
|
|
3857
3892
|
// store the result for all queries in local memory (the O matrix from the paper)
|
|
3858
3893
|
o4_t lo[DV4/NL];
|
|
@@ -4157,7 +4192,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
4157
4192
|
half4, \
|
|
4158
4193
|
float, \
|
|
4159
4194
|
float, float4, \
|
|
4160
|
-
|
|
4195
|
+
float4
|
|
4161
4196
|
|
|
4162
4197
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
|
4163
4198
|
|
|
@@ -80,6 +80,7 @@ set(GGML_OPENCL_KERNELS
|
|
|
80
80
|
mul_mv_q4_0_f32_1d_8x_flat
|
|
81
81
|
mul_mv_q4_0_f32_1d_16x_flat
|
|
82
82
|
mul_mv_q6_k
|
|
83
|
+
mul_mv_id_q4_0_f32_8x_flat
|
|
83
84
|
mul
|
|
84
85
|
norm
|
|
85
86
|
relu
|
|
@@ -95,6 +96,12 @@ set(GGML_OPENCL_KERNELS
|
|
|
95
96
|
sub
|
|
96
97
|
sum_rows
|
|
97
98
|
transpose
|
|
99
|
+
concat
|
|
100
|
+
tsembd
|
|
101
|
+
upscale
|
|
102
|
+
tanh
|
|
103
|
+
pad
|
|
104
|
+
repeat
|
|
98
105
|
)
|
|
99
106
|
|
|
100
107
|
foreach (K ${GGML_OPENCL_KERNELS})
|