whisper.rn 0.4.0-rc.10 → 0.4.0-rc.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -3
- package/cpp/ggml-alloc.c +6 -14
- package/cpp/ggml-backend-impl.h +50 -11
- package/cpp/ggml-backend-reg.cpp +409 -31
- package/cpp/ggml-backend.cpp +9 -3
- package/cpp/ggml-backend.h +18 -0
- package/cpp/ggml-common.h +41 -43
- package/cpp/ggml-cpp.h +1 -0
- package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +941 -254
- package/cpp/ggml-cpu-aarch64.h +2 -24
- package/cpp/ggml-cpu-impl.h +171 -11
- package/cpp/ggml-cpu-quants.c +1812 -389
- package/cpp/ggml-cpu-traits.cpp +36 -0
- package/cpp/ggml-cpu-traits.h +38 -0
- package/cpp/ggml-cpu.c +1432 -610
- package/cpp/ggml-cpu.cpp +131 -141
- package/cpp/ggml-cpu.h +10 -50
- package/cpp/ggml-impl.h +27 -11
- package/cpp/ggml-metal-impl.h +39 -0
- package/cpp/ggml-metal.h +1 -1
- package/cpp/ggml-metal.m +1031 -359
- package/cpp/ggml-opt.cpp +854 -0
- package/cpp/ggml-opt.h +216 -0
- package/cpp/ggml-quants.c +0 -9
- package/cpp/ggml-threading.h +4 -2
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +501 -1537
- package/cpp/ggml.h +144 -171
- package/cpp/gguf.cpp +1329 -0
- package/cpp/gguf.h +202 -0
- package/cpp/whisper.cpp +254 -114
- package/cpp/whisper.h +6 -3
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +1 -1
- package/src/version.json +1 -1
- package/whisper-rn.podspec +2 -2
- package/cpp/README.md +0 -4
- package/cpp/ggml-aarch64.c +0 -129
- package/cpp/ggml-aarch64.h +0 -19
- package/cpp/ggml-backend.cpp.rej +0 -12
package/cpp/ggml-cpu.c
CHANGED
|
@@ -3,11 +3,10 @@
|
|
|
3
3
|
|
|
4
4
|
#include "ggml-backend-impl.h"
|
|
5
5
|
#include "ggml-backend.h"
|
|
6
|
-
#include "ggml-cpu-
|
|
6
|
+
#include "ggml-cpu-traits.h"
|
|
7
7
|
#include "ggml-cpu-impl.h"
|
|
8
8
|
#include "ggml-cpu.h"
|
|
9
9
|
#include "ggml-impl.h"
|
|
10
|
-
#include "ggml-quants.h"
|
|
11
10
|
#include "ggml-cpu-quants.h"
|
|
12
11
|
#include "ggml-threading.h"
|
|
13
12
|
#include "ggml.h"
|
|
@@ -109,10 +108,12 @@ static wsp_ggml_fp16_t wsp_ggml_table_gelu_quick_f16[1 << 16];
|
|
|
109
108
|
#if defined(__ARM_ARCH)
|
|
110
109
|
struct wsp_ggml_arm_arch_features_type {
|
|
111
110
|
int has_neon;
|
|
111
|
+
int has_dotprod;
|
|
112
112
|
int has_i8mm;
|
|
113
113
|
int has_sve;
|
|
114
114
|
int sve_cnt;
|
|
115
|
-
|
|
115
|
+
int has_sme;
|
|
116
|
+
} wsp_ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
|
|
116
117
|
#endif
|
|
117
118
|
|
|
118
119
|
|
|
@@ -124,8 +125,7 @@ struct wsp_ggml_arm_arch_features_type {
|
|
|
124
125
|
#endif
|
|
125
126
|
#include <windows.h>
|
|
126
127
|
|
|
127
|
-
|
|
128
|
-
#if !defined(__clang__)
|
|
128
|
+
#if defined(_MSC_VER) && !defined(__clang__)
|
|
129
129
|
#define WSP_GGML_CACHE_ALIGN __declspec(align(WSP_GGML_CACHE_LINE))
|
|
130
130
|
|
|
131
131
|
typedef volatile LONG atomic_int;
|
|
@@ -222,10 +222,6 @@ typedef void * thread_ret_t;
|
|
|
222
222
|
|
|
223
223
|
typedef pthread_t wsp_ggml_thread_t;
|
|
224
224
|
|
|
225
|
-
#ifdef WSP_GGML_USE_CPU_HBM
|
|
226
|
-
#include <hbwmalloc.h>
|
|
227
|
-
#endif
|
|
228
|
-
|
|
229
225
|
#if defined(__APPLE__)
|
|
230
226
|
#include <unistd.h>
|
|
231
227
|
#include <mach/mach.h>
|
|
@@ -241,6 +237,8 @@ typedef pthread_t wsp_ggml_thread_t;
|
|
|
241
237
|
#else
|
|
242
238
|
#if defined(__POWER9_VECTOR__)
|
|
243
239
|
#define CACHE_LINE_SIZE 128
|
|
240
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
241
|
+
#define CACHE_LINE_SIZE 256
|
|
244
242
|
#else
|
|
245
243
|
#define CACHE_LINE_SIZE 64
|
|
246
244
|
#endif
|
|
@@ -299,7 +297,6 @@ static const struct wsp_ggml_type_traits_cpu type_traits_cpu[WSP_GGML_TYPE_COUNT
|
|
|
299
297
|
},
|
|
300
298
|
[WSP_GGML_TYPE_Q8_0] = {
|
|
301
299
|
.from_float = wsp_quantize_row_q8_0,
|
|
302
|
-
.from_float_to_mat = wsp_quantize_mat_q8_0,
|
|
303
300
|
.vec_dot = wsp_ggml_vec_dot_q8_0_q8_0,
|
|
304
301
|
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
305
302
|
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
|
@@ -407,33 +404,6 @@ static const struct wsp_ggml_type_traits_cpu type_traits_cpu[WSP_GGML_TYPE_COUNT
|
|
|
407
404
|
.vec_dot_type = WSP_GGML_TYPE_BF16,
|
|
408
405
|
.nrows = 1,
|
|
409
406
|
},
|
|
410
|
-
[WSP_GGML_TYPE_Q4_0_4_4] = {
|
|
411
|
-
.from_float = NULL,
|
|
412
|
-
.vec_dot = NULL,
|
|
413
|
-
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
414
|
-
.nrows = 1,
|
|
415
|
-
.ncols = 4,
|
|
416
|
-
.gemv = wsp_ggml_gemv_q4_0_4x4_q8_0,
|
|
417
|
-
.gemm = wsp_ggml_gemm_q4_0_4x4_q8_0,
|
|
418
|
-
},
|
|
419
|
-
[WSP_GGML_TYPE_Q4_0_4_8] = {
|
|
420
|
-
.from_float = NULL,
|
|
421
|
-
.vec_dot = NULL,
|
|
422
|
-
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
423
|
-
.nrows = 1,
|
|
424
|
-
.ncols = 4,
|
|
425
|
-
.gemv = wsp_ggml_gemv_q4_0_4x8_q8_0,
|
|
426
|
-
.gemm = wsp_ggml_gemm_q4_0_4x8_q8_0,
|
|
427
|
-
},
|
|
428
|
-
[WSP_GGML_TYPE_Q4_0_8_8] = {
|
|
429
|
-
.from_float = NULL,
|
|
430
|
-
.vec_dot = NULL,
|
|
431
|
-
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
432
|
-
.nrows = 1,
|
|
433
|
-
.ncols = 8,
|
|
434
|
-
.gemv = wsp_ggml_gemv_q4_0_8x8_q8_0,
|
|
435
|
-
.gemm = wsp_ggml_gemm_q4_0_8x8_q8_0,
|
|
436
|
-
},
|
|
437
407
|
[WSP_GGML_TYPE_TQ1_0] = {
|
|
438
408
|
.from_float = wsp_quantize_row_tq1_0,
|
|
439
409
|
.vec_dot = wsp_ggml_vec_dot_tq1_0_q8_K,
|
|
@@ -485,21 +455,21 @@ const struct wsp_ggml_type_traits_cpu * wsp_ggml_get_type_traits_cpu(enum wsp_gg
|
|
|
485
455
|
#define WSP_GGML_F32x4_ADD vaddq_f32
|
|
486
456
|
#define WSP_GGML_F32x4_MUL vmulq_f32
|
|
487
457
|
#define WSP_GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
|
|
488
|
-
#define WSP_GGML_F32x4_REDUCE(res, x)
|
|
489
|
-
{
|
|
490
|
-
int offset = WSP_GGML_F32_ARR >> 1;
|
|
491
|
-
for (int i = 0; i < offset; ++i) {
|
|
492
|
-
(x)[i] = vaddq_f32((x)[i], (x)[offset+i]);
|
|
493
|
-
}
|
|
494
|
-
offset >>= 1;
|
|
495
|
-
for (int i = 0; i < offset; ++i) {
|
|
496
|
-
(x)[i] = vaddq_f32((x)[i], (x)[offset+i]);
|
|
497
|
-
}
|
|
498
|
-
offset >>= 1;
|
|
499
|
-
for (int i = 0; i < offset; ++i) {
|
|
500
|
-
(x)[i] = vaddq_f32((x)[i], (x)[offset+i]);
|
|
501
|
-
}
|
|
502
|
-
(res) = WSP_GGML_F32x4_REDUCE_ONE((x)[0]);
|
|
458
|
+
#define WSP_GGML_F32x4_REDUCE(res, x) \
|
|
459
|
+
{ \
|
|
460
|
+
int offset = WSP_GGML_F32_ARR >> 1; \
|
|
461
|
+
for (int i = 0; i < offset; ++i) { \
|
|
462
|
+
(x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
|
|
463
|
+
} \
|
|
464
|
+
offset >>= 1; \
|
|
465
|
+
for (int i = 0; i < offset; ++i) { \
|
|
466
|
+
(x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
|
|
467
|
+
} \
|
|
468
|
+
offset >>= 1; \
|
|
469
|
+
for (int i = 0; i < offset; ++i) { \
|
|
470
|
+
(x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
|
|
471
|
+
} \
|
|
472
|
+
(res) = (wsp_ggml_float) WSP_GGML_F32x4_REDUCE_ONE((x)[0]); \
|
|
503
473
|
}
|
|
504
474
|
|
|
505
475
|
#define WSP_GGML_F32_VEC WSP_GGML_F32x4
|
|
@@ -614,7 +584,7 @@ do { \
|
|
|
614
584
|
for (int i = 0; i < offset; ++i) { \
|
|
615
585
|
x[i] = _mm512_add_ps(x[i], x[offset+i]); \
|
|
616
586
|
} \
|
|
617
|
-
res = _mm512_reduce_add_ps(x[0]);
|
|
587
|
+
res = (wsp_ggml_float) _mm512_reduce_add_ps(x[0]); \
|
|
618
588
|
} while (0)
|
|
619
589
|
|
|
620
590
|
// TODO: is this optimal ?
|
|
@@ -664,7 +634,7 @@ do { \
|
|
|
664
634
|
for (int i = 0; i < offset; ++i) { \
|
|
665
635
|
x[i] = _mm512_add_ps(x[i], x[offset+i]); \
|
|
666
636
|
} \
|
|
667
|
-
res = _mm512_reduce_add_ps(x[0]);
|
|
637
|
+
res = (wsp_ggml_float) _mm512_reduce_add_ps(x[0]); \
|
|
668
638
|
} while (0)
|
|
669
639
|
|
|
670
640
|
#define WSP_GGML_F16_VEC WSP_GGML_F32Cx16
|
|
@@ -675,8 +645,8 @@ do { \
|
|
|
675
645
|
#define WSP_GGML_F16_VEC_FMA WSP_GGML_F32Cx16_FMA
|
|
676
646
|
#define WSP_GGML_F16_VEC_ADD WSP_GGML_F32Cx16_ADD
|
|
677
647
|
#define WSP_GGML_F16_VEC_MUL WSP_GGML_F32Cx16_MUL
|
|
678
|
-
#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx16_REDUCE
|
|
679
648
|
|
|
649
|
+
#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx16_REDUCE
|
|
680
650
|
#elif defined(__AVX__)
|
|
681
651
|
|
|
682
652
|
#define WSP_GGML_SIMD
|
|
@@ -745,7 +715,7 @@ do { \
|
|
|
745
715
|
#define WSP_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
|
746
716
|
#define WSP_GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
|
|
747
717
|
#else
|
|
748
|
-
static inline __m256 __avx_f32cx8_load(wsp_ggml_fp16_t *x) {
|
|
718
|
+
static inline __m256 __avx_f32cx8_load(const wsp_ggml_fp16_t * x) {
|
|
749
719
|
float tmp[8];
|
|
750
720
|
|
|
751
721
|
for (int i = 0; i < 8; i++) {
|
|
@@ -1017,7 +987,7 @@ inline static void __wasm_f16x4_store(wsp_ggml_fp16_t * p, v128_t x) {
|
|
|
1017
987
|
#define WSP_GGML_F16_STEP 32
|
|
1018
988
|
#define WSP_GGML_F16_EPR 4
|
|
1019
989
|
|
|
1020
|
-
static inline __m128 __sse_f16x4_load(wsp_ggml_fp16_t *x) {
|
|
990
|
+
static inline __m128 __sse_f16x4_load(const wsp_ggml_fp16_t * x) {
|
|
1021
991
|
float tmp[4];
|
|
1022
992
|
|
|
1023
993
|
tmp[0] = WSP_GGML_FP16_TO_FP32(x[0]);
|
|
@@ -1028,7 +998,7 @@ static inline __m128 __sse_f16x4_load(wsp_ggml_fp16_t *x) {
|
|
|
1028
998
|
return _mm_loadu_ps(tmp);
|
|
1029
999
|
}
|
|
1030
1000
|
|
|
1031
|
-
static inline void __sse_f16x4_store(wsp_ggml_fp16_t *x, __m128 y) {
|
|
1001
|
+
static inline void __sse_f16x4_store(wsp_ggml_fp16_t * x, __m128 y) {
|
|
1032
1002
|
float arr[4];
|
|
1033
1003
|
|
|
1034
1004
|
_mm_storeu_ps(arr, y);
|
|
@@ -1109,29 +1079,23 @@ do { \
|
|
|
1109
1079
|
#define WSP_GGML_F16_STEP 32
|
|
1110
1080
|
#define WSP_GGML_F16_EPR 8
|
|
1111
1081
|
|
|
1112
|
-
// F16 arithmetic is not supported by
|
|
1082
|
+
// F16 arithmetic is not supported by LASX, so we use F32 instead
|
|
1113
1083
|
|
|
1114
1084
|
#define WSP_GGML_F32Cx8 __m256
|
|
1115
1085
|
#define WSP_GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
|
1116
1086
|
#define WSP_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
|
|
1117
1087
|
|
|
1118
1088
|
static inline __m256 __lasx_f32cx8_load(const wsp_ggml_fp16_t * x) {
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
}
|
|
1124
|
-
|
|
1125
|
-
return (__m256)__lasx_xvld(tmp, 0);
|
|
1089
|
+
__m256i a;
|
|
1090
|
+
memcpy(&a, x, sizeof(wsp_ggml_fp16_t) * 8);
|
|
1091
|
+
a = __lasx_xvpermi_d(a, 0 | (1 << 4));
|
|
1092
|
+
return __lasx_xvfcvtl_s_h(a);
|
|
1126
1093
|
}
|
|
1127
|
-
static inline void __lasx_f32cx8_store(wsp_ggml_fp16_t * x, __m256 y) {
|
|
1128
|
-
float arr[8];
|
|
1129
|
-
|
|
1130
|
-
__lasx_xvst(y, arr, 0);
|
|
1131
1094
|
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1095
|
+
static inline void __lasx_f32cx8_store(wsp_ggml_fp16_t * x, __m256 y) {
|
|
1096
|
+
__m256i a = __lasx_xvfcvt_h_s(y, y);
|
|
1097
|
+
a = __lasx_xvpermi_d(a, 0 | (2 << 2));
|
|
1098
|
+
memcpy(x, &a, sizeof(wsp_ggml_fp16_t) * 8);
|
|
1135
1099
|
}
|
|
1136
1100
|
#define WSP_GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
|
|
1137
1101
|
#define WSP_GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
|
|
@@ -1168,28 +1132,28 @@ static inline void __lasx_f32cx8_store(wsp_ggml_fp16_t * x, __m256 y) {
|
|
|
1168
1132
|
#define WSP_GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
|
|
1169
1133
|
#define WSP_GGML_F32x4_ADD __lsx_vfadd_s
|
|
1170
1134
|
#define WSP_GGML_F32x4_MUL __lsx_vfmul_s
|
|
1171
|
-
#define WSP_GGML_F32x4_REDUCE(res, x)
|
|
1172
|
-
{
|
|
1173
|
-
int offset = WSP_GGML_F32_ARR >> 1;
|
|
1174
|
-
for (int i = 0; i < offset; ++i) {
|
|
1175
|
-
x[i] = __lsx_vfadd_s(x[i], x[offset+i]);
|
|
1176
|
-
}
|
|
1177
|
-
offset >>= 1;
|
|
1178
|
-
for (int i = 0; i < offset; ++i) {
|
|
1179
|
-
x[i] = __lsx_vfadd_s(x[i], x[offset+i]);
|
|
1180
|
-
}
|
|
1181
|
-
offset >>= 1;
|
|
1182
|
-
for (int i = 0; i < offset; ++i) {
|
|
1183
|
-
x[i] = __lsx_vfadd_s(x[i], x[offset+i]);
|
|
1184
|
-
}
|
|
1185
|
-
__m128i tmp
|
|
1186
|
-
tmp
|
|
1187
|
-
tmp
|
|
1188
|
-
const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88);
|
|
1189
|
-
tmp
|
|
1190
|
-
tmp
|
|
1191
|
-
tmp
|
|
1192
|
-
res
|
|
1135
|
+
#define WSP_GGML_F32x4_REDUCE(res, x) \
|
|
1136
|
+
{ \
|
|
1137
|
+
int offset = WSP_GGML_F32_ARR >> 1; \
|
|
1138
|
+
for (int i = 0; i < offset; ++i) { \
|
|
1139
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
|
|
1140
|
+
} \
|
|
1141
|
+
offset >>= 1; \
|
|
1142
|
+
for (int i = 0; i < offset; ++i) { \
|
|
1143
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
|
|
1144
|
+
} \
|
|
1145
|
+
offset >>= 1; \
|
|
1146
|
+
for (int i = 0; i < offset; ++i) { \
|
|
1147
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
|
|
1148
|
+
} \
|
|
1149
|
+
__m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \
|
|
1150
|
+
tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \
|
|
1151
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
|
1152
|
+
const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
|
|
1153
|
+
tmp = __lsx_vsrli_d((__m128i) t0, 32); \
|
|
1154
|
+
tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \
|
|
1155
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
|
1156
|
+
res = (wsp_ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
|
|
1193
1157
|
}
|
|
1194
1158
|
|
|
1195
1159
|
#define WSP_GGML_F32_VEC WSP_GGML_F32x4
|
|
@@ -1249,6 +1213,87 @@ static inline void __lsx_f16x4_store(wsp_ggml_fp16_t * x, __m128 y) {
|
|
|
1249
1213
|
#define WSP_GGML_F16_VEC_MUL WSP_GGML_F32Cx4_MUL
|
|
1250
1214
|
#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx4_REDUCE
|
|
1251
1215
|
|
|
1216
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
1217
|
+
|
|
1218
|
+
#define WSP_GGML_SIMD
|
|
1219
|
+
|
|
1220
|
+
// F32 s390x
|
|
1221
|
+
|
|
1222
|
+
#define WSP_GGML_F32_STEP 32
|
|
1223
|
+
#define WSP_GGML_F32_EPR 4
|
|
1224
|
+
|
|
1225
|
+
#define WSP_GGML_F32x4 __vector float
|
|
1226
|
+
#define WSP_GGML_F32x4_ZERO vec_splats(0.0f)
|
|
1227
|
+
#define WSP_GGML_F32x4_SET1 vec_splats
|
|
1228
|
+
#define WSP_GGML_F32x4_LOAD(p) vec_xl(0, p)
|
|
1229
|
+
#define WSP_GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
|
|
1230
|
+
#define WSP_GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
|
|
1231
|
+
#define WSP_GGML_F32x4_ADD vec_add
|
|
1232
|
+
#define WSP_GGML_F32x4_MUL vec_mul
|
|
1233
|
+
#define WSP_GGML_F32x4_REDUCE(res, x) \
|
|
1234
|
+
{ \
|
|
1235
|
+
int offset = WSP_GGML_F32_ARR >> 1; \
|
|
1236
|
+
for (int i = 0; i < offset; ++i) { \
|
|
1237
|
+
x[i] = vec_add(x[i], x[offset + i]); \
|
|
1238
|
+
} \
|
|
1239
|
+
offset >>= 1; \
|
|
1240
|
+
for (int i = 0; i < offset; ++i) { \
|
|
1241
|
+
x[i] = vec_add(x[i], x[offset + i]); \
|
|
1242
|
+
} \
|
|
1243
|
+
offset >>= 1; \
|
|
1244
|
+
for (int i = 0; i < offset; ++i) { \
|
|
1245
|
+
x[i] = vec_add(x[i], x[offset + i]); \
|
|
1246
|
+
} \
|
|
1247
|
+
res = vec_extract(x[0], 0) + \
|
|
1248
|
+
vec_extract(x[0], 1) + \
|
|
1249
|
+
vec_extract(x[0], 2) + \
|
|
1250
|
+
vec_extract(x[0], 3); \
|
|
1251
|
+
}
|
|
1252
|
+
|
|
1253
|
+
#define WSP_GGML_F32_VEC WSP_GGML_F32x4
|
|
1254
|
+
#define WSP_GGML_F32_VEC_ZERO WSP_GGML_F32x4_ZERO
|
|
1255
|
+
#define WSP_GGML_F32_VEC_SET1 WSP_GGML_F32x4_SET1
|
|
1256
|
+
#define WSP_GGML_F32_VEC_LOAD WSP_GGML_F32x4_LOAD
|
|
1257
|
+
#define WSP_GGML_F32_VEC_STORE WSP_GGML_F32x4_STORE
|
|
1258
|
+
#define WSP_GGML_F32_VEC_FMA WSP_GGML_F32x4_FMA
|
|
1259
|
+
#define WSP_GGML_F32_VEC_ADD WSP_GGML_F32x4_ADD
|
|
1260
|
+
#define WSP_GGML_F32_VEC_MUL WSP_GGML_F32x4_MUL
|
|
1261
|
+
#define WSP_GGML_F32_VEC_REDUCE WSP_GGML_F32x4_REDUCE
|
|
1262
|
+
|
|
1263
|
+
// F16 s390x
|
|
1264
|
+
#define WSP_GGML_F16_STEP WSP_GGML_F32_STEP
|
|
1265
|
+
#define WSP_GGML_F16_EPR WSP_GGML_F32_EPR
|
|
1266
|
+
|
|
1267
|
+
static inline __vector float __lzs_f16cx4_load(const wsp_ggml_fp16_t * x) {
|
|
1268
|
+
float tmp[4];
|
|
1269
|
+
|
|
1270
|
+
for (int i = 0; i < 4; i++) {
|
|
1271
|
+
tmp[i] = WSP_GGML_FP16_TO_FP32(x[i]);
|
|
1272
|
+
}
|
|
1273
|
+
|
|
1274
|
+
return vec_xl(0, tmp);
|
|
1275
|
+
}
|
|
1276
|
+
|
|
1277
|
+
static inline void __lzs_f16cx4_store(wsp_ggml_fp16_t * x, __vector float y) {
|
|
1278
|
+
float arr[4];
|
|
1279
|
+
|
|
1280
|
+
vec_xst(y, 0, arr);
|
|
1281
|
+
|
|
1282
|
+
for (int i = 0; i < 4; i++) {
|
|
1283
|
+
x[i] = WSP_GGML_FP32_TO_FP16(arr[i]);
|
|
1284
|
+
}
|
|
1285
|
+
}
|
|
1286
|
+
|
|
1287
|
+
#define WSP_GGML_F16_VEC WSP_GGML_F32x4
|
|
1288
|
+
#define WSP_GGML_F16_VEC_ZERO WSP_GGML_F32x4_ZERO
|
|
1289
|
+
#define WSP_GGML_F16_VEC_SET1 WSP_GGML_F32x4_SET1
|
|
1290
|
+
#define WSP_GGML_F16_VEC_LOAD(p, i) __lzs_f16cx4_load(p)
|
|
1291
|
+
#define WSP_GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
|
|
1292
|
+
#define WSP_GGML_F16_VEC_FMA WSP_GGML_F32x4_FMA
|
|
1293
|
+
#define WSP_GGML_F16_VEC_ADD WSP_GGML_F32x4_ADD
|
|
1294
|
+
#define WSP_GGML_F16_VEC_MUL WSP_GGML_F32x4_MUL
|
|
1295
|
+
#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32x4_REDUCE
|
|
1296
|
+
|
|
1252
1297
|
#endif
|
|
1253
1298
|
|
|
1254
1299
|
// WSP_GGML_F32_ARR / WSP_GGML_F16_ARR
|
|
@@ -1328,12 +1373,12 @@ struct wsp_ggml_threadpool {
|
|
|
1328
1373
|
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
|
|
1329
1374
|
atomic_int WSP_GGML_CACHE_ALIGN n_barrier;
|
|
1330
1375
|
atomic_int WSP_GGML_CACHE_ALIGN n_barrier_passed;
|
|
1331
|
-
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
|
1376
|
+
atomic_int WSP_GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
|
1332
1377
|
|
|
1333
1378
|
// these are atomic as an annotation for thread-sanitizer
|
|
1334
1379
|
atomic_bool stop; // Used for stopping the threadpool altogether
|
|
1335
1380
|
atomic_bool pause; // Used for pausing the threadpool or individual threads
|
|
1336
|
-
|
|
1381
|
+
atomic_int abort; // Used for aborting processing of a graph
|
|
1337
1382
|
|
|
1338
1383
|
struct wsp_ggml_compute_state * workers; // per thread state
|
|
1339
1384
|
int n_threads_max; // number of threads in the pool
|
|
@@ -1357,41 +1402,48 @@ struct wsp_ggml_compute_state {
|
|
|
1357
1402
|
int ith;
|
|
1358
1403
|
};
|
|
1359
1404
|
|
|
1360
|
-
struct wsp_ggml_compute_params {
|
|
1361
|
-
// ith = thread index, nth = number of threads
|
|
1362
|
-
int ith, nth;
|
|
1363
|
-
|
|
1364
|
-
// work buffer for all threads
|
|
1365
|
-
size_t wsize;
|
|
1366
|
-
void * wdata;
|
|
1367
|
-
|
|
1368
|
-
struct wsp_ggml_threadpool * threadpool;
|
|
1369
|
-
};
|
|
1370
|
-
|
|
1371
1405
|
//
|
|
1372
1406
|
// fundamental operations
|
|
1373
1407
|
//
|
|
1374
1408
|
|
|
1375
1409
|
inline static void wsp_ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
1376
|
-
|
|
1377
1410
|
inline static void wsp_ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
1378
1411
|
|
|
1379
|
-
inline static void wsp_ggml_vec_set_i32(const int n, int32_t * x, const int32_t
|
|
1412
|
+
inline static void wsp_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
1413
|
+
inline static void wsp_ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
|
|
1380
1414
|
|
|
1381
1415
|
inline static void wsp_ggml_vec_set_f16(const int n, wsp_ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
1382
|
-
|
|
1383
1416
|
inline static void wsp_ggml_vec_set_bf16(const int n, wsp_ggml_bf16_t * x, const wsp_ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
1384
|
-
|
|
1385
1417
|
inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
|
1418
|
+
inline static void wsp_ggml_vec_add_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
|
|
1419
|
+
for (int i = 0; i < n; ++i) {
|
|
1420
|
+
z[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(x[i]) + WSP_GGML_FP16_TO_FP32(y[i]));
|
|
1421
|
+
}
|
|
1422
|
+
}
|
|
1386
1423
|
inline static void wsp_ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
|
|
1387
1424
|
inline static void wsp_ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
|
|
1388
1425
|
inline static void wsp_ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
|
|
1389
1426
|
inline static void wsp_ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
|
|
1427
|
+
inline static void wsp_ggml_vec_sub_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
|
|
1428
|
+
for (int i = 0; i < n; ++i) {
|
|
1429
|
+
z[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(x[i]) - WSP_GGML_FP16_TO_FP32(y[i]));
|
|
1430
|
+
}
|
|
1431
|
+
}
|
|
1390
1432
|
inline static void wsp_ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
1391
1433
|
inline static void wsp_ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
|
|
1392
1434
|
inline static void wsp_ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
|
|
1393
1435
|
inline static void wsp_ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
|
1436
|
+
inline static void wsp_ggml_vec_mul_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
|
|
1437
|
+
for (int i = 0; i < n; ++i) {
|
|
1438
|
+
z[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(x[i]) * WSP_GGML_FP16_TO_FP32(y[i]));
|
|
1439
|
+
}
|
|
1440
|
+
}
|
|
1394
1441
|
inline static void wsp_ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
|
1442
|
+
inline static void wsp_ggml_vec_div_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
|
|
1443
|
+
for (int i = 0; i < n; ++i) {
|
|
1444
|
+
z[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(x[i]) / WSP_GGML_FP16_TO_FP32(y[i]));
|
|
1445
|
+
}
|
|
1446
|
+
}
|
|
1395
1447
|
|
|
1396
1448
|
static void wsp_ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
|
|
1397
1449
|
assert(nrc == 1);
|
|
@@ -1868,7 +1920,7 @@ inline static float wsp_ggml_silu_f32(float x) {
|
|
|
1868
1920
|
|
|
1869
1921
|
#if __FINITE_MATH_ONLY__
|
|
1870
1922
|
#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
|
|
1871
|
-
#error "ref: https://github.com/
|
|
1923
|
+
#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
|
|
1872
1924
|
#endif
|
|
1873
1925
|
|
|
1874
1926
|
#if defined(__ARM_NEON) && defined(__aarch64__)
|
|
@@ -2276,7 +2328,7 @@ struct wsp_ggml_state {
|
|
|
2276
2328
|
|
|
2277
2329
|
static struct wsp_ggml_state g_state = {0};
|
|
2278
2330
|
|
|
2279
|
-
|
|
2331
|
+
void wsp_ggml_barrier(struct wsp_ggml_threadpool * tp) {
|
|
2280
2332
|
int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
|
|
2281
2333
|
if (n_threads == 1) {
|
|
2282
2334
|
return;
|
|
@@ -2430,7 +2482,11 @@ bool wsp_ggml_is_numa(void) {
|
|
|
2430
2482
|
#endif
|
|
2431
2483
|
|
|
2432
2484
|
#if !defined(HWCAP2_I8MM)
|
|
2433
|
-
#define HWCAP2_I8MM
|
|
2485
|
+
#define HWCAP2_I8MM (1 << 13)
|
|
2486
|
+
#endif
|
|
2487
|
+
|
|
2488
|
+
#if !defined(HWCAP2_SME)
|
|
2489
|
+
#define HWCAP2_SME (1 << 23)
|
|
2434
2490
|
#endif
|
|
2435
2491
|
|
|
2436
2492
|
static void wsp_ggml_init_arm_arch_features(void) {
|
|
@@ -2438,9 +2494,11 @@ static void wsp_ggml_init_arm_arch_features(void) {
|
|
|
2438
2494
|
uint32_t hwcap = getauxval(AT_HWCAP);
|
|
2439
2495
|
uint32_t hwcap2 = getauxval(AT_HWCAP2);
|
|
2440
2496
|
|
|
2441
|
-
wsp_ggml_arm_arch_features.has_neon
|
|
2442
|
-
wsp_ggml_arm_arch_features.
|
|
2443
|
-
wsp_ggml_arm_arch_features.
|
|
2497
|
+
wsp_ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
|
|
2498
|
+
wsp_ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
|
|
2499
|
+
wsp_ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
|
|
2500
|
+
wsp_ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
|
|
2501
|
+
wsp_ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
|
|
2444
2502
|
|
|
2445
2503
|
#if defined(__ARM_FEATURE_SVE)
|
|
2446
2504
|
wsp_ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
|
|
@@ -2453,11 +2511,21 @@ static void wsp_ggml_init_arm_arch_features(void) {
|
|
|
2453
2511
|
}
|
|
2454
2512
|
wsp_ggml_arm_arch_features.has_neon = oldp;
|
|
2455
2513
|
|
|
2514
|
+
if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) {
|
|
2515
|
+
oldp = 0;
|
|
2516
|
+
}
|
|
2517
|
+
wsp_ggml_arm_arch_features.has_dotprod = oldp;
|
|
2518
|
+
|
|
2456
2519
|
if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
|
|
2457
2520
|
oldp = 0;
|
|
2458
2521
|
}
|
|
2459
2522
|
wsp_ggml_arm_arch_features.has_i8mm = oldp;
|
|
2460
2523
|
|
|
2524
|
+
if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
|
|
2525
|
+
oldp = 0;
|
|
2526
|
+
}
|
|
2527
|
+
wsp_ggml_arm_arch_features.has_sme = oldp;
|
|
2528
|
+
|
|
2461
2529
|
wsp_ggml_arm_arch_features.has_sve = 0;
|
|
2462
2530
|
wsp_ggml_arm_arch_features.sve_cnt = 0;
|
|
2463
2531
|
#else
|
|
@@ -2481,6 +2549,12 @@ static void wsp_ggml_init_arm_arch_features(void) {
|
|
|
2481
2549
|
wsp_ggml_arm_arch_features.has_sve = 0;
|
|
2482
2550
|
wsp_ggml_arm_arch_features.sve_cnt = 0;
|
|
2483
2551
|
#endif
|
|
2552
|
+
|
|
2553
|
+
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
|
|
2554
|
+
wsp_ggml_arm_arch_features.has_sme = 1;
|
|
2555
|
+
#else
|
|
2556
|
+
wsp_ggml_arm_arch_features.has_sme = 0;
|
|
2557
|
+
#endif
|
|
2484
2558
|
#endif
|
|
2485
2559
|
}
|
|
2486
2560
|
#endif
|
|
@@ -4005,6 +4079,57 @@ static void wsp_ggml_compute_forward_dup_bytes(
|
|
|
4005
4079
|
}
|
|
4006
4080
|
}
|
|
4007
4081
|
|
|
4082
|
+
static void wsp_ggml_compute_forward_dup_q(
|
|
4083
|
+
const struct wsp_ggml_compute_params * params,
|
|
4084
|
+
struct wsp_ggml_tensor * dst) {
|
|
4085
|
+
|
|
4086
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
4087
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
4088
|
+
|
|
4089
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
4090
|
+
|
|
4091
|
+
const enum wsp_ggml_type type = src0->type;
|
|
4092
|
+
wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = wsp_ggml_get_type_traits(type)->to_float;
|
|
4093
|
+
|
|
4094
|
+
size_t qk = wsp_ggml_blck_size(type);
|
|
4095
|
+
const int64_t nr = wsp_ggml_nelements(src1) / qk;
|
|
4096
|
+
|
|
4097
|
+
// destination must be contiguous in the first dimension
|
|
4098
|
+
WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(dst->type));
|
|
4099
|
+
// must either have first dimension large enough to hold a row, or fully contiguous
|
|
4100
|
+
WSP_GGML_ASSERT((ne10 % qk) == 0 || wsp_ggml_is_contiguous(dst));
|
|
4101
|
+
|
|
4102
|
+
const int ith = params->ith;
|
|
4103
|
+
const int nth = params->nth;
|
|
4104
|
+
|
|
4105
|
+
const int dr = (nr + nth - 1)/nth;
|
|
4106
|
+
|
|
4107
|
+
// row range for this thread
|
|
4108
|
+
const int ir0 = dr*ith;
|
|
4109
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
4110
|
+
|
|
4111
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
4112
|
+
|
|
4113
|
+
uint32_t i = ir * qk;
|
|
4114
|
+
|
|
4115
|
+
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
|
4116
|
+
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
|
4117
|
+
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
4118
|
+
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
|
4119
|
+
const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
|
4120
|
+
|
|
4121
|
+
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
|
4122
|
+
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
|
4123
|
+
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
|
4124
|
+
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
|
4125
|
+
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
|
4126
|
+
|
|
4127
|
+
wsp_dewsp_quantize_row_q(
|
|
4128
|
+
(const void *) ((char *) src0->data + x_offset),
|
|
4129
|
+
(float *) ((char *) dst->data + dst_offset), qk);
|
|
4130
|
+
}
|
|
4131
|
+
}
|
|
4132
|
+
|
|
4008
4133
|
static void wsp_ggml_compute_forward_dup(
|
|
4009
4134
|
const struct wsp_ggml_compute_params * params,
|
|
4010
4135
|
struct wsp_ggml_tensor * dst) {
|
|
@@ -4031,6 +4156,10 @@ static void wsp_ggml_compute_forward_dup(
|
|
|
4031
4156
|
} break;
|
|
4032
4157
|
default:
|
|
4033
4158
|
{
|
|
4159
|
+
if (wsp_ggml_is_quantized(src0->type) && dst->type == WSP_GGML_TYPE_F32) {
|
|
4160
|
+
wsp_ggml_compute_forward_dup_q(params, dst);
|
|
4161
|
+
break;
|
|
4162
|
+
}
|
|
4034
4163
|
WSP_GGML_ABORT("fatal error");
|
|
4035
4164
|
}
|
|
4036
4165
|
}
|
|
@@ -4270,7 +4399,7 @@ static void wsp_ggml_compute_forward_add_f16_f16(
|
|
|
4270
4399
|
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
4271
4400
|
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
4272
4401
|
|
|
4273
|
-
WSP_GGML_ASSERT(
|
|
4402
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
|
|
4274
4403
|
|
|
4275
4404
|
const int ith = params->ith;
|
|
4276
4405
|
const int nth = params->nth;
|
|
@@ -4295,17 +4424,22 @@ static void wsp_ggml_compute_forward_add_f16_f16(
|
|
|
4295
4424
|
|
|
4296
4425
|
if (nb10 == sizeof(wsp_ggml_fp16_t)) {
|
|
4297
4426
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
4298
|
-
// src0
|
|
4299
|
-
const
|
|
4300
|
-
const
|
|
4301
|
-
const
|
|
4427
|
+
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
|
4428
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
4429
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
4430
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
4431
|
+
|
|
4432
|
+
const int64_t i13 = i03 % ne13;
|
|
4433
|
+
const int64_t i12 = i02 % ne12;
|
|
4434
|
+
const int64_t i11 = i01 % ne11;
|
|
4435
|
+
const int64_t nr0 = ne00 / ne10;
|
|
4302
4436
|
|
|
4303
|
-
wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data +
|
|
4304
|
-
wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data +
|
|
4305
|
-
wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data +
|
|
4437
|
+
wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
4438
|
+
wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
4439
|
+
wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
|
4306
4440
|
|
|
4307
|
-
for (
|
|
4308
|
-
dst_ptr
|
|
4441
|
+
for (int64_t r = 0; r < nr0; ++r) {
|
|
4442
|
+
wsp_ggml_vec_add_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
|
4309
4443
|
}
|
|
4310
4444
|
}
|
|
4311
4445
|
}
|
|
@@ -4505,9 +4639,6 @@ static void wsp_ggml_compute_forward_add(
|
|
|
4505
4639
|
case WSP_GGML_TYPE_IQ4_XS:
|
|
4506
4640
|
case WSP_GGML_TYPE_IQ3_S:
|
|
4507
4641
|
case WSP_GGML_TYPE_IQ2_S:
|
|
4508
|
-
case WSP_GGML_TYPE_Q4_0_4_4:
|
|
4509
|
-
case WSP_GGML_TYPE_Q4_0_4_8:
|
|
4510
|
-
case WSP_GGML_TYPE_Q4_0_8_8:
|
|
4511
4642
|
{
|
|
4512
4643
|
wsp_ggml_compute_forward_add_q_f32(params, dst);
|
|
4513
4644
|
} break;
|
|
@@ -4885,9 +5016,6 @@ static void wsp_ggml_compute_forward_add1(
|
|
|
4885
5016
|
case WSP_GGML_TYPE_IQ4_XS:
|
|
4886
5017
|
case WSP_GGML_TYPE_IQ3_S:
|
|
4887
5018
|
case WSP_GGML_TYPE_IQ2_S:
|
|
4888
|
-
case WSP_GGML_TYPE_Q4_0_4_4:
|
|
4889
|
-
case WSP_GGML_TYPE_Q4_0_4_8:
|
|
4890
|
-
case WSP_GGML_TYPE_Q4_0_8_8:
|
|
4891
5019
|
{
|
|
4892
5020
|
wsp_ggml_compute_forward_add1_q_f32(params, dst);
|
|
4893
5021
|
} break;
|
|
@@ -5015,9 +5143,6 @@ static void wsp_ggml_compute_forward_acc(
|
|
|
5015
5143
|
case WSP_GGML_TYPE_IQ4_XS:
|
|
5016
5144
|
case WSP_GGML_TYPE_IQ3_S:
|
|
5017
5145
|
case WSP_GGML_TYPE_IQ2_S:
|
|
5018
|
-
case WSP_GGML_TYPE_Q4_0_4_4:
|
|
5019
|
-
case WSP_GGML_TYPE_Q4_0_4_8:
|
|
5020
|
-
case WSP_GGML_TYPE_Q4_0_8_8:
|
|
5021
5146
|
default:
|
|
5022
5147
|
{
|
|
5023
5148
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -5102,6 +5227,62 @@ static void wsp_ggml_compute_forward_sub_f32(
|
|
|
5102
5227
|
}
|
|
5103
5228
|
}
|
|
5104
5229
|
|
|
5230
|
+
static void wsp_ggml_compute_forward_sub_f16(
|
|
5231
|
+
const struct wsp_ggml_compute_params * params,
|
|
5232
|
+
struct wsp_ggml_tensor * dst) {
|
|
5233
|
+
|
|
5234
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
5235
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
5236
|
+
|
|
5237
|
+
assert(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
|
|
5238
|
+
|
|
5239
|
+
const int ith = params->ith;
|
|
5240
|
+
const int nth = params->nth;
|
|
5241
|
+
|
|
5242
|
+
const int nr = wsp_ggml_nrows(src0);
|
|
5243
|
+
|
|
5244
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
5245
|
+
|
|
5246
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
5247
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
|
|
5248
|
+
WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16);
|
|
5249
|
+
|
|
5250
|
+
WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t));
|
|
5251
|
+
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
5252
|
+
|
|
5253
|
+
// rows per thread
|
|
5254
|
+
const int dr = (nr + nth - 1)/nth;
|
|
5255
|
+
|
|
5256
|
+
// row range for this thread
|
|
5257
|
+
const int ir0 = dr*ith;
|
|
5258
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
5259
|
+
|
|
5260
|
+
if (nb10 == sizeof(wsp_ggml_fp16_t)) {
|
|
5261
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
5262
|
+
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
|
5263
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
5264
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
5265
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
5266
|
+
|
|
5267
|
+
const int64_t i13 = i03 % ne13;
|
|
5268
|
+
const int64_t i12 = i02 % ne12;
|
|
5269
|
+
const int64_t i11 = i01 % ne11;
|
|
5270
|
+
const int64_t nr0 = ne00 / ne10;
|
|
5271
|
+
|
|
5272
|
+
wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
5273
|
+
wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
5274
|
+
wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
|
5275
|
+
|
|
5276
|
+
for (int64_t r = 0; r < nr0; ++r) {
|
|
5277
|
+
wsp_ggml_vec_sub_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
|
5278
|
+
}
|
|
5279
|
+
}
|
|
5280
|
+
} else {
|
|
5281
|
+
// src1 is not contiguous
|
|
5282
|
+
WSP_GGML_ABORT("unimplemented error");
|
|
5283
|
+
}
|
|
5284
|
+
}
|
|
5285
|
+
|
|
5105
5286
|
static void wsp_ggml_compute_forward_sub(
|
|
5106
5287
|
const struct wsp_ggml_compute_params * params,
|
|
5107
5288
|
struct wsp_ggml_tensor * dst) {
|
|
@@ -5113,6 +5294,10 @@ static void wsp_ggml_compute_forward_sub(
|
|
|
5113
5294
|
{
|
|
5114
5295
|
wsp_ggml_compute_forward_sub_f32(params, dst);
|
|
5115
5296
|
} break;
|
|
5297
|
+
case WSP_GGML_TYPE_F16:
|
|
5298
|
+
{
|
|
5299
|
+
wsp_ggml_compute_forward_sub_f16(params, dst);
|
|
5300
|
+
} break;
|
|
5116
5301
|
default:
|
|
5117
5302
|
{
|
|
5118
5303
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -5193,32 +5378,9 @@ static void wsp_ggml_compute_forward_mul_f32(
|
|
|
5193
5378
|
}
|
|
5194
5379
|
}
|
|
5195
5380
|
|
|
5196
|
-
static void
|
|
5197
|
-
|
|
5198
|
-
|
|
5199
|
-
|
|
5200
|
-
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
5201
|
-
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
5202
|
-
|
|
5203
|
-
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32 && "only f32 src1 supported for now");
|
|
5204
|
-
|
|
5205
|
-
switch (src0->type) {
|
|
5206
|
-
case WSP_GGML_TYPE_F32:
|
|
5207
|
-
{
|
|
5208
|
-
wsp_ggml_compute_forward_mul_f32(params, dst);
|
|
5209
|
-
} break;
|
|
5210
|
-
default:
|
|
5211
|
-
{
|
|
5212
|
-
WSP_GGML_ABORT("fatal error");
|
|
5213
|
-
}
|
|
5214
|
-
}
|
|
5215
|
-
}
|
|
5216
|
-
|
|
5217
|
-
// wsp_ggml_compute_forward_div
|
|
5218
|
-
|
|
5219
|
-
static void wsp_ggml_compute_forward_div_f32(
|
|
5220
|
-
const struct wsp_ggml_compute_params * params,
|
|
5221
|
-
struct wsp_ggml_tensor * dst) {
|
|
5381
|
+
static void wsp_ggml_compute_forward_mul_f16(
|
|
5382
|
+
const struct wsp_ggml_compute_params * params,
|
|
5383
|
+
struct wsp_ggml_tensor * dst) {
|
|
5222
5384
|
|
|
5223
5385
|
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
5224
5386
|
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
@@ -5232,8 +5394,84 @@ static void wsp_ggml_compute_forward_div_f32(
|
|
|
5232
5394
|
|
|
5233
5395
|
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
5234
5396
|
|
|
5235
|
-
WSP_GGML_ASSERT(
|
|
5236
|
-
WSP_GGML_ASSERT(
|
|
5397
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
5398
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
|
|
5399
|
+
WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16);
|
|
5400
|
+
|
|
5401
|
+
WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t));
|
|
5402
|
+
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
5403
|
+
|
|
5404
|
+
if (nb10 == sizeof(wsp_ggml_fp16_t)) {
|
|
5405
|
+
for (int64_t ir = ith; ir < nr; ir += nth) {
|
|
5406
|
+
// src0 and dst are same shape => same indices
|
|
5407
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
5408
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
5409
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
5410
|
+
|
|
5411
|
+
const int64_t i13 = i03 % ne13;
|
|
5412
|
+
const int64_t i12 = i02 % ne12;
|
|
5413
|
+
const int64_t i11 = i01 % ne11;
|
|
5414
|
+
const int64_t nr0 = ne00 / ne10;
|
|
5415
|
+
|
|
5416
|
+
wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
5417
|
+
wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
5418
|
+
wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
|
5419
|
+
|
|
5420
|
+
for (int64_t r = 0 ; r < nr0; ++r) {
|
|
5421
|
+
wsp_ggml_vec_mul_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
|
5422
|
+
}
|
|
5423
|
+
}
|
|
5424
|
+
} else {
|
|
5425
|
+
// src1 is not contiguous
|
|
5426
|
+
WSP_GGML_ABORT("unimplemented error");
|
|
5427
|
+
}
|
|
5428
|
+
}
|
|
5429
|
+
|
|
5430
|
+
static void wsp_ggml_compute_forward_mul(
|
|
5431
|
+
const struct wsp_ggml_compute_params * params,
|
|
5432
|
+
struct wsp_ggml_tensor * dst) {
|
|
5433
|
+
|
|
5434
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
5435
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
5436
|
+
|
|
5437
|
+
WSP_GGML_ASSERT((src1->type == WSP_GGML_TYPE_F32 || src1->type == WSP_GGML_TYPE_F16) && "only f32/f16 src1 supported for now");
|
|
5438
|
+
|
|
5439
|
+
switch (src0->type) {
|
|
5440
|
+
case WSP_GGML_TYPE_F32:
|
|
5441
|
+
{
|
|
5442
|
+
wsp_ggml_compute_forward_mul_f32(params, dst);
|
|
5443
|
+
} break;
|
|
5444
|
+
case WSP_GGML_TYPE_F16:
|
|
5445
|
+
{
|
|
5446
|
+
wsp_ggml_compute_forward_mul_f16(params, dst);
|
|
5447
|
+
} break;
|
|
5448
|
+
default:
|
|
5449
|
+
{
|
|
5450
|
+
WSP_GGML_ABORT("fatal error");
|
|
5451
|
+
}
|
|
5452
|
+
}
|
|
5453
|
+
}
|
|
5454
|
+
|
|
5455
|
+
// wsp_ggml_compute_forward_div
|
|
5456
|
+
|
|
5457
|
+
static void wsp_ggml_compute_forward_div_f32(
|
|
5458
|
+
const struct wsp_ggml_compute_params * params,
|
|
5459
|
+
struct wsp_ggml_tensor * dst) {
|
|
5460
|
+
|
|
5461
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
5462
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
5463
|
+
|
|
5464
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
|
|
5465
|
+
|
|
5466
|
+
const int ith = params->ith;
|
|
5467
|
+
const int nth = params->nth;
|
|
5468
|
+
|
|
5469
|
+
const int64_t nr = wsp_ggml_nrows(src0);
|
|
5470
|
+
|
|
5471
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
5472
|
+
|
|
5473
|
+
WSP_GGML_ASSERT( nb0 == sizeof(float));
|
|
5474
|
+
WSP_GGML_ASSERT(nb00 == sizeof(float));
|
|
5237
5475
|
|
|
5238
5476
|
if (nb10 == sizeof(float)) {
|
|
5239
5477
|
for (int64_t ir = ith; ir < nr; ir += nth) {
|
|
@@ -5287,6 +5525,55 @@ static void wsp_ggml_compute_forward_div_f32(
|
|
|
5287
5525
|
}
|
|
5288
5526
|
}
|
|
5289
5527
|
|
|
5528
|
+
static void wsp_ggml_compute_forward_div_f16(
|
|
5529
|
+
const struct wsp_ggml_compute_params * params,
|
|
5530
|
+
struct wsp_ggml_tensor * dst) {
|
|
5531
|
+
|
|
5532
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
5533
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
5534
|
+
|
|
5535
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
|
|
5536
|
+
|
|
5537
|
+
const int ith = params->ith;
|
|
5538
|
+
const int nth = params->nth;
|
|
5539
|
+
|
|
5540
|
+
const int64_t nr = wsp_ggml_nrows(src0);
|
|
5541
|
+
|
|
5542
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
5543
|
+
|
|
5544
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
5545
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
|
|
5546
|
+
WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16);
|
|
5547
|
+
|
|
5548
|
+
WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t));
|
|
5549
|
+
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
5550
|
+
|
|
5551
|
+
if (nb10 == sizeof(wsp_ggml_fp16_t)) {
|
|
5552
|
+
for (int64_t ir = ith; ir < nr; ir += nth) {
|
|
5553
|
+
// src0 and dst are same shape => same indices
|
|
5554
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
5555
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
5556
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
5557
|
+
|
|
5558
|
+
const int64_t i13 = i03 % ne13;
|
|
5559
|
+
const int64_t i12 = i02 % ne12;
|
|
5560
|
+
const int64_t i11 = i01 % ne11;
|
|
5561
|
+
const int64_t nr0 = ne00 / ne10;
|
|
5562
|
+
|
|
5563
|
+
wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
5564
|
+
wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
5565
|
+
wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
|
5566
|
+
|
|
5567
|
+
for (int64_t r = 0; r < nr0; ++r) {
|
|
5568
|
+
wsp_ggml_vec_div_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
|
5569
|
+
}
|
|
5570
|
+
}
|
|
5571
|
+
} else {
|
|
5572
|
+
// src1 is not contiguous
|
|
5573
|
+
WSP_GGML_ABORT("unimplemented error");
|
|
5574
|
+
}
|
|
5575
|
+
}
|
|
5576
|
+
|
|
5290
5577
|
static void wsp_ggml_compute_forward_div(
|
|
5291
5578
|
const struct wsp_ggml_compute_params * params,
|
|
5292
5579
|
struct wsp_ggml_tensor * dst) {
|
|
@@ -5298,6 +5585,10 @@ static void wsp_ggml_compute_forward_div(
|
|
|
5298
5585
|
{
|
|
5299
5586
|
wsp_ggml_compute_forward_div_f32(params, dst);
|
|
5300
5587
|
} break;
|
|
5588
|
+
case WSP_GGML_TYPE_F16:
|
|
5589
|
+
{
|
|
5590
|
+
wsp_ggml_compute_forward_div_f16(params, dst);
|
|
5591
|
+
} break;
|
|
5301
5592
|
default:
|
|
5302
5593
|
{
|
|
5303
5594
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -6738,20 +7029,20 @@ static void wsp_ggml_compute_forward_silu_back_f32(
|
|
|
6738
7029
|
const struct wsp_ggml_compute_params * params,
|
|
6739
7030
|
struct wsp_ggml_tensor * dst) {
|
|
6740
7031
|
|
|
6741
|
-
const struct wsp_ggml_tensor *
|
|
6742
|
-
const struct wsp_ggml_tensor *
|
|
7032
|
+
const struct wsp_ggml_tensor * grad = dst->src[0];
|
|
7033
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
6743
7034
|
|
|
6744
7035
|
assert(wsp_ggml_is_contiguous_1(grad));
|
|
6745
|
-
assert(wsp_ggml_is_contiguous_1(
|
|
7036
|
+
assert(wsp_ggml_is_contiguous_1(src1));
|
|
6746
7037
|
assert(wsp_ggml_is_contiguous_1(dst));
|
|
6747
|
-
assert(wsp_ggml_are_same_shape(
|
|
6748
|
-
assert(wsp_ggml_are_same_shape(
|
|
7038
|
+
assert(wsp_ggml_are_same_shape(src1, dst));
|
|
7039
|
+
assert(wsp_ggml_are_same_shape(src1, grad));
|
|
6749
7040
|
|
|
6750
7041
|
const int ith = params->ith;
|
|
6751
7042
|
const int nth = params->nth;
|
|
6752
7043
|
|
|
6753
|
-
const int nc =
|
|
6754
|
-
const int nr = wsp_ggml_nrows(
|
|
7044
|
+
const int nc = src1->ne[0];
|
|
7045
|
+
const int nr = wsp_ggml_nrows(src1);
|
|
6755
7046
|
|
|
6756
7047
|
// rows per thread
|
|
6757
7048
|
const int dr = (nr + nth - 1)/nth;
|
|
@@ -6763,7 +7054,7 @@ static void wsp_ggml_compute_forward_silu_back_f32(
|
|
|
6763
7054
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
6764
7055
|
wsp_ggml_vec_silu_backward_f32(nc,
|
|
6765
7056
|
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
|
6766
|
-
(float *) ((char *)
|
|
7057
|
+
(float *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
6767
7058
|
(float *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
6768
7059
|
|
|
6769
7060
|
#ifndef NDEBUG
|
|
@@ -6942,7 +7233,7 @@ static void wsp_ggml_compute_forward_norm_f32(
|
|
|
6942
7233
|
float eps;
|
|
6943
7234
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
6944
7235
|
|
|
6945
|
-
WSP_GGML_ASSERT(eps
|
|
7236
|
+
WSP_GGML_ASSERT(eps >= 0.0f);
|
|
6946
7237
|
|
|
6947
7238
|
// TODO: optimize
|
|
6948
7239
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
@@ -7013,7 +7304,7 @@ static void wsp_ggml_compute_forward_rms_norm_f32(
|
|
|
7013
7304
|
float eps;
|
|
7014
7305
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
7015
7306
|
|
|
7016
|
-
WSP_GGML_ASSERT(eps
|
|
7307
|
+
WSP_GGML_ASSERT(eps >= 0.0f);
|
|
7017
7308
|
|
|
7018
7309
|
// TODO: optimize
|
|
7019
7310
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
@@ -7065,12 +7356,13 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
|
|
|
7065
7356
|
const struct wsp_ggml_compute_params * params,
|
|
7066
7357
|
struct wsp_ggml_tensor * dst) {
|
|
7067
7358
|
|
|
7068
|
-
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
7069
|
-
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
7359
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
|
|
7360
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
|
|
7070
7361
|
|
|
7071
7362
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst) && wsp_ggml_are_same_shape(src0, src1));
|
|
7072
7363
|
|
|
7073
7364
|
WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
7365
|
+
WSP_GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
7074
7366
|
|
|
7075
7367
|
const int ith = params->ith;
|
|
7076
7368
|
const int nth = params->nth;
|
|
@@ -7089,8 +7381,8 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
|
|
|
7089
7381
|
const int64_t i12 = i02;
|
|
7090
7382
|
const int64_t i13 = i03;
|
|
7091
7383
|
|
|
7092
|
-
const float *
|
|
7093
|
-
const float *
|
|
7384
|
+
const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
7385
|
+
const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
|
7094
7386
|
|
|
7095
7387
|
wsp_ggml_float sum_xx = 0.0;
|
|
7096
7388
|
wsp_ggml_float sum_xdz = 0.0;
|
|
@@ -7113,9 +7405,9 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
|
|
|
7113
7405
|
{
|
|
7114
7406
|
// z = rms_norm(x)
|
|
7115
7407
|
//
|
|
7116
|
-
// rms_norm(
|
|
7408
|
+
// rms_norm(src1) =
|
|
7117
7409
|
// scale(
|
|
7118
|
-
//
|
|
7410
|
+
// src1,
|
|
7119
7411
|
// div(
|
|
7120
7412
|
// 1,
|
|
7121
7413
|
// sqrt(
|
|
@@ -7123,13 +7415,13 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
|
|
|
7123
7415
|
// scale(
|
|
7124
7416
|
// sum(
|
|
7125
7417
|
// sqr(
|
|
7126
|
-
//
|
|
7418
|
+
// src1)),
|
|
7127
7419
|
// (1.0/N)),
|
|
7128
7420
|
// eps))));
|
|
7129
7421
|
|
|
7130
7422
|
// postorder:
|
|
7131
7423
|
// ## op args grad
|
|
7132
|
-
// 00 param
|
|
7424
|
+
// 00 param src1 grad[#00]
|
|
7133
7425
|
// 01 const 1
|
|
7134
7426
|
// 02 sqr (#00) grad[#02]
|
|
7135
7427
|
// 03 sum (#02) grad[#03]
|
|
@@ -7206,6 +7498,7 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
|
|
|
7206
7498
|
// dx := scale(dx, rrms)
|
|
7207
7499
|
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
7208
7500
|
|
|
7501
|
+
// dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
|
|
7209
7502
|
wsp_ggml_vec_cpy_f32 (ne00, dx, x);
|
|
7210
7503
|
// wsp_ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
|
|
7211
7504
|
wsp_ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
|
|
@@ -7433,20 +7726,9 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
7433
7726
|
const int ith = params->ith;
|
|
7434
7727
|
const int nth = params->nth;
|
|
7435
7728
|
|
|
7436
|
-
enum wsp_ggml_type
|
|
7437
|
-
|
|
7438
|
-
if (src0->buffer && wsp_ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
|
|
7439
|
-
type = (enum wsp_ggml_type)(intptr_t)src0->extra;
|
|
7440
|
-
}
|
|
7441
|
-
|
|
7442
|
-
enum wsp_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
|
7729
|
+
enum wsp_ggml_type const vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
|
|
7443
7730
|
wsp_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
|
|
7444
|
-
|
|
7445
|
-
int64_t const vec_dot_num_rows = type_traits_cpu[type].nrows;
|
|
7446
|
-
int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
|
|
7447
|
-
int64_t const blck_size_interleave = wsp_ggml_get_type_traits(type)->blck_size_interleave;
|
|
7448
|
-
wsp_ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
|
|
7449
|
-
wsp_ggml_gemm_t const gemm = type_traits_cpu[type].gemm;
|
|
7731
|
+
int64_t const vec_dot_num_rows = type_traits_cpu[src0->type].nrows;
|
|
7450
7732
|
|
|
7451
7733
|
WSP_GGML_ASSERT(ne0 == ne01);
|
|
7452
7734
|
WSP_GGML_ASSERT(ne1 == ne11);
|
|
@@ -7454,7 +7736,7 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
7454
7736
|
WSP_GGML_ASSERT(ne3 == ne13);
|
|
7455
7737
|
|
|
7456
7738
|
// we don't support permuted src0 or src1
|
|
7457
|
-
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
|
|
7739
|
+
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(src0->type));
|
|
7458
7740
|
WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type));
|
|
7459
7741
|
|
|
7460
7742
|
// dst cannot be transposed or permuted
|
|
@@ -7466,6 +7748,7 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
7466
7748
|
// nb01 >= nb00 - src0 is not transposed
|
|
7467
7749
|
// compute by src0 rows
|
|
7468
7750
|
|
|
7751
|
+
// TODO: extract to "extra_op"
|
|
7469
7752
|
#if WSP_GGML_USE_LLAMAFILE
|
|
7470
7753
|
// broadcast factors
|
|
7471
7754
|
const int64_t r2 = ne12 / ne02;
|
|
@@ -7476,15 +7759,15 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
7476
7759
|
if (src1_cont) {
|
|
7477
7760
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
7478
7761
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
7479
|
-
if (!llamafile_sgemm(
|
|
7762
|
+
if (!llamafile_sgemm(params,
|
|
7763
|
+
ne01, ne11, ne00/wsp_ggml_blck_size(src0->type),
|
|
7480
7764
|
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
7481
|
-
nb01/wsp_ggml_type_size(type),
|
|
7765
|
+
nb01/wsp_ggml_type_size(src0->type),
|
|
7482
7766
|
(const char *)src1->data + i12*nb12 + i13*nb13,
|
|
7483
7767
|
nb11/wsp_ggml_type_size(src1->type),
|
|
7484
7768
|
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
7485
7769
|
nb1/wsp_ggml_type_size(dst->type),
|
|
7486
|
-
|
|
7487
|
-
type,
|
|
7770
|
+
src0->type,
|
|
7488
7771
|
src1->type,
|
|
7489
7772
|
dst->type))
|
|
7490
7773
|
goto UseGgmlGemm1;
|
|
@@ -7496,6 +7779,7 @@ UseGgmlGemm1:;
|
|
|
7496
7779
|
if (src1->type != vec_dot_type) {
|
|
7497
7780
|
char * wdata = params->wdata;
|
|
7498
7781
|
|
|
7782
|
+
const size_t nbw0 = wsp_ggml_type_size(vec_dot_type);
|
|
7499
7783
|
const size_t nbw1 = wsp_ggml_row_size(vec_dot_type, ne10);
|
|
7500
7784
|
const size_t nbw2 = nbw1*ne11;
|
|
7501
7785
|
const size_t nbw3 = nbw2*ne12;
|
|
@@ -7503,24 +7787,30 @@ UseGgmlGemm1:;
|
|
|
7503
7787
|
assert(params->wsize >= ne13*nbw3);
|
|
7504
7788
|
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
7505
7789
|
|
|
7790
|
+
#if 0
|
|
7506
7791
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
7507
7792
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
7508
|
-
int64_t
|
|
7509
|
-
if ((wsp_ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
|
|
7510
|
-
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
|
7511
|
-
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
|
7512
|
-
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
|
7513
|
-
4, ne10, blck_size_interleave);
|
|
7514
|
-
}
|
|
7515
|
-
i11_processed = ne11 - ne11 % 4;
|
|
7516
|
-
}
|
|
7517
|
-
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
|
7793
|
+
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
|
7518
7794
|
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
|
7519
|
-
|
|
7520
|
-
|
|
7795
|
+
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
|
7796
|
+
ne10);
|
|
7797
|
+
}
|
|
7798
|
+
}
|
|
7799
|
+
}
|
|
7800
|
+
#else
|
|
7801
|
+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
7802
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
7803
|
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
7804
|
+
size_t bs = wsp_ggml_blck_size(vec_dot_type);
|
|
7805
|
+
int64_t ne10_block_start = (ith * ne10/bs) / nth;
|
|
7806
|
+
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
|
|
7807
|
+
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
|
|
7808
|
+
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
|
|
7809
|
+
(ne10_block_end - ne10_block_start) * bs);
|
|
7521
7810
|
}
|
|
7522
7811
|
}
|
|
7523
7812
|
}
|
|
7813
|
+
#endif
|
|
7524
7814
|
}
|
|
7525
7815
|
|
|
7526
7816
|
if (ith == 0) {
|
|
@@ -7537,15 +7827,15 @@ UseGgmlGemm1:;
|
|
|
7537
7827
|
|
|
7538
7828
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
7539
7829
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
7540
|
-
if (!llamafile_sgemm(
|
|
7830
|
+
if (!llamafile_sgemm(params,
|
|
7831
|
+
ne01, ne11, ne00/wsp_ggml_blck_size(src0->type),
|
|
7541
7832
|
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
7542
|
-
nb01/wsp_ggml_type_size(type),
|
|
7833
|
+
nb01/wsp_ggml_type_size(src0->type),
|
|
7543
7834
|
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
|
|
7544
7835
|
row_size/wsp_ggml_type_size(vec_dot_type),
|
|
7545
7836
|
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
7546
7837
|
nb1/wsp_ggml_type_size(dst->type),
|
|
7547
|
-
|
|
7548
|
-
type,
|
|
7838
|
+
src0->type,
|
|
7549
7839
|
vec_dot_type,
|
|
7550
7840
|
dst->type))
|
|
7551
7841
|
goto UseGgmlGemm2;
|
|
@@ -7560,14 +7850,6 @@ UseGgmlGemm2:;
|
|
|
7560
7850
|
// This is the size of the rest of the dimensions of the result
|
|
7561
7851
|
const int64_t nr1 = ne1 * ne2 * ne3;
|
|
7562
7852
|
|
|
7563
|
-
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
|
7564
|
-
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
|
7565
|
-
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
|
7566
|
-
// this check can be removed once they are extended to support odd numbered rows/cols too
|
|
7567
|
-
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
|
7568
|
-
num_rows_per_vec_dot = 1;
|
|
7569
|
-
}
|
|
7570
|
-
|
|
7571
7853
|
// Now select a reasonable chunk size.
|
|
7572
7854
|
int chunk_size = 16;
|
|
7573
7855
|
|
|
@@ -7583,7 +7865,7 @@ UseGgmlGemm2:;
|
|
|
7583
7865
|
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
|
7584
7866
|
|
|
7585
7867
|
// If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
|
|
7586
|
-
// Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/
|
|
7868
|
+
// Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggml-org/llama.cpp/pull/6915
|
|
7587
7869
|
// In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
|
|
7588
7870
|
if (nchunk0 * nchunk1 < nth * 4 || wsp_ggml_is_numa()) {
|
|
7589
7871
|
// distribute the thread work across the inner or outer loop based on which one is larger
|
|
@@ -7595,28 +7877,6 @@ UseGgmlGemm2:;
|
|
|
7595
7877
|
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
|
7596
7878
|
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
|
7597
7879
|
|
|
7598
|
-
if ((wsp_ggml_n_dims(src0) == 2) && gemv) {
|
|
7599
|
-
const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
|
7600
|
-
const size_t src1_col_stride = wsp_ggml_is_contiguous(src1) || src1->type != vec_dot_type ? wsp_ggml_row_size(vec_dot_type, ne10) : nb11;
|
|
7601
|
-
int64_t src0_start = (ith * ne01) / nth;
|
|
7602
|
-
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
|
7603
|
-
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
|
|
7604
|
-
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
|
|
7605
|
-
if (src0_start >= src0_end) return;
|
|
7606
|
-
|
|
7607
|
-
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
|
7608
|
-
if (gemm && (ne11 > 3)) {
|
|
7609
|
-
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
|
|
7610
|
-
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
|
7611
|
-
}
|
|
7612
|
-
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
|
|
7613
|
-
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
|
7614
|
-
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
|
7615
|
-
src0_end - src0_start);
|
|
7616
|
-
}
|
|
7617
|
-
return;
|
|
7618
|
-
}
|
|
7619
|
-
|
|
7620
7880
|
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
7621
7881
|
int current_chunk = ith;
|
|
7622
7882
|
|
|
@@ -7630,7 +7890,15 @@ UseGgmlGemm2:;
|
|
|
7630
7890
|
const int64_t ir1_start = dr1 * ith1;
|
|
7631
7891
|
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
|
7632
7892
|
|
|
7633
|
-
|
|
7893
|
+
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
|
7894
|
+
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
|
7895
|
+
|
|
7896
|
+
// these checks are needed to avoid crossing dim1 boundaries
|
|
7897
|
+
// can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
|
|
7898
|
+
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
|
|
7899
|
+
num_rows_per_vec_dot = 1;
|
|
7900
|
+
}
|
|
7901
|
+
wsp_ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
|
7634
7902
|
|
|
7635
7903
|
if (nth >= nchunk0 * nchunk1) {
|
|
7636
7904
|
break;
|
|
@@ -7642,6 +7910,84 @@ UseGgmlGemm2:;
|
|
|
7642
7910
|
|
|
7643
7911
|
// wsp_ggml_compute_forward_mul_mat_id
|
|
7644
7912
|
|
|
7913
|
+
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
|
|
7914
|
+
|
|
7915
|
+
struct mmid_row_mapping {
|
|
7916
|
+
int32_t i1;
|
|
7917
|
+
int32_t i2;
|
|
7918
|
+
};
|
|
7919
|
+
|
|
7920
|
+
static void wsp_ggml_compute_forward_mul_mat_id_one_chunk(
|
|
7921
|
+
struct wsp_ggml_tensor * dst,
|
|
7922
|
+
const struct wsp_ggml_tensor * src0,
|
|
7923
|
+
const struct wsp_ggml_tensor * src1,
|
|
7924
|
+
const struct wsp_ggml_tensor * ids,
|
|
7925
|
+
const int64_t cur_a,
|
|
7926
|
+
const int64_t ir0_start,
|
|
7927
|
+
const int64_t ir0_end,
|
|
7928
|
+
const int64_t ir1_start,
|
|
7929
|
+
const int64_t ir1_end,
|
|
7930
|
+
const char * src0_cur,
|
|
7931
|
+
const struct mmid_row_mapping * matrix_rows,
|
|
7932
|
+
const size_t row_size,
|
|
7933
|
+
const bool src1_cont,
|
|
7934
|
+
const void * wdata) {
|
|
7935
|
+
|
|
7936
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
7937
|
+
|
|
7938
|
+
const enum wsp_ggml_type type = src0->type;
|
|
7939
|
+
|
|
7940
|
+
wsp_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
|
7941
|
+
enum wsp_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
|
7942
|
+
|
|
7943
|
+
const int64_t blck_0 = 16;
|
|
7944
|
+
const int64_t blck_1 = 16;
|
|
7945
|
+
|
|
7946
|
+
float tmp[16];
|
|
7947
|
+
|
|
7948
|
+
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
|
7949
|
+
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
|
7950
|
+
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
|
|
7951
|
+
const int64_t _i12 = ir1; // logical row index for this expert
|
|
7952
|
+
|
|
7953
|
+
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
|
|
7954
|
+
const int id = row_mapping.i1; // selected expert index
|
|
7955
|
+
|
|
7956
|
+
const int64_t i11 = id % ne11;
|
|
7957
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
7958
|
+
|
|
7959
|
+
const int64_t i1 = id; // selected expert index
|
|
7960
|
+
const int64_t i2 = i12; // row
|
|
7961
|
+
|
|
7962
|
+
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
|
7963
|
+
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
|
7964
|
+
// the original src1 data pointer, so we should index using the indices directly
|
|
7965
|
+
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
|
7966
|
+
const char * src1_col = (const char *) wdata +
|
|
7967
|
+
(src1_cont || src1->type != vec_dot_type
|
|
7968
|
+
? (i11 + i12*ne11)*row_size
|
|
7969
|
+
: (i11*nb11 + i12*nb12));
|
|
7970
|
+
|
|
7971
|
+
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
|
|
7972
|
+
|
|
7973
|
+
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
|
7974
|
+
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
|
|
7975
|
+
}
|
|
7976
|
+
|
|
7977
|
+
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
|
|
7978
|
+
}
|
|
7979
|
+
}
|
|
7980
|
+
}
|
|
7981
|
+
}
|
|
7982
|
+
|
|
7983
|
+
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
|
|
7984
|
+
|
|
7985
|
+
void * ptr = *p;
|
|
7986
|
+
ptr = (void *) WSP_GGML_PAD((uintptr_t) ptr, align);
|
|
7987
|
+
*p = (void *) ((char *) ptr + size);
|
|
7988
|
+
return ptr;
|
|
7989
|
+
}
|
|
7990
|
+
|
|
7645
7991
|
static void wsp_ggml_compute_forward_mul_mat_id(
|
|
7646
7992
|
const struct wsp_ggml_compute_params * params,
|
|
7647
7993
|
struct wsp_ggml_tensor * dst) {
|
|
@@ -7659,11 +8005,8 @@ static void wsp_ggml_compute_forward_mul_mat_id(
|
|
|
7659
8005
|
|
|
7660
8006
|
const bool src1_cont = wsp_ggml_is_contiguous(src1);
|
|
7661
8007
|
|
|
7662
|
-
wsp_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
|
7663
8008
|
enum wsp_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
|
7664
8009
|
wsp_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
|
|
7665
|
-
int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
|
|
7666
|
-
wsp_ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
|
|
7667
8010
|
|
|
7668
8011
|
// we don't support permuted src0 or src1
|
|
7669
8012
|
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
|
|
@@ -7679,21 +8022,27 @@ static void wsp_ggml_compute_forward_mul_mat_id(
|
|
|
7679
8022
|
const int n_ids = ids->ne[0]; // n_expert_used
|
|
7680
8023
|
const int n_as = ne02; // n_expert
|
|
7681
8024
|
|
|
7682
|
-
|
|
7683
|
-
(char *) params->wdata :
|
|
7684
|
-
(char *) params->wdata + WSP_GGML_PAD(wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)), sizeof(int64_t));
|
|
8025
|
+
void * wdata_cur = params->wdata;
|
|
7685
8026
|
|
|
7686
|
-
|
|
7687
|
-
|
|
7688
|
-
|
|
7689
|
-
|
|
8027
|
+
if (src1->type != vec_dot_type) {
|
|
8028
|
+
incr_ptr_aligned(&wdata_cur, wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)), sizeof(int64_t));
|
|
8029
|
+
}
|
|
8030
|
+
|
|
8031
|
+
int64_t * matrix_row_counts = // [n_as]
|
|
8032
|
+
incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
|
|
8033
|
+
|
|
8034
|
+
struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
|
|
8035
|
+
incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
|
|
7690
8036
|
|
|
7691
|
-
|
|
7692
|
-
|
|
8037
|
+
char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
|
|
8038
|
+
incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
|
|
8039
|
+
|
|
8040
|
+
WSP_GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
|
|
7693
8041
|
|
|
7694
8042
|
if (src1->type != vec_dot_type) {
|
|
7695
8043
|
char * wdata = params->wdata;
|
|
7696
8044
|
|
|
8045
|
+
const size_t nbw0 = wsp_ggml_type_size(vec_dot_type);
|
|
7697
8046
|
const size_t nbw1 = wsp_ggml_row_size(vec_dot_type, ne10);
|
|
7698
8047
|
const size_t nbw2 = nbw1*ne11;
|
|
7699
8048
|
const size_t nbw3 = nbw2*ne12;
|
|
@@ -7701,19 +8050,32 @@ static void wsp_ggml_compute_forward_mul_mat_id(
|
|
|
7701
8050
|
assert(params->wsize >= ne13*nbw3);
|
|
7702
8051
|
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
7703
8052
|
|
|
8053
|
+
#if 0
|
|
7704
8054
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
7705
|
-
for (int64_t i12 =
|
|
7706
|
-
for (int64_t i11 =
|
|
8055
|
+
for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
|
|
8056
|
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
7707
8057
|
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
|
7708
8058
|
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
|
7709
8059
|
ne10);
|
|
7710
8060
|
}
|
|
7711
8061
|
}
|
|
7712
8062
|
}
|
|
8063
|
+
#else
|
|
8064
|
+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
8065
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
8066
|
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
8067
|
+
size_t bs = wsp_ggml_blck_size(vec_dot_type);
|
|
8068
|
+
int64_t ne10_block_start = (ith * ne10/bs) / nth;
|
|
8069
|
+
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
|
|
8070
|
+
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
|
|
8071
|
+
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
|
|
8072
|
+
(ne10_block_end - ne10_block_start) * bs);
|
|
8073
|
+
}
|
|
8074
|
+
}
|
|
8075
|
+
}
|
|
8076
|
+
#endif
|
|
7713
8077
|
}
|
|
7714
8078
|
|
|
7715
|
-
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
|
|
7716
|
-
|
|
7717
8079
|
if (ith == 0) {
|
|
7718
8080
|
// initialize matrix_row_counts
|
|
7719
8081
|
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
|
@@ -7731,9 +8093,14 @@ static void wsp_ggml_compute_forward_mul_mat_id(
|
|
|
7731
8093
|
}
|
|
7732
8094
|
}
|
|
7733
8095
|
|
|
8096
|
+
// reset current_chunk
|
|
8097
|
+
for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
|
|
8098
|
+
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
|
|
8099
|
+
*current_chunk_ctr = nth;
|
|
8100
|
+
}
|
|
8101
|
+
|
|
7734
8102
|
wsp_ggml_barrier(params->threadpool);
|
|
7735
8103
|
|
|
7736
|
-
// compute each matrix multiplication in sequence
|
|
7737
8104
|
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
|
7738
8105
|
const int64_t cne1 = matrix_row_counts[cur_a];
|
|
7739
8106
|
|
|
@@ -7741,112 +8108,64 @@ static void wsp_ggml_compute_forward_mul_mat_id(
|
|
|
7741
8108
|
continue;
|
|
7742
8109
|
}
|
|
7743
8110
|
|
|
7744
|
-
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
|
|
7745
|
-
|
|
7746
|
-
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
|
8111
|
+
const char * src0_cur = (const char *) src0->data + cur_a * nb02;
|
|
8112
|
+
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
|
7747
8113
|
const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10);
|
|
7748
8114
|
|
|
7749
|
-
const int64_t nr0 = ne01;
|
|
7750
|
-
const int64_t nr1 = cne1;
|
|
7751
|
-
|
|
7752
|
-
if (((wsp_ggml_n_dims(src0) - 1) == 2) && gemv) {
|
|
7753
|
-
int64_t src0_cur_start = (ith * ne01) / nth;
|
|
7754
|
-
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
|
7755
|
-
src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
|
|
7756
|
-
src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
|
|
7757
|
-
if (src0_cur_start >= src0_cur_end) return;
|
|
7758
|
-
|
|
7759
|
-
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
|
7760
|
-
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
|
7761
|
-
const int id = row_mapping.i1; // selected expert index
|
|
7762
|
-
|
|
7763
|
-
const int64_t i11 = id % ne11;
|
|
7764
|
-
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
7765
|
-
|
|
7766
|
-
const int64_t i1 = id; // selected expert index
|
|
7767
|
-
const int64_t i2 = i12; // row
|
|
7768
|
-
|
|
7769
|
-
const char * src1_col = (const char *) wdata +
|
|
7770
|
-
(src1_cont || src1->type != vec_dot_type
|
|
7771
|
-
? (i11 + i12 * ne11) * row_size
|
|
7772
|
-
: (i11 * nb11 + i12 * nb12));
|
|
8115
|
+
const int64_t nr0 = ne01;
|
|
8116
|
+
const int64_t nr1 = cne1;
|
|
7773
8117
|
|
|
7774
|
-
|
|
7775
|
-
|
|
7776
|
-
|
|
7777
|
-
continue;
|
|
8118
|
+
int chunk_size = 16;
|
|
8119
|
+
if (nr0 == 1 || nr1 == 1) {
|
|
8120
|
+
chunk_size = 64;
|
|
7778
8121
|
}
|
|
7779
8122
|
|
|
7780
|
-
|
|
7781
|
-
|
|
7782
|
-
const
|
|
7783
|
-
|
|
7784
|
-
|
|
7785
|
-
const
|
|
7786
|
-
|
|
7787
|
-
|
|
7788
|
-
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
|
|
7789
|
-
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
|
|
7790
|
-
|
|
7791
|
-
const int64_t ir010 = dr0*ith0;
|
|
7792
|
-
const int64_t ir011 = MIN(ir010 + dr0, nr0);
|
|
7793
|
-
|
|
7794
|
-
const int64_t ir110 = dr1*ith1;
|
|
7795
|
-
const int64_t ir111 = MIN(ir110 + dr1, nr1);
|
|
7796
|
-
|
|
7797
|
-
// threads with no work simply yield (not sure if it helps)
|
|
7798
|
-
//if (ir010 >= ir011 || ir110 >= ir111) {
|
|
7799
|
-
// sched_yield();
|
|
7800
|
-
// continue;
|
|
7801
|
-
//}
|
|
7802
|
-
|
|
7803
|
-
// block-tiling attempt
|
|
7804
|
-
const int64_t blck_0 = 16;
|
|
7805
|
-
const int64_t blck_1 = 16;
|
|
8123
|
+
#if defined(__aarch64__)
|
|
8124
|
+
// disable for ARM
|
|
8125
|
+
const bool disable_chunking = true;
|
|
8126
|
+
#else
|
|
8127
|
+
// disable for NUMA
|
|
8128
|
+
const bool disable_chunking = wsp_ggml_is_numa();
|
|
8129
|
+
#endif // defined(__aarch64__)
|
|
7806
8130
|
|
|
7807
|
-
|
|
7808
|
-
|
|
8131
|
+
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
|
8132
|
+
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
|
7809
8133
|
|
|
7810
|
-
|
|
7811
|
-
|
|
7812
|
-
|
|
7813
|
-
|
|
8134
|
+
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
|
|
8135
|
+
nchunk0 = nr0 > nr1 ? nth : 1;
|
|
8136
|
+
nchunk1 = nr0 > nr1 ? 1 : nth;
|
|
8137
|
+
}
|
|
7814
8138
|
|
|
7815
|
-
|
|
7816
|
-
|
|
8139
|
+
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
|
8140
|
+
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
|
7817
8141
|
|
|
7818
|
-
|
|
7819
|
-
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
8142
|
+
int current_chunk = ith;
|
|
7820
8143
|
|
|
7821
|
-
|
|
7822
|
-
const int64_t i2 = i12; // row
|
|
8144
|
+
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
|
|
7823
8145
|
|
|
7824
|
-
|
|
7825
|
-
|
|
7826
|
-
|
|
7827
|
-
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
|
7828
|
-
const char * src1_col = (const char *) wdata +
|
|
7829
|
-
(src1_cont || src1->type != vec_dot_type
|
|
7830
|
-
? (i11 + i12*ne11)*row_size
|
|
7831
|
-
: (i11*nb11 + i12*nb12));
|
|
8146
|
+
while (current_chunk < nchunk0 * nchunk1) {
|
|
8147
|
+
const int64_t ith0 = current_chunk % nchunk0;
|
|
8148
|
+
const int64_t ith1 = current_chunk / nchunk0;
|
|
7832
8149
|
|
|
7833
|
-
|
|
8150
|
+
const int64_t ir0_start = dr0 * ith0;
|
|
8151
|
+
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
|
7834
8152
|
|
|
7835
|
-
|
|
7836
|
-
|
|
7837
|
-
//}
|
|
8153
|
+
const int64_t ir1_start = dr1 * ith1;
|
|
8154
|
+
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
|
7838
8155
|
|
|
7839
|
-
|
|
7840
|
-
|
|
7841
|
-
|
|
8156
|
+
wsp_ggml_compute_forward_mul_mat_id_one_chunk(
|
|
8157
|
+
dst, src0, src1, ids, cur_a,
|
|
8158
|
+
ir0_start, ir0_end, ir1_start, ir1_end,
|
|
8159
|
+
src0_cur, matrix_rows, row_size, src1_cont, wdata
|
|
8160
|
+
);
|
|
7842
8161
|
|
|
7843
|
-
|
|
7844
|
-
|
|
8162
|
+
if (nth >= nchunk0 * nchunk1) {
|
|
8163
|
+
break;
|
|
7845
8164
|
}
|
|
8165
|
+
|
|
8166
|
+
current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
|
|
7846
8167
|
}
|
|
7847
8168
|
}
|
|
7848
|
-
|
|
7849
|
-
#undef MMID_MATRIX_ROW
|
|
7850
8169
|
}
|
|
7851
8170
|
|
|
7852
8171
|
// wsp_ggml_compute_forward_out_prod
|
|
@@ -7867,12 +8186,13 @@ static void wsp_ggml_compute_forward_out_prod_f32(
|
|
|
7867
8186
|
const int ith = params->ith;
|
|
7868
8187
|
const int nth = params->nth;
|
|
7869
8188
|
|
|
7870
|
-
WSP_GGML_ASSERT(ne0
|
|
7871
|
-
WSP_GGML_ASSERT(ne1
|
|
7872
|
-
WSP_GGML_ASSERT(ne2
|
|
7873
|
-
WSP_GGML_ASSERT(
|
|
7874
|
-
|
|
7875
|
-
WSP_GGML_ASSERT(
|
|
8189
|
+
WSP_GGML_ASSERT(ne0 == ne00);
|
|
8190
|
+
WSP_GGML_ASSERT(ne1 == ne10);
|
|
8191
|
+
WSP_GGML_ASSERT(ne2 == ne12);
|
|
8192
|
+
WSP_GGML_ASSERT(ne3 == ne13);
|
|
8193
|
+
|
|
8194
|
+
WSP_GGML_ASSERT(ne2 % ne02 == 0);
|
|
8195
|
+
WSP_GGML_ASSERT(ne3 % ne03 == 0);
|
|
7876
8196
|
|
|
7877
8197
|
// we don't support permuted src0 or src1
|
|
7878
8198
|
WSP_GGML_ASSERT(nb00 == sizeof(float));
|
|
@@ -7914,6 +8234,10 @@ static void wsp_ggml_compute_forward_out_prod_f32(
|
|
|
7914
8234
|
const int64_t blck_0 = MAX(WSP_GGML_VEC_MAD_UNROLL, 32);
|
|
7915
8235
|
const int64_t blck_1 = 16;
|
|
7916
8236
|
|
|
8237
|
+
// dps == dst per src0, used for group query attention
|
|
8238
|
+
const int64_t dps2 = ne2 / ne02;
|
|
8239
|
+
const int64_t dps3 = ne3 / ne03;
|
|
8240
|
+
|
|
7917
8241
|
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
|
|
7918
8242
|
const int64_t bir1 = MIN(bir + blck_1, ir1);
|
|
7919
8243
|
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
|
|
@@ -7924,8 +8248,8 @@ static void wsp_ggml_compute_forward_out_prod_f32(
|
|
|
7924
8248
|
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
|
7925
8249
|
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
|
7926
8250
|
|
|
7927
|
-
const int64_t i02 = i2;
|
|
7928
|
-
const int64_t i03 = i3;
|
|
8251
|
+
const int64_t i02 = i2 / dps2;
|
|
8252
|
+
const int64_t i03 = i3 / dps3;
|
|
7929
8253
|
|
|
7930
8254
|
//const int64_t i10 = i1;
|
|
7931
8255
|
const int64_t i12 = i2;
|
|
@@ -7938,7 +8262,7 @@ static void wsp_ggml_compute_forward_out_prod_f32(
|
|
|
7938
8262
|
|
|
7939
8263
|
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
|
7940
8264
|
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
|
7941
|
-
float * d = (float *) ((char *) dst->data + ( i1*nb1
|
|
8265
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
|
7942
8266
|
|
|
7943
8267
|
wsp_ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
|
7944
8268
|
}
|
|
@@ -7947,7 +8271,7 @@ static void wsp_ggml_compute_forward_out_prod_f32(
|
|
|
7947
8271
|
|
|
7948
8272
|
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
|
7949
8273
|
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
|
7950
|
-
float * d = (float *) ((char *) dst->data + ( i1*nb1
|
|
8274
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
|
7951
8275
|
|
|
7952
8276
|
wsp_ggml_vec_mad_f32(ne0, d, s0, *s1);
|
|
7953
8277
|
}
|
|
@@ -8084,9 +8408,6 @@ static void wsp_ggml_compute_forward_out_prod(
|
|
|
8084
8408
|
case WSP_GGML_TYPE_IQ4_XS:
|
|
8085
8409
|
case WSP_GGML_TYPE_IQ3_S:
|
|
8086
8410
|
case WSP_GGML_TYPE_IQ2_S:
|
|
8087
|
-
case WSP_GGML_TYPE_Q4_0_4_4:
|
|
8088
|
-
case WSP_GGML_TYPE_Q4_0_4_8:
|
|
8089
|
-
case WSP_GGML_TYPE_Q4_0_8_8:
|
|
8090
8411
|
{
|
|
8091
8412
|
wsp_ggml_compute_forward_out_prod_q_f32(params, dst);
|
|
8092
8413
|
} break;
|
|
@@ -8239,6 +8560,77 @@ static void wsp_ggml_compute_forward_set_f32(
|
|
|
8239
8560
|
}
|
|
8240
8561
|
}
|
|
8241
8562
|
|
|
8563
|
+
static void wsp_ggml_compute_forward_set_i32(
|
|
8564
|
+
const struct wsp_ggml_compute_params * params,
|
|
8565
|
+
struct wsp_ggml_tensor * dst) {
|
|
8566
|
+
|
|
8567
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
8568
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
8569
|
+
|
|
8570
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
|
|
8571
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
|
|
8572
|
+
|
|
8573
|
+
// view src0 and dst with these strides and data offset inbytes during set
|
|
8574
|
+
// nb0 is implicitly element_size because src0 and dst are contiguous
|
|
8575
|
+
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
|
8576
|
+
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
|
8577
|
+
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
|
8578
|
+
size_t offset = ((int32_t *) dst->op_params)[3];
|
|
8579
|
+
bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
8580
|
+
|
|
8581
|
+
if (!inplace) {
|
|
8582
|
+
if (params->ith == 0) {
|
|
8583
|
+
// memcpy needs to be synchronized across threads to avoid race conditions.
|
|
8584
|
+
// => do it in INIT phase
|
|
8585
|
+
memcpy(
|
|
8586
|
+
((char *) dst->data),
|
|
8587
|
+
((char *) src0->data),
|
|
8588
|
+
wsp_ggml_nbytes(dst));
|
|
8589
|
+
}
|
|
8590
|
+
wsp_ggml_barrier(params->threadpool);
|
|
8591
|
+
}
|
|
8592
|
+
|
|
8593
|
+
const int ith = params->ith;
|
|
8594
|
+
const int nth = params->nth;
|
|
8595
|
+
|
|
8596
|
+
const int nr = wsp_ggml_nrows(src1);
|
|
8597
|
+
const int nc = src1->ne[0];
|
|
8598
|
+
|
|
8599
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
|
8600
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
|
8601
|
+
|
|
8602
|
+
// src0 and dst as viewed during set
|
|
8603
|
+
const size_t nb0 = wsp_ggml_element_size(src0);
|
|
8604
|
+
|
|
8605
|
+
const int im0 = (ne10 == 0 ? 0 : ne10-1);
|
|
8606
|
+
const int im1 = (ne11 == 0 ? 0 : ne11-1);
|
|
8607
|
+
const int im2 = (ne12 == 0 ? 0 : ne12-1);
|
|
8608
|
+
const int im3 = (ne13 == 0 ? 0 : ne13-1);
|
|
8609
|
+
|
|
8610
|
+
WSP_GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= wsp_ggml_nbytes(dst));
|
|
8611
|
+
|
|
8612
|
+
WSP_GGML_ASSERT(nb10 == sizeof(int32_t));
|
|
8613
|
+
|
|
8614
|
+
// rows per thread
|
|
8615
|
+
const int dr = (nr + nth - 1)/nth;
|
|
8616
|
+
|
|
8617
|
+
// row range for this thread
|
|
8618
|
+
const int ir0 = dr*ith;
|
|
8619
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
8620
|
+
|
|
8621
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
8622
|
+
// src0 and dst are viewed with shape of src1 and offset
|
|
8623
|
+
// => same indices
|
|
8624
|
+
const int i3 = ir/(ne12*ne11);
|
|
8625
|
+
const int i2 = (ir - i3*ne12*ne11)/ne11;
|
|
8626
|
+
const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
|
|
8627
|
+
|
|
8628
|
+
wsp_ggml_vec_cpy_i32(nc,
|
|
8629
|
+
(int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
|
|
8630
|
+
(int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
|
|
8631
|
+
}
|
|
8632
|
+
}
|
|
8633
|
+
|
|
8242
8634
|
static void wsp_ggml_compute_forward_set(
|
|
8243
8635
|
const struct wsp_ggml_compute_params * params,
|
|
8244
8636
|
struct wsp_ggml_tensor * dst) {
|
|
@@ -8250,6 +8642,10 @@ static void wsp_ggml_compute_forward_set(
|
|
|
8250
8642
|
{
|
|
8251
8643
|
wsp_ggml_compute_forward_set_f32(params, dst);
|
|
8252
8644
|
} break;
|
|
8645
|
+
case WSP_GGML_TYPE_I32:
|
|
8646
|
+
{
|
|
8647
|
+
wsp_ggml_compute_forward_set_i32(params, dst);
|
|
8648
|
+
} break;
|
|
8253
8649
|
case WSP_GGML_TYPE_F16:
|
|
8254
8650
|
case WSP_GGML_TYPE_BF16:
|
|
8255
8651
|
case WSP_GGML_TYPE_Q4_0:
|
|
@@ -8274,9 +8670,6 @@ static void wsp_ggml_compute_forward_set(
|
|
|
8274
8670
|
case WSP_GGML_TYPE_IQ4_XS:
|
|
8275
8671
|
case WSP_GGML_TYPE_IQ3_S:
|
|
8276
8672
|
case WSP_GGML_TYPE_IQ2_S:
|
|
8277
|
-
case WSP_GGML_TYPE_Q4_0_4_4:
|
|
8278
|
-
case WSP_GGML_TYPE_Q4_0_4_8:
|
|
8279
|
-
case WSP_GGML_TYPE_Q4_0_8_8:
|
|
8280
8673
|
default:
|
|
8281
8674
|
{
|
|
8282
8675
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -8538,9 +8931,6 @@ static void wsp_ggml_compute_forward_get_rows(
|
|
|
8538
8931
|
case WSP_GGML_TYPE_IQ4_XS:
|
|
8539
8932
|
case WSP_GGML_TYPE_IQ3_S:
|
|
8540
8933
|
case WSP_GGML_TYPE_IQ2_S:
|
|
8541
|
-
case WSP_GGML_TYPE_Q4_0_4_4:
|
|
8542
|
-
case WSP_GGML_TYPE_Q4_0_4_8:
|
|
8543
|
-
case WSP_GGML_TYPE_Q4_0_8_8:
|
|
8544
8934
|
{
|
|
8545
8935
|
wsp_ggml_compute_forward_get_rows_q(params, dst);
|
|
8546
8936
|
} break;
|
|
@@ -8957,9 +9347,9 @@ static void wsp_ggml_compute_forward_soft_max(
|
|
|
8957
9347
|
}
|
|
8958
9348
|
|
|
8959
9349
|
|
|
8960
|
-
//
|
|
9350
|
+
// wsp_ggml_compute_forward_soft_max_ext_back
|
|
8961
9351
|
|
|
8962
|
-
static void
|
|
9352
|
+
static void wsp_ggml_compute_forward_soft_max_ext_back_f32(
|
|
8963
9353
|
const struct wsp_ggml_compute_params * params,
|
|
8964
9354
|
struct wsp_ggml_tensor * dst) {
|
|
8965
9355
|
|
|
@@ -8972,6 +9362,14 @@ static void wsp_ggml_compute_forward_soft_max_back_f32(
|
|
|
8972
9362
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
|
|
8973
9363
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src1, dst));
|
|
8974
9364
|
|
|
9365
|
+
float scale = 1.0f;
|
|
9366
|
+
float max_bias = 0.0f;
|
|
9367
|
+
|
|
9368
|
+
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
|
9369
|
+
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
|
9370
|
+
|
|
9371
|
+
WSP_GGML_ASSERT(max_bias == 0.0f);
|
|
9372
|
+
|
|
8975
9373
|
// TODO: handle transposed/permuted matrices
|
|
8976
9374
|
|
|
8977
9375
|
const int ith = params->ith;
|
|
@@ -9020,10 +9418,11 @@ static void wsp_ggml_compute_forward_soft_max_back_f32(
|
|
|
9020
9418
|
|
|
9021
9419
|
// linear runtime, no additional memory
|
|
9022
9420
|
float dot_y_dy = 0;
|
|
9023
|
-
wsp_ggml_vec_dot_f32
|
|
9024
|
-
wsp_ggml_vec_cpy_f32
|
|
9025
|
-
wsp_ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
|
9026
|
-
wsp_ggml_vec_mul_f32
|
|
9421
|
+
wsp_ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
|
|
9422
|
+
wsp_ggml_vec_cpy_f32 (nc, dx, dy);
|
|
9423
|
+
wsp_ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
|
|
9424
|
+
wsp_ggml_vec_mul_f32 (nc, dx, dx, y);
|
|
9425
|
+
wsp_ggml_vec_scale_f32(nc, dx, scale);
|
|
9027
9426
|
|
|
9028
9427
|
#ifndef NDEBUG
|
|
9029
9428
|
for (int i = 0; i < nc; ++i) {
|
|
@@ -9034,7 +9433,7 @@ static void wsp_ggml_compute_forward_soft_max_back_f32(
|
|
|
9034
9433
|
}
|
|
9035
9434
|
}
|
|
9036
9435
|
|
|
9037
|
-
static void
|
|
9436
|
+
static void wsp_ggml_compute_forward_soft_max_ext_back(
|
|
9038
9437
|
const struct wsp_ggml_compute_params * params,
|
|
9039
9438
|
struct wsp_ggml_tensor * dst) {
|
|
9040
9439
|
|
|
@@ -9043,7 +9442,7 @@ static void wsp_ggml_compute_forward_soft_max_back(
|
|
|
9043
9442
|
switch (src0->type) {
|
|
9044
9443
|
case WSP_GGML_TYPE_F32:
|
|
9045
9444
|
{
|
|
9046
|
-
|
|
9445
|
+
wsp_ggml_compute_forward_soft_max_ext_back_f32(params, dst);
|
|
9047
9446
|
} break;
|
|
9048
9447
|
default:
|
|
9049
9448
|
{
|
|
@@ -9060,10 +9459,6 @@ static void wsp_ggml_compute_forward_clamp_f32(
|
|
|
9060
9459
|
|
|
9061
9460
|
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
9062
9461
|
|
|
9063
|
-
if (params->ith != 0) {
|
|
9064
|
-
return;
|
|
9065
|
-
}
|
|
9066
|
-
|
|
9067
9462
|
float min;
|
|
9068
9463
|
float max;
|
|
9069
9464
|
memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
|
|
@@ -9130,9 +9525,6 @@ static void wsp_ggml_compute_forward_clamp(
|
|
|
9130
9525
|
case WSP_GGML_TYPE_IQ3_S:
|
|
9131
9526
|
case WSP_GGML_TYPE_IQ2_S:
|
|
9132
9527
|
case WSP_GGML_TYPE_Q8_K:
|
|
9133
|
-
case WSP_GGML_TYPE_Q4_0_4_4:
|
|
9134
|
-
case WSP_GGML_TYPE_Q4_0_4_8:
|
|
9135
|
-
case WSP_GGML_TYPE_Q4_0_8_8:
|
|
9136
9528
|
case WSP_GGML_TYPE_I8:
|
|
9137
9529
|
case WSP_GGML_TYPE_I16:
|
|
9138
9530
|
case WSP_GGML_TYPE_I32:
|
|
@@ -9187,6 +9579,64 @@ static void wsp_ggml_rope_cache_init(
|
|
|
9187
9579
|
}
|
|
9188
9580
|
}
|
|
9189
9581
|
|
|
9582
|
+
static void wsp_ggml_mrope_cache_init(
|
|
9583
|
+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
|
9584
|
+
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
9585
|
+
float * cache, float sin_sign, float theta_scale) {
|
|
9586
|
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
9587
|
+
float theta_t = theta_base_t;
|
|
9588
|
+
float theta_h = theta_base_h;
|
|
9589
|
+
float theta_w = theta_base_w;
|
|
9590
|
+
float theta_e = theta_base_e; // extra position id for vision encoder
|
|
9591
|
+
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
|
9592
|
+
int sec_w = sections[1] + sections[0];
|
|
9593
|
+
int sec_e = sections[2] + sec_w;
|
|
9594
|
+
WSP_GGML_ASSERT(sect_dims <= ne0);
|
|
9595
|
+
|
|
9596
|
+
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
9597
|
+
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
|
9598
|
+
|
|
9599
|
+
int sector = (i0 / 2) % sect_dims;
|
|
9600
|
+
if (indep_sects) {
|
|
9601
|
+
// compute theta independently for each dim sections
|
|
9602
|
+
// (i.e. reset corresponding theta when `i0` go from one section to another)
|
|
9603
|
+
if (sector == 0) {
|
|
9604
|
+
theta_t = theta_base_t;
|
|
9605
|
+
}
|
|
9606
|
+
else if (sector == sections[0]) {
|
|
9607
|
+
theta_h = theta_base_h;;
|
|
9608
|
+
}
|
|
9609
|
+
else if (sector == sec_w) {
|
|
9610
|
+
theta_w = theta_base_w;
|
|
9611
|
+
}
|
|
9612
|
+
else if (sector == sec_e) {
|
|
9613
|
+
theta_e = theta_base_e;
|
|
9614
|
+
}
|
|
9615
|
+
}
|
|
9616
|
+
|
|
9617
|
+
float theta = theta_t;
|
|
9618
|
+
if (sector >= sections[0] && sector < sec_w) {
|
|
9619
|
+
theta = theta_h;
|
|
9620
|
+
}
|
|
9621
|
+
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
|
9622
|
+
theta = theta_w;
|
|
9623
|
+
}
|
|
9624
|
+
else if (sector >= sec_w + sections[2]) {
|
|
9625
|
+
theta = theta_e;
|
|
9626
|
+
}
|
|
9627
|
+
|
|
9628
|
+
rope_yarn(
|
|
9629
|
+
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
|
9630
|
+
);
|
|
9631
|
+
cache[i0 + 1] *= sin_sign;
|
|
9632
|
+
|
|
9633
|
+
theta_t *= theta_scale;
|
|
9634
|
+
theta_w *= theta_scale;
|
|
9635
|
+
theta_h *= theta_scale;
|
|
9636
|
+
theta_e *= theta_scale;
|
|
9637
|
+
}
|
|
9638
|
+
}
|
|
9639
|
+
|
|
9190
9640
|
static void wsp_ggml_compute_forward_rope_f32(
|
|
9191
9641
|
const struct wsp_ggml_compute_params * params,
|
|
9192
9642
|
struct wsp_ggml_tensor * dst,
|
|
@@ -9197,6 +9647,7 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
9197
9647
|
const struct wsp_ggml_tensor * src2 = dst->src[2];
|
|
9198
9648
|
|
|
9199
9649
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
9650
|
+
int sections[4];
|
|
9200
9651
|
|
|
9201
9652
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
9202
9653
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
@@ -9210,6 +9661,7 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
9210
9661
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
9211
9662
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
9212
9663
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
9664
|
+
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
9213
9665
|
|
|
9214
9666
|
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
9215
9667
|
|
|
@@ -9242,6 +9694,16 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
9242
9694
|
wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
9243
9695
|
|
|
9244
9696
|
const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
|
|
9697
|
+
const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, multimodal rotary position embedding
|
|
9698
|
+
const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
|
|
9699
|
+
|
|
9700
|
+
if (is_mrope) {
|
|
9701
|
+
WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
9702
|
+
}
|
|
9703
|
+
|
|
9704
|
+
if (is_vision) {
|
|
9705
|
+
WSP_GGML_ASSERT(n_dims == ne0/2);
|
|
9706
|
+
}
|
|
9245
9707
|
|
|
9246
9708
|
const float * freq_factors = NULL;
|
|
9247
9709
|
if (src2 != NULL) {
|
|
@@ -9257,18 +9719,63 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
9257
9719
|
|
|
9258
9720
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
9259
9721
|
|
|
9260
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
9261
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
9262
|
-
const int64_t p = pos[i2];
|
|
9722
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
9723
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
9263
9724
|
|
|
9264
9725
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
9265
|
-
|
|
9726
|
+
if (!is_mrope) {
|
|
9727
|
+
const int64_t p = pos[i2];
|
|
9728
|
+
wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
9729
|
+
}
|
|
9730
|
+
else {
|
|
9731
|
+
const int64_t p_t = pos[i2];
|
|
9732
|
+
const int64_t p_h = pos[i2 + ne2];
|
|
9733
|
+
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
9734
|
+
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
9735
|
+
wsp_ggml_mrope_cache_init(
|
|
9736
|
+
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
9737
|
+
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
9738
|
+
}
|
|
9266
9739
|
|
|
9267
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
9740
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
9268
9741
|
if (ir++ < ir0) continue;
|
|
9269
9742
|
if (ir > ir1) break;
|
|
9270
9743
|
|
|
9271
|
-
if (
|
|
9744
|
+
if (is_neox || is_mrope) {
|
|
9745
|
+
if (is_vision){
|
|
9746
|
+
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
9747
|
+
const int64_t ic = i0/2;
|
|
9748
|
+
|
|
9749
|
+
const float cos_theta = cache[i0 + 0];
|
|
9750
|
+
const float sin_theta = cache[i0 + 1];
|
|
9751
|
+
|
|
9752
|
+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
9753
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
9754
|
+
|
|
9755
|
+
const float x0 = src[0];
|
|
9756
|
+
const float x1 = src[n_dims];
|
|
9757
|
+
|
|
9758
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
9759
|
+
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
9760
|
+
}
|
|
9761
|
+
} else {
|
|
9762
|
+
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
9763
|
+
const int64_t ic = i0/2;
|
|
9764
|
+
|
|
9765
|
+
const float cos_theta = cache[i0 + 0];
|
|
9766
|
+
const float sin_theta = cache[i0 + 1];
|
|
9767
|
+
|
|
9768
|
+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
9769
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
9770
|
+
|
|
9771
|
+
const float x0 = src[0];
|
|
9772
|
+
const float x1 = src[n_dims/2];
|
|
9773
|
+
|
|
9774
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
9775
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
9776
|
+
}
|
|
9777
|
+
}
|
|
9778
|
+
} else {
|
|
9272
9779
|
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
9273
9780
|
const float cos_theta = cache[i0 + 0];
|
|
9274
9781
|
const float sin_theta = cache[i0 + 1];
|
|
@@ -9282,8 +9789,10 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
9282
9789
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
9283
9790
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
9284
9791
|
}
|
|
9285
|
-
}
|
|
9286
|
-
|
|
9792
|
+
}
|
|
9793
|
+
|
|
9794
|
+
if (is_vision) {
|
|
9795
|
+
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
9287
9796
|
const int64_t ic = i0/2;
|
|
9288
9797
|
|
|
9289
9798
|
const float cos_theta = cache[i0 + 0];
|
|
@@ -9293,19 +9802,20 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
9293
9802
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
9294
9803
|
|
|
9295
9804
|
const float x0 = src[0];
|
|
9296
|
-
const float x1 = src[n_dims
|
|
9805
|
+
const float x1 = src[n_dims];
|
|
9297
9806
|
|
|
9298
|
-
dst_data[0]
|
|
9299
|
-
dst_data[n_dims
|
|
9807
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
9808
|
+
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
9300
9809
|
}
|
|
9301
|
-
}
|
|
9302
|
-
|
|
9303
|
-
|
|
9304
|
-
|
|
9305
|
-
|
|
9810
|
+
} else {
|
|
9811
|
+
// fill the remain channels with data from src tensor
|
|
9812
|
+
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
9813
|
+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
9814
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
9306
9815
|
|
|
9307
|
-
|
|
9308
|
-
|
|
9816
|
+
dst_data[0] = src[0];
|
|
9817
|
+
dst_data[1] = src[1];
|
|
9818
|
+
}
|
|
9309
9819
|
}
|
|
9310
9820
|
}
|
|
9311
9821
|
}
|
|
@@ -9323,6 +9833,7 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
9323
9833
|
const struct wsp_ggml_tensor * src2 = dst->src[2];
|
|
9324
9834
|
|
|
9325
9835
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
9836
|
+
int sections[4];
|
|
9326
9837
|
|
|
9327
9838
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
9328
9839
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
@@ -9335,6 +9846,8 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
9335
9846
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
9336
9847
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
9337
9848
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
9849
|
+
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
9850
|
+
|
|
9338
9851
|
|
|
9339
9852
|
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
9340
9853
|
|
|
@@ -9367,6 +9880,16 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
9367
9880
|
wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
9368
9881
|
|
|
9369
9882
|
const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
|
|
9883
|
+
const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE;
|
|
9884
|
+
const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
|
|
9885
|
+
|
|
9886
|
+
if (is_mrope) {
|
|
9887
|
+
WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
9888
|
+
}
|
|
9889
|
+
|
|
9890
|
+
if (is_vision) {
|
|
9891
|
+
WSP_GGML_ASSERT(n_dims == ne0/2);
|
|
9892
|
+
}
|
|
9370
9893
|
|
|
9371
9894
|
const float * freq_factors = NULL;
|
|
9372
9895
|
if (src2 != NULL) {
|
|
@@ -9384,16 +9907,61 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
9384
9907
|
|
|
9385
9908
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
9386
9909
|
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
9387
|
-
const int64_t p = pos[i2];
|
|
9388
9910
|
|
|
9389
9911
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
9390
|
-
|
|
9912
|
+
if (!is_mrope) {
|
|
9913
|
+
const int64_t p = pos[i2];
|
|
9914
|
+
wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
9915
|
+
}
|
|
9916
|
+
else {
|
|
9917
|
+
const int64_t p_t = pos[i2];
|
|
9918
|
+
const int64_t p_h = pos[i2 + ne2];
|
|
9919
|
+
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
9920
|
+
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
9921
|
+
wsp_ggml_mrope_cache_init(
|
|
9922
|
+
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
9923
|
+
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
9924
|
+
}
|
|
9391
9925
|
|
|
9392
9926
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
9393
9927
|
if (ir++ < ir0) continue;
|
|
9394
9928
|
if (ir > ir1) break;
|
|
9395
9929
|
|
|
9396
|
-
if (
|
|
9930
|
+
if (is_neox || is_mrope) {
|
|
9931
|
+
if (is_vision) {
|
|
9932
|
+
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
9933
|
+
const int64_t ic = i0/2;
|
|
9934
|
+
|
|
9935
|
+
const float cos_theta = cache[i0 + 0];
|
|
9936
|
+
const float sin_theta = cache[i0 + 1];
|
|
9937
|
+
|
|
9938
|
+
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
9939
|
+
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
9940
|
+
|
|
9941
|
+
const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
|
|
9942
|
+
const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims]);
|
|
9943
|
+
|
|
9944
|
+
dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
9945
|
+
dst_data[n_dims] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
9946
|
+
}
|
|
9947
|
+
} else {
|
|
9948
|
+
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
9949
|
+
const int64_t ic = i0/2;
|
|
9950
|
+
|
|
9951
|
+
const float cos_theta = cache[i0 + 0];
|
|
9952
|
+
const float sin_theta = cache[i0 + 1];
|
|
9953
|
+
|
|
9954
|
+
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
9955
|
+
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
9956
|
+
|
|
9957
|
+
const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
|
|
9958
|
+
const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims/2]);
|
|
9959
|
+
|
|
9960
|
+
dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
9961
|
+
dst_data[n_dims/2] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
9962
|
+
}
|
|
9963
|
+
}
|
|
9964
|
+
} else {
|
|
9397
9965
|
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
9398
9966
|
const float cos_theta = cache[i0 + 0];
|
|
9399
9967
|
const float sin_theta = cache[i0 + 1];
|
|
@@ -9407,8 +9975,10 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
9407
9975
|
dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
9408
9976
|
dst_data[1] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
9409
9977
|
}
|
|
9410
|
-
}
|
|
9411
|
-
|
|
9978
|
+
}
|
|
9979
|
+
|
|
9980
|
+
if (is_vision) {
|
|
9981
|
+
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
9412
9982
|
const int64_t ic = i0/2;
|
|
9413
9983
|
|
|
9414
9984
|
const float cos_theta = cache[i0 + 0];
|
|
@@ -9418,19 +9988,19 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
9418
9988
|
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
9419
9989
|
|
|
9420
9990
|
const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
|
|
9421
|
-
const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims
|
|
9991
|
+
const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims]);
|
|
9422
9992
|
|
|
9423
|
-
dst_data[0]
|
|
9424
|
-
dst_data[n_dims
|
|
9993
|
+
dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
9994
|
+
dst_data[n_dims] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
9425
9995
|
}
|
|
9426
|
-
}
|
|
9427
|
-
|
|
9428
|
-
|
|
9429
|
-
|
|
9430
|
-
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
9996
|
+
} else {
|
|
9997
|
+
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
9998
|
+
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
9999
|
+
wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
9431
10000
|
|
|
9432
|
-
|
|
9433
|
-
|
|
10001
|
+
dst_data[0] = src[0];
|
|
10002
|
+
dst_data[1] = src[1];
|
|
10003
|
+
}
|
|
9434
10004
|
}
|
|
9435
10005
|
}
|
|
9436
10006
|
}
|
|
@@ -9861,9 +10431,10 @@ static void wsp_ggml_compute_forward_im2col_back_f32(
|
|
|
9861
10431
|
const struct wsp_ggml_compute_params * params,
|
|
9862
10432
|
struct wsp_ggml_tensor * dst) {
|
|
9863
10433
|
|
|
9864
|
-
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
9865
|
-
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
10434
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
|
|
10435
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1]; // convolution kernel
|
|
9866
10436
|
|
|
10437
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
9867
10438
|
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
9868
10439
|
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
|
|
9869
10440
|
|
|
@@ -9885,11 +10456,11 @@ static void wsp_ggml_compute_forward_im2col_back_f32(
|
|
|
9885
10456
|
const int64_t IH = is_2D ? ne1 : 1;
|
|
9886
10457
|
const int64_t IW = ne0;
|
|
9887
10458
|
|
|
9888
|
-
const int64_t KH = is_2D ?
|
|
9889
|
-
const int64_t KW =
|
|
10459
|
+
const int64_t KH = is_2D ? ne11 : 1;
|
|
10460
|
+
const int64_t KW = ne10;
|
|
9890
10461
|
|
|
9891
|
-
const int64_t OH = is_2D ?
|
|
9892
|
-
const int64_t OW =
|
|
10462
|
+
const int64_t OH = is_2D ? ne02 : 1;
|
|
10463
|
+
const int64_t OW = ne01;
|
|
9893
10464
|
|
|
9894
10465
|
int ofs0 = is_2D ? nb3 : nb2;
|
|
9895
10466
|
int ofs1 = is_2D ? nb2 : nb1;
|
|
@@ -9935,9 +10506,9 @@ static void wsp_ggml_compute_forward_im2col_back_f32(
|
|
|
9935
10506
|
continue;
|
|
9936
10507
|
}
|
|
9937
10508
|
|
|
9938
|
-
const float * const
|
|
10509
|
+
const float * const grad_in = (const float *) src0->data
|
|
9939
10510
|
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
9940
|
-
grad +=
|
|
10511
|
+
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
|
|
9941
10512
|
}
|
|
9942
10513
|
}
|
|
9943
10514
|
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
|
|
@@ -10429,6 +11000,40 @@ static void wsp_ggml_compute_forward_pad(
|
|
|
10429
11000
|
}
|
|
10430
11001
|
}
|
|
10431
11002
|
|
|
11003
|
+
// wsp_ggml_compute_forward_pad_reflect_1d
|
|
11004
|
+
|
|
11005
|
+
static void wsp_ggml_compute_forward_pad_reflect_1d(
|
|
11006
|
+
const struct wsp_ggml_compute_params * params,
|
|
11007
|
+
struct wsp_ggml_tensor * dst) {
|
|
11008
|
+
|
|
11009
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
11010
|
+
|
|
11011
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
11012
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
|
|
11013
|
+
|
|
11014
|
+
const int ith = params->ith;
|
|
11015
|
+
const int nth = params->nth;
|
|
11016
|
+
|
|
11017
|
+
const int32_t * opts = (const int32_t *) dst->op_params;
|
|
11018
|
+
const int p0 = opts[0];
|
|
11019
|
+
const int p1 = opts[1];
|
|
11020
|
+
|
|
11021
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
11022
|
+
|
|
11023
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
11024
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
11025
|
+
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
|
11026
|
+
float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
|
|
11027
|
+
float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
|
|
11028
|
+
|
|
11029
|
+
wsp_ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
11030
|
+
|
|
11031
|
+
for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
|
|
11032
|
+
for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
|
|
11033
|
+
}
|
|
11034
|
+
}
|
|
11035
|
+
}
|
|
11036
|
+
}
|
|
10432
11037
|
|
|
10433
11038
|
// wsp_ggml_compute_forward_arange
|
|
10434
11039
|
|
|
@@ -11645,9 +12250,9 @@ static void wsp_ggml_compute_forward_add_rel_pos(
|
|
|
11645
12250
|
static void wsp_ggml_compute_forward_rwkv_wkv6_f32(
|
|
11646
12251
|
const struct wsp_ggml_compute_params * params,
|
|
11647
12252
|
struct wsp_ggml_tensor * dst) {
|
|
11648
|
-
const int64_t T = dst->src[1]->ne[
|
|
12253
|
+
const int64_t T = dst->src[1]->ne[2];
|
|
11649
12254
|
const int64_t C = dst->ne[0];
|
|
11650
|
-
const int64_t HEADS = dst->src[1]->ne[
|
|
12255
|
+
const int64_t HEADS = dst->src[1]->ne[1];
|
|
11651
12256
|
const int64_t n_seqs = dst->src[5]->ne[1];
|
|
11652
12257
|
const int64_t head_size = C / HEADS;
|
|
11653
12258
|
|
|
@@ -11842,6 +12447,197 @@ static void wsp_ggml_compute_forward_rwkv_wkv6(
|
|
|
11842
12447
|
}
|
|
11843
12448
|
}
|
|
11844
12449
|
|
|
12450
|
+
// wsp_ggml_compute_forward_gla
|
|
12451
|
+
|
|
12452
|
+
static void wsp_ggml_compute_forward_gla_f32(
|
|
12453
|
+
const struct wsp_ggml_compute_params * params,
|
|
12454
|
+
struct wsp_ggml_tensor * dst) {
|
|
12455
|
+
const int64_t T = dst->src[1]->ne[2];
|
|
12456
|
+
const int64_t C = dst->ne[0];
|
|
12457
|
+
const int64_t HEADS = dst->src[1]->ne[1];
|
|
12458
|
+
const int64_t n_seqs = dst->src[4]->ne[1];
|
|
12459
|
+
const int64_t head_size = C / HEADS;
|
|
12460
|
+
const float scale = wsp_ggml_get_op_params_f32(dst, 0);
|
|
12461
|
+
|
|
12462
|
+
float * dst_data = (float *) dst->data;
|
|
12463
|
+
float * state = ((float *) dst->data) + C * T;
|
|
12464
|
+
|
|
12465
|
+
const int ith = params->ith;
|
|
12466
|
+
const int nth = params->nth;
|
|
12467
|
+
|
|
12468
|
+
if (ith >= HEADS) {
|
|
12469
|
+
return;
|
|
12470
|
+
}
|
|
12471
|
+
|
|
12472
|
+
const int h_start = (HEADS * ith) / nth;
|
|
12473
|
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
12474
|
+
(HEADS * (ith + 1)) / nth : HEADS;
|
|
12475
|
+
|
|
12476
|
+
float * k = (float *) dst->src[0]->data;
|
|
12477
|
+
float * v = (float *) dst->src[1]->data;
|
|
12478
|
+
float * q = (float *) dst->src[2]->data;
|
|
12479
|
+
float * g = (float *) dst->src[3]->data;
|
|
12480
|
+
|
|
12481
|
+
size_t t_stride = HEADS * head_size; // Same to C
|
|
12482
|
+
|
|
12483
|
+
size_t h_stride = C / HEADS;
|
|
12484
|
+
WSP_GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
|
12485
|
+
size_t h_stride_2d = head_size * head_size;
|
|
12486
|
+
|
|
12487
|
+
if (ith == 0) {
|
|
12488
|
+
memset(dst_data, 0, T * C * sizeof(float));
|
|
12489
|
+
}
|
|
12490
|
+
wsp_ggml_barrier(params->threadpool);
|
|
12491
|
+
|
|
12492
|
+
|
|
12493
|
+
#if defined(__AVX__) && !defined(__AVX512F__)
|
|
12494
|
+
#define WSP_GGML_F32X WSP_GGML_F32x8
|
|
12495
|
+
#define WSP_GGML_F32X_SET1 WSP_GGML_F32x8_SET1
|
|
12496
|
+
#define WSP_GGML_F32X_LOAD WSP_GGML_F32x8_LOAD
|
|
12497
|
+
#define WSP_GGML_F32X_STORE WSP_GGML_F32x8_STORE
|
|
12498
|
+
#define WSP_GGML_F32X_MUL WSP_GGML_F32x8_MUL
|
|
12499
|
+
#define WSP_GGML_F32X_FMA WSP_GGML_F32x8_FMA
|
|
12500
|
+
#define GLA_VECTOR_SIZE 8
|
|
12501
|
+
#elif defined(__AVX512F__)
|
|
12502
|
+
#define WSP_GGML_F32X WSP_GGML_F32x16
|
|
12503
|
+
#define WSP_GGML_F32X_SET1 WSP_GGML_F32x16_SET1
|
|
12504
|
+
#define WSP_GGML_F32X_LOAD WSP_GGML_F32x16_LOAD
|
|
12505
|
+
#define WSP_GGML_F32X_STORE WSP_GGML_F32x16_STORE
|
|
12506
|
+
#define WSP_GGML_F32X_MUL WSP_GGML_F32x16_MUL
|
|
12507
|
+
#define WSP_GGML_F32X_FMA WSP_GGML_F32x16_FMA
|
|
12508
|
+
#define GLA_VECTOR_SIZE 16
|
|
12509
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
12510
|
+
#define WSP_GGML_F32X WSP_GGML_F32x4
|
|
12511
|
+
#define WSP_GGML_F32X_SET1 WSP_GGML_F32x4_SET1
|
|
12512
|
+
#define WSP_GGML_F32X_LOAD WSP_GGML_F32x4_LOAD
|
|
12513
|
+
#define WSP_GGML_F32X_STORE WSP_GGML_F32x4_STORE
|
|
12514
|
+
#define WSP_GGML_F32X_MUL WSP_GGML_F32x4_MUL
|
|
12515
|
+
#define WSP_GGML_F32X_FMA WSP_GGML_F32x4_FMA
|
|
12516
|
+
#define GLA_VECTOR_SIZE 4
|
|
12517
|
+
#endif
|
|
12518
|
+
|
|
12519
|
+
#ifdef GLA_VECTOR_SIZE
|
|
12520
|
+
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
|
|
12521
|
+
|
|
12522
|
+
for (int64_t t = 0; t < T; t++) {
|
|
12523
|
+
size_t t_offset = t * t_stride;
|
|
12524
|
+
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
12525
|
+
float * state_cur = state + state_offset;
|
|
12526
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
|
|
12527
|
+
|
|
12528
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
12529
|
+
size_t h_offset = h * h_stride;
|
|
12530
|
+
size_t t_h_offset = t_offset + h_offset;
|
|
12531
|
+
size_t h_2d_offset = h * h_stride_2d;
|
|
12532
|
+
|
|
12533
|
+
for (int64_t i = 0; i < head_size; i++) {
|
|
12534
|
+
size_t t_h_i_offset = t_h_offset + i;
|
|
12535
|
+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
|
12536
|
+
|
|
12537
|
+
float k_val = k[t_h_i_offset];
|
|
12538
|
+
float q_val = q[t_h_i_offset] * scale;
|
|
12539
|
+
float g_val = g[t_h_i_offset];
|
|
12540
|
+
|
|
12541
|
+
// Broadcast scalar values to vectors
|
|
12542
|
+
WSP_GGML_F32X k_vec = WSP_GGML_F32X_SET1(k_val);
|
|
12543
|
+
WSP_GGML_F32X q_vec = WSP_GGML_F32X_SET1(q_val);
|
|
12544
|
+
WSP_GGML_F32X g_vec = WSP_GGML_F32X_SET1(g_val);
|
|
12545
|
+
|
|
12546
|
+
for (int64_t j = 0; j < vec_count; j++) {
|
|
12547
|
+
size_t base_j = j * GLA_VECTOR_SIZE;
|
|
12548
|
+
size_t t_h_j_offset = t_h_offset + base_j;
|
|
12549
|
+
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
|
12550
|
+
|
|
12551
|
+
// Load x elements at once
|
|
12552
|
+
WSP_GGML_F32X v_vec = WSP_GGML_F32X_LOAD(&v[t_h_j_offset]);
|
|
12553
|
+
WSP_GGML_F32X prev_state_vec = WSP_GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
|
|
12554
|
+
WSP_GGML_F32X dst_vec = WSP_GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
|
|
12555
|
+
|
|
12556
|
+
// Compute kv = v * k
|
|
12557
|
+
WSP_GGML_F32X kv_vec = WSP_GGML_F32X_MUL(v_vec, k_vec);
|
|
12558
|
+
|
|
12559
|
+
// Compute temp = prev_state * g + kv
|
|
12560
|
+
WSP_GGML_F32X temp_vec = WSP_GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
|
|
12561
|
+
|
|
12562
|
+
// Update dst: dst += temp * q
|
|
12563
|
+
dst_vec = WSP_GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
|
|
12564
|
+
WSP_GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
|
|
12565
|
+
|
|
12566
|
+
// Update state
|
|
12567
|
+
WSP_GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
|
|
12568
|
+
}
|
|
12569
|
+
|
|
12570
|
+
// Handle remaining elements, this will not be used.
|
|
12571
|
+
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
|
|
12572
|
+
size_t t_h_j_offset = t_h_offset + j;
|
|
12573
|
+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
12574
|
+
float v_val = v[t_h_j_offset];
|
|
12575
|
+
float kv_val = v_val * k_val;
|
|
12576
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
12577
|
+
float temp_val = kv_val + prev_state_val * g_val;
|
|
12578
|
+
dst_data[t_h_j_offset] += temp_val * q_val;
|
|
12579
|
+
state_cur[h_2d_i_j_offset] = temp_val;
|
|
12580
|
+
}
|
|
12581
|
+
}
|
|
12582
|
+
}
|
|
12583
|
+
}
|
|
12584
|
+
|
|
12585
|
+
#else
|
|
12586
|
+
for (int64_t t = 0; t < T; t++) {
|
|
12587
|
+
size_t t_offset = t * t_stride;
|
|
12588
|
+
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
12589
|
+
float * state_cur = state + state_offset;
|
|
12590
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
|
|
12591
|
+
|
|
12592
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
12593
|
+
size_t h_offset = h * h_stride;
|
|
12594
|
+
size_t t_h_offset = t_offset + h_offset;
|
|
12595
|
+
size_t h_2d_offset = h * h_stride_2d;
|
|
12596
|
+
|
|
12597
|
+
for (int64_t i = 0; i < head_size; i++) {
|
|
12598
|
+
size_t t_h_i_offset = t_h_offset + i;
|
|
12599
|
+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
|
12600
|
+
|
|
12601
|
+
float k_val = k[t_h_i_offset];
|
|
12602
|
+
float q_val = q[t_h_i_offset] * scale;
|
|
12603
|
+
float g_val = g[t_h_i_offset];
|
|
12604
|
+
|
|
12605
|
+
for (int64_t j = 0; j < head_size; j++) {
|
|
12606
|
+
size_t t_h_j_offset = t_h_offset + j;
|
|
12607
|
+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
12608
|
+
|
|
12609
|
+
float v_val = v[t_h_j_offset];
|
|
12610
|
+
float kv_val = v_val * k_val;
|
|
12611
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
12612
|
+
float temp_val = prev_state_val * g_val + kv_val;
|
|
12613
|
+
dst_data[t_h_j_offset] += temp_val * q_val;
|
|
12614
|
+
state_cur[h_2d_i_j_offset] = temp_val;
|
|
12615
|
+
}
|
|
12616
|
+
}
|
|
12617
|
+
}
|
|
12618
|
+
}
|
|
12619
|
+
#endif
|
|
12620
|
+
}
|
|
12621
|
+
|
|
12622
|
+
|
|
12623
|
+
static void wsp_ggml_compute_forward_gla(
|
|
12624
|
+
const struct wsp_ggml_compute_params * params,
|
|
12625
|
+
struct wsp_ggml_tensor * dst) {
|
|
12626
|
+
|
|
12627
|
+
const struct wsp_ggml_tensor * src0 = dst->src[0];
|
|
12628
|
+
|
|
12629
|
+
switch (src0->type) {
|
|
12630
|
+
case WSP_GGML_TYPE_F32:
|
|
12631
|
+
{
|
|
12632
|
+
wsp_ggml_compute_forward_gla_f32(params, dst);
|
|
12633
|
+
} break;
|
|
12634
|
+
default:
|
|
12635
|
+
{
|
|
12636
|
+
WSP_GGML_ABORT("fatal error");
|
|
12637
|
+
}
|
|
12638
|
+
}
|
|
12639
|
+
}
|
|
12640
|
+
|
|
11845
12641
|
// wsp_ggml_compute_forward_map_unary
|
|
11846
12642
|
|
|
11847
12643
|
static void wsp_ggml_compute_forward_map_unary_f32(
|
|
@@ -12135,22 +12931,22 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
12135
12931
|
const struct wsp_ggml_compute_params * params,
|
|
12136
12932
|
struct wsp_ggml_tensor * dst) {
|
|
12137
12933
|
|
|
12138
|
-
const struct wsp_ggml_tensor *
|
|
12139
|
-
const struct wsp_ggml_tensor *
|
|
12140
|
-
const struct wsp_ggml_tensor *
|
|
12934
|
+
const struct wsp_ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
|
|
12935
|
+
const struct wsp_ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
|
|
12936
|
+
const struct wsp_ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
|
|
12141
12937
|
|
|
12142
12938
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst));
|
|
12143
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(
|
|
12144
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(
|
|
12145
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(
|
|
12146
|
-
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(
|
|
12939
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0f));
|
|
12940
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1f));
|
|
12941
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(grad));
|
|
12942
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0f, src1f) && wsp_ggml_are_same_shape(src0f, dst));
|
|
12147
12943
|
|
|
12148
12944
|
const int64_t ith = params->ith;
|
|
12149
12945
|
const int64_t nth = params->nth;
|
|
12150
12946
|
|
|
12151
12947
|
// TODO: handle transposed/permuted matrices
|
|
12152
|
-
const int64_t nc =
|
|
12153
|
-
const int64_t nr = wsp_ggml_nrows(
|
|
12948
|
+
const int64_t nc = src0f->ne[0];
|
|
12949
|
+
const int64_t nr = wsp_ggml_nrows(src0f);
|
|
12154
12950
|
|
|
12155
12951
|
// rows per thread
|
|
12156
12952
|
const int64_t dr = (nr + nth - 1)/nth;
|
|
@@ -12159,12 +12955,12 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
12159
12955
|
const int64_t ir0 = dr*ith;
|
|
12160
12956
|
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
12161
12957
|
|
|
12162
|
-
const float d_by_nr = ((const float *)
|
|
12958
|
+
const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
|
|
12163
12959
|
|
|
12164
12960
|
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
|
12165
|
-
float
|
|
12166
|
-
float * s0 = (float *)((char *)
|
|
12167
|
-
float * s1 = (float *)((char *)
|
|
12961
|
+
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
12962
|
+
const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
|
|
12963
|
+
const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
|
|
12168
12964
|
|
|
12169
12965
|
#ifndef NDEBUG
|
|
12170
12966
|
for (int64_t i = 0; i < nc; ++i) {
|
|
@@ -12177,11 +12973,11 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
12177
12973
|
// soft_max
|
|
12178
12974
|
float max = -INFINITY;
|
|
12179
12975
|
wsp_ggml_vec_max_f32(nc, &max, s0);
|
|
12180
|
-
wsp_ggml_float sum = wsp_ggml_vec_soft_max_f32(nc, ds0, s0, max);
|
|
12976
|
+
const wsp_ggml_float sum = wsp_ggml_vec_soft_max_f32(nc, ds0, s0, max);
|
|
12181
12977
|
assert(sum > 0.0);
|
|
12182
12978
|
wsp_ggml_vec_scale_f32(nc, ds0, 1.0/sum);
|
|
12183
12979
|
|
|
12184
|
-
// grad(
|
|
12980
|
+
// grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
|
|
12185
12981
|
wsp_ggml_vec_sub_f32(nc, ds0, ds0, s1);
|
|
12186
12982
|
wsp_ggml_vec_scale_f32(nc, ds0, d_by_nr);
|
|
12187
12983
|
|
|
@@ -12304,6 +13100,9 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
12304
13100
|
return;
|
|
12305
13101
|
}
|
|
12306
13102
|
|
|
13103
|
+
// extra_buffer op?
|
|
13104
|
+
if (wsp_ggml_cpu_extra_compute_forward(params, tensor)) return;
|
|
13105
|
+
|
|
12307
13106
|
switch (tensor->op) {
|
|
12308
13107
|
case WSP_GGML_OP_DUP:
|
|
12309
13108
|
{
|
|
@@ -12475,7 +13274,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
12475
13274
|
} break;
|
|
12476
13275
|
case WSP_GGML_OP_SOFT_MAX_BACK:
|
|
12477
13276
|
{
|
|
12478
|
-
|
|
13277
|
+
wsp_ggml_compute_forward_soft_max_ext_back(params, tensor);
|
|
12479
13278
|
} break;
|
|
12480
13279
|
case WSP_GGML_OP_ROPE:
|
|
12481
13280
|
{
|
|
@@ -12525,6 +13324,10 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
12525
13324
|
{
|
|
12526
13325
|
wsp_ggml_compute_forward_pad(params, tensor);
|
|
12527
13326
|
} break;
|
|
13327
|
+
case WSP_GGML_OP_PAD_REFLECT_1D:
|
|
13328
|
+
{
|
|
13329
|
+
wsp_ggml_compute_forward_pad_reflect_1d(params, tensor);
|
|
13330
|
+
} break;
|
|
12528
13331
|
case WSP_GGML_OP_ARANGE:
|
|
12529
13332
|
{
|
|
12530
13333
|
wsp_ggml_compute_forward_arange(params, tensor);
|
|
@@ -12584,6 +13387,10 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
12584
13387
|
{
|
|
12585
13388
|
wsp_ggml_compute_forward_rwkv_wkv6(params, tensor);
|
|
12586
13389
|
} break;
|
|
13390
|
+
case WSP_GGML_OP_GATED_LINEAR_ATTN:
|
|
13391
|
+
{
|
|
13392
|
+
wsp_ggml_compute_forward_gla(params, tensor);
|
|
13393
|
+
} break;
|
|
12587
13394
|
case WSP_GGML_OP_MAP_UNARY:
|
|
12588
13395
|
{
|
|
12589
13396
|
wsp_ggml_unary_op_f32_t fun;
|
|
@@ -12867,6 +13674,7 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
12867
13674
|
} break;
|
|
12868
13675
|
case WSP_GGML_OP_UPSCALE:
|
|
12869
13676
|
case WSP_GGML_OP_PAD:
|
|
13677
|
+
case WSP_GGML_OP_PAD_REFLECT_1D:
|
|
12870
13678
|
case WSP_GGML_OP_ARANGE:
|
|
12871
13679
|
case WSP_GGML_OP_TIMESTEP_EMBEDDING:
|
|
12872
13680
|
case WSP_GGML_OP_ARGSORT:
|
|
@@ -12881,6 +13689,7 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
12881
13689
|
case WSP_GGML_OP_WIN_UNPART:
|
|
12882
13690
|
case WSP_GGML_OP_GET_REL_POS:
|
|
12883
13691
|
case WSP_GGML_OP_RWKV_WKV6:
|
|
13692
|
+
case WSP_GGML_OP_GATED_LINEAR_ATTN:
|
|
12884
13693
|
case WSP_GGML_OP_MAP_UNARY:
|
|
12885
13694
|
case WSP_GGML_OP_MAP_BINARY:
|
|
12886
13695
|
case WSP_GGML_OP_MAP_CUSTOM1_F32:
|
|
@@ -12956,7 +13765,7 @@ static thread_ret_t wsp_ggml_graph_compute_secondary_thread(void* data);
|
|
|
12956
13765
|
#include "windows.h"
|
|
12957
13766
|
|
|
12958
13767
|
// TODO: support > 64 CPUs
|
|
12959
|
-
bool wsp_ggml_thread_apply_affinity(bool * mask) {
|
|
13768
|
+
static bool wsp_ggml_thread_apply_affinity(bool * mask) {
|
|
12960
13769
|
HANDLE h = GetCurrentThread();
|
|
12961
13770
|
uint64_t bitmask = 0ULL;
|
|
12962
13771
|
|
|
@@ -13246,140 +14055,148 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(
|
|
|
13246
14055
|
|
|
13247
14056
|
size_t cur = 0;
|
|
13248
14057
|
|
|
13249
|
-
|
|
13250
|
-
|
|
13251
|
-
|
|
13252
|
-
|
|
13253
|
-
|
|
13254
|
-
|
|
13255
|
-
(node->
|
|
13256
|
-
|
|
14058
|
+
if (!wsp_ggml_cpu_extra_work_size(n_threads, node, &cur)) {
|
|
14059
|
+
|
|
14060
|
+
switch (node->op) {
|
|
14061
|
+
case WSP_GGML_OP_CPY:
|
|
14062
|
+
case WSP_GGML_OP_DUP:
|
|
14063
|
+
{
|
|
14064
|
+
if (wsp_ggml_is_quantized(node->type) ||
|
|
14065
|
+
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
|
|
14066
|
+
(node->src[0]->type == WSP_GGML_TYPE_F16 && node->src[1] && node->src[1]->type == WSP_GGML_TYPE_BF16) ||
|
|
14067
|
+
(node->src[0]->type == WSP_GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == WSP_GGML_TYPE_F16)) {
|
|
14068
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
|
14069
|
+
}
|
|
14070
|
+
} break;
|
|
14071
|
+
case WSP_GGML_OP_ADD:
|
|
14072
|
+
case WSP_GGML_OP_ADD1:
|
|
14073
|
+
{
|
|
14074
|
+
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
14075
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
|
14076
|
+
}
|
|
14077
|
+
} break;
|
|
14078
|
+
case WSP_GGML_OP_ACC:
|
|
14079
|
+
{
|
|
14080
|
+
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
14081
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
|
|
14082
|
+
}
|
|
14083
|
+
} break;
|
|
14084
|
+
case WSP_GGML_OP_COUNT_EQUAL:
|
|
14085
|
+
{
|
|
14086
|
+
cur = wsp_ggml_type_size(node->type)*n_tasks;
|
|
14087
|
+
} break;
|
|
14088
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
14089
|
+
{
|
|
14090
|
+
const enum wsp_ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
|
|
14091
|
+
|
|
14092
|
+
if (node->src[1]->type != vec_dot_type) {
|
|
14093
|
+
cur = wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(node->src[1]));
|
|
14094
|
+
}
|
|
14095
|
+
} break;
|
|
14096
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
14097
|
+
{
|
|
14098
|
+
cur = 0;
|
|
14099
|
+
const struct wsp_ggml_tensor * src0 = node->src[0];
|
|
14100
|
+
const struct wsp_ggml_tensor * src1 = node->src[1];
|
|
14101
|
+
const struct wsp_ggml_tensor * ids = node->src[2];
|
|
14102
|
+
const enum wsp_ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
|
|
14103
|
+
const int n_as = src0->ne[2];
|
|
14104
|
+
// src1
|
|
14105
|
+
if (src1->type != vec_dot_type) {
|
|
14106
|
+
cur += wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)) + sizeof(int64_t);
|
|
14107
|
+
}
|
|
14108
|
+
// matrix_row_counts
|
|
14109
|
+
cur += n_as * sizeof(int64_t) + sizeof(int64_t);
|
|
14110
|
+
// matrix_rows
|
|
14111
|
+
cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
|
|
14112
|
+
// atomic_current_chunk
|
|
14113
|
+
cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
|
|
14114
|
+
} break;
|
|
14115
|
+
case WSP_GGML_OP_OUT_PROD:
|
|
14116
|
+
{
|
|
14117
|
+
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
14118
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
|
14119
|
+
}
|
|
14120
|
+
} break;
|
|
14121
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
14122
|
+
case WSP_GGML_OP_ROPE:
|
|
14123
|
+
case WSP_GGML_OP_ROPE_BACK:
|
|
14124
|
+
{
|
|
13257
14125
|
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
|
13258
|
-
}
|
|
13259
|
-
|
|
13260
|
-
|
|
13261
|
-
|
|
13262
|
-
|
|
13263
|
-
|
|
13264
|
-
|
|
13265
|
-
|
|
13266
|
-
|
|
13267
|
-
|
|
13268
|
-
|
|
13269
|
-
|
|
13270
|
-
|
|
13271
|
-
|
|
13272
|
-
|
|
13273
|
-
|
|
13274
|
-
|
|
13275
|
-
|
|
13276
|
-
|
|
13277
|
-
|
|
13278
|
-
|
|
13279
|
-
|
|
14126
|
+
} break;
|
|
14127
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
14128
|
+
{
|
|
14129
|
+
WSP_GGML_ASSERT(node->src[0]->ne[3] == 1);
|
|
14130
|
+
WSP_GGML_ASSERT(node->src[1]->ne[2] == 1);
|
|
14131
|
+
WSP_GGML_ASSERT(node->src[1]->ne[3] == 1);
|
|
14132
|
+
|
|
14133
|
+
const int64_t ne00 = node->src[0]->ne[0]; // K
|
|
14134
|
+
const int64_t ne01 = node->src[0]->ne[1]; // Cout
|
|
14135
|
+
const int64_t ne02 = node->src[0]->ne[2]; // Cin
|
|
14136
|
+
const int64_t ne10 = node->src[1]->ne[0]; // L
|
|
14137
|
+
const int64_t ne11 = node->src[1]->ne[1]; // Cin
|
|
14138
|
+
|
|
14139
|
+
if ((node->src[0]->type == WSP_GGML_TYPE_F16 ||
|
|
14140
|
+
node->src[0]->type == WSP_GGML_TYPE_BF16) &&
|
|
14141
|
+
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
14142
|
+
cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02;
|
|
14143
|
+
cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11;
|
|
14144
|
+
} else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
|
|
14145
|
+
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
14146
|
+
cur += sizeof(float)*ne00*ne01*ne02;
|
|
14147
|
+
cur += sizeof(float)*ne10*ne11;
|
|
14148
|
+
} else {
|
|
14149
|
+
WSP_GGML_ABORT("fatal error");
|
|
14150
|
+
}
|
|
14151
|
+
} break;
|
|
14152
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
14153
|
+
{
|
|
14154
|
+
const int64_t ne00 = node->src[0]->ne[0]; // W
|
|
14155
|
+
const int64_t ne01 = node->src[0]->ne[1]; // H
|
|
14156
|
+
const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
|
|
14157
|
+
const int64_t ne03 = node->src[0]->ne[3]; // Channels In
|
|
13280
14158
|
|
|
13281
|
-
|
|
13282
|
-
|
|
13283
|
-
|
|
13284
|
-
} break;
|
|
13285
|
-
case WSP_GGML_OP_MUL_MAT_ID:
|
|
13286
|
-
{
|
|
13287
|
-
cur = 0;
|
|
13288
|
-
const struct wsp_ggml_tensor * src0 = node->src[0];
|
|
13289
|
-
const struct wsp_ggml_tensor * src1 = node->src[1];
|
|
13290
|
-
const enum wsp_ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
|
|
13291
|
-
if (src1->type != vec_dot_type) {
|
|
13292
|
-
cur += wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1));
|
|
13293
|
-
}
|
|
13294
|
-
const int n_as = src0->ne[2];
|
|
13295
|
-
cur += WSP_GGML_PAD(cur, sizeof(int64_t)); // align
|
|
13296
|
-
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
|
13297
|
-
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
|
|
13298
|
-
} break;
|
|
13299
|
-
case WSP_GGML_OP_OUT_PROD:
|
|
13300
|
-
{
|
|
13301
|
-
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
13302
|
-
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
|
13303
|
-
}
|
|
13304
|
-
} break;
|
|
13305
|
-
case WSP_GGML_OP_SOFT_MAX:
|
|
13306
|
-
case WSP_GGML_OP_ROPE:
|
|
13307
|
-
{
|
|
13308
|
-
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
|
13309
|
-
} break;
|
|
13310
|
-
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
13311
|
-
{
|
|
13312
|
-
WSP_GGML_ASSERT(node->src[0]->ne[3] == 1);
|
|
13313
|
-
WSP_GGML_ASSERT(node->src[1]->ne[2] == 1);
|
|
13314
|
-
WSP_GGML_ASSERT(node->src[1]->ne[3] == 1);
|
|
13315
|
-
|
|
13316
|
-
const int64_t ne00 = node->src[0]->ne[0]; // K
|
|
13317
|
-
const int64_t ne01 = node->src[0]->ne[1]; // Cout
|
|
13318
|
-
const int64_t ne02 = node->src[0]->ne[2]; // Cin
|
|
13319
|
-
|
|
13320
|
-
const int64_t ne10 = node->src[1]->ne[0]; // L
|
|
13321
|
-
const int64_t ne11 = node->src[1]->ne[1]; // Cin
|
|
13322
|
-
|
|
13323
|
-
if ((node->src[0]->type == WSP_GGML_TYPE_F16 ||
|
|
13324
|
-
node->src[0]->type == WSP_GGML_TYPE_BF16) &&
|
|
13325
|
-
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
13326
|
-
cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02;
|
|
13327
|
-
cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11;
|
|
13328
|
-
} else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
|
|
13329
|
-
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
13330
|
-
cur += sizeof(float)*ne00*ne01*ne02;
|
|
13331
|
-
cur += sizeof(float)*ne10*ne11;
|
|
13332
|
-
} else {
|
|
13333
|
-
WSP_GGML_ABORT("fatal error");
|
|
13334
|
-
}
|
|
13335
|
-
} break;
|
|
13336
|
-
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
13337
|
-
{
|
|
13338
|
-
const int64_t ne00 = node->src[0]->ne[0]; // W
|
|
13339
|
-
const int64_t ne01 = node->src[0]->ne[1]; // H
|
|
13340
|
-
const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
|
|
13341
|
-
const int64_t ne03 = node->src[0]->ne[3]; // Channels In
|
|
13342
|
-
|
|
13343
|
-
const int64_t ne10 = node->src[1]->ne[0]; // W
|
|
13344
|
-
const int64_t ne11 = node->src[1]->ne[1]; // H
|
|
13345
|
-
const int64_t ne12 = node->src[1]->ne[2]; // Channels In
|
|
13346
|
-
|
|
13347
|
-
cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
|
13348
|
-
cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11*ne12;
|
|
13349
|
-
} break;
|
|
13350
|
-
case WSP_GGML_OP_FLASH_ATTN_EXT:
|
|
13351
|
-
{
|
|
13352
|
-
const int64_t ne00 = node->src[0]->ne[0]; // D
|
|
14159
|
+
const int64_t ne10 = node->src[1]->ne[0]; // W
|
|
14160
|
+
const int64_t ne11 = node->src[1]->ne[1]; // H
|
|
14161
|
+
const int64_t ne12 = node->src[1]->ne[2]; // Channels In
|
|
13353
14162
|
|
|
13354
|
-
|
|
13355
|
-
|
|
13356
|
-
|
|
13357
|
-
|
|
13358
|
-
|
|
13359
|
-
|
|
13360
|
-
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back
|
|
13361
|
-
if (node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
13362
|
-
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
|
13363
|
-
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
|
13364
|
-
} else if (node->src[1]->type == WSP_GGML_TYPE_F16) {
|
|
13365
|
-
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
|
13366
|
-
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
|
13367
|
-
} else if (node->src[1]->type == WSP_GGML_TYPE_BF16) {
|
|
13368
|
-
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
|
13369
|
-
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
|
13370
|
-
}
|
|
13371
|
-
} break;
|
|
14163
|
+
cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
|
14164
|
+
cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11*ne12;
|
|
14165
|
+
} break;
|
|
14166
|
+
case WSP_GGML_OP_FLASH_ATTN_EXT:
|
|
14167
|
+
{
|
|
14168
|
+
const int64_t ne00 = node->src[0]->ne[0]; // D
|
|
13372
14169
|
|
|
13373
|
-
|
|
13374
|
-
|
|
13375
|
-
|
|
13376
|
-
|
|
13377
|
-
|
|
13378
|
-
|
|
13379
|
-
|
|
13380
|
-
|
|
13381
|
-
|
|
13382
|
-
|
|
14170
|
+
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
|
14171
|
+
} break;
|
|
14172
|
+
case WSP_GGML_OP_FLASH_ATTN_BACK:
|
|
14173
|
+
{
|
|
14174
|
+
const int64_t D = node->src[0]->ne[0];
|
|
14175
|
+
const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
|
|
14176
|
+
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back
|
|
14177
|
+
if (node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
14178
|
+
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
|
14179
|
+
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
|
14180
|
+
} else if (node->src[1]->type == WSP_GGML_TYPE_F16) {
|
|
14181
|
+
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
|
14182
|
+
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
|
14183
|
+
} else if (node->src[1]->type == WSP_GGML_TYPE_BF16) {
|
|
14184
|
+
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
|
14185
|
+
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
|
14186
|
+
}
|
|
14187
|
+
} break;
|
|
14188
|
+
|
|
14189
|
+
case WSP_GGML_OP_CROSS_ENTROPY_LOSS:
|
|
14190
|
+
{
|
|
14191
|
+
cur = wsp_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
|
14192
|
+
} break;
|
|
14193
|
+
case WSP_GGML_OP_COUNT:
|
|
14194
|
+
{
|
|
14195
|
+
WSP_GGML_ABORT("fatal error");
|
|
14196
|
+
}
|
|
14197
|
+
default:
|
|
14198
|
+
break;
|
|
14199
|
+
}
|
|
13383
14200
|
}
|
|
13384
14201
|
|
|
13385
14202
|
work_size = MAX(work_size, cur);
|
|
@@ -13414,20 +14231,24 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) {
|
|
|
13414
14231
|
/*.threadpool=*/ tp,
|
|
13415
14232
|
};
|
|
13416
14233
|
|
|
13417
|
-
for (int node_n = 0; node_n < cgraph->n_nodes &&
|
|
14234
|
+
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
|
13418
14235
|
struct wsp_ggml_tensor * node = cgraph->nodes[node_n];
|
|
13419
14236
|
|
|
13420
14237
|
wsp_ggml_compute_forward(¶ms, node);
|
|
13421
14238
|
|
|
13422
14239
|
if (state->ith == 0 && cplan->abort_callback &&
|
|
13423
14240
|
cplan->abort_callback(cplan->abort_callback_data)) {
|
|
13424
|
-
tp->abort
|
|
14241
|
+
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
|
|
13425
14242
|
tp->ec = WSP_GGML_STATUS_ABORTED;
|
|
13426
14243
|
}
|
|
13427
14244
|
|
|
13428
|
-
|
|
14245
|
+
if (node_n + 1 < cgraph->n_nodes) {
|
|
14246
|
+
wsp_ggml_barrier(state->threadpool);
|
|
14247
|
+
}
|
|
13429
14248
|
}
|
|
13430
14249
|
|
|
14250
|
+
wsp_ggml_barrier(state->threadpool);
|
|
14251
|
+
|
|
13431
14252
|
return 0;
|
|
13432
14253
|
}
|
|
13433
14254
|
|
|
@@ -13578,29 +14399,6 @@ static void wsp_ggml_graph_compute_kickoff(struct wsp_ggml_threadpool * threadpo
|
|
|
13578
14399
|
|
|
13579
14400
|
#endif // WSP_GGML_USE_OPENMP
|
|
13580
14401
|
|
|
13581
|
-
void wsp_ggml_threadpool_params_init(struct wsp_ggml_threadpool_params * p, int n_threads) {
|
|
13582
|
-
p->n_threads = n_threads;
|
|
13583
|
-
p->prio = 0; // default priority (usually means normal or inherited)
|
|
13584
|
-
p->poll = 50; // hybrid-polling enabled
|
|
13585
|
-
p->strict_cpu = false; // no strict placement (all threads share same cpumask)
|
|
13586
|
-
p->paused = false; // threads are ready to go
|
|
13587
|
-
memset(p->cpumask, 0, WSP_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
|
|
13588
|
-
}
|
|
13589
|
-
|
|
13590
|
-
struct wsp_ggml_threadpool_params wsp_ggml_threadpool_params_default(int n_threads) {
|
|
13591
|
-
struct wsp_ggml_threadpool_params p;
|
|
13592
|
-
wsp_ggml_threadpool_params_init(&p, n_threads);
|
|
13593
|
-
return p;
|
|
13594
|
-
}
|
|
13595
|
-
|
|
13596
|
-
bool wsp_ggml_threadpool_params_match(const struct wsp_ggml_threadpool_params * p0, const struct wsp_ggml_threadpool_params * p1) {
|
|
13597
|
-
if (p0->n_threads != p1->n_threads ) return false;
|
|
13598
|
-
if (p0->prio != p1->prio ) return false;
|
|
13599
|
-
if (p0->poll != p1->poll ) return false;
|
|
13600
|
-
if (p0->strict_cpu != p1->strict_cpu ) return false;
|
|
13601
|
-
return memcmp(p0->cpumask, p1->cpumask, WSP_GGML_MAX_N_THREADS) == 0;
|
|
13602
|
-
}
|
|
13603
|
-
|
|
13604
14402
|
static struct wsp_ggml_threadpool * wsp_ggml_threadpool_new_impl(
|
|
13605
14403
|
struct wsp_ggml_threadpool_params * tpp,
|
|
13606
14404
|
struct wsp_ggml_cgraph * cgraph,
|
|
@@ -13617,7 +14415,7 @@ static struct wsp_ggml_threadpool * wsp_ggml_threadpool_new_impl(
|
|
|
13617
14415
|
threadpool->current_chunk = 0;
|
|
13618
14416
|
threadpool->stop = false;
|
|
13619
14417
|
threadpool->pause = tpp->paused;
|
|
13620
|
-
threadpool->abort =
|
|
14418
|
+
threadpool->abort = -1;
|
|
13621
14419
|
threadpool->workers = NULL;
|
|
13622
14420
|
threadpool->n_threads_max = tpp->n_threads;
|
|
13623
14421
|
threadpool->n_threads_cur = tpp->n_threads;
|
|
@@ -13696,7 +14494,7 @@ enum wsp_ggml_status wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, str
|
|
|
13696
14494
|
threadpool->cgraph = cgraph;
|
|
13697
14495
|
threadpool->cplan = cplan;
|
|
13698
14496
|
threadpool->current_chunk = 0;
|
|
13699
|
-
threadpool->abort =
|
|
14497
|
+
threadpool->abort = -1;
|
|
13700
14498
|
threadpool->ec = WSP_GGML_STATUS_SUCCESS;
|
|
13701
14499
|
}
|
|
13702
14500
|
|
|
@@ -13895,16 +14693,32 @@ int wsp_ggml_cpu_has_vsx(void) {
|
|
|
13895
14693
|
#endif
|
|
13896
14694
|
}
|
|
13897
14695
|
|
|
14696
|
+
int wsp_ggml_cpu_has_vxe(void) {
|
|
14697
|
+
#if defined(__VXE__) || defined(__VXE2__)
|
|
14698
|
+
return 1;
|
|
14699
|
+
#else
|
|
14700
|
+
return 0;
|
|
14701
|
+
#endif
|
|
14702
|
+
}
|
|
14703
|
+
|
|
13898
14704
|
int wsp_ggml_cpu_has_neon(void) {
|
|
13899
|
-
#if defined(__ARM_ARCH)
|
|
14705
|
+
#if defined(__ARM_ARCH) && defined(__ARM_NEON)
|
|
13900
14706
|
return wsp_ggml_arm_arch_features.has_neon;
|
|
13901
14707
|
#else
|
|
13902
14708
|
return 0;
|
|
13903
14709
|
#endif
|
|
13904
14710
|
}
|
|
13905
14711
|
|
|
14712
|
+
int wsp_ggml_cpu_has_dotprod(void) {
|
|
14713
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
|
|
14714
|
+
return wsp_ggml_arm_arch_features.has_dotprod;
|
|
14715
|
+
#else
|
|
14716
|
+
return 0;
|
|
14717
|
+
#endif
|
|
14718
|
+
}
|
|
14719
|
+
|
|
13906
14720
|
int wsp_ggml_cpu_has_sve(void) {
|
|
13907
|
-
#if defined(__ARM_ARCH)
|
|
14721
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
|
|
13908
14722
|
return wsp_ggml_arm_arch_features.has_sve;
|
|
13909
14723
|
#else
|
|
13910
14724
|
return 0;
|
|
@@ -13912,7 +14726,7 @@ int wsp_ggml_cpu_has_sve(void) {
|
|
|
13912
14726
|
}
|
|
13913
14727
|
|
|
13914
14728
|
int wsp_ggml_cpu_has_matmul_int8(void) {
|
|
13915
|
-
#if defined(__ARM_ARCH)
|
|
14729
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
13916
14730
|
return wsp_ggml_arm_arch_features.has_i8mm;
|
|
13917
14731
|
#else
|
|
13918
14732
|
return 0;
|
|
@@ -13920,13 +14734,21 @@ int wsp_ggml_cpu_has_matmul_int8(void) {
|
|
|
13920
14734
|
}
|
|
13921
14735
|
|
|
13922
14736
|
int wsp_ggml_cpu_get_sve_cnt(void) {
|
|
13923
|
-
#if defined(__ARM_ARCH)
|
|
14737
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
|
|
13924
14738
|
return wsp_ggml_arm_arch_features.sve_cnt;
|
|
13925
14739
|
#else
|
|
13926
14740
|
return 0;
|
|
13927
14741
|
#endif
|
|
13928
14742
|
}
|
|
13929
14743
|
|
|
14744
|
+
int wsp_ggml_cpu_has_sme(void) {
|
|
14745
|
+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
|
|
14746
|
+
return wsp_ggml_arm_arch_features.has_sme;
|
|
14747
|
+
#else
|
|
14748
|
+
return 0;
|
|
14749
|
+
#endif
|
|
14750
|
+
}
|
|
14751
|
+
|
|
13930
14752
|
void wsp_ggml_cpu_init(void) {
|
|
13931
14753
|
// needed to initialize f16 tables
|
|
13932
14754
|
{
|