llama_cpp 0.5.3 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +209 -82
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +163 -84
- data/ext/llama_cpp/src/ggml-metal.metal +121 -38
- data/ext/llama_cpp/src/ggml.c +1596 -842
- data/ext/llama_cpp/src/ggml.h +116 -35
- data/ext/llama_cpp/src/llama.cpp +1015 -586
- data/ext/llama_cpp/src/llama.h +304 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -89,7 +89,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo
|
|
89
89
|
|
90
90
|
static int pthread_join(pthread_t thread, void * unused) {
|
91
91
|
(void) unused;
|
92
|
-
|
92
|
+
int ret = (int) WaitForSingleObject(thread, INFINITE);
|
93
|
+
CloseHandle(thread);
|
94
|
+
return ret;
|
93
95
|
}
|
94
96
|
|
95
97
|
static int sched_yield (void) {
|
@@ -134,6 +136,7 @@ typedef void * thread_ret_t;
|
|
134
136
|
|
135
137
|
#define GGML_SOFT_MAX_UNROLL 4
|
136
138
|
#define GGML_VEC_DOT_UNROLL 2
|
139
|
+
#define GGML_VEC_MAD_UNROLL 32
|
137
140
|
|
138
141
|
//
|
139
142
|
// logging
|
@@ -242,18 +245,18 @@ inline static void * ggml_aligned_malloc(size_t size) {
|
|
242
245
|
//
|
243
246
|
|
244
247
|
#define GGML_TENSOR_UNARY_OP_LOCALS \
|
245
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
246
|
-
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
247
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
248
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
248
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
249
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
250
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
251
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
249
252
|
|
250
253
|
#define GGML_TENSOR_BINARY_OP_LOCALS \
|
251
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
252
|
-
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
253
|
-
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
254
|
-
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
255
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
256
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
254
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
255
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
256
|
+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
257
|
+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
258
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
259
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
257
260
|
|
258
261
|
#if defined(GGML_USE_ACCELERATE)
|
259
262
|
#include <Accelerate/Accelerate.h>
|
@@ -1863,7 +1866,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
1863
1866
|
#define GGML_F16x8_ADD vaddq_f16
|
1864
1867
|
#define GGML_F16x8_MUL vmulq_f16
|
1865
1868
|
#define GGML_F16x8_REDUCE(res, x) \
|
1866
|
-
{
|
1869
|
+
do { \
|
1867
1870
|
int offset = GGML_F16_ARR >> 1; \
|
1868
1871
|
for (int i = 0; i < offset; ++i) { \
|
1869
1872
|
x[i] = vaddq_f16(x[i], x[offset+i]); \
|
@@ -1879,7 +1882,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
1879
1882
|
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
|
1880
1883
|
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
|
1881
1884
|
res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
|
1882
|
-
}
|
1885
|
+
} while (0)
|
1883
1886
|
|
1884
1887
|
#define GGML_F16_VEC GGML_F16x8
|
1885
1888
|
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
|
@@ -1940,7 +1943,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
1940
1943
|
#define GGML_F32x8_ADD _mm256_add_ps
|
1941
1944
|
#define GGML_F32x8_MUL _mm256_mul_ps
|
1942
1945
|
#define GGML_F32x8_REDUCE(res, x) \
|
1943
|
-
{
|
1946
|
+
do { \
|
1944
1947
|
int offset = GGML_F32_ARR >> 1; \
|
1945
1948
|
for (int i = 0; i < offset; ++i) { \
|
1946
1949
|
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
|
@@ -1957,7 +1960,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
1957
1960
|
_mm256_extractf128_ps(x[0], 1)); \
|
1958
1961
|
const __m128 t1 = _mm_hadd_ps(t0, t0); \
|
1959
1962
|
res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
|
1960
|
-
}
|
1963
|
+
} while (0)
|
1961
1964
|
// TODO: is this optimal ?
|
1962
1965
|
|
1963
1966
|
#define GGML_F32_VEC GGML_F32x8
|
@@ -3707,6 +3710,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
|
|
3707
3710
|
#endif
|
3708
3711
|
}
|
3709
3712
|
|
3713
|
+
// xs and vs are byte strides of x and v
|
3714
|
+
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
|
3715
|
+
|
3716
|
+
const float * restrict x[GGML_VEC_MAD_UNROLL];
|
3717
|
+
const float * restrict v[GGML_VEC_MAD_UNROLL];
|
3718
|
+
|
3719
|
+
for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
|
3720
|
+
x[i] = (const float *) ((const char *) xv + i*xs);
|
3721
|
+
v[i] = (const float *) ((const char *) vv + i*vs);
|
3722
|
+
}
|
3723
|
+
|
3724
|
+
#if defined(GGML_SIMD)
|
3725
|
+
const int np = (n & ~(GGML_F32_STEP - 1));
|
3726
|
+
|
3727
|
+
GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
|
3728
|
+
|
3729
|
+
for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
|
3730
|
+
vx[k] = GGML_F32_VEC_SET1(v[k][0]);
|
3731
|
+
}
|
3732
|
+
|
3733
|
+
GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
|
3734
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
3735
|
+
|
3736
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
3737
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
3738
|
+
ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
|
3739
|
+
|
3740
|
+
for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
|
3741
|
+
ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
|
3742
|
+
ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
|
3743
|
+
}
|
3744
|
+
|
3745
|
+
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
3746
|
+
}
|
3747
|
+
}
|
3748
|
+
|
3749
|
+
// leftovers
|
3750
|
+
for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
|
3751
|
+
for (int i = np; i < n; ++i) {
|
3752
|
+
y[i] += x[k][i]*v[k][0];
|
3753
|
+
}
|
3754
|
+
}
|
3755
|
+
#else
|
3756
|
+
// scalar
|
3757
|
+
for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
|
3758
|
+
for (int i = 0; i < n; ++i) {
|
3759
|
+
y[i] += x[k][i]*v[k][0];
|
3760
|
+
}
|
3761
|
+
}
|
3762
|
+
#endif
|
3763
|
+
}
|
3764
|
+
|
3710
3765
|
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
3711
3766
|
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
3712
3767
|
#if defined(GGML_USE_ACCELERATE)
|
@@ -4392,10 +4447,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
4392
4447
|
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
4393
4448
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
4394
4449
|
|
4395
|
-
return
|
4396
|
-
|
4397
|
-
|
4398
|
-
(t0->ne[3] == t1->ne[3]);
|
4450
|
+
return (t0->ne[1] == t1->ne[1]) &&
|
4451
|
+
(t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
|
4452
|
+
(t1->ne[3]%t0->ne[3] == 0);
|
4399
4453
|
}
|
4400
4454
|
|
4401
4455
|
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
@@ -5065,43 +5119,78 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
5065
5119
|
return tensor;
|
5066
5120
|
}
|
5067
5121
|
|
5122
|
+
void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
|
5123
|
+
const int64_t ne2 = tensor->ne[2];
|
5124
|
+
const int64_t ne1 = tensor->ne[1];
|
5125
|
+
const int64_t ne0 = tensor->ne[0];
|
5126
|
+
|
5127
|
+
const int64_t i3_ = (i/(ne2*ne1*ne0));
|
5128
|
+
const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
|
5129
|
+
const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
|
5130
|
+
const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
|
5131
|
+
|
5132
|
+
if (i0) {
|
5133
|
+
* i0 = i0_;
|
5134
|
+
}
|
5135
|
+
if (i1) {
|
5136
|
+
* i1 = i1_;
|
5137
|
+
}
|
5138
|
+
if (i2) {
|
5139
|
+
* i2 = i2_;
|
5140
|
+
}
|
5141
|
+
if (i3) {
|
5142
|
+
* i3 = i3_;
|
5143
|
+
}
|
5144
|
+
}
|
5145
|
+
|
5068
5146
|
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
5147
|
+
if (!ggml_is_contiguous(tensor)) {
|
5148
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
5149
|
+
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
5150
|
+
return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
|
5151
|
+
}
|
5069
5152
|
switch (tensor->type) {
|
5070
5153
|
case GGML_TYPE_I8:
|
5071
5154
|
{
|
5072
5155
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
5073
5156
|
return ((int8_t *)(tensor->data))[i];
|
5074
|
-
}
|
5157
|
+
}
|
5075
5158
|
case GGML_TYPE_I16:
|
5076
5159
|
{
|
5077
5160
|
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
5078
5161
|
return ((int16_t *)(tensor->data))[i];
|
5079
|
-
}
|
5162
|
+
}
|
5080
5163
|
case GGML_TYPE_I32:
|
5081
5164
|
{
|
5082
5165
|
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
5083
5166
|
return ((int32_t *)(tensor->data))[i];
|
5084
|
-
}
|
5167
|
+
}
|
5085
5168
|
case GGML_TYPE_F16:
|
5086
5169
|
{
|
5087
5170
|
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
|
5088
5171
|
return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
|
5089
|
-
}
|
5172
|
+
}
|
5090
5173
|
case GGML_TYPE_F32:
|
5091
5174
|
{
|
5092
5175
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
5093
5176
|
return ((float *)(tensor->data))[i];
|
5094
|
-
}
|
5177
|
+
}
|
5095
5178
|
default:
|
5096
5179
|
{
|
5097
5180
|
GGML_ASSERT(false);
|
5098
|
-
}
|
5181
|
+
}
|
5099
5182
|
}
|
5100
5183
|
|
5101
5184
|
return 0.0f;
|
5102
5185
|
}
|
5103
5186
|
|
5104
5187
|
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
5188
|
+
if (!ggml_is_contiguous(tensor)) {
|
5189
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
5190
|
+
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
5191
|
+
ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
|
5192
|
+
return;
|
5193
|
+
}
|
5105
5194
|
switch (tensor->type) {
|
5106
5195
|
case GGML_TYPE_I8:
|
5107
5196
|
{
|
@@ -5135,43 +5224,104 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
5135
5224
|
}
|
5136
5225
|
}
|
5137
5226
|
|
5227
|
+
int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
|
5228
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
5229
|
+
switch (tensor->type) {
|
5230
|
+
case GGML_TYPE_I8:
|
5231
|
+
return ((int8_t *) data)[0];
|
5232
|
+
case GGML_TYPE_I16:
|
5233
|
+
return ((int16_t *) data)[0];
|
5234
|
+
case GGML_TYPE_I32:
|
5235
|
+
return ((int32_t *) data)[0];
|
5236
|
+
case GGML_TYPE_F16:
|
5237
|
+
return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
|
5238
|
+
case GGML_TYPE_F32:
|
5239
|
+
return ((float *) data)[0];
|
5240
|
+
default:
|
5241
|
+
GGML_ASSERT(false);
|
5242
|
+
}
|
5243
|
+
|
5244
|
+
return 0.0f;
|
5245
|
+
}
|
5246
|
+
|
5247
|
+
void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
|
5248
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
5249
|
+
switch (tensor->type) {
|
5250
|
+
case GGML_TYPE_I8:
|
5251
|
+
{
|
5252
|
+
((int8_t *)(data))[0] = value;
|
5253
|
+
} break;
|
5254
|
+
case GGML_TYPE_I16:
|
5255
|
+
{
|
5256
|
+
((int16_t *)(data))[0] = value;
|
5257
|
+
} break;
|
5258
|
+
case GGML_TYPE_I32:
|
5259
|
+
{
|
5260
|
+
((int32_t *)(data))[0] = value;
|
5261
|
+
} break;
|
5262
|
+
case GGML_TYPE_F16:
|
5263
|
+
{
|
5264
|
+
((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
|
5265
|
+
} break;
|
5266
|
+
case GGML_TYPE_F32:
|
5267
|
+
{
|
5268
|
+
((float *)(data))[0] = value;
|
5269
|
+
} break;
|
5270
|
+
default:
|
5271
|
+
{
|
5272
|
+
GGML_ASSERT(false);
|
5273
|
+
} break;
|
5274
|
+
}
|
5275
|
+
}
|
5276
|
+
|
5138
5277
|
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
5278
|
+
if (!ggml_is_contiguous(tensor)) {
|
5279
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
5280
|
+
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
5281
|
+
return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
|
5282
|
+
}
|
5139
5283
|
switch (tensor->type) {
|
5140
5284
|
case GGML_TYPE_I8:
|
5141
5285
|
{
|
5142
5286
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
5143
5287
|
return ((int8_t *)(tensor->data))[i];
|
5144
|
-
}
|
5288
|
+
}
|
5145
5289
|
case GGML_TYPE_I16:
|
5146
5290
|
{
|
5147
5291
|
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
5148
5292
|
return ((int16_t *)(tensor->data))[i];
|
5149
|
-
}
|
5293
|
+
}
|
5150
5294
|
case GGML_TYPE_I32:
|
5151
5295
|
{
|
5152
5296
|
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
5153
5297
|
return ((int32_t *)(tensor->data))[i];
|
5154
|
-
}
|
5298
|
+
}
|
5155
5299
|
case GGML_TYPE_F16:
|
5156
5300
|
{
|
5157
5301
|
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
|
5158
5302
|
return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
|
5159
|
-
}
|
5303
|
+
}
|
5160
5304
|
case GGML_TYPE_F32:
|
5161
5305
|
{
|
5162
5306
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
5163
5307
|
return ((float *)(tensor->data))[i];
|
5164
|
-
}
|
5308
|
+
}
|
5165
5309
|
default:
|
5166
5310
|
{
|
5167
5311
|
GGML_ASSERT(false);
|
5168
|
-
}
|
5312
|
+
}
|
5169
5313
|
}
|
5170
5314
|
|
5171
5315
|
return 0.0f;
|
5172
5316
|
}
|
5173
5317
|
|
5174
5318
|
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
5319
|
+
if (!ggml_is_contiguous(tensor)) {
|
5320
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
5321
|
+
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
5322
|
+
ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
|
5323
|
+
return;
|
5324
|
+
}
|
5175
5325
|
switch (tensor->type) {
|
5176
5326
|
case GGML_TYPE_I8:
|
5177
5327
|
{
|
@@ -5205,6 +5355,56 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
|
5205
5355
|
}
|
5206
5356
|
}
|
5207
5357
|
|
5358
|
+
float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
|
5359
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
5360
|
+
switch (tensor->type) {
|
5361
|
+
case GGML_TYPE_I8:
|
5362
|
+
return ((int8_t *) data)[0];
|
5363
|
+
case GGML_TYPE_I16:
|
5364
|
+
return ((int16_t *) data)[0];
|
5365
|
+
case GGML_TYPE_I32:
|
5366
|
+
return ((int32_t *) data)[0];
|
5367
|
+
case GGML_TYPE_F16:
|
5368
|
+
return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
|
5369
|
+
case GGML_TYPE_F32:
|
5370
|
+
return ((float *) data)[0];
|
5371
|
+
default:
|
5372
|
+
GGML_ASSERT(false);
|
5373
|
+
}
|
5374
|
+
|
5375
|
+
return 0.0f;
|
5376
|
+
}
|
5377
|
+
|
5378
|
+
void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
|
5379
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
5380
|
+
switch (tensor->type) {
|
5381
|
+
case GGML_TYPE_I8:
|
5382
|
+
{
|
5383
|
+
((int8_t *)(data))[0] = value;
|
5384
|
+
} break;
|
5385
|
+
case GGML_TYPE_I16:
|
5386
|
+
{
|
5387
|
+
((int16_t *)(data))[0] = value;
|
5388
|
+
} break;
|
5389
|
+
case GGML_TYPE_I32:
|
5390
|
+
{
|
5391
|
+
((int32_t *)(data))[0] = value;
|
5392
|
+
} break;
|
5393
|
+
case GGML_TYPE_F16:
|
5394
|
+
{
|
5395
|
+
((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
|
5396
|
+
} break;
|
5397
|
+
case GGML_TYPE_F32:
|
5398
|
+
{
|
5399
|
+
((float *)(data))[0] = value;
|
5400
|
+
} break;
|
5401
|
+
default:
|
5402
|
+
{
|
5403
|
+
GGML_ASSERT(false);
|
5404
|
+
} break;
|
5405
|
+
}
|
5406
|
+
}
|
5407
|
+
|
5208
5408
|
void * ggml_get_data(const struct ggml_tensor * tensor) {
|
5209
5409
|
return tensor->data;
|
5210
5410
|
}
|
@@ -5347,6 +5547,44 @@ struct ggml_tensor * ggml_add_inplace(
|
|
5347
5547
|
return ggml_add_impl(ctx, a, b, true);
|
5348
5548
|
}
|
5349
5549
|
|
5550
|
+
// ggml_add_cast
|
5551
|
+
|
5552
|
+
static struct ggml_tensor * ggml_add_cast_impl(
|
5553
|
+
struct ggml_context * ctx,
|
5554
|
+
struct ggml_tensor * a,
|
5555
|
+
struct ggml_tensor * b,
|
5556
|
+
enum ggml_type type) {
|
5557
|
+
// TODO: support less-strict constraint
|
5558
|
+
// GGML_ASSERT(ggml_can_repeat(b, a));
|
5559
|
+
GGML_ASSERT(ggml_can_repeat_rows(b, a));
|
5560
|
+
GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
|
5561
|
+
|
5562
|
+
bool is_node = false;
|
5563
|
+
|
5564
|
+
if (a->grad || b->grad) {
|
5565
|
+
// TODO: support backward pass for broadcasting
|
5566
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
5567
|
+
is_node = true;
|
5568
|
+
}
|
5569
|
+
|
5570
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
|
5571
|
+
|
5572
|
+
result->op = GGML_OP_ADD;
|
5573
|
+
result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
|
5574
|
+
result->src[0] = a;
|
5575
|
+
result->src[1] = b;
|
5576
|
+
|
5577
|
+
return result;
|
5578
|
+
}
|
5579
|
+
|
5580
|
+
struct ggml_tensor * ggml_add_cast(
|
5581
|
+
struct ggml_context * ctx,
|
5582
|
+
struct ggml_tensor * a,
|
5583
|
+
struct ggml_tensor * b,
|
5584
|
+
enum ggml_type type) {
|
5585
|
+
return ggml_add_cast_impl(ctx, a, b, type);
|
5586
|
+
}
|
5587
|
+
|
5350
5588
|
// ggml_add1
|
5351
5589
|
|
5352
5590
|
static struct ggml_tensor * ggml_add1_impl(
|
@@ -5783,7 +6021,6 @@ struct ggml_tensor * ggml_repeat(
|
|
5783
6021
|
result->op = GGML_OP_REPEAT;
|
5784
6022
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5785
6023
|
result->src[0] = a;
|
5786
|
-
result->src[1] = b;
|
5787
6024
|
|
5788
6025
|
return result;
|
5789
6026
|
}
|
@@ -5811,7 +6048,6 @@ struct ggml_tensor * ggml_repeat_back(
|
|
5811
6048
|
result->op = GGML_OP_REPEAT_BACK;
|
5812
6049
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5813
6050
|
result->src[0] = a;
|
5814
|
-
result->src[1] = b;
|
5815
6051
|
|
5816
6052
|
return result;
|
5817
6053
|
}
|
@@ -6186,8 +6422,9 @@ struct ggml_tensor * ggml_out_prod(
|
|
6186
6422
|
is_node = true;
|
6187
6423
|
}
|
6188
6424
|
|
6189
|
-
|
6190
|
-
|
6425
|
+
// a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
|
6426
|
+
const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
|
6427
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
|
6191
6428
|
|
6192
6429
|
result->op = GGML_OP_OUT_PROD;
|
6193
6430
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
@@ -6406,6 +6643,54 @@ struct ggml_tensor * ggml_cont_inplace(
|
|
6406
6643
|
return ggml_cont_impl(ctx, a, true);
|
6407
6644
|
}
|
6408
6645
|
|
6646
|
+
|
6647
|
+
// make contiguous, with new shape
|
6648
|
+
GGML_API struct ggml_tensor * ggml_cont_1d(
|
6649
|
+
struct ggml_context * ctx,
|
6650
|
+
struct ggml_tensor * a,
|
6651
|
+
int64_t ne0) {
|
6652
|
+
return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
|
6653
|
+
}
|
6654
|
+
|
6655
|
+
GGML_API struct ggml_tensor * ggml_cont_2d(
|
6656
|
+
struct ggml_context * ctx,
|
6657
|
+
struct ggml_tensor * a,
|
6658
|
+
int64_t ne0,
|
6659
|
+
int64_t ne1) {
|
6660
|
+
return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
|
6661
|
+
}
|
6662
|
+
|
6663
|
+
GGML_API struct ggml_tensor * ggml_cont_3d(
|
6664
|
+
struct ggml_context * ctx,
|
6665
|
+
struct ggml_tensor * a,
|
6666
|
+
int64_t ne0,
|
6667
|
+
int64_t ne1,
|
6668
|
+
int64_t ne2) {
|
6669
|
+
return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
|
6670
|
+
}
|
6671
|
+
|
6672
|
+
struct ggml_tensor * ggml_cont_4d(
|
6673
|
+
struct ggml_context * ctx,
|
6674
|
+
struct ggml_tensor * a,
|
6675
|
+
int64_t ne0,
|
6676
|
+
int64_t ne1,
|
6677
|
+
int64_t ne2,
|
6678
|
+
int64_t ne3) {
|
6679
|
+
GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
|
6680
|
+
|
6681
|
+
bool is_node = false;
|
6682
|
+
|
6683
|
+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
|
6684
|
+
ggml_format_name(result, "%s (cont)", a->name);
|
6685
|
+
|
6686
|
+
result->op = GGML_OP_CONT;
|
6687
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6688
|
+
result->src[0] = a;
|
6689
|
+
|
6690
|
+
return result;
|
6691
|
+
}
|
6692
|
+
|
6693
|
+
|
6409
6694
|
// ggml_reshape
|
6410
6695
|
|
6411
6696
|
struct ggml_tensor * ggml_reshape(
|
@@ -6413,7 +6698,7 @@ struct ggml_tensor * ggml_reshape(
|
|
6413
6698
|
struct ggml_tensor * a,
|
6414
6699
|
struct ggml_tensor * b) {
|
6415
6700
|
GGML_ASSERT(ggml_is_contiguous(a));
|
6416
|
-
|
6701
|
+
// as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
|
6417
6702
|
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
|
6418
6703
|
|
6419
6704
|
bool is_node = false;
|
@@ -6786,7 +7071,6 @@ struct ggml_tensor * ggml_get_rows_back(
|
|
6786
7071
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6787
7072
|
result->src[0] = a;
|
6788
7073
|
result->src[1] = b;
|
6789
|
-
result->src[2] = c;
|
6790
7074
|
|
6791
7075
|
return result;
|
6792
7076
|
}
|
@@ -6968,7 +7252,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
|
|
6968
7252
|
static struct ggml_tensor * ggml_rope_impl(
|
6969
7253
|
struct ggml_context * ctx,
|
6970
7254
|
struct ggml_tensor * a,
|
6971
|
-
|
7255
|
+
struct ggml_tensor * b,
|
6972
7256
|
int n_dims,
|
6973
7257
|
int mode,
|
6974
7258
|
int n_ctx,
|
@@ -6977,7 +7261,10 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6977
7261
|
float xpos_base,
|
6978
7262
|
bool xpos_down,
|
6979
7263
|
bool inplace) {
|
6980
|
-
GGML_ASSERT(
|
7264
|
+
GGML_ASSERT(ggml_is_vector(b));
|
7265
|
+
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
7266
|
+
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
7267
|
+
|
6981
7268
|
bool is_node = false;
|
6982
7269
|
|
6983
7270
|
if (a->grad) {
|
@@ -6986,7 +7273,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6986
7273
|
|
6987
7274
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6988
7275
|
|
6989
|
-
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
|
7276
|
+
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
|
6990
7277
|
memcpy(params + 4, &freq_base, sizeof(float));
|
6991
7278
|
memcpy(params + 5, &freq_scale, sizeof(float));
|
6992
7279
|
memcpy(params + 6, &xpos_base, sizeof(float));
|
@@ -6996,6 +7283,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6996
7283
|
result->op = GGML_OP_ROPE;
|
6997
7284
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6998
7285
|
result->src[0] = a;
|
7286
|
+
result->src[1] = b;
|
6999
7287
|
|
7000
7288
|
return result;
|
7001
7289
|
}
|
@@ -7003,55 +7291,55 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
7003
7291
|
struct ggml_tensor * ggml_rope(
|
7004
7292
|
struct ggml_context * ctx,
|
7005
7293
|
struct ggml_tensor * a,
|
7006
|
-
|
7294
|
+
struct ggml_tensor * b,
|
7007
7295
|
int n_dims,
|
7008
7296
|
int mode,
|
7009
7297
|
int n_ctx) {
|
7010
|
-
return ggml_rope_impl(ctx, a,
|
7298
|
+
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
|
7011
7299
|
}
|
7012
7300
|
|
7013
7301
|
struct ggml_tensor * ggml_rope_inplace(
|
7014
7302
|
struct ggml_context * ctx,
|
7015
7303
|
struct ggml_tensor * a,
|
7016
|
-
|
7304
|
+
struct ggml_tensor * b,
|
7017
7305
|
int n_dims,
|
7018
7306
|
int mode,
|
7019
7307
|
int n_ctx) {
|
7020
|
-
return ggml_rope_impl(ctx, a,
|
7308
|
+
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
|
7021
7309
|
}
|
7022
7310
|
|
7023
7311
|
struct ggml_tensor * ggml_rope_custom(
|
7024
7312
|
struct ggml_context * ctx,
|
7025
7313
|
struct ggml_tensor * a,
|
7026
|
-
|
7314
|
+
struct ggml_tensor * b,
|
7027
7315
|
int n_dims,
|
7028
7316
|
int mode,
|
7029
7317
|
int n_ctx,
|
7030
7318
|
float freq_base,
|
7031
7319
|
float freq_scale) {
|
7032
|
-
return ggml_rope_impl(ctx, a,
|
7320
|
+
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
|
7033
7321
|
}
|
7034
7322
|
|
7035
7323
|
struct ggml_tensor * ggml_rope_custom_inplace(
|
7036
7324
|
struct ggml_context * ctx,
|
7037
7325
|
struct ggml_tensor * a,
|
7038
|
-
|
7326
|
+
struct ggml_tensor * b,
|
7039
7327
|
int n_dims,
|
7040
7328
|
int mode,
|
7041
7329
|
int n_ctx,
|
7042
7330
|
float freq_base,
|
7043
7331
|
float freq_scale) {
|
7044
|
-
return ggml_rope_impl(ctx, a,
|
7332
|
+
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
|
7045
7333
|
}
|
7046
7334
|
|
7047
7335
|
struct ggml_tensor * ggml_rope_xpos_inplace(
|
7048
7336
|
struct ggml_context * ctx,
|
7049
7337
|
struct ggml_tensor * a,
|
7050
|
-
|
7338
|
+
struct ggml_tensor * b,
|
7051
7339
|
int n_dims,
|
7052
7340
|
float base,
|
7053
7341
|
bool down) {
|
7054
|
-
return ggml_rope_impl(ctx, a,
|
7342
|
+
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
|
7055
7343
|
}
|
7056
7344
|
|
7057
7345
|
// ggml_rope_back
|
@@ -7059,7 +7347,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
|
|
7059
7347
|
struct ggml_tensor * ggml_rope_back(
|
7060
7348
|
struct ggml_context * ctx,
|
7061
7349
|
struct ggml_tensor * a,
|
7062
|
-
|
7350
|
+
struct ggml_tensor * b,
|
7063
7351
|
int n_dims,
|
7064
7352
|
int mode,
|
7065
7353
|
int n_ctx,
|
@@ -7067,7 +7355,10 @@ struct ggml_tensor * ggml_rope_back(
|
|
7067
7355
|
float freq_scale,
|
7068
7356
|
float xpos_base,
|
7069
7357
|
bool xpos_down) {
|
7070
|
-
GGML_ASSERT(
|
7358
|
+
GGML_ASSERT(ggml_is_vector(b));
|
7359
|
+
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
7360
|
+
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
7361
|
+
|
7071
7362
|
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
7072
7363
|
|
7073
7364
|
bool is_node = false;
|
@@ -7078,7 +7369,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
7078
7369
|
|
7079
7370
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
7080
7371
|
|
7081
|
-
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
|
7372
|
+
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
|
7082
7373
|
memcpy(params + 4, &freq_base, sizeof(float));
|
7083
7374
|
memcpy(params + 5, &freq_scale, sizeof(float));
|
7084
7375
|
memcpy(params + 6, &xpos_base, sizeof(float));
|
@@ -7088,6 +7379,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
7088
7379
|
result->op = GGML_OP_ROPE_BACK;
|
7089
7380
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7090
7381
|
result->src[0] = a;
|
7382
|
+
result->src[1] = b;
|
7091
7383
|
|
7092
7384
|
return result;
|
7093
7385
|
}
|
@@ -7484,27 +7776,30 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
7484
7776
|
|
7485
7777
|
// d shape [D,N,ne2,ne3]
|
7486
7778
|
// q shape [D,N,ne2,ne3]
|
7487
|
-
// k shape [D,M,
|
7488
|
-
// v shape [M,D,
|
7779
|
+
// k shape [D,M,kvne2,ne3]
|
7780
|
+
// v shape [M,D,kvne2,ne3]
|
7489
7781
|
|
7490
|
-
const int64_t
|
7491
|
-
const int64_t
|
7492
|
-
const int64_t
|
7493
|
-
const int64_t
|
7494
|
-
const int64_t
|
7782
|
+
const int64_t D = q->ne[0];
|
7783
|
+
const int64_t N = q->ne[1];
|
7784
|
+
const int64_t M = k->ne[1];
|
7785
|
+
const int64_t ne2 = q->ne[2];
|
7786
|
+
const int64_t ne3 = q->ne[3];
|
7787
|
+
const int64_t kvne2 = k->ne[2];
|
7495
7788
|
|
7496
7789
|
GGML_ASSERT(k->ne[0] == D);
|
7497
7790
|
GGML_ASSERT(v->ne[0] == M);
|
7498
7791
|
GGML_ASSERT(v->ne[1] == D);
|
7499
7792
|
GGML_ASSERT(d->ne[0] == D);
|
7500
7793
|
GGML_ASSERT(d->ne[1] == N);
|
7501
|
-
GGML_ASSERT(k->ne[2] ==
|
7794
|
+
GGML_ASSERT(k->ne[2] == kvne2);
|
7502
7795
|
GGML_ASSERT(k->ne[3] == ne3);
|
7503
|
-
GGML_ASSERT(v->ne[2] ==
|
7796
|
+
GGML_ASSERT(v->ne[2] == kvne2);
|
7504
7797
|
GGML_ASSERT(v->ne[3] == ne3);
|
7505
7798
|
GGML_ASSERT(d->ne[2] == ne2);
|
7506
7799
|
GGML_ASSERT(d->ne[3] == ne3);
|
7507
7800
|
|
7801
|
+
GGML_ASSERT(ne2 % kvne2 == 0);
|
7802
|
+
|
7508
7803
|
bool is_node = false;
|
7509
7804
|
|
7510
7805
|
if (q->grad || k->grad || v->grad) {
|
@@ -7514,14 +7809,23 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
7514
7809
|
}
|
7515
7810
|
|
7516
7811
|
// store gradients of q, k and v as continuous tensors concatenated in result.
|
7517
|
-
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
7518
|
-
// gradq->data = result->data
|
7519
|
-
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
7520
|
-
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
7521
7812
|
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
7522
|
-
int64_t
|
7813
|
+
const int64_t elem_q = ggml_nelements(q);
|
7814
|
+
const int64_t elem_k = ggml_nelements(k);
|
7815
|
+
const int64_t elem_v = ggml_nelements(v);
|
7523
7816
|
|
7524
|
-
|
7817
|
+
enum ggml_type result_type = GGML_TYPE_F32;
|
7818
|
+
GGML_ASSERT(ggml_blck_size(result_type) == 1);
|
7819
|
+
const size_t tsize = ggml_type_size(result_type);
|
7820
|
+
|
7821
|
+
const size_t offs_q = 0;
|
7822
|
+
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
|
7823
|
+
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
|
7824
|
+
const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
|
7825
|
+
|
7826
|
+
const size_t nelements = (end + tsize - 1)/tsize;
|
7827
|
+
|
7828
|
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
|
7525
7829
|
|
7526
7830
|
int32_t masked_i = masked ? 1 : 0;
|
7527
7831
|
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
|
@@ -8214,7 +8518,7 @@ static void ggml_compute_forward_dup_f16(
|
|
8214
8518
|
return;
|
8215
8519
|
}
|
8216
8520
|
|
8217
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
8521
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
8218
8522
|
|
8219
8523
|
const int ith = params->ith; // thread index
|
8220
8524
|
const int nth = params->nth; // number of threads
|
@@ -8485,7 +8789,7 @@ static void ggml_compute_forward_dup_f32(
|
|
8485
8789
|
return;
|
8486
8790
|
}
|
8487
8791
|
|
8488
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
8792
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
8489
8793
|
|
8490
8794
|
const int ith = params->ith; // thread index
|
8491
8795
|
const int nth = params->nth; // number of threads
|
@@ -8766,7 +9070,7 @@ static void ggml_compute_forward_add_f32(
|
|
8766
9070
|
|
8767
9071
|
const int nr = ggml_nrows(src0);
|
8768
9072
|
|
8769
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9073
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
8770
9074
|
|
8771
9075
|
GGML_ASSERT( nb0 == sizeof(float));
|
8772
9076
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -8798,8 +9102,6 @@ static void ggml_compute_forward_add_f32(
|
|
8798
9102
|
#else
|
8799
9103
|
ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
|
8800
9104
|
#endif
|
8801
|
-
// }
|
8802
|
-
// }
|
8803
9105
|
}
|
8804
9106
|
} else {
|
8805
9107
|
// src1 is not contiguous
|
@@ -8841,7 +9143,7 @@ static void ggml_compute_forward_add_f16_f32(
|
|
8841
9143
|
|
8842
9144
|
const int nr = ggml_nrows(src0);
|
8843
9145
|
|
8844
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9146
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
8845
9147
|
|
8846
9148
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
8847
9149
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
@@ -8895,7 +9197,7 @@ static void ggml_compute_forward_add_f16_f16(
|
|
8895
9197
|
|
8896
9198
|
const int nr = ggml_nrows(src0);
|
8897
9199
|
|
8898
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9200
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
8899
9201
|
|
8900
9202
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
8901
9203
|
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
@@ -8946,14 +9248,15 @@ static void ggml_compute_forward_add_q_f32(
|
|
8946
9248
|
|
8947
9249
|
const int nr = ggml_nrows(src0);
|
8948
9250
|
|
8949
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9251
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
8950
9252
|
|
8951
9253
|
const int ith = params->ith;
|
8952
9254
|
const int nth = params->nth;
|
8953
9255
|
|
8954
9256
|
const enum ggml_type type = src0->type;
|
9257
|
+
const enum ggml_type dtype = dst->type;
|
8955
9258
|
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
8956
|
-
ggml_from_float_t const quantize_row_q = type_traits[
|
9259
|
+
ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
|
8957
9260
|
|
8958
9261
|
// we don't support permuted src0 or src1
|
8959
9262
|
GGML_ASSERT(nb00 == ggml_type_size(type));
|
@@ -8965,7 +9268,6 @@ static void ggml_compute_forward_add_q_f32(
|
|
8965
9268
|
GGML_ASSERT(nb2 <= nb3);
|
8966
9269
|
|
8967
9270
|
GGML_ASSERT(ggml_is_quantized(src0->type));
|
8968
|
-
GGML_ASSERT(dst->type == src0->type);
|
8969
9271
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
8970
9272
|
|
8971
9273
|
// rows per thread
|
@@ -9003,7 +9305,11 @@ static void ggml_compute_forward_add_q_f32(
|
|
9003
9305
|
// add src1
|
9004
9306
|
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
9005
9307
|
// quantize row to dst
|
9006
|
-
quantize_row_q
|
9308
|
+
if (quantize_row_q != NULL) {
|
9309
|
+
quantize_row_q(wdata, dst_row, ne00);
|
9310
|
+
} else {
|
9311
|
+
memcpy(dst_row, wdata, ne0*nb0);
|
9312
|
+
}
|
9007
9313
|
}
|
9008
9314
|
}
|
9009
9315
|
|
@@ -9068,7 +9374,7 @@ static void ggml_compute_forward_add1_f32(
|
|
9068
9374
|
|
9069
9375
|
const int nr = ggml_nrows(src0);
|
9070
9376
|
|
9071
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
9377
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9072
9378
|
|
9073
9379
|
GGML_ASSERT( nb0 == sizeof(float));
|
9074
9380
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -9123,7 +9429,7 @@ static void ggml_compute_forward_add1_f16_f32(
|
|
9123
9429
|
|
9124
9430
|
const int nr = ggml_nrows(src0);
|
9125
9431
|
|
9126
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
9432
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9127
9433
|
|
9128
9434
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
9129
9435
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
@@ -9173,7 +9479,7 @@ static void ggml_compute_forward_add1_f16_f16(
|
|
9173
9479
|
|
9174
9480
|
const int nr = ggml_nrows(src0);
|
9175
9481
|
|
9176
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
9482
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9177
9483
|
|
9178
9484
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
9179
9485
|
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
@@ -9223,7 +9529,7 @@ static void ggml_compute_forward_add1_q_f32(
|
|
9223
9529
|
|
9224
9530
|
const int nr = ggml_nrows(src0);
|
9225
9531
|
|
9226
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
9532
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9227
9533
|
|
9228
9534
|
const enum ggml_type type = src0->type;
|
9229
9535
|
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
@@ -9351,8 +9657,8 @@ static void ggml_compute_forward_acc_f32(
|
|
9351
9657
|
const int nr = ggml_nrows(src1);
|
9352
9658
|
const int nc = src1->ne[0];
|
9353
9659
|
|
9354
|
-
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
9355
|
-
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
9660
|
+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
9661
|
+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
9356
9662
|
|
9357
9663
|
// src0 and dst as viewed during acc
|
9358
9664
|
const size_t nb0 = ggml_element_size(src0);
|
@@ -9441,7 +9747,7 @@ static void ggml_compute_forward_sub_f32(
|
|
9441
9747
|
|
9442
9748
|
const int nr = ggml_nrows(src0);
|
9443
9749
|
|
9444
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9750
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
9445
9751
|
|
9446
9752
|
GGML_ASSERT( nb0 == sizeof(float));
|
9447
9753
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -9531,7 +9837,7 @@ static void ggml_compute_forward_mul_f32(
|
|
9531
9837
|
|
9532
9838
|
const int64_t nr = ggml_nrows(src0);
|
9533
9839
|
|
9534
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9840
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
9535
9841
|
|
9536
9842
|
GGML_ASSERT( nb0 == sizeof(float));
|
9537
9843
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -9622,7 +9928,7 @@ static void ggml_compute_forward_div_f32(
|
|
9622
9928
|
|
9623
9929
|
const int nr = ggml_nrows(src0);
|
9624
9930
|
|
9625
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9931
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
9626
9932
|
|
9627
9933
|
GGML_ASSERT( nb0 == sizeof(float));
|
9628
9934
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -9831,8 +10137,8 @@ static void ggml_compute_forward_sum_f32(
|
|
9831
10137
|
assert(ggml_is_scalar(dst));
|
9832
10138
|
assert(src0->nb[0] == sizeof(float));
|
9833
10139
|
|
9834
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
9835
|
-
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
10140
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
10141
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
9836
10142
|
|
9837
10143
|
ggml_float sum = 0;
|
9838
10144
|
ggml_float row_sum = 0;
|
@@ -9863,8 +10169,8 @@ static void ggml_compute_forward_sum_f16(
|
|
9863
10169
|
|
9864
10170
|
assert(src0->nb[0] == sizeof(ggml_fp16_t));
|
9865
10171
|
|
9866
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
9867
|
-
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
10172
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
10173
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
9868
10174
|
|
9869
10175
|
float sum = 0;
|
9870
10176
|
float row_sum = 0;
|
@@ -9917,7 +10223,7 @@ static void ggml_compute_forward_sum_rows_f32(
|
|
9917
10223
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
9918
10224
|
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
9919
10225
|
|
9920
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
10226
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9921
10227
|
|
9922
10228
|
GGML_ASSERT(ne0 == 1);
|
9923
10229
|
GGML_ASSERT(ne1 == ne01);
|
@@ -9967,7 +10273,7 @@ static void ggml_compute_forward_mean_f32(
|
|
9967
10273
|
|
9968
10274
|
assert(src0->nb[0] == sizeof(float));
|
9969
10275
|
|
9970
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
10276
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9971
10277
|
|
9972
10278
|
assert(ne0 == 1);
|
9973
10279
|
assert(ne1 == ne01);
|
@@ -10067,7 +10373,7 @@ static void ggml_compute_forward_repeat_f32(
|
|
10067
10373
|
return;
|
10068
10374
|
}
|
10069
10375
|
|
10070
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
10376
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
10071
10377
|
|
10072
10378
|
// guaranteed to be an integer due to the check in ggml_can_repeat
|
10073
10379
|
const int nr0 = (int)(ne0/ne00);
|
@@ -10099,11 +10405,61 @@ static void ggml_compute_forward_repeat_f32(
|
|
10099
10405
|
}
|
10100
10406
|
}
|
10101
10407
|
|
10408
|
+
static void ggml_compute_forward_repeat_f16(
|
10409
|
+
const struct ggml_compute_params * params,
|
10410
|
+
const struct ggml_tensor * src0,
|
10411
|
+
struct ggml_tensor * dst) {
|
10412
|
+
GGML_ASSERT(params->ith == 0);
|
10413
|
+
GGML_ASSERT(ggml_can_repeat(src0, dst));
|
10414
|
+
|
10415
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10416
|
+
return;
|
10417
|
+
}
|
10418
|
+
|
10419
|
+
GGML_TENSOR_UNARY_OP_LOCALS;
|
10420
|
+
|
10421
|
+
// guaranteed to be an integer due to the check in ggml_can_repeat
|
10422
|
+
const int nr0 = (int)(ne0/ne00);
|
10423
|
+
const int nr1 = (int)(ne1/ne01);
|
10424
|
+
const int nr2 = (int)(ne2/ne02);
|
10425
|
+
const int nr3 = (int)(ne3/ne03);
|
10426
|
+
|
10427
|
+
// TODO: support for transposed / permuted tensors
|
10428
|
+
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
10429
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
10430
|
+
|
10431
|
+
// TODO: maybe this is not optimal?
|
10432
|
+
for (int i3 = 0; i3 < nr3; i3++) {
|
10433
|
+
for (int k3 = 0; k3 < ne03; k3++) {
|
10434
|
+
for (int i2 = 0; i2 < nr2; i2++) {
|
10435
|
+
for (int k2 = 0; k2 < ne02; k2++) {
|
10436
|
+
for (int i1 = 0; i1 < nr1; i1++) {
|
10437
|
+
for (int k1 = 0; k1 < ne01; k1++) {
|
10438
|
+
for (int i0 = 0; i0 < nr0; i0++) {
|
10439
|
+
ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
|
10440
|
+
ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
|
10441
|
+
// ggml_vec_cpy_f16(ne00, y, x)
|
10442
|
+
for (int i = 0; i < ne00; ++i) {
|
10443
|
+
y[i] = x[i];
|
10444
|
+
}
|
10445
|
+
}
|
10446
|
+
}
|
10447
|
+
}
|
10448
|
+
}
|
10449
|
+
}
|
10450
|
+
}
|
10451
|
+
}
|
10452
|
+
}
|
10453
|
+
|
10102
10454
|
static void ggml_compute_forward_repeat(
|
10103
10455
|
const struct ggml_compute_params * params,
|
10104
10456
|
const struct ggml_tensor * src0,
|
10105
10457
|
struct ggml_tensor * dst) {
|
10106
10458
|
switch (src0->type) {
|
10459
|
+
case GGML_TYPE_F16:
|
10460
|
+
{
|
10461
|
+
ggml_compute_forward_repeat_f16(params, src0, dst);
|
10462
|
+
} break;
|
10107
10463
|
case GGML_TYPE_F32:
|
10108
10464
|
{
|
10109
10465
|
ggml_compute_forward_repeat_f32(params, src0, dst);
|
@@ -10128,7 +10484,7 @@ static void ggml_compute_forward_repeat_back_f32(
|
|
10128
10484
|
return;
|
10129
10485
|
}
|
10130
10486
|
|
10131
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
10487
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
10132
10488
|
|
10133
10489
|
// guaranteed to be an integer due to the check in ggml_can_repeat
|
10134
10490
|
const int nr0 = (int)(ne00/ne0);
|
@@ -10206,7 +10562,7 @@ static void ggml_compute_forward_concat_f32(
|
|
10206
10562
|
|
10207
10563
|
const int ith = params->ith;
|
10208
10564
|
|
10209
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
10565
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
10210
10566
|
|
10211
10567
|
// TODO: support for transposed / permuted tensors
|
10212
10568
|
GGML_ASSERT(nb0 == sizeof(float));
|
@@ -10808,7 +11164,7 @@ static void ggml_compute_forward_norm_f32(
|
|
10808
11164
|
const int ith = params->ith;
|
10809
11165
|
const int nth = params->nth;
|
10810
11166
|
|
10811
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
11167
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
10812
11168
|
|
10813
11169
|
float eps;
|
10814
11170
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -10877,7 +11233,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
10877
11233
|
const int ith = params->ith;
|
10878
11234
|
const int nth = params->nth;
|
10879
11235
|
|
10880
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
11236
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
10881
11237
|
|
10882
11238
|
float eps;
|
10883
11239
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -10942,7 +11298,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|
10942
11298
|
const int ith = params->ith;
|
10943
11299
|
const int nth = params->nth;
|
10944
11300
|
|
10945
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
11301
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
10946
11302
|
|
10947
11303
|
float eps;
|
10948
11304
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -11117,7 +11473,7 @@ static void ggml_compute_forward_group_norm_f32(
|
|
11117
11473
|
const int ith = params->ith;
|
11118
11474
|
const int nth = params->nth;
|
11119
11475
|
|
11120
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
11476
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
11121
11477
|
|
11122
11478
|
const float eps = 1e-6f; // TODO: make this a parameter
|
11123
11479
|
|
@@ -11228,7 +11584,7 @@ static void ggml_compute_forward_mul_mat(
|
|
11228
11584
|
int64_t t0 = ggml_perf_time_us();
|
11229
11585
|
UNUSED(t0);
|
11230
11586
|
|
11231
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
11587
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
11232
11588
|
|
11233
11589
|
const int ith = params->ith;
|
11234
11590
|
const int nth = params->nth;
|
@@ -11443,10 +11799,10 @@ static void ggml_compute_forward_out_prod_f32(
|
|
11443
11799
|
const struct ggml_tensor * src0,
|
11444
11800
|
const struct ggml_tensor * src1,
|
11445
11801
|
struct ggml_tensor * dst) {
|
11446
|
-
int64_t t0 = ggml_perf_time_us();
|
11447
|
-
UNUSED(t0);
|
11802
|
+
// int64_t t0 = ggml_perf_time_us();
|
11803
|
+
// UNUSED(t0);
|
11448
11804
|
|
11449
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
11805
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
11450
11806
|
|
11451
11807
|
const int ith = params->ith;
|
11452
11808
|
const int nth = params->nth;
|
@@ -11485,6 +11841,146 @@ static void ggml_compute_forward_out_prod_f32(
|
|
11485
11841
|
return;
|
11486
11842
|
}
|
11487
11843
|
|
11844
|
+
// dst[:,:,:,:] = 0
|
11845
|
+
// for i2,i3:
|
11846
|
+
// for i1:
|
11847
|
+
// for i01:
|
11848
|
+
// for i0:
|
11849
|
+
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
11850
|
+
|
11851
|
+
// parallelize by last three dimensions
|
11852
|
+
|
11853
|
+
// total rows in dst
|
11854
|
+
const int64_t nr = ne1*ne2*ne3;
|
11855
|
+
|
11856
|
+
// rows per thread
|
11857
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
11858
|
+
|
11859
|
+
// row range for this thread
|
11860
|
+
const int64_t ir0 = dr*ith;
|
11861
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
11862
|
+
|
11863
|
+
// block-tiling attempt
|
11864
|
+
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
|
11865
|
+
const int64_t blck_1 = 16;
|
11866
|
+
|
11867
|
+
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
|
11868
|
+
const int64_t bir1 = MIN(bir + blck_1, ir1);
|
11869
|
+
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
|
11870
|
+
const int64_t bne01 = MIN(bi01 + blck_0, ne01);
|
11871
|
+
for (int64_t ir = bir; ir < bir1; ++ir) {
|
11872
|
+
// dst indices
|
11873
|
+
const int64_t i3 = ir/(ne2*ne1);
|
11874
|
+
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
11875
|
+
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
11876
|
+
|
11877
|
+
const int64_t i02 = i2;
|
11878
|
+
const int64_t i03 = i3;
|
11879
|
+
|
11880
|
+
//const int64_t i10 = i1;
|
11881
|
+
const int64_t i12 = i2;
|
11882
|
+
const int64_t i13 = i3;
|
11883
|
+
|
11884
|
+
#if GGML_VEC_MAD_UNROLL > 2
|
11885
|
+
const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
|
11886
|
+
for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
|
11887
|
+
const int64_t i11 = i01;
|
11888
|
+
|
11889
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
11890
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
11891
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
11892
|
+
|
11893
|
+
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
11894
|
+
}
|
11895
|
+
for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
|
11896
|
+
const int64_t i11 = i01;
|
11897
|
+
|
11898
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
11899
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
11900
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
11901
|
+
|
11902
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
11903
|
+
}
|
11904
|
+
#else
|
11905
|
+
for (int64_t i01 = bi01; i01 < bne01; ++i01) {
|
11906
|
+
const int64_t i11 = i01;
|
11907
|
+
|
11908
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
11909
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
11910
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
11911
|
+
|
11912
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
11913
|
+
}
|
11914
|
+
#endif
|
11915
|
+
}
|
11916
|
+
}
|
11917
|
+
}
|
11918
|
+
|
11919
|
+
|
11920
|
+
//int64_t t1 = ggml_perf_time_us();
|
11921
|
+
//static int64_t acc = 0;
|
11922
|
+
//acc += t1 - t0;
|
11923
|
+
//if (t1 - t0 > 10) {
|
11924
|
+
// printf("\n");
|
11925
|
+
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
|
11926
|
+
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
|
11927
|
+
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
|
11928
|
+
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
|
11929
|
+
|
11930
|
+
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
|
11931
|
+
//}
|
11932
|
+
}
|
11933
|
+
|
11934
|
+
static void ggml_compute_forward_out_prod_q_f32(
|
11935
|
+
const struct ggml_compute_params * params,
|
11936
|
+
const struct ggml_tensor * src0,
|
11937
|
+
const struct ggml_tensor * src1,
|
11938
|
+
struct ggml_tensor * dst) {
|
11939
|
+
// int64_t t0 = ggml_perf_time_us();
|
11940
|
+
// UNUSED(t0);
|
11941
|
+
|
11942
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
11943
|
+
|
11944
|
+
const int ith = params->ith;
|
11945
|
+
const int nth = params->nth;
|
11946
|
+
|
11947
|
+
const enum ggml_type type = src0->type;
|
11948
|
+
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
11949
|
+
|
11950
|
+
GGML_ASSERT(ne02 == ne12);
|
11951
|
+
GGML_ASSERT(ne03 == ne13);
|
11952
|
+
GGML_ASSERT(ne2 == ne12);
|
11953
|
+
GGML_ASSERT(ne3 == ne13);
|
11954
|
+
|
11955
|
+
// we don't support permuted src0 dim0
|
11956
|
+
GGML_ASSERT(nb00 == ggml_type_size(type));
|
11957
|
+
|
11958
|
+
// dst dim0 cannot be transposed or permuted
|
11959
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
11960
|
+
// GGML_ASSERT(nb0 <= nb1);
|
11961
|
+
// GGML_ASSERT(nb1 <= nb2);
|
11962
|
+
// GGML_ASSERT(nb2 <= nb3);
|
11963
|
+
|
11964
|
+
GGML_ASSERT(ne0 == ne00);
|
11965
|
+
GGML_ASSERT(ne1 == ne10);
|
11966
|
+
GGML_ASSERT(ne2 == ne02);
|
11967
|
+
GGML_ASSERT(ne3 == ne03);
|
11968
|
+
|
11969
|
+
// nb01 >= nb00 - src0 is not transposed
|
11970
|
+
// compute by src0 rows
|
11971
|
+
|
11972
|
+
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
|
11973
|
+
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
11974
|
+
|
11975
|
+
if (params->type == GGML_TASK_INIT) {
|
11976
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
11977
|
+
return;
|
11978
|
+
}
|
11979
|
+
|
11980
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
11981
|
+
return;
|
11982
|
+
}
|
11983
|
+
|
11488
11984
|
// parallelize by last three dimensions
|
11489
11985
|
|
11490
11986
|
// total rows in dst
|
@@ -11504,6 +12000,8 @@ static void ggml_compute_forward_out_prod_f32(
|
|
11504
12000
|
// for i0:
|
11505
12001
|
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
11506
12002
|
|
12003
|
+
float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
|
12004
|
+
|
11507
12005
|
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
11508
12006
|
// dst indices
|
11509
12007
|
const int64_t i3 = ir/(ne2*ne1);
|
@@ -11524,10 +12022,8 @@ static void ggml_compute_forward_out_prod_f32(
|
|
11524
12022
|
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
11525
12023
|
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
11526
12024
|
|
11527
|
-
|
11528
|
-
|
11529
|
-
// d[i0] += s0[i0] * s1[i1];
|
11530
|
-
// }
|
12025
|
+
dequantize_row_q(s0, wdata, ne0);
|
12026
|
+
ggml_vec_mad_f32(ne0, d, wdata, *s1);
|
11531
12027
|
}
|
11532
12028
|
}
|
11533
12029
|
|
@@ -11556,10 +12052,13 @@ static void ggml_compute_forward_out_prod(
|
|
11556
12052
|
case GGML_TYPE_Q5_0:
|
11557
12053
|
case GGML_TYPE_Q5_1:
|
11558
12054
|
case GGML_TYPE_Q8_0:
|
11559
|
-
case
|
12055
|
+
case GGML_TYPE_Q2_K:
|
12056
|
+
case GGML_TYPE_Q3_K:
|
12057
|
+
case GGML_TYPE_Q4_K:
|
12058
|
+
case GGML_TYPE_Q5_K:
|
12059
|
+
case GGML_TYPE_Q6_K:
|
11560
12060
|
{
|
11561
|
-
|
11562
|
-
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
12061
|
+
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
11563
12062
|
} break;
|
11564
12063
|
case GGML_TYPE_F16:
|
11565
12064
|
{
|
@@ -11677,8 +12176,8 @@ static void ggml_compute_forward_set_f32(
|
|
11677
12176
|
const int nr = ggml_nrows(src1);
|
11678
12177
|
const int nc = src1->ne[0];
|
11679
12178
|
|
11680
|
-
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
11681
|
-
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
12179
|
+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
12180
|
+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
11682
12181
|
|
11683
12182
|
// src0 and dst as viewed during set
|
11684
12183
|
const size_t nb0 = ggml_element_size(src0);
|
@@ -11947,14 +12446,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
|
|
11947
12446
|
const struct ggml_compute_params * params,
|
11948
12447
|
const struct ggml_tensor * src0,
|
11949
12448
|
const struct ggml_tensor * src1,
|
11950
|
-
const struct ggml_tensor * opt0,
|
11951
12449
|
struct ggml_tensor * dst) {
|
11952
12450
|
GGML_ASSERT(params->ith == 0);
|
11953
|
-
GGML_ASSERT(ggml_are_same_shape(opt0, dst));
|
11954
|
-
GGML_ASSERT(ggml_is_contiguous(opt0));
|
11955
12451
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
11956
12452
|
|
11957
|
-
ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
12453
|
+
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
12454
|
+
|
12455
|
+
if (params->type == GGML_TASK_INIT) {
|
12456
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
12457
|
+
}
|
11958
12458
|
|
11959
12459
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11960
12460
|
return;
|
@@ -11980,11 +12480,8 @@ static void ggml_compute_forward_get_rows_back_f32(
|
|
11980
12480
|
const struct ggml_compute_params * params,
|
11981
12481
|
const struct ggml_tensor * src0,
|
11982
12482
|
const struct ggml_tensor * src1,
|
11983
|
-
const struct ggml_tensor * opt0,
|
11984
12483
|
struct ggml_tensor * dst) {
|
11985
12484
|
GGML_ASSERT(params->ith == 0);
|
11986
|
-
GGML_ASSERT(ggml_are_same_shape(opt0, dst));
|
11987
|
-
GGML_ASSERT(ggml_is_contiguous(opt0));
|
11988
12485
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
11989
12486
|
|
11990
12487
|
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
@@ -12018,16 +12515,15 @@ static void ggml_compute_forward_get_rows_back(
|
|
12018
12515
|
const struct ggml_compute_params * params,
|
12019
12516
|
const struct ggml_tensor * src0,
|
12020
12517
|
const struct ggml_tensor * src1,
|
12021
|
-
const struct ggml_tensor * opt0,
|
12022
12518
|
struct ggml_tensor * dst) {
|
12023
12519
|
switch (src0->type) {
|
12024
12520
|
case GGML_TYPE_F16:
|
12025
12521
|
{
|
12026
|
-
ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1,
|
12522
|
+
ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
|
12027
12523
|
} break;
|
12028
12524
|
case GGML_TYPE_F32:
|
12029
12525
|
{
|
12030
|
-
ggml_compute_forward_get_rows_back_f32(params, src0, src1,
|
12526
|
+
ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
|
12031
12527
|
} break;
|
12032
12528
|
default:
|
12033
12529
|
{
|
@@ -12068,7 +12564,7 @@ static void ggml_compute_forward_diag_f32(
|
|
12068
12564
|
|
12069
12565
|
// TODO: handle transposed/permuted matrices
|
12070
12566
|
|
12071
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
12567
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
12072
12568
|
|
12073
12569
|
GGML_ASSERT(ne00 == ne0);
|
12074
12570
|
GGML_ASSERT(ne00 == ne1);
|
@@ -12456,13 +12952,11 @@ static void ggml_compute_forward_alibi_f16(
|
|
12456
12952
|
return;
|
12457
12953
|
}
|
12458
12954
|
|
12459
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
12955
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12460
12956
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
12461
12957
|
float max_bias;
|
12462
12958
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
12463
12959
|
|
12464
|
-
assert(n_past >= 0);
|
12465
|
-
|
12466
12960
|
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
12467
12961
|
const int ne1 = src0->ne[1]; // seq_len_without_past
|
12468
12962
|
const int ne2 = src0->ne[2]; // n_head -> this is k
|
@@ -12477,7 +12971,7 @@ static void ggml_compute_forward_alibi_f16(
|
|
12477
12971
|
//const int nb3 = src0->nb[3];
|
12478
12972
|
|
12479
12973
|
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
12480
|
-
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
12974
|
+
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
12481
12975
|
GGML_ASSERT(n_head == ne2);
|
12482
12976
|
|
12483
12977
|
// add alibi to src0 (KQ_scaled)
|
@@ -12623,8 +13117,8 @@ static void ggml_compute_forward_clamp(
|
|
12623
13117
|
static void ggml_compute_forward_rope_f32(
|
12624
13118
|
const struct ggml_compute_params * params,
|
12625
13119
|
const struct ggml_tensor * src0,
|
13120
|
+
const struct ggml_tensor * src1,
|
12626
13121
|
struct ggml_tensor * dst) {
|
12627
|
-
|
12628
13122
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
12629
13123
|
return;
|
12630
13124
|
}
|
@@ -12634,9 +13128,9 @@ static void ggml_compute_forward_rope_f32(
|
|
12634
13128
|
|
12635
13129
|
// these two only relevant for xPos RoPE:
|
12636
13130
|
float xpos_base;
|
12637
|
-
bool
|
13131
|
+
bool xpos_down;
|
12638
13132
|
|
12639
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13133
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12640
13134
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
12641
13135
|
const int mode = ((int32_t *) dst->op_params)[2];
|
12642
13136
|
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
@@ -12645,9 +13139,7 @@ static void ggml_compute_forward_rope_f32(
|
|
12645
13139
|
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
12646
13140
|
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
12647
13141
|
|
12648
|
-
|
12649
|
-
|
12650
|
-
GGML_TENSOR_UNARY_OP_LOCALS;
|
13142
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
12651
13143
|
|
12652
13144
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
12653
13145
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
@@ -12677,9 +13169,11 @@ static void ggml_compute_forward_rope_f32(
|
|
12677
13169
|
const bool is_neox = mode & 2;
|
12678
13170
|
const bool is_glm = mode & 4;
|
12679
13171
|
|
13172
|
+
const int32_t * pos = (const int32_t *) src1->data;
|
13173
|
+
|
12680
13174
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
12681
|
-
for (int64_t i2 =
|
12682
|
-
const int64_t p =
|
13175
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
13176
|
+
const int64_t p = pos[i2];
|
12683
13177
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
12684
13178
|
if (ir++ < ir0) continue;
|
12685
13179
|
if (ir > ir1) break;
|
@@ -12716,7 +13210,7 @@ static void ggml_compute_forward_rope_f32(
|
|
12716
13210
|
const float cos_theta = cosf(theta);
|
12717
13211
|
const float sin_theta = sinf(theta);
|
12718
13212
|
// zeta scaling for xPos only:
|
12719
|
-
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0),
|
13213
|
+
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
12720
13214
|
if (xpos_down) zeta = 1.0f / zeta;
|
12721
13215
|
|
12722
13216
|
theta *= theta_scale;
|
@@ -12761,8 +13255,8 @@ static void ggml_compute_forward_rope_f32(
|
|
12761
13255
|
static void ggml_compute_forward_rope_f16(
|
12762
13256
|
const struct ggml_compute_params * params,
|
12763
13257
|
const struct ggml_tensor * src0,
|
13258
|
+
const struct ggml_tensor * src1,
|
12764
13259
|
struct ggml_tensor * dst) {
|
12765
|
-
|
12766
13260
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
12767
13261
|
return;
|
12768
13262
|
}
|
@@ -12770,16 +13264,14 @@ static void ggml_compute_forward_rope_f16(
|
|
12770
13264
|
float freq_base;
|
12771
13265
|
float freq_scale;
|
12772
13266
|
|
12773
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13267
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12774
13268
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
12775
13269
|
const int mode = ((int32_t *) dst->op_params)[2];
|
12776
13270
|
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
12777
13271
|
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
12778
13272
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
12779
13273
|
|
12780
|
-
|
12781
|
-
|
12782
|
-
GGML_TENSOR_UNARY_OP_LOCALS;
|
13274
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
12783
13275
|
|
12784
13276
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
12785
13277
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
@@ -12809,9 +13301,11 @@ static void ggml_compute_forward_rope_f16(
|
|
12809
13301
|
const bool is_neox = mode & 2;
|
12810
13302
|
const bool is_glm = mode & 4;
|
12811
13303
|
|
13304
|
+
const int32_t * pos = (const int32_t *) src1->data;
|
13305
|
+
|
12812
13306
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
12813
|
-
for (int64_t i2 =
|
12814
|
-
const int64_t p =
|
13307
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
13308
|
+
const int64_t p = pos[i2];
|
12815
13309
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
12816
13310
|
if (ir++ < ir0) continue;
|
12817
13311
|
if (ir > ir1) break;
|
@@ -12890,15 +13384,16 @@ static void ggml_compute_forward_rope_f16(
|
|
12890
13384
|
static void ggml_compute_forward_rope(
|
12891
13385
|
const struct ggml_compute_params * params,
|
12892
13386
|
const struct ggml_tensor * src0,
|
13387
|
+
const struct ggml_tensor * src1,
|
12893
13388
|
struct ggml_tensor * dst) {
|
12894
13389
|
switch (src0->type) {
|
12895
13390
|
case GGML_TYPE_F16:
|
12896
13391
|
{
|
12897
|
-
ggml_compute_forward_rope_f16(params, src0, dst);
|
13392
|
+
ggml_compute_forward_rope_f16(params, src0, src1, dst);
|
12898
13393
|
} break;
|
12899
13394
|
case GGML_TYPE_F32:
|
12900
13395
|
{
|
12901
|
-
ggml_compute_forward_rope_f32(params, src0, dst);
|
13396
|
+
ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
12902
13397
|
} break;
|
12903
13398
|
default:
|
12904
13399
|
{
|
@@ -12912,6 +13407,7 @@ static void ggml_compute_forward_rope(
|
|
12912
13407
|
static void ggml_compute_forward_rope_back_f32(
|
12913
13408
|
const struct ggml_compute_params * params,
|
12914
13409
|
const struct ggml_tensor * src0,
|
13410
|
+
const struct ggml_tensor * src1,
|
12915
13411
|
struct ggml_tensor * dst) {
|
12916
13412
|
|
12917
13413
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -12929,7 +13425,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|
12929
13425
|
float xpos_base;
|
12930
13426
|
bool xpos_down;
|
12931
13427
|
|
12932
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13428
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12933
13429
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
12934
13430
|
const int mode = ((int32_t *) dst->op_params)[2];
|
12935
13431
|
const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
|
@@ -12938,9 +13434,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|
12938
13434
|
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
12939
13435
|
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
12940
13436
|
|
12941
|
-
|
12942
|
-
|
12943
|
-
GGML_TENSOR_UNARY_OP_LOCALS;
|
13437
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
12944
13438
|
|
12945
13439
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
12946
13440
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
@@ -12966,9 +13460,11 @@ static void ggml_compute_forward_rope_back_f32(
|
|
12966
13460
|
|
12967
13461
|
const bool is_neox = mode & 2;
|
12968
13462
|
|
13463
|
+
const int32_t * pos = (const int32_t *) src1->data;
|
13464
|
+
|
12969
13465
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
12970
|
-
for (int64_t i2 =
|
12971
|
-
const int64_t p =
|
13466
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
13467
|
+
const int64_t p = pos[i2];
|
12972
13468
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
12973
13469
|
if (ir++ < ir0) continue;
|
12974
13470
|
if (ir > ir1) break;
|
@@ -12980,7 +13476,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|
12980
13476
|
const float cos_theta = cosf(theta);
|
12981
13477
|
const float sin_theta = sinf(theta);
|
12982
13478
|
// zeta scaling for xPos only:
|
12983
|
-
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0),
|
13479
|
+
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
12984
13480
|
if (xpos_down) zeta = 1.0f / zeta;
|
12985
13481
|
|
12986
13482
|
theta *= theta_scale;
|
@@ -13023,6 +13519,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|
13023
13519
|
static void ggml_compute_forward_rope_back_f16(
|
13024
13520
|
const struct ggml_compute_params * params,
|
13025
13521
|
const struct ggml_tensor * src0,
|
13522
|
+
const struct ggml_tensor * src1,
|
13026
13523
|
struct ggml_tensor * dst) {
|
13027
13524
|
|
13028
13525
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -13033,13 +13530,11 @@ static void ggml_compute_forward_rope_back_f16(
|
|
13033
13530
|
// dx = rope_back(dy, src1)
|
13034
13531
|
// src0 is dy, src1 contains options
|
13035
13532
|
|
13036
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13533
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
13037
13534
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
13038
13535
|
const int mode = ((int32_t *) dst->op_params)[2];
|
13039
13536
|
|
13040
|
-
|
13041
|
-
|
13042
|
-
GGML_TENSOR_UNARY_OP_LOCALS;
|
13537
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
13043
13538
|
|
13044
13539
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
13045
13540
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
@@ -13065,9 +13560,11 @@ static void ggml_compute_forward_rope_back_f16(
|
|
13065
13560
|
|
13066
13561
|
const bool is_neox = mode & 2;
|
13067
13562
|
|
13563
|
+
const int32_t * pos = (const int32_t *) src1->data;
|
13564
|
+
|
13068
13565
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
13069
|
-
for (int64_t i2 =
|
13070
|
-
const int64_t p =
|
13566
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
13567
|
+
const int64_t p = pos[i2];
|
13071
13568
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
13072
13569
|
if (ir++ < ir0) continue;
|
13073
13570
|
if (ir > ir1) break;
|
@@ -13119,15 +13616,16 @@ static void ggml_compute_forward_rope_back_f16(
|
|
13119
13616
|
static void ggml_compute_forward_rope_back(
|
13120
13617
|
const struct ggml_compute_params * params,
|
13121
13618
|
const struct ggml_tensor * src0,
|
13619
|
+
const struct ggml_tensor * src1,
|
13122
13620
|
struct ggml_tensor * dst) {
|
13123
13621
|
switch (src0->type) {
|
13124
13622
|
case GGML_TYPE_F16:
|
13125
13623
|
{
|
13126
|
-
ggml_compute_forward_rope_back_f16(params, src0, dst);
|
13624
|
+
ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
|
13127
13625
|
} break;
|
13128
13626
|
case GGML_TYPE_F32:
|
13129
13627
|
{
|
13130
|
-
ggml_compute_forward_rope_back_f32(params, src0, dst);
|
13628
|
+
ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
|
13131
13629
|
} break;
|
13132
13630
|
default:
|
13133
13631
|
{
|
@@ -13150,7 +13648,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
|
13150
13648
|
int64_t t0 = ggml_perf_time_us();
|
13151
13649
|
UNUSED(t0);
|
13152
13650
|
|
13153
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
13651
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13154
13652
|
|
13155
13653
|
const int ith = params->ith;
|
13156
13654
|
const int nth = params->nth;
|
@@ -13241,7 +13739,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
|
|
13241
13739
|
int64_t t0 = ggml_perf_time_us();
|
13242
13740
|
UNUSED(t0);
|
13243
13741
|
|
13244
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
13742
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13245
13743
|
|
13246
13744
|
const int ith = params->ith;
|
13247
13745
|
const int nth = params->nth;
|
@@ -13353,7 +13851,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
|
13353
13851
|
int64_t t0 = ggml_perf_time_us();
|
13354
13852
|
UNUSED(t0);
|
13355
13853
|
|
13356
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
13854
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13357
13855
|
|
13358
13856
|
const int ith = params->ith;
|
13359
13857
|
const int nth = params->nth;
|
@@ -13444,7 +13942,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
13444
13942
|
int64_t t0 = ggml_perf_time_us();
|
13445
13943
|
UNUSED(t0);
|
13446
13944
|
|
13447
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
13945
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13448
13946
|
|
13449
13947
|
const int ith = params->ith;
|
13450
13948
|
const int nth = params->nth;
|
@@ -13562,7 +14060,7 @@ static void ggml_compute_forward_conv_1d(
|
|
13562
14060
|
ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
|
13563
14061
|
} else {
|
13564
14062
|
GGML_ASSERT(false); // only stride 1 and 2 supported
|
13565
|
-
}
|
14063
|
+
}
|
13566
14064
|
}
|
13567
14065
|
|
13568
14066
|
// ggml_compute_forward_conv_2d
|
@@ -13579,7 +14077,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
|
|
13579
14077
|
int64_t t0 = ggml_perf_time_us();
|
13580
14078
|
UNUSED(t0);
|
13581
14079
|
|
13582
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
14080
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13583
14081
|
|
13584
14082
|
const int ith = params->ith;
|
13585
14083
|
const int nth = params->nth;
|
@@ -13699,7 +14197,7 @@ static void ggml_compute_forward_conv_transpose_2d(
|
|
13699
14197
|
int64_t t0 = ggml_perf_time_us();
|
13700
14198
|
UNUSED(t0);
|
13701
14199
|
|
13702
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
14200
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13703
14201
|
|
13704
14202
|
const int ith = params->ith;
|
13705
14203
|
const int nth = params->nth;
|
@@ -13958,7 +14456,7 @@ static void ggml_compute_forward_upscale_f32(
|
|
13958
14456
|
|
13959
14457
|
const int ith = params->ith;
|
13960
14458
|
|
13961
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
14459
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
13962
14460
|
|
13963
14461
|
const int scale_factor = dst->op_params[0];
|
13964
14462
|
|
@@ -14010,14 +14508,14 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14010
14508
|
int64_t t0 = ggml_perf_time_us();
|
14011
14509
|
UNUSED(t0);
|
14012
14510
|
|
14013
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14014
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14015
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14016
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14017
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14018
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14019
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14020
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14511
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14512
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14513
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14514
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14515
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14516
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14517
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14518
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14021
14519
|
|
14022
14520
|
const int ith = params->ith;
|
14023
14521
|
const int nth = params->nth;
|
@@ -14087,10 +14585,11 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14087
14585
|
S[i] = -INFINITY;
|
14088
14586
|
}
|
14089
14587
|
|
14090
|
-
|
14588
|
+
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
14589
|
+
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
14091
14590
|
// k indices
|
14092
14591
|
const int ik3 = iq3;
|
14093
|
-
const int ik2 = iq2;
|
14592
|
+
const int ik2 = iq2 % nek2;
|
14094
14593
|
const int ik1 = ic;
|
14095
14594
|
|
14096
14595
|
// S indices
|
@@ -14103,20 +14602,18 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14103
14602
|
}
|
14104
14603
|
|
14105
14604
|
// scale
|
14106
|
-
ggml_vec_scale_f32(
|
14605
|
+
ggml_vec_scale_f32(masked_begin, S, scale);
|
14107
14606
|
|
14108
|
-
|
14109
|
-
|
14110
|
-
if (i > P + iq1) {
|
14111
|
-
S[i] = -INFINITY;
|
14112
|
-
}
|
14113
|
-
}
|
14607
|
+
for (int64_t i = masked_begin; i < M; i++) {
|
14608
|
+
S[i] = -INFINITY;
|
14114
14609
|
}
|
14115
14610
|
|
14116
14611
|
// softmax
|
14612
|
+
// exclude known -INF S[..] values from max and loop
|
14613
|
+
// dont forget to set their SW values to zero
|
14117
14614
|
{
|
14118
14615
|
float max = -INFINITY;
|
14119
|
-
ggml_vec_max_f32(
|
14616
|
+
ggml_vec_max_f32(masked_begin, &max, S);
|
14120
14617
|
|
14121
14618
|
ggml_float sum = 0.0;
|
14122
14619
|
{
|
@@ -14130,10 +14627,15 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14130
14627
|
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
14131
14628
|
|
14132
14629
|
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
14630
|
+
if (i >= masked_begin) {
|
14631
|
+
break;
|
14632
|
+
}
|
14133
14633
|
float * SS = S + i;
|
14134
14634
|
|
14135
14635
|
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
14136
|
-
if (
|
14636
|
+
if (i + j >= masked_begin) {
|
14637
|
+
break;
|
14638
|
+
} else if (SS[j] == -INFINITY) {
|
14137
14639
|
SS[j] = 0.0f;
|
14138
14640
|
} else {
|
14139
14641
|
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
@@ -14158,10 +14660,10 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14158
14660
|
assert(sum > 0.0);
|
14159
14661
|
|
14160
14662
|
sum = 1.0/sum;
|
14161
|
-
ggml_vec_scale_f32(
|
14663
|
+
ggml_vec_scale_f32(masked_begin, S, sum);
|
14162
14664
|
|
14163
14665
|
#ifndef NDEBUG
|
14164
|
-
for (int i = 0; i <
|
14666
|
+
for (int i = 0; i < masked_begin; ++i) {
|
14165
14667
|
assert(!isnan(S[i]));
|
14166
14668
|
assert(!isinf(S[i]));
|
14167
14669
|
}
|
@@ -14174,9 +14676,13 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14174
14676
|
const int i2 = iq2;
|
14175
14677
|
const int i3 = iq3;
|
14176
14678
|
|
14177
|
-
|
14178
|
-
|
14179
|
-
|
14679
|
+
// v indices
|
14680
|
+
const int iv2 = iq2 % nev2;
|
14681
|
+
const int iv3 = iq3;
|
14682
|
+
|
14683
|
+
ggml_vec_dot_f32(masked_begin,
|
14684
|
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
14685
|
+
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
14180
14686
|
S);
|
14181
14687
|
}
|
14182
14688
|
}
|
@@ -14192,14 +14698,14 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14192
14698
|
int64_t t0 = ggml_perf_time_us();
|
14193
14699
|
UNUSED(t0);
|
14194
14700
|
|
14195
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14196
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14197
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14198
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14199
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14200
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14201
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14202
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14701
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14702
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14703
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14704
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14705
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14706
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14707
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14708
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14203
14709
|
|
14204
14710
|
const int ith = params->ith;
|
14205
14711
|
const int nth = params->nth;
|
@@ -14273,7 +14779,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14273
14779
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
14274
14780
|
// k indices
|
14275
14781
|
const int ik3 = iq3;
|
14276
|
-
const int ik2 = iq2;
|
14782
|
+
const int ik2 = iq2 % nek2;
|
14277
14783
|
const int ik1 = ic;
|
14278
14784
|
|
14279
14785
|
// S indices
|
@@ -14288,7 +14794,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14288
14794
|
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
14289
14795
|
// k indices
|
14290
14796
|
const int ik3 = iq3;
|
14291
|
-
const int ik2 = iq2;
|
14797
|
+
const int ik2 = iq2 % nek2;
|
14292
14798
|
const int ik1 = ic;
|
14293
14799
|
|
14294
14800
|
// S indices
|
@@ -14313,6 +14819,8 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14313
14819
|
}
|
14314
14820
|
|
14315
14821
|
// softmax
|
14822
|
+
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
14823
|
+
// dont forget to set their S values to zero
|
14316
14824
|
{
|
14317
14825
|
float max = -INFINITY;
|
14318
14826
|
ggml_vec_max_f32(M, &max, S);
|
@@ -14369,6 +14877,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14369
14877
|
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
14370
14878
|
}
|
14371
14879
|
|
14880
|
+
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
14372
14881
|
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
14373
14882
|
for (int64_t ic = 0; ic < nev1; ++ic) {
|
14374
14883
|
// dst indices
|
@@ -14376,9 +14885,13 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14376
14885
|
const int i2 = iq2;
|
14377
14886
|
const int i3 = iq3;
|
14378
14887
|
|
14379
|
-
|
14380
|
-
|
14381
|
-
|
14888
|
+
// v indices
|
14889
|
+
const int iv2 = iq2 % nev2;
|
14890
|
+
const int iv3 = iq3;
|
14891
|
+
|
14892
|
+
ggml_vec_dot_f16(nev0,
|
14893
|
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
14894
|
+
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
14382
14895
|
S16);
|
14383
14896
|
}
|
14384
14897
|
} else {
|
@@ -14388,9 +14901,13 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14388
14901
|
const int i2 = iq2;
|
14389
14902
|
const int i3 = iq3;
|
14390
14903
|
|
14391
|
-
|
14392
|
-
|
14393
|
-
|
14904
|
+
// v indices
|
14905
|
+
const int iv2 = iq2 % nev2;
|
14906
|
+
const int iv3 = iq3;
|
14907
|
+
|
14908
|
+
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
14909
|
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
14910
|
+
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
14394
14911
|
S16);
|
14395
14912
|
}
|
14396
14913
|
}
|
@@ -14433,18 +14950,18 @@ static void ggml_compute_forward_flash_ff_f16(
|
|
14433
14950
|
int64_t t0 = ggml_perf_time_us();
|
14434
14951
|
UNUSED(t0);
|
14435
14952
|
|
14436
|
-
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
14437
|
-
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
14438
|
-
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
14439
|
-
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
14440
|
-
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
14441
|
-
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
14442
|
-
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
14443
|
-
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
14444
|
-
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
14445
|
-
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
14446
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14447
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14953
|
+
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
14954
|
+
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
14955
|
+
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
14956
|
+
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
14957
|
+
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
14958
|
+
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
14959
|
+
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
14960
|
+
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
14961
|
+
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
14962
|
+
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
14963
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14964
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14448
14965
|
|
14449
14966
|
const int ith = params->ith;
|
14450
14967
|
const int nth = params->nth;
|
@@ -14592,16 +15109,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|
14592
15109
|
int64_t t0 = ggml_perf_time_us();
|
14593
15110
|
UNUSED(t0);
|
14594
15111
|
|
14595
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14596
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14597
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14598
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14599
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14600
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14601
|
-
GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
|
14602
|
-
GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
|
14603
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14604
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15112
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15113
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15114
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15115
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15116
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15117
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15118
|
+
GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
|
15119
|
+
GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
|
15120
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15121
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14605
15122
|
|
14606
15123
|
const int ith = params->ith;
|
14607
15124
|
const int nth = params->nth;
|
@@ -14649,10 +15166,37 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|
14649
15166
|
return;
|
14650
15167
|
}
|
14651
15168
|
|
14652
|
-
|
15169
|
+
const int64_t elem_q = ggml_nelements(q);
|
15170
|
+
const int64_t elem_k = ggml_nelements(k);
|
14653
15171
|
|
14654
|
-
|
14655
|
-
|
15172
|
+
enum ggml_type result_type = dst->type;
|
15173
|
+
GGML_ASSERT(ggml_blck_size(result_type) == 1);
|
15174
|
+
const size_t tsize = ggml_type_size(result_type);
|
15175
|
+
|
15176
|
+
const size_t offs_q = 0;
|
15177
|
+
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
|
15178
|
+
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
|
15179
|
+
|
15180
|
+
void * grad_q = (char *) dst->data;
|
15181
|
+
void * grad_k = (char *) dst->data + offs_k;
|
15182
|
+
void * grad_v = (char *) dst->data + offs_v;
|
15183
|
+
|
15184
|
+
const size_t nbgq1 = nb0*neq0;
|
15185
|
+
const size_t nbgq2 = nb0*neq0*neq1;
|
15186
|
+
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
15187
|
+
|
15188
|
+
const size_t nbgk1 = nb0*nek0;
|
15189
|
+
const size_t nbgk2 = nb0*nek0*nek1;
|
15190
|
+
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
15191
|
+
|
15192
|
+
const size_t nbgv1 = nb0*nev0;
|
15193
|
+
const size_t nbgv2 = nb0*nev0*nev1;
|
15194
|
+
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
15195
|
+
|
15196
|
+
// parallelize by k rows using ggml_vec_dot_f32
|
15197
|
+
|
15198
|
+
// total rows in k
|
15199
|
+
const int nr = nek2*nek3;
|
14656
15200
|
|
14657
15201
|
// rows per thread
|
14658
15202
|
const int dr = (nr + nth - 1)/nth;
|
@@ -14665,268 +15209,243 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|
14665
15209
|
|
14666
15210
|
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
14667
15211
|
|
15212
|
+
// how often k2 (and v2) is repeated in q2
|
15213
|
+
int nrep = neq2/nek2;
|
15214
|
+
|
14668
15215
|
for (int ir = ir0; ir < ir1; ++ir) {
|
14669
15216
|
// q indices
|
14670
|
-
const int
|
14671
|
-
const int
|
14672
|
-
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
15217
|
+
const int ik3 = ir/(nek2);
|
15218
|
+
const int ik2 = ir - ik3*nek2;
|
14673
15219
|
|
15220
|
+
const int iq3 = ik3;
|
15221
|
+
const int id3 = ik3;
|
15222
|
+
const int iv3 = ik3;
|
15223
|
+
const int iv2 = ik2;
|
14674
15224
|
|
14675
|
-
|
14676
|
-
|
14677
|
-
|
14678
|
-
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
15225
|
+
for (int irep = 0; irep < nrep; ++irep) {
|
15226
|
+
const int iq2 = ik2 + irep*nek2;
|
15227
|
+
const int id2 = iq2;
|
14679
15228
|
|
14680
|
-
|
14681
|
-
|
14682
|
-
|
15229
|
+
// (ik2 + irep*nek2) % nek2 == ik2
|
15230
|
+
for (int iq1 = 0; iq1 < neq1; ++iq1) {
|
15231
|
+
const int id1 = iq1;
|
14683
15232
|
|
14684
|
-
|
14685
|
-
//
|
14686
|
-
|
14687
|
-
|
14688
|
-
const int ik1 = ic;
|
15233
|
+
// not sure about CACHE_LINE_SIZE_F32..
|
15234
|
+
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
15235
|
+
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
15236
|
+
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
14689
15237
|
|
14690
|
-
|
14691
|
-
|
15238
|
+
for (int i = M; i < Mup; ++i) {
|
15239
|
+
S[i] = -INFINITY;
|
15240
|
+
}
|
14692
15241
|
|
14693
|
-
|
14694
|
-
|
14695
|
-
|
14696
|
-
|
14697
|
-
}
|
15242
|
+
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
15243
|
+
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15244
|
+
// k indices
|
15245
|
+
const int ik1 = ic;
|
14698
15246
|
|
14699
|
-
|
14700
|
-
|
15247
|
+
// S indices
|
15248
|
+
const int i1 = ik1;
|
14701
15249
|
|
14702
|
-
|
14703
|
-
|
14704
|
-
|
14705
|
-
|
14706
|
-
}
|
15250
|
+
ggml_vec_dot_f32(neq0,
|
15251
|
+
S + i1,
|
15252
|
+
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
15253
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
14707
15254
|
}
|
14708
|
-
}
|
14709
15255
|
|
14710
|
-
|
14711
|
-
|
14712
|
-
float max = -INFINITY;
|
14713
|
-
ggml_vec_max_f32(M, &max, S);
|
15256
|
+
// scale
|
15257
|
+
ggml_vec_scale_f32(masked_begin, S, scale);
|
14714
15258
|
|
14715
|
-
|
15259
|
+
for (int64_t i = masked_begin; i < M; i++) {
|
15260
|
+
S[i] = -INFINITY;
|
15261
|
+
}
|
15262
|
+
|
15263
|
+
// softmax
|
15264
|
+
// exclude known -INF S[..] values from max and loop
|
15265
|
+
// dont forget to set their SM values to zero
|
14716
15266
|
{
|
15267
|
+
float max = -INFINITY;
|
15268
|
+
ggml_vec_max_f32(masked_begin, &max, S);
|
15269
|
+
|
15270
|
+
ggml_float sum = 0.0;
|
15271
|
+
{
|
14717
15272
|
#ifdef GGML_SOFT_MAX_ACCELERATE
|
14718
|
-
|
14719
|
-
|
14720
|
-
|
14721
|
-
|
15273
|
+
max = -max;
|
15274
|
+
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
15275
|
+
vvexpf(SM, SM, &Mup);
|
15276
|
+
ggml_vec_sum_f32(Mup, &sum, SM);
|
14722
15277
|
#else
|
14723
|
-
|
14724
|
-
|
14725
|
-
|
14726
|
-
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
14727
|
-
float * SR = S + i;
|
14728
|
-
float * SW = SM + i;
|
15278
|
+
uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
|
15279
|
+
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
14729
15280
|
|
14730
|
-
for (int
|
14731
|
-
if (
|
14732
|
-
|
14733
|
-
}
|
15281
|
+
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
15282
|
+
if (i >= masked_begin) {
|
15283
|
+
break;
|
15284
|
+
}
|
15285
|
+
float * SR = S + i;
|
15286
|
+
float * SW = SM + i;
|
15287
|
+
|
15288
|
+
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
15289
|
+
if (i + j >= masked_begin) {
|
15290
|
+
break;
|
15291
|
+
} else if (SR[j] == -INFINITY) {
|
15292
|
+
SW[j] = 0.0f;
|
15293
|
+
} else {
|
14734
15294
|
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
14735
|
-
|
15295
|
+
const float val = expf(SR[j] - max);
|
14736
15296
|
#else
|
14737
|
-
|
14738
|
-
|
14739
|
-
|
15297
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
15298
|
+
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
15299
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
14740
15300
|
#endif
|
14741
|
-
|
14742
|
-
|
15301
|
+
sump[j] += (ggml_float)val;
|
15302
|
+
SW[j] = val;
|
15303
|
+
}
|
14743
15304
|
}
|
14744
15305
|
}
|
14745
|
-
}
|
14746
15306
|
|
14747
|
-
|
14748
|
-
|
14749
|
-
|
15307
|
+
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
15308
|
+
sum += sump[i];
|
15309
|
+
}
|
14750
15310
|
#endif
|
14751
|
-
|
14752
|
-
|
14753
|
-
assert(sum > 0.0);
|
14754
|
-
|
14755
|
-
sum = 1.0/sum;
|
14756
|
-
ggml_vec_scale_f32(M, SM, sum);
|
14757
|
-
|
14758
|
-
}
|
14759
|
-
|
14760
|
-
// step-by-step explanation
|
14761
|
-
{
|
14762
|
-
// forward-process shape grads from backward process
|
14763
|
-
// parallel_for iq2,iq3:
|
14764
|
-
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
14765
|
-
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
14766
|
-
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
14767
|
-
// for iq1:
|
14768
|
-
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
14769
|
-
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
14770
|
-
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
14771
|
-
// S0 = -Inf [D,1,1,1]
|
14772
|
-
// ~S1[i] = dot(kcur[:D,i], qcur)
|
14773
|
-
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
14774
|
-
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
14775
|
-
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14776
|
-
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
14777
|
-
// ~S5[i] = dot(vcur[:,i], S4)
|
14778
|
-
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
14779
|
-
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
14780
|
-
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
14781
|
-
// dst backward-/ grad[dst] = d
|
14782
|
-
//
|
14783
|
-
// output gradients with their dependencies:
|
14784
|
-
//
|
14785
|
-
// grad[kcur] = grad[S1].T @ qcur
|
14786
|
-
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
14787
|
-
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14788
|
-
// grad[S4] = grad[S5] @ vcur
|
14789
|
-
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
14790
|
-
// grad[qcur] = grad[S1] @ kcur
|
14791
|
-
// grad[vcur] = grad[S5].T @ S4
|
14792
|
-
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
14793
|
-
//
|
14794
|
-
// in post-order:
|
14795
|
-
//
|
14796
|
-
// S1 = qcur @ kcur.T
|
14797
|
-
// S2 = S1 * scale
|
14798
|
-
// S3 = diag_mask_inf(S2, P)
|
14799
|
-
// S4 = softmax(S3)
|
14800
|
-
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
14801
|
-
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14802
|
-
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
14803
|
-
// grad[qcur] = grad[S1] @ kcur
|
14804
|
-
// grad[kcur] = grad[S1].T @ qcur
|
14805
|
-
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
14806
|
-
//
|
14807
|
-
// using less variables (SM=S4):
|
14808
|
-
//
|
14809
|
-
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
14810
|
-
// SM = softmax(S)
|
14811
|
-
// S = d[:D,iq1,iq2,iq3] @ vcur
|
14812
|
-
// dot_SM_gradSM = dot(SM, S)
|
14813
|
-
// S = SM * (S - dot(SM, S))
|
14814
|
-
// S = diag_mask_zero(S, P) * scale
|
14815
|
-
//
|
14816
|
-
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
14817
|
-
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
14818
|
-
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
14819
|
-
}
|
14820
|
-
|
14821
|
-
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
14822
|
-
// S = d[:D,iq1,iq2,iq3] @ vcur
|
14823
|
-
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
14824
|
-
ggml_vec_set_f32(M, S, 0);
|
14825
|
-
for (int64_t ic = 0; ic < D; ++ic) {
|
14826
|
-
// dst indices
|
14827
|
-
const int i1 = iq1;
|
14828
|
-
const int i2 = iq2;
|
14829
|
-
const int i3 = iq3;
|
15311
|
+
}
|
14830
15312
|
|
14831
|
-
|
14832
|
-
S,
|
14833
|
-
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
14834
|
-
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
14835
|
-
}
|
15313
|
+
assert(sum > 0.0);
|
14836
15314
|
|
14837
|
-
|
14838
|
-
|
14839
|
-
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
14840
|
-
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
14841
|
-
ggml_vec_mul_f32 (M, S, S, SM);
|
15315
|
+
sum = 1.0/sum;
|
15316
|
+
ggml_vec_scale_f32(masked_begin, SM, sum);
|
14842
15317
|
|
14843
|
-
// S = diag_mask_zero(S, P) * scale
|
14844
|
-
if (masked) {
|
14845
|
-
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
14846
|
-
// S[i] = 0;
|
14847
|
-
// }
|
14848
|
-
for (int64_t i = P; i < M; i++) {
|
14849
|
-
if (i > P + iq1) {
|
14850
|
-
S[i] = 0;
|
14851
|
-
}
|
14852
15318
|
}
|
14853
|
-
}
|
14854
|
-
ggml_vec_scale_f32(M, S, scale);
|
14855
|
-
|
14856
|
-
void * grad_q = (char *) dst->data;
|
14857
|
-
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
14858
|
-
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
14859
|
-
|
14860
|
-
const size_t nbgq1 = nb0*neq0;
|
14861
|
-
const size_t nbgq2 = nb0*neq0*neq1;
|
14862
|
-
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
14863
|
-
|
14864
|
-
const size_t nbgk1 = nb0*nek0;
|
14865
|
-
const size_t nbgk2 = nb0*nek0*nek1;
|
14866
|
-
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
14867
|
-
|
14868
|
-
const size_t nbgv1 = nb0*nev0;
|
14869
|
-
const size_t nbgv2 = nb0*nev0*nev1;
|
14870
|
-
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
14871
|
-
|
14872
|
-
// S shape [M,1]
|
14873
|
-
// SM shape [M,1]
|
14874
|
-
// kcur shape [D,M]
|
14875
|
-
// qcur shape [D,1]
|
14876
|
-
// vcur shape [M,D]
|
14877
|
-
//
|
14878
|
-
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
14879
|
-
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
14880
|
-
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
14881
|
-
//
|
14882
|
-
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
14883
|
-
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
14884
|
-
for (int64_t ic = 0; ic < M; ++ic) {
|
14885
|
-
// dst indices
|
14886
|
-
const int i1 = iq1;
|
14887
|
-
const int i2 = iq2;
|
14888
|
-
const int i3 = iq3;
|
14889
15319
|
|
14890
|
-
|
14891
|
-
|
14892
|
-
|
14893
|
-
|
14894
|
-
|
15320
|
+
// step-by-step explanation
|
15321
|
+
{
|
15322
|
+
// forward-process shape grads from backward process
|
15323
|
+
// parallel_for ik2,ik3:
|
15324
|
+
// for irep:
|
15325
|
+
// iq2 = ik2 + irep*nek2
|
15326
|
+
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
|
15327
|
+
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
15328
|
+
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
|
15329
|
+
// for iq1:
|
15330
|
+
// kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
15331
|
+
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
15332
|
+
// vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
15333
|
+
// S0 = -Inf [D,1,1,1]
|
15334
|
+
// ~S1[i] = dot(kcur[:D,i], qcur)
|
15335
|
+
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
15336
|
+
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
15337
|
+
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
15338
|
+
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
15339
|
+
// ~S5[i] = dot(vcur[:,i], S4)
|
15340
|
+
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
|
15341
|
+
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
15342
|
+
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
|
15343
|
+
// dst backward-/ grad[dst] = d
|
15344
|
+
//
|
15345
|
+
// output gradients with their dependencies:
|
15346
|
+
//
|
15347
|
+
// grad[kcur] = grad[S1].T @ qcur
|
15348
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
15349
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
15350
|
+
// grad[S4] = grad[S5] @ vcur
|
15351
|
+
// grad[S4] = d[:D,id1,id2,id3] @ vcur
|
15352
|
+
// grad[qcur] = grad[S1] @ kcur
|
15353
|
+
// grad[vcur] = grad[S5].T @ S4
|
15354
|
+
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
|
15355
|
+
//
|
15356
|
+
// in post-order:
|
15357
|
+
//
|
15358
|
+
// S1 = qcur @ kcur.T
|
15359
|
+
// S2 = S1 * scale
|
15360
|
+
// S3 = diag_mask_inf(S2, P)
|
15361
|
+
// S4 = softmax(S3)
|
15362
|
+
// grad[S4] = d[:D,id1,id2,id3] @ vcur
|
15363
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
15364
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
15365
|
+
// grad[qcur] = grad[S1] @ kcur
|
15366
|
+
// grad[kcur] = grad[S1].T @ qcur
|
15367
|
+
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
|
15368
|
+
//
|
15369
|
+
// using less variables (SM=S4):
|
15370
|
+
//
|
15371
|
+
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
15372
|
+
// SM = softmax(S)
|
15373
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
15374
|
+
// dot_SM_gradSM = dot(SM, S)
|
15375
|
+
// S = SM * (S - dot(SM, S))
|
15376
|
+
// S = diag_mask_zero(S, P) * scale
|
15377
|
+
//
|
15378
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
15379
|
+
// grad[k][:D,:M,ik2,ik3] += S.T @ qcur
|
15380
|
+
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
|
15381
|
+
}
|
14895
15382
|
|
14896
|
-
|
14897
|
-
|
14898
|
-
|
14899
|
-
|
14900
|
-
//
|
14901
|
-
|
14902
|
-
|
14903
|
-
|
15383
|
+
// S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
|
15384
|
+
// S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
|
15385
|
+
// for ic:
|
15386
|
+
// S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
|
15387
|
+
// exclude known future zero S[..] values from operation
|
15388
|
+
ggml_vec_set_f32(masked_begin, S, 0);
|
15389
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
15390
|
+
ggml_vec_mad_f32(masked_begin,
|
15391
|
+
S,
|
15392
|
+
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
15393
|
+
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
15394
|
+
}
|
14904
15395
|
|
14905
|
-
//
|
14906
|
-
|
14907
|
-
|
14908
|
-
|
14909
|
-
|
14910
|
-
|
14911
|
-
|
14912
|
-
|
15396
|
+
// S = SM * (S - dot(SM, S))
|
15397
|
+
float dot_SM_gradSM = 0;
|
15398
|
+
ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
|
15399
|
+
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
15400
|
+
ggml_vec_mul_f32 (masked_begin, S, S, SM);
|
15401
|
+
|
15402
|
+
// S = diag_mask_zero(S, P) * scale
|
15403
|
+
// already done by above ggml_vec_set_f32
|
15404
|
+
|
15405
|
+
// exclude known zero S[..] values from operation
|
15406
|
+
ggml_vec_scale_f32(masked_begin, S, scale);
|
15407
|
+
|
15408
|
+
// S shape [M,1]
|
15409
|
+
// SM shape [M,1]
|
15410
|
+
// kcur shape [D,M]
|
15411
|
+
// qcur shape [D,1]
|
15412
|
+
// vcur shape [M,D]
|
15413
|
+
|
15414
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
15415
|
+
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
15416
|
+
// for ic:
|
15417
|
+
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
|
15418
|
+
// exclude known zero S[..] values from loop
|
15419
|
+
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15420
|
+
ggml_vec_mad_f32(D,
|
15421
|
+
(float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
|
15422
|
+
(float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
15423
|
+
S[ic]);
|
15424
|
+
}
|
14913
15425
|
|
14914
|
-
|
14915
|
-
|
14916
|
-
|
14917
|
-
|
14918
|
-
//
|
14919
|
-
|
14920
|
-
|
14921
|
-
|
15426
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
15427
|
+
// for ic:
|
15428
|
+
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
15429
|
+
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
15430
|
+
// exclude known zero S[..] values from loop
|
15431
|
+
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15432
|
+
ggml_vec_mad_f32(D,
|
15433
|
+
(float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
|
15434
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
|
15435
|
+
S[ic]);
|
15436
|
+
}
|
14922
15437
|
|
14923
|
-
//
|
14924
|
-
//
|
14925
|
-
//
|
14926
|
-
|
14927
|
-
|
14928
|
-
|
14929
|
-
|
15438
|
+
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
|
15439
|
+
// for ic:
|
15440
|
+
// grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
|
15441
|
+
// grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
|
15442
|
+
// exclude known zero SM[..] values from mad
|
15443
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
15444
|
+
ggml_vec_mad_f32(masked_begin,
|
15445
|
+
(float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
|
15446
|
+
SM,
|
15447
|
+
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
15448
|
+
}
|
14930
15449
|
}
|
14931
15450
|
}
|
14932
15451
|
}
|
@@ -14962,8 +15481,8 @@ static void ggml_compute_forward_win_part_f32(
|
|
14962
15481
|
return;
|
14963
15482
|
}
|
14964
15483
|
|
14965
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
14966
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15484
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
15485
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14967
15486
|
|
14968
15487
|
const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
|
14969
15488
|
const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
|
@@ -15024,8 +15543,8 @@ static void ggml_compute_forward_win_unpart_f32(
|
|
15024
15543
|
return;
|
15025
15544
|
}
|
15026
15545
|
|
15027
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
15028
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15546
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
15547
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15029
15548
|
|
15030
15549
|
const int32_t w = ((const int32_t *)(dst->op_params))[0];
|
15031
15550
|
|
@@ -15142,7 +15661,7 @@ static void ggml_compute_forward_get_rel_pos_f16(
|
|
15142
15661
|
|
15143
15662
|
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
|
15144
15663
|
|
15145
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
15664
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
15146
15665
|
|
15147
15666
|
const int64_t w = ne1;
|
15148
15667
|
|
@@ -15840,7 +16359,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
15840
16359
|
} break;
|
15841
16360
|
case GGML_OP_GET_ROWS_BACK:
|
15842
16361
|
{
|
15843
|
-
ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor
|
16362
|
+
ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
|
15844
16363
|
} break;
|
15845
16364
|
case GGML_OP_DIAG:
|
15846
16365
|
{
|
@@ -15864,11 +16383,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
15864
16383
|
} break;
|
15865
16384
|
case GGML_OP_ROPE:
|
15866
16385
|
{
|
15867
|
-
ggml_compute_forward_rope(params, tensor->src[0], tensor);
|
16386
|
+
ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
|
15868
16387
|
} break;
|
15869
16388
|
case GGML_OP_ROPE_BACK:
|
15870
16389
|
{
|
15871
|
-
ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
|
16390
|
+
ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
|
15872
16391
|
} break;
|
15873
16392
|
case GGML_OP_ALIBI:
|
15874
16393
|
{
|
@@ -16013,7 +16532,218 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
16013
16532
|
|
16014
16533
|
////////////////////////////////////////////////////////////////////////////////
|
16015
16534
|
|
16016
|
-
|
16535
|
+
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
|
16536
|
+
|
16537
|
+
static size_t hash(void * p) {
|
16538
|
+
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
16539
|
+
}
|
16540
|
+
|
16541
|
+
static size_t hash_find(void * hash_table[], void * p) {
|
16542
|
+
size_t h = hash(p);
|
16543
|
+
|
16544
|
+
// linear probing
|
16545
|
+
size_t i = h;
|
16546
|
+
while (hash_table[i] != NULL && hash_table[i] != p) {
|
16547
|
+
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
16548
|
+
if (i == h) {
|
16549
|
+
// visited all hash table entries -> not found
|
16550
|
+
return GGML_GRAPH_HASHTABLE_SIZE;
|
16551
|
+
}
|
16552
|
+
}
|
16553
|
+
return i;
|
16554
|
+
}
|
16555
|
+
|
16556
|
+
static bool hash_insert(void * hash_table[], void * p) {
|
16557
|
+
size_t i = hash_find(hash_table, p);
|
16558
|
+
|
16559
|
+
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
16560
|
+
|
16561
|
+
if (hash_table[i] == p) {
|
16562
|
+
return true;
|
16563
|
+
}
|
16564
|
+
|
16565
|
+
// insert
|
16566
|
+
GGML_ASSERT(hash_table[i] == NULL);
|
16567
|
+
hash_table[i] = p;
|
16568
|
+
return false;
|
16569
|
+
}
|
16570
|
+
|
16571
|
+
static bool hash_contains(void * hash_table[], void * p) {
|
16572
|
+
size_t i = hash_find(hash_table, p);
|
16573
|
+
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
|
16574
|
+
}
|
16575
|
+
|
16576
|
+
struct hash_map {
|
16577
|
+
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
|
16578
|
+
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
|
16579
|
+
};
|
16580
|
+
|
16581
|
+
static struct hash_map * new_hash_map(void) {
|
16582
|
+
struct hash_map * result = malloc(sizeof(struct hash_map));
|
16583
|
+
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
|
16584
|
+
result->keys[i] = NULL;
|
16585
|
+
result->vals[i] = NULL;
|
16586
|
+
}
|
16587
|
+
return result;
|
16588
|
+
}
|
16589
|
+
|
16590
|
+
static void free_hash_map(struct hash_map * map) {
|
16591
|
+
free(map);
|
16592
|
+
}
|
16593
|
+
|
16594
|
+
// gradient checkpointing
|
16595
|
+
|
16596
|
+
static struct ggml_tensor * ggml_recompute_graph_node(
|
16597
|
+
struct ggml_context * ctx,
|
16598
|
+
struct ggml_cgraph * graph,
|
16599
|
+
struct hash_map * replacements,
|
16600
|
+
struct ggml_tensor * node) {
|
16601
|
+
|
16602
|
+
if (node == NULL) {
|
16603
|
+
return NULL;
|
16604
|
+
}
|
16605
|
+
|
16606
|
+
if (node->is_param) {
|
16607
|
+
return node;
|
16608
|
+
}
|
16609
|
+
|
16610
|
+
if (!hash_contains(graph->visited_hash_table, node)) {
|
16611
|
+
return node;
|
16612
|
+
}
|
16613
|
+
|
16614
|
+
int count_children = 0;
|
16615
|
+
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
16616
|
+
if (node->src[k]) {
|
16617
|
+
++count_children;
|
16618
|
+
}
|
16619
|
+
}
|
16620
|
+
|
16621
|
+
if (count_children == 0) {
|
16622
|
+
return node;
|
16623
|
+
}
|
16624
|
+
|
16625
|
+
size_t i = hash_find(replacements->keys, node);
|
16626
|
+
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
16627
|
+
if (replacements->keys[i] == node) {
|
16628
|
+
return (struct ggml_tensor *) replacements->vals[i];
|
16629
|
+
}
|
16630
|
+
|
16631
|
+
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
|
16632
|
+
|
16633
|
+
// insert clone into replacements
|
16634
|
+
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
|
16635
|
+
replacements->keys[i] = node;
|
16636
|
+
replacements->vals[i] = clone;
|
16637
|
+
|
16638
|
+
clone->op = node->op;
|
16639
|
+
clone->grad = node->grad;
|
16640
|
+
clone->is_param = node->is_param;
|
16641
|
+
clone->extra = node->extra;
|
16642
|
+
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
|
16643
|
+
clone->nb[k] = node->nb[k];
|
16644
|
+
}
|
16645
|
+
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
16646
|
+
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
16647
|
+
}
|
16648
|
+
if (node->view_src != NULL) {
|
16649
|
+
clone->data = (node->view_src->data == NULL)
|
16650
|
+
? NULL // view_src not yet allocated
|
16651
|
+
: (char *) node->view_src->data // view_src already allocated
|
16652
|
+
+ node->view_offs;
|
16653
|
+
clone->view_src = node->view_src;
|
16654
|
+
clone->view_offs = node->view_offs;
|
16655
|
+
}
|
16656
|
+
|
16657
|
+
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
16658
|
+
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
16659
|
+
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
|
16660
|
+
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
|
16661
|
+
|
16662
|
+
return clone;
|
16663
|
+
}
|
16664
|
+
|
16665
|
+
void ggml_build_backward_gradient_checkpointing(
|
16666
|
+
struct ggml_context * ctx,
|
16667
|
+
struct ggml_cgraph * gf,
|
16668
|
+
struct ggml_cgraph * gb,
|
16669
|
+
struct ggml_cgraph * gb_tmp,
|
16670
|
+
struct ggml_tensor * * checkpoints,
|
16671
|
+
int n_checkpoints) {
|
16672
|
+
*gb_tmp = *gf;
|
16673
|
+
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
16674
|
+
|
16675
|
+
if (n_checkpoints <= 0) {
|
16676
|
+
*gb = *gb_tmp;
|
16677
|
+
return;
|
16678
|
+
}
|
16679
|
+
|
16680
|
+
struct hash_map * replacements = new_hash_map();
|
16681
|
+
|
16682
|
+
// insert checkpoints in replacements
|
16683
|
+
for (int i = 0; i < n_checkpoints; ++i) {
|
16684
|
+
size_t k = hash_find(replacements->keys, checkpoints[i]);
|
16685
|
+
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
16686
|
+
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
|
16687
|
+
replacements->keys[k] = checkpoints[i];
|
16688
|
+
replacements->vals[k] = checkpoints[i];
|
16689
|
+
}
|
16690
|
+
|
16691
|
+
*gb = *gf;
|
16692
|
+
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
|
16693
|
+
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
|
16694
|
+
// by recomputing them from checkpoints
|
16695
|
+
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
|
16696
|
+
struct ggml_tensor * node = gb_tmp->nodes[i];
|
16697
|
+
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
16698
|
+
// insert new tensors recomputing src, reusing already made replacements,
|
16699
|
+
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
16700
|
+
// recurse for input tensors,
|
16701
|
+
// unless (i.e. terminating when) input tensors are replacments (like checkpoints)
|
16702
|
+
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
16703
|
+
}
|
16704
|
+
// insert rewritten backward node with replacements made into resulting backward graph gb
|
16705
|
+
ggml_build_forward_expand(gb, node);
|
16706
|
+
}
|
16707
|
+
|
16708
|
+
free_hash_map(replacements);
|
16709
|
+
}
|
16710
|
+
|
16711
|
+
// functions to change gradients considering the case that input a might be initial gradient with zero value
|
16712
|
+
|
16713
|
+
static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
|
16714
|
+
if (hash_contains(zero_table, a)) {
|
16715
|
+
return b;
|
16716
|
+
} else {
|
16717
|
+
return ggml_add_impl(ctx, a, b, false);
|
16718
|
+
}
|
16719
|
+
}
|
16720
|
+
|
16721
|
+
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
|
16722
|
+
if (hash_contains(zero_table, a)) {
|
16723
|
+
struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
|
16724
|
+
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
|
16725
|
+
} else {
|
16726
|
+
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
|
16727
|
+
}
|
16728
|
+
}
|
16729
|
+
|
16730
|
+
static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
|
16731
|
+
if (hash_contains(zero_table, a)) {
|
16732
|
+
return ggml_repeat(ctx, b, a);
|
16733
|
+
} else {
|
16734
|
+
return ggml_add1_impl(ctx, a, b, false);
|
16735
|
+
}
|
16736
|
+
}
|
16737
|
+
|
16738
|
+
static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
|
16739
|
+
if (hash_contains(zero_table, a)) {
|
16740
|
+
return ggml_neg(ctx, b);
|
16741
|
+
} else {
|
16742
|
+
return ggml_sub_impl(ctx, a, b, false);
|
16743
|
+
}
|
16744
|
+
}
|
16745
|
+
|
16746
|
+
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
|
16017
16747
|
struct ggml_tensor * src0 = tensor->src[0];
|
16018
16748
|
struct ggml_tensor * src1 = tensor->src[1];
|
16019
16749
|
|
@@ -16021,34 +16751,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16021
16751
|
case GGML_OP_DUP:
|
16022
16752
|
{
|
16023
16753
|
if (src0->grad) {
|
16024
|
-
src0->grad =
|
16754
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16025
16755
|
}
|
16026
16756
|
} break;
|
16027
16757
|
case GGML_OP_ADD:
|
16028
16758
|
{
|
16029
16759
|
if (src0->grad) {
|
16030
|
-
src0->grad =
|
16760
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16031
16761
|
}
|
16032
16762
|
if (src1->grad) {
|
16033
|
-
src1->grad =
|
16763
|
+
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
16034
16764
|
}
|
16035
16765
|
} break;
|
16036
16766
|
case GGML_OP_ADD1:
|
16037
16767
|
{
|
16038
16768
|
if (src0->grad) {
|
16039
|
-
src0->grad =
|
16769
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16040
16770
|
}
|
16041
16771
|
if (src1->grad) {
|
16042
|
-
src1->grad =
|
16772
|
+
src1->grad = ggml_add_or_set(ctx,
|
16043
16773
|
src1->grad,
|
16044
16774
|
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
|
16045
|
-
|
16775
|
+
zero_table);
|
16046
16776
|
}
|
16047
16777
|
} break;
|
16048
16778
|
case GGML_OP_ACC:
|
16049
16779
|
{
|
16050
16780
|
if (src0->grad) {
|
16051
|
-
src0->grad =
|
16781
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16052
16782
|
}
|
16053
16783
|
if (src1->grad) {
|
16054
16784
|
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
|
@@ -16065,117 +16795,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16065
16795
|
nb1, nb2, nb3, offset);
|
16066
16796
|
|
16067
16797
|
src1->grad =
|
16068
|
-
|
16798
|
+
ggml_add_or_set(ctx,
|
16069
16799
|
src1->grad,
|
16070
16800
|
ggml_reshape(ctx,
|
16071
16801
|
ggml_cont(ctx, tensor_grad_view),
|
16072
16802
|
src1->grad),
|
16073
|
-
|
16803
|
+
zero_table);
|
16074
16804
|
}
|
16075
16805
|
} break;
|
16076
16806
|
case GGML_OP_SUB:
|
16077
16807
|
{
|
16078
16808
|
if (src0->grad) {
|
16079
|
-
src0->grad =
|
16809
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16080
16810
|
}
|
16081
16811
|
if (src1->grad) {
|
16082
|
-
src1->grad =
|
16812
|
+
src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
16083
16813
|
}
|
16084
16814
|
} break;
|
16085
16815
|
case GGML_OP_MUL:
|
16086
16816
|
{
|
16087
16817
|
if (src0->grad) {
|
16088
16818
|
src0->grad =
|
16089
|
-
|
16819
|
+
ggml_add_or_set(ctx,
|
16090
16820
|
src0->grad,
|
16091
16821
|
ggml_mul(ctx, src1, tensor->grad),
|
16092
|
-
|
16822
|
+
zero_table);
|
16093
16823
|
}
|
16094
16824
|
if (src1->grad) {
|
16095
16825
|
src1->grad =
|
16096
|
-
|
16826
|
+
ggml_add_or_set(ctx,
|
16097
16827
|
src1->grad,
|
16098
16828
|
ggml_mul(ctx, src0, tensor->grad),
|
16099
|
-
|
16829
|
+
zero_table);
|
16100
16830
|
}
|
16101
16831
|
} break;
|
16102
16832
|
case GGML_OP_DIV:
|
16103
16833
|
{
|
16104
16834
|
if (src0->grad) {
|
16105
16835
|
src0->grad =
|
16106
|
-
|
16836
|
+
ggml_add_or_set(ctx,
|
16107
16837
|
src0->grad,
|
16108
16838
|
ggml_div(ctx, tensor->grad, src1),
|
16109
|
-
|
16839
|
+
zero_table);
|
16110
16840
|
}
|
16111
16841
|
if (src1->grad) {
|
16112
16842
|
src1->grad =
|
16113
|
-
|
16843
|
+
ggml_sub_or_set(ctx,
|
16114
16844
|
src1->grad,
|
16115
16845
|
ggml_mul(ctx,
|
16116
16846
|
tensor->grad,
|
16117
16847
|
ggml_div(ctx, tensor, src1)),
|
16118
|
-
|
16848
|
+
zero_table);
|
16119
16849
|
}
|
16120
16850
|
} break;
|
16121
16851
|
case GGML_OP_SQR:
|
16122
16852
|
{
|
16123
16853
|
if (src0->grad) {
|
16124
16854
|
src0->grad =
|
16125
|
-
|
16855
|
+
ggml_add_or_set(ctx,
|
16126
16856
|
src0->grad,
|
16127
16857
|
ggml_scale(ctx,
|
16128
16858
|
ggml_mul(ctx, src0, tensor->grad),
|
16129
16859
|
ggml_new_f32(ctx, 2.0f)),
|
16130
|
-
|
16860
|
+
zero_table);
|
16131
16861
|
}
|
16132
16862
|
} break;
|
16133
16863
|
case GGML_OP_SQRT:
|
16134
16864
|
{
|
16135
16865
|
if (src0->grad) {
|
16136
16866
|
src0->grad =
|
16137
|
-
|
16867
|
+
ggml_add_or_set(ctx,
|
16138
16868
|
src0->grad,
|
16139
16869
|
ggml_scale(ctx,
|
16140
16870
|
ggml_div(ctx,
|
16141
16871
|
tensor->grad,
|
16142
16872
|
tensor),
|
16143
16873
|
ggml_new_f32(ctx, 0.5f)),
|
16144
|
-
|
16874
|
+
zero_table);
|
16145
16875
|
}
|
16146
16876
|
} break;
|
16147
16877
|
case GGML_OP_LOG:
|
16148
16878
|
{
|
16149
16879
|
if (src0->grad) {
|
16150
16880
|
src0->grad =
|
16151
|
-
|
16881
|
+
ggml_add_or_set(ctx,
|
16152
16882
|
src0->grad,
|
16153
16883
|
ggml_div(ctx,
|
16154
16884
|
tensor->grad,
|
16155
16885
|
src0),
|
16156
|
-
|
16886
|
+
zero_table);
|
16157
16887
|
}
|
16158
16888
|
} break;
|
16159
16889
|
case GGML_OP_SUM:
|
16160
16890
|
{
|
16161
16891
|
if (src0->grad) {
|
16162
16892
|
src0->grad =
|
16163
|
-
|
16893
|
+
ggml_add1_or_set(ctx,
|
16164
16894
|
src0->grad,
|
16165
16895
|
tensor->grad,
|
16166
|
-
|
16896
|
+
zero_table);
|
16167
16897
|
}
|
16168
16898
|
} break;
|
16169
16899
|
case GGML_OP_SUM_ROWS:
|
16170
16900
|
{
|
16171
16901
|
if (src0->grad) {
|
16172
16902
|
src0->grad =
|
16173
|
-
|
16903
|
+
ggml_add_or_set(ctx,
|
16174
16904
|
src0->grad,
|
16175
16905
|
ggml_repeat(ctx,
|
16176
16906
|
tensor->grad,
|
16177
16907
|
src0->grad),
|
16178
|
-
|
16908
|
+
zero_table);
|
16179
16909
|
}
|
16180
16910
|
} break;
|
16181
16911
|
case GGML_OP_MEAN:
|
@@ -16187,20 +16917,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16187
16917
|
{
|
16188
16918
|
// necessary for llama
|
16189
16919
|
if (src0->grad) {
|
16190
|
-
src0->grad =
|
16920
|
+
src0->grad = ggml_add_or_set(ctx,
|
16191
16921
|
src0->grad,
|
16192
16922
|
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
16193
|
-
|
16923
|
+
zero_table);
|
16194
16924
|
}
|
16195
16925
|
} break;
|
16196
16926
|
case GGML_OP_REPEAT_BACK:
|
16197
16927
|
{
|
16198
16928
|
if (src0->grad) {
|
16199
16929
|
// TODO: test this
|
16200
|
-
src0->grad =
|
16930
|
+
src0->grad = ggml_add_or_set(ctx,
|
16201
16931
|
src0->grad,
|
16202
16932
|
ggml_repeat(ctx, tensor->grad, src0->grad),
|
16203
|
-
|
16933
|
+
zero_table);
|
16204
16934
|
}
|
16205
16935
|
} break;
|
16206
16936
|
case GGML_OP_CONCAT:
|
@@ -16222,10 +16952,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16222
16952
|
float eps;
|
16223
16953
|
memcpy(&eps, tensor->op_params, sizeof(float));
|
16224
16954
|
|
16225
|
-
src0->grad =
|
16955
|
+
src0->grad = ggml_add_or_set(ctx,
|
16226
16956
|
src0->grad,
|
16227
16957
|
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
|
16228
|
-
|
16958
|
+
zero_table);
|
16229
16959
|
}
|
16230
16960
|
} break;
|
16231
16961
|
case GGML_OP_RMS_NORM_BACK:
|
@@ -16249,37 +16979,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16249
16979
|
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
|
16250
16980
|
// ds1 = t.T.dot(dt)
|
16251
16981
|
|
16252
|
-
// tensor.shape [m,p]
|
16253
|
-
// src0.shape [n,m]
|
16254
|
-
// src1.shape [n,p]
|
16982
|
+
// tensor.shape [m,p,qq,rr]
|
16983
|
+
// src0.shape [n,m,q1,r1]
|
16984
|
+
// src1.shape [n,p,qq,rr]
|
16255
16985
|
|
16256
16986
|
// necessary for llama
|
16257
16987
|
if (src0->grad) {
|
16988
|
+
struct ggml_tensor * s1_tg =
|
16989
|
+
ggml_out_prod(ctx, // [n,m,qq,rr]
|
16990
|
+
src1, // [n,p,qq,rr]
|
16991
|
+
tensor->grad); // [m,p,qq,rr]
|
16992
|
+
const int64_t qq = s1_tg->ne[2];
|
16993
|
+
const int64_t rr = s1_tg->ne[3];
|
16994
|
+
const int64_t q1 = src0->ne[2];
|
16995
|
+
const int64_t r1 = src0->ne[3];
|
16996
|
+
const bool ne2_broadcasted = qq > q1;
|
16997
|
+
const bool ne3_broadcasted = rr > r1;
|
16998
|
+
if (ne2_broadcasted || ne3_broadcasted) {
|
16999
|
+
// sum broadcast repetitions of s1_tg into shape of src0
|
17000
|
+
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
|
17001
|
+
}
|
16258
17002
|
src0->grad =
|
16259
|
-
|
16260
|
-
src0->grad,
|
16261
|
-
|
16262
|
-
|
16263
|
-
tensor->grad), // [m,p]
|
16264
|
-
inplace);
|
17003
|
+
ggml_add_or_set(ctx,
|
17004
|
+
src0->grad, // [n,m,q1,r1]
|
17005
|
+
s1_tg, // [n,m,q1,r1]
|
17006
|
+
zero_table);
|
16265
17007
|
}
|
16266
17008
|
if (src1->grad) {
|
16267
17009
|
src1->grad =
|
16268
|
-
|
16269
|
-
src1->grad,
|
16270
|
-
// ggml_mul_mat(ctx, // [n,p]
|
16271
|
-
// ggml_cont(ctx, // [m,n]
|
16272
|
-
// ggml_transpose(ctx, src0)), // [m,n]
|
16273
|
-
// tensor->grad), // [m,p]
|
17010
|
+
ggml_add_or_set(ctx,
|
17011
|
+
src1->grad, // [n,p,qq,rr]
|
17012
|
+
// ggml_mul_mat(ctx, // [n,p,qq,rr]
|
17013
|
+
// ggml_cont(ctx, // [m,n,q1,r1]
|
17014
|
+
// ggml_transpose(ctx, src0)), // [m,n,q1,r1]
|
17015
|
+
// tensor->grad), // [m,p,qq,rr]
|
16274
17016
|
|
16275
17017
|
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
16276
17018
|
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
16277
17019
|
// // and then use ggml_out_prod
|
16278
|
-
ggml_out_prod(ctx, // [n,p]
|
16279
|
-
src0, // [n,m]
|
16280
|
-
ggml_transpose(ctx, // [p,m]
|
16281
|
-
tensor->grad)), // [m,p]
|
16282
|
-
|
17020
|
+
ggml_out_prod(ctx, // [n,p,qq,rr]
|
17021
|
+
src0, // [n,m,q1,r1]
|
17022
|
+
ggml_transpose(ctx, // [p,m,qq,rr]
|
17023
|
+
tensor->grad)), // [m,p,qq,rr]
|
17024
|
+
zero_table);
|
16283
17025
|
}
|
16284
17026
|
} break;
|
16285
17027
|
case GGML_OP_OUT_PROD:
|
@@ -16291,17 +17033,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16291
17033
|
// necessary for llama
|
16292
17034
|
if (src0->grad) {
|
16293
17035
|
src0->grad =
|
16294
|
-
|
17036
|
+
ggml_add_or_set(ctx,
|
16295
17037
|
src0->grad,
|
16296
17038
|
ggml_scale_impl(ctx, tensor->grad, src1, false),
|
16297
|
-
|
17039
|
+
zero_table);
|
16298
17040
|
}
|
16299
17041
|
if (src1->grad) {
|
16300
17042
|
src1->grad =
|
16301
|
-
|
17043
|
+
ggml_add_or_set(ctx,
|
16302
17044
|
src1->grad,
|
16303
17045
|
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
|
16304
|
-
|
17046
|
+
zero_table);
|
16305
17047
|
}
|
16306
17048
|
} break;
|
16307
17049
|
case GGML_OP_SET:
|
@@ -16328,23 +17070,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16328
17070
|
}
|
16329
17071
|
|
16330
17072
|
if (src0->grad) {
|
16331
|
-
src0->grad =
|
17073
|
+
src0->grad = ggml_add_or_set(ctx,
|
16332
17074
|
src0->grad,
|
16333
17075
|
ggml_acc_impl(ctx,
|
16334
17076
|
tensor->grad,
|
16335
17077
|
ggml_neg(ctx, tensor_grad_view),
|
16336
17078
|
nb1, nb2, nb3, offset, false),
|
16337
|
-
|
17079
|
+
zero_table);
|
16338
17080
|
}
|
16339
17081
|
|
16340
17082
|
if (src1->grad) {
|
16341
17083
|
src1->grad =
|
16342
|
-
|
17084
|
+
ggml_add_or_set(ctx,
|
16343
17085
|
src1->grad,
|
16344
17086
|
ggml_reshape(ctx,
|
16345
17087
|
ggml_cont(ctx, tensor_grad_view),
|
16346
17088
|
src1->grad),
|
16347
|
-
|
17089
|
+
zero_table);
|
16348
17090
|
}
|
16349
17091
|
} break;
|
16350
17092
|
case GGML_OP_CPY:
|
@@ -16355,7 +17097,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16355
17097
|
// tensor = src0 * 1 + src1 * 0
|
16356
17098
|
if (src0->grad) {
|
16357
17099
|
// dsrc0 = dtensor * 1
|
16358
|
-
src0->grad =
|
17100
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16359
17101
|
}
|
16360
17102
|
if (src1->grad) {
|
16361
17103
|
// dsrc1 = dtensor * 0 -> noop
|
@@ -16367,7 +17109,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16367
17109
|
if (src0->grad) {
|
16368
17110
|
GGML_ASSERT(ggml_is_contiguous(src0->grad));
|
16369
17111
|
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
|
16370
|
-
src0->grad =
|
17112
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16371
17113
|
}
|
16372
17114
|
} break;
|
16373
17115
|
case GGML_OP_RESHAPE:
|
@@ -16375,9 +17117,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16375
17117
|
// necessary for llama
|
16376
17118
|
if (src0->grad) {
|
16377
17119
|
src0->grad =
|
16378
|
-
|
16379
|
-
ggml_reshape(ctx,
|
16380
|
-
|
17120
|
+
ggml_add_or_set(ctx, src0->grad,
|
17121
|
+
ggml_reshape(ctx,
|
17122
|
+
ggml_is_contiguous(tensor->grad)
|
17123
|
+
? tensor->grad
|
17124
|
+
: ggml_cont(ctx, tensor->grad),
|
17125
|
+
src0->grad),
|
17126
|
+
zero_table);
|
16381
17127
|
}
|
16382
17128
|
} break;
|
16383
17129
|
case GGML_OP_VIEW:
|
@@ -16406,7 +17152,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16406
17152
|
nb3 = (nb3 / n0) * ng;
|
16407
17153
|
}
|
16408
17154
|
|
16409
|
-
src0->grad =
|
17155
|
+
src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
|
16410
17156
|
}
|
16411
17157
|
} break;
|
16412
17158
|
case GGML_OP_PERMUTE:
|
@@ -16424,14 +17170,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16424
17170
|
axes_backward[axis2] = 2;
|
16425
17171
|
axes_backward[axis3] = 3;
|
16426
17172
|
src0->grad =
|
16427
|
-
|
17173
|
+
ggml_add_or_set(ctx, src0->grad,
|
16428
17174
|
ggml_permute(ctx,
|
16429
17175
|
tensor->grad,
|
16430
17176
|
axes_backward[0],
|
16431
17177
|
axes_backward[1],
|
16432
17178
|
axes_backward[2],
|
16433
17179
|
axes_backward[3]),
|
16434
|
-
|
17180
|
+
zero_table);
|
16435
17181
|
}
|
16436
17182
|
} break;
|
16437
17183
|
case GGML_OP_TRANSPOSE:
|
@@ -16439,9 +17185,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16439
17185
|
// necessary for llama
|
16440
17186
|
if (src0->grad) {
|
16441
17187
|
src0->grad =
|
16442
|
-
|
17188
|
+
ggml_add_or_set(ctx, src0->grad,
|
16443
17189
|
ggml_transpose(ctx, tensor->grad),
|
16444
|
-
|
17190
|
+
zero_table);
|
16445
17191
|
}
|
16446
17192
|
} break;
|
16447
17193
|
case GGML_OP_GET_ROWS:
|
@@ -16449,9 +17195,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16449
17195
|
// necessary for llama (only for tokenizer)
|
16450
17196
|
if (src0->grad) {
|
16451
17197
|
src0->grad =
|
16452
|
-
|
17198
|
+
ggml_add_or_set(ctx, src0->grad,
|
17199
|
+
// last ggml_get_rows_back argument src0->grad is only
|
17200
|
+
// necessary to setup correct output shape
|
16453
17201
|
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
|
16454
|
-
|
17202
|
+
zero_table);
|
16455
17203
|
}
|
16456
17204
|
if (src1->grad) {
|
16457
17205
|
// noop
|
@@ -16471,9 +17219,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16471
17219
|
if (src0->grad) {
|
16472
17220
|
const int n_past = ((int32_t *) tensor->op_params)[0];
|
16473
17221
|
src0->grad =
|
16474
|
-
|
17222
|
+
ggml_add_or_set(ctx, src0->grad,
|
16475
17223
|
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
16476
|
-
|
17224
|
+
zero_table);
|
16477
17225
|
}
|
16478
17226
|
} break;
|
16479
17227
|
case GGML_OP_DIAG_MASK_ZERO:
|
@@ -16482,9 +17230,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16482
17230
|
if (src0->grad) {
|
16483
17231
|
const int n_past = ((int32_t *) tensor->op_params)[0];
|
16484
17232
|
src0->grad =
|
16485
|
-
|
17233
|
+
ggml_add_or_set(ctx, src0->grad,
|
16486
17234
|
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
16487
|
-
|
17235
|
+
zero_table);
|
16488
17236
|
}
|
16489
17237
|
} break;
|
16490
17238
|
case GGML_OP_SOFT_MAX:
|
@@ -16492,9 +17240,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16492
17240
|
// necessary for llama
|
16493
17241
|
if (src0->grad) {
|
16494
17242
|
src0->grad =
|
16495
|
-
|
17243
|
+
ggml_add_or_set(ctx, src0->grad,
|
16496
17244
|
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
16497
|
-
|
17245
|
+
zero_table);
|
16498
17246
|
}
|
16499
17247
|
|
16500
17248
|
} break;
|
@@ -16506,7 +17254,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16506
17254
|
{
|
16507
17255
|
// necessary for llama
|
16508
17256
|
if (src0->grad) {
|
16509
|
-
const int n_past = ((int32_t *) tensor->op_params)[0];
|
17257
|
+
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
16510
17258
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
16511
17259
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
16512
17260
|
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
@@ -16519,11 +17267,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16519
17267
|
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
|
16520
17268
|
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
|
16521
17269
|
|
16522
|
-
src0->grad =
|
17270
|
+
src0->grad = ggml_add_or_set(ctx,
|
16523
17271
|
src0->grad,
|
16524
17272
|
ggml_rope_back(ctx,
|
16525
17273
|
tensor->grad,
|
16526
|
-
|
17274
|
+
src1,
|
16527
17275
|
n_dims,
|
16528
17276
|
mode,
|
16529
17277
|
n_ctx,
|
@@ -16531,13 +17279,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16531
17279
|
freq_scale,
|
16532
17280
|
xpos_base,
|
16533
17281
|
xpos_down),
|
16534
|
-
|
17282
|
+
zero_table);
|
16535
17283
|
}
|
16536
17284
|
} break;
|
16537
17285
|
case GGML_OP_ROPE_BACK:
|
16538
17286
|
{
|
16539
17287
|
if (src0->grad) {
|
16540
|
-
const int n_past = ((int32_t *) tensor->op_params)[0];
|
17288
|
+
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
16541
17289
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
16542
17290
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
16543
17291
|
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
@@ -16550,11 +17298,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16550
17298
|
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
|
16551
17299
|
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
|
16552
17300
|
|
16553
|
-
src0->grad =
|
17301
|
+
src0->grad = ggml_add_or_set(ctx,
|
16554
17302
|
src0->grad,
|
16555
17303
|
ggml_rope_impl(ctx,
|
16556
17304
|
tensor->grad,
|
16557
|
-
|
17305
|
+
src1,
|
16558
17306
|
n_dims,
|
16559
17307
|
mode,
|
16560
17308
|
n_ctx,
|
@@ -16563,7 +17311,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16563
17311
|
xpos_base,
|
16564
17312
|
xpos_down,
|
16565
17313
|
false),
|
16566
|
-
|
17314
|
+
zero_table);
|
16567
17315
|
}
|
16568
17316
|
} break;
|
16569
17317
|
case GGML_OP_ALIBI:
|
@@ -16614,145 +17362,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16614
17362
|
masked);
|
16615
17363
|
}
|
16616
17364
|
|
16617
|
-
|
16618
|
-
|
16619
|
-
|
16620
|
-
|
16621
|
-
|
16622
|
-
|
16623
|
-
|
16624
|
-
|
16625
|
-
|
16626
|
-
|
16627
|
-
|
16628
|
-
|
16629
|
-
offset);
|
16630
|
-
} break;
|
16631
|
-
case 3:
|
16632
|
-
{
|
16633
|
-
grad_q = ggml_view_3d(ctx,
|
16634
|
-
flash_grad,
|
16635
|
-
src0->ne[0],
|
16636
|
-
src0->ne[1],
|
16637
|
-
src0->ne[2],
|
16638
|
-
nb0*src0->ne[0],
|
16639
|
-
nb0*src0->ne[0]*src0->ne[1],
|
16640
|
-
offset);
|
16641
|
-
} break;
|
16642
|
-
case 4:
|
16643
|
-
{
|
16644
|
-
grad_q = ggml_view_4d(ctx,
|
16645
|
-
flash_grad,
|
16646
|
-
src0->ne[0],
|
16647
|
-
src0->ne[1],
|
16648
|
-
src0->ne[2],
|
16649
|
-
src0->ne[3],
|
16650
|
-
nb0*src0->ne[0],
|
16651
|
-
nb0*src0->ne[0]*src0->ne[1],
|
16652
|
-
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
16653
|
-
offset);
|
16654
|
-
} break;
|
16655
|
-
}
|
17365
|
+
struct ggml_tensor * src2 = tensor->src[2];
|
17366
|
+
const int64_t elem_q = ggml_nelements(src0);
|
17367
|
+
const int64_t elem_k = ggml_nelements(src1);
|
17368
|
+
const int64_t elem_v = ggml_nelements(src2);
|
17369
|
+
|
17370
|
+
enum ggml_type result_type = flash_grad->type;
|
17371
|
+
GGML_ASSERT(ggml_blck_size(result_type) == 1);
|
17372
|
+
const size_t tsize = ggml_type_size(result_type);
|
17373
|
+
|
17374
|
+
const size_t offs_q = 0;
|
17375
|
+
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
|
17376
|
+
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
|
16656
17377
|
|
16657
|
-
|
17378
|
+
if (src0->grad) {
|
17379
|
+
struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
|
17380
|
+
struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
|
17381
|
+
src0->grad = ggml_add_or_set(ctx,
|
16658
17382
|
src0->grad,
|
16659
17383
|
grad_q,
|
16660
|
-
|
17384
|
+
zero_table);
|
16661
17385
|
}
|
16662
|
-
|
16663
17386
|
if (src1->grad) {
|
16664
|
-
struct ggml_tensor *
|
16665
|
-
|
16666
|
-
|
16667
|
-
switch(src1->n_dims) {
|
16668
|
-
case 2:
|
16669
|
-
{
|
16670
|
-
grad_k = ggml_view_2d(ctx,
|
16671
|
-
flash_grad,
|
16672
|
-
src1->ne[0],
|
16673
|
-
src1->ne[1],
|
16674
|
-
nb0*src1->ne[0],
|
16675
|
-
offset);
|
16676
|
-
} break;
|
16677
|
-
case 3:
|
16678
|
-
{
|
16679
|
-
grad_k = ggml_view_3d(ctx,
|
16680
|
-
flash_grad,
|
16681
|
-
src1->ne[0],
|
16682
|
-
src1->ne[1],
|
16683
|
-
src1->ne[2],
|
16684
|
-
nb0*src1->ne[0],
|
16685
|
-
nb0*src1->ne[0]*src1->ne[1],
|
16686
|
-
offset);
|
16687
|
-
} break;
|
16688
|
-
case 4:
|
16689
|
-
{
|
16690
|
-
grad_k = ggml_view_4d(ctx,
|
16691
|
-
flash_grad,
|
16692
|
-
src1->ne[0],
|
16693
|
-
src1->ne[1],
|
16694
|
-
src1->ne[2],
|
16695
|
-
src1->ne[3],
|
16696
|
-
nb0*src1->ne[0],
|
16697
|
-
nb0*src1->ne[0]*src1->ne[1],
|
16698
|
-
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
16699
|
-
offset);
|
16700
|
-
} break;
|
16701
|
-
}
|
16702
|
-
|
16703
|
-
src1->grad = ggml_add_impl(ctx,
|
17387
|
+
struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
|
17388
|
+
struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
|
17389
|
+
src1->grad = ggml_add_or_set(ctx,
|
16704
17390
|
src1->grad,
|
16705
17391
|
grad_k,
|
16706
|
-
|
17392
|
+
zero_table);
|
16707
17393
|
}
|
16708
|
-
|
16709
|
-
|
16710
|
-
|
16711
|
-
|
16712
|
-
|
16713
|
-
const size_t nb0 = flash_grad->nb[0];
|
16714
|
-
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
16715
|
-
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
16716
|
-
switch(opt0->n_dims) {
|
16717
|
-
case 2:
|
16718
|
-
{
|
16719
|
-
grad_v = ggml_view_2d(ctx,
|
16720
|
-
flash_grad,
|
16721
|
-
opt0->ne[0],
|
16722
|
-
opt0->ne[1],
|
16723
|
-
nb0*opt0->ne[0],
|
16724
|
-
offset);
|
16725
|
-
} break;
|
16726
|
-
case 3:
|
16727
|
-
{
|
16728
|
-
grad_v = ggml_view_3d(ctx,
|
16729
|
-
flash_grad,
|
16730
|
-
opt0->ne[0],
|
16731
|
-
opt0->ne[1],
|
16732
|
-
opt0->ne[2],
|
16733
|
-
nb0*opt0->ne[0],
|
16734
|
-
nb0*opt0->ne[0]*opt0->ne[1],
|
16735
|
-
offset);
|
16736
|
-
} break;
|
16737
|
-
case 4:
|
16738
|
-
{
|
16739
|
-
grad_v = ggml_view_4d(ctx,
|
16740
|
-
flash_grad,
|
16741
|
-
opt0->ne[0],
|
16742
|
-
opt0->ne[1],
|
16743
|
-
opt0->ne[2],
|
16744
|
-
opt0->ne[3],
|
16745
|
-
nb0*opt0->ne[0],
|
16746
|
-
nb0*opt0->ne[0]*opt0->ne[1],
|
16747
|
-
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
16748
|
-
offset);
|
16749
|
-
} break;
|
16750
|
-
}
|
16751
|
-
|
16752
|
-
opt0->grad = ggml_add_impl(ctx,
|
16753
|
-
opt0->grad,
|
17394
|
+
if (src2->grad) {
|
17395
|
+
struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
|
17396
|
+
struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
|
17397
|
+
src2->grad = ggml_add_or_set(ctx,
|
17398
|
+
src2->grad,
|
16754
17399
|
grad_v,
|
16755
|
-
|
17400
|
+
zero_table);
|
16756
17401
|
}
|
16757
17402
|
} break;
|
16758
17403
|
case GGML_OP_FLASH_FF:
|
@@ -16772,12 +17417,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16772
17417
|
{
|
16773
17418
|
if (src0->grad) {
|
16774
17419
|
src0->grad =
|
16775
|
-
|
17420
|
+
ggml_add_or_set(ctx,
|
16776
17421
|
src0->grad,
|
16777
17422
|
ggml_mul(ctx,
|
16778
17423
|
ggml_sgn(ctx, src0),
|
16779
17424
|
tensor->grad),
|
16780
|
-
|
17425
|
+
zero_table);
|
16781
17426
|
}
|
16782
17427
|
} break;
|
16783
17428
|
case GGML_UNARY_OP_SGN:
|
@@ -16789,7 +17434,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16789
17434
|
case GGML_UNARY_OP_NEG:
|
16790
17435
|
{
|
16791
17436
|
if (src0->grad) {
|
16792
|
-
src0->grad =
|
17437
|
+
src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16793
17438
|
}
|
16794
17439
|
} break;
|
16795
17440
|
case GGML_UNARY_OP_STEP:
|
@@ -16809,12 +17454,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16809
17454
|
case GGML_UNARY_OP_RELU:
|
16810
17455
|
{
|
16811
17456
|
if (src0->grad) {
|
16812
|
-
src0->grad =
|
17457
|
+
src0->grad = ggml_add_or_set(ctx,
|
16813
17458
|
src0->grad,
|
16814
17459
|
ggml_mul(ctx,
|
16815
17460
|
ggml_step(ctx, src0),
|
16816
17461
|
tensor->grad),
|
16817
|
-
|
17462
|
+
zero_table);
|
16818
17463
|
}
|
16819
17464
|
} break;
|
16820
17465
|
case GGML_UNARY_OP_GELU:
|
@@ -16829,10 +17474,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16829
17474
|
{
|
16830
17475
|
// necessary for llama
|
16831
17476
|
if (src0->grad) {
|
16832
|
-
src0->grad =
|
17477
|
+
src0->grad = ggml_add_or_set(ctx,
|
16833
17478
|
src0->grad,
|
16834
17479
|
ggml_silu_back(ctx, src0, tensor->grad),
|
16835
|
-
|
17480
|
+
zero_table);
|
16836
17481
|
}
|
16837
17482
|
} break;
|
16838
17483
|
default:
|
@@ -16855,13 +17500,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16855
17500
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
16856
17501
|
{
|
16857
17502
|
if (src0->grad) {
|
16858
|
-
src0->grad =
|
17503
|
+
src0->grad = ggml_add_or_set(ctx,
|
16859
17504
|
src0->grad,
|
16860
17505
|
ggml_cross_entropy_loss_back(ctx,
|
16861
17506
|
src0,
|
16862
17507
|
src1,
|
16863
17508
|
tensor->grad),
|
16864
|
-
|
17509
|
+
zero_table);
|
16865
17510
|
}
|
16866
17511
|
} break;
|
16867
17512
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
@@ -16877,34 +17522,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16877
17522
|
GGML_ASSERT(false);
|
16878
17523
|
} break;
|
16879
17524
|
}
|
16880
|
-
}
|
16881
|
-
|
16882
|
-
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
|
16883
|
-
|
16884
|
-
static size_t hash(void * p) {
|
16885
|
-
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
16886
|
-
}
|
16887
17525
|
|
16888
|
-
|
16889
|
-
|
16890
|
-
|
16891
|
-
// linear probing
|
16892
|
-
size_t i = h;
|
16893
|
-
while (hash_table[i] != NULL && hash_table[i] != p) {
|
16894
|
-
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
16895
|
-
if (i == h) {
|
16896
|
-
// hash table is full
|
16897
|
-
GGML_ASSERT(false);
|
17526
|
+
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
17527
|
+
if (tensor->src[i] && tensor->src[i]->grad) {
|
17528
|
+
GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
|
16898
17529
|
}
|
16899
17530
|
}
|
16900
|
-
|
16901
|
-
if (hash_table[i] == p) {
|
16902
|
-
return true;
|
16903
|
-
}
|
16904
|
-
|
16905
|
-
// insert
|
16906
|
-
hash_table[i] = p;
|
16907
|
-
return false;
|
16908
17531
|
}
|
16909
17532
|
|
16910
17533
|
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
|
@@ -16922,8 +17545,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
|
16922
17545
|
}
|
16923
17546
|
|
16924
17547
|
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
16925
|
-
|
16926
|
-
|
17548
|
+
const int k =
|
17549
|
+
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
|
17550
|
+
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
|
17551
|
+
/* unknown order, just fall back to using i*/ i;
|
17552
|
+
if (node->src[k]) {
|
17553
|
+
ggml_visit_parents(cgraph, node->src[k]);
|
16927
17554
|
}
|
16928
17555
|
}
|
16929
17556
|
|
@@ -16982,6 +17609,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
|
|
16982
17609
|
/*.grads =*/ { NULL },
|
16983
17610
|
/*.leafs =*/ { NULL },
|
16984
17611
|
/*.hash_table =*/ { NULL },
|
17612
|
+
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
|
16985
17613
|
/*.perf_runs =*/ 0,
|
16986
17614
|
/*.perf_cycles =*/ 0,
|
16987
17615
|
/*.perf_time_us =*/ 0,
|
@@ -17007,12 +17635,22 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|
17007
17635
|
}
|
17008
17636
|
}
|
17009
17637
|
|
17638
|
+
// remember original gradients which start with zero values
|
17639
|
+
void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
|
17640
|
+
memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
|
17641
|
+
for (int i = 0; i < gf->n_nodes; i++) {
|
17642
|
+
if (gf->grads[i]) {
|
17643
|
+
hash_insert(zero_table, gf->grads[i]);
|
17644
|
+
}
|
17645
|
+
}
|
17646
|
+
|
17010
17647
|
for (int i = gf->n_nodes - 1; i >= 0; i--) {
|
17011
17648
|
struct ggml_tensor * node = gf->nodes[i];
|
17012
17649
|
|
17013
|
-
//
|
17650
|
+
// inplace operations to add gradients are not created by ggml_compute_backward
|
17651
|
+
// use allocator to automatically make inplace operations
|
17014
17652
|
if (node->grad) {
|
17015
|
-
ggml_compute_backward(ctx, node,
|
17653
|
+
ggml_compute_backward(ctx, node, zero_table);
|
17016
17654
|
}
|
17017
17655
|
}
|
17018
17656
|
|
@@ -17024,6 +17662,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|
17024
17662
|
ggml_build_forward_expand(gb, node->grad);
|
17025
17663
|
}
|
17026
17664
|
}
|
17665
|
+
|
17666
|
+
free(zero_table);
|
17027
17667
|
}
|
17028
17668
|
|
17029
17669
|
struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
|
@@ -17043,6 +17683,7 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
|
|
17043
17683
|
/*.grads =*/ { NULL },
|
17044
17684
|
/*.leafs =*/ { NULL },
|
17045
17685
|
/*.hash_table =*/ { NULL },
|
17686
|
+
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
|
17046
17687
|
/*.perf_runs =*/ 0,
|
17047
17688
|
/*.perf_cycles =*/ 0,
|
17048
17689
|
/*.perf_time_us =*/ 0,
|
@@ -17433,7 +18074,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|
17433
18074
|
} break;
|
17434
18075
|
case GGML_OP_CONCAT:
|
17435
18076
|
case GGML_OP_MUL_MAT:
|
17436
|
-
case GGML_OP_OUT_PROD:
|
17437
18077
|
{
|
17438
18078
|
n_tasks = n_threads;
|
17439
18079
|
|
@@ -17475,6 +18115,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|
17475
18115
|
cur = 0;
|
17476
18116
|
}
|
17477
18117
|
|
18118
|
+
work_size = MAX(work_size, cur);
|
18119
|
+
} break;
|
18120
|
+
case GGML_OP_OUT_PROD:
|
18121
|
+
{
|
18122
|
+
n_tasks = n_threads;
|
18123
|
+
|
18124
|
+
size_t cur = 0;
|
18125
|
+
|
18126
|
+
if (ggml_is_quantized(node->src[0]->type)) {
|
18127
|
+
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
18128
|
+
}
|
18129
|
+
|
17478
18130
|
work_size = MAX(work_size, cur);
|
17479
18131
|
} break;
|
17480
18132
|
case GGML_OP_SCALE:
|
@@ -18568,7 +19220,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
|
|
18568
19220
|
}
|
18569
19221
|
|
18570
19222
|
static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
|
18571
|
-
|
19223
|
+
int64_t i = 0;
|
18572
19224
|
for (int p = 0; p < np; ++p) {
|
18573
19225
|
const int64_t ne = ggml_nelements(ps[p]) ;
|
18574
19226
|
// TODO: add function to get all elements at once
|
@@ -18578,6 +19230,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
|
|
18578
19230
|
}
|
18579
19231
|
}
|
18580
19232
|
|
19233
|
+
static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
|
19234
|
+
int64_t i = 0;
|
19235
|
+
for (int p = 0; p < np; ++p) {
|
19236
|
+
const int64_t ne = ggml_nelements(ps[p]) ;
|
19237
|
+
// TODO: add function to get all elements at once
|
19238
|
+
for (int64_t j = 0; j < ne; ++j) {
|
19239
|
+
g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
|
19240
|
+
}
|
19241
|
+
}
|
19242
|
+
}
|
19243
|
+
|
18581
19244
|
//
|
18582
19245
|
// ADAM
|
18583
19246
|
//
|
@@ -18626,26 +19289,43 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18626
19289
|
const float eps = params.adam.eps;
|
18627
19290
|
const float gclip = params.adam.gclip;
|
18628
19291
|
const int decay_min_ndim = params.adam.decay_min_ndim;
|
19292
|
+
const int n_accum = MAX(1, params.n_gradient_accumulation);
|
19293
|
+
const float accum_norm = 1.0f / (float) n_accum;
|
18629
19294
|
|
19295
|
+
float * g = opt->adam.g->data; // gradients
|
18630
19296
|
float * m = opt->adam.m->data; // first moment
|
18631
19297
|
float * v = opt->adam.v->data; // second moment
|
18632
19298
|
|
18633
19299
|
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
18634
19300
|
|
18635
|
-
if (callback) {
|
18636
|
-
callback(callback_data, &sched);
|
18637
|
-
}
|
18638
|
-
|
18639
|
-
// compute the function value
|
18640
|
-
ggml_graph_reset (gf);
|
18641
|
-
ggml_set_f32 (f->grad, 1.0f);
|
18642
|
-
|
18643
19301
|
struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
|
18644
19302
|
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
|
18645
19303
|
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
|
18646
|
-
ggml_graph_compute(gb, &cplan);
|
18647
19304
|
|
18648
|
-
|
19305
|
+
bool cancel = false;
|
19306
|
+
|
19307
|
+
// compute the function value
|
19308
|
+
float fx = 0;
|
19309
|
+
ggml_set_zero(opt->adam.g);
|
19310
|
+
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
19311
|
+
if (callback) {
|
19312
|
+
callback(callback_data, accum_step, &sched, &cancel);
|
19313
|
+
if (cancel) {
|
19314
|
+
break;
|
19315
|
+
}
|
19316
|
+
}
|
19317
|
+
// ggml_graph_reset (gf);
|
19318
|
+
ggml_set_f32 (f->grad, 1.0f);
|
19319
|
+
ggml_graph_compute(gb, &cplan);
|
19320
|
+
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19321
|
+
fx += ggml_get_f32_1d(f, 0);
|
19322
|
+
}
|
19323
|
+
if (cancel) {
|
19324
|
+
return GGML_OPT_DID_NOT_CONVERGE;
|
19325
|
+
}
|
19326
|
+
fx *= accum_norm;
|
19327
|
+
|
19328
|
+
opt->adam.fx_prev = fx;
|
18649
19329
|
opt->adam.fx_best = opt->adam.fx_prev;
|
18650
19330
|
if (pf) {
|
18651
19331
|
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
@@ -18668,6 +19348,9 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18668
19348
|
|
18669
19349
|
// run the optimizer
|
18670
19350
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
19351
|
+
if (cancel) {
|
19352
|
+
break;
|
19353
|
+
}
|
18671
19354
|
opt->iter = iter0 + t + 1;
|
18672
19355
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
18673
19356
|
|
@@ -18690,12 +19373,8 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18690
19373
|
if (gclip > 0.0f) {
|
18691
19374
|
// gradient clipping
|
18692
19375
|
ggml_float sum = 0.0;
|
18693
|
-
for (
|
18694
|
-
|
18695
|
-
for (int64_t j = 0; j < ne; ++j) {
|
18696
|
-
float g = ggml_get_f32_1d(ps[p]->grad, j);
|
18697
|
-
sum += (ggml_float)(g*g);
|
18698
|
-
}
|
19376
|
+
for (int64_t i = 0; i < nx; ++i) {
|
19377
|
+
sum += (ggml_float)(g[i]*g[i]);
|
18699
19378
|
}
|
18700
19379
|
ggml_float norm = sqrt(sum);
|
18701
19380
|
if (norm > (ggml_float) gclip) {
|
@@ -18709,10 +19388,10 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18709
19388
|
const int64_t ne = ggml_nelements(ps[p]);
|
18710
19389
|
const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
|
18711
19390
|
for (int64_t j = 0; j < ne; ++j) {
|
18712
|
-
float x
|
18713
|
-
float
|
18714
|
-
m[i] = m[i]*beta1 +
|
18715
|
-
v[i] = v[i]*beta2 +
|
19391
|
+
float x = ggml_get_f32_1d(ps[p], j);
|
19392
|
+
float g_ = g[i]*gnorm;
|
19393
|
+
m[i] = m[i]*beta1 + g_*(1.0f - beta1);
|
19394
|
+
v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
|
18716
19395
|
float mh = m[i]*beta1h;
|
18717
19396
|
float vh = v[i]*beta2h;
|
18718
19397
|
vh = sqrtf(vh) + eps;
|
@@ -18723,16 +19402,26 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18723
19402
|
}
|
18724
19403
|
}
|
18725
19404
|
|
18726
|
-
|
18727
|
-
|
19405
|
+
fx = 0;
|
19406
|
+
ggml_set_zero(opt->adam.g);
|
19407
|
+
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
19408
|
+
if (callback) {
|
19409
|
+
callback(callback_data, accum_step, &sched, &cancel);
|
19410
|
+
if (cancel) {
|
19411
|
+
break;
|
19412
|
+
}
|
19413
|
+
}
|
19414
|
+
// ggml_graph_reset (gf);
|
19415
|
+
ggml_set_f32 (f->grad, 1.0f);
|
19416
|
+
ggml_graph_compute(gb, &cplan);
|
19417
|
+
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19418
|
+
fx += ggml_get_f32_1d(f, 0);
|
18728
19419
|
}
|
19420
|
+
if (cancel) {
|
19421
|
+
break;
|
19422
|
+
}
|
19423
|
+
fx *= accum_norm;
|
18729
19424
|
|
18730
|
-
ggml_graph_reset (gf);
|
18731
|
-
ggml_set_f32 (f->grad, 1.0f);
|
18732
|
-
|
18733
|
-
ggml_graph_compute(gb, &cplan);
|
18734
|
-
|
18735
|
-
const float fx = ggml_get_f32_1d(f, 0);
|
18736
19425
|
opt->loss_after = fx;
|
18737
19426
|
|
18738
19427
|
|
@@ -18812,11 +19501,11 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18812
19501
|
float * step,
|
18813
19502
|
const float * xp,
|
18814
19503
|
struct ggml_tensor * f,
|
18815
|
-
struct ggml_cgraph * gf,
|
18816
19504
|
struct ggml_cgraph * gb,
|
18817
19505
|
struct ggml_cplan * cplan,
|
18818
19506
|
const int np,
|
18819
19507
|
struct ggml_tensor * ps[],
|
19508
|
+
bool * cancel,
|
18820
19509
|
ggml_opt_callback callback,
|
18821
19510
|
void * callback_data) {
|
18822
19511
|
int count = 0;
|
@@ -18830,6 +19519,9 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18830
19519
|
const float dec = 0.5f;
|
18831
19520
|
const float inc = 2.1f;
|
18832
19521
|
|
19522
|
+
const int n_accum = MAX(1, params->n_gradient_accumulation);
|
19523
|
+
const float accum_norm = 1.0f / (float) n_accum;
|
19524
|
+
|
18833
19525
|
if (*step <= 0.f) {
|
18834
19526
|
return GGML_LINESEARCH_INVALID_PARAMETERS;
|
18835
19527
|
}
|
@@ -18846,13 +19538,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18846
19538
|
finit = *fx;
|
18847
19539
|
dgtest = params->lbfgs.ftol*dginit;
|
18848
19540
|
|
18849
|
-
while (
|
18850
|
-
if (callback) {
|
18851
|
-
// LBFG-S does not support learning rate -> ignore learning schedule
|
18852
|
-
float sched = 0;
|
18853
|
-
callback(callback_data, &sched);
|
18854
|
-
}
|
18855
|
-
|
19541
|
+
while (!*cancel) {
|
18856
19542
|
ggml_vec_cpy_f32(nx, x, xp);
|
18857
19543
|
ggml_vec_mad_f32(nx, x, d, *step);
|
18858
19544
|
|
@@ -18860,14 +19546,28 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18860
19546
|
{
|
18861
19547
|
ggml_opt_set_params(np, ps, x);
|
18862
19548
|
|
18863
|
-
|
18864
|
-
|
18865
|
-
|
18866
|
-
|
18867
|
-
|
18868
|
-
|
19549
|
+
*fx = 0;
|
19550
|
+
memset(g, 0, sizeof(float)*nx);
|
19551
|
+
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
19552
|
+
if (callback) {
|
19553
|
+
// LBFG-S does not support learning rate -> ignore learning schedule
|
19554
|
+
float sched = 0;
|
19555
|
+
callback(callback_data, accum_step, &sched, cancel);
|
19556
|
+
if (*cancel) {
|
19557
|
+
break;
|
19558
|
+
}
|
19559
|
+
}
|
19560
|
+
// ggml_graph_reset (gf);
|
19561
|
+
ggml_set_f32 (f->grad, 1.0f);
|
19562
|
+
ggml_graph_compute(gb, cplan);
|
19563
|
+
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19564
|
+
*fx += ggml_get_f32_1d(f, 0);
|
19565
|
+
}
|
19566
|
+
if (*cancel) {
|
19567
|
+
break;
|
19568
|
+
}
|
19569
|
+
*fx *= accum_norm;
|
18869
19570
|
|
18870
|
-
*fx = ggml_get_f32_1d(f, 0);
|
18871
19571
|
}
|
18872
19572
|
|
18873
19573
|
++count;
|
@@ -18913,7 +19613,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18913
19613
|
(*step) *= width;
|
18914
19614
|
}
|
18915
19615
|
|
18916
|
-
|
19616
|
+
GGML_UNREACHABLE();
|
18917
19617
|
}
|
18918
19618
|
|
18919
19619
|
static enum ggml_opt_result ggml_opt_lbfgs(
|
@@ -18968,6 +19668,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
18968
19668
|
|
18969
19669
|
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
18970
19670
|
|
19671
|
+
const int n_accum = MAX(1, params.n_gradient_accumulation);
|
19672
|
+
const float accum_norm = 1.0f / (float) n_accum;
|
19673
|
+
|
18971
19674
|
float fx = 0.0f; // cost function value
|
18972
19675
|
float xnorm = 0.0f; // ||x||
|
18973
19676
|
float gnorm = 0.0f; // ||g||
|
@@ -18981,24 +19684,33 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
18981
19684
|
float * lm_s = opt->lbfgs.lms->data;
|
18982
19685
|
float * lm_y = opt->lbfgs.lmy->data;
|
18983
19686
|
|
18984
|
-
|
18985
|
-
// LBFG-S does not support learning rate -> ignore learning schedule
|
18986
|
-
float sched = 0;
|
18987
|
-
callback(callback_data, &sched);
|
18988
|
-
}
|
19687
|
+
bool cancel = false;
|
18989
19688
|
|
18990
19689
|
// evaluate the function value and its gradient
|
18991
19690
|
{
|
18992
19691
|
ggml_opt_set_params(np, ps, x);
|
18993
19692
|
|
18994
|
-
|
18995
|
-
|
18996
|
-
|
18997
|
-
|
18998
|
-
|
18999
|
-
|
19000
|
-
|
19001
|
-
|
19693
|
+
fx = 0;
|
19694
|
+
memset(g, 0, sizeof(float)*nx);
|
19695
|
+
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
19696
|
+
if (callback) {
|
19697
|
+
// LBFG-S does not support learning rate -> ignore learning schedule
|
19698
|
+
float sched = 0;
|
19699
|
+
callback(callback_data, accum_step, &sched, &cancel);
|
19700
|
+
if (cancel) {
|
19701
|
+
break;
|
19702
|
+
}
|
19703
|
+
}
|
19704
|
+
// ggml_graph_reset (gf);
|
19705
|
+
ggml_set_f32 (f->grad, 1.0f);
|
19706
|
+
ggml_graph_compute(gb, &cplan);
|
19707
|
+
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19708
|
+
fx += ggml_get_f32_1d(f, 0);
|
19709
|
+
}
|
19710
|
+
if (cancel) {
|
19711
|
+
return GGML_OPT_DID_NOT_CONVERGE;
|
19712
|
+
}
|
19713
|
+
fx *= accum_norm;
|
19002
19714
|
|
19003
19715
|
opt->loss_before = fx;
|
19004
19716
|
opt->loss_after = fx;
|
@@ -19056,7 +19768,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19056
19768
|
ggml_vec_cpy_f32(nx, xp, x);
|
19057
19769
|
ggml_vec_cpy_f32(nx, gp, g);
|
19058
19770
|
|
19059
|
-
ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f,
|
19771
|
+
ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
|
19772
|
+
if (!cancel) {
|
19773
|
+
break;
|
19774
|
+
}
|
19060
19775
|
|
19061
19776
|
if (ls < 0) {
|
19062
19777
|
// linesearch failed - go back to the previous point and return
|
@@ -19165,7 +19880,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19165
19880
|
step[0] = 1.0;
|
19166
19881
|
}
|
19167
19882
|
|
19168
|
-
|
19883
|
+
GGML_UNREACHABLE();
|
19169
19884
|
}
|
19170
19885
|
|
19171
19886
|
struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
@@ -19185,6 +19900,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
19185
19900
|
.print_forward_graph = true,
|
19186
19901
|
.print_backward_graph = true,
|
19187
19902
|
|
19903
|
+
.n_gradient_accumulation = 1,
|
19904
|
+
|
19188
19905
|
.adam = {
|
19189
19906
|
.n_iter = 10000,
|
19190
19907
|
.sched = 1.000f,
|
@@ -19213,6 +19930,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
19213
19930
|
.print_forward_graph = true,
|
19214
19931
|
.print_backward_graph = true,
|
19215
19932
|
|
19933
|
+
.n_gradient_accumulation = 1,
|
19934
|
+
|
19216
19935
|
.lbfgs = {
|
19217
19936
|
.m = 6,
|
19218
19937
|
.n_iter = 100,
|
@@ -19243,13 +19962,32 @@ GGML_API void ggml_opt_init(
|
|
19243
19962
|
opt->iter = 0;
|
19244
19963
|
opt->nx = nx;
|
19245
19964
|
opt->just_initialized = true;
|
19965
|
+
if (opt->ctx == NULL) {
|
19966
|
+
struct ggml_init_params ctx_opt_params;
|
19967
|
+
if (opt->params.type == GGML_OPT_ADAM) {
|
19968
|
+
ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
|
19969
|
+
if (opt->params.past > 0) {
|
19970
|
+
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
|
19971
|
+
}
|
19972
|
+
} else if (opt->params.type == GGML_OPT_LBFGS) {
|
19973
|
+
ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
|
19974
|
+
if (opt->params.past > 0) {
|
19975
|
+
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
|
19976
|
+
}
|
19977
|
+
}
|
19978
|
+
ctx_opt_params.mem_buffer = NULL;
|
19979
|
+
ctx_opt_params.no_alloc = false;
|
19980
|
+
|
19981
|
+
opt->ctx = ggml_init(ctx_opt_params);
|
19982
|
+
}
|
19246
19983
|
switch (opt->params.type) {
|
19247
19984
|
case GGML_OPT_ADAM:
|
19248
19985
|
{
|
19249
|
-
opt->adam.
|
19250
|
-
opt->adam.
|
19986
|
+
opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
19987
|
+
opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
19988
|
+
opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
19251
19989
|
opt->adam.pf = params.past > 0
|
19252
|
-
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
19990
|
+
? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
|
19253
19991
|
: NULL;
|
19254
19992
|
ggml_set_zero(opt->adam.m);
|
19255
19993
|
ggml_set_zero(opt->adam.v);
|
@@ -19259,18 +19997,18 @@ GGML_API void ggml_opt_init(
|
|
19259
19997
|
} break;
|
19260
19998
|
case GGML_OPT_LBFGS:
|
19261
19999
|
{
|
19262
|
-
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
19263
|
-
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
19264
|
-
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
19265
|
-
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
19266
|
-
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
20000
|
+
opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
20001
|
+
opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
20002
|
+
opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
20003
|
+
opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
20004
|
+
opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
19267
20005
|
opt->lbfgs.pf = params.past > 0
|
19268
|
-
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
20006
|
+
? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
|
19269
20007
|
: NULL;
|
19270
|
-
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
19271
|
-
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
19272
|
-
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
19273
|
-
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
20008
|
+
opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
|
20009
|
+
opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
|
20010
|
+
opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
20011
|
+
opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
19274
20012
|
ggml_set_zero(opt->lbfgs.x);
|
19275
20013
|
ggml_set_zero(opt->lbfgs.xp);
|
19276
20014
|
ggml_set_zero(opt->lbfgs.g);
|
@@ -19876,10 +20614,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|
19876
20614
|
} break;
|
19877
20615
|
case GGUF_TYPE_ARRAY:
|
19878
20616
|
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
|
19879
|
-
}
|
20617
|
+
}
|
19880
20618
|
} break;
|
19881
20619
|
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
|
19882
|
-
}
|
20620
|
+
}
|
19883
20621
|
|
19884
20622
|
if (!ok) {
|
19885
20623
|
break;
|
@@ -20155,78 +20893,94 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
|
|
20155
20893
|
return keyfound;
|
20156
20894
|
}
|
20157
20895
|
|
20158
|
-
const char * gguf_get_key(const struct gguf_context * ctx, int
|
20159
|
-
return ctx->kv[
|
20896
|
+
const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
|
20897
|
+
return ctx->kv[key_id].key.data;
|
20160
20898
|
}
|
20161
20899
|
|
20162
|
-
enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int
|
20163
|
-
return ctx->kv[
|
20900
|
+
enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
|
20901
|
+
return ctx->kv[key_id].type;
|
20164
20902
|
}
|
20165
20903
|
|
20166
|
-
enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int
|
20167
|
-
|
20904
|
+
enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
|
20905
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
|
20906
|
+
return ctx->kv[key_id].value.arr.type;
|
20168
20907
|
}
|
20169
20908
|
|
20170
|
-
const void * gguf_get_arr_data(const struct gguf_context * ctx, int
|
20171
|
-
|
20909
|
+
const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
|
20910
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
|
20911
|
+
return ctx->kv[key_id].value.arr.data;
|
20172
20912
|
}
|
20173
20913
|
|
20174
20914
|
const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
|
20915
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
|
20175
20916
|
struct gguf_kv * kv = &ctx->kv[key_id];
|
20176
20917
|
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
|
20177
20918
|
return str->data;
|
20178
20919
|
}
|
20179
20920
|
|
20180
|
-
int gguf_get_arr_n(const struct gguf_context * ctx, int
|
20181
|
-
|
20921
|
+
int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
|
20922
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
|
20923
|
+
return ctx->kv[key_id].value.arr.n;
|
20182
20924
|
}
|
20183
20925
|
|
20184
|
-
uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int
|
20185
|
-
|
20926
|
+
uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
|
20927
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
|
20928
|
+
return ctx->kv[key_id].value.uint8;
|
20186
20929
|
}
|
20187
20930
|
|
20188
|
-
int8_t gguf_get_val_i8(const struct gguf_context * ctx, int
|
20189
|
-
|
20931
|
+
int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
|
20932
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
|
20933
|
+
return ctx->kv[key_id].value.int8;
|
20190
20934
|
}
|
20191
20935
|
|
20192
|
-
uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int
|
20193
|
-
|
20936
|
+
uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
|
20937
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
|
20938
|
+
return ctx->kv[key_id].value.uint16;
|
20194
20939
|
}
|
20195
20940
|
|
20196
|
-
int16_t gguf_get_val_i16(const struct gguf_context * ctx, int
|
20197
|
-
|
20941
|
+
int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
|
20942
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
|
20943
|
+
return ctx->kv[key_id].value.int16;
|
20198
20944
|
}
|
20199
20945
|
|
20200
|
-
uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int
|
20201
|
-
|
20946
|
+
uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
|
20947
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
|
20948
|
+
return ctx->kv[key_id].value.uint32;
|
20202
20949
|
}
|
20203
20950
|
|
20204
|
-
int32_t gguf_get_val_i32(const struct gguf_context * ctx, int
|
20205
|
-
|
20951
|
+
int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
|
20952
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
|
20953
|
+
return ctx->kv[key_id].value.int32;
|
20206
20954
|
}
|
20207
20955
|
|
20208
|
-
float gguf_get_val_f32(const struct gguf_context * ctx, int
|
20209
|
-
|
20956
|
+
float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
|
20957
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
|
20958
|
+
return ctx->kv[key_id].value.float32;
|
20210
20959
|
}
|
20211
20960
|
|
20212
|
-
uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int
|
20213
|
-
|
20961
|
+
uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
|
20962
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
|
20963
|
+
return ctx->kv[key_id].value.uint64;
|
20214
20964
|
}
|
20215
20965
|
|
20216
|
-
int64_t gguf_get_val_i64(const struct gguf_context * ctx, int
|
20217
|
-
|
20966
|
+
int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
|
20967
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
|
20968
|
+
return ctx->kv[key_id].value.int64;
|
20218
20969
|
}
|
20219
20970
|
|
20220
|
-
double gguf_get_val_f64(const struct gguf_context * ctx, int
|
20221
|
-
|
20971
|
+
double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
|
20972
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
|
20973
|
+
return ctx->kv[key_id].value.float64;
|
20222
20974
|
}
|
20223
20975
|
|
20224
|
-
bool gguf_get_val_bool(const struct gguf_context * ctx, int
|
20225
|
-
|
20976
|
+
bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
|
20977
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
|
20978
|
+
return ctx->kv[key_id].value.bool_;
|
20226
20979
|
}
|
20227
20980
|
|
20228
|
-
const char * gguf_get_val_str
|
20229
|
-
|
20981
|
+
const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
|
20982
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
|
20983
|
+
return ctx->kv[key_id].value.str.data;
|
20230
20984
|
}
|
20231
20985
|
|
20232
20986
|
int gguf_get_n_tensors(const struct gguf_context * ctx) {
|
@@ -20591,10 +21345,10 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf *
|
|
20591
21345
|
} break;
|
20592
21346
|
case GGUF_TYPE_ARRAY:
|
20593
21347
|
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
|
20594
|
-
}
|
21348
|
+
}
|
20595
21349
|
} break;
|
20596
21350
|
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
|
20597
|
-
}
|
21351
|
+
}
|
20598
21352
|
}
|
20599
21353
|
|
20600
21354
|
// write tensor infos
|