llama_cpp 0.6.0 → 0.7.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +49 -3
- data/ext/llama_cpp/src/ggml-alloc.c +62 -107
- data/ext/llama_cpp/src/ggml-alloc.h +11 -5
- data/ext/llama_cpp/src/ggml-backend.c +385 -0
- data/ext/llama_cpp/src/ggml-backend.h +143 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +622 -150
- data/ext/llama_cpp/src/ggml-cuda.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.h +18 -1
- data/ext/llama_cpp/src/ggml-metal.m +358 -131
- data/ext/llama_cpp/src/ggml-metal.metal +137 -47
- data/ext/llama_cpp/src/ggml-opencl.cpp +136 -68
- data/ext/llama_cpp/src/ggml.c +812 -365
- data/ext/llama_cpp/src/ggml.h +25 -7
- data/ext/llama_cpp/src/k_quants.c +744 -2
- data/ext/llama_cpp/src/k_quants.h +5 -5
- data/ext/llama_cpp/src/llama.cpp +2387 -421
- data/ext/llama_cpp/src/llama.h +22 -6
- data/ext/llama_cpp/src/unicode.h +462 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +5 -0
- metadata +5 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -162,40 +162,16 @@ typedef void * thread_ret_t;
|
|
162
162
|
|
163
163
|
#define GGML_PRINT(...) printf(__VA_ARGS__)
|
164
164
|
|
165
|
+
//
|
166
|
+
// end of logging block
|
167
|
+
//
|
168
|
+
|
165
169
|
#ifdef GGML_USE_ACCELERATE
|
166
170
|
// uncomment to use vDSP for soft max computation
|
167
171
|
// note: not sure if it is actually faster
|
168
172
|
//#define GGML_SOFT_MAX_ACCELERATE
|
169
173
|
#endif
|
170
174
|
|
171
|
-
//
|
172
|
-
// logging
|
173
|
-
//
|
174
|
-
|
175
|
-
#if (GGML_DEBUG >= 1)
|
176
|
-
#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
|
177
|
-
#else
|
178
|
-
#define GGML_PRINT_DEBUG(...)
|
179
|
-
#endif
|
180
|
-
|
181
|
-
#if (GGML_DEBUG >= 5)
|
182
|
-
#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
|
183
|
-
#else
|
184
|
-
#define GGML_PRINT_DEBUG_5(...)
|
185
|
-
#endif
|
186
|
-
|
187
|
-
#if (GGML_DEBUG >= 10)
|
188
|
-
#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
|
189
|
-
#else
|
190
|
-
#define GGML_PRINT_DEBUG_10(...)
|
191
|
-
#endif
|
192
|
-
|
193
|
-
#define GGML_PRINT(...) printf(__VA_ARGS__)
|
194
|
-
|
195
|
-
//
|
196
|
-
// end of logging block
|
197
|
-
//
|
198
|
-
|
199
175
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
200
176
|
#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
|
201
177
|
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
@@ -1032,8 +1008,8 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
|
|
1032
1008
|
y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
|
1033
1009
|
|
1034
1010
|
// get the 5-th bit and store it in qh at the right position
|
1035
|
-
qh |= ((xi0 &
|
1036
|
-
qh |= ((xi1 &
|
1011
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
1012
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
|
1037
1013
|
}
|
1038
1014
|
|
1039
1015
|
memcpy(&y[i].qh, &qh, sizeof(qh));
|
@@ -1080,8 +1056,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
|
|
1080
1056
|
y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
|
1081
1057
|
|
1082
1058
|
// get the 5-th bit and store it in qh at the right position
|
1083
|
-
qh |= ((xi0 &
|
1084
|
-
qh |= ((xi1 &
|
1059
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
1060
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
|
1085
1061
|
}
|
1086
1062
|
|
1087
1063
|
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
@@ -1272,6 +1248,33 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1272
1248
|
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1273
1249
|
#endif
|
1274
1250
|
}
|
1251
|
+
#elif defined(__riscv_v_intrinsic)
|
1252
|
+
|
1253
|
+
size_t vl = __riscv_vsetvl_e32m4(QK8_0);
|
1254
|
+
|
1255
|
+
for (int i = 0; i < nb; i++) {
|
1256
|
+
// load elements
|
1257
|
+
vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
|
1258
|
+
|
1259
|
+
vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
|
1260
|
+
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
|
1261
|
+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
|
1262
|
+
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
|
1263
|
+
|
1264
|
+
const float d = amax / ((1 << 7) - 1);
|
1265
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1266
|
+
|
1267
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1268
|
+
|
1269
|
+
vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
|
1270
|
+
|
1271
|
+
// convert to integer
|
1272
|
+
vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
|
1273
|
+
vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
|
1274
|
+
|
1275
|
+
// store result
|
1276
|
+
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
|
1277
|
+
}
|
1275
1278
|
#else
|
1276
1279
|
// scalar
|
1277
1280
|
quantize_row_q8_0_reference(x, y, k);
|
@@ -1490,6 +1493,41 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
|
|
1490
1493
|
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1491
1494
|
#endif
|
1492
1495
|
}
|
1496
|
+
#elif defined(__riscv_v_intrinsic)
|
1497
|
+
|
1498
|
+
size_t vl = __riscv_vsetvl_e32m4(QK8_1);
|
1499
|
+
|
1500
|
+
for (int i = 0; i < nb; i++) {
|
1501
|
+
// load elements
|
1502
|
+
vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
|
1503
|
+
|
1504
|
+
vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
|
1505
|
+
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
|
1506
|
+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
|
1507
|
+
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
|
1508
|
+
|
1509
|
+
const float d = amax / ((1 << 7) - 1);
|
1510
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1511
|
+
|
1512
|
+
y[i].d = d;
|
1513
|
+
|
1514
|
+
vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
|
1515
|
+
|
1516
|
+
// convert to integer
|
1517
|
+
vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
|
1518
|
+
vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
|
1519
|
+
|
1520
|
+
// store result
|
1521
|
+
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
|
1522
|
+
|
1523
|
+
// compute sum for y[i].s
|
1524
|
+
vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
|
1525
|
+
vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
|
1526
|
+
|
1527
|
+
// set y[i].s
|
1528
|
+
int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
|
1529
|
+
y[i].s = sum*d;
|
1530
|
+
}
|
1493
1531
|
#else
|
1494
1532
|
// scalar
|
1495
1533
|
quantize_row_q8_1_reference(x, y, k);
|
@@ -2662,30 +2700,32 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2662
2700
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
2663
2701
|
|
2664
2702
|
for (int i = 0; i < nb; i++) {
|
2665
|
-
|
2703
|
+
// load elements
|
2704
|
+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
|
2666
2705
|
|
2667
|
-
|
2668
|
-
|
2706
|
+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
|
2707
|
+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
|
2669
2708
|
|
2670
|
-
|
2671
|
-
|
2709
|
+
// mask and store lower part of x, and then upper part
|
2710
|
+
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
2711
|
+
vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
2672
2712
|
|
2673
|
-
|
2674
|
-
|
2713
|
+
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
2714
|
+
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
2675
2715
|
|
2676
|
-
|
2677
|
-
|
2716
|
+
// subtract offset
|
2717
|
+
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
|
2718
|
+
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
|
2678
2719
|
|
2679
|
-
|
2680
|
-
|
2720
|
+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
2721
|
+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
2681
2722
|
|
2682
2723
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
2683
2724
|
|
2684
|
-
vint32m1_t vs1 =
|
2685
|
-
vint32m1_t vs2 =
|
2725
|
+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
2726
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
2686
2727
|
|
2687
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(
|
2688
|
-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
|
2728
|
+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
2689
2729
|
|
2690
2730
|
sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
|
2691
2731
|
}
|
@@ -2823,27 +2863,28 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|
2823
2863
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
2824
2864
|
|
2825
2865
|
for (int i = 0; i < nb; i++) {
|
2826
|
-
|
2866
|
+
// load elements
|
2867
|
+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
|
2827
2868
|
|
2828
|
-
|
2829
|
-
|
2869
|
+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
|
2870
|
+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
|
2830
2871
|
|
2831
|
-
|
2832
|
-
|
2872
|
+
// mask and store lower part of x, and then upper part
|
2873
|
+
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
2874
|
+
vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
2833
2875
|
|
2834
|
-
|
2835
|
-
|
2876
|
+
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
2877
|
+
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
2836
2878
|
|
2837
|
-
|
2838
|
-
|
2879
|
+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
2880
|
+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
2839
2881
|
|
2840
2882
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
2841
2883
|
|
2842
|
-
vint32m1_t vs1 =
|
2843
|
-
vint32m1_t vs2 =
|
2884
|
+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
2885
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
2844
2886
|
|
2845
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(
|
2846
|
-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
|
2887
|
+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
2847
2888
|
|
2848
2889
|
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
|
2849
2890
|
}
|
@@ -3088,66 +3129,61 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
|
|
3088
3129
|
|
3089
3130
|
uint32_t qh;
|
3090
3131
|
|
3091
|
-
// These temp values are for masking and shift operations
|
3092
|
-
uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
3093
|
-
uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
|
3094
|
-
0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
|
3095
|
-
|
3096
3132
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
3097
3133
|
|
3134
|
+
// These tempory registers are for masking and shift operations
|
3135
|
+
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
3136
|
+
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
3137
|
+
|
3138
|
+
vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
|
3139
|
+
vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
|
3140
|
+
|
3098
3141
|
for (int i = 0; i < nb; i++) {
|
3099
3142
|
memcpy(&qh, x[i].qh, sizeof(uint32_t));
|
3100
3143
|
|
3101
|
-
// temporary registers
|
3102
|
-
vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
|
3103
|
-
vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
|
3104
|
-
vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
|
3105
|
-
vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
|
3106
|
-
|
3107
3144
|
// ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
|
3108
|
-
|
3109
|
-
|
3110
|
-
|
3145
|
+
vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
|
3146
|
+
vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
|
3147
|
+
vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
|
3111
3148
|
|
3112
3149
|
// ((qh & (1u << (j + 16))) >> (j + 12));
|
3113
|
-
|
3114
|
-
|
3150
|
+
vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
|
3151
|
+
vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
|
3115
3152
|
|
3116
3153
|
// narrowing
|
3117
|
-
|
3118
|
-
|
3154
|
+
vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
|
3155
|
+
vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
|
3119
3156
|
|
3120
|
-
|
3121
|
-
|
3157
|
+
vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
|
3158
|
+
vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
|
3122
3159
|
|
3123
3160
|
// load
|
3124
|
-
|
3161
|
+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
|
3125
3162
|
|
3126
|
-
|
3127
|
-
|
3163
|
+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
|
3164
|
+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
|
3128
3165
|
|
3129
|
-
|
3130
|
-
|
3166
|
+
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
3167
|
+
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
3131
3168
|
|
3132
|
-
|
3133
|
-
|
3169
|
+
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
|
3170
|
+
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
|
3134
3171
|
|
3135
|
-
|
3136
|
-
|
3172
|
+
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
3173
|
+
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
3137
3174
|
|
3138
|
-
|
3139
|
-
|
3175
|
+
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
|
3176
|
+
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
|
3140
3177
|
|
3141
|
-
|
3142
|
-
|
3178
|
+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
3179
|
+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
3143
3180
|
|
3144
3181
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
3145
3182
|
|
3146
|
-
vint32m1_t vs1 =
|
3147
|
-
vint32m1_t vs2 =
|
3183
|
+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
3184
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
3148
3185
|
|
3149
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(
|
3150
|
-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
|
3186
|
+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
3151
3187
|
|
3152
3188
|
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
|
3153
3189
|
}
|
@@ -3414,62 +3450,58 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|
3414
3450
|
|
3415
3451
|
uint32_t qh;
|
3416
3452
|
|
3417
|
-
// These temp values are for shift operations
|
3418
|
-
uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
3419
|
-
|
3420
3453
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
3421
3454
|
|
3455
|
+
// temporary registers for shift operations
|
3456
|
+
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
3457
|
+
vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
|
3458
|
+
|
3422
3459
|
for (int i = 0; i < nb; i++) {
|
3423
3460
|
memcpy(&qh, x[i].qh, sizeof(uint32_t));
|
3424
3461
|
|
3425
|
-
// temporary registers
|
3426
|
-
vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
|
3427
|
-
vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
|
3428
|
-
|
3429
3462
|
// load qh
|
3430
|
-
|
3463
|
+
vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
|
3431
3464
|
|
3432
3465
|
// ((qh >> (j + 0)) << 4) & 0x10;
|
3433
|
-
|
3434
|
-
|
3435
|
-
|
3466
|
+
vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
|
3467
|
+
vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
|
3468
|
+
vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
|
3436
3469
|
|
3437
3470
|
// ((qh >> (j + 12)) ) & 0x10;
|
3438
|
-
|
3439
|
-
|
3471
|
+
vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
|
3472
|
+
vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
|
3440
3473
|
|
3441
3474
|
// narrowing
|
3442
|
-
|
3443
|
-
|
3475
|
+
vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
|
3476
|
+
vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
|
3444
3477
|
|
3445
|
-
|
3446
|
-
|
3478
|
+
vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
|
3479
|
+
vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
|
3447
3480
|
|
3448
3481
|
// load
|
3449
|
-
|
3482
|
+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
|
3450
3483
|
|
3451
|
-
|
3452
|
-
|
3484
|
+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
|
3485
|
+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
|
3453
3486
|
|
3454
|
-
|
3455
|
-
|
3487
|
+
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
3488
|
+
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
3456
3489
|
|
3457
|
-
|
3458
|
-
|
3490
|
+
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
|
3491
|
+
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
|
3459
3492
|
|
3460
|
-
|
3461
|
-
|
3493
|
+
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
3494
|
+
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
3462
3495
|
|
3463
|
-
|
3464
|
-
|
3496
|
+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
3497
|
+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
3465
3498
|
|
3466
3499
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
3467
3500
|
|
3468
|
-
vint32m1_t vs1 =
|
3469
|
-
vint32m1_t vs2 =
|
3501
|
+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
3502
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
3470
3503
|
|
3471
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(
|
3472
|
-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
|
3504
|
+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
3473
3505
|
|
3474
3506
|
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
|
3475
3507
|
}
|
@@ -4025,12 +4057,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
4025
4057
|
"ALIBI",
|
4026
4058
|
"CLAMP",
|
4027
4059
|
"CONV_1D",
|
4060
|
+
"CONV_TRANSPOSE_1D",
|
4028
4061
|
"CONV_2D",
|
4029
4062
|
"CONV_TRANSPOSE_2D",
|
4030
4063
|
"POOL_1D",
|
4031
4064
|
"POOL_2D",
|
4032
4065
|
"UPSCALE",
|
4033
4066
|
|
4067
|
+
"CONV_1D_STAGE_0",
|
4068
|
+
"CONV_1D_STAGE_1",
|
4069
|
+
|
4034
4070
|
"FLASH_ATTN",
|
4035
4071
|
"FLASH_FF",
|
4036
4072
|
"FLASH_ATTN_BACK",
|
@@ -4056,7 +4092,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
4056
4092
|
"CROSS_ENTROPY_LOSS_BACK",
|
4057
4093
|
};
|
4058
4094
|
|
4059
|
-
static_assert(GGML_OP_COUNT ==
|
4095
|
+
static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
|
4060
4096
|
|
4061
4097
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
4062
4098
|
"none",
|
@@ -4107,12 +4143,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
4107
4143
|
"alibi(x)",
|
4108
4144
|
"clamp(x)",
|
4109
4145
|
"conv_1d(x)",
|
4146
|
+
"conv_transpose_1d(x)",
|
4110
4147
|
"conv_2d(x)",
|
4111
4148
|
"conv_transpose_2d(x)",
|
4112
4149
|
"pool_1d(x)",
|
4113
4150
|
"pool_2d(x)",
|
4114
4151
|
"upscale(x)",
|
4115
4152
|
|
4153
|
+
"conv_1d_stage_0(x)",
|
4154
|
+
"conv_1d_stage_1(x)",
|
4155
|
+
|
4116
4156
|
"flash_attn(x)",
|
4117
4157
|
"flash_ff(x)",
|
4118
4158
|
"flash_attn_back(x)",
|
@@ -4138,7 +4178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
4138
4178
|
"cross_entropy_loss_back(x,y)",
|
4139
4179
|
};
|
4140
4180
|
|
4141
|
-
static_assert(GGML_OP_COUNT ==
|
4181
|
+
static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
|
4142
4182
|
|
4143
4183
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
4144
4184
|
|
@@ -4167,7 +4207,10 @@ static void ggml_setup_op_has_task_pass(void) {
|
|
4167
4207
|
p[GGML_OP_DIAG_MASK_INF ] = true;
|
4168
4208
|
p[GGML_OP_DIAG_MASK_ZERO ] = true;
|
4169
4209
|
p[GGML_OP_CONV_1D ] = true;
|
4210
|
+
p[GGML_OP_CONV_1D_STAGE_0 ] = true;
|
4211
|
+
p[GGML_OP_CONV_1D_STAGE_1 ] = true;
|
4170
4212
|
p[GGML_OP_CONV_2D ] = true;
|
4213
|
+
p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
|
4171
4214
|
p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
|
4172
4215
|
p[GGML_OP_FLASH_ATTN_BACK ] = true;
|
4173
4216
|
p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
|
@@ -4884,6 +4927,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
|
4884
4927
|
*result = (struct ggml_tensor) {
|
4885
4928
|
/*.type =*/ type,
|
4886
4929
|
/*.backend =*/ GGML_BACKEND_CPU,
|
4930
|
+
/*.buffer =*/ NULL,
|
4887
4931
|
/*.n_dims =*/ n_dims,
|
4888
4932
|
/*.ne =*/ { 1, 1, 1, 1 },
|
4889
4933
|
/*.nb =*/ { 0, 0, 0, 0 },
|
@@ -5450,6 +5494,39 @@ struct ggml_tensor * ggml_view_tensor(
|
|
5450
5494
|
return result;
|
5451
5495
|
}
|
5452
5496
|
|
5497
|
+
struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) {
|
5498
|
+
struct ggml_object * obj = ctx->objects_begin;
|
5499
|
+
|
5500
|
+
char * const mem_buffer = ctx->mem_buffer;
|
5501
|
+
|
5502
|
+
while (obj != NULL) {
|
5503
|
+
if (obj->type == GGML_OBJECT_TENSOR) {
|
5504
|
+
return (struct ggml_tensor *)(mem_buffer + obj->offs);
|
5505
|
+
}
|
5506
|
+
|
5507
|
+
obj = obj->next;
|
5508
|
+
}
|
5509
|
+
|
5510
|
+
return NULL;
|
5511
|
+
}
|
5512
|
+
|
5513
|
+
struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
5514
|
+
struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
|
5515
|
+
obj = obj->next;
|
5516
|
+
|
5517
|
+
char * const mem_buffer = ctx->mem_buffer;
|
5518
|
+
|
5519
|
+
while (obj != NULL) {
|
5520
|
+
if (obj->type == GGML_OBJECT_TENSOR) {
|
5521
|
+
return (struct ggml_tensor *)(mem_buffer + obj->offs);
|
5522
|
+
}
|
5523
|
+
|
5524
|
+
obj = obj->next;
|
5525
|
+
}
|
5526
|
+
|
5527
|
+
return NULL;
|
5528
|
+
}
|
5529
|
+
|
5453
5530
|
struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
|
5454
5531
|
struct ggml_object * obj = ctx->objects_begin;
|
5455
5532
|
|
@@ -6690,7 +6767,6 @@ struct ggml_tensor * ggml_cont_4d(
|
|
6690
6767
|
return result;
|
6691
6768
|
}
|
6692
6769
|
|
6693
|
-
|
6694
6770
|
// ggml_reshape
|
6695
6771
|
|
6696
6772
|
struct ggml_tensor * ggml_reshape(
|
@@ -7448,14 +7524,17 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
|
|
7448
7524
|
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
7449
7525
|
}
|
7450
7526
|
|
7451
|
-
|
7452
|
-
|
7453
|
-
|
7454
|
-
|
7455
|
-
|
7456
|
-
|
7457
|
-
|
7458
|
-
|
7527
|
+
// im2col: [N, IC, IL] => [N, OL, IC*K]
|
7528
|
+
// a: [OC,IC, K]
|
7529
|
+
// b: [N, IC, IL]
|
7530
|
+
// result: [N, OL, IC*K]
|
7531
|
+
static struct ggml_tensor * ggml_conv_1d_stage_0(
|
7532
|
+
struct ggml_context * ctx,
|
7533
|
+
struct ggml_tensor * a,
|
7534
|
+
struct ggml_tensor * b,
|
7535
|
+
int s0,
|
7536
|
+
int p0,
|
7537
|
+
int d0) {
|
7459
7538
|
GGML_ASSERT(a->ne[1] == b->ne[1]);
|
7460
7539
|
bool is_node = false;
|
7461
7540
|
|
@@ -7464,16 +7543,54 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
|
|
7464
7543
|
is_node = true;
|
7465
7544
|
}
|
7466
7545
|
|
7546
|
+
const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
7547
|
+
|
7467
7548
|
const int64_t ne[4] = {
|
7468
|
-
|
7469
|
-
|
7549
|
+
a->ne[1] * a->ne[0],
|
7550
|
+
OL,
|
7551
|
+
b->ne[2],
|
7552
|
+
1,
|
7470
7553
|
};
|
7471
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx,
|
7554
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
|
7472
7555
|
|
7473
7556
|
int32_t params[] = { s0, p0, d0 };
|
7474
7557
|
ggml_set_op_params(result, params, sizeof(params));
|
7475
7558
|
|
7476
|
-
result->op =
|
7559
|
+
result->op = GGML_OP_CONV_1D_STAGE_0;
|
7560
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7561
|
+
result->src[0] = a;
|
7562
|
+
result->src[1] = b;
|
7563
|
+
|
7564
|
+
return result;
|
7565
|
+
}
|
7566
|
+
|
7567
|
+
// ggml_conv_1d_stage_1
|
7568
|
+
|
7569
|
+
// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
7570
|
+
// a: [OC, IC, K]
|
7571
|
+
// b: [N, OL, IC * K]
|
7572
|
+
// result: [N, OC, OL]
|
7573
|
+
static struct ggml_tensor * ggml_conv_1d_stage_1(
|
7574
|
+
struct ggml_context * ctx,
|
7575
|
+
struct ggml_tensor * a,
|
7576
|
+
struct ggml_tensor * b) {
|
7577
|
+
|
7578
|
+
bool is_node = false;
|
7579
|
+
|
7580
|
+
if (a->grad || b->grad) {
|
7581
|
+
GGML_ASSERT(false); // TODO: implement backward
|
7582
|
+
is_node = true;
|
7583
|
+
}
|
7584
|
+
|
7585
|
+
const int64_t ne[4] = {
|
7586
|
+
b->ne[1],
|
7587
|
+
a->ne[2],
|
7588
|
+
b->ne[2],
|
7589
|
+
1,
|
7590
|
+
};
|
7591
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
7592
|
+
|
7593
|
+
result->op = GGML_OP_CONV_1D_STAGE_1;
|
7477
7594
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7478
7595
|
result->src[0] = a;
|
7479
7596
|
result->src[1] = b;
|
@@ -7481,6 +7598,53 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
|
|
7481
7598
|
return result;
|
7482
7599
|
}
|
7483
7600
|
|
7601
|
+
// ggml_conv_1d
|
7602
|
+
|
7603
|
+
GGML_API struct ggml_tensor * ggml_conv_1d(
|
7604
|
+
struct ggml_context * ctx,
|
7605
|
+
struct ggml_tensor * a,
|
7606
|
+
struct ggml_tensor * b,
|
7607
|
+
int s0,
|
7608
|
+
int p0,
|
7609
|
+
int d0) {
|
7610
|
+
struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
|
7611
|
+
result = ggml_conv_1d_stage_1(ctx, a, result);
|
7612
|
+
return result;
|
7613
|
+
}
|
7614
|
+
|
7615
|
+
// GGML_API struct ggml_tensor * ggml_conv_1d(
|
7616
|
+
// struct ggml_context * ctx,
|
7617
|
+
// struct ggml_tensor * a,
|
7618
|
+
// struct ggml_tensor * b,
|
7619
|
+
// int s0,
|
7620
|
+
// int p0,
|
7621
|
+
// int d0) {
|
7622
|
+
// GGML_ASSERT(ggml_is_matrix(b));
|
7623
|
+
// GGML_ASSERT(a->ne[1] == b->ne[1]);
|
7624
|
+
// bool is_node = false;
|
7625
|
+
|
7626
|
+
// if (a->grad || b->grad) {
|
7627
|
+
// GGML_ASSERT(false); // TODO: implement backward
|
7628
|
+
// is_node = true;
|
7629
|
+
// }
|
7630
|
+
|
7631
|
+
// const int64_t ne[4] = {
|
7632
|
+
// ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
|
7633
|
+
// a->ne[2], 1, 1,
|
7634
|
+
// };
|
7635
|
+
// struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
|
7636
|
+
|
7637
|
+
// int32_t params[] = { s0, p0, d0 };
|
7638
|
+
// ggml_set_op_params(result, params, sizeof(params));
|
7639
|
+
|
7640
|
+
// result->op = GGML_OP_CONV_1D;
|
7641
|
+
// result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7642
|
+
// result->src[0] = a;
|
7643
|
+
// result->src[1] = b;
|
7644
|
+
|
7645
|
+
// return result;
|
7646
|
+
// }
|
7647
|
+
|
7484
7648
|
// ggml_conv_1d_ph
|
7485
7649
|
|
7486
7650
|
struct ggml_tensor* ggml_conv_1d_ph(
|
@@ -7492,6 +7656,50 @@ struct ggml_tensor* ggml_conv_1d_ph(
|
|
7492
7656
|
return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
|
7493
7657
|
}
|
7494
7658
|
|
7659
|
+
// ggml_conv_transpose_1d
|
7660
|
+
|
7661
|
+
static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
|
7662
|
+
return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
|
7663
|
+
}
|
7664
|
+
|
7665
|
+
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
|
7666
|
+
struct ggml_context * ctx,
|
7667
|
+
struct ggml_tensor * a,
|
7668
|
+
struct ggml_tensor * b,
|
7669
|
+
int s0,
|
7670
|
+
int p0,
|
7671
|
+
int d0) {
|
7672
|
+
GGML_ASSERT(ggml_is_matrix(b));
|
7673
|
+
GGML_ASSERT(a->ne[2] == b->ne[1]);
|
7674
|
+
GGML_ASSERT(a->ne[3] == 1);
|
7675
|
+
|
7676
|
+
GGML_ASSERT(p0 == 0);
|
7677
|
+
GGML_ASSERT(d0 == 1);
|
7678
|
+
|
7679
|
+
bool is_node = false;
|
7680
|
+
|
7681
|
+
if (a->grad || b->grad) {
|
7682
|
+
GGML_ASSERT(false); // TODO: implement backward
|
7683
|
+
is_node = true;
|
7684
|
+
}
|
7685
|
+
|
7686
|
+
const int64_t ne[4] = {
|
7687
|
+
ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
|
7688
|
+
a->ne[1], b->ne[2], 1,
|
7689
|
+
};
|
7690
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
7691
|
+
|
7692
|
+
int32_t params[] = { s0, p0, d0 };
|
7693
|
+
ggml_set_op_params(result, params, sizeof(params));
|
7694
|
+
|
7695
|
+
result->op = GGML_OP_CONV_TRANSPOSE_1D;
|
7696
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7697
|
+
result->src[0] = a;
|
7698
|
+
result->src[1] = b;
|
7699
|
+
|
7700
|
+
return result;
|
7701
|
+
}
|
7702
|
+
|
7495
7703
|
// ggml_conv_2d
|
7496
7704
|
|
7497
7705
|
struct ggml_tensor * ggml_conv_2d(
|
@@ -8472,6 +8680,7 @@ void ggml_set_param(
|
|
8472
8680
|
|
8473
8681
|
GGML_ASSERT(tensor->grad == NULL);
|
8474
8682
|
tensor->grad = ggml_dup_tensor(ctx, tensor);
|
8683
|
+
ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
|
8475
8684
|
}
|
8476
8685
|
|
8477
8686
|
// ggml_compute_forward_dup
|
@@ -11058,7 +11267,7 @@ static void ggml_compute_forward_silu_f32(
|
|
11058
11267
|
|
11059
11268
|
#ifndef NDEBUG
|
11060
11269
|
for (int k = 0; k < nc; k++) {
|
11061
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
11270
|
+
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
11062
11271
|
UNUSED(x);
|
11063
11272
|
assert(!isnan(x));
|
11064
11273
|
assert(!isinf(x));
|
@@ -11621,11 +11830,6 @@ static void ggml_compute_forward_mul_mat(
|
|
11621
11830
|
|
11622
11831
|
#if defined(GGML_USE_CLBLAST)
|
11623
11832
|
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
11624
|
-
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
|
11625
|
-
// ref: https://github.com/ggerganov/ggml/pull/224
|
11626
|
-
GGML_ASSERT(ne02 == ne12);
|
11627
|
-
GGML_ASSERT(ne03 == ne13);
|
11628
|
-
|
11629
11833
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
11630
11834
|
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
11631
11835
|
}
|
@@ -12889,28 +13093,25 @@ static void ggml_compute_forward_alibi_f32(
|
|
12889
13093
|
return;
|
12890
13094
|
}
|
12891
13095
|
|
12892
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13096
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12893
13097
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
12894
13098
|
float max_bias;
|
12895
13099
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
12896
13100
|
|
12897
|
-
|
13101
|
+
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
13102
|
+
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
|
13103
|
+
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
|
13104
|
+
//const int64_t ne3 = src0->ne[3]; // 1 -> bsz
|
12898
13105
|
|
12899
|
-
const
|
12900
|
-
const
|
12901
|
-
const int ne2 = src0->ne[2]; // n_head -> this is k
|
12902
|
-
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
13106
|
+
const int64_t n = ggml_nrows(src0);
|
13107
|
+
const int64_t ne2_ne3 = n/ne1; // ne2*ne3
|
12903
13108
|
|
12904
|
-
const
|
12905
|
-
const
|
12906
|
-
|
12907
|
-
const int nb0 = src0->nb[0];
|
12908
|
-
const int nb1 = src0->nb[1];
|
12909
|
-
const int nb2 = src0->nb[2];
|
13109
|
+
const size_t nb0 = src0->nb[0];
|
13110
|
+
const size_t nb1 = src0->nb[1];
|
13111
|
+
const size_t nb2 = src0->nb[2];
|
12910
13112
|
//const int nb3 = src0->nb[3];
|
12911
13113
|
|
12912
13114
|
GGML_ASSERT(nb0 == sizeof(float));
|
12913
|
-
GGML_ASSERT(ne1 + n_past == ne0);
|
12914
13115
|
GGML_ASSERT(n_head == ne2);
|
12915
13116
|
|
12916
13117
|
// add alibi to src0 (KQ_scaled)
|
@@ -12919,9 +13120,9 @@ static void ggml_compute_forward_alibi_f32(
|
|
12919
13120
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
12920
13121
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
12921
13122
|
|
12922
|
-
for (
|
12923
|
-
for (
|
12924
|
-
for (
|
13123
|
+
for (int64_t i = 0; i < ne0; i++) {
|
13124
|
+
for (int64_t j = 0; j < ne1; j++) {
|
13125
|
+
for (int64_t k = 0; k < ne2_ne3; k++) {
|
12925
13126
|
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
12926
13127
|
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
12927
13128
|
|
@@ -12936,7 +13137,6 @@ static void ggml_compute_forward_alibi_f32(
|
|
12936
13137
|
}
|
12937
13138
|
|
12938
13139
|
pdst[0] = i * m_k + src[0];
|
12939
|
-
|
12940
13140
|
}
|
12941
13141
|
}
|
12942
13142
|
}
|
@@ -13636,7 +13836,7 @@ static void ggml_compute_forward_rope_back(
|
|
13636
13836
|
|
13637
13837
|
// ggml_compute_forward_conv_1d
|
13638
13838
|
|
13639
|
-
static void
|
13839
|
+
static void ggml_compute_forward_conv_1d_f16_f32(
|
13640
13840
|
const struct ggml_compute_params * params,
|
13641
13841
|
const struct ggml_tensor * src0,
|
13642
13842
|
const struct ggml_tensor * src1,
|
@@ -13654,42 +13854,33 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
|
13654
13854
|
const int nth = params->nth;
|
13655
13855
|
|
13656
13856
|
const int nk = ne00;
|
13657
|
-
const int nh = nk/2;
|
13658
13857
|
|
13659
|
-
|
13858
|
+
// size of the convolution row - the kernel size unrolled across all input channels
|
13859
|
+
const int ew0 = nk*ne01;
|
13860
|
+
|
13861
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
13862
|
+
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
13863
|
+
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
13660
13864
|
|
13661
|
-
GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
|
13662
13865
|
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
13663
13866
|
GGML_ASSERT(nb10 == sizeof(float));
|
13664
13867
|
|
13665
13868
|
if (params->type == GGML_TASK_INIT) {
|
13666
|
-
// TODO: fix this memset (wsize is overestimated)
|
13667
13869
|
memset(params->wdata, 0, params->wsize);
|
13668
13870
|
|
13669
|
-
|
13670
|
-
{
|
13671
|
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13871
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13672
13872
|
|
13673
|
-
|
13674
|
-
|
13675
|
-
|
13676
|
-
ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
|
13677
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
13678
|
-
dst_data[i00*ew0 + i01] = src[i00];
|
13679
|
-
}
|
13680
|
-
}
|
13681
|
-
}
|
13682
|
-
}
|
13873
|
+
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
13874
|
+
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
13875
|
+
ggml_fp16_t * dst_data = wdata;
|
13683
13876
|
|
13684
|
-
|
13685
|
-
|
13686
|
-
|
13877
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
13878
|
+
for (int64_t ik = 0; ik < nk; ik++) {
|
13879
|
+
const int idx0 = i0*s0 + ik*d0 - p0;
|
13687
13880
|
|
13688
|
-
|
13689
|
-
|
13690
|
-
|
13691
|
-
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
13692
|
-
dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
|
13881
|
+
if(!(idx0 < 0 || idx0 >= ne10)) {
|
13882
|
+
dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]);
|
13883
|
+
}
|
13693
13884
|
}
|
13694
13885
|
}
|
13695
13886
|
}
|
@@ -13702,7 +13893,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
|
13702
13893
|
}
|
13703
13894
|
|
13704
13895
|
// total rows in dst
|
13705
|
-
const int nr =
|
13896
|
+
const int nr = ne2;
|
13706
13897
|
|
13707
13898
|
// rows per thread
|
13708
13899
|
const int dr = (nr + nth - 1)/nth;
|
@@ -13711,23 +13902,22 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
|
13711
13902
|
const int ir0 = dr*ith;
|
13712
13903
|
const int ir1 = MIN(ir0 + dr, nr);
|
13713
13904
|
|
13714
|
-
|
13715
|
-
|
13716
|
-
|
13717
|
-
|
13718
|
-
|
13719
|
-
|
13720
|
-
|
13721
|
-
|
13722
|
-
(ggml_fp16_t *)
|
13723
|
-
|
13724
|
-
dst_data[i0] += v;
|
13905
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13906
|
+
|
13907
|
+
for (int i2 = 0; i2 < ne2; i2++) {
|
13908
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
13909
|
+
float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
|
13910
|
+
|
13911
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
13912
|
+
ggml_vec_dot_f16(ew0, dst_data + i0,
|
13913
|
+
(ggml_fp16_t *) ((char *) src0->data + i1*nb02),
|
13914
|
+
(ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
|
13725
13915
|
}
|
13726
13916
|
}
|
13727
13917
|
}
|
13728
13918
|
}
|
13729
13919
|
|
13730
|
-
static void
|
13920
|
+
static void ggml_compute_forward_conv_1d_f32(
|
13731
13921
|
const struct ggml_compute_params * params,
|
13732
13922
|
const struct ggml_tensor * src0,
|
13733
13923
|
const struct ggml_tensor * src1,
|
@@ -13745,42 +13935,32 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
|
|
13745
13935
|
const int nth = params->nth;
|
13746
13936
|
|
13747
13937
|
const int nk = ne00;
|
13748
|
-
const int nh = nk/2;
|
13749
13938
|
|
13750
|
-
const int ew0 =
|
13939
|
+
const int ew0 = nk*ne01;
|
13940
|
+
|
13941
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
13942
|
+
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
13943
|
+
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
13751
13944
|
|
13752
|
-
GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
|
13753
13945
|
GGML_ASSERT(nb00 == sizeof(float));
|
13754
13946
|
GGML_ASSERT(nb10 == sizeof(float));
|
13755
13947
|
|
13756
13948
|
if (params->type == GGML_TASK_INIT) {
|
13757
|
-
// TODO: fix this memset (wsize is overestimated)
|
13758
13949
|
memset(params->wdata, 0, params->wsize);
|
13759
13950
|
|
13760
|
-
|
13761
|
-
{
|
13762
|
-
float * const wdata = (float *) params->wdata + 0;
|
13951
|
+
float * const wdata = (float *) params->wdata + 0;
|
13763
13952
|
|
13764
|
-
|
13765
|
-
|
13766
|
-
|
13767
|
-
float * dst_data = wdata + i02*ew0*ne00;
|
13768
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
13769
|
-
dst_data[i00*ew0 + i01] = src[i00];
|
13770
|
-
}
|
13771
|
-
}
|
13772
|
-
}
|
13773
|
-
}
|
13953
|
+
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
13954
|
+
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
13955
|
+
float * dst_data = wdata;
|
13774
13956
|
|
13775
|
-
|
13776
|
-
|
13777
|
-
|
13957
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
13958
|
+
for (int64_t ik = 0; ik < nk; ik++) {
|
13959
|
+
const int idx0 = i0*s0 + ik*d0 - p0;
|
13778
13960
|
|
13779
|
-
|
13780
|
-
|
13781
|
-
|
13782
|
-
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
13783
|
-
dst_data[(i10 + nh)*ew0 + i11] = src[i10];
|
13961
|
+
if(!(idx0 < 0 || idx0 >= ne10)) {
|
13962
|
+
dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
|
13963
|
+
}
|
13784
13964
|
}
|
13785
13965
|
}
|
13786
13966
|
}
|
@@ -13802,35 +13982,242 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
|
|
13802
13982
|
const int ir0 = dr*ith;
|
13803
13983
|
const int ir1 = MIN(ir0 + dr, nr);
|
13804
13984
|
|
13805
|
-
|
13806
|
-
|
13807
|
-
|
13808
|
-
|
13809
|
-
|
13810
|
-
|
13811
|
-
|
13812
|
-
|
13813
|
-
(float *)
|
13814
|
-
|
13815
|
-
|
13985
|
+
float * const wdata = (float *) params->wdata + 0;
|
13986
|
+
|
13987
|
+
for (int i2 = 0; i2 < ne2; i2++) {
|
13988
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
13989
|
+
float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
|
13990
|
+
|
13991
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
13992
|
+
ggml_vec_dot_f32(ew0, dst_data + i0,
|
13993
|
+
(float *) ((char *) src0->data + i1*nb02),
|
13994
|
+
(float *) wdata + i2*nb2 + i0*ew0);
|
13995
|
+
}
|
13996
|
+
}
|
13997
|
+
}
|
13998
|
+
}
|
13999
|
+
|
14000
|
+
static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
|
14001
|
+
ggml_fp16_t * A,
|
14002
|
+
ggml_fp16_t * B,
|
14003
|
+
float * C,
|
14004
|
+
const int ith, const int nth) {
|
14005
|
+
// does not seem to make a difference
|
14006
|
+
int64_t m0, m1, n0, n1;
|
14007
|
+
// patches per thread
|
14008
|
+
if (m > n) {
|
14009
|
+
n0 = 0;
|
14010
|
+
n1 = n;
|
14011
|
+
|
14012
|
+
// total patches in dst
|
14013
|
+
const int np = m;
|
14014
|
+
|
14015
|
+
// patches per thread
|
14016
|
+
const int dp = (np + nth - 1)/nth;
|
14017
|
+
|
14018
|
+
// patch range for this thread
|
14019
|
+
m0 = dp*ith;
|
14020
|
+
m1 = MIN(m0 + dp, np);
|
14021
|
+
} else {
|
14022
|
+
m0 = 0;
|
14023
|
+
m1 = m;
|
14024
|
+
|
14025
|
+
// total patches in dst
|
14026
|
+
const int np = n;
|
14027
|
+
|
14028
|
+
// patches per thread
|
14029
|
+
const int dp = (np + nth - 1)/nth;
|
14030
|
+
|
14031
|
+
// patch range for this thread
|
14032
|
+
n0 = dp*ith;
|
14033
|
+
n1 = MIN(n0 + dp, np);
|
14034
|
+
}
|
14035
|
+
|
14036
|
+
// block-tiling attempt
|
14037
|
+
int64_t blck_n = 16;
|
14038
|
+
int64_t blck_m = 16;
|
14039
|
+
|
14040
|
+
// int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
|
14041
|
+
// int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
|
14042
|
+
// if (blck_size > 0) {
|
14043
|
+
// blck_0 = 4;
|
14044
|
+
// blck_1 = blck_size / blck_0;
|
14045
|
+
// if (blck_1 < 0) {
|
14046
|
+
// blck_1 = 1;
|
14047
|
+
// }
|
14048
|
+
// // blck_0 = (int64_t)sqrt(blck_size);
|
14049
|
+
// // blck_1 = blck_0;
|
14050
|
+
// }
|
14051
|
+
// // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
|
14052
|
+
|
14053
|
+
for (int j = n0; j < n1; j+=blck_n) {
|
14054
|
+
for (int i = m0; i < m1; i+=blck_m) {
|
14055
|
+
// printf("i j k => %d %d %d\n", i, j, K);
|
14056
|
+
for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
|
14057
|
+
for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
|
14058
|
+
ggml_vec_dot_f16(k,
|
14059
|
+
C + ii*n + jj,
|
14060
|
+
A + ii * k,
|
14061
|
+
B + jj * k);
|
14062
|
+
}
|
13816
14063
|
}
|
13817
14064
|
}
|
13818
14065
|
}
|
13819
14066
|
}
|
13820
14067
|
|
13821
|
-
|
14068
|
+
// src0: kernel [OC, IC, K]
|
14069
|
+
// src1: signal [N, IC, IL]
|
14070
|
+
// dst: result [N, OL, IC*K]
|
14071
|
+
static void ggml_compute_forward_conv_1d_stage_0_f32(
|
13822
14072
|
const struct ggml_compute_params * params,
|
13823
14073
|
const struct ggml_tensor * src0,
|
13824
14074
|
const struct ggml_tensor * src1,
|
13825
14075
|
struct ggml_tensor * dst) {
|
13826
|
-
|
14076
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
14077
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
14078
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
14079
|
+
|
14080
|
+
int64_t t0 = ggml_perf_time_us();
|
14081
|
+
UNUSED(t0);
|
14082
|
+
|
14083
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
14084
|
+
|
14085
|
+
const int64_t N = ne12;
|
14086
|
+
const int64_t IC = ne11;
|
14087
|
+
const int64_t IL = ne10;
|
14088
|
+
|
14089
|
+
const int64_t K = ne00;
|
14090
|
+
|
14091
|
+
const int64_t OL = ne1;
|
14092
|
+
|
14093
|
+
const int ith = params->ith;
|
14094
|
+
const int nth = params->nth;
|
14095
|
+
|
14096
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
14097
|
+
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
14098
|
+
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
14099
|
+
|
14100
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
14101
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
14102
|
+
|
14103
|
+
if (params->type == GGML_TASK_INIT) {
|
14104
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
14105
|
+
return;
|
14106
|
+
}
|
14107
|
+
|
14108
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14109
|
+
return;
|
14110
|
+
}
|
14111
|
+
|
14112
|
+
// im2col: [N, IC, IL] => [N, OL, IC*K]
|
14113
|
+
{
|
14114
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
|
14115
|
+
|
14116
|
+
for (int64_t in = 0; in < N; in++) {
|
14117
|
+
for (int64_t iol = 0; iol < OL; iol++) {
|
14118
|
+
for (int64_t iic = ith; iic < IC; iic+=nth) {
|
14119
|
+
|
14120
|
+
// micro kernel
|
14121
|
+
ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
|
14122
|
+
const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
|
14123
|
+
|
14124
|
+
for (int64_t ik = 0; ik < K; ik++) {
|
14125
|
+
const int64_t iil = iol*s0 + ik*d0 - p0;
|
14126
|
+
|
14127
|
+
if (!(iil < 0 || iil >= IL)) {
|
14128
|
+
dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]);
|
14129
|
+
}
|
14130
|
+
}
|
14131
|
+
}
|
14132
|
+
}
|
14133
|
+
}
|
14134
|
+
}
|
14135
|
+
}
|
14136
|
+
|
14137
|
+
// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
14138
|
+
// src0: [OC, IC, K]
|
14139
|
+
// src1: [N, OL, IC * K]
|
14140
|
+
// result: [N, OC, OL]
|
14141
|
+
static void ggml_compute_forward_conv_1d_stage_1_f16(
|
14142
|
+
const struct ggml_compute_params * params,
|
14143
|
+
const struct ggml_tensor * src0,
|
14144
|
+
const struct ggml_tensor * src1,
|
14145
|
+
struct ggml_tensor * dst) {
|
14146
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
14147
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
14148
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
14149
|
+
|
14150
|
+
int64_t t0 = ggml_perf_time_us();
|
14151
|
+
UNUSED(t0);
|
14152
|
+
|
14153
|
+
if (params->type == GGML_TASK_INIT) {
|
14154
|
+
return;
|
14155
|
+
}
|
14156
|
+
|
14157
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14158
|
+
return;
|
14159
|
+
}
|
14160
|
+
|
14161
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
14162
|
+
|
14163
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
14164
|
+
GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
|
14165
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
14166
|
+
|
14167
|
+
const int N = ne12;
|
14168
|
+
const int OL = ne11;
|
14169
|
+
|
14170
|
+
const int OC = ne02;
|
14171
|
+
const int IC = ne01;
|
14172
|
+
const int K = ne00;
|
14173
|
+
|
14174
|
+
const int ith = params->ith;
|
14175
|
+
const int nth = params->nth;
|
14176
|
+
|
14177
|
+
int64_t m = OC;
|
14178
|
+
int64_t n = OL;
|
14179
|
+
int64_t k = IC * K;
|
14180
|
+
|
14181
|
+
// [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
14182
|
+
for (int i = 0; i < N; i++) {
|
14183
|
+
ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
|
14184
|
+
ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
|
14185
|
+
float * C = (float *)dst->data + i * m * n; // [m, n]
|
14186
|
+
|
14187
|
+
gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
|
14188
|
+
}
|
14189
|
+
}
|
14190
|
+
|
14191
|
+
static void ggml_compute_forward_conv_1d(
|
14192
|
+
const struct ggml_compute_params * params,
|
14193
|
+
const struct ggml_tensor * src0,
|
14194
|
+
const struct ggml_tensor * src1,
|
14195
|
+
struct ggml_tensor * dst) {
|
14196
|
+
switch(src0->type) {
|
13827
14197
|
case GGML_TYPE_F16:
|
13828
14198
|
{
|
13829
|
-
|
14199
|
+
ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
|
13830
14200
|
} break;
|
13831
14201
|
case GGML_TYPE_F32:
|
13832
14202
|
{
|
13833
|
-
|
14203
|
+
ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
|
14204
|
+
} break;
|
14205
|
+
default:
|
14206
|
+
{
|
14207
|
+
GGML_ASSERT(false);
|
14208
|
+
} break;
|
14209
|
+
}
|
14210
|
+
}
|
14211
|
+
|
14212
|
+
static void ggml_compute_forward_conv_1d_stage_0(
|
14213
|
+
const struct ggml_compute_params * params,
|
14214
|
+
const struct ggml_tensor * src0,
|
14215
|
+
const struct ggml_tensor * src1,
|
14216
|
+
struct ggml_tensor * dst) {
|
14217
|
+
switch(src0->type) {
|
14218
|
+
case GGML_TYPE_F16:
|
14219
|
+
{
|
14220
|
+
ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
|
13834
14221
|
} break;
|
13835
14222
|
default:
|
13836
14223
|
{
|
@@ -13839,7 +14226,26 @@ static void ggml_compute_forward_conv_1d_s1_ph(
|
|
13839
14226
|
}
|
13840
14227
|
}
|
13841
14228
|
|
13842
|
-
static void
|
14229
|
+
static void ggml_compute_forward_conv_1d_stage_1(
|
14230
|
+
const struct ggml_compute_params * params,
|
14231
|
+
const struct ggml_tensor * src0,
|
14232
|
+
const struct ggml_tensor * src1,
|
14233
|
+
struct ggml_tensor * dst) {
|
14234
|
+
switch(src0->type) {
|
14235
|
+
case GGML_TYPE_F16:
|
14236
|
+
{
|
14237
|
+
ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
|
14238
|
+
} break;
|
14239
|
+
default:
|
14240
|
+
{
|
14241
|
+
GGML_ASSERT(false);
|
14242
|
+
} break;
|
14243
|
+
}
|
14244
|
+
}
|
14245
|
+
|
14246
|
+
// ggml_compute_forward_conv_transpose_1d
|
14247
|
+
|
14248
|
+
static void ggml_compute_forward_conv_transpose_1d_f16_f32(
|
13843
14249
|
const struct ggml_compute_params * params,
|
13844
14250
|
const struct ggml_tensor * src0,
|
13845
14251
|
const struct ggml_tensor * src1,
|
@@ -13856,43 +14262,38 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
|
13856
14262
|
const int ith = params->ith;
|
13857
14263
|
const int nth = params->nth;
|
13858
14264
|
|
13859
|
-
const int nk = ne00;
|
13860
|
-
const int nh = nk/2;
|
13861
|
-
|
13862
|
-
const int ew0 = ggml_up32(ne01);
|
14265
|
+
const int nk = ne00*ne01*ne02;
|
13863
14266
|
|
13864
|
-
GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
|
13865
14267
|
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
13866
14268
|
GGML_ASSERT(nb10 == sizeof(float));
|
13867
14269
|
|
13868
14270
|
if (params->type == GGML_TASK_INIT) {
|
13869
|
-
// TODO: fix this memset (wsize is overestimated)
|
13870
14271
|
memset(params->wdata, 0, params->wsize);
|
13871
14272
|
|
13872
|
-
//
|
14273
|
+
// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
13873
14274
|
{
|
13874
14275
|
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13875
14276
|
|
13876
14277
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
13877
14278
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
13878
14279
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
|
13879
|
-
ggml_fp16_t * dst_data = wdata +
|
14280
|
+
ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
|
13880
14281
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
13881
|
-
dst_data[i00*
|
14282
|
+
dst_data[i00*ne02 + i02] = src[i00];
|
13882
14283
|
}
|
13883
14284
|
}
|
13884
14285
|
}
|
13885
14286
|
}
|
13886
14287
|
|
13887
|
-
//
|
14288
|
+
// permute source data (src1) from (L x Cin) to (Cin x L)
|
13888
14289
|
{
|
13889
|
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata +
|
14290
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
|
14291
|
+
ggml_fp16_t * dst_data = wdata;
|
13890
14292
|
|
13891
14293
|
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
13892
14294
|
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
13893
|
-
ggml_fp16_t * dst_data = wdata;
|
13894
14295
|
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
13895
|
-
dst_data[
|
14296
|
+
dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
|
13896
14297
|
}
|
13897
14298
|
}
|
13898
14299
|
}
|
@@ -13904,8 +14305,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
|
13904
14305
|
return;
|
13905
14306
|
}
|
13906
14307
|
|
14308
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
14309
|
+
|
13907
14310
|
// total rows in dst
|
13908
|
-
const int nr =
|
14311
|
+
const int nr = ne1;
|
13909
14312
|
|
13910
14313
|
// rows per thread
|
13911
14314
|
const int dr = (nr + nth - 1)/nth;
|
@@ -13914,23 +14317,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
|
13914
14317
|
const int ir0 = dr*ith;
|
13915
14318
|
const int ir1 = MIN(ir0 + dr, nr);
|
13916
14319
|
|
14320
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
14321
|
+
ggml_fp16_t * const wdata_src = wdata + nk;
|
14322
|
+
|
13917
14323
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
13918
14324
|
float * dst_data = (float *)((char *) dst->data + i1*nb1);
|
13919
|
-
|
13920
|
-
|
13921
|
-
|
13922
|
-
|
13923
|
-
|
13924
|
-
|
13925
|
-
(ggml_fp16_t *)
|
13926
|
-
|
13927
|
-
dst_data[
|
14325
|
+
ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
|
14326
|
+
for (int i10 = 0; i10 < ne10; i10++) {
|
14327
|
+
const int i1n = i10*ne11;
|
14328
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
14329
|
+
float v = 0;
|
14330
|
+
ggml_vec_dot_f16(ne02, &v,
|
14331
|
+
(ggml_fp16_t *) wdata_src + i1n,
|
14332
|
+
(ggml_fp16_t *) wdata_kernel + i00*ne02);
|
14333
|
+
dst_data[i10*s0 + i00] += v;
|
13928
14334
|
}
|
13929
14335
|
}
|
13930
14336
|
}
|
13931
14337
|
}
|
13932
14338
|
|
13933
|
-
static void
|
14339
|
+
static void ggml_compute_forward_conv_transpose_1d_f32(
|
13934
14340
|
const struct ggml_compute_params * params,
|
13935
14341
|
const struct ggml_tensor * src0,
|
13936
14342
|
const struct ggml_tensor * src1,
|
@@ -13947,29 +14353,24 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
13947
14353
|
const int ith = params->ith;
|
13948
14354
|
const int nth = params->nth;
|
13949
14355
|
|
13950
|
-
const int nk = ne00;
|
13951
|
-
const int nh = nk/2;
|
13952
|
-
|
13953
|
-
const int ew0 = ggml_up32(ne01);
|
14356
|
+
const int nk = ne00*ne01*ne02;
|
13954
14357
|
|
13955
|
-
GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
|
13956
14358
|
GGML_ASSERT(nb00 == sizeof(float));
|
13957
14359
|
GGML_ASSERT(nb10 == sizeof(float));
|
13958
14360
|
|
13959
14361
|
if (params->type == GGML_TASK_INIT) {
|
13960
|
-
// TODO: fix this memset (wsize is overestimated)
|
13961
14362
|
memset(params->wdata, 0, params->wsize);
|
13962
14363
|
|
13963
|
-
// prepare kernel data (src0)
|
14364
|
+
// prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
13964
14365
|
{
|
13965
14366
|
float * const wdata = (float *) params->wdata + 0;
|
13966
14367
|
|
13967
14368
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
13968
14369
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
13969
14370
|
const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
|
13970
|
-
float * dst_data = wdata +
|
14371
|
+
float * dst_data = wdata + i01*ne00*ne02;
|
13971
14372
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
13972
|
-
dst_data[i00*
|
14373
|
+
dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00];
|
13973
14374
|
}
|
13974
14375
|
}
|
13975
14376
|
}
|
@@ -13977,13 +14378,13 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
13977
14378
|
|
13978
14379
|
// prepare source data (src1)
|
13979
14380
|
{
|
13980
|
-
float * const wdata = (float *) params->wdata +
|
14381
|
+
float * const wdata = (float *) params->wdata + nk;
|
14382
|
+
float * dst_data = wdata;
|
13981
14383
|
|
13982
14384
|
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
13983
14385
|
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
13984
|
-
float * dst_data = wdata;
|
13985
14386
|
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
13986
|
-
dst_data[
|
14387
|
+
dst_data[i10*ne11 + i11] = src[i10];
|
13987
14388
|
}
|
13988
14389
|
}
|
13989
14390
|
}
|
@@ -13995,8 +14396,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
13995
14396
|
return;
|
13996
14397
|
}
|
13997
14398
|
|
14399
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
14400
|
+
|
13998
14401
|
// total rows in dst
|
13999
|
-
const int nr =
|
14402
|
+
const int nr = ne1;
|
14000
14403
|
|
14001
14404
|
// rows per thread
|
14002
14405
|
const int dr = (nr + nth - 1)/nth;
|
@@ -14005,23 +14408,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
14005
14408
|
const int ir0 = dr*ith;
|
14006
14409
|
const int ir1 = MIN(ir0 + dr, nr);
|
14007
14410
|
|
14411
|
+
float * const wdata = (float *) params->wdata + 0;
|
14412
|
+
float * const wdata_src = wdata + nk;
|
14413
|
+
|
14008
14414
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
14009
14415
|
float * dst_data = (float *)((char *) dst->data + i1*nb1);
|
14010
|
-
|
14011
|
-
|
14012
|
-
|
14013
|
-
|
14014
|
-
|
14015
|
-
|
14016
|
-
|
14017
|
-
|
14018
|
-
dst_data[
|
14416
|
+
float * wdata_kernel = wdata + i1*ne02*ne00;
|
14417
|
+
for (int i10 = 0; i10 < ne10; i10++) {
|
14418
|
+
const int i1n = i10*ne11;
|
14419
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
14420
|
+
float v = 0;
|
14421
|
+
ggml_vec_dot_f32(ne02, &v,
|
14422
|
+
wdata_src + i1n,
|
14423
|
+
wdata_kernel + i00*ne02);
|
14424
|
+
dst_data[i10*s0 + i00] += v;
|
14019
14425
|
}
|
14020
14426
|
}
|
14021
14427
|
}
|
14022
14428
|
}
|
14023
14429
|
|
14024
|
-
static void
|
14430
|
+
static void ggml_compute_forward_conv_transpose_1d(
|
14025
14431
|
const struct ggml_compute_params * params,
|
14026
14432
|
const struct ggml_tensor * src0,
|
14027
14433
|
const struct ggml_tensor * src1,
|
@@ -14029,11 +14435,11 @@ static void ggml_compute_forward_conv_1d_s2_ph(
|
|
14029
14435
|
switch (src0->type) {
|
14030
14436
|
case GGML_TYPE_F16:
|
14031
14437
|
{
|
14032
|
-
|
14438
|
+
ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
|
14033
14439
|
} break;
|
14034
14440
|
case GGML_TYPE_F32:
|
14035
14441
|
{
|
14036
|
-
|
14442
|
+
ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
|
14037
14443
|
} break;
|
14038
14444
|
default:
|
14039
14445
|
{
|
@@ -14042,27 +14448,6 @@ static void ggml_compute_forward_conv_1d_s2_ph(
|
|
14042
14448
|
}
|
14043
14449
|
}
|
14044
14450
|
|
14045
|
-
// ggml_compute_forward_conv_1d
|
14046
|
-
|
14047
|
-
static void ggml_compute_forward_conv_1d(
|
14048
|
-
const struct ggml_compute_params * params,
|
14049
|
-
const struct ggml_tensor * src0,
|
14050
|
-
const struct ggml_tensor * src1,
|
14051
|
-
struct ggml_tensor * dst) {
|
14052
|
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
14053
|
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
14054
|
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
14055
|
-
GGML_ASSERT(d0 == 1); // dilation not supported
|
14056
|
-
GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported
|
14057
|
-
if (s0 == 1) {
|
14058
|
-
ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst);
|
14059
|
-
} else if (s0 == 2) {
|
14060
|
-
ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
|
14061
|
-
} else {
|
14062
|
-
GGML_ASSERT(false); // only stride 1 and 2 supported
|
14063
|
-
}
|
14064
|
-
}
|
14065
|
-
|
14066
14451
|
// ggml_compute_forward_conv_2d
|
14067
14452
|
|
14068
14453
|
static void ggml_compute_forward_conv_2d_f16_f32(
|
@@ -14077,7 +14462,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
|
|
14077
14462
|
int64_t t0 = ggml_perf_time_us();
|
14078
14463
|
UNUSED(t0);
|
14079
14464
|
|
14080
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
14465
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
14081
14466
|
|
14082
14467
|
const int ith = params->ith;
|
14083
14468
|
const int nth = params->nth;
|
@@ -14105,20 +14490,22 @@ static void ggml_compute_forward_conv_2d_f16_f32(
|
|
14105
14490
|
{
|
14106
14491
|
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
14107
14492
|
|
14108
|
-
for (int
|
14109
|
-
|
14110
|
-
|
14111
|
-
|
14112
|
-
|
14113
|
-
for (int
|
14114
|
-
for (int
|
14115
|
-
for (int
|
14116
|
-
|
14117
|
-
|
14118
|
-
|
14119
|
-
|
14120
|
-
|
14121
|
-
|
14493
|
+
for (int i13 = 0; i13 < ne13; i13++) {
|
14494
|
+
for (int i12 = 0; i12 < ne12; i12++) {
|
14495
|
+
const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
|
14496
|
+
ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
|
14497
|
+
|
14498
|
+
for (int i1 = 0; i1 < ne1; i1++) {
|
14499
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
14500
|
+
for (int ik1 = 0; ik1 < nk1; ik1++) {
|
14501
|
+
for (int ik0 = 0; ik0 < nk0; ik0++) {
|
14502
|
+
const int idx0 = i0*s0 + ik0*d0 - p0;
|
14503
|
+
const int idx1 = i1*s1 + ik1*d1 - p1;
|
14504
|
+
|
14505
|
+
if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
|
14506
|
+
dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
|
14507
|
+
GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
|
14508
|
+
}
|
14122
14509
|
}
|
14123
14510
|
}
|
14124
14511
|
}
|
@@ -16401,6 +16788,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
16401
16788
|
{
|
16402
16789
|
ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
|
16403
16790
|
} break;
|
16791
|
+
case GGML_OP_CONV_1D_STAGE_0:
|
16792
|
+
{
|
16793
|
+
ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
|
16794
|
+
} break;
|
16795
|
+
case GGML_OP_CONV_1D_STAGE_1:
|
16796
|
+
{
|
16797
|
+
ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
|
16798
|
+
} break;
|
16799
|
+
case GGML_OP_CONV_TRANSPOSE_1D:
|
16800
|
+
{
|
16801
|
+
ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
|
16802
|
+
} break;
|
16404
16803
|
case GGML_OP_CONV_2D:
|
16405
16804
|
{
|
16406
16805
|
ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
|
@@ -17326,10 +17725,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
17326
17725
|
{
|
17327
17726
|
GGML_ASSERT(false); // TODO: not implemented
|
17328
17727
|
} break;
|
17728
|
+
case GGML_OP_CONV_1D_STAGE_0:
|
17729
|
+
{
|
17730
|
+
GGML_ASSERT(false); // TODO: not implemented
|
17731
|
+
} break;
|
17732
|
+
case GGML_OP_CONV_1D_STAGE_1:
|
17733
|
+
{
|
17734
|
+
GGML_ASSERT(false); // TODO: not implemented
|
17735
|
+
} break;
|
17329
17736
|
case GGML_OP_CONV_2D:
|
17330
17737
|
{
|
17331
17738
|
GGML_ASSERT(false); // TODO: not implemented
|
17332
17739
|
} break;
|
17740
|
+
case GGML_OP_CONV_TRANSPOSE_1D:
|
17741
|
+
{
|
17742
|
+
GGML_ASSERT(false); // TODO: not implemented
|
17743
|
+
} break;
|
17333
17744
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
17334
17745
|
{
|
17335
17746
|
GGML_ASSERT(false); // TODO: not implemented
|
@@ -18171,21 +18582,68 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|
18171
18582
|
GGML_ASSERT(node->src[1]->ne[2] == 1);
|
18172
18583
|
GGML_ASSERT(node->src[1]->ne[3] == 1);
|
18173
18584
|
|
18585
|
+
const int64_t ne00 = node->src[0]->ne[0];
|
18586
|
+
const int64_t ne01 = node->src[0]->ne[1];
|
18587
|
+
const int64_t ne02 = node->src[0]->ne[2];
|
18588
|
+
|
18589
|
+
const int64_t ne10 = node->src[1]->ne[0];
|
18590
|
+
const int64_t ne11 = node->src[1]->ne[1];
|
18591
|
+
|
18592
|
+
const int64_t ne0 = node->ne[0];
|
18593
|
+
const int64_t ne1 = node->ne[1];
|
18594
|
+
const int64_t nk = ne00;
|
18595
|
+
const int64_t ew0 = nk * ne01;
|
18596
|
+
|
18597
|
+
UNUSED(ne02);
|
18598
|
+
UNUSED(ne10);
|
18599
|
+
UNUSED(ne11);
|
18600
|
+
|
18174
18601
|
size_t cur = 0;
|
18175
|
-
const int nk = node->src[0]->ne[0];
|
18176
18602
|
|
18177
18603
|
if (node->src[0]->type == GGML_TYPE_F16 &&
|
18178
|
-
|
18179
|
-
cur = sizeof(ggml_fp16_t)*(
|
18180
|
-
|
18181
|
-
|
18182
|
-
|
18604
|
+
node->src[1]->type == GGML_TYPE_F32) {
|
18605
|
+
cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
|
18606
|
+
} else if (node->src[0]->type == GGML_TYPE_F32 &&
|
18607
|
+
node->src[1]->type == GGML_TYPE_F32) {
|
18608
|
+
cur = sizeof(float)*(ne0*ne1*ew0);
|
18609
|
+
} else {
|
18610
|
+
GGML_ASSERT(false);
|
18611
|
+
}
|
18612
|
+
|
18613
|
+
work_size = MAX(work_size, cur);
|
18614
|
+
} break;
|
18615
|
+
case GGML_OP_CONV_1D_STAGE_0:
|
18616
|
+
{
|
18617
|
+
n_tasks = n_threads;
|
18618
|
+
} break;
|
18619
|
+
case GGML_OP_CONV_1D_STAGE_1:
|
18620
|
+
{
|
18621
|
+
n_tasks = n_threads;
|
18622
|
+
} break;
|
18623
|
+
case GGML_OP_CONV_TRANSPOSE_1D:
|
18624
|
+
{
|
18625
|
+
n_tasks = n_threads;
|
18626
|
+
|
18627
|
+
GGML_ASSERT(node->src[0]->ne[3] == 1);
|
18628
|
+
GGML_ASSERT(node->src[1]->ne[2] == 1);
|
18629
|
+
GGML_ASSERT(node->src[1]->ne[3] == 1);
|
18630
|
+
|
18631
|
+
const int64_t ne00 = node->src[0]->ne[0]; // K
|
18632
|
+
const int64_t ne01 = node->src[0]->ne[1]; // Cout
|
18633
|
+
const int64_t ne02 = node->src[0]->ne[2]; // Cin
|
18634
|
+
|
18635
|
+
const int64_t ne10 = node->src[1]->ne[0]; // L
|
18636
|
+
const int64_t ne11 = node->src[1]->ne[1]; // Cin
|
18637
|
+
|
18638
|
+
size_t cur = 0;
|
18639
|
+
if (node->src[0]->type == GGML_TYPE_F16 &&
|
18640
|
+
node->src[1]->type == GGML_TYPE_F32) {
|
18641
|
+
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
|
18642
|
+
cur += sizeof(ggml_fp16_t)*ne10*ne11;
|
18183
18643
|
} else if (node->src[0]->type == GGML_TYPE_F32 &&
|
18184
|
-
|
18185
|
-
cur
|
18186
|
-
|
18187
|
-
( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
|
18188
|
-
);
|
18644
|
+
node->src[1]->type == GGML_TYPE_F32) {
|
18645
|
+
cur += sizeof(float)*ne00*ne01*ne02;
|
18646
|
+
cur += sizeof(float)*ne10*ne11;
|
18189
18647
|
} else {
|
18190
18648
|
GGML_ASSERT(false);
|
18191
18649
|
}
|
@@ -19311,7 +19769,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19311
19769
|
if (callback) {
|
19312
19770
|
callback(callback_data, accum_step, &sched, &cancel);
|
19313
19771
|
if (cancel) {
|
19314
|
-
|
19772
|
+
return GGML_OPT_CANCEL;
|
19315
19773
|
}
|
19316
19774
|
}
|
19317
19775
|
// ggml_graph_reset (gf);
|
@@ -19320,9 +19778,6 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19320
19778
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19321
19779
|
fx += ggml_get_f32_1d(f, 0);
|
19322
19780
|
}
|
19323
|
-
if (cancel) {
|
19324
|
-
return GGML_OPT_DID_NOT_CONVERGE;
|
19325
|
-
}
|
19326
19781
|
fx *= accum_norm;
|
19327
19782
|
|
19328
19783
|
opt->adam.fx_prev = fx;
|
@@ -19348,9 +19803,6 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19348
19803
|
|
19349
19804
|
// run the optimizer
|
19350
19805
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
19351
|
-
if (cancel) {
|
19352
|
-
break;
|
19353
|
-
}
|
19354
19806
|
opt->iter = iter0 + t + 1;
|
19355
19807
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
19356
19808
|
|
@@ -19408,7 +19860,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19408
19860
|
if (callback) {
|
19409
19861
|
callback(callback_data, accum_step, &sched, &cancel);
|
19410
19862
|
if (cancel) {
|
19411
|
-
|
19863
|
+
return GGML_OPT_CANCEL;;
|
19412
19864
|
}
|
19413
19865
|
}
|
19414
19866
|
// ggml_graph_reset (gf);
|
@@ -19417,9 +19869,6 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19417
19869
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19418
19870
|
fx += ggml_get_f32_1d(f, 0);
|
19419
19871
|
}
|
19420
|
-
if (cancel) {
|
19421
|
-
break;
|
19422
|
-
}
|
19423
19872
|
fx *= accum_norm;
|
19424
19873
|
|
19425
19874
|
opt->loss_after = fx;
|
@@ -19538,7 +19987,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
19538
19987
|
finit = *fx;
|
19539
19988
|
dgtest = params->lbfgs.ftol*dginit;
|
19540
19989
|
|
19541
|
-
while (
|
19990
|
+
while (true) {
|
19542
19991
|
ggml_vec_cpy_f32(nx, x, xp);
|
19543
19992
|
ggml_vec_mad_f32(nx, x, d, *step);
|
19544
19993
|
|
@@ -19554,7 +20003,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
19554
20003
|
float sched = 0;
|
19555
20004
|
callback(callback_data, accum_step, &sched, cancel);
|
19556
20005
|
if (*cancel) {
|
19557
|
-
|
20006
|
+
return GGML_OPT_CANCEL;
|
19558
20007
|
}
|
19559
20008
|
}
|
19560
20009
|
// ggml_graph_reset (gf);
|
@@ -19563,9 +20012,6 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
19563
20012
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19564
20013
|
*fx += ggml_get_f32_1d(f, 0);
|
19565
20014
|
}
|
19566
|
-
if (*cancel) {
|
19567
|
-
break;
|
19568
|
-
}
|
19569
20015
|
*fx *= accum_norm;
|
19570
20016
|
|
19571
20017
|
}
|
@@ -19698,7 +20144,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19698
20144
|
float sched = 0;
|
19699
20145
|
callback(callback_data, accum_step, &sched, &cancel);
|
19700
20146
|
if (cancel) {
|
19701
|
-
|
20147
|
+
return GGML_OPT_CANCEL;
|
19702
20148
|
}
|
19703
20149
|
}
|
19704
20150
|
// ggml_graph_reset (gf);
|
@@ -19707,9 +20153,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19707
20153
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19708
20154
|
fx += ggml_get_f32_1d(f, 0);
|
19709
20155
|
}
|
19710
|
-
if (cancel) {
|
19711
|
-
return GGML_OPT_DID_NOT_CONVERGE;
|
19712
|
-
}
|
19713
20156
|
fx *= accum_norm;
|
19714
20157
|
|
19715
20158
|
opt->loss_before = fx;
|
@@ -19768,9 +20211,13 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19768
20211
|
ggml_vec_cpy_f32(nx, xp, x);
|
19769
20212
|
ggml_vec_cpy_f32(nx, gp, g);
|
19770
20213
|
|
20214
|
+
// TODO: instead of passing &cancel here, use the return code of the linesearch
|
20215
|
+
// to determine if the optimization should be cancelled
|
20216
|
+
// this is a simple change, but not doing this atm, since I don't have a nice
|
20217
|
+
// way to test and don't want to break something with so many changes lined up
|
19771
20218
|
ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
|
19772
|
-
if (
|
19773
|
-
|
20219
|
+
if (cancel) {
|
20220
|
+
return GGML_OPT_CANCEL;
|
19774
20221
|
}
|
19775
20222
|
|
19776
20223
|
if (ls < 0) {
|