llama_cpp 0.6.0 → 0.7.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -162,40 +162,16 @@ typedef void * thread_ret_t;
162
162
 
163
163
  #define GGML_PRINT(...) printf(__VA_ARGS__)
164
164
 
165
+ //
166
+ // end of logging block
167
+ //
168
+
165
169
  #ifdef GGML_USE_ACCELERATE
166
170
  // uncomment to use vDSP for soft max computation
167
171
  // note: not sure if it is actually faster
168
172
  //#define GGML_SOFT_MAX_ACCELERATE
169
173
  #endif
170
174
 
171
- //
172
- // logging
173
- //
174
-
175
- #if (GGML_DEBUG >= 1)
176
- #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
177
- #else
178
- #define GGML_PRINT_DEBUG(...)
179
- #endif
180
-
181
- #if (GGML_DEBUG >= 5)
182
- #define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
183
- #else
184
- #define GGML_PRINT_DEBUG_5(...)
185
- #endif
186
-
187
- #if (GGML_DEBUG >= 10)
188
- #define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
189
- #else
190
- #define GGML_PRINT_DEBUG_10(...)
191
- #endif
192
-
193
- #define GGML_PRINT(...) printf(__VA_ARGS__)
194
-
195
- //
196
- // end of logging block
197
- //
198
-
199
175
  #if defined(_MSC_VER) || defined(__MINGW32__)
200
176
  #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
201
177
  #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
@@ -1032,8 +1008,8 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
1032
1008
  y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1033
1009
 
1034
1010
  // get the 5-th bit and store it in qh at the right position
1035
- qh |= ((xi0 & 0x10) >> 4) << (j + 0);
1036
- qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
1011
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1012
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1037
1013
  }
1038
1014
 
1039
1015
  memcpy(&y[i].qh, &qh, sizeof(qh));
@@ -1080,8 +1056,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
1080
1056
  y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1081
1057
 
1082
1058
  // get the 5-th bit and store it in qh at the right position
1083
- qh |= ((xi0 & 0x10) >> 4) << (j + 0);
1084
- qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
1059
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1060
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1085
1061
  }
1086
1062
 
1087
1063
  memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
@@ -1272,6 +1248,33 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1272
1248
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1273
1249
  #endif
1274
1250
  }
1251
+ #elif defined(__riscv_v_intrinsic)
1252
+
1253
+ size_t vl = __riscv_vsetvl_e32m4(QK8_0);
1254
+
1255
+ for (int i = 0; i < nb; i++) {
1256
+ // load elements
1257
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
1258
+
1259
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1260
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
1261
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1262
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1263
+
1264
+ const float d = amax / ((1 << 7) - 1);
1265
+ const float id = d ? 1.0f/d : 0.0f;
1266
+
1267
+ y[i].d = GGML_FP32_TO_FP16(d);
1268
+
1269
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1270
+
1271
+ // convert to integer
1272
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1273
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1274
+
1275
+ // store result
1276
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1277
+ }
1275
1278
  #else
1276
1279
  // scalar
1277
1280
  quantize_row_q8_0_reference(x, y, k);
@@ -1490,6 +1493,41 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
1490
1493
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1491
1494
  #endif
1492
1495
  }
1496
+ #elif defined(__riscv_v_intrinsic)
1497
+
1498
+ size_t vl = __riscv_vsetvl_e32m4(QK8_1);
1499
+
1500
+ for (int i = 0; i < nb; i++) {
1501
+ // load elements
1502
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
1503
+
1504
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1505
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
1506
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1507
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1508
+
1509
+ const float d = amax / ((1 << 7) - 1);
1510
+ const float id = d ? 1.0f/d : 0.0f;
1511
+
1512
+ y[i].d = d;
1513
+
1514
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1515
+
1516
+ // convert to integer
1517
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1518
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1519
+
1520
+ // store result
1521
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1522
+
1523
+ // compute sum for y[i].s
1524
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
1525
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
1526
+
1527
+ // set y[i].s
1528
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
1529
+ y[i].s = sum*d;
1530
+ }
1493
1531
  #else
1494
1532
  // scalar
1495
1533
  quantize_row_q8_1_reference(x, y, k);
@@ -2662,30 +2700,32 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2662
2700
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
2663
2701
 
2664
2702
  for (int i = 0; i < nb; i++) {
2665
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2703
+ // load elements
2704
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2666
2705
 
2667
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2668
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2706
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2707
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2669
2708
 
2670
- vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2671
- vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2709
+ // mask and store lower part of x, and then upper part
2710
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2711
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2672
2712
 
2673
- vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2674
- vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2713
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2714
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2675
2715
 
2676
- vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
2677
- vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
2716
+ // subtract offset
2717
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2718
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2678
2719
 
2679
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2680
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2720
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2721
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2681
2722
 
2682
2723
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2683
2724
 
2684
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2685
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2725
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2726
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2686
2727
 
2687
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2688
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2728
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2689
2729
 
2690
2730
  sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2691
2731
  }
@@ -2823,27 +2863,28 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2823
2863
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
2824
2864
 
2825
2865
  for (int i = 0; i < nb; i++) {
2826
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2866
+ // load elements
2867
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2827
2868
 
2828
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2829
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2869
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2870
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2830
2871
 
2831
- vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2832
- vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2872
+ // mask and store lower part of x, and then upper part
2873
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2874
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2833
2875
 
2834
- vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2835
- vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2876
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2877
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2836
2878
 
2837
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2838
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2879
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2880
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2839
2881
 
2840
2882
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2841
2883
 
2842
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2843
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2884
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2885
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2844
2886
 
2845
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2846
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2887
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2847
2888
 
2848
2889
  sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2849
2890
  }
@@ -3088,66 +3129,61 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
3088
3129
 
3089
3130
  uint32_t qh;
3090
3131
 
3091
- // These temp values are for masking and shift operations
3092
- uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3093
- uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
3094
- 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
3095
-
3096
3132
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3097
3133
 
3134
+ // These tempory registers are for masking and shift operations
3135
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3136
+ vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3137
+
3138
+ vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3139
+ vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3140
+
3098
3141
  for (int i = 0; i < nb; i++) {
3099
3142
  memcpy(&qh, x[i].qh, sizeof(uint32_t));
3100
3143
 
3101
- // temporary registers
3102
- vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
3103
- vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
3104
- vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
3105
- vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
3106
-
3107
3144
  // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3108
- vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl);
3109
- vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl);
3110
- vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3145
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3146
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3147
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3111
3148
 
3112
3149
  // ((qh & (1u << (j + 16))) >> (j + 12));
3113
- vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl);
3114
- vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl);
3150
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3151
+ vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
3115
3152
 
3116
3153
  // narrowing
3117
- vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl);
3118
- vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3154
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3155
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3119
3156
 
3120
- vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl);
3121
- vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3157
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3158
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3122
3159
 
3123
3160
  // load
3124
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3161
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3125
3162
 
3126
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3127
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3163
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3164
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3128
3165
 
3129
- vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3130
- vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3166
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3167
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3131
3168
 
3132
- vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3133
- vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3169
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3170
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3134
3171
 
3135
- vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3136
- vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3172
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3173
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3137
3174
 
3138
- vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl);
3139
- vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl);
3175
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3176
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
3140
3177
 
3141
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3142
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3178
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3179
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3143
3180
 
3144
3181
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3145
3182
 
3146
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3147
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3183
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3184
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3148
3185
 
3149
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3150
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3186
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3151
3187
 
3152
3188
  sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3153
3189
  }
@@ -3414,62 +3450,58 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
3414
3450
 
3415
3451
  uint32_t qh;
3416
3452
 
3417
- // These temp values are for shift operations
3418
- uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3419
-
3420
3453
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3421
3454
 
3455
+ // temporary registers for shift operations
3456
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3457
+ vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3458
+
3422
3459
  for (int i = 0; i < nb; i++) {
3423
3460
  memcpy(&qh, x[i].qh, sizeof(uint32_t));
3424
3461
 
3425
- // temporary registers
3426
- vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
3427
- vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
3428
-
3429
3462
  // load qh
3430
- vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl);
3463
+ vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
3431
3464
 
3432
3465
  // ((qh >> (j + 0)) << 4) & 0x10;
3433
- vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl);
3434
- vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3435
- vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl);
3466
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3467
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3468
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
3436
3469
 
3437
3470
  // ((qh >> (j + 12)) ) & 0x10;
3438
- vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl);
3439
- vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl);
3471
+ vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3472
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
3440
3473
 
3441
3474
  // narrowing
3442
- vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl);
3443
- vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3475
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3476
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3444
3477
 
3445
- vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl);
3446
- vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3478
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3479
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3447
3480
 
3448
3481
  // load
3449
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3482
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3450
3483
 
3451
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3452
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3484
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3485
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3453
3486
 
3454
- vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3455
- vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3487
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3488
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3456
3489
 
3457
- vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3458
- vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3490
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3491
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3459
3492
 
3460
- vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3461
- vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3493
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3494
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3462
3495
 
3463
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3464
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3496
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3497
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3465
3498
 
3466
3499
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3467
3500
 
3468
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3469
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3501
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3502
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3470
3503
 
3471
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3472
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3504
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3473
3505
 
3474
3506
  sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
3475
3507
  }
@@ -4025,12 +4057,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4025
4057
  "ALIBI",
4026
4058
  "CLAMP",
4027
4059
  "CONV_1D",
4060
+ "CONV_TRANSPOSE_1D",
4028
4061
  "CONV_2D",
4029
4062
  "CONV_TRANSPOSE_2D",
4030
4063
  "POOL_1D",
4031
4064
  "POOL_2D",
4032
4065
  "UPSCALE",
4033
4066
 
4067
+ "CONV_1D_STAGE_0",
4068
+ "CONV_1D_STAGE_1",
4069
+
4034
4070
  "FLASH_ATTN",
4035
4071
  "FLASH_FF",
4036
4072
  "FLASH_ATTN_BACK",
@@ -4056,7 +4092,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4056
4092
  "CROSS_ENTROPY_LOSS_BACK",
4057
4093
  };
4058
4094
 
4059
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
4095
+ static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4060
4096
 
4061
4097
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4062
4098
  "none",
@@ -4107,12 +4143,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4107
4143
  "alibi(x)",
4108
4144
  "clamp(x)",
4109
4145
  "conv_1d(x)",
4146
+ "conv_transpose_1d(x)",
4110
4147
  "conv_2d(x)",
4111
4148
  "conv_transpose_2d(x)",
4112
4149
  "pool_1d(x)",
4113
4150
  "pool_2d(x)",
4114
4151
  "upscale(x)",
4115
4152
 
4153
+ "conv_1d_stage_0(x)",
4154
+ "conv_1d_stage_1(x)",
4155
+
4116
4156
  "flash_attn(x)",
4117
4157
  "flash_ff(x)",
4118
4158
  "flash_attn_back(x)",
@@ -4138,7 +4178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4138
4178
  "cross_entropy_loss_back(x,y)",
4139
4179
  };
4140
4180
 
4141
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
4181
+ static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4142
4182
 
4143
4183
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
4144
4184
 
@@ -4167,7 +4207,10 @@ static void ggml_setup_op_has_task_pass(void) {
4167
4207
  p[GGML_OP_DIAG_MASK_INF ] = true;
4168
4208
  p[GGML_OP_DIAG_MASK_ZERO ] = true;
4169
4209
  p[GGML_OP_CONV_1D ] = true;
4210
+ p[GGML_OP_CONV_1D_STAGE_0 ] = true;
4211
+ p[GGML_OP_CONV_1D_STAGE_1 ] = true;
4170
4212
  p[GGML_OP_CONV_2D ] = true;
4213
+ p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
4171
4214
  p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
4172
4215
  p[GGML_OP_FLASH_ATTN_BACK ] = true;
4173
4216
  p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
@@ -4884,6 +4927,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
4884
4927
  *result = (struct ggml_tensor) {
4885
4928
  /*.type =*/ type,
4886
4929
  /*.backend =*/ GGML_BACKEND_CPU,
4930
+ /*.buffer =*/ NULL,
4887
4931
  /*.n_dims =*/ n_dims,
4888
4932
  /*.ne =*/ { 1, 1, 1, 1 },
4889
4933
  /*.nb =*/ { 0, 0, 0, 0 },
@@ -5450,6 +5494,39 @@ struct ggml_tensor * ggml_view_tensor(
5450
5494
  return result;
5451
5495
  }
5452
5496
 
5497
+ struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) {
5498
+ struct ggml_object * obj = ctx->objects_begin;
5499
+
5500
+ char * const mem_buffer = ctx->mem_buffer;
5501
+
5502
+ while (obj != NULL) {
5503
+ if (obj->type == GGML_OBJECT_TENSOR) {
5504
+ return (struct ggml_tensor *)(mem_buffer + obj->offs);
5505
+ }
5506
+
5507
+ obj = obj->next;
5508
+ }
5509
+
5510
+ return NULL;
5511
+ }
5512
+
5513
+ struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) {
5514
+ struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
5515
+ obj = obj->next;
5516
+
5517
+ char * const mem_buffer = ctx->mem_buffer;
5518
+
5519
+ while (obj != NULL) {
5520
+ if (obj->type == GGML_OBJECT_TENSOR) {
5521
+ return (struct ggml_tensor *)(mem_buffer + obj->offs);
5522
+ }
5523
+
5524
+ obj = obj->next;
5525
+ }
5526
+
5527
+ return NULL;
5528
+ }
5529
+
5453
5530
  struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
5454
5531
  struct ggml_object * obj = ctx->objects_begin;
5455
5532
 
@@ -6690,7 +6767,6 @@ struct ggml_tensor * ggml_cont_4d(
6690
6767
  return result;
6691
6768
  }
6692
6769
 
6693
-
6694
6770
  // ggml_reshape
6695
6771
 
6696
6772
  struct ggml_tensor * ggml_reshape(
@@ -7448,14 +7524,17 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
7448
7524
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7449
7525
  }
7450
7526
 
7451
- GGML_API struct ggml_tensor * ggml_conv_1d(
7452
- struct ggml_context * ctx,
7453
- struct ggml_tensor * a,
7454
- struct ggml_tensor * b,
7455
- int s0,
7456
- int p0,
7457
- int d0) {
7458
- GGML_ASSERT(ggml_is_matrix(b));
7527
+ // im2col: [N, IC, IL] => [N, OL, IC*K]
7528
+ // a: [OC,IC, K]
7529
+ // b: [N, IC, IL]
7530
+ // result: [N, OL, IC*K]
7531
+ static struct ggml_tensor * ggml_conv_1d_stage_0(
7532
+ struct ggml_context * ctx,
7533
+ struct ggml_tensor * a,
7534
+ struct ggml_tensor * b,
7535
+ int s0,
7536
+ int p0,
7537
+ int d0) {
7459
7538
  GGML_ASSERT(a->ne[1] == b->ne[1]);
7460
7539
  bool is_node = false;
7461
7540
 
@@ -7464,16 +7543,54 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7464
7543
  is_node = true;
7465
7544
  }
7466
7545
 
7546
+ const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
7547
+
7467
7548
  const int64_t ne[4] = {
7468
- ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7469
- a->ne[2], 1, 1,
7549
+ a->ne[1] * a->ne[0],
7550
+ OL,
7551
+ b->ne[2],
7552
+ 1,
7470
7553
  };
7471
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7554
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
7472
7555
 
7473
7556
  int32_t params[] = { s0, p0, d0 };
7474
7557
  ggml_set_op_params(result, params, sizeof(params));
7475
7558
 
7476
- result->op = GGML_OP_CONV_1D;
7559
+ result->op = GGML_OP_CONV_1D_STAGE_0;
7560
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7561
+ result->src[0] = a;
7562
+ result->src[1] = b;
7563
+
7564
+ return result;
7565
+ }
7566
+
7567
+ // ggml_conv_1d_stage_1
7568
+
7569
+ // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
7570
+ // a: [OC, IC, K]
7571
+ // b: [N, OL, IC * K]
7572
+ // result: [N, OC, OL]
7573
+ static struct ggml_tensor * ggml_conv_1d_stage_1(
7574
+ struct ggml_context * ctx,
7575
+ struct ggml_tensor * a,
7576
+ struct ggml_tensor * b) {
7577
+
7578
+ bool is_node = false;
7579
+
7580
+ if (a->grad || b->grad) {
7581
+ GGML_ASSERT(false); // TODO: implement backward
7582
+ is_node = true;
7583
+ }
7584
+
7585
+ const int64_t ne[4] = {
7586
+ b->ne[1],
7587
+ a->ne[2],
7588
+ b->ne[2],
7589
+ 1,
7590
+ };
7591
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7592
+
7593
+ result->op = GGML_OP_CONV_1D_STAGE_1;
7477
7594
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7478
7595
  result->src[0] = a;
7479
7596
  result->src[1] = b;
@@ -7481,6 +7598,53 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7481
7598
  return result;
7482
7599
  }
7483
7600
 
7601
+ // ggml_conv_1d
7602
+
7603
+ GGML_API struct ggml_tensor * ggml_conv_1d(
7604
+ struct ggml_context * ctx,
7605
+ struct ggml_tensor * a,
7606
+ struct ggml_tensor * b,
7607
+ int s0,
7608
+ int p0,
7609
+ int d0) {
7610
+ struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
7611
+ result = ggml_conv_1d_stage_1(ctx, a, result);
7612
+ return result;
7613
+ }
7614
+
7615
+ // GGML_API struct ggml_tensor * ggml_conv_1d(
7616
+ // struct ggml_context * ctx,
7617
+ // struct ggml_tensor * a,
7618
+ // struct ggml_tensor * b,
7619
+ // int s0,
7620
+ // int p0,
7621
+ // int d0) {
7622
+ // GGML_ASSERT(ggml_is_matrix(b));
7623
+ // GGML_ASSERT(a->ne[1] == b->ne[1]);
7624
+ // bool is_node = false;
7625
+
7626
+ // if (a->grad || b->grad) {
7627
+ // GGML_ASSERT(false); // TODO: implement backward
7628
+ // is_node = true;
7629
+ // }
7630
+
7631
+ // const int64_t ne[4] = {
7632
+ // ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7633
+ // a->ne[2], 1, 1,
7634
+ // };
7635
+ // struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7636
+
7637
+ // int32_t params[] = { s0, p0, d0 };
7638
+ // ggml_set_op_params(result, params, sizeof(params));
7639
+
7640
+ // result->op = GGML_OP_CONV_1D;
7641
+ // result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7642
+ // result->src[0] = a;
7643
+ // result->src[1] = b;
7644
+
7645
+ // return result;
7646
+ // }
7647
+
7484
7648
  // ggml_conv_1d_ph
7485
7649
 
7486
7650
  struct ggml_tensor* ggml_conv_1d_ph(
@@ -7492,6 +7656,50 @@ struct ggml_tensor* ggml_conv_1d_ph(
7492
7656
  return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
7493
7657
  }
7494
7658
 
7659
+ // ggml_conv_transpose_1d
7660
+
7661
+ static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
7662
+ return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
7663
+ }
7664
+
7665
+ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
7666
+ struct ggml_context * ctx,
7667
+ struct ggml_tensor * a,
7668
+ struct ggml_tensor * b,
7669
+ int s0,
7670
+ int p0,
7671
+ int d0) {
7672
+ GGML_ASSERT(ggml_is_matrix(b));
7673
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
7674
+ GGML_ASSERT(a->ne[3] == 1);
7675
+
7676
+ GGML_ASSERT(p0 == 0);
7677
+ GGML_ASSERT(d0 == 1);
7678
+
7679
+ bool is_node = false;
7680
+
7681
+ if (a->grad || b->grad) {
7682
+ GGML_ASSERT(false); // TODO: implement backward
7683
+ is_node = true;
7684
+ }
7685
+
7686
+ const int64_t ne[4] = {
7687
+ ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
7688
+ a->ne[1], b->ne[2], 1,
7689
+ };
7690
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7691
+
7692
+ int32_t params[] = { s0, p0, d0 };
7693
+ ggml_set_op_params(result, params, sizeof(params));
7694
+
7695
+ result->op = GGML_OP_CONV_TRANSPOSE_1D;
7696
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7697
+ result->src[0] = a;
7698
+ result->src[1] = b;
7699
+
7700
+ return result;
7701
+ }
7702
+
7495
7703
  // ggml_conv_2d
7496
7704
 
7497
7705
  struct ggml_tensor * ggml_conv_2d(
@@ -8472,6 +8680,7 @@ void ggml_set_param(
8472
8680
 
8473
8681
  GGML_ASSERT(tensor->grad == NULL);
8474
8682
  tensor->grad = ggml_dup_tensor(ctx, tensor);
8683
+ ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
8475
8684
  }
8476
8685
 
8477
8686
  // ggml_compute_forward_dup
@@ -11058,7 +11267,7 @@ static void ggml_compute_forward_silu_f32(
11058
11267
 
11059
11268
  #ifndef NDEBUG
11060
11269
  for (int k = 0; k < nc; k++) {
11061
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
11270
+ const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
11062
11271
  UNUSED(x);
11063
11272
  assert(!isnan(x));
11064
11273
  assert(!isinf(x));
@@ -11621,11 +11830,6 @@ static void ggml_compute_forward_mul_mat(
11621
11830
 
11622
11831
  #if defined(GGML_USE_CLBLAST)
11623
11832
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
11624
- // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
11625
- // ref: https://github.com/ggerganov/ggml/pull/224
11626
- GGML_ASSERT(ne02 == ne12);
11627
- GGML_ASSERT(ne03 == ne13);
11628
-
11629
11833
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
11630
11834
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
11631
11835
  }
@@ -12889,28 +13093,25 @@ static void ggml_compute_forward_alibi_f32(
12889
13093
  return;
12890
13094
  }
12891
13095
 
12892
- const int n_past = ((int32_t *) dst->op_params)[0];
13096
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12893
13097
  const int n_head = ((int32_t *) dst->op_params)[1];
12894
13098
  float max_bias;
12895
13099
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12896
13100
 
12897
- assert(n_past >= 0);
13101
+ const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13102
+ const int64_t ne1 = src0->ne[1]; // seq_len_without_past
13103
+ const int64_t ne2 = src0->ne[2]; // n_head -> this is k
13104
+ //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
12898
13105
 
12899
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12900
- const int ne1 = src0->ne[1]; // seq_len_without_past
12901
- const int ne2 = src0->ne[2]; // n_head -> this is k
12902
- //const int ne3 = src0->ne[3]; // 1 -> bsz
13106
+ const int64_t n = ggml_nrows(src0);
13107
+ const int64_t ne2_ne3 = n/ne1; // ne2*ne3
12903
13108
 
12904
- const int n = ggml_nrows(src0);
12905
- const int ne2_ne3 = n/ne1; // ne2*ne3
12906
-
12907
- const int nb0 = src0->nb[0];
12908
- const int nb1 = src0->nb[1];
12909
- const int nb2 = src0->nb[2];
13109
+ const size_t nb0 = src0->nb[0];
13110
+ const size_t nb1 = src0->nb[1];
13111
+ const size_t nb2 = src0->nb[2];
12910
13112
  //const int nb3 = src0->nb[3];
12911
13113
 
12912
13114
  GGML_ASSERT(nb0 == sizeof(float));
12913
- GGML_ASSERT(ne1 + n_past == ne0);
12914
13115
  GGML_ASSERT(n_head == ne2);
12915
13116
 
12916
13117
  // add alibi to src0 (KQ_scaled)
@@ -12919,9 +13120,9 @@ static void ggml_compute_forward_alibi_f32(
12919
13120
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
12920
13121
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
12921
13122
 
12922
- for (int i = 0; i < ne0; i++) {
12923
- for (int j = 0; j < ne1; j++) {
12924
- for (int k = 0; k < ne2_ne3; k++) {
13123
+ for (int64_t i = 0; i < ne0; i++) {
13124
+ for (int64_t j = 0; j < ne1; j++) {
13125
+ for (int64_t k = 0; k < ne2_ne3; k++) {
12925
13126
  float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
12926
13127
  float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
12927
13128
 
@@ -12936,7 +13137,6 @@ static void ggml_compute_forward_alibi_f32(
12936
13137
  }
12937
13138
 
12938
13139
  pdst[0] = i * m_k + src[0];
12939
-
12940
13140
  }
12941
13141
  }
12942
13142
  }
@@ -13636,7 +13836,7 @@ static void ggml_compute_forward_rope_back(
13636
13836
 
13637
13837
  // ggml_compute_forward_conv_1d
13638
13838
 
13639
- static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13839
+ static void ggml_compute_forward_conv_1d_f16_f32(
13640
13840
  const struct ggml_compute_params * params,
13641
13841
  const struct ggml_tensor * src0,
13642
13842
  const struct ggml_tensor * src1,
@@ -13654,42 +13854,33 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13654
13854
  const int nth = params->nth;
13655
13855
 
13656
13856
  const int nk = ne00;
13657
- const int nh = nk/2;
13658
13857
 
13659
- const int ew0 = ggml_up32(ne01);
13858
+ // size of the convolution row - the kernel size unrolled across all input channels
13859
+ const int ew0 = nk*ne01;
13860
+
13861
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
13862
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
13863
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
13660
13864
 
13661
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13662
13865
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13663
13866
  GGML_ASSERT(nb10 == sizeof(float));
13664
13867
 
13665
13868
  if (params->type == GGML_TASK_INIT) {
13666
- // TODO: fix this memset (wsize is overestimated)
13667
13869
  memset(params->wdata, 0, params->wsize);
13668
13870
 
13669
- // prepare kernel data (src0)
13670
- {
13671
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13871
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13672
13872
 
13673
- for (int64_t i02 = 0; i02 < ne02; i02++) {
13674
- for (int64_t i01 = 0; i01 < ne01; i01++) {
13675
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
13676
- ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
13677
- for (int64_t i00 = 0; i00 < ne00; i00++) {
13678
- dst_data[i00*ew0 + i01] = src[i00];
13679
- }
13680
- }
13681
- }
13682
- }
13873
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
13874
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
13875
+ ggml_fp16_t * dst_data = wdata;
13683
13876
 
13684
- // prepare source data (src1)
13685
- {
13686
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
13877
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
13878
+ for (int64_t ik = 0; ik < nk; ik++) {
13879
+ const int idx0 = i0*s0 + ik*d0 - p0;
13687
13880
 
13688
- for (int64_t i11 = 0; i11 < ne11; i11++) {
13689
- const float * const src = (float *)((char *) src1->data + i11*nb11);
13690
- ggml_fp16_t * dst_data = wdata;
13691
- for (int64_t i10 = 0; i10 < ne10; i10++) {
13692
- dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
13881
+ if(!(idx0 < 0 || idx0 >= ne10)) {
13882
+ dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]);
13883
+ }
13693
13884
  }
13694
13885
  }
13695
13886
  }
@@ -13702,7 +13893,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13702
13893
  }
13703
13894
 
13704
13895
  // total rows in dst
13705
- const int nr = ne02;
13896
+ const int nr = ne2;
13706
13897
 
13707
13898
  // rows per thread
13708
13899
  const int dr = (nr + nth - 1)/nth;
@@ -13711,23 +13902,22 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13711
13902
  const int ir0 = dr*ith;
13712
13903
  const int ir1 = MIN(ir0 + dr, nr);
13713
13904
 
13714
- for (int i1 = ir0; i1 < ir1; i1++) {
13715
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
13716
- for (int64_t i0 = 0; i0 < ne10; ++i0) {
13717
- dst_data[i0] = 0;
13718
- for (int k = -nh; k <= nh; k++) {
13719
- float v = 0.0f;
13720
- ggml_vec_dot_f16(ew0, &v,
13721
- (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13722
- (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13723
-
13724
- dst_data[i0] += v;
13905
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13906
+
13907
+ for (int i2 = 0; i2 < ne2; i2++) {
13908
+ for (int i1 = ir0; i1 < ir1; i1++) {
13909
+ float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
13910
+
13911
+ for (int i0 = 0; i0 < ne0; i0++) {
13912
+ ggml_vec_dot_f16(ew0, dst_data + i0,
13913
+ (ggml_fp16_t *) ((char *) src0->data + i1*nb02),
13914
+ (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
13725
13915
  }
13726
13916
  }
13727
13917
  }
13728
13918
  }
13729
13919
 
13730
- static void ggml_compute_forward_conv_1d_s1_ph_f32(
13920
+ static void ggml_compute_forward_conv_1d_f32(
13731
13921
  const struct ggml_compute_params * params,
13732
13922
  const struct ggml_tensor * src0,
13733
13923
  const struct ggml_tensor * src1,
@@ -13745,42 +13935,32 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13745
13935
  const int nth = params->nth;
13746
13936
 
13747
13937
  const int nk = ne00;
13748
- const int nh = nk/2;
13749
13938
 
13750
- const int ew0 = ggml_up32(ne01);
13939
+ const int ew0 = nk*ne01;
13940
+
13941
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
13942
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
13943
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
13751
13944
 
13752
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13753
13945
  GGML_ASSERT(nb00 == sizeof(float));
13754
13946
  GGML_ASSERT(nb10 == sizeof(float));
13755
13947
 
13756
13948
  if (params->type == GGML_TASK_INIT) {
13757
- // TODO: fix this memset (wsize is overestimated)
13758
13949
  memset(params->wdata, 0, params->wsize);
13759
13950
 
13760
- // prepare kernel data (src0)
13761
- {
13762
- float * const wdata = (float *) params->wdata + 0;
13951
+ float * const wdata = (float *) params->wdata + 0;
13763
13952
 
13764
- for (int64_t i02 = 0; i02 < ne02; i02++) {
13765
- for (int64_t i01 = 0; i01 < ne01; i01++) {
13766
- const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
13767
- float * dst_data = wdata + i02*ew0*ne00;
13768
- for (int64_t i00 = 0; i00 < ne00; i00++) {
13769
- dst_data[i00*ew0 + i01] = src[i00];
13770
- }
13771
- }
13772
- }
13773
- }
13953
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
13954
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
13955
+ float * dst_data = wdata;
13774
13956
 
13775
- // prepare source data (src1)
13776
- {
13777
- float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
13957
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
13958
+ for (int64_t ik = 0; ik < nk; ik++) {
13959
+ const int idx0 = i0*s0 + ik*d0 - p0;
13778
13960
 
13779
- for (int64_t i11 = 0; i11 < ne11; i11++) {
13780
- const float * const src = (float *)((char *) src1->data + i11*nb11);
13781
- float * dst_data = wdata;
13782
- for (int64_t i10 = 0; i10 < ne10; i10++) {
13783
- dst_data[(i10 + nh)*ew0 + i11] = src[i10];
13961
+ if(!(idx0 < 0 || idx0 >= ne10)) {
13962
+ dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
13963
+ }
13784
13964
  }
13785
13965
  }
13786
13966
  }
@@ -13802,35 +13982,242 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13802
13982
  const int ir0 = dr*ith;
13803
13983
  const int ir1 = MIN(ir0 + dr, nr);
13804
13984
 
13805
- for (int i1 = ir0; i1 < ir1; i1++) {
13806
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
13807
- for (int64_t i0 = 0; i0 < ne10; ++i0) {
13808
- dst_data[i0] = 0;
13809
- for (int k = -nh; k <= nh; k++) {
13810
- float v = 0.0f;
13811
- ggml_vec_dot_f32(ew0, &v,
13812
- (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13813
- (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13814
-
13815
- dst_data[i0] += v;
13985
+ float * const wdata = (float *) params->wdata + 0;
13986
+
13987
+ for (int i2 = 0; i2 < ne2; i2++) {
13988
+ for (int i1 = ir0; i1 < ir1; i1++) {
13989
+ float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
13990
+
13991
+ for (int i0 = 0; i0 < ne0; i0++) {
13992
+ ggml_vec_dot_f32(ew0, dst_data + i0,
13993
+ (float *) ((char *) src0->data + i1*nb02),
13994
+ (float *) wdata + i2*nb2 + i0*ew0);
13995
+ }
13996
+ }
13997
+ }
13998
+ }
13999
+
14000
+ static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
14001
+ ggml_fp16_t * A,
14002
+ ggml_fp16_t * B,
14003
+ float * C,
14004
+ const int ith, const int nth) {
14005
+ // does not seem to make a difference
14006
+ int64_t m0, m1, n0, n1;
14007
+ // patches per thread
14008
+ if (m > n) {
14009
+ n0 = 0;
14010
+ n1 = n;
14011
+
14012
+ // total patches in dst
14013
+ const int np = m;
14014
+
14015
+ // patches per thread
14016
+ const int dp = (np + nth - 1)/nth;
14017
+
14018
+ // patch range for this thread
14019
+ m0 = dp*ith;
14020
+ m1 = MIN(m0 + dp, np);
14021
+ } else {
14022
+ m0 = 0;
14023
+ m1 = m;
14024
+
14025
+ // total patches in dst
14026
+ const int np = n;
14027
+
14028
+ // patches per thread
14029
+ const int dp = (np + nth - 1)/nth;
14030
+
14031
+ // patch range for this thread
14032
+ n0 = dp*ith;
14033
+ n1 = MIN(n0 + dp, np);
14034
+ }
14035
+
14036
+ // block-tiling attempt
14037
+ int64_t blck_n = 16;
14038
+ int64_t blck_m = 16;
14039
+
14040
+ // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
14041
+ // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
14042
+ // if (blck_size > 0) {
14043
+ // blck_0 = 4;
14044
+ // blck_1 = blck_size / blck_0;
14045
+ // if (blck_1 < 0) {
14046
+ // blck_1 = 1;
14047
+ // }
14048
+ // // blck_0 = (int64_t)sqrt(blck_size);
14049
+ // // blck_1 = blck_0;
14050
+ // }
14051
+ // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
14052
+
14053
+ for (int j = n0; j < n1; j+=blck_n) {
14054
+ for (int i = m0; i < m1; i+=blck_m) {
14055
+ // printf("i j k => %d %d %d\n", i, j, K);
14056
+ for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
14057
+ for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
14058
+ ggml_vec_dot_f16(k,
14059
+ C + ii*n + jj,
14060
+ A + ii * k,
14061
+ B + jj * k);
14062
+ }
13816
14063
  }
13817
14064
  }
13818
14065
  }
13819
14066
  }
13820
14067
 
13821
- static void ggml_compute_forward_conv_1d_s1_ph(
14068
+ // src0: kernel [OC, IC, K]
14069
+ // src1: signal [N, IC, IL]
14070
+ // dst: result [N, OL, IC*K]
14071
+ static void ggml_compute_forward_conv_1d_stage_0_f32(
13822
14072
  const struct ggml_compute_params * params,
13823
14073
  const struct ggml_tensor * src0,
13824
14074
  const struct ggml_tensor * src1,
13825
14075
  struct ggml_tensor * dst) {
13826
- switch (src0->type) {
14076
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14077
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
14078
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
14079
+
14080
+ int64_t t0 = ggml_perf_time_us();
14081
+ UNUSED(t0);
14082
+
14083
+ GGML_TENSOR_BINARY_OP_LOCALS;
14084
+
14085
+ const int64_t N = ne12;
14086
+ const int64_t IC = ne11;
14087
+ const int64_t IL = ne10;
14088
+
14089
+ const int64_t K = ne00;
14090
+
14091
+ const int64_t OL = ne1;
14092
+
14093
+ const int ith = params->ith;
14094
+ const int nth = params->nth;
14095
+
14096
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14097
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
14098
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
14099
+
14100
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14101
+ GGML_ASSERT(nb10 == sizeof(float));
14102
+
14103
+ if (params->type == GGML_TASK_INIT) {
14104
+ memset(dst->data, 0, ggml_nbytes(dst));
14105
+ return;
14106
+ }
14107
+
14108
+ if (params->type == GGML_TASK_FINALIZE) {
14109
+ return;
14110
+ }
14111
+
14112
+ // im2col: [N, IC, IL] => [N, OL, IC*K]
14113
+ {
14114
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
14115
+
14116
+ for (int64_t in = 0; in < N; in++) {
14117
+ for (int64_t iol = 0; iol < OL; iol++) {
14118
+ for (int64_t iic = ith; iic < IC; iic+=nth) {
14119
+
14120
+ // micro kernel
14121
+ ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
14122
+ const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
14123
+
14124
+ for (int64_t ik = 0; ik < K; ik++) {
14125
+ const int64_t iil = iol*s0 + ik*d0 - p0;
14126
+
14127
+ if (!(iil < 0 || iil >= IL)) {
14128
+ dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]);
14129
+ }
14130
+ }
14131
+ }
14132
+ }
14133
+ }
14134
+ }
14135
+ }
14136
+
14137
+ // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
14138
+ // src0: [OC, IC, K]
14139
+ // src1: [N, OL, IC * K]
14140
+ // result: [N, OC, OL]
14141
+ static void ggml_compute_forward_conv_1d_stage_1_f16(
14142
+ const struct ggml_compute_params * params,
14143
+ const struct ggml_tensor * src0,
14144
+ const struct ggml_tensor * src1,
14145
+ struct ggml_tensor * dst) {
14146
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14147
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
14148
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
14149
+
14150
+ int64_t t0 = ggml_perf_time_us();
14151
+ UNUSED(t0);
14152
+
14153
+ if (params->type == GGML_TASK_INIT) {
14154
+ return;
14155
+ }
14156
+
14157
+ if (params->type == GGML_TASK_FINALIZE) {
14158
+ return;
14159
+ }
14160
+
14161
+ GGML_TENSOR_BINARY_OP_LOCALS;
14162
+
14163
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14164
+ GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
14165
+ GGML_ASSERT(nb0 == sizeof(float));
14166
+
14167
+ const int N = ne12;
14168
+ const int OL = ne11;
14169
+
14170
+ const int OC = ne02;
14171
+ const int IC = ne01;
14172
+ const int K = ne00;
14173
+
14174
+ const int ith = params->ith;
14175
+ const int nth = params->nth;
14176
+
14177
+ int64_t m = OC;
14178
+ int64_t n = OL;
14179
+ int64_t k = IC * K;
14180
+
14181
+ // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
14182
+ for (int i = 0; i < N; i++) {
14183
+ ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
14184
+ ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
14185
+ float * C = (float *)dst->data + i * m * n; // [m, n]
14186
+
14187
+ gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
14188
+ }
14189
+ }
14190
+
14191
+ static void ggml_compute_forward_conv_1d(
14192
+ const struct ggml_compute_params * params,
14193
+ const struct ggml_tensor * src0,
14194
+ const struct ggml_tensor * src1,
14195
+ struct ggml_tensor * dst) {
14196
+ switch(src0->type) {
13827
14197
  case GGML_TYPE_F16:
13828
14198
  {
13829
- ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst);
14199
+ ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
13830
14200
  } break;
13831
14201
  case GGML_TYPE_F32:
13832
14202
  {
13833
- ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst);
14203
+ ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
14204
+ } break;
14205
+ default:
14206
+ {
14207
+ GGML_ASSERT(false);
14208
+ } break;
14209
+ }
14210
+ }
14211
+
14212
+ static void ggml_compute_forward_conv_1d_stage_0(
14213
+ const struct ggml_compute_params * params,
14214
+ const struct ggml_tensor * src0,
14215
+ const struct ggml_tensor * src1,
14216
+ struct ggml_tensor * dst) {
14217
+ switch(src0->type) {
14218
+ case GGML_TYPE_F16:
14219
+ {
14220
+ ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
13834
14221
  } break;
13835
14222
  default:
13836
14223
  {
@@ -13839,7 +14226,26 @@ static void ggml_compute_forward_conv_1d_s1_ph(
13839
14226
  }
13840
14227
  }
13841
14228
 
13842
- static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
14229
+ static void ggml_compute_forward_conv_1d_stage_1(
14230
+ const struct ggml_compute_params * params,
14231
+ const struct ggml_tensor * src0,
14232
+ const struct ggml_tensor * src1,
14233
+ struct ggml_tensor * dst) {
14234
+ switch(src0->type) {
14235
+ case GGML_TYPE_F16:
14236
+ {
14237
+ ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
14238
+ } break;
14239
+ default:
14240
+ {
14241
+ GGML_ASSERT(false);
14242
+ } break;
14243
+ }
14244
+ }
14245
+
14246
+ // ggml_compute_forward_conv_transpose_1d
14247
+
14248
+ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
13843
14249
  const struct ggml_compute_params * params,
13844
14250
  const struct ggml_tensor * src0,
13845
14251
  const struct ggml_tensor * src1,
@@ -13856,43 +14262,38 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13856
14262
  const int ith = params->ith;
13857
14263
  const int nth = params->nth;
13858
14264
 
13859
- const int nk = ne00;
13860
- const int nh = nk/2;
13861
-
13862
- const int ew0 = ggml_up32(ne01);
14265
+ const int nk = ne00*ne01*ne02;
13863
14266
 
13864
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13865
14267
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13866
14268
  GGML_ASSERT(nb10 == sizeof(float));
13867
14269
 
13868
14270
  if (params->type == GGML_TASK_INIT) {
13869
- // TODO: fix this memset (wsize is overestimated)
13870
14271
  memset(params->wdata, 0, params->wsize);
13871
14272
 
13872
- // prepare kernel data (src0)
14273
+ // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
13873
14274
  {
13874
14275
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13875
14276
 
13876
14277
  for (int64_t i02 = 0; i02 < ne02; i02++) {
13877
14278
  for (int64_t i01 = 0; i01 < ne01; i01++) {
13878
14279
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
13879
- ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
14280
+ ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
13880
14281
  for (int64_t i00 = 0; i00 < ne00; i00++) {
13881
- dst_data[i00*ew0 + i01] = src[i00];
14282
+ dst_data[i00*ne02 + i02] = src[i00];
13882
14283
  }
13883
14284
  }
13884
14285
  }
13885
14286
  }
13886
14287
 
13887
- // prepare source data (src1)
14288
+ // permute source data (src1) from (L x Cin) to (Cin x L)
13888
14289
  {
13889
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
14290
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
14291
+ ggml_fp16_t * dst_data = wdata;
13890
14292
 
13891
14293
  for (int64_t i11 = 0; i11 < ne11; i11++) {
13892
14294
  const float * const src = (float *)((char *) src1->data + i11*nb11);
13893
- ggml_fp16_t * dst_data = wdata;
13894
14295
  for (int64_t i10 = 0; i10 < ne10; i10++) {
13895
- dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
14296
+ dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
13896
14297
  }
13897
14298
  }
13898
14299
  }
@@ -13904,8 +14305,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13904
14305
  return;
13905
14306
  }
13906
14307
 
14308
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14309
+
13907
14310
  // total rows in dst
13908
- const int nr = ne02;
14311
+ const int nr = ne1;
13909
14312
 
13910
14313
  // rows per thread
13911
14314
  const int dr = (nr + nth - 1)/nth;
@@ -13914,23 +14317,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13914
14317
  const int ir0 = dr*ith;
13915
14318
  const int ir1 = MIN(ir0 + dr, nr);
13916
14319
 
14320
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14321
+ ggml_fp16_t * const wdata_src = wdata + nk;
14322
+
13917
14323
  for (int i1 = ir0; i1 < ir1; i1++) {
13918
14324
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
13919
- for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
13920
- dst_data[i0/2] = 0;
13921
- for (int k = -nh; k <= nh; k++) {
13922
- float v = 0.0f;
13923
- ggml_vec_dot_f16(ew0, &v,
13924
- (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13925
- (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13926
-
13927
- dst_data[i0/2] += v;
14325
+ ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
14326
+ for (int i10 = 0; i10 < ne10; i10++) {
14327
+ const int i1n = i10*ne11;
14328
+ for (int i00 = 0; i00 < ne00; i00++) {
14329
+ float v = 0;
14330
+ ggml_vec_dot_f16(ne02, &v,
14331
+ (ggml_fp16_t *) wdata_src + i1n,
14332
+ (ggml_fp16_t *) wdata_kernel + i00*ne02);
14333
+ dst_data[i10*s0 + i00] += v;
13928
14334
  }
13929
14335
  }
13930
14336
  }
13931
14337
  }
13932
14338
 
13933
- static void ggml_compute_forward_conv_1d_s2_ph_f32(
14339
+ static void ggml_compute_forward_conv_transpose_1d_f32(
13934
14340
  const struct ggml_compute_params * params,
13935
14341
  const struct ggml_tensor * src0,
13936
14342
  const struct ggml_tensor * src1,
@@ -13947,29 +14353,24 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13947
14353
  const int ith = params->ith;
13948
14354
  const int nth = params->nth;
13949
14355
 
13950
- const int nk = ne00;
13951
- const int nh = nk/2;
13952
-
13953
- const int ew0 = ggml_up32(ne01);
14356
+ const int nk = ne00*ne01*ne02;
13954
14357
 
13955
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13956
14358
  GGML_ASSERT(nb00 == sizeof(float));
13957
14359
  GGML_ASSERT(nb10 == sizeof(float));
13958
14360
 
13959
14361
  if (params->type == GGML_TASK_INIT) {
13960
- // TODO: fix this memset (wsize is overestimated)
13961
14362
  memset(params->wdata, 0, params->wsize);
13962
14363
 
13963
- // prepare kernel data (src0)
14364
+ // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
13964
14365
  {
13965
14366
  float * const wdata = (float *) params->wdata + 0;
13966
14367
 
13967
14368
  for (int64_t i02 = 0; i02 < ne02; i02++) {
13968
14369
  for (int64_t i01 = 0; i01 < ne01; i01++) {
13969
14370
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
13970
- float * dst_data = wdata + i02*ew0*ne00;
14371
+ float * dst_data = wdata + i01*ne00*ne02;
13971
14372
  for (int64_t i00 = 0; i00 < ne00; i00++) {
13972
- dst_data[i00*ew0 + i01] = src[i00];
14373
+ dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00];
13973
14374
  }
13974
14375
  }
13975
14376
  }
@@ -13977,13 +14378,13 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13977
14378
 
13978
14379
  // prepare source data (src1)
13979
14380
  {
13980
- float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
14381
+ float * const wdata = (float *) params->wdata + nk;
14382
+ float * dst_data = wdata;
13981
14383
 
13982
14384
  for (int64_t i11 = 0; i11 < ne11; i11++) {
13983
14385
  const float * const src = (float *)((char *) src1->data + i11*nb11);
13984
- float * dst_data = wdata;
13985
14386
  for (int64_t i10 = 0; i10 < ne10; i10++) {
13986
- dst_data[(i10 + nh)*ew0 + i11] = src[i10];
14387
+ dst_data[i10*ne11 + i11] = src[i10];
13987
14388
  }
13988
14389
  }
13989
14390
  }
@@ -13995,8 +14396,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13995
14396
  return;
13996
14397
  }
13997
14398
 
14399
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14400
+
13998
14401
  // total rows in dst
13999
- const int nr = ne02;
14402
+ const int nr = ne1;
14000
14403
 
14001
14404
  // rows per thread
14002
14405
  const int dr = (nr + nth - 1)/nth;
@@ -14005,23 +14408,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
14005
14408
  const int ir0 = dr*ith;
14006
14409
  const int ir1 = MIN(ir0 + dr, nr);
14007
14410
 
14411
+ float * const wdata = (float *) params->wdata + 0;
14412
+ float * const wdata_src = wdata + nk;
14413
+
14008
14414
  for (int i1 = ir0; i1 < ir1; i1++) {
14009
14415
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
14010
- for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
14011
- dst_data[i0/2] = 0;
14012
- for (int k = -nh; k <= nh; k++) {
14013
- float v = 0.0f;
14014
- ggml_vec_dot_f32(ew0, &v,
14015
- (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
14016
- (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
14017
-
14018
- dst_data[i0/2] += v;
14416
+ float * wdata_kernel = wdata + i1*ne02*ne00;
14417
+ for (int i10 = 0; i10 < ne10; i10++) {
14418
+ const int i1n = i10*ne11;
14419
+ for (int i00 = 0; i00 < ne00; i00++) {
14420
+ float v = 0;
14421
+ ggml_vec_dot_f32(ne02, &v,
14422
+ wdata_src + i1n,
14423
+ wdata_kernel + i00*ne02);
14424
+ dst_data[i10*s0 + i00] += v;
14019
14425
  }
14020
14426
  }
14021
14427
  }
14022
14428
  }
14023
14429
 
14024
- static void ggml_compute_forward_conv_1d_s2_ph(
14430
+ static void ggml_compute_forward_conv_transpose_1d(
14025
14431
  const struct ggml_compute_params * params,
14026
14432
  const struct ggml_tensor * src0,
14027
14433
  const struct ggml_tensor * src1,
@@ -14029,11 +14435,11 @@ static void ggml_compute_forward_conv_1d_s2_ph(
14029
14435
  switch (src0->type) {
14030
14436
  case GGML_TYPE_F16:
14031
14437
  {
14032
- ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst);
14438
+ ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
14033
14439
  } break;
14034
14440
  case GGML_TYPE_F32:
14035
14441
  {
14036
- ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst);
14442
+ ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
14037
14443
  } break;
14038
14444
  default:
14039
14445
  {
@@ -14042,27 +14448,6 @@ static void ggml_compute_forward_conv_1d_s2_ph(
14042
14448
  }
14043
14449
  }
14044
14450
 
14045
- // ggml_compute_forward_conv_1d
14046
-
14047
- static void ggml_compute_forward_conv_1d(
14048
- const struct ggml_compute_params * params,
14049
- const struct ggml_tensor * src0,
14050
- const struct ggml_tensor * src1,
14051
- struct ggml_tensor * dst) {
14052
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14053
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
14054
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
14055
- GGML_ASSERT(d0 == 1); // dilation not supported
14056
- GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported
14057
- if (s0 == 1) {
14058
- ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst);
14059
- } else if (s0 == 2) {
14060
- ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
14061
- } else {
14062
- GGML_ASSERT(false); // only stride 1 and 2 supported
14063
- }
14064
- }
14065
-
14066
14451
  // ggml_compute_forward_conv_2d
14067
14452
 
14068
14453
  static void ggml_compute_forward_conv_2d_f16_f32(
@@ -14077,7 +14462,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
14077
14462
  int64_t t0 = ggml_perf_time_us();
14078
14463
  UNUSED(t0);
14079
14464
 
14080
- GGML_TENSOR_BINARY_OP_LOCALS
14465
+ GGML_TENSOR_BINARY_OP_LOCALS;
14081
14466
 
14082
14467
  const int ith = params->ith;
14083
14468
  const int nth = params->nth;
@@ -14105,20 +14490,22 @@ static void ggml_compute_forward_conv_2d_f16_f32(
14105
14490
  {
14106
14491
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14107
14492
 
14108
- for (int i12 = 0; i12 < ne12; i12++) {
14109
- const float * const src = (float *)((char *) src1->data + i12*nb12);
14110
- ggml_fp16_t * dst_data = wdata;
14111
-
14112
- for (int i1 = 0; i1 < ne1; i1++) {
14113
- for (int i0 = 0; i0 < ne0; i0++) {
14114
- for (int ik1 = 0; ik1 < nk1; ik1++) {
14115
- for (int ik0 = 0; ik0 < nk0; ik0++) {
14116
- const int idx0 = i0*s0 + ik0*d0 - p0;
14117
- const int idx1 = i1*s1 + ik1*d1 - p1;
14118
-
14119
- if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
14120
- dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
14121
- GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
14493
+ for (int i13 = 0; i13 < ne13; i13++) {
14494
+ for (int i12 = 0; i12 < ne12; i12++) {
14495
+ const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
14496
+ ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
14497
+
14498
+ for (int i1 = 0; i1 < ne1; i1++) {
14499
+ for (int i0 = 0; i0 < ne0; i0++) {
14500
+ for (int ik1 = 0; ik1 < nk1; ik1++) {
14501
+ for (int ik0 = 0; ik0 < nk0; ik0++) {
14502
+ const int idx0 = i0*s0 + ik0*d0 - p0;
14503
+ const int idx1 = i1*s1 + ik1*d1 - p1;
14504
+
14505
+ if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
14506
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
14507
+ GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
14508
+ }
14122
14509
  }
14123
14510
  }
14124
14511
  }
@@ -16401,6 +16788,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16401
16788
  {
16402
16789
  ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
16403
16790
  } break;
16791
+ case GGML_OP_CONV_1D_STAGE_0:
16792
+ {
16793
+ ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
16794
+ } break;
16795
+ case GGML_OP_CONV_1D_STAGE_1:
16796
+ {
16797
+ ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
16798
+ } break;
16799
+ case GGML_OP_CONV_TRANSPOSE_1D:
16800
+ {
16801
+ ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
16802
+ } break;
16404
16803
  case GGML_OP_CONV_2D:
16405
16804
  {
16406
16805
  ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
@@ -17326,10 +17725,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17326
17725
  {
17327
17726
  GGML_ASSERT(false); // TODO: not implemented
17328
17727
  } break;
17728
+ case GGML_OP_CONV_1D_STAGE_0:
17729
+ {
17730
+ GGML_ASSERT(false); // TODO: not implemented
17731
+ } break;
17732
+ case GGML_OP_CONV_1D_STAGE_1:
17733
+ {
17734
+ GGML_ASSERT(false); // TODO: not implemented
17735
+ } break;
17329
17736
  case GGML_OP_CONV_2D:
17330
17737
  {
17331
17738
  GGML_ASSERT(false); // TODO: not implemented
17332
17739
  } break;
17740
+ case GGML_OP_CONV_TRANSPOSE_1D:
17741
+ {
17742
+ GGML_ASSERT(false); // TODO: not implemented
17743
+ } break;
17333
17744
  case GGML_OP_CONV_TRANSPOSE_2D:
17334
17745
  {
17335
17746
  GGML_ASSERT(false); // TODO: not implemented
@@ -18171,21 +18582,68 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
18171
18582
  GGML_ASSERT(node->src[1]->ne[2] == 1);
18172
18583
  GGML_ASSERT(node->src[1]->ne[3] == 1);
18173
18584
 
18585
+ const int64_t ne00 = node->src[0]->ne[0];
18586
+ const int64_t ne01 = node->src[0]->ne[1];
18587
+ const int64_t ne02 = node->src[0]->ne[2];
18588
+
18589
+ const int64_t ne10 = node->src[1]->ne[0];
18590
+ const int64_t ne11 = node->src[1]->ne[1];
18591
+
18592
+ const int64_t ne0 = node->ne[0];
18593
+ const int64_t ne1 = node->ne[1];
18594
+ const int64_t nk = ne00;
18595
+ const int64_t ew0 = nk * ne01;
18596
+
18597
+ UNUSED(ne02);
18598
+ UNUSED(ne10);
18599
+ UNUSED(ne11);
18600
+
18174
18601
  size_t cur = 0;
18175
- const int nk = node->src[0]->ne[0];
18176
18602
 
18177
18603
  if (node->src[0]->type == GGML_TYPE_F16 &&
18178
- node->src[1]->type == GGML_TYPE_F32) {
18179
- cur = sizeof(ggml_fp16_t)*(
18180
- nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
18181
- ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
18182
- );
18604
+ node->src[1]->type == GGML_TYPE_F32) {
18605
+ cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
18606
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
18607
+ node->src[1]->type == GGML_TYPE_F32) {
18608
+ cur = sizeof(float)*(ne0*ne1*ew0);
18609
+ } else {
18610
+ GGML_ASSERT(false);
18611
+ }
18612
+
18613
+ work_size = MAX(work_size, cur);
18614
+ } break;
18615
+ case GGML_OP_CONV_1D_STAGE_0:
18616
+ {
18617
+ n_tasks = n_threads;
18618
+ } break;
18619
+ case GGML_OP_CONV_1D_STAGE_1:
18620
+ {
18621
+ n_tasks = n_threads;
18622
+ } break;
18623
+ case GGML_OP_CONV_TRANSPOSE_1D:
18624
+ {
18625
+ n_tasks = n_threads;
18626
+
18627
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
18628
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
18629
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
18630
+
18631
+ const int64_t ne00 = node->src[0]->ne[0]; // K
18632
+ const int64_t ne01 = node->src[0]->ne[1]; // Cout
18633
+ const int64_t ne02 = node->src[0]->ne[2]; // Cin
18634
+
18635
+ const int64_t ne10 = node->src[1]->ne[0]; // L
18636
+ const int64_t ne11 = node->src[1]->ne[1]; // Cin
18637
+
18638
+ size_t cur = 0;
18639
+ if (node->src[0]->type == GGML_TYPE_F16 &&
18640
+ node->src[1]->type == GGML_TYPE_F32) {
18641
+ cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18642
+ cur += sizeof(ggml_fp16_t)*ne10*ne11;
18183
18643
  } else if (node->src[0]->type == GGML_TYPE_F32 &&
18184
- node->src[1]->type == GGML_TYPE_F32) {
18185
- cur = sizeof(float)*(
18186
- nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
18187
- ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
18188
- );
18644
+ node->src[1]->type == GGML_TYPE_F32) {
18645
+ cur += sizeof(float)*ne00*ne01*ne02;
18646
+ cur += sizeof(float)*ne10*ne11;
18189
18647
  } else {
18190
18648
  GGML_ASSERT(false);
18191
18649
  }
@@ -19311,7 +19769,7 @@ static enum ggml_opt_result ggml_opt_adam(
19311
19769
  if (callback) {
19312
19770
  callback(callback_data, accum_step, &sched, &cancel);
19313
19771
  if (cancel) {
19314
- break;
19772
+ return GGML_OPT_CANCEL;
19315
19773
  }
19316
19774
  }
19317
19775
  // ggml_graph_reset (gf);
@@ -19320,9 +19778,6 @@ static enum ggml_opt_result ggml_opt_adam(
19320
19778
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19321
19779
  fx += ggml_get_f32_1d(f, 0);
19322
19780
  }
19323
- if (cancel) {
19324
- return GGML_OPT_DID_NOT_CONVERGE;
19325
- }
19326
19781
  fx *= accum_norm;
19327
19782
 
19328
19783
  opt->adam.fx_prev = fx;
@@ -19348,9 +19803,6 @@ static enum ggml_opt_result ggml_opt_adam(
19348
19803
 
19349
19804
  // run the optimizer
19350
19805
  for (int t = 0; t < params.adam.n_iter; ++t) {
19351
- if (cancel) {
19352
- break;
19353
- }
19354
19806
  opt->iter = iter0 + t + 1;
19355
19807
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
19356
19808
 
@@ -19408,7 +19860,7 @@ static enum ggml_opt_result ggml_opt_adam(
19408
19860
  if (callback) {
19409
19861
  callback(callback_data, accum_step, &sched, &cancel);
19410
19862
  if (cancel) {
19411
- break;
19863
+ return GGML_OPT_CANCEL;;
19412
19864
  }
19413
19865
  }
19414
19866
  // ggml_graph_reset (gf);
@@ -19417,9 +19869,6 @@ static enum ggml_opt_result ggml_opt_adam(
19417
19869
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19418
19870
  fx += ggml_get_f32_1d(f, 0);
19419
19871
  }
19420
- if (cancel) {
19421
- break;
19422
- }
19423
19872
  fx *= accum_norm;
19424
19873
 
19425
19874
  opt->loss_after = fx;
@@ -19538,7 +19987,7 @@ static enum ggml_opt_result linesearch_backtracking(
19538
19987
  finit = *fx;
19539
19988
  dgtest = params->lbfgs.ftol*dginit;
19540
19989
 
19541
- while (!*cancel) {
19990
+ while (true) {
19542
19991
  ggml_vec_cpy_f32(nx, x, xp);
19543
19992
  ggml_vec_mad_f32(nx, x, d, *step);
19544
19993
 
@@ -19554,7 +20003,7 @@ static enum ggml_opt_result linesearch_backtracking(
19554
20003
  float sched = 0;
19555
20004
  callback(callback_data, accum_step, &sched, cancel);
19556
20005
  if (*cancel) {
19557
- break;
20006
+ return GGML_OPT_CANCEL;
19558
20007
  }
19559
20008
  }
19560
20009
  // ggml_graph_reset (gf);
@@ -19563,9 +20012,6 @@ static enum ggml_opt_result linesearch_backtracking(
19563
20012
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19564
20013
  *fx += ggml_get_f32_1d(f, 0);
19565
20014
  }
19566
- if (*cancel) {
19567
- break;
19568
- }
19569
20015
  *fx *= accum_norm;
19570
20016
 
19571
20017
  }
@@ -19698,7 +20144,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19698
20144
  float sched = 0;
19699
20145
  callback(callback_data, accum_step, &sched, &cancel);
19700
20146
  if (cancel) {
19701
- break;
20147
+ return GGML_OPT_CANCEL;
19702
20148
  }
19703
20149
  }
19704
20150
  // ggml_graph_reset (gf);
@@ -19707,9 +20153,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19707
20153
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19708
20154
  fx += ggml_get_f32_1d(f, 0);
19709
20155
  }
19710
- if (cancel) {
19711
- return GGML_OPT_DID_NOT_CONVERGE;
19712
- }
19713
20156
  fx *= accum_norm;
19714
20157
 
19715
20158
  opt->loss_before = fx;
@@ -19768,9 +20211,13 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19768
20211
  ggml_vec_cpy_f32(nx, xp, x);
19769
20212
  ggml_vec_cpy_f32(nx, gp, g);
19770
20213
 
20214
+ // TODO: instead of passing &cancel here, use the return code of the linesearch
20215
+ // to determine if the optimization should be cancelled
20216
+ // this is a simple change, but not doing this atm, since I don't have a nice
20217
+ // way to test and don't want to break something with so many changes lined up
19771
20218
  ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
19772
- if (!cancel) {
19773
- break;
20219
+ if (cancel) {
20220
+ return GGML_OPT_CANCEL;
19774
20221
  }
19775
20222
 
19776
20223
  if (ls < 0) {