@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
#include "ggml-backend.h"
|
|
5
5
|
#include "ggml-impl.h"
|
|
6
6
|
#include "ggml-threading.h"
|
|
7
|
+
#include "ggml-cpu.h"
|
|
7
8
|
#include "ggml.h"
|
|
8
9
|
|
|
9
10
|
// FIXME: required here for quantization functions
|
|
@@ -382,58 +383,16 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
|
|
|
382
383
|
}
|
|
383
384
|
}
|
|
384
385
|
|
|
385
|
-
// FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
|
|
386
|
-
// currently, the ggml_cpu_has_* functions are entirely compile-time
|
|
387
386
|
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
//if (ggml_cpu_has_f16c()) {
|
|
391
|
-
for (; i + 7 < n; i += 8) {
|
|
392
|
-
__m256 x_vec = _mm256_loadu_ps(x + i);
|
|
393
|
-
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
|
394
|
-
_mm_storeu_si128((__m128i *)(y + i), y_vec);
|
|
395
|
-
}
|
|
396
|
-
for(; i + 3 < n; i += 4) {
|
|
397
|
-
__m128 x_vec = _mm_loadu_ps(x + i);
|
|
398
|
-
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
|
399
|
-
_mm_storel_epi64((__m128i *)(y + i), y_vec);
|
|
400
|
-
}
|
|
401
|
-
//}
|
|
402
|
-
#endif
|
|
403
|
-
for (; i < n; i++) {
|
|
387
|
+
int i = 0;
|
|
388
|
+
for (; i < n; ++i) {
|
|
404
389
|
y[i] = GGML_FP32_TO_FP16(x[i]);
|
|
405
390
|
}
|
|
406
391
|
}
|
|
407
392
|
|
|
408
393
|
void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
//if (ggml_cpu_has_avx512()) {
|
|
412
|
-
for (; i + 16 <= n; i += 16) {
|
|
413
|
-
_mm512_storeu_ps(y + i,
|
|
414
|
-
_mm512_castsi512_ps(
|
|
415
|
-
_mm512_slli_epi32(
|
|
416
|
-
_mm512_cvtepu16_epi32(
|
|
417
|
-
_mm256_loadu_si256(
|
|
418
|
-
(const __m256i *)(x + i))),
|
|
419
|
-
16)));
|
|
420
|
-
}
|
|
421
|
-
//}
|
|
422
|
-
#endif
|
|
423
|
-
#if defined(__AVX2__)
|
|
424
|
-
//if (ggml_cpu_has_avx2()) {
|
|
425
|
-
for (; i + 8 <= n; i += 8) {
|
|
426
|
-
_mm256_storeu_ps(y + i,
|
|
427
|
-
_mm256_castsi256_ps(
|
|
428
|
-
_mm256_slli_epi32(
|
|
429
|
-
_mm256_cvtepu16_epi32(
|
|
430
|
-
_mm_loadu_si128(
|
|
431
|
-
(const __m128i *)(x + i))),
|
|
432
|
-
16)));
|
|
433
|
-
}
|
|
434
|
-
//}
|
|
435
|
-
#endif
|
|
436
|
-
for (; i < n; i++) {
|
|
394
|
+
int i = 0;
|
|
395
|
+
for (; i < n; ++i) {
|
|
437
396
|
y[i] = GGML_BF16_TO_FP32(x[i]);
|
|
438
397
|
}
|
|
439
398
|
}
|
|
@@ -956,6 +915,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
|
956
915
|
"CONV_TRANSPOSE_1D",
|
|
957
916
|
"IM2COL",
|
|
958
917
|
"IM2COL_BACK",
|
|
918
|
+
"CONV_2D_DW",
|
|
959
919
|
"CONV_TRANSPOSE_2D",
|
|
960
920
|
"POOL_1D",
|
|
961
921
|
"POOL_2D",
|
|
@@ -982,23 +942,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
|
982
942
|
|
|
983
943
|
"UNARY",
|
|
984
944
|
|
|
985
|
-
"MAP_UNARY",
|
|
986
|
-
"MAP_BINARY",
|
|
987
|
-
|
|
988
|
-
"MAP_CUSTOM1_F32",
|
|
989
|
-
"MAP_CUSTOM2_F32",
|
|
990
|
-
"MAP_CUSTOM3_F32",
|
|
991
|
-
|
|
992
945
|
"MAP_CUSTOM1",
|
|
993
946
|
"MAP_CUSTOM2",
|
|
994
947
|
"MAP_CUSTOM3",
|
|
995
948
|
|
|
949
|
+
"CUSTOM",
|
|
950
|
+
|
|
996
951
|
"CROSS_ENTROPY_LOSS",
|
|
997
952
|
"CROSS_ENTROPY_LOSS_BACK",
|
|
998
953
|
"OPT_STEP_ADAMW",
|
|
999
954
|
};
|
|
1000
955
|
|
|
1001
|
-
static_assert(GGML_OP_COUNT ==
|
|
956
|
+
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
|
1002
957
|
|
|
1003
958
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
1004
959
|
"none",
|
|
@@ -1055,6 +1010,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
|
1055
1010
|
"conv_transpose_1d(x)",
|
|
1056
1011
|
"im2col(x)",
|
|
1057
1012
|
"im2col_back(x)",
|
|
1013
|
+
"conv_2d_dw(x)",
|
|
1058
1014
|
"conv_transpose_2d(x)",
|
|
1059
1015
|
"pool_1d(x)",
|
|
1060
1016
|
"pool_2d(x)",
|
|
@@ -1081,23 +1037,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
|
1081
1037
|
|
|
1082
1038
|
"unary(x)",
|
|
1083
1039
|
|
|
1084
|
-
"
|
|
1085
|
-
"
|
|
1086
|
-
|
|
1087
|
-
"custom_f32(x)",
|
|
1088
|
-
"custom_f32(x,y)",
|
|
1089
|
-
"custom_f32(x,y,z)",
|
|
1040
|
+
"map_custom(x)",
|
|
1041
|
+
"map_custom(x,y)",
|
|
1042
|
+
"map_custom(x,y,z)",
|
|
1090
1043
|
|
|
1091
1044
|
"custom(x)",
|
|
1092
|
-
"custom(x,y)",
|
|
1093
|
-
"custom(x,y,z)",
|
|
1094
1045
|
|
|
1095
1046
|
"cross_entropy_loss(x,y)",
|
|
1096
1047
|
"cross_entropy_loss_back(x,y)",
|
|
1097
1048
|
"adamw(x)",
|
|
1098
1049
|
};
|
|
1099
1050
|
|
|
1100
|
-
static_assert(GGML_OP_COUNT ==
|
|
1051
|
+
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
|
1101
1052
|
|
|
1102
1053
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
|
1103
1054
|
|
|
@@ -1159,6 +1110,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
|
|
1159
1110
|
}
|
|
1160
1111
|
|
|
1161
1112
|
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
|
1113
|
+
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
|
1114
|
+
if (tensor->ne[i] <= 0) {
|
|
1115
|
+
return 0;
|
|
1116
|
+
}
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1162
1119
|
size_t nbytes;
|
|
1163
1120
|
const size_t blck_size = ggml_blck_size(tensor->type);
|
|
1164
1121
|
if (blck_size == 1) {
|
|
@@ -1342,12 +1299,23 @@ bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
|
|
|
1342
1299
|
return ggml_is_contiguous_n(tensor, 2);
|
|
1343
1300
|
}
|
|
1344
1301
|
|
|
1302
|
+
bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
|
|
1303
|
+
return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
|
|
1304
|
+
}
|
|
1305
|
+
|
|
1345
1306
|
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
|
1346
1307
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
|
1347
1308
|
|
|
1348
1309
|
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
|
|
1349
1310
|
}
|
|
1350
1311
|
|
|
1312
|
+
bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
|
|
1313
|
+
return
|
|
1314
|
+
tensor->nb[0] > tensor->nb[2] &&
|
|
1315
|
+
tensor->nb[1] > tensor->nb[0] &&
|
|
1316
|
+
tensor->nb[2] == ggml_type_size(tensor->type);
|
|
1317
|
+
}
|
|
1318
|
+
|
|
1351
1319
|
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
|
1352
1320
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
|
1353
1321
|
|
|
@@ -2764,11 +2732,11 @@ void ggml_mul_mat_set_prec(
|
|
|
2764
2732
|
c = ggml_mul_mat_id(ctx, as, b, ids);
|
|
2765
2733
|
|
|
2766
2734
|
as -> [cols, rows, n_expert]
|
|
2767
|
-
ids -> [n_experts_used, n_tokens] (i32)
|
|
2768
2735
|
b -> [cols, n_expert_used, n_tokens]
|
|
2736
|
+
ids -> [n_expert_used, n_tokens] (i32)
|
|
2769
2737
|
c -> [rows, n_expert_used, n_tokens]
|
|
2770
2738
|
|
|
2771
|
-
in b,
|
|
2739
|
+
in b, n_expert_used can be broadcasted to match the n_expert_used of ids
|
|
2772
2740
|
|
|
2773
2741
|
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
|
|
2774
2742
|
*/
|
|
@@ -4054,6 +4022,46 @@ struct ggml_tensor * ggml_conv_2d_dw(
|
|
|
4054
4022
|
return result;
|
|
4055
4023
|
}
|
|
4056
4024
|
|
|
4025
|
+
// ggml_conv_2d_dw_direct
|
|
4026
|
+
|
|
4027
|
+
struct ggml_tensor * ggml_conv_2d_dw_direct(
|
|
4028
|
+
struct ggml_context * ctx,
|
|
4029
|
+
struct ggml_tensor * a,
|
|
4030
|
+
struct ggml_tensor * b,
|
|
4031
|
+
int stride0,
|
|
4032
|
+
int stride1,
|
|
4033
|
+
int pad0,
|
|
4034
|
+
int pad1,
|
|
4035
|
+
int dilation0,
|
|
4036
|
+
int dilation1) {
|
|
4037
|
+
GGML_ASSERT(a->ne[2] == 1);
|
|
4038
|
+
GGML_ASSERT(a->ne[3] == b->ne[2]);
|
|
4039
|
+
int64_t ne[4];
|
|
4040
|
+
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
|
|
4041
|
+
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
|
|
4042
|
+
ne[2] = b->ne[2];
|
|
4043
|
+
ne[3] = b->ne[3];
|
|
4044
|
+
|
|
4045
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
|
|
4046
|
+
|
|
4047
|
+
if (ggml_is_contiguous_channels(b)) {
|
|
4048
|
+
// Result will be permuted the same way as input (CWHN order)
|
|
4049
|
+
const int64_t type_size = ggml_type_size(result->type);
|
|
4050
|
+
GGML_ASSERT(ggml_blck_size(result->type) == 1);
|
|
4051
|
+
result->nb[0] = result->ne[2] * type_size;
|
|
4052
|
+
result->nb[1] = result->ne[0] * result->nb[0];
|
|
4053
|
+
result->nb[2] = type_size;
|
|
4054
|
+
}
|
|
4055
|
+
|
|
4056
|
+
int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
|
|
4057
|
+
ggml_set_op_params(result, params, sizeof(params));
|
|
4058
|
+
|
|
4059
|
+
result->op = GGML_OP_CONV_2D_DW;
|
|
4060
|
+
result->src[0] = a;
|
|
4061
|
+
result->src[1] = b;
|
|
4062
|
+
return result;
|
|
4063
|
+
}
|
|
4064
|
+
|
|
4057
4065
|
// ggml_conv_transpose_2d_p0
|
|
4058
4066
|
|
|
4059
4067
|
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
|
|
@@ -4178,7 +4186,8 @@ static struct ggml_tensor * ggml_upscale_impl(
|
|
|
4178
4186
|
int ne0,
|
|
4179
4187
|
int ne1,
|
|
4180
4188
|
int ne2,
|
|
4181
|
-
int ne3
|
|
4189
|
+
int ne3,
|
|
4190
|
+
enum ggml_scale_mode mode) {
|
|
4182
4191
|
GGML_ASSERT(a->ne[0] <= ne0);
|
|
4183
4192
|
GGML_ASSERT(a->ne[1] <= ne1);
|
|
4184
4193
|
GGML_ASSERT(a->ne[2] <= ne2);
|
|
@@ -4186,6 +4195,8 @@ static struct ggml_tensor * ggml_upscale_impl(
|
|
|
4186
4195
|
|
|
4187
4196
|
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
|
|
4188
4197
|
|
|
4198
|
+
ggml_set_op_params_i32(result, 0, mode);
|
|
4199
|
+
|
|
4189
4200
|
result->op = GGML_OP_UPSCALE;
|
|
4190
4201
|
result->src[0] = a;
|
|
4191
4202
|
|
|
@@ -4195,8 +4206,9 @@ static struct ggml_tensor * ggml_upscale_impl(
|
|
|
4195
4206
|
struct ggml_tensor * ggml_upscale(
|
|
4196
4207
|
struct ggml_context * ctx,
|
|
4197
4208
|
struct ggml_tensor * a,
|
|
4198
|
-
int scale_factor
|
|
4199
|
-
|
|
4209
|
+
int scale_factor,
|
|
4210
|
+
enum ggml_scale_mode mode) {
|
|
4211
|
+
return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
|
|
4200
4212
|
}
|
|
4201
4213
|
|
|
4202
4214
|
struct ggml_tensor * ggml_upscale_ext(
|
|
@@ -4205,8 +4217,9 @@ struct ggml_tensor * ggml_upscale_ext(
|
|
|
4205
4217
|
int ne0,
|
|
4206
4218
|
int ne1,
|
|
4207
4219
|
int ne2,
|
|
4208
|
-
int ne3
|
|
4209
|
-
|
|
4220
|
+
int ne3,
|
|
4221
|
+
enum ggml_scale_mode mode) {
|
|
4222
|
+
return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
|
|
4210
4223
|
}
|
|
4211
4224
|
|
|
4212
4225
|
// ggml_pad
|
|
@@ -4369,7 +4382,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|
|
4369
4382
|
}
|
|
4370
4383
|
|
|
4371
4384
|
// permute(0, 2, 1, 3)
|
|
4372
|
-
int64_t ne[4] = {
|
|
4385
|
+
int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
|
4373
4386
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
|
4374
4387
|
|
|
4375
4388
|
float params[] = { scale, max_bias, logit_softcap };
|
|
@@ -4836,179 +4849,6 @@ struct ggml_tensor * ggml_unary_inplace(
|
|
|
4836
4849
|
return ggml_unary_impl(ctx, a, op, true);
|
|
4837
4850
|
}
|
|
4838
4851
|
|
|
4839
|
-
// ggml_map_unary
|
|
4840
|
-
|
|
4841
|
-
static struct ggml_tensor * ggml_map_unary_impl_f32(
|
|
4842
|
-
struct ggml_context * ctx,
|
|
4843
|
-
struct ggml_tensor * a,
|
|
4844
|
-
const ggml_unary_op_f32_t fun,
|
|
4845
|
-
bool inplace) {
|
|
4846
|
-
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
4847
|
-
|
|
4848
|
-
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
|
4849
|
-
|
|
4850
|
-
result->op = GGML_OP_MAP_UNARY;
|
|
4851
|
-
result->src[0] = a;
|
|
4852
|
-
|
|
4853
|
-
return result;
|
|
4854
|
-
}
|
|
4855
|
-
|
|
4856
|
-
struct ggml_tensor * ggml_map_unary_f32(
|
|
4857
|
-
struct ggml_context * ctx,
|
|
4858
|
-
struct ggml_tensor * a,
|
|
4859
|
-
const ggml_unary_op_f32_t fun) {
|
|
4860
|
-
return ggml_map_unary_impl_f32(ctx, a, fun, false);
|
|
4861
|
-
}
|
|
4862
|
-
|
|
4863
|
-
struct ggml_tensor * ggml_map_unary_inplace_f32(
|
|
4864
|
-
struct ggml_context * ctx,
|
|
4865
|
-
struct ggml_tensor * a,
|
|
4866
|
-
const ggml_unary_op_f32_t fun) {
|
|
4867
|
-
return ggml_map_unary_impl_f32(ctx, a, fun, true);
|
|
4868
|
-
}
|
|
4869
|
-
|
|
4870
|
-
// ggml_map_binary
|
|
4871
|
-
|
|
4872
|
-
static struct ggml_tensor * ggml_map_binary_impl_f32(
|
|
4873
|
-
struct ggml_context * ctx,
|
|
4874
|
-
struct ggml_tensor * a,
|
|
4875
|
-
struct ggml_tensor * b,
|
|
4876
|
-
const ggml_binary_op_f32_t fun,
|
|
4877
|
-
bool inplace) {
|
|
4878
|
-
GGML_ASSERT(ggml_are_same_shape(a, b));
|
|
4879
|
-
|
|
4880
|
-
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
4881
|
-
|
|
4882
|
-
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
|
4883
|
-
|
|
4884
|
-
result->op = GGML_OP_MAP_BINARY;
|
|
4885
|
-
result->src[0] = a;
|
|
4886
|
-
result->src[1] = b;
|
|
4887
|
-
|
|
4888
|
-
return result;
|
|
4889
|
-
}
|
|
4890
|
-
|
|
4891
|
-
struct ggml_tensor * ggml_map_binary_f32(
|
|
4892
|
-
struct ggml_context * ctx,
|
|
4893
|
-
struct ggml_tensor * a,
|
|
4894
|
-
struct ggml_tensor * b,
|
|
4895
|
-
const ggml_binary_op_f32_t fun) {
|
|
4896
|
-
return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
|
|
4897
|
-
}
|
|
4898
|
-
|
|
4899
|
-
struct ggml_tensor * ggml_map_binary_inplace_f32(
|
|
4900
|
-
struct ggml_context * ctx,
|
|
4901
|
-
struct ggml_tensor * a,
|
|
4902
|
-
struct ggml_tensor * b,
|
|
4903
|
-
const ggml_binary_op_f32_t fun) {
|
|
4904
|
-
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
|
4905
|
-
}
|
|
4906
|
-
|
|
4907
|
-
// ggml_map_custom1_f32
|
|
4908
|
-
|
|
4909
|
-
static struct ggml_tensor * ggml_map_custom1_impl_f32(
|
|
4910
|
-
struct ggml_context * ctx,
|
|
4911
|
-
struct ggml_tensor * a,
|
|
4912
|
-
const ggml_custom1_op_f32_t fun,
|
|
4913
|
-
bool inplace) {
|
|
4914
|
-
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
4915
|
-
|
|
4916
|
-
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
|
4917
|
-
|
|
4918
|
-
result->op = GGML_OP_MAP_CUSTOM1_F32;
|
|
4919
|
-
result->src[0] = a;
|
|
4920
|
-
|
|
4921
|
-
return result;
|
|
4922
|
-
}
|
|
4923
|
-
|
|
4924
|
-
struct ggml_tensor * ggml_map_custom1_f32(
|
|
4925
|
-
struct ggml_context * ctx,
|
|
4926
|
-
struct ggml_tensor * a,
|
|
4927
|
-
const ggml_custom1_op_f32_t fun) {
|
|
4928
|
-
return ggml_map_custom1_impl_f32(ctx, a, fun, false);
|
|
4929
|
-
}
|
|
4930
|
-
|
|
4931
|
-
struct ggml_tensor * ggml_map_custom1_inplace_f32(
|
|
4932
|
-
struct ggml_context * ctx,
|
|
4933
|
-
struct ggml_tensor * a,
|
|
4934
|
-
const ggml_custom1_op_f32_t fun) {
|
|
4935
|
-
return ggml_map_custom1_impl_f32(ctx, a, fun, true);
|
|
4936
|
-
}
|
|
4937
|
-
|
|
4938
|
-
// ggml_map_custom2_f32
|
|
4939
|
-
|
|
4940
|
-
static struct ggml_tensor * ggml_map_custom2_impl_f32(
|
|
4941
|
-
struct ggml_context * ctx,
|
|
4942
|
-
struct ggml_tensor * a,
|
|
4943
|
-
struct ggml_tensor * b,
|
|
4944
|
-
const ggml_custom2_op_f32_t fun,
|
|
4945
|
-
bool inplace) {
|
|
4946
|
-
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
4947
|
-
|
|
4948
|
-
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
|
4949
|
-
|
|
4950
|
-
result->op = GGML_OP_MAP_CUSTOM2_F32;
|
|
4951
|
-
result->src[0] = a;
|
|
4952
|
-
result->src[1] = b;
|
|
4953
|
-
|
|
4954
|
-
return result;
|
|
4955
|
-
}
|
|
4956
|
-
|
|
4957
|
-
struct ggml_tensor * ggml_map_custom2_f32(
|
|
4958
|
-
struct ggml_context * ctx,
|
|
4959
|
-
struct ggml_tensor * a,
|
|
4960
|
-
struct ggml_tensor * b,
|
|
4961
|
-
const ggml_custom2_op_f32_t fun) {
|
|
4962
|
-
return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
|
|
4963
|
-
}
|
|
4964
|
-
|
|
4965
|
-
struct ggml_tensor * ggml_map_custom2_inplace_f32(
|
|
4966
|
-
struct ggml_context * ctx,
|
|
4967
|
-
struct ggml_tensor * a,
|
|
4968
|
-
struct ggml_tensor * b,
|
|
4969
|
-
const ggml_custom2_op_f32_t fun) {
|
|
4970
|
-
return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
|
|
4971
|
-
}
|
|
4972
|
-
|
|
4973
|
-
// ggml_map_custom3_f32
|
|
4974
|
-
|
|
4975
|
-
static struct ggml_tensor * ggml_map_custom3_impl_f32(
|
|
4976
|
-
struct ggml_context * ctx,
|
|
4977
|
-
struct ggml_tensor * a,
|
|
4978
|
-
struct ggml_tensor * b,
|
|
4979
|
-
struct ggml_tensor * c,
|
|
4980
|
-
const ggml_custom3_op_f32_t fun,
|
|
4981
|
-
bool inplace) {
|
|
4982
|
-
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
4983
|
-
|
|
4984
|
-
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
|
4985
|
-
|
|
4986
|
-
result->op = GGML_OP_MAP_CUSTOM3_F32;
|
|
4987
|
-
result->src[0] = a;
|
|
4988
|
-
result->src[1] = b;
|
|
4989
|
-
result->src[2] = c;
|
|
4990
|
-
|
|
4991
|
-
return result;
|
|
4992
|
-
}
|
|
4993
|
-
|
|
4994
|
-
struct ggml_tensor * ggml_map_custom3_f32(
|
|
4995
|
-
struct ggml_context * ctx,
|
|
4996
|
-
struct ggml_tensor * a,
|
|
4997
|
-
struct ggml_tensor * b,
|
|
4998
|
-
struct ggml_tensor * c,
|
|
4999
|
-
const ggml_custom3_op_f32_t fun) {
|
|
5000
|
-
return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
|
|
5001
|
-
}
|
|
5002
|
-
|
|
5003
|
-
struct ggml_tensor * ggml_map_custom3_inplace_f32(
|
|
5004
|
-
struct ggml_context * ctx,
|
|
5005
|
-
struct ggml_tensor * a,
|
|
5006
|
-
struct ggml_tensor * b,
|
|
5007
|
-
struct ggml_tensor * c,
|
|
5008
|
-
const ggml_custom3_op_f32_t fun) {
|
|
5009
|
-
return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
|
|
5010
|
-
}
|
|
5011
|
-
|
|
5012
4852
|
// ggml_map_custom1
|
|
5013
4853
|
|
|
5014
4854
|
static struct ggml_tensor * ggml_map_custom1_impl(
|
|
@@ -5027,7 +4867,7 @@ static struct ggml_tensor * ggml_map_custom1_impl(
|
|
|
5027
4867
|
/*.n_tasks =*/ n_tasks,
|
|
5028
4868
|
/*.userdata =*/ userdata
|
|
5029
4869
|
};
|
|
5030
|
-
ggml_set_op_params(result,
|
|
4870
|
+
ggml_set_op_params(result, ¶ms, sizeof(params));
|
|
5031
4871
|
|
|
5032
4872
|
result->op = GGML_OP_MAP_CUSTOM1;
|
|
5033
4873
|
result->src[0] = a;
|
|
@@ -5072,7 +4912,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
|
|
|
5072
4912
|
/*.n_tasks =*/ n_tasks,
|
|
5073
4913
|
/*.userdata =*/ userdata
|
|
5074
4914
|
};
|
|
5075
|
-
ggml_set_op_params(result,
|
|
4915
|
+
ggml_set_op_params(result, ¶ms, sizeof(params));
|
|
5076
4916
|
|
|
5077
4917
|
result->op = GGML_OP_MAP_CUSTOM2;
|
|
5078
4918
|
result->src[0] = a;
|
|
@@ -5121,7 +4961,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
|
|
|
5121
4961
|
/*.n_tasks =*/ n_tasks,
|
|
5122
4962
|
/*.userdata =*/ userdata
|
|
5123
4963
|
};
|
|
5124
|
-
ggml_set_op_params(result,
|
|
4964
|
+
ggml_set_op_params(result, ¶ms, sizeof(params));
|
|
5125
4965
|
|
|
5126
4966
|
result->op = GGML_OP_MAP_CUSTOM3;
|
|
5127
4967
|
result->src[0] = a;
|
|
@@ -5153,6 +4993,66 @@ struct ggml_tensor * ggml_map_custom3_inplace(
|
|
|
5153
4993
|
return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
|
|
5154
4994
|
}
|
|
5155
4995
|
|
|
4996
|
+
struct ggml_tensor * ggml_custom_4d(
|
|
4997
|
+
struct ggml_context * ctx,
|
|
4998
|
+
enum ggml_type type,
|
|
4999
|
+
int64_t ne0,
|
|
5000
|
+
int64_t ne1,
|
|
5001
|
+
int64_t ne2,
|
|
5002
|
+
int64_t ne3,
|
|
5003
|
+
struct ggml_tensor ** args,
|
|
5004
|
+
int n_args,
|
|
5005
|
+
ggml_custom_op_t fun,
|
|
5006
|
+
int n_tasks,
|
|
5007
|
+
void * userdata) {
|
|
5008
|
+
|
|
5009
|
+
GGML_ASSERT(n_args < GGML_MAX_SRC);
|
|
5010
|
+
|
|
5011
|
+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
|
|
5012
|
+
|
|
5013
|
+
struct ggml_custom_op_params params = {
|
|
5014
|
+
/*.fun =*/ fun,
|
|
5015
|
+
/*.n_tasks =*/ n_tasks,
|
|
5016
|
+
/*.userdata =*/ userdata
|
|
5017
|
+
};
|
|
5018
|
+
ggml_set_op_params(result, ¶ms, sizeof(params));
|
|
5019
|
+
|
|
5020
|
+
result->op = GGML_OP_CUSTOM;
|
|
5021
|
+
for (int i = 0; i < n_args; i++) {
|
|
5022
|
+
result->src[i] = args[i];
|
|
5023
|
+
}
|
|
5024
|
+
|
|
5025
|
+
return result;
|
|
5026
|
+
}
|
|
5027
|
+
|
|
5028
|
+
struct ggml_tensor * ggml_custom_inplace(
|
|
5029
|
+
struct ggml_context * ctx,
|
|
5030
|
+
struct ggml_tensor * a,
|
|
5031
|
+
struct ggml_tensor ** args,
|
|
5032
|
+
int n_args,
|
|
5033
|
+
ggml_custom_op_t fun,
|
|
5034
|
+
int n_tasks,
|
|
5035
|
+
void * userdata) {
|
|
5036
|
+
|
|
5037
|
+
GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
|
|
5038
|
+
|
|
5039
|
+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
|
5040
|
+
|
|
5041
|
+
struct ggml_custom_op_params params = {
|
|
5042
|
+
/*.fun =*/ fun,
|
|
5043
|
+
/*.n_tasks =*/ n_tasks,
|
|
5044
|
+
/*.userdata =*/ userdata
|
|
5045
|
+
};
|
|
5046
|
+
ggml_set_op_params(result, ¶ms, sizeof(params));
|
|
5047
|
+
|
|
5048
|
+
result->op = GGML_OP_CUSTOM;
|
|
5049
|
+
result->src[0] = a;
|
|
5050
|
+
for (int i = 0; i < n_args; i++) {
|
|
5051
|
+
result->src[i + 1] = args[i];
|
|
5052
|
+
}
|
|
5053
|
+
|
|
5054
|
+
return result;
|
|
5055
|
+
}
|
|
5156
5056
|
// ggml_cross_entropy_loss
|
|
5157
5057
|
|
|
5158
5058
|
struct ggml_tensor * ggml_cross_entropy_loss(
|
|
@@ -5599,7 +5499,7 @@ static void ggml_compute_backward(
|
|
|
5599
5499
|
// tensor = src0 * 1 + src1 * 0
|
|
5600
5500
|
if (src0_needs_grads) {
|
|
5601
5501
|
// dsrc0 = dtensor * 1
|
|
5602
|
-
ggml_add_or_set(ctx, cgraph, isrc0, grad);
|
|
5502
|
+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
|
|
5603
5503
|
}
|
|
5604
5504
|
if (src1_needs_grads) {
|
|
5605
5505
|
// dsrc1 = dtensor * 0 -> noop
|
|
@@ -5880,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
|
|
5880
5780
|
}
|
|
5881
5781
|
|
|
5882
5782
|
void ggml_build_backward_expand(
|
|
5883
|
-
struct ggml_context *
|
|
5884
|
-
struct
|
|
5885
|
-
struct
|
|
5886
|
-
bool accumulate) {
|
|
5783
|
+
struct ggml_context * ctx,
|
|
5784
|
+
struct ggml_cgraph * cgraph,
|
|
5785
|
+
struct ggml_tensor ** grad_accs) {
|
|
5887
5786
|
GGML_ASSERT(cgraph->n_nodes > 0);
|
|
5888
5787
|
GGML_ASSERT(cgraph->grads);
|
|
5889
5788
|
GGML_ASSERT(cgraph->grad_accs);
|
|
@@ -5956,21 +5855,24 @@ void ggml_build_backward_expand(
|
|
|
5956
5855
|
GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
|
|
5957
5856
|
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
|
|
5958
5857
|
|
|
5959
|
-
const size_t
|
|
5960
|
-
GGML_ASSERT(
|
|
5961
|
-
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used,
|
|
5962
|
-
if (
|
|
5963
|
-
cgraph->grad_accs[
|
|
5964
|
-
cgraph->grads[
|
|
5965
|
-
|
|
5858
|
+
const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
|
|
5859
|
+
GGML_ASSERT(ihash != GGML_HASHSET_FULL);
|
|
5860
|
+
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
|
|
5861
|
+
if (grad_accs && grad_accs[i]) {
|
|
5862
|
+
cgraph->grad_accs[ihash] = grad_accs[i];
|
|
5863
|
+
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
|
5864
|
+
} else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
|
|
5865
|
+
// loss tensors always need a gradient accumulator
|
|
5866
|
+
cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
|
5867
|
+
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
|
5966
5868
|
}
|
|
5967
|
-
grads_needed[
|
|
5869
|
+
grads_needed[ihash] = true;
|
|
5968
5870
|
}
|
|
5969
5871
|
|
|
5970
5872
|
for (int i = n_nodes_f - 1; i >= 0; --i) {
|
|
5971
5873
|
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
|
|
5972
5874
|
// use allocator to automatically make inplace operations
|
|
5973
|
-
ggml_compute_backward(
|
|
5875
|
+
ggml_compute_backward(ctx, cgraph, i, grads_needed);
|
|
5974
5876
|
}
|
|
5975
5877
|
|
|
5976
5878
|
free(grads_needed);
|
|
@@ -6116,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
|
|
|
6116
6018
|
}
|
|
6117
6019
|
}
|
|
6118
6020
|
|
|
6119
|
-
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
|
|
6120
|
-
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads
|
|
6021
|
+
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
|
|
6022
|
+
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
|
|
6121
6023
|
ggml_graph_cpy(cgraph, result);
|
|
6122
6024
|
return result;
|
|
6123
6025
|
}
|
|
@@ -6136,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
|
|
6136
6038
|
}
|
|
6137
6039
|
|
|
6138
6040
|
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
|
6041
|
+
if (!cgraph) {
|
|
6042
|
+
return;
|
|
6043
|
+
}
|
|
6139
6044
|
GGML_ASSERT(cgraph->grads != NULL);
|
|
6140
6045
|
|
|
6141
6046
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
@@ -6445,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
|
|
|
6445
6350
|
tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
|
|
6446
6351
|
}
|
|
6447
6352
|
|
|
6448
|
-
void ggml_set_param(struct
|
|
6449
|
-
|
|
6353
|
+
void ggml_set_param(struct ggml_tensor * tensor) {
|
|
6354
|
+
GGML_ASSERT(tensor->op == GGML_OP_NONE);
|
|
6450
6355
|
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
|
|
6451
6356
|
}
|
|
6452
6357
|
|