llama_cpp 0.6.0 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1032,8 +1032,8 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
1032
1032
  y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1033
1033
 
1034
1034
  // get the 5-th bit and store it in qh at the right position
1035
- qh |= ((xi0 & 0x10) >> 4) << (j + 0);
1036
- qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
1035
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1036
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1037
1037
  }
1038
1038
 
1039
1039
  memcpy(&y[i].qh, &qh, sizeof(qh));
@@ -1080,8 +1080,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
1080
1080
  y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1081
1081
 
1082
1082
  // get the 5-th bit and store it in qh at the right position
1083
- qh |= ((xi0 & 0x10) >> 4) << (j + 0);
1084
- qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
1083
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1084
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1085
1085
  }
1086
1086
 
1087
1087
  memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
@@ -1272,6 +1272,33 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1272
1272
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1273
1273
  #endif
1274
1274
  }
1275
+ #elif defined(__riscv_v_intrinsic)
1276
+
1277
+ size_t vl = __riscv_vsetvl_e32m4(QK8_0);
1278
+
1279
+ for (int i = 0; i < nb; i++) {
1280
+ // load elements
1281
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
1282
+
1283
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1284
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
1285
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1286
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1287
+
1288
+ const float d = amax / ((1 << 7) - 1);
1289
+ const float id = d ? 1.0f/d : 0.0f;
1290
+
1291
+ y[i].d = GGML_FP32_TO_FP16(d);
1292
+
1293
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1294
+
1295
+ // convert to integer
1296
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1297
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1298
+
1299
+ // store result
1300
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1301
+ }
1275
1302
  #else
1276
1303
  // scalar
1277
1304
  quantize_row_q8_0_reference(x, y, k);
@@ -1490,6 +1517,41 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
1490
1517
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1491
1518
  #endif
1492
1519
  }
1520
+ #elif defined(__riscv_v_intrinsic)
1521
+
1522
+ size_t vl = __riscv_vsetvl_e32m4(QK8_1);
1523
+
1524
+ for (int i = 0; i < nb; i++) {
1525
+ // load elements
1526
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
1527
+
1528
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1529
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
1530
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1531
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1532
+
1533
+ const float d = amax / ((1 << 7) - 1);
1534
+ const float id = d ? 1.0f/d : 0.0f;
1535
+
1536
+ y[i].d = d;
1537
+
1538
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1539
+
1540
+ // convert to integer
1541
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1542
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1543
+
1544
+ // store result
1545
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1546
+
1547
+ // compute sum for y[i].s
1548
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
1549
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
1550
+
1551
+ // set y[i].s
1552
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
1553
+ y[i].s = sum*d;
1554
+ }
1493
1555
  #else
1494
1556
  // scalar
1495
1557
  quantize_row_q8_1_reference(x, y, k);
@@ -2662,30 +2724,32 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2662
2724
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
2663
2725
 
2664
2726
  for (int i = 0; i < nb; i++) {
2665
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2727
+ // load elements
2728
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2666
2729
 
2667
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2668
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2730
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2731
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2669
2732
 
2670
- vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2671
- vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2733
+ // mask and store lower part of x, and then upper part
2734
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2735
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2672
2736
 
2673
- vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2674
- vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2737
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2738
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2675
2739
 
2676
- vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
2677
- vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
2740
+ // subtract offset
2741
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2742
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2678
2743
 
2679
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2680
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2744
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2745
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2681
2746
 
2682
2747
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2683
2748
 
2684
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2685
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2749
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2750
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2686
2751
 
2687
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2688
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2752
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2689
2753
 
2690
2754
  sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2691
2755
  }
@@ -2823,27 +2887,28 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2823
2887
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
2824
2888
 
2825
2889
  for (int i = 0; i < nb; i++) {
2826
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2890
+ // load elements
2891
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2827
2892
 
2828
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2829
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2893
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2894
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2830
2895
 
2831
- vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2832
- vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2896
+ // mask and store lower part of x, and then upper part
2897
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2898
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2833
2899
 
2834
- vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2835
- vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2900
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2901
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2836
2902
 
2837
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2838
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2903
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2904
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2839
2905
 
2840
2906
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2841
2907
 
2842
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2843
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2908
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2909
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2844
2910
 
2845
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2846
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2911
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2847
2912
 
2848
2913
  sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2849
2914
  }
@@ -3088,66 +3153,61 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
3088
3153
 
3089
3154
  uint32_t qh;
3090
3155
 
3091
- // These temp values are for masking and shift operations
3092
- uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3093
- uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
3094
- 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
3095
-
3096
3156
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3097
3157
 
3158
+ // These tempory registers are for masking and shift operations
3159
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3160
+ vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3161
+
3162
+ vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3163
+ vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3164
+
3098
3165
  for (int i = 0; i < nb; i++) {
3099
3166
  memcpy(&qh, x[i].qh, sizeof(uint32_t));
3100
3167
 
3101
- // temporary registers
3102
- vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
3103
- vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
3104
- vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
3105
- vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
3106
-
3107
3168
  // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3108
- vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl);
3109
- vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl);
3110
- vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3169
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3170
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3171
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3111
3172
 
3112
3173
  // ((qh & (1u << (j + 16))) >> (j + 12));
3113
- vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl);
3114
- vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl);
3174
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3175
+ vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
3115
3176
 
3116
3177
  // narrowing
3117
- vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl);
3118
- vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3178
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3179
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3119
3180
 
3120
- vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl);
3121
- vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3181
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3182
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3122
3183
 
3123
3184
  // load
3124
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3185
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3125
3186
 
3126
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3127
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3187
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3188
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3128
3189
 
3129
- vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3130
- vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3190
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3191
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3131
3192
 
3132
- vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3133
- vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3193
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3194
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3134
3195
 
3135
- vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3136
- vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3196
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3197
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3137
3198
 
3138
- vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl);
3139
- vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl);
3199
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3200
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
3140
3201
 
3141
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3142
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3202
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3203
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3143
3204
 
3144
3205
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3145
3206
 
3146
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3147
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3207
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3208
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3148
3209
 
3149
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3150
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3210
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3151
3211
 
3152
3212
  sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3153
3213
  }
@@ -3414,62 +3474,58 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
3414
3474
 
3415
3475
  uint32_t qh;
3416
3476
 
3417
- // These temp values are for shift operations
3418
- uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3419
-
3420
3477
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3421
3478
 
3479
+ // temporary registers for shift operations
3480
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3481
+ vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3482
+
3422
3483
  for (int i = 0; i < nb; i++) {
3423
3484
  memcpy(&qh, x[i].qh, sizeof(uint32_t));
3424
3485
 
3425
- // temporary registers
3426
- vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
3427
- vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
3428
-
3429
3486
  // load qh
3430
- vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl);
3487
+ vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
3431
3488
 
3432
3489
  // ((qh >> (j + 0)) << 4) & 0x10;
3433
- vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl);
3434
- vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3435
- vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl);
3490
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3491
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3492
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
3436
3493
 
3437
3494
  // ((qh >> (j + 12)) ) & 0x10;
3438
- vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl);
3439
- vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl);
3495
+ vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3496
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
3440
3497
 
3441
3498
  // narrowing
3442
- vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl);
3443
- vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3499
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3500
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3444
3501
 
3445
- vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl);
3446
- vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3502
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3503
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3447
3504
 
3448
3505
  // load
3449
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3506
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3450
3507
 
3451
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3452
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3508
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3509
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3453
3510
 
3454
- vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3455
- vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3511
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3512
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3456
3513
 
3457
- vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3458
- vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3514
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3515
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3459
3516
 
3460
- vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3461
- vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3517
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3518
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3462
3519
 
3463
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3464
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3520
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3521
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3465
3522
 
3466
3523
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3467
3524
 
3468
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3469
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3525
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3526
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3470
3527
 
3471
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3472
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3528
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3473
3529
 
3474
3530
  sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
3475
3531
  }
@@ -4025,12 +4081,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4025
4081
  "ALIBI",
4026
4082
  "CLAMP",
4027
4083
  "CONV_1D",
4084
+ "CONV_TRANSPOSE_1D",
4028
4085
  "CONV_2D",
4029
4086
  "CONV_TRANSPOSE_2D",
4030
4087
  "POOL_1D",
4031
4088
  "POOL_2D",
4032
4089
  "UPSCALE",
4033
4090
 
4091
+ "CONV_1D_STAGE_0",
4092
+ "CONV_1D_STAGE_1",
4093
+
4034
4094
  "FLASH_ATTN",
4035
4095
  "FLASH_FF",
4036
4096
  "FLASH_ATTN_BACK",
@@ -4056,7 +4116,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4056
4116
  "CROSS_ENTROPY_LOSS_BACK",
4057
4117
  };
4058
4118
 
4059
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
4119
+ static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4060
4120
 
4061
4121
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4062
4122
  "none",
@@ -4107,12 +4167,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4107
4167
  "alibi(x)",
4108
4168
  "clamp(x)",
4109
4169
  "conv_1d(x)",
4170
+ "conv_transpose_1d(x)",
4110
4171
  "conv_2d(x)",
4111
4172
  "conv_transpose_2d(x)",
4112
4173
  "pool_1d(x)",
4113
4174
  "pool_2d(x)",
4114
4175
  "upscale(x)",
4115
4176
 
4177
+ "conv_1d_stage_0(x)",
4178
+ "conv_1d_stage_1(x)",
4179
+
4116
4180
  "flash_attn(x)",
4117
4181
  "flash_ff(x)",
4118
4182
  "flash_attn_back(x)",
@@ -4138,7 +4202,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4138
4202
  "cross_entropy_loss_back(x,y)",
4139
4203
  };
4140
4204
 
4141
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
4205
+ static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4142
4206
 
4143
4207
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
4144
4208
 
@@ -4167,7 +4231,10 @@ static void ggml_setup_op_has_task_pass(void) {
4167
4231
  p[GGML_OP_DIAG_MASK_INF ] = true;
4168
4232
  p[GGML_OP_DIAG_MASK_ZERO ] = true;
4169
4233
  p[GGML_OP_CONV_1D ] = true;
4234
+ p[GGML_OP_CONV_1D_STAGE_0 ] = true;
4235
+ p[GGML_OP_CONV_1D_STAGE_1 ] = true;
4170
4236
  p[GGML_OP_CONV_2D ] = true;
4237
+ p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
4171
4238
  p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
4172
4239
  p[GGML_OP_FLASH_ATTN_BACK ] = true;
4173
4240
  p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
@@ -6690,7 +6757,6 @@ struct ggml_tensor * ggml_cont_4d(
6690
6757
  return result;
6691
6758
  }
6692
6759
 
6693
-
6694
6760
  // ggml_reshape
6695
6761
 
6696
6762
  struct ggml_tensor * ggml_reshape(
@@ -7448,14 +7514,17 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
7448
7514
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7449
7515
  }
7450
7516
 
7451
- GGML_API struct ggml_tensor * ggml_conv_1d(
7452
- struct ggml_context * ctx,
7453
- struct ggml_tensor * a,
7454
- struct ggml_tensor * b,
7455
- int s0,
7456
- int p0,
7457
- int d0) {
7458
- GGML_ASSERT(ggml_is_matrix(b));
7517
+ // im2col: [N, IC, IL] => [N, OL, IC*K]
7518
+ // a: [OC,IC, K]
7519
+ // b: [N, IC, IL]
7520
+ // result: [N, OL, IC*K]
7521
+ static struct ggml_tensor * ggml_conv_1d_stage_0(
7522
+ struct ggml_context * ctx,
7523
+ struct ggml_tensor * a,
7524
+ struct ggml_tensor * b,
7525
+ int s0,
7526
+ int p0,
7527
+ int d0) {
7459
7528
  GGML_ASSERT(a->ne[1] == b->ne[1]);
7460
7529
  bool is_node = false;
7461
7530
 
@@ -7464,16 +7533,54 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7464
7533
  is_node = true;
7465
7534
  }
7466
7535
 
7536
+ const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
7537
+
7467
7538
  const int64_t ne[4] = {
7468
- ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7469
- a->ne[2], 1, 1,
7539
+ a->ne[1] * a->ne[0],
7540
+ OL,
7541
+ b->ne[2],
7542
+ 1,
7470
7543
  };
7471
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7544
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
7472
7545
 
7473
7546
  int32_t params[] = { s0, p0, d0 };
7474
7547
  ggml_set_op_params(result, params, sizeof(params));
7475
7548
 
7476
- result->op = GGML_OP_CONV_1D;
7549
+ result->op = GGML_OP_CONV_1D_STAGE_0;
7550
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7551
+ result->src[0] = a;
7552
+ result->src[1] = b;
7553
+
7554
+ return result;
7555
+ }
7556
+
7557
+ // ggml_conv_1d_stage_1
7558
+
7559
+ // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
7560
+ // a: [OC, IC, K]
7561
+ // b: [N, OL, IC * K]
7562
+ // result: [N, OC, OL]
7563
+ static struct ggml_tensor * ggml_conv_1d_stage_1(
7564
+ struct ggml_context * ctx,
7565
+ struct ggml_tensor * a,
7566
+ struct ggml_tensor * b) {
7567
+
7568
+ bool is_node = false;
7569
+
7570
+ if (a->grad || b->grad) {
7571
+ GGML_ASSERT(false); // TODO: implement backward
7572
+ is_node = true;
7573
+ }
7574
+
7575
+ const int64_t ne[4] = {
7576
+ b->ne[1],
7577
+ a->ne[2],
7578
+ b->ne[2],
7579
+ 1,
7580
+ };
7581
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7582
+
7583
+ result->op = GGML_OP_CONV_1D_STAGE_1;
7477
7584
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7478
7585
  result->src[0] = a;
7479
7586
  result->src[1] = b;
@@ -7481,6 +7588,53 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7481
7588
  return result;
7482
7589
  }
7483
7590
 
7591
+ // ggml_conv_1d
7592
+
7593
+ GGML_API struct ggml_tensor * ggml_conv_1d(
7594
+ struct ggml_context * ctx,
7595
+ struct ggml_tensor * a,
7596
+ struct ggml_tensor * b,
7597
+ int s0,
7598
+ int p0,
7599
+ int d0) {
7600
+ struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
7601
+ result = ggml_conv_1d_stage_1(ctx, a, result);
7602
+ return result;
7603
+ }
7604
+
7605
+ // GGML_API struct ggml_tensor * ggml_conv_1d(
7606
+ // struct ggml_context * ctx,
7607
+ // struct ggml_tensor * a,
7608
+ // struct ggml_tensor * b,
7609
+ // int s0,
7610
+ // int p0,
7611
+ // int d0) {
7612
+ // GGML_ASSERT(ggml_is_matrix(b));
7613
+ // GGML_ASSERT(a->ne[1] == b->ne[1]);
7614
+ // bool is_node = false;
7615
+
7616
+ // if (a->grad || b->grad) {
7617
+ // GGML_ASSERT(false); // TODO: implement backward
7618
+ // is_node = true;
7619
+ // }
7620
+
7621
+ // const int64_t ne[4] = {
7622
+ // ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7623
+ // a->ne[2], 1, 1,
7624
+ // };
7625
+ // struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7626
+
7627
+ // int32_t params[] = { s0, p0, d0 };
7628
+ // ggml_set_op_params(result, params, sizeof(params));
7629
+
7630
+ // result->op = GGML_OP_CONV_1D;
7631
+ // result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7632
+ // result->src[0] = a;
7633
+ // result->src[1] = b;
7634
+
7635
+ // return result;
7636
+ // }
7637
+
7484
7638
  // ggml_conv_1d_ph
7485
7639
 
7486
7640
  struct ggml_tensor* ggml_conv_1d_ph(
@@ -7492,6 +7646,50 @@ struct ggml_tensor* ggml_conv_1d_ph(
7492
7646
  return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
7493
7647
  }
7494
7648
 
7649
+ // ggml_conv_transpose_1d
7650
+
7651
+ static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
7652
+ return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
7653
+ }
7654
+
7655
+ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
7656
+ struct ggml_context * ctx,
7657
+ struct ggml_tensor * a,
7658
+ struct ggml_tensor * b,
7659
+ int s0,
7660
+ int p0,
7661
+ int d0) {
7662
+ GGML_ASSERT(ggml_is_matrix(b));
7663
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
7664
+ GGML_ASSERT(a->ne[3] == 1);
7665
+
7666
+ GGML_ASSERT(p0 == 0);
7667
+ GGML_ASSERT(d0 == 1);
7668
+
7669
+ bool is_node = false;
7670
+
7671
+ if (a->grad || b->grad) {
7672
+ GGML_ASSERT(false); // TODO: implement backward
7673
+ is_node = true;
7674
+ }
7675
+
7676
+ const int64_t ne[4] = {
7677
+ ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
7678
+ a->ne[1], b->ne[2], 1,
7679
+ };
7680
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7681
+
7682
+ int32_t params[] = { s0, p0, d0 };
7683
+ ggml_set_op_params(result, params, sizeof(params));
7684
+
7685
+ result->op = GGML_OP_CONV_TRANSPOSE_1D;
7686
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7687
+ result->src[0] = a;
7688
+ result->src[1] = b;
7689
+
7690
+ return result;
7691
+ }
7692
+
7495
7693
  // ggml_conv_2d
7496
7694
 
7497
7695
  struct ggml_tensor * ggml_conv_2d(
@@ -11621,11 +11819,6 @@ static void ggml_compute_forward_mul_mat(
11621
11819
 
11622
11820
  #if defined(GGML_USE_CLBLAST)
11623
11821
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
11624
- // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
11625
- // ref: https://github.com/ggerganov/ggml/pull/224
11626
- GGML_ASSERT(ne02 == ne12);
11627
- GGML_ASSERT(ne03 == ne13);
11628
-
11629
11822
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
11630
11823
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
11631
11824
  }
@@ -12889,7 +13082,7 @@ static void ggml_compute_forward_alibi_f32(
12889
13082
  return;
12890
13083
  }
12891
13084
 
12892
- const int n_past = ((int32_t *) dst->op_params)[0];
13085
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
12893
13086
  const int n_head = ((int32_t *) dst->op_params)[1];
12894
13087
  float max_bias;
12895
13088
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -12910,7 +13103,6 @@ static void ggml_compute_forward_alibi_f32(
12910
13103
  //const int nb3 = src0->nb[3];
12911
13104
 
12912
13105
  GGML_ASSERT(nb0 == sizeof(float));
12913
- GGML_ASSERT(ne1 + n_past == ne0);
12914
13106
  GGML_ASSERT(n_head == ne2);
12915
13107
 
12916
13108
  // add alibi to src0 (KQ_scaled)
@@ -13636,7 +13828,7 @@ static void ggml_compute_forward_rope_back(
13636
13828
 
13637
13829
  // ggml_compute_forward_conv_1d
13638
13830
 
13639
- static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13831
+ static void ggml_compute_forward_conv_1d_f16_f32(
13640
13832
  const struct ggml_compute_params * params,
13641
13833
  const struct ggml_tensor * src0,
13642
13834
  const struct ggml_tensor * src1,
@@ -13654,42 +13846,33 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13654
13846
  const int nth = params->nth;
13655
13847
 
13656
13848
  const int nk = ne00;
13657
- const int nh = nk/2;
13658
13849
 
13659
- const int ew0 = ggml_up32(ne01);
13850
+ // size of the convolution row - the kernel size unrolled across all input channels
13851
+ const int ew0 = nk*ne01;
13852
+
13853
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
13854
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
13855
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
13660
13856
 
13661
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13662
13857
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13663
13858
  GGML_ASSERT(nb10 == sizeof(float));
13664
13859
 
13665
13860
  if (params->type == GGML_TASK_INIT) {
13666
- // TODO: fix this memset (wsize is overestimated)
13667
13861
  memset(params->wdata, 0, params->wsize);
13668
13862
 
13669
- // prepare kernel data (src0)
13670
- {
13671
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13863
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13672
13864
 
13673
- for (int64_t i02 = 0; i02 < ne02; i02++) {
13674
- for (int64_t i01 = 0; i01 < ne01; i01++) {
13675
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
13676
- ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
13677
- for (int64_t i00 = 0; i00 < ne00; i00++) {
13678
- dst_data[i00*ew0 + i01] = src[i00];
13679
- }
13680
- }
13681
- }
13682
- }
13865
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
13866
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
13867
+ ggml_fp16_t * dst_data = wdata;
13683
13868
 
13684
- // prepare source data (src1)
13685
- {
13686
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
13869
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
13870
+ for (int64_t ik = 0; ik < nk; ik++) {
13871
+ const int idx0 = i0*s0 + ik*d0 - p0;
13687
13872
 
13688
- for (int64_t i11 = 0; i11 < ne11; i11++) {
13689
- const float * const src = (float *)((char *) src1->data + i11*nb11);
13690
- ggml_fp16_t * dst_data = wdata;
13691
- for (int64_t i10 = 0; i10 < ne10; i10++) {
13692
- dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
13873
+ if(!(idx0 < 0 || idx0 >= ne10)) {
13874
+ dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]);
13875
+ }
13693
13876
  }
13694
13877
  }
13695
13878
  }
@@ -13702,7 +13885,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13702
13885
  }
13703
13886
 
13704
13887
  // total rows in dst
13705
- const int nr = ne02;
13888
+ const int nr = ne2;
13706
13889
 
13707
13890
  // rows per thread
13708
13891
  const int dr = (nr + nth - 1)/nth;
@@ -13711,23 +13894,22 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13711
13894
  const int ir0 = dr*ith;
13712
13895
  const int ir1 = MIN(ir0 + dr, nr);
13713
13896
 
13714
- for (int i1 = ir0; i1 < ir1; i1++) {
13715
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
13716
- for (int64_t i0 = 0; i0 < ne10; ++i0) {
13717
- dst_data[i0] = 0;
13718
- for (int k = -nh; k <= nh; k++) {
13719
- float v = 0.0f;
13720
- ggml_vec_dot_f16(ew0, &v,
13721
- (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13722
- (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13723
-
13724
- dst_data[i0] += v;
13897
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13898
+
13899
+ for (int i2 = 0; i2 < ne2; i2++) {
13900
+ for (int i1 = ir0; i1 < ir1; i1++) {
13901
+ float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
13902
+
13903
+ for (int i0 = 0; i0 < ne0; i0++) {
13904
+ ggml_vec_dot_f16(ew0, dst_data + i0,
13905
+ (ggml_fp16_t *) ((char *) src0->data + i1*nb02),
13906
+ (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
13725
13907
  }
13726
13908
  }
13727
13909
  }
13728
13910
  }
13729
13911
 
13730
- static void ggml_compute_forward_conv_1d_s1_ph_f32(
13912
+ static void ggml_compute_forward_conv_1d_f32(
13731
13913
  const struct ggml_compute_params * params,
13732
13914
  const struct ggml_tensor * src0,
13733
13915
  const struct ggml_tensor * src1,
@@ -13745,42 +13927,32 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13745
13927
  const int nth = params->nth;
13746
13928
 
13747
13929
  const int nk = ne00;
13748
- const int nh = nk/2;
13749
13930
 
13750
- const int ew0 = ggml_up32(ne01);
13931
+ const int ew0 = nk*ne01;
13932
+
13933
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
13934
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
13935
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
13751
13936
 
13752
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13753
13937
  GGML_ASSERT(nb00 == sizeof(float));
13754
13938
  GGML_ASSERT(nb10 == sizeof(float));
13755
13939
 
13756
13940
  if (params->type == GGML_TASK_INIT) {
13757
- // TODO: fix this memset (wsize is overestimated)
13758
13941
  memset(params->wdata, 0, params->wsize);
13759
13942
 
13760
- // prepare kernel data (src0)
13761
- {
13762
- float * const wdata = (float *) params->wdata + 0;
13943
+ float * const wdata = (float *) params->wdata + 0;
13763
13944
 
13764
- for (int64_t i02 = 0; i02 < ne02; i02++) {
13765
- for (int64_t i01 = 0; i01 < ne01; i01++) {
13766
- const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
13767
- float * dst_data = wdata + i02*ew0*ne00;
13768
- for (int64_t i00 = 0; i00 < ne00; i00++) {
13769
- dst_data[i00*ew0 + i01] = src[i00];
13770
- }
13771
- }
13772
- }
13773
- }
13945
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
13946
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
13947
+ float * dst_data = wdata;
13774
13948
 
13775
- // prepare source data (src1)
13776
- {
13777
- float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
13949
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
13950
+ for (int64_t ik = 0; ik < nk; ik++) {
13951
+ const int idx0 = i0*s0 + ik*d0 - p0;
13778
13952
 
13779
- for (int64_t i11 = 0; i11 < ne11; i11++) {
13780
- const float * const src = (float *)((char *) src1->data + i11*nb11);
13781
- float * dst_data = wdata;
13782
- for (int64_t i10 = 0; i10 < ne10; i10++) {
13783
- dst_data[(i10 + nh)*ew0 + i11] = src[i10];
13953
+ if(!(idx0 < 0 || idx0 >= ne10)) {
13954
+ dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
13955
+ }
13784
13956
  }
13785
13957
  }
13786
13958
  }
@@ -13802,35 +13974,242 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13802
13974
  const int ir0 = dr*ith;
13803
13975
  const int ir1 = MIN(ir0 + dr, nr);
13804
13976
 
13805
- for (int i1 = ir0; i1 < ir1; i1++) {
13806
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
13807
- for (int64_t i0 = 0; i0 < ne10; ++i0) {
13808
- dst_data[i0] = 0;
13809
- for (int k = -nh; k <= nh; k++) {
13810
- float v = 0.0f;
13811
- ggml_vec_dot_f32(ew0, &v,
13812
- (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13813
- (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13814
-
13815
- dst_data[i0] += v;
13977
+ float * const wdata = (float *) params->wdata + 0;
13978
+
13979
+ for (int i2 = 0; i2 < ne2; i2++) {
13980
+ for (int i1 = ir0; i1 < ir1; i1++) {
13981
+ float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
13982
+
13983
+ for (int i0 = 0; i0 < ne0; i0++) {
13984
+ ggml_vec_dot_f32(ew0, dst_data + i0,
13985
+ (float *) ((char *) src0->data + i1*nb02),
13986
+ (float *) wdata + i2*nb2 + i0*ew0);
13987
+ }
13988
+ }
13989
+ }
13990
+ }
13991
+
13992
+ static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
13993
+ ggml_fp16_t * A,
13994
+ ggml_fp16_t * B,
13995
+ float * C,
13996
+ const int ith, const int nth) {
13997
+ // does not seem to make a difference
13998
+ int64_t m0, m1, n0, n1;
13999
+ // patches per thread
14000
+ if (m > n) {
14001
+ n0 = 0;
14002
+ n1 = n;
14003
+
14004
+ // total patches in dst
14005
+ const int np = m;
14006
+
14007
+ // patches per thread
14008
+ const int dp = (np + nth - 1)/nth;
14009
+
14010
+ // patch range for this thread
14011
+ m0 = dp*ith;
14012
+ m1 = MIN(m0 + dp, np);
14013
+ } else {
14014
+ m0 = 0;
14015
+ m1 = m;
14016
+
14017
+ // total patches in dst
14018
+ const int np = n;
14019
+
14020
+ // patches per thread
14021
+ const int dp = (np + nth - 1)/nth;
14022
+
14023
+ // patch range for this thread
14024
+ n0 = dp*ith;
14025
+ n1 = MIN(n0 + dp, np);
14026
+ }
14027
+
14028
+ // block-tiling attempt
14029
+ int64_t blck_n = 16;
14030
+ int64_t blck_m = 16;
14031
+
14032
+ // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
14033
+ // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
14034
+ // if (blck_size > 0) {
14035
+ // blck_0 = 4;
14036
+ // blck_1 = blck_size / blck_0;
14037
+ // if (blck_1 < 0) {
14038
+ // blck_1 = 1;
14039
+ // }
14040
+ // // blck_0 = (int64_t)sqrt(blck_size);
14041
+ // // blck_1 = blck_0;
14042
+ // }
14043
+ // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
14044
+
14045
+ for (int j = n0; j < n1; j+=blck_n) {
14046
+ for (int i = m0; i < m1; i+=blck_m) {
14047
+ // printf("i j k => %d %d %d\n", i, j, K);
14048
+ for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
14049
+ for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
14050
+ ggml_vec_dot_f16(k,
14051
+ C + ii*n + jj,
14052
+ A + ii * k,
14053
+ B + jj * k);
14054
+ }
13816
14055
  }
13817
14056
  }
13818
14057
  }
13819
14058
  }
13820
14059
 
13821
- static void ggml_compute_forward_conv_1d_s1_ph(
14060
+ // src0: kernel [OC, IC, K]
14061
+ // src1: signal [N, IC, IL]
14062
+ // dst: result [N, OL, IC*K]
14063
+ static void ggml_compute_forward_conv_1d_stage_0_f32(
13822
14064
  const struct ggml_compute_params * params,
13823
14065
  const struct ggml_tensor * src0,
13824
14066
  const struct ggml_tensor * src1,
13825
14067
  struct ggml_tensor * dst) {
13826
- switch (src0->type) {
14068
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14069
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
14070
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
14071
+
14072
+ int64_t t0 = ggml_perf_time_us();
14073
+ UNUSED(t0);
14074
+
14075
+ GGML_TENSOR_BINARY_OP_LOCALS;
14076
+
14077
+ const int64_t N = ne12;
14078
+ const int64_t IC = ne11;
14079
+ const int64_t IL = ne10;
14080
+
14081
+ const int64_t K = ne00;
14082
+
14083
+ const int64_t OL = ne1;
14084
+
14085
+ const int ith = params->ith;
14086
+ const int nth = params->nth;
14087
+
14088
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14089
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
14090
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
14091
+
14092
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14093
+ GGML_ASSERT(nb10 == sizeof(float));
14094
+
14095
+ if (params->type == GGML_TASK_INIT) {
14096
+ memset(dst->data, 0, ggml_nbytes(dst));
14097
+ return;
14098
+ }
14099
+
14100
+ if (params->type == GGML_TASK_FINALIZE) {
14101
+ return;
14102
+ }
14103
+
14104
+ // im2col: [N, IC, IL] => [N, OL, IC*K]
14105
+ {
14106
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
14107
+
14108
+ for (int64_t in = 0; in < N; in++) {
14109
+ for (int64_t iol = 0; iol < OL; iol++) {
14110
+ for (int64_t iic = ith; iic < IC; iic+=nth) {
14111
+
14112
+ // micro kernel
14113
+ ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
14114
+ const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
14115
+
14116
+ for (int64_t ik = 0; ik < K; ik++) {
14117
+ const int64_t iil = iol*s0 + ik*d0 - p0;
14118
+
14119
+ if (!(iil < 0 || iil >= IL)) {
14120
+ dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]);
14121
+ }
14122
+ }
14123
+ }
14124
+ }
14125
+ }
14126
+ }
14127
+ }
14128
+
14129
+ // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
14130
+ // src0: [OC, IC, K]
14131
+ // src1: [N, OL, IC * K]
14132
+ // result: [N, OC, OL]
14133
+ static void ggml_compute_forward_conv_1d_stage_1_f16(
14134
+ const struct ggml_compute_params * params,
14135
+ const struct ggml_tensor * src0,
14136
+ const struct ggml_tensor * src1,
14137
+ struct ggml_tensor * dst) {
14138
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14139
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
14140
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
14141
+
14142
+ int64_t t0 = ggml_perf_time_us();
14143
+ UNUSED(t0);
14144
+
14145
+ if (params->type == GGML_TASK_INIT) {
14146
+ return;
14147
+ }
14148
+
14149
+ if (params->type == GGML_TASK_FINALIZE) {
14150
+ return;
14151
+ }
14152
+
14153
+ GGML_TENSOR_BINARY_OP_LOCALS;
14154
+
14155
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14156
+ GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
14157
+ GGML_ASSERT(nb0 == sizeof(float));
14158
+
14159
+ const int N = ne12;
14160
+ const int OL = ne11;
14161
+
14162
+ const int OC = ne02;
14163
+ const int IC = ne01;
14164
+ const int K = ne00;
14165
+
14166
+ const int ith = params->ith;
14167
+ const int nth = params->nth;
14168
+
14169
+ int64_t m = OC;
14170
+ int64_t n = OL;
14171
+ int64_t k = IC * K;
14172
+
14173
+ // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
14174
+ for (int i = 0; i < N; i++) {
14175
+ ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
14176
+ ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
14177
+ float * C = (float *)dst->data + i * m * n; // [m, n]
14178
+
14179
+ gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
14180
+ }
14181
+ }
14182
+
14183
+ static void ggml_compute_forward_conv_1d(
14184
+ const struct ggml_compute_params * params,
14185
+ const struct ggml_tensor * src0,
14186
+ const struct ggml_tensor * src1,
14187
+ struct ggml_tensor * dst) {
14188
+ switch(src0->type) {
13827
14189
  case GGML_TYPE_F16:
13828
14190
  {
13829
- ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst);
14191
+ ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
13830
14192
  } break;
13831
14193
  case GGML_TYPE_F32:
13832
14194
  {
13833
- ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst);
14195
+ ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
14196
+ } break;
14197
+ default:
14198
+ {
14199
+ GGML_ASSERT(false);
14200
+ } break;
14201
+ }
14202
+ }
14203
+
14204
+ static void ggml_compute_forward_conv_1d_stage_0(
14205
+ const struct ggml_compute_params * params,
14206
+ const struct ggml_tensor * src0,
14207
+ const struct ggml_tensor * src1,
14208
+ struct ggml_tensor * dst) {
14209
+ switch(src0->type) {
14210
+ case GGML_TYPE_F16:
14211
+ {
14212
+ ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
13834
14213
  } break;
13835
14214
  default:
13836
14215
  {
@@ -13839,7 +14218,26 @@ static void ggml_compute_forward_conv_1d_s1_ph(
13839
14218
  }
13840
14219
  }
13841
14220
 
13842
- static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
14221
+ static void ggml_compute_forward_conv_1d_stage_1(
14222
+ const struct ggml_compute_params * params,
14223
+ const struct ggml_tensor * src0,
14224
+ const struct ggml_tensor * src1,
14225
+ struct ggml_tensor * dst) {
14226
+ switch(src0->type) {
14227
+ case GGML_TYPE_F16:
14228
+ {
14229
+ ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
14230
+ } break;
14231
+ default:
14232
+ {
14233
+ GGML_ASSERT(false);
14234
+ } break;
14235
+ }
14236
+ }
14237
+
14238
+ // ggml_compute_forward_conv_transpose_1d
14239
+
14240
+ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
13843
14241
  const struct ggml_compute_params * params,
13844
14242
  const struct ggml_tensor * src0,
13845
14243
  const struct ggml_tensor * src1,
@@ -13856,43 +14254,38 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13856
14254
  const int ith = params->ith;
13857
14255
  const int nth = params->nth;
13858
14256
 
13859
- const int nk = ne00;
13860
- const int nh = nk/2;
13861
-
13862
- const int ew0 = ggml_up32(ne01);
14257
+ const int nk = ne00*ne01*ne02;
13863
14258
 
13864
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13865
14259
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13866
14260
  GGML_ASSERT(nb10 == sizeof(float));
13867
14261
 
13868
14262
  if (params->type == GGML_TASK_INIT) {
13869
- // TODO: fix this memset (wsize is overestimated)
13870
14263
  memset(params->wdata, 0, params->wsize);
13871
14264
 
13872
- // prepare kernel data (src0)
14265
+ // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
13873
14266
  {
13874
14267
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13875
14268
 
13876
14269
  for (int64_t i02 = 0; i02 < ne02; i02++) {
13877
14270
  for (int64_t i01 = 0; i01 < ne01; i01++) {
13878
14271
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
13879
- ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
14272
+ ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
13880
14273
  for (int64_t i00 = 0; i00 < ne00; i00++) {
13881
- dst_data[i00*ew0 + i01] = src[i00];
14274
+ dst_data[i00*ne02 + i02] = src[i00];
13882
14275
  }
13883
14276
  }
13884
14277
  }
13885
14278
  }
13886
14279
 
13887
- // prepare source data (src1)
14280
+ // permute source data (src1) from (L x Cin) to (Cin x L)
13888
14281
  {
13889
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
14282
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
14283
+ ggml_fp16_t * dst_data = wdata;
13890
14284
 
13891
14285
  for (int64_t i11 = 0; i11 < ne11; i11++) {
13892
14286
  const float * const src = (float *)((char *) src1->data + i11*nb11);
13893
- ggml_fp16_t * dst_data = wdata;
13894
14287
  for (int64_t i10 = 0; i10 < ne10; i10++) {
13895
- dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
14288
+ dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
13896
14289
  }
13897
14290
  }
13898
14291
  }
@@ -13904,8 +14297,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13904
14297
  return;
13905
14298
  }
13906
14299
 
14300
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14301
+
13907
14302
  // total rows in dst
13908
- const int nr = ne02;
14303
+ const int nr = ne1;
13909
14304
 
13910
14305
  // rows per thread
13911
14306
  const int dr = (nr + nth - 1)/nth;
@@ -13914,23 +14309,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13914
14309
  const int ir0 = dr*ith;
13915
14310
  const int ir1 = MIN(ir0 + dr, nr);
13916
14311
 
14312
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14313
+ ggml_fp16_t * const wdata_src = wdata + nk;
14314
+
13917
14315
  for (int i1 = ir0; i1 < ir1; i1++) {
13918
14316
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
13919
- for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
13920
- dst_data[i0/2] = 0;
13921
- for (int k = -nh; k <= nh; k++) {
13922
- float v = 0.0f;
13923
- ggml_vec_dot_f16(ew0, &v,
13924
- (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13925
- (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13926
-
13927
- dst_data[i0/2] += v;
14317
+ ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
14318
+ for (int i10 = 0; i10 < ne10; i10++) {
14319
+ const int i1n = i10*ne11;
14320
+ for (int i00 = 0; i00 < ne00; i00++) {
14321
+ float v = 0;
14322
+ ggml_vec_dot_f16(ne02, &v,
14323
+ (ggml_fp16_t *) wdata_src + i1n,
14324
+ (ggml_fp16_t *) wdata_kernel + i00*ne02);
14325
+ dst_data[i10*s0 + i00] += v;
13928
14326
  }
13929
14327
  }
13930
14328
  }
13931
14329
  }
13932
14330
 
13933
- static void ggml_compute_forward_conv_1d_s2_ph_f32(
14331
+ static void ggml_compute_forward_conv_transpose_1d_f32(
13934
14332
  const struct ggml_compute_params * params,
13935
14333
  const struct ggml_tensor * src0,
13936
14334
  const struct ggml_tensor * src1,
@@ -13947,29 +14345,24 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13947
14345
  const int ith = params->ith;
13948
14346
  const int nth = params->nth;
13949
14347
 
13950
- const int nk = ne00;
13951
- const int nh = nk/2;
13952
-
13953
- const int ew0 = ggml_up32(ne01);
14348
+ const int nk = ne00*ne01*ne02;
13954
14349
 
13955
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13956
14350
  GGML_ASSERT(nb00 == sizeof(float));
13957
14351
  GGML_ASSERT(nb10 == sizeof(float));
13958
14352
 
13959
14353
  if (params->type == GGML_TASK_INIT) {
13960
- // TODO: fix this memset (wsize is overestimated)
13961
14354
  memset(params->wdata, 0, params->wsize);
13962
14355
 
13963
- // prepare kernel data (src0)
14356
+ // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
13964
14357
  {
13965
14358
  float * const wdata = (float *) params->wdata + 0;
13966
14359
 
13967
14360
  for (int64_t i02 = 0; i02 < ne02; i02++) {
13968
14361
  for (int64_t i01 = 0; i01 < ne01; i01++) {
13969
14362
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
13970
- float * dst_data = wdata + i02*ew0*ne00;
14363
+ float * dst_data = wdata + i01*ne00*ne02;
13971
14364
  for (int64_t i00 = 0; i00 < ne00; i00++) {
13972
- dst_data[i00*ew0 + i01] = src[i00];
14365
+ dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00];
13973
14366
  }
13974
14367
  }
13975
14368
  }
@@ -13977,13 +14370,13 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13977
14370
 
13978
14371
  // prepare source data (src1)
13979
14372
  {
13980
- float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
14373
+ float * const wdata = (float *) params->wdata + nk;
14374
+ float * dst_data = wdata;
13981
14375
 
13982
14376
  for (int64_t i11 = 0; i11 < ne11; i11++) {
13983
14377
  const float * const src = (float *)((char *) src1->data + i11*nb11);
13984
- float * dst_data = wdata;
13985
14378
  for (int64_t i10 = 0; i10 < ne10; i10++) {
13986
- dst_data[(i10 + nh)*ew0 + i11] = src[i10];
14379
+ dst_data[i10*ne11 + i11] = src[i10];
13987
14380
  }
13988
14381
  }
13989
14382
  }
@@ -13995,8 +14388,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13995
14388
  return;
13996
14389
  }
13997
14390
 
14391
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14392
+
13998
14393
  // total rows in dst
13999
- const int nr = ne02;
14394
+ const int nr = ne1;
14000
14395
 
14001
14396
  // rows per thread
14002
14397
  const int dr = (nr + nth - 1)/nth;
@@ -14005,23 +14400,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
14005
14400
  const int ir0 = dr*ith;
14006
14401
  const int ir1 = MIN(ir0 + dr, nr);
14007
14402
 
14403
+ float * const wdata = (float *) params->wdata + 0;
14404
+ float * const wdata_src = wdata + nk;
14405
+
14008
14406
  for (int i1 = ir0; i1 < ir1; i1++) {
14009
14407
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
14010
- for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
14011
- dst_data[i0/2] = 0;
14012
- for (int k = -nh; k <= nh; k++) {
14013
- float v = 0.0f;
14014
- ggml_vec_dot_f32(ew0, &v,
14015
- (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
14016
- (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
14017
-
14018
- dst_data[i0/2] += v;
14408
+ float * wdata_kernel = wdata + i1*ne02*ne00;
14409
+ for (int i10 = 0; i10 < ne10; i10++) {
14410
+ const int i1n = i10*ne11;
14411
+ for (int i00 = 0; i00 < ne00; i00++) {
14412
+ float v = 0;
14413
+ ggml_vec_dot_f32(ne02, &v,
14414
+ wdata_src + i1n,
14415
+ wdata_kernel + i00*ne02);
14416
+ dst_data[i10*s0 + i00] += v;
14019
14417
  }
14020
14418
  }
14021
14419
  }
14022
14420
  }
14023
14421
 
14024
- static void ggml_compute_forward_conv_1d_s2_ph(
14422
+ static void ggml_compute_forward_conv_transpose_1d(
14025
14423
  const struct ggml_compute_params * params,
14026
14424
  const struct ggml_tensor * src0,
14027
14425
  const struct ggml_tensor * src1,
@@ -14029,11 +14427,11 @@ static void ggml_compute_forward_conv_1d_s2_ph(
14029
14427
  switch (src0->type) {
14030
14428
  case GGML_TYPE_F16:
14031
14429
  {
14032
- ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst);
14430
+ ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
14033
14431
  } break;
14034
14432
  case GGML_TYPE_F32:
14035
14433
  {
14036
- ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst);
14434
+ ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
14037
14435
  } break;
14038
14436
  default:
14039
14437
  {
@@ -14042,27 +14440,6 @@ static void ggml_compute_forward_conv_1d_s2_ph(
14042
14440
  }
14043
14441
  }
14044
14442
 
14045
- // ggml_compute_forward_conv_1d
14046
-
14047
- static void ggml_compute_forward_conv_1d(
14048
- const struct ggml_compute_params * params,
14049
- const struct ggml_tensor * src0,
14050
- const struct ggml_tensor * src1,
14051
- struct ggml_tensor * dst) {
14052
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14053
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
14054
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
14055
- GGML_ASSERT(d0 == 1); // dilation not supported
14056
- GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported
14057
- if (s0 == 1) {
14058
- ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst);
14059
- } else if (s0 == 2) {
14060
- ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
14061
- } else {
14062
- GGML_ASSERT(false); // only stride 1 and 2 supported
14063
- }
14064
- }
14065
-
14066
14443
  // ggml_compute_forward_conv_2d
14067
14444
 
14068
14445
  static void ggml_compute_forward_conv_2d_f16_f32(
@@ -14105,20 +14482,22 @@ static void ggml_compute_forward_conv_2d_f16_f32(
14105
14482
  {
14106
14483
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14107
14484
 
14108
- for (int i12 = 0; i12 < ne12; i12++) {
14109
- const float * const src = (float *)((char *) src1->data + i12*nb12);
14110
- ggml_fp16_t * dst_data = wdata;
14111
-
14112
- for (int i1 = 0; i1 < ne1; i1++) {
14113
- for (int i0 = 0; i0 < ne0; i0++) {
14114
- for (int ik1 = 0; ik1 < nk1; ik1++) {
14115
- for (int ik0 = 0; ik0 < nk0; ik0++) {
14116
- const int idx0 = i0*s0 + ik0*d0 - p0;
14117
- const int idx1 = i1*s1 + ik1*d1 - p1;
14118
-
14119
- if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
14120
- dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
14121
- GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
14485
+ for (int i13 = 0; i13 < ne13; i13++) {
14486
+ for (int i12 = 0; i12 < ne12; i12++) {
14487
+ const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
14488
+ ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
14489
+
14490
+ for (int i1 = 0; i1 < ne1; i1++) {
14491
+ for (int i0 = 0; i0 < ne0; i0++) {
14492
+ for (int ik1 = 0; ik1 < nk1; ik1++) {
14493
+ for (int ik0 = 0; ik0 < nk0; ik0++) {
14494
+ const int idx0 = i0*s0 + ik0*d0 - p0;
14495
+ const int idx1 = i1*s1 + ik1*d1 - p1;
14496
+
14497
+ if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
14498
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
14499
+ GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
14500
+ }
14122
14501
  }
14123
14502
  }
14124
14503
  }
@@ -16401,6 +16780,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16401
16780
  {
16402
16781
  ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
16403
16782
  } break;
16783
+ case GGML_OP_CONV_1D_STAGE_0:
16784
+ {
16785
+ ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
16786
+ } break;
16787
+ case GGML_OP_CONV_1D_STAGE_1:
16788
+ {
16789
+ ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
16790
+ } break;
16791
+ case GGML_OP_CONV_TRANSPOSE_1D:
16792
+ {
16793
+ ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
16794
+ } break;
16404
16795
  case GGML_OP_CONV_2D:
16405
16796
  {
16406
16797
  ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
@@ -17326,10 +17717,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17326
17717
  {
17327
17718
  GGML_ASSERT(false); // TODO: not implemented
17328
17719
  } break;
17720
+ case GGML_OP_CONV_1D_STAGE_0:
17721
+ {
17722
+ GGML_ASSERT(false); // TODO: not implemented
17723
+ } break;
17724
+ case GGML_OP_CONV_1D_STAGE_1:
17725
+ {
17726
+ GGML_ASSERT(false); // TODO: not implemented
17727
+ } break;
17329
17728
  case GGML_OP_CONV_2D:
17330
17729
  {
17331
17730
  GGML_ASSERT(false); // TODO: not implemented
17332
17731
  } break;
17732
+ case GGML_OP_CONV_TRANSPOSE_1D:
17733
+ {
17734
+ GGML_ASSERT(false); // TODO: not implemented
17735
+ } break;
17333
17736
  case GGML_OP_CONV_TRANSPOSE_2D:
17334
17737
  {
17335
17738
  GGML_ASSERT(false); // TODO: not implemented
@@ -18171,21 +18574,68 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
18171
18574
  GGML_ASSERT(node->src[1]->ne[2] == 1);
18172
18575
  GGML_ASSERT(node->src[1]->ne[3] == 1);
18173
18576
 
18577
+ const int64_t ne00 = node->src[0]->ne[0];
18578
+ const int64_t ne01 = node->src[0]->ne[1];
18579
+ const int64_t ne02 = node->src[0]->ne[2];
18580
+
18581
+ const int64_t ne10 = node->src[1]->ne[0];
18582
+ const int64_t ne11 = node->src[1]->ne[1];
18583
+
18584
+ const int64_t ne0 = node->ne[0];
18585
+ const int64_t ne1 = node->ne[1];
18586
+ const int64_t nk = ne00;
18587
+ const int64_t ew0 = nk * ne01;
18588
+
18589
+ UNUSED(ne02);
18590
+ UNUSED(ne10);
18591
+ UNUSED(ne11);
18592
+
18174
18593
  size_t cur = 0;
18175
- const int nk = node->src[0]->ne[0];
18176
18594
 
18177
18595
  if (node->src[0]->type == GGML_TYPE_F16 &&
18178
- node->src[1]->type == GGML_TYPE_F32) {
18179
- cur = sizeof(ggml_fp16_t)*(
18180
- nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
18181
- ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
18182
- );
18596
+ node->src[1]->type == GGML_TYPE_F32) {
18597
+ cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
18598
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
18599
+ node->src[1]->type == GGML_TYPE_F32) {
18600
+ cur = sizeof(float)*(ne0*ne1*ew0);
18601
+ } else {
18602
+ GGML_ASSERT(false);
18603
+ }
18604
+
18605
+ work_size = MAX(work_size, cur);
18606
+ } break;
18607
+ case GGML_OP_CONV_1D_STAGE_0:
18608
+ {
18609
+ n_tasks = n_threads;
18610
+ } break;
18611
+ case GGML_OP_CONV_1D_STAGE_1:
18612
+ {
18613
+ n_tasks = n_threads;
18614
+ } break;
18615
+ case GGML_OP_CONV_TRANSPOSE_1D:
18616
+ {
18617
+ n_tasks = n_threads;
18618
+
18619
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
18620
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
18621
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
18622
+
18623
+ const int64_t ne00 = node->src[0]->ne[0]; // K
18624
+ const int64_t ne01 = node->src[0]->ne[1]; // Cout
18625
+ const int64_t ne02 = node->src[0]->ne[2]; // Cin
18626
+
18627
+ const int64_t ne10 = node->src[1]->ne[0]; // L
18628
+ const int64_t ne11 = node->src[1]->ne[1]; // Cin
18629
+
18630
+ size_t cur = 0;
18631
+ if (node->src[0]->type == GGML_TYPE_F16 &&
18632
+ node->src[1]->type == GGML_TYPE_F32) {
18633
+ cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18634
+ cur += sizeof(ggml_fp16_t)*ne10*ne11;
18183
18635
  } else if (node->src[0]->type == GGML_TYPE_F32 &&
18184
- node->src[1]->type == GGML_TYPE_F32) {
18185
- cur = sizeof(float)*(
18186
- nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
18187
- ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
18188
- );
18636
+ node->src[1]->type == GGML_TYPE_F32) {
18637
+ cur += sizeof(float)*ne00*ne01*ne02;
18638
+ cur += sizeof(float)*ne10*ne11;
18189
18639
  } else {
18190
18640
  GGML_ASSERT(false);
18191
18641
  }
@@ -19311,7 +19761,7 @@ static enum ggml_opt_result ggml_opt_adam(
19311
19761
  if (callback) {
19312
19762
  callback(callback_data, accum_step, &sched, &cancel);
19313
19763
  if (cancel) {
19314
- break;
19764
+ return GGML_OPT_CANCEL;
19315
19765
  }
19316
19766
  }
19317
19767
  // ggml_graph_reset (gf);
@@ -19320,9 +19770,6 @@ static enum ggml_opt_result ggml_opt_adam(
19320
19770
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19321
19771
  fx += ggml_get_f32_1d(f, 0);
19322
19772
  }
19323
- if (cancel) {
19324
- return GGML_OPT_DID_NOT_CONVERGE;
19325
- }
19326
19773
  fx *= accum_norm;
19327
19774
 
19328
19775
  opt->adam.fx_prev = fx;
@@ -19348,9 +19795,6 @@ static enum ggml_opt_result ggml_opt_adam(
19348
19795
 
19349
19796
  // run the optimizer
19350
19797
  for (int t = 0; t < params.adam.n_iter; ++t) {
19351
- if (cancel) {
19352
- break;
19353
- }
19354
19798
  opt->iter = iter0 + t + 1;
19355
19799
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
19356
19800
 
@@ -19408,7 +19852,7 @@ static enum ggml_opt_result ggml_opt_adam(
19408
19852
  if (callback) {
19409
19853
  callback(callback_data, accum_step, &sched, &cancel);
19410
19854
  if (cancel) {
19411
- break;
19855
+ return GGML_OPT_CANCEL;;
19412
19856
  }
19413
19857
  }
19414
19858
  // ggml_graph_reset (gf);
@@ -19417,9 +19861,6 @@ static enum ggml_opt_result ggml_opt_adam(
19417
19861
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19418
19862
  fx += ggml_get_f32_1d(f, 0);
19419
19863
  }
19420
- if (cancel) {
19421
- break;
19422
- }
19423
19864
  fx *= accum_norm;
19424
19865
 
19425
19866
  opt->loss_after = fx;
@@ -19538,7 +19979,7 @@ static enum ggml_opt_result linesearch_backtracking(
19538
19979
  finit = *fx;
19539
19980
  dgtest = params->lbfgs.ftol*dginit;
19540
19981
 
19541
- while (!*cancel) {
19982
+ while (true) {
19542
19983
  ggml_vec_cpy_f32(nx, x, xp);
19543
19984
  ggml_vec_mad_f32(nx, x, d, *step);
19544
19985
 
@@ -19554,7 +19995,7 @@ static enum ggml_opt_result linesearch_backtracking(
19554
19995
  float sched = 0;
19555
19996
  callback(callback_data, accum_step, &sched, cancel);
19556
19997
  if (*cancel) {
19557
- break;
19998
+ return GGML_OPT_CANCEL;
19558
19999
  }
19559
20000
  }
19560
20001
  // ggml_graph_reset (gf);
@@ -19563,9 +20004,6 @@ static enum ggml_opt_result linesearch_backtracking(
19563
20004
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19564
20005
  *fx += ggml_get_f32_1d(f, 0);
19565
20006
  }
19566
- if (*cancel) {
19567
- break;
19568
- }
19569
20007
  *fx *= accum_norm;
19570
20008
 
19571
20009
  }
@@ -19698,7 +20136,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19698
20136
  float sched = 0;
19699
20137
  callback(callback_data, accum_step, &sched, &cancel);
19700
20138
  if (cancel) {
19701
- break;
20139
+ return GGML_OPT_CANCEL;
19702
20140
  }
19703
20141
  }
19704
20142
  // ggml_graph_reset (gf);
@@ -19707,9 +20145,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19707
20145
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19708
20146
  fx += ggml_get_f32_1d(f, 0);
19709
20147
  }
19710
- if (cancel) {
19711
- return GGML_OPT_DID_NOT_CONVERGE;
19712
- }
19713
20148
  fx *= accum_norm;
19714
20149
 
19715
20150
  opt->loss_before = fx;
@@ -19769,8 +20204,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19769
20204
  ggml_vec_cpy_f32(nx, gp, g);
19770
20205
 
19771
20206
  ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
19772
- if (!cancel) {
19773
- break;
20207
+ if (cancel) {
20208
+ return GGML_OPT_CANCEL;
19774
20209
  }
19775
20210
 
19776
20211
  if (ls < 0) {