whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,86 @@
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
2
|
+
|
3
|
+
#ifdef cl_intel_subgroups
|
4
|
+
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
5
|
+
#else
|
6
|
+
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
7
|
+
#endif
|
8
|
+
|
9
|
+
#ifdef cl_intel_required_subgroup_size
|
10
|
+
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
11
|
+
#define INTEL_GPU 1
|
12
|
+
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
13
|
+
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
14
|
+
#elif defined(cl_qcom_reqd_sub_group_size)
|
15
|
+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
16
|
+
#define ADRENO_GPU 1
|
17
|
+
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
18
|
+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
19
|
+
#endif
|
20
|
+
|
21
|
+
#define SWAP(x, y, T) { T tmp = (x); (x) = (y); (y) = tmp; }
|
22
|
+
|
23
|
+
enum ggml_sort_order {
|
24
|
+
GGML_SORT_ORDER_ASC,
|
25
|
+
GGML_SORT_ORDER_DESC,
|
26
|
+
};
|
27
|
+
|
28
|
+
kernel void kernel_argsort_f32_i32(
|
29
|
+
global float * src0,
|
30
|
+
ulong offset0,
|
31
|
+
global int * dst,
|
32
|
+
ulong offsetd,
|
33
|
+
const int ne00,
|
34
|
+
const int ne00_pad,
|
35
|
+
const int order,
|
36
|
+
local int * dst_row
|
37
|
+
) {
|
38
|
+
// bitonic sort
|
39
|
+
int col = get_local_id(0);
|
40
|
+
int row = get_group_id(1);
|
41
|
+
|
42
|
+
if (col >= ne00_pad) {
|
43
|
+
return;
|
44
|
+
}
|
45
|
+
|
46
|
+
src0 = (global char *)((global char *)src0 + offset0);
|
47
|
+
dst = (global float *)((global char *)dst + offsetd);
|
48
|
+
|
49
|
+
global float * x_row = src0 + row * ne00;
|
50
|
+
|
51
|
+
// initialize indices
|
52
|
+
dst_row[col] = col;
|
53
|
+
|
54
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
55
|
+
|
56
|
+
for (int k = 2; k <= ne00_pad; k *= 2) {
|
57
|
+
for (int j = k / 2; j > 0; j /= 2) {
|
58
|
+
int ixj = col ^ j;
|
59
|
+
if (ixj > col) {
|
60
|
+
if ((col & k) == 0) {
|
61
|
+
if (dst_row[col] >= ne00 ||
|
62
|
+
(dst_row[ixj] < ne00 && (order == GGML_SORT_ORDER_ASC ?
|
63
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
64
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
65
|
+
) {
|
66
|
+
SWAP(dst_row[col], dst_row[ixj], int);
|
67
|
+
}
|
68
|
+
} else {
|
69
|
+
if (dst_row[ixj] >= ne00 ||
|
70
|
+
(dst_row[col] < ne00 && (order == GGML_SORT_ORDER_ASC ?
|
71
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
72
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
73
|
+
) {
|
74
|
+
SWAP(dst_row[col], dst_row[ixj], int);
|
75
|
+
}
|
76
|
+
}
|
77
|
+
}
|
78
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
79
|
+
}
|
80
|
+
}
|
81
|
+
|
82
|
+
// copy the result to dst without the padding
|
83
|
+
if (col < ne00) {
|
84
|
+
dst[row * ne00 + col] = dst_row[col];
|
85
|
+
}
|
86
|
+
}
|
@@ -0,0 +1,109 @@
|
|
1
|
+
kernel void kernel_concat_f32_contiguous(
|
2
|
+
global const char * p_src0, ulong off_src0,
|
3
|
+
global const char * p_src1, ulong off_src1,
|
4
|
+
global char * p_dst, ulong off_dst,
|
5
|
+
int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice
|
6
|
+
int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes)
|
7
|
+
int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice
|
8
|
+
int dim
|
9
|
+
) {
|
10
|
+
global const float * src0 = (global const float*)((global char*)p_src0 + off_src0);
|
11
|
+
global const float * src1 = (global const float*)((global char*)p_src1 + off_src1);
|
12
|
+
global float * dst = (global float*)((global char*)p_dst + off_dst);
|
13
|
+
|
14
|
+
int i0 = get_global_id(0); // Index along dst's 0th dimension
|
15
|
+
int i1 = get_global_id(1); // Index along dst's 1st dimension
|
16
|
+
int i2 = get_global_id(2); // Index along dst's 2nd dimension
|
17
|
+
|
18
|
+
if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) {
|
19
|
+
return;
|
20
|
+
}
|
21
|
+
|
22
|
+
ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0;
|
23
|
+
ulong src_idx;
|
24
|
+
|
25
|
+
if (dim == 0) {
|
26
|
+
if (i0 < d_ne00) { // Data from src0
|
27
|
+
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
|
28
|
+
dst[dst_idx] = src0[src_idx];
|
29
|
+
} else { // Data from src1
|
30
|
+
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00);
|
31
|
+
dst[dst_idx] = src1[src_idx];
|
32
|
+
}
|
33
|
+
} else if (dim == 1) {
|
34
|
+
if (i1 < d_ne01) { // Data from src0
|
35
|
+
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
|
36
|
+
dst[dst_idx] = src0[src_idx];
|
37
|
+
} else { // Data from src1
|
38
|
+
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0;
|
39
|
+
dst[dst_idx] = src1[src_idx];
|
40
|
+
}
|
41
|
+
} else if (dim == 2) {
|
42
|
+
if (i2 < d_ne02) { // Data from src0
|
43
|
+
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
|
44
|
+
dst[dst_idx] = src0[src_idx];
|
45
|
+
} else { // Data from src1
|
46
|
+
|
47
|
+
src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0;
|
48
|
+
dst[dst_idx] = src1[src_idx];
|
49
|
+
}
|
50
|
+
}
|
51
|
+
}
|
52
|
+
|
53
|
+
kernel void kernel_concat_f32_non_contiguous(
|
54
|
+
global const char * p_src0, ulong off_src0,
|
55
|
+
global const char * p_src1, ulong off_src1,
|
56
|
+
global char * p_dst, ulong off_dst,
|
57
|
+
|
58
|
+
long ne00, long ne01, long ne02, long ne03,
|
59
|
+
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
|
60
|
+
|
61
|
+
ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1
|
62
|
+
|
63
|
+
long d_ne0, long d_ne1, long d_ne2, long d_ne3,
|
64
|
+
ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3,
|
65
|
+
int dim
|
66
|
+
) {
|
67
|
+
global const char * src0_base = p_src0 + off_src0;
|
68
|
+
global const char * src1_base = p_src1 + off_src1;
|
69
|
+
global char * dst_base = p_dst + off_dst;
|
70
|
+
|
71
|
+
long current_i1 = get_global_id(0); // Index for dst_dim_1
|
72
|
+
long current_i2 = get_global_id(1); // Index for dst_dim_2
|
73
|
+
long current_i3 = get_global_id(2); // Index for dst_dim_3
|
74
|
+
|
75
|
+
if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) {
|
76
|
+
return;
|
77
|
+
}
|
78
|
+
|
79
|
+
global const float * x_val_ptr;
|
80
|
+
global float * y_val_ptr;
|
81
|
+
|
82
|
+
for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) {
|
83
|
+
bool use_src0;
|
84
|
+
long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3;
|
85
|
+
|
86
|
+
if (dim == 0) {
|
87
|
+
use_src0 = (current_i0 < ne00);
|
88
|
+
if (!use_src0) { s_i0 = current_i0 - ne00; }
|
89
|
+
} else if (dim == 1) {
|
90
|
+
use_src0 = (current_i1 < ne01);
|
91
|
+
if (!use_src0) { s_i1 = current_i1 - ne01; }
|
92
|
+
} else if (dim == 2) {
|
93
|
+
use_src0 = (current_i2 < ne02);
|
94
|
+
if (!use_src0) { s_i2 = current_i2 - ne02; }
|
95
|
+
} else { // dim == 3
|
96
|
+
use_src0 = (current_i3 < ne03);
|
97
|
+
if (!use_src0) { s_i3 = current_i3 - ne03; }
|
98
|
+
}
|
99
|
+
|
100
|
+
if (use_src0) {
|
101
|
+
x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00);
|
102
|
+
} else {
|
103
|
+
x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10);
|
104
|
+
}
|
105
|
+
|
106
|
+
y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0);
|
107
|
+
*y_val_ptr = *x_val_ptr;
|
108
|
+
}
|
109
|
+
}
|
@@ -0,0 +1,72 @@
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
2
|
+
|
3
|
+
//------------------------------------------------------------------------------
|
4
|
+
// div
|
5
|
+
//------------------------------------------------------------------------------
|
6
|
+
kernel void kernel_div(
|
7
|
+
global char * src0,
|
8
|
+
ulong offset0,
|
9
|
+
global char * src1,
|
10
|
+
ulong offset1,
|
11
|
+
global char * dst,
|
12
|
+
ulong offsetd,
|
13
|
+
ulong nb00,
|
14
|
+
ulong nb01,
|
15
|
+
ulong nb02,
|
16
|
+
ulong nb03,
|
17
|
+
int ne10,
|
18
|
+
int ne11,
|
19
|
+
int ne12,
|
20
|
+
int ne13,
|
21
|
+
ulong nb10,
|
22
|
+
ulong nb11,
|
23
|
+
ulong nb12,
|
24
|
+
ulong nb13,
|
25
|
+
int ne0,
|
26
|
+
ulong nb0,
|
27
|
+
ulong nb1,
|
28
|
+
ulong nb2,
|
29
|
+
ulong nb3
|
30
|
+
) {
|
31
|
+
src0 = src0 + offset0;
|
32
|
+
src1 = src1 + offset1;
|
33
|
+
dst = dst + offsetd;
|
34
|
+
|
35
|
+
int i03 = get_group_id(2);
|
36
|
+
int i02 = get_group_id(1);
|
37
|
+
int i01 = get_group_id(0);
|
38
|
+
|
39
|
+
int i13 = i03 % ne13;
|
40
|
+
int i12 = i02 % ne12;
|
41
|
+
int i11 = i01 % ne11;
|
42
|
+
|
43
|
+
global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
44
|
+
global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
45
|
+
global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
46
|
+
|
47
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
48
|
+
const int i10 = i0 % ne10;
|
49
|
+
*((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) / *((global float *)(src1_ptr + i10*nb10));
|
50
|
+
}
|
51
|
+
}
|
52
|
+
|
53
|
+
// assumption: src1 is a row
|
54
|
+
// broadcast src1 into src0
|
55
|
+
kernel void kernel_div_row(
|
56
|
+
global float4 * src0,
|
57
|
+
ulong offset0,
|
58
|
+
global float4 * src1,
|
59
|
+
ulong offset1,
|
60
|
+
global float4 * dst,
|
61
|
+
ulong offsetd,
|
62
|
+
int ne
|
63
|
+
) {
|
64
|
+
src0 = (global float4*)((global char*)src0 + offset0);
|
65
|
+
src1 = (global float4*)((global char*)src1 + offset1);
|
66
|
+
dst = (global float4*)((global char*)dst + offsetd);
|
67
|
+
|
68
|
+
// This performs better than using %.
|
69
|
+
uint gid = get_global_id(0);
|
70
|
+
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
|
71
|
+
dst[gid] = src0[gid] / src1[idx1];
|
72
|
+
}
|
@@ -0,0 +1,201 @@
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
2
|
+
|
3
|
+
#define GELU_COEF_A 0.044715f
|
4
|
+
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
5
|
+
|
6
|
+
//------------------------------------------------------------------------------
|
7
|
+
// geglu
|
8
|
+
//------------------------------------------------------------------------------
|
9
|
+
kernel void kernel_geglu(
|
10
|
+
global char * src0,
|
11
|
+
ulong offset0,
|
12
|
+
global char * src1,
|
13
|
+
ulong offset1,
|
14
|
+
global char * dst,
|
15
|
+
ulong offsetd,
|
16
|
+
ulong nb01,
|
17
|
+
ulong nb11,
|
18
|
+
int ne0,
|
19
|
+
ulong nb1,
|
20
|
+
int ne00_off,
|
21
|
+
int ne10_off
|
22
|
+
) {
|
23
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
24
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
25
|
+
dst = (global char*)((global char*)dst + offsetd);
|
26
|
+
|
27
|
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
28
|
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
29
|
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
30
|
+
|
31
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
32
|
+
const float x0 = src0_row[i0];
|
33
|
+
const float x1 = src1_row[i0];
|
34
|
+
|
35
|
+
const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
36
|
+
|
37
|
+
dst_row[i0] = gelu*x1;
|
38
|
+
}
|
39
|
+
}
|
40
|
+
|
41
|
+
kernel void kernel_geglu_f16(
|
42
|
+
global char * src0,
|
43
|
+
ulong offset0,
|
44
|
+
global char * src1,
|
45
|
+
ulong offset1,
|
46
|
+
global char * dst,
|
47
|
+
ulong offsetd,
|
48
|
+
ulong nb01,
|
49
|
+
ulong nb11,
|
50
|
+
int ne0,
|
51
|
+
ulong nb1,
|
52
|
+
int ne00_off,
|
53
|
+
int ne10_off
|
54
|
+
) {
|
55
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
56
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
57
|
+
dst = (global char*)((global char*)dst + offsetd);
|
58
|
+
|
59
|
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
60
|
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
61
|
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
62
|
+
|
63
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
64
|
+
const half x0 = src0_row[i0];
|
65
|
+
const half x1 = src1_row[i0];
|
66
|
+
|
67
|
+
const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
68
|
+
|
69
|
+
dst_row[i0] = gelu*x1;
|
70
|
+
}
|
71
|
+
}
|
72
|
+
|
73
|
+
//------------------------------------------------------------------------------
|
74
|
+
// reglu
|
75
|
+
//------------------------------------------------------------------------------
|
76
|
+
kernel void kernel_reglu(
|
77
|
+
global char * src0,
|
78
|
+
ulong offset0,
|
79
|
+
global char * src1,
|
80
|
+
ulong offset1,
|
81
|
+
global char * dst,
|
82
|
+
ulong offsetd,
|
83
|
+
ulong nb01,
|
84
|
+
ulong nb11,
|
85
|
+
int ne0,
|
86
|
+
ulong nb1,
|
87
|
+
int ne00_off,
|
88
|
+
int ne10_off
|
89
|
+
) {
|
90
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
91
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
92
|
+
dst = (global char*)((global char*)dst + offsetd);
|
93
|
+
|
94
|
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
95
|
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
96
|
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
97
|
+
|
98
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
99
|
+
const float x0 = src0_row[i0];
|
100
|
+
const float x1 = src1_row[i0];
|
101
|
+
|
102
|
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
103
|
+
}
|
104
|
+
}
|
105
|
+
|
106
|
+
kernel void kernel_reglu_f16(
|
107
|
+
global char * src0,
|
108
|
+
ulong offset0,
|
109
|
+
global char * src1,
|
110
|
+
ulong offset1,
|
111
|
+
global char * dst,
|
112
|
+
ulong offsetd,
|
113
|
+
ulong nb01,
|
114
|
+
ulong nb11,
|
115
|
+
int ne0,
|
116
|
+
ulong nb1,
|
117
|
+
int ne00_off,
|
118
|
+
int ne10_off
|
119
|
+
) {
|
120
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
121
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
122
|
+
dst = (global char*)((global char*)dst + offsetd);
|
123
|
+
|
124
|
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
125
|
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
126
|
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
127
|
+
|
128
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
129
|
+
const half x0 = src0_row[i0];
|
130
|
+
const half x1 = src1_row[i0];
|
131
|
+
|
132
|
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
133
|
+
}
|
134
|
+
}
|
135
|
+
|
136
|
+
//------------------------------------------------------------------------------
|
137
|
+
// swiglu
|
138
|
+
//------------------------------------------------------------------------------
|
139
|
+
kernel void kernel_swiglu(
|
140
|
+
global char * src0,
|
141
|
+
ulong offset0,
|
142
|
+
global char * src1,
|
143
|
+
ulong offset1,
|
144
|
+
global char * dst,
|
145
|
+
ulong offsetd,
|
146
|
+
ulong nb01,
|
147
|
+
ulong nb11,
|
148
|
+
int ne0,
|
149
|
+
ulong nb1,
|
150
|
+
int ne00_off,
|
151
|
+
int ne10_off
|
152
|
+
) {
|
153
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
154
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
155
|
+
dst = (global char*)((global char*)dst + offsetd);
|
156
|
+
|
157
|
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
158
|
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
159
|
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
160
|
+
|
161
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
162
|
+
const float x0 = src0_row[i0];
|
163
|
+
const float x1 = src1_row[i0];
|
164
|
+
|
165
|
+
const float silu = x0 / (1.0f + exp(-x0));
|
166
|
+
|
167
|
+
dst_row[i0] = silu*x1;
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
kernel void kernel_swiglu_f16(
|
172
|
+
global char * src0,
|
173
|
+
ulong offset0,
|
174
|
+
global char * src1,
|
175
|
+
ulong offset1,
|
176
|
+
global char * dst,
|
177
|
+
ulong offsetd,
|
178
|
+
ulong nb01,
|
179
|
+
ulong nb11,
|
180
|
+
int ne0,
|
181
|
+
ulong nb1,
|
182
|
+
int ne00_off,
|
183
|
+
int ne10_off
|
184
|
+
) {
|
185
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
186
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
187
|
+
dst = (global char*)((global char*)dst + offsetd);
|
188
|
+
|
189
|
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
190
|
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
191
|
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
192
|
+
|
193
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
194
|
+
const half x0 = src0_row[i0];
|
195
|
+
const half x1 = src1_row[i0];
|
196
|
+
|
197
|
+
const half silu = x0 / (1.0f + exp(-x0));
|
198
|
+
|
199
|
+
dst_row[i0] = silu*x1;
|
200
|
+
}
|
201
|
+
}
|
@@ -0,0 +1,72 @@
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
2
|
+
|
3
|
+
#ifdef cl_intel_subgroups
|
4
|
+
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
5
|
+
#else
|
6
|
+
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
7
|
+
#endif
|
8
|
+
|
9
|
+
#ifdef cl_intel_required_subgroup_size
|
10
|
+
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
11
|
+
#define INTEL_GPU 1
|
12
|
+
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
13
|
+
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
14
|
+
#elif defined(cl_qcom_reqd_sub_group_size)
|
15
|
+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
16
|
+
#define ADRENO_GPU 1
|
17
|
+
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
18
|
+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
19
|
+
#endif
|
20
|
+
|
21
|
+
// Workgroup must be a subgroup
|
22
|
+
#ifdef INTEL_GPU
|
23
|
+
REQD_SUBGROUP_SIZE_32
|
24
|
+
#elif defined (ADRENO_GPU)
|
25
|
+
REQD_SUBGROUP_SIZE_64
|
26
|
+
#endif
|
27
|
+
kernel void kernel_group_norm(
|
28
|
+
global float * src0,
|
29
|
+
ulong offset0,
|
30
|
+
global float * dst,
|
31
|
+
ulong offsetd,
|
32
|
+
int ne,
|
33
|
+
int group_size,
|
34
|
+
float eps
|
35
|
+
) {
|
36
|
+
src0 = (global float *)((global char *)src0 + offset0);
|
37
|
+
dst = (global float *)((global char *)dst + offsetd);
|
38
|
+
|
39
|
+
int start = get_group_id(0) * group_size;
|
40
|
+
int end = start + group_size;
|
41
|
+
|
42
|
+
start += get_local_id(0);
|
43
|
+
|
44
|
+
if (end >= ne) {
|
45
|
+
end = ne;
|
46
|
+
}
|
47
|
+
|
48
|
+
float tmp = 0.0f;
|
49
|
+
|
50
|
+
for (int j = start; j < end; j += get_local_size(0)) {
|
51
|
+
tmp += src0[j];
|
52
|
+
}
|
53
|
+
|
54
|
+
tmp = sub_group_reduce_add(tmp);
|
55
|
+
|
56
|
+
const float mean = tmp / group_size;
|
57
|
+
tmp = 0.0f;
|
58
|
+
|
59
|
+
for (int j = start; j < end; j += get_local_size(0)) {
|
60
|
+
float xi = src0[j] - mean;
|
61
|
+
dst[j] = xi;
|
62
|
+
tmp += xi * xi;
|
63
|
+
}
|
64
|
+
|
65
|
+
tmp = sub_group_reduce_add(tmp);
|
66
|
+
|
67
|
+
const float variance = tmp / group_size;
|
68
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
69
|
+
for (int j = start; j < end; j += get_local_size(0)) {
|
70
|
+
dst[j] *= scale;
|
71
|
+
}
|
72
|
+
}
|