llama_cpp 0.6.0 → 0.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +49 -3
- data/ext/llama_cpp/src/ggml-alloc.c +62 -107
- data/ext/llama_cpp/src/ggml-alloc.h +11 -5
- data/ext/llama_cpp/src/ggml-backend.c +385 -0
- data/ext/llama_cpp/src/ggml-backend.h +143 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +622 -150
- data/ext/llama_cpp/src/ggml-cuda.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.h +18 -1
- data/ext/llama_cpp/src/ggml-metal.m +358 -131
- data/ext/llama_cpp/src/ggml-metal.metal +137 -47
- data/ext/llama_cpp/src/ggml-opencl.cpp +136 -68
- data/ext/llama_cpp/src/ggml.c +812 -365
- data/ext/llama_cpp/src/ggml.h +25 -7
- data/ext/llama_cpp/src/k_quants.c +744 -2
- data/ext/llama_cpp/src/k_quants.h +5 -5
- data/ext/llama_cpp/src/llama.cpp +2387 -421
- data/ext/llama_cpp/src/llama.h +22 -6
- data/ext/llama_cpp/src/unicode.h +462 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +5 -0
- metadata +5 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -162,40 +162,16 @@ typedef void * thread_ret_t;
|
|
162
162
|
|
163
163
|
#define GGML_PRINT(...) printf(__VA_ARGS__)
|
164
164
|
|
165
|
+
//
|
166
|
+
// end of logging block
|
167
|
+
//
|
168
|
+
|
165
169
|
#ifdef GGML_USE_ACCELERATE
|
166
170
|
// uncomment to use vDSP for soft max computation
|
167
171
|
// note: not sure if it is actually faster
|
168
172
|
//#define GGML_SOFT_MAX_ACCELERATE
|
169
173
|
#endif
|
170
174
|
|
171
|
-
//
|
172
|
-
// logging
|
173
|
-
//
|
174
|
-
|
175
|
-
#if (GGML_DEBUG >= 1)
|
176
|
-
#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
|
177
|
-
#else
|
178
|
-
#define GGML_PRINT_DEBUG(...)
|
179
|
-
#endif
|
180
|
-
|
181
|
-
#if (GGML_DEBUG >= 5)
|
182
|
-
#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
|
183
|
-
#else
|
184
|
-
#define GGML_PRINT_DEBUG_5(...)
|
185
|
-
#endif
|
186
|
-
|
187
|
-
#if (GGML_DEBUG >= 10)
|
188
|
-
#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
|
189
|
-
#else
|
190
|
-
#define GGML_PRINT_DEBUG_10(...)
|
191
|
-
#endif
|
192
|
-
|
193
|
-
#define GGML_PRINT(...) printf(__VA_ARGS__)
|
194
|
-
|
195
|
-
//
|
196
|
-
// end of logging block
|
197
|
-
//
|
198
|
-
|
199
175
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
200
176
|
#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
|
201
177
|
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
@@ -1032,8 +1008,8 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
|
|
1032
1008
|
y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
|
1033
1009
|
|
1034
1010
|
// get the 5-th bit and store it in qh at the right position
|
1035
|
-
qh |= ((xi0 &
|
1036
|
-
qh |= ((xi1 &
|
1011
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
1012
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
|
1037
1013
|
}
|
1038
1014
|
|
1039
1015
|
memcpy(&y[i].qh, &qh, sizeof(qh));
|
@@ -1080,8 +1056,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
|
|
1080
1056
|
y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
|
1081
1057
|
|
1082
1058
|
// get the 5-th bit and store it in qh at the right position
|
1083
|
-
qh |= ((xi0 &
|
1084
|
-
qh |= ((xi1 &
|
1059
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
1060
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
|
1085
1061
|
}
|
1086
1062
|
|
1087
1063
|
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
@@ -1272,6 +1248,33 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1272
1248
|
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1273
1249
|
#endif
|
1274
1250
|
}
|
1251
|
+
#elif defined(__riscv_v_intrinsic)
|
1252
|
+
|
1253
|
+
size_t vl = __riscv_vsetvl_e32m4(QK8_0);
|
1254
|
+
|
1255
|
+
for (int i = 0; i < nb; i++) {
|
1256
|
+
// load elements
|
1257
|
+
vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
|
1258
|
+
|
1259
|
+
vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
|
1260
|
+
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
|
1261
|
+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
|
1262
|
+
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
|
1263
|
+
|
1264
|
+
const float d = amax / ((1 << 7) - 1);
|
1265
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1266
|
+
|
1267
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1268
|
+
|
1269
|
+
vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
|
1270
|
+
|
1271
|
+
// convert to integer
|
1272
|
+
vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
|
1273
|
+
vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
|
1274
|
+
|
1275
|
+
// store result
|
1276
|
+
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
|
1277
|
+
}
|
1275
1278
|
#else
|
1276
1279
|
// scalar
|
1277
1280
|
quantize_row_q8_0_reference(x, y, k);
|
@@ -1490,6 +1493,41 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
|
|
1490
1493
|
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1491
1494
|
#endif
|
1492
1495
|
}
|
1496
|
+
#elif defined(__riscv_v_intrinsic)
|
1497
|
+
|
1498
|
+
size_t vl = __riscv_vsetvl_e32m4(QK8_1);
|
1499
|
+
|
1500
|
+
for (int i = 0; i < nb; i++) {
|
1501
|
+
// load elements
|
1502
|
+
vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
|
1503
|
+
|
1504
|
+
vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
|
1505
|
+
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
|
1506
|
+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
|
1507
|
+
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
|
1508
|
+
|
1509
|
+
const float d = amax / ((1 << 7) - 1);
|
1510
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1511
|
+
|
1512
|
+
y[i].d = d;
|
1513
|
+
|
1514
|
+
vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
|
1515
|
+
|
1516
|
+
// convert to integer
|
1517
|
+
vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
|
1518
|
+
vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
|
1519
|
+
|
1520
|
+
// store result
|
1521
|
+
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
|
1522
|
+
|
1523
|
+
// compute sum for y[i].s
|
1524
|
+
vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
|
1525
|
+
vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
|
1526
|
+
|
1527
|
+
// set y[i].s
|
1528
|
+
int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
|
1529
|
+
y[i].s = sum*d;
|
1530
|
+
}
|
1493
1531
|
#else
|
1494
1532
|
// scalar
|
1495
1533
|
quantize_row_q8_1_reference(x, y, k);
|
@@ -2662,30 +2700,32 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2662
2700
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
2663
2701
|
|
2664
2702
|
for (int i = 0; i < nb; i++) {
|
2665
|
-
|
2703
|
+
// load elements
|
2704
|
+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
|
2666
2705
|
|
2667
|
-
|
2668
|
-
|
2706
|
+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
|
2707
|
+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
|
2669
2708
|
|
2670
|
-
|
2671
|
-
|
2709
|
+
// mask and store lower part of x, and then upper part
|
2710
|
+
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
2711
|
+
vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
2672
2712
|
|
2673
|
-
|
2674
|
-
|
2713
|
+
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
2714
|
+
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
2675
2715
|
|
2676
|
-
|
2677
|
-
|
2716
|
+
// subtract offset
|
2717
|
+
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
|
2718
|
+
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
|
2678
2719
|
|
2679
|
-
|
2680
|
-
|
2720
|
+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
2721
|
+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
2681
2722
|
|
2682
2723
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
2683
2724
|
|
2684
|
-
vint32m1_t vs1 =
|
2685
|
-
vint32m1_t vs2 =
|
2725
|
+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
2726
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
2686
2727
|
|
2687
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(
|
2688
|
-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
|
2728
|
+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
2689
2729
|
|
2690
2730
|
sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
|
2691
2731
|
}
|
@@ -2823,27 +2863,28 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|
2823
2863
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
2824
2864
|
|
2825
2865
|
for (int i = 0; i < nb; i++) {
|
2826
|
-
|
2866
|
+
// load elements
|
2867
|
+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
|
2827
2868
|
|
2828
|
-
|
2829
|
-
|
2869
|
+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
|
2870
|
+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
|
2830
2871
|
|
2831
|
-
|
2832
|
-
|
2872
|
+
// mask and store lower part of x, and then upper part
|
2873
|
+
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
2874
|
+
vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
2833
2875
|
|
2834
|
-
|
2835
|
-
|
2876
|
+
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
2877
|
+
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
2836
2878
|
|
2837
|
-
|
2838
|
-
|
2879
|
+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
2880
|
+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
2839
2881
|
|
2840
2882
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
2841
2883
|
|
2842
|
-
vint32m1_t vs1 =
|
2843
|
-
vint32m1_t vs2 =
|
2884
|
+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
2885
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
2844
2886
|
|
2845
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(
|
2846
|
-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
|
2887
|
+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
2847
2888
|
|
2848
2889
|
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
|
2849
2890
|
}
|
@@ -3088,66 +3129,61 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
|
|
3088
3129
|
|
3089
3130
|
uint32_t qh;
|
3090
3131
|
|
3091
|
-
// These temp values are for masking and shift operations
|
3092
|
-
uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
3093
|
-
uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
|
3094
|
-
0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
|
3095
|
-
|
3096
3132
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
3097
3133
|
|
3134
|
+
// These tempory registers are for masking and shift operations
|
3135
|
+
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
3136
|
+
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
3137
|
+
|
3138
|
+
vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
|
3139
|
+
vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
|
3140
|
+
|
3098
3141
|
for (int i = 0; i < nb; i++) {
|
3099
3142
|
memcpy(&qh, x[i].qh, sizeof(uint32_t));
|
3100
3143
|
|
3101
|
-
// temporary registers
|
3102
|
-
vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
|
3103
|
-
vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
|
3104
|
-
vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
|
3105
|
-
vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
|
3106
|
-
|
3107
3144
|
// ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
|
3108
|
-
|
3109
|
-
|
3110
|
-
|
3145
|
+
vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
|
3146
|
+
vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
|
3147
|
+
vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
|
3111
3148
|
|
3112
3149
|
// ((qh & (1u << (j + 16))) >> (j + 12));
|
3113
|
-
|
3114
|
-
|
3150
|
+
vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
|
3151
|
+
vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
|
3115
3152
|
|
3116
3153
|
// narrowing
|
3117
|
-
|
3118
|
-
|
3154
|
+
vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
|
3155
|
+
vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
|
3119
3156
|
|
3120
|
-
|
3121
|
-
|
3157
|
+
vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
|
3158
|
+
vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
|
3122
3159
|
|
3123
3160
|
// load
|
3124
|
-
|
3161
|
+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
|
3125
3162
|
|
3126
|
-
|
3127
|
-
|
3163
|
+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
|
3164
|
+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
|
3128
3165
|
|
3129
|
-
|
3130
|
-
|
3166
|
+
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
3167
|
+
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
3131
3168
|
|
3132
|
-
|
3133
|
-
|
3169
|
+
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
|
3170
|
+
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
|
3134
3171
|
|
3135
|
-
|
3136
|
-
|
3172
|
+
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
3173
|
+
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
3137
3174
|
|
3138
|
-
|
3139
|
-
|
3175
|
+
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
|
3176
|
+
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
|
3140
3177
|
|
3141
|
-
|
3142
|
-
|
3178
|
+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
3179
|
+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
3143
3180
|
|
3144
3181
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
3145
3182
|
|
3146
|
-
vint32m1_t vs1 =
|
3147
|
-
vint32m1_t vs2 =
|
3183
|
+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
3184
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
3148
3185
|
|
3149
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(
|
3150
|
-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
|
3186
|
+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
3151
3187
|
|
3152
3188
|
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
|
3153
3189
|
}
|
@@ -3414,62 +3450,58 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|
3414
3450
|
|
3415
3451
|
uint32_t qh;
|
3416
3452
|
|
3417
|
-
// These temp values are for shift operations
|
3418
|
-
uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
3419
|
-
|
3420
3453
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
3421
3454
|
|
3455
|
+
// temporary registers for shift operations
|
3456
|
+
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
3457
|
+
vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
|
3458
|
+
|
3422
3459
|
for (int i = 0; i < nb; i++) {
|
3423
3460
|
memcpy(&qh, x[i].qh, sizeof(uint32_t));
|
3424
3461
|
|
3425
|
-
// temporary registers
|
3426
|
-
vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
|
3427
|
-
vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
|
3428
|
-
|
3429
3462
|
// load qh
|
3430
|
-
|
3463
|
+
vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
|
3431
3464
|
|
3432
3465
|
// ((qh >> (j + 0)) << 4) & 0x10;
|
3433
|
-
|
3434
|
-
|
3435
|
-
|
3466
|
+
vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
|
3467
|
+
vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
|
3468
|
+
vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
|
3436
3469
|
|
3437
3470
|
// ((qh >> (j + 12)) ) & 0x10;
|
3438
|
-
|
3439
|
-
|
3471
|
+
vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
|
3472
|
+
vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
|
3440
3473
|
|
3441
3474
|
// narrowing
|
3442
|
-
|
3443
|
-
|
3475
|
+
vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
|
3476
|
+
vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
|
3444
3477
|
|
3445
|
-
|
3446
|
-
|
3478
|
+
vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
|
3479
|
+
vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
|
3447
3480
|
|
3448
3481
|
// load
|
3449
|
-
|
3482
|
+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
|
3450
3483
|
|
3451
|
-
|
3452
|
-
|
3484
|
+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
|
3485
|
+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
|
3453
3486
|
|
3454
|
-
|
3455
|
-
|
3487
|
+
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
3488
|
+
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
3456
3489
|
|
3457
|
-
|
3458
|
-
|
3490
|
+
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
|
3491
|
+
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
|
3459
3492
|
|
3460
|
-
|
3461
|
-
|
3493
|
+
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
3494
|
+
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
3462
3495
|
|
3463
|
-
|
3464
|
-
|
3496
|
+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
3497
|
+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
3465
3498
|
|
3466
3499
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
3467
3500
|
|
3468
|
-
vint32m1_t vs1 =
|
3469
|
-
vint32m1_t vs2 =
|
3501
|
+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
3502
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
3470
3503
|
|
3471
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(
|
3472
|
-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
|
3504
|
+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
3473
3505
|
|
3474
3506
|
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
|
3475
3507
|
}
|
@@ -4025,12 +4057,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
4025
4057
|
"ALIBI",
|
4026
4058
|
"CLAMP",
|
4027
4059
|
"CONV_1D",
|
4060
|
+
"CONV_TRANSPOSE_1D",
|
4028
4061
|
"CONV_2D",
|
4029
4062
|
"CONV_TRANSPOSE_2D",
|
4030
4063
|
"POOL_1D",
|
4031
4064
|
"POOL_2D",
|
4032
4065
|
"UPSCALE",
|
4033
4066
|
|
4067
|
+
"CONV_1D_STAGE_0",
|
4068
|
+
"CONV_1D_STAGE_1",
|
4069
|
+
|
4034
4070
|
"FLASH_ATTN",
|
4035
4071
|
"FLASH_FF",
|
4036
4072
|
"FLASH_ATTN_BACK",
|
@@ -4056,7 +4092,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
4056
4092
|
"CROSS_ENTROPY_LOSS_BACK",
|
4057
4093
|
};
|
4058
4094
|
|
4059
|
-
static_assert(GGML_OP_COUNT ==
|
4095
|
+
static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
|
4060
4096
|
|
4061
4097
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
4062
4098
|
"none",
|
@@ -4107,12 +4143,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
4107
4143
|
"alibi(x)",
|
4108
4144
|
"clamp(x)",
|
4109
4145
|
"conv_1d(x)",
|
4146
|
+
"conv_transpose_1d(x)",
|
4110
4147
|
"conv_2d(x)",
|
4111
4148
|
"conv_transpose_2d(x)",
|
4112
4149
|
"pool_1d(x)",
|
4113
4150
|
"pool_2d(x)",
|
4114
4151
|
"upscale(x)",
|
4115
4152
|
|
4153
|
+
"conv_1d_stage_0(x)",
|
4154
|
+
"conv_1d_stage_1(x)",
|
4155
|
+
|
4116
4156
|
"flash_attn(x)",
|
4117
4157
|
"flash_ff(x)",
|
4118
4158
|
"flash_attn_back(x)",
|
@@ -4138,7 +4178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
4138
4178
|
"cross_entropy_loss_back(x,y)",
|
4139
4179
|
};
|
4140
4180
|
|
4141
|
-
static_assert(GGML_OP_COUNT ==
|
4181
|
+
static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
|
4142
4182
|
|
4143
4183
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
4144
4184
|
|
@@ -4167,7 +4207,10 @@ static void ggml_setup_op_has_task_pass(void) {
|
|
4167
4207
|
p[GGML_OP_DIAG_MASK_INF ] = true;
|
4168
4208
|
p[GGML_OP_DIAG_MASK_ZERO ] = true;
|
4169
4209
|
p[GGML_OP_CONV_1D ] = true;
|
4210
|
+
p[GGML_OP_CONV_1D_STAGE_0 ] = true;
|
4211
|
+
p[GGML_OP_CONV_1D_STAGE_1 ] = true;
|
4170
4212
|
p[GGML_OP_CONV_2D ] = true;
|
4213
|
+
p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
|
4171
4214
|
p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
|
4172
4215
|
p[GGML_OP_FLASH_ATTN_BACK ] = true;
|
4173
4216
|
p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
|
@@ -4884,6 +4927,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
|
4884
4927
|
*result = (struct ggml_tensor) {
|
4885
4928
|
/*.type =*/ type,
|
4886
4929
|
/*.backend =*/ GGML_BACKEND_CPU,
|
4930
|
+
/*.buffer =*/ NULL,
|
4887
4931
|
/*.n_dims =*/ n_dims,
|
4888
4932
|
/*.ne =*/ { 1, 1, 1, 1 },
|
4889
4933
|
/*.nb =*/ { 0, 0, 0, 0 },
|
@@ -5450,6 +5494,39 @@ struct ggml_tensor * ggml_view_tensor(
|
|
5450
5494
|
return result;
|
5451
5495
|
}
|
5452
5496
|
|
5497
|
+
struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) {
|
5498
|
+
struct ggml_object * obj = ctx->objects_begin;
|
5499
|
+
|
5500
|
+
char * const mem_buffer = ctx->mem_buffer;
|
5501
|
+
|
5502
|
+
while (obj != NULL) {
|
5503
|
+
if (obj->type == GGML_OBJECT_TENSOR) {
|
5504
|
+
return (struct ggml_tensor *)(mem_buffer + obj->offs);
|
5505
|
+
}
|
5506
|
+
|
5507
|
+
obj = obj->next;
|
5508
|
+
}
|
5509
|
+
|
5510
|
+
return NULL;
|
5511
|
+
}
|
5512
|
+
|
5513
|
+
struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
5514
|
+
struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
|
5515
|
+
obj = obj->next;
|
5516
|
+
|
5517
|
+
char * const mem_buffer = ctx->mem_buffer;
|
5518
|
+
|
5519
|
+
while (obj != NULL) {
|
5520
|
+
if (obj->type == GGML_OBJECT_TENSOR) {
|
5521
|
+
return (struct ggml_tensor *)(mem_buffer + obj->offs);
|
5522
|
+
}
|
5523
|
+
|
5524
|
+
obj = obj->next;
|
5525
|
+
}
|
5526
|
+
|
5527
|
+
return NULL;
|
5528
|
+
}
|
5529
|
+
|
5453
5530
|
struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
|
5454
5531
|
struct ggml_object * obj = ctx->objects_begin;
|
5455
5532
|
|
@@ -6690,7 +6767,6 @@ struct ggml_tensor * ggml_cont_4d(
|
|
6690
6767
|
return result;
|
6691
6768
|
}
|
6692
6769
|
|
6693
|
-
|
6694
6770
|
// ggml_reshape
|
6695
6771
|
|
6696
6772
|
struct ggml_tensor * ggml_reshape(
|
@@ -7448,14 +7524,17 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
|
|
7448
7524
|
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
7449
7525
|
}
|
7450
7526
|
|
7451
|
-
|
7452
|
-
|
7453
|
-
|
7454
|
-
|
7455
|
-
|
7456
|
-
|
7457
|
-
|
7458
|
-
|
7527
|
+
// im2col: [N, IC, IL] => [N, OL, IC*K]
|
7528
|
+
// a: [OC,IC, K]
|
7529
|
+
// b: [N, IC, IL]
|
7530
|
+
// result: [N, OL, IC*K]
|
7531
|
+
static struct ggml_tensor * ggml_conv_1d_stage_0(
|
7532
|
+
struct ggml_context * ctx,
|
7533
|
+
struct ggml_tensor * a,
|
7534
|
+
struct ggml_tensor * b,
|
7535
|
+
int s0,
|
7536
|
+
int p0,
|
7537
|
+
int d0) {
|
7459
7538
|
GGML_ASSERT(a->ne[1] == b->ne[1]);
|
7460
7539
|
bool is_node = false;
|
7461
7540
|
|
@@ -7464,16 +7543,54 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
|
|
7464
7543
|
is_node = true;
|
7465
7544
|
}
|
7466
7545
|
|
7546
|
+
const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
7547
|
+
|
7467
7548
|
const int64_t ne[4] = {
|
7468
|
-
|
7469
|
-
|
7549
|
+
a->ne[1] * a->ne[0],
|
7550
|
+
OL,
|
7551
|
+
b->ne[2],
|
7552
|
+
1,
|
7470
7553
|
};
|
7471
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx,
|
7554
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
|
7472
7555
|
|
7473
7556
|
int32_t params[] = { s0, p0, d0 };
|
7474
7557
|
ggml_set_op_params(result, params, sizeof(params));
|
7475
7558
|
|
7476
|
-
result->op =
|
7559
|
+
result->op = GGML_OP_CONV_1D_STAGE_0;
|
7560
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7561
|
+
result->src[0] = a;
|
7562
|
+
result->src[1] = b;
|
7563
|
+
|
7564
|
+
return result;
|
7565
|
+
}
|
7566
|
+
|
7567
|
+
// ggml_conv_1d_stage_1
|
7568
|
+
|
7569
|
+
// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
7570
|
+
// a: [OC, IC, K]
|
7571
|
+
// b: [N, OL, IC * K]
|
7572
|
+
// result: [N, OC, OL]
|
7573
|
+
static struct ggml_tensor * ggml_conv_1d_stage_1(
|
7574
|
+
struct ggml_context * ctx,
|
7575
|
+
struct ggml_tensor * a,
|
7576
|
+
struct ggml_tensor * b) {
|
7577
|
+
|
7578
|
+
bool is_node = false;
|
7579
|
+
|
7580
|
+
if (a->grad || b->grad) {
|
7581
|
+
GGML_ASSERT(false); // TODO: implement backward
|
7582
|
+
is_node = true;
|
7583
|
+
}
|
7584
|
+
|
7585
|
+
const int64_t ne[4] = {
|
7586
|
+
b->ne[1],
|
7587
|
+
a->ne[2],
|
7588
|
+
b->ne[2],
|
7589
|
+
1,
|
7590
|
+
};
|
7591
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
7592
|
+
|
7593
|
+
result->op = GGML_OP_CONV_1D_STAGE_1;
|
7477
7594
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7478
7595
|
result->src[0] = a;
|
7479
7596
|
result->src[1] = b;
|
@@ -7481,6 +7598,53 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
|
|
7481
7598
|
return result;
|
7482
7599
|
}
|
7483
7600
|
|
7601
|
+
// ggml_conv_1d
|
7602
|
+
|
7603
|
+
GGML_API struct ggml_tensor * ggml_conv_1d(
|
7604
|
+
struct ggml_context * ctx,
|
7605
|
+
struct ggml_tensor * a,
|
7606
|
+
struct ggml_tensor * b,
|
7607
|
+
int s0,
|
7608
|
+
int p0,
|
7609
|
+
int d0) {
|
7610
|
+
struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
|
7611
|
+
result = ggml_conv_1d_stage_1(ctx, a, result);
|
7612
|
+
return result;
|
7613
|
+
}
|
7614
|
+
|
7615
|
+
// GGML_API struct ggml_tensor * ggml_conv_1d(
|
7616
|
+
// struct ggml_context * ctx,
|
7617
|
+
// struct ggml_tensor * a,
|
7618
|
+
// struct ggml_tensor * b,
|
7619
|
+
// int s0,
|
7620
|
+
// int p0,
|
7621
|
+
// int d0) {
|
7622
|
+
// GGML_ASSERT(ggml_is_matrix(b));
|
7623
|
+
// GGML_ASSERT(a->ne[1] == b->ne[1]);
|
7624
|
+
// bool is_node = false;
|
7625
|
+
|
7626
|
+
// if (a->grad || b->grad) {
|
7627
|
+
// GGML_ASSERT(false); // TODO: implement backward
|
7628
|
+
// is_node = true;
|
7629
|
+
// }
|
7630
|
+
|
7631
|
+
// const int64_t ne[4] = {
|
7632
|
+
// ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
|
7633
|
+
// a->ne[2], 1, 1,
|
7634
|
+
// };
|
7635
|
+
// struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
|
7636
|
+
|
7637
|
+
// int32_t params[] = { s0, p0, d0 };
|
7638
|
+
// ggml_set_op_params(result, params, sizeof(params));
|
7639
|
+
|
7640
|
+
// result->op = GGML_OP_CONV_1D;
|
7641
|
+
// result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7642
|
+
// result->src[0] = a;
|
7643
|
+
// result->src[1] = b;
|
7644
|
+
|
7645
|
+
// return result;
|
7646
|
+
// }
|
7647
|
+
|
7484
7648
|
// ggml_conv_1d_ph
|
7485
7649
|
|
7486
7650
|
struct ggml_tensor* ggml_conv_1d_ph(
|
@@ -7492,6 +7656,50 @@ struct ggml_tensor* ggml_conv_1d_ph(
|
|
7492
7656
|
return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
|
7493
7657
|
}
|
7494
7658
|
|
7659
|
+
// ggml_conv_transpose_1d
|
7660
|
+
|
7661
|
+
static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
|
7662
|
+
return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
|
7663
|
+
}
|
7664
|
+
|
7665
|
+
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
|
7666
|
+
struct ggml_context * ctx,
|
7667
|
+
struct ggml_tensor * a,
|
7668
|
+
struct ggml_tensor * b,
|
7669
|
+
int s0,
|
7670
|
+
int p0,
|
7671
|
+
int d0) {
|
7672
|
+
GGML_ASSERT(ggml_is_matrix(b));
|
7673
|
+
GGML_ASSERT(a->ne[2] == b->ne[1]);
|
7674
|
+
GGML_ASSERT(a->ne[3] == 1);
|
7675
|
+
|
7676
|
+
GGML_ASSERT(p0 == 0);
|
7677
|
+
GGML_ASSERT(d0 == 1);
|
7678
|
+
|
7679
|
+
bool is_node = false;
|
7680
|
+
|
7681
|
+
if (a->grad || b->grad) {
|
7682
|
+
GGML_ASSERT(false); // TODO: implement backward
|
7683
|
+
is_node = true;
|
7684
|
+
}
|
7685
|
+
|
7686
|
+
const int64_t ne[4] = {
|
7687
|
+
ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
|
7688
|
+
a->ne[1], b->ne[2], 1,
|
7689
|
+
};
|
7690
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
7691
|
+
|
7692
|
+
int32_t params[] = { s0, p0, d0 };
|
7693
|
+
ggml_set_op_params(result, params, sizeof(params));
|
7694
|
+
|
7695
|
+
result->op = GGML_OP_CONV_TRANSPOSE_1D;
|
7696
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7697
|
+
result->src[0] = a;
|
7698
|
+
result->src[1] = b;
|
7699
|
+
|
7700
|
+
return result;
|
7701
|
+
}
|
7702
|
+
|
7495
7703
|
// ggml_conv_2d
|
7496
7704
|
|
7497
7705
|
struct ggml_tensor * ggml_conv_2d(
|
@@ -8472,6 +8680,7 @@ void ggml_set_param(
|
|
8472
8680
|
|
8473
8681
|
GGML_ASSERT(tensor->grad == NULL);
|
8474
8682
|
tensor->grad = ggml_dup_tensor(ctx, tensor);
|
8683
|
+
ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
|
8475
8684
|
}
|
8476
8685
|
|
8477
8686
|
// ggml_compute_forward_dup
|
@@ -11058,7 +11267,7 @@ static void ggml_compute_forward_silu_f32(
|
|
11058
11267
|
|
11059
11268
|
#ifndef NDEBUG
|
11060
11269
|
for (int k = 0; k < nc; k++) {
|
11061
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
11270
|
+
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
11062
11271
|
UNUSED(x);
|
11063
11272
|
assert(!isnan(x));
|
11064
11273
|
assert(!isinf(x));
|
@@ -11621,11 +11830,6 @@ static void ggml_compute_forward_mul_mat(
|
|
11621
11830
|
|
11622
11831
|
#if defined(GGML_USE_CLBLAST)
|
11623
11832
|
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
11624
|
-
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
|
11625
|
-
// ref: https://github.com/ggerganov/ggml/pull/224
|
11626
|
-
GGML_ASSERT(ne02 == ne12);
|
11627
|
-
GGML_ASSERT(ne03 == ne13);
|
11628
|
-
|
11629
11833
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
11630
11834
|
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
11631
11835
|
}
|
@@ -12889,28 +13093,25 @@ static void ggml_compute_forward_alibi_f32(
|
|
12889
13093
|
return;
|
12890
13094
|
}
|
12891
13095
|
|
12892
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13096
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12893
13097
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
12894
13098
|
float max_bias;
|
12895
13099
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
12896
13100
|
|
12897
|
-
|
13101
|
+
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
13102
|
+
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
|
13103
|
+
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
|
13104
|
+
//const int64_t ne3 = src0->ne[3]; // 1 -> bsz
|
12898
13105
|
|
12899
|
-
const
|
12900
|
-
const
|
12901
|
-
const int ne2 = src0->ne[2]; // n_head -> this is k
|
12902
|
-
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
13106
|
+
const int64_t n = ggml_nrows(src0);
|
13107
|
+
const int64_t ne2_ne3 = n/ne1; // ne2*ne3
|
12903
13108
|
|
12904
|
-
const
|
12905
|
-
const
|
12906
|
-
|
12907
|
-
const int nb0 = src0->nb[0];
|
12908
|
-
const int nb1 = src0->nb[1];
|
12909
|
-
const int nb2 = src0->nb[2];
|
13109
|
+
const size_t nb0 = src0->nb[0];
|
13110
|
+
const size_t nb1 = src0->nb[1];
|
13111
|
+
const size_t nb2 = src0->nb[2];
|
12910
13112
|
//const int nb3 = src0->nb[3];
|
12911
13113
|
|
12912
13114
|
GGML_ASSERT(nb0 == sizeof(float));
|
12913
|
-
GGML_ASSERT(ne1 + n_past == ne0);
|
12914
13115
|
GGML_ASSERT(n_head == ne2);
|
12915
13116
|
|
12916
13117
|
// add alibi to src0 (KQ_scaled)
|
@@ -12919,9 +13120,9 @@ static void ggml_compute_forward_alibi_f32(
|
|
12919
13120
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
12920
13121
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
12921
13122
|
|
12922
|
-
for (
|
12923
|
-
for (
|
12924
|
-
for (
|
13123
|
+
for (int64_t i = 0; i < ne0; i++) {
|
13124
|
+
for (int64_t j = 0; j < ne1; j++) {
|
13125
|
+
for (int64_t k = 0; k < ne2_ne3; k++) {
|
12925
13126
|
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
12926
13127
|
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
12927
13128
|
|
@@ -12936,7 +13137,6 @@ static void ggml_compute_forward_alibi_f32(
|
|
12936
13137
|
}
|
12937
13138
|
|
12938
13139
|
pdst[0] = i * m_k + src[0];
|
12939
|
-
|
12940
13140
|
}
|
12941
13141
|
}
|
12942
13142
|
}
|
@@ -13636,7 +13836,7 @@ static void ggml_compute_forward_rope_back(
|
|
13636
13836
|
|
13637
13837
|
// ggml_compute_forward_conv_1d
|
13638
13838
|
|
13639
|
-
static void
|
13839
|
+
static void ggml_compute_forward_conv_1d_f16_f32(
|
13640
13840
|
const struct ggml_compute_params * params,
|
13641
13841
|
const struct ggml_tensor * src0,
|
13642
13842
|
const struct ggml_tensor * src1,
|
@@ -13654,42 +13854,33 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
|
13654
13854
|
const int nth = params->nth;
|
13655
13855
|
|
13656
13856
|
const int nk = ne00;
|
13657
|
-
const int nh = nk/2;
|
13658
13857
|
|
13659
|
-
|
13858
|
+
// size of the convolution row - the kernel size unrolled across all input channels
|
13859
|
+
const int ew0 = nk*ne01;
|
13860
|
+
|
13861
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
13862
|
+
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
13863
|
+
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
13660
13864
|
|
13661
|
-
GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
|
13662
13865
|
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
13663
13866
|
GGML_ASSERT(nb10 == sizeof(float));
|
13664
13867
|
|
13665
13868
|
if (params->type == GGML_TASK_INIT) {
|
13666
|
-
// TODO: fix this memset (wsize is overestimated)
|
13667
13869
|
memset(params->wdata, 0, params->wsize);
|
13668
13870
|
|
13669
|
-
|
13670
|
-
{
|
13671
|
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13871
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13672
13872
|
|
13673
|
-
|
13674
|
-
|
13675
|
-
|
13676
|
-
ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
|
13677
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
13678
|
-
dst_data[i00*ew0 + i01] = src[i00];
|
13679
|
-
}
|
13680
|
-
}
|
13681
|
-
}
|
13682
|
-
}
|
13873
|
+
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
13874
|
+
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
13875
|
+
ggml_fp16_t * dst_data = wdata;
|
13683
13876
|
|
13684
|
-
|
13685
|
-
|
13686
|
-
|
13877
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
13878
|
+
for (int64_t ik = 0; ik < nk; ik++) {
|
13879
|
+
const int idx0 = i0*s0 + ik*d0 - p0;
|
13687
13880
|
|
13688
|
-
|
13689
|
-
|
13690
|
-
|
13691
|
-
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
13692
|
-
dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
|
13881
|
+
if(!(idx0 < 0 || idx0 >= ne10)) {
|
13882
|
+
dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]);
|
13883
|
+
}
|
13693
13884
|
}
|
13694
13885
|
}
|
13695
13886
|
}
|
@@ -13702,7 +13893,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
|
13702
13893
|
}
|
13703
13894
|
|
13704
13895
|
// total rows in dst
|
13705
|
-
const int nr =
|
13896
|
+
const int nr = ne2;
|
13706
13897
|
|
13707
13898
|
// rows per thread
|
13708
13899
|
const int dr = (nr + nth - 1)/nth;
|
@@ -13711,23 +13902,22 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
|
13711
13902
|
const int ir0 = dr*ith;
|
13712
13903
|
const int ir1 = MIN(ir0 + dr, nr);
|
13713
13904
|
|
13714
|
-
|
13715
|
-
|
13716
|
-
|
13717
|
-
|
13718
|
-
|
13719
|
-
|
13720
|
-
|
13721
|
-
|
13722
|
-
(ggml_fp16_t *)
|
13723
|
-
|
13724
|
-
dst_data[i0] += v;
|
13905
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13906
|
+
|
13907
|
+
for (int i2 = 0; i2 < ne2; i2++) {
|
13908
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
13909
|
+
float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
|
13910
|
+
|
13911
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
13912
|
+
ggml_vec_dot_f16(ew0, dst_data + i0,
|
13913
|
+
(ggml_fp16_t *) ((char *) src0->data + i1*nb02),
|
13914
|
+
(ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
|
13725
13915
|
}
|
13726
13916
|
}
|
13727
13917
|
}
|
13728
13918
|
}
|
13729
13919
|
|
13730
|
-
static void
|
13920
|
+
static void ggml_compute_forward_conv_1d_f32(
|
13731
13921
|
const struct ggml_compute_params * params,
|
13732
13922
|
const struct ggml_tensor * src0,
|
13733
13923
|
const struct ggml_tensor * src1,
|
@@ -13745,42 +13935,32 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
|
|
13745
13935
|
const int nth = params->nth;
|
13746
13936
|
|
13747
13937
|
const int nk = ne00;
|
13748
|
-
const int nh = nk/2;
|
13749
13938
|
|
13750
|
-
const int ew0 =
|
13939
|
+
const int ew0 = nk*ne01;
|
13940
|
+
|
13941
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
13942
|
+
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
13943
|
+
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
13751
13944
|
|
13752
|
-
GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
|
13753
13945
|
GGML_ASSERT(nb00 == sizeof(float));
|
13754
13946
|
GGML_ASSERT(nb10 == sizeof(float));
|
13755
13947
|
|
13756
13948
|
if (params->type == GGML_TASK_INIT) {
|
13757
|
-
// TODO: fix this memset (wsize is overestimated)
|
13758
13949
|
memset(params->wdata, 0, params->wsize);
|
13759
13950
|
|
13760
|
-
|
13761
|
-
{
|
13762
|
-
float * const wdata = (float *) params->wdata + 0;
|
13951
|
+
float * const wdata = (float *) params->wdata + 0;
|
13763
13952
|
|
13764
|
-
|
13765
|
-
|
13766
|
-
|
13767
|
-
float * dst_data = wdata + i02*ew0*ne00;
|
13768
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
13769
|
-
dst_data[i00*ew0 + i01] = src[i00];
|
13770
|
-
}
|
13771
|
-
}
|
13772
|
-
}
|
13773
|
-
}
|
13953
|
+
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
13954
|
+
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
13955
|
+
float * dst_data = wdata;
|
13774
13956
|
|
13775
|
-
|
13776
|
-
|
13777
|
-
|
13957
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
13958
|
+
for (int64_t ik = 0; ik < nk; ik++) {
|
13959
|
+
const int idx0 = i0*s0 + ik*d0 - p0;
|
13778
13960
|
|
13779
|
-
|
13780
|
-
|
13781
|
-
|
13782
|
-
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
13783
|
-
dst_data[(i10 + nh)*ew0 + i11] = src[i10];
|
13961
|
+
if(!(idx0 < 0 || idx0 >= ne10)) {
|
13962
|
+
dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
|
13963
|
+
}
|
13784
13964
|
}
|
13785
13965
|
}
|
13786
13966
|
}
|
@@ -13802,35 +13982,242 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
|
|
13802
13982
|
const int ir0 = dr*ith;
|
13803
13983
|
const int ir1 = MIN(ir0 + dr, nr);
|
13804
13984
|
|
13805
|
-
|
13806
|
-
|
13807
|
-
|
13808
|
-
|
13809
|
-
|
13810
|
-
|
13811
|
-
|
13812
|
-
|
13813
|
-
(float *)
|
13814
|
-
|
13815
|
-
|
13985
|
+
float * const wdata = (float *) params->wdata + 0;
|
13986
|
+
|
13987
|
+
for (int i2 = 0; i2 < ne2; i2++) {
|
13988
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
13989
|
+
float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
|
13990
|
+
|
13991
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
13992
|
+
ggml_vec_dot_f32(ew0, dst_data + i0,
|
13993
|
+
(float *) ((char *) src0->data + i1*nb02),
|
13994
|
+
(float *) wdata + i2*nb2 + i0*ew0);
|
13995
|
+
}
|
13996
|
+
}
|
13997
|
+
}
|
13998
|
+
}
|
13999
|
+
|
14000
|
+
static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
|
14001
|
+
ggml_fp16_t * A,
|
14002
|
+
ggml_fp16_t * B,
|
14003
|
+
float * C,
|
14004
|
+
const int ith, const int nth) {
|
14005
|
+
// does not seem to make a difference
|
14006
|
+
int64_t m0, m1, n0, n1;
|
14007
|
+
// patches per thread
|
14008
|
+
if (m > n) {
|
14009
|
+
n0 = 0;
|
14010
|
+
n1 = n;
|
14011
|
+
|
14012
|
+
// total patches in dst
|
14013
|
+
const int np = m;
|
14014
|
+
|
14015
|
+
// patches per thread
|
14016
|
+
const int dp = (np + nth - 1)/nth;
|
14017
|
+
|
14018
|
+
// patch range for this thread
|
14019
|
+
m0 = dp*ith;
|
14020
|
+
m1 = MIN(m0 + dp, np);
|
14021
|
+
} else {
|
14022
|
+
m0 = 0;
|
14023
|
+
m1 = m;
|
14024
|
+
|
14025
|
+
// total patches in dst
|
14026
|
+
const int np = n;
|
14027
|
+
|
14028
|
+
// patches per thread
|
14029
|
+
const int dp = (np + nth - 1)/nth;
|
14030
|
+
|
14031
|
+
// patch range for this thread
|
14032
|
+
n0 = dp*ith;
|
14033
|
+
n1 = MIN(n0 + dp, np);
|
14034
|
+
}
|
14035
|
+
|
14036
|
+
// block-tiling attempt
|
14037
|
+
int64_t blck_n = 16;
|
14038
|
+
int64_t blck_m = 16;
|
14039
|
+
|
14040
|
+
// int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
|
14041
|
+
// int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
|
14042
|
+
// if (blck_size > 0) {
|
14043
|
+
// blck_0 = 4;
|
14044
|
+
// blck_1 = blck_size / blck_0;
|
14045
|
+
// if (blck_1 < 0) {
|
14046
|
+
// blck_1 = 1;
|
14047
|
+
// }
|
14048
|
+
// // blck_0 = (int64_t)sqrt(blck_size);
|
14049
|
+
// // blck_1 = blck_0;
|
14050
|
+
// }
|
14051
|
+
// // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
|
14052
|
+
|
14053
|
+
for (int j = n0; j < n1; j+=blck_n) {
|
14054
|
+
for (int i = m0; i < m1; i+=blck_m) {
|
14055
|
+
// printf("i j k => %d %d %d\n", i, j, K);
|
14056
|
+
for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
|
14057
|
+
for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
|
14058
|
+
ggml_vec_dot_f16(k,
|
14059
|
+
C + ii*n + jj,
|
14060
|
+
A + ii * k,
|
14061
|
+
B + jj * k);
|
14062
|
+
}
|
13816
14063
|
}
|
13817
14064
|
}
|
13818
14065
|
}
|
13819
14066
|
}
|
13820
14067
|
|
13821
|
-
|
14068
|
+
// src0: kernel [OC, IC, K]
|
14069
|
+
// src1: signal [N, IC, IL]
|
14070
|
+
// dst: result [N, OL, IC*K]
|
14071
|
+
static void ggml_compute_forward_conv_1d_stage_0_f32(
|
13822
14072
|
const struct ggml_compute_params * params,
|
13823
14073
|
const struct ggml_tensor * src0,
|
13824
14074
|
const struct ggml_tensor * src1,
|
13825
14075
|
struct ggml_tensor * dst) {
|
13826
|
-
|
14076
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
14077
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
14078
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
14079
|
+
|
14080
|
+
int64_t t0 = ggml_perf_time_us();
|
14081
|
+
UNUSED(t0);
|
14082
|
+
|
14083
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
14084
|
+
|
14085
|
+
const int64_t N = ne12;
|
14086
|
+
const int64_t IC = ne11;
|
14087
|
+
const int64_t IL = ne10;
|
14088
|
+
|
14089
|
+
const int64_t K = ne00;
|
14090
|
+
|
14091
|
+
const int64_t OL = ne1;
|
14092
|
+
|
14093
|
+
const int ith = params->ith;
|
14094
|
+
const int nth = params->nth;
|
14095
|
+
|
14096
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
14097
|
+
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
14098
|
+
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
14099
|
+
|
14100
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
14101
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
14102
|
+
|
14103
|
+
if (params->type == GGML_TASK_INIT) {
|
14104
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
14105
|
+
return;
|
14106
|
+
}
|
14107
|
+
|
14108
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14109
|
+
return;
|
14110
|
+
}
|
14111
|
+
|
14112
|
+
// im2col: [N, IC, IL] => [N, OL, IC*K]
|
14113
|
+
{
|
14114
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
|
14115
|
+
|
14116
|
+
for (int64_t in = 0; in < N; in++) {
|
14117
|
+
for (int64_t iol = 0; iol < OL; iol++) {
|
14118
|
+
for (int64_t iic = ith; iic < IC; iic+=nth) {
|
14119
|
+
|
14120
|
+
// micro kernel
|
14121
|
+
ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
|
14122
|
+
const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
|
14123
|
+
|
14124
|
+
for (int64_t ik = 0; ik < K; ik++) {
|
14125
|
+
const int64_t iil = iol*s0 + ik*d0 - p0;
|
14126
|
+
|
14127
|
+
if (!(iil < 0 || iil >= IL)) {
|
14128
|
+
dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]);
|
14129
|
+
}
|
14130
|
+
}
|
14131
|
+
}
|
14132
|
+
}
|
14133
|
+
}
|
14134
|
+
}
|
14135
|
+
}
|
14136
|
+
|
14137
|
+
// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
14138
|
+
// src0: [OC, IC, K]
|
14139
|
+
// src1: [N, OL, IC * K]
|
14140
|
+
// result: [N, OC, OL]
|
14141
|
+
static void ggml_compute_forward_conv_1d_stage_1_f16(
|
14142
|
+
const struct ggml_compute_params * params,
|
14143
|
+
const struct ggml_tensor * src0,
|
14144
|
+
const struct ggml_tensor * src1,
|
14145
|
+
struct ggml_tensor * dst) {
|
14146
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
14147
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
14148
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
14149
|
+
|
14150
|
+
int64_t t0 = ggml_perf_time_us();
|
14151
|
+
UNUSED(t0);
|
14152
|
+
|
14153
|
+
if (params->type == GGML_TASK_INIT) {
|
14154
|
+
return;
|
14155
|
+
}
|
14156
|
+
|
14157
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14158
|
+
return;
|
14159
|
+
}
|
14160
|
+
|
14161
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
14162
|
+
|
14163
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
14164
|
+
GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
|
14165
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
14166
|
+
|
14167
|
+
const int N = ne12;
|
14168
|
+
const int OL = ne11;
|
14169
|
+
|
14170
|
+
const int OC = ne02;
|
14171
|
+
const int IC = ne01;
|
14172
|
+
const int K = ne00;
|
14173
|
+
|
14174
|
+
const int ith = params->ith;
|
14175
|
+
const int nth = params->nth;
|
14176
|
+
|
14177
|
+
int64_t m = OC;
|
14178
|
+
int64_t n = OL;
|
14179
|
+
int64_t k = IC * K;
|
14180
|
+
|
14181
|
+
// [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
14182
|
+
for (int i = 0; i < N; i++) {
|
14183
|
+
ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
|
14184
|
+
ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
|
14185
|
+
float * C = (float *)dst->data + i * m * n; // [m, n]
|
14186
|
+
|
14187
|
+
gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
|
14188
|
+
}
|
14189
|
+
}
|
14190
|
+
|
14191
|
+
static void ggml_compute_forward_conv_1d(
|
14192
|
+
const struct ggml_compute_params * params,
|
14193
|
+
const struct ggml_tensor * src0,
|
14194
|
+
const struct ggml_tensor * src1,
|
14195
|
+
struct ggml_tensor * dst) {
|
14196
|
+
switch(src0->type) {
|
13827
14197
|
case GGML_TYPE_F16:
|
13828
14198
|
{
|
13829
|
-
|
14199
|
+
ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
|
13830
14200
|
} break;
|
13831
14201
|
case GGML_TYPE_F32:
|
13832
14202
|
{
|
13833
|
-
|
14203
|
+
ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
|
14204
|
+
} break;
|
14205
|
+
default:
|
14206
|
+
{
|
14207
|
+
GGML_ASSERT(false);
|
14208
|
+
} break;
|
14209
|
+
}
|
14210
|
+
}
|
14211
|
+
|
14212
|
+
static void ggml_compute_forward_conv_1d_stage_0(
|
14213
|
+
const struct ggml_compute_params * params,
|
14214
|
+
const struct ggml_tensor * src0,
|
14215
|
+
const struct ggml_tensor * src1,
|
14216
|
+
struct ggml_tensor * dst) {
|
14217
|
+
switch(src0->type) {
|
14218
|
+
case GGML_TYPE_F16:
|
14219
|
+
{
|
14220
|
+
ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
|
13834
14221
|
} break;
|
13835
14222
|
default:
|
13836
14223
|
{
|
@@ -13839,7 +14226,26 @@ static void ggml_compute_forward_conv_1d_s1_ph(
|
|
13839
14226
|
}
|
13840
14227
|
}
|
13841
14228
|
|
13842
|
-
static void
|
14229
|
+
static void ggml_compute_forward_conv_1d_stage_1(
|
14230
|
+
const struct ggml_compute_params * params,
|
14231
|
+
const struct ggml_tensor * src0,
|
14232
|
+
const struct ggml_tensor * src1,
|
14233
|
+
struct ggml_tensor * dst) {
|
14234
|
+
switch(src0->type) {
|
14235
|
+
case GGML_TYPE_F16:
|
14236
|
+
{
|
14237
|
+
ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
|
14238
|
+
} break;
|
14239
|
+
default:
|
14240
|
+
{
|
14241
|
+
GGML_ASSERT(false);
|
14242
|
+
} break;
|
14243
|
+
}
|
14244
|
+
}
|
14245
|
+
|
14246
|
+
// ggml_compute_forward_conv_transpose_1d
|
14247
|
+
|
14248
|
+
static void ggml_compute_forward_conv_transpose_1d_f16_f32(
|
13843
14249
|
const struct ggml_compute_params * params,
|
13844
14250
|
const struct ggml_tensor * src0,
|
13845
14251
|
const struct ggml_tensor * src1,
|
@@ -13856,43 +14262,38 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
|
13856
14262
|
const int ith = params->ith;
|
13857
14263
|
const int nth = params->nth;
|
13858
14264
|
|
13859
|
-
const int nk = ne00;
|
13860
|
-
const int nh = nk/2;
|
13861
|
-
|
13862
|
-
const int ew0 = ggml_up32(ne01);
|
14265
|
+
const int nk = ne00*ne01*ne02;
|
13863
14266
|
|
13864
|
-
GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
|
13865
14267
|
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
13866
14268
|
GGML_ASSERT(nb10 == sizeof(float));
|
13867
14269
|
|
13868
14270
|
if (params->type == GGML_TASK_INIT) {
|
13869
|
-
// TODO: fix this memset (wsize is overestimated)
|
13870
14271
|
memset(params->wdata, 0, params->wsize);
|
13871
14272
|
|
13872
|
-
//
|
14273
|
+
// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
13873
14274
|
{
|
13874
14275
|
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13875
14276
|
|
13876
14277
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
13877
14278
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
13878
14279
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
|
13879
|
-
ggml_fp16_t * dst_data = wdata +
|
14280
|
+
ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
|
13880
14281
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
13881
|
-
dst_data[i00*
|
14282
|
+
dst_data[i00*ne02 + i02] = src[i00];
|
13882
14283
|
}
|
13883
14284
|
}
|
13884
14285
|
}
|
13885
14286
|
}
|
13886
14287
|
|
13887
|
-
//
|
14288
|
+
// permute source data (src1) from (L x Cin) to (Cin x L)
|
13888
14289
|
{
|
13889
|
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata +
|
14290
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
|
14291
|
+
ggml_fp16_t * dst_data = wdata;
|
13890
14292
|
|
13891
14293
|
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
13892
14294
|
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
13893
|
-
ggml_fp16_t * dst_data = wdata;
|
13894
14295
|
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
13895
|
-
dst_data[
|
14296
|
+
dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
|
13896
14297
|
}
|
13897
14298
|
}
|
13898
14299
|
}
|
@@ -13904,8 +14305,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
|
13904
14305
|
return;
|
13905
14306
|
}
|
13906
14307
|
|
14308
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
14309
|
+
|
13907
14310
|
// total rows in dst
|
13908
|
-
const int nr =
|
14311
|
+
const int nr = ne1;
|
13909
14312
|
|
13910
14313
|
// rows per thread
|
13911
14314
|
const int dr = (nr + nth - 1)/nth;
|
@@ -13914,23 +14317,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
|
13914
14317
|
const int ir0 = dr*ith;
|
13915
14318
|
const int ir1 = MIN(ir0 + dr, nr);
|
13916
14319
|
|
14320
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
14321
|
+
ggml_fp16_t * const wdata_src = wdata + nk;
|
14322
|
+
|
13917
14323
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
13918
14324
|
float * dst_data = (float *)((char *) dst->data + i1*nb1);
|
13919
|
-
|
13920
|
-
|
13921
|
-
|
13922
|
-
|
13923
|
-
|
13924
|
-
|
13925
|
-
(ggml_fp16_t *)
|
13926
|
-
|
13927
|
-
dst_data[
|
14325
|
+
ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
|
14326
|
+
for (int i10 = 0; i10 < ne10; i10++) {
|
14327
|
+
const int i1n = i10*ne11;
|
14328
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
14329
|
+
float v = 0;
|
14330
|
+
ggml_vec_dot_f16(ne02, &v,
|
14331
|
+
(ggml_fp16_t *) wdata_src + i1n,
|
14332
|
+
(ggml_fp16_t *) wdata_kernel + i00*ne02);
|
14333
|
+
dst_data[i10*s0 + i00] += v;
|
13928
14334
|
}
|
13929
14335
|
}
|
13930
14336
|
}
|
13931
14337
|
}
|
13932
14338
|
|
13933
|
-
static void
|
14339
|
+
static void ggml_compute_forward_conv_transpose_1d_f32(
|
13934
14340
|
const struct ggml_compute_params * params,
|
13935
14341
|
const struct ggml_tensor * src0,
|
13936
14342
|
const struct ggml_tensor * src1,
|
@@ -13947,29 +14353,24 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
13947
14353
|
const int ith = params->ith;
|
13948
14354
|
const int nth = params->nth;
|
13949
14355
|
|
13950
|
-
const int nk = ne00;
|
13951
|
-
const int nh = nk/2;
|
13952
|
-
|
13953
|
-
const int ew0 = ggml_up32(ne01);
|
14356
|
+
const int nk = ne00*ne01*ne02;
|
13954
14357
|
|
13955
|
-
GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
|
13956
14358
|
GGML_ASSERT(nb00 == sizeof(float));
|
13957
14359
|
GGML_ASSERT(nb10 == sizeof(float));
|
13958
14360
|
|
13959
14361
|
if (params->type == GGML_TASK_INIT) {
|
13960
|
-
// TODO: fix this memset (wsize is overestimated)
|
13961
14362
|
memset(params->wdata, 0, params->wsize);
|
13962
14363
|
|
13963
|
-
// prepare kernel data (src0)
|
14364
|
+
// prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
13964
14365
|
{
|
13965
14366
|
float * const wdata = (float *) params->wdata + 0;
|
13966
14367
|
|
13967
14368
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
13968
14369
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
13969
14370
|
const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
|
13970
|
-
float * dst_data = wdata +
|
14371
|
+
float * dst_data = wdata + i01*ne00*ne02;
|
13971
14372
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
13972
|
-
dst_data[i00*
|
14373
|
+
dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00];
|
13973
14374
|
}
|
13974
14375
|
}
|
13975
14376
|
}
|
@@ -13977,13 +14378,13 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
13977
14378
|
|
13978
14379
|
// prepare source data (src1)
|
13979
14380
|
{
|
13980
|
-
float * const wdata = (float *) params->wdata +
|
14381
|
+
float * const wdata = (float *) params->wdata + nk;
|
14382
|
+
float * dst_data = wdata;
|
13981
14383
|
|
13982
14384
|
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
13983
14385
|
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
13984
|
-
float * dst_data = wdata;
|
13985
14386
|
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
13986
|
-
dst_data[
|
14387
|
+
dst_data[i10*ne11 + i11] = src[i10];
|
13987
14388
|
}
|
13988
14389
|
}
|
13989
14390
|
}
|
@@ -13995,8 +14396,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
13995
14396
|
return;
|
13996
14397
|
}
|
13997
14398
|
|
14399
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
14400
|
+
|
13998
14401
|
// total rows in dst
|
13999
|
-
const int nr =
|
14402
|
+
const int nr = ne1;
|
14000
14403
|
|
14001
14404
|
// rows per thread
|
14002
14405
|
const int dr = (nr + nth - 1)/nth;
|
@@ -14005,23 +14408,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
14005
14408
|
const int ir0 = dr*ith;
|
14006
14409
|
const int ir1 = MIN(ir0 + dr, nr);
|
14007
14410
|
|
14411
|
+
float * const wdata = (float *) params->wdata + 0;
|
14412
|
+
float * const wdata_src = wdata + nk;
|
14413
|
+
|
14008
14414
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
14009
14415
|
float * dst_data = (float *)((char *) dst->data + i1*nb1);
|
14010
|
-
|
14011
|
-
|
14012
|
-
|
14013
|
-
|
14014
|
-
|
14015
|
-
|
14016
|
-
|
14017
|
-
|
14018
|
-
dst_data[
|
14416
|
+
float * wdata_kernel = wdata + i1*ne02*ne00;
|
14417
|
+
for (int i10 = 0; i10 < ne10; i10++) {
|
14418
|
+
const int i1n = i10*ne11;
|
14419
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
14420
|
+
float v = 0;
|
14421
|
+
ggml_vec_dot_f32(ne02, &v,
|
14422
|
+
wdata_src + i1n,
|
14423
|
+
wdata_kernel + i00*ne02);
|
14424
|
+
dst_data[i10*s0 + i00] += v;
|
14019
14425
|
}
|
14020
14426
|
}
|
14021
14427
|
}
|
14022
14428
|
}
|
14023
14429
|
|
14024
|
-
static void
|
14430
|
+
static void ggml_compute_forward_conv_transpose_1d(
|
14025
14431
|
const struct ggml_compute_params * params,
|
14026
14432
|
const struct ggml_tensor * src0,
|
14027
14433
|
const struct ggml_tensor * src1,
|
@@ -14029,11 +14435,11 @@ static void ggml_compute_forward_conv_1d_s2_ph(
|
|
14029
14435
|
switch (src0->type) {
|
14030
14436
|
case GGML_TYPE_F16:
|
14031
14437
|
{
|
14032
|
-
|
14438
|
+
ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
|
14033
14439
|
} break;
|
14034
14440
|
case GGML_TYPE_F32:
|
14035
14441
|
{
|
14036
|
-
|
14442
|
+
ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
|
14037
14443
|
} break;
|
14038
14444
|
default:
|
14039
14445
|
{
|
@@ -14042,27 +14448,6 @@ static void ggml_compute_forward_conv_1d_s2_ph(
|
|
14042
14448
|
}
|
14043
14449
|
}
|
14044
14450
|
|
14045
|
-
// ggml_compute_forward_conv_1d
|
14046
|
-
|
14047
|
-
static void ggml_compute_forward_conv_1d(
|
14048
|
-
const struct ggml_compute_params * params,
|
14049
|
-
const struct ggml_tensor * src0,
|
14050
|
-
const struct ggml_tensor * src1,
|
14051
|
-
struct ggml_tensor * dst) {
|
14052
|
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
14053
|
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
14054
|
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
14055
|
-
GGML_ASSERT(d0 == 1); // dilation not supported
|
14056
|
-
GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported
|
14057
|
-
if (s0 == 1) {
|
14058
|
-
ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst);
|
14059
|
-
} else if (s0 == 2) {
|
14060
|
-
ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
|
14061
|
-
} else {
|
14062
|
-
GGML_ASSERT(false); // only stride 1 and 2 supported
|
14063
|
-
}
|
14064
|
-
}
|
14065
|
-
|
14066
14451
|
// ggml_compute_forward_conv_2d
|
14067
14452
|
|
14068
14453
|
static void ggml_compute_forward_conv_2d_f16_f32(
|
@@ -14077,7 +14462,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
|
|
14077
14462
|
int64_t t0 = ggml_perf_time_us();
|
14078
14463
|
UNUSED(t0);
|
14079
14464
|
|
14080
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
14465
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
14081
14466
|
|
14082
14467
|
const int ith = params->ith;
|
14083
14468
|
const int nth = params->nth;
|
@@ -14105,20 +14490,22 @@ static void ggml_compute_forward_conv_2d_f16_f32(
|
|
14105
14490
|
{
|
14106
14491
|
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
14107
14492
|
|
14108
|
-
for (int
|
14109
|
-
|
14110
|
-
|
14111
|
-
|
14112
|
-
|
14113
|
-
for (int
|
14114
|
-
for (int
|
14115
|
-
for (int
|
14116
|
-
|
14117
|
-
|
14118
|
-
|
14119
|
-
|
14120
|
-
|
14121
|
-
|
14493
|
+
for (int i13 = 0; i13 < ne13; i13++) {
|
14494
|
+
for (int i12 = 0; i12 < ne12; i12++) {
|
14495
|
+
const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
|
14496
|
+
ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
|
14497
|
+
|
14498
|
+
for (int i1 = 0; i1 < ne1; i1++) {
|
14499
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
14500
|
+
for (int ik1 = 0; ik1 < nk1; ik1++) {
|
14501
|
+
for (int ik0 = 0; ik0 < nk0; ik0++) {
|
14502
|
+
const int idx0 = i0*s0 + ik0*d0 - p0;
|
14503
|
+
const int idx1 = i1*s1 + ik1*d1 - p1;
|
14504
|
+
|
14505
|
+
if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
|
14506
|
+
dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
|
14507
|
+
GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
|
14508
|
+
}
|
14122
14509
|
}
|
14123
14510
|
}
|
14124
14511
|
}
|
@@ -16401,6 +16788,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
16401
16788
|
{
|
16402
16789
|
ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
|
16403
16790
|
} break;
|
16791
|
+
case GGML_OP_CONV_1D_STAGE_0:
|
16792
|
+
{
|
16793
|
+
ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
|
16794
|
+
} break;
|
16795
|
+
case GGML_OP_CONV_1D_STAGE_1:
|
16796
|
+
{
|
16797
|
+
ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
|
16798
|
+
} break;
|
16799
|
+
case GGML_OP_CONV_TRANSPOSE_1D:
|
16800
|
+
{
|
16801
|
+
ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
|
16802
|
+
} break;
|
16404
16803
|
case GGML_OP_CONV_2D:
|
16405
16804
|
{
|
16406
16805
|
ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
|
@@ -17326,10 +17725,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
17326
17725
|
{
|
17327
17726
|
GGML_ASSERT(false); // TODO: not implemented
|
17328
17727
|
} break;
|
17728
|
+
case GGML_OP_CONV_1D_STAGE_0:
|
17729
|
+
{
|
17730
|
+
GGML_ASSERT(false); // TODO: not implemented
|
17731
|
+
} break;
|
17732
|
+
case GGML_OP_CONV_1D_STAGE_1:
|
17733
|
+
{
|
17734
|
+
GGML_ASSERT(false); // TODO: not implemented
|
17735
|
+
} break;
|
17329
17736
|
case GGML_OP_CONV_2D:
|
17330
17737
|
{
|
17331
17738
|
GGML_ASSERT(false); // TODO: not implemented
|
17332
17739
|
} break;
|
17740
|
+
case GGML_OP_CONV_TRANSPOSE_1D:
|
17741
|
+
{
|
17742
|
+
GGML_ASSERT(false); // TODO: not implemented
|
17743
|
+
} break;
|
17333
17744
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
17334
17745
|
{
|
17335
17746
|
GGML_ASSERT(false); // TODO: not implemented
|
@@ -18171,21 +18582,68 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|
18171
18582
|
GGML_ASSERT(node->src[1]->ne[2] == 1);
|
18172
18583
|
GGML_ASSERT(node->src[1]->ne[3] == 1);
|
18173
18584
|
|
18585
|
+
const int64_t ne00 = node->src[0]->ne[0];
|
18586
|
+
const int64_t ne01 = node->src[0]->ne[1];
|
18587
|
+
const int64_t ne02 = node->src[0]->ne[2];
|
18588
|
+
|
18589
|
+
const int64_t ne10 = node->src[1]->ne[0];
|
18590
|
+
const int64_t ne11 = node->src[1]->ne[1];
|
18591
|
+
|
18592
|
+
const int64_t ne0 = node->ne[0];
|
18593
|
+
const int64_t ne1 = node->ne[1];
|
18594
|
+
const int64_t nk = ne00;
|
18595
|
+
const int64_t ew0 = nk * ne01;
|
18596
|
+
|
18597
|
+
UNUSED(ne02);
|
18598
|
+
UNUSED(ne10);
|
18599
|
+
UNUSED(ne11);
|
18600
|
+
|
18174
18601
|
size_t cur = 0;
|
18175
|
-
const int nk = node->src[0]->ne[0];
|
18176
18602
|
|
18177
18603
|
if (node->src[0]->type == GGML_TYPE_F16 &&
|
18178
|
-
|
18179
|
-
cur = sizeof(ggml_fp16_t)*(
|
18180
|
-
|
18181
|
-
|
18182
|
-
|
18604
|
+
node->src[1]->type == GGML_TYPE_F32) {
|
18605
|
+
cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
|
18606
|
+
} else if (node->src[0]->type == GGML_TYPE_F32 &&
|
18607
|
+
node->src[1]->type == GGML_TYPE_F32) {
|
18608
|
+
cur = sizeof(float)*(ne0*ne1*ew0);
|
18609
|
+
} else {
|
18610
|
+
GGML_ASSERT(false);
|
18611
|
+
}
|
18612
|
+
|
18613
|
+
work_size = MAX(work_size, cur);
|
18614
|
+
} break;
|
18615
|
+
case GGML_OP_CONV_1D_STAGE_0:
|
18616
|
+
{
|
18617
|
+
n_tasks = n_threads;
|
18618
|
+
} break;
|
18619
|
+
case GGML_OP_CONV_1D_STAGE_1:
|
18620
|
+
{
|
18621
|
+
n_tasks = n_threads;
|
18622
|
+
} break;
|
18623
|
+
case GGML_OP_CONV_TRANSPOSE_1D:
|
18624
|
+
{
|
18625
|
+
n_tasks = n_threads;
|
18626
|
+
|
18627
|
+
GGML_ASSERT(node->src[0]->ne[3] == 1);
|
18628
|
+
GGML_ASSERT(node->src[1]->ne[2] == 1);
|
18629
|
+
GGML_ASSERT(node->src[1]->ne[3] == 1);
|
18630
|
+
|
18631
|
+
const int64_t ne00 = node->src[0]->ne[0]; // K
|
18632
|
+
const int64_t ne01 = node->src[0]->ne[1]; // Cout
|
18633
|
+
const int64_t ne02 = node->src[0]->ne[2]; // Cin
|
18634
|
+
|
18635
|
+
const int64_t ne10 = node->src[1]->ne[0]; // L
|
18636
|
+
const int64_t ne11 = node->src[1]->ne[1]; // Cin
|
18637
|
+
|
18638
|
+
size_t cur = 0;
|
18639
|
+
if (node->src[0]->type == GGML_TYPE_F16 &&
|
18640
|
+
node->src[1]->type == GGML_TYPE_F32) {
|
18641
|
+
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
|
18642
|
+
cur += sizeof(ggml_fp16_t)*ne10*ne11;
|
18183
18643
|
} else if (node->src[0]->type == GGML_TYPE_F32 &&
|
18184
|
-
|
18185
|
-
cur
|
18186
|
-
|
18187
|
-
( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
|
18188
|
-
);
|
18644
|
+
node->src[1]->type == GGML_TYPE_F32) {
|
18645
|
+
cur += sizeof(float)*ne00*ne01*ne02;
|
18646
|
+
cur += sizeof(float)*ne10*ne11;
|
18189
18647
|
} else {
|
18190
18648
|
GGML_ASSERT(false);
|
18191
18649
|
}
|
@@ -19311,7 +19769,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19311
19769
|
if (callback) {
|
19312
19770
|
callback(callback_data, accum_step, &sched, &cancel);
|
19313
19771
|
if (cancel) {
|
19314
|
-
|
19772
|
+
return GGML_OPT_CANCEL;
|
19315
19773
|
}
|
19316
19774
|
}
|
19317
19775
|
// ggml_graph_reset (gf);
|
@@ -19320,9 +19778,6 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19320
19778
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19321
19779
|
fx += ggml_get_f32_1d(f, 0);
|
19322
19780
|
}
|
19323
|
-
if (cancel) {
|
19324
|
-
return GGML_OPT_DID_NOT_CONVERGE;
|
19325
|
-
}
|
19326
19781
|
fx *= accum_norm;
|
19327
19782
|
|
19328
19783
|
opt->adam.fx_prev = fx;
|
@@ -19348,9 +19803,6 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19348
19803
|
|
19349
19804
|
// run the optimizer
|
19350
19805
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
19351
|
-
if (cancel) {
|
19352
|
-
break;
|
19353
|
-
}
|
19354
19806
|
opt->iter = iter0 + t + 1;
|
19355
19807
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
19356
19808
|
|
@@ -19408,7 +19860,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19408
19860
|
if (callback) {
|
19409
19861
|
callback(callback_data, accum_step, &sched, &cancel);
|
19410
19862
|
if (cancel) {
|
19411
|
-
|
19863
|
+
return GGML_OPT_CANCEL;;
|
19412
19864
|
}
|
19413
19865
|
}
|
19414
19866
|
// ggml_graph_reset (gf);
|
@@ -19417,9 +19869,6 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
19417
19869
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19418
19870
|
fx += ggml_get_f32_1d(f, 0);
|
19419
19871
|
}
|
19420
|
-
if (cancel) {
|
19421
|
-
break;
|
19422
|
-
}
|
19423
19872
|
fx *= accum_norm;
|
19424
19873
|
|
19425
19874
|
opt->loss_after = fx;
|
@@ -19538,7 +19987,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
19538
19987
|
finit = *fx;
|
19539
19988
|
dgtest = params->lbfgs.ftol*dginit;
|
19540
19989
|
|
19541
|
-
while (
|
19990
|
+
while (true) {
|
19542
19991
|
ggml_vec_cpy_f32(nx, x, xp);
|
19543
19992
|
ggml_vec_mad_f32(nx, x, d, *step);
|
19544
19993
|
|
@@ -19554,7 +20003,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
19554
20003
|
float sched = 0;
|
19555
20004
|
callback(callback_data, accum_step, &sched, cancel);
|
19556
20005
|
if (*cancel) {
|
19557
|
-
|
20006
|
+
return GGML_OPT_CANCEL;
|
19558
20007
|
}
|
19559
20008
|
}
|
19560
20009
|
// ggml_graph_reset (gf);
|
@@ -19563,9 +20012,6 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
19563
20012
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19564
20013
|
*fx += ggml_get_f32_1d(f, 0);
|
19565
20014
|
}
|
19566
|
-
if (*cancel) {
|
19567
|
-
break;
|
19568
|
-
}
|
19569
20015
|
*fx *= accum_norm;
|
19570
20016
|
|
19571
20017
|
}
|
@@ -19698,7 +20144,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19698
20144
|
float sched = 0;
|
19699
20145
|
callback(callback_data, accum_step, &sched, &cancel);
|
19700
20146
|
if (cancel) {
|
19701
|
-
|
20147
|
+
return GGML_OPT_CANCEL;
|
19702
20148
|
}
|
19703
20149
|
}
|
19704
20150
|
// ggml_graph_reset (gf);
|
@@ -19707,9 +20153,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19707
20153
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19708
20154
|
fx += ggml_get_f32_1d(f, 0);
|
19709
20155
|
}
|
19710
|
-
if (cancel) {
|
19711
|
-
return GGML_OPT_DID_NOT_CONVERGE;
|
19712
|
-
}
|
19713
20156
|
fx *= accum_norm;
|
19714
20157
|
|
19715
20158
|
opt->loss_before = fx;
|
@@ -19768,9 +20211,13 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19768
20211
|
ggml_vec_cpy_f32(nx, xp, x);
|
19769
20212
|
ggml_vec_cpy_f32(nx, gp, g);
|
19770
20213
|
|
20214
|
+
// TODO: instead of passing &cancel here, use the return code of the linesearch
|
20215
|
+
// to determine if the optimization should be cancelled
|
20216
|
+
// this is a simple change, but not doing this atm, since I don't have a nice
|
20217
|
+
// way to test and don't want to break something with so many changes lined up
|
19771
20218
|
ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
|
19772
|
-
if (
|
19773
|
-
|
20219
|
+
if (cancel) {
|
20220
|
+
return GGML_OPT_CANCEL;
|
19774
20221
|
}
|
19775
20222
|
|
19776
20223
|
if (ls < 0) {
|