llama_cpp 0.10.3 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +13 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/llama_cpp/extconf.rb +35 -110
  5. data/ext/llama_cpp/llama_cpp.cpp +52 -28
  6. data/lib/llama_cpp/version.rb +2 -2
  7. data/sig/llama_cpp.rbs +3 -1
  8. data/vendor/include/.gitkeep +0 -0
  9. data/vendor/lib/.gitkeep +0 -0
  10. data/vendor/tmp/llama.cpp/Makefile +758 -0
  11. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.c +6 -2
  12. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.cu +73 -63
  13. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-impl.h +1 -0
  14. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.m +43 -20
  15. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.metal +464 -245
  16. data/vendor/tmp/llama.cpp/ggml-opencl.h +25 -0
  17. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.c +61 -57
  18. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.c +171 -5
  19. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.h +1 -0
  20. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.cpp +222 -105
  21. data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.h +31 -32
  22. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +38 -0
  23. metadata +30 -27
  24. data/ext/llama_cpp/src/ggml-opencl.h +0 -25
  25. data/ext/llama_cpp/src/llama-util.h +0 -546
  26. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/LICENSE +0 -0
  27. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.c +0 -0
  28. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.h +0 -0
  29. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend-impl.h +0 -0
  30. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.h +0 -0
  31. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.h +0 -0
  32. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.h +0 -0
  33. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.c +0 -0
  34. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.h +0 -0
  35. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-opencl.cpp +0 -0
  36. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.h +0 -0
  37. /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/unicode.h +0 -0
@@ -0,0 +1,25 @@
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+
5
+ #ifdef __cplusplus
6
+ extern "C" {
7
+ #endif
8
+
9
+ GGML_API void ggml_cl_init(void);
10
+
11
+ GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
12
+ GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
13
+ GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
14
+ GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
15
+
16
+ GGML_API void * ggml_cl_host_malloc(size_t size);
17
+ GGML_API void ggml_cl_host_free(void * ptr);
18
+
19
+ GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
20
+
21
+ GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
22
+
23
+ #ifdef __cplusplus
24
+ }
25
+ #endif
@@ -410,13 +410,17 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
410
410
 
411
411
  #if !defined(__ARM_FEATURE_DOTPROD)
412
412
 
413
- inline static int32x4_t vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
413
+ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
414
414
  const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
415
415
  const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
416
416
 
417
417
  return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
418
418
  }
419
419
 
420
+ #else
421
+
422
+ #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
423
+
420
424
  #endif
421
425
 
422
426
  #endif
@@ -2481,8 +2485,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
2481
2485
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2482
2486
 
2483
2487
  // dot product into int32x4_t
2484
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2485
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2488
+ const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2489
+ const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2486
2490
 
2487
2491
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2488
2492
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
@@ -2769,8 +2773,8 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
2769
2773
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2770
2774
 
2771
2775
  // dot product into int32x4_t
2772
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2773
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
2776
+ const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2777
+ const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
2774
2778
 
2775
2779
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
2776
2780
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
@@ -2936,11 +2940,11 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
2936
2940
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2937
2941
 
2938
2942
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2939
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2940
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2943
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2944
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2941
2945
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2942
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2943
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2946
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2947
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2944
2948
  }
2945
2949
 
2946
2950
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3228,11 +3232,11 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
3228
3232
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3229
3233
 
3230
3234
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3231
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3232
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3235
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3236
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3233
3237
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3234
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3235
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
3238
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3239
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
3236
3240
  }
3237
3241
 
3238
3242
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
@@ -3483,12 +3487,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
3483
3487
  const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
3484
3488
 
3485
3489
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3486
- vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3487
- vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3490
+ ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3491
+ ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3488
3492
 
3489
3493
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3490
- vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3491
- vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3494
+ ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3495
+ ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3492
3496
  }
3493
3497
 
3494
3498
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3598,8 +3602,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3598
3602
  // We use this macro instead of a function call because for some reason
3599
3603
  // the code runs 2-3% slower, even if the function is declared inline
3600
3604
  #define MULTIPLY_ACCUM_WITH_SCALE(index)\
3601
- isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3602
- isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3605
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3606
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3603
3607
 
3604
3608
  #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
3605
3609
  q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
@@ -3973,10 +3977,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3973
3977
  q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
3974
3978
  q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
3975
3979
 
3976
- isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
3977
- isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
3978
- isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
3979
- isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
3980
+ isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
3981
+ isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
3982
+ isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
3983
+ isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
3980
3984
 
3981
3985
  sum += d * (isum1 + isum2);
3982
3986
  }
@@ -4256,10 +4260,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4256
4260
  q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
4257
4261
  q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
4258
4262
 
4259
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
4260
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
4261
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
4262
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
4263
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
4264
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
4265
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
4266
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
4263
4267
 
4264
4268
  scale += 4;
4265
4269
 
@@ -4273,10 +4277,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4273
4277
  q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
4274
4278
  q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
4275
4279
 
4276
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
4277
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
4278
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
4279
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
4280
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
4281
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
4282
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
4283
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
4280
4284
 
4281
4285
  scale += 4;
4282
4286
 
@@ -4757,10 +4761,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4757
4761
  q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
4758
4762
  q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
4759
4763
 
4760
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
4761
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
4762
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
4763
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
4764
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
4765
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
4766
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
4767
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
4764
4768
 
4765
4769
  sum += d * isum;
4766
4770
 
@@ -5109,14 +5113,14 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5109
5113
  q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5110
5114
  q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5111
5115
 
5112
- const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5116
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5113
5117
  sumi1 += vaddvq_s32(p1) * scales[2*j+0];
5114
5118
 
5115
5119
  q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
5116
5120
  q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5117
5121
  q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5118
5122
 
5119
- const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5123
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5120
5124
 
5121
5125
  sumi2 += vaddvq_s32(p2) * scales[2*j+1];
5122
5126
  }
@@ -5449,13 +5453,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5449
5453
  q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5450
5454
  q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5451
5455
 
5452
- const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5456
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5453
5457
  const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
5454
5458
 
5455
5459
  q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5456
5460
  q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5457
5461
 
5458
- const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
5462
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
5459
5463
  const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
5460
5464
 
5461
5465
  sumf += d * (sumi1 + sumi2);
@@ -5722,8 +5726,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
5722
5726
  q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
5723
5727
  q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
5724
5728
 
5725
- sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
5726
- sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
5729
+ sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
5730
+ sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
5727
5731
  }
5728
5732
 
5729
5733
  sumf += d * sumi - dmin * sumi_mins;
@@ -6112,10 +6116,10 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
6112
6116
  q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
6113
6117
  q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
6114
6118
 
6115
- int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
6116
- int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
6117
- int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
6118
- int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
6119
+ int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
6120
+ int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
6121
+ int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
6122
+ int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
6119
6123
 
6120
6124
  sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
6121
6125
  }
@@ -6399,10 +6403,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6399
6403
  q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
6400
6404
  q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
6401
6405
 
6402
- isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6403
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6404
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6405
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6406
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6407
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6408
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6409
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6406
6410
 
6407
6411
  scale += 4;
6408
6412
 
@@ -6426,10 +6430,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6426
6430
  q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
6427
6431
  q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
6428
6432
 
6429
- isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6430
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6431
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6432
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6433
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6434
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6435
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6436
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6433
6437
  scale += 4;
6434
6438
  }
6435
6439
  //sum += isum * d_all * y[i].d;
@@ -6816,10 +6820,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6816
6820
  q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
6817
6821
  q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
6818
6822
 
6819
- isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6820
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6821
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6822
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6823
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6824
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6825
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6826
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6823
6827
 
6824
6828
  sum += isum * d_all * y[i].d;
6825
6829
 
@@ -4766,8 +4766,11 @@ struct ggml_tensor * ggml_get_rows(
4766
4766
  }
4767
4767
 
4768
4768
  // TODO: implement non F32 return
4769
- //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
4770
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
4769
+ enum ggml_type type = GGML_TYPE_F32;
4770
+ if (a->type == GGML_TYPE_I32) {
4771
+ type = a->type;
4772
+ }
4773
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
4771
4774
 
4772
4775
  result->op = GGML_OP_GET_ROWS;
4773
4776
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6938,14 +6941,165 @@ static void ggml_compute_forward_dup_f32(
6938
6941
  }
6939
6942
  }
6940
6943
 
6941
- static void ggml_compute_forward_dup(
6944
+ // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
6945
+ static void ggml_compute_forward_dup_bytes(
6942
6946
  const struct ggml_compute_params * params,
6943
6947
  const struct ggml_tensor * src0,
6944
6948
  struct ggml_tensor * dst) {
6945
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
6949
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
6950
+ GGML_ASSERT(src0->type == dst->type);
6951
+
6952
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6953
+ return;
6954
+ }
6955
+
6956
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
6946
6957
  ggml_compute_forward_dup_same_cont(params, src0, dst);
6947
6958
  return;
6948
6959
  }
6960
+
6961
+ GGML_TENSOR_UNARY_OP_LOCALS;
6962
+
6963
+ const size_t type_size = ggml_type_size(src0->type);
6964
+ const int ith = params->ith; // thread index
6965
+ const int nth = params->nth; // number of threads
6966
+
6967
+
6968
+ // parallelize by rows
6969
+ const int nr = ne01;
6970
+ // number of rows per thread
6971
+ const int dr = (nr + nth - 1) / nth;
6972
+ // row range for this thread
6973
+ const int ir0 = dr * ith;
6974
+ const int ir1 = MIN(ir0 + dr, nr);
6975
+
6976
+ if (src0->type == dst->type &&
6977
+ ne00 == ne0 &&
6978
+ nb00 == type_size && nb0 == type_size) {
6979
+ // copy by rows
6980
+ const size_t rs = ne00 * type_size;
6981
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6982
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6983
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
6984
+ memcpy(
6985
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
6986
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
6987
+ rs);
6988
+ }
6989
+ }
6990
+ }
6991
+ return;
6992
+ }
6993
+
6994
+ if (ggml_is_contiguous(dst)) {
6995
+ size_t id = 0;
6996
+ char * dst_ptr = (char *) dst->data;
6997
+ const size_t rs = ne00 * type_size;
6998
+
6999
+ if (nb00 == type_size) {
7000
+ // src0 is contigous on first dimension, copy by rows
7001
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7002
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7003
+ id += rs * ir0;
7004
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7005
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
7006
+ memcpy(dst_ptr + id, src0_ptr, rs);
7007
+ id += rs;
7008
+ }
7009
+ id += rs * (ne01 - ir1);
7010
+ }
7011
+ }
7012
+ } else {
7013
+ //printf("%s: this is not optimal - fix me\n", __func__);
7014
+
7015
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7016
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7017
+ id += rs * ir0;
7018
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7019
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7020
+ const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
7021
+ memcpy(dst_ptr + id, src0_ptr, type_size);
7022
+
7023
+ id += type_size;
7024
+ }
7025
+ }
7026
+ id += rs * (ne01 - ir1);
7027
+ }
7028
+ }
7029
+ }
7030
+
7031
+ return;
7032
+ }
7033
+
7034
+ // dst counters
7035
+
7036
+ int64_t i10 = 0;
7037
+ int64_t i11 = 0;
7038
+ int64_t i12 = 0;
7039
+ int64_t i13 = 0;
7040
+
7041
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7042
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7043
+ i10 += ne00 * ir0;
7044
+ while (i10 >= ne0) {
7045
+ i10 -= ne0;
7046
+ if (++i11 == ne1) {
7047
+ i11 = 0;
7048
+ if (++i12 == ne2) {
7049
+ i12 = 0;
7050
+ if (++i13 == ne3) {
7051
+ i13 = 0;
7052
+ }
7053
+ }
7054
+ }
7055
+ }
7056
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7057
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7058
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7059
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7060
+
7061
+ memcpy(dst_ptr, src0_ptr, type_size);
7062
+
7063
+ if (++i10 == ne0) {
7064
+ i10 = 0;
7065
+ if (++i11 == ne1) {
7066
+ i11 = 0;
7067
+ if (++i12 == ne2) {
7068
+ i12 = 0;
7069
+ if (++i13 == ne3) {
7070
+ i13 = 0;
7071
+ }
7072
+ }
7073
+ }
7074
+ }
7075
+ }
7076
+ }
7077
+ i10 += ne00 * (ne01 - ir1);
7078
+ while (i10 >= ne0) {
7079
+ i10 -= ne0;
7080
+ if (++i11 == ne1) {
7081
+ i11 = 0;
7082
+ if (++i12 == ne2) {
7083
+ i12 = 0;
7084
+ if (++i13 == ne3) {
7085
+ i13 = 0;
7086
+ }
7087
+ }
7088
+ }
7089
+ }
7090
+ }
7091
+ }
7092
+ }
7093
+
7094
+ static void ggml_compute_forward_dup(
7095
+ const struct ggml_compute_params * params,
7096
+ const struct ggml_tensor * src0,
7097
+ struct ggml_tensor * dst) {
7098
+ if (src0->type == dst->type) {
7099
+ ggml_compute_forward_dup_bytes(params, src0, dst);
7100
+ return;
7101
+ }
7102
+
6949
7103
  switch (src0->type) {
6950
7104
  case GGML_TYPE_F16:
6951
7105
  {
@@ -8404,10 +8558,12 @@ static void ggml_compute_forward_repeat(
8404
8558
  struct ggml_tensor * dst) {
8405
8559
  switch (src0->type) {
8406
8560
  case GGML_TYPE_F16:
8561
+ case GGML_TYPE_I16:
8407
8562
  {
8408
8563
  ggml_compute_forward_repeat_f16(params, src0, dst);
8409
8564
  } break;
8410
8565
  case GGML_TYPE_F32:
8566
+ case GGML_TYPE_I32:
8411
8567
  {
8412
8568
  ggml_compute_forward_repeat_f32(params, src0, dst);
8413
8569
  } break;
@@ -8550,6 +8706,7 @@ static void ggml_compute_forward_concat(
8550
8706
  struct ggml_tensor* dst) {
8551
8707
  switch (src0->type) {
8552
8708
  case GGML_TYPE_F32:
8709
+ case GGML_TYPE_I32:
8553
8710
  {
8554
8711
  ggml_compute_forward_concat_f32(params, src0, src1, dst);
8555
8712
  } break;
@@ -9687,7 +9844,7 @@ static void ggml_compute_forward_mul_mat(
9687
9844
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
9688
9845
 
9689
9846
  assert(params->wsize >= ne11*ne12*ne13*row_size);
9690
- assert(src1->type == GGML_TYPE_F32);
9847
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9691
9848
 
9692
9849
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9693
9850
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -10674,6 +10831,7 @@ static void ggml_compute_forward_get_rows(
10674
10831
  ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
10675
10832
  } break;
10676
10833
  case GGML_TYPE_F32:
10834
+ case GGML_TYPE_I32:
10677
10835
  {
10678
10836
  ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
10679
10837
  } break;
@@ -19638,6 +19796,14 @@ int ggml_cpu_has_avx(void) {
19638
19796
  #endif
19639
19797
  }
19640
19798
 
19799
+ int ggml_cpu_has_avx_vnni(void) {
19800
+ #if defined(__AVXVNNI__)
19801
+ return 1;
19802
+ #else
19803
+ return 0;
19804
+ #endif
19805
+ }
19806
+
19641
19807
  int ggml_cpu_has_avx2(void) {
19642
19808
  #if defined(__AVX2__)
19643
19809
  return 1;
@@ -2198,6 +2198,7 @@ extern "C" {
2198
2198
  //
2199
2199
 
2200
2200
  GGML_API int ggml_cpu_has_avx (void);
2201
+ GGML_API int ggml_cpu_has_avx_vnni (void);
2201
2202
  GGML_API int ggml_cpu_has_avx2 (void);
2202
2203
  GGML_API int ggml_cpu_has_avx512 (void);
2203
2204
  GGML_API int ggml_cpu_has_avx512_vbmi(void);