@fugood/llama.node 1.0.0-beta.4 → 1.0.0-beta.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +7 -4
- package/lib/binding.ts +1 -1
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +27 -26
- package/src/LlamaCompletionWorker.cpp +21 -4
- package/src/LlamaCompletionWorker.h +2 -0
- package/src/LlamaContext.cpp +3 -12
- package/src/common.hpp +6 -5
- package/src/llama.cpp/CMakeLists.txt +15 -4
- package/src/llama.cpp/common/CMakeLists.txt +15 -24
- package/src/llama.cpp/common/arg.cpp +172 -110
- package/src/llama.cpp/common/chat-parser.cpp +385 -0
- package/src/llama.cpp/common/chat-parser.h +120 -0
- package/src/llama.cpp/common/chat.cpp +726 -596
- package/src/llama.cpp/common/chat.h +74 -8
- package/src/llama.cpp/common/common.cpp +56 -38
- package/src/llama.cpp/common/common.h +9 -3
- package/src/llama.cpp/common/json-partial.cpp +256 -0
- package/src/llama.cpp/common/json-partial.h +38 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/src/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/src/llama.cpp/common/sampling.cpp +7 -8
- package/src/llama.cpp/common/speculative.cpp +6 -4
- package/src/llama.cpp/ggml/CMakeLists.txt +48 -3
- package/src/llama.cpp/ggml/include/ggml.h +22 -3
- package/src/llama.cpp/ggml/src/CMakeLists.txt +81 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +131 -49
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2162 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +12 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +64 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +282 -100
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1570 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +119 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +204 -49
- package/src/llama.cpp/include/llama.h +145 -40
- package/src/llama.cpp/src/CMakeLists.txt +5 -1
- package/src/llama.cpp/src/llama-arch.cpp +99 -3
- package/src/llama.cpp/src/llama-arch.h +10 -1
- package/src/llama.cpp/src/llama-batch.cpp +728 -272
- package/src/llama.cpp/src/llama-batch.h +112 -54
- package/src/llama.cpp/src/llama-chat.cpp +19 -2
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +525 -339
- package/src/llama.cpp/src/llama-context.h +38 -17
- package/src/llama.cpp/src/llama-cparams.cpp +4 -0
- package/src/llama.cpp/src/llama-cparams.h +2 -0
- package/src/llama.cpp/src/llama-grammar.cpp +12 -2
- package/src/llama.cpp/src/llama-graph.cpp +413 -353
- package/src/llama.cpp/src/llama-graph.h +112 -56
- package/src/llama.cpp/src/llama-hparams.cpp +10 -2
- package/src/llama.cpp/src/llama-hparams.h +13 -2
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +279 -0
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +128 -0
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +1815 -0
- package/src/llama.cpp/src/llama-kv-cache-unified.h +303 -0
- package/src/llama.cpp/src/llama-kv-cells.h +415 -0
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
- package/src/llama.cpp/src/llama-memory-hybrid.h +138 -0
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +1112 -0
- package/src/llama.cpp/src/llama-memory-recurrent.h +183 -0
- package/src/llama.cpp/src/llama-memory.cpp +41 -0
- package/src/llama.cpp/src/llama-memory.h +86 -5
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/src/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/src/llama.cpp/src/llama-model.cpp +1137 -528
- package/src/llama.cpp/src/llama-model.h +4 -0
- package/src/llama.cpp/src/llama-quant.cpp +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2 -2
- package/src/llama.cpp/src/llama-vocab.cpp +69 -32
- package/src/llama.cpp/src/llama-vocab.h +1 -0
- package/src/llama.cpp/src/llama.cpp +11 -7
- package/src/llama.cpp/src/unicode.cpp +5 -0
- package/src/tts_utils.h +1 -1
- package/src/llama.cpp/common/json.hpp +0 -24766
- package/src/llama.cpp/common/minja/chat-template.hpp +0 -541
- package/src/llama.cpp/common/minja/minja.hpp +0 -2974
- package/src/llama.cpp/common/stb_image.h +0 -7988
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13326
- package/src/llama.cpp/src/llama-kv-cache.cpp +0 -2827
- package/src/llama.cpp/src/llama-kv-cache.h +0 -515
- /package/src/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -53,7 +53,6 @@
|
|
|
53
53
|
#include "ggml-cpu-impl.h"
|
|
54
54
|
#include "ggml-quants.h"
|
|
55
55
|
|
|
56
|
-
#include <atomic>
|
|
57
56
|
#include <array>
|
|
58
57
|
#include <type_traits>
|
|
59
58
|
|
|
@@ -63,7 +62,7 @@
|
|
|
63
62
|
#define NOINLINE __attribute__((__noinline__))
|
|
64
63
|
#endif
|
|
65
64
|
|
|
66
|
-
#if defined(__ARM_NEON) || defined(__AVX512F__)
|
|
65
|
+
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
|
|
67
66
|
#define VECTOR_REGISTERS 32
|
|
68
67
|
#else
|
|
69
68
|
#define VECTOR_REGISTERS 16
|
|
@@ -110,6 +109,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
|
|
110
109
|
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
|
111
110
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
112
111
|
|
|
112
|
+
#if defined(__VXE__) || defined(__VXE2__)
|
|
113
|
+
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
|
|
114
|
+
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
|
|
115
|
+
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
|
|
116
|
+
#endif
|
|
117
|
+
|
|
113
118
|
#if defined(__MMA__)
|
|
114
119
|
typedef vector unsigned char vec_t;
|
|
115
120
|
typedef __vector_quad acc_t;
|
|
@@ -163,6 +168,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
|
|
|
163
168
|
#endif
|
|
164
169
|
#endif
|
|
165
170
|
|
|
171
|
+
#if defined(__VXE__) || defined(__VXE2__)
|
|
172
|
+
template <>
|
|
173
|
+
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
|
174
|
+
return vec_madd(a, b, c);
|
|
175
|
+
}
|
|
176
|
+
#endif
|
|
177
|
+
|
|
166
178
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
167
179
|
// VECTORIZED HORIZONTAL SUM
|
|
168
180
|
|
|
@@ -179,6 +191,13 @@ inline float hsum(float16x8_t x) {
|
|
|
179
191
|
}
|
|
180
192
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
181
193
|
|
|
194
|
+
#if defined(__VXE__) || defined(__VXE2__)
|
|
195
|
+
inline float hsum(float32x4_t x) {
|
|
196
|
+
float32x4_t tmp = x + vec_reve(x);
|
|
197
|
+
return tmp[0] + tmp[1];
|
|
198
|
+
}
|
|
199
|
+
#endif
|
|
200
|
+
|
|
182
201
|
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
|
183
202
|
inline float hsum(__m128 x) {
|
|
184
203
|
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
|
@@ -228,6 +247,21 @@ template <> inline float32x4_t load(const ggml_fp16_t *p) {
|
|
|
228
247
|
#endif // _MSC_VER
|
|
229
248
|
#endif // __ARM_NEON
|
|
230
249
|
|
|
250
|
+
#if defined(__VXE__) || defined(__VXE2__)
|
|
251
|
+
template <> inline float32x4_t load(const ggml_fp16_t * p) {
|
|
252
|
+
float tmp[4];
|
|
253
|
+
|
|
254
|
+
for (int i = 0; i < 4; i++) {
|
|
255
|
+
tmp[i] = GGML_FP16_TO_FP32(p[i]);
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
return vec_xl(0, (const float *)(tmp));
|
|
259
|
+
}
|
|
260
|
+
template <> inline float32x4_t load(const float * p) {
|
|
261
|
+
return vec_xl(0, p);
|
|
262
|
+
}
|
|
263
|
+
#endif
|
|
264
|
+
|
|
231
265
|
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
|
232
266
|
template <> inline __m128 load(const float *p) {
|
|
233
267
|
return _mm_loadu_ps(p);
|
|
@@ -394,8 +428,6 @@ class tinyBLAS {
|
|
|
394
428
|
|
|
395
429
|
template <int RM, int RN, int BM>
|
|
396
430
|
NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
|
|
397
|
-
static std::atomic<int64_t> current_chunk;
|
|
398
|
-
|
|
399
431
|
GGML_ASSERT(m % (RM * BM) == 0);
|
|
400
432
|
const int64_t ytiles = m / (RM * BM);
|
|
401
433
|
const int64_t xtiles = (n + RN -1) / RN;
|
|
@@ -410,7 +442,7 @@ class tinyBLAS {
|
|
|
410
442
|
if (params->ith == 0) {
|
|
411
443
|
GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
|
|
412
444
|
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
413
|
-
|
|
445
|
+
ggml_threadpool_chunk_set(params->threadpool, params->nth);
|
|
414
446
|
}
|
|
415
447
|
|
|
416
448
|
ggml_barrier(params->threadpool);
|
|
@@ -439,8 +471,7 @@ class tinyBLAS {
|
|
|
439
471
|
GGML_ASSERT(jj == jj2);
|
|
440
472
|
}
|
|
441
473
|
|
|
442
|
-
|
|
443
|
-
job = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed);
|
|
474
|
+
job = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
444
475
|
}
|
|
445
476
|
|
|
446
477
|
ggml_barrier(params->threadpool);
|
|
@@ -3323,6 +3354,14 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
3323
3354
|
(const float *)B, ldb,
|
|
3324
3355
|
(float *)C, ldc};
|
|
3325
3356
|
return tb.matmul(m, n);
|
|
3357
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
3358
|
+
if (n < 4)
|
|
3359
|
+
return false;
|
|
3360
|
+
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
|
|
3361
|
+
k, (const float *)A, lda,
|
|
3362
|
+
(const float *)B, ldb,
|
|
3363
|
+
(float *)C, ldc};
|
|
3364
|
+
return tb.matmul(m, n);
|
|
3326
3365
|
#elif defined(__MMA__)
|
|
3327
3366
|
if (k % 8)
|
|
3328
3367
|
return false;
|
|
@@ -3414,6 +3453,16 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
3414
3453
|
(float *)C, ldc};
|
|
3415
3454
|
return tb.matmul(m, n);
|
|
3416
3455
|
}
|
|
3456
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
3457
|
+
if (n < 4)
|
|
3458
|
+
return false;
|
|
3459
|
+
if (Btype == GGML_TYPE_F16) {
|
|
3460
|
+
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
3461
|
+
k, (const ggml_fp16_t *)A, lda,
|
|
3462
|
+
(const ggml_fp16_t *)B, ldb,
|
|
3463
|
+
(float *)C, ldc};
|
|
3464
|
+
return tb.matmul(m, n);
|
|
3465
|
+
}
|
|
3417
3466
|
#endif
|
|
3418
3467
|
return false;
|
|
3419
3468
|
}
|
|
@@ -6793,6 +6793,73 @@ void ggml_compute_forward_pad_reflect_1d(
|
|
|
6793
6793
|
}
|
|
6794
6794
|
}
|
|
6795
6795
|
|
|
6796
|
+
// ggml_compute_forward_roll
|
|
6797
|
+
|
|
6798
|
+
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
|
|
6799
|
+
if (i < 0) {
|
|
6800
|
+
return i + ne;
|
|
6801
|
+
} else if (i >= ne) {
|
|
6802
|
+
return i - ne;
|
|
6803
|
+
}
|
|
6804
|
+
return i;
|
|
6805
|
+
}
|
|
6806
|
+
|
|
6807
|
+
static void ggml_compute_forward_roll_f32(
|
|
6808
|
+
const ggml_compute_params * params,
|
|
6809
|
+
ggml_tensor * dst) {
|
|
6810
|
+
|
|
6811
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
6812
|
+
const float * src_data = (const float *) src0->data;
|
|
6813
|
+
float * dst_data = (float *) dst->data;
|
|
6814
|
+
|
|
6815
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
6816
|
+
|
|
6817
|
+
const int s0 = ggml_get_op_params_i32(dst, 0);
|
|
6818
|
+
const int s1 = ggml_get_op_params_i32(dst, 1);
|
|
6819
|
+
const int s2 = ggml_get_op_params_i32(dst, 2);
|
|
6820
|
+
const int s3 = ggml_get_op_params_i32(dst, 3);
|
|
6821
|
+
|
|
6822
|
+
const int64_t total = ne1 * ne2 * ne3;
|
|
6823
|
+
const int64_t per_thread = (total + params->nth) / params->nth;
|
|
6824
|
+
const int64_t start = params->ith * per_thread;
|
|
6825
|
+
const int64_t end = std::min(start + per_thread, total);
|
|
6826
|
+
|
|
6827
|
+
for (int64_t i = start; i < end; ++i) {
|
|
6828
|
+
const int64_t i1 = i % ne1;
|
|
6829
|
+
const int64_t i2 = (i / ne1) % ne2;
|
|
6830
|
+
const int64_t i3 = i / (ne2 * ne1);
|
|
6831
|
+
float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
|
|
6832
|
+
|
|
6833
|
+
const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
|
|
6834
|
+
const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
|
|
6835
|
+
const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
|
|
6836
|
+
const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
|
|
6837
|
+
|
|
6838
|
+
const int64_t s = ggml_wrap_index(-s0, ne00);
|
|
6839
|
+
const int64_t n = ne00 - s;
|
|
6840
|
+
ggml_vec_cpy_f32(n, dst_row, src_row + s);
|
|
6841
|
+
ggml_vec_cpy_f32(s, dst_row + n, src_row);
|
|
6842
|
+
}
|
|
6843
|
+
}
|
|
6844
|
+
|
|
6845
|
+
void ggml_compute_forward_roll(
|
|
6846
|
+
const ggml_compute_params * params,
|
|
6847
|
+
ggml_tensor * dst) {
|
|
6848
|
+
|
|
6849
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
6850
|
+
|
|
6851
|
+
switch (src0->type) {
|
|
6852
|
+
case GGML_TYPE_F32:
|
|
6853
|
+
{
|
|
6854
|
+
ggml_compute_forward_roll_f32(params, dst);
|
|
6855
|
+
} break;
|
|
6856
|
+
default:
|
|
6857
|
+
{
|
|
6858
|
+
GGML_ABORT("fatal error");
|
|
6859
|
+
}
|
|
6860
|
+
}
|
|
6861
|
+
}
|
|
6862
|
+
|
|
6796
6863
|
// ggml_compute_forward_arange
|
|
6797
6864
|
|
|
6798
6865
|
static void ggml_compute_forward_arange_f32(
|
|
@@ -7633,39 +7700,83 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
7633
7700
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
7634
7701
|
const int ir = ir1 - ir0;
|
|
7635
7702
|
|
|
7636
|
-
|
|
7637
|
-
for (int
|
|
7638
|
-
|
|
7639
|
-
|
|
7640
|
-
|
|
7641
|
-
|
|
7642
|
-
|
|
7643
|
-
|
|
7644
|
-
|
|
7645
|
-
|
|
7646
|
-
|
|
7647
|
-
|
|
7648
|
-
|
|
7649
|
-
|
|
7650
|
-
|
|
7651
|
-
|
|
7652
|
-
|
|
7653
|
-
|
|
7654
|
-
|
|
7655
|
-
|
|
7656
|
-
|
|
7657
|
-
|
|
7658
|
-
|
|
7659
|
-
|
|
7660
|
-
|
|
7661
|
-
|
|
7662
|
-
|
|
7663
|
-
|
|
7703
|
+
#ifdef __ARM_FEATURE_SVE
|
|
7704
|
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
|
7705
|
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
|
7706
|
+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
|
7707
|
+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
7708
|
+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
|
7709
|
+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
7710
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
|
7711
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
|
7712
|
+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
7713
|
+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
|
7714
|
+
|
|
7715
|
+
// use the output as the source for the next token-wise iterations
|
|
7716
|
+
if (i2 > 0) { s0 = s; }
|
|
7717
|
+
|
|
7718
|
+
// d_inner
|
|
7719
|
+
for (int i1 = 0; i1 < ir; ++i1) {
|
|
7720
|
+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
|
7721
|
+
float x_dt = x[i1] * dt_soft_plus;
|
|
7722
|
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
|
7723
|
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
|
7724
|
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
|
7725
|
+
|
|
7726
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
7727
|
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
|
7728
|
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
|
|
7729
|
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
|
|
7730
|
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
|
7731
|
+
|
|
7732
|
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
7733
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
7734
|
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
|
7735
|
+
|
|
7736
|
+
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
|
|
7737
|
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
7738
|
+
|
|
7739
|
+
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
|
7740
|
+
}
|
|
7741
|
+
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
7664
7742
|
}
|
|
7665
|
-
y[i1] = sumf;
|
|
7666
7743
|
}
|
|
7667
7744
|
}
|
|
7668
|
-
|
|
7745
|
+
#else
|
|
7746
|
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
|
7747
|
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
|
7748
|
+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
|
7749
|
+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
7750
|
+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
|
7751
|
+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
7752
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
|
7753
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
|
7754
|
+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
7755
|
+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
|
7756
|
+
|
|
7757
|
+
// use the output as the source for the next token-wise iterations
|
|
7758
|
+
if (i2 > 0) { s0 = s; }
|
|
7759
|
+
|
|
7760
|
+
// d_inner
|
|
7761
|
+
for (int i1 = 0; i1 < ir; ++i1) {
|
|
7762
|
+
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
|
7763
|
+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
|
7764
|
+
float x_dt = x[i1] * dt_soft_plus;
|
|
7765
|
+
float sumf = 0.0f;
|
|
7766
|
+
// d_state
|
|
7767
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
|
7768
|
+
int i = i0 + i1*nc;
|
|
7769
|
+
// state = prev_state * dA + dB * x
|
|
7770
|
+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
|
7771
|
+
// y = rowwise_dotprod(state, C)
|
|
7772
|
+
sumf += state * C[i0];
|
|
7773
|
+
s[i] = state;
|
|
7774
|
+
}
|
|
7775
|
+
y[i1] = sumf;
|
|
7776
|
+
}
|
|
7777
|
+
}
|
|
7778
|
+
}
|
|
7779
|
+
#endif
|
|
7669
7780
|
}
|
|
7670
7781
|
|
|
7671
7782
|
void ggml_compute_forward_ssm_scan(
|
|
@@ -8070,6 +8181,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
|
8070
8181
|
#define GGML_F32X_MUL GGML_F32x16_MUL
|
|
8071
8182
|
#define GGML_F32X_FMA GGML_F32x16_FMA
|
|
8072
8183
|
#define WKV_VECTOR_SIZE 16
|
|
8184
|
+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
|
8185
|
+
#define GGML_F32X GGML_F32xt
|
|
8186
|
+
#define GGML_F32X_SET1 GGML_F32xt_SET1
|
|
8187
|
+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
|
|
8188
|
+
#define GGML_F32X_STORE GGML_F32xt_STORE
|
|
8189
|
+
#define GGML_F32X_MUL GGML_F32xt_MUL
|
|
8190
|
+
#define GGML_F32X_FMA GGML_F32xt_FMA
|
|
8191
|
+
#define WKV_VECTOR_SIZE 8
|
|
8073
8192
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
8074
8193
|
#define GGML_F32X GGML_F32x4
|
|
8075
8194
|
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
|
@@ -8081,7 +8200,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
|
8081
8200
|
#endif
|
|
8082
8201
|
|
|
8083
8202
|
#ifdef WKV_VECTOR_SIZE
|
|
8084
|
-
|
|
8203
|
+
int wkv_vector_size;
|
|
8204
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8205
|
+
wkv_vector_size = svcntw();
|
|
8206
|
+
#else
|
|
8207
|
+
wkv_vector_size = WKV_VECTOR_SIZE;
|
|
8208
|
+
#endif
|
|
8209
|
+
const int64_t vec_count = head_size / wkv_vector_size;
|
|
8085
8210
|
|
|
8086
8211
|
for (int64_t t = 0; t < T; t++) {
|
|
8087
8212
|
size_t t_offset = t * t_stride;
|
|
@@ -8111,7 +8236,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
|
8111
8236
|
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
|
|
8112
8237
|
|
|
8113
8238
|
for (int64_t j = 0; j < vec_count; j++) {
|
|
8114
|
-
size_t base_j = j *
|
|
8239
|
+
size_t base_j = j * wkv_vector_size;
|
|
8115
8240
|
size_t t_h_j_offset = t_h_offset + base_j;
|
|
8116
8241
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
|
8117
8242
|
|
|
@@ -8136,7 +8261,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
|
8136
8261
|
}
|
|
8137
8262
|
|
|
8138
8263
|
// Handle remaining elements, this will not be used.
|
|
8139
|
-
for (int64_t j = vec_count *
|
|
8264
|
+
for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
|
|
8140
8265
|
size_t t_h_j_offset = t_h_offset + j;
|
|
8141
8266
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
8142
8267
|
float v_val = v[t_h_j_offset];
|
|
@@ -8272,6 +8397,14 @@ static void ggml_compute_forward_gla_f32(
|
|
|
8272
8397
|
#define GGML_F32X_MUL GGML_F32x16_MUL
|
|
8273
8398
|
#define GGML_F32X_FMA GGML_F32x16_FMA
|
|
8274
8399
|
#define GLA_VECTOR_SIZE 16
|
|
8400
|
+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
|
8401
|
+
#define GGML_F32X GGML_F32xt
|
|
8402
|
+
#define GGML_F32X_SET1 GGML_F32xt_SET1
|
|
8403
|
+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
|
|
8404
|
+
#define GGML_F32X_STORE GGML_F32xt_STORE
|
|
8405
|
+
#define GGML_F32X_MUL GGML_F32xt_MUL
|
|
8406
|
+
#define GGML_F32X_FMA GGML_F32xt_FMA
|
|
8407
|
+
#define GLA_VECTOR_SIZE 8
|
|
8275
8408
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
8276
8409
|
#define GGML_F32X GGML_F32x4
|
|
8277
8410
|
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
|
@@ -8283,7 +8416,13 @@ static void ggml_compute_forward_gla_f32(
|
|
|
8283
8416
|
#endif
|
|
8284
8417
|
|
|
8285
8418
|
#ifdef GLA_VECTOR_SIZE
|
|
8286
|
-
|
|
8419
|
+
int gla_vector_size;
|
|
8420
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8421
|
+
gla_vector_size = svcntw();
|
|
8422
|
+
#else
|
|
8423
|
+
gla_vector_size = GLA_VECTOR_SIZE;
|
|
8424
|
+
#endif
|
|
8425
|
+
const int64_t vec_count = head_size / gla_vector_size;
|
|
8287
8426
|
|
|
8288
8427
|
for (int64_t t = 0; t < T; t++) {
|
|
8289
8428
|
size_t t_offset = t * t_stride;
|
|
@@ -8310,7 +8449,7 @@ static void ggml_compute_forward_gla_f32(
|
|
|
8310
8449
|
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
|
|
8311
8450
|
|
|
8312
8451
|
for (int64_t j = 0; j < vec_count; j++) {
|
|
8313
|
-
size_t base_j = j *
|
|
8452
|
+
size_t base_j = j * gla_vector_size;
|
|
8314
8453
|
size_t t_h_j_offset = t_h_offset + base_j;
|
|
8315
8454
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
|
8316
8455
|
|
|
@@ -8334,7 +8473,7 @@ static void ggml_compute_forward_gla_f32(
|
|
|
8334
8473
|
}
|
|
8335
8474
|
|
|
8336
8475
|
// Handle remaining elements, this will not be used.
|
|
8337
|
-
for (int64_t j = vec_count *
|
|
8476
|
+
for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
|
|
8338
8477
|
size_t t_h_j_offset = t_h_offset + j;
|
|
8339
8478
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
8340
8479
|
float v_val = v[t_h_j_offset];
|
|
@@ -8443,83 +8582,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
|
8443
8582
|
int64_t h_stride_2d = head_size * head_size;
|
|
8444
8583
|
|
|
8445
8584
|
#if defined(GGML_SIMD)
|
|
8446
|
-
|
|
8447
|
-
|
|
8448
|
-
int64_t
|
|
8449
|
-
|
|
8450
|
-
|
|
8451
|
-
|
|
8452
|
-
|
|
8453
|
-
|
|
8454
|
-
int64_t
|
|
8455
|
-
|
|
8456
|
-
|
|
8457
|
-
|
|
8458
|
-
|
|
8459
|
-
int64_t
|
|
8460
|
-
|
|
8461
|
-
|
|
8585
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8586
|
+
// scalar Route to scalar implementation //TODO: Write SVE code
|
|
8587
|
+
for (int64_t t = 0; t < T; t++) {
|
|
8588
|
+
int64_t t_offset = t * t_stride;
|
|
8589
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
8590
|
+
float * state_cur = state + state_offset;
|
|
8591
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
|
8592
|
+
|
|
8593
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
8594
|
+
int64_t h_offset = h * h_stride;
|
|
8595
|
+
int64_t t_h_offset = t_offset + h_offset;
|
|
8596
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
|
8597
|
+
|
|
8598
|
+
for (int64_t i = 0; i < head_size; i++) {
|
|
8599
|
+
int64_t t_h_i_offset = t_h_offset + i;
|
|
8600
|
+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
|
8601
|
+
|
|
8602
|
+
float v_val = v[t_h_i_offset];
|
|
8603
|
+
|
|
8604
|
+
float sa = 0, result = 0;
|
|
8605
|
+
for (int64_t j = 0; j < head_size; j++) {
|
|
8606
|
+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
|
8607
|
+
}
|
|
8462
8608
|
|
|
8463
|
-
|
|
8464
|
-
|
|
8465
|
-
|
|
8466
|
-
|
|
8467
|
-
|
|
8468
|
-
|
|
8469
|
-
|
|
8470
|
-
|
|
8471
|
-
|
|
8472
|
-
|
|
8473
|
-
|
|
8609
|
+
for (int64_t j = 0; j < head_size; j++) {
|
|
8610
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
|
8611
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
8612
|
+
|
|
8613
|
+
float r_val = r[t_h_j_offset];
|
|
8614
|
+
float w_val = w[t_h_j_offset];
|
|
8615
|
+
float k_val = k[t_h_j_offset];
|
|
8616
|
+
float b_val = b[t_h_j_offset];
|
|
8617
|
+
float kv_val = v_val * k_val;
|
|
8618
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
8619
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
8620
|
+
result += state_cur[h_2d_i_j_offset] * r_val;
|
|
8474
8621
|
}
|
|
8475
|
-
|
|
8622
|
+
dst_data[t_h_i_offset] = result;
|
|
8476
8623
|
}
|
|
8624
|
+
}
|
|
8625
|
+
}
|
|
8626
|
+
#else
|
|
8627
|
+
for (int64_t t = 0; t < T; t++) {
|
|
8628
|
+
int64_t t_offset = t * t_stride;
|
|
8629
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
8630
|
+
float * state_cur = state + state_offset;
|
|
8631
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
|
8632
|
+
|
|
8633
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
8634
|
+
int64_t h_offset = h * h_stride;
|
|
8635
|
+
int64_t t_h_offset = t_offset + h_offset;
|
|
8636
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
|
8637
|
+
|
|
8638
|
+
for (int64_t ii = 0; ii < head_size; ii++) {
|
|
8639
|
+
int64_t t_h_i_offset = t_h_offset + ii;
|
|
8640
|
+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
|
8641
|
+
|
|
8642
|
+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
|
8643
|
+
|
|
8644
|
+
float sa = 0;
|
|
8645
|
+
{
|
|
8646
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
8647
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
|
8648
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
|
8649
|
+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
|
8650
|
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
8651
|
+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
|
8652
|
+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
|
8653
|
+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
|
8654
|
+
}
|
|
8655
|
+
}
|
|
8656
|
+
GGML_F32_VEC_REDUCE(sa, sum);
|
|
8657
|
+
}
|
|
8477
8658
|
|
|
8478
|
-
|
|
8659
|
+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
|
8479
8660
|
|
|
8480
|
-
|
|
8481
|
-
|
|
8482
|
-
|
|
8483
|
-
|
|
8484
|
-
|
|
8485
|
-
|
|
8661
|
+
int64_t j = 0;
|
|
8662
|
+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
8663
|
+
for (; j < head_size; j += GGML_F32_STEP) {
|
|
8664
|
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
8665
|
+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
|
8666
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
|
8486
8667
|
|
|
8487
|
-
|
|
8488
|
-
|
|
8489
|
-
|
|
8490
|
-
|
|
8668
|
+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
|
8669
|
+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
|
8670
|
+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
|
8671
|
+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
|
8491
8672
|
|
|
8492
|
-
|
|
8673
|
+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
|
8493
8674
|
|
|
8494
|
-
|
|
8495
|
-
|
|
8496
|
-
|
|
8497
|
-
|
|
8498
|
-
|
|
8675
|
+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
|
8676
|
+
// kv + s * decay + sa * b
|
|
8677
|
+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
|
8678
|
+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
|
8679
|
+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
|
8499
8680
|
|
|
8500
|
-
|
|
8681
|
+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
|
8682
|
+
}
|
|
8683
|
+
}
|
|
8684
|
+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
|
8685
|
+
|
|
8686
|
+
// There shouldn't be left-overs though.
|
|
8687
|
+
for (; j < head_size; j++) {
|
|
8688
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
|
8689
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
8690
|
+
|
|
8691
|
+
float r_val = r[t_h_j_offset];
|
|
8692
|
+
float w_val = w[t_h_j_offset];
|
|
8693
|
+
float k_val = k[t_h_j_offset];
|
|
8694
|
+
float b_val = b[t_h_j_offset];
|
|
8695
|
+
float kv_val = v[t_h_i_offset] * k_val;
|
|
8696
|
+
|
|
8697
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
8698
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
8699
|
+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
|
8501
8700
|
}
|
|
8502
|
-
}
|
|
8503
|
-
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
|
8504
|
-
|
|
8505
|
-
// There shouldn't be left-overs though.
|
|
8506
|
-
for (; j < head_size; j++) {
|
|
8507
|
-
int64_t t_h_j_offset = t_h_offset + j;
|
|
8508
|
-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
8509
|
-
|
|
8510
|
-
float r_val = r[t_h_j_offset];
|
|
8511
|
-
float w_val = w[t_h_j_offset];
|
|
8512
|
-
float k_val = k[t_h_j_offset];
|
|
8513
|
-
float b_val = b[t_h_j_offset];
|
|
8514
|
-
float kv_val = v[t_h_i_offset] * k_val;
|
|
8515
|
-
|
|
8516
|
-
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
8517
|
-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
8518
|
-
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
|
8519
8701
|
}
|
|
8520
8702
|
}
|
|
8521
8703
|
}
|
|
8522
|
-
|
|
8704
|
+
#endif
|
|
8523
8705
|
#else
|
|
8524
8706
|
for (int64_t t = 0; t < T; t++) {
|
|
8525
8707
|
int64_t t_offset = t * t_stride;
|
|
@@ -72,6 +72,7 @@ void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params
|
|
|
72
72
|
void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
73
73
|
void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
74
74
|
void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
75
|
+
void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
75
76
|
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
76
77
|
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
77
78
|
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|