llama_cpp 0.5.3 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -89,7 +89,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo
89
89
 
90
90
  static int pthread_join(pthread_t thread, void * unused) {
91
91
  (void) unused;
92
- return (int) WaitForSingleObject(thread, INFINITE);
92
+ int ret = (int) WaitForSingleObject(thread, INFINITE);
93
+ CloseHandle(thread);
94
+ return ret;
93
95
  }
94
96
 
95
97
  static int sched_yield (void) {
@@ -134,6 +136,7 @@ typedef void * thread_ret_t;
134
136
 
135
137
  #define GGML_SOFT_MAX_UNROLL 4
136
138
  #define GGML_VEC_DOT_UNROLL 2
139
+ #define GGML_VEC_MAD_UNROLL 32
137
140
 
138
141
  //
139
142
  // logging
@@ -242,18 +245,18 @@ inline static void * ggml_aligned_malloc(size_t size) {
242
245
  //
243
246
 
244
247
  #define GGML_TENSOR_UNARY_OP_LOCALS \
245
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
246
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \
247
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \
248
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
248
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
249
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
250
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
251
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
249
252
 
250
253
  #define GGML_TENSOR_BINARY_OP_LOCALS \
251
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
252
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \
253
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \
254
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \
255
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \
256
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
254
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
255
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
256
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
257
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
258
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
259
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
257
260
 
258
261
  #if defined(GGML_USE_ACCELERATE)
259
262
  #include <Accelerate/Accelerate.h>
@@ -1863,7 +1866,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1863
1866
  #define GGML_F16x8_ADD vaddq_f16
1864
1867
  #define GGML_F16x8_MUL vmulq_f16
1865
1868
  #define GGML_F16x8_REDUCE(res, x) \
1866
- { \
1869
+ do { \
1867
1870
  int offset = GGML_F16_ARR >> 1; \
1868
1871
  for (int i = 0; i < offset; ++i) { \
1869
1872
  x[i] = vaddq_f16(x[i], x[offset+i]); \
@@ -1879,7 +1882,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1879
1882
  const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1880
1883
  const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
1881
1884
  res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1882
- }
1885
+ } while (0)
1883
1886
 
1884
1887
  #define GGML_F16_VEC GGML_F16x8
1885
1888
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
@@ -1940,7 +1943,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1940
1943
  #define GGML_F32x8_ADD _mm256_add_ps
1941
1944
  #define GGML_F32x8_MUL _mm256_mul_ps
1942
1945
  #define GGML_F32x8_REDUCE(res, x) \
1943
- { \
1946
+ do { \
1944
1947
  int offset = GGML_F32_ARR >> 1; \
1945
1948
  for (int i = 0; i < offset; ++i) { \
1946
1949
  x[i] = _mm256_add_ps(x[i], x[offset+i]); \
@@ -1957,7 +1960,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1957
1960
  _mm256_extractf128_ps(x[0], 1)); \
1958
1961
  const __m128 t1 = _mm_hadd_ps(t0, t0); \
1959
1962
  res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
1960
- }
1963
+ } while (0)
1961
1964
  // TODO: is this optimal ?
1962
1965
 
1963
1966
  #define GGML_F32_VEC GGML_F32x8
@@ -3707,6 +3710,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
3707
3710
  #endif
3708
3711
  }
3709
3712
 
3713
+ // xs and vs are byte strides of x and v
3714
+ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
3715
+
3716
+ const float * restrict x[GGML_VEC_MAD_UNROLL];
3717
+ const float * restrict v[GGML_VEC_MAD_UNROLL];
3718
+
3719
+ for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
3720
+ x[i] = (const float *) ((const char *) xv + i*xs);
3721
+ v[i] = (const float *) ((const char *) vv + i*vs);
3722
+ }
3723
+
3724
+ #if defined(GGML_SIMD)
3725
+ const int np = (n & ~(GGML_F32_STEP - 1));
3726
+
3727
+ GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
3728
+
3729
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3730
+ vx[k] = GGML_F32_VEC_SET1(v[k][0]);
3731
+ }
3732
+
3733
+ GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
3734
+ GGML_F32_VEC ay[GGML_F32_ARR];
3735
+
3736
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
3737
+ for (int j = 0; j < GGML_F32_ARR; j++) {
3738
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
3739
+
3740
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3741
+ ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
3742
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
3743
+ }
3744
+
3745
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
3746
+ }
3747
+ }
3748
+
3749
+ // leftovers
3750
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3751
+ for (int i = np; i < n; ++i) {
3752
+ y[i] += x[k][i]*v[k][0];
3753
+ }
3754
+ }
3755
+ #else
3756
+ // scalar
3757
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3758
+ for (int i = 0; i < n; ++i) {
3759
+ y[i] += x[k][i]*v[k][0];
3760
+ }
3761
+ }
3762
+ #endif
3763
+ }
3764
+
3710
3765
  //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
3711
3766
  inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
3712
3767
  #if defined(GGML_USE_ACCELERATE)
@@ -4392,10 +4447,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
4392
4447
  static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
4393
4448
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4394
4449
 
4395
- return
4396
- (t0->ne[1] == t1->ne[1]) &&
4397
- (t0->ne[2] == t1->ne[2]) &&
4398
- (t0->ne[3] == t1->ne[3]);
4450
+ return (t0->ne[1] == t1->ne[1]) &&
4451
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
4452
+ (t1->ne[3]%t0->ne[3] == 0);
4399
4453
  }
4400
4454
 
4401
4455
  enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
@@ -5065,43 +5119,78 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
5065
5119
  return tensor;
5066
5120
  }
5067
5121
 
5122
+ void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
5123
+ const int64_t ne2 = tensor->ne[2];
5124
+ const int64_t ne1 = tensor->ne[1];
5125
+ const int64_t ne0 = tensor->ne[0];
5126
+
5127
+ const int64_t i3_ = (i/(ne2*ne1*ne0));
5128
+ const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
5129
+ const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
5130
+ const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
5131
+
5132
+ if (i0) {
5133
+ * i0 = i0_;
5134
+ }
5135
+ if (i1) {
5136
+ * i1 = i1_;
5137
+ }
5138
+ if (i2) {
5139
+ * i2 = i2_;
5140
+ }
5141
+ if (i3) {
5142
+ * i3 = i3_;
5143
+ }
5144
+ }
5145
+
5068
5146
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
5147
+ if (!ggml_is_contiguous(tensor)) {
5148
+ int64_t id[4] = { 0, 0, 0, 0 };
5149
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5150
+ return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
5151
+ }
5069
5152
  switch (tensor->type) {
5070
5153
  case GGML_TYPE_I8:
5071
5154
  {
5072
5155
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
5073
5156
  return ((int8_t *)(tensor->data))[i];
5074
- } break;
5157
+ }
5075
5158
  case GGML_TYPE_I16:
5076
5159
  {
5077
5160
  GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
5078
5161
  return ((int16_t *)(tensor->data))[i];
5079
- } break;
5162
+ }
5080
5163
  case GGML_TYPE_I32:
5081
5164
  {
5082
5165
  GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
5083
5166
  return ((int32_t *)(tensor->data))[i];
5084
- } break;
5167
+ }
5085
5168
  case GGML_TYPE_F16:
5086
5169
  {
5087
5170
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
5088
5171
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
5089
- } break;
5172
+ }
5090
5173
  case GGML_TYPE_F32:
5091
5174
  {
5092
5175
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
5093
5176
  return ((float *)(tensor->data))[i];
5094
- } break;
5177
+ }
5095
5178
  default:
5096
5179
  {
5097
5180
  GGML_ASSERT(false);
5098
- } break;
5181
+ }
5099
5182
  }
5100
5183
 
5101
5184
  return 0.0f;
5102
5185
  }
5103
5186
 
5104
5187
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
5188
+ if (!ggml_is_contiguous(tensor)) {
5189
+ int64_t id[4] = { 0, 0, 0, 0 };
5190
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5191
+ ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
5192
+ return;
5193
+ }
5105
5194
  switch (tensor->type) {
5106
5195
  case GGML_TYPE_I8:
5107
5196
  {
@@ -5135,43 +5224,104 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
5135
5224
  }
5136
5225
  }
5137
5226
 
5227
+ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
5228
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5229
+ switch (tensor->type) {
5230
+ case GGML_TYPE_I8:
5231
+ return ((int8_t *) data)[0];
5232
+ case GGML_TYPE_I16:
5233
+ return ((int16_t *) data)[0];
5234
+ case GGML_TYPE_I32:
5235
+ return ((int32_t *) data)[0];
5236
+ case GGML_TYPE_F16:
5237
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
5238
+ case GGML_TYPE_F32:
5239
+ return ((float *) data)[0];
5240
+ default:
5241
+ GGML_ASSERT(false);
5242
+ }
5243
+
5244
+ return 0.0f;
5245
+ }
5246
+
5247
+ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
5248
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5249
+ switch (tensor->type) {
5250
+ case GGML_TYPE_I8:
5251
+ {
5252
+ ((int8_t *)(data))[0] = value;
5253
+ } break;
5254
+ case GGML_TYPE_I16:
5255
+ {
5256
+ ((int16_t *)(data))[0] = value;
5257
+ } break;
5258
+ case GGML_TYPE_I32:
5259
+ {
5260
+ ((int32_t *)(data))[0] = value;
5261
+ } break;
5262
+ case GGML_TYPE_F16:
5263
+ {
5264
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
5265
+ } break;
5266
+ case GGML_TYPE_F32:
5267
+ {
5268
+ ((float *)(data))[0] = value;
5269
+ } break;
5270
+ default:
5271
+ {
5272
+ GGML_ASSERT(false);
5273
+ } break;
5274
+ }
5275
+ }
5276
+
5138
5277
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
5278
+ if (!ggml_is_contiguous(tensor)) {
5279
+ int64_t id[4] = { 0, 0, 0, 0 };
5280
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5281
+ return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
5282
+ }
5139
5283
  switch (tensor->type) {
5140
5284
  case GGML_TYPE_I8:
5141
5285
  {
5142
5286
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
5143
5287
  return ((int8_t *)(tensor->data))[i];
5144
- } break;
5288
+ }
5145
5289
  case GGML_TYPE_I16:
5146
5290
  {
5147
5291
  GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
5148
5292
  return ((int16_t *)(tensor->data))[i];
5149
- } break;
5293
+ }
5150
5294
  case GGML_TYPE_I32:
5151
5295
  {
5152
5296
  GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
5153
5297
  return ((int32_t *)(tensor->data))[i];
5154
- } break;
5298
+ }
5155
5299
  case GGML_TYPE_F16:
5156
5300
  {
5157
5301
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
5158
5302
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
5159
- } break;
5303
+ }
5160
5304
  case GGML_TYPE_F32:
5161
5305
  {
5162
5306
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
5163
5307
  return ((float *)(tensor->data))[i];
5164
- } break;
5308
+ }
5165
5309
  default:
5166
5310
  {
5167
5311
  GGML_ASSERT(false);
5168
- } break;
5312
+ }
5169
5313
  }
5170
5314
 
5171
5315
  return 0.0f;
5172
5316
  }
5173
5317
 
5174
5318
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
5319
+ if (!ggml_is_contiguous(tensor)) {
5320
+ int64_t id[4] = { 0, 0, 0, 0 };
5321
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5322
+ ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
5323
+ return;
5324
+ }
5175
5325
  switch (tensor->type) {
5176
5326
  case GGML_TYPE_I8:
5177
5327
  {
@@ -5205,6 +5355,56 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
5205
5355
  }
5206
5356
  }
5207
5357
 
5358
+ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
5359
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5360
+ switch (tensor->type) {
5361
+ case GGML_TYPE_I8:
5362
+ return ((int8_t *) data)[0];
5363
+ case GGML_TYPE_I16:
5364
+ return ((int16_t *) data)[0];
5365
+ case GGML_TYPE_I32:
5366
+ return ((int32_t *) data)[0];
5367
+ case GGML_TYPE_F16:
5368
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
5369
+ case GGML_TYPE_F32:
5370
+ return ((float *) data)[0];
5371
+ default:
5372
+ GGML_ASSERT(false);
5373
+ }
5374
+
5375
+ return 0.0f;
5376
+ }
5377
+
5378
+ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
5379
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5380
+ switch (tensor->type) {
5381
+ case GGML_TYPE_I8:
5382
+ {
5383
+ ((int8_t *)(data))[0] = value;
5384
+ } break;
5385
+ case GGML_TYPE_I16:
5386
+ {
5387
+ ((int16_t *)(data))[0] = value;
5388
+ } break;
5389
+ case GGML_TYPE_I32:
5390
+ {
5391
+ ((int32_t *)(data))[0] = value;
5392
+ } break;
5393
+ case GGML_TYPE_F16:
5394
+ {
5395
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
5396
+ } break;
5397
+ case GGML_TYPE_F32:
5398
+ {
5399
+ ((float *)(data))[0] = value;
5400
+ } break;
5401
+ default:
5402
+ {
5403
+ GGML_ASSERT(false);
5404
+ } break;
5405
+ }
5406
+ }
5407
+
5208
5408
  void * ggml_get_data(const struct ggml_tensor * tensor) {
5209
5409
  return tensor->data;
5210
5410
  }
@@ -5347,6 +5547,44 @@ struct ggml_tensor * ggml_add_inplace(
5347
5547
  return ggml_add_impl(ctx, a, b, true);
5348
5548
  }
5349
5549
 
5550
+ // ggml_add_cast
5551
+
5552
+ static struct ggml_tensor * ggml_add_cast_impl(
5553
+ struct ggml_context * ctx,
5554
+ struct ggml_tensor * a,
5555
+ struct ggml_tensor * b,
5556
+ enum ggml_type type) {
5557
+ // TODO: support less-strict constraint
5558
+ // GGML_ASSERT(ggml_can_repeat(b, a));
5559
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
5560
+ GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
5561
+
5562
+ bool is_node = false;
5563
+
5564
+ if (a->grad || b->grad) {
5565
+ // TODO: support backward pass for broadcasting
5566
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5567
+ is_node = true;
5568
+ }
5569
+
5570
+ struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
5571
+
5572
+ result->op = GGML_OP_ADD;
5573
+ result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
5574
+ result->src[0] = a;
5575
+ result->src[1] = b;
5576
+
5577
+ return result;
5578
+ }
5579
+
5580
+ struct ggml_tensor * ggml_add_cast(
5581
+ struct ggml_context * ctx,
5582
+ struct ggml_tensor * a,
5583
+ struct ggml_tensor * b,
5584
+ enum ggml_type type) {
5585
+ return ggml_add_cast_impl(ctx, a, b, type);
5586
+ }
5587
+
5350
5588
  // ggml_add1
5351
5589
 
5352
5590
  static struct ggml_tensor * ggml_add1_impl(
@@ -5783,7 +6021,6 @@ struct ggml_tensor * ggml_repeat(
5783
6021
  result->op = GGML_OP_REPEAT;
5784
6022
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5785
6023
  result->src[0] = a;
5786
- result->src[1] = b;
5787
6024
 
5788
6025
  return result;
5789
6026
  }
@@ -5811,7 +6048,6 @@ struct ggml_tensor * ggml_repeat_back(
5811
6048
  result->op = GGML_OP_REPEAT_BACK;
5812
6049
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5813
6050
  result->src[0] = a;
5814
- result->src[1] = b;
5815
6051
 
5816
6052
  return result;
5817
6053
  }
@@ -6186,8 +6422,9 @@ struct ggml_tensor * ggml_out_prod(
6186
6422
  is_node = true;
6187
6423
  }
6188
6424
 
6189
- const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
6190
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
6425
+ // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
6426
+ const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
6427
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
6191
6428
 
6192
6429
  result->op = GGML_OP_OUT_PROD;
6193
6430
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6406,6 +6643,54 @@ struct ggml_tensor * ggml_cont_inplace(
6406
6643
  return ggml_cont_impl(ctx, a, true);
6407
6644
  }
6408
6645
 
6646
+
6647
+ // make contiguous, with new shape
6648
+ GGML_API struct ggml_tensor * ggml_cont_1d(
6649
+ struct ggml_context * ctx,
6650
+ struct ggml_tensor * a,
6651
+ int64_t ne0) {
6652
+ return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
6653
+ }
6654
+
6655
+ GGML_API struct ggml_tensor * ggml_cont_2d(
6656
+ struct ggml_context * ctx,
6657
+ struct ggml_tensor * a,
6658
+ int64_t ne0,
6659
+ int64_t ne1) {
6660
+ return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
6661
+ }
6662
+
6663
+ GGML_API struct ggml_tensor * ggml_cont_3d(
6664
+ struct ggml_context * ctx,
6665
+ struct ggml_tensor * a,
6666
+ int64_t ne0,
6667
+ int64_t ne1,
6668
+ int64_t ne2) {
6669
+ return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
6670
+ }
6671
+
6672
+ struct ggml_tensor * ggml_cont_4d(
6673
+ struct ggml_context * ctx,
6674
+ struct ggml_tensor * a,
6675
+ int64_t ne0,
6676
+ int64_t ne1,
6677
+ int64_t ne2,
6678
+ int64_t ne3) {
6679
+ GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
6680
+
6681
+ bool is_node = false;
6682
+
6683
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
6684
+ ggml_format_name(result, "%s (cont)", a->name);
6685
+
6686
+ result->op = GGML_OP_CONT;
6687
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6688
+ result->src[0] = a;
6689
+
6690
+ return result;
6691
+ }
6692
+
6693
+
6409
6694
  // ggml_reshape
6410
6695
 
6411
6696
  struct ggml_tensor * ggml_reshape(
@@ -6413,7 +6698,7 @@ struct ggml_tensor * ggml_reshape(
6413
6698
  struct ggml_tensor * a,
6414
6699
  struct ggml_tensor * b) {
6415
6700
  GGML_ASSERT(ggml_is_contiguous(a));
6416
- GGML_ASSERT(ggml_is_contiguous(b));
6701
+ // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
6417
6702
  GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
6418
6703
 
6419
6704
  bool is_node = false;
@@ -6786,7 +7071,6 @@ struct ggml_tensor * ggml_get_rows_back(
6786
7071
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6787
7072
  result->src[0] = a;
6788
7073
  result->src[1] = b;
6789
- result->src[2] = c;
6790
7074
 
6791
7075
  return result;
6792
7076
  }
@@ -6968,7 +7252,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
6968
7252
  static struct ggml_tensor * ggml_rope_impl(
6969
7253
  struct ggml_context * ctx,
6970
7254
  struct ggml_tensor * a,
6971
- int n_past,
7255
+ struct ggml_tensor * b,
6972
7256
  int n_dims,
6973
7257
  int mode,
6974
7258
  int n_ctx,
@@ -6977,7 +7261,10 @@ static struct ggml_tensor * ggml_rope_impl(
6977
7261
  float xpos_base,
6978
7262
  bool xpos_down,
6979
7263
  bool inplace) {
6980
- GGML_ASSERT(n_past >= 0);
7264
+ GGML_ASSERT(ggml_is_vector(b));
7265
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
7266
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
7267
+
6981
7268
  bool is_node = false;
6982
7269
 
6983
7270
  if (a->grad) {
@@ -6986,7 +7273,7 @@ static struct ggml_tensor * ggml_rope_impl(
6986
7273
 
6987
7274
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6988
7275
 
6989
- int32_t params[8] = { n_past, n_dims, mode, n_ctx };
7276
+ int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
6990
7277
  memcpy(params + 4, &freq_base, sizeof(float));
6991
7278
  memcpy(params + 5, &freq_scale, sizeof(float));
6992
7279
  memcpy(params + 6, &xpos_base, sizeof(float));
@@ -6996,6 +7283,7 @@ static struct ggml_tensor * ggml_rope_impl(
6996
7283
  result->op = GGML_OP_ROPE;
6997
7284
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6998
7285
  result->src[0] = a;
7286
+ result->src[1] = b;
6999
7287
 
7000
7288
  return result;
7001
7289
  }
@@ -7003,55 +7291,55 @@ static struct ggml_tensor * ggml_rope_impl(
7003
7291
  struct ggml_tensor * ggml_rope(
7004
7292
  struct ggml_context * ctx,
7005
7293
  struct ggml_tensor * a,
7006
- int n_past,
7294
+ struct ggml_tensor * b,
7007
7295
  int n_dims,
7008
7296
  int mode,
7009
7297
  int n_ctx) {
7010
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
7298
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
7011
7299
  }
7012
7300
 
7013
7301
  struct ggml_tensor * ggml_rope_inplace(
7014
7302
  struct ggml_context * ctx,
7015
7303
  struct ggml_tensor * a,
7016
- int n_past,
7304
+ struct ggml_tensor * b,
7017
7305
  int n_dims,
7018
7306
  int mode,
7019
7307
  int n_ctx) {
7020
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
7308
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
7021
7309
  }
7022
7310
 
7023
7311
  struct ggml_tensor * ggml_rope_custom(
7024
7312
  struct ggml_context * ctx,
7025
7313
  struct ggml_tensor * a,
7026
- int n_past,
7314
+ struct ggml_tensor * b,
7027
7315
  int n_dims,
7028
7316
  int mode,
7029
7317
  int n_ctx,
7030
7318
  float freq_base,
7031
7319
  float freq_scale) {
7032
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
7320
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
7033
7321
  }
7034
7322
 
7035
7323
  struct ggml_tensor * ggml_rope_custom_inplace(
7036
7324
  struct ggml_context * ctx,
7037
7325
  struct ggml_tensor * a,
7038
- int n_past,
7326
+ struct ggml_tensor * b,
7039
7327
  int n_dims,
7040
7328
  int mode,
7041
7329
  int n_ctx,
7042
7330
  float freq_base,
7043
7331
  float freq_scale) {
7044
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
7332
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
7045
7333
  }
7046
7334
 
7047
7335
  struct ggml_tensor * ggml_rope_xpos_inplace(
7048
7336
  struct ggml_context * ctx,
7049
7337
  struct ggml_tensor * a,
7050
- int n_past,
7338
+ struct ggml_tensor * b,
7051
7339
  int n_dims,
7052
7340
  float base,
7053
7341
  bool down) {
7054
- return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
7342
+ return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
7055
7343
  }
7056
7344
 
7057
7345
  // ggml_rope_back
@@ -7059,7 +7347,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
7059
7347
  struct ggml_tensor * ggml_rope_back(
7060
7348
  struct ggml_context * ctx,
7061
7349
  struct ggml_tensor * a,
7062
- int n_past,
7350
+ struct ggml_tensor * b,
7063
7351
  int n_dims,
7064
7352
  int mode,
7065
7353
  int n_ctx,
@@ -7067,7 +7355,10 @@ struct ggml_tensor * ggml_rope_back(
7067
7355
  float freq_scale,
7068
7356
  float xpos_base,
7069
7357
  bool xpos_down) {
7070
- GGML_ASSERT(n_past >= 0);
7358
+ GGML_ASSERT(ggml_is_vector(b));
7359
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
7360
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
7361
+
7071
7362
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
7072
7363
 
7073
7364
  bool is_node = false;
@@ -7078,7 +7369,7 @@ struct ggml_tensor * ggml_rope_back(
7078
7369
 
7079
7370
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
7080
7371
 
7081
- int32_t params[8] = { n_past, n_dims, mode, n_ctx };
7372
+ int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
7082
7373
  memcpy(params + 4, &freq_base, sizeof(float));
7083
7374
  memcpy(params + 5, &freq_scale, sizeof(float));
7084
7375
  memcpy(params + 6, &xpos_base, sizeof(float));
@@ -7088,6 +7379,7 @@ struct ggml_tensor * ggml_rope_back(
7088
7379
  result->op = GGML_OP_ROPE_BACK;
7089
7380
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7090
7381
  result->src[0] = a;
7382
+ result->src[1] = b;
7091
7383
 
7092
7384
  return result;
7093
7385
  }
@@ -7484,27 +7776,30 @@ struct ggml_tensor * ggml_flash_attn_back(
7484
7776
 
7485
7777
  // d shape [D,N,ne2,ne3]
7486
7778
  // q shape [D,N,ne2,ne3]
7487
- // k shape [D,M,ne2,ne3]
7488
- // v shape [M,D,ne2,ne3]
7779
+ // k shape [D,M,kvne2,ne3]
7780
+ // v shape [M,D,kvne2,ne3]
7489
7781
 
7490
- const int64_t D = q->ne[0];
7491
- const int64_t N = q->ne[1];
7492
- const int64_t M = k->ne[1];
7493
- const int64_t ne2 = q->ne[2];
7494
- const int64_t ne3 = q->ne[3];
7782
+ const int64_t D = q->ne[0];
7783
+ const int64_t N = q->ne[1];
7784
+ const int64_t M = k->ne[1];
7785
+ const int64_t ne2 = q->ne[2];
7786
+ const int64_t ne3 = q->ne[3];
7787
+ const int64_t kvne2 = k->ne[2];
7495
7788
 
7496
7789
  GGML_ASSERT(k->ne[0] == D);
7497
7790
  GGML_ASSERT(v->ne[0] == M);
7498
7791
  GGML_ASSERT(v->ne[1] == D);
7499
7792
  GGML_ASSERT(d->ne[0] == D);
7500
7793
  GGML_ASSERT(d->ne[1] == N);
7501
- GGML_ASSERT(k->ne[2] == ne2);
7794
+ GGML_ASSERT(k->ne[2] == kvne2);
7502
7795
  GGML_ASSERT(k->ne[3] == ne3);
7503
- GGML_ASSERT(v->ne[2] == ne2);
7796
+ GGML_ASSERT(v->ne[2] == kvne2);
7504
7797
  GGML_ASSERT(v->ne[3] == ne3);
7505
7798
  GGML_ASSERT(d->ne[2] == ne2);
7506
7799
  GGML_ASSERT(d->ne[3] == ne3);
7507
7800
 
7801
+ GGML_ASSERT(ne2 % kvne2 == 0);
7802
+
7508
7803
  bool is_node = false;
7509
7804
 
7510
7805
  if (q->grad || k->grad || v->grad) {
@@ -7514,14 +7809,23 @@ struct ggml_tensor * ggml_flash_attn_back(
7514
7809
  }
7515
7810
 
7516
7811
  // store gradients of q, k and v as continuous tensors concatenated in result.
7517
- // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
7518
- // gradq->data = result->data
7519
- // gradk->data = result->data + nb0*D*N*ne2*ne3
7520
- // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
7521
7812
  // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
7522
- int64_t ne[4] = {D,M+N+M,ne2,ne3};
7813
+ const int64_t elem_q = ggml_nelements(q);
7814
+ const int64_t elem_k = ggml_nelements(k);
7815
+ const int64_t elem_v = ggml_nelements(v);
7523
7816
 
7524
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7817
+ enum ggml_type result_type = GGML_TYPE_F32;
7818
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
7819
+ const size_t tsize = ggml_type_size(result_type);
7820
+
7821
+ const size_t offs_q = 0;
7822
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
7823
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
7824
+ const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
7825
+
7826
+ const size_t nelements = (end + tsize - 1)/tsize;
7827
+
7828
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
7525
7829
 
7526
7830
  int32_t masked_i = masked ? 1 : 0;
7527
7831
  ggml_set_op_params(result, &masked_i, sizeof(masked_i));
@@ -8214,7 +8518,7 @@ static void ggml_compute_forward_dup_f16(
8214
8518
  return;
8215
8519
  }
8216
8520
 
8217
- GGML_TENSOR_UNARY_OP_LOCALS;
8521
+ GGML_TENSOR_UNARY_OP_LOCALS
8218
8522
 
8219
8523
  const int ith = params->ith; // thread index
8220
8524
  const int nth = params->nth; // number of threads
@@ -8485,7 +8789,7 @@ static void ggml_compute_forward_dup_f32(
8485
8789
  return;
8486
8790
  }
8487
8791
 
8488
- GGML_TENSOR_UNARY_OP_LOCALS;
8792
+ GGML_TENSOR_UNARY_OP_LOCALS
8489
8793
 
8490
8794
  const int ith = params->ith; // thread index
8491
8795
  const int nth = params->nth; // number of threads
@@ -8766,7 +9070,7 @@ static void ggml_compute_forward_add_f32(
8766
9070
 
8767
9071
  const int nr = ggml_nrows(src0);
8768
9072
 
8769
- GGML_TENSOR_BINARY_OP_LOCALS;
9073
+ GGML_TENSOR_BINARY_OP_LOCALS
8770
9074
 
8771
9075
  GGML_ASSERT( nb0 == sizeof(float));
8772
9076
  GGML_ASSERT(nb00 == sizeof(float));
@@ -8798,8 +9102,6 @@ static void ggml_compute_forward_add_f32(
8798
9102
  #else
8799
9103
  ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
8800
9104
  #endif
8801
- // }
8802
- // }
8803
9105
  }
8804
9106
  } else {
8805
9107
  // src1 is not contiguous
@@ -8841,7 +9143,7 @@ static void ggml_compute_forward_add_f16_f32(
8841
9143
 
8842
9144
  const int nr = ggml_nrows(src0);
8843
9145
 
8844
- GGML_TENSOR_BINARY_OP_LOCALS;
9146
+ GGML_TENSOR_BINARY_OP_LOCALS
8845
9147
 
8846
9148
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
8847
9149
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -8895,7 +9197,7 @@ static void ggml_compute_forward_add_f16_f16(
8895
9197
 
8896
9198
  const int nr = ggml_nrows(src0);
8897
9199
 
8898
- GGML_TENSOR_BINARY_OP_LOCALS;
9200
+ GGML_TENSOR_BINARY_OP_LOCALS
8899
9201
 
8900
9202
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
8901
9203
  GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -8946,14 +9248,15 @@ static void ggml_compute_forward_add_q_f32(
8946
9248
 
8947
9249
  const int nr = ggml_nrows(src0);
8948
9250
 
8949
- GGML_TENSOR_BINARY_OP_LOCALS;
9251
+ GGML_TENSOR_BINARY_OP_LOCALS
8950
9252
 
8951
9253
  const int ith = params->ith;
8952
9254
  const int nth = params->nth;
8953
9255
 
8954
9256
  const enum ggml_type type = src0->type;
9257
+ const enum ggml_type dtype = dst->type;
8955
9258
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
8956
- ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
9259
+ ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
8957
9260
 
8958
9261
  // we don't support permuted src0 or src1
8959
9262
  GGML_ASSERT(nb00 == ggml_type_size(type));
@@ -8965,7 +9268,6 @@ static void ggml_compute_forward_add_q_f32(
8965
9268
  GGML_ASSERT(nb2 <= nb3);
8966
9269
 
8967
9270
  GGML_ASSERT(ggml_is_quantized(src0->type));
8968
- GGML_ASSERT(dst->type == src0->type);
8969
9271
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
8970
9272
 
8971
9273
  // rows per thread
@@ -9003,7 +9305,11 @@ static void ggml_compute_forward_add_q_f32(
9003
9305
  // add src1
9004
9306
  ggml_vec_acc_f32(ne00, wdata, src1_row);
9005
9307
  // quantize row to dst
9006
- quantize_row_q(wdata, dst_row, ne00);
9308
+ if (quantize_row_q != NULL) {
9309
+ quantize_row_q(wdata, dst_row, ne00);
9310
+ } else {
9311
+ memcpy(dst_row, wdata, ne0*nb0);
9312
+ }
9007
9313
  }
9008
9314
  }
9009
9315
 
@@ -9068,7 +9374,7 @@ static void ggml_compute_forward_add1_f32(
9068
9374
 
9069
9375
  const int nr = ggml_nrows(src0);
9070
9376
 
9071
- GGML_TENSOR_UNARY_OP_LOCALS;
9377
+ GGML_TENSOR_UNARY_OP_LOCALS
9072
9378
 
9073
9379
  GGML_ASSERT( nb0 == sizeof(float));
9074
9380
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9123,7 +9429,7 @@ static void ggml_compute_forward_add1_f16_f32(
9123
9429
 
9124
9430
  const int nr = ggml_nrows(src0);
9125
9431
 
9126
- GGML_TENSOR_UNARY_OP_LOCALS;
9432
+ GGML_TENSOR_UNARY_OP_LOCALS
9127
9433
 
9128
9434
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9129
9435
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -9173,7 +9479,7 @@ static void ggml_compute_forward_add1_f16_f16(
9173
9479
 
9174
9480
  const int nr = ggml_nrows(src0);
9175
9481
 
9176
- GGML_TENSOR_UNARY_OP_LOCALS;
9482
+ GGML_TENSOR_UNARY_OP_LOCALS
9177
9483
 
9178
9484
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9179
9485
  GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -9223,7 +9529,7 @@ static void ggml_compute_forward_add1_q_f32(
9223
9529
 
9224
9530
  const int nr = ggml_nrows(src0);
9225
9531
 
9226
- GGML_TENSOR_UNARY_OP_LOCALS;
9532
+ GGML_TENSOR_UNARY_OP_LOCALS
9227
9533
 
9228
9534
  const enum ggml_type type = src0->type;
9229
9535
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
@@ -9351,8 +9657,8 @@ static void ggml_compute_forward_acc_f32(
9351
9657
  const int nr = ggml_nrows(src1);
9352
9658
  const int nc = src1->ne[0];
9353
9659
 
9354
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
9355
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
9660
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
9661
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
9356
9662
 
9357
9663
  // src0 and dst as viewed during acc
9358
9664
  const size_t nb0 = ggml_element_size(src0);
@@ -9441,7 +9747,7 @@ static void ggml_compute_forward_sub_f32(
9441
9747
 
9442
9748
  const int nr = ggml_nrows(src0);
9443
9749
 
9444
- GGML_TENSOR_BINARY_OP_LOCALS;
9750
+ GGML_TENSOR_BINARY_OP_LOCALS
9445
9751
 
9446
9752
  GGML_ASSERT( nb0 == sizeof(float));
9447
9753
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9531,7 +9837,7 @@ static void ggml_compute_forward_mul_f32(
9531
9837
 
9532
9838
  const int64_t nr = ggml_nrows(src0);
9533
9839
 
9534
- GGML_TENSOR_BINARY_OP_LOCALS;
9840
+ GGML_TENSOR_BINARY_OP_LOCALS
9535
9841
 
9536
9842
  GGML_ASSERT( nb0 == sizeof(float));
9537
9843
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9622,7 +9928,7 @@ static void ggml_compute_forward_div_f32(
9622
9928
 
9623
9929
  const int nr = ggml_nrows(src0);
9624
9930
 
9625
- GGML_TENSOR_BINARY_OP_LOCALS;
9931
+ GGML_TENSOR_BINARY_OP_LOCALS
9626
9932
 
9627
9933
  GGML_ASSERT( nb0 == sizeof(float));
9628
9934
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9831,8 +10137,8 @@ static void ggml_compute_forward_sum_f32(
9831
10137
  assert(ggml_is_scalar(dst));
9832
10138
  assert(src0->nb[0] == sizeof(float));
9833
10139
 
9834
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
9835
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb);
10140
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10141
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
9836
10142
 
9837
10143
  ggml_float sum = 0;
9838
10144
  ggml_float row_sum = 0;
@@ -9863,8 +10169,8 @@ static void ggml_compute_forward_sum_f16(
9863
10169
 
9864
10170
  assert(src0->nb[0] == sizeof(ggml_fp16_t));
9865
10171
 
9866
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
9867
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb);
10172
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10173
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
9868
10174
 
9869
10175
  float sum = 0;
9870
10176
  float row_sum = 0;
@@ -9917,7 +10223,7 @@ static void ggml_compute_forward_sum_rows_f32(
9917
10223
  GGML_ASSERT(src0->nb[0] == sizeof(float));
9918
10224
  GGML_ASSERT(dst->nb[0] == sizeof(float));
9919
10225
 
9920
- GGML_TENSOR_UNARY_OP_LOCALS;
10226
+ GGML_TENSOR_UNARY_OP_LOCALS
9921
10227
 
9922
10228
  GGML_ASSERT(ne0 == 1);
9923
10229
  GGML_ASSERT(ne1 == ne01);
@@ -9967,7 +10273,7 @@ static void ggml_compute_forward_mean_f32(
9967
10273
 
9968
10274
  assert(src0->nb[0] == sizeof(float));
9969
10275
 
9970
- GGML_TENSOR_UNARY_OP_LOCALS;
10276
+ GGML_TENSOR_UNARY_OP_LOCALS
9971
10277
 
9972
10278
  assert(ne0 == 1);
9973
10279
  assert(ne1 == ne01);
@@ -10067,7 +10373,7 @@ static void ggml_compute_forward_repeat_f32(
10067
10373
  return;
10068
10374
  }
10069
10375
 
10070
- GGML_TENSOR_UNARY_OP_LOCALS;
10376
+ GGML_TENSOR_UNARY_OP_LOCALS
10071
10377
 
10072
10378
  // guaranteed to be an integer due to the check in ggml_can_repeat
10073
10379
  const int nr0 = (int)(ne0/ne00);
@@ -10099,11 +10405,61 @@ static void ggml_compute_forward_repeat_f32(
10099
10405
  }
10100
10406
  }
10101
10407
 
10408
+ static void ggml_compute_forward_repeat_f16(
10409
+ const struct ggml_compute_params * params,
10410
+ const struct ggml_tensor * src0,
10411
+ struct ggml_tensor * dst) {
10412
+ GGML_ASSERT(params->ith == 0);
10413
+ GGML_ASSERT(ggml_can_repeat(src0, dst));
10414
+
10415
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10416
+ return;
10417
+ }
10418
+
10419
+ GGML_TENSOR_UNARY_OP_LOCALS;
10420
+
10421
+ // guaranteed to be an integer due to the check in ggml_can_repeat
10422
+ const int nr0 = (int)(ne0/ne00);
10423
+ const int nr1 = (int)(ne1/ne01);
10424
+ const int nr2 = (int)(ne2/ne02);
10425
+ const int nr3 = (int)(ne3/ne03);
10426
+
10427
+ // TODO: support for transposed / permuted tensors
10428
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
10429
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
10430
+
10431
+ // TODO: maybe this is not optimal?
10432
+ for (int i3 = 0; i3 < nr3; i3++) {
10433
+ for (int k3 = 0; k3 < ne03; k3++) {
10434
+ for (int i2 = 0; i2 < nr2; i2++) {
10435
+ for (int k2 = 0; k2 < ne02; k2++) {
10436
+ for (int i1 = 0; i1 < nr1; i1++) {
10437
+ for (int k1 = 0; k1 < ne01; k1++) {
10438
+ for (int i0 = 0; i0 < nr0; i0++) {
10439
+ ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
10440
+ ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
10441
+ // ggml_vec_cpy_f16(ne00, y, x)
10442
+ for (int i = 0; i < ne00; ++i) {
10443
+ y[i] = x[i];
10444
+ }
10445
+ }
10446
+ }
10447
+ }
10448
+ }
10449
+ }
10450
+ }
10451
+ }
10452
+ }
10453
+
10102
10454
  static void ggml_compute_forward_repeat(
10103
10455
  const struct ggml_compute_params * params,
10104
10456
  const struct ggml_tensor * src0,
10105
10457
  struct ggml_tensor * dst) {
10106
10458
  switch (src0->type) {
10459
+ case GGML_TYPE_F16:
10460
+ {
10461
+ ggml_compute_forward_repeat_f16(params, src0, dst);
10462
+ } break;
10107
10463
  case GGML_TYPE_F32:
10108
10464
  {
10109
10465
  ggml_compute_forward_repeat_f32(params, src0, dst);
@@ -10128,7 +10484,7 @@ static void ggml_compute_forward_repeat_back_f32(
10128
10484
  return;
10129
10485
  }
10130
10486
 
10131
- GGML_TENSOR_UNARY_OP_LOCALS;
10487
+ GGML_TENSOR_UNARY_OP_LOCALS
10132
10488
 
10133
10489
  // guaranteed to be an integer due to the check in ggml_can_repeat
10134
10490
  const int nr0 = (int)(ne00/ne0);
@@ -10206,7 +10562,7 @@ static void ggml_compute_forward_concat_f32(
10206
10562
 
10207
10563
  const int ith = params->ith;
10208
10564
 
10209
- GGML_TENSOR_BINARY_OP_LOCALS;
10565
+ GGML_TENSOR_BINARY_OP_LOCALS
10210
10566
 
10211
10567
  // TODO: support for transposed / permuted tensors
10212
10568
  GGML_ASSERT(nb0 == sizeof(float));
@@ -10808,7 +11164,7 @@ static void ggml_compute_forward_norm_f32(
10808
11164
  const int ith = params->ith;
10809
11165
  const int nth = params->nth;
10810
11166
 
10811
- GGML_TENSOR_UNARY_OP_LOCALS;
11167
+ GGML_TENSOR_UNARY_OP_LOCALS
10812
11168
 
10813
11169
  float eps;
10814
11170
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -10877,7 +11233,7 @@ static void ggml_compute_forward_rms_norm_f32(
10877
11233
  const int ith = params->ith;
10878
11234
  const int nth = params->nth;
10879
11235
 
10880
- GGML_TENSOR_UNARY_OP_LOCALS;
11236
+ GGML_TENSOR_UNARY_OP_LOCALS
10881
11237
 
10882
11238
  float eps;
10883
11239
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -10942,7 +11298,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
10942
11298
  const int ith = params->ith;
10943
11299
  const int nth = params->nth;
10944
11300
 
10945
- GGML_TENSOR_BINARY_OP_LOCALS;
11301
+ GGML_TENSOR_BINARY_OP_LOCALS
10946
11302
 
10947
11303
  float eps;
10948
11304
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -11117,7 +11473,7 @@ static void ggml_compute_forward_group_norm_f32(
11117
11473
  const int ith = params->ith;
11118
11474
  const int nth = params->nth;
11119
11475
 
11120
- GGML_TENSOR_UNARY_OP_LOCALS;
11476
+ GGML_TENSOR_UNARY_OP_LOCALS
11121
11477
 
11122
11478
  const float eps = 1e-6f; // TODO: make this a parameter
11123
11479
 
@@ -11228,7 +11584,7 @@ static void ggml_compute_forward_mul_mat(
11228
11584
  int64_t t0 = ggml_perf_time_us();
11229
11585
  UNUSED(t0);
11230
11586
 
11231
- GGML_TENSOR_BINARY_OP_LOCALS;
11587
+ GGML_TENSOR_BINARY_OP_LOCALS
11232
11588
 
11233
11589
  const int ith = params->ith;
11234
11590
  const int nth = params->nth;
@@ -11443,10 +11799,10 @@ static void ggml_compute_forward_out_prod_f32(
11443
11799
  const struct ggml_tensor * src0,
11444
11800
  const struct ggml_tensor * src1,
11445
11801
  struct ggml_tensor * dst) {
11446
- int64_t t0 = ggml_perf_time_us();
11447
- UNUSED(t0);
11802
+ // int64_t t0 = ggml_perf_time_us();
11803
+ // UNUSED(t0);
11448
11804
 
11449
- GGML_TENSOR_BINARY_OP_LOCALS;
11805
+ GGML_TENSOR_BINARY_OP_LOCALS
11450
11806
 
11451
11807
  const int ith = params->ith;
11452
11808
  const int nth = params->nth;
@@ -11485,6 +11841,146 @@ static void ggml_compute_forward_out_prod_f32(
11485
11841
  return;
11486
11842
  }
11487
11843
 
11844
+ // dst[:,:,:,:] = 0
11845
+ // for i2,i3:
11846
+ // for i1:
11847
+ // for i01:
11848
+ // for i0:
11849
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
11850
+
11851
+ // parallelize by last three dimensions
11852
+
11853
+ // total rows in dst
11854
+ const int64_t nr = ne1*ne2*ne3;
11855
+
11856
+ // rows per thread
11857
+ const int64_t dr = (nr + nth - 1)/nth;
11858
+
11859
+ // row range for this thread
11860
+ const int64_t ir0 = dr*ith;
11861
+ const int64_t ir1 = MIN(ir0 + dr, nr);
11862
+
11863
+ // block-tiling attempt
11864
+ const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
11865
+ const int64_t blck_1 = 16;
11866
+
11867
+ for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
11868
+ const int64_t bir1 = MIN(bir + blck_1, ir1);
11869
+ for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
11870
+ const int64_t bne01 = MIN(bi01 + blck_0, ne01);
11871
+ for (int64_t ir = bir; ir < bir1; ++ir) {
11872
+ // dst indices
11873
+ const int64_t i3 = ir/(ne2*ne1);
11874
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
11875
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
11876
+
11877
+ const int64_t i02 = i2;
11878
+ const int64_t i03 = i3;
11879
+
11880
+ //const int64_t i10 = i1;
11881
+ const int64_t i12 = i2;
11882
+ const int64_t i13 = i3;
11883
+
11884
+ #if GGML_VEC_MAD_UNROLL > 2
11885
+ const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
11886
+ for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
11887
+ const int64_t i11 = i01;
11888
+
11889
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11890
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11891
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11892
+
11893
+ ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
11894
+ }
11895
+ for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
11896
+ const int64_t i11 = i01;
11897
+
11898
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11899
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11900
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11901
+
11902
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
11903
+ }
11904
+ #else
11905
+ for (int64_t i01 = bi01; i01 < bne01; ++i01) {
11906
+ const int64_t i11 = i01;
11907
+
11908
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11909
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11910
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11911
+
11912
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
11913
+ }
11914
+ #endif
11915
+ }
11916
+ }
11917
+ }
11918
+
11919
+
11920
+ //int64_t t1 = ggml_perf_time_us();
11921
+ //static int64_t acc = 0;
11922
+ //acc += t1 - t0;
11923
+ //if (t1 - t0 > 10) {
11924
+ // printf("\n");
11925
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
11926
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
11927
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
11928
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
11929
+
11930
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
11931
+ //}
11932
+ }
11933
+
11934
+ static void ggml_compute_forward_out_prod_q_f32(
11935
+ const struct ggml_compute_params * params,
11936
+ const struct ggml_tensor * src0,
11937
+ const struct ggml_tensor * src1,
11938
+ struct ggml_tensor * dst) {
11939
+ // int64_t t0 = ggml_perf_time_us();
11940
+ // UNUSED(t0);
11941
+
11942
+ GGML_TENSOR_BINARY_OP_LOCALS;
11943
+
11944
+ const int ith = params->ith;
11945
+ const int nth = params->nth;
11946
+
11947
+ const enum ggml_type type = src0->type;
11948
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
11949
+
11950
+ GGML_ASSERT(ne02 == ne12);
11951
+ GGML_ASSERT(ne03 == ne13);
11952
+ GGML_ASSERT(ne2 == ne12);
11953
+ GGML_ASSERT(ne3 == ne13);
11954
+
11955
+ // we don't support permuted src0 dim0
11956
+ GGML_ASSERT(nb00 == ggml_type_size(type));
11957
+
11958
+ // dst dim0 cannot be transposed or permuted
11959
+ GGML_ASSERT(nb0 == sizeof(float));
11960
+ // GGML_ASSERT(nb0 <= nb1);
11961
+ // GGML_ASSERT(nb1 <= nb2);
11962
+ // GGML_ASSERT(nb2 <= nb3);
11963
+
11964
+ GGML_ASSERT(ne0 == ne00);
11965
+ GGML_ASSERT(ne1 == ne10);
11966
+ GGML_ASSERT(ne2 == ne02);
11967
+ GGML_ASSERT(ne3 == ne03);
11968
+
11969
+ // nb01 >= nb00 - src0 is not transposed
11970
+ // compute by src0 rows
11971
+
11972
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
11973
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
11974
+
11975
+ if (params->type == GGML_TASK_INIT) {
11976
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
11977
+ return;
11978
+ }
11979
+
11980
+ if (params->type == GGML_TASK_FINALIZE) {
11981
+ return;
11982
+ }
11983
+
11488
11984
  // parallelize by last three dimensions
11489
11985
 
11490
11986
  // total rows in dst
@@ -11504,6 +12000,8 @@ static void ggml_compute_forward_out_prod_f32(
11504
12000
  // for i0:
11505
12001
  // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
11506
12002
 
12003
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
12004
+
11507
12005
  for (int64_t ir = ir0; ir < ir1; ++ir) {
11508
12006
  // dst indices
11509
12007
  const int64_t i3 = ir/(ne2*ne1);
@@ -11524,10 +12022,8 @@ static void ggml_compute_forward_out_prod_f32(
11524
12022
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11525
12023
  float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11526
12024
 
11527
- ggml_vec_mad_f32(ne0, d, s0, *s1);
11528
- // for (int64_t i0 = 0; i0 < ne0; ++i0) {
11529
- // d[i0] += s0[i0] * s1[i1];
11530
- // }
12025
+ dequantize_row_q(s0, wdata, ne0);
12026
+ ggml_vec_mad_f32(ne0, d, wdata, *s1);
11531
12027
  }
11532
12028
  }
11533
12029
 
@@ -11556,10 +12052,13 @@ static void ggml_compute_forward_out_prod(
11556
12052
  case GGML_TYPE_Q5_0:
11557
12053
  case GGML_TYPE_Q5_1:
11558
12054
  case GGML_TYPE_Q8_0:
11559
- case GGML_TYPE_Q8_1:
12055
+ case GGML_TYPE_Q2_K:
12056
+ case GGML_TYPE_Q3_K:
12057
+ case GGML_TYPE_Q4_K:
12058
+ case GGML_TYPE_Q5_K:
12059
+ case GGML_TYPE_Q6_K:
11560
12060
  {
11561
- GGML_ASSERT(false); // todo
11562
- // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
12061
+ ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
11563
12062
  } break;
11564
12063
  case GGML_TYPE_F16:
11565
12064
  {
@@ -11677,8 +12176,8 @@ static void ggml_compute_forward_set_f32(
11677
12176
  const int nr = ggml_nrows(src1);
11678
12177
  const int nc = src1->ne[0];
11679
12178
 
11680
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
11681
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
12179
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
12180
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
11682
12181
 
11683
12182
  // src0 and dst as viewed during set
11684
12183
  const size_t nb0 = ggml_element_size(src0);
@@ -11947,14 +12446,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
11947
12446
  const struct ggml_compute_params * params,
11948
12447
  const struct ggml_tensor * src0,
11949
12448
  const struct ggml_tensor * src1,
11950
- const struct ggml_tensor * opt0,
11951
12449
  struct ggml_tensor * dst) {
11952
12450
  GGML_ASSERT(params->ith == 0);
11953
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
11954
- GGML_ASSERT(ggml_is_contiguous(opt0));
11955
12451
  GGML_ASSERT(ggml_is_contiguous(dst));
11956
12452
 
11957
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
12453
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
12454
+
12455
+ if (params->type == GGML_TASK_INIT) {
12456
+ memset(dst->data, 0, ggml_nbytes(dst));
12457
+ }
11958
12458
 
11959
12459
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11960
12460
  return;
@@ -11980,11 +12480,8 @@ static void ggml_compute_forward_get_rows_back_f32(
11980
12480
  const struct ggml_compute_params * params,
11981
12481
  const struct ggml_tensor * src0,
11982
12482
  const struct ggml_tensor * src1,
11983
- const struct ggml_tensor * opt0,
11984
12483
  struct ggml_tensor * dst) {
11985
12484
  GGML_ASSERT(params->ith == 0);
11986
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
11987
- GGML_ASSERT(ggml_is_contiguous(opt0));
11988
12485
  GGML_ASSERT(ggml_is_contiguous(dst));
11989
12486
 
11990
12487
  // ggml_compute_forward_dup_same_cont(params, opt0, dst);
@@ -12018,16 +12515,15 @@ static void ggml_compute_forward_get_rows_back(
12018
12515
  const struct ggml_compute_params * params,
12019
12516
  const struct ggml_tensor * src0,
12020
12517
  const struct ggml_tensor * src1,
12021
- const struct ggml_tensor * opt0,
12022
12518
  struct ggml_tensor * dst) {
12023
12519
  switch (src0->type) {
12024
12520
  case GGML_TYPE_F16:
12025
12521
  {
12026
- ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst);
12522
+ ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
12027
12523
  } break;
12028
12524
  case GGML_TYPE_F32:
12029
12525
  {
12030
- ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst);
12526
+ ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
12031
12527
  } break;
12032
12528
  default:
12033
12529
  {
@@ -12068,7 +12564,7 @@ static void ggml_compute_forward_diag_f32(
12068
12564
 
12069
12565
  // TODO: handle transposed/permuted matrices
12070
12566
 
12071
- GGML_TENSOR_UNARY_OP_LOCALS;
12567
+ GGML_TENSOR_UNARY_OP_LOCALS
12072
12568
 
12073
12569
  GGML_ASSERT(ne00 == ne0);
12074
12570
  GGML_ASSERT(ne00 == ne1);
@@ -12456,13 +12952,11 @@ static void ggml_compute_forward_alibi_f16(
12456
12952
  return;
12457
12953
  }
12458
12954
 
12459
- const int n_past = ((int32_t *) dst->op_params)[0];
12955
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12460
12956
  const int n_head = ((int32_t *) dst->op_params)[1];
12461
12957
  float max_bias;
12462
12958
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12463
12959
 
12464
- assert(n_past >= 0);
12465
-
12466
12960
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12467
12961
  const int ne1 = src0->ne[1]; // seq_len_without_past
12468
12962
  const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -12477,7 +12971,7 @@ static void ggml_compute_forward_alibi_f16(
12477
12971
  //const int nb3 = src0->nb[3];
12478
12972
 
12479
12973
  GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
12480
- GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12974
+ //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12481
12975
  GGML_ASSERT(n_head == ne2);
12482
12976
 
12483
12977
  // add alibi to src0 (KQ_scaled)
@@ -12623,8 +13117,8 @@ static void ggml_compute_forward_clamp(
12623
13117
  static void ggml_compute_forward_rope_f32(
12624
13118
  const struct ggml_compute_params * params,
12625
13119
  const struct ggml_tensor * src0,
13120
+ const struct ggml_tensor * src1,
12626
13121
  struct ggml_tensor * dst) {
12627
-
12628
13122
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12629
13123
  return;
12630
13124
  }
@@ -12634,9 +13128,9 @@ static void ggml_compute_forward_rope_f32(
12634
13128
 
12635
13129
  // these two only relevant for xPos RoPE:
12636
13130
  float xpos_base;
12637
- bool xpos_down;
13131
+ bool xpos_down;
12638
13132
 
12639
- const int n_past = ((int32_t *) dst->op_params)[0];
13133
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12640
13134
  const int n_dims = ((int32_t *) dst->op_params)[1];
12641
13135
  const int mode = ((int32_t *) dst->op_params)[2];
12642
13136
  const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -12645,9 +13139,7 @@ static void ggml_compute_forward_rope_f32(
12645
13139
  memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
12646
13140
  memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
12647
13141
 
12648
- assert(n_past >= 0);
12649
-
12650
- GGML_TENSOR_UNARY_OP_LOCALS;
13142
+ GGML_TENSOR_UNARY_OP_LOCALS
12651
13143
 
12652
13144
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12653
13145
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12677,9 +13169,11 @@ static void ggml_compute_forward_rope_f32(
12677
13169
  const bool is_neox = mode & 2;
12678
13170
  const bool is_glm = mode & 4;
12679
13171
 
13172
+ const int32_t * pos = (const int32_t *) src1->data;
13173
+
12680
13174
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12681
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12682
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13175
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13176
+ const int64_t p = pos[i2];
12683
13177
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12684
13178
  if (ir++ < ir0) continue;
12685
13179
  if (ir > ir1) break;
@@ -12716,7 +13210,7 @@ static void ggml_compute_forward_rope_f32(
12716
13210
  const float cos_theta = cosf(theta);
12717
13211
  const float sin_theta = sinf(theta);
12718
13212
  // zeta scaling for xPos only:
12719
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
13213
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
12720
13214
  if (xpos_down) zeta = 1.0f / zeta;
12721
13215
 
12722
13216
  theta *= theta_scale;
@@ -12761,8 +13255,8 @@ static void ggml_compute_forward_rope_f32(
12761
13255
  static void ggml_compute_forward_rope_f16(
12762
13256
  const struct ggml_compute_params * params,
12763
13257
  const struct ggml_tensor * src0,
13258
+ const struct ggml_tensor * src1,
12764
13259
  struct ggml_tensor * dst) {
12765
-
12766
13260
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12767
13261
  return;
12768
13262
  }
@@ -12770,16 +13264,14 @@ static void ggml_compute_forward_rope_f16(
12770
13264
  float freq_base;
12771
13265
  float freq_scale;
12772
13266
 
12773
- const int n_past = ((int32_t *) dst->op_params)[0];
13267
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12774
13268
  const int n_dims = ((int32_t *) dst->op_params)[1];
12775
13269
  const int mode = ((int32_t *) dst->op_params)[2];
12776
13270
  const int n_ctx = ((int32_t *) dst->op_params)[3];
12777
13271
  memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
12778
13272
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
12779
13273
 
12780
- assert(n_past >= 0);
12781
-
12782
- GGML_TENSOR_UNARY_OP_LOCALS;
13274
+ GGML_TENSOR_UNARY_OP_LOCALS
12783
13275
 
12784
13276
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12785
13277
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12809,9 +13301,11 @@ static void ggml_compute_forward_rope_f16(
12809
13301
  const bool is_neox = mode & 2;
12810
13302
  const bool is_glm = mode & 4;
12811
13303
 
13304
+ const int32_t * pos = (const int32_t *) src1->data;
13305
+
12812
13306
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12813
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12814
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13307
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13308
+ const int64_t p = pos[i2];
12815
13309
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12816
13310
  if (ir++ < ir0) continue;
12817
13311
  if (ir > ir1) break;
@@ -12890,15 +13384,16 @@ static void ggml_compute_forward_rope_f16(
12890
13384
  static void ggml_compute_forward_rope(
12891
13385
  const struct ggml_compute_params * params,
12892
13386
  const struct ggml_tensor * src0,
13387
+ const struct ggml_tensor * src1,
12893
13388
  struct ggml_tensor * dst) {
12894
13389
  switch (src0->type) {
12895
13390
  case GGML_TYPE_F16:
12896
13391
  {
12897
- ggml_compute_forward_rope_f16(params, src0, dst);
13392
+ ggml_compute_forward_rope_f16(params, src0, src1, dst);
12898
13393
  } break;
12899
13394
  case GGML_TYPE_F32:
12900
13395
  {
12901
- ggml_compute_forward_rope_f32(params, src0, dst);
13396
+ ggml_compute_forward_rope_f32(params, src0, src1, dst);
12902
13397
  } break;
12903
13398
  default:
12904
13399
  {
@@ -12912,6 +13407,7 @@ static void ggml_compute_forward_rope(
12912
13407
  static void ggml_compute_forward_rope_back_f32(
12913
13408
  const struct ggml_compute_params * params,
12914
13409
  const struct ggml_tensor * src0,
13410
+ const struct ggml_tensor * src1,
12915
13411
  struct ggml_tensor * dst) {
12916
13412
 
12917
13413
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -12929,7 +13425,7 @@ static void ggml_compute_forward_rope_back_f32(
12929
13425
  float xpos_base;
12930
13426
  bool xpos_down;
12931
13427
 
12932
- const int n_past = ((int32_t *) dst->op_params)[0];
13428
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12933
13429
  const int n_dims = ((int32_t *) dst->op_params)[1];
12934
13430
  const int mode = ((int32_t *) dst->op_params)[2];
12935
13431
  const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
@@ -12938,9 +13434,7 @@ static void ggml_compute_forward_rope_back_f32(
12938
13434
  memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
12939
13435
  memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
12940
13436
 
12941
- assert(n_past >= 0);
12942
-
12943
- GGML_TENSOR_UNARY_OP_LOCALS;
13437
+ GGML_TENSOR_UNARY_OP_LOCALS
12944
13438
 
12945
13439
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12946
13440
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12966,9 +13460,11 @@ static void ggml_compute_forward_rope_back_f32(
12966
13460
 
12967
13461
  const bool is_neox = mode & 2;
12968
13462
 
13463
+ const int32_t * pos = (const int32_t *) src1->data;
13464
+
12969
13465
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12970
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12971
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13466
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13467
+ const int64_t p = pos[i2];
12972
13468
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12973
13469
  if (ir++ < ir0) continue;
12974
13470
  if (ir > ir1) break;
@@ -12980,7 +13476,7 @@ static void ggml_compute_forward_rope_back_f32(
12980
13476
  const float cos_theta = cosf(theta);
12981
13477
  const float sin_theta = sinf(theta);
12982
13478
  // zeta scaling for xPos only:
12983
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
13479
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
12984
13480
  if (xpos_down) zeta = 1.0f / zeta;
12985
13481
 
12986
13482
  theta *= theta_scale;
@@ -13023,6 +13519,7 @@ static void ggml_compute_forward_rope_back_f32(
13023
13519
  static void ggml_compute_forward_rope_back_f16(
13024
13520
  const struct ggml_compute_params * params,
13025
13521
  const struct ggml_tensor * src0,
13522
+ const struct ggml_tensor * src1,
13026
13523
  struct ggml_tensor * dst) {
13027
13524
 
13028
13525
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -13033,13 +13530,11 @@ static void ggml_compute_forward_rope_back_f16(
13033
13530
  // dx = rope_back(dy, src1)
13034
13531
  // src0 is dy, src1 contains options
13035
13532
 
13036
- const int n_past = ((int32_t *) dst->op_params)[0];
13533
+ //const int n_past = ((int32_t *) dst->op_params)[0];
13037
13534
  const int n_dims = ((int32_t *) dst->op_params)[1];
13038
13535
  const int mode = ((int32_t *) dst->op_params)[2];
13039
13536
 
13040
- assert(n_past >= 0);
13041
-
13042
- GGML_TENSOR_UNARY_OP_LOCALS;
13537
+ GGML_TENSOR_UNARY_OP_LOCALS
13043
13538
 
13044
13539
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
13045
13540
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -13065,9 +13560,11 @@ static void ggml_compute_forward_rope_back_f16(
13065
13560
 
13066
13561
  const bool is_neox = mode & 2;
13067
13562
 
13563
+ const int32_t * pos = (const int32_t *) src1->data;
13564
+
13068
13565
  for (int64_t i3 = 0; i3 < ne3; i3++) {
13069
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
13070
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13566
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13567
+ const int64_t p = pos[i2];
13071
13568
  for (int64_t i1 = 0; i1 < ne1; i1++) {
13072
13569
  if (ir++ < ir0) continue;
13073
13570
  if (ir > ir1) break;
@@ -13119,15 +13616,16 @@ static void ggml_compute_forward_rope_back_f16(
13119
13616
  static void ggml_compute_forward_rope_back(
13120
13617
  const struct ggml_compute_params * params,
13121
13618
  const struct ggml_tensor * src0,
13619
+ const struct ggml_tensor * src1,
13122
13620
  struct ggml_tensor * dst) {
13123
13621
  switch (src0->type) {
13124
13622
  case GGML_TYPE_F16:
13125
13623
  {
13126
- ggml_compute_forward_rope_back_f16(params, src0, dst);
13624
+ ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
13127
13625
  } break;
13128
13626
  case GGML_TYPE_F32:
13129
13627
  {
13130
- ggml_compute_forward_rope_back_f32(params, src0, dst);
13628
+ ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
13131
13629
  } break;
13132
13630
  default:
13133
13631
  {
@@ -13150,7 +13648,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13150
13648
  int64_t t0 = ggml_perf_time_us();
13151
13649
  UNUSED(t0);
13152
13650
 
13153
- GGML_TENSOR_BINARY_OP_LOCALS;
13651
+ GGML_TENSOR_BINARY_OP_LOCALS
13154
13652
 
13155
13653
  const int ith = params->ith;
13156
13654
  const int nth = params->nth;
@@ -13241,7 +13739,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13241
13739
  int64_t t0 = ggml_perf_time_us();
13242
13740
  UNUSED(t0);
13243
13741
 
13244
- GGML_TENSOR_BINARY_OP_LOCALS;
13742
+ GGML_TENSOR_BINARY_OP_LOCALS
13245
13743
 
13246
13744
  const int ith = params->ith;
13247
13745
  const int nth = params->nth;
@@ -13353,7 +13851,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13353
13851
  int64_t t0 = ggml_perf_time_us();
13354
13852
  UNUSED(t0);
13355
13853
 
13356
- GGML_TENSOR_BINARY_OP_LOCALS;
13854
+ GGML_TENSOR_BINARY_OP_LOCALS
13357
13855
 
13358
13856
  const int ith = params->ith;
13359
13857
  const int nth = params->nth;
@@ -13444,7 +13942,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13444
13942
  int64_t t0 = ggml_perf_time_us();
13445
13943
  UNUSED(t0);
13446
13944
 
13447
- GGML_TENSOR_BINARY_OP_LOCALS;
13945
+ GGML_TENSOR_BINARY_OP_LOCALS
13448
13946
 
13449
13947
  const int ith = params->ith;
13450
13948
  const int nth = params->nth;
@@ -13562,7 +14060,7 @@ static void ggml_compute_forward_conv_1d(
13562
14060
  ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
13563
14061
  } else {
13564
14062
  GGML_ASSERT(false); // only stride 1 and 2 supported
13565
- };
14063
+ }
13566
14064
  }
13567
14065
 
13568
14066
  // ggml_compute_forward_conv_2d
@@ -13579,7 +14077,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
13579
14077
  int64_t t0 = ggml_perf_time_us();
13580
14078
  UNUSED(t0);
13581
14079
 
13582
- GGML_TENSOR_BINARY_OP_LOCALS;
14080
+ GGML_TENSOR_BINARY_OP_LOCALS
13583
14081
 
13584
14082
  const int ith = params->ith;
13585
14083
  const int nth = params->nth;
@@ -13699,7 +14197,7 @@ static void ggml_compute_forward_conv_transpose_2d(
13699
14197
  int64_t t0 = ggml_perf_time_us();
13700
14198
  UNUSED(t0);
13701
14199
 
13702
- GGML_TENSOR_BINARY_OP_LOCALS;
14200
+ GGML_TENSOR_BINARY_OP_LOCALS
13703
14201
 
13704
14202
  const int ith = params->ith;
13705
14203
  const int nth = params->nth;
@@ -13958,7 +14456,7 @@ static void ggml_compute_forward_upscale_f32(
13958
14456
 
13959
14457
  const int ith = params->ith;
13960
14458
 
13961
- GGML_TENSOR_UNARY_OP_LOCALS;
14459
+ GGML_TENSOR_UNARY_OP_LOCALS
13962
14460
 
13963
14461
  const int scale_factor = dst->op_params[0];
13964
14462
 
@@ -14010,14 +14508,14 @@ static void ggml_compute_forward_flash_attn_f32(
14010
14508
  int64_t t0 = ggml_perf_time_us();
14011
14509
  UNUSED(t0);
14012
14510
 
14013
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14014
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14015
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14016
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14017
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14018
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14019
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14020
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14511
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14512
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14513
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14514
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14515
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14516
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14517
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14518
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14021
14519
 
14022
14520
  const int ith = params->ith;
14023
14521
  const int nth = params->nth;
@@ -14087,10 +14585,11 @@ static void ggml_compute_forward_flash_attn_f32(
14087
14585
  S[i] = -INFINITY;
14088
14586
  }
14089
14587
 
14090
- for (int64_t ic = 0; ic < nek1; ++ic) {
14588
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
14589
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
14091
14590
  // k indices
14092
14591
  const int ik3 = iq3;
14093
- const int ik2 = iq2;
14592
+ const int ik2 = iq2 % nek2;
14094
14593
  const int ik1 = ic;
14095
14594
 
14096
14595
  // S indices
@@ -14103,20 +14602,18 @@ static void ggml_compute_forward_flash_attn_f32(
14103
14602
  }
14104
14603
 
14105
14604
  // scale
14106
- ggml_vec_scale_f32(nek1, S, scale);
14605
+ ggml_vec_scale_f32(masked_begin, S, scale);
14107
14606
 
14108
- if (masked) {
14109
- for (int64_t i = P; i < M; i++) {
14110
- if (i > P + iq1) {
14111
- S[i] = -INFINITY;
14112
- }
14113
- }
14607
+ for (int64_t i = masked_begin; i < M; i++) {
14608
+ S[i] = -INFINITY;
14114
14609
  }
14115
14610
 
14116
14611
  // softmax
14612
+ // exclude known -INF S[..] values from max and loop
14613
+ // dont forget to set their SW values to zero
14117
14614
  {
14118
14615
  float max = -INFINITY;
14119
- ggml_vec_max_f32(M, &max, S);
14616
+ ggml_vec_max_f32(masked_begin, &max, S);
14120
14617
 
14121
14618
  ggml_float sum = 0.0;
14122
14619
  {
@@ -14130,10 +14627,15 @@ static void ggml_compute_forward_flash_attn_f32(
14130
14627
  ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14131
14628
 
14132
14629
  for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14630
+ if (i >= masked_begin) {
14631
+ break;
14632
+ }
14133
14633
  float * SS = S + i;
14134
14634
 
14135
14635
  for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14136
- if (SS[j] == -INFINITY) {
14636
+ if (i + j >= masked_begin) {
14637
+ break;
14638
+ } else if (SS[j] == -INFINITY) {
14137
14639
  SS[j] = 0.0f;
14138
14640
  } else {
14139
14641
  #ifndef GGML_FLASH_ATTN_EXP_FP16
@@ -14158,10 +14660,10 @@ static void ggml_compute_forward_flash_attn_f32(
14158
14660
  assert(sum > 0.0);
14159
14661
 
14160
14662
  sum = 1.0/sum;
14161
- ggml_vec_scale_f32(M, S, sum);
14663
+ ggml_vec_scale_f32(masked_begin, S, sum);
14162
14664
 
14163
14665
  #ifndef NDEBUG
14164
- for (int i = 0; i < M; ++i) {
14666
+ for (int i = 0; i < masked_begin; ++i) {
14165
14667
  assert(!isnan(S[i]));
14166
14668
  assert(!isinf(S[i]));
14167
14669
  }
@@ -14174,9 +14676,13 @@ static void ggml_compute_forward_flash_attn_f32(
14174
14676
  const int i2 = iq2;
14175
14677
  const int i3 = iq3;
14176
14678
 
14177
- ggml_vec_dot_f32(nek1,
14178
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14179
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14679
+ // v indices
14680
+ const int iv2 = iq2 % nev2;
14681
+ const int iv3 = iq3;
14682
+
14683
+ ggml_vec_dot_f32(masked_begin,
14684
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14685
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14180
14686
  S);
14181
14687
  }
14182
14688
  }
@@ -14192,14 +14698,14 @@ static void ggml_compute_forward_flash_attn_f16(
14192
14698
  int64_t t0 = ggml_perf_time_us();
14193
14699
  UNUSED(t0);
14194
14700
 
14195
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14196
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14197
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14198
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14199
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14200
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14201
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14202
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14701
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14702
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14703
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14704
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14705
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14706
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14707
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14708
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14203
14709
 
14204
14710
  const int ith = params->ith;
14205
14711
  const int nth = params->nth;
@@ -14273,7 +14779,7 @@ static void ggml_compute_forward_flash_attn_f16(
14273
14779
  for (int64_t ic = 0; ic < nek1; ++ic) {
14274
14780
  // k indices
14275
14781
  const int ik3 = iq3;
14276
- const int ik2 = iq2;
14782
+ const int ik2 = iq2 % nek2;
14277
14783
  const int ik1 = ic;
14278
14784
 
14279
14785
  // S indices
@@ -14288,7 +14794,7 @@ static void ggml_compute_forward_flash_attn_f16(
14288
14794
  for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
14289
14795
  // k indices
14290
14796
  const int ik3 = iq3;
14291
- const int ik2 = iq2;
14797
+ const int ik2 = iq2 % nek2;
14292
14798
  const int ik1 = ic;
14293
14799
 
14294
14800
  // S indices
@@ -14313,6 +14819,8 @@ static void ggml_compute_forward_flash_attn_f16(
14313
14819
  }
14314
14820
 
14315
14821
  // softmax
14822
+ // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
14823
+ // dont forget to set their S values to zero
14316
14824
  {
14317
14825
  float max = -INFINITY;
14318
14826
  ggml_vec_max_f32(M, &max, S);
@@ -14369,6 +14877,7 @@ static void ggml_compute_forward_flash_attn_f16(
14369
14877
  S16[i] = GGML_FP32_TO_FP16(S[i]);
14370
14878
  }
14371
14879
 
14880
+ // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
14372
14881
  if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
14373
14882
  for (int64_t ic = 0; ic < nev1; ++ic) {
14374
14883
  // dst indices
@@ -14376,9 +14885,13 @@ static void ggml_compute_forward_flash_attn_f16(
14376
14885
  const int i2 = iq2;
14377
14886
  const int i3 = iq3;
14378
14887
 
14379
- ggml_vec_dot_f16(nek1,
14380
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14381
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14888
+ // v indices
14889
+ const int iv2 = iq2 % nev2;
14890
+ const int iv3 = iq3;
14891
+
14892
+ ggml_vec_dot_f16(nev0,
14893
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14894
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14382
14895
  S16);
14383
14896
  }
14384
14897
  } else {
@@ -14388,9 +14901,13 @@ static void ggml_compute_forward_flash_attn_f16(
14388
14901
  const int i2 = iq2;
14389
14902
  const int i3 = iq3;
14390
14903
 
14391
- ggml_vec_dot_f16_unroll(nek1, nbv1,
14392
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14393
- ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14904
+ // v indices
14905
+ const int iv2 = iq2 % nev2;
14906
+ const int iv3 = iq3;
14907
+
14908
+ ggml_vec_dot_f16_unroll(nev0, nbv1,
14909
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14910
+ ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14394
14911
  S16);
14395
14912
  }
14396
14913
  }
@@ -14433,18 +14950,18 @@ static void ggml_compute_forward_flash_ff_f16(
14433
14950
  int64_t t0 = ggml_perf_time_us();
14434
14951
  UNUSED(t0);
14435
14952
 
14436
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne);
14437
- GGML_TENSOR_LOCALS(size_t, nba, a, nb);
14438
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne);
14439
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb);
14440
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne);
14441
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb);
14442
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne);
14443
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb);
14444
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne);
14445
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb);
14446
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14447
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14953
+ GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
14954
+ GGML_TENSOR_LOCALS(size_t, nba, a, nb)
14955
+ GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
14956
+ GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
14957
+ GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
14958
+ GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
14959
+ GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
14960
+ GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
14961
+ GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
14962
+ GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
14963
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14964
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14448
14965
 
14449
14966
  const int ith = params->ith;
14450
14967
  const int nth = params->nth;
@@ -14592,16 +15109,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
14592
15109
  int64_t t0 = ggml_perf_time_us();
14593
15110
  UNUSED(t0);
14594
15111
 
14595
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14596
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14597
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14598
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14599
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14600
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14601
- GGML_TENSOR_LOCALS(int64_t, ned, d, ne);
14602
- GGML_TENSOR_LOCALS(size_t, nbd, d, nb);
14603
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14604
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
15112
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15113
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15114
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15115
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15116
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15117
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15118
+ GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
15119
+ GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
15120
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15121
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14605
15122
 
14606
15123
  const int ith = params->ith;
14607
15124
  const int nth = params->nth;
@@ -14649,10 +15166,37 @@ static void ggml_compute_forward_flash_attn_back_f32(
14649
15166
  return;
14650
15167
  }
14651
15168
 
14652
- // parallelize by q rows using ggml_vec_dot_f32
15169
+ const int64_t elem_q = ggml_nelements(q);
15170
+ const int64_t elem_k = ggml_nelements(k);
14653
15171
 
14654
- // total rows in q
14655
- const int nr = neq2*neq3;
15172
+ enum ggml_type result_type = dst->type;
15173
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
15174
+ const size_t tsize = ggml_type_size(result_type);
15175
+
15176
+ const size_t offs_q = 0;
15177
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
15178
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
15179
+
15180
+ void * grad_q = (char *) dst->data;
15181
+ void * grad_k = (char *) dst->data + offs_k;
15182
+ void * grad_v = (char *) dst->data + offs_v;
15183
+
15184
+ const size_t nbgq1 = nb0*neq0;
15185
+ const size_t nbgq2 = nb0*neq0*neq1;
15186
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
15187
+
15188
+ const size_t nbgk1 = nb0*nek0;
15189
+ const size_t nbgk2 = nb0*nek0*nek1;
15190
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
15191
+
15192
+ const size_t nbgv1 = nb0*nev0;
15193
+ const size_t nbgv2 = nb0*nev0*nev1;
15194
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
15195
+
15196
+ // parallelize by k rows using ggml_vec_dot_f32
15197
+
15198
+ // total rows in k
15199
+ const int nr = nek2*nek3;
14656
15200
 
14657
15201
  // rows per thread
14658
15202
  const int dr = (nr + nth - 1)/nth;
@@ -14665,268 +15209,243 @@ static void ggml_compute_forward_flash_attn_back_f32(
14665
15209
 
14666
15210
  //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
14667
15211
 
15212
+ // how often k2 (and v2) is repeated in q2
15213
+ int nrep = neq2/nek2;
15214
+
14668
15215
  for (int ir = ir0; ir < ir1; ++ir) {
14669
15216
  // q indices
14670
- const int iq3 = ir/(neq2);
14671
- const int iq2 = ir - iq3*neq2;
14672
- for ( int iq1 = 0; iq1 < neq1; ++iq1) {
15217
+ const int ik3 = ir/(nek2);
15218
+ const int ik2 = ir - ik3*nek2;
14673
15219
 
15220
+ const int iq3 = ik3;
15221
+ const int id3 = ik3;
15222
+ const int iv3 = ik3;
15223
+ const int iv2 = ik2;
14674
15224
 
14675
- // not sure about CACHE_LINE_SIZE_F32..
14676
- // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
14677
- float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
14678
- float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
15225
+ for (int irep = 0; irep < nrep; ++irep) {
15226
+ const int iq2 = ik2 + irep*nek2;
15227
+ const int id2 = iq2;
14679
15228
 
14680
- for (int i = M; i < Mup; ++i) {
14681
- S[i] = -INFINITY;
14682
- }
15229
+ // (ik2 + irep*nek2) % nek2 == ik2
15230
+ for (int iq1 = 0; iq1 < neq1; ++iq1) {
15231
+ const int id1 = iq1;
14683
15232
 
14684
- for (int64_t ic = 0; ic < nek1; ++ic) {
14685
- // k indices
14686
- const int ik3 = iq3;
14687
- const int ik2 = iq2;
14688
- const int ik1 = ic;
15233
+ // not sure about CACHE_LINE_SIZE_F32..
15234
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
15235
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
15236
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
14689
15237
 
14690
- // S indices
14691
- const int i1 = ik1;
15238
+ for (int i = M; i < Mup; ++i) {
15239
+ S[i] = -INFINITY;
15240
+ }
14692
15241
 
14693
- ggml_vec_dot_f32(neq0,
14694
- S + i1,
14695
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
14696
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14697
- }
15242
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15243
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15244
+ // k indices
15245
+ const int ik1 = ic;
14698
15246
 
14699
- // scale
14700
- ggml_vec_scale_f32(nek1, S, scale);
15247
+ // S indices
15248
+ const int i1 = ik1;
14701
15249
 
14702
- if (masked) {
14703
- for (int64_t i = P; i < M; i++) {
14704
- if (i > P + iq1) {
14705
- S[i] = -INFINITY;
14706
- }
15250
+ ggml_vec_dot_f32(neq0,
15251
+ S + i1,
15252
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15253
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14707
15254
  }
14708
- }
14709
15255
 
14710
- // softmax
14711
- {
14712
- float max = -INFINITY;
14713
- ggml_vec_max_f32(M, &max, S);
15256
+ // scale
15257
+ ggml_vec_scale_f32(masked_begin, S, scale);
14714
15258
 
14715
- ggml_float sum = 0.0;
15259
+ for (int64_t i = masked_begin; i < M; i++) {
15260
+ S[i] = -INFINITY;
15261
+ }
15262
+
15263
+ // softmax
15264
+ // exclude known -INF S[..] values from max and loop
15265
+ // dont forget to set their SM values to zero
14716
15266
  {
15267
+ float max = -INFINITY;
15268
+ ggml_vec_max_f32(masked_begin, &max, S);
15269
+
15270
+ ggml_float sum = 0.0;
15271
+ {
14717
15272
  #ifdef GGML_SOFT_MAX_ACCELERATE
14718
- max = -max;
14719
- vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
14720
- vvexpf(SM, SM, &Mup);
14721
- ggml_vec_sum_f32(Mup, &sum, SM);
15273
+ max = -max;
15274
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
15275
+ vvexpf(SM, SM, &Mup);
15276
+ ggml_vec_sum_f32(Mup, &sum, SM);
14722
15277
  #else
14723
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
14724
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14725
-
14726
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14727
- float * SR = S + i;
14728
- float * SW = SM + i;
15278
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15279
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14729
15280
 
14730
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14731
- if (SR[j] == -INFINITY) {
14732
- SW[j] = 0.0f;
14733
- } else {
15281
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15282
+ if (i >= masked_begin) {
15283
+ break;
15284
+ }
15285
+ float * SR = S + i;
15286
+ float * SW = SM + i;
15287
+
15288
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15289
+ if (i + j >= masked_begin) {
15290
+ break;
15291
+ } else if (SR[j] == -INFINITY) {
15292
+ SW[j] = 0.0f;
15293
+ } else {
14734
15294
  #ifndef GGML_FLASH_ATTN_EXP_FP16
14735
- const float val = expf(SR[j] - max);
15295
+ const float val = expf(SR[j] - max);
14736
15296
  #else
14737
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
14738
- memcpy(&scvt[j], &s, sizeof(uint16_t));
14739
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
15297
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
15298
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
15299
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
14740
15300
  #endif
14741
- sump[j] += (ggml_float)val;
14742
- SW[j] = val;
15301
+ sump[j] += (ggml_float)val;
15302
+ SW[j] = val;
15303
+ }
14743
15304
  }
14744
15305
  }
14745
- }
14746
15306
 
14747
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14748
- sum += sump[i];
14749
- }
15307
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15308
+ sum += sump[i];
15309
+ }
14750
15310
  #endif
14751
- }
14752
-
14753
- assert(sum > 0.0);
14754
-
14755
- sum = 1.0/sum;
14756
- ggml_vec_scale_f32(M, SM, sum);
14757
-
14758
- }
14759
-
14760
- // step-by-step explanation
14761
- {
14762
- // forward-process shape grads from backward process
14763
- // parallel_for iq2,iq3:
14764
- // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
14765
- // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
14766
- // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
14767
- // for iq1:
14768
- // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
14769
- // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
14770
- // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
14771
- // S0 = -Inf [D,1,1,1]
14772
- // ~S1[i] = dot(kcur[:D,i], qcur)
14773
- // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
14774
- // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
14775
- // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14776
- // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
14777
- // ~S5[i] = dot(vcur[:,i], S4)
14778
- // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
14779
- // ~dst[i,iq1,iq2,iq3] = S5[i] ^
14780
- // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
14781
- // dst backward-/ grad[dst] = d
14782
- //
14783
- // output gradients with their dependencies:
14784
- //
14785
- // grad[kcur] = grad[S1].T @ qcur
14786
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14787
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14788
- // grad[S4] = grad[S5] @ vcur
14789
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14790
- // grad[qcur] = grad[S1] @ kcur
14791
- // grad[vcur] = grad[S5].T @ S4
14792
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14793
- //
14794
- // in post-order:
14795
- //
14796
- // S1 = qcur @ kcur.T
14797
- // S2 = S1 * scale
14798
- // S3 = diag_mask_inf(S2, P)
14799
- // S4 = softmax(S3)
14800
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14801
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14802
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14803
- // grad[qcur] = grad[S1] @ kcur
14804
- // grad[kcur] = grad[S1].T @ qcur
14805
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14806
- //
14807
- // using less variables (SM=S4):
14808
- //
14809
- // S = diag_mask_inf(qcur @ kcur.T * scale, P)
14810
- // SM = softmax(S)
14811
- // S = d[:D,iq1,iq2,iq3] @ vcur
14812
- // dot_SM_gradSM = dot(SM, S)
14813
- // S = SM * (S - dot(SM, S))
14814
- // S = diag_mask_zero(S, P) * scale
14815
- //
14816
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14817
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14818
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14819
- }
14820
-
14821
- // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
14822
- // S = d[:D,iq1,iq2,iq3] @ vcur
14823
- // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
14824
- ggml_vec_set_f32(M, S, 0);
14825
- for (int64_t ic = 0; ic < D; ++ic) {
14826
- // dst indices
14827
- const int i1 = iq1;
14828
- const int i2 = iq2;
14829
- const int i3 = iq3;
15311
+ }
14830
15312
 
14831
- ggml_vec_mad_f32(M,
14832
- S,
14833
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14834
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
14835
- }
15313
+ assert(sum > 0.0);
14836
15314
 
14837
- // S = SM * (S - dot(SM, S))
14838
- float dot_SM_gradSM = 0;
14839
- ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
14840
- ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
14841
- ggml_vec_mul_f32 (M, S, S, SM);
15315
+ sum = 1.0/sum;
15316
+ ggml_vec_scale_f32(masked_begin, SM, sum);
14842
15317
 
14843
- // S = diag_mask_zero(S, P) * scale
14844
- if (masked) {
14845
- // for (int64_t i = P + iq1 + 1; i < M; i++) {
14846
- // S[i] = 0;
14847
- // }
14848
- for (int64_t i = P; i < M; i++) {
14849
- if (i > P + iq1) {
14850
- S[i] = 0;
14851
- }
14852
15318
  }
14853
- }
14854
- ggml_vec_scale_f32(M, S, scale);
14855
-
14856
- void * grad_q = (char *) dst->data;
14857
- void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
14858
- void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
14859
-
14860
- const size_t nbgq1 = nb0*neq0;
14861
- const size_t nbgq2 = nb0*neq0*neq1;
14862
- const size_t nbgq3 = nb0*neq0*neq1*neq2;
14863
-
14864
- const size_t nbgk1 = nb0*nek0;
14865
- const size_t nbgk2 = nb0*nek0*nek1;
14866
- const size_t nbgk3 = nb0*nek0*nek1*neq2;
14867
-
14868
- const size_t nbgv1 = nb0*nev0;
14869
- const size_t nbgv2 = nb0*nev0*nev1;
14870
- const size_t nbgv3 = nb0*nev0*nev1*neq2;
14871
-
14872
- // S shape [M,1]
14873
- // SM shape [M,1]
14874
- // kcur shape [D,M]
14875
- // qcur shape [D,1]
14876
- // vcur shape [M,D]
14877
- //
14878
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14879
- // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
14880
- // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
14881
- //
14882
- //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
14883
- //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
14884
- for (int64_t ic = 0; ic < M; ++ic) {
14885
- // dst indices
14886
- const int i1 = iq1;
14887
- const int i2 = iq2;
14888
- const int i3 = iq3;
14889
15319
 
14890
- ggml_vec_mad_f32(D,
14891
- (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
14892
- (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
14893
- S[ic]);
14894
- }
15320
+ // step-by-step explanation
15321
+ {
15322
+ // forward-process shape grads from backward process
15323
+ // parallel_for ik2,ik3:
15324
+ // for irep:
15325
+ // iq2 = ik2 + irep*nek2
15326
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
15327
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
15328
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
15329
+ // for iq1:
15330
+ // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
15331
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
15332
+ // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
15333
+ // S0 = -Inf [D,1,1,1]
15334
+ // ~S1[i] = dot(kcur[:D,i], qcur)
15335
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
15336
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
15337
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15338
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
15339
+ // ~S5[i] = dot(vcur[:,i], S4)
15340
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
15341
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
15342
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
15343
+ // dst backward-/ grad[dst] = d
15344
+ //
15345
+ // output gradients with their dependencies:
15346
+ //
15347
+ // grad[kcur] = grad[S1].T @ qcur
15348
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
15349
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15350
+ // grad[S4] = grad[S5] @ vcur
15351
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
15352
+ // grad[qcur] = grad[S1] @ kcur
15353
+ // grad[vcur] = grad[S5].T @ S4
15354
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
15355
+ //
15356
+ // in post-order:
15357
+ //
15358
+ // S1 = qcur @ kcur.T
15359
+ // S2 = S1 * scale
15360
+ // S3 = diag_mask_inf(S2, P)
15361
+ // S4 = softmax(S3)
15362
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
15363
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15364
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
15365
+ // grad[qcur] = grad[S1] @ kcur
15366
+ // grad[kcur] = grad[S1].T @ qcur
15367
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
15368
+ //
15369
+ // using less variables (SM=S4):
15370
+ //
15371
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
15372
+ // SM = softmax(S)
15373
+ // S = d[:D,iq1,iq2,iq3] @ vcur
15374
+ // dot_SM_gradSM = dot(SM, S)
15375
+ // S = SM * (S - dot(SM, S))
15376
+ // S = diag_mask_zero(S, P) * scale
15377
+ //
15378
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
15379
+ // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
15380
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
15381
+ }
14895
15382
 
14896
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14897
- // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
14898
- // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
14899
- for (int64_t ic = 0; ic < M; ++ic) {
14900
- // dst indices
14901
- const int i1 = iq1;
14902
- const int i2 = iq2;
14903
- const int i3 = iq3;
15383
+ // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
15384
+ // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
15385
+ // for ic:
15386
+ // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
15387
+ // exclude known future zero S[..] values from operation
15388
+ ggml_vec_set_f32(masked_begin, S, 0);
15389
+ for (int64_t ic = 0; ic < D; ++ic) {
15390
+ ggml_vec_mad_f32(masked_begin,
15391
+ S,
15392
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15393
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
15394
+ }
14904
15395
 
14905
- // ggml_vec_set_f32(D,
14906
- // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14907
- // 0);
14908
- ggml_vec_mad_f32(D,
14909
- (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14910
- (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
14911
- S[ic]);
14912
- }
15396
+ // S = SM * (S - dot(SM, S))
15397
+ float dot_SM_gradSM = 0;
15398
+ ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
15399
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
15400
+ ggml_vec_mul_f32 (masked_begin, S, S, SM);
15401
+
15402
+ // S = diag_mask_zero(S, P) * scale
15403
+ // already done by above ggml_vec_set_f32
15404
+
15405
+ // exclude known zero S[..] values from operation
15406
+ ggml_vec_scale_f32(masked_begin, S, scale);
15407
+
15408
+ // S shape [M,1]
15409
+ // SM shape [M,1]
15410
+ // kcur shape [D,M]
15411
+ // qcur shape [D,1]
15412
+ // vcur shape [M,D]
15413
+
15414
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
15415
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
15416
+ // for ic:
15417
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
15418
+ // exclude known zero S[..] values from loop
15419
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15420
+ ggml_vec_mad_f32(D,
15421
+ (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
15422
+ (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
15423
+ S[ic]);
15424
+ }
14913
15425
 
14914
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14915
- // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
14916
- // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
14917
- for (int64_t ic = 0; ic < D; ++ic) {
14918
- // dst indices
14919
- const int i1 = iq1;
14920
- const int i2 = iq2;
14921
- const int i3 = iq3;
15426
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
15427
+ // for ic:
15428
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
15429
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
15430
+ // exclude known zero S[..] values from loop
15431
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15432
+ ggml_vec_mad_f32(D,
15433
+ (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
15434
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
15435
+ S[ic]);
15436
+ }
14922
15437
 
14923
- // ggml_vec_set_f32(M,
14924
- // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14925
- // 0);
14926
- ggml_vec_mad_f32(M,
14927
- (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14928
- SM,
14929
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
15438
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
15439
+ // for ic:
15440
+ // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
15441
+ // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
15442
+ // exclude known zero SM[..] values from mad
15443
+ for (int64_t ic = 0; ic < D; ++ic) {
15444
+ ggml_vec_mad_f32(masked_begin,
15445
+ (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
15446
+ SM,
15447
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
15448
+ }
14930
15449
  }
14931
15450
  }
14932
15451
  }
@@ -14962,8 +15481,8 @@ static void ggml_compute_forward_win_part_f32(
14962
15481
  return;
14963
15482
  }
14964
15483
 
14965
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
14966
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
15484
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
15485
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14967
15486
 
14968
15487
  const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
14969
15488
  const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
@@ -15024,8 +15543,8 @@ static void ggml_compute_forward_win_unpart_f32(
15024
15543
  return;
15025
15544
  }
15026
15545
 
15027
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
15028
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
15546
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
15547
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15029
15548
 
15030
15549
  const int32_t w = ((const int32_t *)(dst->op_params))[0];
15031
15550
 
@@ -15142,7 +15661,7 @@ static void ggml_compute_forward_get_rel_pos_f16(
15142
15661
 
15143
15662
  // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
15144
15663
 
15145
- GGML_TENSOR_UNARY_OP_LOCALS;
15664
+ GGML_TENSOR_UNARY_OP_LOCALS
15146
15665
 
15147
15666
  const int64_t w = ne1;
15148
15667
 
@@ -15840,7 +16359,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15840
16359
  } break;
15841
16360
  case GGML_OP_GET_ROWS_BACK:
15842
16361
  {
15843
- ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
16362
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
15844
16363
  } break;
15845
16364
  case GGML_OP_DIAG:
15846
16365
  {
@@ -15864,11 +16383,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15864
16383
  } break;
15865
16384
  case GGML_OP_ROPE:
15866
16385
  {
15867
- ggml_compute_forward_rope(params, tensor->src[0], tensor);
16386
+ ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
15868
16387
  } break;
15869
16388
  case GGML_OP_ROPE_BACK:
15870
16389
  {
15871
- ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
16390
+ ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
15872
16391
  } break;
15873
16392
  case GGML_OP_ALIBI:
15874
16393
  {
@@ -16013,7 +16532,218 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16013
16532
 
16014
16533
  ////////////////////////////////////////////////////////////////////////////////
16015
16534
 
16016
- static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
16535
+ static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
16536
+
16537
+ static size_t hash(void * p) {
16538
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
16539
+ }
16540
+
16541
+ static size_t hash_find(void * hash_table[], void * p) {
16542
+ size_t h = hash(p);
16543
+
16544
+ // linear probing
16545
+ size_t i = h;
16546
+ while (hash_table[i] != NULL && hash_table[i] != p) {
16547
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
16548
+ if (i == h) {
16549
+ // visited all hash table entries -> not found
16550
+ return GGML_GRAPH_HASHTABLE_SIZE;
16551
+ }
16552
+ }
16553
+ return i;
16554
+ }
16555
+
16556
+ static bool hash_insert(void * hash_table[], void * p) {
16557
+ size_t i = hash_find(hash_table, p);
16558
+
16559
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16560
+
16561
+ if (hash_table[i] == p) {
16562
+ return true;
16563
+ }
16564
+
16565
+ // insert
16566
+ GGML_ASSERT(hash_table[i] == NULL);
16567
+ hash_table[i] = p;
16568
+ return false;
16569
+ }
16570
+
16571
+ static bool hash_contains(void * hash_table[], void * p) {
16572
+ size_t i = hash_find(hash_table, p);
16573
+ return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
16574
+ }
16575
+
16576
+ struct hash_map {
16577
+ void * keys[GGML_GRAPH_HASHTABLE_SIZE];
16578
+ void * vals[GGML_GRAPH_HASHTABLE_SIZE];
16579
+ };
16580
+
16581
+ static struct hash_map * new_hash_map(void) {
16582
+ struct hash_map * result = malloc(sizeof(struct hash_map));
16583
+ for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
16584
+ result->keys[i] = NULL;
16585
+ result->vals[i] = NULL;
16586
+ }
16587
+ return result;
16588
+ }
16589
+
16590
+ static void free_hash_map(struct hash_map * map) {
16591
+ free(map);
16592
+ }
16593
+
16594
+ // gradient checkpointing
16595
+
16596
+ static struct ggml_tensor * ggml_recompute_graph_node(
16597
+ struct ggml_context * ctx,
16598
+ struct ggml_cgraph * graph,
16599
+ struct hash_map * replacements,
16600
+ struct ggml_tensor * node) {
16601
+
16602
+ if (node == NULL) {
16603
+ return NULL;
16604
+ }
16605
+
16606
+ if (node->is_param) {
16607
+ return node;
16608
+ }
16609
+
16610
+ if (!hash_contains(graph->visited_hash_table, node)) {
16611
+ return node;
16612
+ }
16613
+
16614
+ int count_children = 0;
16615
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16616
+ if (node->src[k]) {
16617
+ ++count_children;
16618
+ }
16619
+ }
16620
+
16621
+ if (count_children == 0) {
16622
+ return node;
16623
+ }
16624
+
16625
+ size_t i = hash_find(replacements->keys, node);
16626
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16627
+ if (replacements->keys[i] == node) {
16628
+ return (struct ggml_tensor *) replacements->vals[i];
16629
+ }
16630
+
16631
+ struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
16632
+
16633
+ // insert clone into replacements
16634
+ GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
16635
+ replacements->keys[i] = node;
16636
+ replacements->vals[i] = clone;
16637
+
16638
+ clone->op = node->op;
16639
+ clone->grad = node->grad;
16640
+ clone->is_param = node->is_param;
16641
+ clone->extra = node->extra;
16642
+ for (int k = 0; k < GGML_MAX_DIMS; ++k) {
16643
+ clone->nb[k] = node->nb[k];
16644
+ }
16645
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16646
+ clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
16647
+ }
16648
+ if (node->view_src != NULL) {
16649
+ clone->data = (node->view_src->data == NULL)
16650
+ ? NULL // view_src not yet allocated
16651
+ : (char *) node->view_src->data // view_src already allocated
16652
+ + node->view_offs;
16653
+ clone->view_src = node->view_src;
16654
+ clone->view_offs = node->view_offs;
16655
+ }
16656
+
16657
+ GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
16658
+ GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
16659
+ memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
16660
+ ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
16661
+
16662
+ return clone;
16663
+ }
16664
+
16665
+ void ggml_build_backward_gradient_checkpointing(
16666
+ struct ggml_context * ctx,
16667
+ struct ggml_cgraph * gf,
16668
+ struct ggml_cgraph * gb,
16669
+ struct ggml_cgraph * gb_tmp,
16670
+ struct ggml_tensor * * checkpoints,
16671
+ int n_checkpoints) {
16672
+ *gb_tmp = *gf;
16673
+ ggml_build_backward_expand(ctx, gf, gb_tmp, true);
16674
+
16675
+ if (n_checkpoints <= 0) {
16676
+ *gb = *gb_tmp;
16677
+ return;
16678
+ }
16679
+
16680
+ struct hash_map * replacements = new_hash_map();
16681
+
16682
+ // insert checkpoints in replacements
16683
+ for (int i = 0; i < n_checkpoints; ++i) {
16684
+ size_t k = hash_find(replacements->keys, checkpoints[i]);
16685
+ GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16686
+ GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
16687
+ replacements->keys[k] = checkpoints[i];
16688
+ replacements->vals[k] = checkpoints[i];
16689
+ }
16690
+
16691
+ *gb = *gf;
16692
+ // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
16693
+ // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
16694
+ // by recomputing them from checkpoints
16695
+ for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
16696
+ struct ggml_tensor * node = gb_tmp->nodes[i];
16697
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16698
+ // insert new tensors recomputing src, reusing already made replacements,
16699
+ // remember replacements: remember new tensors with mapping from corresponding gf nodes
16700
+ // recurse for input tensors,
16701
+ // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
16702
+ node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
16703
+ }
16704
+ // insert rewritten backward node with replacements made into resulting backward graph gb
16705
+ ggml_build_forward_expand(gb, node);
16706
+ }
16707
+
16708
+ free_hash_map(replacements);
16709
+ }
16710
+
16711
+ // functions to change gradients considering the case that input a might be initial gradient with zero value
16712
+
16713
+ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16714
+ if (hash_contains(zero_table, a)) {
16715
+ return b;
16716
+ } else {
16717
+ return ggml_add_impl(ctx, a, b, false);
16718
+ }
16719
+ }
16720
+
16721
+ static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
16722
+ if (hash_contains(zero_table, a)) {
16723
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
16724
+ return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
16725
+ } else {
16726
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
16727
+ }
16728
+ }
16729
+
16730
+ static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16731
+ if (hash_contains(zero_table, a)) {
16732
+ return ggml_repeat(ctx, b, a);
16733
+ } else {
16734
+ return ggml_add1_impl(ctx, a, b, false);
16735
+ }
16736
+ }
16737
+
16738
+ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16739
+ if (hash_contains(zero_table, a)) {
16740
+ return ggml_neg(ctx, b);
16741
+ } else {
16742
+ return ggml_sub_impl(ctx, a, b, false);
16743
+ }
16744
+ }
16745
+
16746
+ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
16017
16747
  struct ggml_tensor * src0 = tensor->src[0];
16018
16748
  struct ggml_tensor * src1 = tensor->src[1];
16019
16749
 
@@ -16021,34 +16751,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16021
16751
  case GGML_OP_DUP:
16022
16752
  {
16023
16753
  if (src0->grad) {
16024
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16754
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16025
16755
  }
16026
16756
  } break;
16027
16757
  case GGML_OP_ADD:
16028
16758
  {
16029
16759
  if (src0->grad) {
16030
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16760
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16031
16761
  }
16032
16762
  if (src1->grad) {
16033
- src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
16763
+ src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
16034
16764
  }
16035
16765
  } break;
16036
16766
  case GGML_OP_ADD1:
16037
16767
  {
16038
16768
  if (src0->grad) {
16039
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16769
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16040
16770
  }
16041
16771
  if (src1->grad) {
16042
- src1->grad = ggml_add_impl(ctx,
16772
+ src1->grad = ggml_add_or_set(ctx,
16043
16773
  src1->grad,
16044
16774
  ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
16045
- inplace);
16775
+ zero_table);
16046
16776
  }
16047
16777
  } break;
16048
16778
  case GGML_OP_ACC:
16049
16779
  {
16050
16780
  if (src0->grad) {
16051
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16781
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16052
16782
  }
16053
16783
  if (src1->grad) {
16054
16784
  const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -16065,117 +16795,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16065
16795
  nb1, nb2, nb3, offset);
16066
16796
 
16067
16797
  src1->grad =
16068
- ggml_add_impl(ctx,
16798
+ ggml_add_or_set(ctx,
16069
16799
  src1->grad,
16070
16800
  ggml_reshape(ctx,
16071
16801
  ggml_cont(ctx, tensor_grad_view),
16072
16802
  src1->grad),
16073
- inplace);
16803
+ zero_table);
16074
16804
  }
16075
16805
  } break;
16076
16806
  case GGML_OP_SUB:
16077
16807
  {
16078
16808
  if (src0->grad) {
16079
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16809
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16080
16810
  }
16081
16811
  if (src1->grad) {
16082
- src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
16812
+ src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
16083
16813
  }
16084
16814
  } break;
16085
16815
  case GGML_OP_MUL:
16086
16816
  {
16087
16817
  if (src0->grad) {
16088
16818
  src0->grad =
16089
- ggml_add_impl(ctx,
16819
+ ggml_add_or_set(ctx,
16090
16820
  src0->grad,
16091
16821
  ggml_mul(ctx, src1, tensor->grad),
16092
- inplace);
16822
+ zero_table);
16093
16823
  }
16094
16824
  if (src1->grad) {
16095
16825
  src1->grad =
16096
- ggml_add_impl(ctx,
16826
+ ggml_add_or_set(ctx,
16097
16827
  src1->grad,
16098
16828
  ggml_mul(ctx, src0, tensor->grad),
16099
- inplace);
16829
+ zero_table);
16100
16830
  }
16101
16831
  } break;
16102
16832
  case GGML_OP_DIV:
16103
16833
  {
16104
16834
  if (src0->grad) {
16105
16835
  src0->grad =
16106
- ggml_add_impl(ctx,
16836
+ ggml_add_or_set(ctx,
16107
16837
  src0->grad,
16108
16838
  ggml_div(ctx, tensor->grad, src1),
16109
- inplace);
16839
+ zero_table);
16110
16840
  }
16111
16841
  if (src1->grad) {
16112
16842
  src1->grad =
16113
- ggml_sub_impl(ctx,
16843
+ ggml_sub_or_set(ctx,
16114
16844
  src1->grad,
16115
16845
  ggml_mul(ctx,
16116
16846
  tensor->grad,
16117
16847
  ggml_div(ctx, tensor, src1)),
16118
- inplace);
16848
+ zero_table);
16119
16849
  }
16120
16850
  } break;
16121
16851
  case GGML_OP_SQR:
16122
16852
  {
16123
16853
  if (src0->grad) {
16124
16854
  src0->grad =
16125
- ggml_add_impl(ctx,
16855
+ ggml_add_or_set(ctx,
16126
16856
  src0->grad,
16127
16857
  ggml_scale(ctx,
16128
16858
  ggml_mul(ctx, src0, tensor->grad),
16129
16859
  ggml_new_f32(ctx, 2.0f)),
16130
- inplace);
16860
+ zero_table);
16131
16861
  }
16132
16862
  } break;
16133
16863
  case GGML_OP_SQRT:
16134
16864
  {
16135
16865
  if (src0->grad) {
16136
16866
  src0->grad =
16137
- ggml_add_impl(ctx,
16867
+ ggml_add_or_set(ctx,
16138
16868
  src0->grad,
16139
16869
  ggml_scale(ctx,
16140
16870
  ggml_div(ctx,
16141
16871
  tensor->grad,
16142
16872
  tensor),
16143
16873
  ggml_new_f32(ctx, 0.5f)),
16144
- inplace);
16874
+ zero_table);
16145
16875
  }
16146
16876
  } break;
16147
16877
  case GGML_OP_LOG:
16148
16878
  {
16149
16879
  if (src0->grad) {
16150
16880
  src0->grad =
16151
- ggml_add_impl(ctx,
16881
+ ggml_add_or_set(ctx,
16152
16882
  src0->grad,
16153
16883
  ggml_div(ctx,
16154
16884
  tensor->grad,
16155
16885
  src0),
16156
- inplace);
16886
+ zero_table);
16157
16887
  }
16158
16888
  } break;
16159
16889
  case GGML_OP_SUM:
16160
16890
  {
16161
16891
  if (src0->grad) {
16162
16892
  src0->grad =
16163
- ggml_add1_impl(ctx,
16893
+ ggml_add1_or_set(ctx,
16164
16894
  src0->grad,
16165
16895
  tensor->grad,
16166
- inplace);
16896
+ zero_table);
16167
16897
  }
16168
16898
  } break;
16169
16899
  case GGML_OP_SUM_ROWS:
16170
16900
  {
16171
16901
  if (src0->grad) {
16172
16902
  src0->grad =
16173
- ggml_add_impl(ctx,
16903
+ ggml_add_or_set(ctx,
16174
16904
  src0->grad,
16175
16905
  ggml_repeat(ctx,
16176
16906
  tensor->grad,
16177
16907
  src0->grad),
16178
- inplace);
16908
+ zero_table);
16179
16909
  }
16180
16910
  } break;
16181
16911
  case GGML_OP_MEAN:
@@ -16187,20 +16917,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16187
16917
  {
16188
16918
  // necessary for llama
16189
16919
  if (src0->grad) {
16190
- src0->grad = ggml_add_impl(ctx,
16920
+ src0->grad = ggml_add_or_set(ctx,
16191
16921
  src0->grad,
16192
16922
  ggml_repeat_back(ctx, tensor->grad, src0->grad),
16193
- inplace);
16923
+ zero_table);
16194
16924
  }
16195
16925
  } break;
16196
16926
  case GGML_OP_REPEAT_BACK:
16197
16927
  {
16198
16928
  if (src0->grad) {
16199
16929
  // TODO: test this
16200
- src0->grad = ggml_add_impl(ctx,
16930
+ src0->grad = ggml_add_or_set(ctx,
16201
16931
  src0->grad,
16202
16932
  ggml_repeat(ctx, tensor->grad, src0->grad),
16203
- inplace);
16933
+ zero_table);
16204
16934
  }
16205
16935
  } break;
16206
16936
  case GGML_OP_CONCAT:
@@ -16222,10 +16952,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16222
16952
  float eps;
16223
16953
  memcpy(&eps, tensor->op_params, sizeof(float));
16224
16954
 
16225
- src0->grad = ggml_add_impl(ctx,
16955
+ src0->grad = ggml_add_or_set(ctx,
16226
16956
  src0->grad,
16227
16957
  ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
16228
- inplace);
16958
+ zero_table);
16229
16959
  }
16230
16960
  } break;
16231
16961
  case GGML_OP_RMS_NORM_BACK:
@@ -16249,37 +16979,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16249
16979
  // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
16250
16980
  // ds1 = t.T.dot(dt)
16251
16981
 
16252
- // tensor.shape [m,p]
16253
- // src0.shape [n,m]
16254
- // src1.shape [n,p]
16982
+ // tensor.shape [m,p,qq,rr]
16983
+ // src0.shape [n,m,q1,r1]
16984
+ // src1.shape [n,p,qq,rr]
16255
16985
 
16256
16986
  // necessary for llama
16257
16987
  if (src0->grad) {
16988
+ struct ggml_tensor * s1_tg =
16989
+ ggml_out_prod(ctx, // [n,m,qq,rr]
16990
+ src1, // [n,p,qq,rr]
16991
+ tensor->grad); // [m,p,qq,rr]
16992
+ const int64_t qq = s1_tg->ne[2];
16993
+ const int64_t rr = s1_tg->ne[3];
16994
+ const int64_t q1 = src0->ne[2];
16995
+ const int64_t r1 = src0->ne[3];
16996
+ const bool ne2_broadcasted = qq > q1;
16997
+ const bool ne3_broadcasted = rr > r1;
16998
+ if (ne2_broadcasted || ne3_broadcasted) {
16999
+ // sum broadcast repetitions of s1_tg into shape of src0
17000
+ s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
17001
+ }
16258
17002
  src0->grad =
16259
- ggml_add_impl(ctx,
16260
- src0->grad,
16261
- ggml_out_prod(ctx, // [n,m]
16262
- src1, // [n,p]
16263
- tensor->grad), // [m,p]
16264
- inplace);
17003
+ ggml_add_or_set(ctx,
17004
+ src0->grad, // [n,m,q1,r1]
17005
+ s1_tg, // [n,m,q1,r1]
17006
+ zero_table);
16265
17007
  }
16266
17008
  if (src1->grad) {
16267
17009
  src1->grad =
16268
- ggml_add_impl(ctx,
16269
- src1->grad,
16270
- // ggml_mul_mat(ctx, // [n,p]
16271
- // ggml_cont(ctx, // [m,n]
16272
- // ggml_transpose(ctx, src0)), // [m,n]
16273
- // tensor->grad), // [m,p]
17010
+ ggml_add_or_set(ctx,
17011
+ src1->grad, // [n,p,qq,rr]
17012
+ // ggml_mul_mat(ctx, // [n,p,qq,rr]
17013
+ // ggml_cont(ctx, // [m,n,q1,r1]
17014
+ // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
17015
+ // tensor->grad), // [m,p,qq,rr]
16274
17016
 
16275
17017
  // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
16276
17018
  // // avoid transpose of src0, rather transpose smaller tensor->grad
16277
17019
  // // and then use ggml_out_prod
16278
- ggml_out_prod(ctx, // [n,p]
16279
- src0, // [n,m]
16280
- ggml_transpose(ctx, // [p,m]
16281
- tensor->grad)), // [m,p]
16282
- inplace);
17020
+ ggml_out_prod(ctx, // [n,p,qq,rr]
17021
+ src0, // [n,m,q1,r1]
17022
+ ggml_transpose(ctx, // [p,m,qq,rr]
17023
+ tensor->grad)), // [m,p,qq,rr]
17024
+ zero_table);
16283
17025
  }
16284
17026
  } break;
16285
17027
  case GGML_OP_OUT_PROD:
@@ -16291,17 +17033,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16291
17033
  // necessary for llama
16292
17034
  if (src0->grad) {
16293
17035
  src0->grad =
16294
- ggml_add_impl(ctx,
17036
+ ggml_add_or_set(ctx,
16295
17037
  src0->grad,
16296
17038
  ggml_scale_impl(ctx, tensor->grad, src1, false),
16297
- inplace);
17039
+ zero_table);
16298
17040
  }
16299
17041
  if (src1->grad) {
16300
17042
  src1->grad =
16301
- ggml_add_impl(ctx,
17043
+ ggml_add_or_set(ctx,
16302
17044
  src1->grad,
16303
17045
  ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
16304
- inplace);
17046
+ zero_table);
16305
17047
  }
16306
17048
  } break;
16307
17049
  case GGML_OP_SET:
@@ -16328,23 +17070,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16328
17070
  }
16329
17071
 
16330
17072
  if (src0->grad) {
16331
- src0->grad = ggml_add_impl(ctx,
17073
+ src0->grad = ggml_add_or_set(ctx,
16332
17074
  src0->grad,
16333
17075
  ggml_acc_impl(ctx,
16334
17076
  tensor->grad,
16335
17077
  ggml_neg(ctx, tensor_grad_view),
16336
17078
  nb1, nb2, nb3, offset, false),
16337
- inplace);
17079
+ zero_table);
16338
17080
  }
16339
17081
 
16340
17082
  if (src1->grad) {
16341
17083
  src1->grad =
16342
- ggml_add_impl(ctx,
17084
+ ggml_add_or_set(ctx,
16343
17085
  src1->grad,
16344
17086
  ggml_reshape(ctx,
16345
17087
  ggml_cont(ctx, tensor_grad_view),
16346
17088
  src1->grad),
16347
- inplace);
17089
+ zero_table);
16348
17090
  }
16349
17091
  } break;
16350
17092
  case GGML_OP_CPY:
@@ -16355,7 +17097,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16355
17097
  // tensor = src0 * 1 + src1 * 0
16356
17098
  if (src0->grad) {
16357
17099
  // dsrc0 = dtensor * 1
16358
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
17100
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16359
17101
  }
16360
17102
  if (src1->grad) {
16361
17103
  // dsrc1 = dtensor * 0 -> noop
@@ -16367,7 +17109,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16367
17109
  if (src0->grad) {
16368
17110
  GGML_ASSERT(ggml_is_contiguous(src0->grad));
16369
17111
  GGML_ASSERT(ggml_is_contiguous(tensor->grad));
16370
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
17112
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16371
17113
  }
16372
17114
  } break;
16373
17115
  case GGML_OP_RESHAPE:
@@ -16375,9 +17117,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16375
17117
  // necessary for llama
16376
17118
  if (src0->grad) {
16377
17119
  src0->grad =
16378
- ggml_add_impl(ctx, src0->grad,
16379
- ggml_reshape(ctx, tensor->grad, src0->grad),
16380
- inplace);
17120
+ ggml_add_or_set(ctx, src0->grad,
17121
+ ggml_reshape(ctx,
17122
+ ggml_is_contiguous(tensor->grad)
17123
+ ? tensor->grad
17124
+ : ggml_cont(ctx, tensor->grad),
17125
+ src0->grad),
17126
+ zero_table);
16381
17127
  }
16382
17128
  } break;
16383
17129
  case GGML_OP_VIEW:
@@ -16406,7 +17152,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16406
17152
  nb3 = (nb3 / n0) * ng;
16407
17153
  }
16408
17154
 
16409
- src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
17155
+ src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
16410
17156
  }
16411
17157
  } break;
16412
17158
  case GGML_OP_PERMUTE:
@@ -16424,14 +17170,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16424
17170
  axes_backward[axis2] = 2;
16425
17171
  axes_backward[axis3] = 3;
16426
17172
  src0->grad =
16427
- ggml_add_impl(ctx, src0->grad,
17173
+ ggml_add_or_set(ctx, src0->grad,
16428
17174
  ggml_permute(ctx,
16429
17175
  tensor->grad,
16430
17176
  axes_backward[0],
16431
17177
  axes_backward[1],
16432
17178
  axes_backward[2],
16433
17179
  axes_backward[3]),
16434
- inplace);
17180
+ zero_table);
16435
17181
  }
16436
17182
  } break;
16437
17183
  case GGML_OP_TRANSPOSE:
@@ -16439,9 +17185,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16439
17185
  // necessary for llama
16440
17186
  if (src0->grad) {
16441
17187
  src0->grad =
16442
- ggml_add_impl(ctx, src0->grad,
17188
+ ggml_add_or_set(ctx, src0->grad,
16443
17189
  ggml_transpose(ctx, tensor->grad),
16444
- inplace);
17190
+ zero_table);
16445
17191
  }
16446
17192
  } break;
16447
17193
  case GGML_OP_GET_ROWS:
@@ -16449,9 +17195,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16449
17195
  // necessary for llama (only for tokenizer)
16450
17196
  if (src0->grad) {
16451
17197
  src0->grad =
16452
- ggml_add_impl(ctx, src0->grad,
17198
+ ggml_add_or_set(ctx, src0->grad,
17199
+ // last ggml_get_rows_back argument src0->grad is only
17200
+ // necessary to setup correct output shape
16453
17201
  ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
16454
- inplace);
17202
+ zero_table);
16455
17203
  }
16456
17204
  if (src1->grad) {
16457
17205
  // noop
@@ -16471,9 +17219,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16471
17219
  if (src0->grad) {
16472
17220
  const int n_past = ((int32_t *) tensor->op_params)[0];
16473
17221
  src0->grad =
16474
- ggml_add_impl(ctx, src0->grad,
17222
+ ggml_add_or_set(ctx, src0->grad,
16475
17223
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
16476
- inplace);
17224
+ zero_table);
16477
17225
  }
16478
17226
  } break;
16479
17227
  case GGML_OP_DIAG_MASK_ZERO:
@@ -16482,9 +17230,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16482
17230
  if (src0->grad) {
16483
17231
  const int n_past = ((int32_t *) tensor->op_params)[0];
16484
17232
  src0->grad =
16485
- ggml_add_impl(ctx, src0->grad,
17233
+ ggml_add_or_set(ctx, src0->grad,
16486
17234
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
16487
- inplace);
17235
+ zero_table);
16488
17236
  }
16489
17237
  } break;
16490
17238
  case GGML_OP_SOFT_MAX:
@@ -16492,9 +17240,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16492
17240
  // necessary for llama
16493
17241
  if (src0->grad) {
16494
17242
  src0->grad =
16495
- ggml_add_impl(ctx, src0->grad,
17243
+ ggml_add_or_set(ctx, src0->grad,
16496
17244
  ggml_soft_max_back(ctx, tensor->grad, tensor),
16497
- inplace);
17245
+ zero_table);
16498
17246
  }
16499
17247
 
16500
17248
  } break;
@@ -16506,7 +17254,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16506
17254
  {
16507
17255
  // necessary for llama
16508
17256
  if (src0->grad) {
16509
- const int n_past = ((int32_t *) tensor->op_params)[0];
17257
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
16510
17258
  const int n_dims = ((int32_t *) tensor->op_params)[1];
16511
17259
  const int mode = ((int32_t *) tensor->op_params)[2];
16512
17260
  const int n_ctx = ((int32_t *) tensor->op_params)[3];
@@ -16519,11 +17267,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16519
17267
  memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
16520
17268
  memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
16521
17269
 
16522
- src0->grad = ggml_add_impl(ctx,
17270
+ src0->grad = ggml_add_or_set(ctx,
16523
17271
  src0->grad,
16524
17272
  ggml_rope_back(ctx,
16525
17273
  tensor->grad,
16526
- n_past,
17274
+ src1,
16527
17275
  n_dims,
16528
17276
  mode,
16529
17277
  n_ctx,
@@ -16531,13 +17279,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16531
17279
  freq_scale,
16532
17280
  xpos_base,
16533
17281
  xpos_down),
16534
- inplace);
17282
+ zero_table);
16535
17283
  }
16536
17284
  } break;
16537
17285
  case GGML_OP_ROPE_BACK:
16538
17286
  {
16539
17287
  if (src0->grad) {
16540
- const int n_past = ((int32_t *) tensor->op_params)[0];
17288
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
16541
17289
  const int n_dims = ((int32_t *) tensor->op_params)[1];
16542
17290
  const int mode = ((int32_t *) tensor->op_params)[2];
16543
17291
  const int n_ctx = ((int32_t *) tensor->op_params)[3];
@@ -16550,11 +17298,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16550
17298
  memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
16551
17299
  memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
16552
17300
 
16553
- src0->grad = ggml_add_impl(ctx,
17301
+ src0->grad = ggml_add_or_set(ctx,
16554
17302
  src0->grad,
16555
17303
  ggml_rope_impl(ctx,
16556
17304
  tensor->grad,
16557
- n_past,
17305
+ src1,
16558
17306
  n_dims,
16559
17307
  mode,
16560
17308
  n_ctx,
@@ -16563,7 +17311,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16563
17311
  xpos_base,
16564
17312
  xpos_down,
16565
17313
  false),
16566
- inplace);
17314
+ zero_table);
16567
17315
  }
16568
17316
  } break;
16569
17317
  case GGML_OP_ALIBI:
@@ -16614,145 +17362,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16614
17362
  masked);
16615
17363
  }
16616
17364
 
16617
- if (src0->grad) {
16618
- struct ggml_tensor * grad_q = NULL;
16619
- const size_t nb0 = flash_grad->nb[0];
16620
- const size_t offset = 0;
16621
- switch(src0->n_dims) {
16622
- case 2:
16623
- {
16624
- grad_q = ggml_view_2d(ctx,
16625
- flash_grad,
16626
- src0->ne[0],
16627
- src0->ne[1],
16628
- nb0*src0->ne[0],
16629
- offset);
16630
- } break;
16631
- case 3:
16632
- {
16633
- grad_q = ggml_view_3d(ctx,
16634
- flash_grad,
16635
- src0->ne[0],
16636
- src0->ne[1],
16637
- src0->ne[2],
16638
- nb0*src0->ne[0],
16639
- nb0*src0->ne[0]*src0->ne[1],
16640
- offset);
16641
- } break;
16642
- case 4:
16643
- {
16644
- grad_q = ggml_view_4d(ctx,
16645
- flash_grad,
16646
- src0->ne[0],
16647
- src0->ne[1],
16648
- src0->ne[2],
16649
- src0->ne[3],
16650
- nb0*src0->ne[0],
16651
- nb0*src0->ne[0]*src0->ne[1],
16652
- nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
16653
- offset);
16654
- } break;
16655
- }
17365
+ struct ggml_tensor * src2 = tensor->src[2];
17366
+ const int64_t elem_q = ggml_nelements(src0);
17367
+ const int64_t elem_k = ggml_nelements(src1);
17368
+ const int64_t elem_v = ggml_nelements(src2);
17369
+
17370
+ enum ggml_type result_type = flash_grad->type;
17371
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
17372
+ const size_t tsize = ggml_type_size(result_type);
17373
+
17374
+ const size_t offs_q = 0;
17375
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
17376
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
16656
17377
 
16657
- src0->grad = ggml_add_impl(ctx,
17378
+ if (src0->grad) {
17379
+ struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
17380
+ struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
17381
+ src0->grad = ggml_add_or_set(ctx,
16658
17382
  src0->grad,
16659
17383
  grad_q,
16660
- inplace);
17384
+ zero_table);
16661
17385
  }
16662
-
16663
17386
  if (src1->grad) {
16664
- struct ggml_tensor * grad_k = NULL;
16665
- const size_t nb0 = flash_grad->nb[0];
16666
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
16667
- switch(src1->n_dims) {
16668
- case 2:
16669
- {
16670
- grad_k = ggml_view_2d(ctx,
16671
- flash_grad,
16672
- src1->ne[0],
16673
- src1->ne[1],
16674
- nb0*src1->ne[0],
16675
- offset);
16676
- } break;
16677
- case 3:
16678
- {
16679
- grad_k = ggml_view_3d(ctx,
16680
- flash_grad,
16681
- src1->ne[0],
16682
- src1->ne[1],
16683
- src1->ne[2],
16684
- nb0*src1->ne[0],
16685
- nb0*src1->ne[0]*src1->ne[1],
16686
- offset);
16687
- } break;
16688
- case 4:
16689
- {
16690
- grad_k = ggml_view_4d(ctx,
16691
- flash_grad,
16692
- src1->ne[0],
16693
- src1->ne[1],
16694
- src1->ne[2],
16695
- src1->ne[3],
16696
- nb0*src1->ne[0],
16697
- nb0*src1->ne[0]*src1->ne[1],
16698
- nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
16699
- offset);
16700
- } break;
16701
- }
16702
-
16703
- src1->grad = ggml_add_impl(ctx,
17387
+ struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
17388
+ struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
17389
+ src1->grad = ggml_add_or_set(ctx,
16704
17390
  src1->grad,
16705
17391
  grad_k,
16706
- inplace);
17392
+ zero_table);
16707
17393
  }
16708
-
16709
- struct ggml_tensor * opt0 = tensor->src[2];
16710
-
16711
- if (opt0->grad) {
16712
- struct ggml_tensor * grad_v = NULL;
16713
- const size_t nb0 = flash_grad->nb[0];
16714
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
16715
- + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
16716
- switch(opt0->n_dims) {
16717
- case 2:
16718
- {
16719
- grad_v = ggml_view_2d(ctx,
16720
- flash_grad,
16721
- opt0->ne[0],
16722
- opt0->ne[1],
16723
- nb0*opt0->ne[0],
16724
- offset);
16725
- } break;
16726
- case 3:
16727
- {
16728
- grad_v = ggml_view_3d(ctx,
16729
- flash_grad,
16730
- opt0->ne[0],
16731
- opt0->ne[1],
16732
- opt0->ne[2],
16733
- nb0*opt0->ne[0],
16734
- nb0*opt0->ne[0]*opt0->ne[1],
16735
- offset);
16736
- } break;
16737
- case 4:
16738
- {
16739
- grad_v = ggml_view_4d(ctx,
16740
- flash_grad,
16741
- opt0->ne[0],
16742
- opt0->ne[1],
16743
- opt0->ne[2],
16744
- opt0->ne[3],
16745
- nb0*opt0->ne[0],
16746
- nb0*opt0->ne[0]*opt0->ne[1],
16747
- nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
16748
- offset);
16749
- } break;
16750
- }
16751
-
16752
- opt0->grad = ggml_add_impl(ctx,
16753
- opt0->grad,
17394
+ if (src2->grad) {
17395
+ struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
17396
+ struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
17397
+ src2->grad = ggml_add_or_set(ctx,
17398
+ src2->grad,
16754
17399
  grad_v,
16755
- inplace);
17400
+ zero_table);
16756
17401
  }
16757
17402
  } break;
16758
17403
  case GGML_OP_FLASH_FF:
@@ -16772,12 +17417,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16772
17417
  {
16773
17418
  if (src0->grad) {
16774
17419
  src0->grad =
16775
- ggml_add_impl(ctx,
17420
+ ggml_add_or_set(ctx,
16776
17421
  src0->grad,
16777
17422
  ggml_mul(ctx,
16778
17423
  ggml_sgn(ctx, src0),
16779
17424
  tensor->grad),
16780
- inplace);
17425
+ zero_table);
16781
17426
  }
16782
17427
  } break;
16783
17428
  case GGML_UNARY_OP_SGN:
@@ -16789,7 +17434,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16789
17434
  case GGML_UNARY_OP_NEG:
16790
17435
  {
16791
17436
  if (src0->grad) {
16792
- src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
17437
+ src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
16793
17438
  }
16794
17439
  } break;
16795
17440
  case GGML_UNARY_OP_STEP:
@@ -16809,12 +17454,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16809
17454
  case GGML_UNARY_OP_RELU:
16810
17455
  {
16811
17456
  if (src0->grad) {
16812
- src0->grad = ggml_add_impl(ctx,
17457
+ src0->grad = ggml_add_or_set(ctx,
16813
17458
  src0->grad,
16814
17459
  ggml_mul(ctx,
16815
17460
  ggml_step(ctx, src0),
16816
17461
  tensor->grad),
16817
- inplace);
17462
+ zero_table);
16818
17463
  }
16819
17464
  } break;
16820
17465
  case GGML_UNARY_OP_GELU:
@@ -16829,10 +17474,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16829
17474
  {
16830
17475
  // necessary for llama
16831
17476
  if (src0->grad) {
16832
- src0->grad = ggml_add_impl(ctx,
17477
+ src0->grad = ggml_add_or_set(ctx,
16833
17478
  src0->grad,
16834
17479
  ggml_silu_back(ctx, src0, tensor->grad),
16835
- inplace);
17480
+ zero_table);
16836
17481
  }
16837
17482
  } break;
16838
17483
  default:
@@ -16855,13 +17500,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16855
17500
  case GGML_OP_CROSS_ENTROPY_LOSS:
16856
17501
  {
16857
17502
  if (src0->grad) {
16858
- src0->grad = ggml_add_impl(ctx,
17503
+ src0->grad = ggml_add_or_set(ctx,
16859
17504
  src0->grad,
16860
17505
  ggml_cross_entropy_loss_back(ctx,
16861
17506
  src0,
16862
17507
  src1,
16863
17508
  tensor->grad),
16864
- inplace);
17509
+ zero_table);
16865
17510
  }
16866
17511
  } break;
16867
17512
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@@ -16877,34 +17522,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16877
17522
  GGML_ASSERT(false);
16878
17523
  } break;
16879
17524
  }
16880
- }
16881
-
16882
- static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
16883
-
16884
- static size_t hash(void * p) {
16885
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
16886
- }
16887
17525
 
16888
- static bool hash_insert(void * hash_table[], void * p) {
16889
- size_t h = hash(p);
16890
-
16891
- // linear probing
16892
- size_t i = h;
16893
- while (hash_table[i] != NULL && hash_table[i] != p) {
16894
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
16895
- if (i == h) {
16896
- // hash table is full
16897
- GGML_ASSERT(false);
17526
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
17527
+ if (tensor->src[i] && tensor->src[i]->grad) {
17528
+ GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
16898
17529
  }
16899
17530
  }
16900
-
16901
- if (hash_table[i] == p) {
16902
- return true;
16903
- }
16904
-
16905
- // insert
16906
- hash_table[i] = p;
16907
- return false;
16908
17531
  }
16909
17532
 
16910
17533
  static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
@@ -16922,8 +17545,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
16922
17545
  }
16923
17546
 
16924
17547
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
16925
- if (node->src[i]) {
16926
- ggml_visit_parents(cgraph, node->src[i]);
17548
+ const int k =
17549
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
17550
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
17551
+ /* unknown order, just fall back to using i*/ i;
17552
+ if (node->src[k]) {
17553
+ ggml_visit_parents(cgraph, node->src[k]);
16927
17554
  }
16928
17555
  }
16929
17556
 
@@ -16982,6 +17609,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
16982
17609
  /*.grads =*/ { NULL },
16983
17610
  /*.leafs =*/ { NULL },
16984
17611
  /*.hash_table =*/ { NULL },
17612
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
16985
17613
  /*.perf_runs =*/ 0,
16986
17614
  /*.perf_cycles =*/ 0,
16987
17615
  /*.perf_time_us =*/ 0,
@@ -17007,12 +17635,22 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
17007
17635
  }
17008
17636
  }
17009
17637
 
17638
+ // remember original gradients which start with zero values
17639
+ void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
17640
+ memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
17641
+ for (int i = 0; i < gf->n_nodes; i++) {
17642
+ if (gf->grads[i]) {
17643
+ hash_insert(zero_table, gf->grads[i]);
17644
+ }
17645
+ }
17646
+
17010
17647
  for (int i = gf->n_nodes - 1; i >= 0; i--) {
17011
17648
  struct ggml_tensor * node = gf->nodes[i];
17012
17649
 
17013
- // because we detached the grad nodes from the original graph, we can afford inplace operations
17650
+ // inplace operations to add gradients are not created by ggml_compute_backward
17651
+ // use allocator to automatically make inplace operations
17014
17652
  if (node->grad) {
17015
- ggml_compute_backward(ctx, node, keep);
17653
+ ggml_compute_backward(ctx, node, zero_table);
17016
17654
  }
17017
17655
  }
17018
17656
 
@@ -17024,6 +17662,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
17024
17662
  ggml_build_forward_expand(gb, node->grad);
17025
17663
  }
17026
17664
  }
17665
+
17666
+ free(zero_table);
17027
17667
  }
17028
17668
 
17029
17669
  struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
@@ -17043,6 +17683,7 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
17043
17683
  /*.grads =*/ { NULL },
17044
17684
  /*.leafs =*/ { NULL },
17045
17685
  /*.hash_table =*/ { NULL },
17686
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
17046
17687
  /*.perf_runs =*/ 0,
17047
17688
  /*.perf_cycles =*/ 0,
17048
17689
  /*.perf_time_us =*/ 0,
@@ -17433,7 +18074,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
17433
18074
  } break;
17434
18075
  case GGML_OP_CONCAT:
17435
18076
  case GGML_OP_MUL_MAT:
17436
- case GGML_OP_OUT_PROD:
17437
18077
  {
17438
18078
  n_tasks = n_threads;
17439
18079
 
@@ -17475,6 +18115,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
17475
18115
  cur = 0;
17476
18116
  }
17477
18117
 
18118
+ work_size = MAX(work_size, cur);
18119
+ } break;
18120
+ case GGML_OP_OUT_PROD:
18121
+ {
18122
+ n_tasks = n_threads;
18123
+
18124
+ size_t cur = 0;
18125
+
18126
+ if (ggml_is_quantized(node->src[0]->type)) {
18127
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
18128
+ }
18129
+
17478
18130
  work_size = MAX(work_size, cur);
17479
18131
  } break;
17480
18132
  case GGML_OP_SCALE:
@@ -18568,7 +19220,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
18568
19220
  }
18569
19221
 
18570
19222
  static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
18571
- int i = 0;
19223
+ int64_t i = 0;
18572
19224
  for (int p = 0; p < np; ++p) {
18573
19225
  const int64_t ne = ggml_nelements(ps[p]) ;
18574
19226
  // TODO: add function to get all elements at once
@@ -18578,6 +19230,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
18578
19230
  }
18579
19231
  }
18580
19232
 
19233
+ static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
19234
+ int64_t i = 0;
19235
+ for (int p = 0; p < np; ++p) {
19236
+ const int64_t ne = ggml_nelements(ps[p]) ;
19237
+ // TODO: add function to get all elements at once
19238
+ for (int64_t j = 0; j < ne; ++j) {
19239
+ g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
19240
+ }
19241
+ }
19242
+ }
19243
+
18581
19244
  //
18582
19245
  // ADAM
18583
19246
  //
@@ -18626,26 +19289,43 @@ static enum ggml_opt_result ggml_opt_adam(
18626
19289
  const float eps = params.adam.eps;
18627
19290
  const float gclip = params.adam.gclip;
18628
19291
  const int decay_min_ndim = params.adam.decay_min_ndim;
19292
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
19293
+ const float accum_norm = 1.0f / (float) n_accum;
18629
19294
 
19295
+ float * g = opt->adam.g->data; // gradients
18630
19296
  float * m = opt->adam.m->data; // first moment
18631
19297
  float * v = opt->adam.v->data; // second moment
18632
19298
 
18633
19299
  float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
18634
19300
 
18635
- if (callback) {
18636
- callback(callback_data, &sched);
18637
- }
18638
-
18639
- // compute the function value
18640
- ggml_graph_reset (gf);
18641
- ggml_set_f32 (f->grad, 1.0f);
18642
-
18643
19301
  struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
18644
19302
  struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
18645
19303
  cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
18646
- ggml_graph_compute(gb, &cplan);
18647
19304
 
18648
- opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
19305
+ bool cancel = false;
19306
+
19307
+ // compute the function value
19308
+ float fx = 0;
19309
+ ggml_set_zero(opt->adam.g);
19310
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19311
+ if (callback) {
19312
+ callback(callback_data, accum_step, &sched, &cancel);
19313
+ if (cancel) {
19314
+ break;
19315
+ }
19316
+ }
19317
+ // ggml_graph_reset (gf);
19318
+ ggml_set_f32 (f->grad, 1.0f);
19319
+ ggml_graph_compute(gb, &cplan);
19320
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19321
+ fx += ggml_get_f32_1d(f, 0);
19322
+ }
19323
+ if (cancel) {
19324
+ return GGML_OPT_DID_NOT_CONVERGE;
19325
+ }
19326
+ fx *= accum_norm;
19327
+
19328
+ opt->adam.fx_prev = fx;
18649
19329
  opt->adam.fx_best = opt->adam.fx_prev;
18650
19330
  if (pf) {
18651
19331
  pf[opt->iter % params.past] = opt->adam.fx_prev;
@@ -18668,6 +19348,9 @@ static enum ggml_opt_result ggml_opt_adam(
18668
19348
 
18669
19349
  // run the optimizer
18670
19350
  for (int t = 0; t < params.adam.n_iter; ++t) {
19351
+ if (cancel) {
19352
+ break;
19353
+ }
18671
19354
  opt->iter = iter0 + t + 1;
18672
19355
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
18673
19356
 
@@ -18690,12 +19373,8 @@ static enum ggml_opt_result ggml_opt_adam(
18690
19373
  if (gclip > 0.0f) {
18691
19374
  // gradient clipping
18692
19375
  ggml_float sum = 0.0;
18693
- for (int p = 0; p < np; ++p) {
18694
- const int64_t ne = ggml_nelements(ps[p]);
18695
- for (int64_t j = 0; j < ne; ++j) {
18696
- float g = ggml_get_f32_1d(ps[p]->grad, j);
18697
- sum += (ggml_float)(g*g);
18698
- }
19376
+ for (int64_t i = 0; i < nx; ++i) {
19377
+ sum += (ggml_float)(g[i]*g[i]);
18699
19378
  }
18700
19379
  ggml_float norm = sqrt(sum);
18701
19380
  if (norm > (ggml_float) gclip) {
@@ -18709,10 +19388,10 @@ static enum ggml_opt_result ggml_opt_adam(
18709
19388
  const int64_t ne = ggml_nelements(ps[p]);
18710
19389
  const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
18711
19390
  for (int64_t j = 0; j < ne; ++j) {
18712
- float x = ggml_get_f32_1d(ps[p], j);
18713
- float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
18714
- m[i] = m[i]*beta1 + g*(1.0f - beta1);
18715
- v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
19391
+ float x = ggml_get_f32_1d(ps[p], j);
19392
+ float g_ = g[i]*gnorm;
19393
+ m[i] = m[i]*beta1 + g_*(1.0f - beta1);
19394
+ v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
18716
19395
  float mh = m[i]*beta1h;
18717
19396
  float vh = v[i]*beta2h;
18718
19397
  vh = sqrtf(vh) + eps;
@@ -18723,16 +19402,26 @@ static enum ggml_opt_result ggml_opt_adam(
18723
19402
  }
18724
19403
  }
18725
19404
 
18726
- if (callback) {
18727
- callback(callback_data, &sched);
19405
+ fx = 0;
19406
+ ggml_set_zero(opt->adam.g);
19407
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19408
+ if (callback) {
19409
+ callback(callback_data, accum_step, &sched, &cancel);
19410
+ if (cancel) {
19411
+ break;
19412
+ }
19413
+ }
19414
+ // ggml_graph_reset (gf);
19415
+ ggml_set_f32 (f->grad, 1.0f);
19416
+ ggml_graph_compute(gb, &cplan);
19417
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19418
+ fx += ggml_get_f32_1d(f, 0);
18728
19419
  }
19420
+ if (cancel) {
19421
+ break;
19422
+ }
19423
+ fx *= accum_norm;
18729
19424
 
18730
- ggml_graph_reset (gf);
18731
- ggml_set_f32 (f->grad, 1.0f);
18732
-
18733
- ggml_graph_compute(gb, &cplan);
18734
-
18735
- const float fx = ggml_get_f32_1d(f, 0);
18736
19425
  opt->loss_after = fx;
18737
19426
 
18738
19427
 
@@ -18812,11 +19501,11 @@ static enum ggml_opt_result linesearch_backtracking(
18812
19501
  float * step,
18813
19502
  const float * xp,
18814
19503
  struct ggml_tensor * f,
18815
- struct ggml_cgraph * gf,
18816
19504
  struct ggml_cgraph * gb,
18817
19505
  struct ggml_cplan * cplan,
18818
19506
  const int np,
18819
19507
  struct ggml_tensor * ps[],
19508
+ bool * cancel,
18820
19509
  ggml_opt_callback callback,
18821
19510
  void * callback_data) {
18822
19511
  int count = 0;
@@ -18830,6 +19519,9 @@ static enum ggml_opt_result linesearch_backtracking(
18830
19519
  const float dec = 0.5f;
18831
19520
  const float inc = 2.1f;
18832
19521
 
19522
+ const int n_accum = MAX(1, params->n_gradient_accumulation);
19523
+ const float accum_norm = 1.0f / (float) n_accum;
19524
+
18833
19525
  if (*step <= 0.f) {
18834
19526
  return GGML_LINESEARCH_INVALID_PARAMETERS;
18835
19527
  }
@@ -18846,13 +19538,7 @@ static enum ggml_opt_result linesearch_backtracking(
18846
19538
  finit = *fx;
18847
19539
  dgtest = params->lbfgs.ftol*dginit;
18848
19540
 
18849
- while (true) {
18850
- if (callback) {
18851
- // LBFG-S does not support learning rate -> ignore learning schedule
18852
- float sched = 0;
18853
- callback(callback_data, &sched);
18854
- }
18855
-
19541
+ while (!*cancel) {
18856
19542
  ggml_vec_cpy_f32(nx, x, xp);
18857
19543
  ggml_vec_mad_f32(nx, x, d, *step);
18858
19544
 
@@ -18860,14 +19546,28 @@ static enum ggml_opt_result linesearch_backtracking(
18860
19546
  {
18861
19547
  ggml_opt_set_params(np, ps, x);
18862
19548
 
18863
- ggml_graph_reset (gf);
18864
- ggml_set_f32 (f->grad, 1.0f);
18865
-
18866
- ggml_graph_compute(gb, cplan);
18867
-
18868
- ggml_opt_get_grad(np, ps, g);
19549
+ *fx = 0;
19550
+ memset(g, 0, sizeof(float)*nx);
19551
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19552
+ if (callback) {
19553
+ // LBFG-S does not support learning rate -> ignore learning schedule
19554
+ float sched = 0;
19555
+ callback(callback_data, accum_step, &sched, cancel);
19556
+ if (*cancel) {
19557
+ break;
19558
+ }
19559
+ }
19560
+ // ggml_graph_reset (gf);
19561
+ ggml_set_f32 (f->grad, 1.0f);
19562
+ ggml_graph_compute(gb, cplan);
19563
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19564
+ *fx += ggml_get_f32_1d(f, 0);
19565
+ }
19566
+ if (*cancel) {
19567
+ break;
19568
+ }
19569
+ *fx *= accum_norm;
18869
19570
 
18870
- *fx = ggml_get_f32_1d(f, 0);
18871
19571
  }
18872
19572
 
18873
19573
  ++count;
@@ -18913,7 +19613,7 @@ static enum ggml_opt_result linesearch_backtracking(
18913
19613
  (*step) *= width;
18914
19614
  }
18915
19615
 
18916
- return GGML_LINESEARCH_FAIL;
19616
+ GGML_UNREACHABLE();
18917
19617
  }
18918
19618
 
18919
19619
  static enum ggml_opt_result ggml_opt_lbfgs(
@@ -18968,6 +19668,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18968
19668
 
18969
19669
  float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
18970
19670
 
19671
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
19672
+ const float accum_norm = 1.0f / (float) n_accum;
19673
+
18971
19674
  float fx = 0.0f; // cost function value
18972
19675
  float xnorm = 0.0f; // ||x||
18973
19676
  float gnorm = 0.0f; // ||g||
@@ -18981,24 +19684,33 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18981
19684
  float * lm_s = opt->lbfgs.lms->data;
18982
19685
  float * lm_y = opt->lbfgs.lmy->data;
18983
19686
 
18984
- if (callback) {
18985
- // LBFG-S does not support learning rate -> ignore learning schedule
18986
- float sched = 0;
18987
- callback(callback_data, &sched);
18988
- }
19687
+ bool cancel = false;
18989
19688
 
18990
19689
  // evaluate the function value and its gradient
18991
19690
  {
18992
19691
  ggml_opt_set_params(np, ps, x);
18993
19692
 
18994
- ggml_graph_reset (gf);
18995
- ggml_set_f32 (f->grad, 1.0f);
18996
-
18997
- ggml_graph_compute(gb, &cplan);
18998
-
18999
- ggml_opt_get_grad(np, ps, g);
19000
-
19001
- fx = ggml_get_f32_1d(f, 0);
19693
+ fx = 0;
19694
+ memset(g, 0, sizeof(float)*nx);
19695
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19696
+ if (callback) {
19697
+ // LBFG-S does not support learning rate -> ignore learning schedule
19698
+ float sched = 0;
19699
+ callback(callback_data, accum_step, &sched, &cancel);
19700
+ if (cancel) {
19701
+ break;
19702
+ }
19703
+ }
19704
+ // ggml_graph_reset (gf);
19705
+ ggml_set_f32 (f->grad, 1.0f);
19706
+ ggml_graph_compute(gb, &cplan);
19707
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19708
+ fx += ggml_get_f32_1d(f, 0);
19709
+ }
19710
+ if (cancel) {
19711
+ return GGML_OPT_DID_NOT_CONVERGE;
19712
+ }
19713
+ fx *= accum_norm;
19002
19714
 
19003
19715
  opt->loss_before = fx;
19004
19716
  opt->loss_after = fx;
@@ -19056,7 +19768,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19056
19768
  ggml_vec_cpy_f32(nx, xp, x);
19057
19769
  ggml_vec_cpy_f32(nx, gp, g);
19058
19770
 
19059
- ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data);
19771
+ ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
19772
+ if (!cancel) {
19773
+ break;
19774
+ }
19060
19775
 
19061
19776
  if (ls < 0) {
19062
19777
  // linesearch failed - go back to the previous point and return
@@ -19165,7 +19880,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19165
19880
  step[0] = 1.0;
19166
19881
  }
19167
19882
 
19168
- return GGML_OPT_DID_NOT_CONVERGE;
19883
+ GGML_UNREACHABLE();
19169
19884
  }
19170
19885
 
19171
19886
  struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
@@ -19185,6 +19900,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
19185
19900
  .print_forward_graph = true,
19186
19901
  .print_backward_graph = true,
19187
19902
 
19903
+ .n_gradient_accumulation = 1,
19904
+
19188
19905
  .adam = {
19189
19906
  .n_iter = 10000,
19190
19907
  .sched = 1.000f,
@@ -19213,6 +19930,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
19213
19930
  .print_forward_graph = true,
19214
19931
  .print_backward_graph = true,
19215
19932
 
19933
+ .n_gradient_accumulation = 1,
19934
+
19216
19935
  .lbfgs = {
19217
19936
  .m = 6,
19218
19937
  .n_iter = 100,
@@ -19243,13 +19962,32 @@ GGML_API void ggml_opt_init(
19243
19962
  opt->iter = 0;
19244
19963
  opt->nx = nx;
19245
19964
  opt->just_initialized = true;
19965
+ if (opt->ctx == NULL) {
19966
+ struct ggml_init_params ctx_opt_params;
19967
+ if (opt->params.type == GGML_OPT_ADAM) {
19968
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
19969
+ if (opt->params.past > 0) {
19970
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
19971
+ }
19972
+ } else if (opt->params.type == GGML_OPT_LBFGS) {
19973
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
19974
+ if (opt->params.past > 0) {
19975
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
19976
+ }
19977
+ }
19978
+ ctx_opt_params.mem_buffer = NULL;
19979
+ ctx_opt_params.no_alloc = false;
19980
+
19981
+ opt->ctx = ggml_init(ctx_opt_params);
19982
+ }
19246
19983
  switch (opt->params.type) {
19247
19984
  case GGML_OPT_ADAM:
19248
19985
  {
19249
- opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19250
- opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19986
+ opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19987
+ opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19988
+ opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19251
19989
  opt->adam.pf = params.past > 0
19252
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
19990
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
19253
19991
  : NULL;
19254
19992
  ggml_set_zero(opt->adam.m);
19255
19993
  ggml_set_zero(opt->adam.v);
@@ -19259,18 +19997,18 @@ GGML_API void ggml_opt_init(
19259
19997
  } break;
19260
19998
  case GGML_OPT_LBFGS:
19261
19999
  {
19262
- opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19263
- opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19264
- opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19265
- opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19266
- opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
20000
+ opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20001
+ opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20002
+ opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20003
+ opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20004
+ opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19267
20005
  opt->lbfgs.pf = params.past > 0
19268
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
20006
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
19269
20007
  : NULL;
19270
- opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
19271
- opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
19272
- opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
19273
- opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
20008
+ opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
20009
+ opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
20010
+ opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
20011
+ opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
19274
20012
  ggml_set_zero(opt->lbfgs.x);
19275
20013
  ggml_set_zero(opt->lbfgs.xp);
19276
20014
  ggml_set_zero(opt->lbfgs.g);
@@ -19876,10 +20614,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
19876
20614
  } break;
19877
20615
  case GGUF_TYPE_ARRAY:
19878
20616
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
19879
- };
20617
+ }
19880
20618
  } break;
19881
20619
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
19882
- };
20620
+ }
19883
20621
 
19884
20622
  if (!ok) {
19885
20623
  break;
@@ -20155,78 +20893,94 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
20155
20893
  return keyfound;
20156
20894
  }
20157
20895
 
20158
- const char * gguf_get_key(const struct gguf_context * ctx, int i) {
20159
- return ctx->kv[i].key.data;
20896
+ const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
20897
+ return ctx->kv[key_id].key.data;
20160
20898
  }
20161
20899
 
20162
- enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) {
20163
- return ctx->kv[i].type;
20900
+ enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
20901
+ return ctx->kv[key_id].type;
20164
20902
  }
20165
20903
 
20166
- enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) {
20167
- return ctx->kv[i].value.arr.type;
20904
+ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
20905
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20906
+ return ctx->kv[key_id].value.arr.type;
20168
20907
  }
20169
20908
 
20170
- const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) {
20171
- return ctx->kv[i].value.arr.data;
20909
+ const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
20910
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20911
+ return ctx->kv[key_id].value.arr.data;
20172
20912
  }
20173
20913
 
20174
20914
  const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
20915
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20175
20916
  struct gguf_kv * kv = &ctx->kv[key_id];
20176
20917
  struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
20177
20918
  return str->data;
20178
20919
  }
20179
20920
 
20180
- int gguf_get_arr_n(const struct gguf_context * ctx, int i) {
20181
- return ctx->kv[i].value.arr.n;
20921
+ int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
20922
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20923
+ return ctx->kv[key_id].value.arr.n;
20182
20924
  }
20183
20925
 
20184
- uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) {
20185
- return ctx->kv[i].value.uint8;
20926
+ uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
20927
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
20928
+ return ctx->kv[key_id].value.uint8;
20186
20929
  }
20187
20930
 
20188
- int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) {
20189
- return ctx->kv[i].value.int8;
20931
+ int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
20932
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
20933
+ return ctx->kv[key_id].value.int8;
20190
20934
  }
20191
20935
 
20192
- uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) {
20193
- return ctx->kv[i].value.uint16;
20936
+ uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
20937
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
20938
+ return ctx->kv[key_id].value.uint16;
20194
20939
  }
20195
20940
 
20196
- int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) {
20197
- return ctx->kv[i].value.int16;
20941
+ int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
20942
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
20943
+ return ctx->kv[key_id].value.int16;
20198
20944
  }
20199
20945
 
20200
- uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) {
20201
- return ctx->kv[i].value.uint32;
20946
+ uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
20947
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
20948
+ return ctx->kv[key_id].value.uint32;
20202
20949
  }
20203
20950
 
20204
- int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) {
20205
- return ctx->kv[i].value.int32;
20951
+ int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
20952
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
20953
+ return ctx->kv[key_id].value.int32;
20206
20954
  }
20207
20955
 
20208
- float gguf_get_val_f32(const struct gguf_context * ctx, int i) {
20209
- return ctx->kv[i].value.float32;
20956
+ float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
20957
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
20958
+ return ctx->kv[key_id].value.float32;
20210
20959
  }
20211
20960
 
20212
- uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) {
20213
- return ctx->kv[i].value.uint64;
20961
+ uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
20962
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
20963
+ return ctx->kv[key_id].value.uint64;
20214
20964
  }
20215
20965
 
20216
- int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) {
20217
- return ctx->kv[i].value.int64;
20966
+ int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
20967
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
20968
+ return ctx->kv[key_id].value.int64;
20218
20969
  }
20219
20970
 
20220
- double gguf_get_val_f64(const struct gguf_context * ctx, int i) {
20221
- return ctx->kv[i].value.float64;
20971
+ double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
20972
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
20973
+ return ctx->kv[key_id].value.float64;
20222
20974
  }
20223
20975
 
20224
- bool gguf_get_val_bool(const struct gguf_context * ctx, int i) {
20225
- return ctx->kv[i].value.bool_;
20976
+ bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
20977
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
20978
+ return ctx->kv[key_id].value.bool_;
20226
20979
  }
20227
20980
 
20228
- const char * gguf_get_val_str (const struct gguf_context * ctx, int i) {
20229
- return ctx->kv[i].value.str.data;
20981
+ const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
20982
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
20983
+ return ctx->kv[key_id].value.str.data;
20230
20984
  }
20231
20985
 
20232
20986
  int gguf_get_n_tensors(const struct gguf_context * ctx) {
@@ -20591,10 +21345,10 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf *
20591
21345
  } break;
20592
21346
  case GGUF_TYPE_ARRAY:
20593
21347
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
20594
- };
21348
+ }
20595
21349
  } break;
20596
21350
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
20597
- };
21351
+ }
20598
21352
  }
20599
21353
 
20600
21354
  // write tensor infos