llama_cpp 0.10.3 → 0.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/LICENSE.txt +1 -1
- data/ext/llama_cpp/extconf.rb +35 -110
- data/ext/llama_cpp/llama_cpp.cpp +52 -28
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -1
- data/vendor/include/.gitkeep +0 -0
- data/vendor/lib/.gitkeep +0 -0
- data/vendor/tmp/llama.cpp/Makefile +758 -0
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.c +6 -2
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.cu +73 -63
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-impl.h +1 -0
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.m +43 -20
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.metal +464 -245
- data/vendor/tmp/llama.cpp/ggml-opencl.h +25 -0
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.c +61 -57
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.c +171 -5
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.h +1 -0
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.cpp +222 -105
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.h +31 -32
- data/vendor/tmp/llama.cpp/scripts/get-flags.mk +38 -0
- metadata +30 -27
- data/ext/llama_cpp/src/ggml-opencl.h +0 -25
- data/ext/llama_cpp/src/llama-util.h +0 -546
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/LICENSE +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.c +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend-impl.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.c +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-opencl.cpp +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/unicode.h +0 -0
@@ -0,0 +1,25 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include "ggml.h"
|
4
|
+
|
5
|
+
#ifdef __cplusplus
|
6
|
+
extern "C" {
|
7
|
+
#endif
|
8
|
+
|
9
|
+
GGML_API void ggml_cl_init(void);
|
10
|
+
|
11
|
+
GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
12
|
+
GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
13
|
+
GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
14
|
+
GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
|
15
|
+
|
16
|
+
GGML_API void * ggml_cl_host_malloc(size_t size);
|
17
|
+
GGML_API void ggml_cl_host_free(void * ptr);
|
18
|
+
|
19
|
+
GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
|
20
|
+
|
21
|
+
GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
|
22
|
+
|
23
|
+
#ifdef __cplusplus
|
24
|
+
}
|
25
|
+
#endif
|
@@ -410,13 +410,17 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
|
410
410
|
|
411
411
|
#if !defined(__ARM_FEATURE_DOTPROD)
|
412
412
|
|
413
|
-
inline static int32x4_t
|
413
|
+
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
414
414
|
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
415
415
|
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
416
416
|
|
417
417
|
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
418
418
|
}
|
419
419
|
|
420
|
+
#else
|
421
|
+
|
422
|
+
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
|
423
|
+
|
420
424
|
#endif
|
421
425
|
|
422
426
|
#endif
|
@@ -2481,8 +2485,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
|
|
2481
2485
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2482
2486
|
|
2483
2487
|
// dot product into int32x4_t
|
2484
|
-
const int32x4_t p_0 =
|
2485
|
-
const int32x4_t p_1 =
|
2488
|
+
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
2489
|
+
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
2486
2490
|
|
2487
2491
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2488
2492
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
@@ -2769,8 +2773,8 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
|
|
2769
2773
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2770
2774
|
|
2771
2775
|
// dot product into int32x4_t
|
2772
|
-
const int32x4_t p_0 =
|
2773
|
-
const int32x4_t p_1 =
|
2776
|
+
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
|
2777
|
+
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
|
2774
2778
|
|
2775
2779
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
2776
2780
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
@@ -2936,11 +2940,11 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|
2936
2940
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2937
2941
|
|
2938
2942
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
2939
|
-
|
2940
|
-
|
2943
|
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
2944
|
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2941
2945
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
2942
|
-
|
2943
|
-
|
2946
|
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
2947
|
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
2944
2948
|
}
|
2945
2949
|
|
2946
2950
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
@@ -3228,11 +3232,11 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
|
|
3228
3232
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
3229
3233
|
|
3230
3234
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3231
|
-
|
3232
|
-
|
3235
|
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
3236
|
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
3233
3237
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3234
|
-
|
3235
|
-
|
3238
|
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
3239
|
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
3236
3240
|
}
|
3237
3241
|
|
3238
3242
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
@@ -3483,12 +3487,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|
3483
3487
|
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
3484
3488
|
|
3485
3489
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3486
|
-
|
3487
|
-
|
3490
|
+
ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
3491
|
+
ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
3488
3492
|
|
3489
3493
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3490
|
-
|
3491
|
-
|
3494
|
+
ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
3495
|
+
ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
3492
3496
|
}
|
3493
3497
|
|
3494
3498
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
@@ -3598,8 +3602,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3598
3602
|
// We use this macro instead of a function call because for some reason
|
3599
3603
|
// the code runs 2-3% slower, even if the function is declared inline
|
3600
3604
|
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
3601
|
-
isum += vaddvq_s32(
|
3602
|
-
isum += vaddvq_s32(
|
3605
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
3606
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
3603
3607
|
|
3604
3608
|
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
3605
3609
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
|
@@ -3973,10 +3977,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3973
3977
|
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
3974
3978
|
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
3975
3979
|
|
3976
|
-
isum1 += vaddvq_s32(
|
3977
|
-
isum2 += vaddvq_s32(
|
3978
|
-
isum1 += vaddvq_s32(
|
3979
|
-
isum2 += vaddvq_s32(
|
3980
|
+
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
|
3981
|
+
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
|
3982
|
+
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
|
3983
|
+
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
3980
3984
|
|
3981
3985
|
sum += d * (isum1 + isum2);
|
3982
3986
|
}
|
@@ -4256,10 +4260,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4256
4260
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
4257
4261
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
4258
4262
|
|
4259
|
-
isum += vaddvq_s32(
|
4260
|
-
isum += vaddvq_s32(
|
4261
|
-
isum += vaddvq_s32(
|
4262
|
-
isum += vaddvq_s32(
|
4263
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
4264
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
4265
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
4266
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
4263
4267
|
|
4264
4268
|
scale += 4;
|
4265
4269
|
|
@@ -4273,10 +4277,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4273
4277
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
4274
4278
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
4275
4279
|
|
4276
|
-
isum += vaddvq_s32(
|
4277
|
-
isum += vaddvq_s32(
|
4278
|
-
isum += vaddvq_s32(
|
4279
|
-
isum += vaddvq_s32(
|
4280
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
4281
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
4282
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
4283
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
4280
4284
|
|
4281
4285
|
scale += 4;
|
4282
4286
|
|
@@ -4757,10 +4761,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4757
4761
|
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
4758
4762
|
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
4759
4763
|
|
4760
|
-
isum += vaddvq_s32(
|
4761
|
-
isum += vaddvq_s32(
|
4762
|
-
isum += vaddvq_s32(
|
4763
|
-
isum += vaddvq_s32(
|
4764
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
4765
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
4766
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
4767
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
4764
4768
|
|
4765
4769
|
sum += d * isum;
|
4766
4770
|
|
@@ -5109,14 +5113,14 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5109
5113
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5110
5114
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
5111
5115
|
|
5112
|
-
const int32x4_t p1 =
|
5116
|
+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
5113
5117
|
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
5114
5118
|
|
5115
5119
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
5116
5120
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
5117
5121
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
5118
5122
|
|
5119
|
-
const int32x4_t p2 =
|
5123
|
+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
5120
5124
|
|
5121
5125
|
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
5122
5126
|
}
|
@@ -5449,13 +5453,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5449
5453
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5450
5454
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
5451
5455
|
|
5452
|
-
const int32x4_t p1 =
|
5456
|
+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
5453
5457
|
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
|
5454
5458
|
|
5455
5459
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
5456
5460
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
5457
5461
|
|
5458
|
-
const int32x4_t p2 =
|
5462
|
+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
5459
5463
|
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
5460
5464
|
|
5461
5465
|
sumf += d * (sumi1 + sumi2);
|
@@ -5722,8 +5726,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5722
5726
|
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
5723
5727
|
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
5724
5728
|
|
5725
|
-
sumi += vaddvq_s32(
|
5726
|
-
sumi += vaddvq_s32(
|
5729
|
+
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
5730
|
+
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
5727
5731
|
}
|
5728
5732
|
|
5729
5733
|
sumf += d * sumi - dmin * sumi_mins;
|
@@ -6112,10 +6116,10 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6112
6116
|
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
6113
6117
|
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
6114
6118
|
|
6115
|
-
int32_t sumi1 = sc[0] * vaddvq_s32(
|
6116
|
-
int32_t sumi2 = sc[1] * vaddvq_s32(
|
6117
|
-
int32_t sumi3 = sc[2] * vaddvq_s32(
|
6118
|
-
int32_t sumi4 = sc[3] * vaddvq_s32(
|
6119
|
+
int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
|
6120
|
+
int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
|
6121
|
+
int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
6122
|
+
int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
6119
6123
|
|
6120
6124
|
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
6121
6125
|
}
|
@@ -6399,10 +6403,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6399
6403
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
6400
6404
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
6401
6405
|
|
6402
|
-
isum += vaddvq_s32(
|
6403
|
-
vaddvq_s32(
|
6404
|
-
vaddvq_s32(
|
6405
|
-
vaddvq_s32(
|
6406
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
6407
|
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
6408
|
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
6409
|
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
6406
6410
|
|
6407
6411
|
scale += 4;
|
6408
6412
|
|
@@ -6426,10 +6430,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6426
6430
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
6427
6431
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
6428
6432
|
|
6429
|
-
isum += vaddvq_s32(
|
6430
|
-
vaddvq_s32(
|
6431
|
-
vaddvq_s32(
|
6432
|
-
vaddvq_s32(
|
6433
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
6434
|
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
6435
|
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
6436
|
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
6433
6437
|
scale += 4;
|
6434
6438
|
}
|
6435
6439
|
//sum += isum * d_all * y[i].d;
|
@@ -6816,10 +6820,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6816
6820
|
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
6817
6821
|
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
6818
6822
|
|
6819
|
-
isum += vaddvq_s32(
|
6820
|
-
vaddvq_s32(
|
6821
|
-
vaddvq_s32(
|
6822
|
-
vaddvq_s32(
|
6823
|
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
6824
|
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
6825
|
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
6826
|
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
6823
6827
|
|
6824
6828
|
sum += isum * d_all * y[i].d;
|
6825
6829
|
|
@@ -4766,8 +4766,11 @@ struct ggml_tensor * ggml_get_rows(
|
|
4766
4766
|
}
|
4767
4767
|
|
4768
4768
|
// TODO: implement non F32 return
|
4769
|
-
|
4770
|
-
|
4769
|
+
enum ggml_type type = GGML_TYPE_F32;
|
4770
|
+
if (a->type == GGML_TYPE_I32) {
|
4771
|
+
type = a->type;
|
4772
|
+
}
|
4773
|
+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
|
4771
4774
|
|
4772
4775
|
result->op = GGML_OP_GET_ROWS;
|
4773
4776
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
@@ -6938,14 +6941,165 @@ static void ggml_compute_forward_dup_f32(
|
|
6938
6941
|
}
|
6939
6942
|
}
|
6940
6943
|
|
6941
|
-
|
6944
|
+
// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
|
6945
|
+
static void ggml_compute_forward_dup_bytes(
|
6942
6946
|
const struct ggml_compute_params * params,
|
6943
6947
|
const struct ggml_tensor * src0,
|
6944
6948
|
struct ggml_tensor * dst) {
|
6945
|
-
|
6949
|
+
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
6950
|
+
GGML_ASSERT(src0->type == dst->type);
|
6951
|
+
|
6952
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6953
|
+
return;
|
6954
|
+
}
|
6955
|
+
|
6956
|
+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
|
6946
6957
|
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
6947
6958
|
return;
|
6948
6959
|
}
|
6960
|
+
|
6961
|
+
GGML_TENSOR_UNARY_OP_LOCALS;
|
6962
|
+
|
6963
|
+
const size_t type_size = ggml_type_size(src0->type);
|
6964
|
+
const int ith = params->ith; // thread index
|
6965
|
+
const int nth = params->nth; // number of threads
|
6966
|
+
|
6967
|
+
|
6968
|
+
// parallelize by rows
|
6969
|
+
const int nr = ne01;
|
6970
|
+
// number of rows per thread
|
6971
|
+
const int dr = (nr + nth - 1) / nth;
|
6972
|
+
// row range for this thread
|
6973
|
+
const int ir0 = dr * ith;
|
6974
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6975
|
+
|
6976
|
+
if (src0->type == dst->type &&
|
6977
|
+
ne00 == ne0 &&
|
6978
|
+
nb00 == type_size && nb0 == type_size) {
|
6979
|
+
// copy by rows
|
6980
|
+
const size_t rs = ne00 * type_size;
|
6981
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
6982
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
6983
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
6984
|
+
memcpy(
|
6985
|
+
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
6986
|
+
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
6987
|
+
rs);
|
6988
|
+
}
|
6989
|
+
}
|
6990
|
+
}
|
6991
|
+
return;
|
6992
|
+
}
|
6993
|
+
|
6994
|
+
if (ggml_is_contiguous(dst)) {
|
6995
|
+
size_t id = 0;
|
6996
|
+
char * dst_ptr = (char *) dst->data;
|
6997
|
+
const size_t rs = ne00 * type_size;
|
6998
|
+
|
6999
|
+
if (nb00 == type_size) {
|
7000
|
+
// src0 is contigous on first dimension, copy by rows
|
7001
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7002
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7003
|
+
id += rs * ir0;
|
7004
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
7005
|
+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
7006
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
7007
|
+
id += rs;
|
7008
|
+
}
|
7009
|
+
id += rs * (ne01 - ir1);
|
7010
|
+
}
|
7011
|
+
}
|
7012
|
+
} else {
|
7013
|
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
7014
|
+
|
7015
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7016
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7017
|
+
id += rs * ir0;
|
7018
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
7019
|
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
7020
|
+
const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
|
7021
|
+
memcpy(dst_ptr + id, src0_ptr, type_size);
|
7022
|
+
|
7023
|
+
id += type_size;
|
7024
|
+
}
|
7025
|
+
}
|
7026
|
+
id += rs * (ne01 - ir1);
|
7027
|
+
}
|
7028
|
+
}
|
7029
|
+
}
|
7030
|
+
|
7031
|
+
return;
|
7032
|
+
}
|
7033
|
+
|
7034
|
+
// dst counters
|
7035
|
+
|
7036
|
+
int64_t i10 = 0;
|
7037
|
+
int64_t i11 = 0;
|
7038
|
+
int64_t i12 = 0;
|
7039
|
+
int64_t i13 = 0;
|
7040
|
+
|
7041
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7042
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7043
|
+
i10 += ne00 * ir0;
|
7044
|
+
while (i10 >= ne0) {
|
7045
|
+
i10 -= ne0;
|
7046
|
+
if (++i11 == ne1) {
|
7047
|
+
i11 = 0;
|
7048
|
+
if (++i12 == ne2) {
|
7049
|
+
i12 = 0;
|
7050
|
+
if (++i13 == ne3) {
|
7051
|
+
i13 = 0;
|
7052
|
+
}
|
7053
|
+
}
|
7054
|
+
}
|
7055
|
+
}
|
7056
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
7057
|
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
7058
|
+
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
7059
|
+
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
7060
|
+
|
7061
|
+
memcpy(dst_ptr, src0_ptr, type_size);
|
7062
|
+
|
7063
|
+
if (++i10 == ne0) {
|
7064
|
+
i10 = 0;
|
7065
|
+
if (++i11 == ne1) {
|
7066
|
+
i11 = 0;
|
7067
|
+
if (++i12 == ne2) {
|
7068
|
+
i12 = 0;
|
7069
|
+
if (++i13 == ne3) {
|
7070
|
+
i13 = 0;
|
7071
|
+
}
|
7072
|
+
}
|
7073
|
+
}
|
7074
|
+
}
|
7075
|
+
}
|
7076
|
+
}
|
7077
|
+
i10 += ne00 * (ne01 - ir1);
|
7078
|
+
while (i10 >= ne0) {
|
7079
|
+
i10 -= ne0;
|
7080
|
+
if (++i11 == ne1) {
|
7081
|
+
i11 = 0;
|
7082
|
+
if (++i12 == ne2) {
|
7083
|
+
i12 = 0;
|
7084
|
+
if (++i13 == ne3) {
|
7085
|
+
i13 = 0;
|
7086
|
+
}
|
7087
|
+
}
|
7088
|
+
}
|
7089
|
+
}
|
7090
|
+
}
|
7091
|
+
}
|
7092
|
+
}
|
7093
|
+
|
7094
|
+
static void ggml_compute_forward_dup(
|
7095
|
+
const struct ggml_compute_params * params,
|
7096
|
+
const struct ggml_tensor * src0,
|
7097
|
+
struct ggml_tensor * dst) {
|
7098
|
+
if (src0->type == dst->type) {
|
7099
|
+
ggml_compute_forward_dup_bytes(params, src0, dst);
|
7100
|
+
return;
|
7101
|
+
}
|
7102
|
+
|
6949
7103
|
switch (src0->type) {
|
6950
7104
|
case GGML_TYPE_F16:
|
6951
7105
|
{
|
@@ -8404,10 +8558,12 @@ static void ggml_compute_forward_repeat(
|
|
8404
8558
|
struct ggml_tensor * dst) {
|
8405
8559
|
switch (src0->type) {
|
8406
8560
|
case GGML_TYPE_F16:
|
8561
|
+
case GGML_TYPE_I16:
|
8407
8562
|
{
|
8408
8563
|
ggml_compute_forward_repeat_f16(params, src0, dst);
|
8409
8564
|
} break;
|
8410
8565
|
case GGML_TYPE_F32:
|
8566
|
+
case GGML_TYPE_I32:
|
8411
8567
|
{
|
8412
8568
|
ggml_compute_forward_repeat_f32(params, src0, dst);
|
8413
8569
|
} break;
|
@@ -8550,6 +8706,7 @@ static void ggml_compute_forward_concat(
|
|
8550
8706
|
struct ggml_tensor* dst) {
|
8551
8707
|
switch (src0->type) {
|
8552
8708
|
case GGML_TYPE_F32:
|
8709
|
+
case GGML_TYPE_I32:
|
8553
8710
|
{
|
8554
8711
|
ggml_compute_forward_concat_f32(params, src0, src1, dst);
|
8555
8712
|
} break;
|
@@ -9687,7 +9844,7 @@ static void ggml_compute_forward_mul_mat(
|
|
9687
9844
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
9688
9845
|
|
9689
9846
|
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
9690
|
-
|
9847
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
9691
9848
|
|
9692
9849
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
9693
9850
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
@@ -10674,6 +10831,7 @@ static void ggml_compute_forward_get_rows(
|
|
10674
10831
|
ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
|
10675
10832
|
} break;
|
10676
10833
|
case GGML_TYPE_F32:
|
10834
|
+
case GGML_TYPE_I32:
|
10677
10835
|
{
|
10678
10836
|
ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
|
10679
10837
|
} break;
|
@@ -19638,6 +19796,14 @@ int ggml_cpu_has_avx(void) {
|
|
19638
19796
|
#endif
|
19639
19797
|
}
|
19640
19798
|
|
19799
|
+
int ggml_cpu_has_avx_vnni(void) {
|
19800
|
+
#if defined(__AVXVNNI__)
|
19801
|
+
return 1;
|
19802
|
+
#else
|
19803
|
+
return 0;
|
19804
|
+
#endif
|
19805
|
+
}
|
19806
|
+
|
19641
19807
|
int ggml_cpu_has_avx2(void) {
|
19642
19808
|
#if defined(__AVX2__)
|
19643
19809
|
return 1;
|
@@ -2198,6 +2198,7 @@ extern "C" {
|
|
2198
2198
|
//
|
2199
2199
|
|
2200
2200
|
GGML_API int ggml_cpu_has_avx (void);
|
2201
|
+
GGML_API int ggml_cpu_has_avx_vnni (void);
|
2201
2202
|
GGML_API int ggml_cpu_has_avx2 (void);
|
2202
2203
|
GGML_API int ggml_cpu_has_avx512 (void);
|
2203
2204
|
GGML_API int ggml_cpu_has_avx512_vbmi(void);
|