@fugood/llama.node 0.3.12 → 0.3.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +2 -1
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +14 -0
- package/src/LlamaContext.cpp +110 -79
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +95 -13
- package/src/llama.cpp/.github/workflows/docker.yml +2 -0
- package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +23 -6
- package/src/llama.cpp/common/arg.cpp +292 -14
- package/src/llama.cpp/common/chat.cpp +1128 -315
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +27 -171
- package/src/llama.cpp/common/common.h +41 -73
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/llguidance.cpp +3 -3
- package/src/llama.cpp/common/log.cpp +1 -0
- package/src/llama.cpp/common/log.h +2 -1
- package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +93 -49
- package/src/llama.cpp/common/speculative.cpp +6 -5
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +47 -9
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
- package/src/llama.cpp/examples/main/main.cpp +73 -28
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +115 -79
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/server/httplib.h +381 -292
- package/src/llama.cpp/examples/server/server.cpp +134 -128
- package/src/llama.cpp/examples/server/utils.hpp +95 -106
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +251 -142
- package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
- package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
- package/src/llama.cpp/ggml/include/ggml.h +6 -2
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
- package/src/llama.cpp/ggml/src/ggml.c +9 -4
- package/src/llama.cpp/include/llama.h +32 -14
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +21 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +183 -183
- package/src/llama.cpp/src/llama-grammar.h +13 -4
- package/src/llama.cpp/src/llama-impl.h +6 -6
- package/src/llama.cpp/src/llama-kv-cache.h +2 -1
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-mmap.h +1 -0
- package/src/llama.cpp/src/llama-model.cpp +70 -6
- package/src/llama.cpp/src/llama-sampling.cpp +174 -67
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +154 -5
- package/src/llama.cpp/src/unicode.cpp +9 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +691 -325
- package/src/llama.cpp/tests/test-gguf.cpp +4 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/tests/test-sampling.cpp +15 -0
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -52
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
2
|
+
// SPDX-License-Identifier: MIT
|
|
3
|
+
//
|
|
4
|
+
#include <arm_neon.h>
|
|
5
|
+
#include <assert.h>
|
|
6
|
+
#include <cfloat>
|
|
7
|
+
#include <stdint.h>
|
|
8
|
+
#include <string.h>
|
|
9
|
+
#if defined(__linux__)
|
|
10
|
+
#include <asm/hwcap.h>
|
|
11
|
+
#include <sys/auxv.h>
|
|
12
|
+
#elif defined(__APPLE__)
|
|
13
|
+
#include <string_view>
|
|
14
|
+
#include <sys/sysctl.h>
|
|
15
|
+
#include <sys/types.h>
|
|
16
|
+
#elif defined(_WIN32)
|
|
17
|
+
#include <windows.h>
|
|
18
|
+
#include <excpt.h>
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
#include "kleidiai.h"
|
|
22
|
+
|
|
23
|
+
#include "ggml-cpu.h"
|
|
24
|
+
#include "ggml-impl.h"
|
|
25
|
+
#include "ggml-backend-impl.h"
|
|
26
|
+
#include "ggml-threading.h"
|
|
27
|
+
#include "ggml-cpu-traits.h"
|
|
28
|
+
|
|
29
|
+
#include "kernels.h"
|
|
30
|
+
|
|
31
|
+
#include "kai_common.h"
|
|
32
|
+
|
|
33
|
+
#define GGML_COMMON_DECL_CPP
|
|
34
|
+
#include "ggml-common.h"
|
|
35
|
+
|
|
36
|
+
struct ggml_kleidiai_context {
|
|
37
|
+
ggml_kleidiai_kernels * kernels;
|
|
38
|
+
} static ctx = { NULL };
|
|
39
|
+
|
|
40
|
+
static void init_kleidiai_context(void) {
|
|
41
|
+
|
|
42
|
+
ggml_critical_section_start();
|
|
43
|
+
static bool initialized = false;
|
|
44
|
+
|
|
45
|
+
if (!initialized) {
|
|
46
|
+
initialized = true;
|
|
47
|
+
const char *env_var = getenv("GGML_KLEIDIAI_SME");
|
|
48
|
+
int sme_enabled = 0;
|
|
49
|
+
|
|
50
|
+
cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
|
51
|
+
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
|
52
|
+
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
|
53
|
+
|
|
54
|
+
if (env_var) {
|
|
55
|
+
sme_enabled = atoi(env_var);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if (sme_enabled != 0) {
|
|
59
|
+
features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
|
60
|
+
}
|
|
61
|
+
ctx.kernels = ggml_kleidiai_select_kernels(features);
|
|
62
|
+
}
|
|
63
|
+
ggml_critical_section_end();
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
|
67
|
+
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
|
68
|
+
return tensor->ne[dim];
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
namespace ggml::cpu::kleidiai {
|
|
72
|
+
class tensor_traits : public ggml::cpu::tensor_traits {
|
|
73
|
+
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
74
|
+
GGML_ASSERT(ctx.kernels);
|
|
75
|
+
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
|
76
|
+
|
|
77
|
+
size_t k = op->src[0]->ne[0];
|
|
78
|
+
size_t m = op->src[1]->ne[1];
|
|
79
|
+
|
|
80
|
+
size_t mr = kernel->get_mr();
|
|
81
|
+
size_t kr = kernel->get_kr();
|
|
82
|
+
size_t sr = kernel->get_sr();
|
|
83
|
+
|
|
84
|
+
size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
|
|
85
|
+
|
|
86
|
+
return true;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
|
90
|
+
if (dst->op == GGML_OP_MUL_MAT) {
|
|
91
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
92
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
93
|
+
|
|
94
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
95
|
+
|
|
96
|
+
GGML_ASSERT(ctx.kernels);
|
|
97
|
+
kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
|
98
|
+
lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
|
|
99
|
+
|
|
100
|
+
GGML_ASSERT(kernel);
|
|
101
|
+
|
|
102
|
+
const int ith = params->ith;
|
|
103
|
+
const int nth = params->nth;
|
|
104
|
+
|
|
105
|
+
const size_t k = ne00;
|
|
106
|
+
const size_t m = ne11;
|
|
107
|
+
const size_t n = ne01;
|
|
108
|
+
|
|
109
|
+
const size_t n_step = kernel->get_n_step();
|
|
110
|
+
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
|
111
|
+
const size_t n_start = ith * num_n_per_thread;
|
|
112
|
+
|
|
113
|
+
size_t n_to_process = num_n_per_thread;
|
|
114
|
+
if ((n_start + n_to_process) > n) {
|
|
115
|
+
n_to_process = n - n_start;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
|
119
|
+
uint8_t * lhs_packed = (uint8_t*)params->wdata;
|
|
120
|
+
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
|
121
|
+
|
|
122
|
+
size_t mr = kernel->get_mr();
|
|
123
|
+
size_t kr = kernel->get_kr();
|
|
124
|
+
size_t sr = kernel->get_sr();
|
|
125
|
+
|
|
126
|
+
// Calculate number of columns to be processed per thread
|
|
127
|
+
const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
|
|
128
|
+
const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
|
|
129
|
+
const size_t m_start = ith * num_m_per_thread;
|
|
130
|
+
size_t m_to_process = num_m_per_thread;
|
|
131
|
+
if ((m_start + m_to_process) > m) {
|
|
132
|
+
m_to_process = m - m_start;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if(m_start < m) {
|
|
136
|
+
// Transform LHS
|
|
137
|
+
const size_t src_stride = src1->nb[1];
|
|
138
|
+
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
|
|
139
|
+
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
|
|
140
|
+
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
|
141
|
+
|
|
142
|
+
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
ggml_barrier(params->threadpool);
|
|
146
|
+
|
|
147
|
+
// Perform the operation
|
|
148
|
+
const size_t dst_stride = dst->nb[1];
|
|
149
|
+
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
|
|
150
|
+
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
|
|
151
|
+
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
|
152
|
+
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
|
153
|
+
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
|
154
|
+
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
|
155
|
+
|
|
156
|
+
kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
|
|
157
|
+
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
|
158
|
+
return true;
|
|
159
|
+
}
|
|
160
|
+
return false;
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
public:
|
|
164
|
+
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
|
165
|
+
GGML_ASSERT(ctx.kernels);
|
|
166
|
+
const size_t n = tensor->ne[1];
|
|
167
|
+
const size_t k = tensor->ne[0];
|
|
168
|
+
size_t nr = ctx.kernels->gemm.get_nr();
|
|
169
|
+
size_t kr = ctx.kernels->gemm.get_kr();
|
|
170
|
+
size_t sr = ctx.kernels->gemm.get_sr();
|
|
171
|
+
|
|
172
|
+
#ifndef NDEBUG
|
|
173
|
+
const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
|
|
174
|
+
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
|
175
|
+
#endif
|
|
176
|
+
struct kai_rhs_pack_qs4cxs1s0_param params;
|
|
177
|
+
params.lhs_zero_point = 1;
|
|
178
|
+
params.rhs_zero_point = 8;
|
|
179
|
+
ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms);
|
|
180
|
+
|
|
181
|
+
return 0;
|
|
182
|
+
|
|
183
|
+
GGML_UNUSED(data_size);
|
|
184
|
+
}
|
|
185
|
+
};
|
|
186
|
+
|
|
187
|
+
static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
|
|
188
|
+
static tensor_traits traits;
|
|
189
|
+
return &traits;
|
|
190
|
+
}
|
|
191
|
+
} // namespace ggml::cpu::kleidiai
|
|
192
|
+
|
|
193
|
+
GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
|
194
|
+
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
|
195
|
+
|
|
196
|
+
GGML_UNUSED(buffer);
|
|
197
|
+
return GGML_STATUS_SUCCESS;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
|
201
|
+
const void * data, size_t offset, size_t size) {
|
|
202
|
+
GGML_ASSERT(offset == 0);
|
|
203
|
+
GGML_ASSERT(size == ggml_nbytes(tensor));
|
|
204
|
+
|
|
205
|
+
auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
|
|
206
|
+
auto OK = tensor_traits->repack(tensor, data, size);
|
|
207
|
+
|
|
208
|
+
GGML_ASSERT(OK == 0);
|
|
209
|
+
GGML_UNUSED(buffer);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
213
|
+
return "CPU_KLEIDIAI";
|
|
214
|
+
|
|
215
|
+
GGML_UNUSED(buft);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
219
|
+
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
|
220
|
+
|
|
221
|
+
if (buffer == nullptr) {
|
|
222
|
+
return nullptr;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
buffer->buft = buft;
|
|
226
|
+
buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
|
|
227
|
+
buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
|
|
228
|
+
buffer->iface.get_tensor = nullptr;
|
|
229
|
+
buffer->iface.cpy_tensor = nullptr;
|
|
230
|
+
return buffer;
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
234
|
+
return TENSOR_ALIGNMENT;
|
|
235
|
+
|
|
236
|
+
GGML_UNUSED(buft);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
namespace ggml::cpu::kleidiai {
|
|
240
|
+
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
241
|
+
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
242
|
+
if ( op->op == GGML_OP_MUL_MAT &&
|
|
243
|
+
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
|
244
|
+
op->src[0]->buffer &&
|
|
245
|
+
(ggml_n_dims(op->src[0]) == 2) &&
|
|
246
|
+
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
|
|
247
|
+
) {
|
|
248
|
+
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
249
|
+
return false;
|
|
250
|
+
}
|
|
251
|
+
if (op->src[1]->type == GGML_TYPE_F32 &&
|
|
252
|
+
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
|
253
|
+
return true;
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
return false;
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
|
260
|
+
if (op->op == GGML_OP_MUL_MAT) {
|
|
261
|
+
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
262
|
+
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
return nullptr;
|
|
266
|
+
}
|
|
267
|
+
};
|
|
268
|
+
} // namespace ggml::cpu::kleidiai
|
|
269
|
+
|
|
270
|
+
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
|
|
271
|
+
static ggml::cpu::kleidiai::extra_buffer_type ctx;
|
|
272
|
+
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
|
|
273
|
+
/* .iface = */ {
|
|
274
|
+
/* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
|
|
275
|
+
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
|
276
|
+
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
|
277
|
+
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
|
278
|
+
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
|
279
|
+
/* .is_host = */ nullptr,
|
|
280
|
+
},
|
|
281
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
|
282
|
+
/* .context = */ &ctx,
|
|
283
|
+
};
|
|
284
|
+
|
|
285
|
+
init_kleidiai_context();
|
|
286
|
+
|
|
287
|
+
return &ggml_backend_cpu_buffer_type_kleidiai;
|
|
288
|
+
}
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
2
|
+
// SPDX-License-Identifier: MIT
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
#pragma once
|
|
6
|
+
|
|
7
|
+
#include "ggml-alloc.h"
|
|
8
|
+
|
|
9
|
+
#ifdef __cplusplus
|
|
10
|
+
extern "C" {
|
|
11
|
+
#endif
|
|
12
|
+
|
|
13
|
+
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void);
|
|
14
|
+
|
|
15
|
+
#ifdef __cplusplus
|
|
16
|
+
}
|
|
17
|
+
#endif
|
|
@@ -280,14 +280,6 @@ template <> inline __m256bh load(const float *p) {
|
|
|
280
280
|
}
|
|
281
281
|
#endif
|
|
282
282
|
|
|
283
|
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
284
|
-
// CONSTANTS
|
|
285
|
-
|
|
286
|
-
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
|
287
|
-
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
|
288
|
-
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
|
289
|
-
#endif
|
|
290
|
-
|
|
291
283
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
292
284
|
// FLOATING POINT MATRIX MULTIPLICATION
|
|
293
285
|
|
|
@@ -614,6 +606,14 @@ class tinyBLAS_Q0_AVX {
|
|
|
614
606
|
TC *C, int64_t ldc,
|
|
615
607
|
int ith, int nth)
|
|
616
608
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
|
609
|
+
const int8_t kvalues_iq4nl[16] = {
|
|
610
|
+
-127, -104, -83, -65,
|
|
611
|
+
-49, -35, -22, -10,
|
|
612
|
+
1, 13, 25, 38,
|
|
613
|
+
53, 69, 89, 113
|
|
614
|
+
};
|
|
615
|
+
|
|
616
|
+
iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
|
|
617
617
|
}
|
|
618
618
|
|
|
619
619
|
void matmul(int64_t m, int64_t n) {
|
|
@@ -1038,6 +1038,7 @@ class tinyBLAS_Q0_AVX {
|
|
|
1038
1038
|
const int64_t ldc;
|
|
1039
1039
|
const int ith;
|
|
1040
1040
|
const int nth;
|
|
1041
|
+
__m128i iq4nlt;
|
|
1041
1042
|
};
|
|
1042
1043
|
#endif // __AVX__
|
|
1043
1044
|
|
|
@@ -7,7 +7,7 @@ if (CUDAToolkit_FOUND)
|
|
|
7
7
|
|
|
8
8
|
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
|
9
9
|
# native == GPUs available at build time
|
|
10
|
-
#
|
|
10
|
+
# 50 == Maxwell, lowest CUDA 12 standard
|
|
11
11
|
# 60 == P100, FP16 CUDA intrinsics
|
|
12
12
|
# 61 == Pascal, __dp4a instruction (per-byte integer dot product)
|
|
13
13
|
# 70 == V100, FP16 tensor cores
|
|
@@ -15,9 +15,9 @@ if (CUDAToolkit_FOUND)
|
|
|
15
15
|
if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
|
16
16
|
set(CMAKE_CUDA_ARCHITECTURES "native")
|
|
17
17
|
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
|
18
|
-
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
|
|
18
|
+
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
|
|
19
19
|
else()
|
|
20
|
-
set(CMAKE_CUDA_ARCHITECTURES "
|
|
20
|
+
set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80")
|
|
21
21
|
endif()
|
|
22
22
|
endif()
|
|
23
23
|
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
|
@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
|
|
|
69
69
|
add_compile_definitions(GGML_CUDA_NO_VMM)
|
|
70
70
|
endif()
|
|
71
71
|
|
|
72
|
+
if (NOT GGML_CUDA_FA)
|
|
73
|
+
add_compile_definitions(GGML_CUDA_NO_FA)
|
|
74
|
+
endif()
|
|
75
|
+
|
|
72
76
|
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
|
73
77
|
add_compile_definitions(GGML_CUDA_F16)
|
|
74
78
|
endif()
|
|
@@ -98,6 +102,15 @@ if (CUDAToolkit_FOUND)
|
|
|
98
102
|
|
|
99
103
|
set(CUDA_FLAGS -use_fast_math)
|
|
100
104
|
|
|
105
|
+
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
|
106
|
+
# Options are:
|
|
107
|
+
# - none (not recommended)
|
|
108
|
+
# - speed (nvcc's default)
|
|
109
|
+
# - balance
|
|
110
|
+
# - size
|
|
111
|
+
list(APPEND CUDA_FLAGS -compress-mode=${GGML_CUDA_COMPRESSION_MODE})
|
|
112
|
+
endif()
|
|
113
|
+
|
|
101
114
|
if (GGML_FATAL_WARNINGS)
|
|
102
115
|
list(APPEND CUDA_FLAGS -Werror all-warnings)
|
|
103
116
|
endif()
|
|
@@ -39,6 +39,12 @@ endif()
|
|
|
39
39
|
find_package(hip REQUIRED)
|
|
40
40
|
find_package(hipblas REQUIRED)
|
|
41
41
|
find_package(rocblas REQUIRED)
|
|
42
|
+
if (GGML_HIP_ROCWMMA_FATTN)
|
|
43
|
+
CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
|
|
44
|
+
if (NOT ${FOUND_ROCWMMA})
|
|
45
|
+
message(FATAL_ERROR "rocwmma has not been found")
|
|
46
|
+
endif()
|
|
47
|
+
endif()
|
|
42
48
|
|
|
43
49
|
if (${hip_VERSION} VERSION_LESS 5.5)
|
|
44
50
|
message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
|
|
@@ -107,6 +113,14 @@ if (GGML_HIP_NO_VMM)
|
|
|
107
113
|
add_compile_definitions(GGML_HIP_NO_VMM)
|
|
108
114
|
endif()
|
|
109
115
|
|
|
116
|
+
if (GGML_HIP_ROCWMMA_FATTN)
|
|
117
|
+
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
|
|
118
|
+
endif()
|
|
119
|
+
|
|
120
|
+
if (NOT GGML_CUDA_FA)
|
|
121
|
+
add_compile_definitions(GGML_CUDA_NO_FA)
|
|
122
|
+
endif()
|
|
123
|
+
|
|
110
124
|
if (CXX_IS_HIPCC)
|
|
111
125
|
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
|
112
126
|
target_link_libraries(ggml-hip PRIVATE hip::device)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
#include <arm_sve.h>
|
|
17
17
|
#endif // __ARM_FEATURE_SVE
|
|
18
18
|
|
|
19
|
-
#if defined(__ARM_NEON) && !defined(__CUDACC__)
|
|
19
|
+
#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
|
|
20
20
|
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
|
21
21
|
//
|
|
22
22
|
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
|
@@ -27,12 +27,12 @@ configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
|
|
|
27
27
|
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
|
28
28
|
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
|
|
29
29
|
|
|
30
|
+
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
|
|
30
31
|
if (GGML_METAL_EMBED_LIBRARY)
|
|
31
32
|
enable_language(ASM)
|
|
32
33
|
|
|
33
34
|
add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
|
|
34
35
|
|
|
35
|
-
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
|
|
36
36
|
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
|
|
37
37
|
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
|
|
38
38
|
|
|
@@ -88,12 +88,11 @@ else()
|
|
|
88
88
|
|
|
89
89
|
add_custom_command(
|
|
90
90
|
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
|
91
|
-
COMMAND xcrun -sdk macosx metal
|
|
92
|
-
|
|
93
|
-
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
|
|
91
|
+
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
|
|
92
|
+
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
|
94
93
|
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
|
|
95
94
|
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
|
|
96
|
-
DEPENDS ggml-metal.metal
|
|
95
|
+
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
|
|
97
96
|
COMMENT "Compiling Metal kernels"
|
|
98
97
|
)
|
|
99
98
|
|
|
@@ -285,4 +285,239 @@ typedef struct {
|
|
|
285
285
|
float eps;
|
|
286
286
|
} ggml_metal_kargs_rms_norm;
|
|
287
287
|
|
|
288
|
+
typedef struct {
|
|
289
|
+
int64_t ne00;
|
|
290
|
+
int64_t ne01;
|
|
291
|
+
int64_t ne02;
|
|
292
|
+
uint64_t nb00;
|
|
293
|
+
uint64_t nb01;
|
|
294
|
+
uint64_t nb02;
|
|
295
|
+
int32_t n_groups;
|
|
296
|
+
float eps;
|
|
297
|
+
} ggml_metal_kargs_group_norm;
|
|
298
|
+
|
|
299
|
+
typedef struct {
|
|
300
|
+
int32_t IC;
|
|
301
|
+
int32_t IL;
|
|
302
|
+
int32_t K;
|
|
303
|
+
int32_t s0;
|
|
304
|
+
uint64_t nb0;
|
|
305
|
+
uint64_t nb1;
|
|
306
|
+
} ggml_metal_kargs_conv_transpose_1d;
|
|
307
|
+
|
|
308
|
+
typedef struct {
|
|
309
|
+
uint64_t ofs0;
|
|
310
|
+
uint64_t ofs1;
|
|
311
|
+
int32_t IW;
|
|
312
|
+
int32_t IH;
|
|
313
|
+
int32_t CHW;
|
|
314
|
+
int32_t s0;
|
|
315
|
+
int32_t s1;
|
|
316
|
+
int32_t p0;
|
|
317
|
+
int32_t p1;
|
|
318
|
+
int32_t d0;
|
|
319
|
+
int32_t d1;
|
|
320
|
+
int32_t N;
|
|
321
|
+
int32_t KH;
|
|
322
|
+
int32_t KW;
|
|
323
|
+
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
|
324
|
+
} ggml_metal_kargs_im2col;
|
|
325
|
+
|
|
326
|
+
typedef struct {
|
|
327
|
+
int64_t ne00;
|
|
328
|
+
int64_t ne01;
|
|
329
|
+
int64_t ne02;
|
|
330
|
+
int64_t ne03;
|
|
331
|
+
uint64_t nb00;
|
|
332
|
+
uint64_t nb01;
|
|
333
|
+
uint64_t nb02;
|
|
334
|
+
uint64_t nb03;
|
|
335
|
+
int64_t ne10;
|
|
336
|
+
int64_t ne11;
|
|
337
|
+
int64_t ne12;
|
|
338
|
+
int64_t ne13;
|
|
339
|
+
uint64_t nb10;
|
|
340
|
+
uint64_t nb11;
|
|
341
|
+
uint64_t nb12;
|
|
342
|
+
uint64_t nb13;
|
|
343
|
+
int64_t ne0;
|
|
344
|
+
int64_t ne1;
|
|
345
|
+
int64_t ne2;
|
|
346
|
+
int64_t ne3;
|
|
347
|
+
uint64_t nb0;
|
|
348
|
+
uint64_t nb1;
|
|
349
|
+
uint64_t nb2;
|
|
350
|
+
uint64_t nb3;
|
|
351
|
+
} ggml_metal_kargs_sum_rows;
|
|
352
|
+
|
|
353
|
+
typedef struct {
|
|
354
|
+
int64_t ne00;
|
|
355
|
+
int64_t ne01;
|
|
356
|
+
int64_t ne02;
|
|
357
|
+
float scale;
|
|
358
|
+
float max_bias;
|
|
359
|
+
float m0;
|
|
360
|
+
float m1;
|
|
361
|
+
uint32_t n_head_log2;
|
|
362
|
+
} ggml_metal_kargs_soft_max;
|
|
363
|
+
|
|
364
|
+
typedef struct {
|
|
365
|
+
int64_t ne00;
|
|
366
|
+
int64_t ne01;
|
|
367
|
+
int n_past;
|
|
368
|
+
} ggml_metal_kargs_diag_mask_inf;
|
|
369
|
+
|
|
370
|
+
typedef struct {
|
|
371
|
+
int64_t ne00;
|
|
372
|
+
int64_t ne01;
|
|
373
|
+
int64_t ne02;
|
|
374
|
+
uint64_t nb00;
|
|
375
|
+
uint64_t nb01;
|
|
376
|
+
uint64_t nb02;
|
|
377
|
+
int64_t ne10;
|
|
378
|
+
int64_t ne11;
|
|
379
|
+
uint64_t nb10;
|
|
380
|
+
uint64_t nb11;
|
|
381
|
+
int64_t ne0;
|
|
382
|
+
int64_t ne1;
|
|
383
|
+
int64_t ne2;
|
|
384
|
+
uint64_t nb0;
|
|
385
|
+
uint64_t nb1;
|
|
386
|
+
uint64_t nb2;
|
|
387
|
+
} ggml_metal_kargs_ssm_conv;
|
|
388
|
+
|
|
389
|
+
typedef struct {
|
|
390
|
+
int64_t d_state;
|
|
391
|
+
int64_t d_inner;
|
|
392
|
+
int64_t n_seq_tokens;
|
|
393
|
+
int64_t n_seqs;
|
|
394
|
+
uint64_t nb00;
|
|
395
|
+
uint64_t nb01;
|
|
396
|
+
uint64_t nb02;
|
|
397
|
+
uint64_t nb10;
|
|
398
|
+
uint64_t nb11;
|
|
399
|
+
uint64_t nb12;
|
|
400
|
+
uint64_t nb13;
|
|
401
|
+
uint64_t nb20;
|
|
402
|
+
uint64_t nb21;
|
|
403
|
+
uint64_t nb22;
|
|
404
|
+
uint64_t nb30;
|
|
405
|
+
uint64_t nb31;
|
|
406
|
+
uint64_t nb40;
|
|
407
|
+
uint64_t nb41;
|
|
408
|
+
uint64_t nb42;
|
|
409
|
+
uint64_t nb50;
|
|
410
|
+
uint64_t nb51;
|
|
411
|
+
uint64_t nb52;
|
|
412
|
+
} ggml_metal_kargs_ssm_scan;
|
|
413
|
+
|
|
414
|
+
typedef struct {
|
|
415
|
+
int64_t ne00;
|
|
416
|
+
uint64_t nb01;
|
|
417
|
+
uint64_t nb02;
|
|
418
|
+
int64_t ne10;
|
|
419
|
+
uint64_t nb10;
|
|
420
|
+
uint64_t nb11;
|
|
421
|
+
uint64_t nb1;
|
|
422
|
+
uint64_t nb2;
|
|
423
|
+
} ggml_metal_kargs_get_rows;
|
|
424
|
+
|
|
425
|
+
typedef struct {
|
|
426
|
+
int64_t ne00;
|
|
427
|
+
int64_t ne01;
|
|
428
|
+
int64_t ne02;
|
|
429
|
+
int64_t ne03;
|
|
430
|
+
uint64_t nb00;
|
|
431
|
+
uint64_t nb01;
|
|
432
|
+
uint64_t nb02;
|
|
433
|
+
uint64_t nb03;
|
|
434
|
+
int64_t ne0;
|
|
435
|
+
int64_t ne1;
|
|
436
|
+
int64_t ne2;
|
|
437
|
+
int64_t ne3;
|
|
438
|
+
uint64_t nb0;
|
|
439
|
+
uint64_t nb1;
|
|
440
|
+
uint64_t nb2;
|
|
441
|
+
uint64_t nb3;
|
|
442
|
+
float sf0;
|
|
443
|
+
float sf1;
|
|
444
|
+
float sf2;
|
|
445
|
+
float sf3;
|
|
446
|
+
} ggml_metal_kargs_upscale;
|
|
447
|
+
|
|
448
|
+
typedef struct {
|
|
449
|
+
int64_t ne00;
|
|
450
|
+
int64_t ne01;
|
|
451
|
+
int64_t ne02;
|
|
452
|
+
int64_t ne03;
|
|
453
|
+
uint64_t nb00;
|
|
454
|
+
uint64_t nb01;
|
|
455
|
+
uint64_t nb02;
|
|
456
|
+
uint64_t nb03;
|
|
457
|
+
int64_t ne0;
|
|
458
|
+
int64_t ne1;
|
|
459
|
+
int64_t ne2;
|
|
460
|
+
int64_t ne3;
|
|
461
|
+
uint64_t nb0;
|
|
462
|
+
uint64_t nb1;
|
|
463
|
+
uint64_t nb2;
|
|
464
|
+
uint64_t nb3;
|
|
465
|
+
} ggml_metal_kargs_pad;
|
|
466
|
+
|
|
467
|
+
typedef struct {
|
|
468
|
+
int64_t ne00;
|
|
469
|
+
int64_t ne01;
|
|
470
|
+
int64_t ne02;
|
|
471
|
+
int64_t ne03;
|
|
472
|
+
uint64_t nb00;
|
|
473
|
+
uint64_t nb01;
|
|
474
|
+
uint64_t nb02;
|
|
475
|
+
uint64_t nb03;
|
|
476
|
+
int64_t ne0;
|
|
477
|
+
int64_t ne1;
|
|
478
|
+
int64_t ne2;
|
|
479
|
+
int64_t ne3;
|
|
480
|
+
uint64_t nb0;
|
|
481
|
+
uint64_t nb1;
|
|
482
|
+
uint64_t nb2;
|
|
483
|
+
uint64_t nb3;
|
|
484
|
+
int32_t p0;
|
|
485
|
+
int32_t p1;
|
|
486
|
+
} ggml_metal_kargs_pad_reflect_1d;
|
|
487
|
+
|
|
488
|
+
typedef struct {
|
|
489
|
+
uint64_t nb1;
|
|
490
|
+
int dim;
|
|
491
|
+
int max_period;
|
|
492
|
+
} ggml_metal_kargs_timestep_embedding;
|
|
493
|
+
|
|
494
|
+
typedef struct {
|
|
495
|
+
float slope;
|
|
496
|
+
} ggml_metal_kargs_leaky_relu;
|
|
497
|
+
|
|
498
|
+
typedef struct {
|
|
499
|
+
int64_t ncols;
|
|
500
|
+
int64_t ncols_pad;
|
|
501
|
+
} ggml_metal_kargs_argsort;
|
|
502
|
+
|
|
503
|
+
typedef struct {
|
|
504
|
+
int64_t ne0;
|
|
505
|
+
float start;
|
|
506
|
+
float step;
|
|
507
|
+
} ggml_metal_kargs_arange;
|
|
508
|
+
|
|
509
|
+
typedef struct {
|
|
510
|
+
int32_t k0;
|
|
511
|
+
int32_t k1;
|
|
512
|
+
int32_t s0;
|
|
513
|
+
int32_t s1;
|
|
514
|
+
int32_t p0;
|
|
515
|
+
int32_t p1;
|
|
516
|
+
int64_t IH;
|
|
517
|
+
int64_t IW;
|
|
518
|
+
int64_t OH;
|
|
519
|
+
int64_t OW;
|
|
520
|
+
int64_t parallel_elements;
|
|
521
|
+
} ggml_metal_kargs_pool_2d;
|
|
522
|
+
|
|
288
523
|
#endif // GGML_METAL_IMPL
|