llama_cpp 0.10.3 → 0.10.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -6,19 +6,19 @@
6
6
  extern "C" {
7
7
  #endif
8
8
 
9
- void ggml_cl_init(void);
9
+ GGML_API void ggml_cl_init(void);
10
10
 
11
- void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
12
- bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
13
- size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
14
- void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
11
+ GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
12
+ GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
13
+ GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
14
+ GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
15
15
 
16
- void * ggml_cl_host_malloc(size_t size);
17
- void ggml_cl_host_free(void * ptr);
16
+ GGML_API void * ggml_cl_host_malloc(size_t size);
17
+ GGML_API void ggml_cl_host_free(void * ptr);
18
18
 
19
- void ggml_cl_free_data(const struct ggml_tensor* tensor);
19
+ GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
20
20
 
21
- void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
21
+ GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
22
22
 
23
23
  #ifdef __cplusplus
24
24
  }
@@ -410,13 +410,17 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
410
410
 
411
411
  #if !defined(__ARM_FEATURE_DOTPROD)
412
412
 
413
- inline static int32x4_t vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
413
+ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
414
414
  const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
415
415
  const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
416
416
 
417
417
  return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
418
418
  }
419
419
 
420
+ #else
421
+
422
+ #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
423
+
420
424
  #endif
421
425
 
422
426
  #endif
@@ -2481,8 +2485,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
2481
2485
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2482
2486
 
2483
2487
  // dot product into int32x4_t
2484
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2485
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2488
+ const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2489
+ const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2486
2490
 
2487
2491
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2488
2492
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
@@ -2769,8 +2773,8 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
2769
2773
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2770
2774
 
2771
2775
  // dot product into int32x4_t
2772
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2773
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
2776
+ const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2777
+ const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
2774
2778
 
2775
2779
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
2776
2780
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
@@ -2936,11 +2940,11 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
2936
2940
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2937
2941
 
2938
2942
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2939
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2940
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2943
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2944
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2941
2945
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2942
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2943
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2946
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2947
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2944
2948
  }
2945
2949
 
2946
2950
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3228,11 +3232,11 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
3228
3232
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3229
3233
 
3230
3234
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3231
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3232
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3235
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3236
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3233
3237
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3234
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3235
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
3238
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3239
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
3236
3240
  }
3237
3241
 
3238
3242
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
@@ -3483,12 +3487,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
3483
3487
  const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
3484
3488
 
3485
3489
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3486
- vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3487
- vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3490
+ ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3491
+ ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3488
3492
 
3489
3493
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3490
- vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3491
- vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3494
+ ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3495
+ ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3492
3496
  }
3493
3497
 
3494
3498
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3598,8 +3602,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3598
3602
  // We use this macro instead of a function call because for some reason
3599
3603
  // the code runs 2-3% slower, even if the function is declared inline
3600
3604
  #define MULTIPLY_ACCUM_WITH_SCALE(index)\
3601
- isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3602
- isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3605
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3606
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3603
3607
 
3604
3608
  #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
3605
3609
  q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
@@ -3973,10 +3977,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3973
3977
  q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
3974
3978
  q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
3975
3979
 
3976
- isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
3977
- isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
3978
- isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
3979
- isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
3980
+ isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
3981
+ isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
3982
+ isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
3983
+ isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
3980
3984
 
3981
3985
  sum += d * (isum1 + isum2);
3982
3986
  }
@@ -4256,10 +4260,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4256
4260
  q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
4257
4261
  q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
4258
4262
 
4259
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
4260
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
4261
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
4262
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
4263
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
4264
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
4265
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
4266
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
4263
4267
 
4264
4268
  scale += 4;
4265
4269
 
@@ -4273,10 +4277,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4273
4277
  q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
4274
4278
  q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
4275
4279
 
4276
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
4277
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
4278
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
4279
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
4280
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
4281
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
4282
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
4283
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
4280
4284
 
4281
4285
  scale += 4;
4282
4286
 
@@ -4757,10 +4761,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4757
4761
  q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
4758
4762
  q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
4759
4763
 
4760
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
4761
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
4762
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
4763
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
4764
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
4765
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
4766
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
4767
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
4764
4768
 
4765
4769
  sum += d * isum;
4766
4770
 
@@ -5109,14 +5113,14 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5109
5113
  q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5110
5114
  q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5111
5115
 
5112
- const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5116
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5113
5117
  sumi1 += vaddvq_s32(p1) * scales[2*j+0];
5114
5118
 
5115
5119
  q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
5116
5120
  q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5117
5121
  q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5118
5122
 
5119
- const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5123
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5120
5124
 
5121
5125
  sumi2 += vaddvq_s32(p2) * scales[2*j+1];
5122
5126
  }
@@ -5449,13 +5453,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5449
5453
  q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5450
5454
  q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5451
5455
 
5452
- const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5456
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5453
5457
  const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
5454
5458
 
5455
5459
  q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5456
5460
  q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5457
5461
 
5458
- const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
5462
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
5459
5463
  const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
5460
5464
 
5461
5465
  sumf += d * (sumi1 + sumi2);
@@ -5722,8 +5726,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
5722
5726
  q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
5723
5727
  q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
5724
5728
 
5725
- sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
5726
- sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
5729
+ sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
5730
+ sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
5727
5731
  }
5728
5732
 
5729
5733
  sumf += d * sumi - dmin * sumi_mins;
@@ -6112,10 +6116,10 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
6112
6116
  q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
6113
6117
  q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
6114
6118
 
6115
- int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
6116
- int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
6117
- int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
6118
- int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
6119
+ int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
6120
+ int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
6121
+ int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
6122
+ int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
6119
6123
 
6120
6124
  sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
6121
6125
  }
@@ -6399,10 +6403,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6399
6403
  q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
6400
6404
  q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
6401
6405
 
6402
- isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6403
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6404
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6405
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6406
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6407
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6408
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6409
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6406
6410
 
6407
6411
  scale += 4;
6408
6412
 
@@ -6426,10 +6430,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6426
6430
  q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
6427
6431
  q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
6428
6432
 
6429
- isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6430
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6431
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6432
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6433
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6434
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6435
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6436
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6433
6437
  scale += 4;
6434
6438
  }
6435
6439
  //sum += isum * d_all * y[i].d;
@@ -6816,10 +6820,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6816
6820
  q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
6817
6821
  q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
6818
6822
 
6819
- isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6820
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6821
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6822
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6823
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6824
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6825
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6826
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6823
6827
 
6824
6828
  sum += isum * d_all * y[i].d;
6825
6829
 
@@ -4766,8 +4766,11 @@ struct ggml_tensor * ggml_get_rows(
4766
4766
  }
4767
4767
 
4768
4768
  // TODO: implement non F32 return
4769
- //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
4770
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
4769
+ enum ggml_type type = GGML_TYPE_F32;
4770
+ if (a->type == GGML_TYPE_I32) {
4771
+ type = a->type;
4772
+ }
4773
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
4771
4774
 
4772
4775
  result->op = GGML_OP_GET_ROWS;
4773
4776
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6938,14 +6941,165 @@ static void ggml_compute_forward_dup_f32(
6938
6941
  }
6939
6942
  }
6940
6943
 
6941
- static void ggml_compute_forward_dup(
6944
+ // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
6945
+ static void ggml_compute_forward_dup_bytes(
6942
6946
  const struct ggml_compute_params * params,
6943
6947
  const struct ggml_tensor * src0,
6944
6948
  struct ggml_tensor * dst) {
6945
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
6949
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
6950
+ GGML_ASSERT(src0->type == dst->type);
6951
+
6952
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6953
+ return;
6954
+ }
6955
+
6956
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
6946
6957
  ggml_compute_forward_dup_same_cont(params, src0, dst);
6947
6958
  return;
6948
6959
  }
6960
+
6961
+ GGML_TENSOR_UNARY_OP_LOCALS;
6962
+
6963
+ const size_t type_size = ggml_type_size(src0->type);
6964
+ const int ith = params->ith; // thread index
6965
+ const int nth = params->nth; // number of threads
6966
+
6967
+
6968
+ // parallelize by rows
6969
+ const int nr = ne01;
6970
+ // number of rows per thread
6971
+ const int dr = (nr + nth - 1) / nth;
6972
+ // row range for this thread
6973
+ const int ir0 = dr * ith;
6974
+ const int ir1 = MIN(ir0 + dr, nr);
6975
+
6976
+ if (src0->type == dst->type &&
6977
+ ne00 == ne0 &&
6978
+ nb00 == type_size && nb0 == type_size) {
6979
+ // copy by rows
6980
+ const size_t rs = ne00 * type_size;
6981
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6982
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6983
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
6984
+ memcpy(
6985
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
6986
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
6987
+ rs);
6988
+ }
6989
+ }
6990
+ }
6991
+ return;
6992
+ }
6993
+
6994
+ if (ggml_is_contiguous(dst)) {
6995
+ size_t id = 0;
6996
+ char * dst_ptr = (char *) dst->data;
6997
+ const size_t rs = ne00 * type_size;
6998
+
6999
+ if (nb00 == type_size) {
7000
+ // src0 is contigous on first dimension, copy by rows
7001
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7002
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7003
+ id += rs * ir0;
7004
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7005
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
7006
+ memcpy(dst_ptr + id, src0_ptr, rs);
7007
+ id += rs;
7008
+ }
7009
+ id += rs * (ne01 - ir1);
7010
+ }
7011
+ }
7012
+ } else {
7013
+ //printf("%s: this is not optimal - fix me\n", __func__);
7014
+
7015
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7016
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7017
+ id += rs * ir0;
7018
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7019
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7020
+ const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
7021
+ memcpy(dst_ptr + id, src0_ptr, type_size);
7022
+
7023
+ id += type_size;
7024
+ }
7025
+ }
7026
+ id += rs * (ne01 - ir1);
7027
+ }
7028
+ }
7029
+ }
7030
+
7031
+ return;
7032
+ }
7033
+
7034
+ // dst counters
7035
+
7036
+ int64_t i10 = 0;
7037
+ int64_t i11 = 0;
7038
+ int64_t i12 = 0;
7039
+ int64_t i13 = 0;
7040
+
7041
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7042
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7043
+ i10 += ne00 * ir0;
7044
+ while (i10 >= ne0) {
7045
+ i10 -= ne0;
7046
+ if (++i11 == ne1) {
7047
+ i11 = 0;
7048
+ if (++i12 == ne2) {
7049
+ i12 = 0;
7050
+ if (++i13 == ne3) {
7051
+ i13 = 0;
7052
+ }
7053
+ }
7054
+ }
7055
+ }
7056
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7057
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7058
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7059
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7060
+
7061
+ memcpy(dst_ptr, src0_ptr, type_size);
7062
+
7063
+ if (++i10 == ne0) {
7064
+ i10 = 0;
7065
+ if (++i11 == ne1) {
7066
+ i11 = 0;
7067
+ if (++i12 == ne2) {
7068
+ i12 = 0;
7069
+ if (++i13 == ne3) {
7070
+ i13 = 0;
7071
+ }
7072
+ }
7073
+ }
7074
+ }
7075
+ }
7076
+ }
7077
+ i10 += ne00 * (ne01 - ir1);
7078
+ while (i10 >= ne0) {
7079
+ i10 -= ne0;
7080
+ if (++i11 == ne1) {
7081
+ i11 = 0;
7082
+ if (++i12 == ne2) {
7083
+ i12 = 0;
7084
+ if (++i13 == ne3) {
7085
+ i13 = 0;
7086
+ }
7087
+ }
7088
+ }
7089
+ }
7090
+ }
7091
+ }
7092
+ }
7093
+
7094
+ static void ggml_compute_forward_dup(
7095
+ const struct ggml_compute_params * params,
7096
+ const struct ggml_tensor * src0,
7097
+ struct ggml_tensor * dst) {
7098
+ if (src0->type == dst->type) {
7099
+ ggml_compute_forward_dup_bytes(params, src0, dst);
7100
+ return;
7101
+ }
7102
+
6949
7103
  switch (src0->type) {
6950
7104
  case GGML_TYPE_F16:
6951
7105
  {
@@ -8404,10 +8558,12 @@ static void ggml_compute_forward_repeat(
8404
8558
  struct ggml_tensor * dst) {
8405
8559
  switch (src0->type) {
8406
8560
  case GGML_TYPE_F16:
8561
+ case GGML_TYPE_I16:
8407
8562
  {
8408
8563
  ggml_compute_forward_repeat_f16(params, src0, dst);
8409
8564
  } break;
8410
8565
  case GGML_TYPE_F32:
8566
+ case GGML_TYPE_I32:
8411
8567
  {
8412
8568
  ggml_compute_forward_repeat_f32(params, src0, dst);
8413
8569
  } break;
@@ -8550,6 +8706,7 @@ static void ggml_compute_forward_concat(
8550
8706
  struct ggml_tensor* dst) {
8551
8707
  switch (src0->type) {
8552
8708
  case GGML_TYPE_F32:
8709
+ case GGML_TYPE_I32:
8553
8710
  {
8554
8711
  ggml_compute_forward_concat_f32(params, src0, src1, dst);
8555
8712
  } break;
@@ -9687,7 +9844,7 @@ static void ggml_compute_forward_mul_mat(
9687
9844
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
9688
9845
 
9689
9846
  assert(params->wsize >= ne11*ne12*ne13*row_size);
9690
- assert(src1->type == GGML_TYPE_F32);
9847
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9691
9848
 
9692
9849
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9693
9850
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -10674,6 +10831,7 @@ static void ggml_compute_forward_get_rows(
10674
10831
  ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
10675
10832
  } break;
10676
10833
  case GGML_TYPE_F32:
10834
+ case GGML_TYPE_I32:
10677
10835
  {
10678
10836
  ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
10679
10837
  } break;
@@ -19638,6 +19796,14 @@ int ggml_cpu_has_avx(void) {
19638
19796
  #endif
19639
19797
  }
19640
19798
 
19799
+ int ggml_cpu_has_avx_vnni(void) {
19800
+ #if defined(__AVXVNNI__)
19801
+ return 1;
19802
+ #else
19803
+ return 0;
19804
+ #endif
19805
+ }
19806
+
19641
19807
  int ggml_cpu_has_avx2(void) {
19642
19808
  #if defined(__AVX2__)
19643
19809
  return 1;
@@ -2198,6 +2198,7 @@ extern "C" {
2198
2198
  //
2199
2199
 
2200
2200
  GGML_API int ggml_cpu_has_avx (void);
2201
+ GGML_API int ggml_cpu_has_avx_vnni (void);
2201
2202
  GGML_API int ggml_cpu_has_avx2 (void);
2202
2203
  GGML_API int ggml_cpu_has_avx512 (void);
2203
2204
  GGML_API int ggml_cpu_has_avx512_vbmi(void);