@novastera-oss/llamarn 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +134 -36
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -2
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +30 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +50 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +134 -36
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
- package/cpp/llama.cpp/src/llama-batch.h +36 -11
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +313 -213
- package/cpp/llama.cpp/src/llama-context.h +16 -12
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
- package/cpp/llama.cpp/src/llama-graph.h +90 -34
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
- package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +64 -23
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +726 -141
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/llama.h +134 -36
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
kernel void kernel_concat_f32_contiguous(
|
|
2
|
+
global const char * p_src0, ulong off_src0,
|
|
3
|
+
global const char * p_src1, ulong off_src1,
|
|
4
|
+
global char * p_dst, ulong off_dst,
|
|
5
|
+
int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice
|
|
6
|
+
int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes)
|
|
7
|
+
int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice
|
|
8
|
+
int dim
|
|
9
|
+
) {
|
|
10
|
+
global const float * src0 = (global const float*)((global char*)p_src0 + off_src0);
|
|
11
|
+
global const float * src1 = (global const float*)((global char*)p_src1 + off_src1);
|
|
12
|
+
global float * dst = (global float*)((global char*)p_dst + off_dst);
|
|
13
|
+
|
|
14
|
+
int i0 = get_global_id(0); // Index along dst's 0th dimension
|
|
15
|
+
int i1 = get_global_id(1); // Index along dst's 1st dimension
|
|
16
|
+
int i2 = get_global_id(2); // Index along dst's 2nd dimension
|
|
17
|
+
|
|
18
|
+
if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) {
|
|
19
|
+
return;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0;
|
|
23
|
+
ulong src_idx;
|
|
24
|
+
|
|
25
|
+
if (dim == 0) {
|
|
26
|
+
if (i0 < d_ne00) { // Data from src0
|
|
27
|
+
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
|
|
28
|
+
dst[dst_idx] = src0[src_idx];
|
|
29
|
+
} else { // Data from src1
|
|
30
|
+
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00);
|
|
31
|
+
dst[dst_idx] = src1[src_idx];
|
|
32
|
+
}
|
|
33
|
+
} else if (dim == 1) {
|
|
34
|
+
if (i1 < d_ne01) { // Data from src0
|
|
35
|
+
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
|
|
36
|
+
dst[dst_idx] = src0[src_idx];
|
|
37
|
+
} else { // Data from src1
|
|
38
|
+
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0;
|
|
39
|
+
dst[dst_idx] = src1[src_idx];
|
|
40
|
+
}
|
|
41
|
+
} else if (dim == 2) {
|
|
42
|
+
if (i2 < d_ne02) { // Data from src0
|
|
43
|
+
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
|
|
44
|
+
dst[dst_idx] = src0[src_idx];
|
|
45
|
+
} else { // Data from src1
|
|
46
|
+
|
|
47
|
+
src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0;
|
|
48
|
+
dst[dst_idx] = src1[src_idx];
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
kernel void kernel_concat_f32_non_contiguous(
|
|
54
|
+
global const char * p_src0, ulong off_src0,
|
|
55
|
+
global const char * p_src1, ulong off_src1,
|
|
56
|
+
global char * p_dst, ulong off_dst,
|
|
57
|
+
|
|
58
|
+
long ne00, long ne01, long ne02, long ne03,
|
|
59
|
+
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
|
|
60
|
+
|
|
61
|
+
ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1
|
|
62
|
+
|
|
63
|
+
long d_ne0, long d_ne1, long d_ne2, long d_ne3,
|
|
64
|
+
ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3,
|
|
65
|
+
int dim
|
|
66
|
+
) {
|
|
67
|
+
global const char * src0_base = p_src0 + off_src0;
|
|
68
|
+
global const char * src1_base = p_src1 + off_src1;
|
|
69
|
+
global char * dst_base = p_dst + off_dst;
|
|
70
|
+
|
|
71
|
+
long current_i1 = get_global_id(0); // Index for dst_dim_1
|
|
72
|
+
long current_i2 = get_global_id(1); // Index for dst_dim_2
|
|
73
|
+
long current_i3 = get_global_id(2); // Index for dst_dim_3
|
|
74
|
+
|
|
75
|
+
if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) {
|
|
76
|
+
return;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
global const float * x_val_ptr;
|
|
80
|
+
global float * y_val_ptr;
|
|
81
|
+
|
|
82
|
+
for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) {
|
|
83
|
+
bool use_src0;
|
|
84
|
+
long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3;
|
|
85
|
+
|
|
86
|
+
if (dim == 0) {
|
|
87
|
+
use_src0 = (current_i0 < ne00);
|
|
88
|
+
if (!use_src0) { s_i0 = current_i0 - ne00; }
|
|
89
|
+
} else if (dim == 1) {
|
|
90
|
+
use_src0 = (current_i1 < ne01);
|
|
91
|
+
if (!use_src0) { s_i1 = current_i1 - ne01; }
|
|
92
|
+
} else if (dim == 2) {
|
|
93
|
+
use_src0 = (current_i2 < ne02);
|
|
94
|
+
if (!use_src0) { s_i2 = current_i2 - ne02; }
|
|
95
|
+
} else { // dim == 3
|
|
96
|
+
use_src0 = (current_i3 < ne03);
|
|
97
|
+
if (!use_src0) { s_i3 = current_i3 - ne03; }
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
if (use_src0) {
|
|
101
|
+
x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00);
|
|
102
|
+
} else {
|
|
103
|
+
x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0);
|
|
107
|
+
*y_val_ptr = *x_val_ptr;
|
|
108
|
+
}
|
|
109
|
+
}
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
#ifdef cl_intel_subgroups
|
|
4
|
+
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
|
5
|
+
#else
|
|
6
|
+
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
|
7
|
+
#endif
|
|
8
|
+
|
|
9
|
+
#ifdef cl_intel_required_subgroup_size
|
|
10
|
+
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
|
11
|
+
#define INTEL_GPU 1
|
|
12
|
+
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
|
13
|
+
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
|
14
|
+
#elif defined(cl_qcom_reqd_sub_group_size)
|
|
15
|
+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
|
16
|
+
#define ADRENO_GPU 1
|
|
17
|
+
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
|
18
|
+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
#define QK4_0 32
|
|
22
|
+
|
|
23
|
+
typedef char int8_t;
|
|
24
|
+
typedef uchar uint8_t;
|
|
25
|
+
typedef short int16_t;
|
|
26
|
+
typedef ushort uint16_t;
|
|
27
|
+
typedef int int32_t;
|
|
28
|
+
typedef uint uint32_t;
|
|
29
|
+
|
|
30
|
+
//------------------------------------------------------------------------------
|
|
31
|
+
// block_q4_0
|
|
32
|
+
//------------------------------------------------------------------------------
|
|
33
|
+
struct block_q4_0
|
|
34
|
+
{
|
|
35
|
+
half d;
|
|
36
|
+
uint8_t qs[QK4_0 / 2];
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
// This function requires the original shuffled weights.
|
|
40
|
+
// As a reminder, the original weights are shuffled so that (q[0], q[16]) are
|
|
41
|
+
// packed together in a byte, so are (q[1], q[17]) and so on.
|
|
42
|
+
inline float block_q_4_0_dot_y_flat(
|
|
43
|
+
global uchar * x,
|
|
44
|
+
global half * dh,
|
|
45
|
+
float sumy,
|
|
46
|
+
float16 yl,
|
|
47
|
+
int il
|
|
48
|
+
) {
|
|
49
|
+
float d = *dh;
|
|
50
|
+
global ushort * qs = ((global ushort *)x + il/2);
|
|
51
|
+
float acc = 0.f;
|
|
52
|
+
|
|
53
|
+
acc += yl.s0 * (qs[0] & 0x000F);
|
|
54
|
+
acc += yl.s1 * (qs[0] & 0x0F00);
|
|
55
|
+
acc += yl.s8 * (qs[0] & 0x00F0);
|
|
56
|
+
acc += yl.s9 * (qs[0] & 0xF000);
|
|
57
|
+
|
|
58
|
+
acc += yl.s2 * (qs[1] & 0x000F);
|
|
59
|
+
acc += yl.s3 * (qs[1] & 0x0F00);
|
|
60
|
+
acc += yl.sa * (qs[1] & 0x00F0);
|
|
61
|
+
acc += yl.sb * (qs[1] & 0xF000);
|
|
62
|
+
|
|
63
|
+
acc += yl.s4 * (qs[2] & 0x000F);
|
|
64
|
+
acc += yl.s5 * (qs[2] & 0x0F00);
|
|
65
|
+
acc += yl.sc * (qs[2] & 0x00F0);
|
|
66
|
+
acc += yl.sd * (qs[2] & 0xF000);
|
|
67
|
+
|
|
68
|
+
acc += yl.s6 * (qs[3] & 0x000F);
|
|
69
|
+
acc += yl.s7 * (qs[3] & 0x0F00);
|
|
70
|
+
acc += yl.se * (qs[3] & 0x00F0);
|
|
71
|
+
acc += yl.sf * (qs[3] & 0xF000);
|
|
72
|
+
|
|
73
|
+
return d * (sumy * -8.f + acc);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
//
|
|
77
|
+
// This variant outputs 8 values.
|
|
78
|
+
//
|
|
79
|
+
#undef N_DST
|
|
80
|
+
#undef N_SIMDGROUP
|
|
81
|
+
#undef N_SIMDWIDTH
|
|
82
|
+
|
|
83
|
+
#ifdef INTEL_GPU
|
|
84
|
+
#define N_DST 8 // each SIMD group works on 8 rows
|
|
85
|
+
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
|
|
86
|
+
#define N_SIMDWIDTH 16 // subgroup size
|
|
87
|
+
#elif defined (ADRENO_GPU)
|
|
88
|
+
#define N_DST 8
|
|
89
|
+
#define N_SIMDGROUP 1
|
|
90
|
+
#define N_SIMDWIDTH 64
|
|
91
|
+
#endif
|
|
92
|
+
|
|
93
|
+
inline void mul_vec_q_n_f32_8x_flat(
|
|
94
|
+
global char * src0_q,
|
|
95
|
+
global half * src0_d,
|
|
96
|
+
global float * src1,
|
|
97
|
+
global float * dst,
|
|
98
|
+
int ne00,
|
|
99
|
+
int ne01,
|
|
100
|
+
int ne02,
|
|
101
|
+
int ne10,
|
|
102
|
+
int ne12,
|
|
103
|
+
int ne0,
|
|
104
|
+
int ne1,
|
|
105
|
+
int r2,
|
|
106
|
+
int r3
|
|
107
|
+
) {
|
|
108
|
+
const ulong nb = ne00/QK4_0;
|
|
109
|
+
|
|
110
|
+
int r0 = get_group_id(0);
|
|
111
|
+
int r1 = get_group_id(1);
|
|
112
|
+
int im = 0;
|
|
113
|
+
|
|
114
|
+
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
|
115
|
+
|
|
116
|
+
int i12 = im%ne12;
|
|
117
|
+
int i13 = im/ne12;
|
|
118
|
+
|
|
119
|
+
// The number of scales is the same as the number of blocks.
|
|
120
|
+
ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
121
|
+
// Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
|
|
122
|
+
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
|
|
123
|
+
|
|
124
|
+
global uchar * x = (global uchar *) src0_q + offset0_q;
|
|
125
|
+
global half * d = (global half *) src0_d + offset0_d;
|
|
126
|
+
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
127
|
+
|
|
128
|
+
float16 yl;
|
|
129
|
+
float8 sumf = 0.f;
|
|
130
|
+
|
|
131
|
+
int ix = get_sub_group_local_id()/2;
|
|
132
|
+
int il = 8*(get_sub_group_local_id()%2);
|
|
133
|
+
|
|
134
|
+
global float * yb = y + ix*QK4_0 + il;
|
|
135
|
+
|
|
136
|
+
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
|
137
|
+
float sumy = 0.f;
|
|
138
|
+
|
|
139
|
+
sumy += yb[0];
|
|
140
|
+
sumy += yb[1];
|
|
141
|
+
sumy += yb[2];
|
|
142
|
+
sumy += yb[3];
|
|
143
|
+
sumy += yb[4];
|
|
144
|
+
sumy += yb[5];
|
|
145
|
+
sumy += yb[6];
|
|
146
|
+
sumy += yb[7];
|
|
147
|
+
|
|
148
|
+
sumy += yb[16];
|
|
149
|
+
sumy += yb[17];
|
|
150
|
+
sumy += yb[18];
|
|
151
|
+
sumy += yb[19];
|
|
152
|
+
sumy += yb[20];
|
|
153
|
+
sumy += yb[21];
|
|
154
|
+
sumy += yb[22];
|
|
155
|
+
sumy += yb[23];
|
|
156
|
+
|
|
157
|
+
yl.s0 = yb[0];
|
|
158
|
+
yl.s1 = yb[1]/256.f;
|
|
159
|
+
|
|
160
|
+
yl.s2 = yb[2];
|
|
161
|
+
yl.s3 = yb[3]/256.f;
|
|
162
|
+
|
|
163
|
+
yl.s4 = yb[4];
|
|
164
|
+
yl.s5 = yb[5]/256.f;
|
|
165
|
+
|
|
166
|
+
yl.s6 = yb[6];
|
|
167
|
+
yl.s7 = yb[7]/256.f;
|
|
168
|
+
|
|
169
|
+
yl.s8 = yb[16]/16.f;
|
|
170
|
+
yl.s9 = yb[17]/4096.f;
|
|
171
|
+
|
|
172
|
+
yl.sa = yb[18]/16.f;
|
|
173
|
+
yl.sb = yb[19]/4096.f;
|
|
174
|
+
|
|
175
|
+
yl.sc = yb[20]/16.f;
|
|
176
|
+
yl.sd = yb[21]/4096.f;
|
|
177
|
+
|
|
178
|
+
yl.se = yb[22]/16.f;
|
|
179
|
+
yl.sf = yb[23]/4096.f;
|
|
180
|
+
|
|
181
|
+
sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
|
|
182
|
+
sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
|
|
183
|
+
sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
|
|
184
|
+
sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
|
|
185
|
+
|
|
186
|
+
sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
|
|
187
|
+
sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
|
|
188
|
+
sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
|
|
189
|
+
sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
|
|
190
|
+
|
|
191
|
+
yb += QK4_0 * (N_SIMDWIDTH/2);
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
float8 tot = (float8)(
|
|
195
|
+
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
|
196
|
+
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
|
|
197
|
+
sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
|
|
198
|
+
sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
|
|
199
|
+
);
|
|
200
|
+
|
|
201
|
+
if (get_sub_group_local_id() == 0) {
|
|
202
|
+
if (first_row + 0 < ne01) {
|
|
203
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
|
204
|
+
}
|
|
205
|
+
if (first_row + 1 < ne01) {
|
|
206
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
|
207
|
+
}
|
|
208
|
+
if (first_row + 2 < ne01) {
|
|
209
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
|
210
|
+
}
|
|
211
|
+
if (first_row + 3 < ne01) {
|
|
212
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
if (first_row + 4 < ne01) {
|
|
216
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
|
|
217
|
+
}
|
|
218
|
+
if (first_row + 5 < ne01) {
|
|
219
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
|
|
220
|
+
}
|
|
221
|
+
if (first_row + 6 < ne01) {
|
|
222
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
|
|
223
|
+
}
|
|
224
|
+
if (first_row + 7 < ne01) {
|
|
225
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
#ifdef INTEL_GPU
|
|
231
|
+
REQD_SUBGROUP_SIZE_16
|
|
232
|
+
#elif defined (ADRENO_GPU)
|
|
233
|
+
REQD_SUBGROUP_SIZE_64
|
|
234
|
+
#endif
|
|
235
|
+
kernel void kernel_mul_mv_id_q4_0_f32_8x_flat(
|
|
236
|
+
global char * src0_q,
|
|
237
|
+
global half * src0_d,
|
|
238
|
+
global float * src1,
|
|
239
|
+
ulong offset1,
|
|
240
|
+
global char * src2,
|
|
241
|
+
ulong offset2,
|
|
242
|
+
global float * dst,
|
|
243
|
+
ulong offsetd,
|
|
244
|
+
int ne00,
|
|
245
|
+
int ne01,
|
|
246
|
+
int ne02,
|
|
247
|
+
ulong nb00,
|
|
248
|
+
ulong nb02,
|
|
249
|
+
int ne10,
|
|
250
|
+
int ne11,
|
|
251
|
+
int ne12,
|
|
252
|
+
ulong nb11,
|
|
253
|
+
ulong nb12,
|
|
254
|
+
int ne20,
|
|
255
|
+
int ne21,
|
|
256
|
+
ulong nb21,
|
|
257
|
+
int ne0,
|
|
258
|
+
int ne1,
|
|
259
|
+
int r2,
|
|
260
|
+
int r3
|
|
261
|
+
) {
|
|
262
|
+
src1 = (global float *)((global char *)src1 + offset1);
|
|
263
|
+
src2 = (global char *)((global char *)src2 + offset2);
|
|
264
|
+
dst = (global float *)((global char *)dst + offsetd);
|
|
265
|
+
|
|
266
|
+
const int iid1 = get_group_id(2)/ne20;
|
|
267
|
+
const int idx = get_group_id(2)%ne20;
|
|
268
|
+
|
|
269
|
+
const int i02 = ((global int *)(src2 + iid1*nb21))[idx];
|
|
270
|
+
|
|
271
|
+
const int i11 = idx%ne11;
|
|
272
|
+
const int i12 = iid1;
|
|
273
|
+
|
|
274
|
+
const int i1 = idx;
|
|
275
|
+
const int i2 = i12;
|
|
276
|
+
|
|
277
|
+
global char * src0_q_cur = src0_q + (i02*nb02/nb00)*(QK4_0/2);
|
|
278
|
+
global half * src0_d_cur = src0_d + (i02*nb02/nb00);
|
|
279
|
+
global float * src1_cur = (global float *)((global char *) src1 + i11*nb11 + i12*nb12);
|
|
280
|
+
global float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
|
|
281
|
+
|
|
282
|
+
mul_vec_q_n_f32_8x_flat(src0_q_cur, src0_d_cur, src1_cur, dst_cur, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
|
283
|
+
}
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
kernel void kernel_pad(
|
|
2
|
+
global const void * src0_ptr,
|
|
3
|
+
ulong src0_offset,
|
|
4
|
+
global void * dst_ptr,
|
|
5
|
+
ulong dst_offset,
|
|
6
|
+
int s_ne0, int s_ne1, int s_ne2,
|
|
7
|
+
int d_ne0, int d_ne1, int d_ne2
|
|
8
|
+
) {
|
|
9
|
+
global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset);
|
|
10
|
+
global float * dst = (global float *)((global char *)dst_ptr + dst_offset);
|
|
11
|
+
|
|
12
|
+
int nidx = get_global_id(0);
|
|
13
|
+
int idx_d1 = get_group_id(1);
|
|
14
|
+
int idx_d2 = get_group_id(2);
|
|
15
|
+
|
|
16
|
+
if (nidx >= d_ne0) {
|
|
17
|
+
return;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1;
|
|
21
|
+
|
|
22
|
+
bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2);
|
|
23
|
+
|
|
24
|
+
if (in_src_bounds) {
|
|
25
|
+
int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1;
|
|
26
|
+
dst[dst_el_offset] = src0[src_el_offset];
|
|
27
|
+
} else {
|
|
28
|
+
dst[dst_el_offset] = 0.0f;
|
|
29
|
+
}
|
|
30
|
+
}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
kernel void kernel_repeat(
|
|
2
|
+
global const char * src0_data_in,
|
|
3
|
+
global char * dst_data_in,
|
|
4
|
+
ulong src0_offset,
|
|
5
|
+
ulong dst_offset,
|
|
6
|
+
int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3,
|
|
7
|
+
ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3,
|
|
8
|
+
int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3,
|
|
9
|
+
ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3
|
|
10
|
+
) {
|
|
11
|
+
global const char * src0_data = src0_data_in + src0_offset;
|
|
12
|
+
global char * dst_data = dst_data_in + dst_offset;
|
|
13
|
+
|
|
14
|
+
const int d3 = get_global_id(2);
|
|
15
|
+
const int d2 = get_global_id(1);
|
|
16
|
+
const int d1 = get_global_id(0);
|
|
17
|
+
|
|
18
|
+
if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) {
|
|
19
|
+
return;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
const int s3 = d3 % src0_ne3;
|
|
23
|
+
const int s2 = d2 % src0_ne2;
|
|
24
|
+
const int s1 = d1 % src0_ne1;
|
|
25
|
+
|
|
26
|
+
const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1;
|
|
27
|
+
global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1;
|
|
28
|
+
|
|
29
|
+
for (int d0 = 0; d0 < dst_ne0; ++d0) {
|
|
30
|
+
// Determine source index for dimension 0 based on tiling/broadcasting.
|
|
31
|
+
const int s0 = d0 % src0_ne0;
|
|
32
|
+
|
|
33
|
+
const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0;
|
|
34
|
+
global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0;
|
|
35
|
+
for (int k = 0; k < src0_nb0; ++k) {
|
|
36
|
+
current_dst_el_ptr[k] = current_src_el_ptr[k];
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
}
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
#ifdef cl_intel_required_subgroup_size
|
|
4
|
+
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
|
5
|
+
#define INTEL_GPU 1
|
|
6
|
+
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
|
7
|
+
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
|
8
|
+
#elif defined(cl_qcom_reqd_sub_group_size)
|
|
9
|
+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
|
10
|
+
#define ADRENO_GPU 1
|
|
11
|
+
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
|
12
|
+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
|
13
|
+
#endif
|
|
14
|
+
|
|
15
|
+
kernel void kernel_tanh_f32_nd(
|
|
16
|
+
global void * p_src0_base, ulong off_src0_abs,
|
|
17
|
+
global void * p_dst_base, ulong off_dst_abs,
|
|
18
|
+
int ne00, int ne01, int ne02, int ne03,
|
|
19
|
+
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
|
|
20
|
+
int ne10, int ne11, int ne12, int ne13,
|
|
21
|
+
ulong nb10, ulong nb11, ulong nb12, ulong nb13
|
|
22
|
+
) {
|
|
23
|
+
int i0 = get_global_id(0);
|
|
24
|
+
int i1 = get_global_id(1);
|
|
25
|
+
int i2 = get_global_id(2);
|
|
26
|
+
|
|
27
|
+
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
|
28
|
+
for (int i3 = 0; i3 < ne13; ++i3) {
|
|
29
|
+
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
|
30
|
+
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
|
31
|
+
|
|
32
|
+
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
|
33
|
+
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
|
34
|
+
|
|
35
|
+
*dst_val_ptr = tanh(*src_val_ptr);
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
kernel void kernel_tanh_f16_nd(
|
|
41
|
+
global void * p_src0_base, ulong off_src0_abs,
|
|
42
|
+
global void * p_dst_base, ulong off_dst_abs,
|
|
43
|
+
int ne00, int ne01, int ne02, int ne03,
|
|
44
|
+
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
|
|
45
|
+
int ne10, int ne11, int ne12, int ne13,
|
|
46
|
+
ulong nb10, ulong nb11, ulong nb12, ulong nb13
|
|
47
|
+
) {
|
|
48
|
+
int i0 = get_global_id(0);
|
|
49
|
+
int i1 = get_global_id(1);
|
|
50
|
+
int i2 = get_global_id(2);
|
|
51
|
+
|
|
52
|
+
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
|
53
|
+
for (int i3 = 0; i3 < ne13; ++i3) {
|
|
54
|
+
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
|
55
|
+
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
|
56
|
+
|
|
57
|
+
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
|
58
|
+
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
|
59
|
+
|
|
60
|
+
*dst_val_ptr = tanh(*src_val_ptr);
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
}
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
kernel void kernel_timestep_embedding(
|
|
2
|
+
global const void * p_timesteps,
|
|
3
|
+
ulong off_timesteps,
|
|
4
|
+
global void * p_dst,
|
|
5
|
+
ulong off_dst,
|
|
6
|
+
int dst_nb1_bytes,
|
|
7
|
+
int logical_dim,
|
|
8
|
+
int max_period
|
|
9
|
+
) {
|
|
10
|
+
int local_i;
|
|
11
|
+
int local_j;
|
|
12
|
+
int local_half_dim;
|
|
13
|
+
float local_timestep_val;
|
|
14
|
+
float local_freq;
|
|
15
|
+
float local_arg;
|
|
16
|
+
global float * local_embed_data_ptr;
|
|
17
|
+
global const float * local_timesteps_input_ptr;
|
|
18
|
+
global float * local_dst_output_base_ptr;
|
|
19
|
+
|
|
20
|
+
local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps);
|
|
21
|
+
local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst);
|
|
22
|
+
|
|
23
|
+
local_i = get_global_id(1);
|
|
24
|
+
local_j = get_global_id(0);
|
|
25
|
+
|
|
26
|
+
local_half_dim = logical_dim / 2;
|
|
27
|
+
local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes);
|
|
28
|
+
|
|
29
|
+
if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) {
|
|
30
|
+
local_embed_data_ptr[logical_dim] = 0.0f;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
if (local_j >= local_half_dim) {
|
|
34
|
+
return;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
local_timestep_val = local_timesteps_input_ptr[local_i];
|
|
38
|
+
|
|
39
|
+
if (local_half_dim == 0) {
|
|
40
|
+
local_freq = 1.0f;
|
|
41
|
+
} else {
|
|
42
|
+
local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim);
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
local_arg = local_timestep_val * local_freq;
|
|
46
|
+
local_embed_data_ptr[local_j] = cos(local_arg);
|
|
47
|
+
local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg);
|
|
48
|
+
}
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
kernel void kernel_upscale(
|
|
2
|
+
global const void * p_src0,
|
|
3
|
+
ulong off_src0,
|
|
4
|
+
global void * p_dst,
|
|
5
|
+
ulong off_dst,
|
|
6
|
+
ulong nb00,
|
|
7
|
+
ulong nb01,
|
|
8
|
+
ulong nb02,
|
|
9
|
+
ulong nb03,
|
|
10
|
+
int ne10,
|
|
11
|
+
int ne11,
|
|
12
|
+
int ne12,
|
|
13
|
+
int ne13,
|
|
14
|
+
float sf0,
|
|
15
|
+
float sf1,
|
|
16
|
+
float sf2,
|
|
17
|
+
float sf3
|
|
18
|
+
) {
|
|
19
|
+
global const char * src_base = (global const char *)p_src0 + off_src0;
|
|
20
|
+
global float * dst_base = (global float *)((global char *)p_dst + off_dst);
|
|
21
|
+
|
|
22
|
+
int index = get_global_id(0);
|
|
23
|
+
int dst_total_elements = ne10 * ne11 * ne12 * ne13;
|
|
24
|
+
|
|
25
|
+
if (index >= dst_total_elements) {
|
|
26
|
+
return;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
int i10 = index % ne10;
|
|
30
|
+
int i11 = (index / ne10) % ne11;
|
|
31
|
+
int i12 = (index / (ne10 * ne11)) % ne12;
|
|
32
|
+
int i13 = index / (ne10 * ne11 * ne12);
|
|
33
|
+
|
|
34
|
+
int i00 = (int)(i10 / sf0);
|
|
35
|
+
int i01 = (int)(i11 / sf1);
|
|
36
|
+
int i02 = (int)(i12 / sf2);
|
|
37
|
+
int i03 = (int)(i13 / sf3);
|
|
38
|
+
|
|
39
|
+
ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00;
|
|
40
|
+
global const float * src_element_ptr = (global const float *)(src_base + offset_src_element);
|
|
41
|
+
|
|
42
|
+
dst_base[index] = *src_element_ptr;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
kernel void kernel_upscale_bilinear(
|
|
46
|
+
global const void * p_src0,
|
|
47
|
+
ulong off_src0,
|
|
48
|
+
global void * p_dst,
|
|
49
|
+
ulong off_dst,
|
|
50
|
+
ulong nb00,
|
|
51
|
+
ulong nb01,
|
|
52
|
+
ulong nb02,
|
|
53
|
+
ulong nb03,
|
|
54
|
+
int ne00_src,
|
|
55
|
+
int ne01_src,
|
|
56
|
+
int ne10_dst,
|
|
57
|
+
int ne11_dst,
|
|
58
|
+
int ne12_dst,
|
|
59
|
+
int ne13_dst,
|
|
60
|
+
float sf0,
|
|
61
|
+
float sf1,
|
|
62
|
+
float sf2,
|
|
63
|
+
float sf3
|
|
64
|
+
) {
|
|
65
|
+
global const char * src_base = (global const char *)p_src0 + off_src0;
|
|
66
|
+
global float * dst_base = (global float *)((global char *)p_dst + off_dst);
|
|
67
|
+
|
|
68
|
+
int index = get_global_id(0);
|
|
69
|
+
int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
|
70
|
+
|
|
71
|
+
if (index >= dst_total_elements) {
|
|
72
|
+
return;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
int i10_dst = index % ne10_dst;
|
|
76
|
+
int i11_dst = (index / ne10_dst) % ne11_dst;
|
|
77
|
+
int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
|
78
|
+
int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
|
79
|
+
|
|
80
|
+
int i02_src = (int)(i12_dst / sf2);
|
|
81
|
+
int i03_src = (int)(i13_dst / sf3);
|
|
82
|
+
|
|
83
|
+
const float pixel_offset = 0.5f;
|
|
84
|
+
|
|
85
|
+
float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
|
86
|
+
long y0_src = (long)floor(y_src_f);
|
|
87
|
+
long y1_src = y0_src + 1;
|
|
88
|
+
|
|
89
|
+
y0_src = max(0L, min(y0_src, (long)ne01_src - 1));
|
|
90
|
+
y1_src = max(0L, min(y1_src, (long)ne01_src - 1));
|
|
91
|
+
|
|
92
|
+
float dy = y_src_f - (float)y0_src;
|
|
93
|
+
dy = max(0.0f, min(dy, 1.0f));
|
|
94
|
+
|
|
95
|
+
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
|
96
|
+
long x0_src = (long)floor(x_src_f);
|
|
97
|
+
long x1_src = x0_src + 1;
|
|
98
|
+
|
|
99
|
+
x0_src = max(0L, min(x0_src, (long)ne00_src - 1));
|
|
100
|
+
x1_src = max(0L, min(x1_src, (long)ne00_src - 1));
|
|
101
|
+
|
|
102
|
+
float dx = x_src_f - (float)x0_src;
|
|
103
|
+
dx = max(0.0f, min(dx, 1.0f));
|
|
104
|
+
|
|
105
|
+
global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
|
106
|
+
global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
|
107
|
+
global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
|
108
|
+
global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
|
109
|
+
|
|
110
|
+
const float val_a = *p_a;
|
|
111
|
+
const float val_b = *p_b;
|
|
112
|
+
const float val_c = *p_c;
|
|
113
|
+
const float val_d = *p_d;
|
|
114
|
+
|
|
115
|
+
float result = val_a * (1.0f - dx) * (1.0f - dy) +
|
|
116
|
+
val_b * dx * (1.0f - dy) +
|
|
117
|
+
val_c * (1.0f - dx) * dy +
|
|
118
|
+
val_d * dx * dy;
|
|
119
|
+
|
|
120
|
+
dst_base[index] = result;
|
|
121
|
+
}
|