@fugood/llama.node 0.3.17 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -5,15 +5,35 @@ find_package (Threads REQUIRED)
|
|
|
5
5
|
|
|
6
6
|
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
7
7
|
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
8
|
+
message(STATUS "Enabling coopmat glslc support")
|
|
8
9
|
endif()
|
|
9
10
|
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
10
11
|
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
12
|
+
message(STATUS "Enabling coopmat2 glslc support")
|
|
11
13
|
endif()
|
|
12
14
|
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
13
15
|
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
16
|
+
message(STATUS "Enabling dot glslc support")
|
|
14
17
|
endif()
|
|
18
|
+
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
19
|
+
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
20
|
+
message(STATUS "Enabling bfloat16 glslc support")
|
|
21
|
+
endif()
|
|
22
|
+
|
|
15
23
|
set(TARGET vulkan-shaders-gen)
|
|
16
24
|
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
|
17
25
|
install(TARGETS ${TARGET} RUNTIME)
|
|
18
26
|
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
|
19
27
|
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
|
|
28
|
+
|
|
29
|
+
# Configure output directories for MSVC builds
|
|
30
|
+
if(MSVC)
|
|
31
|
+
# Get the main project's runtime output directory if possible
|
|
32
|
+
if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY)
|
|
33
|
+
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
|
|
34
|
+
string(TOUPPER ${CONFIG} CONFIG)
|
|
35
|
+
set_target_properties(${TARGET} PROPERTIES
|
|
36
|
+
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
|
37
|
+
endforeach()
|
|
38
|
+
endif()
|
|
39
|
+
endif()
|
|
@@ -63,7 +63,8 @@ const std::vector<std::string> type_names = {
|
|
|
63
63
|
"iq3_xxs",
|
|
64
64
|
"iq3_s",
|
|
65
65
|
"iq4_xs",
|
|
66
|
-
"iq4_nl"
|
|
66
|
+
"iq4_nl",
|
|
67
|
+
"bf16",
|
|
67
68
|
};
|
|
68
69
|
|
|
69
70
|
namespace {
|
|
@@ -214,7 +215,7 @@ static std::mutex compile_count_mutex;
|
|
|
214
215
|
static std::condition_variable compile_count_cond;
|
|
215
216
|
|
|
216
217
|
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
|
217
|
-
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "
|
|
218
|
+
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
|
218
219
|
std::string out_fname = join_paths(output_dir, name + ".spv");
|
|
219
220
|
std::string in_path = join_paths(input_dir, in_fname);
|
|
220
221
|
|
|
@@ -296,7 +297,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
|
296
297
|
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
|
297
298
|
|
|
298
299
|
std::map<std::string, std::string> base_dict = {
|
|
299
|
-
{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
|
|
300
300
|
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
|
301
301
|
};
|
|
302
302
|
std::string shader_name = "matmul";
|
|
@@ -318,12 +318,45 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
|
318
318
|
|
|
319
319
|
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
|
320
320
|
|
|
321
|
+
auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
|
|
322
|
+
if (t == "bf16") {
|
|
323
|
+
// scalar path promotes to float
|
|
324
|
+
if (!coopmat && !coopmat2) {
|
|
325
|
+
return "float";
|
|
326
|
+
}
|
|
327
|
+
return "bfloat16_t";
|
|
328
|
+
}
|
|
329
|
+
if (coopmat2 || fp16) {
|
|
330
|
+
return "float16_t";
|
|
331
|
+
}
|
|
332
|
+
return "float";
|
|
333
|
+
};
|
|
334
|
+
|
|
321
335
|
// Shaders with f16 B_TYPE
|
|
322
|
-
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
|
323
|
-
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
336
|
+
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
|
337
|
+
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
324
338
|
|
|
325
|
-
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
326
|
-
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
|
339
|
+
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
340
|
+
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
|
341
|
+
|
|
342
|
+
// bf16
|
|
343
|
+
{
|
|
344
|
+
std::string load_vec_a_unaligned = "1";
|
|
345
|
+
// For aligned matmul loads
|
|
346
|
+
std::string load_vec_a = coopmat2 ? "1" : "4";
|
|
347
|
+
|
|
348
|
+
// scalar path promotes to float
|
|
349
|
+
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
|
|
350
|
+
|
|
351
|
+
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
|
|
352
|
+
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
353
|
+
if (!(coopmat || coopmat2))
|
|
354
|
+
#endif
|
|
355
|
+
{
|
|
356
|
+
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
357
|
+
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
358
|
+
}
|
|
359
|
+
}
|
|
327
360
|
|
|
328
361
|
for (const auto& tname : type_names) {
|
|
329
362
|
std::string load_vec_quant = "2";
|
|
@@ -332,26 +365,30 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
|
332
365
|
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
|
333
366
|
load_vec_quant = "4";
|
|
334
367
|
|
|
368
|
+
if (tname == "bf16") {
|
|
369
|
+
continue;
|
|
370
|
+
}
|
|
371
|
+
|
|
335
372
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
336
373
|
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
|
337
|
-
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
|
|
374
|
+
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
|
338
375
|
// For aligned matmul loads
|
|
339
|
-
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
|
|
376
|
+
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
|
340
377
|
|
|
341
378
|
// don't generate f32 variants for coopmat2
|
|
342
379
|
if (!coopmat2) {
|
|
343
|
-
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
|
344
|
-
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
380
|
+
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
|
381
|
+
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
345
382
|
}
|
|
346
383
|
|
|
347
384
|
if (tname != "f16" && tname != "f32") {
|
|
348
|
-
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
|
349
|
-
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
385
|
+
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
|
386
|
+
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
350
387
|
}
|
|
351
388
|
|
|
352
389
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
353
390
|
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
|
354
|
-
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
|
391
|
+
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
|
355
392
|
}
|
|
356
393
|
#endif
|
|
357
394
|
}
|
|
@@ -384,16 +421,18 @@ void process_shaders() {
|
|
|
384
421
|
#endif
|
|
385
422
|
}
|
|
386
423
|
|
|
387
|
-
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
388
424
|
// flash attention
|
|
389
425
|
for (const auto& f16acc : {false, true}) {
|
|
390
426
|
std::string acctype = f16acc ? "float16_t" : "float";
|
|
427
|
+
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
|
|
391
428
|
|
|
392
429
|
for (const auto& tname : type_names) {
|
|
393
430
|
if (tname == "f32") {
|
|
394
431
|
continue;
|
|
395
432
|
}
|
|
433
|
+
if (tname == "bf16") continue;
|
|
396
434
|
|
|
435
|
+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
397
436
|
if (tname == "f16") {
|
|
398
437
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
|
399
438
|
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
|
|
@@ -402,9 +441,27 @@ void process_shaders() {
|
|
|
402
441
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
|
403
442
|
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
|
404
443
|
}
|
|
444
|
+
#endif
|
|
445
|
+
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
446
|
+
if (tname == "f16") {
|
|
447
|
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
|
448
|
+
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
|
449
|
+
} else if (tname == "q4_0" || tname == "q8_0") {
|
|
450
|
+
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
451
|
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
|
452
|
+
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
|
453
|
+
}
|
|
454
|
+
#endif
|
|
455
|
+
if (tname == "f16") {
|
|
456
|
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
|
457
|
+
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
|
|
458
|
+
} else if (tname == "q4_0" || tname == "q8_0") {
|
|
459
|
+
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
460
|
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
|
461
|
+
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
|
462
|
+
}
|
|
405
463
|
}
|
|
406
464
|
}
|
|
407
|
-
#endif
|
|
408
465
|
|
|
409
466
|
for (const auto& tname : type_names) {
|
|
410
467
|
// mul mat vec
|
|
@@ -417,12 +474,12 @@ void process_shaders() {
|
|
|
417
474
|
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
|
418
475
|
|
|
419
476
|
// Dequant shaders
|
|
420
|
-
if (tname != "f16") {
|
|
477
|
+
if (tname != "f16" && tname != "bf16") {
|
|
421
478
|
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
|
422
479
|
}
|
|
423
480
|
|
|
424
481
|
if (!string_ends_with(tname, "_k")) {
|
|
425
|
-
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
|
|
482
|
+
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
|
|
426
483
|
|
|
427
484
|
if (tname == "f16") {
|
|
428
485
|
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
|
@@ -447,9 +504,13 @@ void process_shaders() {
|
|
|
447
504
|
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
448
505
|
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
|
449
506
|
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
|
507
|
+
string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
|
508
|
+
string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
|
450
509
|
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
451
510
|
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
|
452
511
|
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
|
512
|
+
string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
|
513
|
+
string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
|
453
514
|
|
|
454
515
|
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
|
455
516
|
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
@@ -457,8 +518,26 @@ void process_shaders() {
|
|
|
457
518
|
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
458
519
|
}
|
|
459
520
|
|
|
460
|
-
|
|
461
|
-
|
|
521
|
+
auto get_type_str = [](bool f16) {
|
|
522
|
+
return f16 ? "float16_t" : "float";
|
|
523
|
+
};
|
|
524
|
+
auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
|
|
525
|
+
std::string s;
|
|
526
|
+
s += std::string(src0_f16 ? "_f16" : "_f32");
|
|
527
|
+
s += std::string(src1_f16 ? "_f16" : "_f32");
|
|
528
|
+
s += std::string(dst_f16 ? "_f16" : "_f32");
|
|
529
|
+
return s;
|
|
530
|
+
};
|
|
531
|
+
for (std::string op : {"add", "sub", "mul", "div"}) {
|
|
532
|
+
for (auto src0_f16 : {false, true}) {
|
|
533
|
+
for (auto src1_f16 : {false, true}) {
|
|
534
|
+
for (auto dst_f16 : {false, true}) {
|
|
535
|
+
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
|
|
536
|
+
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
}
|
|
540
|
+
}
|
|
462
541
|
|
|
463
542
|
string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
464
543
|
|
|
@@ -493,14 +572,21 @@ void process_shaders() {
|
|
|
493
572
|
|
|
494
573
|
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
495
574
|
|
|
496
|
-
string_to_spv("
|
|
497
|
-
string_to_spv("
|
|
498
|
-
string_to_spv("
|
|
499
|
-
string_to_spv("
|
|
500
|
-
string_to_spv("
|
|
501
|
-
string_to_spv("
|
|
502
|
-
string_to_spv("
|
|
503
|
-
string_to_spv("
|
|
575
|
+
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
576
|
+
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
577
|
+
string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
578
|
+
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
579
|
+
string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
580
|
+
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
581
|
+
string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
582
|
+
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
583
|
+
string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
584
|
+
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
585
|
+
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
586
|
+
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
587
|
+
|
|
588
|
+
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
589
|
+
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
504
590
|
|
|
505
591
|
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
506
592
|
|
|
@@ -544,6 +630,9 @@ void process_shaders() {
|
|
|
544
630
|
|
|
545
631
|
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
|
546
632
|
|
|
633
|
+
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
|
634
|
+
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
|
635
|
+
|
|
547
636
|
for (auto &c : compiles) {
|
|
548
637
|
c.wait();
|
|
549
638
|
}
|
|
@@ -598,7 +687,12 @@ void write_output_files() {
|
|
|
598
687
|
std::remove(path.c_str());
|
|
599
688
|
}
|
|
600
689
|
}
|
|
601
|
-
|
|
690
|
+
for (const char *op : {"add", "sub", "mul", "div"}) {
|
|
691
|
+
fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
|
|
692
|
+
fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
|
|
693
|
+
fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
|
|
694
|
+
fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
|
|
695
|
+
}
|
|
602
696
|
fclose(hdr);
|
|
603
697
|
fclose(src);
|
|
604
698
|
}
|
|
@@ -1299,6 +1299,10 @@ bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
|
|
|
1299
1299
|
return ggml_is_contiguous_n(tensor, 2);
|
|
1300
1300
|
}
|
|
1301
1301
|
|
|
1302
|
+
bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
|
|
1303
|
+
return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
|
|
1304
|
+
}
|
|
1305
|
+
|
|
1302
1306
|
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
|
1303
1307
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
|
1304
1308
|
|
|
@@ -2728,11 +2732,11 @@ void ggml_mul_mat_set_prec(
|
|
|
2728
2732
|
c = ggml_mul_mat_id(ctx, as, b, ids);
|
|
2729
2733
|
|
|
2730
2734
|
as -> [cols, rows, n_expert]
|
|
2731
|
-
ids -> [n_experts_used, n_tokens] (i32)
|
|
2732
2735
|
b -> [cols, n_expert_used, n_tokens]
|
|
2736
|
+
ids -> [n_expert_used, n_tokens] (i32)
|
|
2733
2737
|
c -> [rows, n_expert_used, n_tokens]
|
|
2734
2738
|
|
|
2735
|
-
in b,
|
|
2739
|
+
in b, n_expert_used can be broadcasted to match the n_expert_used of ids
|
|
2736
2740
|
|
|
2737
2741
|
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
|
|
2738
2742
|
*/
|
|
@@ -5495,7 +5499,7 @@ static void ggml_compute_backward(
|
|
|
5495
5499
|
// tensor = src0 * 1 + src1 * 0
|
|
5496
5500
|
if (src0_needs_grads) {
|
|
5497
5501
|
// dsrc0 = dtensor * 1
|
|
5498
|
-
ggml_add_or_set(ctx, cgraph, isrc0, grad);
|
|
5502
|
+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
|
|
5499
5503
|
}
|
|
5500
5504
|
if (src1_needs_grads) {
|
|
5501
5505
|
// dsrc1 = dtensor * 0 -> noop
|
|
@@ -5776,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
|
|
5776
5780
|
}
|
|
5777
5781
|
|
|
5778
5782
|
void ggml_build_backward_expand(
|
|
5779
|
-
struct ggml_context *
|
|
5780
|
-
struct
|
|
5781
|
-
struct
|
|
5782
|
-
bool accumulate) {
|
|
5783
|
+
struct ggml_context * ctx,
|
|
5784
|
+
struct ggml_cgraph * cgraph,
|
|
5785
|
+
struct ggml_tensor ** grad_accs) {
|
|
5783
5786
|
GGML_ASSERT(cgraph->n_nodes > 0);
|
|
5784
5787
|
GGML_ASSERT(cgraph->grads);
|
|
5785
5788
|
GGML_ASSERT(cgraph->grad_accs);
|
|
@@ -5852,21 +5855,24 @@ void ggml_build_backward_expand(
|
|
|
5852
5855
|
GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
|
|
5853
5856
|
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
|
|
5854
5857
|
|
|
5855
|
-
const size_t
|
|
5856
|
-
GGML_ASSERT(
|
|
5857
|
-
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used,
|
|
5858
|
-
if (
|
|
5859
|
-
cgraph->grad_accs[
|
|
5860
|
-
cgraph->grads[
|
|
5861
|
-
|
|
5858
|
+
const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
|
|
5859
|
+
GGML_ASSERT(ihash != GGML_HASHSET_FULL);
|
|
5860
|
+
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
|
|
5861
|
+
if (grad_accs && grad_accs[i]) {
|
|
5862
|
+
cgraph->grad_accs[ihash] = grad_accs[i];
|
|
5863
|
+
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
|
5864
|
+
} else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
|
|
5865
|
+
// loss tensors always need a gradient accumulator
|
|
5866
|
+
cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
|
5867
|
+
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
|
5862
5868
|
}
|
|
5863
|
-
grads_needed[
|
|
5869
|
+
grads_needed[ihash] = true;
|
|
5864
5870
|
}
|
|
5865
5871
|
|
|
5866
5872
|
for (int i = n_nodes_f - 1; i >= 0; --i) {
|
|
5867
5873
|
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
|
|
5868
5874
|
// use allocator to automatically make inplace operations
|
|
5869
|
-
ggml_compute_backward(
|
|
5875
|
+
ggml_compute_backward(ctx, cgraph, i, grads_needed);
|
|
5870
5876
|
}
|
|
5871
5877
|
|
|
5872
5878
|
free(grads_needed);
|
|
@@ -6012,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
|
|
|
6012
6018
|
}
|
|
6013
6019
|
}
|
|
6014
6020
|
|
|
6015
|
-
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
|
|
6016
|
-
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads
|
|
6021
|
+
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
|
|
6022
|
+
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
|
|
6017
6023
|
ggml_graph_cpy(cgraph, result);
|
|
6018
6024
|
return result;
|
|
6019
6025
|
}
|
|
@@ -6032,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
|
|
6032
6038
|
}
|
|
6033
6039
|
|
|
6034
6040
|
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
|
6041
|
+
if (!cgraph) {
|
|
6042
|
+
return;
|
|
6043
|
+
}
|
|
6035
6044
|
GGML_ASSERT(cgraph->grads != NULL);
|
|
6036
6045
|
|
|
6037
6046
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
@@ -6341,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
|
|
|
6341
6350
|
tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
|
|
6342
6351
|
}
|
|
6343
6352
|
|
|
6344
|
-
void ggml_set_param(struct
|
|
6345
|
-
|
|
6353
|
+
void ggml_set_param(struct ggml_tensor * tensor) {
|
|
6354
|
+
GGML_ASSERT(tensor->op == GGML_OP_NONE);
|
|
6346
6355
|
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
|
|
6347
6356
|
}
|
|
6348
6357
|
|