@fugood/llama.node 0.3.12 → 0.3.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +2 -1
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +14 -0
- package/src/LlamaContext.cpp +110 -79
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +95 -13
- package/src/llama.cpp/.github/workflows/docker.yml +2 -0
- package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +23 -6
- package/src/llama.cpp/common/arg.cpp +292 -14
- package/src/llama.cpp/common/chat.cpp +1128 -315
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +27 -171
- package/src/llama.cpp/common/common.h +41 -73
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/llguidance.cpp +3 -3
- package/src/llama.cpp/common/log.cpp +1 -0
- package/src/llama.cpp/common/log.h +2 -1
- package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +93 -49
- package/src/llama.cpp/common/speculative.cpp +6 -5
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +47 -9
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
- package/src/llama.cpp/examples/main/main.cpp +73 -28
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +115 -79
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/server/httplib.h +381 -292
- package/src/llama.cpp/examples/server/server.cpp +134 -128
- package/src/llama.cpp/examples/server/utils.hpp +95 -106
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +251 -142
- package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
- package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
- package/src/llama.cpp/ggml/include/ggml.h +6 -2
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
- package/src/llama.cpp/ggml/src/ggml.c +9 -4
- package/src/llama.cpp/include/llama.h +32 -14
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +21 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +183 -183
- package/src/llama.cpp/src/llama-grammar.h +13 -4
- package/src/llama.cpp/src/llama-impl.h +6 -6
- package/src/llama.cpp/src/llama-kv-cache.h +2 -1
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-mmap.h +1 -0
- package/src/llama.cpp/src/llama-model.cpp +70 -6
- package/src/llama.cpp/src/llama-sampling.cpp +174 -67
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +154 -5
- package/src/llama.cpp/src/unicode.cpp +9 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +691 -325
- package/src/llama.cpp/tests/test-gguf.cpp +4 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/tests/test-sampling.cpp +15 -0
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -52
|
@@ -18,20 +18,22 @@
|
|
|
18
18
|
#include <ggml.h>
|
|
19
19
|
#include <ggml-alloc.h>
|
|
20
20
|
#include <ggml-backend.h>
|
|
21
|
+
#include <ggml-cpp.h>
|
|
21
22
|
|
|
22
23
|
#include <algorithm>
|
|
23
24
|
#include <array>
|
|
24
25
|
#include <cfloat>
|
|
26
|
+
#include <cinttypes>
|
|
25
27
|
#include <cstdint>
|
|
28
|
+
#include <cstdio>
|
|
29
|
+
#include <cstdlib>
|
|
26
30
|
#include <cstring>
|
|
27
|
-
#include <
|
|
31
|
+
#include <future>
|
|
28
32
|
#include <memory>
|
|
29
33
|
#include <random>
|
|
30
|
-
#include <
|
|
31
|
-
#include <stdlib.h>
|
|
34
|
+
#include <regex>
|
|
32
35
|
#include <string>
|
|
33
36
|
#include <thread>
|
|
34
|
-
#include <future>
|
|
35
37
|
#include <vector>
|
|
36
38
|
|
|
37
39
|
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
|
@@ -467,6 +469,7 @@ struct test_case {
|
|
|
467
469
|
|
|
468
470
|
// allocate
|
|
469
471
|
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
|
|
472
|
+
|
|
470
473
|
if (buf == NULL) {
|
|
471
474
|
printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
|
|
472
475
|
ggml_free(ctx);
|
|
@@ -588,14 +591,13 @@ struct test_case {
|
|
|
588
591
|
/* .mem_base = */ NULL,
|
|
589
592
|
/* .no_alloc = */ true,
|
|
590
593
|
};
|
|
591
|
-
|
|
594
|
+
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
|
|
592
595
|
GGML_ASSERT(ctx);
|
|
593
596
|
|
|
594
|
-
ggml_tensor * out = build_graph(ctx);
|
|
597
|
+
ggml_tensor * out = build_graph(ctx.get());
|
|
595
598
|
|
|
596
599
|
if (op_name != nullptr && op_desc(out) != op_name) {
|
|
597
600
|
//printf(" %s: skipping\n", op_desc(out).c_str());
|
|
598
|
-
ggml_free(ctx);
|
|
599
601
|
return true;
|
|
600
602
|
}
|
|
601
603
|
|
|
@@ -605,7 +607,6 @@ struct test_case {
|
|
|
605
607
|
// check if backends support op
|
|
606
608
|
if (!ggml_backend_supports_op(backend, out)) {
|
|
607
609
|
printf("not supported\n");
|
|
608
|
-
ggml_free(ctx);
|
|
609
610
|
return true;
|
|
610
611
|
}
|
|
611
612
|
|
|
@@ -618,22 +619,26 @@ struct test_case {
|
|
|
618
619
|
printf("%*s", last - len, "");
|
|
619
620
|
|
|
620
621
|
// allocate
|
|
621
|
-
|
|
622
|
+
ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
|
|
623
|
+
|
|
622
624
|
if (buf == NULL) {
|
|
623
625
|
printf("failed to allocate tensors\n");
|
|
624
|
-
ggml_free(ctx);
|
|
625
626
|
return false;
|
|
626
627
|
}
|
|
627
628
|
|
|
628
629
|
// randomize tensors
|
|
629
|
-
initialize_tensors(ctx);
|
|
630
|
+
initialize_tensors(ctx.get());
|
|
630
631
|
|
|
631
632
|
// build graph
|
|
632
|
-
ggml_cgraph * gf = ggml_new_graph_custom(ctx, graph_nodes, false);
|
|
633
|
+
ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);
|
|
633
634
|
ggml_build_forward_expand(gf, out);
|
|
634
635
|
|
|
635
636
|
// warmup run
|
|
636
|
-
ggml_backend_graph_compute(backend, gf);
|
|
637
|
+
ggml_status status = ggml_backend_graph_compute(backend, gf);
|
|
638
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
639
|
+
fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
|
|
640
|
+
return false;
|
|
641
|
+
}
|
|
637
642
|
|
|
638
643
|
// determine number of runs
|
|
639
644
|
int n_runs;
|
|
@@ -684,7 +689,11 @@ struct test_case {
|
|
|
684
689
|
int total_runs = 0;
|
|
685
690
|
do {
|
|
686
691
|
int64_t start_time = ggml_time_us();
|
|
687
|
-
ggml_backend_graph_compute(backend, gf);
|
|
692
|
+
ggml_status status = ggml_backend_graph_compute(backend, gf);
|
|
693
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
694
|
+
fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
|
|
695
|
+
return false;
|
|
696
|
+
}
|
|
688
697
|
int64_t end_time = ggml_time_us();
|
|
689
698
|
|
|
690
699
|
total_time_us += end_time - start_time;
|
|
@@ -722,10 +731,6 @@ struct test_case {
|
|
|
722
731
|
}
|
|
723
732
|
printf("\n");
|
|
724
733
|
|
|
725
|
-
ggml_backend_buffer_free(buf);
|
|
726
|
-
|
|
727
|
-
ggml_free(ctx);
|
|
728
|
-
|
|
729
734
|
return true;
|
|
730
735
|
}
|
|
731
736
|
|
|
@@ -738,17 +743,16 @@ struct test_case {
|
|
|
738
743
|
/* .mem_base = */ NULL,
|
|
739
744
|
/* .no_alloc = */ true,
|
|
740
745
|
};
|
|
741
|
-
|
|
746
|
+
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
|
|
742
747
|
GGML_ASSERT(ctx);
|
|
743
748
|
|
|
744
|
-
gf = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
|
|
745
|
-
gb = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
|
|
749
|
+
gf = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
|
|
750
|
+
gb = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
|
|
746
751
|
|
|
747
|
-
ggml_tensor * out = build_graph(ctx);
|
|
752
|
+
ggml_tensor * out = build_graph(ctx.get());
|
|
748
753
|
|
|
749
754
|
if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
|
|
750
755
|
//printf(" %s: skipping\n", op_desc(out).c_str());
|
|
751
|
-
ggml_free(ctx);
|
|
752
756
|
return true;
|
|
753
757
|
}
|
|
754
758
|
|
|
@@ -756,7 +760,6 @@ struct test_case {
|
|
|
756
760
|
fflush(stdout);
|
|
757
761
|
|
|
758
762
|
if (out->type != GGML_TYPE_F32) {
|
|
759
|
-
ggml_free(ctx);
|
|
760
763
|
printf("not supported [%s->type != FP32]\n", out->name);
|
|
761
764
|
return true;
|
|
762
765
|
}
|
|
@@ -764,7 +767,7 @@ struct test_case {
|
|
|
764
767
|
// check if the backend supports the ops
|
|
765
768
|
bool supported = true;
|
|
766
769
|
bool any_params = false;
|
|
767
|
-
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
770
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
|
|
768
771
|
if (!ggml_backend_supports_op(backend, t)) {
|
|
769
772
|
printf("not supported [%s] ", ggml_backend_name(backend));
|
|
770
773
|
supported = false;
|
|
@@ -785,40 +788,38 @@ struct test_case {
|
|
|
785
788
|
}
|
|
786
789
|
if (!supported) {
|
|
787
790
|
printf("\n");
|
|
788
|
-
ggml_free(ctx);
|
|
789
791
|
return true;
|
|
790
792
|
}
|
|
791
793
|
|
|
792
794
|
int64_t ngrads = 0;
|
|
793
|
-
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
795
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
|
|
794
796
|
if (t->flags & GGML_TENSOR_FLAG_PARAM) {
|
|
795
797
|
ngrads += ggml_nelements(t);
|
|
796
798
|
}
|
|
797
799
|
}
|
|
798
800
|
if (ngrads > grad_nmax()) {
|
|
799
801
|
printf("skipping large tensors for speed \n");
|
|
800
|
-
ggml_free(ctx);
|
|
801
802
|
return true;
|
|
802
803
|
}
|
|
803
804
|
|
|
804
805
|
|
|
805
806
|
if (!ggml_is_scalar(out)) {
|
|
806
|
-
out = ggml_sum(ctx, out);
|
|
807
|
+
out = ggml_sum(ctx.get(), out);
|
|
807
808
|
ggml_set_name(out, "sum_of_out");
|
|
808
809
|
}
|
|
809
810
|
ggml_set_loss(out);
|
|
810
811
|
|
|
811
812
|
ggml_build_forward_expand(gf, out);
|
|
812
813
|
ggml_graph_cpy(gf, gb);
|
|
813
|
-
ggml_build_backward_expand(ctx, ctx, gb, false);
|
|
814
|
+
ggml_build_backward_expand(ctx.get(), ctx.get(), gb, false);
|
|
814
815
|
if (expect.size() != 1 || expect[0] != 0.0f) {
|
|
815
816
|
GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
|
|
816
|
-
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
817
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
|
|
817
818
|
GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);
|
|
818
819
|
}
|
|
819
820
|
}
|
|
820
821
|
|
|
821
|
-
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
822
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
|
|
822
823
|
if (!ggml_backend_supports_op(backend, t)) {
|
|
823
824
|
printf("not supported [%s] ", ggml_backend_name(backend));
|
|
824
825
|
supported = false;
|
|
@@ -832,27 +833,32 @@ struct test_case {
|
|
|
832
833
|
}
|
|
833
834
|
if (!supported) {
|
|
834
835
|
printf("\n");
|
|
835
|
-
ggml_free(ctx);
|
|
836
836
|
return true;
|
|
837
837
|
}
|
|
838
838
|
|
|
839
839
|
// allocate
|
|
840
|
-
|
|
840
|
+
ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
|
|
841
841
|
if (buf == NULL) {
|
|
842
842
|
printf("failed to allocate tensors [%s] ", ggml_backend_name(backend));
|
|
843
|
-
ggml_free(ctx);
|
|
844
843
|
return false;
|
|
845
844
|
}
|
|
846
845
|
|
|
847
|
-
|
|
848
|
-
initialize_tensors(ctx); // Randomizes all tensors (including gradients).
|
|
846
|
+
initialize_tensors(ctx.get()); // Randomizes all tensors (including gradients).
|
|
849
847
|
ggml_graph_reset(gb); // Sets gradients to 1 if loss, 0 otherwise.
|
|
850
848
|
|
|
851
|
-
ggml_backend_graph_compute(backend, gf);
|
|
852
|
-
|
|
849
|
+
ggml_status status = ggml_backend_graph_compute(backend, gf);
|
|
850
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
851
|
+
fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
|
|
852
|
+
return false;
|
|
853
|
+
}
|
|
854
|
+
status = ggml_backend_graph_compute(backend, gb);
|
|
855
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
856
|
+
fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
|
|
857
|
+
return false;
|
|
858
|
+
}
|
|
853
859
|
|
|
854
860
|
bool ok = true;
|
|
855
|
-
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
|
|
861
|
+
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
|
|
856
862
|
if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) {
|
|
857
863
|
continue;
|
|
858
864
|
}
|
|
@@ -897,20 +903,36 @@ struct test_case {
|
|
|
897
903
|
float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh
|
|
898
904
|
|
|
899
905
|
ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float));
|
|
900
|
-
ggml_backend_graph_compute(backend, gf);
|
|
906
|
+
status = ggml_backend_graph_compute(backend, gf);
|
|
907
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
908
|
+
fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
|
|
909
|
+
return false;
|
|
910
|
+
}
|
|
901
911
|
ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out));
|
|
902
912
|
|
|
903
913
|
ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float));
|
|
904
|
-
ggml_backend_graph_compute(backend, gf);
|
|
914
|
+
status = ggml_backend_graph_compute(backend, gf);
|
|
915
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
916
|
+
fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
|
|
917
|
+
return false;
|
|
918
|
+
}
|
|
905
919
|
ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out));
|
|
906
920
|
|
|
907
921
|
if (grad_precise()) {
|
|
908
922
|
ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float));
|
|
909
|
-
ggml_backend_graph_compute(backend, gf);
|
|
923
|
+
status = ggml_backend_graph_compute(backend, gf);
|
|
924
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
925
|
+
fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
|
|
926
|
+
return false;
|
|
927
|
+
}
|
|
910
928
|
ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out));
|
|
911
929
|
|
|
912
930
|
ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float));
|
|
913
|
-
ggml_backend_graph_compute(backend, gf);
|
|
931
|
+
status = ggml_backend_graph_compute(backend, gf);
|
|
932
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
933
|
+
fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
|
|
934
|
+
return false;
|
|
935
|
+
}
|
|
914
936
|
ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out));
|
|
915
937
|
|
|
916
938
|
gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps);
|
|
@@ -936,10 +958,6 @@ struct test_case {
|
|
|
936
958
|
printf("compare failed ");
|
|
937
959
|
}
|
|
938
960
|
|
|
939
|
-
ggml_backend_buffer_free(buf);
|
|
940
|
-
|
|
941
|
-
ggml_free(ctx);
|
|
942
|
-
|
|
943
961
|
if (ok) {
|
|
944
962
|
printf("\033[1;32mOK\033[0m\n");
|
|
945
963
|
return true;
|
|
@@ -1254,7 +1272,7 @@ struct test_count_equal : public test_case {
|
|
|
1254
1272
|
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
1255
1273
|
ggml_set_name(b, "b");
|
|
1256
1274
|
|
|
1257
|
-
ggml_tensor * b_argmax = ggml_argmax(ctx,
|
|
1275
|
+
ggml_tensor * b_argmax = ggml_argmax(ctx, b);
|
|
1258
1276
|
ggml_set_name(b_argmax, "b_argmax");
|
|
1259
1277
|
|
|
1260
1278
|
ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
|
|
@@ -1511,6 +1529,7 @@ struct test_cont : public test_case {
|
|
|
1511
1529
|
};
|
|
1512
1530
|
|
|
1513
1531
|
// GGML_OP_ADD
|
|
1532
|
+
// GGML_OP_SUB
|
|
1514
1533
|
// GGML_OP_MUL
|
|
1515
1534
|
// GGML_OP_DIV
|
|
1516
1535
|
struct test_bin_bcast : public test_case {
|
|
@@ -3118,6 +3137,7 @@ struct test_leaky_relu : public test_case {
|
|
|
3118
3137
|
struct test_flash_attn_ext : public test_case {
|
|
3119
3138
|
const int64_t hs; // head size
|
|
3120
3139
|
const int64_t nh; // num heads
|
|
3140
|
+
const int64_t nr; // repeat in Q, tests for grouped-query attention
|
|
3121
3141
|
const int64_t kv; // kv size
|
|
3122
3142
|
const int64_t nb; // batch size
|
|
3123
3143
|
|
|
@@ -3130,7 +3150,7 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3130
3150
|
std::array<int32_t, 4> permute;
|
|
3131
3151
|
|
|
3132
3152
|
std::string vars() override {
|
|
3133
|
-
return
|
|
3153
|
+
return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
|
|
3134
3154
|
}
|
|
3135
3155
|
|
|
3136
3156
|
double max_nmse_err() override {
|
|
@@ -3141,13 +3161,13 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3141
3161
|
GGML_UNUSED(t);
|
|
3142
3162
|
// Just counting matmul costs:
|
|
3143
3163
|
// Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
|
|
3144
|
-
return 2 * 2 * nh * nb * hs * kv;
|
|
3164
|
+
return 2 * 2 * nh*nr * nb * hs * kv;
|
|
3145
3165
|
}
|
|
3146
3166
|
|
|
3147
|
-
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
|
|
3167
|
+
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
|
|
3148
3168
|
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
|
|
3149
3169
|
std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
|
3150
|
-
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
|
|
3170
|
+
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
|
|
3151
3171
|
|
|
3152
3172
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
3153
3173
|
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
|
@@ -3165,13 +3185,13 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3165
3185
|
return t;
|
|
3166
3186
|
};
|
|
3167
3187
|
|
|
3168
|
-
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
|
|
3188
|
+
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
|
|
3169
3189
|
ggml_set_name(q, "q");
|
|
3170
3190
|
|
|
3171
|
-
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh,
|
|
3191
|
+
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
|
3172
3192
|
ggml_set_name(k, "k");
|
|
3173
3193
|
|
|
3174
|
-
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh,
|
|
3194
|
+
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
|
3175
3195
|
ggml_set_name(v, "v");
|
|
3176
3196
|
|
|
3177
3197
|
ggml_tensor * m = nullptr;
|
|
@@ -3751,10 +3771,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3751
3771
|
std::default_random_engine rng(0);
|
|
3752
3772
|
|
|
3753
3773
|
// unary ops
|
|
3754
|
-
for (
|
|
3755
|
-
for (int
|
|
3756
|
-
|
|
3757
|
-
|
|
3774
|
+
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
|
3775
|
+
for (int v : {0, 1}) {
|
|
3776
|
+
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
|
|
3777
|
+
test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
|
|
3778
|
+
test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
|
|
3779
|
+
}
|
|
3758
3780
|
}
|
|
3759
3781
|
}
|
|
3760
3782
|
|
|
@@ -3860,7 +3882,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3860
3882
|
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
|
|
3861
3883
|
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
|
3862
3884
|
|
|
3863
|
-
test_cases.emplace_back(new test_count_equal());
|
|
3885
|
+
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
|
|
3886
|
+
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
|
|
3864
3887
|
|
|
3865
3888
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
|
|
3866
3889
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
|
|
@@ -3885,8 +3908,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3885
3908
|
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
|
|
3886
3909
|
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
|
|
3887
3910
|
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
|
|
3888
|
-
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
|
|
3889
|
-
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
|
|
3890
3911
|
}
|
|
3891
3912
|
|
|
3892
3913
|
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
|
|
@@ -3938,41 +3959,42 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3938
3959
|
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
|
|
3939
3960
|
|
|
3940
3961
|
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
|
|
3941
|
-
for (auto op : {ggml_add, ggml_mul, ggml_div}) {
|
|
3962
|
+
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
|
|
3942
3963
|
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
|
|
3943
3964
|
}
|
|
3944
3965
|
};
|
|
3945
|
-
|
|
3946
|
-
|
|
3947
|
-
|
|
3948
|
-
|
|
3949
|
-
|
|
3950
|
-
|
|
3951
|
-
|
|
3952
|
-
|
|
3953
|
-
|
|
3954
|
-
|
|
3955
|
-
|
|
3956
|
-
|
|
3957
|
-
|
|
3958
|
-
|
|
3959
|
-
|
|
3960
|
-
|
|
3961
|
-
|
|
3962
|
-
|
|
3963
|
-
|
|
3964
|
-
|
|
3965
|
-
|
|
3966
|
-
|
|
3967
|
-
|
|
3968
|
-
|
|
3969
|
-
|
|
3970
|
-
|
|
3971
|
-
|
|
3972
|
-
|
|
3973
|
-
|
|
3974
|
-
|
|
3975
|
-
|
|
3966
|
+
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
|
3967
|
+
add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
|
|
3968
|
+
add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
|
|
3969
|
+
add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
|
|
3970
|
+
add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
|
|
3971
|
+
add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
|
|
3972
|
+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
|
|
3973
|
+
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
|
|
3974
|
+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
|
|
3975
|
+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
|
|
3976
|
+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
|
|
3977
|
+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
|
|
3978
|
+
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
|
|
3979
|
+
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
|
|
3980
|
+
|
|
3981
|
+
// stable diffusion
|
|
3982
|
+
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
|
|
3983
|
+
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});
|
|
3984
|
+
add_test_bin_bcast(type, {1280, 16, 16, 1}, {1, 1, 1, 1});
|
|
3985
|
+
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 256, 1, 1});
|
|
3986
|
+
add_test_bin_bcast(type, {1, 1, 1280, 1}, {16, 16, 1, 1});
|
|
3987
|
+
add_test_bin_bcast(type, {16, 16, 1280, 1}, {1, 1, 1, 1});
|
|
3988
|
+
add_test_bin_bcast(type, {1, 1, 1920, 1}, {16, 16, 1, 1});
|
|
3989
|
+
add_test_bin_bcast(type, {1, 1, 2560, 1}, {16, 16, 1, 1});
|
|
3990
|
+
add_test_bin_bcast(type, {1, 1, 1280, 1}, {32, 32, 1, 1});
|
|
3991
|
+
add_test_bin_bcast(type, {1, 1, 1920, 1}, {32, 32, 1, 1});
|
|
3992
|
+
add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
|
|
3993
|
+
add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
|
|
3994
|
+
add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
|
|
3995
|
+
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
|
|
3996
|
+
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
|
3997
|
+
}
|
|
3976
3998
|
|
|
3977
3999
|
test_cases.emplace_back(new test_add1());
|
|
3978
4000
|
test_cases.emplace_back(new test_scale());
|
|
@@ -4091,7 +4113,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4091
4113
|
for (int n_mats : {4, 8}) {
|
|
4092
4114
|
for (int n_used : {1, 2, 4}) {
|
|
4093
4115
|
for (bool b : {false, true}) {
|
|
4094
|
-
for (int n : {1, 32}) {
|
|
4116
|
+
for (int n : {1, 32, 129}) {
|
|
4095
4117
|
int m = 512;
|
|
4096
4118
|
int k = 256;
|
|
4097
4119
|
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
|
|
@@ -4136,12 +4158,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4136
4158
|
}
|
|
4137
4159
|
}
|
|
4138
4160
|
|
|
4139
|
-
|
|
4140
|
-
|
|
4141
|
-
|
|
4142
|
-
|
|
4143
|
-
|
|
4144
|
-
|
|
4161
|
+
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
|
4162
|
+
test_cases.emplace_back(new test_sqr(type));
|
|
4163
|
+
test_cases.emplace_back(new test_sqrt(type));
|
|
4164
|
+
test_cases.emplace_back(new test_log(type));
|
|
4165
|
+
test_cases.emplace_back(new test_sin(type));
|
|
4166
|
+
test_cases.emplace_back(new test_cos(type));
|
|
4167
|
+
test_cases.emplace_back(new test_clamp(type));
|
|
4168
|
+
}
|
|
4145
4169
|
|
|
4146
4170
|
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
|
|
4147
4171
|
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));
|
|
@@ -4278,14 +4302,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4278
4302
|
if (!mask && max_bias > 0.0f) continue;
|
|
4279
4303
|
for (float logit_softcap : {0.0f, 10.0f}) {
|
|
4280
4304
|
if (hs != 128 && logit_softcap != 0.0f) continue;
|
|
4281
|
-
for (int nh : {
|
|
4282
|
-
for (int
|
|
4283
|
-
|
|
4284
|
-
|
|
4285
|
-
|
|
4286
|
-
|
|
4287
|
-
|
|
4288
|
-
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV
|
|
4305
|
+
for (int nh : { 4, }) {
|
|
4306
|
+
for (int nr : { 1, 4, 16 }) {
|
|
4307
|
+
if (nr == 16 && hs != 128) continue;
|
|
4308
|
+
for (int kv : { 512, 1024, }) {
|
|
4309
|
+
if (nr != 1 && kv != 512) continue;
|
|
4310
|
+
for (int nb : { 1, 3, 32, 35, }) {
|
|
4311
|
+
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
|
4312
|
+
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
|
4313
|
+
// run fewer test cases permuted
|
|
4314
|
+
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
|
4315
|
+
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
|
|
4316
|
+
}
|
|
4289
4317
|
}
|
|
4290
4318
|
}
|
|
4291
4319
|
}
|
|
@@ -4360,9 +4388,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|
|
4360
4388
|
return test_cases;
|
|
4361
4389
|
}
|
|
4362
4390
|
|
|
4363
|
-
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
|
|
4391
|
+
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) {
|
|
4392
|
+
auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
|
|
4393
|
+
if (params_filter == nullptr) {
|
|
4394
|
+
return;
|
|
4395
|
+
}
|
|
4396
|
+
|
|
4397
|
+
std::regex params_filter_regex(params_filter);
|
|
4398
|
+
|
|
4399
|
+
for (auto it = test_cases.begin(); it != test_cases.end();) {
|
|
4400
|
+
if (!std::regex_search((*it)->vars(), params_filter_regex)) {
|
|
4401
|
+
it = test_cases.erase(it);
|
|
4402
|
+
continue;
|
|
4403
|
+
}
|
|
4404
|
+
|
|
4405
|
+
it++;
|
|
4406
|
+
}
|
|
4407
|
+
};
|
|
4408
|
+
|
|
4364
4409
|
if (mode == MODE_TEST) {
|
|
4365
4410
|
auto test_cases = make_test_cases_eval();
|
|
4411
|
+
filter_test_cases(test_cases, params_filter);
|
|
4366
4412
|
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
|
|
4367
4413
|
if (backend_cpu == NULL) {
|
|
4368
4414
|
printf(" Failed to initialize CPU backend\n");
|
|
@@ -4384,6 +4430,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|
|
4384
4430
|
|
|
4385
4431
|
if (mode == MODE_GRAD) {
|
|
4386
4432
|
auto test_cases = make_test_cases_eval();
|
|
4433
|
+
filter_test_cases(test_cases, params_filter);
|
|
4387
4434
|
size_t n_ok = 0;
|
|
4388
4435
|
for (auto & test : test_cases) {
|
|
4389
4436
|
if (test->eval_grad(backend, op_name)) {
|
|
@@ -4397,6 +4444,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|
|
4397
4444
|
|
|
4398
4445
|
if (mode == MODE_PERF) {
|
|
4399
4446
|
auto test_cases = make_test_cases_perf();
|
|
4447
|
+
filter_test_cases(test_cases, params_filter);
|
|
4400
4448
|
for (auto & test : test_cases) {
|
|
4401
4449
|
test->eval_perf(backend, op_name);
|
|
4402
4450
|
}
|
|
@@ -4407,7 +4455,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|
|
4407
4455
|
}
|
|
4408
4456
|
|
|
4409
4457
|
static void usage(char ** argv) {
|
|
4410
|
-
printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
|
|
4458
|
+
printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>]\n", argv[0]);
|
|
4411
4459
|
printf(" valid modes:\n");
|
|
4412
4460
|
printf(" - test (default, compare with CPU backend for correctness)\n");
|
|
4413
4461
|
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
|
|
@@ -4417,8 +4465,9 @@ static void usage(char ** argv) {
|
|
|
4417
4465
|
|
|
4418
4466
|
int main(int argc, char ** argv) {
|
|
4419
4467
|
test_mode mode = MODE_TEST;
|
|
4420
|
-
const char * op_name_filter =
|
|
4421
|
-
const char * backend_filter =
|
|
4468
|
+
const char * op_name_filter = nullptr;
|
|
4469
|
+
const char * backend_filter = nullptr;
|
|
4470
|
+
const char * params_filter = nullptr;
|
|
4422
4471
|
|
|
4423
4472
|
for (int i = 1; i < argc; i++) {
|
|
4424
4473
|
if (strcmp(argv[i], "test") == 0) {
|
|
@@ -4441,6 +4490,13 @@ int main(int argc, char ** argv) {
|
|
|
4441
4490
|
usage(argv);
|
|
4442
4491
|
return 1;
|
|
4443
4492
|
}
|
|
4493
|
+
} else if (strcmp(argv[i], "-p") == 0) {
|
|
4494
|
+
if (i + 1 < argc) {
|
|
4495
|
+
params_filter = argv[++i];
|
|
4496
|
+
} else {
|
|
4497
|
+
usage(argv);
|
|
4498
|
+
return 1;
|
|
4499
|
+
}
|
|
4444
4500
|
} else {
|
|
4445
4501
|
usage(argv);
|
|
4446
4502
|
return 1;
|
|
@@ -4487,7 +4543,7 @@ int main(int argc, char ** argv) {
|
|
|
4487
4543
|
printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
|
|
4488
4544
|
printf("\n");
|
|
4489
4545
|
|
|
4490
|
-
bool ok = test_backend(backend, mode, op_name_filter);
|
|
4546
|
+
bool ok = test_backend(backend, mode, op_name_filter, params_filter);
|
|
4491
4547
|
|
|
4492
4548
|
printf(" Backend %s: ", ggml_backend_name(backend));
|
|
4493
4549
|
if (ok) {
|