llama_cpp 0.5.2 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -89,7 +89,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo
89
89
 
90
90
  static int pthread_join(pthread_t thread, void * unused) {
91
91
  (void) unused;
92
- return (int) WaitForSingleObject(thread, INFINITE);
92
+ int ret = (int) WaitForSingleObject(thread, INFINITE);
93
+ CloseHandle(thread);
94
+ return ret;
93
95
  }
94
96
 
95
97
  static int sched_yield (void) {
@@ -134,6 +136,7 @@ typedef void * thread_ret_t;
134
136
 
135
137
  #define GGML_SOFT_MAX_UNROLL 4
136
138
  #define GGML_VEC_DOT_UNROLL 2
139
+ #define GGML_VEC_MAD_UNROLL 32
137
140
 
138
141
  //
139
142
  // logging
@@ -242,18 +245,18 @@ inline static void * ggml_aligned_malloc(size_t size) {
242
245
  //
243
246
 
244
247
  #define GGML_TENSOR_UNARY_OP_LOCALS \
245
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
246
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \
247
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \
248
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
248
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
249
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
250
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
251
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
249
252
 
250
253
  #define GGML_TENSOR_BINARY_OP_LOCALS \
251
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
252
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \
253
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \
254
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \
255
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \
256
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
254
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
255
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
256
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
257
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
258
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
259
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
257
260
 
258
261
  #if defined(GGML_USE_ACCELERATE)
259
262
  #include <Accelerate/Accelerate.h>
@@ -1863,7 +1866,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1863
1866
  #define GGML_F16x8_ADD vaddq_f16
1864
1867
  #define GGML_F16x8_MUL vmulq_f16
1865
1868
  #define GGML_F16x8_REDUCE(res, x) \
1866
- { \
1869
+ do { \
1867
1870
  int offset = GGML_F16_ARR >> 1; \
1868
1871
  for (int i = 0; i < offset; ++i) { \
1869
1872
  x[i] = vaddq_f16(x[i], x[offset+i]); \
@@ -1879,7 +1882,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1879
1882
  const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1880
1883
  const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
1881
1884
  res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1882
- }
1885
+ } while (0)
1883
1886
 
1884
1887
  #define GGML_F16_VEC GGML_F16x8
1885
1888
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
@@ -1940,7 +1943,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1940
1943
  #define GGML_F32x8_ADD _mm256_add_ps
1941
1944
  #define GGML_F32x8_MUL _mm256_mul_ps
1942
1945
  #define GGML_F32x8_REDUCE(res, x) \
1943
- { \
1946
+ do { \
1944
1947
  int offset = GGML_F32_ARR >> 1; \
1945
1948
  for (int i = 0; i < offset; ++i) { \
1946
1949
  x[i] = _mm256_add_ps(x[i], x[offset+i]); \
@@ -1957,7 +1960,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1957
1960
  _mm256_extractf128_ps(x[0], 1)); \
1958
1961
  const __m128 t1 = _mm_hadd_ps(t0, t0); \
1959
1962
  res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
1960
- }
1963
+ } while (0)
1961
1964
  // TODO: is this optimal ?
1962
1965
 
1963
1966
  #define GGML_F32_VEC GGML_F32x8
@@ -3707,6 +3710,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
3707
3710
  #endif
3708
3711
  }
3709
3712
 
3713
+ // xs and vs are byte strides of x and v
3714
+ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
3715
+
3716
+ const float * restrict x[GGML_VEC_MAD_UNROLL];
3717
+ const float * restrict v[GGML_VEC_MAD_UNROLL];
3718
+
3719
+ for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
3720
+ x[i] = (const float *) ((const char *) xv + i*xs);
3721
+ v[i] = (const float *) ((const char *) vv + i*vs);
3722
+ }
3723
+
3724
+ #if defined(GGML_SIMD)
3725
+ const int np = (n & ~(GGML_F32_STEP - 1));
3726
+
3727
+ GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
3728
+
3729
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3730
+ vx[k] = GGML_F32_VEC_SET1(v[k][0]);
3731
+ }
3732
+
3733
+ GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
3734
+ GGML_F32_VEC ay[GGML_F32_ARR];
3735
+
3736
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
3737
+ for (int j = 0; j < GGML_F32_ARR; j++) {
3738
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
3739
+
3740
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3741
+ ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
3742
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
3743
+ }
3744
+
3745
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
3746
+ }
3747
+ }
3748
+
3749
+ // leftovers
3750
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3751
+ for (int i = np; i < n; ++i) {
3752
+ y[i] += x[k][i]*v[k][0];
3753
+ }
3754
+ }
3755
+ #else
3756
+ // scalar
3757
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3758
+ for (int i = 0; i < n; ++i) {
3759
+ y[i] += x[k][i]*v[k][0];
3760
+ }
3761
+ }
3762
+ #endif
3763
+ }
3764
+
3710
3765
  //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
3711
3766
  inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
3712
3767
  #if defined(GGML_USE_ACCELERATE)
@@ -4303,10 +4358,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
4303
4358
  }
4304
4359
 
4305
4360
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
4306
- size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
4307
- for (int i = 1; i < GGML_MAX_DIMS; ++i) {
4308
- nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
4361
+ size_t nbytes;
4362
+ size_t blck_size = ggml_blck_size(tensor->type);
4363
+ if (blck_size == 1) {
4364
+ nbytes = ggml_type_size(tensor->type);
4365
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
4366
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
4367
+ }
4309
4368
  }
4369
+ else {
4370
+ nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
4371
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
4372
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
4373
+ }
4374
+ }
4375
+
4310
4376
  return nbytes;
4311
4377
  }
4312
4378
 
@@ -4381,10 +4447,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
4381
4447
  static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
4382
4448
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4383
4449
 
4384
- return
4385
- (t0->ne[1] == t1->ne[1]) &&
4386
- (t0->ne[2] == t1->ne[2]) &&
4387
- (t0->ne[3] == t1->ne[3]);
4450
+ return (t0->ne[1] == t1->ne[1]) &&
4451
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
4452
+ (t1->ne[3]%t0->ne[3] == 0);
4388
4453
  }
4389
4454
 
4390
4455
  enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
@@ -5054,43 +5119,78 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
5054
5119
  return tensor;
5055
5120
  }
5056
5121
 
5122
+ void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
5123
+ const int64_t ne2 = tensor->ne[2];
5124
+ const int64_t ne1 = tensor->ne[1];
5125
+ const int64_t ne0 = tensor->ne[0];
5126
+
5127
+ const int64_t i3_ = (i/(ne2*ne1*ne0));
5128
+ const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
5129
+ const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
5130
+ const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
5131
+
5132
+ if (i0) {
5133
+ * i0 = i0_;
5134
+ }
5135
+ if (i1) {
5136
+ * i1 = i1_;
5137
+ }
5138
+ if (i2) {
5139
+ * i2 = i2_;
5140
+ }
5141
+ if (i3) {
5142
+ * i3 = i3_;
5143
+ }
5144
+ }
5145
+
5057
5146
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
5147
+ if (!ggml_is_contiguous(tensor)) {
5148
+ int64_t id[4] = { 0, 0, 0, 0 };
5149
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5150
+ return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
5151
+ }
5058
5152
  switch (tensor->type) {
5059
5153
  case GGML_TYPE_I8:
5060
5154
  {
5061
5155
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
5062
5156
  return ((int8_t *)(tensor->data))[i];
5063
- } break;
5157
+ }
5064
5158
  case GGML_TYPE_I16:
5065
5159
  {
5066
5160
  GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
5067
5161
  return ((int16_t *)(tensor->data))[i];
5068
- } break;
5162
+ }
5069
5163
  case GGML_TYPE_I32:
5070
5164
  {
5071
5165
  GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
5072
5166
  return ((int32_t *)(tensor->data))[i];
5073
- } break;
5167
+ }
5074
5168
  case GGML_TYPE_F16:
5075
5169
  {
5076
5170
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
5077
5171
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
5078
- } break;
5172
+ }
5079
5173
  case GGML_TYPE_F32:
5080
5174
  {
5081
5175
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
5082
5176
  return ((float *)(tensor->data))[i];
5083
- } break;
5177
+ }
5084
5178
  default:
5085
5179
  {
5086
5180
  GGML_ASSERT(false);
5087
- } break;
5181
+ }
5088
5182
  }
5089
5183
 
5090
5184
  return 0.0f;
5091
5185
  }
5092
5186
 
5093
5187
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
5188
+ if (!ggml_is_contiguous(tensor)) {
5189
+ int64_t id[4] = { 0, 0, 0, 0 };
5190
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5191
+ ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
5192
+ return;
5193
+ }
5094
5194
  switch (tensor->type) {
5095
5195
  case GGML_TYPE_I8:
5096
5196
  {
@@ -5124,43 +5224,104 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
5124
5224
  }
5125
5225
  }
5126
5226
 
5227
+ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
5228
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5229
+ switch (tensor->type) {
5230
+ case GGML_TYPE_I8:
5231
+ return ((int8_t *) data)[0];
5232
+ case GGML_TYPE_I16:
5233
+ return ((int16_t *) data)[0];
5234
+ case GGML_TYPE_I32:
5235
+ return ((int32_t *) data)[0];
5236
+ case GGML_TYPE_F16:
5237
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
5238
+ case GGML_TYPE_F32:
5239
+ return ((float *) data)[0];
5240
+ default:
5241
+ GGML_ASSERT(false);
5242
+ }
5243
+
5244
+ return 0.0f;
5245
+ }
5246
+
5247
+ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
5248
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5249
+ switch (tensor->type) {
5250
+ case GGML_TYPE_I8:
5251
+ {
5252
+ ((int8_t *)(data))[0] = value;
5253
+ } break;
5254
+ case GGML_TYPE_I16:
5255
+ {
5256
+ ((int16_t *)(data))[0] = value;
5257
+ } break;
5258
+ case GGML_TYPE_I32:
5259
+ {
5260
+ ((int32_t *)(data))[0] = value;
5261
+ } break;
5262
+ case GGML_TYPE_F16:
5263
+ {
5264
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
5265
+ } break;
5266
+ case GGML_TYPE_F32:
5267
+ {
5268
+ ((float *)(data))[0] = value;
5269
+ } break;
5270
+ default:
5271
+ {
5272
+ GGML_ASSERT(false);
5273
+ } break;
5274
+ }
5275
+ }
5276
+
5127
5277
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
5278
+ if (!ggml_is_contiguous(tensor)) {
5279
+ int64_t id[4] = { 0, 0, 0, 0 };
5280
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5281
+ return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
5282
+ }
5128
5283
  switch (tensor->type) {
5129
5284
  case GGML_TYPE_I8:
5130
5285
  {
5131
5286
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
5132
5287
  return ((int8_t *)(tensor->data))[i];
5133
- } break;
5288
+ }
5134
5289
  case GGML_TYPE_I16:
5135
5290
  {
5136
5291
  GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
5137
5292
  return ((int16_t *)(tensor->data))[i];
5138
- } break;
5293
+ }
5139
5294
  case GGML_TYPE_I32:
5140
5295
  {
5141
5296
  GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
5142
5297
  return ((int32_t *)(tensor->data))[i];
5143
- } break;
5298
+ }
5144
5299
  case GGML_TYPE_F16:
5145
5300
  {
5146
5301
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
5147
5302
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
5148
- } break;
5303
+ }
5149
5304
  case GGML_TYPE_F32:
5150
5305
  {
5151
5306
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
5152
5307
  return ((float *)(tensor->data))[i];
5153
- } break;
5308
+ }
5154
5309
  default:
5155
5310
  {
5156
5311
  GGML_ASSERT(false);
5157
- } break;
5312
+ }
5158
5313
  }
5159
5314
 
5160
5315
  return 0.0f;
5161
5316
  }
5162
5317
 
5163
5318
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
5319
+ if (!ggml_is_contiguous(tensor)) {
5320
+ int64_t id[4] = { 0, 0, 0, 0 };
5321
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5322
+ ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
5323
+ return;
5324
+ }
5164
5325
  switch (tensor->type) {
5165
5326
  case GGML_TYPE_I8:
5166
5327
  {
@@ -5194,6 +5355,56 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
5194
5355
  }
5195
5356
  }
5196
5357
 
5358
+ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
5359
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5360
+ switch (tensor->type) {
5361
+ case GGML_TYPE_I8:
5362
+ return ((int8_t *) data)[0];
5363
+ case GGML_TYPE_I16:
5364
+ return ((int16_t *) data)[0];
5365
+ case GGML_TYPE_I32:
5366
+ return ((int32_t *) data)[0];
5367
+ case GGML_TYPE_F16:
5368
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
5369
+ case GGML_TYPE_F32:
5370
+ return ((float *) data)[0];
5371
+ default:
5372
+ GGML_ASSERT(false);
5373
+ }
5374
+
5375
+ return 0.0f;
5376
+ }
5377
+
5378
+ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
5379
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5380
+ switch (tensor->type) {
5381
+ case GGML_TYPE_I8:
5382
+ {
5383
+ ((int8_t *)(data))[0] = value;
5384
+ } break;
5385
+ case GGML_TYPE_I16:
5386
+ {
5387
+ ((int16_t *)(data))[0] = value;
5388
+ } break;
5389
+ case GGML_TYPE_I32:
5390
+ {
5391
+ ((int32_t *)(data))[0] = value;
5392
+ } break;
5393
+ case GGML_TYPE_F16:
5394
+ {
5395
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
5396
+ } break;
5397
+ case GGML_TYPE_F32:
5398
+ {
5399
+ ((float *)(data))[0] = value;
5400
+ } break;
5401
+ default:
5402
+ {
5403
+ GGML_ASSERT(false);
5404
+ } break;
5405
+ }
5406
+ }
5407
+
5197
5408
  void * ggml_get_data(const struct ggml_tensor * tensor) {
5198
5409
  return tensor->data;
5199
5410
  }
@@ -5336,6 +5547,44 @@ struct ggml_tensor * ggml_add_inplace(
5336
5547
  return ggml_add_impl(ctx, a, b, true);
5337
5548
  }
5338
5549
 
5550
+ // ggml_add_cast
5551
+
5552
+ static struct ggml_tensor * ggml_add_cast_impl(
5553
+ struct ggml_context * ctx,
5554
+ struct ggml_tensor * a,
5555
+ struct ggml_tensor * b,
5556
+ enum ggml_type type) {
5557
+ // TODO: support less-strict constraint
5558
+ // GGML_ASSERT(ggml_can_repeat(b, a));
5559
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
5560
+ GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
5561
+
5562
+ bool is_node = false;
5563
+
5564
+ if (a->grad || b->grad) {
5565
+ // TODO: support backward pass for broadcasting
5566
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5567
+ is_node = true;
5568
+ }
5569
+
5570
+ struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
5571
+
5572
+ result->op = GGML_OP_ADD;
5573
+ result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
5574
+ result->src[0] = a;
5575
+ result->src[1] = b;
5576
+
5577
+ return result;
5578
+ }
5579
+
5580
+ struct ggml_tensor * ggml_add_cast(
5581
+ struct ggml_context * ctx,
5582
+ struct ggml_tensor * a,
5583
+ struct ggml_tensor * b,
5584
+ enum ggml_type type) {
5585
+ return ggml_add_cast_impl(ctx, a, b, type);
5586
+ }
5587
+
5339
5588
  // ggml_add1
5340
5589
 
5341
5590
  static struct ggml_tensor * ggml_add1_impl(
@@ -5772,7 +6021,6 @@ struct ggml_tensor * ggml_repeat(
5772
6021
  result->op = GGML_OP_REPEAT;
5773
6022
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5774
6023
  result->src[0] = a;
5775
- result->src[1] = b;
5776
6024
 
5777
6025
  return result;
5778
6026
  }
@@ -5800,7 +6048,6 @@ struct ggml_tensor * ggml_repeat_back(
5800
6048
  result->op = GGML_OP_REPEAT_BACK;
5801
6049
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5802
6050
  result->src[0] = a;
5803
- result->src[1] = b;
5804
6051
 
5805
6052
  return result;
5806
6053
  }
@@ -6175,8 +6422,9 @@ struct ggml_tensor * ggml_out_prod(
6175
6422
  is_node = true;
6176
6423
  }
6177
6424
 
6178
- const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
6179
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
6425
+ // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
6426
+ const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
6427
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
6180
6428
 
6181
6429
  result->op = GGML_OP_OUT_PROD;
6182
6430
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6395,6 +6643,54 @@ struct ggml_tensor * ggml_cont_inplace(
6395
6643
  return ggml_cont_impl(ctx, a, true);
6396
6644
  }
6397
6645
 
6646
+
6647
+ // make contiguous, with new shape
6648
+ GGML_API struct ggml_tensor * ggml_cont_1d(
6649
+ struct ggml_context * ctx,
6650
+ struct ggml_tensor * a,
6651
+ int64_t ne0) {
6652
+ return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
6653
+ }
6654
+
6655
+ GGML_API struct ggml_tensor * ggml_cont_2d(
6656
+ struct ggml_context * ctx,
6657
+ struct ggml_tensor * a,
6658
+ int64_t ne0,
6659
+ int64_t ne1) {
6660
+ return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
6661
+ }
6662
+
6663
+ GGML_API struct ggml_tensor * ggml_cont_3d(
6664
+ struct ggml_context * ctx,
6665
+ struct ggml_tensor * a,
6666
+ int64_t ne0,
6667
+ int64_t ne1,
6668
+ int64_t ne2) {
6669
+ return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
6670
+ }
6671
+
6672
+ struct ggml_tensor * ggml_cont_4d(
6673
+ struct ggml_context * ctx,
6674
+ struct ggml_tensor * a,
6675
+ int64_t ne0,
6676
+ int64_t ne1,
6677
+ int64_t ne2,
6678
+ int64_t ne3) {
6679
+ GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
6680
+
6681
+ bool is_node = false;
6682
+
6683
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
6684
+ ggml_format_name(result, "%s (cont)", a->name);
6685
+
6686
+ result->op = GGML_OP_CONT;
6687
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6688
+ result->src[0] = a;
6689
+
6690
+ return result;
6691
+ }
6692
+
6693
+
6398
6694
  // ggml_reshape
6399
6695
 
6400
6696
  struct ggml_tensor * ggml_reshape(
@@ -6402,7 +6698,7 @@ struct ggml_tensor * ggml_reshape(
6402
6698
  struct ggml_tensor * a,
6403
6699
  struct ggml_tensor * b) {
6404
6700
  GGML_ASSERT(ggml_is_contiguous(a));
6405
- GGML_ASSERT(ggml_is_contiguous(b));
6701
+ // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
6406
6702
  GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
6407
6703
 
6408
6704
  bool is_node = false;
@@ -6775,7 +7071,6 @@ struct ggml_tensor * ggml_get_rows_back(
6775
7071
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6776
7072
  result->src[0] = a;
6777
7073
  result->src[1] = b;
6778
- result->src[2] = c;
6779
7074
 
6780
7075
  return result;
6781
7076
  }
@@ -6957,7 +7252,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
6957
7252
  static struct ggml_tensor * ggml_rope_impl(
6958
7253
  struct ggml_context * ctx,
6959
7254
  struct ggml_tensor * a,
6960
- int n_past,
7255
+ struct ggml_tensor * b,
6961
7256
  int n_dims,
6962
7257
  int mode,
6963
7258
  int n_ctx,
@@ -6966,7 +7261,10 @@ static struct ggml_tensor * ggml_rope_impl(
6966
7261
  float xpos_base,
6967
7262
  bool xpos_down,
6968
7263
  bool inplace) {
6969
- GGML_ASSERT(n_past >= 0);
7264
+ GGML_ASSERT(ggml_is_vector(b));
7265
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
7266
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
7267
+
6970
7268
  bool is_node = false;
6971
7269
 
6972
7270
  if (a->grad) {
@@ -6975,7 +7273,7 @@ static struct ggml_tensor * ggml_rope_impl(
6975
7273
 
6976
7274
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6977
7275
 
6978
- int32_t params[8] = { n_past, n_dims, mode, n_ctx };
7276
+ int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
6979
7277
  memcpy(params + 4, &freq_base, sizeof(float));
6980
7278
  memcpy(params + 5, &freq_scale, sizeof(float));
6981
7279
  memcpy(params + 6, &xpos_base, sizeof(float));
@@ -6985,6 +7283,7 @@ static struct ggml_tensor * ggml_rope_impl(
6985
7283
  result->op = GGML_OP_ROPE;
6986
7284
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6987
7285
  result->src[0] = a;
7286
+ result->src[1] = b;
6988
7287
 
6989
7288
  return result;
6990
7289
  }
@@ -6992,55 +7291,55 @@ static struct ggml_tensor * ggml_rope_impl(
6992
7291
  struct ggml_tensor * ggml_rope(
6993
7292
  struct ggml_context * ctx,
6994
7293
  struct ggml_tensor * a,
6995
- int n_past,
7294
+ struct ggml_tensor * b,
6996
7295
  int n_dims,
6997
7296
  int mode,
6998
7297
  int n_ctx) {
6999
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
7298
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
7000
7299
  }
7001
7300
 
7002
7301
  struct ggml_tensor * ggml_rope_inplace(
7003
7302
  struct ggml_context * ctx,
7004
7303
  struct ggml_tensor * a,
7005
- int n_past,
7304
+ struct ggml_tensor * b,
7006
7305
  int n_dims,
7007
7306
  int mode,
7008
7307
  int n_ctx) {
7009
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
7308
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
7010
7309
  }
7011
7310
 
7012
7311
  struct ggml_tensor * ggml_rope_custom(
7013
7312
  struct ggml_context * ctx,
7014
7313
  struct ggml_tensor * a,
7015
- int n_past,
7314
+ struct ggml_tensor * b,
7016
7315
  int n_dims,
7017
7316
  int mode,
7018
7317
  int n_ctx,
7019
7318
  float freq_base,
7020
7319
  float freq_scale) {
7021
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
7320
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
7022
7321
  }
7023
7322
 
7024
7323
  struct ggml_tensor * ggml_rope_custom_inplace(
7025
7324
  struct ggml_context * ctx,
7026
7325
  struct ggml_tensor * a,
7027
- int n_past,
7326
+ struct ggml_tensor * b,
7028
7327
  int n_dims,
7029
7328
  int mode,
7030
7329
  int n_ctx,
7031
7330
  float freq_base,
7032
7331
  float freq_scale) {
7033
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
7332
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
7034
7333
  }
7035
7334
 
7036
7335
  struct ggml_tensor * ggml_rope_xpos_inplace(
7037
7336
  struct ggml_context * ctx,
7038
7337
  struct ggml_tensor * a,
7039
- int n_past,
7338
+ struct ggml_tensor * b,
7040
7339
  int n_dims,
7041
7340
  float base,
7042
7341
  bool down) {
7043
- return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
7342
+ return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
7044
7343
  }
7045
7344
 
7046
7345
  // ggml_rope_back
@@ -7048,7 +7347,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
7048
7347
  struct ggml_tensor * ggml_rope_back(
7049
7348
  struct ggml_context * ctx,
7050
7349
  struct ggml_tensor * a,
7051
- int n_past,
7350
+ struct ggml_tensor * b,
7052
7351
  int n_dims,
7053
7352
  int mode,
7054
7353
  int n_ctx,
@@ -7056,7 +7355,10 @@ struct ggml_tensor * ggml_rope_back(
7056
7355
  float freq_scale,
7057
7356
  float xpos_base,
7058
7357
  bool xpos_down) {
7059
- GGML_ASSERT(n_past >= 0);
7358
+ GGML_ASSERT(ggml_is_vector(b));
7359
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
7360
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
7361
+
7060
7362
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
7061
7363
 
7062
7364
  bool is_node = false;
@@ -7067,7 +7369,7 @@ struct ggml_tensor * ggml_rope_back(
7067
7369
 
7068
7370
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
7069
7371
 
7070
- int32_t params[8] = { n_past, n_dims, mode, n_ctx };
7372
+ int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
7071
7373
  memcpy(params + 4, &freq_base, sizeof(float));
7072
7374
  memcpy(params + 5, &freq_scale, sizeof(float));
7073
7375
  memcpy(params + 6, &xpos_base, sizeof(float));
@@ -7077,6 +7379,7 @@ struct ggml_tensor * ggml_rope_back(
7077
7379
  result->op = GGML_OP_ROPE_BACK;
7078
7380
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7079
7381
  result->src[0] = a;
7382
+ result->src[1] = b;
7080
7383
 
7081
7384
  return result;
7082
7385
  }
@@ -7473,27 +7776,30 @@ struct ggml_tensor * ggml_flash_attn_back(
7473
7776
 
7474
7777
  // d shape [D,N,ne2,ne3]
7475
7778
  // q shape [D,N,ne2,ne3]
7476
- // k shape [D,M,ne2,ne3]
7477
- // v shape [M,D,ne2,ne3]
7779
+ // k shape [D,M,kvne2,ne3]
7780
+ // v shape [M,D,kvne2,ne3]
7478
7781
 
7479
- const int64_t D = q->ne[0];
7480
- const int64_t N = q->ne[1];
7481
- const int64_t M = k->ne[1];
7482
- const int64_t ne2 = q->ne[2];
7483
- const int64_t ne3 = q->ne[3];
7782
+ const int64_t D = q->ne[0];
7783
+ const int64_t N = q->ne[1];
7784
+ const int64_t M = k->ne[1];
7785
+ const int64_t ne2 = q->ne[2];
7786
+ const int64_t ne3 = q->ne[3];
7787
+ const int64_t kvne2 = k->ne[2];
7484
7788
 
7485
7789
  GGML_ASSERT(k->ne[0] == D);
7486
7790
  GGML_ASSERT(v->ne[0] == M);
7487
7791
  GGML_ASSERT(v->ne[1] == D);
7488
7792
  GGML_ASSERT(d->ne[0] == D);
7489
7793
  GGML_ASSERT(d->ne[1] == N);
7490
- GGML_ASSERT(k->ne[2] == ne2);
7794
+ GGML_ASSERT(k->ne[2] == kvne2);
7491
7795
  GGML_ASSERT(k->ne[3] == ne3);
7492
- GGML_ASSERT(v->ne[2] == ne2);
7796
+ GGML_ASSERT(v->ne[2] == kvne2);
7493
7797
  GGML_ASSERT(v->ne[3] == ne3);
7494
7798
  GGML_ASSERT(d->ne[2] == ne2);
7495
7799
  GGML_ASSERT(d->ne[3] == ne3);
7496
7800
 
7801
+ GGML_ASSERT(ne2 % kvne2 == 0);
7802
+
7497
7803
  bool is_node = false;
7498
7804
 
7499
7805
  if (q->grad || k->grad || v->grad) {
@@ -7503,14 +7809,23 @@ struct ggml_tensor * ggml_flash_attn_back(
7503
7809
  }
7504
7810
 
7505
7811
  // store gradients of q, k and v as continuous tensors concatenated in result.
7506
- // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
7507
- // gradq->data = result->data
7508
- // gradk->data = result->data + nb0*D*N*ne2*ne3
7509
- // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
7510
7812
  // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
7511
- int64_t ne[4] = {D,M+N+M,ne2,ne3};
7813
+ const int64_t elem_q = ggml_nelements(q);
7814
+ const int64_t elem_k = ggml_nelements(k);
7815
+ const int64_t elem_v = ggml_nelements(v);
7512
7816
 
7513
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7817
+ enum ggml_type result_type = GGML_TYPE_F32;
7818
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
7819
+ const size_t tsize = ggml_type_size(result_type);
7820
+
7821
+ const size_t offs_q = 0;
7822
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
7823
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
7824
+ const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
7825
+
7826
+ const size_t nelements = (end + tsize - 1)/tsize;
7827
+
7828
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
7514
7829
 
7515
7830
  int32_t masked_i = masked ? 1 : 0;
7516
7831
  ggml_set_op_params(result, &masked_i, sizeof(masked_i));
@@ -8203,7 +8518,7 @@ static void ggml_compute_forward_dup_f16(
8203
8518
  return;
8204
8519
  }
8205
8520
 
8206
- GGML_TENSOR_UNARY_OP_LOCALS;
8521
+ GGML_TENSOR_UNARY_OP_LOCALS
8207
8522
 
8208
8523
  const int ith = params->ith; // thread index
8209
8524
  const int nth = params->nth; // number of threads
@@ -8474,7 +8789,7 @@ static void ggml_compute_forward_dup_f32(
8474
8789
  return;
8475
8790
  }
8476
8791
 
8477
- GGML_TENSOR_UNARY_OP_LOCALS;
8792
+ GGML_TENSOR_UNARY_OP_LOCALS
8478
8793
 
8479
8794
  const int ith = params->ith; // thread index
8480
8795
  const int nth = params->nth; // number of threads
@@ -8755,7 +9070,7 @@ static void ggml_compute_forward_add_f32(
8755
9070
 
8756
9071
  const int nr = ggml_nrows(src0);
8757
9072
 
8758
- GGML_TENSOR_BINARY_OP_LOCALS;
9073
+ GGML_TENSOR_BINARY_OP_LOCALS
8759
9074
 
8760
9075
  GGML_ASSERT( nb0 == sizeof(float));
8761
9076
  GGML_ASSERT(nb00 == sizeof(float));
@@ -8787,8 +9102,6 @@ static void ggml_compute_forward_add_f32(
8787
9102
  #else
8788
9103
  ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
8789
9104
  #endif
8790
- // }
8791
- // }
8792
9105
  }
8793
9106
  } else {
8794
9107
  // src1 is not contiguous
@@ -8830,7 +9143,7 @@ static void ggml_compute_forward_add_f16_f32(
8830
9143
 
8831
9144
  const int nr = ggml_nrows(src0);
8832
9145
 
8833
- GGML_TENSOR_BINARY_OP_LOCALS;
9146
+ GGML_TENSOR_BINARY_OP_LOCALS
8834
9147
 
8835
9148
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
8836
9149
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -8884,7 +9197,7 @@ static void ggml_compute_forward_add_f16_f16(
8884
9197
 
8885
9198
  const int nr = ggml_nrows(src0);
8886
9199
 
8887
- GGML_TENSOR_BINARY_OP_LOCALS;
9200
+ GGML_TENSOR_BINARY_OP_LOCALS
8888
9201
 
8889
9202
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
8890
9203
  GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -8935,14 +9248,15 @@ static void ggml_compute_forward_add_q_f32(
8935
9248
 
8936
9249
  const int nr = ggml_nrows(src0);
8937
9250
 
8938
- GGML_TENSOR_BINARY_OP_LOCALS;
9251
+ GGML_TENSOR_BINARY_OP_LOCALS
8939
9252
 
8940
9253
  const int ith = params->ith;
8941
9254
  const int nth = params->nth;
8942
9255
 
8943
9256
  const enum ggml_type type = src0->type;
9257
+ const enum ggml_type dtype = dst->type;
8944
9258
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
8945
- ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
9259
+ ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
8946
9260
 
8947
9261
  // we don't support permuted src0 or src1
8948
9262
  GGML_ASSERT(nb00 == ggml_type_size(type));
@@ -8954,7 +9268,6 @@ static void ggml_compute_forward_add_q_f32(
8954
9268
  GGML_ASSERT(nb2 <= nb3);
8955
9269
 
8956
9270
  GGML_ASSERT(ggml_is_quantized(src0->type));
8957
- GGML_ASSERT(dst->type == src0->type);
8958
9271
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
8959
9272
 
8960
9273
  // rows per thread
@@ -8992,7 +9305,11 @@ static void ggml_compute_forward_add_q_f32(
8992
9305
  // add src1
8993
9306
  ggml_vec_acc_f32(ne00, wdata, src1_row);
8994
9307
  // quantize row to dst
8995
- quantize_row_q(wdata, dst_row, ne00);
9308
+ if (quantize_row_q != NULL) {
9309
+ quantize_row_q(wdata, dst_row, ne00);
9310
+ } else {
9311
+ memcpy(dst_row, wdata, ne0*nb0);
9312
+ }
8996
9313
  }
8997
9314
  }
8998
9315
 
@@ -9057,7 +9374,7 @@ static void ggml_compute_forward_add1_f32(
9057
9374
 
9058
9375
  const int nr = ggml_nrows(src0);
9059
9376
 
9060
- GGML_TENSOR_UNARY_OP_LOCALS;
9377
+ GGML_TENSOR_UNARY_OP_LOCALS
9061
9378
 
9062
9379
  GGML_ASSERT( nb0 == sizeof(float));
9063
9380
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9112,7 +9429,7 @@ static void ggml_compute_forward_add1_f16_f32(
9112
9429
 
9113
9430
  const int nr = ggml_nrows(src0);
9114
9431
 
9115
- GGML_TENSOR_UNARY_OP_LOCALS;
9432
+ GGML_TENSOR_UNARY_OP_LOCALS
9116
9433
 
9117
9434
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9118
9435
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -9162,7 +9479,7 @@ static void ggml_compute_forward_add1_f16_f16(
9162
9479
 
9163
9480
  const int nr = ggml_nrows(src0);
9164
9481
 
9165
- GGML_TENSOR_UNARY_OP_LOCALS;
9482
+ GGML_TENSOR_UNARY_OP_LOCALS
9166
9483
 
9167
9484
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9168
9485
  GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -9212,7 +9529,7 @@ static void ggml_compute_forward_add1_q_f32(
9212
9529
 
9213
9530
  const int nr = ggml_nrows(src0);
9214
9531
 
9215
- GGML_TENSOR_UNARY_OP_LOCALS;
9532
+ GGML_TENSOR_UNARY_OP_LOCALS
9216
9533
 
9217
9534
  const enum ggml_type type = src0->type;
9218
9535
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
@@ -9340,8 +9657,8 @@ static void ggml_compute_forward_acc_f32(
9340
9657
  const int nr = ggml_nrows(src1);
9341
9658
  const int nc = src1->ne[0];
9342
9659
 
9343
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
9344
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
9660
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
9661
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
9345
9662
 
9346
9663
  // src0 and dst as viewed during acc
9347
9664
  const size_t nb0 = ggml_element_size(src0);
@@ -9430,7 +9747,7 @@ static void ggml_compute_forward_sub_f32(
9430
9747
 
9431
9748
  const int nr = ggml_nrows(src0);
9432
9749
 
9433
- GGML_TENSOR_BINARY_OP_LOCALS;
9750
+ GGML_TENSOR_BINARY_OP_LOCALS
9434
9751
 
9435
9752
  GGML_ASSERT( nb0 == sizeof(float));
9436
9753
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9520,7 +9837,7 @@ static void ggml_compute_forward_mul_f32(
9520
9837
 
9521
9838
  const int64_t nr = ggml_nrows(src0);
9522
9839
 
9523
- GGML_TENSOR_BINARY_OP_LOCALS;
9840
+ GGML_TENSOR_BINARY_OP_LOCALS
9524
9841
 
9525
9842
  GGML_ASSERT( nb0 == sizeof(float));
9526
9843
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9611,7 +9928,7 @@ static void ggml_compute_forward_div_f32(
9611
9928
 
9612
9929
  const int nr = ggml_nrows(src0);
9613
9930
 
9614
- GGML_TENSOR_BINARY_OP_LOCALS;
9931
+ GGML_TENSOR_BINARY_OP_LOCALS
9615
9932
 
9616
9933
  GGML_ASSERT( nb0 == sizeof(float));
9617
9934
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9820,8 +10137,8 @@ static void ggml_compute_forward_sum_f32(
9820
10137
  assert(ggml_is_scalar(dst));
9821
10138
  assert(src0->nb[0] == sizeof(float));
9822
10139
 
9823
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
9824
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb);
10140
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10141
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
9825
10142
 
9826
10143
  ggml_float sum = 0;
9827
10144
  ggml_float row_sum = 0;
@@ -9852,8 +10169,8 @@ static void ggml_compute_forward_sum_f16(
9852
10169
 
9853
10170
  assert(src0->nb[0] == sizeof(ggml_fp16_t));
9854
10171
 
9855
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
9856
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb);
10172
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10173
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
9857
10174
 
9858
10175
  float sum = 0;
9859
10176
  float row_sum = 0;
@@ -9906,7 +10223,7 @@ static void ggml_compute_forward_sum_rows_f32(
9906
10223
  GGML_ASSERT(src0->nb[0] == sizeof(float));
9907
10224
  GGML_ASSERT(dst->nb[0] == sizeof(float));
9908
10225
 
9909
- GGML_TENSOR_UNARY_OP_LOCALS;
10226
+ GGML_TENSOR_UNARY_OP_LOCALS
9910
10227
 
9911
10228
  GGML_ASSERT(ne0 == 1);
9912
10229
  GGML_ASSERT(ne1 == ne01);
@@ -9956,7 +10273,7 @@ static void ggml_compute_forward_mean_f32(
9956
10273
 
9957
10274
  assert(src0->nb[0] == sizeof(float));
9958
10275
 
9959
- GGML_TENSOR_UNARY_OP_LOCALS;
10276
+ GGML_TENSOR_UNARY_OP_LOCALS
9960
10277
 
9961
10278
  assert(ne0 == 1);
9962
10279
  assert(ne1 == ne01);
@@ -10056,7 +10373,7 @@ static void ggml_compute_forward_repeat_f32(
10056
10373
  return;
10057
10374
  }
10058
10375
 
10059
- GGML_TENSOR_UNARY_OP_LOCALS;
10376
+ GGML_TENSOR_UNARY_OP_LOCALS
10060
10377
 
10061
10378
  // guaranteed to be an integer due to the check in ggml_can_repeat
10062
10379
  const int nr0 = (int)(ne0/ne00);
@@ -10088,11 +10405,61 @@ static void ggml_compute_forward_repeat_f32(
10088
10405
  }
10089
10406
  }
10090
10407
 
10408
+ static void ggml_compute_forward_repeat_f16(
10409
+ const struct ggml_compute_params * params,
10410
+ const struct ggml_tensor * src0,
10411
+ struct ggml_tensor * dst) {
10412
+ GGML_ASSERT(params->ith == 0);
10413
+ GGML_ASSERT(ggml_can_repeat(src0, dst));
10414
+
10415
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10416
+ return;
10417
+ }
10418
+
10419
+ GGML_TENSOR_UNARY_OP_LOCALS;
10420
+
10421
+ // guaranteed to be an integer due to the check in ggml_can_repeat
10422
+ const int nr0 = (int)(ne0/ne00);
10423
+ const int nr1 = (int)(ne1/ne01);
10424
+ const int nr2 = (int)(ne2/ne02);
10425
+ const int nr3 = (int)(ne3/ne03);
10426
+
10427
+ // TODO: support for transposed / permuted tensors
10428
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
10429
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
10430
+
10431
+ // TODO: maybe this is not optimal?
10432
+ for (int i3 = 0; i3 < nr3; i3++) {
10433
+ for (int k3 = 0; k3 < ne03; k3++) {
10434
+ for (int i2 = 0; i2 < nr2; i2++) {
10435
+ for (int k2 = 0; k2 < ne02; k2++) {
10436
+ for (int i1 = 0; i1 < nr1; i1++) {
10437
+ for (int k1 = 0; k1 < ne01; k1++) {
10438
+ for (int i0 = 0; i0 < nr0; i0++) {
10439
+ ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
10440
+ ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
10441
+ // ggml_vec_cpy_f16(ne00, y, x)
10442
+ for (int i = 0; i < ne00; ++i) {
10443
+ y[i] = x[i];
10444
+ }
10445
+ }
10446
+ }
10447
+ }
10448
+ }
10449
+ }
10450
+ }
10451
+ }
10452
+ }
10453
+
10091
10454
  static void ggml_compute_forward_repeat(
10092
10455
  const struct ggml_compute_params * params,
10093
10456
  const struct ggml_tensor * src0,
10094
10457
  struct ggml_tensor * dst) {
10095
10458
  switch (src0->type) {
10459
+ case GGML_TYPE_F16:
10460
+ {
10461
+ ggml_compute_forward_repeat_f16(params, src0, dst);
10462
+ } break;
10096
10463
  case GGML_TYPE_F32:
10097
10464
  {
10098
10465
  ggml_compute_forward_repeat_f32(params, src0, dst);
@@ -10117,7 +10484,7 @@ static void ggml_compute_forward_repeat_back_f32(
10117
10484
  return;
10118
10485
  }
10119
10486
 
10120
- GGML_TENSOR_UNARY_OP_LOCALS;
10487
+ GGML_TENSOR_UNARY_OP_LOCALS
10121
10488
 
10122
10489
  // guaranteed to be an integer due to the check in ggml_can_repeat
10123
10490
  const int nr0 = (int)(ne00/ne0);
@@ -10195,7 +10562,7 @@ static void ggml_compute_forward_concat_f32(
10195
10562
 
10196
10563
  const int ith = params->ith;
10197
10564
 
10198
- GGML_TENSOR_BINARY_OP_LOCALS;
10565
+ GGML_TENSOR_BINARY_OP_LOCALS
10199
10566
 
10200
10567
  // TODO: support for transposed / permuted tensors
10201
10568
  GGML_ASSERT(nb0 == sizeof(float));
@@ -10797,7 +11164,7 @@ static void ggml_compute_forward_norm_f32(
10797
11164
  const int ith = params->ith;
10798
11165
  const int nth = params->nth;
10799
11166
 
10800
- GGML_TENSOR_UNARY_OP_LOCALS;
11167
+ GGML_TENSOR_UNARY_OP_LOCALS
10801
11168
 
10802
11169
  float eps;
10803
11170
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -10866,7 +11233,7 @@ static void ggml_compute_forward_rms_norm_f32(
10866
11233
  const int ith = params->ith;
10867
11234
  const int nth = params->nth;
10868
11235
 
10869
- GGML_TENSOR_UNARY_OP_LOCALS;
11236
+ GGML_TENSOR_UNARY_OP_LOCALS
10870
11237
 
10871
11238
  float eps;
10872
11239
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -10931,7 +11298,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
10931
11298
  const int ith = params->ith;
10932
11299
  const int nth = params->nth;
10933
11300
 
10934
- GGML_TENSOR_BINARY_OP_LOCALS;
11301
+ GGML_TENSOR_BINARY_OP_LOCALS
10935
11302
 
10936
11303
  float eps;
10937
11304
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -11106,7 +11473,7 @@ static void ggml_compute_forward_group_norm_f32(
11106
11473
  const int ith = params->ith;
11107
11474
  const int nth = params->nth;
11108
11475
 
11109
- GGML_TENSOR_UNARY_OP_LOCALS;
11476
+ GGML_TENSOR_UNARY_OP_LOCALS
11110
11477
 
11111
11478
  const float eps = 1e-6f; // TODO: make this a parameter
11112
11479
 
@@ -11217,7 +11584,7 @@ static void ggml_compute_forward_mul_mat(
11217
11584
  int64_t t0 = ggml_perf_time_us();
11218
11585
  UNUSED(t0);
11219
11586
 
11220
- GGML_TENSOR_BINARY_OP_LOCALS;
11587
+ GGML_TENSOR_BINARY_OP_LOCALS
11221
11588
 
11222
11589
  const int ith = params->ith;
11223
11590
  const int nth = params->nth;
@@ -11432,10 +11799,10 @@ static void ggml_compute_forward_out_prod_f32(
11432
11799
  const struct ggml_tensor * src0,
11433
11800
  const struct ggml_tensor * src1,
11434
11801
  struct ggml_tensor * dst) {
11435
- int64_t t0 = ggml_perf_time_us();
11436
- UNUSED(t0);
11802
+ // int64_t t0 = ggml_perf_time_us();
11803
+ // UNUSED(t0);
11437
11804
 
11438
- GGML_TENSOR_BINARY_OP_LOCALS;
11805
+ GGML_TENSOR_BINARY_OP_LOCALS
11439
11806
 
11440
11807
  const int ith = params->ith;
11441
11808
  const int nth = params->nth;
@@ -11474,6 +11841,146 @@ static void ggml_compute_forward_out_prod_f32(
11474
11841
  return;
11475
11842
  }
11476
11843
 
11844
+ // dst[:,:,:,:] = 0
11845
+ // for i2,i3:
11846
+ // for i1:
11847
+ // for i01:
11848
+ // for i0:
11849
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
11850
+
11851
+ // parallelize by last three dimensions
11852
+
11853
+ // total rows in dst
11854
+ const int64_t nr = ne1*ne2*ne3;
11855
+
11856
+ // rows per thread
11857
+ const int64_t dr = (nr + nth - 1)/nth;
11858
+
11859
+ // row range for this thread
11860
+ const int64_t ir0 = dr*ith;
11861
+ const int64_t ir1 = MIN(ir0 + dr, nr);
11862
+
11863
+ // block-tiling attempt
11864
+ const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
11865
+ const int64_t blck_1 = 16;
11866
+
11867
+ for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
11868
+ const int64_t bir1 = MIN(bir + blck_1, ir1);
11869
+ for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
11870
+ const int64_t bne01 = MIN(bi01 + blck_0, ne01);
11871
+ for (int64_t ir = bir; ir < bir1; ++ir) {
11872
+ // dst indices
11873
+ const int64_t i3 = ir/(ne2*ne1);
11874
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
11875
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
11876
+
11877
+ const int64_t i02 = i2;
11878
+ const int64_t i03 = i3;
11879
+
11880
+ //const int64_t i10 = i1;
11881
+ const int64_t i12 = i2;
11882
+ const int64_t i13 = i3;
11883
+
11884
+ #if GGML_VEC_MAD_UNROLL > 2
11885
+ const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
11886
+ for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
11887
+ const int64_t i11 = i01;
11888
+
11889
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11890
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11891
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11892
+
11893
+ ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
11894
+ }
11895
+ for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
11896
+ const int64_t i11 = i01;
11897
+
11898
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11899
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11900
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11901
+
11902
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
11903
+ }
11904
+ #else
11905
+ for (int64_t i01 = bi01; i01 < bne01; ++i01) {
11906
+ const int64_t i11 = i01;
11907
+
11908
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11909
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11910
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11911
+
11912
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
11913
+ }
11914
+ #endif
11915
+ }
11916
+ }
11917
+ }
11918
+
11919
+
11920
+ //int64_t t1 = ggml_perf_time_us();
11921
+ //static int64_t acc = 0;
11922
+ //acc += t1 - t0;
11923
+ //if (t1 - t0 > 10) {
11924
+ // printf("\n");
11925
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
11926
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
11927
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
11928
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
11929
+
11930
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
11931
+ //}
11932
+ }
11933
+
11934
+ static void ggml_compute_forward_out_prod_q_f32(
11935
+ const struct ggml_compute_params * params,
11936
+ const struct ggml_tensor * src0,
11937
+ const struct ggml_tensor * src1,
11938
+ struct ggml_tensor * dst) {
11939
+ // int64_t t0 = ggml_perf_time_us();
11940
+ // UNUSED(t0);
11941
+
11942
+ GGML_TENSOR_BINARY_OP_LOCALS;
11943
+
11944
+ const int ith = params->ith;
11945
+ const int nth = params->nth;
11946
+
11947
+ const enum ggml_type type = src0->type;
11948
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
11949
+
11950
+ GGML_ASSERT(ne02 == ne12);
11951
+ GGML_ASSERT(ne03 == ne13);
11952
+ GGML_ASSERT(ne2 == ne12);
11953
+ GGML_ASSERT(ne3 == ne13);
11954
+
11955
+ // we don't support permuted src0 dim0
11956
+ GGML_ASSERT(nb00 == ggml_type_size(type));
11957
+
11958
+ // dst dim0 cannot be transposed or permuted
11959
+ GGML_ASSERT(nb0 == sizeof(float));
11960
+ // GGML_ASSERT(nb0 <= nb1);
11961
+ // GGML_ASSERT(nb1 <= nb2);
11962
+ // GGML_ASSERT(nb2 <= nb3);
11963
+
11964
+ GGML_ASSERT(ne0 == ne00);
11965
+ GGML_ASSERT(ne1 == ne10);
11966
+ GGML_ASSERT(ne2 == ne02);
11967
+ GGML_ASSERT(ne3 == ne03);
11968
+
11969
+ // nb01 >= nb00 - src0 is not transposed
11970
+ // compute by src0 rows
11971
+
11972
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
11973
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
11974
+
11975
+ if (params->type == GGML_TASK_INIT) {
11976
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
11977
+ return;
11978
+ }
11979
+
11980
+ if (params->type == GGML_TASK_FINALIZE) {
11981
+ return;
11982
+ }
11983
+
11477
11984
  // parallelize by last three dimensions
11478
11985
 
11479
11986
  // total rows in dst
@@ -11493,6 +12000,8 @@ static void ggml_compute_forward_out_prod_f32(
11493
12000
  // for i0:
11494
12001
  // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
11495
12002
 
12003
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
12004
+
11496
12005
  for (int64_t ir = ir0; ir < ir1; ++ir) {
11497
12006
  // dst indices
11498
12007
  const int64_t i3 = ir/(ne2*ne1);
@@ -11513,10 +12022,8 @@ static void ggml_compute_forward_out_prod_f32(
11513
12022
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11514
12023
  float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11515
12024
 
11516
- ggml_vec_mad_f32(ne0, d, s0, *s1);
11517
- // for (int64_t i0 = 0; i0 < ne0; ++i0) {
11518
- // d[i0] += s0[i0] * s1[i1];
11519
- // }
12025
+ dequantize_row_q(s0, wdata, ne0);
12026
+ ggml_vec_mad_f32(ne0, d, wdata, *s1);
11520
12027
  }
11521
12028
  }
11522
12029
 
@@ -11545,10 +12052,13 @@ static void ggml_compute_forward_out_prod(
11545
12052
  case GGML_TYPE_Q5_0:
11546
12053
  case GGML_TYPE_Q5_1:
11547
12054
  case GGML_TYPE_Q8_0:
11548
- case GGML_TYPE_Q8_1:
12055
+ case GGML_TYPE_Q2_K:
12056
+ case GGML_TYPE_Q3_K:
12057
+ case GGML_TYPE_Q4_K:
12058
+ case GGML_TYPE_Q5_K:
12059
+ case GGML_TYPE_Q6_K:
11549
12060
  {
11550
- GGML_ASSERT(false); // todo
11551
- // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
12061
+ ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
11552
12062
  } break;
11553
12063
  case GGML_TYPE_F16:
11554
12064
  {
@@ -11666,8 +12176,8 @@ static void ggml_compute_forward_set_f32(
11666
12176
  const int nr = ggml_nrows(src1);
11667
12177
  const int nc = src1->ne[0];
11668
12178
 
11669
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
11670
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
12179
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
12180
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
11671
12181
 
11672
12182
  // src0 and dst as viewed during set
11673
12183
  const size_t nb0 = ggml_element_size(src0);
@@ -11936,14 +12446,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
11936
12446
  const struct ggml_compute_params * params,
11937
12447
  const struct ggml_tensor * src0,
11938
12448
  const struct ggml_tensor * src1,
11939
- const struct ggml_tensor * opt0,
11940
12449
  struct ggml_tensor * dst) {
11941
12450
  GGML_ASSERT(params->ith == 0);
11942
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
11943
- GGML_ASSERT(ggml_is_contiguous(opt0));
11944
12451
  GGML_ASSERT(ggml_is_contiguous(dst));
11945
12452
 
11946
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
12453
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
12454
+
12455
+ if (params->type == GGML_TASK_INIT) {
12456
+ memset(dst->data, 0, ggml_nbytes(dst));
12457
+ }
11947
12458
 
11948
12459
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11949
12460
  return;
@@ -11969,11 +12480,8 @@ static void ggml_compute_forward_get_rows_back_f32(
11969
12480
  const struct ggml_compute_params * params,
11970
12481
  const struct ggml_tensor * src0,
11971
12482
  const struct ggml_tensor * src1,
11972
- const struct ggml_tensor * opt0,
11973
12483
  struct ggml_tensor * dst) {
11974
12484
  GGML_ASSERT(params->ith == 0);
11975
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
11976
- GGML_ASSERT(ggml_is_contiguous(opt0));
11977
12485
  GGML_ASSERT(ggml_is_contiguous(dst));
11978
12486
 
11979
12487
  // ggml_compute_forward_dup_same_cont(params, opt0, dst);
@@ -12007,16 +12515,15 @@ static void ggml_compute_forward_get_rows_back(
12007
12515
  const struct ggml_compute_params * params,
12008
12516
  const struct ggml_tensor * src0,
12009
12517
  const struct ggml_tensor * src1,
12010
- const struct ggml_tensor * opt0,
12011
12518
  struct ggml_tensor * dst) {
12012
12519
  switch (src0->type) {
12013
12520
  case GGML_TYPE_F16:
12014
12521
  {
12015
- ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst);
12522
+ ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
12016
12523
  } break;
12017
12524
  case GGML_TYPE_F32:
12018
12525
  {
12019
- ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst);
12526
+ ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
12020
12527
  } break;
12021
12528
  default:
12022
12529
  {
@@ -12057,7 +12564,7 @@ static void ggml_compute_forward_diag_f32(
12057
12564
 
12058
12565
  // TODO: handle transposed/permuted matrices
12059
12566
 
12060
- GGML_TENSOR_UNARY_OP_LOCALS;
12567
+ GGML_TENSOR_UNARY_OP_LOCALS
12061
12568
 
12062
12569
  GGML_ASSERT(ne00 == ne0);
12063
12570
  GGML_ASSERT(ne00 == ne1);
@@ -12445,13 +12952,11 @@ static void ggml_compute_forward_alibi_f16(
12445
12952
  return;
12446
12953
  }
12447
12954
 
12448
- const int n_past = ((int32_t *) dst->op_params)[0];
12955
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12449
12956
  const int n_head = ((int32_t *) dst->op_params)[1];
12450
12957
  float max_bias;
12451
12958
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12452
12959
 
12453
- assert(n_past >= 0);
12454
-
12455
12960
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12456
12961
  const int ne1 = src0->ne[1]; // seq_len_without_past
12457
12962
  const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -12466,7 +12971,7 @@ static void ggml_compute_forward_alibi_f16(
12466
12971
  //const int nb3 = src0->nb[3];
12467
12972
 
12468
12973
  GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
12469
- GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12974
+ //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12470
12975
  GGML_ASSERT(n_head == ne2);
12471
12976
 
12472
12977
  // add alibi to src0 (KQ_scaled)
@@ -12612,8 +13117,8 @@ static void ggml_compute_forward_clamp(
12612
13117
  static void ggml_compute_forward_rope_f32(
12613
13118
  const struct ggml_compute_params * params,
12614
13119
  const struct ggml_tensor * src0,
13120
+ const struct ggml_tensor * src1,
12615
13121
  struct ggml_tensor * dst) {
12616
-
12617
13122
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12618
13123
  return;
12619
13124
  }
@@ -12623,9 +13128,9 @@ static void ggml_compute_forward_rope_f32(
12623
13128
 
12624
13129
  // these two only relevant for xPos RoPE:
12625
13130
  float xpos_base;
12626
- bool xpos_down;
13131
+ bool xpos_down;
12627
13132
 
12628
- const int n_past = ((int32_t *) dst->op_params)[0];
13133
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12629
13134
  const int n_dims = ((int32_t *) dst->op_params)[1];
12630
13135
  const int mode = ((int32_t *) dst->op_params)[2];
12631
13136
  const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -12634,9 +13139,7 @@ static void ggml_compute_forward_rope_f32(
12634
13139
  memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
12635
13140
  memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
12636
13141
 
12637
- assert(n_past >= 0);
12638
-
12639
- GGML_TENSOR_UNARY_OP_LOCALS;
13142
+ GGML_TENSOR_UNARY_OP_LOCALS
12640
13143
 
12641
13144
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12642
13145
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12666,9 +13169,11 @@ static void ggml_compute_forward_rope_f32(
12666
13169
  const bool is_neox = mode & 2;
12667
13170
  const bool is_glm = mode & 4;
12668
13171
 
13172
+ const int32_t * pos = (const int32_t *) src1->data;
13173
+
12669
13174
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12670
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12671
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13175
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13176
+ const int64_t p = pos[i2];
12672
13177
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12673
13178
  if (ir++ < ir0) continue;
12674
13179
  if (ir > ir1) break;
@@ -12705,7 +13210,7 @@ static void ggml_compute_forward_rope_f32(
12705
13210
  const float cos_theta = cosf(theta);
12706
13211
  const float sin_theta = sinf(theta);
12707
13212
  // zeta scaling for xPos only:
12708
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
13213
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
12709
13214
  if (xpos_down) zeta = 1.0f / zeta;
12710
13215
 
12711
13216
  theta *= theta_scale;
@@ -12750,8 +13255,8 @@ static void ggml_compute_forward_rope_f32(
12750
13255
  static void ggml_compute_forward_rope_f16(
12751
13256
  const struct ggml_compute_params * params,
12752
13257
  const struct ggml_tensor * src0,
13258
+ const struct ggml_tensor * src1,
12753
13259
  struct ggml_tensor * dst) {
12754
-
12755
13260
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12756
13261
  return;
12757
13262
  }
@@ -12759,16 +13264,14 @@ static void ggml_compute_forward_rope_f16(
12759
13264
  float freq_base;
12760
13265
  float freq_scale;
12761
13266
 
12762
- const int n_past = ((int32_t *) dst->op_params)[0];
13267
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12763
13268
  const int n_dims = ((int32_t *) dst->op_params)[1];
12764
13269
  const int mode = ((int32_t *) dst->op_params)[2];
12765
13270
  const int n_ctx = ((int32_t *) dst->op_params)[3];
12766
13271
  memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
12767
13272
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
12768
13273
 
12769
- assert(n_past >= 0);
12770
-
12771
- GGML_TENSOR_UNARY_OP_LOCALS;
13274
+ GGML_TENSOR_UNARY_OP_LOCALS
12772
13275
 
12773
13276
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12774
13277
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12798,9 +13301,11 @@ static void ggml_compute_forward_rope_f16(
12798
13301
  const bool is_neox = mode & 2;
12799
13302
  const bool is_glm = mode & 4;
12800
13303
 
13304
+ const int32_t * pos = (const int32_t *) src1->data;
13305
+
12801
13306
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12802
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12803
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13307
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13308
+ const int64_t p = pos[i2];
12804
13309
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12805
13310
  if (ir++ < ir0) continue;
12806
13311
  if (ir > ir1) break;
@@ -12879,15 +13384,16 @@ static void ggml_compute_forward_rope_f16(
12879
13384
  static void ggml_compute_forward_rope(
12880
13385
  const struct ggml_compute_params * params,
12881
13386
  const struct ggml_tensor * src0,
13387
+ const struct ggml_tensor * src1,
12882
13388
  struct ggml_tensor * dst) {
12883
13389
  switch (src0->type) {
12884
13390
  case GGML_TYPE_F16:
12885
13391
  {
12886
- ggml_compute_forward_rope_f16(params, src0, dst);
13392
+ ggml_compute_forward_rope_f16(params, src0, src1, dst);
12887
13393
  } break;
12888
13394
  case GGML_TYPE_F32:
12889
13395
  {
12890
- ggml_compute_forward_rope_f32(params, src0, dst);
13396
+ ggml_compute_forward_rope_f32(params, src0, src1, dst);
12891
13397
  } break;
12892
13398
  default:
12893
13399
  {
@@ -12901,6 +13407,7 @@ static void ggml_compute_forward_rope(
12901
13407
  static void ggml_compute_forward_rope_back_f32(
12902
13408
  const struct ggml_compute_params * params,
12903
13409
  const struct ggml_tensor * src0,
13410
+ const struct ggml_tensor * src1,
12904
13411
  struct ggml_tensor * dst) {
12905
13412
 
12906
13413
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -12918,7 +13425,7 @@ static void ggml_compute_forward_rope_back_f32(
12918
13425
  float xpos_base;
12919
13426
  bool xpos_down;
12920
13427
 
12921
- const int n_past = ((int32_t *) dst->op_params)[0];
13428
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12922
13429
  const int n_dims = ((int32_t *) dst->op_params)[1];
12923
13430
  const int mode = ((int32_t *) dst->op_params)[2];
12924
13431
  const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
@@ -12927,9 +13434,7 @@ static void ggml_compute_forward_rope_back_f32(
12927
13434
  memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
12928
13435
  memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
12929
13436
 
12930
- assert(n_past >= 0);
12931
-
12932
- GGML_TENSOR_UNARY_OP_LOCALS;
13437
+ GGML_TENSOR_UNARY_OP_LOCALS
12933
13438
 
12934
13439
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12935
13440
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12955,9 +13460,11 @@ static void ggml_compute_forward_rope_back_f32(
12955
13460
 
12956
13461
  const bool is_neox = mode & 2;
12957
13462
 
13463
+ const int32_t * pos = (const int32_t *) src1->data;
13464
+
12958
13465
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12959
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12960
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13466
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13467
+ const int64_t p = pos[i2];
12961
13468
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12962
13469
  if (ir++ < ir0) continue;
12963
13470
  if (ir > ir1) break;
@@ -12969,7 +13476,7 @@ static void ggml_compute_forward_rope_back_f32(
12969
13476
  const float cos_theta = cosf(theta);
12970
13477
  const float sin_theta = sinf(theta);
12971
13478
  // zeta scaling for xPos only:
12972
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
13479
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
12973
13480
  if (xpos_down) zeta = 1.0f / zeta;
12974
13481
 
12975
13482
  theta *= theta_scale;
@@ -13012,6 +13519,7 @@ static void ggml_compute_forward_rope_back_f32(
13012
13519
  static void ggml_compute_forward_rope_back_f16(
13013
13520
  const struct ggml_compute_params * params,
13014
13521
  const struct ggml_tensor * src0,
13522
+ const struct ggml_tensor * src1,
13015
13523
  struct ggml_tensor * dst) {
13016
13524
 
13017
13525
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -13022,13 +13530,11 @@ static void ggml_compute_forward_rope_back_f16(
13022
13530
  // dx = rope_back(dy, src1)
13023
13531
  // src0 is dy, src1 contains options
13024
13532
 
13025
- const int n_past = ((int32_t *) dst->op_params)[0];
13533
+ //const int n_past = ((int32_t *) dst->op_params)[0];
13026
13534
  const int n_dims = ((int32_t *) dst->op_params)[1];
13027
13535
  const int mode = ((int32_t *) dst->op_params)[2];
13028
13536
 
13029
- assert(n_past >= 0);
13030
-
13031
- GGML_TENSOR_UNARY_OP_LOCALS;
13537
+ GGML_TENSOR_UNARY_OP_LOCALS
13032
13538
 
13033
13539
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
13034
13540
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -13054,9 +13560,11 @@ static void ggml_compute_forward_rope_back_f16(
13054
13560
 
13055
13561
  const bool is_neox = mode & 2;
13056
13562
 
13563
+ const int32_t * pos = (const int32_t *) src1->data;
13564
+
13057
13565
  for (int64_t i3 = 0; i3 < ne3; i3++) {
13058
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
13059
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13566
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13567
+ const int64_t p = pos[i2];
13060
13568
  for (int64_t i1 = 0; i1 < ne1; i1++) {
13061
13569
  if (ir++ < ir0) continue;
13062
13570
  if (ir > ir1) break;
@@ -13108,15 +13616,16 @@ static void ggml_compute_forward_rope_back_f16(
13108
13616
  static void ggml_compute_forward_rope_back(
13109
13617
  const struct ggml_compute_params * params,
13110
13618
  const struct ggml_tensor * src0,
13619
+ const struct ggml_tensor * src1,
13111
13620
  struct ggml_tensor * dst) {
13112
13621
  switch (src0->type) {
13113
13622
  case GGML_TYPE_F16:
13114
13623
  {
13115
- ggml_compute_forward_rope_back_f16(params, src0, dst);
13624
+ ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
13116
13625
  } break;
13117
13626
  case GGML_TYPE_F32:
13118
13627
  {
13119
- ggml_compute_forward_rope_back_f32(params, src0, dst);
13628
+ ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
13120
13629
  } break;
13121
13630
  default:
13122
13631
  {
@@ -13139,7 +13648,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13139
13648
  int64_t t0 = ggml_perf_time_us();
13140
13649
  UNUSED(t0);
13141
13650
 
13142
- GGML_TENSOR_BINARY_OP_LOCALS;
13651
+ GGML_TENSOR_BINARY_OP_LOCALS
13143
13652
 
13144
13653
  const int ith = params->ith;
13145
13654
  const int nth = params->nth;
@@ -13230,7 +13739,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13230
13739
  int64_t t0 = ggml_perf_time_us();
13231
13740
  UNUSED(t0);
13232
13741
 
13233
- GGML_TENSOR_BINARY_OP_LOCALS;
13742
+ GGML_TENSOR_BINARY_OP_LOCALS
13234
13743
 
13235
13744
  const int ith = params->ith;
13236
13745
  const int nth = params->nth;
@@ -13342,7 +13851,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13342
13851
  int64_t t0 = ggml_perf_time_us();
13343
13852
  UNUSED(t0);
13344
13853
 
13345
- GGML_TENSOR_BINARY_OP_LOCALS;
13854
+ GGML_TENSOR_BINARY_OP_LOCALS
13346
13855
 
13347
13856
  const int ith = params->ith;
13348
13857
  const int nth = params->nth;
@@ -13433,7 +13942,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13433
13942
  int64_t t0 = ggml_perf_time_us();
13434
13943
  UNUSED(t0);
13435
13944
 
13436
- GGML_TENSOR_BINARY_OP_LOCALS;
13945
+ GGML_TENSOR_BINARY_OP_LOCALS
13437
13946
 
13438
13947
  const int ith = params->ith;
13439
13948
  const int nth = params->nth;
@@ -13551,7 +14060,7 @@ static void ggml_compute_forward_conv_1d(
13551
14060
  ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
13552
14061
  } else {
13553
14062
  GGML_ASSERT(false); // only stride 1 and 2 supported
13554
- };
14063
+ }
13555
14064
  }
13556
14065
 
13557
14066
  // ggml_compute_forward_conv_2d
@@ -13568,7 +14077,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
13568
14077
  int64_t t0 = ggml_perf_time_us();
13569
14078
  UNUSED(t0);
13570
14079
 
13571
- GGML_TENSOR_BINARY_OP_LOCALS;
14080
+ GGML_TENSOR_BINARY_OP_LOCALS
13572
14081
 
13573
14082
  const int ith = params->ith;
13574
14083
  const int nth = params->nth;
@@ -13688,7 +14197,7 @@ static void ggml_compute_forward_conv_transpose_2d(
13688
14197
  int64_t t0 = ggml_perf_time_us();
13689
14198
  UNUSED(t0);
13690
14199
 
13691
- GGML_TENSOR_BINARY_OP_LOCALS;
14200
+ GGML_TENSOR_BINARY_OP_LOCALS
13692
14201
 
13693
14202
  const int ith = params->ith;
13694
14203
  const int nth = params->nth;
@@ -13947,7 +14456,7 @@ static void ggml_compute_forward_upscale_f32(
13947
14456
 
13948
14457
  const int ith = params->ith;
13949
14458
 
13950
- GGML_TENSOR_UNARY_OP_LOCALS;
14459
+ GGML_TENSOR_UNARY_OP_LOCALS
13951
14460
 
13952
14461
  const int scale_factor = dst->op_params[0];
13953
14462
 
@@ -13999,14 +14508,14 @@ static void ggml_compute_forward_flash_attn_f32(
13999
14508
  int64_t t0 = ggml_perf_time_us();
14000
14509
  UNUSED(t0);
14001
14510
 
14002
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14003
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14004
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14005
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14006
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14007
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14008
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14009
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14511
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14512
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14513
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14514
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14515
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14516
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14517
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14518
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14010
14519
 
14011
14520
  const int ith = params->ith;
14012
14521
  const int nth = params->nth;
@@ -14076,10 +14585,11 @@ static void ggml_compute_forward_flash_attn_f32(
14076
14585
  S[i] = -INFINITY;
14077
14586
  }
14078
14587
 
14079
- for (int64_t ic = 0; ic < nek1; ++ic) {
14588
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
14589
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
14080
14590
  // k indices
14081
14591
  const int ik3 = iq3;
14082
- const int ik2 = iq2;
14592
+ const int ik2 = iq2 % nek2;
14083
14593
  const int ik1 = ic;
14084
14594
 
14085
14595
  // S indices
@@ -14092,20 +14602,18 @@ static void ggml_compute_forward_flash_attn_f32(
14092
14602
  }
14093
14603
 
14094
14604
  // scale
14095
- ggml_vec_scale_f32(nek1, S, scale);
14605
+ ggml_vec_scale_f32(masked_begin, S, scale);
14096
14606
 
14097
- if (masked) {
14098
- for (int64_t i = P; i < M; i++) {
14099
- if (i > P + iq1) {
14100
- S[i] = -INFINITY;
14101
- }
14102
- }
14607
+ for (int64_t i = masked_begin; i < M; i++) {
14608
+ S[i] = -INFINITY;
14103
14609
  }
14104
14610
 
14105
14611
  // softmax
14612
+ // exclude known -INF S[..] values from max and loop
14613
+ // dont forget to set their SW values to zero
14106
14614
  {
14107
14615
  float max = -INFINITY;
14108
- ggml_vec_max_f32(M, &max, S);
14616
+ ggml_vec_max_f32(masked_begin, &max, S);
14109
14617
 
14110
14618
  ggml_float sum = 0.0;
14111
14619
  {
@@ -14119,10 +14627,15 @@ static void ggml_compute_forward_flash_attn_f32(
14119
14627
  ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14120
14628
 
14121
14629
  for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14630
+ if (i >= masked_begin) {
14631
+ break;
14632
+ }
14122
14633
  float * SS = S + i;
14123
14634
 
14124
14635
  for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14125
- if (SS[j] == -INFINITY) {
14636
+ if (i + j >= masked_begin) {
14637
+ break;
14638
+ } else if (SS[j] == -INFINITY) {
14126
14639
  SS[j] = 0.0f;
14127
14640
  } else {
14128
14641
  #ifndef GGML_FLASH_ATTN_EXP_FP16
@@ -14147,10 +14660,10 @@ static void ggml_compute_forward_flash_attn_f32(
14147
14660
  assert(sum > 0.0);
14148
14661
 
14149
14662
  sum = 1.0/sum;
14150
- ggml_vec_scale_f32(M, S, sum);
14663
+ ggml_vec_scale_f32(masked_begin, S, sum);
14151
14664
 
14152
14665
  #ifndef NDEBUG
14153
- for (int i = 0; i < M; ++i) {
14666
+ for (int i = 0; i < masked_begin; ++i) {
14154
14667
  assert(!isnan(S[i]));
14155
14668
  assert(!isinf(S[i]));
14156
14669
  }
@@ -14163,9 +14676,13 @@ static void ggml_compute_forward_flash_attn_f32(
14163
14676
  const int i2 = iq2;
14164
14677
  const int i3 = iq3;
14165
14678
 
14166
- ggml_vec_dot_f32(nek1,
14167
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14168
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14679
+ // v indices
14680
+ const int iv2 = iq2 % nev2;
14681
+ const int iv3 = iq3;
14682
+
14683
+ ggml_vec_dot_f32(masked_begin,
14684
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14685
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14169
14686
  S);
14170
14687
  }
14171
14688
  }
@@ -14181,14 +14698,14 @@ static void ggml_compute_forward_flash_attn_f16(
14181
14698
  int64_t t0 = ggml_perf_time_us();
14182
14699
  UNUSED(t0);
14183
14700
 
14184
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14185
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14186
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14187
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14188
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14189
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14190
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14191
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14701
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14702
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14703
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14704
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14705
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14706
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14707
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14708
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14192
14709
 
14193
14710
  const int ith = params->ith;
14194
14711
  const int nth = params->nth;
@@ -14262,7 +14779,7 @@ static void ggml_compute_forward_flash_attn_f16(
14262
14779
  for (int64_t ic = 0; ic < nek1; ++ic) {
14263
14780
  // k indices
14264
14781
  const int ik3 = iq3;
14265
- const int ik2 = iq2;
14782
+ const int ik2 = iq2 % nek2;
14266
14783
  const int ik1 = ic;
14267
14784
 
14268
14785
  // S indices
@@ -14277,7 +14794,7 @@ static void ggml_compute_forward_flash_attn_f16(
14277
14794
  for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
14278
14795
  // k indices
14279
14796
  const int ik3 = iq3;
14280
- const int ik2 = iq2;
14797
+ const int ik2 = iq2 % nek2;
14281
14798
  const int ik1 = ic;
14282
14799
 
14283
14800
  // S indices
@@ -14302,6 +14819,8 @@ static void ggml_compute_forward_flash_attn_f16(
14302
14819
  }
14303
14820
 
14304
14821
  // softmax
14822
+ // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
14823
+ // dont forget to set their S values to zero
14305
14824
  {
14306
14825
  float max = -INFINITY;
14307
14826
  ggml_vec_max_f32(M, &max, S);
@@ -14358,6 +14877,7 @@ static void ggml_compute_forward_flash_attn_f16(
14358
14877
  S16[i] = GGML_FP32_TO_FP16(S[i]);
14359
14878
  }
14360
14879
 
14880
+ // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
14361
14881
  if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
14362
14882
  for (int64_t ic = 0; ic < nev1; ++ic) {
14363
14883
  // dst indices
@@ -14365,9 +14885,13 @@ static void ggml_compute_forward_flash_attn_f16(
14365
14885
  const int i2 = iq2;
14366
14886
  const int i3 = iq3;
14367
14887
 
14368
- ggml_vec_dot_f16(nek1,
14369
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14370
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14888
+ // v indices
14889
+ const int iv2 = iq2 % nev2;
14890
+ const int iv3 = iq3;
14891
+
14892
+ ggml_vec_dot_f16(nev0,
14893
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14894
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14371
14895
  S16);
14372
14896
  }
14373
14897
  } else {
@@ -14377,9 +14901,13 @@ static void ggml_compute_forward_flash_attn_f16(
14377
14901
  const int i2 = iq2;
14378
14902
  const int i3 = iq3;
14379
14903
 
14380
- ggml_vec_dot_f16_unroll(nek1, nbv1,
14381
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14382
- ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14904
+ // v indices
14905
+ const int iv2 = iq2 % nev2;
14906
+ const int iv3 = iq3;
14907
+
14908
+ ggml_vec_dot_f16_unroll(nev0, nbv1,
14909
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14910
+ ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14383
14911
  S16);
14384
14912
  }
14385
14913
  }
@@ -14422,18 +14950,18 @@ static void ggml_compute_forward_flash_ff_f16(
14422
14950
  int64_t t0 = ggml_perf_time_us();
14423
14951
  UNUSED(t0);
14424
14952
 
14425
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne);
14426
- GGML_TENSOR_LOCALS(size_t, nba, a, nb);
14427
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne);
14428
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb);
14429
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne);
14430
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb);
14431
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne);
14432
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb);
14433
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne);
14434
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb);
14435
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14436
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14953
+ GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
14954
+ GGML_TENSOR_LOCALS(size_t, nba, a, nb)
14955
+ GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
14956
+ GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
14957
+ GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
14958
+ GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
14959
+ GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
14960
+ GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
14961
+ GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
14962
+ GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
14963
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14964
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14437
14965
 
14438
14966
  const int ith = params->ith;
14439
14967
  const int nth = params->nth;
@@ -14581,16 +15109,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
14581
15109
  int64_t t0 = ggml_perf_time_us();
14582
15110
  UNUSED(t0);
14583
15111
 
14584
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14585
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14586
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14587
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14588
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14589
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14590
- GGML_TENSOR_LOCALS(int64_t, ned, d, ne);
14591
- GGML_TENSOR_LOCALS(size_t, nbd, d, nb);
14592
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14593
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
15112
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15113
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15114
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15115
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15116
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15117
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15118
+ GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
15119
+ GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
15120
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15121
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14594
15122
 
14595
15123
  const int ith = params->ith;
14596
15124
  const int nth = params->nth;
@@ -14638,10 +15166,37 @@ static void ggml_compute_forward_flash_attn_back_f32(
14638
15166
  return;
14639
15167
  }
14640
15168
 
14641
- // parallelize by q rows using ggml_vec_dot_f32
15169
+ const int64_t elem_q = ggml_nelements(q);
15170
+ const int64_t elem_k = ggml_nelements(k);
14642
15171
 
14643
- // total rows in q
14644
- const int nr = neq2*neq3;
15172
+ enum ggml_type result_type = dst->type;
15173
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
15174
+ const size_t tsize = ggml_type_size(result_type);
15175
+
15176
+ const size_t offs_q = 0;
15177
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
15178
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
15179
+
15180
+ void * grad_q = (char *) dst->data;
15181
+ void * grad_k = (char *) dst->data + offs_k;
15182
+ void * grad_v = (char *) dst->data + offs_v;
15183
+
15184
+ const size_t nbgq1 = nb0*neq0;
15185
+ const size_t nbgq2 = nb0*neq0*neq1;
15186
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
15187
+
15188
+ const size_t nbgk1 = nb0*nek0;
15189
+ const size_t nbgk2 = nb0*nek0*nek1;
15190
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
15191
+
15192
+ const size_t nbgv1 = nb0*nev0;
15193
+ const size_t nbgv2 = nb0*nev0*nev1;
15194
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
15195
+
15196
+ // parallelize by k rows using ggml_vec_dot_f32
15197
+
15198
+ // total rows in k
15199
+ const int nr = nek2*nek3;
14645
15200
 
14646
15201
  // rows per thread
14647
15202
  const int dr = (nr + nth - 1)/nth;
@@ -14654,268 +15209,243 @@ static void ggml_compute_forward_flash_attn_back_f32(
14654
15209
 
14655
15210
  //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
14656
15211
 
15212
+ // how often k2 (and v2) is repeated in q2
15213
+ int nrep = neq2/nek2;
15214
+
14657
15215
  for (int ir = ir0; ir < ir1; ++ir) {
14658
15216
  // q indices
14659
- const int iq3 = ir/(neq2);
14660
- const int iq2 = ir - iq3*neq2;
14661
- for ( int iq1 = 0; iq1 < neq1; ++iq1) {
15217
+ const int ik3 = ir/(nek2);
15218
+ const int ik2 = ir - ik3*nek2;
14662
15219
 
15220
+ const int iq3 = ik3;
15221
+ const int id3 = ik3;
15222
+ const int iv3 = ik3;
15223
+ const int iv2 = ik2;
14663
15224
 
14664
- // not sure about CACHE_LINE_SIZE_F32..
14665
- // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
14666
- float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
14667
- float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
15225
+ for (int irep = 0; irep < nrep; ++irep) {
15226
+ const int iq2 = ik2 + irep*nek2;
15227
+ const int id2 = iq2;
14668
15228
 
14669
- for (int i = M; i < Mup; ++i) {
14670
- S[i] = -INFINITY;
14671
- }
15229
+ // (ik2 + irep*nek2) % nek2 == ik2
15230
+ for (int iq1 = 0; iq1 < neq1; ++iq1) {
15231
+ const int id1 = iq1;
14672
15232
 
14673
- for (int64_t ic = 0; ic < nek1; ++ic) {
14674
- // k indices
14675
- const int ik3 = iq3;
14676
- const int ik2 = iq2;
14677
- const int ik1 = ic;
15233
+ // not sure about CACHE_LINE_SIZE_F32..
15234
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
15235
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
15236
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
14678
15237
 
14679
- // S indices
14680
- const int i1 = ik1;
15238
+ for (int i = M; i < Mup; ++i) {
15239
+ S[i] = -INFINITY;
15240
+ }
14681
15241
 
14682
- ggml_vec_dot_f32(neq0,
14683
- S + i1,
14684
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
14685
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14686
- }
15242
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15243
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15244
+ // k indices
15245
+ const int ik1 = ic;
14687
15246
 
14688
- // scale
14689
- ggml_vec_scale_f32(nek1, S, scale);
15247
+ // S indices
15248
+ const int i1 = ik1;
14690
15249
 
14691
- if (masked) {
14692
- for (int64_t i = P; i < M; i++) {
14693
- if (i > P + iq1) {
14694
- S[i] = -INFINITY;
14695
- }
15250
+ ggml_vec_dot_f32(neq0,
15251
+ S + i1,
15252
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15253
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14696
15254
  }
14697
- }
14698
15255
 
14699
- // softmax
14700
- {
14701
- float max = -INFINITY;
14702
- ggml_vec_max_f32(M, &max, S);
15256
+ // scale
15257
+ ggml_vec_scale_f32(masked_begin, S, scale);
14703
15258
 
14704
- ggml_float sum = 0.0;
15259
+ for (int64_t i = masked_begin; i < M; i++) {
15260
+ S[i] = -INFINITY;
15261
+ }
15262
+
15263
+ // softmax
15264
+ // exclude known -INF S[..] values from max and loop
15265
+ // dont forget to set their SM values to zero
14705
15266
  {
15267
+ float max = -INFINITY;
15268
+ ggml_vec_max_f32(masked_begin, &max, S);
15269
+
15270
+ ggml_float sum = 0.0;
15271
+ {
14706
15272
  #ifdef GGML_SOFT_MAX_ACCELERATE
14707
- max = -max;
14708
- vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
14709
- vvexpf(SM, SM, &Mup);
14710
- ggml_vec_sum_f32(Mup, &sum, SM);
15273
+ max = -max;
15274
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
15275
+ vvexpf(SM, SM, &Mup);
15276
+ ggml_vec_sum_f32(Mup, &sum, SM);
14711
15277
  #else
14712
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
14713
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15278
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15279
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14714
15280
 
14715
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14716
- float * SR = S + i;
14717
- float * SW = SM + i;
14718
-
14719
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14720
- if (SR[j] == -INFINITY) {
14721
- SW[j] = 0.0f;
14722
- } else {
15281
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15282
+ if (i >= masked_begin) {
15283
+ break;
15284
+ }
15285
+ float * SR = S + i;
15286
+ float * SW = SM + i;
15287
+
15288
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15289
+ if (i + j >= masked_begin) {
15290
+ break;
15291
+ } else if (SR[j] == -INFINITY) {
15292
+ SW[j] = 0.0f;
15293
+ } else {
14723
15294
  #ifndef GGML_FLASH_ATTN_EXP_FP16
14724
- const float val = expf(SR[j] - max);
15295
+ const float val = expf(SR[j] - max);
14725
15296
  #else
14726
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
14727
- memcpy(&scvt[j], &s, sizeof(uint16_t));
14728
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
15297
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
15298
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
15299
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
14729
15300
  #endif
14730
- sump[j] += (ggml_float)val;
14731
- SW[j] = val;
15301
+ sump[j] += (ggml_float)val;
15302
+ SW[j] = val;
15303
+ }
14732
15304
  }
14733
15305
  }
14734
- }
14735
15306
 
14736
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14737
- sum += sump[i];
14738
- }
15307
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15308
+ sum += sump[i];
15309
+ }
14739
15310
  #endif
14740
- }
14741
-
14742
- assert(sum > 0.0);
14743
-
14744
- sum = 1.0/sum;
14745
- ggml_vec_scale_f32(M, SM, sum);
14746
-
14747
- }
14748
-
14749
- // step-by-step explanation
14750
- {
14751
- // forward-process shape grads from backward process
14752
- // parallel_for iq2,iq3:
14753
- // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
14754
- // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
14755
- // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
14756
- // for iq1:
14757
- // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
14758
- // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
14759
- // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
14760
- // S0 = -Inf [D,1,1,1]
14761
- // ~S1[i] = dot(kcur[:D,i], qcur)
14762
- // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
14763
- // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
14764
- // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14765
- // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
14766
- // ~S5[i] = dot(vcur[:,i], S4)
14767
- // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
14768
- // ~dst[i,iq1,iq2,iq3] = S5[i] ^
14769
- // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
14770
- // dst backward-/ grad[dst] = d
14771
- //
14772
- // output gradients with their dependencies:
14773
- //
14774
- // grad[kcur] = grad[S1].T @ qcur
14775
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14776
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14777
- // grad[S4] = grad[S5] @ vcur
14778
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14779
- // grad[qcur] = grad[S1] @ kcur
14780
- // grad[vcur] = grad[S5].T @ S4
14781
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14782
- //
14783
- // in post-order:
14784
- //
14785
- // S1 = qcur @ kcur.T
14786
- // S2 = S1 * scale
14787
- // S3 = diag_mask_inf(S2, P)
14788
- // S4 = softmax(S3)
14789
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14790
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14791
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14792
- // grad[qcur] = grad[S1] @ kcur
14793
- // grad[kcur] = grad[S1].T @ qcur
14794
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14795
- //
14796
- // using less variables (SM=S4):
14797
- //
14798
- // S = diag_mask_inf(qcur @ kcur.T * scale, P)
14799
- // SM = softmax(S)
14800
- // S = d[:D,iq1,iq2,iq3] @ vcur
14801
- // dot_SM_gradSM = dot(SM, S)
14802
- // S = SM * (S - dot(SM, S))
14803
- // S = diag_mask_zero(S, P) * scale
14804
- //
14805
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14806
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14807
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14808
- }
14809
-
14810
- // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
14811
- // S = d[:D,iq1,iq2,iq3] @ vcur
14812
- // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
14813
- ggml_vec_set_f32(M, S, 0);
14814
- for (int64_t ic = 0; ic < D; ++ic) {
14815
- // dst indices
14816
- const int i1 = iq1;
14817
- const int i2 = iq2;
14818
- const int i3 = iq3;
15311
+ }
14819
15312
 
14820
- ggml_vec_mad_f32(M,
14821
- S,
14822
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14823
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
14824
- }
15313
+ assert(sum > 0.0);
14825
15314
 
14826
- // S = SM * (S - dot(SM, S))
14827
- float dot_SM_gradSM = 0;
14828
- ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
14829
- ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
14830
- ggml_vec_mul_f32 (M, S, S, SM);
15315
+ sum = 1.0/sum;
15316
+ ggml_vec_scale_f32(masked_begin, SM, sum);
14831
15317
 
14832
- // S = diag_mask_zero(S, P) * scale
14833
- if (masked) {
14834
- // for (int64_t i = P + iq1 + 1; i < M; i++) {
14835
- // S[i] = 0;
14836
- // }
14837
- for (int64_t i = P; i < M; i++) {
14838
- if (i > P + iq1) {
14839
- S[i] = 0;
14840
- }
14841
15318
  }
14842
- }
14843
- ggml_vec_scale_f32(M, S, scale);
14844
-
14845
- void * grad_q = (char *) dst->data;
14846
- void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
14847
- void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
14848
-
14849
- const size_t nbgq1 = nb0*neq0;
14850
- const size_t nbgq2 = nb0*neq0*neq1;
14851
- const size_t nbgq3 = nb0*neq0*neq1*neq2;
14852
-
14853
- const size_t nbgk1 = nb0*nek0;
14854
- const size_t nbgk2 = nb0*nek0*nek1;
14855
- const size_t nbgk3 = nb0*nek0*nek1*neq2;
14856
-
14857
- const size_t nbgv1 = nb0*nev0;
14858
- const size_t nbgv2 = nb0*nev0*nev1;
14859
- const size_t nbgv3 = nb0*nev0*nev1*neq2;
14860
-
14861
- // S shape [M,1]
14862
- // SM shape [M,1]
14863
- // kcur shape [D,M]
14864
- // qcur shape [D,1]
14865
- // vcur shape [M,D]
14866
- //
14867
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14868
- // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
14869
- // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
14870
- //
14871
- //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
14872
- //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
14873
- for (int64_t ic = 0; ic < M; ++ic) {
14874
- // dst indices
14875
- const int i1 = iq1;
14876
- const int i2 = iq2;
14877
- const int i3 = iq3;
14878
15319
 
14879
- ggml_vec_mad_f32(D,
14880
- (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
14881
- (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
14882
- S[ic]);
14883
- }
15320
+ // step-by-step explanation
15321
+ {
15322
+ // forward-process shape grads from backward process
15323
+ // parallel_for ik2,ik3:
15324
+ // for irep:
15325
+ // iq2 = ik2 + irep*nek2
15326
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
15327
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
15328
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
15329
+ // for iq1:
15330
+ // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
15331
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
15332
+ // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
15333
+ // S0 = -Inf [D,1,1,1]
15334
+ // ~S1[i] = dot(kcur[:D,i], qcur)
15335
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
15336
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
15337
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15338
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
15339
+ // ~S5[i] = dot(vcur[:,i], S4)
15340
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
15341
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
15342
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
15343
+ // dst backward-/ grad[dst] = d
15344
+ //
15345
+ // output gradients with their dependencies:
15346
+ //
15347
+ // grad[kcur] = grad[S1].T @ qcur
15348
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
15349
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15350
+ // grad[S4] = grad[S5] @ vcur
15351
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
15352
+ // grad[qcur] = grad[S1] @ kcur
15353
+ // grad[vcur] = grad[S5].T @ S4
15354
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
15355
+ //
15356
+ // in post-order:
15357
+ //
15358
+ // S1 = qcur @ kcur.T
15359
+ // S2 = S1 * scale
15360
+ // S3 = diag_mask_inf(S2, P)
15361
+ // S4 = softmax(S3)
15362
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
15363
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15364
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
15365
+ // grad[qcur] = grad[S1] @ kcur
15366
+ // grad[kcur] = grad[S1].T @ qcur
15367
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
15368
+ //
15369
+ // using less variables (SM=S4):
15370
+ //
15371
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
15372
+ // SM = softmax(S)
15373
+ // S = d[:D,iq1,iq2,iq3] @ vcur
15374
+ // dot_SM_gradSM = dot(SM, S)
15375
+ // S = SM * (S - dot(SM, S))
15376
+ // S = diag_mask_zero(S, P) * scale
15377
+ //
15378
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
15379
+ // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
15380
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
15381
+ }
14884
15382
 
14885
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14886
- // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
14887
- // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
14888
- for (int64_t ic = 0; ic < M; ++ic) {
14889
- // dst indices
14890
- const int i1 = iq1;
14891
- const int i2 = iq2;
14892
- const int i3 = iq3;
15383
+ // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
15384
+ // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
15385
+ // for ic:
15386
+ // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
15387
+ // exclude known future zero S[..] values from operation
15388
+ ggml_vec_set_f32(masked_begin, S, 0);
15389
+ for (int64_t ic = 0; ic < D; ++ic) {
15390
+ ggml_vec_mad_f32(masked_begin,
15391
+ S,
15392
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15393
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
15394
+ }
14893
15395
 
14894
- // ggml_vec_set_f32(D,
14895
- // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14896
- // 0);
14897
- ggml_vec_mad_f32(D,
14898
- (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14899
- (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
14900
- S[ic]);
14901
- }
15396
+ // S = SM * (S - dot(SM, S))
15397
+ float dot_SM_gradSM = 0;
15398
+ ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
15399
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
15400
+ ggml_vec_mul_f32 (masked_begin, S, S, SM);
15401
+
15402
+ // S = diag_mask_zero(S, P) * scale
15403
+ // already done by above ggml_vec_set_f32
15404
+
15405
+ // exclude known zero S[..] values from operation
15406
+ ggml_vec_scale_f32(masked_begin, S, scale);
15407
+
15408
+ // S shape [M,1]
15409
+ // SM shape [M,1]
15410
+ // kcur shape [D,M]
15411
+ // qcur shape [D,1]
15412
+ // vcur shape [M,D]
15413
+
15414
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
15415
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
15416
+ // for ic:
15417
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
15418
+ // exclude known zero S[..] values from loop
15419
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15420
+ ggml_vec_mad_f32(D,
15421
+ (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
15422
+ (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
15423
+ S[ic]);
15424
+ }
14902
15425
 
14903
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14904
- // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
14905
- // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
14906
- for (int64_t ic = 0; ic < D; ++ic) {
14907
- // dst indices
14908
- const int i1 = iq1;
14909
- const int i2 = iq2;
14910
- const int i3 = iq3;
15426
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
15427
+ // for ic:
15428
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
15429
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
15430
+ // exclude known zero S[..] values from loop
15431
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15432
+ ggml_vec_mad_f32(D,
15433
+ (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
15434
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
15435
+ S[ic]);
15436
+ }
14911
15437
 
14912
- // ggml_vec_set_f32(M,
14913
- // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14914
- // 0);
14915
- ggml_vec_mad_f32(M,
14916
- (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14917
- SM,
14918
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
15438
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
15439
+ // for ic:
15440
+ // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
15441
+ // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
15442
+ // exclude known zero SM[..] values from mad
15443
+ for (int64_t ic = 0; ic < D; ++ic) {
15444
+ ggml_vec_mad_f32(masked_begin,
15445
+ (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
15446
+ SM,
15447
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
15448
+ }
14919
15449
  }
14920
15450
  }
14921
15451
  }
@@ -14951,8 +15481,8 @@ static void ggml_compute_forward_win_part_f32(
14951
15481
  return;
14952
15482
  }
14953
15483
 
14954
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
14955
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
15484
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
15485
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14956
15486
 
14957
15487
  const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
14958
15488
  const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
@@ -15013,8 +15543,8 @@ static void ggml_compute_forward_win_unpart_f32(
15013
15543
  return;
15014
15544
  }
15015
15545
 
15016
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
15017
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
15546
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
15547
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15018
15548
 
15019
15549
  const int32_t w = ((const int32_t *)(dst->op_params))[0];
15020
15550
 
@@ -15131,7 +15661,7 @@ static void ggml_compute_forward_get_rel_pos_f16(
15131
15661
 
15132
15662
  // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
15133
15663
 
15134
- GGML_TENSOR_UNARY_OP_LOCALS;
15664
+ GGML_TENSOR_UNARY_OP_LOCALS
15135
15665
 
15136
15666
  const int64_t w = ne1;
15137
15667
 
@@ -15829,7 +16359,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15829
16359
  } break;
15830
16360
  case GGML_OP_GET_ROWS_BACK:
15831
16361
  {
15832
- ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
16362
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
15833
16363
  } break;
15834
16364
  case GGML_OP_DIAG:
15835
16365
  {
@@ -15853,11 +16383,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15853
16383
  } break;
15854
16384
  case GGML_OP_ROPE:
15855
16385
  {
15856
- ggml_compute_forward_rope(params, tensor->src[0], tensor);
16386
+ ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
15857
16387
  } break;
15858
16388
  case GGML_OP_ROPE_BACK:
15859
16389
  {
15860
- ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
16390
+ ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
15861
16391
  } break;
15862
16392
  case GGML_OP_ALIBI:
15863
16393
  {
@@ -16002,7 +16532,218 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16002
16532
 
16003
16533
  ////////////////////////////////////////////////////////////////////////////////
16004
16534
 
16005
- static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
16535
+ static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
16536
+
16537
+ static size_t hash(void * p) {
16538
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
16539
+ }
16540
+
16541
+ static size_t hash_find(void * hash_table[], void * p) {
16542
+ size_t h = hash(p);
16543
+
16544
+ // linear probing
16545
+ size_t i = h;
16546
+ while (hash_table[i] != NULL && hash_table[i] != p) {
16547
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
16548
+ if (i == h) {
16549
+ // visited all hash table entries -> not found
16550
+ return GGML_GRAPH_HASHTABLE_SIZE;
16551
+ }
16552
+ }
16553
+ return i;
16554
+ }
16555
+
16556
+ static bool hash_insert(void * hash_table[], void * p) {
16557
+ size_t i = hash_find(hash_table, p);
16558
+
16559
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16560
+
16561
+ if (hash_table[i] == p) {
16562
+ return true;
16563
+ }
16564
+
16565
+ // insert
16566
+ GGML_ASSERT(hash_table[i] == NULL);
16567
+ hash_table[i] = p;
16568
+ return false;
16569
+ }
16570
+
16571
+ static bool hash_contains(void * hash_table[], void * p) {
16572
+ size_t i = hash_find(hash_table, p);
16573
+ return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
16574
+ }
16575
+
16576
+ struct hash_map {
16577
+ void * keys[GGML_GRAPH_HASHTABLE_SIZE];
16578
+ void * vals[GGML_GRAPH_HASHTABLE_SIZE];
16579
+ };
16580
+
16581
+ static struct hash_map * new_hash_map(void) {
16582
+ struct hash_map * result = malloc(sizeof(struct hash_map));
16583
+ for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
16584
+ result->keys[i] = NULL;
16585
+ result->vals[i] = NULL;
16586
+ }
16587
+ return result;
16588
+ }
16589
+
16590
+ static void free_hash_map(struct hash_map * map) {
16591
+ free(map);
16592
+ }
16593
+
16594
+ // gradient checkpointing
16595
+
16596
+ static struct ggml_tensor * ggml_recompute_graph_node(
16597
+ struct ggml_context * ctx,
16598
+ struct ggml_cgraph * graph,
16599
+ struct hash_map * replacements,
16600
+ struct ggml_tensor * node) {
16601
+
16602
+ if (node == NULL) {
16603
+ return NULL;
16604
+ }
16605
+
16606
+ if (node->is_param) {
16607
+ return node;
16608
+ }
16609
+
16610
+ if (!hash_contains(graph->visited_hash_table, node)) {
16611
+ return node;
16612
+ }
16613
+
16614
+ int count_children = 0;
16615
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16616
+ if (node->src[k]) {
16617
+ ++count_children;
16618
+ }
16619
+ }
16620
+
16621
+ if (count_children == 0) {
16622
+ return node;
16623
+ }
16624
+
16625
+ size_t i = hash_find(replacements->keys, node);
16626
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16627
+ if (replacements->keys[i] == node) {
16628
+ return (struct ggml_tensor *) replacements->vals[i];
16629
+ }
16630
+
16631
+ struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
16632
+
16633
+ // insert clone into replacements
16634
+ GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
16635
+ replacements->keys[i] = node;
16636
+ replacements->vals[i] = clone;
16637
+
16638
+ clone->op = node->op;
16639
+ clone->grad = node->grad;
16640
+ clone->is_param = node->is_param;
16641
+ clone->extra = node->extra;
16642
+ for (int k = 0; k < GGML_MAX_DIMS; ++k) {
16643
+ clone->nb[k] = node->nb[k];
16644
+ }
16645
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16646
+ clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
16647
+ }
16648
+ if (node->view_src != NULL) {
16649
+ clone->data = (node->view_src->data == NULL)
16650
+ ? NULL // view_src not yet allocated
16651
+ : (char *) node->view_src->data // view_src already allocated
16652
+ + node->view_offs;
16653
+ clone->view_src = node->view_src;
16654
+ clone->view_offs = node->view_offs;
16655
+ }
16656
+
16657
+ GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
16658
+ GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
16659
+ memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
16660
+ ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
16661
+
16662
+ return clone;
16663
+ }
16664
+
16665
+ void ggml_build_backward_gradient_checkpointing(
16666
+ struct ggml_context * ctx,
16667
+ struct ggml_cgraph * gf,
16668
+ struct ggml_cgraph * gb,
16669
+ struct ggml_cgraph * gb_tmp,
16670
+ struct ggml_tensor * * checkpoints,
16671
+ int n_checkpoints) {
16672
+ *gb_tmp = *gf;
16673
+ ggml_build_backward_expand(ctx, gf, gb_tmp, true);
16674
+
16675
+ if (n_checkpoints <= 0) {
16676
+ *gb = *gb_tmp;
16677
+ return;
16678
+ }
16679
+
16680
+ struct hash_map * replacements = new_hash_map();
16681
+
16682
+ // insert checkpoints in replacements
16683
+ for (int i = 0; i < n_checkpoints; ++i) {
16684
+ size_t k = hash_find(replacements->keys, checkpoints[i]);
16685
+ GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16686
+ GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
16687
+ replacements->keys[k] = checkpoints[i];
16688
+ replacements->vals[k] = checkpoints[i];
16689
+ }
16690
+
16691
+ *gb = *gf;
16692
+ // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
16693
+ // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
16694
+ // by recomputing them from checkpoints
16695
+ for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
16696
+ struct ggml_tensor * node = gb_tmp->nodes[i];
16697
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16698
+ // insert new tensors recomputing src, reusing already made replacements,
16699
+ // remember replacements: remember new tensors with mapping from corresponding gf nodes
16700
+ // recurse for input tensors,
16701
+ // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
16702
+ node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
16703
+ }
16704
+ // insert rewritten backward node with replacements made into resulting backward graph gb
16705
+ ggml_build_forward_expand(gb, node);
16706
+ }
16707
+
16708
+ free_hash_map(replacements);
16709
+ }
16710
+
16711
+ // functions to change gradients considering the case that input a might be initial gradient with zero value
16712
+
16713
+ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16714
+ if (hash_contains(zero_table, a)) {
16715
+ return b;
16716
+ } else {
16717
+ return ggml_add_impl(ctx, a, b, false);
16718
+ }
16719
+ }
16720
+
16721
+ static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
16722
+ if (hash_contains(zero_table, a)) {
16723
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
16724
+ return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
16725
+ } else {
16726
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
16727
+ }
16728
+ }
16729
+
16730
+ static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16731
+ if (hash_contains(zero_table, a)) {
16732
+ return ggml_repeat(ctx, b, a);
16733
+ } else {
16734
+ return ggml_add1_impl(ctx, a, b, false);
16735
+ }
16736
+ }
16737
+
16738
+ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16739
+ if (hash_contains(zero_table, a)) {
16740
+ return ggml_neg(ctx, b);
16741
+ } else {
16742
+ return ggml_sub_impl(ctx, a, b, false);
16743
+ }
16744
+ }
16745
+
16746
+ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
16006
16747
  struct ggml_tensor * src0 = tensor->src[0];
16007
16748
  struct ggml_tensor * src1 = tensor->src[1];
16008
16749
 
@@ -16010,34 +16751,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16010
16751
  case GGML_OP_DUP:
16011
16752
  {
16012
16753
  if (src0->grad) {
16013
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16754
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16014
16755
  }
16015
16756
  } break;
16016
16757
  case GGML_OP_ADD:
16017
16758
  {
16018
16759
  if (src0->grad) {
16019
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16760
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16020
16761
  }
16021
16762
  if (src1->grad) {
16022
- src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
16763
+ src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
16023
16764
  }
16024
16765
  } break;
16025
16766
  case GGML_OP_ADD1:
16026
16767
  {
16027
16768
  if (src0->grad) {
16028
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16769
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16029
16770
  }
16030
16771
  if (src1->grad) {
16031
- src1->grad = ggml_add_impl(ctx,
16772
+ src1->grad = ggml_add_or_set(ctx,
16032
16773
  src1->grad,
16033
16774
  ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
16034
- inplace);
16775
+ zero_table);
16035
16776
  }
16036
16777
  } break;
16037
16778
  case GGML_OP_ACC:
16038
16779
  {
16039
16780
  if (src0->grad) {
16040
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16781
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16041
16782
  }
16042
16783
  if (src1->grad) {
16043
16784
  const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -16054,117 +16795,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16054
16795
  nb1, nb2, nb3, offset);
16055
16796
 
16056
16797
  src1->grad =
16057
- ggml_add_impl(ctx,
16798
+ ggml_add_or_set(ctx,
16058
16799
  src1->grad,
16059
16800
  ggml_reshape(ctx,
16060
16801
  ggml_cont(ctx, tensor_grad_view),
16061
16802
  src1->grad),
16062
- inplace);
16803
+ zero_table);
16063
16804
  }
16064
16805
  } break;
16065
16806
  case GGML_OP_SUB:
16066
16807
  {
16067
16808
  if (src0->grad) {
16068
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16809
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16069
16810
  }
16070
16811
  if (src1->grad) {
16071
- src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
16812
+ src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
16072
16813
  }
16073
16814
  } break;
16074
16815
  case GGML_OP_MUL:
16075
16816
  {
16076
16817
  if (src0->grad) {
16077
16818
  src0->grad =
16078
- ggml_add_impl(ctx,
16819
+ ggml_add_or_set(ctx,
16079
16820
  src0->grad,
16080
16821
  ggml_mul(ctx, src1, tensor->grad),
16081
- inplace);
16822
+ zero_table);
16082
16823
  }
16083
16824
  if (src1->grad) {
16084
16825
  src1->grad =
16085
- ggml_add_impl(ctx,
16826
+ ggml_add_or_set(ctx,
16086
16827
  src1->grad,
16087
16828
  ggml_mul(ctx, src0, tensor->grad),
16088
- inplace);
16829
+ zero_table);
16089
16830
  }
16090
16831
  } break;
16091
16832
  case GGML_OP_DIV:
16092
16833
  {
16093
16834
  if (src0->grad) {
16094
16835
  src0->grad =
16095
- ggml_add_impl(ctx,
16836
+ ggml_add_or_set(ctx,
16096
16837
  src0->grad,
16097
16838
  ggml_div(ctx, tensor->grad, src1),
16098
- inplace);
16839
+ zero_table);
16099
16840
  }
16100
16841
  if (src1->grad) {
16101
16842
  src1->grad =
16102
- ggml_sub_impl(ctx,
16843
+ ggml_sub_or_set(ctx,
16103
16844
  src1->grad,
16104
16845
  ggml_mul(ctx,
16105
16846
  tensor->grad,
16106
16847
  ggml_div(ctx, tensor, src1)),
16107
- inplace);
16848
+ zero_table);
16108
16849
  }
16109
16850
  } break;
16110
16851
  case GGML_OP_SQR:
16111
16852
  {
16112
16853
  if (src0->grad) {
16113
16854
  src0->grad =
16114
- ggml_add_impl(ctx,
16855
+ ggml_add_or_set(ctx,
16115
16856
  src0->grad,
16116
16857
  ggml_scale(ctx,
16117
16858
  ggml_mul(ctx, src0, tensor->grad),
16118
16859
  ggml_new_f32(ctx, 2.0f)),
16119
- inplace);
16860
+ zero_table);
16120
16861
  }
16121
16862
  } break;
16122
16863
  case GGML_OP_SQRT:
16123
16864
  {
16124
16865
  if (src0->grad) {
16125
16866
  src0->grad =
16126
- ggml_add_impl(ctx,
16867
+ ggml_add_or_set(ctx,
16127
16868
  src0->grad,
16128
16869
  ggml_scale(ctx,
16129
16870
  ggml_div(ctx,
16130
16871
  tensor->grad,
16131
16872
  tensor),
16132
16873
  ggml_new_f32(ctx, 0.5f)),
16133
- inplace);
16874
+ zero_table);
16134
16875
  }
16135
16876
  } break;
16136
16877
  case GGML_OP_LOG:
16137
16878
  {
16138
16879
  if (src0->grad) {
16139
16880
  src0->grad =
16140
- ggml_add_impl(ctx,
16881
+ ggml_add_or_set(ctx,
16141
16882
  src0->grad,
16142
16883
  ggml_div(ctx,
16143
16884
  tensor->grad,
16144
16885
  src0),
16145
- inplace);
16886
+ zero_table);
16146
16887
  }
16147
16888
  } break;
16148
16889
  case GGML_OP_SUM:
16149
16890
  {
16150
16891
  if (src0->grad) {
16151
16892
  src0->grad =
16152
- ggml_add1_impl(ctx,
16893
+ ggml_add1_or_set(ctx,
16153
16894
  src0->grad,
16154
16895
  tensor->grad,
16155
- inplace);
16896
+ zero_table);
16156
16897
  }
16157
16898
  } break;
16158
16899
  case GGML_OP_SUM_ROWS:
16159
16900
  {
16160
16901
  if (src0->grad) {
16161
16902
  src0->grad =
16162
- ggml_add_impl(ctx,
16903
+ ggml_add_or_set(ctx,
16163
16904
  src0->grad,
16164
16905
  ggml_repeat(ctx,
16165
16906
  tensor->grad,
16166
16907
  src0->grad),
16167
- inplace);
16908
+ zero_table);
16168
16909
  }
16169
16910
  } break;
16170
16911
  case GGML_OP_MEAN:
@@ -16176,20 +16917,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16176
16917
  {
16177
16918
  // necessary for llama
16178
16919
  if (src0->grad) {
16179
- src0->grad = ggml_add_impl(ctx,
16920
+ src0->grad = ggml_add_or_set(ctx,
16180
16921
  src0->grad,
16181
16922
  ggml_repeat_back(ctx, tensor->grad, src0->grad),
16182
- inplace);
16923
+ zero_table);
16183
16924
  }
16184
16925
  } break;
16185
16926
  case GGML_OP_REPEAT_BACK:
16186
16927
  {
16187
16928
  if (src0->grad) {
16188
16929
  // TODO: test this
16189
- src0->grad = ggml_add_impl(ctx,
16930
+ src0->grad = ggml_add_or_set(ctx,
16190
16931
  src0->grad,
16191
16932
  ggml_repeat(ctx, tensor->grad, src0->grad),
16192
- inplace);
16933
+ zero_table);
16193
16934
  }
16194
16935
  } break;
16195
16936
  case GGML_OP_CONCAT:
@@ -16211,10 +16952,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16211
16952
  float eps;
16212
16953
  memcpy(&eps, tensor->op_params, sizeof(float));
16213
16954
 
16214
- src0->grad = ggml_add_impl(ctx,
16955
+ src0->grad = ggml_add_or_set(ctx,
16215
16956
  src0->grad,
16216
16957
  ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
16217
- inplace);
16958
+ zero_table);
16218
16959
  }
16219
16960
  } break;
16220
16961
  case GGML_OP_RMS_NORM_BACK:
@@ -16238,37 +16979,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16238
16979
  // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
16239
16980
  // ds1 = t.T.dot(dt)
16240
16981
 
16241
- // tensor.shape [m,p]
16242
- // src0.shape [n,m]
16243
- // src1.shape [n,p]
16982
+ // tensor.shape [m,p,qq,rr]
16983
+ // src0.shape [n,m,q1,r1]
16984
+ // src1.shape [n,p,qq,rr]
16244
16985
 
16245
16986
  // necessary for llama
16246
16987
  if (src0->grad) {
16988
+ struct ggml_tensor * s1_tg =
16989
+ ggml_out_prod(ctx, // [n,m,qq,rr]
16990
+ src1, // [n,p,qq,rr]
16991
+ tensor->grad); // [m,p,qq,rr]
16992
+ const int64_t qq = s1_tg->ne[2];
16993
+ const int64_t rr = s1_tg->ne[3];
16994
+ const int64_t q1 = src0->ne[2];
16995
+ const int64_t r1 = src0->ne[3];
16996
+ const bool ne2_broadcasted = qq > q1;
16997
+ const bool ne3_broadcasted = rr > r1;
16998
+ if (ne2_broadcasted || ne3_broadcasted) {
16999
+ // sum broadcast repetitions of s1_tg into shape of src0
17000
+ s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
17001
+ }
16247
17002
  src0->grad =
16248
- ggml_add_impl(ctx,
16249
- src0->grad,
16250
- ggml_out_prod(ctx, // [n,m]
16251
- src1, // [n,p]
16252
- tensor->grad), // [m,p]
16253
- inplace);
17003
+ ggml_add_or_set(ctx,
17004
+ src0->grad, // [n,m,q1,r1]
17005
+ s1_tg, // [n,m,q1,r1]
17006
+ zero_table);
16254
17007
  }
16255
17008
  if (src1->grad) {
16256
17009
  src1->grad =
16257
- ggml_add_impl(ctx,
16258
- src1->grad,
16259
- // ggml_mul_mat(ctx, // [n,p]
16260
- // ggml_cont(ctx, // [m,n]
16261
- // ggml_transpose(ctx, src0)), // [m,n]
16262
- // tensor->grad), // [m,p]
17010
+ ggml_add_or_set(ctx,
17011
+ src1->grad, // [n,p,qq,rr]
17012
+ // ggml_mul_mat(ctx, // [n,p,qq,rr]
17013
+ // ggml_cont(ctx, // [m,n,q1,r1]
17014
+ // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
17015
+ // tensor->grad), // [m,p,qq,rr]
16263
17016
 
16264
17017
  // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
16265
17018
  // // avoid transpose of src0, rather transpose smaller tensor->grad
16266
17019
  // // and then use ggml_out_prod
16267
- ggml_out_prod(ctx, // [n,p]
16268
- src0, // [n,m]
16269
- ggml_transpose(ctx, // [p,m]
16270
- tensor->grad)), // [m,p]
16271
- inplace);
17020
+ ggml_out_prod(ctx, // [n,p,qq,rr]
17021
+ src0, // [n,m,q1,r1]
17022
+ ggml_transpose(ctx, // [p,m,qq,rr]
17023
+ tensor->grad)), // [m,p,qq,rr]
17024
+ zero_table);
16272
17025
  }
16273
17026
  } break;
16274
17027
  case GGML_OP_OUT_PROD:
@@ -16280,17 +17033,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16280
17033
  // necessary for llama
16281
17034
  if (src0->grad) {
16282
17035
  src0->grad =
16283
- ggml_add_impl(ctx,
17036
+ ggml_add_or_set(ctx,
16284
17037
  src0->grad,
16285
17038
  ggml_scale_impl(ctx, tensor->grad, src1, false),
16286
- inplace);
17039
+ zero_table);
16287
17040
  }
16288
17041
  if (src1->grad) {
16289
17042
  src1->grad =
16290
- ggml_add_impl(ctx,
17043
+ ggml_add_or_set(ctx,
16291
17044
  src1->grad,
16292
17045
  ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
16293
- inplace);
17046
+ zero_table);
16294
17047
  }
16295
17048
  } break;
16296
17049
  case GGML_OP_SET:
@@ -16317,23 +17070,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16317
17070
  }
16318
17071
 
16319
17072
  if (src0->grad) {
16320
- src0->grad = ggml_add_impl(ctx,
17073
+ src0->grad = ggml_add_or_set(ctx,
16321
17074
  src0->grad,
16322
17075
  ggml_acc_impl(ctx,
16323
17076
  tensor->grad,
16324
17077
  ggml_neg(ctx, tensor_grad_view),
16325
17078
  nb1, nb2, nb3, offset, false),
16326
- inplace);
17079
+ zero_table);
16327
17080
  }
16328
17081
 
16329
17082
  if (src1->grad) {
16330
17083
  src1->grad =
16331
- ggml_add_impl(ctx,
17084
+ ggml_add_or_set(ctx,
16332
17085
  src1->grad,
16333
17086
  ggml_reshape(ctx,
16334
17087
  ggml_cont(ctx, tensor_grad_view),
16335
17088
  src1->grad),
16336
- inplace);
17089
+ zero_table);
16337
17090
  }
16338
17091
  } break;
16339
17092
  case GGML_OP_CPY:
@@ -16344,7 +17097,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16344
17097
  // tensor = src0 * 1 + src1 * 0
16345
17098
  if (src0->grad) {
16346
17099
  // dsrc0 = dtensor * 1
16347
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
17100
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16348
17101
  }
16349
17102
  if (src1->grad) {
16350
17103
  // dsrc1 = dtensor * 0 -> noop
@@ -16356,7 +17109,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16356
17109
  if (src0->grad) {
16357
17110
  GGML_ASSERT(ggml_is_contiguous(src0->grad));
16358
17111
  GGML_ASSERT(ggml_is_contiguous(tensor->grad));
16359
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
17112
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16360
17113
  }
16361
17114
  } break;
16362
17115
  case GGML_OP_RESHAPE:
@@ -16364,9 +17117,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16364
17117
  // necessary for llama
16365
17118
  if (src0->grad) {
16366
17119
  src0->grad =
16367
- ggml_add_impl(ctx, src0->grad,
16368
- ggml_reshape(ctx, tensor->grad, src0->grad),
16369
- inplace);
17120
+ ggml_add_or_set(ctx, src0->grad,
17121
+ ggml_reshape(ctx,
17122
+ ggml_is_contiguous(tensor->grad)
17123
+ ? tensor->grad
17124
+ : ggml_cont(ctx, tensor->grad),
17125
+ src0->grad),
17126
+ zero_table);
16370
17127
  }
16371
17128
  } break;
16372
17129
  case GGML_OP_VIEW:
@@ -16395,7 +17152,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16395
17152
  nb3 = (nb3 / n0) * ng;
16396
17153
  }
16397
17154
 
16398
- src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
17155
+ src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
16399
17156
  }
16400
17157
  } break;
16401
17158
  case GGML_OP_PERMUTE:
@@ -16413,14 +17170,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16413
17170
  axes_backward[axis2] = 2;
16414
17171
  axes_backward[axis3] = 3;
16415
17172
  src0->grad =
16416
- ggml_add_impl(ctx, src0->grad,
17173
+ ggml_add_or_set(ctx, src0->grad,
16417
17174
  ggml_permute(ctx,
16418
17175
  tensor->grad,
16419
17176
  axes_backward[0],
16420
17177
  axes_backward[1],
16421
17178
  axes_backward[2],
16422
17179
  axes_backward[3]),
16423
- inplace);
17180
+ zero_table);
16424
17181
  }
16425
17182
  } break;
16426
17183
  case GGML_OP_TRANSPOSE:
@@ -16428,9 +17185,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16428
17185
  // necessary for llama
16429
17186
  if (src0->grad) {
16430
17187
  src0->grad =
16431
- ggml_add_impl(ctx, src0->grad,
17188
+ ggml_add_or_set(ctx, src0->grad,
16432
17189
  ggml_transpose(ctx, tensor->grad),
16433
- inplace);
17190
+ zero_table);
16434
17191
  }
16435
17192
  } break;
16436
17193
  case GGML_OP_GET_ROWS:
@@ -16438,9 +17195,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16438
17195
  // necessary for llama (only for tokenizer)
16439
17196
  if (src0->grad) {
16440
17197
  src0->grad =
16441
- ggml_add_impl(ctx, src0->grad,
17198
+ ggml_add_or_set(ctx, src0->grad,
17199
+ // last ggml_get_rows_back argument src0->grad is only
17200
+ // necessary to setup correct output shape
16442
17201
  ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
16443
- inplace);
17202
+ zero_table);
16444
17203
  }
16445
17204
  if (src1->grad) {
16446
17205
  // noop
@@ -16460,9 +17219,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16460
17219
  if (src0->grad) {
16461
17220
  const int n_past = ((int32_t *) tensor->op_params)[0];
16462
17221
  src0->grad =
16463
- ggml_add_impl(ctx, src0->grad,
17222
+ ggml_add_or_set(ctx, src0->grad,
16464
17223
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
16465
- inplace);
17224
+ zero_table);
16466
17225
  }
16467
17226
  } break;
16468
17227
  case GGML_OP_DIAG_MASK_ZERO:
@@ -16471,9 +17230,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16471
17230
  if (src0->grad) {
16472
17231
  const int n_past = ((int32_t *) tensor->op_params)[0];
16473
17232
  src0->grad =
16474
- ggml_add_impl(ctx, src0->grad,
17233
+ ggml_add_or_set(ctx, src0->grad,
16475
17234
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
16476
- inplace);
17235
+ zero_table);
16477
17236
  }
16478
17237
  } break;
16479
17238
  case GGML_OP_SOFT_MAX:
@@ -16481,9 +17240,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16481
17240
  // necessary for llama
16482
17241
  if (src0->grad) {
16483
17242
  src0->grad =
16484
- ggml_add_impl(ctx, src0->grad,
17243
+ ggml_add_or_set(ctx, src0->grad,
16485
17244
  ggml_soft_max_back(ctx, tensor->grad, tensor),
16486
- inplace);
17245
+ zero_table);
16487
17246
  }
16488
17247
 
16489
17248
  } break;
@@ -16495,7 +17254,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16495
17254
  {
16496
17255
  // necessary for llama
16497
17256
  if (src0->grad) {
16498
- const int n_past = ((int32_t *) tensor->op_params)[0];
17257
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
16499
17258
  const int n_dims = ((int32_t *) tensor->op_params)[1];
16500
17259
  const int mode = ((int32_t *) tensor->op_params)[2];
16501
17260
  const int n_ctx = ((int32_t *) tensor->op_params)[3];
@@ -16508,11 +17267,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16508
17267
  memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
16509
17268
  memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
16510
17269
 
16511
- src0->grad = ggml_add_impl(ctx,
17270
+ src0->grad = ggml_add_or_set(ctx,
16512
17271
  src0->grad,
16513
17272
  ggml_rope_back(ctx,
16514
17273
  tensor->grad,
16515
- n_past,
17274
+ src1,
16516
17275
  n_dims,
16517
17276
  mode,
16518
17277
  n_ctx,
@@ -16520,13 +17279,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16520
17279
  freq_scale,
16521
17280
  xpos_base,
16522
17281
  xpos_down),
16523
- inplace);
17282
+ zero_table);
16524
17283
  }
16525
17284
  } break;
16526
17285
  case GGML_OP_ROPE_BACK:
16527
17286
  {
16528
17287
  if (src0->grad) {
16529
- const int n_past = ((int32_t *) tensor->op_params)[0];
17288
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
16530
17289
  const int n_dims = ((int32_t *) tensor->op_params)[1];
16531
17290
  const int mode = ((int32_t *) tensor->op_params)[2];
16532
17291
  const int n_ctx = ((int32_t *) tensor->op_params)[3];
@@ -16539,11 +17298,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16539
17298
  memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
16540
17299
  memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
16541
17300
 
16542
- src0->grad = ggml_add_impl(ctx,
17301
+ src0->grad = ggml_add_or_set(ctx,
16543
17302
  src0->grad,
16544
17303
  ggml_rope_impl(ctx,
16545
17304
  tensor->grad,
16546
- n_past,
17305
+ src1,
16547
17306
  n_dims,
16548
17307
  mode,
16549
17308
  n_ctx,
@@ -16552,7 +17311,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16552
17311
  xpos_base,
16553
17312
  xpos_down,
16554
17313
  false),
16555
- inplace);
17314
+ zero_table);
16556
17315
  }
16557
17316
  } break;
16558
17317
  case GGML_OP_ALIBI:
@@ -16603,145 +17362,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16603
17362
  masked);
16604
17363
  }
16605
17364
 
16606
- if (src0->grad) {
16607
- struct ggml_tensor * grad_q = NULL;
16608
- const size_t nb0 = flash_grad->nb[0];
16609
- const size_t offset = 0;
16610
- switch(src0->n_dims) {
16611
- case 2:
16612
- {
16613
- grad_q = ggml_view_2d(ctx,
16614
- flash_grad,
16615
- src0->ne[0],
16616
- src0->ne[1],
16617
- nb0*src0->ne[0],
16618
- offset);
16619
- } break;
16620
- case 3:
16621
- {
16622
- grad_q = ggml_view_3d(ctx,
16623
- flash_grad,
16624
- src0->ne[0],
16625
- src0->ne[1],
16626
- src0->ne[2],
16627
- nb0*src0->ne[0],
16628
- nb0*src0->ne[0]*src0->ne[1],
16629
- offset);
16630
- } break;
16631
- case 4:
16632
- {
16633
- grad_q = ggml_view_4d(ctx,
16634
- flash_grad,
16635
- src0->ne[0],
16636
- src0->ne[1],
16637
- src0->ne[2],
16638
- src0->ne[3],
16639
- nb0*src0->ne[0],
16640
- nb0*src0->ne[0]*src0->ne[1],
16641
- nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
16642
- offset);
16643
- } break;
16644
- }
17365
+ struct ggml_tensor * src2 = tensor->src[2];
17366
+ const int64_t elem_q = ggml_nelements(src0);
17367
+ const int64_t elem_k = ggml_nelements(src1);
17368
+ const int64_t elem_v = ggml_nelements(src2);
17369
+
17370
+ enum ggml_type result_type = flash_grad->type;
17371
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
17372
+ const size_t tsize = ggml_type_size(result_type);
17373
+
17374
+ const size_t offs_q = 0;
17375
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
17376
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
16645
17377
 
16646
- src0->grad = ggml_add_impl(ctx,
17378
+ if (src0->grad) {
17379
+ struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
17380
+ struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
17381
+ src0->grad = ggml_add_or_set(ctx,
16647
17382
  src0->grad,
16648
17383
  grad_q,
16649
- inplace);
17384
+ zero_table);
16650
17385
  }
16651
-
16652
17386
  if (src1->grad) {
16653
- struct ggml_tensor * grad_k = NULL;
16654
- const size_t nb0 = flash_grad->nb[0];
16655
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
16656
- switch(src1->n_dims) {
16657
- case 2:
16658
- {
16659
- grad_k = ggml_view_2d(ctx,
16660
- flash_grad,
16661
- src1->ne[0],
16662
- src1->ne[1],
16663
- nb0*src1->ne[0],
16664
- offset);
16665
- } break;
16666
- case 3:
16667
- {
16668
- grad_k = ggml_view_3d(ctx,
16669
- flash_grad,
16670
- src1->ne[0],
16671
- src1->ne[1],
16672
- src1->ne[2],
16673
- nb0*src1->ne[0],
16674
- nb0*src1->ne[0]*src1->ne[1],
16675
- offset);
16676
- } break;
16677
- case 4:
16678
- {
16679
- grad_k = ggml_view_4d(ctx,
16680
- flash_grad,
16681
- src1->ne[0],
16682
- src1->ne[1],
16683
- src1->ne[2],
16684
- src1->ne[3],
16685
- nb0*src1->ne[0],
16686
- nb0*src1->ne[0]*src1->ne[1],
16687
- nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
16688
- offset);
16689
- } break;
16690
- }
16691
-
16692
- src1->grad = ggml_add_impl(ctx,
17387
+ struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
17388
+ struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
17389
+ src1->grad = ggml_add_or_set(ctx,
16693
17390
  src1->grad,
16694
17391
  grad_k,
16695
- inplace);
17392
+ zero_table);
16696
17393
  }
16697
-
16698
- struct ggml_tensor * opt0 = tensor->src[2];
16699
-
16700
- if (opt0->grad) {
16701
- struct ggml_tensor * grad_v = NULL;
16702
- const size_t nb0 = flash_grad->nb[0];
16703
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
16704
- + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
16705
- switch(opt0->n_dims) {
16706
- case 2:
16707
- {
16708
- grad_v = ggml_view_2d(ctx,
16709
- flash_grad,
16710
- opt0->ne[0],
16711
- opt0->ne[1],
16712
- nb0*opt0->ne[0],
16713
- offset);
16714
- } break;
16715
- case 3:
16716
- {
16717
- grad_v = ggml_view_3d(ctx,
16718
- flash_grad,
16719
- opt0->ne[0],
16720
- opt0->ne[1],
16721
- opt0->ne[2],
16722
- nb0*opt0->ne[0],
16723
- nb0*opt0->ne[0]*opt0->ne[1],
16724
- offset);
16725
- } break;
16726
- case 4:
16727
- {
16728
- grad_v = ggml_view_4d(ctx,
16729
- flash_grad,
16730
- opt0->ne[0],
16731
- opt0->ne[1],
16732
- opt0->ne[2],
16733
- opt0->ne[3],
16734
- nb0*opt0->ne[0],
16735
- nb0*opt0->ne[0]*opt0->ne[1],
16736
- nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
16737
- offset);
16738
- } break;
16739
- }
16740
-
16741
- opt0->grad = ggml_add_impl(ctx,
16742
- opt0->grad,
17394
+ if (src2->grad) {
17395
+ struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
17396
+ struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
17397
+ src2->grad = ggml_add_or_set(ctx,
17398
+ src2->grad,
16743
17399
  grad_v,
16744
- inplace);
17400
+ zero_table);
16745
17401
  }
16746
17402
  } break;
16747
17403
  case GGML_OP_FLASH_FF:
@@ -16761,12 +17417,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16761
17417
  {
16762
17418
  if (src0->grad) {
16763
17419
  src0->grad =
16764
- ggml_add_impl(ctx,
17420
+ ggml_add_or_set(ctx,
16765
17421
  src0->grad,
16766
17422
  ggml_mul(ctx,
16767
17423
  ggml_sgn(ctx, src0),
16768
17424
  tensor->grad),
16769
- inplace);
17425
+ zero_table);
16770
17426
  }
16771
17427
  } break;
16772
17428
  case GGML_UNARY_OP_SGN:
@@ -16778,7 +17434,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16778
17434
  case GGML_UNARY_OP_NEG:
16779
17435
  {
16780
17436
  if (src0->grad) {
16781
- src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
17437
+ src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
16782
17438
  }
16783
17439
  } break;
16784
17440
  case GGML_UNARY_OP_STEP:
@@ -16798,12 +17454,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16798
17454
  case GGML_UNARY_OP_RELU:
16799
17455
  {
16800
17456
  if (src0->grad) {
16801
- src0->grad = ggml_add_impl(ctx,
17457
+ src0->grad = ggml_add_or_set(ctx,
16802
17458
  src0->grad,
16803
17459
  ggml_mul(ctx,
16804
17460
  ggml_step(ctx, src0),
16805
17461
  tensor->grad),
16806
- inplace);
17462
+ zero_table);
16807
17463
  }
16808
17464
  } break;
16809
17465
  case GGML_UNARY_OP_GELU:
@@ -16818,10 +17474,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16818
17474
  {
16819
17475
  // necessary for llama
16820
17476
  if (src0->grad) {
16821
- src0->grad = ggml_add_impl(ctx,
17477
+ src0->grad = ggml_add_or_set(ctx,
16822
17478
  src0->grad,
16823
17479
  ggml_silu_back(ctx, src0, tensor->grad),
16824
- inplace);
17480
+ zero_table);
16825
17481
  }
16826
17482
  } break;
16827
17483
  default:
@@ -16844,13 +17500,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16844
17500
  case GGML_OP_CROSS_ENTROPY_LOSS:
16845
17501
  {
16846
17502
  if (src0->grad) {
16847
- src0->grad = ggml_add_impl(ctx,
17503
+ src0->grad = ggml_add_or_set(ctx,
16848
17504
  src0->grad,
16849
17505
  ggml_cross_entropy_loss_back(ctx,
16850
17506
  src0,
16851
17507
  src1,
16852
17508
  tensor->grad),
16853
- inplace);
17509
+ zero_table);
16854
17510
  }
16855
17511
  } break;
16856
17512
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@@ -16866,34 +17522,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16866
17522
  GGML_ASSERT(false);
16867
17523
  } break;
16868
17524
  }
16869
- }
16870
17525
 
16871
- static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
16872
-
16873
- static size_t hash(void * p) {
16874
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
16875
- }
16876
-
16877
- static bool hash_insert(void * hash_table[], void * p) {
16878
- size_t h = hash(p);
16879
-
16880
- // linear probing
16881
- size_t i = h;
16882
- while (hash_table[i] != NULL && hash_table[i] != p) {
16883
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
16884
- if (i == h) {
16885
- // hash table is full
16886
- GGML_ASSERT(false);
17526
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
17527
+ if (tensor->src[i] && tensor->src[i]->grad) {
17528
+ GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
16887
17529
  }
16888
17530
  }
16889
-
16890
- if (hash_table[i] == p) {
16891
- return true;
16892
- }
16893
-
16894
- // insert
16895
- hash_table[i] = p;
16896
- return false;
16897
17531
  }
16898
17532
 
16899
17533
  static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
@@ -16911,8 +17545,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
16911
17545
  }
16912
17546
 
16913
17547
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
16914
- if (node->src[i]) {
16915
- ggml_visit_parents(cgraph, node->src[i]);
17548
+ const int k =
17549
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
17550
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
17551
+ /* unknown order, just fall back to using i*/ i;
17552
+ if (node->src[k]) {
17553
+ ggml_visit_parents(cgraph, node->src[k]);
16916
17554
  }
16917
17555
  }
16918
17556
 
@@ -16971,6 +17609,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
16971
17609
  /*.grads =*/ { NULL },
16972
17610
  /*.leafs =*/ { NULL },
16973
17611
  /*.hash_table =*/ { NULL },
17612
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
16974
17613
  /*.perf_runs =*/ 0,
16975
17614
  /*.perf_cycles =*/ 0,
16976
17615
  /*.perf_time_us =*/ 0,
@@ -16996,12 +17635,22 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
16996
17635
  }
16997
17636
  }
16998
17637
 
17638
+ // remember original gradients which start with zero values
17639
+ void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
17640
+ memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
17641
+ for (int i = 0; i < gf->n_nodes; i++) {
17642
+ if (gf->grads[i]) {
17643
+ hash_insert(zero_table, gf->grads[i]);
17644
+ }
17645
+ }
17646
+
16999
17647
  for (int i = gf->n_nodes - 1; i >= 0; i--) {
17000
17648
  struct ggml_tensor * node = gf->nodes[i];
17001
17649
 
17002
- // because we detached the grad nodes from the original graph, we can afford inplace operations
17650
+ // inplace operations to add gradients are not created by ggml_compute_backward
17651
+ // use allocator to automatically make inplace operations
17003
17652
  if (node->grad) {
17004
- ggml_compute_backward(ctx, node, keep);
17653
+ ggml_compute_backward(ctx, node, zero_table);
17005
17654
  }
17006
17655
  }
17007
17656
 
@@ -17013,6 +17662,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
17013
17662
  ggml_build_forward_expand(gb, node->grad);
17014
17663
  }
17015
17664
  }
17665
+
17666
+ free(zero_table);
17016
17667
  }
17017
17668
 
17018
17669
  struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
@@ -17032,6 +17683,7 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
17032
17683
  /*.grads =*/ { NULL },
17033
17684
  /*.leafs =*/ { NULL },
17034
17685
  /*.hash_table =*/ { NULL },
17686
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
17035
17687
  /*.perf_runs =*/ 0,
17036
17688
  /*.perf_cycles =*/ 0,
17037
17689
  /*.perf_time_us =*/ 0,
@@ -17283,10 +17935,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
17283
17935
  } else {
17284
17936
  // wait for other threads to finish
17285
17937
  const int last = node_n;
17286
- do {
17287
- //sched_yield();
17938
+ while (true) {
17939
+ // TODO: this sched_yield can have significant impact on the performance - either positive or negative
17940
+ // depending on the workload and the operating system.
17941
+ // since it is not clear what is the best approach, it should potentially become user-configurable
17942
+ // ref: https://github.com/ggerganov/ggml/issues/291
17943
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
17944
+ sched_yield();
17945
+ #endif
17946
+
17288
17947
  node_n = atomic_load(&state->shared->node_n);
17289
- } while (node_n == last);
17948
+ if (node_n != last) break;
17949
+ };
17290
17950
  }
17291
17951
 
17292
17952
  // check if we should stop
@@ -17414,7 +18074,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
17414
18074
  } break;
17415
18075
  case GGML_OP_CONCAT:
17416
18076
  case GGML_OP_MUL_MAT:
17417
- case GGML_OP_OUT_PROD:
17418
18077
  {
17419
18078
  n_tasks = n_threads;
17420
18079
 
@@ -17456,6 +18115,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
17456
18115
  cur = 0;
17457
18116
  }
17458
18117
 
18118
+ work_size = MAX(work_size, cur);
18119
+ } break;
18120
+ case GGML_OP_OUT_PROD:
18121
+ {
18122
+ n_tasks = n_threads;
18123
+
18124
+ size_t cur = 0;
18125
+
18126
+ if (ggml_is_quantized(node->src[0]->type)) {
18127
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
18128
+ }
18129
+
17459
18130
  work_size = MAX(work_size, cur);
17460
18131
  } break;
17461
18132
  case GGML_OP_SCALE:
@@ -18337,10 +19008,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
18337
19008
  for (int i = 0; i < cgraph->n_leafs; i++) {
18338
19009
  struct ggml_tensor * node = cgraph->leafs[i];
18339
19010
 
18340
- GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
19011
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
18341
19012
  i,
18342
19013
  node->ne[0], node->ne[1],
18343
- ggml_op_name(node->op));
19014
+ ggml_op_name(node->op),
19015
+ ggml_get_name(node));
18344
19016
  }
18345
19017
 
18346
19018
  for (int i = 0; i < GGML_OP_COUNT; i++) {
@@ -18548,7 +19220,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
18548
19220
  }
18549
19221
 
18550
19222
  static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
18551
- int i = 0;
19223
+ int64_t i = 0;
18552
19224
  for (int p = 0; p < np; ++p) {
18553
19225
  const int64_t ne = ggml_nelements(ps[p]) ;
18554
19226
  // TODO: add function to get all elements at once
@@ -18558,6 +19230,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
18558
19230
  }
18559
19231
  }
18560
19232
 
19233
+ static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
19234
+ int64_t i = 0;
19235
+ for (int p = 0; p < np; ++p) {
19236
+ const int64_t ne = ggml_nelements(ps[p]) ;
19237
+ // TODO: add function to get all elements at once
19238
+ for (int64_t j = 0; j < ne; ++j) {
19239
+ g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
19240
+ }
19241
+ }
19242
+ }
19243
+
18561
19244
  //
18562
19245
  // ADAM
18563
19246
  //
@@ -18606,26 +19289,43 @@ static enum ggml_opt_result ggml_opt_adam(
18606
19289
  const float eps = params.adam.eps;
18607
19290
  const float gclip = params.adam.gclip;
18608
19291
  const int decay_min_ndim = params.adam.decay_min_ndim;
19292
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
19293
+ const float accum_norm = 1.0f / (float) n_accum;
18609
19294
 
19295
+ float * g = opt->adam.g->data; // gradients
18610
19296
  float * m = opt->adam.m->data; // first moment
18611
19297
  float * v = opt->adam.v->data; // second moment
18612
19298
 
18613
19299
  float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
18614
19300
 
18615
- if (callback) {
18616
- callback(callback_data, &sched);
18617
- }
18618
-
18619
- // compute the function value
18620
- ggml_graph_reset (gf);
18621
- ggml_set_f32 (f->grad, 1.0f);
18622
-
18623
19301
  struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
18624
19302
  struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
18625
19303
  cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
18626
- ggml_graph_compute(gb, &cplan);
18627
19304
 
18628
- opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
19305
+ bool cancel = false;
19306
+
19307
+ // compute the function value
19308
+ float fx = 0;
19309
+ ggml_set_zero(opt->adam.g);
19310
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19311
+ if (callback) {
19312
+ callback(callback_data, accum_step, &sched, &cancel);
19313
+ if (cancel) {
19314
+ break;
19315
+ }
19316
+ }
19317
+ // ggml_graph_reset (gf);
19318
+ ggml_set_f32 (f->grad, 1.0f);
19319
+ ggml_graph_compute(gb, &cplan);
19320
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19321
+ fx += ggml_get_f32_1d(f, 0);
19322
+ }
19323
+ if (cancel) {
19324
+ return GGML_OPT_DID_NOT_CONVERGE;
19325
+ }
19326
+ fx *= accum_norm;
19327
+
19328
+ opt->adam.fx_prev = fx;
18629
19329
  opt->adam.fx_best = opt->adam.fx_prev;
18630
19330
  if (pf) {
18631
19331
  pf[opt->iter % params.past] = opt->adam.fx_prev;
@@ -18648,6 +19348,9 @@ static enum ggml_opt_result ggml_opt_adam(
18648
19348
 
18649
19349
  // run the optimizer
18650
19350
  for (int t = 0; t < params.adam.n_iter; ++t) {
19351
+ if (cancel) {
19352
+ break;
19353
+ }
18651
19354
  opt->iter = iter0 + t + 1;
18652
19355
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
18653
19356
 
@@ -18670,12 +19373,8 @@ static enum ggml_opt_result ggml_opt_adam(
18670
19373
  if (gclip > 0.0f) {
18671
19374
  // gradient clipping
18672
19375
  ggml_float sum = 0.0;
18673
- for (int p = 0; p < np; ++p) {
18674
- const int64_t ne = ggml_nelements(ps[p]);
18675
- for (int64_t j = 0; j < ne; ++j) {
18676
- float g = ggml_get_f32_1d(ps[p]->grad, j);
18677
- sum += (ggml_float)(g*g);
18678
- }
19376
+ for (int64_t i = 0; i < nx; ++i) {
19377
+ sum += (ggml_float)(g[i]*g[i]);
18679
19378
  }
18680
19379
  ggml_float norm = sqrt(sum);
18681
19380
  if (norm > (ggml_float) gclip) {
@@ -18689,10 +19388,10 @@ static enum ggml_opt_result ggml_opt_adam(
18689
19388
  const int64_t ne = ggml_nelements(ps[p]);
18690
19389
  const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
18691
19390
  for (int64_t j = 0; j < ne; ++j) {
18692
- float x = ggml_get_f32_1d(ps[p], j);
18693
- float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
18694
- m[i] = m[i]*beta1 + g*(1.0f - beta1);
18695
- v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
19391
+ float x = ggml_get_f32_1d(ps[p], j);
19392
+ float g_ = g[i]*gnorm;
19393
+ m[i] = m[i]*beta1 + g_*(1.0f - beta1);
19394
+ v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
18696
19395
  float mh = m[i]*beta1h;
18697
19396
  float vh = v[i]*beta2h;
18698
19397
  vh = sqrtf(vh) + eps;
@@ -18703,16 +19402,26 @@ static enum ggml_opt_result ggml_opt_adam(
18703
19402
  }
18704
19403
  }
18705
19404
 
18706
- if (callback) {
18707
- callback(callback_data, &sched);
19405
+ fx = 0;
19406
+ ggml_set_zero(opt->adam.g);
19407
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19408
+ if (callback) {
19409
+ callback(callback_data, accum_step, &sched, &cancel);
19410
+ if (cancel) {
19411
+ break;
19412
+ }
19413
+ }
19414
+ // ggml_graph_reset (gf);
19415
+ ggml_set_f32 (f->grad, 1.0f);
19416
+ ggml_graph_compute(gb, &cplan);
19417
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19418
+ fx += ggml_get_f32_1d(f, 0);
18708
19419
  }
19420
+ if (cancel) {
19421
+ break;
19422
+ }
19423
+ fx *= accum_norm;
18709
19424
 
18710
- ggml_graph_reset (gf);
18711
- ggml_set_f32 (f->grad, 1.0f);
18712
-
18713
- ggml_graph_compute(gb, &cplan);
18714
-
18715
- const float fx = ggml_get_f32_1d(f, 0);
18716
19425
  opt->loss_after = fx;
18717
19426
 
18718
19427
 
@@ -18792,11 +19501,11 @@ static enum ggml_opt_result linesearch_backtracking(
18792
19501
  float * step,
18793
19502
  const float * xp,
18794
19503
  struct ggml_tensor * f,
18795
- struct ggml_cgraph * gf,
18796
19504
  struct ggml_cgraph * gb,
18797
19505
  struct ggml_cplan * cplan,
18798
19506
  const int np,
18799
19507
  struct ggml_tensor * ps[],
19508
+ bool * cancel,
18800
19509
  ggml_opt_callback callback,
18801
19510
  void * callback_data) {
18802
19511
  int count = 0;
@@ -18810,6 +19519,9 @@ static enum ggml_opt_result linesearch_backtracking(
18810
19519
  const float dec = 0.5f;
18811
19520
  const float inc = 2.1f;
18812
19521
 
19522
+ const int n_accum = MAX(1, params->n_gradient_accumulation);
19523
+ const float accum_norm = 1.0f / (float) n_accum;
19524
+
18813
19525
  if (*step <= 0.f) {
18814
19526
  return GGML_LINESEARCH_INVALID_PARAMETERS;
18815
19527
  }
@@ -18826,13 +19538,7 @@ static enum ggml_opt_result linesearch_backtracking(
18826
19538
  finit = *fx;
18827
19539
  dgtest = params->lbfgs.ftol*dginit;
18828
19540
 
18829
- while (true) {
18830
- if (callback) {
18831
- // LBFG-S does not support learning rate -> ignore learning schedule
18832
- float sched = 0;
18833
- callback(callback_data, &sched);
18834
- }
18835
-
19541
+ while (!*cancel) {
18836
19542
  ggml_vec_cpy_f32(nx, x, xp);
18837
19543
  ggml_vec_mad_f32(nx, x, d, *step);
18838
19544
 
@@ -18840,14 +19546,28 @@ static enum ggml_opt_result linesearch_backtracking(
18840
19546
  {
18841
19547
  ggml_opt_set_params(np, ps, x);
18842
19548
 
18843
- ggml_graph_reset (gf);
18844
- ggml_set_f32 (f->grad, 1.0f);
18845
-
18846
- ggml_graph_compute(gb, cplan);
18847
-
18848
- ggml_opt_get_grad(np, ps, g);
19549
+ *fx = 0;
19550
+ memset(g, 0, sizeof(float)*nx);
19551
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19552
+ if (callback) {
19553
+ // LBFG-S does not support learning rate -> ignore learning schedule
19554
+ float sched = 0;
19555
+ callback(callback_data, accum_step, &sched, cancel);
19556
+ if (*cancel) {
19557
+ break;
19558
+ }
19559
+ }
19560
+ // ggml_graph_reset (gf);
19561
+ ggml_set_f32 (f->grad, 1.0f);
19562
+ ggml_graph_compute(gb, cplan);
19563
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19564
+ *fx += ggml_get_f32_1d(f, 0);
19565
+ }
19566
+ if (*cancel) {
19567
+ break;
19568
+ }
19569
+ *fx *= accum_norm;
18849
19570
 
18850
- *fx = ggml_get_f32_1d(f, 0);
18851
19571
  }
18852
19572
 
18853
19573
  ++count;
@@ -18893,7 +19613,7 @@ static enum ggml_opt_result linesearch_backtracking(
18893
19613
  (*step) *= width;
18894
19614
  }
18895
19615
 
18896
- return GGML_LINESEARCH_FAIL;
19616
+ GGML_UNREACHABLE();
18897
19617
  }
18898
19618
 
18899
19619
  static enum ggml_opt_result ggml_opt_lbfgs(
@@ -18948,6 +19668,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18948
19668
 
18949
19669
  float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
18950
19670
 
19671
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
19672
+ const float accum_norm = 1.0f / (float) n_accum;
19673
+
18951
19674
  float fx = 0.0f; // cost function value
18952
19675
  float xnorm = 0.0f; // ||x||
18953
19676
  float gnorm = 0.0f; // ||g||
@@ -18961,24 +19684,33 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18961
19684
  float * lm_s = opt->lbfgs.lms->data;
18962
19685
  float * lm_y = opt->lbfgs.lmy->data;
18963
19686
 
18964
- if (callback) {
18965
- // LBFG-S does not support learning rate -> ignore learning schedule
18966
- float sched = 0;
18967
- callback(callback_data, &sched);
18968
- }
19687
+ bool cancel = false;
18969
19688
 
18970
19689
  // evaluate the function value and its gradient
18971
19690
  {
18972
19691
  ggml_opt_set_params(np, ps, x);
18973
19692
 
18974
- ggml_graph_reset (gf);
18975
- ggml_set_f32 (f->grad, 1.0f);
18976
-
18977
- ggml_graph_compute(gb, &cplan);
18978
-
18979
- ggml_opt_get_grad(np, ps, g);
18980
-
18981
- fx = ggml_get_f32_1d(f, 0);
19693
+ fx = 0;
19694
+ memset(g, 0, sizeof(float)*nx);
19695
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19696
+ if (callback) {
19697
+ // LBFG-S does not support learning rate -> ignore learning schedule
19698
+ float sched = 0;
19699
+ callback(callback_data, accum_step, &sched, &cancel);
19700
+ if (cancel) {
19701
+ break;
19702
+ }
19703
+ }
19704
+ // ggml_graph_reset (gf);
19705
+ ggml_set_f32 (f->grad, 1.0f);
19706
+ ggml_graph_compute(gb, &cplan);
19707
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19708
+ fx += ggml_get_f32_1d(f, 0);
19709
+ }
19710
+ if (cancel) {
19711
+ return GGML_OPT_DID_NOT_CONVERGE;
19712
+ }
19713
+ fx *= accum_norm;
18982
19714
 
18983
19715
  opt->loss_before = fx;
18984
19716
  opt->loss_after = fx;
@@ -19036,7 +19768,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19036
19768
  ggml_vec_cpy_f32(nx, xp, x);
19037
19769
  ggml_vec_cpy_f32(nx, gp, g);
19038
19770
 
19039
- ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data);
19771
+ ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
19772
+ if (!cancel) {
19773
+ break;
19774
+ }
19040
19775
 
19041
19776
  if (ls < 0) {
19042
19777
  // linesearch failed - go back to the previous point and return
@@ -19145,7 +19880,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19145
19880
  step[0] = 1.0;
19146
19881
  }
19147
19882
 
19148
- return GGML_OPT_DID_NOT_CONVERGE;
19883
+ GGML_UNREACHABLE();
19149
19884
  }
19150
19885
 
19151
19886
  struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
@@ -19165,6 +19900,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
19165
19900
  .print_forward_graph = true,
19166
19901
  .print_backward_graph = true,
19167
19902
 
19903
+ .n_gradient_accumulation = 1,
19904
+
19168
19905
  .adam = {
19169
19906
  .n_iter = 10000,
19170
19907
  .sched = 1.000f,
@@ -19193,6 +19930,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
19193
19930
  .print_forward_graph = true,
19194
19931
  .print_backward_graph = true,
19195
19932
 
19933
+ .n_gradient_accumulation = 1,
19934
+
19196
19935
  .lbfgs = {
19197
19936
  .m = 6,
19198
19937
  .n_iter = 100,
@@ -19223,13 +19962,32 @@ GGML_API void ggml_opt_init(
19223
19962
  opt->iter = 0;
19224
19963
  opt->nx = nx;
19225
19964
  opt->just_initialized = true;
19965
+ if (opt->ctx == NULL) {
19966
+ struct ggml_init_params ctx_opt_params;
19967
+ if (opt->params.type == GGML_OPT_ADAM) {
19968
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
19969
+ if (opt->params.past > 0) {
19970
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
19971
+ }
19972
+ } else if (opt->params.type == GGML_OPT_LBFGS) {
19973
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
19974
+ if (opt->params.past > 0) {
19975
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
19976
+ }
19977
+ }
19978
+ ctx_opt_params.mem_buffer = NULL;
19979
+ ctx_opt_params.no_alloc = false;
19980
+
19981
+ opt->ctx = ggml_init(ctx_opt_params);
19982
+ }
19226
19983
  switch (opt->params.type) {
19227
19984
  case GGML_OPT_ADAM:
19228
19985
  {
19229
- opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19230
- opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19986
+ opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19987
+ opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19988
+ opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19231
19989
  opt->adam.pf = params.past > 0
19232
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
19990
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
19233
19991
  : NULL;
19234
19992
  ggml_set_zero(opt->adam.m);
19235
19993
  ggml_set_zero(opt->adam.v);
@@ -19239,18 +19997,18 @@ GGML_API void ggml_opt_init(
19239
19997
  } break;
19240
19998
  case GGML_OPT_LBFGS:
19241
19999
  {
19242
- opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19243
- opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19244
- opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19245
- opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19246
- opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
20000
+ opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20001
+ opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20002
+ opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20003
+ opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20004
+ opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19247
20005
  opt->lbfgs.pf = params.past > 0
19248
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
20006
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
19249
20007
  : NULL;
19250
- opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
19251
- opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
19252
- opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
19253
- opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
20008
+ opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
20009
+ opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
20010
+ opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
20011
+ opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
19254
20012
  ggml_set_zero(opt->lbfgs.x);
19255
20013
  ggml_set_zero(opt->lbfgs.xp);
19256
20014
  ggml_set_zero(opt->lbfgs.g);
@@ -19856,10 +20614,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
19856
20614
  } break;
19857
20615
  case GGUF_TYPE_ARRAY:
19858
20616
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
19859
- };
20617
+ }
19860
20618
  } break;
19861
20619
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
19862
- };
20620
+ }
19863
20621
 
19864
20622
  if (!ok) {
19865
20623
  break;
@@ -20099,27 +20857,27 @@ const char * gguf_type_name(enum gguf_type type) {
20099
20857
  return GGUF_TYPE_NAME[type];
20100
20858
  }
20101
20859
 
20102
- int gguf_get_version(struct gguf_context * ctx) {
20860
+ int gguf_get_version(const struct gguf_context * ctx) {
20103
20861
  return ctx->header.version;
20104
20862
  }
20105
20863
 
20106
- size_t gguf_get_alignment(struct gguf_context * ctx) {
20864
+ size_t gguf_get_alignment(const struct gguf_context * ctx) {
20107
20865
  return ctx->alignment;
20108
20866
  }
20109
20867
 
20110
- size_t gguf_get_data_offset(struct gguf_context * ctx) {
20868
+ size_t gguf_get_data_offset(const struct gguf_context * ctx) {
20111
20869
  return ctx->offset;
20112
20870
  }
20113
20871
 
20114
- void * gguf_get_data(struct gguf_context * ctx) {
20872
+ void * gguf_get_data(const struct gguf_context * ctx) {
20115
20873
  return ctx->data;
20116
20874
  }
20117
20875
 
20118
- int gguf_get_n_kv(struct gguf_context * ctx) {
20876
+ int gguf_get_n_kv(const struct gguf_context * ctx) {
20119
20877
  return ctx->header.n_kv;
20120
20878
  }
20121
20879
 
20122
- int gguf_find_key(struct gguf_context * ctx, const char * key) {
20880
+ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
20123
20881
  // return -1 if key not found
20124
20882
  int keyfound = -1;
20125
20883
 
@@ -20135,85 +20893,101 @@ int gguf_find_key(struct gguf_context * ctx, const char * key) {
20135
20893
  return keyfound;
20136
20894
  }
20137
20895
 
20138
- const char * gguf_get_key(struct gguf_context * ctx, int i) {
20139
- return ctx->kv[i].key.data;
20896
+ const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
20897
+ return ctx->kv[key_id].key.data;
20140
20898
  }
20141
20899
 
20142
- enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) {
20143
- return ctx->kv[i].type;
20900
+ enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
20901
+ return ctx->kv[key_id].type;
20144
20902
  }
20145
20903
 
20146
- enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) {
20147
- return ctx->kv[i].value.arr.type;
20904
+ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
20905
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20906
+ return ctx->kv[key_id].value.arr.type;
20148
20907
  }
20149
20908
 
20150
- const void * gguf_get_arr_data(struct gguf_context * ctx, int i) {
20151
- return ctx->kv[i].value.arr.data;
20909
+ const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
20910
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20911
+ return ctx->kv[key_id].value.arr.data;
20152
20912
  }
20153
20913
 
20154
- const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
20914
+ const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
20915
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20155
20916
  struct gguf_kv * kv = &ctx->kv[key_id];
20156
20917
  struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
20157
20918
  return str->data;
20158
20919
  }
20159
20920
 
20160
- int gguf_get_arr_n(struct gguf_context * ctx, int i) {
20161
- return ctx->kv[i].value.arr.n;
20921
+ int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
20922
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20923
+ return ctx->kv[key_id].value.arr.n;
20162
20924
  }
20163
20925
 
20164
- uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) {
20165
- return ctx->kv[i].value.uint8;
20926
+ uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
20927
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
20928
+ return ctx->kv[key_id].value.uint8;
20166
20929
  }
20167
20930
 
20168
- int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) {
20169
- return ctx->kv[i].value.int8;
20931
+ int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
20932
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
20933
+ return ctx->kv[key_id].value.int8;
20170
20934
  }
20171
20935
 
20172
- uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) {
20173
- return ctx->kv[i].value.uint16;
20936
+ uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
20937
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
20938
+ return ctx->kv[key_id].value.uint16;
20174
20939
  }
20175
20940
 
20176
- int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) {
20177
- return ctx->kv[i].value.int16;
20941
+ int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
20942
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
20943
+ return ctx->kv[key_id].value.int16;
20178
20944
  }
20179
20945
 
20180
- uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) {
20181
- return ctx->kv[i].value.uint32;
20946
+ uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
20947
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
20948
+ return ctx->kv[key_id].value.uint32;
20182
20949
  }
20183
20950
 
20184
- int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) {
20185
- return ctx->kv[i].value.int32;
20951
+ int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
20952
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
20953
+ return ctx->kv[key_id].value.int32;
20186
20954
  }
20187
20955
 
20188
- float gguf_get_val_f32(struct gguf_context * ctx, int i) {
20189
- return ctx->kv[i].value.float32;
20956
+ float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
20957
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
20958
+ return ctx->kv[key_id].value.float32;
20190
20959
  }
20191
20960
 
20192
- uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) {
20193
- return ctx->kv[i].value.uint64;
20961
+ uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
20962
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
20963
+ return ctx->kv[key_id].value.uint64;
20194
20964
  }
20195
20965
 
20196
- int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) {
20197
- return ctx->kv[i].value.int64;
20966
+ int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
20967
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
20968
+ return ctx->kv[key_id].value.int64;
20198
20969
  }
20199
20970
 
20200
- double gguf_get_val_f64(struct gguf_context * ctx, int i) {
20201
- return ctx->kv[i].value.float64;
20971
+ double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
20972
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
20973
+ return ctx->kv[key_id].value.float64;
20202
20974
  }
20203
20975
 
20204
- bool gguf_get_val_bool(struct gguf_context * ctx, int i) {
20205
- return ctx->kv[i].value.bool_;
20976
+ bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
20977
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
20978
+ return ctx->kv[key_id].value.bool_;
20206
20979
  }
20207
20980
 
20208
- const char * gguf_get_val_str (struct gguf_context * ctx, int i) {
20209
- return ctx->kv[i].value.str.data;
20981
+ const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
20982
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
20983
+ return ctx->kv[key_id].value.str.data;
20210
20984
  }
20211
20985
 
20212
- int gguf_get_n_tensors(struct gguf_context * ctx) {
20986
+ int gguf_get_n_tensors(const struct gguf_context * ctx) {
20213
20987
  return ctx->header.n_tensors;
20214
20988
  }
20215
20989
 
20216
- int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
20990
+ int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
20217
20991
  // return -1 if tensor not found
20218
20992
  int tensorfound = -1;
20219
20993
 
@@ -20229,11 +21003,11 @@ int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
20229
21003
  return tensorfound;
20230
21004
  }
20231
21005
 
20232
- size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i) {
21006
+ size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
20233
21007
  return ctx->infos[i].offset;
20234
21008
  }
20235
21009
 
20236
- char * gguf_get_tensor_name(struct gguf_context * ctx, int i) {
21010
+ char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
20237
21011
  return ctx->infos[i].name.data;
20238
21012
  }
20239
21013
 
@@ -20516,7 +21290,7 @@ static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_si
20516
21290
  buf->offset += el_size;
20517
21291
  }
20518
21292
 
20519
- static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
21293
+ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
20520
21294
  // write header
20521
21295
  gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
20522
21296
  gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
@@ -20571,10 +21345,10 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
20571
21345
  } break;
20572
21346
  case GGUF_TYPE_ARRAY:
20573
21347
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
20574
- };
21348
+ }
20575
21349
  } break;
20576
21350
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
20577
- };
21351
+ }
20578
21352
  }
20579
21353
 
20580
21354
  // write tensor infos
@@ -20631,7 +21405,7 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
20631
21405
  }
20632
21406
  }
20633
21407
 
20634
- void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta) {
21408
+ void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
20635
21409
  FILE * file = fopen(fname, "wb");
20636
21410
  if (!file) {
20637
21411
  GGML_ASSERT(false && "failed to open file for writing");
@@ -20648,7 +21422,7 @@ void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only
20648
21422
  fclose(file);
20649
21423
  }
20650
21424
 
20651
- size_t gguf_get_meta_size(struct gguf_context * ctx) {
21425
+ size_t gguf_get_meta_size(const struct gguf_context * ctx) {
20652
21426
  // no allocs - only compute size
20653
21427
  struct gguf_buf buf = gguf_buf_init(0);
20654
21428
 
@@ -20657,7 +21431,7 @@ size_t gguf_get_meta_size(struct gguf_context * ctx) {
20657
21431
  return buf.offset;
20658
21432
  }
20659
21433
 
20660
- void gguf_get_meta_data(struct gguf_context * ctx, void * data) {
21434
+ void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
20661
21435
  struct gguf_buf buf = gguf_buf_init(16*1024);
20662
21436
 
20663
21437
  gguf_write_to_buf(ctx, &buf, true);
@@ -20733,6 +21507,14 @@ int ggml_cpu_has_arm_fma(void) {
20733
21507
  #endif
20734
21508
  }
20735
21509
 
21510
+ int ggml_cpu_has_metal(void) {
21511
+ #if defined(GGML_USE_METAL)
21512
+ return 1;
21513
+ #else
21514
+ return 0;
21515
+ #endif
21516
+ }
21517
+
20736
21518
  int ggml_cpu_has_f16c(void) {
20737
21519
  #if defined(__F16C__)
20738
21520
  return 1;