llama_cpp 0.6.0 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1032,8 +1032,8 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
1032
1032
  y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1033
1033
 
1034
1034
  // get the 5-th bit and store it in qh at the right position
1035
- qh |= ((xi0 & 0x10) >> 4) << (j + 0);
1036
- qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
1035
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1036
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1037
1037
  }
1038
1038
 
1039
1039
  memcpy(&y[i].qh, &qh, sizeof(qh));
@@ -1080,8 +1080,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
1080
1080
  y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1081
1081
 
1082
1082
  // get the 5-th bit and store it in qh at the right position
1083
- qh |= ((xi0 & 0x10) >> 4) << (j + 0);
1084
- qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
1083
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1084
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1085
1085
  }
1086
1086
 
1087
1087
  memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
@@ -1272,6 +1272,33 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1272
1272
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1273
1273
  #endif
1274
1274
  }
1275
+ #elif defined(__riscv_v_intrinsic)
1276
+
1277
+ size_t vl = __riscv_vsetvl_e32m4(QK8_0);
1278
+
1279
+ for (int i = 0; i < nb; i++) {
1280
+ // load elements
1281
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
1282
+
1283
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1284
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
1285
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1286
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1287
+
1288
+ const float d = amax / ((1 << 7) - 1);
1289
+ const float id = d ? 1.0f/d : 0.0f;
1290
+
1291
+ y[i].d = GGML_FP32_TO_FP16(d);
1292
+
1293
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1294
+
1295
+ // convert to integer
1296
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1297
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1298
+
1299
+ // store result
1300
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1301
+ }
1275
1302
  #else
1276
1303
  // scalar
1277
1304
  quantize_row_q8_0_reference(x, y, k);
@@ -1490,6 +1517,41 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
1490
1517
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1491
1518
  #endif
1492
1519
  }
1520
+ #elif defined(__riscv_v_intrinsic)
1521
+
1522
+ size_t vl = __riscv_vsetvl_e32m4(QK8_1);
1523
+
1524
+ for (int i = 0; i < nb; i++) {
1525
+ // load elements
1526
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
1527
+
1528
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1529
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
1530
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1531
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1532
+
1533
+ const float d = amax / ((1 << 7) - 1);
1534
+ const float id = d ? 1.0f/d : 0.0f;
1535
+
1536
+ y[i].d = d;
1537
+
1538
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1539
+
1540
+ // convert to integer
1541
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1542
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1543
+
1544
+ // store result
1545
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1546
+
1547
+ // compute sum for y[i].s
1548
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
1549
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
1550
+
1551
+ // set y[i].s
1552
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
1553
+ y[i].s = sum*d;
1554
+ }
1493
1555
  #else
1494
1556
  // scalar
1495
1557
  quantize_row_q8_1_reference(x, y, k);
@@ -2662,30 +2724,32 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2662
2724
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
2663
2725
 
2664
2726
  for (int i = 0; i < nb; i++) {
2665
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2727
+ // load elements
2728
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2666
2729
 
2667
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2668
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2730
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2731
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2669
2732
 
2670
- vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2671
- vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2733
+ // mask and store lower part of x, and then upper part
2734
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2735
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2672
2736
 
2673
- vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2674
- vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2737
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2738
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2675
2739
 
2676
- vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
2677
- vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
2740
+ // subtract offset
2741
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2742
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2678
2743
 
2679
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2680
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2744
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2745
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2681
2746
 
2682
2747
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2683
2748
 
2684
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2685
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2749
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2750
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2686
2751
 
2687
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2688
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2752
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2689
2753
 
2690
2754
  sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2691
2755
  }
@@ -2823,27 +2887,28 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2823
2887
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
2824
2888
 
2825
2889
  for (int i = 0; i < nb; i++) {
2826
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2890
+ // load elements
2891
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2827
2892
 
2828
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2829
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2893
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2894
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2830
2895
 
2831
- vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2832
- vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2896
+ // mask and store lower part of x, and then upper part
2897
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2898
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2833
2899
 
2834
- vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2835
- vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2900
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2901
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2836
2902
 
2837
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2838
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2903
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2904
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2839
2905
 
2840
2906
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2841
2907
 
2842
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2843
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2908
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2909
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2844
2910
 
2845
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2846
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2911
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2847
2912
 
2848
2913
  sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2849
2914
  }
@@ -3088,66 +3153,61 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
3088
3153
 
3089
3154
  uint32_t qh;
3090
3155
 
3091
- // These temp values are for masking and shift operations
3092
- uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3093
- uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
3094
- 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
3095
-
3096
3156
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3097
3157
 
3158
+ // These tempory registers are for masking and shift operations
3159
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3160
+ vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3161
+
3162
+ vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3163
+ vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3164
+
3098
3165
  for (int i = 0; i < nb; i++) {
3099
3166
  memcpy(&qh, x[i].qh, sizeof(uint32_t));
3100
3167
 
3101
- // temporary registers
3102
- vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
3103
- vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
3104
- vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
3105
- vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
3106
-
3107
3168
  // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3108
- vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl);
3109
- vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl);
3110
- vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3169
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3170
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3171
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3111
3172
 
3112
3173
  // ((qh & (1u << (j + 16))) >> (j + 12));
3113
- vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl);
3114
- vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl);
3174
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3175
+ vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
3115
3176
 
3116
3177
  // narrowing
3117
- vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl);
3118
- vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3178
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3179
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3119
3180
 
3120
- vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl);
3121
- vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3181
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3182
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3122
3183
 
3123
3184
  // load
3124
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3185
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3125
3186
 
3126
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3127
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3187
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3188
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3128
3189
 
3129
- vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3130
- vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3190
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3191
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3131
3192
 
3132
- vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3133
- vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3193
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3194
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3134
3195
 
3135
- vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3136
- vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3196
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3197
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3137
3198
 
3138
- vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl);
3139
- vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl);
3199
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3200
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
3140
3201
 
3141
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3142
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3202
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3203
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3143
3204
 
3144
3205
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3145
3206
 
3146
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3147
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3207
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3208
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3148
3209
 
3149
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3150
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3210
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3151
3211
 
3152
3212
  sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3153
3213
  }
@@ -3414,62 +3474,58 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
3414
3474
 
3415
3475
  uint32_t qh;
3416
3476
 
3417
- // These temp values are for shift operations
3418
- uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3419
-
3420
3477
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3421
3478
 
3479
+ // temporary registers for shift operations
3480
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3481
+ vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3482
+
3422
3483
  for (int i = 0; i < nb; i++) {
3423
3484
  memcpy(&qh, x[i].qh, sizeof(uint32_t));
3424
3485
 
3425
- // temporary registers
3426
- vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
3427
- vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
3428
-
3429
3486
  // load qh
3430
- vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl);
3487
+ vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
3431
3488
 
3432
3489
  // ((qh >> (j + 0)) << 4) & 0x10;
3433
- vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl);
3434
- vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3435
- vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl);
3490
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3491
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3492
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
3436
3493
 
3437
3494
  // ((qh >> (j + 12)) ) & 0x10;
3438
- vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl);
3439
- vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl);
3495
+ vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3496
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
3440
3497
 
3441
3498
  // narrowing
3442
- vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl);
3443
- vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3499
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3500
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3444
3501
 
3445
- vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl);
3446
- vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3502
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3503
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3447
3504
 
3448
3505
  // load
3449
- vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3506
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3450
3507
 
3451
- vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3452
- vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3508
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3509
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3453
3510
 
3454
- vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3455
- vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3511
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3512
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3456
3513
 
3457
- vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3458
- vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3514
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3515
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3459
3516
 
3460
- vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3461
- vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3517
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3518
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3462
3519
 
3463
- vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3464
- vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3520
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3521
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3465
3522
 
3466
3523
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3467
3524
 
3468
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3469
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3525
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3526
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3470
3527
 
3471
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3472
- sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3528
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3473
3529
 
3474
3530
  sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
3475
3531
  }
@@ -4025,12 +4081,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4025
4081
  "ALIBI",
4026
4082
  "CLAMP",
4027
4083
  "CONV_1D",
4084
+ "CONV_TRANSPOSE_1D",
4028
4085
  "CONV_2D",
4029
4086
  "CONV_TRANSPOSE_2D",
4030
4087
  "POOL_1D",
4031
4088
  "POOL_2D",
4032
4089
  "UPSCALE",
4033
4090
 
4091
+ "CONV_1D_STAGE_0",
4092
+ "CONV_1D_STAGE_1",
4093
+
4034
4094
  "FLASH_ATTN",
4035
4095
  "FLASH_FF",
4036
4096
  "FLASH_ATTN_BACK",
@@ -4056,7 +4116,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4056
4116
  "CROSS_ENTROPY_LOSS_BACK",
4057
4117
  };
4058
4118
 
4059
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
4119
+ static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4060
4120
 
4061
4121
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4062
4122
  "none",
@@ -4107,12 +4167,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4107
4167
  "alibi(x)",
4108
4168
  "clamp(x)",
4109
4169
  "conv_1d(x)",
4170
+ "conv_transpose_1d(x)",
4110
4171
  "conv_2d(x)",
4111
4172
  "conv_transpose_2d(x)",
4112
4173
  "pool_1d(x)",
4113
4174
  "pool_2d(x)",
4114
4175
  "upscale(x)",
4115
4176
 
4177
+ "conv_1d_stage_0(x)",
4178
+ "conv_1d_stage_1(x)",
4179
+
4116
4180
  "flash_attn(x)",
4117
4181
  "flash_ff(x)",
4118
4182
  "flash_attn_back(x)",
@@ -4138,7 +4202,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4138
4202
  "cross_entropy_loss_back(x,y)",
4139
4203
  };
4140
4204
 
4141
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
4205
+ static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4142
4206
 
4143
4207
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
4144
4208
 
@@ -4167,7 +4231,10 @@ static void ggml_setup_op_has_task_pass(void) {
4167
4231
  p[GGML_OP_DIAG_MASK_INF ] = true;
4168
4232
  p[GGML_OP_DIAG_MASK_ZERO ] = true;
4169
4233
  p[GGML_OP_CONV_1D ] = true;
4234
+ p[GGML_OP_CONV_1D_STAGE_0 ] = true;
4235
+ p[GGML_OP_CONV_1D_STAGE_1 ] = true;
4170
4236
  p[GGML_OP_CONV_2D ] = true;
4237
+ p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
4171
4238
  p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
4172
4239
  p[GGML_OP_FLASH_ATTN_BACK ] = true;
4173
4240
  p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
@@ -6690,7 +6757,6 @@ struct ggml_tensor * ggml_cont_4d(
6690
6757
  return result;
6691
6758
  }
6692
6759
 
6693
-
6694
6760
  // ggml_reshape
6695
6761
 
6696
6762
  struct ggml_tensor * ggml_reshape(
@@ -7448,14 +7514,17 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
7448
7514
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7449
7515
  }
7450
7516
 
7451
- GGML_API struct ggml_tensor * ggml_conv_1d(
7452
- struct ggml_context * ctx,
7453
- struct ggml_tensor * a,
7454
- struct ggml_tensor * b,
7455
- int s0,
7456
- int p0,
7457
- int d0) {
7458
- GGML_ASSERT(ggml_is_matrix(b));
7517
+ // im2col: [N, IC, IL] => [N, OL, IC*K]
7518
+ // a: [OC,IC, K]
7519
+ // b: [N, IC, IL]
7520
+ // result: [N, OL, IC*K]
7521
+ static struct ggml_tensor * ggml_conv_1d_stage_0(
7522
+ struct ggml_context * ctx,
7523
+ struct ggml_tensor * a,
7524
+ struct ggml_tensor * b,
7525
+ int s0,
7526
+ int p0,
7527
+ int d0) {
7459
7528
  GGML_ASSERT(a->ne[1] == b->ne[1]);
7460
7529
  bool is_node = false;
7461
7530
 
@@ -7464,16 +7533,54 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7464
7533
  is_node = true;
7465
7534
  }
7466
7535
 
7536
+ const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
7537
+
7467
7538
  const int64_t ne[4] = {
7468
- ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7469
- a->ne[2], 1, 1,
7539
+ a->ne[1] * a->ne[0],
7540
+ OL,
7541
+ b->ne[2],
7542
+ 1,
7470
7543
  };
7471
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7544
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
7472
7545
 
7473
7546
  int32_t params[] = { s0, p0, d0 };
7474
7547
  ggml_set_op_params(result, params, sizeof(params));
7475
7548
 
7476
- result->op = GGML_OP_CONV_1D;
7549
+ result->op = GGML_OP_CONV_1D_STAGE_0;
7550
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7551
+ result->src[0] = a;
7552
+ result->src[1] = b;
7553
+
7554
+ return result;
7555
+ }
7556
+
7557
+ // ggml_conv_1d_stage_1
7558
+
7559
+ // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
7560
+ // a: [OC, IC, K]
7561
+ // b: [N, OL, IC * K]
7562
+ // result: [N, OC, OL]
7563
+ static struct ggml_tensor * ggml_conv_1d_stage_1(
7564
+ struct ggml_context * ctx,
7565
+ struct ggml_tensor * a,
7566
+ struct ggml_tensor * b) {
7567
+
7568
+ bool is_node = false;
7569
+
7570
+ if (a->grad || b->grad) {
7571
+ GGML_ASSERT(false); // TODO: implement backward
7572
+ is_node = true;
7573
+ }
7574
+
7575
+ const int64_t ne[4] = {
7576
+ b->ne[1],
7577
+ a->ne[2],
7578
+ b->ne[2],
7579
+ 1,
7580
+ };
7581
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7582
+
7583
+ result->op = GGML_OP_CONV_1D_STAGE_1;
7477
7584
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7478
7585
  result->src[0] = a;
7479
7586
  result->src[1] = b;
@@ -7481,6 +7588,53 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7481
7588
  return result;
7482
7589
  }
7483
7590
 
7591
+ // ggml_conv_1d
7592
+
7593
+ GGML_API struct ggml_tensor * ggml_conv_1d(
7594
+ struct ggml_context * ctx,
7595
+ struct ggml_tensor * a,
7596
+ struct ggml_tensor * b,
7597
+ int s0,
7598
+ int p0,
7599
+ int d0) {
7600
+ struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
7601
+ result = ggml_conv_1d_stage_1(ctx, a, result);
7602
+ return result;
7603
+ }
7604
+
7605
+ // GGML_API struct ggml_tensor * ggml_conv_1d(
7606
+ // struct ggml_context * ctx,
7607
+ // struct ggml_tensor * a,
7608
+ // struct ggml_tensor * b,
7609
+ // int s0,
7610
+ // int p0,
7611
+ // int d0) {
7612
+ // GGML_ASSERT(ggml_is_matrix(b));
7613
+ // GGML_ASSERT(a->ne[1] == b->ne[1]);
7614
+ // bool is_node = false;
7615
+
7616
+ // if (a->grad || b->grad) {
7617
+ // GGML_ASSERT(false); // TODO: implement backward
7618
+ // is_node = true;
7619
+ // }
7620
+
7621
+ // const int64_t ne[4] = {
7622
+ // ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7623
+ // a->ne[2], 1, 1,
7624
+ // };
7625
+ // struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7626
+
7627
+ // int32_t params[] = { s0, p0, d0 };
7628
+ // ggml_set_op_params(result, params, sizeof(params));
7629
+
7630
+ // result->op = GGML_OP_CONV_1D;
7631
+ // result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7632
+ // result->src[0] = a;
7633
+ // result->src[1] = b;
7634
+
7635
+ // return result;
7636
+ // }
7637
+
7484
7638
  // ggml_conv_1d_ph
7485
7639
 
7486
7640
  struct ggml_tensor* ggml_conv_1d_ph(
@@ -7492,6 +7646,50 @@ struct ggml_tensor* ggml_conv_1d_ph(
7492
7646
  return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
7493
7647
  }
7494
7648
 
7649
+ // ggml_conv_transpose_1d
7650
+
7651
+ static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
7652
+ return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
7653
+ }
7654
+
7655
+ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
7656
+ struct ggml_context * ctx,
7657
+ struct ggml_tensor * a,
7658
+ struct ggml_tensor * b,
7659
+ int s0,
7660
+ int p0,
7661
+ int d0) {
7662
+ GGML_ASSERT(ggml_is_matrix(b));
7663
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
7664
+ GGML_ASSERT(a->ne[3] == 1);
7665
+
7666
+ GGML_ASSERT(p0 == 0);
7667
+ GGML_ASSERT(d0 == 1);
7668
+
7669
+ bool is_node = false;
7670
+
7671
+ if (a->grad || b->grad) {
7672
+ GGML_ASSERT(false); // TODO: implement backward
7673
+ is_node = true;
7674
+ }
7675
+
7676
+ const int64_t ne[4] = {
7677
+ ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
7678
+ a->ne[1], b->ne[2], 1,
7679
+ };
7680
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7681
+
7682
+ int32_t params[] = { s0, p0, d0 };
7683
+ ggml_set_op_params(result, params, sizeof(params));
7684
+
7685
+ result->op = GGML_OP_CONV_TRANSPOSE_1D;
7686
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7687
+ result->src[0] = a;
7688
+ result->src[1] = b;
7689
+
7690
+ return result;
7691
+ }
7692
+
7495
7693
  // ggml_conv_2d
7496
7694
 
7497
7695
  struct ggml_tensor * ggml_conv_2d(
@@ -11621,11 +11819,6 @@ static void ggml_compute_forward_mul_mat(
11621
11819
 
11622
11820
  #if defined(GGML_USE_CLBLAST)
11623
11821
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
11624
- // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
11625
- // ref: https://github.com/ggerganov/ggml/pull/224
11626
- GGML_ASSERT(ne02 == ne12);
11627
- GGML_ASSERT(ne03 == ne13);
11628
-
11629
11822
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
11630
11823
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
11631
11824
  }
@@ -12889,7 +13082,7 @@ static void ggml_compute_forward_alibi_f32(
12889
13082
  return;
12890
13083
  }
12891
13084
 
12892
- const int n_past = ((int32_t *) dst->op_params)[0];
13085
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
12893
13086
  const int n_head = ((int32_t *) dst->op_params)[1];
12894
13087
  float max_bias;
12895
13088
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -12910,7 +13103,6 @@ static void ggml_compute_forward_alibi_f32(
12910
13103
  //const int nb3 = src0->nb[3];
12911
13104
 
12912
13105
  GGML_ASSERT(nb0 == sizeof(float));
12913
- GGML_ASSERT(ne1 + n_past == ne0);
12914
13106
  GGML_ASSERT(n_head == ne2);
12915
13107
 
12916
13108
  // add alibi to src0 (KQ_scaled)
@@ -13636,7 +13828,7 @@ static void ggml_compute_forward_rope_back(
13636
13828
 
13637
13829
  // ggml_compute_forward_conv_1d
13638
13830
 
13639
- static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13831
+ static void ggml_compute_forward_conv_1d_f16_f32(
13640
13832
  const struct ggml_compute_params * params,
13641
13833
  const struct ggml_tensor * src0,
13642
13834
  const struct ggml_tensor * src1,
@@ -13654,42 +13846,33 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13654
13846
  const int nth = params->nth;
13655
13847
 
13656
13848
  const int nk = ne00;
13657
- const int nh = nk/2;
13658
13849
 
13659
- const int ew0 = ggml_up32(ne01);
13850
+ // size of the convolution row - the kernel size unrolled across all input channels
13851
+ const int ew0 = nk*ne01;
13852
+
13853
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
13854
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
13855
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
13660
13856
 
13661
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13662
13857
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13663
13858
  GGML_ASSERT(nb10 == sizeof(float));
13664
13859
 
13665
13860
  if (params->type == GGML_TASK_INIT) {
13666
- // TODO: fix this memset (wsize is overestimated)
13667
13861
  memset(params->wdata, 0, params->wsize);
13668
13862
 
13669
- // prepare kernel data (src0)
13670
- {
13671
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13863
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13672
13864
 
13673
- for (int64_t i02 = 0; i02 < ne02; i02++) {
13674
- for (int64_t i01 = 0; i01 < ne01; i01++) {
13675
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
13676
- ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
13677
- for (int64_t i00 = 0; i00 < ne00; i00++) {
13678
- dst_data[i00*ew0 + i01] = src[i00];
13679
- }
13680
- }
13681
- }
13682
- }
13865
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
13866
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
13867
+ ggml_fp16_t * dst_data = wdata;
13683
13868
 
13684
- // prepare source data (src1)
13685
- {
13686
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
13869
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
13870
+ for (int64_t ik = 0; ik < nk; ik++) {
13871
+ const int idx0 = i0*s0 + ik*d0 - p0;
13687
13872
 
13688
- for (int64_t i11 = 0; i11 < ne11; i11++) {
13689
- const float * const src = (float *)((char *) src1->data + i11*nb11);
13690
- ggml_fp16_t * dst_data = wdata;
13691
- for (int64_t i10 = 0; i10 < ne10; i10++) {
13692
- dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
13873
+ if(!(idx0 < 0 || idx0 >= ne10)) {
13874
+ dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]);
13875
+ }
13693
13876
  }
13694
13877
  }
13695
13878
  }
@@ -13702,7 +13885,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13702
13885
  }
13703
13886
 
13704
13887
  // total rows in dst
13705
- const int nr = ne02;
13888
+ const int nr = ne2;
13706
13889
 
13707
13890
  // rows per thread
13708
13891
  const int dr = (nr + nth - 1)/nth;
@@ -13711,23 +13894,22 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13711
13894
  const int ir0 = dr*ith;
13712
13895
  const int ir1 = MIN(ir0 + dr, nr);
13713
13896
 
13714
- for (int i1 = ir0; i1 < ir1; i1++) {
13715
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
13716
- for (int64_t i0 = 0; i0 < ne10; ++i0) {
13717
- dst_data[i0] = 0;
13718
- for (int k = -nh; k <= nh; k++) {
13719
- float v = 0.0f;
13720
- ggml_vec_dot_f16(ew0, &v,
13721
- (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13722
- (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13723
-
13724
- dst_data[i0] += v;
13897
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13898
+
13899
+ for (int i2 = 0; i2 < ne2; i2++) {
13900
+ for (int i1 = ir0; i1 < ir1; i1++) {
13901
+ float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
13902
+
13903
+ for (int i0 = 0; i0 < ne0; i0++) {
13904
+ ggml_vec_dot_f16(ew0, dst_data + i0,
13905
+ (ggml_fp16_t *) ((char *) src0->data + i1*nb02),
13906
+ (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
13725
13907
  }
13726
13908
  }
13727
13909
  }
13728
13910
  }
13729
13911
 
13730
- static void ggml_compute_forward_conv_1d_s1_ph_f32(
13912
+ static void ggml_compute_forward_conv_1d_f32(
13731
13913
  const struct ggml_compute_params * params,
13732
13914
  const struct ggml_tensor * src0,
13733
13915
  const struct ggml_tensor * src1,
@@ -13745,42 +13927,32 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13745
13927
  const int nth = params->nth;
13746
13928
 
13747
13929
  const int nk = ne00;
13748
- const int nh = nk/2;
13749
13930
 
13750
- const int ew0 = ggml_up32(ne01);
13931
+ const int ew0 = nk*ne01;
13932
+
13933
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
13934
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
13935
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
13751
13936
 
13752
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13753
13937
  GGML_ASSERT(nb00 == sizeof(float));
13754
13938
  GGML_ASSERT(nb10 == sizeof(float));
13755
13939
 
13756
13940
  if (params->type == GGML_TASK_INIT) {
13757
- // TODO: fix this memset (wsize is overestimated)
13758
13941
  memset(params->wdata, 0, params->wsize);
13759
13942
 
13760
- // prepare kernel data (src0)
13761
- {
13762
- float * const wdata = (float *) params->wdata + 0;
13943
+ float * const wdata = (float *) params->wdata + 0;
13763
13944
 
13764
- for (int64_t i02 = 0; i02 < ne02; i02++) {
13765
- for (int64_t i01 = 0; i01 < ne01; i01++) {
13766
- const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
13767
- float * dst_data = wdata + i02*ew0*ne00;
13768
- for (int64_t i00 = 0; i00 < ne00; i00++) {
13769
- dst_data[i00*ew0 + i01] = src[i00];
13770
- }
13771
- }
13772
- }
13773
- }
13945
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
13946
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
13947
+ float * dst_data = wdata;
13774
13948
 
13775
- // prepare source data (src1)
13776
- {
13777
- float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
13949
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
13950
+ for (int64_t ik = 0; ik < nk; ik++) {
13951
+ const int idx0 = i0*s0 + ik*d0 - p0;
13778
13952
 
13779
- for (int64_t i11 = 0; i11 < ne11; i11++) {
13780
- const float * const src = (float *)((char *) src1->data + i11*nb11);
13781
- float * dst_data = wdata;
13782
- for (int64_t i10 = 0; i10 < ne10; i10++) {
13783
- dst_data[(i10 + nh)*ew0 + i11] = src[i10];
13953
+ if(!(idx0 < 0 || idx0 >= ne10)) {
13954
+ dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
13955
+ }
13784
13956
  }
13785
13957
  }
13786
13958
  }
@@ -13802,35 +13974,242 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13802
13974
  const int ir0 = dr*ith;
13803
13975
  const int ir1 = MIN(ir0 + dr, nr);
13804
13976
 
13805
- for (int i1 = ir0; i1 < ir1; i1++) {
13806
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
13807
- for (int64_t i0 = 0; i0 < ne10; ++i0) {
13808
- dst_data[i0] = 0;
13809
- for (int k = -nh; k <= nh; k++) {
13810
- float v = 0.0f;
13811
- ggml_vec_dot_f32(ew0, &v,
13812
- (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13813
- (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13814
-
13815
- dst_data[i0] += v;
13977
+ float * const wdata = (float *) params->wdata + 0;
13978
+
13979
+ for (int i2 = 0; i2 < ne2; i2++) {
13980
+ for (int i1 = ir0; i1 < ir1; i1++) {
13981
+ float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
13982
+
13983
+ for (int i0 = 0; i0 < ne0; i0++) {
13984
+ ggml_vec_dot_f32(ew0, dst_data + i0,
13985
+ (float *) ((char *) src0->data + i1*nb02),
13986
+ (float *) wdata + i2*nb2 + i0*ew0);
13987
+ }
13988
+ }
13989
+ }
13990
+ }
13991
+
13992
+ static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
13993
+ ggml_fp16_t * A,
13994
+ ggml_fp16_t * B,
13995
+ float * C,
13996
+ const int ith, const int nth) {
13997
+ // does not seem to make a difference
13998
+ int64_t m0, m1, n0, n1;
13999
+ // patches per thread
14000
+ if (m > n) {
14001
+ n0 = 0;
14002
+ n1 = n;
14003
+
14004
+ // total patches in dst
14005
+ const int np = m;
14006
+
14007
+ // patches per thread
14008
+ const int dp = (np + nth - 1)/nth;
14009
+
14010
+ // patch range for this thread
14011
+ m0 = dp*ith;
14012
+ m1 = MIN(m0 + dp, np);
14013
+ } else {
14014
+ m0 = 0;
14015
+ m1 = m;
14016
+
14017
+ // total patches in dst
14018
+ const int np = n;
14019
+
14020
+ // patches per thread
14021
+ const int dp = (np + nth - 1)/nth;
14022
+
14023
+ // patch range for this thread
14024
+ n0 = dp*ith;
14025
+ n1 = MIN(n0 + dp, np);
14026
+ }
14027
+
14028
+ // block-tiling attempt
14029
+ int64_t blck_n = 16;
14030
+ int64_t blck_m = 16;
14031
+
14032
+ // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
14033
+ // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
14034
+ // if (blck_size > 0) {
14035
+ // blck_0 = 4;
14036
+ // blck_1 = blck_size / blck_0;
14037
+ // if (blck_1 < 0) {
14038
+ // blck_1 = 1;
14039
+ // }
14040
+ // // blck_0 = (int64_t)sqrt(blck_size);
14041
+ // // blck_1 = blck_0;
14042
+ // }
14043
+ // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
14044
+
14045
+ for (int j = n0; j < n1; j+=blck_n) {
14046
+ for (int i = m0; i < m1; i+=blck_m) {
14047
+ // printf("i j k => %d %d %d\n", i, j, K);
14048
+ for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
14049
+ for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
14050
+ ggml_vec_dot_f16(k,
14051
+ C + ii*n + jj,
14052
+ A + ii * k,
14053
+ B + jj * k);
14054
+ }
13816
14055
  }
13817
14056
  }
13818
14057
  }
13819
14058
  }
13820
14059
 
13821
- static void ggml_compute_forward_conv_1d_s1_ph(
14060
+ // src0: kernel [OC, IC, K]
14061
+ // src1: signal [N, IC, IL]
14062
+ // dst: result [N, OL, IC*K]
14063
+ static void ggml_compute_forward_conv_1d_stage_0_f32(
13822
14064
  const struct ggml_compute_params * params,
13823
14065
  const struct ggml_tensor * src0,
13824
14066
  const struct ggml_tensor * src1,
13825
14067
  struct ggml_tensor * dst) {
13826
- switch (src0->type) {
14068
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14069
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
14070
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
14071
+
14072
+ int64_t t0 = ggml_perf_time_us();
14073
+ UNUSED(t0);
14074
+
14075
+ GGML_TENSOR_BINARY_OP_LOCALS;
14076
+
14077
+ const int64_t N = ne12;
14078
+ const int64_t IC = ne11;
14079
+ const int64_t IL = ne10;
14080
+
14081
+ const int64_t K = ne00;
14082
+
14083
+ const int64_t OL = ne1;
14084
+
14085
+ const int ith = params->ith;
14086
+ const int nth = params->nth;
14087
+
14088
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14089
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
14090
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
14091
+
14092
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14093
+ GGML_ASSERT(nb10 == sizeof(float));
14094
+
14095
+ if (params->type == GGML_TASK_INIT) {
14096
+ memset(dst->data, 0, ggml_nbytes(dst));
14097
+ return;
14098
+ }
14099
+
14100
+ if (params->type == GGML_TASK_FINALIZE) {
14101
+ return;
14102
+ }
14103
+
14104
+ // im2col: [N, IC, IL] => [N, OL, IC*K]
14105
+ {
14106
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
14107
+
14108
+ for (int64_t in = 0; in < N; in++) {
14109
+ for (int64_t iol = 0; iol < OL; iol++) {
14110
+ for (int64_t iic = ith; iic < IC; iic+=nth) {
14111
+
14112
+ // micro kernel
14113
+ ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
14114
+ const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
14115
+
14116
+ for (int64_t ik = 0; ik < K; ik++) {
14117
+ const int64_t iil = iol*s0 + ik*d0 - p0;
14118
+
14119
+ if (!(iil < 0 || iil >= IL)) {
14120
+ dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]);
14121
+ }
14122
+ }
14123
+ }
14124
+ }
14125
+ }
14126
+ }
14127
+ }
14128
+
14129
+ // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
14130
+ // src0: [OC, IC, K]
14131
+ // src1: [N, OL, IC * K]
14132
+ // result: [N, OC, OL]
14133
+ static void ggml_compute_forward_conv_1d_stage_1_f16(
14134
+ const struct ggml_compute_params * params,
14135
+ const struct ggml_tensor * src0,
14136
+ const struct ggml_tensor * src1,
14137
+ struct ggml_tensor * dst) {
14138
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14139
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
14140
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
14141
+
14142
+ int64_t t0 = ggml_perf_time_us();
14143
+ UNUSED(t0);
14144
+
14145
+ if (params->type == GGML_TASK_INIT) {
14146
+ return;
14147
+ }
14148
+
14149
+ if (params->type == GGML_TASK_FINALIZE) {
14150
+ return;
14151
+ }
14152
+
14153
+ GGML_TENSOR_BINARY_OP_LOCALS;
14154
+
14155
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14156
+ GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
14157
+ GGML_ASSERT(nb0 == sizeof(float));
14158
+
14159
+ const int N = ne12;
14160
+ const int OL = ne11;
14161
+
14162
+ const int OC = ne02;
14163
+ const int IC = ne01;
14164
+ const int K = ne00;
14165
+
14166
+ const int ith = params->ith;
14167
+ const int nth = params->nth;
14168
+
14169
+ int64_t m = OC;
14170
+ int64_t n = OL;
14171
+ int64_t k = IC * K;
14172
+
14173
+ // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
14174
+ for (int i = 0; i < N; i++) {
14175
+ ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
14176
+ ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
14177
+ float * C = (float *)dst->data + i * m * n; // [m, n]
14178
+
14179
+ gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
14180
+ }
14181
+ }
14182
+
14183
+ static void ggml_compute_forward_conv_1d(
14184
+ const struct ggml_compute_params * params,
14185
+ const struct ggml_tensor * src0,
14186
+ const struct ggml_tensor * src1,
14187
+ struct ggml_tensor * dst) {
14188
+ switch(src0->type) {
13827
14189
  case GGML_TYPE_F16:
13828
14190
  {
13829
- ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst);
14191
+ ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
13830
14192
  } break;
13831
14193
  case GGML_TYPE_F32:
13832
14194
  {
13833
- ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst);
14195
+ ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
14196
+ } break;
14197
+ default:
14198
+ {
14199
+ GGML_ASSERT(false);
14200
+ } break;
14201
+ }
14202
+ }
14203
+
14204
+ static void ggml_compute_forward_conv_1d_stage_0(
14205
+ const struct ggml_compute_params * params,
14206
+ const struct ggml_tensor * src0,
14207
+ const struct ggml_tensor * src1,
14208
+ struct ggml_tensor * dst) {
14209
+ switch(src0->type) {
14210
+ case GGML_TYPE_F16:
14211
+ {
14212
+ ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
13834
14213
  } break;
13835
14214
  default:
13836
14215
  {
@@ -13839,7 +14218,26 @@ static void ggml_compute_forward_conv_1d_s1_ph(
13839
14218
  }
13840
14219
  }
13841
14220
 
13842
- static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
14221
+ static void ggml_compute_forward_conv_1d_stage_1(
14222
+ const struct ggml_compute_params * params,
14223
+ const struct ggml_tensor * src0,
14224
+ const struct ggml_tensor * src1,
14225
+ struct ggml_tensor * dst) {
14226
+ switch(src0->type) {
14227
+ case GGML_TYPE_F16:
14228
+ {
14229
+ ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
14230
+ } break;
14231
+ default:
14232
+ {
14233
+ GGML_ASSERT(false);
14234
+ } break;
14235
+ }
14236
+ }
14237
+
14238
+ // ggml_compute_forward_conv_transpose_1d
14239
+
14240
+ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
13843
14241
  const struct ggml_compute_params * params,
13844
14242
  const struct ggml_tensor * src0,
13845
14243
  const struct ggml_tensor * src1,
@@ -13856,43 +14254,38 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13856
14254
  const int ith = params->ith;
13857
14255
  const int nth = params->nth;
13858
14256
 
13859
- const int nk = ne00;
13860
- const int nh = nk/2;
13861
-
13862
- const int ew0 = ggml_up32(ne01);
14257
+ const int nk = ne00*ne01*ne02;
13863
14258
 
13864
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13865
14259
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13866
14260
  GGML_ASSERT(nb10 == sizeof(float));
13867
14261
 
13868
14262
  if (params->type == GGML_TASK_INIT) {
13869
- // TODO: fix this memset (wsize is overestimated)
13870
14263
  memset(params->wdata, 0, params->wsize);
13871
14264
 
13872
- // prepare kernel data (src0)
14265
+ // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
13873
14266
  {
13874
14267
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13875
14268
 
13876
14269
  for (int64_t i02 = 0; i02 < ne02; i02++) {
13877
14270
  for (int64_t i01 = 0; i01 < ne01; i01++) {
13878
14271
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
13879
- ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
14272
+ ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
13880
14273
  for (int64_t i00 = 0; i00 < ne00; i00++) {
13881
- dst_data[i00*ew0 + i01] = src[i00];
14274
+ dst_data[i00*ne02 + i02] = src[i00];
13882
14275
  }
13883
14276
  }
13884
14277
  }
13885
14278
  }
13886
14279
 
13887
- // prepare source data (src1)
14280
+ // permute source data (src1) from (L x Cin) to (Cin x L)
13888
14281
  {
13889
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
14282
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
14283
+ ggml_fp16_t * dst_data = wdata;
13890
14284
 
13891
14285
  for (int64_t i11 = 0; i11 < ne11; i11++) {
13892
14286
  const float * const src = (float *)((char *) src1->data + i11*nb11);
13893
- ggml_fp16_t * dst_data = wdata;
13894
14287
  for (int64_t i10 = 0; i10 < ne10; i10++) {
13895
- dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
14288
+ dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
13896
14289
  }
13897
14290
  }
13898
14291
  }
@@ -13904,8 +14297,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13904
14297
  return;
13905
14298
  }
13906
14299
 
14300
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14301
+
13907
14302
  // total rows in dst
13908
- const int nr = ne02;
14303
+ const int nr = ne1;
13909
14304
 
13910
14305
  // rows per thread
13911
14306
  const int dr = (nr + nth - 1)/nth;
@@ -13914,23 +14309,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13914
14309
  const int ir0 = dr*ith;
13915
14310
  const int ir1 = MIN(ir0 + dr, nr);
13916
14311
 
14312
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14313
+ ggml_fp16_t * const wdata_src = wdata + nk;
14314
+
13917
14315
  for (int i1 = ir0; i1 < ir1; i1++) {
13918
14316
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
13919
- for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
13920
- dst_data[i0/2] = 0;
13921
- for (int k = -nh; k <= nh; k++) {
13922
- float v = 0.0f;
13923
- ggml_vec_dot_f16(ew0, &v,
13924
- (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
13925
- (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
13926
-
13927
- dst_data[i0/2] += v;
14317
+ ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
14318
+ for (int i10 = 0; i10 < ne10; i10++) {
14319
+ const int i1n = i10*ne11;
14320
+ for (int i00 = 0; i00 < ne00; i00++) {
14321
+ float v = 0;
14322
+ ggml_vec_dot_f16(ne02, &v,
14323
+ (ggml_fp16_t *) wdata_src + i1n,
14324
+ (ggml_fp16_t *) wdata_kernel + i00*ne02);
14325
+ dst_data[i10*s0 + i00] += v;
13928
14326
  }
13929
14327
  }
13930
14328
  }
13931
14329
  }
13932
14330
 
13933
- static void ggml_compute_forward_conv_1d_s2_ph_f32(
14331
+ static void ggml_compute_forward_conv_transpose_1d_f32(
13934
14332
  const struct ggml_compute_params * params,
13935
14333
  const struct ggml_tensor * src0,
13936
14334
  const struct ggml_tensor * src1,
@@ -13947,29 +14345,24 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13947
14345
  const int ith = params->ith;
13948
14346
  const int nth = params->nth;
13949
14347
 
13950
- const int nk = ne00;
13951
- const int nh = nk/2;
13952
-
13953
- const int ew0 = ggml_up32(ne01);
14348
+ const int nk = ne00*ne01*ne02;
13954
14349
 
13955
- GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
13956
14350
  GGML_ASSERT(nb00 == sizeof(float));
13957
14351
  GGML_ASSERT(nb10 == sizeof(float));
13958
14352
 
13959
14353
  if (params->type == GGML_TASK_INIT) {
13960
- // TODO: fix this memset (wsize is overestimated)
13961
14354
  memset(params->wdata, 0, params->wsize);
13962
14355
 
13963
- // prepare kernel data (src0)
14356
+ // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
13964
14357
  {
13965
14358
  float * const wdata = (float *) params->wdata + 0;
13966
14359
 
13967
14360
  for (int64_t i02 = 0; i02 < ne02; i02++) {
13968
14361
  for (int64_t i01 = 0; i01 < ne01; i01++) {
13969
14362
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
13970
- float * dst_data = wdata + i02*ew0*ne00;
14363
+ float * dst_data = wdata + i01*ne00*ne02;
13971
14364
  for (int64_t i00 = 0; i00 < ne00; i00++) {
13972
- dst_data[i00*ew0 + i01] = src[i00];
14365
+ dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00];
13973
14366
  }
13974
14367
  }
13975
14368
  }
@@ -13977,13 +14370,13 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13977
14370
 
13978
14371
  // prepare source data (src1)
13979
14372
  {
13980
- float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
14373
+ float * const wdata = (float *) params->wdata + nk;
14374
+ float * dst_data = wdata;
13981
14375
 
13982
14376
  for (int64_t i11 = 0; i11 < ne11; i11++) {
13983
14377
  const float * const src = (float *)((char *) src1->data + i11*nb11);
13984
- float * dst_data = wdata;
13985
14378
  for (int64_t i10 = 0; i10 < ne10; i10++) {
13986
- dst_data[(i10 + nh)*ew0 + i11] = src[i10];
14379
+ dst_data[i10*ne11 + i11] = src[i10];
13987
14380
  }
13988
14381
  }
13989
14382
  }
@@ -13995,8 +14388,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13995
14388
  return;
13996
14389
  }
13997
14390
 
14391
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14392
+
13998
14393
  // total rows in dst
13999
- const int nr = ne02;
14394
+ const int nr = ne1;
14000
14395
 
14001
14396
  // rows per thread
14002
14397
  const int dr = (nr + nth - 1)/nth;
@@ -14005,23 +14400,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
14005
14400
  const int ir0 = dr*ith;
14006
14401
  const int ir1 = MIN(ir0 + dr, nr);
14007
14402
 
14403
+ float * const wdata = (float *) params->wdata + 0;
14404
+ float * const wdata_src = wdata + nk;
14405
+
14008
14406
  for (int i1 = ir0; i1 < ir1; i1++) {
14009
14407
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
14010
- for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
14011
- dst_data[i0/2] = 0;
14012
- for (int k = -nh; k <= nh; k++) {
14013
- float v = 0.0f;
14014
- ggml_vec_dot_f32(ew0, &v,
14015
- (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
14016
- (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
14017
-
14018
- dst_data[i0/2] += v;
14408
+ float * wdata_kernel = wdata + i1*ne02*ne00;
14409
+ for (int i10 = 0; i10 < ne10; i10++) {
14410
+ const int i1n = i10*ne11;
14411
+ for (int i00 = 0; i00 < ne00; i00++) {
14412
+ float v = 0;
14413
+ ggml_vec_dot_f32(ne02, &v,
14414
+ wdata_src + i1n,
14415
+ wdata_kernel + i00*ne02);
14416
+ dst_data[i10*s0 + i00] += v;
14019
14417
  }
14020
14418
  }
14021
14419
  }
14022
14420
  }
14023
14421
 
14024
- static void ggml_compute_forward_conv_1d_s2_ph(
14422
+ static void ggml_compute_forward_conv_transpose_1d(
14025
14423
  const struct ggml_compute_params * params,
14026
14424
  const struct ggml_tensor * src0,
14027
14425
  const struct ggml_tensor * src1,
@@ -14029,11 +14427,11 @@ static void ggml_compute_forward_conv_1d_s2_ph(
14029
14427
  switch (src0->type) {
14030
14428
  case GGML_TYPE_F16:
14031
14429
  {
14032
- ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst);
14430
+ ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
14033
14431
  } break;
14034
14432
  case GGML_TYPE_F32:
14035
14433
  {
14036
- ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst);
14434
+ ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
14037
14435
  } break;
14038
14436
  default:
14039
14437
  {
@@ -14042,27 +14440,6 @@ static void ggml_compute_forward_conv_1d_s2_ph(
14042
14440
  }
14043
14441
  }
14044
14442
 
14045
- // ggml_compute_forward_conv_1d
14046
-
14047
- static void ggml_compute_forward_conv_1d(
14048
- const struct ggml_compute_params * params,
14049
- const struct ggml_tensor * src0,
14050
- const struct ggml_tensor * src1,
14051
- struct ggml_tensor * dst) {
14052
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14053
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
14054
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
14055
- GGML_ASSERT(d0 == 1); // dilation not supported
14056
- GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported
14057
- if (s0 == 1) {
14058
- ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst);
14059
- } else if (s0 == 2) {
14060
- ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
14061
- } else {
14062
- GGML_ASSERT(false); // only stride 1 and 2 supported
14063
- }
14064
- }
14065
-
14066
14443
  // ggml_compute_forward_conv_2d
14067
14444
 
14068
14445
  static void ggml_compute_forward_conv_2d_f16_f32(
@@ -14105,20 +14482,22 @@ static void ggml_compute_forward_conv_2d_f16_f32(
14105
14482
  {
14106
14483
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14107
14484
 
14108
- for (int i12 = 0; i12 < ne12; i12++) {
14109
- const float * const src = (float *)((char *) src1->data + i12*nb12);
14110
- ggml_fp16_t * dst_data = wdata;
14111
-
14112
- for (int i1 = 0; i1 < ne1; i1++) {
14113
- for (int i0 = 0; i0 < ne0; i0++) {
14114
- for (int ik1 = 0; ik1 < nk1; ik1++) {
14115
- for (int ik0 = 0; ik0 < nk0; ik0++) {
14116
- const int idx0 = i0*s0 + ik0*d0 - p0;
14117
- const int idx1 = i1*s1 + ik1*d1 - p1;
14118
-
14119
- if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
14120
- dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
14121
- GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
14485
+ for (int i13 = 0; i13 < ne13; i13++) {
14486
+ for (int i12 = 0; i12 < ne12; i12++) {
14487
+ const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
14488
+ ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
14489
+
14490
+ for (int i1 = 0; i1 < ne1; i1++) {
14491
+ for (int i0 = 0; i0 < ne0; i0++) {
14492
+ for (int ik1 = 0; ik1 < nk1; ik1++) {
14493
+ for (int ik0 = 0; ik0 < nk0; ik0++) {
14494
+ const int idx0 = i0*s0 + ik0*d0 - p0;
14495
+ const int idx1 = i1*s1 + ik1*d1 - p1;
14496
+
14497
+ if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
14498
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
14499
+ GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
14500
+ }
14122
14501
  }
14123
14502
  }
14124
14503
  }
@@ -16401,6 +16780,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16401
16780
  {
16402
16781
  ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
16403
16782
  } break;
16783
+ case GGML_OP_CONV_1D_STAGE_0:
16784
+ {
16785
+ ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
16786
+ } break;
16787
+ case GGML_OP_CONV_1D_STAGE_1:
16788
+ {
16789
+ ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
16790
+ } break;
16791
+ case GGML_OP_CONV_TRANSPOSE_1D:
16792
+ {
16793
+ ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
16794
+ } break;
16404
16795
  case GGML_OP_CONV_2D:
16405
16796
  {
16406
16797
  ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
@@ -17326,10 +17717,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17326
17717
  {
17327
17718
  GGML_ASSERT(false); // TODO: not implemented
17328
17719
  } break;
17720
+ case GGML_OP_CONV_1D_STAGE_0:
17721
+ {
17722
+ GGML_ASSERT(false); // TODO: not implemented
17723
+ } break;
17724
+ case GGML_OP_CONV_1D_STAGE_1:
17725
+ {
17726
+ GGML_ASSERT(false); // TODO: not implemented
17727
+ } break;
17329
17728
  case GGML_OP_CONV_2D:
17330
17729
  {
17331
17730
  GGML_ASSERT(false); // TODO: not implemented
17332
17731
  } break;
17732
+ case GGML_OP_CONV_TRANSPOSE_1D:
17733
+ {
17734
+ GGML_ASSERT(false); // TODO: not implemented
17735
+ } break;
17333
17736
  case GGML_OP_CONV_TRANSPOSE_2D:
17334
17737
  {
17335
17738
  GGML_ASSERT(false); // TODO: not implemented
@@ -18171,21 +18574,68 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
18171
18574
  GGML_ASSERT(node->src[1]->ne[2] == 1);
18172
18575
  GGML_ASSERT(node->src[1]->ne[3] == 1);
18173
18576
 
18577
+ const int64_t ne00 = node->src[0]->ne[0];
18578
+ const int64_t ne01 = node->src[0]->ne[1];
18579
+ const int64_t ne02 = node->src[0]->ne[2];
18580
+
18581
+ const int64_t ne10 = node->src[1]->ne[0];
18582
+ const int64_t ne11 = node->src[1]->ne[1];
18583
+
18584
+ const int64_t ne0 = node->ne[0];
18585
+ const int64_t ne1 = node->ne[1];
18586
+ const int64_t nk = ne00;
18587
+ const int64_t ew0 = nk * ne01;
18588
+
18589
+ UNUSED(ne02);
18590
+ UNUSED(ne10);
18591
+ UNUSED(ne11);
18592
+
18174
18593
  size_t cur = 0;
18175
- const int nk = node->src[0]->ne[0];
18176
18594
 
18177
18595
  if (node->src[0]->type == GGML_TYPE_F16 &&
18178
- node->src[1]->type == GGML_TYPE_F32) {
18179
- cur = sizeof(ggml_fp16_t)*(
18180
- nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
18181
- ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
18182
- );
18596
+ node->src[1]->type == GGML_TYPE_F32) {
18597
+ cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
18598
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
18599
+ node->src[1]->type == GGML_TYPE_F32) {
18600
+ cur = sizeof(float)*(ne0*ne1*ew0);
18601
+ } else {
18602
+ GGML_ASSERT(false);
18603
+ }
18604
+
18605
+ work_size = MAX(work_size, cur);
18606
+ } break;
18607
+ case GGML_OP_CONV_1D_STAGE_0:
18608
+ {
18609
+ n_tasks = n_threads;
18610
+ } break;
18611
+ case GGML_OP_CONV_1D_STAGE_1:
18612
+ {
18613
+ n_tasks = n_threads;
18614
+ } break;
18615
+ case GGML_OP_CONV_TRANSPOSE_1D:
18616
+ {
18617
+ n_tasks = n_threads;
18618
+
18619
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
18620
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
18621
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
18622
+
18623
+ const int64_t ne00 = node->src[0]->ne[0]; // K
18624
+ const int64_t ne01 = node->src[0]->ne[1]; // Cout
18625
+ const int64_t ne02 = node->src[0]->ne[2]; // Cin
18626
+
18627
+ const int64_t ne10 = node->src[1]->ne[0]; // L
18628
+ const int64_t ne11 = node->src[1]->ne[1]; // Cin
18629
+
18630
+ size_t cur = 0;
18631
+ if (node->src[0]->type == GGML_TYPE_F16 &&
18632
+ node->src[1]->type == GGML_TYPE_F32) {
18633
+ cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18634
+ cur += sizeof(ggml_fp16_t)*ne10*ne11;
18183
18635
  } else if (node->src[0]->type == GGML_TYPE_F32 &&
18184
- node->src[1]->type == GGML_TYPE_F32) {
18185
- cur = sizeof(float)*(
18186
- nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
18187
- ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
18188
- );
18636
+ node->src[1]->type == GGML_TYPE_F32) {
18637
+ cur += sizeof(float)*ne00*ne01*ne02;
18638
+ cur += sizeof(float)*ne10*ne11;
18189
18639
  } else {
18190
18640
  GGML_ASSERT(false);
18191
18641
  }
@@ -19311,7 +19761,7 @@ static enum ggml_opt_result ggml_opt_adam(
19311
19761
  if (callback) {
19312
19762
  callback(callback_data, accum_step, &sched, &cancel);
19313
19763
  if (cancel) {
19314
- break;
19764
+ return GGML_OPT_CANCEL;
19315
19765
  }
19316
19766
  }
19317
19767
  // ggml_graph_reset (gf);
@@ -19320,9 +19770,6 @@ static enum ggml_opt_result ggml_opt_adam(
19320
19770
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19321
19771
  fx += ggml_get_f32_1d(f, 0);
19322
19772
  }
19323
- if (cancel) {
19324
- return GGML_OPT_DID_NOT_CONVERGE;
19325
- }
19326
19773
  fx *= accum_norm;
19327
19774
 
19328
19775
  opt->adam.fx_prev = fx;
@@ -19348,9 +19795,6 @@ static enum ggml_opt_result ggml_opt_adam(
19348
19795
 
19349
19796
  // run the optimizer
19350
19797
  for (int t = 0; t < params.adam.n_iter; ++t) {
19351
- if (cancel) {
19352
- break;
19353
- }
19354
19798
  opt->iter = iter0 + t + 1;
19355
19799
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
19356
19800
 
@@ -19408,7 +19852,7 @@ static enum ggml_opt_result ggml_opt_adam(
19408
19852
  if (callback) {
19409
19853
  callback(callback_data, accum_step, &sched, &cancel);
19410
19854
  if (cancel) {
19411
- break;
19855
+ return GGML_OPT_CANCEL;;
19412
19856
  }
19413
19857
  }
19414
19858
  // ggml_graph_reset (gf);
@@ -19417,9 +19861,6 @@ static enum ggml_opt_result ggml_opt_adam(
19417
19861
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19418
19862
  fx += ggml_get_f32_1d(f, 0);
19419
19863
  }
19420
- if (cancel) {
19421
- break;
19422
- }
19423
19864
  fx *= accum_norm;
19424
19865
 
19425
19866
  opt->loss_after = fx;
@@ -19538,7 +19979,7 @@ static enum ggml_opt_result linesearch_backtracking(
19538
19979
  finit = *fx;
19539
19980
  dgtest = params->lbfgs.ftol*dginit;
19540
19981
 
19541
- while (!*cancel) {
19982
+ while (true) {
19542
19983
  ggml_vec_cpy_f32(nx, x, xp);
19543
19984
  ggml_vec_mad_f32(nx, x, d, *step);
19544
19985
 
@@ -19554,7 +19995,7 @@ static enum ggml_opt_result linesearch_backtracking(
19554
19995
  float sched = 0;
19555
19996
  callback(callback_data, accum_step, &sched, cancel);
19556
19997
  if (*cancel) {
19557
- break;
19998
+ return GGML_OPT_CANCEL;
19558
19999
  }
19559
20000
  }
19560
20001
  // ggml_graph_reset (gf);
@@ -19563,9 +20004,6 @@ static enum ggml_opt_result linesearch_backtracking(
19563
20004
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19564
20005
  *fx += ggml_get_f32_1d(f, 0);
19565
20006
  }
19566
- if (*cancel) {
19567
- break;
19568
- }
19569
20007
  *fx *= accum_norm;
19570
20008
 
19571
20009
  }
@@ -19698,7 +20136,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19698
20136
  float sched = 0;
19699
20137
  callback(callback_data, accum_step, &sched, &cancel);
19700
20138
  if (cancel) {
19701
- break;
20139
+ return GGML_OPT_CANCEL;
19702
20140
  }
19703
20141
  }
19704
20142
  // ggml_graph_reset (gf);
@@ -19707,9 +20145,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19707
20145
  ggml_opt_acc_grad(np, ps, g, accum_norm);
19708
20146
  fx += ggml_get_f32_1d(f, 0);
19709
20147
  }
19710
- if (cancel) {
19711
- return GGML_OPT_DID_NOT_CONVERGE;
19712
- }
19713
20148
  fx *= accum_norm;
19714
20149
 
19715
20150
  opt->loss_before = fx;
@@ -19769,8 +20204,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19769
20204
  ggml_vec_cpy_f32(nx, gp, g);
19770
20205
 
19771
20206
  ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
19772
- if (!cancel) {
19773
- break;
20207
+ if (cancel) {
20208
+ return GGML_OPT_CANCEL;
19774
20209
  }
19775
20210
 
19776
20211
  if (ls < 0) {