llama_cpp 0.6.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -162,40 +162,16 @@ typedef void * thread_ret_t;
162
162
 
163
163
  #define GGML_PRINT(...) printf(__VA_ARGS__)
164
164
 
165
+ //
166
+ // end of logging block
167
+ //
168
+
165
169
  #ifdef GGML_USE_ACCELERATE
166
170
  // uncomment to use vDSP for soft max computation
167
171
  // note: not sure if it is actually faster
168
172
  //#define GGML_SOFT_MAX_ACCELERATE
169
173
  #endif
170
174
 
171
- //
172
- // logging
173
- //
174
-
175
- #if (GGML_DEBUG >= 1)
176
- #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
177
- #else
178
- #define GGML_PRINT_DEBUG(...)
179
- #endif
180
-
181
- #if (GGML_DEBUG >= 5)
182
- #define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
183
- #else
184
- #define GGML_PRINT_DEBUG_5(...)
185
- #endif
186
-
187
- #if (GGML_DEBUG >= 10)
188
- #define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
189
- #else
190
- #define GGML_PRINT_DEBUG_10(...)
191
- #endif
192
-
193
- #define GGML_PRINT(...) printf(__VA_ARGS__)
194
-
195
- //
196
- // end of logging block
197
- //
198
-
199
175
  #if defined(_MSC_VER) || defined(__MINGW32__)
200
176
  #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
201
177
  #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
@@ -1032,8 +1008,8 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
1032
1008
  y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1033
1009
 
1034
1010
  // get the 5-th bit and store it in qh at the right position
1035
- qh |= ((xi0 & 0x10) >> 4) << (j + 0);
1036
- qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
1011
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1012
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1037
1013
  }
1038
1014
 
1039
1015
  memcpy(&y[i].qh, &qh, sizeof(qh));
@@ -1080,8 +1056,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
1080
1056
  y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1081
1057
 
1082
1058
  // get the 5-th bit and store it in qh at the right position
1083
- qh |= ((xi0 & 0x10) >> 4) << (j + 0);
1084
- qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
1059
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1060
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1085
1061
  }
1086
1062
 
1087
1063
  memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
@@ -1272,6 +1248,33 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1272
1248
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1273
1249
  #endif
1274
1250
  }
1251
+ #elif defined(__riscv_v_intrinsic)
1252
+
1253
+ size_t vl = __riscv_vsetvl_e32m4(QK8_0);
1254
+
1255
+ for (int i = 0; i < nb; i++) {
1256
+ // load elements
1257
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
1258
+
1259
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1260
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
1261
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1262
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1263
+
1264
+ const float d = amax / ((1 << 7) - 1);
1265
+ const float id = d ? 1.0f/d : 0.0f;
1266
+
1267
+ y[i].d = GGML_FP32_TO_FP16(d);
1268
+
1269
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1270
+
1271
+ // convert to integer
1272
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1273
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1274
+
1275
+ // store result
1276
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1277
+ }
1275
1278
  #else
1276
1279
  // scalar
1277
1280
  quantize_row_q8_0_reference(x, y, k);
@@ -1490,6 +1493,41 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
1490
1493
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1491
1494
  #endif
1492
1495
  }
1496
+ #elif defined(__riscv_v_intrinsic)
1497
+
1498
+ size_t vl = __riscv_vsetvl_e32m4(QK8_1);
1499
+
1500
+ for (int i = 0; i < nb; i++) {
1501
+ // load elements
1502
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
1503
+
1504
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1505
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
1506
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1507
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1508
+
1509
+ const float d = amax / ((1 << 7) - 1);
1510
+ const float id = d ? 1.0f/d : 0.0f;
1511
+
1512
+ y[i].d = d;
1513
+
1514
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1515
+
1516
+ // convert to integer
1517
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1518
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1519
+
1520
+ // store result
1521
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1522
+
1523
+ // compute sum for y[i].s
1524
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
1525
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
1526
+
1527
+ // set y[i].s
1528
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
1529
+ y[i].s = sum*d;
1530
+ }
1493
1531
  #else
1494
1532
  // scalar
1495
1533
  quantize_row_q8_1_reference(x, y, k);
@@ -2662,30 +2700,32 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2662
2700
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
2663
2701
 
2664
2702
  for (int i = 0; i < nb; i++) {
2665
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2703
+ // load elements
2704
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2666
2705
 
2667
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2668
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2706
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2707
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2669
2708
 
2670
- vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2671
- vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2709
+ // mask and store lower part of x, and then upper part
2710
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2711
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2672
2712
 
2673
- vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2674
- vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2713
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2714
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2675
2715
 
2676
- vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
2677
- vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
2716
+ // subtract offset
2717
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2718
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2678
2719
 
2679
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2680
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2720
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2721
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2681
2722
 
2682
2723
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2683
2724
 
2684
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2685
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2725
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2726
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2686
2727
 
2687
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2688
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2728
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2689
2729
 
2690
2730
  sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2691
2731
  }
@@ -2823,27 +2863,28 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2823
2863
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
2824
2864
 
2825
2865
  for (int i = 0; i < nb; i++) {
2826
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2866
+ // load elements
2867
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2827
2868
 
2828
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2829
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2869
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2870
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2830
2871
 
2831
- vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2832
- vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2872
+ // mask and store lower part of x, and then upper part
2873
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2874
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2833
2875
 
2834
- vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2835
- vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2876
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2877
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2836
2878
 
2837
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2838
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2879
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2880
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2839
2881
 
2840
2882
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2841
2883
 
2842
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2843
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2884
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2885
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2844
2886
 
2845
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2846
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2887
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2847
2888
 
2848
2889
  sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2849
2890
  }
@@ -3088,66 +3129,61 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
3088
3129
 
3089
3130
  uint32_t qh;
3090
3131
 
3091
- // These temp values are for masking and shift operations
3092
- uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3093
- uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
3094
- 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
3095
-
3096
3132
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3097
3133
 
3134
+ // These tempory registers are for masking and shift operations
3135
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3136
+ vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3137
+
3138
+ vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3139
+ vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3140
+
3098
3141
  for (int i = 0; i < nb; i++) {
3099
3142
  memcpy(&qh, x[i].qh, sizeof(uint32_t));
3100
3143
 
3101
- // temporary registers
3102
- vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
3103
- vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
3104
- vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
3105
- vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
3106
-
3107
3144
  // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3108
- vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl);
3109
- vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl);
3110
- vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3145
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3146
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3147
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3111
3148
 
3112
3149
  // ((qh & (1u << (j + 16))) >> (j + 12));
3113
- vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl);
3114
- vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl);
3150
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3151
+ vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
3115
3152
 
3116
3153
  // narrowing
3117
- vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl);
3118
- vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3154
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3155
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3119
3156
 
3120
- vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl);
3121
- vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3157
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3158
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3122
3159
 
3123
3160
  // load
3124
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3161
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3125
3162
 
3126
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3127
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3163
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3164
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3128
3165
 
3129
- vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3130
- vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3166
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3167
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3131
3168
 
3132
- vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3133
- vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3169
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3170
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3134
3171
 
3135
- vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3136
- vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3172
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3173
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3137
3174
 
3138
- vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl);
3139
- vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl);
3175
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3176
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
3140
3177
 
3141
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3142
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3178
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3179
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3143
3180
 
3144
3181
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3145
3182
 
3146
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3147
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3183
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3184
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3148
3185
 
3149
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3150
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3186
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3151
3187
 
3152
3188
  sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3153
3189
  }
@@ -3414,62 +3450,58 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
3414
3450
 
3415
3451
  uint32_t qh;
3416
3452
 
3417
- // These temp values are for shift operations
3418
- uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3419
-
3420
3453
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3421
3454
 
3455
+ // temporary registers for shift operations
3456
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3457
+ vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3458
+
3422
3459
  for (int i = 0; i < nb; i++) {
3423
3460
  memcpy(&qh, x[i].qh, sizeof(uint32_t));
3424
3461
 
3425
- // temporary registers
3426
- vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
3427
- vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
3428
-
3429
3462
  // load qh
3430
- vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl);
3463
+ vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
3431
3464
 
3432
3465
  // ((qh >> (j + 0)) << 4) & 0x10;
3433
- vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl);
3434
- vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3435
- vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl);
3466
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3467
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3468
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
3436
3469
 
3437
3470
  // ((qh >> (j + 12)) ) & 0x10;
3438
- vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl);
3439
- vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl);
3471
+ vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3472
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
3440
3473
 
3441
3474
  // narrowing
3442
- vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl);
3443
- vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3475
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3476
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3444
3477
 
3445
- vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl);
3446
- vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3478
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3479
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3447
3480
 
3448
3481
  // load
3449
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3482
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3450
3483
 
3451
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3452
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3484
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3485
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3453
3486
 
3454
- vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3455
- vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3487
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3488
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3456
3489
 
3457
- vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3458
- vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3490
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3491
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3459
3492
 
3460
- vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3461
- vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3493
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3494
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3462
3495
 
3463
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3464
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3496
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3497
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3465
3498
 
3466
3499
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3467
3500
 
3468
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3469
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3501
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3502
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3470
3503
 
3471
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3472
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3504
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3473
3505
 
3474
3506
  sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
3475
3507
  }
@@ -4025,12 +4057,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4025
4057
  "ALIBI",
4026
4058
  "CLAMP",
4027
4059
  "CONV_1D",
4060
+ "CONV_TRANSPOSE_1D",
4028
4061
  "CONV_2D",
4029
4062
  "CONV_TRANSPOSE_2D",
4030
4063
  "POOL_1D",
4031
4064
  "POOL_2D",
4032
4065
  "UPSCALE",
4033
4066
 
4067
+ "CONV_1D_STAGE_0",
4068
+ "CONV_1D_STAGE_1",
4069
+
4034
4070
  "FLASH_ATTN",
4035
4071
  "FLASH_FF",
4036
4072
  "FLASH_ATTN_BACK",
@@ -4056,7 +4092,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4056
4092
  "CROSS_ENTROPY_LOSS_BACK",
4057
4093
  };
4058
4094
 
4059
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
4095
+ static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4060
4096
 
4061
4097
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4062
4098
  "none",
@@ -4107,12 +4143,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4107
4143
  "alibi(x)",
4108
4144
  "clamp(x)",
4109
4145
  "conv_1d(x)",
4146
+ "conv_transpose_1d(x)",
4110
4147
  "conv_2d(x)",
4111
4148
  "conv_transpose_2d(x)",
4112
4149
  "pool_1d(x)",
4113
4150
  "pool_2d(x)",
4114
4151
  "upscale(x)",
4115
4152
 
4153
+ "conv_1d_stage_0(x)",
4154
+ "conv_1d_stage_1(x)",
4155
+
4116
4156
  "flash_attn(x)",
4117
4157
  "flash_ff(x)",
4118
4158
  "flash_attn_back(x)",
@@ -4138,7 +4178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4138
4178
  "cross_entropy_loss_back(x,y)",
4139
4179
  };
4140
4180
 
4141
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
4181
+ static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4142
4182
 
4143
4183
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
4144
4184
 
@@ -4167,7 +4207,10 @@ static void ggml_setup_op_has_task_pass(void) {
4167
4207
  p[GGML_OP_DIAG_MASK_INF ] = true;
4168
4208
  p[GGML_OP_DIAG_MASK_ZERO ] = true;
4169
4209
  p[GGML_OP_CONV_1D ] = true;
4210
+ p[GGML_OP_CONV_1D_STAGE_0 ] = true;
4211
+ p[GGML_OP_CONV_1D_STAGE_1 ] = true;
4170
4212
  p[GGML_OP_CONV_2D ] = true;
4213
+ p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
4171
4214
  p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
4172
4215
  p[GGML_OP_FLASH_ATTN_BACK ] = true;
4173
4216
  p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
@@ -4884,6 +4927,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
4884
4927
  *result = (struct ggml_tensor) {
4885
4928
  /*.type =*/ type,
4886
4929
  /*.backend =*/ GGML_BACKEND_CPU,
4930
+ /*.buffer =*/ NULL,
4887
4931
  /*.n_dims =*/ n_dims,
4888
4932
  /*.ne =*/ { 1, 1, 1, 1 },
4889
4933
  /*.nb =*/ { 0, 0, 0, 0 },
@@ -5450,6 +5494,39 @@ struct ggml_tensor * ggml_view_tensor(
5450
5494
  return result;
5451
5495
  }
5452
5496
 
5497
+ struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) {
5498
+ struct ggml_object * obj = ctx->objects_begin;
5499
+
5500
+ char * const mem_buffer = ctx->mem_buffer;
5501
+
5502
+ while (obj != NULL) {
5503
+ if (obj->type == GGML_OBJECT_TENSOR) {
5504
+ return (struct ggml_tensor *)(mem_buffer + obj->offs);
5505
+ }
5506
+
5507
+ obj = obj->next;
5508
+ }
5509
+
5510
+ return NULL;
5511
+ }
5512
+
5513
+ struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) {
5514
+ struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
5515
+ obj = obj->next;
5516
+
5517
+ char * const mem_buffer = ctx->mem_buffer;
5518
+
5519
+ while (obj != NULL) {
5520
+ if (obj->type == GGML_OBJECT_TENSOR) {
5521
+ return (struct ggml_tensor *)(mem_buffer + obj->offs);
5522
+ }
5523
+
5524
+ obj = obj->next;
5525
+ }
5526
+
5527
+ return NULL;
5528
+ }
5529
+
5453
5530
  struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
5454
5531
  struct ggml_object * obj = ctx->objects_begin;
5455
5532
 
@@ -6690,7 +6767,6 @@ struct ggml_tensor * ggml_cont_4d(
6690
6767
  return result;
6691
6768
  }
6692
6769
 
6693
-
6694
6770
  // ggml_reshape
6695
6771
 
6696
6772
  struct ggml_tensor * ggml_reshape(
@@ -7448,14 +7524,17 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
7448
7524
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7449
7525
  }
7450
7526
 
7451
- GGML_API struct ggml_tensor * ggml_conv_1d(
7452
- struct ggml_context * ctx,
7453
- struct ggml_tensor * a,
7454
- struct ggml_tensor * b,
7455
- int s0,
7456
- int p0,
7457
- int d0) {
7458
- GGML_ASSERT(ggml_is_matrix(b));
7527
+ // im2col: [N, IC, IL] => [N, OL, IC*K]
7528
+ // a: [OC,IC, K]
7529
+ // b: [N, IC, IL]
7530
+ // result: [N, OL, IC*K]
7531
+ static struct ggml_tensor * ggml_conv_1d_stage_0(
7532
+ struct ggml_context * ctx,
7533
+ struct ggml_tensor * a,
7534
+ struct ggml_tensor * b,
7535
+ int s0,
7536
+ int p0,
7537
+ int d0) {
7459
7538
  GGML_ASSERT(a->ne[1] == b->ne[1]);
7460
7539
  bool is_node = false;
7461
7540
 
@@ -7464,16 +7543,54 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7464
7543
  is_node = true;
7465
7544
  }
7466
7545
 
7546
+ const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
7547
+
7467
7548
  const int64_t ne[4] = {
7468
- ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7469
- a->ne[2], 1, 1,
7549
+ a->ne[1] * a->ne[0],
7550
+ OL,
7551
+ b->ne[2],
7552
+ 1,
7470
7553
  };
7471
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7554
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
7472
7555
 
7473
7556
  int32_t params[] = { s0, p0, d0 };
7474
7557
  ggml_set_op_params(result, params, sizeof(params));
7475
7558
 
7476
- result->op = GGML_OP_CONV_1D;
7559
+ result->op = GGML_OP_CONV_1D_STAGE_0;
7560
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7561
+ result->src[0] = a;
7562
+ result->src[1] = b;
7563
+
7564
+ return result;
7565
+ }
7566
+
7567
+ // ggml_conv_1d_stage_1
7568
+
7569
+ // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
7570
+ // a: [OC, IC, K]
7571
+ // b: [N, OL, IC * K]
7572
+ // result: [N, OC, OL]
7573
+ static struct ggml_tensor * ggml_conv_1d_stage_1(
7574
+ struct ggml_context * ctx,
7575
+ struct ggml_tensor * a,
7576
+ struct ggml_tensor * b) {
7577
+
7578
+ bool is_node = false;
7579
+
7580
+ if (a->grad || b->grad) {
7581
+ GGML_ASSERT(false); // TODO: implement backward
7582
+ is_node = true;
7583
+ }
7584
+
7585
+ const int64_t ne[4] = {
7586
+ b->ne[1],
7587
+ a->ne[2],
7588
+ b->ne[2],
7589
+ 1,
7590
+ };
7591
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7592
+
7593
+ result->op = GGML_OP_CONV_1D_STAGE_1;
7477
7594
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7478
7595
  result->src[0] = a;
7479
7596
  result->src[1] = b;
@@ -7481,6 +7598,53 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7481
7598
  return result;
7482
7599
  }
7483
7600
 
7601
+ // ggml_conv_1d
7602
+
7603
+ GGML_API struct ggml_tensor * ggml_conv_1d(
7604
+ struct ggml_context * ctx,
7605
+ struct ggml_tensor * a,
7606
+ struct ggml_tensor * b,
7607
+ int s0,
7608
+ int p0,
7609
+ int d0) {
7610
+ struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
7611
+ result = ggml_conv_1d_stage_1(ctx, a, result);
7612
+ return result;
7613
+ }
7614
+
7615
+ // GGML_API struct ggml_tensor * ggml_conv_1d(
7616
+ // struct ggml_context * ctx,
7617
+ // struct ggml_tensor * a,
7618
+ // struct ggml_tensor * b,
7619
+ // int s0,
7620
+ // int p0,
7621
+ // int d0) {
7622
+ // GGML_ASSERT(ggml_is_matrix(b));
7623
+ // GGML_ASSERT(a->ne[1] == b->ne[1]);
7624
+ // bool is_node = false;
7625
+
7626
+ // if (a->grad || b->grad) {
7627
+ // GGML_ASSERT(false); // TODO: implement backward
7628
+ // is_node = true;
7629
+ // }
7630
+
7631
+ // const int64_t ne[4] = {
7632
+ // ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7633
+ // a->ne[2], 1, 1,
7634
+ // };
7635
+ // struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7636
+
7637
+ // int32_t params[] = { s0, p0, d0 };
7638
+ // ggml_set_op_params(result, params, sizeof(params));
7639
+
7640
+ // result->op = GGML_OP_CONV_1D;
7641
+ // result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7642
+ // result->src[0] = a;
7643
+ // result->src[1] = b;
7644
+
7645
+ // return result;
7646
+ // }
7647
+
7484
7648
  // ggml_conv_1d_ph
7485
7649
 
7486
7650
  struct ggml_tensor* ggml_conv_1d_ph(
@@ -7492,6 +7656,50 @@ struct ggml_tensor* ggml_conv_1d_ph(
7492
7656
  return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
7493
7657
  }
7494
7658
 
7659
+ // ggml_conv_transpose_1d
7660
+
7661
+ static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
7662
+ return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
7663
+ }
7664
+
7665
+ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
7666
+ struct ggml_context * ctx,
7667
+ struct ggml_tensor * a,
7668
+ struct ggml_tensor * b,
7669
+ int s0,
7670
+ int p0,
7671
+ int d0) {
7672
+ GGML_ASSERT(ggml_is_matrix(b));
7673
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
7674
+ GGML_ASSERT(a->ne[3] == 1);
7675
+
7676
+ GGML_ASSERT(p0 == 0);
7677
+ GGML_ASSERT(d0 == 1);
7678
+
7679
+ bool is_node = false;
7680
+
7681
+ if (a->grad || b->grad) {
7682
+ GGML_ASSERT(false); // TODO: implement backward
7683
+ is_node = true;
7684
+ }
7685
+
7686
+ const int64_t ne[4] = {
7687
+ ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
7688
+ a->ne[1], b->ne[2], 1,
7689
+ };
7690
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7691
+
7692
+ int32_t params[] = { s0, p0, d0 };
7693
+ ggml_set_op_params(result, params, sizeof(params));
7694
+
7695
+ result->op = GGML_OP_CONV_TRANSPOSE_1D;
7696
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7697
+ result->src[0] = a;
7698
+ result->src[1] = b;
7699
+
7700
+ return result;
7701
+ }
7702
+
7495
7703
  // ggml_conv_2d
7496
7704
 
7497
7705
  struct ggml_tensor * ggml_conv_2d(
@@ -8472,6 +8680,7 @@ void ggml_set_param(
8472
8680
 
8473
8681
  GGML_ASSERT(tensor->grad == NULL);
8474
8682
  tensor->grad = ggml_dup_tensor(ctx, tensor);
8683
+ ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
8475
8684
  }
8476
8685
 
8477
8686
  // ggml_compute_forward_dup
@@ -11058,7 +11267,7 @@ static void ggml_compute_forward_silu_f32(
11058
11267
 
11059
11268
  #ifndef NDEBUG
11060
11269
  for (int k = 0; k < nc; k++) {
11061
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
11270
+ const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
11062
11271
  UNUSED(x);
11063
11272
  assert(!isnan(x));
11064
11273
  assert(!isinf(x));
@@ -11621,11 +11830,6 @@ static void ggml_compute_forward_mul_mat(
11621
11830
 
11622
11831
  #if defined(GGML_USE_CLBLAST)
11623
11832
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
11624
- // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
11625
- // ref: https://github.com/ggerganov/ggml/pull/224
11626
- GGML_ASSERT(ne02 == ne12);
11627
- GGML_ASSERT(ne03 == ne13);
11628
-
11629
11833
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
11630
11834
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
11631
11835
  }
@@ -12889,28 +13093,25 @@ static void ggml_compute_forward_alibi_f32(
12889
13093
  return;
12890
13094
  }
12891
13095
 
12892
- const int n_past = ((int32_t *) dst->op_params)[0];
13096
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12893
13097
  const int n_head = ((int32_t *) dst->op_params)[1];
12894
13098
  float max_bias;
12895
13099
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12896
13100
 
12897
- assert(n_past >= 0);
13101
+ const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13102
+ const int64_t ne1 = src0->ne[1]; // seq_len_without_past
13103
+ const int64_t ne2 = src0->ne[2]; // n_head -> this is k
13104
+ //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
12898
13105
 
12899
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12900
- const int ne1 = src0->ne[1]; // seq_len_without_past
12901
- const int ne2 = src0->ne[2]; // n_head -> this is k
12902
- //const int ne3 = src0->ne[3]; // 1 -> bsz
13106
+ const int64_t n = ggml_nrows(src0);
13107
+ const int64_t ne2_ne3 = n/ne1; // ne2*ne3
12903
13108
 
12904
- const int n = ggml_nrows(src0);
12905
- const int ne2_ne3 = n/ne1; // ne2*ne3
12906
-
12907
- const int nb0 = src0->nb[0];
12908
- const int nb1 = src0->nb[1];
12909
- const int nb2 = src0->nb[2];
13109
+ const size_t nb0 = src0->nb[0];
13110
+ const size_t nb1 = src0->nb[1];
13111
+ const size_t nb2 = src0->nb[2];
12910
13112
  //const int nb3 = src0->nb[3];
12911
13113
 
12912
13114
  GGML_ASSERT(nb0 == sizeof(float));
12913
- GGML_ASSERT(ne1 + n_past == ne0);
12914
13115
  GGML_ASSERT(n_head == ne2);
12915
13116
 
12916
13117
  // add alibi to src0 (KQ_scaled)
@@ -12919,9 +13120,9 @@ static void ggml_compute_forward_alibi_f32(
12919
13120
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
12920
13121
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
12921
13122
 
12922
- for (int i = 0; i < ne0; i++) {
12923
- for (int j = 0; j < ne1; j++) {
12924
- for (int k = 0; k < ne2_ne3; k++) {
13123
+ for (int64_t i = 0; i < ne0; i++) {
13124
+ for (int64_t j = 0; j < ne1; j++) {
13125
+ for (int64_t k = 0; k < ne2_ne3; k++) {
12925
13126
  float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
12926
13127
  float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
12927
13128
 
@@ -12936,7 +13137,6 @@ static void ggml_compute_forward_alibi_f32(
12936
13137
  }
12937
13138
 
12938
13139
  pdst[0] = i * m_k + src[0];
12939
-
12940
13140
  }
12941
13141
  }
12942
13142
  }
@@ -13636,7 +13836,7 @@ static void ggml_compute_forward_rope_back(
13636
13836
 
13637
13837
  // ggml_compute_forward_conv_1d
13638
13838
 
13639
- static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13839
+ static void ggml_compute_forward_conv_1d_f16_f32(
13640
13840
  const struct ggml_compute_params * params,
13641
13841
  const struct ggml_tensor * src0,
13642
13842
  const struct ggml_tensor * src1,
@@ -13654,42 +13854,33 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13654
13854
  const int nth = params->nth;
13655
13855
 
13656
13856
  const int nk = ne00;
13657
- const int nh = nk/2;
13658
13857
 
13659
- const int ew0 = ggml_up32(ne01);
13858
+ // size of the convolution row - the kernel size unrolled across all input channels
13859
+ const int ew0 = nk*ne01;
13860
+
13861
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
13862
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
13863
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
13660
13864
 
13661
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13662
13865
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13663
13866
  GGML_ASSERT(nb10 == sizeof(float));
13664
13867
 
13665
13868
  if (params->type == GGML_TASK_INIT) {
13666
- // TODO: fix this memset (wsize is overestimated)
13667
13869
  memset(params->wdata, 0, params->wsize);
13668
13870
 
13669
- // prepare kernel data (src0)
13670
- {
13671
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13871
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13672
13872
 
13673
- for (int64_t i02 = 0; i02 < ne02; i02++) {
13674
- for (int64_t i01 = 0; i01 < ne01; i01++) {
13675
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
13676
- ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
13677
- for (int64_t i00 = 0; i00 < ne00; i00++) {
13678
- dst_data[i00*ew0 + i01] = src[i00];
13679
- }
13680
- }
13681
- }
13682
- }
13873
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
13874
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
13875
+ ggml_fp16_t * dst_data = wdata;
13683
13876
 
13684
- // prepare source data (src1)
13685
- {
13686
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
13877
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
13878
+ for (int64_t ik = 0; ik < nk; ik++) {
13879
+ const int idx0 = i0*s0 + ik*d0 - p0;
13687
13880
 
13688
- for (int64_t i11 = 0; i11 < ne11; i11++) {
13689
- const float * const src = (float *)((char *) src1->data + i11*nb11);
13690
- ggml_fp16_t * dst_data = wdata;
13691
- for (int64_t i10 = 0; i10 < ne10; i10++) {
13692
- dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
13881
+ if(!(idx0 < 0 || idx0 >= ne10)) {
13882
+ dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]);
13883
+ }
13693
13884
  }
13694
13885
  }
13695
13886
  }
@@ -13702,7 +13893,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13702
13893
  }
13703
13894
 
13704
13895
  // total rows in dst
13705
- const int nr = ne02;
13896
+ const int nr = ne2;
13706
13897
 
13707
13898
  // rows per thread
13708
13899
  const int dr = (nr + nth - 1)/nth;
@@ -13711,23 +13902,22 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13711
13902
  const int ir0 = dr*ith;
13712
13903
  const int ir1 = MIN(ir0 + dr, nr);
13713
13904
 
13714
- for (int i1 = ir0; i1 < ir1; i1++) {
13715
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
13716
- for (int64_t i0 = 0; i0 < ne10; ++i0) {
13717
- dst_data[i0] = 0;
13718
- for (int k = -nh; k <= nh; k++) {
13719
- float v = 0.0f;
13720
- ggml_vec_dot_f16(ew0, &v,
13721
- (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13722
- (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13723
-
13724
- dst_data[i0] += v;
13905
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13906
+
13907
+ for (int i2 = 0; i2 < ne2; i2++) {
13908
+ for (int i1 = ir0; i1 < ir1; i1++) {
13909
+ float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
13910
+
13911
+ for (int i0 = 0; i0 < ne0; i0++) {
13912
+ ggml_vec_dot_f16(ew0, dst_data + i0,
13913
+ (ggml_fp16_t *) ((char *) src0->data + i1*nb02),
13914
+ (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
13725
13915
  }
13726
13916
  }
13727
13917
  }
13728
13918
  }
13729
13919
 
13730
- static void ggml_compute_forward_conv_1d_s1_ph_f32(
13920
+ static void ggml_compute_forward_conv_1d_f32(
13731
13921
  const struct ggml_compute_params * params,
13732
13922
  const struct ggml_tensor * src0,
13733
13923
  const struct ggml_tensor * src1,
@@ -13745,42 +13935,32 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13745
13935
  const int nth = params->nth;
13746
13936
 
13747
13937
  const int nk = ne00;
13748
- const int nh = nk/2;
13749
13938
 
13750
- const int ew0 = ggml_up32(ne01);
13939
+ const int ew0 = nk*ne01;
13940
+
13941
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
13942
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
13943
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
13751
13944
 
13752
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13753
13945
  GGML_ASSERT(nb00 == sizeof(float));
13754
13946
  GGML_ASSERT(nb10 == sizeof(float));
13755
13947
 
13756
13948
  if (params->type == GGML_TASK_INIT) {
13757
- // TODO: fix this memset (wsize is overestimated)
13758
13949
  memset(params->wdata, 0, params->wsize);
13759
13950
 
13760
- // prepare kernel data (src0)
13761
- {
13762
- float * const wdata = (float *) params->wdata + 0;
13951
+ float * const wdata = (float *) params->wdata + 0;
13763
13952
 
13764
- for (int64_t i02 = 0; i02 < ne02; i02++) {
13765
- for (int64_t i01 = 0; i01 < ne01; i01++) {
13766
- const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
13767
- float * dst_data = wdata + i02*ew0*ne00;
13768
- for (int64_t i00 = 0; i00 < ne00; i00++) {
13769
- dst_data[i00*ew0 + i01] = src[i00];
13770
- }
13771
- }
13772
- }
13773
- }
13953
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
13954
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
13955
+ float * dst_data = wdata;
13774
13956
 
13775
- // prepare source data (src1)
13776
- {
13777
- float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
13957
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
13958
+ for (int64_t ik = 0; ik < nk; ik++) {
13959
+ const int idx0 = i0*s0 + ik*d0 - p0;
13778
13960
 
13779
- for (int64_t i11 = 0; i11 < ne11; i11++) {
13780
- const float * const src = (float *)((char *) src1->data + i11*nb11);
13781
- float * dst_data = wdata;
13782
- for (int64_t i10 = 0; i10 < ne10; i10++) {
13783
- dst_data[(i10 + nh)*ew0 + i11] = src[i10];
13961
+ if(!(idx0 < 0 || idx0 >= ne10)) {
13962
+ dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
13963
+ }
13784
13964
  }
13785
13965
  }
13786
13966
  }
@@ -13802,35 +13982,242 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13802
13982
  const int ir0 = dr*ith;
13803
13983
  const int ir1 = MIN(ir0 + dr, nr);
13804
13984
 
13805
- for (int i1 = ir0; i1 < ir1; i1++) {
13806
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
13807
- for (int64_t i0 = 0; i0 < ne10; ++i0) {
13808
- dst_data[i0] = 0;
13809
- for (int k = -nh; k <= nh; k++) {
13810
- float v = 0.0f;
13811
- ggml_vec_dot_f32(ew0, &v,
13812
- (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13813
- (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13814
-
13815
- dst_data[i0] += v;
13985
+ float * const wdata = (float *) params->wdata + 0;
13986
+
13987
+ for (int i2 = 0; i2 < ne2; i2++) {
13988
+ for (int i1 = ir0; i1 < ir1; i1++) {
13989
+ float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
13990
+
13991
+ for (int i0 = 0; i0 < ne0; i0++) {
13992
+ ggml_vec_dot_f32(ew0, dst_data + i0,
13993
+ (float *) ((char *) src0->data + i1*nb02),
13994
+ (float *) wdata + i2*nb2 + i0*ew0);
13995
+ }
13996
+ }
13997
+ }
13998
+ }
13999
+
14000
+ static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
14001
+ ggml_fp16_t * A,
14002
+ ggml_fp16_t * B,
14003
+ float * C,
14004
+ const int ith, const int nth) {
14005
+ // does not seem to make a difference
14006
+ int64_t m0, m1, n0, n1;
14007
+ // patches per thread
14008
+ if (m > n) {
14009
+ n0 = 0;
14010
+ n1 = n;
14011
+
14012
+ // total patches in dst
14013
+ const int np = m;
14014
+
14015
+ // patches per thread
14016
+ const int dp = (np + nth - 1)/nth;
14017
+
14018
+ // patch range for this thread
14019
+ m0 = dp*ith;
14020
+ m1 = MIN(m0 + dp, np);
14021
+ } else {
14022
+ m0 = 0;
14023
+ m1 = m;
14024
+
14025
+ // total patches in dst
14026
+ const int np = n;
14027
+
14028
+ // patches per thread
14029
+ const int dp = (np + nth - 1)/nth;
14030
+
14031
+ // patch range for this thread
14032
+ n0 = dp*ith;
14033
+ n1 = MIN(n0 + dp, np);
14034
+ }
14035
+
14036
+ // block-tiling attempt
14037
+ int64_t blck_n = 16;
14038
+ int64_t blck_m = 16;
14039
+
14040
+ // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
14041
+ // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
14042
+ // if (blck_size > 0) {
14043
+ // blck_0 = 4;
14044
+ // blck_1 = blck_size / blck_0;
14045
+ // if (blck_1 < 0) {
14046
+ // blck_1 = 1;
14047
+ // }
14048
+ // // blck_0 = (int64_t)sqrt(blck_size);
14049
+ // // blck_1 = blck_0;
14050
+ // }
14051
+ // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
14052
+
14053
+ for (int j = n0; j < n1; j+=blck_n) {
14054
+ for (int i = m0; i < m1; i+=blck_m) {
14055
+ // printf("i j k => %d %d %d\n", i, j, K);
14056
+ for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
14057
+ for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
14058
+ ggml_vec_dot_f16(k,
14059
+ C + ii*n + jj,
14060
+ A + ii * k,
14061
+ B + jj * k);
14062
+ }
13816
14063
  }
13817
14064
  }
13818
14065
  }
13819
14066
  }
13820
14067
 
13821
- static void ggml_compute_forward_conv_1d_s1_ph(
14068
+ // src0: kernel [OC, IC, K]
14069
+ // src1: signal [N, IC, IL]
14070
+ // dst: result [N, OL, IC*K]
14071
+ static void ggml_compute_forward_conv_1d_stage_0_f32(
13822
14072
  const struct ggml_compute_params * params,
13823
14073
  const struct ggml_tensor * src0,
13824
14074
  const struct ggml_tensor * src1,
13825
14075
  struct ggml_tensor * dst) {
13826
- switch (src0->type) {
14076
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14077
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
14078
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
14079
+
14080
+ int64_t t0 = ggml_perf_time_us();
14081
+ UNUSED(t0);
14082
+
14083
+ GGML_TENSOR_BINARY_OP_LOCALS;
14084
+
14085
+ const int64_t N = ne12;
14086
+ const int64_t IC = ne11;
14087
+ const int64_t IL = ne10;
14088
+
14089
+ const int64_t K = ne00;
14090
+
14091
+ const int64_t OL = ne1;
14092
+
14093
+ const int ith = params->ith;
14094
+ const int nth = params->nth;
14095
+
14096
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14097
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
14098
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
14099
+
14100
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14101
+ GGML_ASSERT(nb10 == sizeof(float));
14102
+
14103
+ if (params->type == GGML_TASK_INIT) {
14104
+ memset(dst->data, 0, ggml_nbytes(dst));
14105
+ return;
14106
+ }
14107
+
14108
+ if (params->type == GGML_TASK_FINALIZE) {
14109
+ return;
14110
+ }
14111
+
14112
+ // im2col: [N, IC, IL] => [N, OL, IC*K]
14113
+ {
14114
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
14115
+
14116
+ for (int64_t in = 0; in < N; in++) {
14117
+ for (int64_t iol = 0; iol < OL; iol++) {
14118
+ for (int64_t iic = ith; iic < IC; iic+=nth) {
14119
+
14120
+ // micro kernel
14121
+ ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
14122
+ const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
14123
+
14124
+ for (int64_t ik = 0; ik < K; ik++) {
14125
+ const int64_t iil = iol*s0 + ik*d0 - p0;
14126
+
14127
+ if (!(iil < 0 || iil >= IL)) {
14128
+ dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]);
14129
+ }
14130
+ }
14131
+ }
14132
+ }
14133
+ }
14134
+ }
14135
+ }
14136
+
14137
+ // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
14138
+ // src0: [OC, IC, K]
14139
+ // src1: [N, OL, IC * K]
14140
+ // result: [N, OC, OL]
14141
+ static void ggml_compute_forward_conv_1d_stage_1_f16(
14142
+ const struct ggml_compute_params * params,
14143
+ const struct ggml_tensor * src0,
14144
+ const struct ggml_tensor * src1,
14145
+ struct ggml_tensor * dst) {
14146
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14147
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
14148
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
14149
+
14150
+ int64_t t0 = ggml_perf_time_us();
14151
+ UNUSED(t0);
14152
+
14153
+ if (params->type == GGML_TASK_INIT) {
14154
+ return;
14155
+ }
14156
+
14157
+ if (params->type == GGML_TASK_FINALIZE) {
14158
+ return;
14159
+ }
14160
+
14161
+ GGML_TENSOR_BINARY_OP_LOCALS;
14162
+
14163
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14164
+ GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
14165
+ GGML_ASSERT(nb0 == sizeof(float));
14166
+
14167
+ const int N = ne12;
14168
+ const int OL = ne11;
14169
+
14170
+ const int OC = ne02;
14171
+ const int IC = ne01;
14172
+ const int K = ne00;
14173
+
14174
+ const int ith = params->ith;
14175
+ const int nth = params->nth;
14176
+
14177
+ int64_t m = OC;
14178
+ int64_t n = OL;
14179
+ int64_t k = IC * K;
14180
+
14181
+ // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
14182
+ for (int i = 0; i < N; i++) {
14183
+ ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
14184
+ ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
14185
+ float * C = (float *)dst->data + i * m * n; // [m, n]
14186
+
14187
+ gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
14188
+ }
14189
+ }
14190
+
14191
+ static void ggml_compute_forward_conv_1d(
14192
+ const struct ggml_compute_params * params,
14193
+ const struct ggml_tensor * src0,
14194
+ const struct ggml_tensor * src1,
14195
+ struct ggml_tensor * dst) {
14196
+ switch(src0->type) {
13827
14197
  case GGML_TYPE_F16:
13828
14198
  {
13829
- ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst);
14199
+ ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
13830
14200
  } break;
13831
14201
  case GGML_TYPE_F32:
13832
14202
  {
13833
- ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst);
14203
+ ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
14204
+ } break;
14205
+ default:
14206
+ {
14207
+ GGML_ASSERT(false);
14208
+ } break;
14209
+ }
14210
+ }
14211
+
14212
+ static void ggml_compute_forward_conv_1d_stage_0(
14213
+ const struct ggml_compute_params * params,
14214
+ const struct ggml_tensor * src0,
14215
+ const struct ggml_tensor * src1,
14216
+ struct ggml_tensor * dst) {
14217
+ switch(src0->type) {
14218
+ case GGML_TYPE_F16:
14219
+ {
14220
+ ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
13834
14221
  } break;
13835
14222
  default:
13836
14223
  {
@@ -13839,7 +14226,26 @@ static void ggml_compute_forward_conv_1d_s1_ph(
13839
14226
  }
13840
14227
  }
13841
14228
 
13842
- static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
14229
+ static void ggml_compute_forward_conv_1d_stage_1(
14230
+ const struct ggml_compute_params * params,
14231
+ const struct ggml_tensor * src0,
14232
+ const struct ggml_tensor * src1,
14233
+ struct ggml_tensor * dst) {
14234
+ switch(src0->type) {
14235
+ case GGML_TYPE_F16:
14236
+ {
14237
+ ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
14238
+ } break;
14239
+ default:
14240
+ {
14241
+ GGML_ASSERT(false);
14242
+ } break;
14243
+ }
14244
+ }
14245
+
14246
+ // ggml_compute_forward_conv_transpose_1d
14247
+
14248
+ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
13843
14249
  const struct ggml_compute_params * params,
13844
14250
  const struct ggml_tensor * src0,
13845
14251
  const struct ggml_tensor * src1,
@@ -13856,43 +14262,38 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13856
14262
  const int ith = params->ith;
13857
14263
  const int nth = params->nth;
13858
14264
 
13859
- const int nk = ne00;
13860
- const int nh = nk/2;
13861
-
13862
- const int ew0 = ggml_up32(ne01);
14265
+ const int nk = ne00*ne01*ne02;
13863
14266
 
13864
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13865
14267
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13866
14268
  GGML_ASSERT(nb10 == sizeof(float));
13867
14269
 
13868
14270
  if (params->type == GGML_TASK_INIT) {
13869
- // TODO: fix this memset (wsize is overestimated)
13870
14271
  memset(params->wdata, 0, params->wsize);
13871
14272
 
13872
- // prepare kernel data (src0)
14273
+ // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
13873
14274
  {
13874
14275
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13875
14276
 
13876
14277
  for (int64_t i02 = 0; i02 < ne02; i02++) {
13877
14278
  for (int64_t i01 = 0; i01 < ne01; i01++) {
13878
14279
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
13879
- ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
14280
+ ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
13880
14281
  for (int64_t i00 = 0; i00 < ne00; i00++) {
13881
- dst_data[i00*ew0 + i01] = src[i00];
14282
+ dst_data[i00*ne02 + i02] = src[i00];
13882
14283
  }
13883
14284
  }
13884
14285
  }
13885
14286
  }
13886
14287
 
13887
- // prepare source data (src1)
14288
+ // permute source data (src1) from (L x Cin) to (Cin x L)
13888
14289
  {
13889
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
14290
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
14291
+ ggml_fp16_t * dst_data = wdata;
13890
14292
 
13891
14293
  for (int64_t i11 = 0; i11 < ne11; i11++) {
13892
14294
  const float * const src = (float *)((char *) src1->data + i11*nb11);
13893
- ggml_fp16_t * dst_data = wdata;
13894
14295
  for (int64_t i10 = 0; i10 < ne10; i10++) {
13895
- dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
14296
+ dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
13896
14297
  }
13897
14298
  }
13898
14299
  }
@@ -13904,8 +14305,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13904
14305
  return;
13905
14306
  }
13906
14307
 
14308
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14309
+
13907
14310
  // total rows in dst
13908
- const int nr = ne02;
14311
+ const int nr = ne1;
13909
14312
 
13910
14313
  // rows per thread
13911
14314
  const int dr = (nr + nth - 1)/nth;
@@ -13914,23 +14317,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13914
14317
  const int ir0 = dr*ith;
13915
14318
  const int ir1 = MIN(ir0 + dr, nr);
13916
14319
 
14320
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14321
+ ggml_fp16_t * const wdata_src = wdata + nk;
14322
+
13917
14323
  for (int i1 = ir0; i1 < ir1; i1++) {
13918
14324
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
13919
- for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
13920
- dst_data[i0/2] = 0;
13921
- for (int k = -nh; k <= nh; k++) {
13922
- float v = 0.0f;
13923
- ggml_vec_dot_f16(ew0, &v,
13924
- (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13925
- (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13926
-
13927
- dst_data[i0/2] += v;
14325
+ ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
14326
+ for (int i10 = 0; i10 < ne10; i10++) {
14327
+ const int i1n = i10*ne11;
14328
+ for (int i00 = 0; i00 < ne00; i00++) {
14329
+ float v = 0;
14330
+ ggml_vec_dot_f16(ne02, &v,
14331
+ (ggml_fp16_t *) wdata_src + i1n,
14332
+ (ggml_fp16_t *) wdata_kernel + i00*ne02);
14333
+ dst_data[i10*s0 + i00] += v;
13928
14334
  }
13929
14335
  }
13930
14336
  }
13931
14337
  }
13932
14338
 
13933
- static void ggml_compute_forward_conv_1d_s2_ph_f32(
14339
+ static void ggml_compute_forward_conv_transpose_1d_f32(
13934
14340
  const struct ggml_compute_params * params,
13935
14341
  const struct ggml_tensor * src0,
13936
14342
  const struct ggml_tensor * src1,
@@ -13947,29 +14353,24 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13947
14353
  const int ith = params->ith;
13948
14354
  const int nth = params->nth;
13949
14355
 
13950
- const int nk = ne00;
13951
- const int nh = nk/2;
13952
-
13953
- const int ew0 = ggml_up32(ne01);
14356
+ const int nk = ne00*ne01*ne02;
13954
14357
 
13955
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13956
14358
  GGML_ASSERT(nb00 == sizeof(float));
13957
14359
  GGML_ASSERT(nb10 == sizeof(float));
13958
14360
 
13959
14361
  if (params->type == GGML_TASK_INIT) {
13960
- // TODO: fix this memset (wsize is overestimated)
13961
14362
  memset(params->wdata, 0, params->wsize);
13962
14363
 
13963
- // prepare kernel data (src0)
14364
+ // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
13964
14365
  {
13965
14366
  float * const wdata = (float *) params->wdata + 0;
13966
14367
 
13967
14368
  for (int64_t i02 = 0; i02 < ne02; i02++) {
13968
14369
  for (int64_t i01 = 0; i01 < ne01; i01++) {
13969
14370
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
13970
- float * dst_data = wdata + i02*ew0*ne00;
14371
+ float * dst_data = wdata + i01*ne00*ne02;
13971
14372
  for (int64_t i00 = 0; i00 < ne00; i00++) {
13972
- dst_data[i00*ew0 + i01] = src[i00];
14373
+ dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00];
13973
14374
  }
13974
14375
  }
13975
14376
  }
@@ -13977,13 +14378,13 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13977
14378
 
13978
14379
  // prepare source data (src1)
13979
14380
  {
13980
- float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
14381
+ float * const wdata = (float *) params->wdata + nk;
14382
+ float * dst_data = wdata;
13981
14383
 
13982
14384
  for (int64_t i11 = 0; i11 < ne11; i11++) {
13983
14385
  const float * const src = (float *)((char *) src1->data + i11*nb11);
13984
- float * dst_data = wdata;
13985
14386
  for (int64_t i10 = 0; i10 < ne10; i10++) {
13986
- dst_data[(i10 + nh)*ew0 + i11] = src[i10];
14387
+ dst_data[i10*ne11 + i11] = src[i10];
13987
14388
  }
13988
14389
  }
13989
14390
  }
@@ -13995,8 +14396,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13995
14396
  return;
13996
14397
  }
13997
14398
 
14399
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14400
+
13998
14401
  // total rows in dst
13999
- const int nr = ne02;
14402
+ const int nr = ne1;
14000
14403
 
14001
14404
  // rows per thread
14002
14405
  const int dr = (nr + nth - 1)/nth;
@@ -14005,23 +14408,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
14005
14408
  const int ir0 = dr*ith;
14006
14409
  const int ir1 = MIN(ir0 + dr, nr);
14007
14410
 
14411
+ float * const wdata = (float *) params->wdata + 0;
14412
+ float * const wdata_src = wdata + nk;
14413
+
14008
14414
  for (int i1 = ir0; i1 < ir1; i1++) {
14009
14415
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
14010
- for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
14011
- dst_data[i0/2] = 0;
14012
- for (int k = -nh; k <= nh; k++) {
14013
- float v = 0.0f;
14014
- ggml_vec_dot_f32(ew0, &v,
14015
- (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
14016
- (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
14017
-
14018
- dst_data[i0/2] += v;
14416
+ float * wdata_kernel = wdata + i1*ne02*ne00;
14417
+ for (int i10 = 0; i10 < ne10; i10++) {
14418
+ const int i1n = i10*ne11;
14419
+ for (int i00 = 0; i00 < ne00; i00++) {
14420
+ float v = 0;
14421
+ ggml_vec_dot_f32(ne02, &v,
14422
+ wdata_src + i1n,
14423
+ wdata_kernel + i00*ne02);
14424
+ dst_data[i10*s0 + i00] += v;
14019
14425
  }
14020
14426
  }
14021
14427
  }
14022
14428
  }
14023
14429
 
14024
- static void ggml_compute_forward_conv_1d_s2_ph(
14430
+ static void ggml_compute_forward_conv_transpose_1d(
14025
14431
  const struct ggml_compute_params * params,
14026
14432
  const struct ggml_tensor * src0,
14027
14433
  const struct ggml_tensor * src1,
@@ -14029,11 +14435,11 @@ static void ggml_compute_forward_conv_1d_s2_ph(
14029
14435
  switch (src0->type) {
14030
14436
  case GGML_TYPE_F16:
14031
14437
  {
14032
- ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst);
14438
+ ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
14033
14439
  } break;
14034
14440
  case GGML_TYPE_F32:
14035
14441
  {
14036
- ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst);
14442
+ ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
14037
14443
  } break;
14038
14444
  default:
14039
14445
  {
@@ -14042,27 +14448,6 @@ static void ggml_compute_forward_conv_1d_s2_ph(
14042
14448
  }
14043
14449
  }
14044
14450
 
14045
- // ggml_compute_forward_conv_1d
14046
-
14047
- static void ggml_compute_forward_conv_1d(
14048
- const struct ggml_compute_params * params,
14049
- const struct ggml_tensor * src0,
14050
- const struct ggml_tensor * src1,
14051
- struct ggml_tensor * dst) {
14052
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14053
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
14054
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
14055
- GGML_ASSERT(d0 == 1); // dilation not supported
14056
- GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported
14057
- if (s0 == 1) {
14058
- ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst);
14059
- } else if (s0 == 2) {
14060
- ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
14061
- } else {
14062
- GGML_ASSERT(false); // only stride 1 and 2 supported
14063
- }
14064
- }
14065
-
14066
14451
  // ggml_compute_forward_conv_2d
14067
14452
 
14068
14453
  static void ggml_compute_forward_conv_2d_f16_f32(
@@ -14077,7 +14462,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
14077
14462
  int64_t t0 = ggml_perf_time_us();
14078
14463
  UNUSED(t0);
14079
14464
 
14080
- GGML_TENSOR_BINARY_OP_LOCALS
14465
+ GGML_TENSOR_BINARY_OP_LOCALS;
14081
14466
 
14082
14467
  const int ith = params->ith;
14083
14468
  const int nth = params->nth;
@@ -14105,20 +14490,22 @@ static void ggml_compute_forward_conv_2d_f16_f32(
14105
14490
  {
14106
14491
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14107
14492
 
14108
- for (int i12 = 0; i12 < ne12; i12++) {
14109
- const float * const src = (float *)((char *) src1->data + i12*nb12);
14110
- ggml_fp16_t * dst_data = wdata;
14111
-
14112
- for (int i1 = 0; i1 < ne1; i1++) {
14113
- for (int i0 = 0; i0 < ne0; i0++) {
14114
- for (int ik1 = 0; ik1 < nk1; ik1++) {
14115
- for (int ik0 = 0; ik0 < nk0; ik0++) {
14116
- const int idx0 = i0*s0 + ik0*d0 - p0;
14117
- const int idx1 = i1*s1 + ik1*d1 - p1;
14118
-
14119
- if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
14120
- dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
14121
- GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
14493
+ for (int i13 = 0; i13 < ne13; i13++) {
14494
+ for (int i12 = 0; i12 < ne12; i12++) {
14495
+ const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
14496
+ ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
14497
+
14498
+ for (int i1 = 0; i1 < ne1; i1++) {
14499
+ for (int i0 = 0; i0 < ne0; i0++) {
14500
+ for (int ik1 = 0; ik1 < nk1; ik1++) {
14501
+ for (int ik0 = 0; ik0 < nk0; ik0++) {
14502
+ const int idx0 = i0*s0 + ik0*d0 - p0;
14503
+ const int idx1 = i1*s1 + ik1*d1 - p1;
14504
+
14505
+ if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
14506
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
14507
+ GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
14508
+ }
14122
14509
  }
14123
14510
  }
14124
14511
  }
@@ -16401,6 +16788,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16401
16788
  {
16402
16789
  ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
16403
16790
  } break;
16791
+ case GGML_OP_CONV_1D_STAGE_0:
16792
+ {
16793
+ ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
16794
+ } break;
16795
+ case GGML_OP_CONV_1D_STAGE_1:
16796
+ {
16797
+ ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
16798
+ } break;
16799
+ case GGML_OP_CONV_TRANSPOSE_1D:
16800
+ {
16801
+ ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
16802
+ } break;
16404
16803
  case GGML_OP_CONV_2D:
16405
16804
  {
16406
16805
  ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
@@ -17326,10 +17725,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17326
17725
  {
17327
17726
  GGML_ASSERT(false); // TODO: not implemented
17328
17727
  } break;
17728
+ case GGML_OP_CONV_1D_STAGE_0:
17729
+ {
17730
+ GGML_ASSERT(false); // TODO: not implemented
17731
+ } break;
17732
+ case GGML_OP_CONV_1D_STAGE_1:
17733
+ {
17734
+ GGML_ASSERT(false); // TODO: not implemented
17735
+ } break;
17329
17736
  case GGML_OP_CONV_2D:
17330
17737
  {
17331
17738
  GGML_ASSERT(false); // TODO: not implemented
17332
17739
  } break;
17740
+ case GGML_OP_CONV_TRANSPOSE_1D:
17741
+ {
17742
+ GGML_ASSERT(false); // TODO: not implemented
17743
+ } break;
17333
17744
  case GGML_OP_CONV_TRANSPOSE_2D:
17334
17745
  {
17335
17746
  GGML_ASSERT(false); // TODO: not implemented
@@ -18171,21 +18582,68 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
18171
18582
  GGML_ASSERT(node->src[1]->ne[2] == 1);
18172
18583
  GGML_ASSERT(node->src[1]->ne[3] == 1);
18173
18584
 
18585
+ const int64_t ne00 = node->src[0]->ne[0];
18586
+ const int64_t ne01 = node->src[0]->ne[1];
18587
+ const int64_t ne02 = node->src[0]->ne[2];
18588
+
18589
+ const int64_t ne10 = node->src[1]->ne[0];
18590
+ const int64_t ne11 = node->src[1]->ne[1];
18591
+
18592
+ const int64_t ne0 = node->ne[0];
18593
+ const int64_t ne1 = node->ne[1];
18594
+ const int64_t nk = ne00;
18595
+ const int64_t ew0 = nk * ne01;
18596
+
18597
+ UNUSED(ne02);
18598
+ UNUSED(ne10);
18599
+ UNUSED(ne11);
18600
+
18174
18601
  size_t cur = 0;
18175
- const int nk = node->src[0]->ne[0];
18176
18602
 
18177
18603
  if (node->src[0]->type == GGML_TYPE_F16 &&
18178
- node->src[1]->type == GGML_TYPE_F32) {
18179
- cur = sizeof(ggml_fp16_t)*(
18180
- nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
18181
- ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
18182
- );
18604
+ node->src[1]->type == GGML_TYPE_F32) {
18605
+ cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
18606
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
18607
+ node->src[1]->type == GGML_TYPE_F32) {
18608
+ cur = sizeof(float)*(ne0*ne1*ew0);
18609
+ } else {
18610
+ GGML_ASSERT(false);
18611
+ }
18612
+
18613
+ work_size = MAX(work_size, cur);
18614
+ } break;
18615
+ case GGML_OP_CONV_1D_STAGE_0:
18616
+ {
18617
+ n_tasks = n_threads;
18618
+ } break;
18619
+ case GGML_OP_CONV_1D_STAGE_1:
18620
+ {
18621
+ n_tasks = n_threads;
18622
+ } break;
18623
+ case GGML_OP_CONV_TRANSPOSE_1D:
18624
+ {
18625
+ n_tasks = n_threads;
18626
+
18627
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
18628
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
18629
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
18630
+
18631
+ const int64_t ne00 = node->src[0]->ne[0]; // K
18632
+ const int64_t ne01 = node->src[0]->ne[1]; // Cout
18633
+ const int64_t ne02 = node->src[0]->ne[2]; // Cin
18634
+
18635
+ const int64_t ne10 = node->src[1]->ne[0]; // L
18636
+ const int64_t ne11 = node->src[1]->ne[1]; // Cin
18637
+
18638
+ size_t cur = 0;
18639
+ if (node->src[0]->type == GGML_TYPE_F16 &&
18640
+ node->src[1]->type == GGML_TYPE_F32) {
18641
+ cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18642
+ cur += sizeof(ggml_fp16_t)*ne10*ne11;
18183
18643
  } else if (node->src[0]->type == GGML_TYPE_F32 &&
18184
- node->src[1]->type == GGML_TYPE_F32) {
18185
- cur = sizeof(float)*(
18186
- nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
18187
- ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
18188
- );
18644
+ node->src[1]->type == GGML_TYPE_F32) {
18645
+ cur += sizeof(float)*ne00*ne01*ne02;
18646
+ cur += sizeof(float)*ne10*ne11;
18189
18647
  } else {
18190
18648
  GGML_ASSERT(false);
18191
18649
  }
@@ -19311,7 +19769,7 @@ static enum ggml_opt_result ggml_opt_adam(
19311
19769
  if (callback) {
19312
19770
  callback(callback_data, accum_step, &sched, &cancel);
19313
19771
  if (cancel) {
19314
- break;
19772
+ return GGML_OPT_CANCEL;
19315
19773
  }
19316
19774
  }
19317
19775
  // ggml_graph_reset (gf);
@@ -19320,9 +19778,6 @@ static enum ggml_opt_result ggml_opt_adam(
19320
19778
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19321
19779
  fx += ggml_get_f32_1d(f, 0);
19322
19780
  }
19323
- if (cancel) {
19324
- return GGML_OPT_DID_NOT_CONVERGE;
19325
- }
19326
19781
  fx *= accum_norm;
19327
19782
 
19328
19783
  opt->adam.fx_prev = fx;
@@ -19348,9 +19803,6 @@ static enum ggml_opt_result ggml_opt_adam(
19348
19803
 
19349
19804
  // run the optimizer
19350
19805
  for (int t = 0; t < params.adam.n_iter; ++t) {
19351
- if (cancel) {
19352
- break;
19353
- }
19354
19806
  opt->iter = iter0 + t + 1;
19355
19807
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
19356
19808
 
@@ -19408,7 +19860,7 @@ static enum ggml_opt_result ggml_opt_adam(
19408
19860
  if (callback) {
19409
19861
  callback(callback_data, accum_step, &sched, &cancel);
19410
19862
  if (cancel) {
19411
- break;
19863
+ return GGML_OPT_CANCEL;;
19412
19864
  }
19413
19865
  }
19414
19866
  // ggml_graph_reset (gf);
@@ -19417,9 +19869,6 @@ static enum ggml_opt_result ggml_opt_adam(
19417
19869
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19418
19870
  fx += ggml_get_f32_1d(f, 0);
19419
19871
  }
19420
- if (cancel) {
19421
- break;
19422
- }
19423
19872
  fx *= accum_norm;
19424
19873
 
19425
19874
  opt->loss_after = fx;
@@ -19538,7 +19987,7 @@ static enum ggml_opt_result linesearch_backtracking(
19538
19987
  finit = *fx;
19539
19988
  dgtest = params->lbfgs.ftol*dginit;
19540
19989
 
19541
- while (!*cancel) {
19990
+ while (true) {
19542
19991
  ggml_vec_cpy_f32(nx, x, xp);
19543
19992
  ggml_vec_mad_f32(nx, x, d, *step);
19544
19993
 
@@ -19554,7 +20003,7 @@ static enum ggml_opt_result linesearch_backtracking(
19554
20003
  float sched = 0;
19555
20004
  callback(callback_data, accum_step, &sched, cancel);
19556
20005
  if (*cancel) {
19557
- break;
20006
+ return GGML_OPT_CANCEL;
19558
20007
  }
19559
20008
  }
19560
20009
  // ggml_graph_reset (gf);
@@ -19563,9 +20012,6 @@ static enum ggml_opt_result linesearch_backtracking(
19563
20012
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19564
20013
  *fx += ggml_get_f32_1d(f, 0);
19565
20014
  }
19566
- if (*cancel) {
19567
- break;
19568
- }
19569
20015
  *fx *= accum_norm;
19570
20016
 
19571
20017
  }
@@ -19698,7 +20144,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19698
20144
  float sched = 0;
19699
20145
  callback(callback_data, accum_step, &sched, &cancel);
19700
20146
  if (cancel) {
19701
- break;
20147
+ return GGML_OPT_CANCEL;
19702
20148
  }
19703
20149
  }
19704
20150
  // ggml_graph_reset (gf);
@@ -19707,9 +20153,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19707
20153
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19708
20154
  fx += ggml_get_f32_1d(f, 0);
19709
20155
  }
19710
- if (cancel) {
19711
- return GGML_OPT_DID_NOT_CONVERGE;
19712
- }
19713
20156
  fx *= accum_norm;
19714
20157
 
19715
20158
  opt->loss_before = fx;
@@ -19768,9 +20211,13 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19768
20211
  ggml_vec_cpy_f32(nx, xp, x);
19769
20212
  ggml_vec_cpy_f32(nx, gp, g);
19770
20213
 
20214
+ // TODO: instead of passing &cancel here, use the return code of the linesearch
20215
+ // to determine if the optimization should be cancelled
20216
+ // this is a simple change, but not doing this atm, since I don't have a nice
20217
+ // way to test and don't want to break something with so many changes lined up
19771
20218
  ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
19772
- if (!cancel) {
19773
- break;
20219
+ if (cancel) {
20220
+ return GGML_OPT_CANCEL;
19774
20221
  }
19775
20222
 
19776
20223
  if (ls < 0) {