@fugood/llama.node 0.3.6 → 0.3.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -2
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +3 -1
- package/lib/index.js +16 -1
- package/lib/index.ts +16 -0
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +4 -3
- package/src/LlamaCompletionWorker.cpp +4 -2
- package/src/LlamaContext.cpp +61 -6
- package/src/LlamaContext.h +1 -0
- package/src/common.hpp +6 -11
- package/src/llama.cpp/.github/workflows/build.yml +19 -17
- package/src/llama.cpp/.github/workflows/docker.yml +77 -30
- package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +22 -3
- package/src/llama.cpp/CMakeLists.txt +49 -24
- package/src/llama.cpp/common/arg.cpp +82 -26
- package/src/llama.cpp/common/arg.h +3 -0
- package/src/llama.cpp/common/common.cpp +192 -72
- package/src/llama.cpp/common/common.h +51 -18
- package/src/llama.cpp/common/ngram-cache.cpp +12 -12
- package/src/llama.cpp/common/ngram-cache.h +2 -2
- package/src/llama.cpp/common/sampling.cpp +11 -6
- package/src/llama.cpp/common/speculative.cpp +18 -15
- package/src/llama.cpp/docs/build.md +2 -0
- package/src/llama.cpp/examples/batched/batched.cpp +9 -7
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
- package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
- package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
- package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
- package/src/llama.cpp/examples/infill/infill.cpp +23 -24
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
- package/src/llama.cpp/examples/llava/clip.cpp +4 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
- package/src/llama.cpp/examples/llava/llava.cpp +2 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
- package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
- package/src/llama.cpp/examples/main/main.cpp +51 -29
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
- package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
- package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
- package/src/llama.cpp/examples/run/run.cpp +175 -61
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
- package/src/llama.cpp/examples/server/httplib.h +1295 -409
- package/src/llama.cpp/examples/server/server.cpp +387 -181
- package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
- package/src/llama.cpp/examples/server/utils.hpp +170 -58
- package/src/llama.cpp/examples/simple/simple.cpp +9 -8
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
- package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
- package/src/llama.cpp/examples/tts/tts.cpp +64 -23
- package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
- package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
- package/src/llama.cpp/ggml/include/ggml.h +36 -145
- package/src/llama.cpp/ggml/include/gguf.h +202 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
- package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml.c +117 -1327
- package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
- package/src/llama.cpp/include/llama-cpp.h +6 -1
- package/src/llama.cpp/include/llama.h +138 -75
- package/src/llama.cpp/src/CMakeLists.txt +13 -1
- package/src/llama.cpp/src/llama-adapter.cpp +347 -0
- package/src/llama.cpp/src/llama-adapter.h +74 -0
- package/src/llama.cpp/src/llama-arch.cpp +1487 -0
- package/src/llama.cpp/src/llama-arch.h +400 -0
- package/src/llama.cpp/src/llama-batch.cpp +368 -0
- package/src/llama.cpp/src/llama-batch.h +88 -0
- package/src/llama.cpp/src/llama-chat.cpp +578 -0
- package/src/llama.cpp/src/llama-chat.h +52 -0
- package/src/llama.cpp/src/llama-context.cpp +1775 -0
- package/src/llama.cpp/src/llama-context.h +128 -0
- package/src/llama.cpp/src/llama-cparams.cpp +1 -0
- package/src/llama.cpp/src/llama-cparams.h +37 -0
- package/src/llama.cpp/src/llama-grammar.cpp +5 -4
- package/src/llama.cpp/src/llama-grammar.h +3 -1
- package/src/llama.cpp/src/llama-hparams.cpp +71 -0
- package/src/llama.cpp/src/llama-hparams.h +139 -0
- package/src/llama.cpp/src/llama-impl.cpp +167 -0
- package/src/llama.cpp/src/llama-impl.h +16 -136
- package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
- package/src/llama.cpp/src/llama-kv-cache.h +218 -0
- package/src/llama.cpp/src/llama-mmap.cpp +589 -0
- package/src/llama.cpp/src/llama-mmap.h +67 -0
- package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
- package/src/llama.cpp/src/llama-model-loader.h +167 -0
- package/src/llama.cpp/src/llama-model.cpp +3953 -0
- package/src/llama.cpp/src/llama-model.h +370 -0
- package/src/llama.cpp/src/llama-quant.cpp +934 -0
- package/src/llama.cpp/src/llama-quant.h +1 -0
- package/src/llama.cpp/src/llama-sampling.cpp +147 -32
- package/src/llama.cpp/src/llama-sampling.h +3 -19
- package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
- package/src/llama.cpp/src/llama-vocab.h +97 -142
- package/src/llama.cpp/src/llama.cpp +7160 -20314
- package/src/llama.cpp/src/unicode.cpp +8 -3
- package/src/llama.cpp/tests/CMakeLists.txt +2 -0
- package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
- package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
- package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
- package/src/llama.cpp/tests/test-gguf.cpp +222 -187
- package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +0 -1
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
|
@@ -986,7 +986,7 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
|
|
|
986
986
|
#define GGML_F16_STEP 32
|
|
987
987
|
#define GGML_F16_EPR 4
|
|
988
988
|
|
|
989
|
-
static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
|
|
989
|
+
static inline __m128 __sse_f16x4_load(const ggml_fp16_t * x) {
|
|
990
990
|
float tmp[4];
|
|
991
991
|
|
|
992
992
|
tmp[0] = GGML_FP16_TO_FP32(x[0]);
|
|
@@ -997,7 +997,7 @@ static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
|
|
|
997
997
|
return _mm_loadu_ps(tmp);
|
|
998
998
|
}
|
|
999
999
|
|
|
1000
|
-
static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
|
|
1000
|
+
static inline void __sse_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
|
1001
1001
|
float arr[4];
|
|
1002
1002
|
|
|
1003
1003
|
_mm_storeu_ps(arr, y);
|
|
@@ -3967,6 +3967,57 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
3967
3967
|
}
|
|
3968
3968
|
}
|
|
3969
3969
|
|
|
3970
|
+
static void ggml_compute_forward_dup_q(
|
|
3971
|
+
const struct ggml_compute_params * params,
|
|
3972
|
+
struct ggml_tensor * dst) {
|
|
3973
|
+
|
|
3974
|
+
const struct ggml_tensor * src0 = dst->src[0];
|
|
3975
|
+
const struct ggml_tensor * src1 = dst->src[1];
|
|
3976
|
+
|
|
3977
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
3978
|
+
|
|
3979
|
+
const enum ggml_type type = src0->type;
|
|
3980
|
+
ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
|
|
3981
|
+
|
|
3982
|
+
size_t qk = ggml_blck_size(type);
|
|
3983
|
+
const int64_t nr = ggml_nelements(src1) / qk;
|
|
3984
|
+
|
|
3985
|
+
// destination must be contiguous in the first dimension
|
|
3986
|
+
GGML_ASSERT(nb10 == ggml_type_size(dst->type));
|
|
3987
|
+
// must either have first dimension large enough to hold a row, or fully contiguous
|
|
3988
|
+
GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
|
|
3989
|
+
|
|
3990
|
+
const int ith = params->ith;
|
|
3991
|
+
const int nth = params->nth;
|
|
3992
|
+
|
|
3993
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3994
|
+
|
|
3995
|
+
// row range for this thread
|
|
3996
|
+
const int ir0 = dr*ith;
|
|
3997
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3998
|
+
|
|
3999
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
4000
|
+
|
|
4001
|
+
uint32_t i = ir * qk;
|
|
4002
|
+
|
|
4003
|
+
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
|
4004
|
+
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
|
4005
|
+
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
4006
|
+
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
|
4007
|
+
const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
|
4008
|
+
|
|
4009
|
+
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
|
4010
|
+
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
|
4011
|
+
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
|
4012
|
+
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
|
4013
|
+
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
|
4014
|
+
|
|
4015
|
+
dequantize_row_q(
|
|
4016
|
+
(const void *) ((char *) src0->data + x_offset),
|
|
4017
|
+
(float *) ((char *) dst->data + dst_offset), qk);
|
|
4018
|
+
}
|
|
4019
|
+
}
|
|
4020
|
+
|
|
3970
4021
|
static void ggml_compute_forward_dup(
|
|
3971
4022
|
const struct ggml_compute_params * params,
|
|
3972
4023
|
struct ggml_tensor * dst) {
|
|
@@ -3993,6 +4044,10 @@ static void ggml_compute_forward_dup(
|
|
|
3993
4044
|
} break;
|
|
3994
4045
|
default:
|
|
3995
4046
|
{
|
|
4047
|
+
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
|
|
4048
|
+
ggml_compute_forward_dup_q(params, dst);
|
|
4049
|
+
break;
|
|
4050
|
+
}
|
|
3996
4051
|
GGML_ABORT("fatal error");
|
|
3997
4052
|
}
|
|
3998
4053
|
}
|
|
@@ -6691,20 +6746,20 @@ static void ggml_compute_forward_silu_back_f32(
|
|
|
6691
6746
|
const struct ggml_compute_params * params,
|
|
6692
6747
|
struct ggml_tensor * dst) {
|
|
6693
6748
|
|
|
6694
|
-
const struct ggml_tensor *
|
|
6695
|
-
const struct ggml_tensor *
|
|
6749
|
+
const struct ggml_tensor * grad = dst->src[0];
|
|
6750
|
+
const struct ggml_tensor * src1 = dst->src[1];
|
|
6696
6751
|
|
|
6697
6752
|
assert(ggml_is_contiguous_1(grad));
|
|
6698
|
-
assert(ggml_is_contiguous_1(
|
|
6753
|
+
assert(ggml_is_contiguous_1(src1));
|
|
6699
6754
|
assert(ggml_is_contiguous_1(dst));
|
|
6700
|
-
assert(ggml_are_same_shape(
|
|
6701
|
-
assert(ggml_are_same_shape(
|
|
6755
|
+
assert(ggml_are_same_shape(src1, dst));
|
|
6756
|
+
assert(ggml_are_same_shape(src1, grad));
|
|
6702
6757
|
|
|
6703
6758
|
const int ith = params->ith;
|
|
6704
6759
|
const int nth = params->nth;
|
|
6705
6760
|
|
|
6706
|
-
const int nc =
|
|
6707
|
-
const int nr = ggml_nrows(
|
|
6761
|
+
const int nc = src1->ne[0];
|
|
6762
|
+
const int nr = ggml_nrows(src1);
|
|
6708
6763
|
|
|
6709
6764
|
// rows per thread
|
|
6710
6765
|
const int dr = (nr + nth - 1)/nth;
|
|
@@ -6716,7 +6771,7 @@ static void ggml_compute_forward_silu_back_f32(
|
|
|
6716
6771
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
6717
6772
|
ggml_vec_silu_backward_f32(nc,
|
|
6718
6773
|
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
|
6719
|
-
(float *) ((char *)
|
|
6774
|
+
(float *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
6720
6775
|
(float *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
6721
6776
|
|
|
6722
6777
|
#ifndef NDEBUG
|
|
@@ -6895,7 +6950,7 @@ static void ggml_compute_forward_norm_f32(
|
|
|
6895
6950
|
float eps;
|
|
6896
6951
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
6897
6952
|
|
|
6898
|
-
GGML_ASSERT(eps
|
|
6953
|
+
GGML_ASSERT(eps >= 0.0f);
|
|
6899
6954
|
|
|
6900
6955
|
// TODO: optimize
|
|
6901
6956
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
@@ -6966,7 +7021,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
|
6966
7021
|
float eps;
|
|
6967
7022
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
6968
7023
|
|
|
6969
|
-
GGML_ASSERT(eps
|
|
7024
|
+
GGML_ASSERT(eps >= 0.0f);
|
|
6970
7025
|
|
|
6971
7026
|
// TODO: optimize
|
|
6972
7027
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
@@ -7018,12 +7073,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|
|
7018
7073
|
const struct ggml_compute_params * params,
|
|
7019
7074
|
struct ggml_tensor * dst) {
|
|
7020
7075
|
|
|
7021
|
-
const struct ggml_tensor * src0 = dst->src[0];
|
|
7022
|
-
const struct ggml_tensor * src1 = dst->src[1];
|
|
7076
|
+
const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
|
|
7077
|
+
const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
|
|
7023
7078
|
|
|
7024
7079
|
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
|
|
7025
7080
|
|
|
7026
7081
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
7082
|
+
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
7027
7083
|
|
|
7028
7084
|
const int ith = params->ith;
|
|
7029
7085
|
const int nth = params->nth;
|
|
@@ -7042,8 +7098,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|
|
7042
7098
|
const int64_t i12 = i02;
|
|
7043
7099
|
const int64_t i13 = i03;
|
|
7044
7100
|
|
|
7045
|
-
const float *
|
|
7046
|
-
const float *
|
|
7101
|
+
const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
7102
|
+
const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
|
7047
7103
|
|
|
7048
7104
|
ggml_float sum_xx = 0.0;
|
|
7049
7105
|
ggml_float sum_xdz = 0.0;
|
|
@@ -7066,9 +7122,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|
|
7066
7122
|
{
|
|
7067
7123
|
// z = rms_norm(x)
|
|
7068
7124
|
//
|
|
7069
|
-
// rms_norm(
|
|
7125
|
+
// rms_norm(src1) =
|
|
7070
7126
|
// scale(
|
|
7071
|
-
//
|
|
7127
|
+
// src1,
|
|
7072
7128
|
// div(
|
|
7073
7129
|
// 1,
|
|
7074
7130
|
// sqrt(
|
|
@@ -7076,13 +7132,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|
|
7076
7132
|
// scale(
|
|
7077
7133
|
// sum(
|
|
7078
7134
|
// sqr(
|
|
7079
|
-
//
|
|
7135
|
+
// src1)),
|
|
7080
7136
|
// (1.0/N)),
|
|
7081
7137
|
// eps))));
|
|
7082
7138
|
|
|
7083
7139
|
// postorder:
|
|
7084
7140
|
// ## op args grad
|
|
7085
|
-
// 00 param
|
|
7141
|
+
// 00 param src1 grad[#00]
|
|
7086
7142
|
// 01 const 1
|
|
7087
7143
|
// 02 sqr (#00) grad[#02]
|
|
7088
7144
|
// 03 sum (#02) grad[#03]
|
|
@@ -7159,6 +7215,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|
|
7159
7215
|
// dx := scale(dx, rrms)
|
|
7160
7216
|
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
7161
7217
|
|
|
7218
|
+
// dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
|
|
7162
7219
|
ggml_vec_cpy_f32 (ne00, dx, x);
|
|
7163
7220
|
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
|
|
7164
7221
|
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
|
|
@@ -7419,14 +7476,14 @@ static void ggml_compute_forward_mul_mat(
|
|
|
7419
7476
|
if (src1_cont) {
|
|
7420
7477
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
7421
7478
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
7422
|
-
if (!llamafile_sgemm(
|
|
7479
|
+
if (!llamafile_sgemm(params,
|
|
7480
|
+
ne01, ne11, ne00/ggml_blck_size(src0->type),
|
|
7423
7481
|
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
7424
7482
|
nb01/ggml_type_size(src0->type),
|
|
7425
7483
|
(const char *)src1->data + i12*nb12 + i13*nb13,
|
|
7426
7484
|
nb11/ggml_type_size(src1->type),
|
|
7427
7485
|
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
7428
7486
|
nb1/ggml_type_size(dst->type),
|
|
7429
|
-
ith, nth,
|
|
7430
7487
|
src0->type,
|
|
7431
7488
|
src1->type,
|
|
7432
7489
|
dst->type))
|
|
@@ -7471,14 +7528,14 @@ UseGgmlGemm1:;
|
|
|
7471
7528
|
|
|
7472
7529
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
7473
7530
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
7474
|
-
if (!llamafile_sgemm(
|
|
7531
|
+
if (!llamafile_sgemm(params,
|
|
7532
|
+
ne01, ne11, ne00/ggml_blck_size(src0->type),
|
|
7475
7533
|
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
7476
7534
|
nb01/ggml_type_size(src0->type),
|
|
7477
7535
|
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
|
|
7478
7536
|
row_size/ggml_type_size(vec_dot_type),
|
|
7479
7537
|
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
7480
7538
|
nb1/ggml_type_size(dst->type),
|
|
7481
|
-
ith, nth,
|
|
7482
7539
|
src0->type,
|
|
7483
7540
|
vec_dot_type,
|
|
7484
7541
|
dst->type))
|
|
@@ -7750,12 +7807,13 @@ static void ggml_compute_forward_out_prod_f32(
|
|
|
7750
7807
|
const int ith = params->ith;
|
|
7751
7808
|
const int nth = params->nth;
|
|
7752
7809
|
|
|
7753
|
-
GGML_ASSERT(ne0
|
|
7754
|
-
GGML_ASSERT(ne1
|
|
7755
|
-
GGML_ASSERT(ne2
|
|
7756
|
-
GGML_ASSERT(
|
|
7757
|
-
|
|
7758
|
-
GGML_ASSERT(
|
|
7810
|
+
GGML_ASSERT(ne0 == ne00);
|
|
7811
|
+
GGML_ASSERT(ne1 == ne10);
|
|
7812
|
+
GGML_ASSERT(ne2 == ne12);
|
|
7813
|
+
GGML_ASSERT(ne3 == ne13);
|
|
7814
|
+
|
|
7815
|
+
GGML_ASSERT(ne2 % ne02 == 0);
|
|
7816
|
+
GGML_ASSERT(ne3 % ne03 == 0);
|
|
7759
7817
|
|
|
7760
7818
|
// we don't support permuted src0 or src1
|
|
7761
7819
|
GGML_ASSERT(nb00 == sizeof(float));
|
|
@@ -7797,6 +7855,10 @@ static void ggml_compute_forward_out_prod_f32(
|
|
|
7797
7855
|
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
|
|
7798
7856
|
const int64_t blck_1 = 16;
|
|
7799
7857
|
|
|
7858
|
+
// dps == dst per src0, used for group query attention
|
|
7859
|
+
const int64_t dps2 = ne2 / ne02;
|
|
7860
|
+
const int64_t dps3 = ne3 / ne03;
|
|
7861
|
+
|
|
7800
7862
|
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
|
|
7801
7863
|
const int64_t bir1 = MIN(bir + blck_1, ir1);
|
|
7802
7864
|
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
|
|
@@ -7807,8 +7869,8 @@ static void ggml_compute_forward_out_prod_f32(
|
|
|
7807
7869
|
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
|
7808
7870
|
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
|
7809
7871
|
|
|
7810
|
-
const int64_t i02 = i2;
|
|
7811
|
-
const int64_t i03 = i3;
|
|
7872
|
+
const int64_t i02 = i2 / dps2;
|
|
7873
|
+
const int64_t i03 = i3 / dps3;
|
|
7812
7874
|
|
|
7813
7875
|
//const int64_t i10 = i1;
|
|
7814
7876
|
const int64_t i12 = i2;
|
|
@@ -8906,9 +8968,9 @@ static void ggml_compute_forward_soft_max(
|
|
|
8906
8968
|
}
|
|
8907
8969
|
|
|
8908
8970
|
|
|
8909
|
-
//
|
|
8971
|
+
// ggml_compute_forward_soft_max_ext_back
|
|
8910
8972
|
|
|
8911
|
-
static void
|
|
8973
|
+
static void ggml_compute_forward_soft_max_ext_back_f32(
|
|
8912
8974
|
const struct ggml_compute_params * params,
|
|
8913
8975
|
struct ggml_tensor * dst) {
|
|
8914
8976
|
|
|
@@ -8921,6 +8983,14 @@ static void ggml_compute_forward_soft_max_back_f32(
|
|
|
8921
8983
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
8922
8984
|
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
|
8923
8985
|
|
|
8986
|
+
float scale = 1.0f;
|
|
8987
|
+
float max_bias = 0.0f;
|
|
8988
|
+
|
|
8989
|
+
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
|
8990
|
+
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
|
8991
|
+
|
|
8992
|
+
GGML_ASSERT(max_bias == 0.0f);
|
|
8993
|
+
|
|
8924
8994
|
// TODO: handle transposed/permuted matrices
|
|
8925
8995
|
|
|
8926
8996
|
const int ith = params->ith;
|
|
@@ -8969,10 +9039,11 @@ static void ggml_compute_forward_soft_max_back_f32(
|
|
|
8969
9039
|
|
|
8970
9040
|
// linear runtime, no additional memory
|
|
8971
9041
|
float dot_y_dy = 0;
|
|
8972
|
-
ggml_vec_dot_f32
|
|
8973
|
-
ggml_vec_cpy_f32
|
|
8974
|
-
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
|
8975
|
-
ggml_vec_mul_f32
|
|
9042
|
+
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
|
|
9043
|
+
ggml_vec_cpy_f32 (nc, dx, dy);
|
|
9044
|
+
ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
|
|
9045
|
+
ggml_vec_mul_f32 (nc, dx, dx, y);
|
|
9046
|
+
ggml_vec_scale_f32(nc, dx, scale);
|
|
8976
9047
|
|
|
8977
9048
|
#ifndef NDEBUG
|
|
8978
9049
|
for (int i = 0; i < nc; ++i) {
|
|
@@ -8983,7 +9054,7 @@ static void ggml_compute_forward_soft_max_back_f32(
|
|
|
8983
9054
|
}
|
|
8984
9055
|
}
|
|
8985
9056
|
|
|
8986
|
-
static void
|
|
9057
|
+
static void ggml_compute_forward_soft_max_ext_back(
|
|
8987
9058
|
const struct ggml_compute_params * params,
|
|
8988
9059
|
struct ggml_tensor * dst) {
|
|
8989
9060
|
|
|
@@ -8992,7 +9063,7 @@ static void ggml_compute_forward_soft_max_back(
|
|
|
8992
9063
|
switch (src0->type) {
|
|
8993
9064
|
case GGML_TYPE_F32:
|
|
8994
9065
|
{
|
|
8995
|
-
|
|
9066
|
+
ggml_compute_forward_soft_max_ext_back_f32(params, dst);
|
|
8996
9067
|
} break;
|
|
8997
9068
|
default:
|
|
8998
9069
|
{
|
|
@@ -9985,9 +10056,10 @@ static void ggml_compute_forward_im2col_back_f32(
|
|
|
9985
10056
|
const struct ggml_compute_params * params,
|
|
9986
10057
|
struct ggml_tensor * dst) {
|
|
9987
10058
|
|
|
9988
|
-
const struct ggml_tensor * src0 = dst->src[0];
|
|
9989
|
-
const struct ggml_tensor * src1 = dst->src[1];
|
|
10059
|
+
const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
|
|
10060
|
+
const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
|
|
9990
10061
|
|
|
10062
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
9991
10063
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
9992
10064
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
9993
10065
|
|
|
@@ -10009,11 +10081,11 @@ static void ggml_compute_forward_im2col_back_f32(
|
|
|
10009
10081
|
const int64_t IH = is_2D ? ne1 : 1;
|
|
10010
10082
|
const int64_t IW = ne0;
|
|
10011
10083
|
|
|
10012
|
-
const int64_t KH = is_2D ?
|
|
10013
|
-
const int64_t KW =
|
|
10084
|
+
const int64_t KH = is_2D ? ne11 : 1;
|
|
10085
|
+
const int64_t KW = ne10;
|
|
10014
10086
|
|
|
10015
|
-
const int64_t OH = is_2D ?
|
|
10016
|
-
const int64_t OW =
|
|
10087
|
+
const int64_t OH = is_2D ? ne02 : 1;
|
|
10088
|
+
const int64_t OW = ne01;
|
|
10017
10089
|
|
|
10018
10090
|
int ofs0 = is_2D ? nb3 : nb2;
|
|
10019
10091
|
int ofs1 = is_2D ? nb2 : nb1;
|
|
@@ -10059,9 +10131,9 @@ static void ggml_compute_forward_im2col_back_f32(
|
|
|
10059
10131
|
continue;
|
|
10060
10132
|
}
|
|
10061
10133
|
|
|
10062
|
-
const float * const
|
|
10134
|
+
const float * const grad_in = (const float *) src0->data
|
|
10063
10135
|
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
10064
|
-
grad +=
|
|
10136
|
+
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
|
|
10065
10137
|
}
|
|
10066
10138
|
}
|
|
10067
10139
|
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
|
|
@@ -11803,9 +11875,9 @@ static void ggml_compute_forward_add_rel_pos(
|
|
|
11803
11875
|
static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
11804
11876
|
const struct ggml_compute_params * params,
|
|
11805
11877
|
struct ggml_tensor * dst) {
|
|
11806
|
-
const int64_t T = dst->src[1]->ne[
|
|
11878
|
+
const int64_t T = dst->src[1]->ne[2];
|
|
11807
11879
|
const int64_t C = dst->ne[0];
|
|
11808
|
-
const int64_t HEADS = dst->src[1]->ne[
|
|
11880
|
+
const int64_t HEADS = dst->src[1]->ne[1];
|
|
11809
11881
|
const int64_t n_seqs = dst->src[5]->ne[1];
|
|
11810
11882
|
const int64_t head_size = C / HEADS;
|
|
11811
11883
|
|
|
@@ -12000,6 +12072,197 @@ static void ggml_compute_forward_rwkv_wkv6(
|
|
|
12000
12072
|
}
|
|
12001
12073
|
}
|
|
12002
12074
|
|
|
12075
|
+
// ggml_compute_forward_gla
|
|
12076
|
+
|
|
12077
|
+
static void ggml_compute_forward_gla_f32(
|
|
12078
|
+
const struct ggml_compute_params * params,
|
|
12079
|
+
struct ggml_tensor * dst) {
|
|
12080
|
+
const int64_t T = dst->src[1]->ne[2];
|
|
12081
|
+
const int64_t C = dst->ne[0];
|
|
12082
|
+
const int64_t HEADS = dst->src[1]->ne[1];
|
|
12083
|
+
const int64_t n_seqs = dst->src[4]->ne[1];
|
|
12084
|
+
const int64_t head_size = C / HEADS;
|
|
12085
|
+
const float scale = ggml_get_op_params_f32(dst, 0);
|
|
12086
|
+
|
|
12087
|
+
float * dst_data = (float *) dst->data;
|
|
12088
|
+
float * state = ((float *) dst->data) + C * T;
|
|
12089
|
+
|
|
12090
|
+
const int ith = params->ith;
|
|
12091
|
+
const int nth = params->nth;
|
|
12092
|
+
|
|
12093
|
+
if (ith >= HEADS) {
|
|
12094
|
+
return;
|
|
12095
|
+
}
|
|
12096
|
+
|
|
12097
|
+
const int h_start = (HEADS * ith) / nth;
|
|
12098
|
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
12099
|
+
(HEADS * (ith + 1)) / nth : HEADS;
|
|
12100
|
+
|
|
12101
|
+
float * k = (float *) dst->src[0]->data;
|
|
12102
|
+
float * v = (float *) dst->src[1]->data;
|
|
12103
|
+
float * q = (float *) dst->src[2]->data;
|
|
12104
|
+
float * g = (float *) dst->src[3]->data;
|
|
12105
|
+
|
|
12106
|
+
size_t t_stride = HEADS * head_size; // Same to C
|
|
12107
|
+
|
|
12108
|
+
size_t h_stride = C / HEADS;
|
|
12109
|
+
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
|
12110
|
+
size_t h_stride_2d = head_size * head_size;
|
|
12111
|
+
|
|
12112
|
+
if (ith == 0) {
|
|
12113
|
+
memset(dst_data, 0, T * C * sizeof(float));
|
|
12114
|
+
}
|
|
12115
|
+
ggml_barrier(params->threadpool);
|
|
12116
|
+
|
|
12117
|
+
|
|
12118
|
+
#if defined(__AVX__) && !defined(__AVX512F__)
|
|
12119
|
+
#define GGML_F32X GGML_F32x8
|
|
12120
|
+
#define GGML_F32X_SET1 GGML_F32x8_SET1
|
|
12121
|
+
#define GGML_F32X_LOAD GGML_F32x8_LOAD
|
|
12122
|
+
#define GGML_F32X_STORE GGML_F32x8_STORE
|
|
12123
|
+
#define GGML_F32X_MUL GGML_F32x8_MUL
|
|
12124
|
+
#define GGML_F32X_FMA GGML_F32x8_FMA
|
|
12125
|
+
#define GLA_VECTOR_SIZE 8
|
|
12126
|
+
#elif defined(__AVX512F__)
|
|
12127
|
+
#define GGML_F32X GGML_F32x16
|
|
12128
|
+
#define GGML_F32X_SET1 GGML_F32x16_SET1
|
|
12129
|
+
#define GGML_F32X_LOAD GGML_F32x16_LOAD
|
|
12130
|
+
#define GGML_F32X_STORE GGML_F32x16_STORE
|
|
12131
|
+
#define GGML_F32X_MUL GGML_F32x16_MUL
|
|
12132
|
+
#define GGML_F32X_FMA GGML_F32x16_FMA
|
|
12133
|
+
#define GLA_VECTOR_SIZE 16
|
|
12134
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
12135
|
+
#define GGML_F32X GGML_F32x4
|
|
12136
|
+
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
|
12137
|
+
#define GGML_F32X_LOAD GGML_F32x4_LOAD
|
|
12138
|
+
#define GGML_F32X_STORE GGML_F32x4_STORE
|
|
12139
|
+
#define GGML_F32X_MUL GGML_F32x4_MUL
|
|
12140
|
+
#define GGML_F32X_FMA GGML_F32x4_FMA
|
|
12141
|
+
#define GLA_VECTOR_SIZE 4
|
|
12142
|
+
#endif
|
|
12143
|
+
|
|
12144
|
+
#ifdef GLA_VECTOR_SIZE
|
|
12145
|
+
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
|
|
12146
|
+
|
|
12147
|
+
for (int64_t t = 0; t < T; t++) {
|
|
12148
|
+
size_t t_offset = t * t_stride;
|
|
12149
|
+
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
12150
|
+
float * state_cur = state + state_offset;
|
|
12151
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
|
|
12152
|
+
|
|
12153
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
12154
|
+
size_t h_offset = h * h_stride;
|
|
12155
|
+
size_t t_h_offset = t_offset + h_offset;
|
|
12156
|
+
size_t h_2d_offset = h * h_stride_2d;
|
|
12157
|
+
|
|
12158
|
+
for (int64_t i = 0; i < head_size; i++) {
|
|
12159
|
+
size_t t_h_i_offset = t_h_offset + i;
|
|
12160
|
+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
|
12161
|
+
|
|
12162
|
+
float k_val = k[t_h_i_offset];
|
|
12163
|
+
float q_val = q[t_h_i_offset] * scale;
|
|
12164
|
+
float g_val = g[t_h_i_offset];
|
|
12165
|
+
|
|
12166
|
+
// Broadcast scalar values to vectors
|
|
12167
|
+
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
|
|
12168
|
+
GGML_F32X q_vec = GGML_F32X_SET1(q_val);
|
|
12169
|
+
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
|
|
12170
|
+
|
|
12171
|
+
for (int64_t j = 0; j < vec_count; j++) {
|
|
12172
|
+
size_t base_j = j * GLA_VECTOR_SIZE;
|
|
12173
|
+
size_t t_h_j_offset = t_h_offset + base_j;
|
|
12174
|
+
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
|
12175
|
+
|
|
12176
|
+
// Load x elements at once
|
|
12177
|
+
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
|
|
12178
|
+
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
|
|
12179
|
+
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
|
|
12180
|
+
|
|
12181
|
+
// Compute kv = v * k
|
|
12182
|
+
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
|
|
12183
|
+
|
|
12184
|
+
// Compute temp = prev_state * g + kv
|
|
12185
|
+
GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
|
|
12186
|
+
|
|
12187
|
+
// Update dst: dst += temp * q
|
|
12188
|
+
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
|
|
12189
|
+
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
|
|
12190
|
+
|
|
12191
|
+
// Update state
|
|
12192
|
+
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
|
|
12193
|
+
}
|
|
12194
|
+
|
|
12195
|
+
// Handle remaining elements, this will not be used.
|
|
12196
|
+
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
|
|
12197
|
+
size_t t_h_j_offset = t_h_offset + j;
|
|
12198
|
+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
12199
|
+
float v_val = v[t_h_j_offset];
|
|
12200
|
+
float kv_val = v_val * k_val;
|
|
12201
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
12202
|
+
float temp_val = kv_val + prev_state_val * g_val;
|
|
12203
|
+
dst_data[t_h_j_offset] += temp_val * q_val;
|
|
12204
|
+
state_cur[h_2d_i_j_offset] = temp_val;
|
|
12205
|
+
}
|
|
12206
|
+
}
|
|
12207
|
+
}
|
|
12208
|
+
}
|
|
12209
|
+
|
|
12210
|
+
#else
|
|
12211
|
+
for (int64_t t = 0; t < T; t++) {
|
|
12212
|
+
size_t t_offset = t * t_stride;
|
|
12213
|
+
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
12214
|
+
float * state_cur = state + state_offset;
|
|
12215
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
|
|
12216
|
+
|
|
12217
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
12218
|
+
size_t h_offset = h * h_stride;
|
|
12219
|
+
size_t t_h_offset = t_offset + h_offset;
|
|
12220
|
+
size_t h_2d_offset = h * h_stride_2d;
|
|
12221
|
+
|
|
12222
|
+
for (int64_t i = 0; i < head_size; i++) {
|
|
12223
|
+
size_t t_h_i_offset = t_h_offset + i;
|
|
12224
|
+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
|
12225
|
+
|
|
12226
|
+
float k_val = k[t_h_i_offset];
|
|
12227
|
+
float q_val = q[t_h_i_offset] * scale;
|
|
12228
|
+
float g_val = g[t_h_i_offset];
|
|
12229
|
+
|
|
12230
|
+
for (int64_t j = 0; j < head_size; j++) {
|
|
12231
|
+
size_t t_h_j_offset = t_h_offset + j;
|
|
12232
|
+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
12233
|
+
|
|
12234
|
+
float v_val = v[t_h_j_offset];
|
|
12235
|
+
float kv_val = v_val * k_val;
|
|
12236
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
12237
|
+
float temp_val = prev_state_val * g_val + kv_val;
|
|
12238
|
+
dst_data[t_h_j_offset] += temp_val * q_val;
|
|
12239
|
+
state_cur[h_2d_i_j_offset] = temp_val;
|
|
12240
|
+
}
|
|
12241
|
+
}
|
|
12242
|
+
}
|
|
12243
|
+
}
|
|
12244
|
+
#endif
|
|
12245
|
+
}
|
|
12246
|
+
|
|
12247
|
+
|
|
12248
|
+
static void ggml_compute_forward_gla(
|
|
12249
|
+
const struct ggml_compute_params * params,
|
|
12250
|
+
struct ggml_tensor * dst) {
|
|
12251
|
+
|
|
12252
|
+
const struct ggml_tensor * src0 = dst->src[0];
|
|
12253
|
+
|
|
12254
|
+
switch (src0->type) {
|
|
12255
|
+
case GGML_TYPE_F32:
|
|
12256
|
+
{
|
|
12257
|
+
ggml_compute_forward_gla_f32(params, dst);
|
|
12258
|
+
} break;
|
|
12259
|
+
default:
|
|
12260
|
+
{
|
|
12261
|
+
GGML_ABORT("fatal error");
|
|
12262
|
+
}
|
|
12263
|
+
}
|
|
12264
|
+
}
|
|
12265
|
+
|
|
12003
12266
|
// ggml_compute_forward_map_unary
|
|
12004
12267
|
|
|
12005
12268
|
static void ggml_compute_forward_map_unary_f32(
|
|
@@ -12293,22 +12556,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
12293
12556
|
const struct ggml_compute_params * params,
|
|
12294
12557
|
struct ggml_tensor * dst) {
|
|
12295
12558
|
|
|
12296
|
-
const struct ggml_tensor *
|
|
12297
|
-
const struct ggml_tensor *
|
|
12298
|
-
const struct ggml_tensor *
|
|
12559
|
+
const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
|
|
12560
|
+
const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
|
|
12561
|
+
const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
|
|
12299
12562
|
|
|
12300
12563
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
12301
|
-
GGML_ASSERT(ggml_is_contiguous(
|
|
12302
|
-
GGML_ASSERT(ggml_is_contiguous(
|
|
12303
|
-
GGML_ASSERT(ggml_is_contiguous(
|
|
12304
|
-
GGML_ASSERT(ggml_are_same_shape(
|
|
12564
|
+
GGML_ASSERT(ggml_is_contiguous(src0f));
|
|
12565
|
+
GGML_ASSERT(ggml_is_contiguous(src1f));
|
|
12566
|
+
GGML_ASSERT(ggml_is_contiguous(grad));
|
|
12567
|
+
GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
|
|
12305
12568
|
|
|
12306
12569
|
const int64_t ith = params->ith;
|
|
12307
12570
|
const int64_t nth = params->nth;
|
|
12308
12571
|
|
|
12309
12572
|
// TODO: handle transposed/permuted matrices
|
|
12310
|
-
const int64_t nc =
|
|
12311
|
-
const int64_t nr = ggml_nrows(
|
|
12573
|
+
const int64_t nc = src0f->ne[0];
|
|
12574
|
+
const int64_t nr = ggml_nrows(src0f);
|
|
12312
12575
|
|
|
12313
12576
|
// rows per thread
|
|
12314
12577
|
const int64_t dr = (nr + nth - 1)/nth;
|
|
@@ -12317,12 +12580,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
12317
12580
|
const int64_t ir0 = dr*ith;
|
|
12318
12581
|
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
12319
12582
|
|
|
12320
|
-
const float d_by_nr = ((const float *)
|
|
12583
|
+
const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
|
|
12321
12584
|
|
|
12322
12585
|
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
|
12323
|
-
float
|
|
12324
|
-
float * s0 = (float *)((char *)
|
|
12325
|
-
float * s1 = (float *)((char *)
|
|
12586
|
+
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
12587
|
+
const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
|
|
12588
|
+
const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
|
|
12326
12589
|
|
|
12327
12590
|
#ifndef NDEBUG
|
|
12328
12591
|
for (int64_t i = 0; i < nc; ++i) {
|
|
@@ -12335,11 +12598,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
12335
12598
|
// soft_max
|
|
12336
12599
|
float max = -INFINITY;
|
|
12337
12600
|
ggml_vec_max_f32(nc, &max, s0);
|
|
12338
|
-
ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
|
|
12601
|
+
const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
|
|
12339
12602
|
assert(sum > 0.0);
|
|
12340
12603
|
ggml_vec_scale_f32(nc, ds0, 1.0/sum);
|
|
12341
12604
|
|
|
12342
|
-
// grad(
|
|
12605
|
+
// grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
|
|
12343
12606
|
ggml_vec_sub_f32(nc, ds0, ds0, s1);
|
|
12344
12607
|
ggml_vec_scale_f32(nc, ds0, d_by_nr);
|
|
12345
12608
|
|
|
@@ -12636,7 +12899,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
12636
12899
|
} break;
|
|
12637
12900
|
case GGML_OP_SOFT_MAX_BACK:
|
|
12638
12901
|
{
|
|
12639
|
-
|
|
12902
|
+
ggml_compute_forward_soft_max_ext_back(params, tensor);
|
|
12640
12903
|
} break;
|
|
12641
12904
|
case GGML_OP_ROPE:
|
|
12642
12905
|
{
|
|
@@ -12749,6 +13012,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
12749
13012
|
{
|
|
12750
13013
|
ggml_compute_forward_rwkv_wkv6(params, tensor);
|
|
12751
13014
|
} break;
|
|
13015
|
+
case GGML_OP_GATED_LINEAR_ATTN:
|
|
13016
|
+
{
|
|
13017
|
+
ggml_compute_forward_gla(params, tensor);
|
|
13018
|
+
} break;
|
|
12752
13019
|
case GGML_OP_MAP_UNARY:
|
|
12753
13020
|
{
|
|
12754
13021
|
ggml_unary_op_f32_t fun;
|
|
@@ -13047,6 +13314,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
13047
13314
|
case GGML_OP_WIN_UNPART:
|
|
13048
13315
|
case GGML_OP_GET_REL_POS:
|
|
13049
13316
|
case GGML_OP_RWKV_WKV6:
|
|
13317
|
+
case GGML_OP_GATED_LINEAR_ATTN:
|
|
13050
13318
|
case GGML_OP_MAP_UNARY:
|
|
13051
13319
|
case GGML_OP_MAP_BINARY:
|
|
13052
13320
|
case GGML_OP_MAP_CUSTOM1_F32:
|
|
@@ -13472,6 +13740,7 @@ struct ggml_cplan ggml_graph_plan(
|
|
|
13472
13740
|
} break;
|
|
13473
13741
|
case GGML_OP_SOFT_MAX:
|
|
13474
13742
|
case GGML_OP_ROPE:
|
|
13743
|
+
case GGML_OP_ROPE_BACK:
|
|
13475
13744
|
{
|
|
13476
13745
|
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
|
13477
13746
|
} break;
|