llama_cpp 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +209 -82
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +163 -84
- data/ext/llama_cpp/src/ggml-metal.metal +121 -38
- data/ext/llama_cpp/src/ggml.c +1596 -842
- data/ext/llama_cpp/src/ggml.h +116 -35
- data/ext/llama_cpp/src/llama.cpp +1015 -586
- data/ext/llama_cpp/src/llama.h +304 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -89,7 +89,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo
|
|
89
89
|
|
90
90
|
static int pthread_join(pthread_t thread, void * unused) {
|
91
91
|
(void) unused;
|
92
|
-
|
92
|
+
int ret = (int) WaitForSingleObject(thread, INFINITE);
|
93
|
+
CloseHandle(thread);
|
94
|
+
return ret;
|
93
95
|
}
|
94
96
|
|
95
97
|
static int sched_yield (void) {
|
@@ -134,6 +136,7 @@ typedef void * thread_ret_t;
|
|
134
136
|
|
135
137
|
#define GGML_SOFT_MAX_UNROLL 4
|
136
138
|
#define GGML_VEC_DOT_UNROLL 2
|
139
|
+
#define GGML_VEC_MAD_UNROLL 32
|
137
140
|
|
138
141
|
//
|
139
142
|
// logging
|
@@ -242,18 +245,18 @@ inline static void * ggml_aligned_malloc(size_t size) {
|
|
242
245
|
//
|
243
246
|
|
244
247
|
#define GGML_TENSOR_UNARY_OP_LOCALS \
|
245
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
246
|
-
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
247
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
248
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
248
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
249
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
250
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
251
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
249
252
|
|
250
253
|
#define GGML_TENSOR_BINARY_OP_LOCALS \
|
251
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
252
|
-
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
253
|
-
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
254
|
-
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
255
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
256
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
254
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
255
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
256
|
+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
257
|
+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
258
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
259
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
257
260
|
|
258
261
|
#if defined(GGML_USE_ACCELERATE)
|
259
262
|
#include <Accelerate/Accelerate.h>
|
@@ -1863,7 +1866,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
1863
1866
|
#define GGML_F16x8_ADD vaddq_f16
|
1864
1867
|
#define GGML_F16x8_MUL vmulq_f16
|
1865
1868
|
#define GGML_F16x8_REDUCE(res, x) \
|
1866
|
-
{
|
1869
|
+
do { \
|
1867
1870
|
int offset = GGML_F16_ARR >> 1; \
|
1868
1871
|
for (int i = 0; i < offset; ++i) { \
|
1869
1872
|
x[i] = vaddq_f16(x[i], x[offset+i]); \
|
@@ -1879,7 +1882,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
1879
1882
|
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
|
1880
1883
|
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
|
1881
1884
|
res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
|
1882
|
-
}
|
1885
|
+
} while (0)
|
1883
1886
|
|
1884
1887
|
#define GGML_F16_VEC GGML_F16x8
|
1885
1888
|
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
|
@@ -1940,7 +1943,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
1940
1943
|
#define GGML_F32x8_ADD _mm256_add_ps
|
1941
1944
|
#define GGML_F32x8_MUL _mm256_mul_ps
|
1942
1945
|
#define GGML_F32x8_REDUCE(res, x) \
|
1943
|
-
{
|
1946
|
+
do { \
|
1944
1947
|
int offset = GGML_F32_ARR >> 1; \
|
1945
1948
|
for (int i = 0; i < offset; ++i) { \
|
1946
1949
|
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
|
@@ -1957,7 +1960,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
1957
1960
|
_mm256_extractf128_ps(x[0], 1)); \
|
1958
1961
|
const __m128 t1 = _mm_hadd_ps(t0, t0); \
|
1959
1962
|
res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
|
1960
|
-
}
|
1963
|
+
} while (0)
|
1961
1964
|
// TODO: is this optimal ?
|
1962
1965
|
|
1963
1966
|
#define GGML_F32_VEC GGML_F32x8
|
@@ -3707,6 +3710,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
|
|
3707
3710
|
#endif
|
3708
3711
|
}
|
3709
3712
|
|
3713
|
+
// xs and vs are byte strides of x and v
|
3714
|
+
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
|
3715
|
+
|
3716
|
+
const float * restrict x[GGML_VEC_MAD_UNROLL];
|
3717
|
+
const float * restrict v[GGML_VEC_MAD_UNROLL];
|
3718
|
+
|
3719
|
+
for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
|
3720
|
+
x[i] = (const float *) ((const char *) xv + i*xs);
|
3721
|
+
v[i] = (const float *) ((const char *) vv + i*vs);
|
3722
|
+
}
|
3723
|
+
|
3724
|
+
#if defined(GGML_SIMD)
|
3725
|
+
const int np = (n & ~(GGML_F32_STEP - 1));
|
3726
|
+
|
3727
|
+
GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
|
3728
|
+
|
3729
|
+
for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
|
3730
|
+
vx[k] = GGML_F32_VEC_SET1(v[k][0]);
|
3731
|
+
}
|
3732
|
+
|
3733
|
+
GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
|
3734
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
3735
|
+
|
3736
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
3737
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
3738
|
+
ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
|
3739
|
+
|
3740
|
+
for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
|
3741
|
+
ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
|
3742
|
+
ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
|
3743
|
+
}
|
3744
|
+
|
3745
|
+
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
3746
|
+
}
|
3747
|
+
}
|
3748
|
+
|
3749
|
+
// leftovers
|
3750
|
+
for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
|
3751
|
+
for (int i = np; i < n; ++i) {
|
3752
|
+
y[i] += x[k][i]*v[k][0];
|
3753
|
+
}
|
3754
|
+
}
|
3755
|
+
#else
|
3756
|
+
// scalar
|
3757
|
+
for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
|
3758
|
+
for (int i = 0; i < n; ++i) {
|
3759
|
+
y[i] += x[k][i]*v[k][0];
|
3760
|
+
}
|
3761
|
+
}
|
3762
|
+
#endif
|
3763
|
+
}
|
3764
|
+
|
3710
3765
|
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
3711
3766
|
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
3712
3767
|
#if defined(GGML_USE_ACCELERATE)
|
@@ -4392,10 +4447,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
4392
4447
|
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
4393
4448
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
4394
4449
|
|
4395
|
-
return
|
4396
|
-
|
4397
|
-
|
4398
|
-
(t0->ne[3] == t1->ne[3]);
|
4450
|
+
return (t0->ne[1] == t1->ne[1]) &&
|
4451
|
+
(t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
|
4452
|
+
(t1->ne[3]%t0->ne[3] == 0);
|
4399
4453
|
}
|
4400
4454
|
|
4401
4455
|
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
@@ -5065,43 +5119,78 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
5065
5119
|
return tensor;
|
5066
5120
|
}
|
5067
5121
|
|
5122
|
+
void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
|
5123
|
+
const int64_t ne2 = tensor->ne[2];
|
5124
|
+
const int64_t ne1 = tensor->ne[1];
|
5125
|
+
const int64_t ne0 = tensor->ne[0];
|
5126
|
+
|
5127
|
+
const int64_t i3_ = (i/(ne2*ne1*ne0));
|
5128
|
+
const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
|
5129
|
+
const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
|
5130
|
+
const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
|
5131
|
+
|
5132
|
+
if (i0) {
|
5133
|
+
* i0 = i0_;
|
5134
|
+
}
|
5135
|
+
if (i1) {
|
5136
|
+
* i1 = i1_;
|
5137
|
+
}
|
5138
|
+
if (i2) {
|
5139
|
+
* i2 = i2_;
|
5140
|
+
}
|
5141
|
+
if (i3) {
|
5142
|
+
* i3 = i3_;
|
5143
|
+
}
|
5144
|
+
}
|
5145
|
+
|
5068
5146
|
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
5147
|
+
if (!ggml_is_contiguous(tensor)) {
|
5148
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
5149
|
+
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
5150
|
+
return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
|
5151
|
+
}
|
5069
5152
|
switch (tensor->type) {
|
5070
5153
|
case GGML_TYPE_I8:
|
5071
5154
|
{
|
5072
5155
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
5073
5156
|
return ((int8_t *)(tensor->data))[i];
|
5074
|
-
}
|
5157
|
+
}
|
5075
5158
|
case GGML_TYPE_I16:
|
5076
5159
|
{
|
5077
5160
|
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
5078
5161
|
return ((int16_t *)(tensor->data))[i];
|
5079
|
-
}
|
5162
|
+
}
|
5080
5163
|
case GGML_TYPE_I32:
|
5081
5164
|
{
|
5082
5165
|
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
5083
5166
|
return ((int32_t *)(tensor->data))[i];
|
5084
|
-
}
|
5167
|
+
}
|
5085
5168
|
case GGML_TYPE_F16:
|
5086
5169
|
{
|
5087
5170
|
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
|
5088
5171
|
return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
|
5089
|
-
}
|
5172
|
+
}
|
5090
5173
|
case GGML_TYPE_F32:
|
5091
5174
|
{
|
5092
5175
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
5093
5176
|
return ((float *)(tensor->data))[i];
|
5094
|
-
}
|
5177
|
+
}
|
5095
5178
|
default:
|
5096
5179
|
{
|
5097
5180
|
GGML_ASSERT(false);
|
5098
|
-
}
|
5181
|
+
}
|
5099
5182
|
}
|
5100
5183
|
|
5101
5184
|
return 0.0f;
|
5102
5185
|
}
|
5103
5186
|
|
5104
5187
|
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
5188
|
+
if (!ggml_is_contiguous(tensor)) {
|
5189
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
5190
|
+
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
5191
|
+
ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
|
5192
|
+
return;
|
5193
|
+
}
|
5105
5194
|
switch (tensor->type) {
|
5106
5195
|
case GGML_TYPE_I8:
|
5107
5196
|
{
|
@@ -5135,43 +5224,104 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
5135
5224
|
}
|
5136
5225
|
}
|
5137
5226
|
|
5227
|
+
int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
|
5228
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
5229
|
+
switch (tensor->type) {
|
5230
|
+
case GGML_TYPE_I8:
|
5231
|
+
return ((int8_t *) data)[0];
|
5232
|
+
case GGML_TYPE_I16:
|
5233
|
+
return ((int16_t *) data)[0];
|
5234
|
+
case GGML_TYPE_I32:
|
5235
|
+
return ((int32_t *) data)[0];
|
5236
|
+
case GGML_TYPE_F16:
|
5237
|
+
return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
|
5238
|
+
case GGML_TYPE_F32:
|
5239
|
+
return ((float *) data)[0];
|
5240
|
+
default:
|
5241
|
+
GGML_ASSERT(false);
|
5242
|
+
}
|
5243
|
+
|
5244
|
+
return 0.0f;
|
5245
|
+
}
|
5246
|
+
|
5247
|
+
void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
|
5248
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
5249
|
+
switch (tensor->type) {
|
5250
|
+
case GGML_TYPE_I8:
|
5251
|
+
{
|
5252
|
+
((int8_t *)(data))[0] = value;
|
5253
|
+
} break;
|
5254
|
+
case GGML_TYPE_I16:
|
5255
|
+
{
|
5256
|
+
((int16_t *)(data))[0] = value;
|
5257
|
+
} break;
|
5258
|
+
case GGML_TYPE_I32:
|
5259
|
+
{
|
5260
|
+
((int32_t *)(data))[0] = value;
|
5261
|
+
} break;
|
5262
|
+
case GGML_TYPE_F16:
|
5263
|
+
{
|
5264
|
+
((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
|
5265
|
+
} break;
|
5266
|
+
case GGML_TYPE_F32:
|
5267
|
+
{
|
5268
|
+
((float *)(data))[0] = value;
|
5269
|
+
} break;
|
5270
|
+
default:
|
5271
|
+
{
|
5272
|
+
GGML_ASSERT(false);
|
5273
|
+
} break;
|
5274
|
+
}
|
5275
|
+
}
|
5276
|
+
|
5138
5277
|
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
5278
|
+
if (!ggml_is_contiguous(tensor)) {
|
5279
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
5280
|
+
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
5281
|
+
return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
|
5282
|
+
}
|
5139
5283
|
switch (tensor->type) {
|
5140
5284
|
case GGML_TYPE_I8:
|
5141
5285
|
{
|
5142
5286
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
5143
5287
|
return ((int8_t *)(tensor->data))[i];
|
5144
|
-
}
|
5288
|
+
}
|
5145
5289
|
case GGML_TYPE_I16:
|
5146
5290
|
{
|
5147
5291
|
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
5148
5292
|
return ((int16_t *)(tensor->data))[i];
|
5149
|
-
}
|
5293
|
+
}
|
5150
5294
|
case GGML_TYPE_I32:
|
5151
5295
|
{
|
5152
5296
|
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
5153
5297
|
return ((int32_t *)(tensor->data))[i];
|
5154
|
-
}
|
5298
|
+
}
|
5155
5299
|
case GGML_TYPE_F16:
|
5156
5300
|
{
|
5157
5301
|
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
|
5158
5302
|
return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
|
5159
|
-
}
|
5303
|
+
}
|
5160
5304
|
case GGML_TYPE_F32:
|
5161
5305
|
{
|
5162
5306
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
5163
5307
|
return ((float *)(tensor->data))[i];
|
5164
|
-
}
|
5308
|
+
}
|
5165
5309
|
default:
|
5166
5310
|
{
|
5167
5311
|
GGML_ASSERT(false);
|
5168
|
-
}
|
5312
|
+
}
|
5169
5313
|
}
|
5170
5314
|
|
5171
5315
|
return 0.0f;
|
5172
5316
|
}
|
5173
5317
|
|
5174
5318
|
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
5319
|
+
if (!ggml_is_contiguous(tensor)) {
|
5320
|
+
int64_t id[4] = { 0, 0, 0, 0 };
|
5321
|
+
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
5322
|
+
ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
|
5323
|
+
return;
|
5324
|
+
}
|
5175
5325
|
switch (tensor->type) {
|
5176
5326
|
case GGML_TYPE_I8:
|
5177
5327
|
{
|
@@ -5205,6 +5355,56 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
|
5205
5355
|
}
|
5206
5356
|
}
|
5207
5357
|
|
5358
|
+
float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
|
5359
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
5360
|
+
switch (tensor->type) {
|
5361
|
+
case GGML_TYPE_I8:
|
5362
|
+
return ((int8_t *) data)[0];
|
5363
|
+
case GGML_TYPE_I16:
|
5364
|
+
return ((int16_t *) data)[0];
|
5365
|
+
case GGML_TYPE_I32:
|
5366
|
+
return ((int32_t *) data)[0];
|
5367
|
+
case GGML_TYPE_F16:
|
5368
|
+
return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
|
5369
|
+
case GGML_TYPE_F32:
|
5370
|
+
return ((float *) data)[0];
|
5371
|
+
default:
|
5372
|
+
GGML_ASSERT(false);
|
5373
|
+
}
|
5374
|
+
|
5375
|
+
return 0.0f;
|
5376
|
+
}
|
5377
|
+
|
5378
|
+
void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
|
5379
|
+
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
|
5380
|
+
switch (tensor->type) {
|
5381
|
+
case GGML_TYPE_I8:
|
5382
|
+
{
|
5383
|
+
((int8_t *)(data))[0] = value;
|
5384
|
+
} break;
|
5385
|
+
case GGML_TYPE_I16:
|
5386
|
+
{
|
5387
|
+
((int16_t *)(data))[0] = value;
|
5388
|
+
} break;
|
5389
|
+
case GGML_TYPE_I32:
|
5390
|
+
{
|
5391
|
+
((int32_t *)(data))[0] = value;
|
5392
|
+
} break;
|
5393
|
+
case GGML_TYPE_F16:
|
5394
|
+
{
|
5395
|
+
((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
|
5396
|
+
} break;
|
5397
|
+
case GGML_TYPE_F32:
|
5398
|
+
{
|
5399
|
+
((float *)(data))[0] = value;
|
5400
|
+
} break;
|
5401
|
+
default:
|
5402
|
+
{
|
5403
|
+
GGML_ASSERT(false);
|
5404
|
+
} break;
|
5405
|
+
}
|
5406
|
+
}
|
5407
|
+
|
5208
5408
|
void * ggml_get_data(const struct ggml_tensor * tensor) {
|
5209
5409
|
return tensor->data;
|
5210
5410
|
}
|
@@ -5347,6 +5547,44 @@ struct ggml_tensor * ggml_add_inplace(
|
|
5347
5547
|
return ggml_add_impl(ctx, a, b, true);
|
5348
5548
|
}
|
5349
5549
|
|
5550
|
+
// ggml_add_cast
|
5551
|
+
|
5552
|
+
static struct ggml_tensor * ggml_add_cast_impl(
|
5553
|
+
struct ggml_context * ctx,
|
5554
|
+
struct ggml_tensor * a,
|
5555
|
+
struct ggml_tensor * b,
|
5556
|
+
enum ggml_type type) {
|
5557
|
+
// TODO: support less-strict constraint
|
5558
|
+
// GGML_ASSERT(ggml_can_repeat(b, a));
|
5559
|
+
GGML_ASSERT(ggml_can_repeat_rows(b, a));
|
5560
|
+
GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
|
5561
|
+
|
5562
|
+
bool is_node = false;
|
5563
|
+
|
5564
|
+
if (a->grad || b->grad) {
|
5565
|
+
// TODO: support backward pass for broadcasting
|
5566
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
5567
|
+
is_node = true;
|
5568
|
+
}
|
5569
|
+
|
5570
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
|
5571
|
+
|
5572
|
+
result->op = GGML_OP_ADD;
|
5573
|
+
result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
|
5574
|
+
result->src[0] = a;
|
5575
|
+
result->src[1] = b;
|
5576
|
+
|
5577
|
+
return result;
|
5578
|
+
}
|
5579
|
+
|
5580
|
+
struct ggml_tensor * ggml_add_cast(
|
5581
|
+
struct ggml_context * ctx,
|
5582
|
+
struct ggml_tensor * a,
|
5583
|
+
struct ggml_tensor * b,
|
5584
|
+
enum ggml_type type) {
|
5585
|
+
return ggml_add_cast_impl(ctx, a, b, type);
|
5586
|
+
}
|
5587
|
+
|
5350
5588
|
// ggml_add1
|
5351
5589
|
|
5352
5590
|
static struct ggml_tensor * ggml_add1_impl(
|
@@ -5783,7 +6021,6 @@ struct ggml_tensor * ggml_repeat(
|
|
5783
6021
|
result->op = GGML_OP_REPEAT;
|
5784
6022
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5785
6023
|
result->src[0] = a;
|
5786
|
-
result->src[1] = b;
|
5787
6024
|
|
5788
6025
|
return result;
|
5789
6026
|
}
|
@@ -5811,7 +6048,6 @@ struct ggml_tensor * ggml_repeat_back(
|
|
5811
6048
|
result->op = GGML_OP_REPEAT_BACK;
|
5812
6049
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5813
6050
|
result->src[0] = a;
|
5814
|
-
result->src[1] = b;
|
5815
6051
|
|
5816
6052
|
return result;
|
5817
6053
|
}
|
@@ -6186,8 +6422,9 @@ struct ggml_tensor * ggml_out_prod(
|
|
6186
6422
|
is_node = true;
|
6187
6423
|
}
|
6188
6424
|
|
6189
|
-
|
6190
|
-
|
6425
|
+
// a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
|
6426
|
+
const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
|
6427
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
|
6191
6428
|
|
6192
6429
|
result->op = GGML_OP_OUT_PROD;
|
6193
6430
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
@@ -6406,6 +6643,54 @@ struct ggml_tensor * ggml_cont_inplace(
|
|
6406
6643
|
return ggml_cont_impl(ctx, a, true);
|
6407
6644
|
}
|
6408
6645
|
|
6646
|
+
|
6647
|
+
// make contiguous, with new shape
|
6648
|
+
GGML_API struct ggml_tensor * ggml_cont_1d(
|
6649
|
+
struct ggml_context * ctx,
|
6650
|
+
struct ggml_tensor * a,
|
6651
|
+
int64_t ne0) {
|
6652
|
+
return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
|
6653
|
+
}
|
6654
|
+
|
6655
|
+
GGML_API struct ggml_tensor * ggml_cont_2d(
|
6656
|
+
struct ggml_context * ctx,
|
6657
|
+
struct ggml_tensor * a,
|
6658
|
+
int64_t ne0,
|
6659
|
+
int64_t ne1) {
|
6660
|
+
return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
|
6661
|
+
}
|
6662
|
+
|
6663
|
+
GGML_API struct ggml_tensor * ggml_cont_3d(
|
6664
|
+
struct ggml_context * ctx,
|
6665
|
+
struct ggml_tensor * a,
|
6666
|
+
int64_t ne0,
|
6667
|
+
int64_t ne1,
|
6668
|
+
int64_t ne2) {
|
6669
|
+
return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
|
6670
|
+
}
|
6671
|
+
|
6672
|
+
struct ggml_tensor * ggml_cont_4d(
|
6673
|
+
struct ggml_context * ctx,
|
6674
|
+
struct ggml_tensor * a,
|
6675
|
+
int64_t ne0,
|
6676
|
+
int64_t ne1,
|
6677
|
+
int64_t ne2,
|
6678
|
+
int64_t ne3) {
|
6679
|
+
GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
|
6680
|
+
|
6681
|
+
bool is_node = false;
|
6682
|
+
|
6683
|
+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
|
6684
|
+
ggml_format_name(result, "%s (cont)", a->name);
|
6685
|
+
|
6686
|
+
result->op = GGML_OP_CONT;
|
6687
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6688
|
+
result->src[0] = a;
|
6689
|
+
|
6690
|
+
return result;
|
6691
|
+
}
|
6692
|
+
|
6693
|
+
|
6409
6694
|
// ggml_reshape
|
6410
6695
|
|
6411
6696
|
struct ggml_tensor * ggml_reshape(
|
@@ -6413,7 +6698,7 @@ struct ggml_tensor * ggml_reshape(
|
|
6413
6698
|
struct ggml_tensor * a,
|
6414
6699
|
struct ggml_tensor * b) {
|
6415
6700
|
GGML_ASSERT(ggml_is_contiguous(a));
|
6416
|
-
|
6701
|
+
// as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
|
6417
6702
|
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
|
6418
6703
|
|
6419
6704
|
bool is_node = false;
|
@@ -6786,7 +7071,6 @@ struct ggml_tensor * ggml_get_rows_back(
|
|
6786
7071
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6787
7072
|
result->src[0] = a;
|
6788
7073
|
result->src[1] = b;
|
6789
|
-
result->src[2] = c;
|
6790
7074
|
|
6791
7075
|
return result;
|
6792
7076
|
}
|
@@ -6968,7 +7252,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
|
|
6968
7252
|
static struct ggml_tensor * ggml_rope_impl(
|
6969
7253
|
struct ggml_context * ctx,
|
6970
7254
|
struct ggml_tensor * a,
|
6971
|
-
|
7255
|
+
struct ggml_tensor * b,
|
6972
7256
|
int n_dims,
|
6973
7257
|
int mode,
|
6974
7258
|
int n_ctx,
|
@@ -6977,7 +7261,10 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6977
7261
|
float xpos_base,
|
6978
7262
|
bool xpos_down,
|
6979
7263
|
bool inplace) {
|
6980
|
-
GGML_ASSERT(
|
7264
|
+
GGML_ASSERT(ggml_is_vector(b));
|
7265
|
+
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
7266
|
+
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
7267
|
+
|
6981
7268
|
bool is_node = false;
|
6982
7269
|
|
6983
7270
|
if (a->grad) {
|
@@ -6986,7 +7273,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6986
7273
|
|
6987
7274
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6988
7275
|
|
6989
|
-
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
|
7276
|
+
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
|
6990
7277
|
memcpy(params + 4, &freq_base, sizeof(float));
|
6991
7278
|
memcpy(params + 5, &freq_scale, sizeof(float));
|
6992
7279
|
memcpy(params + 6, &xpos_base, sizeof(float));
|
@@ -6996,6 +7283,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6996
7283
|
result->op = GGML_OP_ROPE;
|
6997
7284
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6998
7285
|
result->src[0] = a;
|
7286
|
+
result->src[1] = b;
|
6999
7287
|
|
7000
7288
|
return result;
|
7001
7289
|
}
|
@@ -7003,55 +7291,55 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
7003
7291
|
struct ggml_tensor * ggml_rope(
|
7004
7292
|
struct ggml_context * ctx,
|
7005
7293
|
struct ggml_tensor * a,
|
7006
|
-
|
7294
|
+
struct ggml_tensor * b,
|
7007
7295
|
int n_dims,
|
7008
7296
|
int mode,
|
7009
7297
|
int n_ctx) {
|
7010
|
-
return ggml_rope_impl(ctx, a,
|
7298
|
+
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
|
7011
7299
|
}
|
7012
7300
|
|
7013
7301
|
struct ggml_tensor * ggml_rope_inplace(
|
7014
7302
|
struct ggml_context * ctx,
|
7015
7303
|
struct ggml_tensor * a,
|
7016
|
-
|
7304
|
+
struct ggml_tensor * b,
|
7017
7305
|
int n_dims,
|
7018
7306
|
int mode,
|
7019
7307
|
int n_ctx) {
|
7020
|
-
return ggml_rope_impl(ctx, a,
|
7308
|
+
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
|
7021
7309
|
}
|
7022
7310
|
|
7023
7311
|
struct ggml_tensor * ggml_rope_custom(
|
7024
7312
|
struct ggml_context * ctx,
|
7025
7313
|
struct ggml_tensor * a,
|
7026
|
-
|
7314
|
+
struct ggml_tensor * b,
|
7027
7315
|
int n_dims,
|
7028
7316
|
int mode,
|
7029
7317
|
int n_ctx,
|
7030
7318
|
float freq_base,
|
7031
7319
|
float freq_scale) {
|
7032
|
-
return ggml_rope_impl(ctx, a,
|
7320
|
+
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
|
7033
7321
|
}
|
7034
7322
|
|
7035
7323
|
struct ggml_tensor * ggml_rope_custom_inplace(
|
7036
7324
|
struct ggml_context * ctx,
|
7037
7325
|
struct ggml_tensor * a,
|
7038
|
-
|
7326
|
+
struct ggml_tensor * b,
|
7039
7327
|
int n_dims,
|
7040
7328
|
int mode,
|
7041
7329
|
int n_ctx,
|
7042
7330
|
float freq_base,
|
7043
7331
|
float freq_scale) {
|
7044
|
-
return ggml_rope_impl(ctx, a,
|
7332
|
+
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
|
7045
7333
|
}
|
7046
7334
|
|
7047
7335
|
struct ggml_tensor * ggml_rope_xpos_inplace(
|
7048
7336
|
struct ggml_context * ctx,
|
7049
7337
|
struct ggml_tensor * a,
|
7050
|
-
|
7338
|
+
struct ggml_tensor * b,
|
7051
7339
|
int n_dims,
|
7052
7340
|
float base,
|
7053
7341
|
bool down) {
|
7054
|
-
return ggml_rope_impl(ctx, a,
|
7342
|
+
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
|
7055
7343
|
}
|
7056
7344
|
|
7057
7345
|
// ggml_rope_back
|
@@ -7059,7 +7347,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
|
|
7059
7347
|
struct ggml_tensor * ggml_rope_back(
|
7060
7348
|
struct ggml_context * ctx,
|
7061
7349
|
struct ggml_tensor * a,
|
7062
|
-
|
7350
|
+
struct ggml_tensor * b,
|
7063
7351
|
int n_dims,
|
7064
7352
|
int mode,
|
7065
7353
|
int n_ctx,
|
@@ -7067,7 +7355,10 @@ struct ggml_tensor * ggml_rope_back(
|
|
7067
7355
|
float freq_scale,
|
7068
7356
|
float xpos_base,
|
7069
7357
|
bool xpos_down) {
|
7070
|
-
GGML_ASSERT(
|
7358
|
+
GGML_ASSERT(ggml_is_vector(b));
|
7359
|
+
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
7360
|
+
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
7361
|
+
|
7071
7362
|
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
7072
7363
|
|
7073
7364
|
bool is_node = false;
|
@@ -7078,7 +7369,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
7078
7369
|
|
7079
7370
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
7080
7371
|
|
7081
|
-
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
|
7372
|
+
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
|
7082
7373
|
memcpy(params + 4, &freq_base, sizeof(float));
|
7083
7374
|
memcpy(params + 5, &freq_scale, sizeof(float));
|
7084
7375
|
memcpy(params + 6, &xpos_base, sizeof(float));
|
@@ -7088,6 +7379,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
7088
7379
|
result->op = GGML_OP_ROPE_BACK;
|
7089
7380
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7090
7381
|
result->src[0] = a;
|
7382
|
+
result->src[1] = b;
|
7091
7383
|
|
7092
7384
|
return result;
|
7093
7385
|
}
|
@@ -7484,27 +7776,30 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
7484
7776
|
|
7485
7777
|
// d shape [D,N,ne2,ne3]
|
7486
7778
|
// q shape [D,N,ne2,ne3]
|
7487
|
-
// k shape [D,M,
|
7488
|
-
// v shape [M,D,
|
7779
|
+
// k shape [D,M,kvne2,ne3]
|
7780
|
+
// v shape [M,D,kvne2,ne3]
|
7489
7781
|
|
7490
|
-
const int64_t
|
7491
|
-
const int64_t
|
7492
|
-
const int64_t
|
7493
|
-
const int64_t
|
7494
|
-
const int64_t
|
7782
|
+
const int64_t D = q->ne[0];
|
7783
|
+
const int64_t N = q->ne[1];
|
7784
|
+
const int64_t M = k->ne[1];
|
7785
|
+
const int64_t ne2 = q->ne[2];
|
7786
|
+
const int64_t ne3 = q->ne[3];
|
7787
|
+
const int64_t kvne2 = k->ne[2];
|
7495
7788
|
|
7496
7789
|
GGML_ASSERT(k->ne[0] == D);
|
7497
7790
|
GGML_ASSERT(v->ne[0] == M);
|
7498
7791
|
GGML_ASSERT(v->ne[1] == D);
|
7499
7792
|
GGML_ASSERT(d->ne[0] == D);
|
7500
7793
|
GGML_ASSERT(d->ne[1] == N);
|
7501
|
-
GGML_ASSERT(k->ne[2] ==
|
7794
|
+
GGML_ASSERT(k->ne[2] == kvne2);
|
7502
7795
|
GGML_ASSERT(k->ne[3] == ne3);
|
7503
|
-
GGML_ASSERT(v->ne[2] ==
|
7796
|
+
GGML_ASSERT(v->ne[2] == kvne2);
|
7504
7797
|
GGML_ASSERT(v->ne[3] == ne3);
|
7505
7798
|
GGML_ASSERT(d->ne[2] == ne2);
|
7506
7799
|
GGML_ASSERT(d->ne[3] == ne3);
|
7507
7800
|
|
7801
|
+
GGML_ASSERT(ne2 % kvne2 == 0);
|
7802
|
+
|
7508
7803
|
bool is_node = false;
|
7509
7804
|
|
7510
7805
|
if (q->grad || k->grad || v->grad) {
|
@@ -7514,14 +7809,23 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
7514
7809
|
}
|
7515
7810
|
|
7516
7811
|
// store gradients of q, k and v as continuous tensors concatenated in result.
|
7517
|
-
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
7518
|
-
// gradq->data = result->data
|
7519
|
-
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
7520
|
-
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
7521
7812
|
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
7522
|
-
int64_t
|
7813
|
+
const int64_t elem_q = ggml_nelements(q);
|
7814
|
+
const int64_t elem_k = ggml_nelements(k);
|
7815
|
+
const int64_t elem_v = ggml_nelements(v);
|
7523
7816
|
|
7524
|
-
|
7817
|
+
enum ggml_type result_type = GGML_TYPE_F32;
|
7818
|
+
GGML_ASSERT(ggml_blck_size(result_type) == 1);
|
7819
|
+
const size_t tsize = ggml_type_size(result_type);
|
7820
|
+
|
7821
|
+
const size_t offs_q = 0;
|
7822
|
+
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
|
7823
|
+
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
|
7824
|
+
const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
|
7825
|
+
|
7826
|
+
const size_t nelements = (end + tsize - 1)/tsize;
|
7827
|
+
|
7828
|
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
|
7525
7829
|
|
7526
7830
|
int32_t masked_i = masked ? 1 : 0;
|
7527
7831
|
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
|
@@ -8214,7 +8518,7 @@ static void ggml_compute_forward_dup_f16(
|
|
8214
8518
|
return;
|
8215
8519
|
}
|
8216
8520
|
|
8217
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
8521
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
8218
8522
|
|
8219
8523
|
const int ith = params->ith; // thread index
|
8220
8524
|
const int nth = params->nth; // number of threads
|
@@ -8485,7 +8789,7 @@ static void ggml_compute_forward_dup_f32(
|
|
8485
8789
|
return;
|
8486
8790
|
}
|
8487
8791
|
|
8488
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
8792
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
8489
8793
|
|
8490
8794
|
const int ith = params->ith; // thread index
|
8491
8795
|
const int nth = params->nth; // number of threads
|
@@ -8766,7 +9070,7 @@ static void ggml_compute_forward_add_f32(
|
|
8766
9070
|
|
8767
9071
|
const int nr = ggml_nrows(src0);
|
8768
9072
|
|
8769
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9073
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
8770
9074
|
|
8771
9075
|
GGML_ASSERT( nb0 == sizeof(float));
|
8772
9076
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -8798,8 +9102,6 @@ static void ggml_compute_forward_add_f32(
|
|
8798
9102
|
#else
|
8799
9103
|
ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
|
8800
9104
|
#endif
|
8801
|
-
// }
|
8802
|
-
// }
|
8803
9105
|
}
|
8804
9106
|
} else {
|
8805
9107
|
// src1 is not contiguous
|
@@ -8841,7 +9143,7 @@ static void ggml_compute_forward_add_f16_f32(
|
|
8841
9143
|
|
8842
9144
|
const int nr = ggml_nrows(src0);
|
8843
9145
|
|
8844
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9146
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
8845
9147
|
|
8846
9148
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
8847
9149
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
@@ -8895,7 +9197,7 @@ static void ggml_compute_forward_add_f16_f16(
|
|
8895
9197
|
|
8896
9198
|
const int nr = ggml_nrows(src0);
|
8897
9199
|
|
8898
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9200
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
8899
9201
|
|
8900
9202
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
8901
9203
|
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
@@ -8946,14 +9248,15 @@ static void ggml_compute_forward_add_q_f32(
|
|
8946
9248
|
|
8947
9249
|
const int nr = ggml_nrows(src0);
|
8948
9250
|
|
8949
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9251
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
8950
9252
|
|
8951
9253
|
const int ith = params->ith;
|
8952
9254
|
const int nth = params->nth;
|
8953
9255
|
|
8954
9256
|
const enum ggml_type type = src0->type;
|
9257
|
+
const enum ggml_type dtype = dst->type;
|
8955
9258
|
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
8956
|
-
ggml_from_float_t const quantize_row_q = type_traits[
|
9259
|
+
ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
|
8957
9260
|
|
8958
9261
|
// we don't support permuted src0 or src1
|
8959
9262
|
GGML_ASSERT(nb00 == ggml_type_size(type));
|
@@ -8965,7 +9268,6 @@ static void ggml_compute_forward_add_q_f32(
|
|
8965
9268
|
GGML_ASSERT(nb2 <= nb3);
|
8966
9269
|
|
8967
9270
|
GGML_ASSERT(ggml_is_quantized(src0->type));
|
8968
|
-
GGML_ASSERT(dst->type == src0->type);
|
8969
9271
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
8970
9272
|
|
8971
9273
|
// rows per thread
|
@@ -9003,7 +9305,11 @@ static void ggml_compute_forward_add_q_f32(
|
|
9003
9305
|
// add src1
|
9004
9306
|
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
9005
9307
|
// quantize row to dst
|
9006
|
-
quantize_row_q
|
9308
|
+
if (quantize_row_q != NULL) {
|
9309
|
+
quantize_row_q(wdata, dst_row, ne00);
|
9310
|
+
} else {
|
9311
|
+
memcpy(dst_row, wdata, ne0*nb0);
|
9312
|
+
}
|
9007
9313
|
}
|
9008
9314
|
}
|
9009
9315
|
|
@@ -9068,7 +9374,7 @@ static void ggml_compute_forward_add1_f32(
|
|
9068
9374
|
|
9069
9375
|
const int nr = ggml_nrows(src0);
|
9070
9376
|
|
9071
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
9377
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9072
9378
|
|
9073
9379
|
GGML_ASSERT( nb0 == sizeof(float));
|
9074
9380
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -9123,7 +9429,7 @@ static void ggml_compute_forward_add1_f16_f32(
|
|
9123
9429
|
|
9124
9430
|
const int nr = ggml_nrows(src0);
|
9125
9431
|
|
9126
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
9432
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9127
9433
|
|
9128
9434
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
9129
9435
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
@@ -9173,7 +9479,7 @@ static void ggml_compute_forward_add1_f16_f16(
|
|
9173
9479
|
|
9174
9480
|
const int nr = ggml_nrows(src0);
|
9175
9481
|
|
9176
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
9482
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9177
9483
|
|
9178
9484
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
9179
9485
|
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
@@ -9223,7 +9529,7 @@ static void ggml_compute_forward_add1_q_f32(
|
|
9223
9529
|
|
9224
9530
|
const int nr = ggml_nrows(src0);
|
9225
9531
|
|
9226
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
9532
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9227
9533
|
|
9228
9534
|
const enum ggml_type type = src0->type;
|
9229
9535
|
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
@@ -9351,8 +9657,8 @@ static void ggml_compute_forward_acc_f32(
|
|
9351
9657
|
const int nr = ggml_nrows(src1);
|
9352
9658
|
const int nc = src1->ne[0];
|
9353
9659
|
|
9354
|
-
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
9355
|
-
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
9660
|
+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
9661
|
+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
9356
9662
|
|
9357
9663
|
// src0 and dst as viewed during acc
|
9358
9664
|
const size_t nb0 = ggml_element_size(src0);
|
@@ -9441,7 +9747,7 @@ static void ggml_compute_forward_sub_f32(
|
|
9441
9747
|
|
9442
9748
|
const int nr = ggml_nrows(src0);
|
9443
9749
|
|
9444
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9750
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
9445
9751
|
|
9446
9752
|
GGML_ASSERT( nb0 == sizeof(float));
|
9447
9753
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -9531,7 +9837,7 @@ static void ggml_compute_forward_mul_f32(
|
|
9531
9837
|
|
9532
9838
|
const int64_t nr = ggml_nrows(src0);
|
9533
9839
|
|
9534
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9840
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
9535
9841
|
|
9536
9842
|
GGML_ASSERT( nb0 == sizeof(float));
|
9537
9843
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -9622,7 +9928,7 @@ static void ggml_compute_forward_div_f32(
|
|
9622
9928
|
|
9623
9929
|
const int nr = ggml_nrows(src0);
|
9624
9930
|
|
9625
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
9931
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
9626
9932
|
|
9627
9933
|
GGML_ASSERT( nb0 == sizeof(float));
|
9628
9934
|
GGML_ASSERT(nb00 == sizeof(float));
|
@@ -9831,8 +10137,8 @@ static void ggml_compute_forward_sum_f32(
|
|
9831
10137
|
assert(ggml_is_scalar(dst));
|
9832
10138
|
assert(src0->nb[0] == sizeof(float));
|
9833
10139
|
|
9834
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
9835
|
-
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
10140
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
10141
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
9836
10142
|
|
9837
10143
|
ggml_float sum = 0;
|
9838
10144
|
ggml_float row_sum = 0;
|
@@ -9863,8 +10169,8 @@ static void ggml_compute_forward_sum_f16(
|
|
9863
10169
|
|
9864
10170
|
assert(src0->nb[0] == sizeof(ggml_fp16_t));
|
9865
10171
|
|
9866
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
9867
|
-
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
10172
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
10173
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
9868
10174
|
|
9869
10175
|
float sum = 0;
|
9870
10176
|
float row_sum = 0;
|
@@ -9917,7 +10223,7 @@ static void ggml_compute_forward_sum_rows_f32(
|
|
9917
10223
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
9918
10224
|
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
9919
10225
|
|
9920
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
10226
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9921
10227
|
|
9922
10228
|
GGML_ASSERT(ne0 == 1);
|
9923
10229
|
GGML_ASSERT(ne1 == ne01);
|
@@ -9967,7 +10273,7 @@ static void ggml_compute_forward_mean_f32(
|
|
9967
10273
|
|
9968
10274
|
assert(src0->nb[0] == sizeof(float));
|
9969
10275
|
|
9970
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
10276
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
9971
10277
|
|
9972
10278
|
assert(ne0 == 1);
|
9973
10279
|
assert(ne1 == ne01);
|
@@ -10067,7 +10373,7 @@ static void ggml_compute_forward_repeat_f32(
|
|
10067
10373
|
return;
|
10068
10374
|
}
|
10069
10375
|
|
10070
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
10376
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
10071
10377
|
|
10072
10378
|
// guaranteed to be an integer due to the check in ggml_can_repeat
|
10073
10379
|
const int nr0 = (int)(ne0/ne00);
|
@@ -10099,11 +10405,61 @@ static void ggml_compute_forward_repeat_f32(
|
|
10099
10405
|
}
|
10100
10406
|
}
|
10101
10407
|
|
10408
|
+
static void ggml_compute_forward_repeat_f16(
|
10409
|
+
const struct ggml_compute_params * params,
|
10410
|
+
const struct ggml_tensor * src0,
|
10411
|
+
struct ggml_tensor * dst) {
|
10412
|
+
GGML_ASSERT(params->ith == 0);
|
10413
|
+
GGML_ASSERT(ggml_can_repeat(src0, dst));
|
10414
|
+
|
10415
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10416
|
+
return;
|
10417
|
+
}
|
10418
|
+
|
10419
|
+
GGML_TENSOR_UNARY_OP_LOCALS;
|
10420
|
+
|
10421
|
+
// guaranteed to be an integer due to the check in ggml_can_repeat
|
10422
|
+
const int nr0 = (int)(ne0/ne00);
|
10423
|
+
const int nr1 = (int)(ne1/ne01);
|
10424
|
+
const int nr2 = (int)(ne2/ne02);
|
10425
|
+
const int nr3 = (int)(ne3/ne03);
|
10426
|
+
|
10427
|
+
// TODO: support for transposed / permuted tensors
|
10428
|
+
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
10429
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
10430
|
+
|
10431
|
+
// TODO: maybe this is not optimal?
|
10432
|
+
for (int i3 = 0; i3 < nr3; i3++) {
|
10433
|
+
for (int k3 = 0; k3 < ne03; k3++) {
|
10434
|
+
for (int i2 = 0; i2 < nr2; i2++) {
|
10435
|
+
for (int k2 = 0; k2 < ne02; k2++) {
|
10436
|
+
for (int i1 = 0; i1 < nr1; i1++) {
|
10437
|
+
for (int k1 = 0; k1 < ne01; k1++) {
|
10438
|
+
for (int i0 = 0; i0 < nr0; i0++) {
|
10439
|
+
ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
|
10440
|
+
ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
|
10441
|
+
// ggml_vec_cpy_f16(ne00, y, x)
|
10442
|
+
for (int i = 0; i < ne00; ++i) {
|
10443
|
+
y[i] = x[i];
|
10444
|
+
}
|
10445
|
+
}
|
10446
|
+
}
|
10447
|
+
}
|
10448
|
+
}
|
10449
|
+
}
|
10450
|
+
}
|
10451
|
+
}
|
10452
|
+
}
|
10453
|
+
|
10102
10454
|
static void ggml_compute_forward_repeat(
|
10103
10455
|
const struct ggml_compute_params * params,
|
10104
10456
|
const struct ggml_tensor * src0,
|
10105
10457
|
struct ggml_tensor * dst) {
|
10106
10458
|
switch (src0->type) {
|
10459
|
+
case GGML_TYPE_F16:
|
10460
|
+
{
|
10461
|
+
ggml_compute_forward_repeat_f16(params, src0, dst);
|
10462
|
+
} break;
|
10107
10463
|
case GGML_TYPE_F32:
|
10108
10464
|
{
|
10109
10465
|
ggml_compute_forward_repeat_f32(params, src0, dst);
|
@@ -10128,7 +10484,7 @@ static void ggml_compute_forward_repeat_back_f32(
|
|
10128
10484
|
return;
|
10129
10485
|
}
|
10130
10486
|
|
10131
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
10487
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
10132
10488
|
|
10133
10489
|
// guaranteed to be an integer due to the check in ggml_can_repeat
|
10134
10490
|
const int nr0 = (int)(ne00/ne0);
|
@@ -10206,7 +10562,7 @@ static void ggml_compute_forward_concat_f32(
|
|
10206
10562
|
|
10207
10563
|
const int ith = params->ith;
|
10208
10564
|
|
10209
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
10565
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
10210
10566
|
|
10211
10567
|
// TODO: support for transposed / permuted tensors
|
10212
10568
|
GGML_ASSERT(nb0 == sizeof(float));
|
@@ -10808,7 +11164,7 @@ static void ggml_compute_forward_norm_f32(
|
|
10808
11164
|
const int ith = params->ith;
|
10809
11165
|
const int nth = params->nth;
|
10810
11166
|
|
10811
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
11167
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
10812
11168
|
|
10813
11169
|
float eps;
|
10814
11170
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -10877,7 +11233,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
10877
11233
|
const int ith = params->ith;
|
10878
11234
|
const int nth = params->nth;
|
10879
11235
|
|
10880
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
11236
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
10881
11237
|
|
10882
11238
|
float eps;
|
10883
11239
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -10942,7 +11298,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|
10942
11298
|
const int ith = params->ith;
|
10943
11299
|
const int nth = params->nth;
|
10944
11300
|
|
10945
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
11301
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
10946
11302
|
|
10947
11303
|
float eps;
|
10948
11304
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -11117,7 +11473,7 @@ static void ggml_compute_forward_group_norm_f32(
|
|
11117
11473
|
const int ith = params->ith;
|
11118
11474
|
const int nth = params->nth;
|
11119
11475
|
|
11120
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
11476
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
11121
11477
|
|
11122
11478
|
const float eps = 1e-6f; // TODO: make this a parameter
|
11123
11479
|
|
@@ -11228,7 +11584,7 @@ static void ggml_compute_forward_mul_mat(
|
|
11228
11584
|
int64_t t0 = ggml_perf_time_us();
|
11229
11585
|
UNUSED(t0);
|
11230
11586
|
|
11231
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
11587
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
11232
11588
|
|
11233
11589
|
const int ith = params->ith;
|
11234
11590
|
const int nth = params->nth;
|
@@ -11443,10 +11799,10 @@ static void ggml_compute_forward_out_prod_f32(
|
|
11443
11799
|
const struct ggml_tensor * src0,
|
11444
11800
|
const struct ggml_tensor * src1,
|
11445
11801
|
struct ggml_tensor * dst) {
|
11446
|
-
int64_t t0 = ggml_perf_time_us();
|
11447
|
-
UNUSED(t0);
|
11802
|
+
// int64_t t0 = ggml_perf_time_us();
|
11803
|
+
// UNUSED(t0);
|
11448
11804
|
|
11449
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
11805
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
11450
11806
|
|
11451
11807
|
const int ith = params->ith;
|
11452
11808
|
const int nth = params->nth;
|
@@ -11485,6 +11841,146 @@ static void ggml_compute_forward_out_prod_f32(
|
|
11485
11841
|
return;
|
11486
11842
|
}
|
11487
11843
|
|
11844
|
+
// dst[:,:,:,:] = 0
|
11845
|
+
// for i2,i3:
|
11846
|
+
// for i1:
|
11847
|
+
// for i01:
|
11848
|
+
// for i0:
|
11849
|
+
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
11850
|
+
|
11851
|
+
// parallelize by last three dimensions
|
11852
|
+
|
11853
|
+
// total rows in dst
|
11854
|
+
const int64_t nr = ne1*ne2*ne3;
|
11855
|
+
|
11856
|
+
// rows per thread
|
11857
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
11858
|
+
|
11859
|
+
// row range for this thread
|
11860
|
+
const int64_t ir0 = dr*ith;
|
11861
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
11862
|
+
|
11863
|
+
// block-tiling attempt
|
11864
|
+
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
|
11865
|
+
const int64_t blck_1 = 16;
|
11866
|
+
|
11867
|
+
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
|
11868
|
+
const int64_t bir1 = MIN(bir + blck_1, ir1);
|
11869
|
+
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
|
11870
|
+
const int64_t bne01 = MIN(bi01 + blck_0, ne01);
|
11871
|
+
for (int64_t ir = bir; ir < bir1; ++ir) {
|
11872
|
+
// dst indices
|
11873
|
+
const int64_t i3 = ir/(ne2*ne1);
|
11874
|
+
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
11875
|
+
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
11876
|
+
|
11877
|
+
const int64_t i02 = i2;
|
11878
|
+
const int64_t i03 = i3;
|
11879
|
+
|
11880
|
+
//const int64_t i10 = i1;
|
11881
|
+
const int64_t i12 = i2;
|
11882
|
+
const int64_t i13 = i3;
|
11883
|
+
|
11884
|
+
#if GGML_VEC_MAD_UNROLL > 2
|
11885
|
+
const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
|
11886
|
+
for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
|
11887
|
+
const int64_t i11 = i01;
|
11888
|
+
|
11889
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
11890
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
11891
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
11892
|
+
|
11893
|
+
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
11894
|
+
}
|
11895
|
+
for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
|
11896
|
+
const int64_t i11 = i01;
|
11897
|
+
|
11898
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
11899
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
11900
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
11901
|
+
|
11902
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
11903
|
+
}
|
11904
|
+
#else
|
11905
|
+
for (int64_t i01 = bi01; i01 < bne01; ++i01) {
|
11906
|
+
const int64_t i11 = i01;
|
11907
|
+
|
11908
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
11909
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
11910
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
11911
|
+
|
11912
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
11913
|
+
}
|
11914
|
+
#endif
|
11915
|
+
}
|
11916
|
+
}
|
11917
|
+
}
|
11918
|
+
|
11919
|
+
|
11920
|
+
//int64_t t1 = ggml_perf_time_us();
|
11921
|
+
//static int64_t acc = 0;
|
11922
|
+
//acc += t1 - t0;
|
11923
|
+
//if (t1 - t0 > 10) {
|
11924
|
+
// printf("\n");
|
11925
|
+
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
|
11926
|
+
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
|
11927
|
+
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
|
11928
|
+
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
|
11929
|
+
|
11930
|
+
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
|
11931
|
+
//}
|
11932
|
+
}
|
11933
|
+
|
11934
|
+
static void ggml_compute_forward_out_prod_q_f32(
|
11935
|
+
const struct ggml_compute_params * params,
|
11936
|
+
const struct ggml_tensor * src0,
|
11937
|
+
const struct ggml_tensor * src1,
|
11938
|
+
struct ggml_tensor * dst) {
|
11939
|
+
// int64_t t0 = ggml_perf_time_us();
|
11940
|
+
// UNUSED(t0);
|
11941
|
+
|
11942
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
11943
|
+
|
11944
|
+
const int ith = params->ith;
|
11945
|
+
const int nth = params->nth;
|
11946
|
+
|
11947
|
+
const enum ggml_type type = src0->type;
|
11948
|
+
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
11949
|
+
|
11950
|
+
GGML_ASSERT(ne02 == ne12);
|
11951
|
+
GGML_ASSERT(ne03 == ne13);
|
11952
|
+
GGML_ASSERT(ne2 == ne12);
|
11953
|
+
GGML_ASSERT(ne3 == ne13);
|
11954
|
+
|
11955
|
+
// we don't support permuted src0 dim0
|
11956
|
+
GGML_ASSERT(nb00 == ggml_type_size(type));
|
11957
|
+
|
11958
|
+
// dst dim0 cannot be transposed or permuted
|
11959
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
11960
|
+
// GGML_ASSERT(nb0 <= nb1);
|
11961
|
+
// GGML_ASSERT(nb1 <= nb2);
|
11962
|
+
// GGML_ASSERT(nb2 <= nb3);
|
11963
|
+
|
11964
|
+
GGML_ASSERT(ne0 == ne00);
|
11965
|
+
GGML_ASSERT(ne1 == ne10);
|
11966
|
+
GGML_ASSERT(ne2 == ne02);
|
11967
|
+
GGML_ASSERT(ne3 == ne03);
|
11968
|
+
|
11969
|
+
// nb01 >= nb00 - src0 is not transposed
|
11970
|
+
// compute by src0 rows
|
11971
|
+
|
11972
|
+
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
|
11973
|
+
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
11974
|
+
|
11975
|
+
if (params->type == GGML_TASK_INIT) {
|
11976
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
11977
|
+
return;
|
11978
|
+
}
|
11979
|
+
|
11980
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
11981
|
+
return;
|
11982
|
+
}
|
11983
|
+
|
11488
11984
|
// parallelize by last three dimensions
|
11489
11985
|
|
11490
11986
|
// total rows in dst
|
@@ -11504,6 +12000,8 @@ static void ggml_compute_forward_out_prod_f32(
|
|
11504
12000
|
// for i0:
|
11505
12001
|
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
11506
12002
|
|
12003
|
+
float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
|
12004
|
+
|
11507
12005
|
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
11508
12006
|
// dst indices
|
11509
12007
|
const int64_t i3 = ir/(ne2*ne1);
|
@@ -11524,10 +12022,8 @@ static void ggml_compute_forward_out_prod_f32(
|
|
11524
12022
|
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
11525
12023
|
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
11526
12024
|
|
11527
|
-
|
11528
|
-
|
11529
|
-
// d[i0] += s0[i0] * s1[i1];
|
11530
|
-
// }
|
12025
|
+
dequantize_row_q(s0, wdata, ne0);
|
12026
|
+
ggml_vec_mad_f32(ne0, d, wdata, *s1);
|
11531
12027
|
}
|
11532
12028
|
}
|
11533
12029
|
|
@@ -11556,10 +12052,13 @@ static void ggml_compute_forward_out_prod(
|
|
11556
12052
|
case GGML_TYPE_Q5_0:
|
11557
12053
|
case GGML_TYPE_Q5_1:
|
11558
12054
|
case GGML_TYPE_Q8_0:
|
11559
|
-
case
|
12055
|
+
case GGML_TYPE_Q2_K:
|
12056
|
+
case GGML_TYPE_Q3_K:
|
12057
|
+
case GGML_TYPE_Q4_K:
|
12058
|
+
case GGML_TYPE_Q5_K:
|
12059
|
+
case GGML_TYPE_Q6_K:
|
11560
12060
|
{
|
11561
|
-
|
11562
|
-
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
12061
|
+
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
11563
12062
|
} break;
|
11564
12063
|
case GGML_TYPE_F16:
|
11565
12064
|
{
|
@@ -11677,8 +12176,8 @@ static void ggml_compute_forward_set_f32(
|
|
11677
12176
|
const int nr = ggml_nrows(src1);
|
11678
12177
|
const int nc = src1->ne[0];
|
11679
12178
|
|
11680
|
-
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
11681
|
-
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
12179
|
+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
12180
|
+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
11682
12181
|
|
11683
12182
|
// src0 and dst as viewed during set
|
11684
12183
|
const size_t nb0 = ggml_element_size(src0);
|
@@ -11947,14 +12446,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
|
|
11947
12446
|
const struct ggml_compute_params * params,
|
11948
12447
|
const struct ggml_tensor * src0,
|
11949
12448
|
const struct ggml_tensor * src1,
|
11950
|
-
const struct ggml_tensor * opt0,
|
11951
12449
|
struct ggml_tensor * dst) {
|
11952
12450
|
GGML_ASSERT(params->ith == 0);
|
11953
|
-
GGML_ASSERT(ggml_are_same_shape(opt0, dst));
|
11954
|
-
GGML_ASSERT(ggml_is_contiguous(opt0));
|
11955
12451
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
11956
12452
|
|
11957
|
-
ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
12453
|
+
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
12454
|
+
|
12455
|
+
if (params->type == GGML_TASK_INIT) {
|
12456
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
12457
|
+
}
|
11958
12458
|
|
11959
12459
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11960
12460
|
return;
|
@@ -11980,11 +12480,8 @@ static void ggml_compute_forward_get_rows_back_f32(
|
|
11980
12480
|
const struct ggml_compute_params * params,
|
11981
12481
|
const struct ggml_tensor * src0,
|
11982
12482
|
const struct ggml_tensor * src1,
|
11983
|
-
const struct ggml_tensor * opt0,
|
11984
12483
|
struct ggml_tensor * dst) {
|
11985
12484
|
GGML_ASSERT(params->ith == 0);
|
11986
|
-
GGML_ASSERT(ggml_are_same_shape(opt0, dst));
|
11987
|
-
GGML_ASSERT(ggml_is_contiguous(opt0));
|
11988
12485
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
11989
12486
|
|
11990
12487
|
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
@@ -12018,16 +12515,15 @@ static void ggml_compute_forward_get_rows_back(
|
|
12018
12515
|
const struct ggml_compute_params * params,
|
12019
12516
|
const struct ggml_tensor * src0,
|
12020
12517
|
const struct ggml_tensor * src1,
|
12021
|
-
const struct ggml_tensor * opt0,
|
12022
12518
|
struct ggml_tensor * dst) {
|
12023
12519
|
switch (src0->type) {
|
12024
12520
|
case GGML_TYPE_F16:
|
12025
12521
|
{
|
12026
|
-
ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1,
|
12522
|
+
ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
|
12027
12523
|
} break;
|
12028
12524
|
case GGML_TYPE_F32:
|
12029
12525
|
{
|
12030
|
-
ggml_compute_forward_get_rows_back_f32(params, src0, src1,
|
12526
|
+
ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
|
12031
12527
|
} break;
|
12032
12528
|
default:
|
12033
12529
|
{
|
@@ -12068,7 +12564,7 @@ static void ggml_compute_forward_diag_f32(
|
|
12068
12564
|
|
12069
12565
|
// TODO: handle transposed/permuted matrices
|
12070
12566
|
|
12071
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
12567
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
12072
12568
|
|
12073
12569
|
GGML_ASSERT(ne00 == ne0);
|
12074
12570
|
GGML_ASSERT(ne00 == ne1);
|
@@ -12456,13 +12952,11 @@ static void ggml_compute_forward_alibi_f16(
|
|
12456
12952
|
return;
|
12457
12953
|
}
|
12458
12954
|
|
12459
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
12955
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12460
12956
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
12461
12957
|
float max_bias;
|
12462
12958
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
12463
12959
|
|
12464
|
-
assert(n_past >= 0);
|
12465
|
-
|
12466
12960
|
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
12467
12961
|
const int ne1 = src0->ne[1]; // seq_len_without_past
|
12468
12962
|
const int ne2 = src0->ne[2]; // n_head -> this is k
|
@@ -12477,7 +12971,7 @@ static void ggml_compute_forward_alibi_f16(
|
|
12477
12971
|
//const int nb3 = src0->nb[3];
|
12478
12972
|
|
12479
12973
|
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
12480
|
-
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
12974
|
+
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
12481
12975
|
GGML_ASSERT(n_head == ne2);
|
12482
12976
|
|
12483
12977
|
// add alibi to src0 (KQ_scaled)
|
@@ -12623,8 +13117,8 @@ static void ggml_compute_forward_clamp(
|
|
12623
13117
|
static void ggml_compute_forward_rope_f32(
|
12624
13118
|
const struct ggml_compute_params * params,
|
12625
13119
|
const struct ggml_tensor * src0,
|
13120
|
+
const struct ggml_tensor * src1,
|
12626
13121
|
struct ggml_tensor * dst) {
|
12627
|
-
|
12628
13122
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
12629
13123
|
return;
|
12630
13124
|
}
|
@@ -12634,9 +13128,9 @@ static void ggml_compute_forward_rope_f32(
|
|
12634
13128
|
|
12635
13129
|
// these two only relevant for xPos RoPE:
|
12636
13130
|
float xpos_base;
|
12637
|
-
bool
|
13131
|
+
bool xpos_down;
|
12638
13132
|
|
12639
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13133
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12640
13134
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
12641
13135
|
const int mode = ((int32_t *) dst->op_params)[2];
|
12642
13136
|
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
@@ -12645,9 +13139,7 @@ static void ggml_compute_forward_rope_f32(
|
|
12645
13139
|
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
12646
13140
|
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
12647
13141
|
|
12648
|
-
|
12649
|
-
|
12650
|
-
GGML_TENSOR_UNARY_OP_LOCALS;
|
13142
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
12651
13143
|
|
12652
13144
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
12653
13145
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
@@ -12677,9 +13169,11 @@ static void ggml_compute_forward_rope_f32(
|
|
12677
13169
|
const bool is_neox = mode & 2;
|
12678
13170
|
const bool is_glm = mode & 4;
|
12679
13171
|
|
13172
|
+
const int32_t * pos = (const int32_t *) src1->data;
|
13173
|
+
|
12680
13174
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
12681
|
-
for (int64_t i2 =
|
12682
|
-
const int64_t p =
|
13175
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
13176
|
+
const int64_t p = pos[i2];
|
12683
13177
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
12684
13178
|
if (ir++ < ir0) continue;
|
12685
13179
|
if (ir > ir1) break;
|
@@ -12716,7 +13210,7 @@ static void ggml_compute_forward_rope_f32(
|
|
12716
13210
|
const float cos_theta = cosf(theta);
|
12717
13211
|
const float sin_theta = sinf(theta);
|
12718
13212
|
// zeta scaling for xPos only:
|
12719
|
-
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0),
|
13213
|
+
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
12720
13214
|
if (xpos_down) zeta = 1.0f / zeta;
|
12721
13215
|
|
12722
13216
|
theta *= theta_scale;
|
@@ -12761,8 +13255,8 @@ static void ggml_compute_forward_rope_f32(
|
|
12761
13255
|
static void ggml_compute_forward_rope_f16(
|
12762
13256
|
const struct ggml_compute_params * params,
|
12763
13257
|
const struct ggml_tensor * src0,
|
13258
|
+
const struct ggml_tensor * src1,
|
12764
13259
|
struct ggml_tensor * dst) {
|
12765
|
-
|
12766
13260
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
12767
13261
|
return;
|
12768
13262
|
}
|
@@ -12770,16 +13264,14 @@ static void ggml_compute_forward_rope_f16(
|
|
12770
13264
|
float freq_base;
|
12771
13265
|
float freq_scale;
|
12772
13266
|
|
12773
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13267
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12774
13268
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
12775
13269
|
const int mode = ((int32_t *) dst->op_params)[2];
|
12776
13270
|
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
12777
13271
|
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
12778
13272
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
12779
13273
|
|
12780
|
-
|
12781
|
-
|
12782
|
-
GGML_TENSOR_UNARY_OP_LOCALS;
|
13274
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
12783
13275
|
|
12784
13276
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
12785
13277
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
@@ -12809,9 +13301,11 @@ static void ggml_compute_forward_rope_f16(
|
|
12809
13301
|
const bool is_neox = mode & 2;
|
12810
13302
|
const bool is_glm = mode & 4;
|
12811
13303
|
|
13304
|
+
const int32_t * pos = (const int32_t *) src1->data;
|
13305
|
+
|
12812
13306
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
12813
|
-
for (int64_t i2 =
|
12814
|
-
const int64_t p =
|
13307
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
13308
|
+
const int64_t p = pos[i2];
|
12815
13309
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
12816
13310
|
if (ir++ < ir0) continue;
|
12817
13311
|
if (ir > ir1) break;
|
@@ -12890,15 +13384,16 @@ static void ggml_compute_forward_rope_f16(
|
|
12890
13384
|
static void ggml_compute_forward_rope(
|
12891
13385
|
const struct ggml_compute_params * params,
|
12892
13386
|
const struct ggml_tensor * src0,
|
13387
|
+
const struct ggml_tensor * src1,
|
12893
13388
|
struct ggml_tensor * dst) {
|
12894
13389
|
switch (src0->type) {
|
12895
13390
|
case GGML_TYPE_F16:
|
12896
13391
|
{
|
12897
|
-
ggml_compute_forward_rope_f16(params, src0, dst);
|
13392
|
+
ggml_compute_forward_rope_f16(params, src0, src1, dst);
|
12898
13393
|
} break;
|
12899
13394
|
case GGML_TYPE_F32:
|
12900
13395
|
{
|
12901
|
-
ggml_compute_forward_rope_f32(params, src0, dst);
|
13396
|
+
ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
12902
13397
|
} break;
|
12903
13398
|
default:
|
12904
13399
|
{
|
@@ -12912,6 +13407,7 @@ static void ggml_compute_forward_rope(
|
|
12912
13407
|
static void ggml_compute_forward_rope_back_f32(
|
12913
13408
|
const struct ggml_compute_params * params,
|
12914
13409
|
const struct ggml_tensor * src0,
|
13410
|
+
const struct ggml_tensor * src1,
|
12915
13411
|
struct ggml_tensor * dst) {
|
12916
13412
|
|
12917
13413
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -12929,7 +13425,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|
12929
13425
|
float xpos_base;
|
12930
13426
|
bool xpos_down;
|
12931
13427
|
|
12932
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13428
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
12933
13429
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
12934
13430
|
const int mode = ((int32_t *) dst->op_params)[2];
|
12935
13431
|
const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
|
@@ -12938,9 +13434,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|
12938
13434
|
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
12939
13435
|
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
12940
13436
|
|
12941
|
-
|
12942
|
-
|
12943
|
-
GGML_TENSOR_UNARY_OP_LOCALS;
|
13437
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
12944
13438
|
|
12945
13439
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
12946
13440
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
@@ -12966,9 +13460,11 @@ static void ggml_compute_forward_rope_back_f32(
|
|
12966
13460
|
|
12967
13461
|
const bool is_neox = mode & 2;
|
12968
13462
|
|
13463
|
+
const int32_t * pos = (const int32_t *) src1->data;
|
13464
|
+
|
12969
13465
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
12970
|
-
for (int64_t i2 =
|
12971
|
-
const int64_t p =
|
13466
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
13467
|
+
const int64_t p = pos[i2];
|
12972
13468
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
12973
13469
|
if (ir++ < ir0) continue;
|
12974
13470
|
if (ir > ir1) break;
|
@@ -12980,7 +13476,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|
12980
13476
|
const float cos_theta = cosf(theta);
|
12981
13477
|
const float sin_theta = sinf(theta);
|
12982
13478
|
// zeta scaling for xPos only:
|
12983
|
-
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0),
|
13479
|
+
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
12984
13480
|
if (xpos_down) zeta = 1.0f / zeta;
|
12985
13481
|
|
12986
13482
|
theta *= theta_scale;
|
@@ -13023,6 +13519,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|
13023
13519
|
static void ggml_compute_forward_rope_back_f16(
|
13024
13520
|
const struct ggml_compute_params * params,
|
13025
13521
|
const struct ggml_tensor * src0,
|
13522
|
+
const struct ggml_tensor * src1,
|
13026
13523
|
struct ggml_tensor * dst) {
|
13027
13524
|
|
13028
13525
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -13033,13 +13530,11 @@ static void ggml_compute_forward_rope_back_f16(
|
|
13033
13530
|
// dx = rope_back(dy, src1)
|
13034
13531
|
// src0 is dy, src1 contains options
|
13035
13532
|
|
13036
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
13533
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
13037
13534
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
13038
13535
|
const int mode = ((int32_t *) dst->op_params)[2];
|
13039
13536
|
|
13040
|
-
|
13041
|
-
|
13042
|
-
GGML_TENSOR_UNARY_OP_LOCALS;
|
13537
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
13043
13538
|
|
13044
13539
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
13045
13540
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
@@ -13065,9 +13560,11 @@ static void ggml_compute_forward_rope_back_f16(
|
|
13065
13560
|
|
13066
13561
|
const bool is_neox = mode & 2;
|
13067
13562
|
|
13563
|
+
const int32_t * pos = (const int32_t *) src1->data;
|
13564
|
+
|
13068
13565
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
13069
|
-
for (int64_t i2 =
|
13070
|
-
const int64_t p =
|
13566
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
13567
|
+
const int64_t p = pos[i2];
|
13071
13568
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
13072
13569
|
if (ir++ < ir0) continue;
|
13073
13570
|
if (ir > ir1) break;
|
@@ -13119,15 +13616,16 @@ static void ggml_compute_forward_rope_back_f16(
|
|
13119
13616
|
static void ggml_compute_forward_rope_back(
|
13120
13617
|
const struct ggml_compute_params * params,
|
13121
13618
|
const struct ggml_tensor * src0,
|
13619
|
+
const struct ggml_tensor * src1,
|
13122
13620
|
struct ggml_tensor * dst) {
|
13123
13621
|
switch (src0->type) {
|
13124
13622
|
case GGML_TYPE_F16:
|
13125
13623
|
{
|
13126
|
-
ggml_compute_forward_rope_back_f16(params, src0, dst);
|
13624
|
+
ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
|
13127
13625
|
} break;
|
13128
13626
|
case GGML_TYPE_F32:
|
13129
13627
|
{
|
13130
|
-
ggml_compute_forward_rope_back_f32(params, src0, dst);
|
13628
|
+
ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
|
13131
13629
|
} break;
|
13132
13630
|
default:
|
13133
13631
|
{
|
@@ -13150,7 +13648,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
|
13150
13648
|
int64_t t0 = ggml_perf_time_us();
|
13151
13649
|
UNUSED(t0);
|
13152
13650
|
|
13153
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
13651
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13154
13652
|
|
13155
13653
|
const int ith = params->ith;
|
13156
13654
|
const int nth = params->nth;
|
@@ -13241,7 +13739,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
|
|
13241
13739
|
int64_t t0 = ggml_perf_time_us();
|
13242
13740
|
UNUSED(t0);
|
13243
13741
|
|
13244
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
13742
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13245
13743
|
|
13246
13744
|
const int ith = params->ith;
|
13247
13745
|
const int nth = params->nth;
|
@@ -13353,7 +13851,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
|
13353
13851
|
int64_t t0 = ggml_perf_time_us();
|
13354
13852
|
UNUSED(t0);
|
13355
13853
|
|
13356
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
13854
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13357
13855
|
|
13358
13856
|
const int ith = params->ith;
|
13359
13857
|
const int nth = params->nth;
|
@@ -13444,7 +13942,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
|
13444
13942
|
int64_t t0 = ggml_perf_time_us();
|
13445
13943
|
UNUSED(t0);
|
13446
13944
|
|
13447
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
13945
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13448
13946
|
|
13449
13947
|
const int ith = params->ith;
|
13450
13948
|
const int nth = params->nth;
|
@@ -13562,7 +14060,7 @@ static void ggml_compute_forward_conv_1d(
|
|
13562
14060
|
ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
|
13563
14061
|
} else {
|
13564
14062
|
GGML_ASSERT(false); // only stride 1 and 2 supported
|
13565
|
-
}
|
14063
|
+
}
|
13566
14064
|
}
|
13567
14065
|
|
13568
14066
|
// ggml_compute_forward_conv_2d
|
@@ -13579,7 +14077,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
|
|
13579
14077
|
int64_t t0 = ggml_perf_time_us();
|
13580
14078
|
UNUSED(t0);
|
13581
14079
|
|
13582
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
14080
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13583
14081
|
|
13584
14082
|
const int ith = params->ith;
|
13585
14083
|
const int nth = params->nth;
|
@@ -13699,7 +14197,7 @@ static void ggml_compute_forward_conv_transpose_2d(
|
|
13699
14197
|
int64_t t0 = ggml_perf_time_us();
|
13700
14198
|
UNUSED(t0);
|
13701
14199
|
|
13702
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
14200
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
13703
14201
|
|
13704
14202
|
const int ith = params->ith;
|
13705
14203
|
const int nth = params->nth;
|
@@ -13958,7 +14456,7 @@ static void ggml_compute_forward_upscale_f32(
|
|
13958
14456
|
|
13959
14457
|
const int ith = params->ith;
|
13960
14458
|
|
13961
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
14459
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
13962
14460
|
|
13963
14461
|
const int scale_factor = dst->op_params[0];
|
13964
14462
|
|
@@ -14010,14 +14508,14 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14010
14508
|
int64_t t0 = ggml_perf_time_us();
|
14011
14509
|
UNUSED(t0);
|
14012
14510
|
|
14013
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14014
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14015
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14016
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14017
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14018
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14019
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14020
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14511
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14512
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14513
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14514
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14515
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14516
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14517
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14518
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14021
14519
|
|
14022
14520
|
const int ith = params->ith;
|
14023
14521
|
const int nth = params->nth;
|
@@ -14087,10 +14585,11 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14087
14585
|
S[i] = -INFINITY;
|
14088
14586
|
}
|
14089
14587
|
|
14090
|
-
|
14588
|
+
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
14589
|
+
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
14091
14590
|
// k indices
|
14092
14591
|
const int ik3 = iq3;
|
14093
|
-
const int ik2 = iq2;
|
14592
|
+
const int ik2 = iq2 % nek2;
|
14094
14593
|
const int ik1 = ic;
|
14095
14594
|
|
14096
14595
|
// S indices
|
@@ -14103,20 +14602,18 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14103
14602
|
}
|
14104
14603
|
|
14105
14604
|
// scale
|
14106
|
-
ggml_vec_scale_f32(
|
14605
|
+
ggml_vec_scale_f32(masked_begin, S, scale);
|
14107
14606
|
|
14108
|
-
|
14109
|
-
|
14110
|
-
if (i > P + iq1) {
|
14111
|
-
S[i] = -INFINITY;
|
14112
|
-
}
|
14113
|
-
}
|
14607
|
+
for (int64_t i = masked_begin; i < M; i++) {
|
14608
|
+
S[i] = -INFINITY;
|
14114
14609
|
}
|
14115
14610
|
|
14116
14611
|
// softmax
|
14612
|
+
// exclude known -INF S[..] values from max and loop
|
14613
|
+
// dont forget to set their SW values to zero
|
14117
14614
|
{
|
14118
14615
|
float max = -INFINITY;
|
14119
|
-
ggml_vec_max_f32(
|
14616
|
+
ggml_vec_max_f32(masked_begin, &max, S);
|
14120
14617
|
|
14121
14618
|
ggml_float sum = 0.0;
|
14122
14619
|
{
|
@@ -14130,10 +14627,15 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14130
14627
|
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
14131
14628
|
|
14132
14629
|
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
14630
|
+
if (i >= masked_begin) {
|
14631
|
+
break;
|
14632
|
+
}
|
14133
14633
|
float * SS = S + i;
|
14134
14634
|
|
14135
14635
|
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
14136
|
-
if (
|
14636
|
+
if (i + j >= masked_begin) {
|
14637
|
+
break;
|
14638
|
+
} else if (SS[j] == -INFINITY) {
|
14137
14639
|
SS[j] = 0.0f;
|
14138
14640
|
} else {
|
14139
14641
|
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
@@ -14158,10 +14660,10 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14158
14660
|
assert(sum > 0.0);
|
14159
14661
|
|
14160
14662
|
sum = 1.0/sum;
|
14161
|
-
ggml_vec_scale_f32(
|
14663
|
+
ggml_vec_scale_f32(masked_begin, S, sum);
|
14162
14664
|
|
14163
14665
|
#ifndef NDEBUG
|
14164
|
-
for (int i = 0; i <
|
14666
|
+
for (int i = 0; i < masked_begin; ++i) {
|
14165
14667
|
assert(!isnan(S[i]));
|
14166
14668
|
assert(!isinf(S[i]));
|
14167
14669
|
}
|
@@ -14174,9 +14676,13 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
14174
14676
|
const int i2 = iq2;
|
14175
14677
|
const int i3 = iq3;
|
14176
14678
|
|
14177
|
-
|
14178
|
-
|
14179
|
-
|
14679
|
+
// v indices
|
14680
|
+
const int iv2 = iq2 % nev2;
|
14681
|
+
const int iv3 = iq3;
|
14682
|
+
|
14683
|
+
ggml_vec_dot_f32(masked_begin,
|
14684
|
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
14685
|
+
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
14180
14686
|
S);
|
14181
14687
|
}
|
14182
14688
|
}
|
@@ -14192,14 +14698,14 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14192
14698
|
int64_t t0 = ggml_perf_time_us();
|
14193
14699
|
UNUSED(t0);
|
14194
14700
|
|
14195
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14196
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14197
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14198
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14199
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14200
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14201
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14202
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14701
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14702
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14703
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14704
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14705
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14706
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14707
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14708
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14203
14709
|
|
14204
14710
|
const int ith = params->ith;
|
14205
14711
|
const int nth = params->nth;
|
@@ -14273,7 +14779,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14273
14779
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
14274
14780
|
// k indices
|
14275
14781
|
const int ik3 = iq3;
|
14276
|
-
const int ik2 = iq2;
|
14782
|
+
const int ik2 = iq2 % nek2;
|
14277
14783
|
const int ik1 = ic;
|
14278
14784
|
|
14279
14785
|
// S indices
|
@@ -14288,7 +14794,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14288
14794
|
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
14289
14795
|
// k indices
|
14290
14796
|
const int ik3 = iq3;
|
14291
|
-
const int ik2 = iq2;
|
14797
|
+
const int ik2 = iq2 % nek2;
|
14292
14798
|
const int ik1 = ic;
|
14293
14799
|
|
14294
14800
|
// S indices
|
@@ -14313,6 +14819,8 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14313
14819
|
}
|
14314
14820
|
|
14315
14821
|
// softmax
|
14822
|
+
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
14823
|
+
// dont forget to set their S values to zero
|
14316
14824
|
{
|
14317
14825
|
float max = -INFINITY;
|
14318
14826
|
ggml_vec_max_f32(M, &max, S);
|
@@ -14369,6 +14877,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14369
14877
|
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
14370
14878
|
}
|
14371
14879
|
|
14880
|
+
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
14372
14881
|
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
14373
14882
|
for (int64_t ic = 0; ic < nev1; ++ic) {
|
14374
14883
|
// dst indices
|
@@ -14376,9 +14885,13 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14376
14885
|
const int i2 = iq2;
|
14377
14886
|
const int i3 = iq3;
|
14378
14887
|
|
14379
|
-
|
14380
|
-
|
14381
|
-
|
14888
|
+
// v indices
|
14889
|
+
const int iv2 = iq2 % nev2;
|
14890
|
+
const int iv3 = iq3;
|
14891
|
+
|
14892
|
+
ggml_vec_dot_f16(nev0,
|
14893
|
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
14894
|
+
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
14382
14895
|
S16);
|
14383
14896
|
}
|
14384
14897
|
} else {
|
@@ -14388,9 +14901,13 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
14388
14901
|
const int i2 = iq2;
|
14389
14902
|
const int i3 = iq3;
|
14390
14903
|
|
14391
|
-
|
14392
|
-
|
14393
|
-
|
14904
|
+
// v indices
|
14905
|
+
const int iv2 = iq2 % nev2;
|
14906
|
+
const int iv3 = iq3;
|
14907
|
+
|
14908
|
+
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
14909
|
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
14910
|
+
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
14394
14911
|
S16);
|
14395
14912
|
}
|
14396
14913
|
}
|
@@ -14433,18 +14950,18 @@ static void ggml_compute_forward_flash_ff_f16(
|
|
14433
14950
|
int64_t t0 = ggml_perf_time_us();
|
14434
14951
|
UNUSED(t0);
|
14435
14952
|
|
14436
|
-
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
14437
|
-
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
14438
|
-
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
14439
|
-
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
14440
|
-
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
14441
|
-
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
14442
|
-
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
14443
|
-
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
14444
|
-
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
14445
|
-
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
14446
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14447
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14953
|
+
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
14954
|
+
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
14955
|
+
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
14956
|
+
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
14957
|
+
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
14958
|
+
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
14959
|
+
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
14960
|
+
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
14961
|
+
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
14962
|
+
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
14963
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14964
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14448
14965
|
|
14449
14966
|
const int ith = params->ith;
|
14450
14967
|
const int nth = params->nth;
|
@@ -14592,16 +15109,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|
14592
15109
|
int64_t t0 = ggml_perf_time_us();
|
14593
15110
|
UNUSED(t0);
|
14594
15111
|
|
14595
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
14596
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
14597
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
14598
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
14599
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
14600
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
14601
|
-
GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
|
14602
|
-
GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
|
14603
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14604
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15112
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15113
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15114
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15115
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15116
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15117
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15118
|
+
GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
|
15119
|
+
GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
|
15120
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15121
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
14605
15122
|
|
14606
15123
|
const int ith = params->ith;
|
14607
15124
|
const int nth = params->nth;
|
@@ -14649,10 +15166,37 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|
14649
15166
|
return;
|
14650
15167
|
}
|
14651
15168
|
|
14652
|
-
|
15169
|
+
const int64_t elem_q = ggml_nelements(q);
|
15170
|
+
const int64_t elem_k = ggml_nelements(k);
|
14653
15171
|
|
14654
|
-
|
14655
|
-
|
15172
|
+
enum ggml_type result_type = dst->type;
|
15173
|
+
GGML_ASSERT(ggml_blck_size(result_type) == 1);
|
15174
|
+
const size_t tsize = ggml_type_size(result_type);
|
15175
|
+
|
15176
|
+
const size_t offs_q = 0;
|
15177
|
+
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
|
15178
|
+
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
|
15179
|
+
|
15180
|
+
void * grad_q = (char *) dst->data;
|
15181
|
+
void * grad_k = (char *) dst->data + offs_k;
|
15182
|
+
void * grad_v = (char *) dst->data + offs_v;
|
15183
|
+
|
15184
|
+
const size_t nbgq1 = nb0*neq0;
|
15185
|
+
const size_t nbgq2 = nb0*neq0*neq1;
|
15186
|
+
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
15187
|
+
|
15188
|
+
const size_t nbgk1 = nb0*nek0;
|
15189
|
+
const size_t nbgk2 = nb0*nek0*nek1;
|
15190
|
+
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
15191
|
+
|
15192
|
+
const size_t nbgv1 = nb0*nev0;
|
15193
|
+
const size_t nbgv2 = nb0*nev0*nev1;
|
15194
|
+
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
15195
|
+
|
15196
|
+
// parallelize by k rows using ggml_vec_dot_f32
|
15197
|
+
|
15198
|
+
// total rows in k
|
15199
|
+
const int nr = nek2*nek3;
|
14656
15200
|
|
14657
15201
|
// rows per thread
|
14658
15202
|
const int dr = (nr + nth - 1)/nth;
|
@@ -14665,268 +15209,243 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|
14665
15209
|
|
14666
15210
|
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
14667
15211
|
|
15212
|
+
// how often k2 (and v2) is repeated in q2
|
15213
|
+
int nrep = neq2/nek2;
|
15214
|
+
|
14668
15215
|
for (int ir = ir0; ir < ir1; ++ir) {
|
14669
15216
|
// q indices
|
14670
|
-
const int
|
14671
|
-
const int
|
14672
|
-
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
15217
|
+
const int ik3 = ir/(nek2);
|
15218
|
+
const int ik2 = ir - ik3*nek2;
|
14673
15219
|
|
15220
|
+
const int iq3 = ik3;
|
15221
|
+
const int id3 = ik3;
|
15222
|
+
const int iv3 = ik3;
|
15223
|
+
const int iv2 = ik2;
|
14674
15224
|
|
14675
|
-
|
14676
|
-
|
14677
|
-
|
14678
|
-
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
15225
|
+
for (int irep = 0; irep < nrep; ++irep) {
|
15226
|
+
const int iq2 = ik2 + irep*nek2;
|
15227
|
+
const int id2 = iq2;
|
14679
15228
|
|
14680
|
-
|
14681
|
-
|
14682
|
-
|
15229
|
+
// (ik2 + irep*nek2) % nek2 == ik2
|
15230
|
+
for (int iq1 = 0; iq1 < neq1; ++iq1) {
|
15231
|
+
const int id1 = iq1;
|
14683
15232
|
|
14684
|
-
|
14685
|
-
//
|
14686
|
-
|
14687
|
-
|
14688
|
-
const int ik1 = ic;
|
15233
|
+
// not sure about CACHE_LINE_SIZE_F32..
|
15234
|
+
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
15235
|
+
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
15236
|
+
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
14689
15237
|
|
14690
|
-
|
14691
|
-
|
15238
|
+
for (int i = M; i < Mup; ++i) {
|
15239
|
+
S[i] = -INFINITY;
|
15240
|
+
}
|
14692
15241
|
|
14693
|
-
|
14694
|
-
|
14695
|
-
|
14696
|
-
|
14697
|
-
}
|
15242
|
+
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
15243
|
+
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15244
|
+
// k indices
|
15245
|
+
const int ik1 = ic;
|
14698
15246
|
|
14699
|
-
|
14700
|
-
|
15247
|
+
// S indices
|
15248
|
+
const int i1 = ik1;
|
14701
15249
|
|
14702
|
-
|
14703
|
-
|
14704
|
-
|
14705
|
-
|
14706
|
-
}
|
15250
|
+
ggml_vec_dot_f32(neq0,
|
15251
|
+
S + i1,
|
15252
|
+
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
15253
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
14707
15254
|
}
|
14708
|
-
}
|
14709
15255
|
|
14710
|
-
|
14711
|
-
|
14712
|
-
float max = -INFINITY;
|
14713
|
-
ggml_vec_max_f32(M, &max, S);
|
15256
|
+
// scale
|
15257
|
+
ggml_vec_scale_f32(masked_begin, S, scale);
|
14714
15258
|
|
14715
|
-
|
15259
|
+
for (int64_t i = masked_begin; i < M; i++) {
|
15260
|
+
S[i] = -INFINITY;
|
15261
|
+
}
|
15262
|
+
|
15263
|
+
// softmax
|
15264
|
+
// exclude known -INF S[..] values from max and loop
|
15265
|
+
// dont forget to set their SM values to zero
|
14716
15266
|
{
|
15267
|
+
float max = -INFINITY;
|
15268
|
+
ggml_vec_max_f32(masked_begin, &max, S);
|
15269
|
+
|
15270
|
+
ggml_float sum = 0.0;
|
15271
|
+
{
|
14717
15272
|
#ifdef GGML_SOFT_MAX_ACCELERATE
|
14718
|
-
|
14719
|
-
|
14720
|
-
|
14721
|
-
|
15273
|
+
max = -max;
|
15274
|
+
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
15275
|
+
vvexpf(SM, SM, &Mup);
|
15276
|
+
ggml_vec_sum_f32(Mup, &sum, SM);
|
14722
15277
|
#else
|
14723
|
-
|
14724
|
-
|
14725
|
-
|
14726
|
-
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
14727
|
-
float * SR = S + i;
|
14728
|
-
float * SW = SM + i;
|
15278
|
+
uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
|
15279
|
+
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
14729
15280
|
|
14730
|
-
for (int
|
14731
|
-
if (
|
14732
|
-
|
14733
|
-
}
|
15281
|
+
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
15282
|
+
if (i >= masked_begin) {
|
15283
|
+
break;
|
15284
|
+
}
|
15285
|
+
float * SR = S + i;
|
15286
|
+
float * SW = SM + i;
|
15287
|
+
|
15288
|
+
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
15289
|
+
if (i + j >= masked_begin) {
|
15290
|
+
break;
|
15291
|
+
} else if (SR[j] == -INFINITY) {
|
15292
|
+
SW[j] = 0.0f;
|
15293
|
+
} else {
|
14734
15294
|
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
14735
|
-
|
15295
|
+
const float val = expf(SR[j] - max);
|
14736
15296
|
#else
|
14737
|
-
|
14738
|
-
|
14739
|
-
|
15297
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
15298
|
+
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
15299
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
14740
15300
|
#endif
|
14741
|
-
|
14742
|
-
|
15301
|
+
sump[j] += (ggml_float)val;
|
15302
|
+
SW[j] = val;
|
15303
|
+
}
|
14743
15304
|
}
|
14744
15305
|
}
|
14745
|
-
}
|
14746
15306
|
|
14747
|
-
|
14748
|
-
|
14749
|
-
|
15307
|
+
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
15308
|
+
sum += sump[i];
|
15309
|
+
}
|
14750
15310
|
#endif
|
14751
|
-
|
14752
|
-
|
14753
|
-
assert(sum > 0.0);
|
14754
|
-
|
14755
|
-
sum = 1.0/sum;
|
14756
|
-
ggml_vec_scale_f32(M, SM, sum);
|
14757
|
-
|
14758
|
-
}
|
14759
|
-
|
14760
|
-
// step-by-step explanation
|
14761
|
-
{
|
14762
|
-
// forward-process shape grads from backward process
|
14763
|
-
// parallel_for iq2,iq3:
|
14764
|
-
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
14765
|
-
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
14766
|
-
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
14767
|
-
// for iq1:
|
14768
|
-
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
14769
|
-
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
14770
|
-
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
14771
|
-
// S0 = -Inf [D,1,1,1]
|
14772
|
-
// ~S1[i] = dot(kcur[:D,i], qcur)
|
14773
|
-
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
14774
|
-
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
14775
|
-
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14776
|
-
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
14777
|
-
// ~S5[i] = dot(vcur[:,i], S4)
|
14778
|
-
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
14779
|
-
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
14780
|
-
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
14781
|
-
// dst backward-/ grad[dst] = d
|
14782
|
-
//
|
14783
|
-
// output gradients with their dependencies:
|
14784
|
-
//
|
14785
|
-
// grad[kcur] = grad[S1].T @ qcur
|
14786
|
-
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
14787
|
-
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14788
|
-
// grad[S4] = grad[S5] @ vcur
|
14789
|
-
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
14790
|
-
// grad[qcur] = grad[S1] @ kcur
|
14791
|
-
// grad[vcur] = grad[S5].T @ S4
|
14792
|
-
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
14793
|
-
//
|
14794
|
-
// in post-order:
|
14795
|
-
//
|
14796
|
-
// S1 = qcur @ kcur.T
|
14797
|
-
// S2 = S1 * scale
|
14798
|
-
// S3 = diag_mask_inf(S2, P)
|
14799
|
-
// S4 = softmax(S3)
|
14800
|
-
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
14801
|
-
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14802
|
-
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
14803
|
-
// grad[qcur] = grad[S1] @ kcur
|
14804
|
-
// grad[kcur] = grad[S1].T @ qcur
|
14805
|
-
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
14806
|
-
//
|
14807
|
-
// using less variables (SM=S4):
|
14808
|
-
//
|
14809
|
-
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
14810
|
-
// SM = softmax(S)
|
14811
|
-
// S = d[:D,iq1,iq2,iq3] @ vcur
|
14812
|
-
// dot_SM_gradSM = dot(SM, S)
|
14813
|
-
// S = SM * (S - dot(SM, S))
|
14814
|
-
// S = diag_mask_zero(S, P) * scale
|
14815
|
-
//
|
14816
|
-
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
14817
|
-
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
14818
|
-
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
14819
|
-
}
|
14820
|
-
|
14821
|
-
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
14822
|
-
// S = d[:D,iq1,iq2,iq3] @ vcur
|
14823
|
-
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
14824
|
-
ggml_vec_set_f32(M, S, 0);
|
14825
|
-
for (int64_t ic = 0; ic < D; ++ic) {
|
14826
|
-
// dst indices
|
14827
|
-
const int i1 = iq1;
|
14828
|
-
const int i2 = iq2;
|
14829
|
-
const int i3 = iq3;
|
15311
|
+
}
|
14830
15312
|
|
14831
|
-
|
14832
|
-
S,
|
14833
|
-
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
14834
|
-
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
14835
|
-
}
|
15313
|
+
assert(sum > 0.0);
|
14836
15314
|
|
14837
|
-
|
14838
|
-
|
14839
|
-
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
14840
|
-
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
14841
|
-
ggml_vec_mul_f32 (M, S, S, SM);
|
15315
|
+
sum = 1.0/sum;
|
15316
|
+
ggml_vec_scale_f32(masked_begin, SM, sum);
|
14842
15317
|
|
14843
|
-
// S = diag_mask_zero(S, P) * scale
|
14844
|
-
if (masked) {
|
14845
|
-
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
14846
|
-
// S[i] = 0;
|
14847
|
-
// }
|
14848
|
-
for (int64_t i = P; i < M; i++) {
|
14849
|
-
if (i > P + iq1) {
|
14850
|
-
S[i] = 0;
|
14851
|
-
}
|
14852
15318
|
}
|
14853
|
-
}
|
14854
|
-
ggml_vec_scale_f32(M, S, scale);
|
14855
|
-
|
14856
|
-
void * grad_q = (char *) dst->data;
|
14857
|
-
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
14858
|
-
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
14859
|
-
|
14860
|
-
const size_t nbgq1 = nb0*neq0;
|
14861
|
-
const size_t nbgq2 = nb0*neq0*neq1;
|
14862
|
-
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
14863
|
-
|
14864
|
-
const size_t nbgk1 = nb0*nek0;
|
14865
|
-
const size_t nbgk2 = nb0*nek0*nek1;
|
14866
|
-
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
14867
|
-
|
14868
|
-
const size_t nbgv1 = nb0*nev0;
|
14869
|
-
const size_t nbgv2 = nb0*nev0*nev1;
|
14870
|
-
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
14871
|
-
|
14872
|
-
// S shape [M,1]
|
14873
|
-
// SM shape [M,1]
|
14874
|
-
// kcur shape [D,M]
|
14875
|
-
// qcur shape [D,1]
|
14876
|
-
// vcur shape [M,D]
|
14877
|
-
//
|
14878
|
-
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
14879
|
-
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
14880
|
-
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
14881
|
-
//
|
14882
|
-
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
14883
|
-
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
14884
|
-
for (int64_t ic = 0; ic < M; ++ic) {
|
14885
|
-
// dst indices
|
14886
|
-
const int i1 = iq1;
|
14887
|
-
const int i2 = iq2;
|
14888
|
-
const int i3 = iq3;
|
14889
15319
|
|
14890
|
-
|
14891
|
-
|
14892
|
-
|
14893
|
-
|
14894
|
-
|
15320
|
+
// step-by-step explanation
|
15321
|
+
{
|
15322
|
+
// forward-process shape grads from backward process
|
15323
|
+
// parallel_for ik2,ik3:
|
15324
|
+
// for irep:
|
15325
|
+
// iq2 = ik2 + irep*nek2
|
15326
|
+
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
|
15327
|
+
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
15328
|
+
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
|
15329
|
+
// for iq1:
|
15330
|
+
// kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
15331
|
+
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
15332
|
+
// vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
15333
|
+
// S0 = -Inf [D,1,1,1]
|
15334
|
+
// ~S1[i] = dot(kcur[:D,i], qcur)
|
15335
|
+
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
15336
|
+
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
15337
|
+
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
15338
|
+
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
15339
|
+
// ~S5[i] = dot(vcur[:,i], S4)
|
15340
|
+
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
|
15341
|
+
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
15342
|
+
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
|
15343
|
+
// dst backward-/ grad[dst] = d
|
15344
|
+
//
|
15345
|
+
// output gradients with their dependencies:
|
15346
|
+
//
|
15347
|
+
// grad[kcur] = grad[S1].T @ qcur
|
15348
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
15349
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
15350
|
+
// grad[S4] = grad[S5] @ vcur
|
15351
|
+
// grad[S4] = d[:D,id1,id2,id3] @ vcur
|
15352
|
+
// grad[qcur] = grad[S1] @ kcur
|
15353
|
+
// grad[vcur] = grad[S5].T @ S4
|
15354
|
+
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
|
15355
|
+
//
|
15356
|
+
// in post-order:
|
15357
|
+
//
|
15358
|
+
// S1 = qcur @ kcur.T
|
15359
|
+
// S2 = S1 * scale
|
15360
|
+
// S3 = diag_mask_inf(S2, P)
|
15361
|
+
// S4 = softmax(S3)
|
15362
|
+
// grad[S4] = d[:D,id1,id2,id3] @ vcur
|
15363
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
15364
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
15365
|
+
// grad[qcur] = grad[S1] @ kcur
|
15366
|
+
// grad[kcur] = grad[S1].T @ qcur
|
15367
|
+
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
|
15368
|
+
//
|
15369
|
+
// using less variables (SM=S4):
|
15370
|
+
//
|
15371
|
+
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
15372
|
+
// SM = softmax(S)
|
15373
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
15374
|
+
// dot_SM_gradSM = dot(SM, S)
|
15375
|
+
// S = SM * (S - dot(SM, S))
|
15376
|
+
// S = diag_mask_zero(S, P) * scale
|
15377
|
+
//
|
15378
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
15379
|
+
// grad[k][:D,:M,ik2,ik3] += S.T @ qcur
|
15380
|
+
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
|
15381
|
+
}
|
14895
15382
|
|
14896
|
-
|
14897
|
-
|
14898
|
-
|
14899
|
-
|
14900
|
-
//
|
14901
|
-
|
14902
|
-
|
14903
|
-
|
15383
|
+
// S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
|
15384
|
+
// S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
|
15385
|
+
// for ic:
|
15386
|
+
// S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
|
15387
|
+
// exclude known future zero S[..] values from operation
|
15388
|
+
ggml_vec_set_f32(masked_begin, S, 0);
|
15389
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
15390
|
+
ggml_vec_mad_f32(masked_begin,
|
15391
|
+
S,
|
15392
|
+
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
15393
|
+
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
15394
|
+
}
|
14904
15395
|
|
14905
|
-
//
|
14906
|
-
|
14907
|
-
|
14908
|
-
|
14909
|
-
|
14910
|
-
|
14911
|
-
|
14912
|
-
|
15396
|
+
// S = SM * (S - dot(SM, S))
|
15397
|
+
float dot_SM_gradSM = 0;
|
15398
|
+
ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
|
15399
|
+
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
15400
|
+
ggml_vec_mul_f32 (masked_begin, S, S, SM);
|
15401
|
+
|
15402
|
+
// S = diag_mask_zero(S, P) * scale
|
15403
|
+
// already done by above ggml_vec_set_f32
|
15404
|
+
|
15405
|
+
// exclude known zero S[..] values from operation
|
15406
|
+
ggml_vec_scale_f32(masked_begin, S, scale);
|
15407
|
+
|
15408
|
+
// S shape [M,1]
|
15409
|
+
// SM shape [M,1]
|
15410
|
+
// kcur shape [D,M]
|
15411
|
+
// qcur shape [D,1]
|
15412
|
+
// vcur shape [M,D]
|
15413
|
+
|
15414
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
15415
|
+
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
15416
|
+
// for ic:
|
15417
|
+
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
|
15418
|
+
// exclude known zero S[..] values from loop
|
15419
|
+
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15420
|
+
ggml_vec_mad_f32(D,
|
15421
|
+
(float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
|
15422
|
+
(float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
15423
|
+
S[ic]);
|
15424
|
+
}
|
14913
15425
|
|
14914
|
-
|
14915
|
-
|
14916
|
-
|
14917
|
-
|
14918
|
-
//
|
14919
|
-
|
14920
|
-
|
14921
|
-
|
15426
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
15427
|
+
// for ic:
|
15428
|
+
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
15429
|
+
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
15430
|
+
// exclude known zero S[..] values from loop
|
15431
|
+
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15432
|
+
ggml_vec_mad_f32(D,
|
15433
|
+
(float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
|
15434
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
|
15435
|
+
S[ic]);
|
15436
|
+
}
|
14922
15437
|
|
14923
|
-
//
|
14924
|
-
//
|
14925
|
-
//
|
14926
|
-
|
14927
|
-
|
14928
|
-
|
14929
|
-
|
15438
|
+
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
|
15439
|
+
// for ic:
|
15440
|
+
// grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
|
15441
|
+
// grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
|
15442
|
+
// exclude known zero SM[..] values from mad
|
15443
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
15444
|
+
ggml_vec_mad_f32(masked_begin,
|
15445
|
+
(float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
|
15446
|
+
SM,
|
15447
|
+
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
15448
|
+
}
|
14930
15449
|
}
|
14931
15450
|
}
|
14932
15451
|
}
|
@@ -14962,8 +15481,8 @@ static void ggml_compute_forward_win_part_f32(
|
|
14962
15481
|
return;
|
14963
15482
|
}
|
14964
15483
|
|
14965
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
14966
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15484
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
15485
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
14967
15486
|
|
14968
15487
|
const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
|
14969
15488
|
const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
|
@@ -15024,8 +15543,8 @@ static void ggml_compute_forward_win_unpart_f32(
|
|
15024
15543
|
return;
|
15025
15544
|
}
|
15026
15545
|
|
15027
|
-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
15028
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15546
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
15547
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15029
15548
|
|
15030
15549
|
const int32_t w = ((const int32_t *)(dst->op_params))[0];
|
15031
15550
|
|
@@ -15142,7 +15661,7 @@ static void ggml_compute_forward_get_rel_pos_f16(
|
|
15142
15661
|
|
15143
15662
|
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
|
15144
15663
|
|
15145
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
15664
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
15146
15665
|
|
15147
15666
|
const int64_t w = ne1;
|
15148
15667
|
|
@@ -15840,7 +16359,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
15840
16359
|
} break;
|
15841
16360
|
case GGML_OP_GET_ROWS_BACK:
|
15842
16361
|
{
|
15843
|
-
ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor
|
16362
|
+
ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
|
15844
16363
|
} break;
|
15845
16364
|
case GGML_OP_DIAG:
|
15846
16365
|
{
|
@@ -15864,11 +16383,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
15864
16383
|
} break;
|
15865
16384
|
case GGML_OP_ROPE:
|
15866
16385
|
{
|
15867
|
-
ggml_compute_forward_rope(params, tensor->src[0], tensor);
|
16386
|
+
ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
|
15868
16387
|
} break;
|
15869
16388
|
case GGML_OP_ROPE_BACK:
|
15870
16389
|
{
|
15871
|
-
ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
|
16390
|
+
ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
|
15872
16391
|
} break;
|
15873
16392
|
case GGML_OP_ALIBI:
|
15874
16393
|
{
|
@@ -16013,7 +16532,218 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
16013
16532
|
|
16014
16533
|
////////////////////////////////////////////////////////////////////////////////
|
16015
16534
|
|
16016
|
-
|
16535
|
+
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
|
16536
|
+
|
16537
|
+
static size_t hash(void * p) {
|
16538
|
+
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
16539
|
+
}
|
16540
|
+
|
16541
|
+
static size_t hash_find(void * hash_table[], void * p) {
|
16542
|
+
size_t h = hash(p);
|
16543
|
+
|
16544
|
+
// linear probing
|
16545
|
+
size_t i = h;
|
16546
|
+
while (hash_table[i] != NULL && hash_table[i] != p) {
|
16547
|
+
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
16548
|
+
if (i == h) {
|
16549
|
+
// visited all hash table entries -> not found
|
16550
|
+
return GGML_GRAPH_HASHTABLE_SIZE;
|
16551
|
+
}
|
16552
|
+
}
|
16553
|
+
return i;
|
16554
|
+
}
|
16555
|
+
|
16556
|
+
static bool hash_insert(void * hash_table[], void * p) {
|
16557
|
+
size_t i = hash_find(hash_table, p);
|
16558
|
+
|
16559
|
+
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
16560
|
+
|
16561
|
+
if (hash_table[i] == p) {
|
16562
|
+
return true;
|
16563
|
+
}
|
16564
|
+
|
16565
|
+
// insert
|
16566
|
+
GGML_ASSERT(hash_table[i] == NULL);
|
16567
|
+
hash_table[i] = p;
|
16568
|
+
return false;
|
16569
|
+
}
|
16570
|
+
|
16571
|
+
static bool hash_contains(void * hash_table[], void * p) {
|
16572
|
+
size_t i = hash_find(hash_table, p);
|
16573
|
+
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
|
16574
|
+
}
|
16575
|
+
|
16576
|
+
struct hash_map {
|
16577
|
+
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
|
16578
|
+
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
|
16579
|
+
};
|
16580
|
+
|
16581
|
+
static struct hash_map * new_hash_map(void) {
|
16582
|
+
struct hash_map * result = malloc(sizeof(struct hash_map));
|
16583
|
+
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
|
16584
|
+
result->keys[i] = NULL;
|
16585
|
+
result->vals[i] = NULL;
|
16586
|
+
}
|
16587
|
+
return result;
|
16588
|
+
}
|
16589
|
+
|
16590
|
+
static void free_hash_map(struct hash_map * map) {
|
16591
|
+
free(map);
|
16592
|
+
}
|
16593
|
+
|
16594
|
+
// gradient checkpointing
|
16595
|
+
|
16596
|
+
static struct ggml_tensor * ggml_recompute_graph_node(
|
16597
|
+
struct ggml_context * ctx,
|
16598
|
+
struct ggml_cgraph * graph,
|
16599
|
+
struct hash_map * replacements,
|
16600
|
+
struct ggml_tensor * node) {
|
16601
|
+
|
16602
|
+
if (node == NULL) {
|
16603
|
+
return NULL;
|
16604
|
+
}
|
16605
|
+
|
16606
|
+
if (node->is_param) {
|
16607
|
+
return node;
|
16608
|
+
}
|
16609
|
+
|
16610
|
+
if (!hash_contains(graph->visited_hash_table, node)) {
|
16611
|
+
return node;
|
16612
|
+
}
|
16613
|
+
|
16614
|
+
int count_children = 0;
|
16615
|
+
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
16616
|
+
if (node->src[k]) {
|
16617
|
+
++count_children;
|
16618
|
+
}
|
16619
|
+
}
|
16620
|
+
|
16621
|
+
if (count_children == 0) {
|
16622
|
+
return node;
|
16623
|
+
}
|
16624
|
+
|
16625
|
+
size_t i = hash_find(replacements->keys, node);
|
16626
|
+
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
16627
|
+
if (replacements->keys[i] == node) {
|
16628
|
+
return (struct ggml_tensor *) replacements->vals[i];
|
16629
|
+
}
|
16630
|
+
|
16631
|
+
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
|
16632
|
+
|
16633
|
+
// insert clone into replacements
|
16634
|
+
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
|
16635
|
+
replacements->keys[i] = node;
|
16636
|
+
replacements->vals[i] = clone;
|
16637
|
+
|
16638
|
+
clone->op = node->op;
|
16639
|
+
clone->grad = node->grad;
|
16640
|
+
clone->is_param = node->is_param;
|
16641
|
+
clone->extra = node->extra;
|
16642
|
+
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
|
16643
|
+
clone->nb[k] = node->nb[k];
|
16644
|
+
}
|
16645
|
+
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
16646
|
+
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
16647
|
+
}
|
16648
|
+
if (node->view_src != NULL) {
|
16649
|
+
clone->data = (node->view_src->data == NULL)
|
16650
|
+
? NULL // view_src not yet allocated
|
16651
|
+
: (char *) node->view_src->data // view_src already allocated
|
16652
|
+
+ node->view_offs;
|
16653
|
+
clone->view_src = node->view_src;
|
16654
|
+
clone->view_offs = node->view_offs;
|
16655
|
+
}
|
16656
|
+
|
16657
|
+
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
16658
|
+
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
16659
|
+
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
|
16660
|
+
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
|
16661
|
+
|
16662
|
+
return clone;
|
16663
|
+
}
|
16664
|
+
|
16665
|
+
void ggml_build_backward_gradient_checkpointing(
|
16666
|
+
struct ggml_context * ctx,
|
16667
|
+
struct ggml_cgraph * gf,
|
16668
|
+
struct ggml_cgraph * gb,
|
16669
|
+
struct ggml_cgraph * gb_tmp,
|
16670
|
+
struct ggml_tensor * * checkpoints,
|
16671
|
+
int n_checkpoints) {
|
16672
|
+
*gb_tmp = *gf;
|
16673
|
+
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
16674
|
+
|
16675
|
+
if (n_checkpoints <= 0) {
|
16676
|
+
*gb = *gb_tmp;
|
16677
|
+
return;
|
16678
|
+
}
|
16679
|
+
|
16680
|
+
struct hash_map * replacements = new_hash_map();
|
16681
|
+
|
16682
|
+
// insert checkpoints in replacements
|
16683
|
+
for (int i = 0; i < n_checkpoints; ++i) {
|
16684
|
+
size_t k = hash_find(replacements->keys, checkpoints[i]);
|
16685
|
+
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
16686
|
+
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
|
16687
|
+
replacements->keys[k] = checkpoints[i];
|
16688
|
+
replacements->vals[k] = checkpoints[i];
|
16689
|
+
}
|
16690
|
+
|
16691
|
+
*gb = *gf;
|
16692
|
+
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
|
16693
|
+
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
|
16694
|
+
// by recomputing them from checkpoints
|
16695
|
+
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
|
16696
|
+
struct ggml_tensor * node = gb_tmp->nodes[i];
|
16697
|
+
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
16698
|
+
// insert new tensors recomputing src, reusing already made replacements,
|
16699
|
+
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
16700
|
+
// recurse for input tensors,
|
16701
|
+
// unless (i.e. terminating when) input tensors are replacments (like checkpoints)
|
16702
|
+
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
16703
|
+
}
|
16704
|
+
// insert rewritten backward node with replacements made into resulting backward graph gb
|
16705
|
+
ggml_build_forward_expand(gb, node);
|
16706
|
+
}
|
16707
|
+
|
16708
|
+
free_hash_map(replacements);
|
16709
|
+
}
|
16710
|
+
|
16711
|
+
// functions to change gradients considering the case that input a might be initial gradient with zero value
|
16712
|
+
|
16713
|
+
static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
|
16714
|
+
if (hash_contains(zero_table, a)) {
|
16715
|
+
return b;
|
16716
|
+
} else {
|
16717
|
+
return ggml_add_impl(ctx, a, b, false);
|
16718
|
+
}
|
16719
|
+
}
|
16720
|
+
|
16721
|
+
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
|
16722
|
+
if (hash_contains(zero_table, a)) {
|
16723
|
+
struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
|
16724
|
+
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
|
16725
|
+
} else {
|
16726
|
+
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
|
16727
|
+
}
|
16728
|
+
}
|
16729
|
+
|
16730
|
+
static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
|
16731
|
+
if (hash_contains(zero_table, a)) {
|
16732
|
+
return ggml_repeat(ctx, b, a);
|
16733
|
+
} else {
|
16734
|
+
return ggml_add1_impl(ctx, a, b, false);
|
16735
|
+
}
|
16736
|
+
}
|
16737
|
+
|
16738
|
+
static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
|
16739
|
+
if (hash_contains(zero_table, a)) {
|
16740
|
+
return ggml_neg(ctx, b);
|
16741
|
+
} else {
|
16742
|
+
return ggml_sub_impl(ctx, a, b, false);
|
16743
|
+
}
|
16744
|
+
}
|
16745
|
+
|
16746
|
+
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
|
16017
16747
|
struct ggml_tensor * src0 = tensor->src[0];
|
16018
16748
|
struct ggml_tensor * src1 = tensor->src[1];
|
16019
16749
|
|
@@ -16021,34 +16751,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16021
16751
|
case GGML_OP_DUP:
|
16022
16752
|
{
|
16023
16753
|
if (src0->grad) {
|
16024
|
-
src0->grad =
|
16754
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16025
16755
|
}
|
16026
16756
|
} break;
|
16027
16757
|
case GGML_OP_ADD:
|
16028
16758
|
{
|
16029
16759
|
if (src0->grad) {
|
16030
|
-
src0->grad =
|
16760
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16031
16761
|
}
|
16032
16762
|
if (src1->grad) {
|
16033
|
-
src1->grad =
|
16763
|
+
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
16034
16764
|
}
|
16035
16765
|
} break;
|
16036
16766
|
case GGML_OP_ADD1:
|
16037
16767
|
{
|
16038
16768
|
if (src0->grad) {
|
16039
|
-
src0->grad =
|
16769
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16040
16770
|
}
|
16041
16771
|
if (src1->grad) {
|
16042
|
-
src1->grad =
|
16772
|
+
src1->grad = ggml_add_or_set(ctx,
|
16043
16773
|
src1->grad,
|
16044
16774
|
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
|
16045
|
-
|
16775
|
+
zero_table);
|
16046
16776
|
}
|
16047
16777
|
} break;
|
16048
16778
|
case GGML_OP_ACC:
|
16049
16779
|
{
|
16050
16780
|
if (src0->grad) {
|
16051
|
-
src0->grad =
|
16781
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16052
16782
|
}
|
16053
16783
|
if (src1->grad) {
|
16054
16784
|
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
|
@@ -16065,117 +16795,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16065
16795
|
nb1, nb2, nb3, offset);
|
16066
16796
|
|
16067
16797
|
src1->grad =
|
16068
|
-
|
16798
|
+
ggml_add_or_set(ctx,
|
16069
16799
|
src1->grad,
|
16070
16800
|
ggml_reshape(ctx,
|
16071
16801
|
ggml_cont(ctx, tensor_grad_view),
|
16072
16802
|
src1->grad),
|
16073
|
-
|
16803
|
+
zero_table);
|
16074
16804
|
}
|
16075
16805
|
} break;
|
16076
16806
|
case GGML_OP_SUB:
|
16077
16807
|
{
|
16078
16808
|
if (src0->grad) {
|
16079
|
-
src0->grad =
|
16809
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16080
16810
|
}
|
16081
16811
|
if (src1->grad) {
|
16082
|
-
src1->grad =
|
16812
|
+
src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
16083
16813
|
}
|
16084
16814
|
} break;
|
16085
16815
|
case GGML_OP_MUL:
|
16086
16816
|
{
|
16087
16817
|
if (src0->grad) {
|
16088
16818
|
src0->grad =
|
16089
|
-
|
16819
|
+
ggml_add_or_set(ctx,
|
16090
16820
|
src0->grad,
|
16091
16821
|
ggml_mul(ctx, src1, tensor->grad),
|
16092
|
-
|
16822
|
+
zero_table);
|
16093
16823
|
}
|
16094
16824
|
if (src1->grad) {
|
16095
16825
|
src1->grad =
|
16096
|
-
|
16826
|
+
ggml_add_or_set(ctx,
|
16097
16827
|
src1->grad,
|
16098
16828
|
ggml_mul(ctx, src0, tensor->grad),
|
16099
|
-
|
16829
|
+
zero_table);
|
16100
16830
|
}
|
16101
16831
|
} break;
|
16102
16832
|
case GGML_OP_DIV:
|
16103
16833
|
{
|
16104
16834
|
if (src0->grad) {
|
16105
16835
|
src0->grad =
|
16106
|
-
|
16836
|
+
ggml_add_or_set(ctx,
|
16107
16837
|
src0->grad,
|
16108
16838
|
ggml_div(ctx, tensor->grad, src1),
|
16109
|
-
|
16839
|
+
zero_table);
|
16110
16840
|
}
|
16111
16841
|
if (src1->grad) {
|
16112
16842
|
src1->grad =
|
16113
|
-
|
16843
|
+
ggml_sub_or_set(ctx,
|
16114
16844
|
src1->grad,
|
16115
16845
|
ggml_mul(ctx,
|
16116
16846
|
tensor->grad,
|
16117
16847
|
ggml_div(ctx, tensor, src1)),
|
16118
|
-
|
16848
|
+
zero_table);
|
16119
16849
|
}
|
16120
16850
|
} break;
|
16121
16851
|
case GGML_OP_SQR:
|
16122
16852
|
{
|
16123
16853
|
if (src0->grad) {
|
16124
16854
|
src0->grad =
|
16125
|
-
|
16855
|
+
ggml_add_or_set(ctx,
|
16126
16856
|
src0->grad,
|
16127
16857
|
ggml_scale(ctx,
|
16128
16858
|
ggml_mul(ctx, src0, tensor->grad),
|
16129
16859
|
ggml_new_f32(ctx, 2.0f)),
|
16130
|
-
|
16860
|
+
zero_table);
|
16131
16861
|
}
|
16132
16862
|
} break;
|
16133
16863
|
case GGML_OP_SQRT:
|
16134
16864
|
{
|
16135
16865
|
if (src0->grad) {
|
16136
16866
|
src0->grad =
|
16137
|
-
|
16867
|
+
ggml_add_or_set(ctx,
|
16138
16868
|
src0->grad,
|
16139
16869
|
ggml_scale(ctx,
|
16140
16870
|
ggml_div(ctx,
|
16141
16871
|
tensor->grad,
|
16142
16872
|
tensor),
|
16143
16873
|
ggml_new_f32(ctx, 0.5f)),
|
16144
|
-
|
16874
|
+
zero_table);
|
16145
16875
|
}
|
16146
16876
|
} break;
|
16147
16877
|
case GGML_OP_LOG:
|
16148
16878
|
{
|
16149
16879
|
if (src0->grad) {
|
16150
16880
|
src0->grad =
|
16151
|
-
|
16881
|
+
ggml_add_or_set(ctx,
|
16152
16882
|
src0->grad,
|
16153
16883
|
ggml_div(ctx,
|
16154
16884
|
tensor->grad,
|
16155
16885
|
src0),
|
16156
|
-
|
16886
|
+
zero_table);
|
16157
16887
|
}
|
16158
16888
|
} break;
|
16159
16889
|
case GGML_OP_SUM:
|
16160
16890
|
{
|
16161
16891
|
if (src0->grad) {
|
16162
16892
|
src0->grad =
|
16163
|
-
|
16893
|
+
ggml_add1_or_set(ctx,
|
16164
16894
|
src0->grad,
|
16165
16895
|
tensor->grad,
|
16166
|
-
|
16896
|
+
zero_table);
|
16167
16897
|
}
|
16168
16898
|
} break;
|
16169
16899
|
case GGML_OP_SUM_ROWS:
|
16170
16900
|
{
|
16171
16901
|
if (src0->grad) {
|
16172
16902
|
src0->grad =
|
16173
|
-
|
16903
|
+
ggml_add_or_set(ctx,
|
16174
16904
|
src0->grad,
|
16175
16905
|
ggml_repeat(ctx,
|
16176
16906
|
tensor->grad,
|
16177
16907
|
src0->grad),
|
16178
|
-
|
16908
|
+
zero_table);
|
16179
16909
|
}
|
16180
16910
|
} break;
|
16181
16911
|
case GGML_OP_MEAN:
|
@@ -16187,20 +16917,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16187
16917
|
{
|
16188
16918
|
// necessary for llama
|
16189
16919
|
if (src0->grad) {
|
16190
|
-
src0->grad =
|
16920
|
+
src0->grad = ggml_add_or_set(ctx,
|
16191
16921
|
src0->grad,
|
16192
16922
|
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
16193
|
-
|
16923
|
+
zero_table);
|
16194
16924
|
}
|
16195
16925
|
} break;
|
16196
16926
|
case GGML_OP_REPEAT_BACK:
|
16197
16927
|
{
|
16198
16928
|
if (src0->grad) {
|
16199
16929
|
// TODO: test this
|
16200
|
-
src0->grad =
|
16930
|
+
src0->grad = ggml_add_or_set(ctx,
|
16201
16931
|
src0->grad,
|
16202
16932
|
ggml_repeat(ctx, tensor->grad, src0->grad),
|
16203
|
-
|
16933
|
+
zero_table);
|
16204
16934
|
}
|
16205
16935
|
} break;
|
16206
16936
|
case GGML_OP_CONCAT:
|
@@ -16222,10 +16952,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16222
16952
|
float eps;
|
16223
16953
|
memcpy(&eps, tensor->op_params, sizeof(float));
|
16224
16954
|
|
16225
|
-
src0->grad =
|
16955
|
+
src0->grad = ggml_add_or_set(ctx,
|
16226
16956
|
src0->grad,
|
16227
16957
|
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
|
16228
|
-
|
16958
|
+
zero_table);
|
16229
16959
|
}
|
16230
16960
|
} break;
|
16231
16961
|
case GGML_OP_RMS_NORM_BACK:
|
@@ -16249,37 +16979,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16249
16979
|
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
|
16250
16980
|
// ds1 = t.T.dot(dt)
|
16251
16981
|
|
16252
|
-
// tensor.shape [m,p]
|
16253
|
-
// src0.shape [n,m]
|
16254
|
-
// src1.shape [n,p]
|
16982
|
+
// tensor.shape [m,p,qq,rr]
|
16983
|
+
// src0.shape [n,m,q1,r1]
|
16984
|
+
// src1.shape [n,p,qq,rr]
|
16255
16985
|
|
16256
16986
|
// necessary for llama
|
16257
16987
|
if (src0->grad) {
|
16988
|
+
struct ggml_tensor * s1_tg =
|
16989
|
+
ggml_out_prod(ctx, // [n,m,qq,rr]
|
16990
|
+
src1, // [n,p,qq,rr]
|
16991
|
+
tensor->grad); // [m,p,qq,rr]
|
16992
|
+
const int64_t qq = s1_tg->ne[2];
|
16993
|
+
const int64_t rr = s1_tg->ne[3];
|
16994
|
+
const int64_t q1 = src0->ne[2];
|
16995
|
+
const int64_t r1 = src0->ne[3];
|
16996
|
+
const bool ne2_broadcasted = qq > q1;
|
16997
|
+
const bool ne3_broadcasted = rr > r1;
|
16998
|
+
if (ne2_broadcasted || ne3_broadcasted) {
|
16999
|
+
// sum broadcast repetitions of s1_tg into shape of src0
|
17000
|
+
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
|
17001
|
+
}
|
16258
17002
|
src0->grad =
|
16259
|
-
|
16260
|
-
src0->grad,
|
16261
|
-
|
16262
|
-
|
16263
|
-
tensor->grad), // [m,p]
|
16264
|
-
inplace);
|
17003
|
+
ggml_add_or_set(ctx,
|
17004
|
+
src0->grad, // [n,m,q1,r1]
|
17005
|
+
s1_tg, // [n,m,q1,r1]
|
17006
|
+
zero_table);
|
16265
17007
|
}
|
16266
17008
|
if (src1->grad) {
|
16267
17009
|
src1->grad =
|
16268
|
-
|
16269
|
-
src1->grad,
|
16270
|
-
// ggml_mul_mat(ctx, // [n,p]
|
16271
|
-
// ggml_cont(ctx, // [m,n]
|
16272
|
-
// ggml_transpose(ctx, src0)), // [m,n]
|
16273
|
-
// tensor->grad), // [m,p]
|
17010
|
+
ggml_add_or_set(ctx,
|
17011
|
+
src1->grad, // [n,p,qq,rr]
|
17012
|
+
// ggml_mul_mat(ctx, // [n,p,qq,rr]
|
17013
|
+
// ggml_cont(ctx, // [m,n,q1,r1]
|
17014
|
+
// ggml_transpose(ctx, src0)), // [m,n,q1,r1]
|
17015
|
+
// tensor->grad), // [m,p,qq,rr]
|
16274
17016
|
|
16275
17017
|
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
16276
17018
|
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
16277
17019
|
// // and then use ggml_out_prod
|
16278
|
-
ggml_out_prod(ctx, // [n,p]
|
16279
|
-
src0, // [n,m]
|
16280
|
-
ggml_transpose(ctx, // [p,m]
|
16281
|
-
tensor->grad)), // [m,p]
|
16282
|
-
|
17020
|
+
ggml_out_prod(ctx, // [n,p,qq,rr]
|
17021
|
+
src0, // [n,m,q1,r1]
|
17022
|
+
ggml_transpose(ctx, // [p,m,qq,rr]
|
17023
|
+
tensor->grad)), // [m,p,qq,rr]
|
17024
|
+
zero_table);
|
16283
17025
|
}
|
16284
17026
|
} break;
|
16285
17027
|
case GGML_OP_OUT_PROD:
|
@@ -16291,17 +17033,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16291
17033
|
// necessary for llama
|
16292
17034
|
if (src0->grad) {
|
16293
17035
|
src0->grad =
|
16294
|
-
|
17036
|
+
ggml_add_or_set(ctx,
|
16295
17037
|
src0->grad,
|
16296
17038
|
ggml_scale_impl(ctx, tensor->grad, src1, false),
|
16297
|
-
|
17039
|
+
zero_table);
|
16298
17040
|
}
|
16299
17041
|
if (src1->grad) {
|
16300
17042
|
src1->grad =
|
16301
|
-
|
17043
|
+
ggml_add_or_set(ctx,
|
16302
17044
|
src1->grad,
|
16303
17045
|
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
|
16304
|
-
|
17046
|
+
zero_table);
|
16305
17047
|
}
|
16306
17048
|
} break;
|
16307
17049
|
case GGML_OP_SET:
|
@@ -16328,23 +17070,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16328
17070
|
}
|
16329
17071
|
|
16330
17072
|
if (src0->grad) {
|
16331
|
-
src0->grad =
|
17073
|
+
src0->grad = ggml_add_or_set(ctx,
|
16332
17074
|
src0->grad,
|
16333
17075
|
ggml_acc_impl(ctx,
|
16334
17076
|
tensor->grad,
|
16335
17077
|
ggml_neg(ctx, tensor_grad_view),
|
16336
17078
|
nb1, nb2, nb3, offset, false),
|
16337
|
-
|
17079
|
+
zero_table);
|
16338
17080
|
}
|
16339
17081
|
|
16340
17082
|
if (src1->grad) {
|
16341
17083
|
src1->grad =
|
16342
|
-
|
17084
|
+
ggml_add_or_set(ctx,
|
16343
17085
|
src1->grad,
|
16344
17086
|
ggml_reshape(ctx,
|
16345
17087
|
ggml_cont(ctx, tensor_grad_view),
|
16346
17088
|
src1->grad),
|
16347
|
-
|
17089
|
+
zero_table);
|
16348
17090
|
}
|
16349
17091
|
} break;
|
16350
17092
|
case GGML_OP_CPY:
|
@@ -16355,7 +17097,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16355
17097
|
// tensor = src0 * 1 + src1 * 0
|
16356
17098
|
if (src0->grad) {
|
16357
17099
|
// dsrc0 = dtensor * 1
|
16358
|
-
src0->grad =
|
17100
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16359
17101
|
}
|
16360
17102
|
if (src1->grad) {
|
16361
17103
|
// dsrc1 = dtensor * 0 -> noop
|
@@ -16367,7 +17109,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16367
17109
|
if (src0->grad) {
|
16368
17110
|
GGML_ASSERT(ggml_is_contiguous(src0->grad));
|
16369
17111
|
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
|
16370
|
-
src0->grad =
|
17112
|
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16371
17113
|
}
|
16372
17114
|
} break;
|
16373
17115
|
case GGML_OP_RESHAPE:
|
@@ -16375,9 +17117,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16375
17117
|
// necessary for llama
|
16376
17118
|
if (src0->grad) {
|
16377
17119
|
src0->grad =
|
16378
|
-
|
16379
|
-
ggml_reshape(ctx,
|
16380
|
-
|
17120
|
+
ggml_add_or_set(ctx, src0->grad,
|
17121
|
+
ggml_reshape(ctx,
|
17122
|
+
ggml_is_contiguous(tensor->grad)
|
17123
|
+
? tensor->grad
|
17124
|
+
: ggml_cont(ctx, tensor->grad),
|
17125
|
+
src0->grad),
|
17126
|
+
zero_table);
|
16381
17127
|
}
|
16382
17128
|
} break;
|
16383
17129
|
case GGML_OP_VIEW:
|
@@ -16406,7 +17152,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16406
17152
|
nb3 = (nb3 / n0) * ng;
|
16407
17153
|
}
|
16408
17154
|
|
16409
|
-
src0->grad =
|
17155
|
+
src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
|
16410
17156
|
}
|
16411
17157
|
} break;
|
16412
17158
|
case GGML_OP_PERMUTE:
|
@@ -16424,14 +17170,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16424
17170
|
axes_backward[axis2] = 2;
|
16425
17171
|
axes_backward[axis3] = 3;
|
16426
17172
|
src0->grad =
|
16427
|
-
|
17173
|
+
ggml_add_or_set(ctx, src0->grad,
|
16428
17174
|
ggml_permute(ctx,
|
16429
17175
|
tensor->grad,
|
16430
17176
|
axes_backward[0],
|
16431
17177
|
axes_backward[1],
|
16432
17178
|
axes_backward[2],
|
16433
17179
|
axes_backward[3]),
|
16434
|
-
|
17180
|
+
zero_table);
|
16435
17181
|
}
|
16436
17182
|
} break;
|
16437
17183
|
case GGML_OP_TRANSPOSE:
|
@@ -16439,9 +17185,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16439
17185
|
// necessary for llama
|
16440
17186
|
if (src0->grad) {
|
16441
17187
|
src0->grad =
|
16442
|
-
|
17188
|
+
ggml_add_or_set(ctx, src0->grad,
|
16443
17189
|
ggml_transpose(ctx, tensor->grad),
|
16444
|
-
|
17190
|
+
zero_table);
|
16445
17191
|
}
|
16446
17192
|
} break;
|
16447
17193
|
case GGML_OP_GET_ROWS:
|
@@ -16449,9 +17195,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16449
17195
|
// necessary for llama (only for tokenizer)
|
16450
17196
|
if (src0->grad) {
|
16451
17197
|
src0->grad =
|
16452
|
-
|
17198
|
+
ggml_add_or_set(ctx, src0->grad,
|
17199
|
+
// last ggml_get_rows_back argument src0->grad is only
|
17200
|
+
// necessary to setup correct output shape
|
16453
17201
|
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
|
16454
|
-
|
17202
|
+
zero_table);
|
16455
17203
|
}
|
16456
17204
|
if (src1->grad) {
|
16457
17205
|
// noop
|
@@ -16471,9 +17219,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16471
17219
|
if (src0->grad) {
|
16472
17220
|
const int n_past = ((int32_t *) tensor->op_params)[0];
|
16473
17221
|
src0->grad =
|
16474
|
-
|
17222
|
+
ggml_add_or_set(ctx, src0->grad,
|
16475
17223
|
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
16476
|
-
|
17224
|
+
zero_table);
|
16477
17225
|
}
|
16478
17226
|
} break;
|
16479
17227
|
case GGML_OP_DIAG_MASK_ZERO:
|
@@ -16482,9 +17230,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16482
17230
|
if (src0->grad) {
|
16483
17231
|
const int n_past = ((int32_t *) tensor->op_params)[0];
|
16484
17232
|
src0->grad =
|
16485
|
-
|
17233
|
+
ggml_add_or_set(ctx, src0->grad,
|
16486
17234
|
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
16487
|
-
|
17235
|
+
zero_table);
|
16488
17236
|
}
|
16489
17237
|
} break;
|
16490
17238
|
case GGML_OP_SOFT_MAX:
|
@@ -16492,9 +17240,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16492
17240
|
// necessary for llama
|
16493
17241
|
if (src0->grad) {
|
16494
17242
|
src0->grad =
|
16495
|
-
|
17243
|
+
ggml_add_or_set(ctx, src0->grad,
|
16496
17244
|
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
16497
|
-
|
17245
|
+
zero_table);
|
16498
17246
|
}
|
16499
17247
|
|
16500
17248
|
} break;
|
@@ -16506,7 +17254,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16506
17254
|
{
|
16507
17255
|
// necessary for llama
|
16508
17256
|
if (src0->grad) {
|
16509
|
-
const int n_past = ((int32_t *) tensor->op_params)[0];
|
17257
|
+
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
16510
17258
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
16511
17259
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
16512
17260
|
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
@@ -16519,11 +17267,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16519
17267
|
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
|
16520
17268
|
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
|
16521
17269
|
|
16522
|
-
src0->grad =
|
17270
|
+
src0->grad = ggml_add_or_set(ctx,
|
16523
17271
|
src0->grad,
|
16524
17272
|
ggml_rope_back(ctx,
|
16525
17273
|
tensor->grad,
|
16526
|
-
|
17274
|
+
src1,
|
16527
17275
|
n_dims,
|
16528
17276
|
mode,
|
16529
17277
|
n_ctx,
|
@@ -16531,13 +17279,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16531
17279
|
freq_scale,
|
16532
17280
|
xpos_base,
|
16533
17281
|
xpos_down),
|
16534
|
-
|
17282
|
+
zero_table);
|
16535
17283
|
}
|
16536
17284
|
} break;
|
16537
17285
|
case GGML_OP_ROPE_BACK:
|
16538
17286
|
{
|
16539
17287
|
if (src0->grad) {
|
16540
|
-
const int n_past = ((int32_t *) tensor->op_params)[0];
|
17288
|
+
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
16541
17289
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
16542
17290
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
16543
17291
|
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
@@ -16550,11 +17298,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16550
17298
|
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
|
16551
17299
|
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
|
16552
17300
|
|
16553
|
-
src0->grad =
|
17301
|
+
src0->grad = ggml_add_or_set(ctx,
|
16554
17302
|
src0->grad,
|
16555
17303
|
ggml_rope_impl(ctx,
|
16556
17304
|
tensor->grad,
|
16557
|
-
|
17305
|
+
src1,
|
16558
17306
|
n_dims,
|
16559
17307
|
mode,
|
16560
17308
|
n_ctx,
|
@@ -16563,7 +17311,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16563
17311
|
xpos_base,
|
16564
17312
|
xpos_down,
|
16565
17313
|
false),
|
16566
|
-
|
17314
|
+
zero_table);
|
16567
17315
|
}
|
16568
17316
|
} break;
|
16569
17317
|
case GGML_OP_ALIBI:
|
@@ -16614,145 +17362,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16614
17362
|
masked);
|
16615
17363
|
}
|
16616
17364
|
|
16617
|
-
|
16618
|
-
|
16619
|
-
|
16620
|
-
|
16621
|
-
|
16622
|
-
|
16623
|
-
|
16624
|
-
|
16625
|
-
|
16626
|
-
|
16627
|
-
|
16628
|
-
|
16629
|
-
offset);
|
16630
|
-
} break;
|
16631
|
-
case 3:
|
16632
|
-
{
|
16633
|
-
grad_q = ggml_view_3d(ctx,
|
16634
|
-
flash_grad,
|
16635
|
-
src0->ne[0],
|
16636
|
-
src0->ne[1],
|
16637
|
-
src0->ne[2],
|
16638
|
-
nb0*src0->ne[0],
|
16639
|
-
nb0*src0->ne[0]*src0->ne[1],
|
16640
|
-
offset);
|
16641
|
-
} break;
|
16642
|
-
case 4:
|
16643
|
-
{
|
16644
|
-
grad_q = ggml_view_4d(ctx,
|
16645
|
-
flash_grad,
|
16646
|
-
src0->ne[0],
|
16647
|
-
src0->ne[1],
|
16648
|
-
src0->ne[2],
|
16649
|
-
src0->ne[3],
|
16650
|
-
nb0*src0->ne[0],
|
16651
|
-
nb0*src0->ne[0]*src0->ne[1],
|
16652
|
-
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
16653
|
-
offset);
|
16654
|
-
} break;
|
16655
|
-
}
|
17365
|
+
struct ggml_tensor * src2 = tensor->src[2];
|
17366
|
+
const int64_t elem_q = ggml_nelements(src0);
|
17367
|
+
const int64_t elem_k = ggml_nelements(src1);
|
17368
|
+
const int64_t elem_v = ggml_nelements(src2);
|
17369
|
+
|
17370
|
+
enum ggml_type result_type = flash_grad->type;
|
17371
|
+
GGML_ASSERT(ggml_blck_size(result_type) == 1);
|
17372
|
+
const size_t tsize = ggml_type_size(result_type);
|
17373
|
+
|
17374
|
+
const size_t offs_q = 0;
|
17375
|
+
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
|
17376
|
+
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
|
16656
17377
|
|
16657
|
-
|
17378
|
+
if (src0->grad) {
|
17379
|
+
struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
|
17380
|
+
struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
|
17381
|
+
src0->grad = ggml_add_or_set(ctx,
|
16658
17382
|
src0->grad,
|
16659
17383
|
grad_q,
|
16660
|
-
|
17384
|
+
zero_table);
|
16661
17385
|
}
|
16662
|
-
|
16663
17386
|
if (src1->grad) {
|
16664
|
-
struct ggml_tensor *
|
16665
|
-
|
16666
|
-
|
16667
|
-
switch(src1->n_dims) {
|
16668
|
-
case 2:
|
16669
|
-
{
|
16670
|
-
grad_k = ggml_view_2d(ctx,
|
16671
|
-
flash_grad,
|
16672
|
-
src1->ne[0],
|
16673
|
-
src1->ne[1],
|
16674
|
-
nb0*src1->ne[0],
|
16675
|
-
offset);
|
16676
|
-
} break;
|
16677
|
-
case 3:
|
16678
|
-
{
|
16679
|
-
grad_k = ggml_view_3d(ctx,
|
16680
|
-
flash_grad,
|
16681
|
-
src1->ne[0],
|
16682
|
-
src1->ne[1],
|
16683
|
-
src1->ne[2],
|
16684
|
-
nb0*src1->ne[0],
|
16685
|
-
nb0*src1->ne[0]*src1->ne[1],
|
16686
|
-
offset);
|
16687
|
-
} break;
|
16688
|
-
case 4:
|
16689
|
-
{
|
16690
|
-
grad_k = ggml_view_4d(ctx,
|
16691
|
-
flash_grad,
|
16692
|
-
src1->ne[0],
|
16693
|
-
src1->ne[1],
|
16694
|
-
src1->ne[2],
|
16695
|
-
src1->ne[3],
|
16696
|
-
nb0*src1->ne[0],
|
16697
|
-
nb0*src1->ne[0]*src1->ne[1],
|
16698
|
-
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
16699
|
-
offset);
|
16700
|
-
} break;
|
16701
|
-
}
|
16702
|
-
|
16703
|
-
src1->grad = ggml_add_impl(ctx,
|
17387
|
+
struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
|
17388
|
+
struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
|
17389
|
+
src1->grad = ggml_add_or_set(ctx,
|
16704
17390
|
src1->grad,
|
16705
17391
|
grad_k,
|
16706
|
-
|
17392
|
+
zero_table);
|
16707
17393
|
}
|
16708
|
-
|
16709
|
-
|
16710
|
-
|
16711
|
-
|
16712
|
-
|
16713
|
-
const size_t nb0 = flash_grad->nb[0];
|
16714
|
-
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
16715
|
-
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
16716
|
-
switch(opt0->n_dims) {
|
16717
|
-
case 2:
|
16718
|
-
{
|
16719
|
-
grad_v = ggml_view_2d(ctx,
|
16720
|
-
flash_grad,
|
16721
|
-
opt0->ne[0],
|
16722
|
-
opt0->ne[1],
|
16723
|
-
nb0*opt0->ne[0],
|
16724
|
-
offset);
|
16725
|
-
} break;
|
16726
|
-
case 3:
|
16727
|
-
{
|
16728
|
-
grad_v = ggml_view_3d(ctx,
|
16729
|
-
flash_grad,
|
16730
|
-
opt0->ne[0],
|
16731
|
-
opt0->ne[1],
|
16732
|
-
opt0->ne[2],
|
16733
|
-
nb0*opt0->ne[0],
|
16734
|
-
nb0*opt0->ne[0]*opt0->ne[1],
|
16735
|
-
offset);
|
16736
|
-
} break;
|
16737
|
-
case 4:
|
16738
|
-
{
|
16739
|
-
grad_v = ggml_view_4d(ctx,
|
16740
|
-
flash_grad,
|
16741
|
-
opt0->ne[0],
|
16742
|
-
opt0->ne[1],
|
16743
|
-
opt0->ne[2],
|
16744
|
-
opt0->ne[3],
|
16745
|
-
nb0*opt0->ne[0],
|
16746
|
-
nb0*opt0->ne[0]*opt0->ne[1],
|
16747
|
-
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
16748
|
-
offset);
|
16749
|
-
} break;
|
16750
|
-
}
|
16751
|
-
|
16752
|
-
opt0->grad = ggml_add_impl(ctx,
|
16753
|
-
opt0->grad,
|
17394
|
+
if (src2->grad) {
|
17395
|
+
struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
|
17396
|
+
struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
|
17397
|
+
src2->grad = ggml_add_or_set(ctx,
|
17398
|
+
src2->grad,
|
16754
17399
|
grad_v,
|
16755
|
-
|
17400
|
+
zero_table);
|
16756
17401
|
}
|
16757
17402
|
} break;
|
16758
17403
|
case GGML_OP_FLASH_FF:
|
@@ -16772,12 +17417,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16772
17417
|
{
|
16773
17418
|
if (src0->grad) {
|
16774
17419
|
src0->grad =
|
16775
|
-
|
17420
|
+
ggml_add_or_set(ctx,
|
16776
17421
|
src0->grad,
|
16777
17422
|
ggml_mul(ctx,
|
16778
17423
|
ggml_sgn(ctx, src0),
|
16779
17424
|
tensor->grad),
|
16780
|
-
|
17425
|
+
zero_table);
|
16781
17426
|
}
|
16782
17427
|
} break;
|
16783
17428
|
case GGML_UNARY_OP_SGN:
|
@@ -16789,7 +17434,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16789
17434
|
case GGML_UNARY_OP_NEG:
|
16790
17435
|
{
|
16791
17436
|
if (src0->grad) {
|
16792
|
-
src0->grad =
|
17437
|
+
src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
16793
17438
|
}
|
16794
17439
|
} break;
|
16795
17440
|
case GGML_UNARY_OP_STEP:
|
@@ -16809,12 +17454,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16809
17454
|
case GGML_UNARY_OP_RELU:
|
16810
17455
|
{
|
16811
17456
|
if (src0->grad) {
|
16812
|
-
src0->grad =
|
17457
|
+
src0->grad = ggml_add_or_set(ctx,
|
16813
17458
|
src0->grad,
|
16814
17459
|
ggml_mul(ctx,
|
16815
17460
|
ggml_step(ctx, src0),
|
16816
17461
|
tensor->grad),
|
16817
|
-
|
17462
|
+
zero_table);
|
16818
17463
|
}
|
16819
17464
|
} break;
|
16820
17465
|
case GGML_UNARY_OP_GELU:
|
@@ -16829,10 +17474,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16829
17474
|
{
|
16830
17475
|
// necessary for llama
|
16831
17476
|
if (src0->grad) {
|
16832
|
-
src0->grad =
|
17477
|
+
src0->grad = ggml_add_or_set(ctx,
|
16833
17478
|
src0->grad,
|
16834
17479
|
ggml_silu_back(ctx, src0, tensor->grad),
|
16835
|
-
|
17480
|
+
zero_table);
|
16836
17481
|
}
|
16837
17482
|
} break;
|
16838
17483
|
default:
|
@@ -16855,13 +17500,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16855
17500
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
16856
17501
|
{
|
16857
17502
|
if (src0->grad) {
|
16858
|
-
src0->grad =
|
17503
|
+
src0->grad = ggml_add_or_set(ctx,
|
16859
17504
|
src0->grad,
|
16860
17505
|
ggml_cross_entropy_loss_back(ctx,
|
16861
17506
|
src0,
|
16862
17507
|
src1,
|
16863
17508
|
tensor->grad),
|
16864
|
-
|
17509
|
+
zero_table);
|
16865
17510
|
}
|
16866
17511
|
} break;
|
16867
17512
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
@@ -16877,34 +17522,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
16877
17522
|
GGML_ASSERT(false);
|
16878
17523
|
} break;
|
16879
17524
|
}
|
16880
|
-
}
|
16881
|
-
|
16882
|
-
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
|
16883
|
-
|
16884
|
-
static size_t hash(void * p) {
|
16885
|
-
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
16886
|
-
}
|
16887
17525
|
|
16888
|
-
|
16889
|
-
|
16890
|
-
|
16891
|
-
// linear probing
|
16892
|
-
size_t i = h;
|
16893
|
-
while (hash_table[i] != NULL && hash_table[i] != p) {
|
16894
|
-
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
16895
|
-
if (i == h) {
|
16896
|
-
// hash table is full
|
16897
|
-
GGML_ASSERT(false);
|
17526
|
+
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
17527
|
+
if (tensor->src[i] && tensor->src[i]->grad) {
|
17528
|
+
GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
|
16898
17529
|
}
|
16899
17530
|
}
|
16900
|
-
|
16901
|
-
if (hash_table[i] == p) {
|
16902
|
-
return true;
|
16903
|
-
}
|
16904
|
-
|
16905
|
-
// insert
|
16906
|
-
hash_table[i] = p;
|
16907
|
-
return false;
|
16908
17531
|
}
|
16909
17532
|
|
16910
17533
|
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
|
@@ -16922,8 +17545,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
|
16922
17545
|
}
|
16923
17546
|
|
16924
17547
|
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
16925
|
-
|
16926
|
-
|
17548
|
+
const int k =
|
17549
|
+
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
|
17550
|
+
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
|
17551
|
+
/* unknown order, just fall back to using i*/ i;
|
17552
|
+
if (node->src[k]) {
|
17553
|
+
ggml_visit_parents(cgraph, node->src[k]);
|
16927
17554
|
}
|
16928
17555
|
}
|
16929
17556
|
|
@@ -16982,6 +17609,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
|
|
16982
17609
|
/*.grads =*/ { NULL },
|
16983
17610
|
/*.leafs =*/ { NULL },
|
16984
17611
|
/*.hash_table =*/ { NULL },
|
17612
|
+
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
|
16985
17613
|
/*.perf_runs =*/ 0,
|
16986
17614
|
/*.perf_cycles =*/ 0,
|
16987
17615
|
/*.perf_time_us =*/ 0,
|
@@ -17007,12 +17635,22 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|
17007
17635
|
}
|
17008
17636
|
}
|
17009
17637
|
|
17638
|
+
// remember original gradients which start with zero values
|
17639
|
+
void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
|
17640
|
+
memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
|
17641
|
+
for (int i = 0; i < gf->n_nodes; i++) {
|
17642
|
+
if (gf->grads[i]) {
|
17643
|
+
hash_insert(zero_table, gf->grads[i]);
|
17644
|
+
}
|
17645
|
+
}
|
17646
|
+
|
17010
17647
|
for (int i = gf->n_nodes - 1; i >= 0; i--) {
|
17011
17648
|
struct ggml_tensor * node = gf->nodes[i];
|
17012
17649
|
|
17013
|
-
//
|
17650
|
+
// inplace operations to add gradients are not created by ggml_compute_backward
|
17651
|
+
// use allocator to automatically make inplace operations
|
17014
17652
|
if (node->grad) {
|
17015
|
-
ggml_compute_backward(ctx, node,
|
17653
|
+
ggml_compute_backward(ctx, node, zero_table);
|
17016
17654
|
}
|
17017
17655
|
}
|
17018
17656
|
|
@@ -17024,6 +17662,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|
17024
17662
|
ggml_build_forward_expand(gb, node->grad);
|
17025
17663
|
}
|
17026
17664
|
}
|
17665
|
+
|
17666
|
+
free(zero_table);
|
17027
17667
|
}
|
17028
17668
|
|
17029
17669
|
struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
|
@@ -17043,6 +17683,7 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
|
|
17043
17683
|
/*.grads =*/ { NULL },
|
17044
17684
|
/*.leafs =*/ { NULL },
|
17045
17685
|
/*.hash_table =*/ { NULL },
|
17686
|
+
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
|
17046
17687
|
/*.perf_runs =*/ 0,
|
17047
17688
|
/*.perf_cycles =*/ 0,
|
17048
17689
|
/*.perf_time_us =*/ 0,
|
@@ -17433,7 +18074,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|
17433
18074
|
} break;
|
17434
18075
|
case GGML_OP_CONCAT:
|
17435
18076
|
case GGML_OP_MUL_MAT:
|
17436
|
-
case GGML_OP_OUT_PROD:
|
17437
18077
|
{
|
17438
18078
|
n_tasks = n_threads;
|
17439
18079
|
|
@@ -17475,6 +18115,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|
17475
18115
|
cur = 0;
|
17476
18116
|
}
|
17477
18117
|
|
18118
|
+
work_size = MAX(work_size, cur);
|
18119
|
+
} break;
|
18120
|
+
case GGML_OP_OUT_PROD:
|
18121
|
+
{
|
18122
|
+
n_tasks = n_threads;
|
18123
|
+
|
18124
|
+
size_t cur = 0;
|
18125
|
+
|
18126
|
+
if (ggml_is_quantized(node->src[0]->type)) {
|
18127
|
+
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
18128
|
+
}
|
18129
|
+
|
17478
18130
|
work_size = MAX(work_size, cur);
|
17479
18131
|
} break;
|
17480
18132
|
case GGML_OP_SCALE:
|
@@ -18568,7 +19220,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
|
|
18568
19220
|
}
|
18569
19221
|
|
18570
19222
|
static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
|
18571
|
-
|
19223
|
+
int64_t i = 0;
|
18572
19224
|
for (int p = 0; p < np; ++p) {
|
18573
19225
|
const int64_t ne = ggml_nelements(ps[p]) ;
|
18574
19226
|
// TODO: add function to get all elements at once
|
@@ -18578,6 +19230,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
|
|
18578
19230
|
}
|
18579
19231
|
}
|
18580
19232
|
|
19233
|
+
static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
|
19234
|
+
int64_t i = 0;
|
19235
|
+
for (int p = 0; p < np; ++p) {
|
19236
|
+
const int64_t ne = ggml_nelements(ps[p]) ;
|
19237
|
+
// TODO: add function to get all elements at once
|
19238
|
+
for (int64_t j = 0; j < ne; ++j) {
|
19239
|
+
g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
|
19240
|
+
}
|
19241
|
+
}
|
19242
|
+
}
|
19243
|
+
|
18581
19244
|
//
|
18582
19245
|
// ADAM
|
18583
19246
|
//
|
@@ -18626,26 +19289,43 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18626
19289
|
const float eps = params.adam.eps;
|
18627
19290
|
const float gclip = params.adam.gclip;
|
18628
19291
|
const int decay_min_ndim = params.adam.decay_min_ndim;
|
19292
|
+
const int n_accum = MAX(1, params.n_gradient_accumulation);
|
19293
|
+
const float accum_norm = 1.0f / (float) n_accum;
|
18629
19294
|
|
19295
|
+
float * g = opt->adam.g->data; // gradients
|
18630
19296
|
float * m = opt->adam.m->data; // first moment
|
18631
19297
|
float * v = opt->adam.v->data; // second moment
|
18632
19298
|
|
18633
19299
|
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
18634
19300
|
|
18635
|
-
if (callback) {
|
18636
|
-
callback(callback_data, &sched);
|
18637
|
-
}
|
18638
|
-
|
18639
|
-
// compute the function value
|
18640
|
-
ggml_graph_reset (gf);
|
18641
|
-
ggml_set_f32 (f->grad, 1.0f);
|
18642
|
-
|
18643
19301
|
struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
|
18644
19302
|
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
|
18645
19303
|
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
|
18646
|
-
ggml_graph_compute(gb, &cplan);
|
18647
19304
|
|
18648
|
-
|
19305
|
+
bool cancel = false;
|
19306
|
+
|
19307
|
+
// compute the function value
|
19308
|
+
float fx = 0;
|
19309
|
+
ggml_set_zero(opt->adam.g);
|
19310
|
+
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
19311
|
+
if (callback) {
|
19312
|
+
callback(callback_data, accum_step, &sched, &cancel);
|
19313
|
+
if (cancel) {
|
19314
|
+
break;
|
19315
|
+
}
|
19316
|
+
}
|
19317
|
+
// ggml_graph_reset (gf);
|
19318
|
+
ggml_set_f32 (f->grad, 1.0f);
|
19319
|
+
ggml_graph_compute(gb, &cplan);
|
19320
|
+
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19321
|
+
fx += ggml_get_f32_1d(f, 0);
|
19322
|
+
}
|
19323
|
+
if (cancel) {
|
19324
|
+
return GGML_OPT_DID_NOT_CONVERGE;
|
19325
|
+
}
|
19326
|
+
fx *= accum_norm;
|
19327
|
+
|
19328
|
+
opt->adam.fx_prev = fx;
|
18649
19329
|
opt->adam.fx_best = opt->adam.fx_prev;
|
18650
19330
|
if (pf) {
|
18651
19331
|
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
@@ -18668,6 +19348,9 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18668
19348
|
|
18669
19349
|
// run the optimizer
|
18670
19350
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
19351
|
+
if (cancel) {
|
19352
|
+
break;
|
19353
|
+
}
|
18671
19354
|
opt->iter = iter0 + t + 1;
|
18672
19355
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
18673
19356
|
|
@@ -18690,12 +19373,8 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18690
19373
|
if (gclip > 0.0f) {
|
18691
19374
|
// gradient clipping
|
18692
19375
|
ggml_float sum = 0.0;
|
18693
|
-
for (
|
18694
|
-
|
18695
|
-
for (int64_t j = 0; j < ne; ++j) {
|
18696
|
-
float g = ggml_get_f32_1d(ps[p]->grad, j);
|
18697
|
-
sum += (ggml_float)(g*g);
|
18698
|
-
}
|
19376
|
+
for (int64_t i = 0; i < nx; ++i) {
|
19377
|
+
sum += (ggml_float)(g[i]*g[i]);
|
18699
19378
|
}
|
18700
19379
|
ggml_float norm = sqrt(sum);
|
18701
19380
|
if (norm > (ggml_float) gclip) {
|
@@ -18709,10 +19388,10 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18709
19388
|
const int64_t ne = ggml_nelements(ps[p]);
|
18710
19389
|
const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
|
18711
19390
|
for (int64_t j = 0; j < ne; ++j) {
|
18712
|
-
float x
|
18713
|
-
float
|
18714
|
-
m[i] = m[i]*beta1 +
|
18715
|
-
v[i] = v[i]*beta2 +
|
19391
|
+
float x = ggml_get_f32_1d(ps[p], j);
|
19392
|
+
float g_ = g[i]*gnorm;
|
19393
|
+
m[i] = m[i]*beta1 + g_*(1.0f - beta1);
|
19394
|
+
v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
|
18716
19395
|
float mh = m[i]*beta1h;
|
18717
19396
|
float vh = v[i]*beta2h;
|
18718
19397
|
vh = sqrtf(vh) + eps;
|
@@ -18723,16 +19402,26 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
18723
19402
|
}
|
18724
19403
|
}
|
18725
19404
|
|
18726
|
-
|
18727
|
-
|
19405
|
+
fx = 0;
|
19406
|
+
ggml_set_zero(opt->adam.g);
|
19407
|
+
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
19408
|
+
if (callback) {
|
19409
|
+
callback(callback_data, accum_step, &sched, &cancel);
|
19410
|
+
if (cancel) {
|
19411
|
+
break;
|
19412
|
+
}
|
19413
|
+
}
|
19414
|
+
// ggml_graph_reset (gf);
|
19415
|
+
ggml_set_f32 (f->grad, 1.0f);
|
19416
|
+
ggml_graph_compute(gb, &cplan);
|
19417
|
+
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19418
|
+
fx += ggml_get_f32_1d(f, 0);
|
18728
19419
|
}
|
19420
|
+
if (cancel) {
|
19421
|
+
break;
|
19422
|
+
}
|
19423
|
+
fx *= accum_norm;
|
18729
19424
|
|
18730
|
-
ggml_graph_reset (gf);
|
18731
|
-
ggml_set_f32 (f->grad, 1.0f);
|
18732
|
-
|
18733
|
-
ggml_graph_compute(gb, &cplan);
|
18734
|
-
|
18735
|
-
const float fx = ggml_get_f32_1d(f, 0);
|
18736
19425
|
opt->loss_after = fx;
|
18737
19426
|
|
18738
19427
|
|
@@ -18812,11 +19501,11 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18812
19501
|
float * step,
|
18813
19502
|
const float * xp,
|
18814
19503
|
struct ggml_tensor * f,
|
18815
|
-
struct ggml_cgraph * gf,
|
18816
19504
|
struct ggml_cgraph * gb,
|
18817
19505
|
struct ggml_cplan * cplan,
|
18818
19506
|
const int np,
|
18819
19507
|
struct ggml_tensor * ps[],
|
19508
|
+
bool * cancel,
|
18820
19509
|
ggml_opt_callback callback,
|
18821
19510
|
void * callback_data) {
|
18822
19511
|
int count = 0;
|
@@ -18830,6 +19519,9 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18830
19519
|
const float dec = 0.5f;
|
18831
19520
|
const float inc = 2.1f;
|
18832
19521
|
|
19522
|
+
const int n_accum = MAX(1, params->n_gradient_accumulation);
|
19523
|
+
const float accum_norm = 1.0f / (float) n_accum;
|
19524
|
+
|
18833
19525
|
if (*step <= 0.f) {
|
18834
19526
|
return GGML_LINESEARCH_INVALID_PARAMETERS;
|
18835
19527
|
}
|
@@ -18846,13 +19538,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18846
19538
|
finit = *fx;
|
18847
19539
|
dgtest = params->lbfgs.ftol*dginit;
|
18848
19540
|
|
18849
|
-
while (
|
18850
|
-
if (callback) {
|
18851
|
-
// LBFG-S does not support learning rate -> ignore learning schedule
|
18852
|
-
float sched = 0;
|
18853
|
-
callback(callback_data, &sched);
|
18854
|
-
}
|
18855
|
-
|
19541
|
+
while (!*cancel) {
|
18856
19542
|
ggml_vec_cpy_f32(nx, x, xp);
|
18857
19543
|
ggml_vec_mad_f32(nx, x, d, *step);
|
18858
19544
|
|
@@ -18860,14 +19546,28 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18860
19546
|
{
|
18861
19547
|
ggml_opt_set_params(np, ps, x);
|
18862
19548
|
|
18863
|
-
|
18864
|
-
|
18865
|
-
|
18866
|
-
|
18867
|
-
|
18868
|
-
|
19549
|
+
*fx = 0;
|
19550
|
+
memset(g, 0, sizeof(float)*nx);
|
19551
|
+
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
19552
|
+
if (callback) {
|
19553
|
+
// LBFG-S does not support learning rate -> ignore learning schedule
|
19554
|
+
float sched = 0;
|
19555
|
+
callback(callback_data, accum_step, &sched, cancel);
|
19556
|
+
if (*cancel) {
|
19557
|
+
break;
|
19558
|
+
}
|
19559
|
+
}
|
19560
|
+
// ggml_graph_reset (gf);
|
19561
|
+
ggml_set_f32 (f->grad, 1.0f);
|
19562
|
+
ggml_graph_compute(gb, cplan);
|
19563
|
+
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19564
|
+
*fx += ggml_get_f32_1d(f, 0);
|
19565
|
+
}
|
19566
|
+
if (*cancel) {
|
19567
|
+
break;
|
19568
|
+
}
|
19569
|
+
*fx *= accum_norm;
|
18869
19570
|
|
18870
|
-
*fx = ggml_get_f32_1d(f, 0);
|
18871
19571
|
}
|
18872
19572
|
|
18873
19573
|
++count;
|
@@ -18913,7 +19613,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
18913
19613
|
(*step) *= width;
|
18914
19614
|
}
|
18915
19615
|
|
18916
|
-
|
19616
|
+
GGML_UNREACHABLE();
|
18917
19617
|
}
|
18918
19618
|
|
18919
19619
|
static enum ggml_opt_result ggml_opt_lbfgs(
|
@@ -18968,6 +19668,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
18968
19668
|
|
18969
19669
|
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
18970
19670
|
|
19671
|
+
const int n_accum = MAX(1, params.n_gradient_accumulation);
|
19672
|
+
const float accum_norm = 1.0f / (float) n_accum;
|
19673
|
+
|
18971
19674
|
float fx = 0.0f; // cost function value
|
18972
19675
|
float xnorm = 0.0f; // ||x||
|
18973
19676
|
float gnorm = 0.0f; // ||g||
|
@@ -18981,24 +19684,33 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
18981
19684
|
float * lm_s = opt->lbfgs.lms->data;
|
18982
19685
|
float * lm_y = opt->lbfgs.lmy->data;
|
18983
19686
|
|
18984
|
-
|
18985
|
-
// LBFG-S does not support learning rate -> ignore learning schedule
|
18986
|
-
float sched = 0;
|
18987
|
-
callback(callback_data, &sched);
|
18988
|
-
}
|
19687
|
+
bool cancel = false;
|
18989
19688
|
|
18990
19689
|
// evaluate the function value and its gradient
|
18991
19690
|
{
|
18992
19691
|
ggml_opt_set_params(np, ps, x);
|
18993
19692
|
|
18994
|
-
|
18995
|
-
|
18996
|
-
|
18997
|
-
|
18998
|
-
|
18999
|
-
|
19000
|
-
|
19001
|
-
|
19693
|
+
fx = 0;
|
19694
|
+
memset(g, 0, sizeof(float)*nx);
|
19695
|
+
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
19696
|
+
if (callback) {
|
19697
|
+
// LBFG-S does not support learning rate -> ignore learning schedule
|
19698
|
+
float sched = 0;
|
19699
|
+
callback(callback_data, accum_step, &sched, &cancel);
|
19700
|
+
if (cancel) {
|
19701
|
+
break;
|
19702
|
+
}
|
19703
|
+
}
|
19704
|
+
// ggml_graph_reset (gf);
|
19705
|
+
ggml_set_f32 (f->grad, 1.0f);
|
19706
|
+
ggml_graph_compute(gb, &cplan);
|
19707
|
+
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
19708
|
+
fx += ggml_get_f32_1d(f, 0);
|
19709
|
+
}
|
19710
|
+
if (cancel) {
|
19711
|
+
return GGML_OPT_DID_NOT_CONVERGE;
|
19712
|
+
}
|
19713
|
+
fx *= accum_norm;
|
19002
19714
|
|
19003
19715
|
opt->loss_before = fx;
|
19004
19716
|
opt->loss_after = fx;
|
@@ -19056,7 +19768,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19056
19768
|
ggml_vec_cpy_f32(nx, xp, x);
|
19057
19769
|
ggml_vec_cpy_f32(nx, gp, g);
|
19058
19770
|
|
19059
|
-
ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f,
|
19771
|
+
ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
|
19772
|
+
if (!cancel) {
|
19773
|
+
break;
|
19774
|
+
}
|
19060
19775
|
|
19061
19776
|
if (ls < 0) {
|
19062
19777
|
// linesearch failed - go back to the previous point and return
|
@@ -19165,7 +19880,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
19165
19880
|
step[0] = 1.0;
|
19166
19881
|
}
|
19167
19882
|
|
19168
|
-
|
19883
|
+
GGML_UNREACHABLE();
|
19169
19884
|
}
|
19170
19885
|
|
19171
19886
|
struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
@@ -19185,6 +19900,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
19185
19900
|
.print_forward_graph = true,
|
19186
19901
|
.print_backward_graph = true,
|
19187
19902
|
|
19903
|
+
.n_gradient_accumulation = 1,
|
19904
|
+
|
19188
19905
|
.adam = {
|
19189
19906
|
.n_iter = 10000,
|
19190
19907
|
.sched = 1.000f,
|
@@ -19213,6 +19930,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
19213
19930
|
.print_forward_graph = true,
|
19214
19931
|
.print_backward_graph = true,
|
19215
19932
|
|
19933
|
+
.n_gradient_accumulation = 1,
|
19934
|
+
|
19216
19935
|
.lbfgs = {
|
19217
19936
|
.m = 6,
|
19218
19937
|
.n_iter = 100,
|
@@ -19243,13 +19962,32 @@ GGML_API void ggml_opt_init(
|
|
19243
19962
|
opt->iter = 0;
|
19244
19963
|
opt->nx = nx;
|
19245
19964
|
opt->just_initialized = true;
|
19965
|
+
if (opt->ctx == NULL) {
|
19966
|
+
struct ggml_init_params ctx_opt_params;
|
19967
|
+
if (opt->params.type == GGML_OPT_ADAM) {
|
19968
|
+
ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
|
19969
|
+
if (opt->params.past > 0) {
|
19970
|
+
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
|
19971
|
+
}
|
19972
|
+
} else if (opt->params.type == GGML_OPT_LBFGS) {
|
19973
|
+
ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
|
19974
|
+
if (opt->params.past > 0) {
|
19975
|
+
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
|
19976
|
+
}
|
19977
|
+
}
|
19978
|
+
ctx_opt_params.mem_buffer = NULL;
|
19979
|
+
ctx_opt_params.no_alloc = false;
|
19980
|
+
|
19981
|
+
opt->ctx = ggml_init(ctx_opt_params);
|
19982
|
+
}
|
19246
19983
|
switch (opt->params.type) {
|
19247
19984
|
case GGML_OPT_ADAM:
|
19248
19985
|
{
|
19249
|
-
opt->adam.
|
19250
|
-
opt->adam.
|
19986
|
+
opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
19987
|
+
opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
19988
|
+
opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
19251
19989
|
opt->adam.pf = params.past > 0
|
19252
|
-
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
19990
|
+
? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
|
19253
19991
|
: NULL;
|
19254
19992
|
ggml_set_zero(opt->adam.m);
|
19255
19993
|
ggml_set_zero(opt->adam.v);
|
@@ -19259,18 +19997,18 @@ GGML_API void ggml_opt_init(
|
|
19259
19997
|
} break;
|
19260
19998
|
case GGML_OPT_LBFGS:
|
19261
19999
|
{
|
19262
|
-
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
19263
|
-
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
19264
|
-
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
19265
|
-
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
19266
|
-
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
20000
|
+
opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
20001
|
+
opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
20002
|
+
opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
20003
|
+
opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
20004
|
+
opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
19267
20005
|
opt->lbfgs.pf = params.past > 0
|
19268
|
-
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
20006
|
+
? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
|
19269
20007
|
: NULL;
|
19270
|
-
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
19271
|
-
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
19272
|
-
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
19273
|
-
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
20008
|
+
opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
|
20009
|
+
opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
|
20010
|
+
opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
20011
|
+
opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
19274
20012
|
ggml_set_zero(opt->lbfgs.x);
|
19275
20013
|
ggml_set_zero(opt->lbfgs.xp);
|
19276
20014
|
ggml_set_zero(opt->lbfgs.g);
|
@@ -19876,10 +20614,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|
19876
20614
|
} break;
|
19877
20615
|
case GGUF_TYPE_ARRAY:
|
19878
20616
|
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
|
19879
|
-
}
|
20617
|
+
}
|
19880
20618
|
} break;
|
19881
20619
|
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
|
19882
|
-
}
|
20620
|
+
}
|
19883
20621
|
|
19884
20622
|
if (!ok) {
|
19885
20623
|
break;
|
@@ -20155,78 +20893,94 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
|
|
20155
20893
|
return keyfound;
|
20156
20894
|
}
|
20157
20895
|
|
20158
|
-
const char * gguf_get_key(const struct gguf_context * ctx, int
|
20159
|
-
return ctx->kv[
|
20896
|
+
const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
|
20897
|
+
return ctx->kv[key_id].key.data;
|
20160
20898
|
}
|
20161
20899
|
|
20162
|
-
enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int
|
20163
|
-
return ctx->kv[
|
20900
|
+
enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
|
20901
|
+
return ctx->kv[key_id].type;
|
20164
20902
|
}
|
20165
20903
|
|
20166
|
-
enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int
|
20167
|
-
|
20904
|
+
enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
|
20905
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
|
20906
|
+
return ctx->kv[key_id].value.arr.type;
|
20168
20907
|
}
|
20169
20908
|
|
20170
|
-
const void * gguf_get_arr_data(const struct gguf_context * ctx, int
|
20171
|
-
|
20909
|
+
const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
|
20910
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
|
20911
|
+
return ctx->kv[key_id].value.arr.data;
|
20172
20912
|
}
|
20173
20913
|
|
20174
20914
|
const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
|
20915
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
|
20175
20916
|
struct gguf_kv * kv = &ctx->kv[key_id];
|
20176
20917
|
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
|
20177
20918
|
return str->data;
|
20178
20919
|
}
|
20179
20920
|
|
20180
|
-
int gguf_get_arr_n(const struct gguf_context * ctx, int
|
20181
|
-
|
20921
|
+
int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
|
20922
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
|
20923
|
+
return ctx->kv[key_id].value.arr.n;
|
20182
20924
|
}
|
20183
20925
|
|
20184
|
-
uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int
|
20185
|
-
|
20926
|
+
uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
|
20927
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
|
20928
|
+
return ctx->kv[key_id].value.uint8;
|
20186
20929
|
}
|
20187
20930
|
|
20188
|
-
int8_t gguf_get_val_i8(const struct gguf_context * ctx, int
|
20189
|
-
|
20931
|
+
int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
|
20932
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
|
20933
|
+
return ctx->kv[key_id].value.int8;
|
20190
20934
|
}
|
20191
20935
|
|
20192
|
-
uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int
|
20193
|
-
|
20936
|
+
uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
|
20937
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
|
20938
|
+
return ctx->kv[key_id].value.uint16;
|
20194
20939
|
}
|
20195
20940
|
|
20196
|
-
int16_t gguf_get_val_i16(const struct gguf_context * ctx, int
|
20197
|
-
|
20941
|
+
int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
|
20942
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
|
20943
|
+
return ctx->kv[key_id].value.int16;
|
20198
20944
|
}
|
20199
20945
|
|
20200
|
-
uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int
|
20201
|
-
|
20946
|
+
uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
|
20947
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
|
20948
|
+
return ctx->kv[key_id].value.uint32;
|
20202
20949
|
}
|
20203
20950
|
|
20204
|
-
int32_t gguf_get_val_i32(const struct gguf_context * ctx, int
|
20205
|
-
|
20951
|
+
int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
|
20952
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
|
20953
|
+
return ctx->kv[key_id].value.int32;
|
20206
20954
|
}
|
20207
20955
|
|
20208
|
-
float gguf_get_val_f32(const struct gguf_context * ctx, int
|
20209
|
-
|
20956
|
+
float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
|
20957
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
|
20958
|
+
return ctx->kv[key_id].value.float32;
|
20210
20959
|
}
|
20211
20960
|
|
20212
|
-
uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int
|
20213
|
-
|
20961
|
+
uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
|
20962
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
|
20963
|
+
return ctx->kv[key_id].value.uint64;
|
20214
20964
|
}
|
20215
20965
|
|
20216
|
-
int64_t gguf_get_val_i64(const struct gguf_context * ctx, int
|
20217
|
-
|
20966
|
+
int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
|
20967
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
|
20968
|
+
return ctx->kv[key_id].value.int64;
|
20218
20969
|
}
|
20219
20970
|
|
20220
|
-
double gguf_get_val_f64(const struct gguf_context * ctx, int
|
20221
|
-
|
20971
|
+
double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
|
20972
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
|
20973
|
+
return ctx->kv[key_id].value.float64;
|
20222
20974
|
}
|
20223
20975
|
|
20224
|
-
bool gguf_get_val_bool(const struct gguf_context * ctx, int
|
20225
|
-
|
20976
|
+
bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
|
20977
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
|
20978
|
+
return ctx->kv[key_id].value.bool_;
|
20226
20979
|
}
|
20227
20980
|
|
20228
|
-
const char * gguf_get_val_str
|
20229
|
-
|
20981
|
+
const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
|
20982
|
+
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
|
20983
|
+
return ctx->kv[key_id].value.str.data;
|
20230
20984
|
}
|
20231
20985
|
|
20232
20986
|
int gguf_get_n_tensors(const struct gguf_context * ctx) {
|
@@ -20591,10 +21345,10 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf *
|
|
20591
21345
|
} break;
|
20592
21346
|
case GGUF_TYPE_ARRAY:
|
20593
21347
|
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
|
20594
|
-
}
|
21348
|
+
}
|
20595
21349
|
} break;
|
20596
21350
|
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
|
20597
|
-
}
|
21351
|
+
}
|
20598
21352
|
}
|
20599
21353
|
|
20600
21354
|
// write tensor infos
|