llama_cpp 0.5.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -89,7 +89,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo
89
89
 
90
90
  static int pthread_join(pthread_t thread, void * unused) {
91
91
  (void) unused;
92
- return (int) WaitForSingleObject(thread, INFINITE);
92
+ int ret = (int) WaitForSingleObject(thread, INFINITE);
93
+ CloseHandle(thread);
94
+ return ret;
93
95
  }
94
96
 
95
97
  static int sched_yield (void) {
@@ -134,6 +136,7 @@ typedef void * thread_ret_t;
134
136
 
135
137
  #define GGML_SOFT_MAX_UNROLL 4
136
138
  #define GGML_VEC_DOT_UNROLL 2
139
+ #define GGML_VEC_MAD_UNROLL 32
137
140
 
138
141
  //
139
142
  // logging
@@ -242,18 +245,18 @@ inline static void * ggml_aligned_malloc(size_t size) {
242
245
  //
243
246
 
244
247
  #define GGML_TENSOR_UNARY_OP_LOCALS \
245
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
246
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \
247
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \
248
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
248
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
249
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
250
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
251
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
249
252
 
250
253
  #define GGML_TENSOR_BINARY_OP_LOCALS \
251
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
252
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \
253
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \
254
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \
255
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \
256
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
254
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
255
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
256
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
257
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
258
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
259
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
257
260
 
258
261
  #if defined(GGML_USE_ACCELERATE)
259
262
  #include <Accelerate/Accelerate.h>
@@ -1863,7 +1866,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1863
1866
  #define GGML_F16x8_ADD vaddq_f16
1864
1867
  #define GGML_F16x8_MUL vmulq_f16
1865
1868
  #define GGML_F16x8_REDUCE(res, x) \
1866
- { \
1869
+ do { \
1867
1870
  int offset = GGML_F16_ARR >> 1; \
1868
1871
  for (int i = 0; i < offset; ++i) { \
1869
1872
  x[i] = vaddq_f16(x[i], x[offset+i]); \
@@ -1879,7 +1882,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1879
1882
  const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1880
1883
  const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
1881
1884
  res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1882
- }
1885
+ } while (0)
1883
1886
 
1884
1887
  #define GGML_F16_VEC GGML_F16x8
1885
1888
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
@@ -1940,7 +1943,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1940
1943
  #define GGML_F32x8_ADD _mm256_add_ps
1941
1944
  #define GGML_F32x8_MUL _mm256_mul_ps
1942
1945
  #define GGML_F32x8_REDUCE(res, x) \
1943
- { \
1946
+ do { \
1944
1947
  int offset = GGML_F32_ARR >> 1; \
1945
1948
  for (int i = 0; i < offset; ++i) { \
1946
1949
  x[i] = _mm256_add_ps(x[i], x[offset+i]); \
@@ -1957,7 +1960,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1957
1960
  _mm256_extractf128_ps(x[0], 1)); \
1958
1961
  const __m128 t1 = _mm_hadd_ps(t0, t0); \
1959
1962
  res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
1960
- }
1963
+ } while (0)
1961
1964
  // TODO: is this optimal ?
1962
1965
 
1963
1966
  #define GGML_F32_VEC GGML_F32x8
@@ -3707,6 +3710,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
3707
3710
  #endif
3708
3711
  }
3709
3712
 
3713
+ // xs and vs are byte strides of x and v
3714
+ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
3715
+
3716
+ const float * restrict x[GGML_VEC_MAD_UNROLL];
3717
+ const float * restrict v[GGML_VEC_MAD_UNROLL];
3718
+
3719
+ for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
3720
+ x[i] = (const float *) ((const char *) xv + i*xs);
3721
+ v[i] = (const float *) ((const char *) vv + i*vs);
3722
+ }
3723
+
3724
+ #if defined(GGML_SIMD)
3725
+ const int np = (n & ~(GGML_F32_STEP - 1));
3726
+
3727
+ GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
3728
+
3729
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3730
+ vx[k] = GGML_F32_VEC_SET1(v[k][0]);
3731
+ }
3732
+
3733
+ GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
3734
+ GGML_F32_VEC ay[GGML_F32_ARR];
3735
+
3736
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
3737
+ for (int j = 0; j < GGML_F32_ARR; j++) {
3738
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
3739
+
3740
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3741
+ ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
3742
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
3743
+ }
3744
+
3745
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
3746
+ }
3747
+ }
3748
+
3749
+ // leftovers
3750
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3751
+ for (int i = np; i < n; ++i) {
3752
+ y[i] += x[k][i]*v[k][0];
3753
+ }
3754
+ }
3755
+ #else
3756
+ // scalar
3757
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3758
+ for (int i = 0; i < n; ++i) {
3759
+ y[i] += x[k][i]*v[k][0];
3760
+ }
3761
+ }
3762
+ #endif
3763
+ }
3764
+
3710
3765
  //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
3711
3766
  inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
3712
3767
  #if defined(GGML_USE_ACCELERATE)
@@ -4392,10 +4447,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
4392
4447
  static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
4393
4448
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4394
4449
 
4395
- return
4396
- (t0->ne[1] == t1->ne[1]) &&
4397
- (t0->ne[2] == t1->ne[2]) &&
4398
- (t0->ne[3] == t1->ne[3]);
4450
+ return (t0->ne[1] == t1->ne[1]) &&
4451
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
4452
+ (t1->ne[3]%t0->ne[3] == 0);
4399
4453
  }
4400
4454
 
4401
4455
  enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
@@ -5065,43 +5119,78 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
5065
5119
  return tensor;
5066
5120
  }
5067
5121
 
5122
+ void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
5123
+ const int64_t ne2 = tensor->ne[2];
5124
+ const int64_t ne1 = tensor->ne[1];
5125
+ const int64_t ne0 = tensor->ne[0];
5126
+
5127
+ const int64_t i3_ = (i/(ne2*ne1*ne0));
5128
+ const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
5129
+ const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
5130
+ const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
5131
+
5132
+ if (i0) {
5133
+ * i0 = i0_;
5134
+ }
5135
+ if (i1) {
5136
+ * i1 = i1_;
5137
+ }
5138
+ if (i2) {
5139
+ * i2 = i2_;
5140
+ }
5141
+ if (i3) {
5142
+ * i3 = i3_;
5143
+ }
5144
+ }
5145
+
5068
5146
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
5147
+ if (!ggml_is_contiguous(tensor)) {
5148
+ int64_t id[4] = { 0, 0, 0, 0 };
5149
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5150
+ return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
5151
+ }
5069
5152
  switch (tensor->type) {
5070
5153
  case GGML_TYPE_I8:
5071
5154
  {
5072
5155
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
5073
5156
  return ((int8_t *)(tensor->data))[i];
5074
- } break;
5157
+ }
5075
5158
  case GGML_TYPE_I16:
5076
5159
  {
5077
5160
  GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
5078
5161
  return ((int16_t *)(tensor->data))[i];
5079
- } break;
5162
+ }
5080
5163
  case GGML_TYPE_I32:
5081
5164
  {
5082
5165
  GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
5083
5166
  return ((int32_t *)(tensor->data))[i];
5084
- } break;
5167
+ }
5085
5168
  case GGML_TYPE_F16:
5086
5169
  {
5087
5170
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
5088
5171
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
5089
- } break;
5172
+ }
5090
5173
  case GGML_TYPE_F32:
5091
5174
  {
5092
5175
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
5093
5176
  return ((float *)(tensor->data))[i];
5094
- } break;
5177
+ }
5095
5178
  default:
5096
5179
  {
5097
5180
  GGML_ASSERT(false);
5098
- } break;
5181
+ }
5099
5182
  }
5100
5183
 
5101
5184
  return 0.0f;
5102
5185
  }
5103
5186
 
5104
5187
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
5188
+ if (!ggml_is_contiguous(tensor)) {
5189
+ int64_t id[4] = { 0, 0, 0, 0 };
5190
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5191
+ ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
5192
+ return;
5193
+ }
5105
5194
  switch (tensor->type) {
5106
5195
  case GGML_TYPE_I8:
5107
5196
  {
@@ -5135,43 +5224,104 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
5135
5224
  }
5136
5225
  }
5137
5226
 
5227
+ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
5228
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5229
+ switch (tensor->type) {
5230
+ case GGML_TYPE_I8:
5231
+ return ((int8_t *) data)[0];
5232
+ case GGML_TYPE_I16:
5233
+ return ((int16_t *) data)[0];
5234
+ case GGML_TYPE_I32:
5235
+ return ((int32_t *) data)[0];
5236
+ case GGML_TYPE_F16:
5237
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
5238
+ case GGML_TYPE_F32:
5239
+ return ((float *) data)[0];
5240
+ default:
5241
+ GGML_ASSERT(false);
5242
+ }
5243
+
5244
+ return 0.0f;
5245
+ }
5246
+
5247
+ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
5248
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5249
+ switch (tensor->type) {
5250
+ case GGML_TYPE_I8:
5251
+ {
5252
+ ((int8_t *)(data))[0] = value;
5253
+ } break;
5254
+ case GGML_TYPE_I16:
5255
+ {
5256
+ ((int16_t *)(data))[0] = value;
5257
+ } break;
5258
+ case GGML_TYPE_I32:
5259
+ {
5260
+ ((int32_t *)(data))[0] = value;
5261
+ } break;
5262
+ case GGML_TYPE_F16:
5263
+ {
5264
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
5265
+ } break;
5266
+ case GGML_TYPE_F32:
5267
+ {
5268
+ ((float *)(data))[0] = value;
5269
+ } break;
5270
+ default:
5271
+ {
5272
+ GGML_ASSERT(false);
5273
+ } break;
5274
+ }
5275
+ }
5276
+
5138
5277
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
5278
+ if (!ggml_is_contiguous(tensor)) {
5279
+ int64_t id[4] = { 0, 0, 0, 0 };
5280
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5281
+ return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
5282
+ }
5139
5283
  switch (tensor->type) {
5140
5284
  case GGML_TYPE_I8:
5141
5285
  {
5142
5286
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
5143
5287
  return ((int8_t *)(tensor->data))[i];
5144
- } break;
5288
+ }
5145
5289
  case GGML_TYPE_I16:
5146
5290
  {
5147
5291
  GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
5148
5292
  return ((int16_t *)(tensor->data))[i];
5149
- } break;
5293
+ }
5150
5294
  case GGML_TYPE_I32:
5151
5295
  {
5152
5296
  GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
5153
5297
  return ((int32_t *)(tensor->data))[i];
5154
- } break;
5298
+ }
5155
5299
  case GGML_TYPE_F16:
5156
5300
  {
5157
5301
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
5158
5302
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
5159
- } break;
5303
+ }
5160
5304
  case GGML_TYPE_F32:
5161
5305
  {
5162
5306
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
5163
5307
  return ((float *)(tensor->data))[i];
5164
- } break;
5308
+ }
5165
5309
  default:
5166
5310
  {
5167
5311
  GGML_ASSERT(false);
5168
- } break;
5312
+ }
5169
5313
  }
5170
5314
 
5171
5315
  return 0.0f;
5172
5316
  }
5173
5317
 
5174
5318
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
5319
+ if (!ggml_is_contiguous(tensor)) {
5320
+ int64_t id[4] = { 0, 0, 0, 0 };
5321
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5322
+ ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
5323
+ return;
5324
+ }
5175
5325
  switch (tensor->type) {
5176
5326
  case GGML_TYPE_I8:
5177
5327
  {
@@ -5205,6 +5355,56 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
5205
5355
  }
5206
5356
  }
5207
5357
 
5358
+ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
5359
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5360
+ switch (tensor->type) {
5361
+ case GGML_TYPE_I8:
5362
+ return ((int8_t *) data)[0];
5363
+ case GGML_TYPE_I16:
5364
+ return ((int16_t *) data)[0];
5365
+ case GGML_TYPE_I32:
5366
+ return ((int32_t *) data)[0];
5367
+ case GGML_TYPE_F16:
5368
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
5369
+ case GGML_TYPE_F32:
5370
+ return ((float *) data)[0];
5371
+ default:
5372
+ GGML_ASSERT(false);
5373
+ }
5374
+
5375
+ return 0.0f;
5376
+ }
5377
+
5378
+ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
5379
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5380
+ switch (tensor->type) {
5381
+ case GGML_TYPE_I8:
5382
+ {
5383
+ ((int8_t *)(data))[0] = value;
5384
+ } break;
5385
+ case GGML_TYPE_I16:
5386
+ {
5387
+ ((int16_t *)(data))[0] = value;
5388
+ } break;
5389
+ case GGML_TYPE_I32:
5390
+ {
5391
+ ((int32_t *)(data))[0] = value;
5392
+ } break;
5393
+ case GGML_TYPE_F16:
5394
+ {
5395
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
5396
+ } break;
5397
+ case GGML_TYPE_F32:
5398
+ {
5399
+ ((float *)(data))[0] = value;
5400
+ } break;
5401
+ default:
5402
+ {
5403
+ GGML_ASSERT(false);
5404
+ } break;
5405
+ }
5406
+ }
5407
+
5208
5408
  void * ggml_get_data(const struct ggml_tensor * tensor) {
5209
5409
  return tensor->data;
5210
5410
  }
@@ -5347,6 +5547,44 @@ struct ggml_tensor * ggml_add_inplace(
5347
5547
  return ggml_add_impl(ctx, a, b, true);
5348
5548
  }
5349
5549
 
5550
+ // ggml_add_cast
5551
+
5552
+ static struct ggml_tensor * ggml_add_cast_impl(
5553
+ struct ggml_context * ctx,
5554
+ struct ggml_tensor * a,
5555
+ struct ggml_tensor * b,
5556
+ enum ggml_type type) {
5557
+ // TODO: support less-strict constraint
5558
+ // GGML_ASSERT(ggml_can_repeat(b, a));
5559
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
5560
+ GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
5561
+
5562
+ bool is_node = false;
5563
+
5564
+ if (a->grad || b->grad) {
5565
+ // TODO: support backward pass for broadcasting
5566
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5567
+ is_node = true;
5568
+ }
5569
+
5570
+ struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
5571
+
5572
+ result->op = GGML_OP_ADD;
5573
+ result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
5574
+ result->src[0] = a;
5575
+ result->src[1] = b;
5576
+
5577
+ return result;
5578
+ }
5579
+
5580
+ struct ggml_tensor * ggml_add_cast(
5581
+ struct ggml_context * ctx,
5582
+ struct ggml_tensor * a,
5583
+ struct ggml_tensor * b,
5584
+ enum ggml_type type) {
5585
+ return ggml_add_cast_impl(ctx, a, b, type);
5586
+ }
5587
+
5350
5588
  // ggml_add1
5351
5589
 
5352
5590
  static struct ggml_tensor * ggml_add1_impl(
@@ -5783,7 +6021,6 @@ struct ggml_tensor * ggml_repeat(
5783
6021
  result->op = GGML_OP_REPEAT;
5784
6022
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5785
6023
  result->src[0] = a;
5786
- result->src[1] = b;
5787
6024
 
5788
6025
  return result;
5789
6026
  }
@@ -5811,7 +6048,6 @@ struct ggml_tensor * ggml_repeat_back(
5811
6048
  result->op = GGML_OP_REPEAT_BACK;
5812
6049
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5813
6050
  result->src[0] = a;
5814
- result->src[1] = b;
5815
6051
 
5816
6052
  return result;
5817
6053
  }
@@ -6186,8 +6422,9 @@ struct ggml_tensor * ggml_out_prod(
6186
6422
  is_node = true;
6187
6423
  }
6188
6424
 
6189
- const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
6190
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
6425
+ // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
6426
+ const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
6427
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
6191
6428
 
6192
6429
  result->op = GGML_OP_OUT_PROD;
6193
6430
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6406,6 +6643,54 @@ struct ggml_tensor * ggml_cont_inplace(
6406
6643
  return ggml_cont_impl(ctx, a, true);
6407
6644
  }
6408
6645
 
6646
+
6647
+ // make contiguous, with new shape
6648
+ GGML_API struct ggml_tensor * ggml_cont_1d(
6649
+ struct ggml_context * ctx,
6650
+ struct ggml_tensor * a,
6651
+ int64_t ne0) {
6652
+ return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
6653
+ }
6654
+
6655
+ GGML_API struct ggml_tensor * ggml_cont_2d(
6656
+ struct ggml_context * ctx,
6657
+ struct ggml_tensor * a,
6658
+ int64_t ne0,
6659
+ int64_t ne1) {
6660
+ return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
6661
+ }
6662
+
6663
+ GGML_API struct ggml_tensor * ggml_cont_3d(
6664
+ struct ggml_context * ctx,
6665
+ struct ggml_tensor * a,
6666
+ int64_t ne0,
6667
+ int64_t ne1,
6668
+ int64_t ne2) {
6669
+ return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
6670
+ }
6671
+
6672
+ struct ggml_tensor * ggml_cont_4d(
6673
+ struct ggml_context * ctx,
6674
+ struct ggml_tensor * a,
6675
+ int64_t ne0,
6676
+ int64_t ne1,
6677
+ int64_t ne2,
6678
+ int64_t ne3) {
6679
+ GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
6680
+
6681
+ bool is_node = false;
6682
+
6683
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
6684
+ ggml_format_name(result, "%s (cont)", a->name);
6685
+
6686
+ result->op = GGML_OP_CONT;
6687
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6688
+ result->src[0] = a;
6689
+
6690
+ return result;
6691
+ }
6692
+
6693
+
6409
6694
  // ggml_reshape
6410
6695
 
6411
6696
  struct ggml_tensor * ggml_reshape(
@@ -6413,7 +6698,7 @@ struct ggml_tensor * ggml_reshape(
6413
6698
  struct ggml_tensor * a,
6414
6699
  struct ggml_tensor * b) {
6415
6700
  GGML_ASSERT(ggml_is_contiguous(a));
6416
- GGML_ASSERT(ggml_is_contiguous(b));
6701
+ // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
6417
6702
  GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
6418
6703
 
6419
6704
  bool is_node = false;
@@ -6786,7 +7071,6 @@ struct ggml_tensor * ggml_get_rows_back(
6786
7071
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6787
7072
  result->src[0] = a;
6788
7073
  result->src[1] = b;
6789
- result->src[2] = c;
6790
7074
 
6791
7075
  return result;
6792
7076
  }
@@ -6968,7 +7252,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
6968
7252
  static struct ggml_tensor * ggml_rope_impl(
6969
7253
  struct ggml_context * ctx,
6970
7254
  struct ggml_tensor * a,
6971
- int n_past,
7255
+ struct ggml_tensor * b,
6972
7256
  int n_dims,
6973
7257
  int mode,
6974
7258
  int n_ctx,
@@ -6977,7 +7261,10 @@ static struct ggml_tensor * ggml_rope_impl(
6977
7261
  float xpos_base,
6978
7262
  bool xpos_down,
6979
7263
  bool inplace) {
6980
- GGML_ASSERT(n_past >= 0);
7264
+ GGML_ASSERT(ggml_is_vector(b));
7265
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
7266
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
7267
+
6981
7268
  bool is_node = false;
6982
7269
 
6983
7270
  if (a->grad) {
@@ -6986,7 +7273,7 @@ static struct ggml_tensor * ggml_rope_impl(
6986
7273
 
6987
7274
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6988
7275
 
6989
- int32_t params[8] = { n_past, n_dims, mode, n_ctx };
7276
+ int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
6990
7277
  memcpy(params + 4, &freq_base, sizeof(float));
6991
7278
  memcpy(params + 5, &freq_scale, sizeof(float));
6992
7279
  memcpy(params + 6, &xpos_base, sizeof(float));
@@ -6996,6 +7283,7 @@ static struct ggml_tensor * ggml_rope_impl(
6996
7283
  result->op = GGML_OP_ROPE;
6997
7284
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6998
7285
  result->src[0] = a;
7286
+ result->src[1] = b;
6999
7287
 
7000
7288
  return result;
7001
7289
  }
@@ -7003,55 +7291,55 @@ static struct ggml_tensor * ggml_rope_impl(
7003
7291
  struct ggml_tensor * ggml_rope(
7004
7292
  struct ggml_context * ctx,
7005
7293
  struct ggml_tensor * a,
7006
- int n_past,
7294
+ struct ggml_tensor * b,
7007
7295
  int n_dims,
7008
7296
  int mode,
7009
7297
  int n_ctx) {
7010
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
7298
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
7011
7299
  }
7012
7300
 
7013
7301
  struct ggml_tensor * ggml_rope_inplace(
7014
7302
  struct ggml_context * ctx,
7015
7303
  struct ggml_tensor * a,
7016
- int n_past,
7304
+ struct ggml_tensor * b,
7017
7305
  int n_dims,
7018
7306
  int mode,
7019
7307
  int n_ctx) {
7020
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
7308
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
7021
7309
  }
7022
7310
 
7023
7311
  struct ggml_tensor * ggml_rope_custom(
7024
7312
  struct ggml_context * ctx,
7025
7313
  struct ggml_tensor * a,
7026
- int n_past,
7314
+ struct ggml_tensor * b,
7027
7315
  int n_dims,
7028
7316
  int mode,
7029
7317
  int n_ctx,
7030
7318
  float freq_base,
7031
7319
  float freq_scale) {
7032
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
7320
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
7033
7321
  }
7034
7322
 
7035
7323
  struct ggml_tensor * ggml_rope_custom_inplace(
7036
7324
  struct ggml_context * ctx,
7037
7325
  struct ggml_tensor * a,
7038
- int n_past,
7326
+ struct ggml_tensor * b,
7039
7327
  int n_dims,
7040
7328
  int mode,
7041
7329
  int n_ctx,
7042
7330
  float freq_base,
7043
7331
  float freq_scale) {
7044
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
7332
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
7045
7333
  }
7046
7334
 
7047
7335
  struct ggml_tensor * ggml_rope_xpos_inplace(
7048
7336
  struct ggml_context * ctx,
7049
7337
  struct ggml_tensor * a,
7050
- int n_past,
7338
+ struct ggml_tensor * b,
7051
7339
  int n_dims,
7052
7340
  float base,
7053
7341
  bool down) {
7054
- return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
7342
+ return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
7055
7343
  }
7056
7344
 
7057
7345
  // ggml_rope_back
@@ -7059,7 +7347,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
7059
7347
  struct ggml_tensor * ggml_rope_back(
7060
7348
  struct ggml_context * ctx,
7061
7349
  struct ggml_tensor * a,
7062
- int n_past,
7350
+ struct ggml_tensor * b,
7063
7351
  int n_dims,
7064
7352
  int mode,
7065
7353
  int n_ctx,
@@ -7067,7 +7355,10 @@ struct ggml_tensor * ggml_rope_back(
7067
7355
  float freq_scale,
7068
7356
  float xpos_base,
7069
7357
  bool xpos_down) {
7070
- GGML_ASSERT(n_past >= 0);
7358
+ GGML_ASSERT(ggml_is_vector(b));
7359
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
7360
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
7361
+
7071
7362
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
7072
7363
 
7073
7364
  bool is_node = false;
@@ -7078,7 +7369,7 @@ struct ggml_tensor * ggml_rope_back(
7078
7369
 
7079
7370
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
7080
7371
 
7081
- int32_t params[8] = { n_past, n_dims, mode, n_ctx };
7372
+ int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
7082
7373
  memcpy(params + 4, &freq_base, sizeof(float));
7083
7374
  memcpy(params + 5, &freq_scale, sizeof(float));
7084
7375
  memcpy(params + 6, &xpos_base, sizeof(float));
@@ -7088,6 +7379,7 @@ struct ggml_tensor * ggml_rope_back(
7088
7379
  result->op = GGML_OP_ROPE_BACK;
7089
7380
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7090
7381
  result->src[0] = a;
7382
+ result->src[1] = b;
7091
7383
 
7092
7384
  return result;
7093
7385
  }
@@ -7484,27 +7776,30 @@ struct ggml_tensor * ggml_flash_attn_back(
7484
7776
 
7485
7777
  // d shape [D,N,ne2,ne3]
7486
7778
  // q shape [D,N,ne2,ne3]
7487
- // k shape [D,M,ne2,ne3]
7488
- // v shape [M,D,ne2,ne3]
7779
+ // k shape [D,M,kvne2,ne3]
7780
+ // v shape [M,D,kvne2,ne3]
7489
7781
 
7490
- const int64_t D = q->ne[0];
7491
- const int64_t N = q->ne[1];
7492
- const int64_t M = k->ne[1];
7493
- const int64_t ne2 = q->ne[2];
7494
- const int64_t ne3 = q->ne[3];
7782
+ const int64_t D = q->ne[0];
7783
+ const int64_t N = q->ne[1];
7784
+ const int64_t M = k->ne[1];
7785
+ const int64_t ne2 = q->ne[2];
7786
+ const int64_t ne3 = q->ne[3];
7787
+ const int64_t kvne2 = k->ne[2];
7495
7788
 
7496
7789
  GGML_ASSERT(k->ne[0] == D);
7497
7790
  GGML_ASSERT(v->ne[0] == M);
7498
7791
  GGML_ASSERT(v->ne[1] == D);
7499
7792
  GGML_ASSERT(d->ne[0] == D);
7500
7793
  GGML_ASSERT(d->ne[1] == N);
7501
- GGML_ASSERT(k->ne[2] == ne2);
7794
+ GGML_ASSERT(k->ne[2] == kvne2);
7502
7795
  GGML_ASSERT(k->ne[3] == ne3);
7503
- GGML_ASSERT(v->ne[2] == ne2);
7796
+ GGML_ASSERT(v->ne[2] == kvne2);
7504
7797
  GGML_ASSERT(v->ne[3] == ne3);
7505
7798
  GGML_ASSERT(d->ne[2] == ne2);
7506
7799
  GGML_ASSERT(d->ne[3] == ne3);
7507
7800
 
7801
+ GGML_ASSERT(ne2 % kvne2 == 0);
7802
+
7508
7803
  bool is_node = false;
7509
7804
 
7510
7805
  if (q->grad || k->grad || v->grad) {
@@ -7514,14 +7809,23 @@ struct ggml_tensor * ggml_flash_attn_back(
7514
7809
  }
7515
7810
 
7516
7811
  // store gradients of q, k and v as continuous tensors concatenated in result.
7517
- // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
7518
- // gradq->data = result->data
7519
- // gradk->data = result->data + nb0*D*N*ne2*ne3
7520
- // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
7521
7812
  // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
7522
- int64_t ne[4] = {D,M+N+M,ne2,ne3};
7813
+ const int64_t elem_q = ggml_nelements(q);
7814
+ const int64_t elem_k = ggml_nelements(k);
7815
+ const int64_t elem_v = ggml_nelements(v);
7523
7816
 
7524
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7817
+ enum ggml_type result_type = GGML_TYPE_F32;
7818
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
7819
+ const size_t tsize = ggml_type_size(result_type);
7820
+
7821
+ const size_t offs_q = 0;
7822
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
7823
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
7824
+ const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
7825
+
7826
+ const size_t nelements = (end + tsize - 1)/tsize;
7827
+
7828
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
7525
7829
 
7526
7830
  int32_t masked_i = masked ? 1 : 0;
7527
7831
  ggml_set_op_params(result, &masked_i, sizeof(masked_i));
@@ -8214,7 +8518,7 @@ static void ggml_compute_forward_dup_f16(
8214
8518
  return;
8215
8519
  }
8216
8520
 
8217
- GGML_TENSOR_UNARY_OP_LOCALS;
8521
+ GGML_TENSOR_UNARY_OP_LOCALS
8218
8522
 
8219
8523
  const int ith = params->ith; // thread index
8220
8524
  const int nth = params->nth; // number of threads
@@ -8485,7 +8789,7 @@ static void ggml_compute_forward_dup_f32(
8485
8789
  return;
8486
8790
  }
8487
8791
 
8488
- GGML_TENSOR_UNARY_OP_LOCALS;
8792
+ GGML_TENSOR_UNARY_OP_LOCALS
8489
8793
 
8490
8794
  const int ith = params->ith; // thread index
8491
8795
  const int nth = params->nth; // number of threads
@@ -8766,7 +9070,7 @@ static void ggml_compute_forward_add_f32(
8766
9070
 
8767
9071
  const int nr = ggml_nrows(src0);
8768
9072
 
8769
- GGML_TENSOR_BINARY_OP_LOCALS;
9073
+ GGML_TENSOR_BINARY_OP_LOCALS
8770
9074
 
8771
9075
  GGML_ASSERT( nb0 == sizeof(float));
8772
9076
  GGML_ASSERT(nb00 == sizeof(float));
@@ -8798,8 +9102,6 @@ static void ggml_compute_forward_add_f32(
8798
9102
  #else
8799
9103
  ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
8800
9104
  #endif
8801
- // }
8802
- // }
8803
9105
  }
8804
9106
  } else {
8805
9107
  // src1 is not contiguous
@@ -8841,7 +9143,7 @@ static void ggml_compute_forward_add_f16_f32(
8841
9143
 
8842
9144
  const int nr = ggml_nrows(src0);
8843
9145
 
8844
- GGML_TENSOR_BINARY_OP_LOCALS;
9146
+ GGML_TENSOR_BINARY_OP_LOCALS
8845
9147
 
8846
9148
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
8847
9149
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -8895,7 +9197,7 @@ static void ggml_compute_forward_add_f16_f16(
8895
9197
 
8896
9198
  const int nr = ggml_nrows(src0);
8897
9199
 
8898
- GGML_TENSOR_BINARY_OP_LOCALS;
9200
+ GGML_TENSOR_BINARY_OP_LOCALS
8899
9201
 
8900
9202
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
8901
9203
  GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -8946,14 +9248,15 @@ static void ggml_compute_forward_add_q_f32(
8946
9248
 
8947
9249
  const int nr = ggml_nrows(src0);
8948
9250
 
8949
- GGML_TENSOR_BINARY_OP_LOCALS;
9251
+ GGML_TENSOR_BINARY_OP_LOCALS
8950
9252
 
8951
9253
  const int ith = params->ith;
8952
9254
  const int nth = params->nth;
8953
9255
 
8954
9256
  const enum ggml_type type = src0->type;
9257
+ const enum ggml_type dtype = dst->type;
8955
9258
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
8956
- ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
9259
+ ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
8957
9260
 
8958
9261
  // we don't support permuted src0 or src1
8959
9262
  GGML_ASSERT(nb00 == ggml_type_size(type));
@@ -8965,7 +9268,6 @@ static void ggml_compute_forward_add_q_f32(
8965
9268
  GGML_ASSERT(nb2 <= nb3);
8966
9269
 
8967
9270
  GGML_ASSERT(ggml_is_quantized(src0->type));
8968
- GGML_ASSERT(dst->type == src0->type);
8969
9271
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
8970
9272
 
8971
9273
  // rows per thread
@@ -9003,7 +9305,11 @@ static void ggml_compute_forward_add_q_f32(
9003
9305
  // add src1
9004
9306
  ggml_vec_acc_f32(ne00, wdata, src1_row);
9005
9307
  // quantize row to dst
9006
- quantize_row_q(wdata, dst_row, ne00);
9308
+ if (quantize_row_q != NULL) {
9309
+ quantize_row_q(wdata, dst_row, ne00);
9310
+ } else {
9311
+ memcpy(dst_row, wdata, ne0*nb0);
9312
+ }
9007
9313
  }
9008
9314
  }
9009
9315
 
@@ -9068,7 +9374,7 @@ static void ggml_compute_forward_add1_f32(
9068
9374
 
9069
9375
  const int nr = ggml_nrows(src0);
9070
9376
 
9071
- GGML_TENSOR_UNARY_OP_LOCALS;
9377
+ GGML_TENSOR_UNARY_OP_LOCALS
9072
9378
 
9073
9379
  GGML_ASSERT( nb0 == sizeof(float));
9074
9380
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9123,7 +9429,7 @@ static void ggml_compute_forward_add1_f16_f32(
9123
9429
 
9124
9430
  const int nr = ggml_nrows(src0);
9125
9431
 
9126
- GGML_TENSOR_UNARY_OP_LOCALS;
9432
+ GGML_TENSOR_UNARY_OP_LOCALS
9127
9433
 
9128
9434
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9129
9435
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -9173,7 +9479,7 @@ static void ggml_compute_forward_add1_f16_f16(
9173
9479
 
9174
9480
  const int nr = ggml_nrows(src0);
9175
9481
 
9176
- GGML_TENSOR_UNARY_OP_LOCALS;
9482
+ GGML_TENSOR_UNARY_OP_LOCALS
9177
9483
 
9178
9484
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9179
9485
  GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -9223,7 +9529,7 @@ static void ggml_compute_forward_add1_q_f32(
9223
9529
 
9224
9530
  const int nr = ggml_nrows(src0);
9225
9531
 
9226
- GGML_TENSOR_UNARY_OP_LOCALS;
9532
+ GGML_TENSOR_UNARY_OP_LOCALS
9227
9533
 
9228
9534
  const enum ggml_type type = src0->type;
9229
9535
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
@@ -9351,8 +9657,8 @@ static void ggml_compute_forward_acc_f32(
9351
9657
  const int nr = ggml_nrows(src1);
9352
9658
  const int nc = src1->ne[0];
9353
9659
 
9354
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
9355
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
9660
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
9661
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
9356
9662
 
9357
9663
  // src0 and dst as viewed during acc
9358
9664
  const size_t nb0 = ggml_element_size(src0);
@@ -9441,7 +9747,7 @@ static void ggml_compute_forward_sub_f32(
9441
9747
 
9442
9748
  const int nr = ggml_nrows(src0);
9443
9749
 
9444
- GGML_TENSOR_BINARY_OP_LOCALS;
9750
+ GGML_TENSOR_BINARY_OP_LOCALS
9445
9751
 
9446
9752
  GGML_ASSERT( nb0 == sizeof(float));
9447
9753
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9531,7 +9837,7 @@ static void ggml_compute_forward_mul_f32(
9531
9837
 
9532
9838
  const int64_t nr = ggml_nrows(src0);
9533
9839
 
9534
- GGML_TENSOR_BINARY_OP_LOCALS;
9840
+ GGML_TENSOR_BINARY_OP_LOCALS
9535
9841
 
9536
9842
  GGML_ASSERT( nb0 == sizeof(float));
9537
9843
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9622,7 +9928,7 @@ static void ggml_compute_forward_div_f32(
9622
9928
 
9623
9929
  const int nr = ggml_nrows(src0);
9624
9930
 
9625
- GGML_TENSOR_BINARY_OP_LOCALS;
9931
+ GGML_TENSOR_BINARY_OP_LOCALS
9626
9932
 
9627
9933
  GGML_ASSERT( nb0 == sizeof(float));
9628
9934
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9831,8 +10137,8 @@ static void ggml_compute_forward_sum_f32(
9831
10137
  assert(ggml_is_scalar(dst));
9832
10138
  assert(src0->nb[0] == sizeof(float));
9833
10139
 
9834
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
9835
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb);
10140
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10141
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
9836
10142
 
9837
10143
  ggml_float sum = 0;
9838
10144
  ggml_float row_sum = 0;
@@ -9863,8 +10169,8 @@ static void ggml_compute_forward_sum_f16(
9863
10169
 
9864
10170
  assert(src0->nb[0] == sizeof(ggml_fp16_t));
9865
10171
 
9866
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
9867
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb);
10172
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10173
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
9868
10174
 
9869
10175
  float sum = 0;
9870
10176
  float row_sum = 0;
@@ -9917,7 +10223,7 @@ static void ggml_compute_forward_sum_rows_f32(
9917
10223
  GGML_ASSERT(src0->nb[0] == sizeof(float));
9918
10224
  GGML_ASSERT(dst->nb[0] == sizeof(float));
9919
10225
 
9920
- GGML_TENSOR_UNARY_OP_LOCALS;
10226
+ GGML_TENSOR_UNARY_OP_LOCALS
9921
10227
 
9922
10228
  GGML_ASSERT(ne0 == 1);
9923
10229
  GGML_ASSERT(ne1 == ne01);
@@ -9967,7 +10273,7 @@ static void ggml_compute_forward_mean_f32(
9967
10273
 
9968
10274
  assert(src0->nb[0] == sizeof(float));
9969
10275
 
9970
- GGML_TENSOR_UNARY_OP_LOCALS;
10276
+ GGML_TENSOR_UNARY_OP_LOCALS
9971
10277
 
9972
10278
  assert(ne0 == 1);
9973
10279
  assert(ne1 == ne01);
@@ -10067,7 +10373,7 @@ static void ggml_compute_forward_repeat_f32(
10067
10373
  return;
10068
10374
  }
10069
10375
 
10070
- GGML_TENSOR_UNARY_OP_LOCALS;
10376
+ GGML_TENSOR_UNARY_OP_LOCALS
10071
10377
 
10072
10378
  // guaranteed to be an integer due to the check in ggml_can_repeat
10073
10379
  const int nr0 = (int)(ne0/ne00);
@@ -10099,11 +10405,61 @@ static void ggml_compute_forward_repeat_f32(
10099
10405
  }
10100
10406
  }
10101
10407
 
10408
+ static void ggml_compute_forward_repeat_f16(
10409
+ const struct ggml_compute_params * params,
10410
+ const struct ggml_tensor * src0,
10411
+ struct ggml_tensor * dst) {
10412
+ GGML_ASSERT(params->ith == 0);
10413
+ GGML_ASSERT(ggml_can_repeat(src0, dst));
10414
+
10415
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10416
+ return;
10417
+ }
10418
+
10419
+ GGML_TENSOR_UNARY_OP_LOCALS;
10420
+
10421
+ // guaranteed to be an integer due to the check in ggml_can_repeat
10422
+ const int nr0 = (int)(ne0/ne00);
10423
+ const int nr1 = (int)(ne1/ne01);
10424
+ const int nr2 = (int)(ne2/ne02);
10425
+ const int nr3 = (int)(ne3/ne03);
10426
+
10427
+ // TODO: support for transposed / permuted tensors
10428
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
10429
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
10430
+
10431
+ // TODO: maybe this is not optimal?
10432
+ for (int i3 = 0; i3 < nr3; i3++) {
10433
+ for (int k3 = 0; k3 < ne03; k3++) {
10434
+ for (int i2 = 0; i2 < nr2; i2++) {
10435
+ for (int k2 = 0; k2 < ne02; k2++) {
10436
+ for (int i1 = 0; i1 < nr1; i1++) {
10437
+ for (int k1 = 0; k1 < ne01; k1++) {
10438
+ for (int i0 = 0; i0 < nr0; i0++) {
10439
+ ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
10440
+ ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
10441
+ // ggml_vec_cpy_f16(ne00, y, x)
10442
+ for (int i = 0; i < ne00; ++i) {
10443
+ y[i] = x[i];
10444
+ }
10445
+ }
10446
+ }
10447
+ }
10448
+ }
10449
+ }
10450
+ }
10451
+ }
10452
+ }
10453
+
10102
10454
  static void ggml_compute_forward_repeat(
10103
10455
  const struct ggml_compute_params * params,
10104
10456
  const struct ggml_tensor * src0,
10105
10457
  struct ggml_tensor * dst) {
10106
10458
  switch (src0->type) {
10459
+ case GGML_TYPE_F16:
10460
+ {
10461
+ ggml_compute_forward_repeat_f16(params, src0, dst);
10462
+ } break;
10107
10463
  case GGML_TYPE_F32:
10108
10464
  {
10109
10465
  ggml_compute_forward_repeat_f32(params, src0, dst);
@@ -10128,7 +10484,7 @@ static void ggml_compute_forward_repeat_back_f32(
10128
10484
  return;
10129
10485
  }
10130
10486
 
10131
- GGML_TENSOR_UNARY_OP_LOCALS;
10487
+ GGML_TENSOR_UNARY_OP_LOCALS
10132
10488
 
10133
10489
  // guaranteed to be an integer due to the check in ggml_can_repeat
10134
10490
  const int nr0 = (int)(ne00/ne0);
@@ -10206,7 +10562,7 @@ static void ggml_compute_forward_concat_f32(
10206
10562
 
10207
10563
  const int ith = params->ith;
10208
10564
 
10209
- GGML_TENSOR_BINARY_OP_LOCALS;
10565
+ GGML_TENSOR_BINARY_OP_LOCALS
10210
10566
 
10211
10567
  // TODO: support for transposed / permuted tensors
10212
10568
  GGML_ASSERT(nb0 == sizeof(float));
@@ -10808,7 +11164,7 @@ static void ggml_compute_forward_norm_f32(
10808
11164
  const int ith = params->ith;
10809
11165
  const int nth = params->nth;
10810
11166
 
10811
- GGML_TENSOR_UNARY_OP_LOCALS;
11167
+ GGML_TENSOR_UNARY_OP_LOCALS
10812
11168
 
10813
11169
  float eps;
10814
11170
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -10877,7 +11233,7 @@ static void ggml_compute_forward_rms_norm_f32(
10877
11233
  const int ith = params->ith;
10878
11234
  const int nth = params->nth;
10879
11235
 
10880
- GGML_TENSOR_UNARY_OP_LOCALS;
11236
+ GGML_TENSOR_UNARY_OP_LOCALS
10881
11237
 
10882
11238
  float eps;
10883
11239
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -10942,7 +11298,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
10942
11298
  const int ith = params->ith;
10943
11299
  const int nth = params->nth;
10944
11300
 
10945
- GGML_TENSOR_BINARY_OP_LOCALS;
11301
+ GGML_TENSOR_BINARY_OP_LOCALS
10946
11302
 
10947
11303
  float eps;
10948
11304
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -11117,7 +11473,7 @@ static void ggml_compute_forward_group_norm_f32(
11117
11473
  const int ith = params->ith;
11118
11474
  const int nth = params->nth;
11119
11475
 
11120
- GGML_TENSOR_UNARY_OP_LOCALS;
11476
+ GGML_TENSOR_UNARY_OP_LOCALS
11121
11477
 
11122
11478
  const float eps = 1e-6f; // TODO: make this a parameter
11123
11479
 
@@ -11228,7 +11584,7 @@ static void ggml_compute_forward_mul_mat(
11228
11584
  int64_t t0 = ggml_perf_time_us();
11229
11585
  UNUSED(t0);
11230
11586
 
11231
- GGML_TENSOR_BINARY_OP_LOCALS;
11587
+ GGML_TENSOR_BINARY_OP_LOCALS
11232
11588
 
11233
11589
  const int ith = params->ith;
11234
11590
  const int nth = params->nth;
@@ -11443,10 +11799,10 @@ static void ggml_compute_forward_out_prod_f32(
11443
11799
  const struct ggml_tensor * src0,
11444
11800
  const struct ggml_tensor * src1,
11445
11801
  struct ggml_tensor * dst) {
11446
- int64_t t0 = ggml_perf_time_us();
11447
- UNUSED(t0);
11802
+ // int64_t t0 = ggml_perf_time_us();
11803
+ // UNUSED(t0);
11448
11804
 
11449
- GGML_TENSOR_BINARY_OP_LOCALS;
11805
+ GGML_TENSOR_BINARY_OP_LOCALS
11450
11806
 
11451
11807
  const int ith = params->ith;
11452
11808
  const int nth = params->nth;
@@ -11485,6 +11841,146 @@ static void ggml_compute_forward_out_prod_f32(
11485
11841
  return;
11486
11842
  }
11487
11843
 
11844
+ // dst[:,:,:,:] = 0
11845
+ // for i2,i3:
11846
+ // for i1:
11847
+ // for i01:
11848
+ // for i0:
11849
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
11850
+
11851
+ // parallelize by last three dimensions
11852
+
11853
+ // total rows in dst
11854
+ const int64_t nr = ne1*ne2*ne3;
11855
+
11856
+ // rows per thread
11857
+ const int64_t dr = (nr + nth - 1)/nth;
11858
+
11859
+ // row range for this thread
11860
+ const int64_t ir0 = dr*ith;
11861
+ const int64_t ir1 = MIN(ir0 + dr, nr);
11862
+
11863
+ // block-tiling attempt
11864
+ const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
11865
+ const int64_t blck_1 = 16;
11866
+
11867
+ for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
11868
+ const int64_t bir1 = MIN(bir + blck_1, ir1);
11869
+ for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
11870
+ const int64_t bne01 = MIN(bi01 + blck_0, ne01);
11871
+ for (int64_t ir = bir; ir < bir1; ++ir) {
11872
+ // dst indices
11873
+ const int64_t i3 = ir/(ne2*ne1);
11874
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
11875
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
11876
+
11877
+ const int64_t i02 = i2;
11878
+ const int64_t i03 = i3;
11879
+
11880
+ //const int64_t i10 = i1;
11881
+ const int64_t i12 = i2;
11882
+ const int64_t i13 = i3;
11883
+
11884
+ #if GGML_VEC_MAD_UNROLL > 2
11885
+ const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
11886
+ for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
11887
+ const int64_t i11 = i01;
11888
+
11889
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11890
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11891
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11892
+
11893
+ ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
11894
+ }
11895
+ for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
11896
+ const int64_t i11 = i01;
11897
+
11898
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11899
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11900
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11901
+
11902
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
11903
+ }
11904
+ #else
11905
+ for (int64_t i01 = bi01; i01 < bne01; ++i01) {
11906
+ const int64_t i11 = i01;
11907
+
11908
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11909
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11910
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11911
+
11912
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
11913
+ }
11914
+ #endif
11915
+ }
11916
+ }
11917
+ }
11918
+
11919
+
11920
+ //int64_t t1 = ggml_perf_time_us();
11921
+ //static int64_t acc = 0;
11922
+ //acc += t1 - t0;
11923
+ //if (t1 - t0 > 10) {
11924
+ // printf("\n");
11925
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
11926
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
11927
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
11928
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
11929
+
11930
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
11931
+ //}
11932
+ }
11933
+
11934
+ static void ggml_compute_forward_out_prod_q_f32(
11935
+ const struct ggml_compute_params * params,
11936
+ const struct ggml_tensor * src0,
11937
+ const struct ggml_tensor * src1,
11938
+ struct ggml_tensor * dst) {
11939
+ // int64_t t0 = ggml_perf_time_us();
11940
+ // UNUSED(t0);
11941
+
11942
+ GGML_TENSOR_BINARY_OP_LOCALS;
11943
+
11944
+ const int ith = params->ith;
11945
+ const int nth = params->nth;
11946
+
11947
+ const enum ggml_type type = src0->type;
11948
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
11949
+
11950
+ GGML_ASSERT(ne02 == ne12);
11951
+ GGML_ASSERT(ne03 == ne13);
11952
+ GGML_ASSERT(ne2 == ne12);
11953
+ GGML_ASSERT(ne3 == ne13);
11954
+
11955
+ // we don't support permuted src0 dim0
11956
+ GGML_ASSERT(nb00 == ggml_type_size(type));
11957
+
11958
+ // dst dim0 cannot be transposed or permuted
11959
+ GGML_ASSERT(nb0 == sizeof(float));
11960
+ // GGML_ASSERT(nb0 <= nb1);
11961
+ // GGML_ASSERT(nb1 <= nb2);
11962
+ // GGML_ASSERT(nb2 <= nb3);
11963
+
11964
+ GGML_ASSERT(ne0 == ne00);
11965
+ GGML_ASSERT(ne1 == ne10);
11966
+ GGML_ASSERT(ne2 == ne02);
11967
+ GGML_ASSERT(ne3 == ne03);
11968
+
11969
+ // nb01 >= nb00 - src0 is not transposed
11970
+ // compute by src0 rows
11971
+
11972
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
11973
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
11974
+
11975
+ if (params->type == GGML_TASK_INIT) {
11976
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
11977
+ return;
11978
+ }
11979
+
11980
+ if (params->type == GGML_TASK_FINALIZE) {
11981
+ return;
11982
+ }
11983
+
11488
11984
  // parallelize by last three dimensions
11489
11985
 
11490
11986
  // total rows in dst
@@ -11504,6 +12000,8 @@ static void ggml_compute_forward_out_prod_f32(
11504
12000
  // for i0:
11505
12001
  // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
11506
12002
 
12003
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
12004
+
11507
12005
  for (int64_t ir = ir0; ir < ir1; ++ir) {
11508
12006
  // dst indices
11509
12007
  const int64_t i3 = ir/(ne2*ne1);
@@ -11524,10 +12022,8 @@ static void ggml_compute_forward_out_prod_f32(
11524
12022
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11525
12023
  float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11526
12024
 
11527
- ggml_vec_mad_f32(ne0, d, s0, *s1);
11528
- // for (int64_t i0 = 0; i0 < ne0; ++i0) {
11529
- // d[i0] += s0[i0] * s1[i1];
11530
- // }
12025
+ dequantize_row_q(s0, wdata, ne0);
12026
+ ggml_vec_mad_f32(ne0, d, wdata, *s1);
11531
12027
  }
11532
12028
  }
11533
12029
 
@@ -11556,10 +12052,13 @@ static void ggml_compute_forward_out_prod(
11556
12052
  case GGML_TYPE_Q5_0:
11557
12053
  case GGML_TYPE_Q5_1:
11558
12054
  case GGML_TYPE_Q8_0:
11559
- case GGML_TYPE_Q8_1:
12055
+ case GGML_TYPE_Q2_K:
12056
+ case GGML_TYPE_Q3_K:
12057
+ case GGML_TYPE_Q4_K:
12058
+ case GGML_TYPE_Q5_K:
12059
+ case GGML_TYPE_Q6_K:
11560
12060
  {
11561
- GGML_ASSERT(false); // todo
11562
- // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
12061
+ ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
11563
12062
  } break;
11564
12063
  case GGML_TYPE_F16:
11565
12064
  {
@@ -11677,8 +12176,8 @@ static void ggml_compute_forward_set_f32(
11677
12176
  const int nr = ggml_nrows(src1);
11678
12177
  const int nc = src1->ne[0];
11679
12178
 
11680
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
11681
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
12179
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
12180
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
11682
12181
 
11683
12182
  // src0 and dst as viewed during set
11684
12183
  const size_t nb0 = ggml_element_size(src0);
@@ -11947,14 +12446,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
11947
12446
  const struct ggml_compute_params * params,
11948
12447
  const struct ggml_tensor * src0,
11949
12448
  const struct ggml_tensor * src1,
11950
- const struct ggml_tensor * opt0,
11951
12449
  struct ggml_tensor * dst) {
11952
12450
  GGML_ASSERT(params->ith == 0);
11953
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
11954
- GGML_ASSERT(ggml_is_contiguous(opt0));
11955
12451
  GGML_ASSERT(ggml_is_contiguous(dst));
11956
12452
 
11957
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
12453
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
12454
+
12455
+ if (params->type == GGML_TASK_INIT) {
12456
+ memset(dst->data, 0, ggml_nbytes(dst));
12457
+ }
11958
12458
 
11959
12459
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11960
12460
  return;
@@ -11980,11 +12480,8 @@ static void ggml_compute_forward_get_rows_back_f32(
11980
12480
  const struct ggml_compute_params * params,
11981
12481
  const struct ggml_tensor * src0,
11982
12482
  const struct ggml_tensor * src1,
11983
- const struct ggml_tensor * opt0,
11984
12483
  struct ggml_tensor * dst) {
11985
12484
  GGML_ASSERT(params->ith == 0);
11986
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
11987
- GGML_ASSERT(ggml_is_contiguous(opt0));
11988
12485
  GGML_ASSERT(ggml_is_contiguous(dst));
11989
12486
 
11990
12487
  // ggml_compute_forward_dup_same_cont(params, opt0, dst);
@@ -12018,16 +12515,15 @@ static void ggml_compute_forward_get_rows_back(
12018
12515
  const struct ggml_compute_params * params,
12019
12516
  const struct ggml_tensor * src0,
12020
12517
  const struct ggml_tensor * src1,
12021
- const struct ggml_tensor * opt0,
12022
12518
  struct ggml_tensor * dst) {
12023
12519
  switch (src0->type) {
12024
12520
  case GGML_TYPE_F16:
12025
12521
  {
12026
- ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst);
12522
+ ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
12027
12523
  } break;
12028
12524
  case GGML_TYPE_F32:
12029
12525
  {
12030
- ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst);
12526
+ ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
12031
12527
  } break;
12032
12528
  default:
12033
12529
  {
@@ -12068,7 +12564,7 @@ static void ggml_compute_forward_diag_f32(
12068
12564
 
12069
12565
  // TODO: handle transposed/permuted matrices
12070
12566
 
12071
- GGML_TENSOR_UNARY_OP_LOCALS;
12567
+ GGML_TENSOR_UNARY_OP_LOCALS
12072
12568
 
12073
12569
  GGML_ASSERT(ne00 == ne0);
12074
12570
  GGML_ASSERT(ne00 == ne1);
@@ -12456,13 +12952,11 @@ static void ggml_compute_forward_alibi_f16(
12456
12952
  return;
12457
12953
  }
12458
12954
 
12459
- const int n_past = ((int32_t *) dst->op_params)[0];
12955
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12460
12956
  const int n_head = ((int32_t *) dst->op_params)[1];
12461
12957
  float max_bias;
12462
12958
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12463
12959
 
12464
- assert(n_past >= 0);
12465
-
12466
12960
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12467
12961
  const int ne1 = src0->ne[1]; // seq_len_without_past
12468
12962
  const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -12477,7 +12971,7 @@ static void ggml_compute_forward_alibi_f16(
12477
12971
  //const int nb3 = src0->nb[3];
12478
12972
 
12479
12973
  GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
12480
- GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12974
+ //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12481
12975
  GGML_ASSERT(n_head == ne2);
12482
12976
 
12483
12977
  // add alibi to src0 (KQ_scaled)
@@ -12623,8 +13117,8 @@ static void ggml_compute_forward_clamp(
12623
13117
  static void ggml_compute_forward_rope_f32(
12624
13118
  const struct ggml_compute_params * params,
12625
13119
  const struct ggml_tensor * src0,
13120
+ const struct ggml_tensor * src1,
12626
13121
  struct ggml_tensor * dst) {
12627
-
12628
13122
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12629
13123
  return;
12630
13124
  }
@@ -12634,9 +13128,9 @@ static void ggml_compute_forward_rope_f32(
12634
13128
 
12635
13129
  // these two only relevant for xPos RoPE:
12636
13130
  float xpos_base;
12637
- bool xpos_down;
13131
+ bool xpos_down;
12638
13132
 
12639
- const int n_past = ((int32_t *) dst->op_params)[0];
13133
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12640
13134
  const int n_dims = ((int32_t *) dst->op_params)[1];
12641
13135
  const int mode = ((int32_t *) dst->op_params)[2];
12642
13136
  const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -12645,9 +13139,7 @@ static void ggml_compute_forward_rope_f32(
12645
13139
  memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
12646
13140
  memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
12647
13141
 
12648
- assert(n_past >= 0);
12649
-
12650
- GGML_TENSOR_UNARY_OP_LOCALS;
13142
+ GGML_TENSOR_UNARY_OP_LOCALS
12651
13143
 
12652
13144
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12653
13145
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12677,9 +13169,11 @@ static void ggml_compute_forward_rope_f32(
12677
13169
  const bool is_neox = mode & 2;
12678
13170
  const bool is_glm = mode & 4;
12679
13171
 
13172
+ const int32_t * pos = (const int32_t *) src1->data;
13173
+
12680
13174
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12681
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12682
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13175
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13176
+ const int64_t p = pos[i2];
12683
13177
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12684
13178
  if (ir++ < ir0) continue;
12685
13179
  if (ir > ir1) break;
@@ -12716,7 +13210,7 @@ static void ggml_compute_forward_rope_f32(
12716
13210
  const float cos_theta = cosf(theta);
12717
13211
  const float sin_theta = sinf(theta);
12718
13212
  // zeta scaling for xPos only:
12719
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
13213
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
12720
13214
  if (xpos_down) zeta = 1.0f / zeta;
12721
13215
 
12722
13216
  theta *= theta_scale;
@@ -12761,8 +13255,8 @@ static void ggml_compute_forward_rope_f32(
12761
13255
  static void ggml_compute_forward_rope_f16(
12762
13256
  const struct ggml_compute_params * params,
12763
13257
  const struct ggml_tensor * src0,
13258
+ const struct ggml_tensor * src1,
12764
13259
  struct ggml_tensor * dst) {
12765
-
12766
13260
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12767
13261
  return;
12768
13262
  }
@@ -12770,16 +13264,14 @@ static void ggml_compute_forward_rope_f16(
12770
13264
  float freq_base;
12771
13265
  float freq_scale;
12772
13266
 
12773
- const int n_past = ((int32_t *) dst->op_params)[0];
13267
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12774
13268
  const int n_dims = ((int32_t *) dst->op_params)[1];
12775
13269
  const int mode = ((int32_t *) dst->op_params)[2];
12776
13270
  const int n_ctx = ((int32_t *) dst->op_params)[3];
12777
13271
  memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
12778
13272
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
12779
13273
 
12780
- assert(n_past >= 0);
12781
-
12782
- GGML_TENSOR_UNARY_OP_LOCALS;
13274
+ GGML_TENSOR_UNARY_OP_LOCALS
12783
13275
 
12784
13276
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12785
13277
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12809,9 +13301,11 @@ static void ggml_compute_forward_rope_f16(
12809
13301
  const bool is_neox = mode & 2;
12810
13302
  const bool is_glm = mode & 4;
12811
13303
 
13304
+ const int32_t * pos = (const int32_t *) src1->data;
13305
+
12812
13306
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12813
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12814
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13307
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13308
+ const int64_t p = pos[i2];
12815
13309
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12816
13310
  if (ir++ < ir0) continue;
12817
13311
  if (ir > ir1) break;
@@ -12890,15 +13384,16 @@ static void ggml_compute_forward_rope_f16(
12890
13384
  static void ggml_compute_forward_rope(
12891
13385
  const struct ggml_compute_params * params,
12892
13386
  const struct ggml_tensor * src0,
13387
+ const struct ggml_tensor * src1,
12893
13388
  struct ggml_tensor * dst) {
12894
13389
  switch (src0->type) {
12895
13390
  case GGML_TYPE_F16:
12896
13391
  {
12897
- ggml_compute_forward_rope_f16(params, src0, dst);
13392
+ ggml_compute_forward_rope_f16(params, src0, src1, dst);
12898
13393
  } break;
12899
13394
  case GGML_TYPE_F32:
12900
13395
  {
12901
- ggml_compute_forward_rope_f32(params, src0, dst);
13396
+ ggml_compute_forward_rope_f32(params, src0, src1, dst);
12902
13397
  } break;
12903
13398
  default:
12904
13399
  {
@@ -12912,6 +13407,7 @@ static void ggml_compute_forward_rope(
12912
13407
  static void ggml_compute_forward_rope_back_f32(
12913
13408
  const struct ggml_compute_params * params,
12914
13409
  const struct ggml_tensor * src0,
13410
+ const struct ggml_tensor * src1,
12915
13411
  struct ggml_tensor * dst) {
12916
13412
 
12917
13413
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -12929,7 +13425,7 @@ static void ggml_compute_forward_rope_back_f32(
12929
13425
  float xpos_base;
12930
13426
  bool xpos_down;
12931
13427
 
12932
- const int n_past = ((int32_t *) dst->op_params)[0];
13428
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12933
13429
  const int n_dims = ((int32_t *) dst->op_params)[1];
12934
13430
  const int mode = ((int32_t *) dst->op_params)[2];
12935
13431
  const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
@@ -12938,9 +13434,7 @@ static void ggml_compute_forward_rope_back_f32(
12938
13434
  memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
12939
13435
  memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
12940
13436
 
12941
- assert(n_past >= 0);
12942
-
12943
- GGML_TENSOR_UNARY_OP_LOCALS;
13437
+ GGML_TENSOR_UNARY_OP_LOCALS
12944
13438
 
12945
13439
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12946
13440
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12966,9 +13460,11 @@ static void ggml_compute_forward_rope_back_f32(
12966
13460
 
12967
13461
  const bool is_neox = mode & 2;
12968
13462
 
13463
+ const int32_t * pos = (const int32_t *) src1->data;
13464
+
12969
13465
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12970
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12971
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13466
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13467
+ const int64_t p = pos[i2];
12972
13468
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12973
13469
  if (ir++ < ir0) continue;
12974
13470
  if (ir > ir1) break;
@@ -12980,7 +13476,7 @@ static void ggml_compute_forward_rope_back_f32(
12980
13476
  const float cos_theta = cosf(theta);
12981
13477
  const float sin_theta = sinf(theta);
12982
13478
  // zeta scaling for xPos only:
12983
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
13479
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
12984
13480
  if (xpos_down) zeta = 1.0f / zeta;
12985
13481
 
12986
13482
  theta *= theta_scale;
@@ -13023,6 +13519,7 @@ static void ggml_compute_forward_rope_back_f32(
13023
13519
  static void ggml_compute_forward_rope_back_f16(
13024
13520
  const struct ggml_compute_params * params,
13025
13521
  const struct ggml_tensor * src0,
13522
+ const struct ggml_tensor * src1,
13026
13523
  struct ggml_tensor * dst) {
13027
13524
 
13028
13525
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -13033,13 +13530,11 @@ static void ggml_compute_forward_rope_back_f16(
13033
13530
  // dx = rope_back(dy, src1)
13034
13531
  // src0 is dy, src1 contains options
13035
13532
 
13036
- const int n_past = ((int32_t *) dst->op_params)[0];
13533
+ //const int n_past = ((int32_t *) dst->op_params)[0];
13037
13534
  const int n_dims = ((int32_t *) dst->op_params)[1];
13038
13535
  const int mode = ((int32_t *) dst->op_params)[2];
13039
13536
 
13040
- assert(n_past >= 0);
13041
-
13042
- GGML_TENSOR_UNARY_OP_LOCALS;
13537
+ GGML_TENSOR_UNARY_OP_LOCALS
13043
13538
 
13044
13539
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
13045
13540
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -13065,9 +13560,11 @@ static void ggml_compute_forward_rope_back_f16(
13065
13560
 
13066
13561
  const bool is_neox = mode & 2;
13067
13562
 
13563
+ const int32_t * pos = (const int32_t *) src1->data;
13564
+
13068
13565
  for (int64_t i3 = 0; i3 < ne3; i3++) {
13069
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
13070
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13566
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13567
+ const int64_t p = pos[i2];
13071
13568
  for (int64_t i1 = 0; i1 < ne1; i1++) {
13072
13569
  if (ir++ < ir0) continue;
13073
13570
  if (ir > ir1) break;
@@ -13119,15 +13616,16 @@ static void ggml_compute_forward_rope_back_f16(
13119
13616
  static void ggml_compute_forward_rope_back(
13120
13617
  const struct ggml_compute_params * params,
13121
13618
  const struct ggml_tensor * src0,
13619
+ const struct ggml_tensor * src1,
13122
13620
  struct ggml_tensor * dst) {
13123
13621
  switch (src0->type) {
13124
13622
  case GGML_TYPE_F16:
13125
13623
  {
13126
- ggml_compute_forward_rope_back_f16(params, src0, dst);
13624
+ ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
13127
13625
  } break;
13128
13626
  case GGML_TYPE_F32:
13129
13627
  {
13130
- ggml_compute_forward_rope_back_f32(params, src0, dst);
13628
+ ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
13131
13629
  } break;
13132
13630
  default:
13133
13631
  {
@@ -13150,7 +13648,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13150
13648
  int64_t t0 = ggml_perf_time_us();
13151
13649
  UNUSED(t0);
13152
13650
 
13153
- GGML_TENSOR_BINARY_OP_LOCALS;
13651
+ GGML_TENSOR_BINARY_OP_LOCALS
13154
13652
 
13155
13653
  const int ith = params->ith;
13156
13654
  const int nth = params->nth;
@@ -13241,7 +13739,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13241
13739
  int64_t t0 = ggml_perf_time_us();
13242
13740
  UNUSED(t0);
13243
13741
 
13244
- GGML_TENSOR_BINARY_OP_LOCALS;
13742
+ GGML_TENSOR_BINARY_OP_LOCALS
13245
13743
 
13246
13744
  const int ith = params->ith;
13247
13745
  const int nth = params->nth;
@@ -13353,7 +13851,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13353
13851
  int64_t t0 = ggml_perf_time_us();
13354
13852
  UNUSED(t0);
13355
13853
 
13356
- GGML_TENSOR_BINARY_OP_LOCALS;
13854
+ GGML_TENSOR_BINARY_OP_LOCALS
13357
13855
 
13358
13856
  const int ith = params->ith;
13359
13857
  const int nth = params->nth;
@@ -13444,7 +13942,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13444
13942
  int64_t t0 = ggml_perf_time_us();
13445
13943
  UNUSED(t0);
13446
13944
 
13447
- GGML_TENSOR_BINARY_OP_LOCALS;
13945
+ GGML_TENSOR_BINARY_OP_LOCALS
13448
13946
 
13449
13947
  const int ith = params->ith;
13450
13948
  const int nth = params->nth;
@@ -13562,7 +14060,7 @@ static void ggml_compute_forward_conv_1d(
13562
14060
  ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
13563
14061
  } else {
13564
14062
  GGML_ASSERT(false); // only stride 1 and 2 supported
13565
- };
14063
+ }
13566
14064
  }
13567
14065
 
13568
14066
  // ggml_compute_forward_conv_2d
@@ -13579,7 +14077,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
13579
14077
  int64_t t0 = ggml_perf_time_us();
13580
14078
  UNUSED(t0);
13581
14079
 
13582
- GGML_TENSOR_BINARY_OP_LOCALS;
14080
+ GGML_TENSOR_BINARY_OP_LOCALS
13583
14081
 
13584
14082
  const int ith = params->ith;
13585
14083
  const int nth = params->nth;
@@ -13699,7 +14197,7 @@ static void ggml_compute_forward_conv_transpose_2d(
13699
14197
  int64_t t0 = ggml_perf_time_us();
13700
14198
  UNUSED(t0);
13701
14199
 
13702
- GGML_TENSOR_BINARY_OP_LOCALS;
14200
+ GGML_TENSOR_BINARY_OP_LOCALS
13703
14201
 
13704
14202
  const int ith = params->ith;
13705
14203
  const int nth = params->nth;
@@ -13958,7 +14456,7 @@ static void ggml_compute_forward_upscale_f32(
13958
14456
 
13959
14457
  const int ith = params->ith;
13960
14458
 
13961
- GGML_TENSOR_UNARY_OP_LOCALS;
14459
+ GGML_TENSOR_UNARY_OP_LOCALS
13962
14460
 
13963
14461
  const int scale_factor = dst->op_params[0];
13964
14462
 
@@ -14010,14 +14508,14 @@ static void ggml_compute_forward_flash_attn_f32(
14010
14508
  int64_t t0 = ggml_perf_time_us();
14011
14509
  UNUSED(t0);
14012
14510
 
14013
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14014
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14015
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14016
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14017
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14018
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14019
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14020
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14511
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14512
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14513
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14514
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14515
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14516
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14517
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14518
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14021
14519
 
14022
14520
  const int ith = params->ith;
14023
14521
  const int nth = params->nth;
@@ -14087,10 +14585,11 @@ static void ggml_compute_forward_flash_attn_f32(
14087
14585
  S[i] = -INFINITY;
14088
14586
  }
14089
14587
 
14090
- for (int64_t ic = 0; ic < nek1; ++ic) {
14588
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
14589
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
14091
14590
  // k indices
14092
14591
  const int ik3 = iq3;
14093
- const int ik2 = iq2;
14592
+ const int ik2 = iq2 % nek2;
14094
14593
  const int ik1 = ic;
14095
14594
 
14096
14595
  // S indices
@@ -14103,20 +14602,18 @@ static void ggml_compute_forward_flash_attn_f32(
14103
14602
  }
14104
14603
 
14105
14604
  // scale
14106
- ggml_vec_scale_f32(nek1, S, scale);
14605
+ ggml_vec_scale_f32(masked_begin, S, scale);
14107
14606
 
14108
- if (masked) {
14109
- for (int64_t i = P; i < M; i++) {
14110
- if (i > P + iq1) {
14111
- S[i] = -INFINITY;
14112
- }
14113
- }
14607
+ for (int64_t i = masked_begin; i < M; i++) {
14608
+ S[i] = -INFINITY;
14114
14609
  }
14115
14610
 
14116
14611
  // softmax
14612
+ // exclude known -INF S[..] values from max and loop
14613
+ // dont forget to set their SW values to zero
14117
14614
  {
14118
14615
  float max = -INFINITY;
14119
- ggml_vec_max_f32(M, &max, S);
14616
+ ggml_vec_max_f32(masked_begin, &max, S);
14120
14617
 
14121
14618
  ggml_float sum = 0.0;
14122
14619
  {
@@ -14130,10 +14627,15 @@ static void ggml_compute_forward_flash_attn_f32(
14130
14627
  ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14131
14628
 
14132
14629
  for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14630
+ if (i >= masked_begin) {
14631
+ break;
14632
+ }
14133
14633
  float * SS = S + i;
14134
14634
 
14135
14635
  for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14136
- if (SS[j] == -INFINITY) {
14636
+ if (i + j >= masked_begin) {
14637
+ break;
14638
+ } else if (SS[j] == -INFINITY) {
14137
14639
  SS[j] = 0.0f;
14138
14640
  } else {
14139
14641
  #ifndef GGML_FLASH_ATTN_EXP_FP16
@@ -14158,10 +14660,10 @@ static void ggml_compute_forward_flash_attn_f32(
14158
14660
  assert(sum > 0.0);
14159
14661
 
14160
14662
  sum = 1.0/sum;
14161
- ggml_vec_scale_f32(M, S, sum);
14663
+ ggml_vec_scale_f32(masked_begin, S, sum);
14162
14664
 
14163
14665
  #ifndef NDEBUG
14164
- for (int i = 0; i < M; ++i) {
14666
+ for (int i = 0; i < masked_begin; ++i) {
14165
14667
  assert(!isnan(S[i]));
14166
14668
  assert(!isinf(S[i]));
14167
14669
  }
@@ -14174,9 +14676,13 @@ static void ggml_compute_forward_flash_attn_f32(
14174
14676
  const int i2 = iq2;
14175
14677
  const int i3 = iq3;
14176
14678
 
14177
- ggml_vec_dot_f32(nek1,
14178
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14179
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14679
+ // v indices
14680
+ const int iv2 = iq2 % nev2;
14681
+ const int iv3 = iq3;
14682
+
14683
+ ggml_vec_dot_f32(masked_begin,
14684
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14685
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14180
14686
  S);
14181
14687
  }
14182
14688
  }
@@ -14192,14 +14698,14 @@ static void ggml_compute_forward_flash_attn_f16(
14192
14698
  int64_t t0 = ggml_perf_time_us();
14193
14699
  UNUSED(t0);
14194
14700
 
14195
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14196
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14197
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14198
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14199
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14200
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14201
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14202
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14701
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14702
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14703
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14704
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14705
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14706
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14707
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14708
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14203
14709
 
14204
14710
  const int ith = params->ith;
14205
14711
  const int nth = params->nth;
@@ -14273,7 +14779,7 @@ static void ggml_compute_forward_flash_attn_f16(
14273
14779
  for (int64_t ic = 0; ic < nek1; ++ic) {
14274
14780
  // k indices
14275
14781
  const int ik3 = iq3;
14276
- const int ik2 = iq2;
14782
+ const int ik2 = iq2 % nek2;
14277
14783
  const int ik1 = ic;
14278
14784
 
14279
14785
  // S indices
@@ -14288,7 +14794,7 @@ static void ggml_compute_forward_flash_attn_f16(
14288
14794
  for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
14289
14795
  // k indices
14290
14796
  const int ik3 = iq3;
14291
- const int ik2 = iq2;
14797
+ const int ik2 = iq2 % nek2;
14292
14798
  const int ik1 = ic;
14293
14799
 
14294
14800
  // S indices
@@ -14313,6 +14819,8 @@ static void ggml_compute_forward_flash_attn_f16(
14313
14819
  }
14314
14820
 
14315
14821
  // softmax
14822
+ // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
14823
+ // dont forget to set their S values to zero
14316
14824
  {
14317
14825
  float max = -INFINITY;
14318
14826
  ggml_vec_max_f32(M, &max, S);
@@ -14369,6 +14877,7 @@ static void ggml_compute_forward_flash_attn_f16(
14369
14877
  S16[i] = GGML_FP32_TO_FP16(S[i]);
14370
14878
  }
14371
14879
 
14880
+ // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
14372
14881
  if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
14373
14882
  for (int64_t ic = 0; ic < nev1; ++ic) {
14374
14883
  // dst indices
@@ -14376,9 +14885,13 @@ static void ggml_compute_forward_flash_attn_f16(
14376
14885
  const int i2 = iq2;
14377
14886
  const int i3 = iq3;
14378
14887
 
14379
- ggml_vec_dot_f16(nek1,
14380
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14381
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14888
+ // v indices
14889
+ const int iv2 = iq2 % nev2;
14890
+ const int iv3 = iq3;
14891
+
14892
+ ggml_vec_dot_f16(nev0,
14893
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14894
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14382
14895
  S16);
14383
14896
  }
14384
14897
  } else {
@@ -14388,9 +14901,13 @@ static void ggml_compute_forward_flash_attn_f16(
14388
14901
  const int i2 = iq2;
14389
14902
  const int i3 = iq3;
14390
14903
 
14391
- ggml_vec_dot_f16_unroll(nek1, nbv1,
14392
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14393
- ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14904
+ // v indices
14905
+ const int iv2 = iq2 % nev2;
14906
+ const int iv3 = iq3;
14907
+
14908
+ ggml_vec_dot_f16_unroll(nev0, nbv1,
14909
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14910
+ ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14394
14911
  S16);
14395
14912
  }
14396
14913
  }
@@ -14433,18 +14950,18 @@ static void ggml_compute_forward_flash_ff_f16(
14433
14950
  int64_t t0 = ggml_perf_time_us();
14434
14951
  UNUSED(t0);
14435
14952
 
14436
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne);
14437
- GGML_TENSOR_LOCALS(size_t, nba, a, nb);
14438
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne);
14439
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb);
14440
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne);
14441
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb);
14442
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne);
14443
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb);
14444
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne);
14445
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb);
14446
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14447
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14953
+ GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
14954
+ GGML_TENSOR_LOCALS(size_t, nba, a, nb)
14955
+ GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
14956
+ GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
14957
+ GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
14958
+ GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
14959
+ GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
14960
+ GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
14961
+ GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
14962
+ GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
14963
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14964
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14448
14965
 
14449
14966
  const int ith = params->ith;
14450
14967
  const int nth = params->nth;
@@ -14592,16 +15109,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
14592
15109
  int64_t t0 = ggml_perf_time_us();
14593
15110
  UNUSED(t0);
14594
15111
 
14595
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14596
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14597
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14598
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14599
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14600
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14601
- GGML_TENSOR_LOCALS(int64_t, ned, d, ne);
14602
- GGML_TENSOR_LOCALS(size_t, nbd, d, nb);
14603
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14604
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
15112
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15113
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15114
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15115
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15116
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15117
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15118
+ GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
15119
+ GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
15120
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15121
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14605
15122
 
14606
15123
  const int ith = params->ith;
14607
15124
  const int nth = params->nth;
@@ -14649,10 +15166,37 @@ static void ggml_compute_forward_flash_attn_back_f32(
14649
15166
  return;
14650
15167
  }
14651
15168
 
14652
- // parallelize by q rows using ggml_vec_dot_f32
15169
+ const int64_t elem_q = ggml_nelements(q);
15170
+ const int64_t elem_k = ggml_nelements(k);
14653
15171
 
14654
- // total rows in q
14655
- const int nr = neq2*neq3;
15172
+ enum ggml_type result_type = dst->type;
15173
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
15174
+ const size_t tsize = ggml_type_size(result_type);
15175
+
15176
+ const size_t offs_q = 0;
15177
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
15178
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
15179
+
15180
+ void * grad_q = (char *) dst->data;
15181
+ void * grad_k = (char *) dst->data + offs_k;
15182
+ void * grad_v = (char *) dst->data + offs_v;
15183
+
15184
+ const size_t nbgq1 = nb0*neq0;
15185
+ const size_t nbgq2 = nb0*neq0*neq1;
15186
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
15187
+
15188
+ const size_t nbgk1 = nb0*nek0;
15189
+ const size_t nbgk2 = nb0*nek0*nek1;
15190
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
15191
+
15192
+ const size_t nbgv1 = nb0*nev0;
15193
+ const size_t nbgv2 = nb0*nev0*nev1;
15194
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
15195
+
15196
+ // parallelize by k rows using ggml_vec_dot_f32
15197
+
15198
+ // total rows in k
15199
+ const int nr = nek2*nek3;
14656
15200
 
14657
15201
  // rows per thread
14658
15202
  const int dr = (nr + nth - 1)/nth;
@@ -14665,268 +15209,243 @@ static void ggml_compute_forward_flash_attn_back_f32(
14665
15209
 
14666
15210
  //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
14667
15211
 
15212
+ // how often k2 (and v2) is repeated in q2
15213
+ int nrep = neq2/nek2;
15214
+
14668
15215
  for (int ir = ir0; ir < ir1; ++ir) {
14669
15216
  // q indices
14670
- const int iq3 = ir/(neq2);
14671
- const int iq2 = ir - iq3*neq2;
14672
- for ( int iq1 = 0; iq1 < neq1; ++iq1) {
15217
+ const int ik3 = ir/(nek2);
15218
+ const int ik2 = ir - ik3*nek2;
14673
15219
 
15220
+ const int iq3 = ik3;
15221
+ const int id3 = ik3;
15222
+ const int iv3 = ik3;
15223
+ const int iv2 = ik2;
14674
15224
 
14675
- // not sure about CACHE_LINE_SIZE_F32..
14676
- // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
14677
- float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
14678
- float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
15225
+ for (int irep = 0; irep < nrep; ++irep) {
15226
+ const int iq2 = ik2 + irep*nek2;
15227
+ const int id2 = iq2;
14679
15228
 
14680
- for (int i = M; i < Mup; ++i) {
14681
- S[i] = -INFINITY;
14682
- }
15229
+ // (ik2 + irep*nek2) % nek2 == ik2
15230
+ for (int iq1 = 0; iq1 < neq1; ++iq1) {
15231
+ const int id1 = iq1;
14683
15232
 
14684
- for (int64_t ic = 0; ic < nek1; ++ic) {
14685
- // k indices
14686
- const int ik3 = iq3;
14687
- const int ik2 = iq2;
14688
- const int ik1 = ic;
15233
+ // not sure about CACHE_LINE_SIZE_F32..
15234
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
15235
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
15236
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
14689
15237
 
14690
- // S indices
14691
- const int i1 = ik1;
15238
+ for (int i = M; i < Mup; ++i) {
15239
+ S[i] = -INFINITY;
15240
+ }
14692
15241
 
14693
- ggml_vec_dot_f32(neq0,
14694
- S + i1,
14695
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
14696
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14697
- }
15242
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15243
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15244
+ // k indices
15245
+ const int ik1 = ic;
14698
15246
 
14699
- // scale
14700
- ggml_vec_scale_f32(nek1, S, scale);
15247
+ // S indices
15248
+ const int i1 = ik1;
14701
15249
 
14702
- if (masked) {
14703
- for (int64_t i = P; i < M; i++) {
14704
- if (i > P + iq1) {
14705
- S[i] = -INFINITY;
14706
- }
15250
+ ggml_vec_dot_f32(neq0,
15251
+ S + i1,
15252
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15253
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14707
15254
  }
14708
- }
14709
15255
 
14710
- // softmax
14711
- {
14712
- float max = -INFINITY;
14713
- ggml_vec_max_f32(M, &max, S);
15256
+ // scale
15257
+ ggml_vec_scale_f32(masked_begin, S, scale);
14714
15258
 
14715
- ggml_float sum = 0.0;
15259
+ for (int64_t i = masked_begin; i < M; i++) {
15260
+ S[i] = -INFINITY;
15261
+ }
15262
+
15263
+ // softmax
15264
+ // exclude known -INF S[..] values from max and loop
15265
+ // dont forget to set their SM values to zero
14716
15266
  {
15267
+ float max = -INFINITY;
15268
+ ggml_vec_max_f32(masked_begin, &max, S);
15269
+
15270
+ ggml_float sum = 0.0;
15271
+ {
14717
15272
  #ifdef GGML_SOFT_MAX_ACCELERATE
14718
- max = -max;
14719
- vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
14720
- vvexpf(SM, SM, &Mup);
14721
- ggml_vec_sum_f32(Mup, &sum, SM);
15273
+ max = -max;
15274
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
15275
+ vvexpf(SM, SM, &Mup);
15276
+ ggml_vec_sum_f32(Mup, &sum, SM);
14722
15277
  #else
14723
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
14724
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14725
-
14726
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14727
- float * SR = S + i;
14728
- float * SW = SM + i;
15278
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15279
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14729
15280
 
14730
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14731
- if (SR[j] == -INFINITY) {
14732
- SW[j] = 0.0f;
14733
- } else {
15281
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15282
+ if (i >= masked_begin) {
15283
+ break;
15284
+ }
15285
+ float * SR = S + i;
15286
+ float * SW = SM + i;
15287
+
15288
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15289
+ if (i + j >= masked_begin) {
15290
+ break;
15291
+ } else if (SR[j] == -INFINITY) {
15292
+ SW[j] = 0.0f;
15293
+ } else {
14734
15294
  #ifndef GGML_FLASH_ATTN_EXP_FP16
14735
- const float val = expf(SR[j] - max);
15295
+ const float val = expf(SR[j] - max);
14736
15296
  #else
14737
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
14738
- memcpy(&scvt[j], &s, sizeof(uint16_t));
14739
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
15297
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
15298
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
15299
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
14740
15300
  #endif
14741
- sump[j] += (ggml_float)val;
14742
- SW[j] = val;
15301
+ sump[j] += (ggml_float)val;
15302
+ SW[j] = val;
15303
+ }
14743
15304
  }
14744
15305
  }
14745
- }
14746
15306
 
14747
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14748
- sum += sump[i];
14749
- }
15307
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15308
+ sum += sump[i];
15309
+ }
14750
15310
  #endif
14751
- }
14752
-
14753
- assert(sum > 0.0);
14754
-
14755
- sum = 1.0/sum;
14756
- ggml_vec_scale_f32(M, SM, sum);
14757
-
14758
- }
14759
-
14760
- // step-by-step explanation
14761
- {
14762
- // forward-process shape grads from backward process
14763
- // parallel_for iq2,iq3:
14764
- // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
14765
- // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
14766
- // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
14767
- // for iq1:
14768
- // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
14769
- // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
14770
- // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
14771
- // S0 = -Inf [D,1,1,1]
14772
- // ~S1[i] = dot(kcur[:D,i], qcur)
14773
- // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
14774
- // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
14775
- // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14776
- // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
14777
- // ~S5[i] = dot(vcur[:,i], S4)
14778
- // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
14779
- // ~dst[i,iq1,iq2,iq3] = S5[i] ^
14780
- // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
14781
- // dst backward-/ grad[dst] = d
14782
- //
14783
- // output gradients with their dependencies:
14784
- //
14785
- // grad[kcur] = grad[S1].T @ qcur
14786
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14787
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14788
- // grad[S4] = grad[S5] @ vcur
14789
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14790
- // grad[qcur] = grad[S1] @ kcur
14791
- // grad[vcur] = grad[S5].T @ S4
14792
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14793
- //
14794
- // in post-order:
14795
- //
14796
- // S1 = qcur @ kcur.T
14797
- // S2 = S1 * scale
14798
- // S3 = diag_mask_inf(S2, P)
14799
- // S4 = softmax(S3)
14800
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14801
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14802
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14803
- // grad[qcur] = grad[S1] @ kcur
14804
- // grad[kcur] = grad[S1].T @ qcur
14805
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14806
- //
14807
- // using less variables (SM=S4):
14808
- //
14809
- // S = diag_mask_inf(qcur @ kcur.T * scale, P)
14810
- // SM = softmax(S)
14811
- // S = d[:D,iq1,iq2,iq3] @ vcur
14812
- // dot_SM_gradSM = dot(SM, S)
14813
- // S = SM * (S - dot(SM, S))
14814
- // S = diag_mask_zero(S, P) * scale
14815
- //
14816
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14817
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14818
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14819
- }
14820
-
14821
- // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
14822
- // S = d[:D,iq1,iq2,iq3] @ vcur
14823
- // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
14824
- ggml_vec_set_f32(M, S, 0);
14825
- for (int64_t ic = 0; ic < D; ++ic) {
14826
- // dst indices
14827
- const int i1 = iq1;
14828
- const int i2 = iq2;
14829
- const int i3 = iq3;
15311
+ }
14830
15312
 
14831
- ggml_vec_mad_f32(M,
14832
- S,
14833
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14834
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
14835
- }
15313
+ assert(sum > 0.0);
14836
15314
 
14837
- // S = SM * (S - dot(SM, S))
14838
- float dot_SM_gradSM = 0;
14839
- ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
14840
- ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
14841
- ggml_vec_mul_f32 (M, S, S, SM);
15315
+ sum = 1.0/sum;
15316
+ ggml_vec_scale_f32(masked_begin, SM, sum);
14842
15317
 
14843
- // S = diag_mask_zero(S, P) * scale
14844
- if (masked) {
14845
- // for (int64_t i = P + iq1 + 1; i < M; i++) {
14846
- // S[i] = 0;
14847
- // }
14848
- for (int64_t i = P; i < M; i++) {
14849
- if (i > P + iq1) {
14850
- S[i] = 0;
14851
- }
14852
15318
  }
14853
- }
14854
- ggml_vec_scale_f32(M, S, scale);
14855
-
14856
- void * grad_q = (char *) dst->data;
14857
- void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
14858
- void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
14859
-
14860
- const size_t nbgq1 = nb0*neq0;
14861
- const size_t nbgq2 = nb0*neq0*neq1;
14862
- const size_t nbgq3 = nb0*neq0*neq1*neq2;
14863
-
14864
- const size_t nbgk1 = nb0*nek0;
14865
- const size_t nbgk2 = nb0*nek0*nek1;
14866
- const size_t nbgk3 = nb0*nek0*nek1*neq2;
14867
-
14868
- const size_t nbgv1 = nb0*nev0;
14869
- const size_t nbgv2 = nb0*nev0*nev1;
14870
- const size_t nbgv3 = nb0*nev0*nev1*neq2;
14871
-
14872
- // S shape [M,1]
14873
- // SM shape [M,1]
14874
- // kcur shape [D,M]
14875
- // qcur shape [D,1]
14876
- // vcur shape [M,D]
14877
- //
14878
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14879
- // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
14880
- // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
14881
- //
14882
- //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
14883
- //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
14884
- for (int64_t ic = 0; ic < M; ++ic) {
14885
- // dst indices
14886
- const int i1 = iq1;
14887
- const int i2 = iq2;
14888
- const int i3 = iq3;
14889
15319
 
14890
- ggml_vec_mad_f32(D,
14891
- (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
14892
- (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
14893
- S[ic]);
14894
- }
15320
+ // step-by-step explanation
15321
+ {
15322
+ // forward-process shape grads from backward process
15323
+ // parallel_for ik2,ik3:
15324
+ // for irep:
15325
+ // iq2 = ik2 + irep*nek2
15326
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
15327
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
15328
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
15329
+ // for iq1:
15330
+ // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
15331
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
15332
+ // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
15333
+ // S0 = -Inf [D,1,1,1]
15334
+ // ~S1[i] = dot(kcur[:D,i], qcur)
15335
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
15336
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
15337
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15338
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
15339
+ // ~S5[i] = dot(vcur[:,i], S4)
15340
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
15341
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
15342
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
15343
+ // dst backward-/ grad[dst] = d
15344
+ //
15345
+ // output gradients with their dependencies:
15346
+ //
15347
+ // grad[kcur] = grad[S1].T @ qcur
15348
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
15349
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15350
+ // grad[S4] = grad[S5] @ vcur
15351
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
15352
+ // grad[qcur] = grad[S1] @ kcur
15353
+ // grad[vcur] = grad[S5].T @ S4
15354
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
15355
+ //
15356
+ // in post-order:
15357
+ //
15358
+ // S1 = qcur @ kcur.T
15359
+ // S2 = S1 * scale
15360
+ // S3 = diag_mask_inf(S2, P)
15361
+ // S4 = softmax(S3)
15362
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
15363
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15364
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
15365
+ // grad[qcur] = grad[S1] @ kcur
15366
+ // grad[kcur] = grad[S1].T @ qcur
15367
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
15368
+ //
15369
+ // using less variables (SM=S4):
15370
+ //
15371
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
15372
+ // SM = softmax(S)
15373
+ // S = d[:D,iq1,iq2,iq3] @ vcur
15374
+ // dot_SM_gradSM = dot(SM, S)
15375
+ // S = SM * (S - dot(SM, S))
15376
+ // S = diag_mask_zero(S, P) * scale
15377
+ //
15378
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
15379
+ // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
15380
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
15381
+ }
14895
15382
 
14896
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14897
- // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
14898
- // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
14899
- for (int64_t ic = 0; ic < M; ++ic) {
14900
- // dst indices
14901
- const int i1 = iq1;
14902
- const int i2 = iq2;
14903
- const int i3 = iq3;
15383
+ // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
15384
+ // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
15385
+ // for ic:
15386
+ // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
15387
+ // exclude known future zero S[..] values from operation
15388
+ ggml_vec_set_f32(masked_begin, S, 0);
15389
+ for (int64_t ic = 0; ic < D; ++ic) {
15390
+ ggml_vec_mad_f32(masked_begin,
15391
+ S,
15392
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15393
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
15394
+ }
14904
15395
 
14905
- // ggml_vec_set_f32(D,
14906
- // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14907
- // 0);
14908
- ggml_vec_mad_f32(D,
14909
- (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14910
- (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
14911
- S[ic]);
14912
- }
15396
+ // S = SM * (S - dot(SM, S))
15397
+ float dot_SM_gradSM = 0;
15398
+ ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
15399
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
15400
+ ggml_vec_mul_f32 (masked_begin, S, S, SM);
15401
+
15402
+ // S = diag_mask_zero(S, P) * scale
15403
+ // already done by above ggml_vec_set_f32
15404
+
15405
+ // exclude known zero S[..] values from operation
15406
+ ggml_vec_scale_f32(masked_begin, S, scale);
15407
+
15408
+ // S shape [M,1]
15409
+ // SM shape [M,1]
15410
+ // kcur shape [D,M]
15411
+ // qcur shape [D,1]
15412
+ // vcur shape [M,D]
15413
+
15414
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
15415
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
15416
+ // for ic:
15417
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
15418
+ // exclude known zero S[..] values from loop
15419
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15420
+ ggml_vec_mad_f32(D,
15421
+ (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
15422
+ (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
15423
+ S[ic]);
15424
+ }
14913
15425
 
14914
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14915
- // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
14916
- // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
14917
- for (int64_t ic = 0; ic < D; ++ic) {
14918
- // dst indices
14919
- const int i1 = iq1;
14920
- const int i2 = iq2;
14921
- const int i3 = iq3;
15426
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
15427
+ // for ic:
15428
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
15429
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
15430
+ // exclude known zero S[..] values from loop
15431
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15432
+ ggml_vec_mad_f32(D,
15433
+ (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
15434
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
15435
+ S[ic]);
15436
+ }
14922
15437
 
14923
- // ggml_vec_set_f32(M,
14924
- // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14925
- // 0);
14926
- ggml_vec_mad_f32(M,
14927
- (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14928
- SM,
14929
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
15438
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
15439
+ // for ic:
15440
+ // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
15441
+ // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
15442
+ // exclude known zero SM[..] values from mad
15443
+ for (int64_t ic = 0; ic < D; ++ic) {
15444
+ ggml_vec_mad_f32(masked_begin,
15445
+ (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
15446
+ SM,
15447
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
15448
+ }
14930
15449
  }
14931
15450
  }
14932
15451
  }
@@ -14962,8 +15481,8 @@ static void ggml_compute_forward_win_part_f32(
14962
15481
  return;
14963
15482
  }
14964
15483
 
14965
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
14966
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
15484
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
15485
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14967
15486
 
14968
15487
  const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
14969
15488
  const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
@@ -15024,8 +15543,8 @@ static void ggml_compute_forward_win_unpart_f32(
15024
15543
  return;
15025
15544
  }
15026
15545
 
15027
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
15028
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
15546
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
15547
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15029
15548
 
15030
15549
  const int32_t w = ((const int32_t *)(dst->op_params))[0];
15031
15550
 
@@ -15142,7 +15661,7 @@ static void ggml_compute_forward_get_rel_pos_f16(
15142
15661
 
15143
15662
  // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
15144
15663
 
15145
- GGML_TENSOR_UNARY_OP_LOCALS;
15664
+ GGML_TENSOR_UNARY_OP_LOCALS
15146
15665
 
15147
15666
  const int64_t w = ne1;
15148
15667
 
@@ -15840,7 +16359,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15840
16359
  } break;
15841
16360
  case GGML_OP_GET_ROWS_BACK:
15842
16361
  {
15843
- ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
16362
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
15844
16363
  } break;
15845
16364
  case GGML_OP_DIAG:
15846
16365
  {
@@ -15864,11 +16383,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15864
16383
  } break;
15865
16384
  case GGML_OP_ROPE:
15866
16385
  {
15867
- ggml_compute_forward_rope(params, tensor->src[0], tensor);
16386
+ ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
15868
16387
  } break;
15869
16388
  case GGML_OP_ROPE_BACK:
15870
16389
  {
15871
- ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
16390
+ ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
15872
16391
  } break;
15873
16392
  case GGML_OP_ALIBI:
15874
16393
  {
@@ -16013,7 +16532,218 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16013
16532
 
16014
16533
  ////////////////////////////////////////////////////////////////////////////////
16015
16534
 
16016
- static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
16535
+ static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
16536
+
16537
+ static size_t hash(void * p) {
16538
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
16539
+ }
16540
+
16541
+ static size_t hash_find(void * hash_table[], void * p) {
16542
+ size_t h = hash(p);
16543
+
16544
+ // linear probing
16545
+ size_t i = h;
16546
+ while (hash_table[i] != NULL && hash_table[i] != p) {
16547
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
16548
+ if (i == h) {
16549
+ // visited all hash table entries -> not found
16550
+ return GGML_GRAPH_HASHTABLE_SIZE;
16551
+ }
16552
+ }
16553
+ return i;
16554
+ }
16555
+
16556
+ static bool hash_insert(void * hash_table[], void * p) {
16557
+ size_t i = hash_find(hash_table, p);
16558
+
16559
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16560
+
16561
+ if (hash_table[i] == p) {
16562
+ return true;
16563
+ }
16564
+
16565
+ // insert
16566
+ GGML_ASSERT(hash_table[i] == NULL);
16567
+ hash_table[i] = p;
16568
+ return false;
16569
+ }
16570
+
16571
+ static bool hash_contains(void * hash_table[], void * p) {
16572
+ size_t i = hash_find(hash_table, p);
16573
+ return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
16574
+ }
16575
+
16576
+ struct hash_map {
16577
+ void * keys[GGML_GRAPH_HASHTABLE_SIZE];
16578
+ void * vals[GGML_GRAPH_HASHTABLE_SIZE];
16579
+ };
16580
+
16581
+ static struct hash_map * new_hash_map(void) {
16582
+ struct hash_map * result = malloc(sizeof(struct hash_map));
16583
+ for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
16584
+ result->keys[i] = NULL;
16585
+ result->vals[i] = NULL;
16586
+ }
16587
+ return result;
16588
+ }
16589
+
16590
+ static void free_hash_map(struct hash_map * map) {
16591
+ free(map);
16592
+ }
16593
+
16594
+ // gradient checkpointing
16595
+
16596
+ static struct ggml_tensor * ggml_recompute_graph_node(
16597
+ struct ggml_context * ctx,
16598
+ struct ggml_cgraph * graph,
16599
+ struct hash_map * replacements,
16600
+ struct ggml_tensor * node) {
16601
+
16602
+ if (node == NULL) {
16603
+ return NULL;
16604
+ }
16605
+
16606
+ if (node->is_param) {
16607
+ return node;
16608
+ }
16609
+
16610
+ if (!hash_contains(graph->visited_hash_table, node)) {
16611
+ return node;
16612
+ }
16613
+
16614
+ int count_children = 0;
16615
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16616
+ if (node->src[k]) {
16617
+ ++count_children;
16618
+ }
16619
+ }
16620
+
16621
+ if (count_children == 0) {
16622
+ return node;
16623
+ }
16624
+
16625
+ size_t i = hash_find(replacements->keys, node);
16626
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16627
+ if (replacements->keys[i] == node) {
16628
+ return (struct ggml_tensor *) replacements->vals[i];
16629
+ }
16630
+
16631
+ struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
16632
+
16633
+ // insert clone into replacements
16634
+ GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
16635
+ replacements->keys[i] = node;
16636
+ replacements->vals[i] = clone;
16637
+
16638
+ clone->op = node->op;
16639
+ clone->grad = node->grad;
16640
+ clone->is_param = node->is_param;
16641
+ clone->extra = node->extra;
16642
+ for (int k = 0; k < GGML_MAX_DIMS; ++k) {
16643
+ clone->nb[k] = node->nb[k];
16644
+ }
16645
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16646
+ clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
16647
+ }
16648
+ if (node->view_src != NULL) {
16649
+ clone->data = (node->view_src->data == NULL)
16650
+ ? NULL // view_src not yet allocated
16651
+ : (char *) node->view_src->data // view_src already allocated
16652
+ + node->view_offs;
16653
+ clone->view_src = node->view_src;
16654
+ clone->view_offs = node->view_offs;
16655
+ }
16656
+
16657
+ GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
16658
+ GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
16659
+ memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
16660
+ ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
16661
+
16662
+ return clone;
16663
+ }
16664
+
16665
+ void ggml_build_backward_gradient_checkpointing(
16666
+ struct ggml_context * ctx,
16667
+ struct ggml_cgraph * gf,
16668
+ struct ggml_cgraph * gb,
16669
+ struct ggml_cgraph * gb_tmp,
16670
+ struct ggml_tensor * * checkpoints,
16671
+ int n_checkpoints) {
16672
+ *gb_tmp = *gf;
16673
+ ggml_build_backward_expand(ctx, gf, gb_tmp, true);
16674
+
16675
+ if (n_checkpoints <= 0) {
16676
+ *gb = *gb_tmp;
16677
+ return;
16678
+ }
16679
+
16680
+ struct hash_map * replacements = new_hash_map();
16681
+
16682
+ // insert checkpoints in replacements
16683
+ for (int i = 0; i < n_checkpoints; ++i) {
16684
+ size_t k = hash_find(replacements->keys, checkpoints[i]);
16685
+ GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16686
+ GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
16687
+ replacements->keys[k] = checkpoints[i];
16688
+ replacements->vals[k] = checkpoints[i];
16689
+ }
16690
+
16691
+ *gb = *gf;
16692
+ // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
16693
+ // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
16694
+ // by recomputing them from checkpoints
16695
+ for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
16696
+ struct ggml_tensor * node = gb_tmp->nodes[i];
16697
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16698
+ // insert new tensors recomputing src, reusing already made replacements,
16699
+ // remember replacements: remember new tensors with mapping from corresponding gf nodes
16700
+ // recurse for input tensors,
16701
+ // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
16702
+ node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
16703
+ }
16704
+ // insert rewritten backward node with replacements made into resulting backward graph gb
16705
+ ggml_build_forward_expand(gb, node);
16706
+ }
16707
+
16708
+ free_hash_map(replacements);
16709
+ }
16710
+
16711
+ // functions to change gradients considering the case that input a might be initial gradient with zero value
16712
+
16713
+ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16714
+ if (hash_contains(zero_table, a)) {
16715
+ return b;
16716
+ } else {
16717
+ return ggml_add_impl(ctx, a, b, false);
16718
+ }
16719
+ }
16720
+
16721
+ static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
16722
+ if (hash_contains(zero_table, a)) {
16723
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
16724
+ return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
16725
+ } else {
16726
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
16727
+ }
16728
+ }
16729
+
16730
+ static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16731
+ if (hash_contains(zero_table, a)) {
16732
+ return ggml_repeat(ctx, b, a);
16733
+ } else {
16734
+ return ggml_add1_impl(ctx, a, b, false);
16735
+ }
16736
+ }
16737
+
16738
+ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16739
+ if (hash_contains(zero_table, a)) {
16740
+ return ggml_neg(ctx, b);
16741
+ } else {
16742
+ return ggml_sub_impl(ctx, a, b, false);
16743
+ }
16744
+ }
16745
+
16746
+ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
16017
16747
  struct ggml_tensor * src0 = tensor->src[0];
16018
16748
  struct ggml_tensor * src1 = tensor->src[1];
16019
16749
 
@@ -16021,34 +16751,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16021
16751
  case GGML_OP_DUP:
16022
16752
  {
16023
16753
  if (src0->grad) {
16024
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16754
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16025
16755
  }
16026
16756
  } break;
16027
16757
  case GGML_OP_ADD:
16028
16758
  {
16029
16759
  if (src0->grad) {
16030
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16760
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16031
16761
  }
16032
16762
  if (src1->grad) {
16033
- src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
16763
+ src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
16034
16764
  }
16035
16765
  } break;
16036
16766
  case GGML_OP_ADD1:
16037
16767
  {
16038
16768
  if (src0->grad) {
16039
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16769
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16040
16770
  }
16041
16771
  if (src1->grad) {
16042
- src1->grad = ggml_add_impl(ctx,
16772
+ src1->grad = ggml_add_or_set(ctx,
16043
16773
  src1->grad,
16044
16774
  ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
16045
- inplace);
16775
+ zero_table);
16046
16776
  }
16047
16777
  } break;
16048
16778
  case GGML_OP_ACC:
16049
16779
  {
16050
16780
  if (src0->grad) {
16051
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16781
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16052
16782
  }
16053
16783
  if (src1->grad) {
16054
16784
  const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -16065,117 +16795,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16065
16795
  nb1, nb2, nb3, offset);
16066
16796
 
16067
16797
  src1->grad =
16068
- ggml_add_impl(ctx,
16798
+ ggml_add_or_set(ctx,
16069
16799
  src1->grad,
16070
16800
  ggml_reshape(ctx,
16071
16801
  ggml_cont(ctx, tensor_grad_view),
16072
16802
  src1->grad),
16073
- inplace);
16803
+ zero_table);
16074
16804
  }
16075
16805
  } break;
16076
16806
  case GGML_OP_SUB:
16077
16807
  {
16078
16808
  if (src0->grad) {
16079
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16809
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16080
16810
  }
16081
16811
  if (src1->grad) {
16082
- src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
16812
+ src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
16083
16813
  }
16084
16814
  } break;
16085
16815
  case GGML_OP_MUL:
16086
16816
  {
16087
16817
  if (src0->grad) {
16088
16818
  src0->grad =
16089
- ggml_add_impl(ctx,
16819
+ ggml_add_or_set(ctx,
16090
16820
  src0->grad,
16091
16821
  ggml_mul(ctx, src1, tensor->grad),
16092
- inplace);
16822
+ zero_table);
16093
16823
  }
16094
16824
  if (src1->grad) {
16095
16825
  src1->grad =
16096
- ggml_add_impl(ctx,
16826
+ ggml_add_or_set(ctx,
16097
16827
  src1->grad,
16098
16828
  ggml_mul(ctx, src0, tensor->grad),
16099
- inplace);
16829
+ zero_table);
16100
16830
  }
16101
16831
  } break;
16102
16832
  case GGML_OP_DIV:
16103
16833
  {
16104
16834
  if (src0->grad) {
16105
16835
  src0->grad =
16106
- ggml_add_impl(ctx,
16836
+ ggml_add_or_set(ctx,
16107
16837
  src0->grad,
16108
16838
  ggml_div(ctx, tensor->grad, src1),
16109
- inplace);
16839
+ zero_table);
16110
16840
  }
16111
16841
  if (src1->grad) {
16112
16842
  src1->grad =
16113
- ggml_sub_impl(ctx,
16843
+ ggml_sub_or_set(ctx,
16114
16844
  src1->grad,
16115
16845
  ggml_mul(ctx,
16116
16846
  tensor->grad,
16117
16847
  ggml_div(ctx, tensor, src1)),
16118
- inplace);
16848
+ zero_table);
16119
16849
  }
16120
16850
  } break;
16121
16851
  case GGML_OP_SQR:
16122
16852
  {
16123
16853
  if (src0->grad) {
16124
16854
  src0->grad =
16125
- ggml_add_impl(ctx,
16855
+ ggml_add_or_set(ctx,
16126
16856
  src0->grad,
16127
16857
  ggml_scale(ctx,
16128
16858
  ggml_mul(ctx, src0, tensor->grad),
16129
16859
  ggml_new_f32(ctx, 2.0f)),
16130
- inplace);
16860
+ zero_table);
16131
16861
  }
16132
16862
  } break;
16133
16863
  case GGML_OP_SQRT:
16134
16864
  {
16135
16865
  if (src0->grad) {
16136
16866
  src0->grad =
16137
- ggml_add_impl(ctx,
16867
+ ggml_add_or_set(ctx,
16138
16868
  src0->grad,
16139
16869
  ggml_scale(ctx,
16140
16870
  ggml_div(ctx,
16141
16871
  tensor->grad,
16142
16872
  tensor),
16143
16873
  ggml_new_f32(ctx, 0.5f)),
16144
- inplace);
16874
+ zero_table);
16145
16875
  }
16146
16876
  } break;
16147
16877
  case GGML_OP_LOG:
16148
16878
  {
16149
16879
  if (src0->grad) {
16150
16880
  src0->grad =
16151
- ggml_add_impl(ctx,
16881
+ ggml_add_or_set(ctx,
16152
16882
  src0->grad,
16153
16883
  ggml_div(ctx,
16154
16884
  tensor->grad,
16155
16885
  src0),
16156
- inplace);
16886
+ zero_table);
16157
16887
  }
16158
16888
  } break;
16159
16889
  case GGML_OP_SUM:
16160
16890
  {
16161
16891
  if (src0->grad) {
16162
16892
  src0->grad =
16163
- ggml_add1_impl(ctx,
16893
+ ggml_add1_or_set(ctx,
16164
16894
  src0->grad,
16165
16895
  tensor->grad,
16166
- inplace);
16896
+ zero_table);
16167
16897
  }
16168
16898
  } break;
16169
16899
  case GGML_OP_SUM_ROWS:
16170
16900
  {
16171
16901
  if (src0->grad) {
16172
16902
  src0->grad =
16173
- ggml_add_impl(ctx,
16903
+ ggml_add_or_set(ctx,
16174
16904
  src0->grad,
16175
16905
  ggml_repeat(ctx,
16176
16906
  tensor->grad,
16177
16907
  src0->grad),
16178
- inplace);
16908
+ zero_table);
16179
16909
  }
16180
16910
  } break;
16181
16911
  case GGML_OP_MEAN:
@@ -16187,20 +16917,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16187
16917
  {
16188
16918
  // necessary for llama
16189
16919
  if (src0->grad) {
16190
- src0->grad = ggml_add_impl(ctx,
16920
+ src0->grad = ggml_add_or_set(ctx,
16191
16921
  src0->grad,
16192
16922
  ggml_repeat_back(ctx, tensor->grad, src0->grad),
16193
- inplace);
16923
+ zero_table);
16194
16924
  }
16195
16925
  } break;
16196
16926
  case GGML_OP_REPEAT_BACK:
16197
16927
  {
16198
16928
  if (src0->grad) {
16199
16929
  // TODO: test this
16200
- src0->grad = ggml_add_impl(ctx,
16930
+ src0->grad = ggml_add_or_set(ctx,
16201
16931
  src0->grad,
16202
16932
  ggml_repeat(ctx, tensor->grad, src0->grad),
16203
- inplace);
16933
+ zero_table);
16204
16934
  }
16205
16935
  } break;
16206
16936
  case GGML_OP_CONCAT:
@@ -16222,10 +16952,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16222
16952
  float eps;
16223
16953
  memcpy(&eps, tensor->op_params, sizeof(float));
16224
16954
 
16225
- src0->grad = ggml_add_impl(ctx,
16955
+ src0->grad = ggml_add_or_set(ctx,
16226
16956
  src0->grad,
16227
16957
  ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
16228
- inplace);
16958
+ zero_table);
16229
16959
  }
16230
16960
  } break;
16231
16961
  case GGML_OP_RMS_NORM_BACK:
@@ -16249,37 +16979,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16249
16979
  // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
16250
16980
  // ds1 = t.T.dot(dt)
16251
16981
 
16252
- // tensor.shape [m,p]
16253
- // src0.shape [n,m]
16254
- // src1.shape [n,p]
16982
+ // tensor.shape [m,p,qq,rr]
16983
+ // src0.shape [n,m,q1,r1]
16984
+ // src1.shape [n,p,qq,rr]
16255
16985
 
16256
16986
  // necessary for llama
16257
16987
  if (src0->grad) {
16988
+ struct ggml_tensor * s1_tg =
16989
+ ggml_out_prod(ctx, // [n,m,qq,rr]
16990
+ src1, // [n,p,qq,rr]
16991
+ tensor->grad); // [m,p,qq,rr]
16992
+ const int64_t qq = s1_tg->ne[2];
16993
+ const int64_t rr = s1_tg->ne[3];
16994
+ const int64_t q1 = src0->ne[2];
16995
+ const int64_t r1 = src0->ne[3];
16996
+ const bool ne2_broadcasted = qq > q1;
16997
+ const bool ne3_broadcasted = rr > r1;
16998
+ if (ne2_broadcasted || ne3_broadcasted) {
16999
+ // sum broadcast repetitions of s1_tg into shape of src0
17000
+ s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
17001
+ }
16258
17002
  src0->grad =
16259
- ggml_add_impl(ctx,
16260
- src0->grad,
16261
- ggml_out_prod(ctx, // [n,m]
16262
- src1, // [n,p]
16263
- tensor->grad), // [m,p]
16264
- inplace);
17003
+ ggml_add_or_set(ctx,
17004
+ src0->grad, // [n,m,q1,r1]
17005
+ s1_tg, // [n,m,q1,r1]
17006
+ zero_table);
16265
17007
  }
16266
17008
  if (src1->grad) {
16267
17009
  src1->grad =
16268
- ggml_add_impl(ctx,
16269
- src1->grad,
16270
- // ggml_mul_mat(ctx, // [n,p]
16271
- // ggml_cont(ctx, // [m,n]
16272
- // ggml_transpose(ctx, src0)), // [m,n]
16273
- // tensor->grad), // [m,p]
17010
+ ggml_add_or_set(ctx,
17011
+ src1->grad, // [n,p,qq,rr]
17012
+ // ggml_mul_mat(ctx, // [n,p,qq,rr]
17013
+ // ggml_cont(ctx, // [m,n,q1,r1]
17014
+ // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
17015
+ // tensor->grad), // [m,p,qq,rr]
16274
17016
 
16275
17017
  // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
16276
17018
  // // avoid transpose of src0, rather transpose smaller tensor->grad
16277
17019
  // // and then use ggml_out_prod
16278
- ggml_out_prod(ctx, // [n,p]
16279
- src0, // [n,m]
16280
- ggml_transpose(ctx, // [p,m]
16281
- tensor->grad)), // [m,p]
16282
- inplace);
17020
+ ggml_out_prod(ctx, // [n,p,qq,rr]
17021
+ src0, // [n,m,q1,r1]
17022
+ ggml_transpose(ctx, // [p,m,qq,rr]
17023
+ tensor->grad)), // [m,p,qq,rr]
17024
+ zero_table);
16283
17025
  }
16284
17026
  } break;
16285
17027
  case GGML_OP_OUT_PROD:
@@ -16291,17 +17033,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16291
17033
  // necessary for llama
16292
17034
  if (src0->grad) {
16293
17035
  src0->grad =
16294
- ggml_add_impl(ctx,
17036
+ ggml_add_or_set(ctx,
16295
17037
  src0->grad,
16296
17038
  ggml_scale_impl(ctx, tensor->grad, src1, false),
16297
- inplace);
17039
+ zero_table);
16298
17040
  }
16299
17041
  if (src1->grad) {
16300
17042
  src1->grad =
16301
- ggml_add_impl(ctx,
17043
+ ggml_add_or_set(ctx,
16302
17044
  src1->grad,
16303
17045
  ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
16304
- inplace);
17046
+ zero_table);
16305
17047
  }
16306
17048
  } break;
16307
17049
  case GGML_OP_SET:
@@ -16328,23 +17070,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16328
17070
  }
16329
17071
 
16330
17072
  if (src0->grad) {
16331
- src0->grad = ggml_add_impl(ctx,
17073
+ src0->grad = ggml_add_or_set(ctx,
16332
17074
  src0->grad,
16333
17075
  ggml_acc_impl(ctx,
16334
17076
  tensor->grad,
16335
17077
  ggml_neg(ctx, tensor_grad_view),
16336
17078
  nb1, nb2, nb3, offset, false),
16337
- inplace);
17079
+ zero_table);
16338
17080
  }
16339
17081
 
16340
17082
  if (src1->grad) {
16341
17083
  src1->grad =
16342
- ggml_add_impl(ctx,
17084
+ ggml_add_or_set(ctx,
16343
17085
  src1->grad,
16344
17086
  ggml_reshape(ctx,
16345
17087
  ggml_cont(ctx, tensor_grad_view),
16346
17088
  src1->grad),
16347
- inplace);
17089
+ zero_table);
16348
17090
  }
16349
17091
  } break;
16350
17092
  case GGML_OP_CPY:
@@ -16355,7 +17097,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16355
17097
  // tensor = src0 * 1 + src1 * 0
16356
17098
  if (src0->grad) {
16357
17099
  // dsrc0 = dtensor * 1
16358
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
17100
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16359
17101
  }
16360
17102
  if (src1->grad) {
16361
17103
  // dsrc1 = dtensor * 0 -> noop
@@ -16367,7 +17109,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16367
17109
  if (src0->grad) {
16368
17110
  GGML_ASSERT(ggml_is_contiguous(src0->grad));
16369
17111
  GGML_ASSERT(ggml_is_contiguous(tensor->grad));
16370
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
17112
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16371
17113
  }
16372
17114
  } break;
16373
17115
  case GGML_OP_RESHAPE:
@@ -16375,9 +17117,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16375
17117
  // necessary for llama
16376
17118
  if (src0->grad) {
16377
17119
  src0->grad =
16378
- ggml_add_impl(ctx, src0->grad,
16379
- ggml_reshape(ctx, tensor->grad, src0->grad),
16380
- inplace);
17120
+ ggml_add_or_set(ctx, src0->grad,
17121
+ ggml_reshape(ctx,
17122
+ ggml_is_contiguous(tensor->grad)
17123
+ ? tensor->grad
17124
+ : ggml_cont(ctx, tensor->grad),
17125
+ src0->grad),
17126
+ zero_table);
16381
17127
  }
16382
17128
  } break;
16383
17129
  case GGML_OP_VIEW:
@@ -16406,7 +17152,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16406
17152
  nb3 = (nb3 / n0) * ng;
16407
17153
  }
16408
17154
 
16409
- src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
17155
+ src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
16410
17156
  }
16411
17157
  } break;
16412
17158
  case GGML_OP_PERMUTE:
@@ -16424,14 +17170,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16424
17170
  axes_backward[axis2] = 2;
16425
17171
  axes_backward[axis3] = 3;
16426
17172
  src0->grad =
16427
- ggml_add_impl(ctx, src0->grad,
17173
+ ggml_add_or_set(ctx, src0->grad,
16428
17174
  ggml_permute(ctx,
16429
17175
  tensor->grad,
16430
17176
  axes_backward[0],
16431
17177
  axes_backward[1],
16432
17178
  axes_backward[2],
16433
17179
  axes_backward[3]),
16434
- inplace);
17180
+ zero_table);
16435
17181
  }
16436
17182
  } break;
16437
17183
  case GGML_OP_TRANSPOSE:
@@ -16439,9 +17185,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16439
17185
  // necessary for llama
16440
17186
  if (src0->grad) {
16441
17187
  src0->grad =
16442
- ggml_add_impl(ctx, src0->grad,
17188
+ ggml_add_or_set(ctx, src0->grad,
16443
17189
  ggml_transpose(ctx, tensor->grad),
16444
- inplace);
17190
+ zero_table);
16445
17191
  }
16446
17192
  } break;
16447
17193
  case GGML_OP_GET_ROWS:
@@ -16449,9 +17195,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16449
17195
  // necessary for llama (only for tokenizer)
16450
17196
  if (src0->grad) {
16451
17197
  src0->grad =
16452
- ggml_add_impl(ctx, src0->grad,
17198
+ ggml_add_or_set(ctx, src0->grad,
17199
+ // last ggml_get_rows_back argument src0->grad is only
17200
+ // necessary to setup correct output shape
16453
17201
  ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
16454
- inplace);
17202
+ zero_table);
16455
17203
  }
16456
17204
  if (src1->grad) {
16457
17205
  // noop
@@ -16471,9 +17219,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16471
17219
  if (src0->grad) {
16472
17220
  const int n_past = ((int32_t *) tensor->op_params)[0];
16473
17221
  src0->grad =
16474
- ggml_add_impl(ctx, src0->grad,
17222
+ ggml_add_or_set(ctx, src0->grad,
16475
17223
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
16476
- inplace);
17224
+ zero_table);
16477
17225
  }
16478
17226
  } break;
16479
17227
  case GGML_OP_DIAG_MASK_ZERO:
@@ -16482,9 +17230,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16482
17230
  if (src0->grad) {
16483
17231
  const int n_past = ((int32_t *) tensor->op_params)[0];
16484
17232
  src0->grad =
16485
- ggml_add_impl(ctx, src0->grad,
17233
+ ggml_add_or_set(ctx, src0->grad,
16486
17234
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
16487
- inplace);
17235
+ zero_table);
16488
17236
  }
16489
17237
  } break;
16490
17238
  case GGML_OP_SOFT_MAX:
@@ -16492,9 +17240,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16492
17240
  // necessary for llama
16493
17241
  if (src0->grad) {
16494
17242
  src0->grad =
16495
- ggml_add_impl(ctx, src0->grad,
17243
+ ggml_add_or_set(ctx, src0->grad,
16496
17244
  ggml_soft_max_back(ctx, tensor->grad, tensor),
16497
- inplace);
17245
+ zero_table);
16498
17246
  }
16499
17247
 
16500
17248
  } break;
@@ -16506,7 +17254,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16506
17254
  {
16507
17255
  // necessary for llama
16508
17256
  if (src0->grad) {
16509
- const int n_past = ((int32_t *) tensor->op_params)[0];
17257
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
16510
17258
  const int n_dims = ((int32_t *) tensor->op_params)[1];
16511
17259
  const int mode = ((int32_t *) tensor->op_params)[2];
16512
17260
  const int n_ctx = ((int32_t *) tensor->op_params)[3];
@@ -16519,11 +17267,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16519
17267
  memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
16520
17268
  memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
16521
17269
 
16522
- src0->grad = ggml_add_impl(ctx,
17270
+ src0->grad = ggml_add_or_set(ctx,
16523
17271
  src0->grad,
16524
17272
  ggml_rope_back(ctx,
16525
17273
  tensor->grad,
16526
- n_past,
17274
+ src1,
16527
17275
  n_dims,
16528
17276
  mode,
16529
17277
  n_ctx,
@@ -16531,13 +17279,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16531
17279
  freq_scale,
16532
17280
  xpos_base,
16533
17281
  xpos_down),
16534
- inplace);
17282
+ zero_table);
16535
17283
  }
16536
17284
  } break;
16537
17285
  case GGML_OP_ROPE_BACK:
16538
17286
  {
16539
17287
  if (src0->grad) {
16540
- const int n_past = ((int32_t *) tensor->op_params)[0];
17288
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
16541
17289
  const int n_dims = ((int32_t *) tensor->op_params)[1];
16542
17290
  const int mode = ((int32_t *) tensor->op_params)[2];
16543
17291
  const int n_ctx = ((int32_t *) tensor->op_params)[3];
@@ -16550,11 +17298,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16550
17298
  memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
16551
17299
  memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
16552
17300
 
16553
- src0->grad = ggml_add_impl(ctx,
17301
+ src0->grad = ggml_add_or_set(ctx,
16554
17302
  src0->grad,
16555
17303
  ggml_rope_impl(ctx,
16556
17304
  tensor->grad,
16557
- n_past,
17305
+ src1,
16558
17306
  n_dims,
16559
17307
  mode,
16560
17308
  n_ctx,
@@ -16563,7 +17311,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16563
17311
  xpos_base,
16564
17312
  xpos_down,
16565
17313
  false),
16566
- inplace);
17314
+ zero_table);
16567
17315
  }
16568
17316
  } break;
16569
17317
  case GGML_OP_ALIBI:
@@ -16614,145 +17362,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16614
17362
  masked);
16615
17363
  }
16616
17364
 
16617
- if (src0->grad) {
16618
- struct ggml_tensor * grad_q = NULL;
16619
- const size_t nb0 = flash_grad->nb[0];
16620
- const size_t offset = 0;
16621
- switch(src0->n_dims) {
16622
- case 2:
16623
- {
16624
- grad_q = ggml_view_2d(ctx,
16625
- flash_grad,
16626
- src0->ne[0],
16627
- src0->ne[1],
16628
- nb0*src0->ne[0],
16629
- offset);
16630
- } break;
16631
- case 3:
16632
- {
16633
- grad_q = ggml_view_3d(ctx,
16634
- flash_grad,
16635
- src0->ne[0],
16636
- src0->ne[1],
16637
- src0->ne[2],
16638
- nb0*src0->ne[0],
16639
- nb0*src0->ne[0]*src0->ne[1],
16640
- offset);
16641
- } break;
16642
- case 4:
16643
- {
16644
- grad_q = ggml_view_4d(ctx,
16645
- flash_grad,
16646
- src0->ne[0],
16647
- src0->ne[1],
16648
- src0->ne[2],
16649
- src0->ne[3],
16650
- nb0*src0->ne[0],
16651
- nb0*src0->ne[0]*src0->ne[1],
16652
- nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
16653
- offset);
16654
- } break;
16655
- }
17365
+ struct ggml_tensor * src2 = tensor->src[2];
17366
+ const int64_t elem_q = ggml_nelements(src0);
17367
+ const int64_t elem_k = ggml_nelements(src1);
17368
+ const int64_t elem_v = ggml_nelements(src2);
17369
+
17370
+ enum ggml_type result_type = flash_grad->type;
17371
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
17372
+ const size_t tsize = ggml_type_size(result_type);
17373
+
17374
+ const size_t offs_q = 0;
17375
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
17376
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
16656
17377
 
16657
- src0->grad = ggml_add_impl(ctx,
17378
+ if (src0->grad) {
17379
+ struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
17380
+ struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
17381
+ src0->grad = ggml_add_or_set(ctx,
16658
17382
  src0->grad,
16659
17383
  grad_q,
16660
- inplace);
17384
+ zero_table);
16661
17385
  }
16662
-
16663
17386
  if (src1->grad) {
16664
- struct ggml_tensor * grad_k = NULL;
16665
- const size_t nb0 = flash_grad->nb[0];
16666
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
16667
- switch(src1->n_dims) {
16668
- case 2:
16669
- {
16670
- grad_k = ggml_view_2d(ctx,
16671
- flash_grad,
16672
- src1->ne[0],
16673
- src1->ne[1],
16674
- nb0*src1->ne[0],
16675
- offset);
16676
- } break;
16677
- case 3:
16678
- {
16679
- grad_k = ggml_view_3d(ctx,
16680
- flash_grad,
16681
- src1->ne[0],
16682
- src1->ne[1],
16683
- src1->ne[2],
16684
- nb0*src1->ne[0],
16685
- nb0*src1->ne[0]*src1->ne[1],
16686
- offset);
16687
- } break;
16688
- case 4:
16689
- {
16690
- grad_k = ggml_view_4d(ctx,
16691
- flash_grad,
16692
- src1->ne[0],
16693
- src1->ne[1],
16694
- src1->ne[2],
16695
- src1->ne[3],
16696
- nb0*src1->ne[0],
16697
- nb0*src1->ne[0]*src1->ne[1],
16698
- nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
16699
- offset);
16700
- } break;
16701
- }
16702
-
16703
- src1->grad = ggml_add_impl(ctx,
17387
+ struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
17388
+ struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
17389
+ src1->grad = ggml_add_or_set(ctx,
16704
17390
  src1->grad,
16705
17391
  grad_k,
16706
- inplace);
17392
+ zero_table);
16707
17393
  }
16708
-
16709
- struct ggml_tensor * opt0 = tensor->src[2];
16710
-
16711
- if (opt0->grad) {
16712
- struct ggml_tensor * grad_v = NULL;
16713
- const size_t nb0 = flash_grad->nb[0];
16714
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
16715
- + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
16716
- switch(opt0->n_dims) {
16717
- case 2:
16718
- {
16719
- grad_v = ggml_view_2d(ctx,
16720
- flash_grad,
16721
- opt0->ne[0],
16722
- opt0->ne[1],
16723
- nb0*opt0->ne[0],
16724
- offset);
16725
- } break;
16726
- case 3:
16727
- {
16728
- grad_v = ggml_view_3d(ctx,
16729
- flash_grad,
16730
- opt0->ne[0],
16731
- opt0->ne[1],
16732
- opt0->ne[2],
16733
- nb0*opt0->ne[0],
16734
- nb0*opt0->ne[0]*opt0->ne[1],
16735
- offset);
16736
- } break;
16737
- case 4:
16738
- {
16739
- grad_v = ggml_view_4d(ctx,
16740
- flash_grad,
16741
- opt0->ne[0],
16742
- opt0->ne[1],
16743
- opt0->ne[2],
16744
- opt0->ne[3],
16745
- nb0*opt0->ne[0],
16746
- nb0*opt0->ne[0]*opt0->ne[1],
16747
- nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
16748
- offset);
16749
- } break;
16750
- }
16751
-
16752
- opt0->grad = ggml_add_impl(ctx,
16753
- opt0->grad,
17394
+ if (src2->grad) {
17395
+ struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
17396
+ struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
17397
+ src2->grad = ggml_add_or_set(ctx,
17398
+ src2->grad,
16754
17399
  grad_v,
16755
- inplace);
17400
+ zero_table);
16756
17401
  }
16757
17402
  } break;
16758
17403
  case GGML_OP_FLASH_FF:
@@ -16772,12 +17417,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16772
17417
  {
16773
17418
  if (src0->grad) {
16774
17419
  src0->grad =
16775
- ggml_add_impl(ctx,
17420
+ ggml_add_or_set(ctx,
16776
17421
  src0->grad,
16777
17422
  ggml_mul(ctx,
16778
17423
  ggml_sgn(ctx, src0),
16779
17424
  tensor->grad),
16780
- inplace);
17425
+ zero_table);
16781
17426
  }
16782
17427
  } break;
16783
17428
  case GGML_UNARY_OP_SGN:
@@ -16789,7 +17434,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16789
17434
  case GGML_UNARY_OP_NEG:
16790
17435
  {
16791
17436
  if (src0->grad) {
16792
- src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
17437
+ src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
16793
17438
  }
16794
17439
  } break;
16795
17440
  case GGML_UNARY_OP_STEP:
@@ -16809,12 +17454,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16809
17454
  case GGML_UNARY_OP_RELU:
16810
17455
  {
16811
17456
  if (src0->grad) {
16812
- src0->grad = ggml_add_impl(ctx,
17457
+ src0->grad = ggml_add_or_set(ctx,
16813
17458
  src0->grad,
16814
17459
  ggml_mul(ctx,
16815
17460
  ggml_step(ctx, src0),
16816
17461
  tensor->grad),
16817
- inplace);
17462
+ zero_table);
16818
17463
  }
16819
17464
  } break;
16820
17465
  case GGML_UNARY_OP_GELU:
@@ -16829,10 +17474,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16829
17474
  {
16830
17475
  // necessary for llama
16831
17476
  if (src0->grad) {
16832
- src0->grad = ggml_add_impl(ctx,
17477
+ src0->grad = ggml_add_or_set(ctx,
16833
17478
  src0->grad,
16834
17479
  ggml_silu_back(ctx, src0, tensor->grad),
16835
- inplace);
17480
+ zero_table);
16836
17481
  }
16837
17482
  } break;
16838
17483
  default:
@@ -16855,13 +17500,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16855
17500
  case GGML_OP_CROSS_ENTROPY_LOSS:
16856
17501
  {
16857
17502
  if (src0->grad) {
16858
- src0->grad = ggml_add_impl(ctx,
17503
+ src0->grad = ggml_add_or_set(ctx,
16859
17504
  src0->grad,
16860
17505
  ggml_cross_entropy_loss_back(ctx,
16861
17506
  src0,
16862
17507
  src1,
16863
17508
  tensor->grad),
16864
- inplace);
17509
+ zero_table);
16865
17510
  }
16866
17511
  } break;
16867
17512
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@@ -16877,34 +17522,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16877
17522
  GGML_ASSERT(false);
16878
17523
  } break;
16879
17524
  }
16880
- }
16881
-
16882
- static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
16883
-
16884
- static size_t hash(void * p) {
16885
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
16886
- }
16887
17525
 
16888
- static bool hash_insert(void * hash_table[], void * p) {
16889
- size_t h = hash(p);
16890
-
16891
- // linear probing
16892
- size_t i = h;
16893
- while (hash_table[i] != NULL && hash_table[i] != p) {
16894
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
16895
- if (i == h) {
16896
- // hash table is full
16897
- GGML_ASSERT(false);
17526
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
17527
+ if (tensor->src[i] && tensor->src[i]->grad) {
17528
+ GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
16898
17529
  }
16899
17530
  }
16900
-
16901
- if (hash_table[i] == p) {
16902
- return true;
16903
- }
16904
-
16905
- // insert
16906
- hash_table[i] = p;
16907
- return false;
16908
17531
  }
16909
17532
 
16910
17533
  static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
@@ -16922,8 +17545,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
16922
17545
  }
16923
17546
 
16924
17547
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
16925
- if (node->src[i]) {
16926
- ggml_visit_parents(cgraph, node->src[i]);
17548
+ const int k =
17549
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
17550
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
17551
+ /* unknown order, just fall back to using i*/ i;
17552
+ if (node->src[k]) {
17553
+ ggml_visit_parents(cgraph, node->src[k]);
16927
17554
  }
16928
17555
  }
16929
17556
 
@@ -16982,6 +17609,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
16982
17609
  /*.grads =*/ { NULL },
16983
17610
  /*.leafs =*/ { NULL },
16984
17611
  /*.hash_table =*/ { NULL },
17612
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
16985
17613
  /*.perf_runs =*/ 0,
16986
17614
  /*.perf_cycles =*/ 0,
16987
17615
  /*.perf_time_us =*/ 0,
@@ -17007,12 +17635,22 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
17007
17635
  }
17008
17636
  }
17009
17637
 
17638
+ // remember original gradients which start with zero values
17639
+ void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
17640
+ memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
17641
+ for (int i = 0; i < gf->n_nodes; i++) {
17642
+ if (gf->grads[i]) {
17643
+ hash_insert(zero_table, gf->grads[i]);
17644
+ }
17645
+ }
17646
+
17010
17647
  for (int i = gf->n_nodes - 1; i >= 0; i--) {
17011
17648
  struct ggml_tensor * node = gf->nodes[i];
17012
17649
 
17013
- // because we detached the grad nodes from the original graph, we can afford inplace operations
17650
+ // inplace operations to add gradients are not created by ggml_compute_backward
17651
+ // use allocator to automatically make inplace operations
17014
17652
  if (node->grad) {
17015
- ggml_compute_backward(ctx, node, keep);
17653
+ ggml_compute_backward(ctx, node, zero_table);
17016
17654
  }
17017
17655
  }
17018
17656
 
@@ -17024,6 +17662,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
17024
17662
  ggml_build_forward_expand(gb, node->grad);
17025
17663
  }
17026
17664
  }
17665
+
17666
+ free(zero_table);
17027
17667
  }
17028
17668
 
17029
17669
  struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
@@ -17043,6 +17683,7 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
17043
17683
  /*.grads =*/ { NULL },
17044
17684
  /*.leafs =*/ { NULL },
17045
17685
  /*.hash_table =*/ { NULL },
17686
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
17046
17687
  /*.perf_runs =*/ 0,
17047
17688
  /*.perf_cycles =*/ 0,
17048
17689
  /*.perf_time_us =*/ 0,
@@ -17433,7 +18074,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
17433
18074
  } break;
17434
18075
  case GGML_OP_CONCAT:
17435
18076
  case GGML_OP_MUL_MAT:
17436
- case GGML_OP_OUT_PROD:
17437
18077
  {
17438
18078
  n_tasks = n_threads;
17439
18079
 
@@ -17475,6 +18115,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
17475
18115
  cur = 0;
17476
18116
  }
17477
18117
 
18118
+ work_size = MAX(work_size, cur);
18119
+ } break;
18120
+ case GGML_OP_OUT_PROD:
18121
+ {
18122
+ n_tasks = n_threads;
18123
+
18124
+ size_t cur = 0;
18125
+
18126
+ if (ggml_is_quantized(node->src[0]->type)) {
18127
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
18128
+ }
18129
+
17478
18130
  work_size = MAX(work_size, cur);
17479
18131
  } break;
17480
18132
  case GGML_OP_SCALE:
@@ -18568,7 +19220,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
18568
19220
  }
18569
19221
 
18570
19222
  static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
18571
- int i = 0;
19223
+ int64_t i = 0;
18572
19224
  for (int p = 0; p < np; ++p) {
18573
19225
  const int64_t ne = ggml_nelements(ps[p]) ;
18574
19226
  // TODO: add function to get all elements at once
@@ -18578,6 +19230,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
18578
19230
  }
18579
19231
  }
18580
19232
 
19233
+ static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
19234
+ int64_t i = 0;
19235
+ for (int p = 0; p < np; ++p) {
19236
+ const int64_t ne = ggml_nelements(ps[p]) ;
19237
+ // TODO: add function to get all elements at once
19238
+ for (int64_t j = 0; j < ne; ++j) {
19239
+ g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
19240
+ }
19241
+ }
19242
+ }
19243
+
18581
19244
  //
18582
19245
  // ADAM
18583
19246
  //
@@ -18626,26 +19289,43 @@ static enum ggml_opt_result ggml_opt_adam(
18626
19289
  const float eps = params.adam.eps;
18627
19290
  const float gclip = params.adam.gclip;
18628
19291
  const int decay_min_ndim = params.adam.decay_min_ndim;
19292
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
19293
+ const float accum_norm = 1.0f / (float) n_accum;
18629
19294
 
19295
+ float * g = opt->adam.g->data; // gradients
18630
19296
  float * m = opt->adam.m->data; // first moment
18631
19297
  float * v = opt->adam.v->data; // second moment
18632
19298
 
18633
19299
  float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
18634
19300
 
18635
- if (callback) {
18636
- callback(callback_data, &sched);
18637
- }
18638
-
18639
- // compute the function value
18640
- ggml_graph_reset (gf);
18641
- ggml_set_f32 (f->grad, 1.0f);
18642
-
18643
19301
  struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
18644
19302
  struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
18645
19303
  cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
18646
- ggml_graph_compute(gb, &cplan);
18647
19304
 
18648
- opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
19305
+ bool cancel = false;
19306
+
19307
+ // compute the function value
19308
+ float fx = 0;
19309
+ ggml_set_zero(opt->adam.g);
19310
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19311
+ if (callback) {
19312
+ callback(callback_data, accum_step, &sched, &cancel);
19313
+ if (cancel) {
19314
+ break;
19315
+ }
19316
+ }
19317
+ // ggml_graph_reset (gf);
19318
+ ggml_set_f32 (f->grad, 1.0f);
19319
+ ggml_graph_compute(gb, &cplan);
19320
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19321
+ fx += ggml_get_f32_1d(f, 0);
19322
+ }
19323
+ if (cancel) {
19324
+ return GGML_OPT_DID_NOT_CONVERGE;
19325
+ }
19326
+ fx *= accum_norm;
19327
+
19328
+ opt->adam.fx_prev = fx;
18649
19329
  opt->adam.fx_best = opt->adam.fx_prev;
18650
19330
  if (pf) {
18651
19331
  pf[opt->iter % params.past] = opt->adam.fx_prev;
@@ -18668,6 +19348,9 @@ static enum ggml_opt_result ggml_opt_adam(
18668
19348
 
18669
19349
  // run the optimizer
18670
19350
  for (int t = 0; t < params.adam.n_iter; ++t) {
19351
+ if (cancel) {
19352
+ break;
19353
+ }
18671
19354
  opt->iter = iter0 + t + 1;
18672
19355
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
18673
19356
 
@@ -18690,12 +19373,8 @@ static enum ggml_opt_result ggml_opt_adam(
18690
19373
  if (gclip > 0.0f) {
18691
19374
  // gradient clipping
18692
19375
  ggml_float sum = 0.0;
18693
- for (int p = 0; p < np; ++p) {
18694
- const int64_t ne = ggml_nelements(ps[p]);
18695
- for (int64_t j = 0; j < ne; ++j) {
18696
- float g = ggml_get_f32_1d(ps[p]->grad, j);
18697
- sum += (ggml_float)(g*g);
18698
- }
19376
+ for (int64_t i = 0; i < nx; ++i) {
19377
+ sum += (ggml_float)(g[i]*g[i]);
18699
19378
  }
18700
19379
  ggml_float norm = sqrt(sum);
18701
19380
  if (norm > (ggml_float) gclip) {
@@ -18709,10 +19388,10 @@ static enum ggml_opt_result ggml_opt_adam(
18709
19388
  const int64_t ne = ggml_nelements(ps[p]);
18710
19389
  const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
18711
19390
  for (int64_t j = 0; j < ne; ++j) {
18712
- float x = ggml_get_f32_1d(ps[p], j);
18713
- float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
18714
- m[i] = m[i]*beta1 + g*(1.0f - beta1);
18715
- v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
19391
+ float x = ggml_get_f32_1d(ps[p], j);
19392
+ float g_ = g[i]*gnorm;
19393
+ m[i] = m[i]*beta1 + g_*(1.0f - beta1);
19394
+ v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
18716
19395
  float mh = m[i]*beta1h;
18717
19396
  float vh = v[i]*beta2h;
18718
19397
  vh = sqrtf(vh) + eps;
@@ -18723,16 +19402,26 @@ static enum ggml_opt_result ggml_opt_adam(
18723
19402
  }
18724
19403
  }
18725
19404
 
18726
- if (callback) {
18727
- callback(callback_data, &sched);
19405
+ fx = 0;
19406
+ ggml_set_zero(opt->adam.g);
19407
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19408
+ if (callback) {
19409
+ callback(callback_data, accum_step, &sched, &cancel);
19410
+ if (cancel) {
19411
+ break;
19412
+ }
19413
+ }
19414
+ // ggml_graph_reset (gf);
19415
+ ggml_set_f32 (f->grad, 1.0f);
19416
+ ggml_graph_compute(gb, &cplan);
19417
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19418
+ fx += ggml_get_f32_1d(f, 0);
18728
19419
  }
19420
+ if (cancel) {
19421
+ break;
19422
+ }
19423
+ fx *= accum_norm;
18729
19424
 
18730
- ggml_graph_reset (gf);
18731
- ggml_set_f32 (f->grad, 1.0f);
18732
-
18733
- ggml_graph_compute(gb, &cplan);
18734
-
18735
- const float fx = ggml_get_f32_1d(f, 0);
18736
19425
  opt->loss_after = fx;
18737
19426
 
18738
19427
 
@@ -18812,11 +19501,11 @@ static enum ggml_opt_result linesearch_backtracking(
18812
19501
  float * step,
18813
19502
  const float * xp,
18814
19503
  struct ggml_tensor * f,
18815
- struct ggml_cgraph * gf,
18816
19504
  struct ggml_cgraph * gb,
18817
19505
  struct ggml_cplan * cplan,
18818
19506
  const int np,
18819
19507
  struct ggml_tensor * ps[],
19508
+ bool * cancel,
18820
19509
  ggml_opt_callback callback,
18821
19510
  void * callback_data) {
18822
19511
  int count = 0;
@@ -18830,6 +19519,9 @@ static enum ggml_opt_result linesearch_backtracking(
18830
19519
  const float dec = 0.5f;
18831
19520
  const float inc = 2.1f;
18832
19521
 
19522
+ const int n_accum = MAX(1, params->n_gradient_accumulation);
19523
+ const float accum_norm = 1.0f / (float) n_accum;
19524
+
18833
19525
  if (*step <= 0.f) {
18834
19526
  return GGML_LINESEARCH_INVALID_PARAMETERS;
18835
19527
  }
@@ -18846,13 +19538,7 @@ static enum ggml_opt_result linesearch_backtracking(
18846
19538
  finit = *fx;
18847
19539
  dgtest = params->lbfgs.ftol*dginit;
18848
19540
 
18849
- while (true) {
18850
- if (callback) {
18851
- // LBFG-S does not support learning rate -> ignore learning schedule
18852
- float sched = 0;
18853
- callback(callback_data, &sched);
18854
- }
18855
-
19541
+ while (!*cancel) {
18856
19542
  ggml_vec_cpy_f32(nx, x, xp);
18857
19543
  ggml_vec_mad_f32(nx, x, d, *step);
18858
19544
 
@@ -18860,14 +19546,28 @@ static enum ggml_opt_result linesearch_backtracking(
18860
19546
  {
18861
19547
  ggml_opt_set_params(np, ps, x);
18862
19548
 
18863
- ggml_graph_reset (gf);
18864
- ggml_set_f32 (f->grad, 1.0f);
18865
-
18866
- ggml_graph_compute(gb, cplan);
18867
-
18868
- ggml_opt_get_grad(np, ps, g);
19549
+ *fx = 0;
19550
+ memset(g, 0, sizeof(float)*nx);
19551
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19552
+ if (callback) {
19553
+ // LBFG-S does not support learning rate -> ignore learning schedule
19554
+ float sched = 0;
19555
+ callback(callback_data, accum_step, &sched, cancel);
19556
+ if (*cancel) {
19557
+ break;
19558
+ }
19559
+ }
19560
+ // ggml_graph_reset (gf);
19561
+ ggml_set_f32 (f->grad, 1.0f);
19562
+ ggml_graph_compute(gb, cplan);
19563
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19564
+ *fx += ggml_get_f32_1d(f, 0);
19565
+ }
19566
+ if (*cancel) {
19567
+ break;
19568
+ }
19569
+ *fx *= accum_norm;
18869
19570
 
18870
- *fx = ggml_get_f32_1d(f, 0);
18871
19571
  }
18872
19572
 
18873
19573
  ++count;
@@ -18913,7 +19613,7 @@ static enum ggml_opt_result linesearch_backtracking(
18913
19613
  (*step) *= width;
18914
19614
  }
18915
19615
 
18916
- return GGML_LINESEARCH_FAIL;
19616
+ GGML_UNREACHABLE();
18917
19617
  }
18918
19618
 
18919
19619
  static enum ggml_opt_result ggml_opt_lbfgs(
@@ -18968,6 +19668,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18968
19668
 
18969
19669
  float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
18970
19670
 
19671
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
19672
+ const float accum_norm = 1.0f / (float) n_accum;
19673
+
18971
19674
  float fx = 0.0f; // cost function value
18972
19675
  float xnorm = 0.0f; // ||x||
18973
19676
  float gnorm = 0.0f; // ||g||
@@ -18981,24 +19684,33 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18981
19684
  float * lm_s = opt->lbfgs.lms->data;
18982
19685
  float * lm_y = opt->lbfgs.lmy->data;
18983
19686
 
18984
- if (callback) {
18985
- // LBFG-S does not support learning rate -> ignore learning schedule
18986
- float sched = 0;
18987
- callback(callback_data, &sched);
18988
- }
19687
+ bool cancel = false;
18989
19688
 
18990
19689
  // evaluate the function value and its gradient
18991
19690
  {
18992
19691
  ggml_opt_set_params(np, ps, x);
18993
19692
 
18994
- ggml_graph_reset (gf);
18995
- ggml_set_f32 (f->grad, 1.0f);
18996
-
18997
- ggml_graph_compute(gb, &cplan);
18998
-
18999
- ggml_opt_get_grad(np, ps, g);
19000
-
19001
- fx = ggml_get_f32_1d(f, 0);
19693
+ fx = 0;
19694
+ memset(g, 0, sizeof(float)*nx);
19695
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19696
+ if (callback) {
19697
+ // LBFG-S does not support learning rate -> ignore learning schedule
19698
+ float sched = 0;
19699
+ callback(callback_data, accum_step, &sched, &cancel);
19700
+ if (cancel) {
19701
+ break;
19702
+ }
19703
+ }
19704
+ // ggml_graph_reset (gf);
19705
+ ggml_set_f32 (f->grad, 1.0f);
19706
+ ggml_graph_compute(gb, &cplan);
19707
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19708
+ fx += ggml_get_f32_1d(f, 0);
19709
+ }
19710
+ if (cancel) {
19711
+ return GGML_OPT_DID_NOT_CONVERGE;
19712
+ }
19713
+ fx *= accum_norm;
19002
19714
 
19003
19715
  opt->loss_before = fx;
19004
19716
  opt->loss_after = fx;
@@ -19056,7 +19768,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19056
19768
  ggml_vec_cpy_f32(nx, xp, x);
19057
19769
  ggml_vec_cpy_f32(nx, gp, g);
19058
19770
 
19059
- ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data);
19771
+ ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
19772
+ if (!cancel) {
19773
+ break;
19774
+ }
19060
19775
 
19061
19776
  if (ls < 0) {
19062
19777
  // linesearch failed - go back to the previous point and return
@@ -19165,7 +19880,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19165
19880
  step[0] = 1.0;
19166
19881
  }
19167
19882
 
19168
- return GGML_OPT_DID_NOT_CONVERGE;
19883
+ GGML_UNREACHABLE();
19169
19884
  }
19170
19885
 
19171
19886
  struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
@@ -19185,6 +19900,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
19185
19900
  .print_forward_graph = true,
19186
19901
  .print_backward_graph = true,
19187
19902
 
19903
+ .n_gradient_accumulation = 1,
19904
+
19188
19905
  .adam = {
19189
19906
  .n_iter = 10000,
19190
19907
  .sched = 1.000f,
@@ -19213,6 +19930,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
19213
19930
  .print_forward_graph = true,
19214
19931
  .print_backward_graph = true,
19215
19932
 
19933
+ .n_gradient_accumulation = 1,
19934
+
19216
19935
  .lbfgs = {
19217
19936
  .m = 6,
19218
19937
  .n_iter = 100,
@@ -19243,13 +19962,32 @@ GGML_API void ggml_opt_init(
19243
19962
  opt->iter = 0;
19244
19963
  opt->nx = nx;
19245
19964
  opt->just_initialized = true;
19965
+ if (opt->ctx == NULL) {
19966
+ struct ggml_init_params ctx_opt_params;
19967
+ if (opt->params.type == GGML_OPT_ADAM) {
19968
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
19969
+ if (opt->params.past > 0) {
19970
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
19971
+ }
19972
+ } else if (opt->params.type == GGML_OPT_LBFGS) {
19973
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
19974
+ if (opt->params.past > 0) {
19975
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
19976
+ }
19977
+ }
19978
+ ctx_opt_params.mem_buffer = NULL;
19979
+ ctx_opt_params.no_alloc = false;
19980
+
19981
+ opt->ctx = ggml_init(ctx_opt_params);
19982
+ }
19246
19983
  switch (opt->params.type) {
19247
19984
  case GGML_OPT_ADAM:
19248
19985
  {
19249
- opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19250
- opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19986
+ opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19987
+ opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19988
+ opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19251
19989
  opt->adam.pf = params.past > 0
19252
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
19990
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
19253
19991
  : NULL;
19254
19992
  ggml_set_zero(opt->adam.m);
19255
19993
  ggml_set_zero(opt->adam.v);
@@ -19259,18 +19997,18 @@ GGML_API void ggml_opt_init(
19259
19997
  } break;
19260
19998
  case GGML_OPT_LBFGS:
19261
19999
  {
19262
- opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19263
- opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19264
- opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19265
- opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19266
- opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
20000
+ opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20001
+ opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20002
+ opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20003
+ opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20004
+ opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19267
20005
  opt->lbfgs.pf = params.past > 0
19268
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
20006
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
19269
20007
  : NULL;
19270
- opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
19271
- opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
19272
- opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
19273
- opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
20008
+ opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
20009
+ opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
20010
+ opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
20011
+ opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
19274
20012
  ggml_set_zero(opt->lbfgs.x);
19275
20013
  ggml_set_zero(opt->lbfgs.xp);
19276
20014
  ggml_set_zero(opt->lbfgs.g);
@@ -19876,10 +20614,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
19876
20614
  } break;
19877
20615
  case GGUF_TYPE_ARRAY:
19878
20616
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
19879
- };
20617
+ }
19880
20618
  } break;
19881
20619
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
19882
- };
20620
+ }
19883
20621
 
19884
20622
  if (!ok) {
19885
20623
  break;
@@ -20155,78 +20893,94 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
20155
20893
  return keyfound;
20156
20894
  }
20157
20895
 
20158
- const char * gguf_get_key(const struct gguf_context * ctx, int i) {
20159
- return ctx->kv[i].key.data;
20896
+ const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
20897
+ return ctx->kv[key_id].key.data;
20160
20898
  }
20161
20899
 
20162
- enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) {
20163
- return ctx->kv[i].type;
20900
+ enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
20901
+ return ctx->kv[key_id].type;
20164
20902
  }
20165
20903
 
20166
- enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) {
20167
- return ctx->kv[i].value.arr.type;
20904
+ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
20905
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20906
+ return ctx->kv[key_id].value.arr.type;
20168
20907
  }
20169
20908
 
20170
- const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) {
20171
- return ctx->kv[i].value.arr.data;
20909
+ const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
20910
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20911
+ return ctx->kv[key_id].value.arr.data;
20172
20912
  }
20173
20913
 
20174
20914
  const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
20915
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20175
20916
  struct gguf_kv * kv = &ctx->kv[key_id];
20176
20917
  struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
20177
20918
  return str->data;
20178
20919
  }
20179
20920
 
20180
- int gguf_get_arr_n(const struct gguf_context * ctx, int i) {
20181
- return ctx->kv[i].value.arr.n;
20921
+ int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
20922
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20923
+ return ctx->kv[key_id].value.arr.n;
20182
20924
  }
20183
20925
 
20184
- uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) {
20185
- return ctx->kv[i].value.uint8;
20926
+ uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
20927
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
20928
+ return ctx->kv[key_id].value.uint8;
20186
20929
  }
20187
20930
 
20188
- int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) {
20189
- return ctx->kv[i].value.int8;
20931
+ int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
20932
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
20933
+ return ctx->kv[key_id].value.int8;
20190
20934
  }
20191
20935
 
20192
- uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) {
20193
- return ctx->kv[i].value.uint16;
20936
+ uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
20937
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
20938
+ return ctx->kv[key_id].value.uint16;
20194
20939
  }
20195
20940
 
20196
- int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) {
20197
- return ctx->kv[i].value.int16;
20941
+ int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
20942
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
20943
+ return ctx->kv[key_id].value.int16;
20198
20944
  }
20199
20945
 
20200
- uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) {
20201
- return ctx->kv[i].value.uint32;
20946
+ uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
20947
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
20948
+ return ctx->kv[key_id].value.uint32;
20202
20949
  }
20203
20950
 
20204
- int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) {
20205
- return ctx->kv[i].value.int32;
20951
+ int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
20952
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
20953
+ return ctx->kv[key_id].value.int32;
20206
20954
  }
20207
20955
 
20208
- float gguf_get_val_f32(const struct gguf_context * ctx, int i) {
20209
- return ctx->kv[i].value.float32;
20956
+ float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
20957
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
20958
+ return ctx->kv[key_id].value.float32;
20210
20959
  }
20211
20960
 
20212
- uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) {
20213
- return ctx->kv[i].value.uint64;
20961
+ uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
20962
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
20963
+ return ctx->kv[key_id].value.uint64;
20214
20964
  }
20215
20965
 
20216
- int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) {
20217
- return ctx->kv[i].value.int64;
20966
+ int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
20967
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
20968
+ return ctx->kv[key_id].value.int64;
20218
20969
  }
20219
20970
 
20220
- double gguf_get_val_f64(const struct gguf_context * ctx, int i) {
20221
- return ctx->kv[i].value.float64;
20971
+ double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
20972
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
20973
+ return ctx->kv[key_id].value.float64;
20222
20974
  }
20223
20975
 
20224
- bool gguf_get_val_bool(const struct gguf_context * ctx, int i) {
20225
- return ctx->kv[i].value.bool_;
20976
+ bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
20977
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
20978
+ return ctx->kv[key_id].value.bool_;
20226
20979
  }
20227
20980
 
20228
- const char * gguf_get_val_str (const struct gguf_context * ctx, int i) {
20229
- return ctx->kv[i].value.str.data;
20981
+ const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
20982
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
20983
+ return ctx->kv[key_id].value.str.data;
20230
20984
  }
20231
20985
 
20232
20986
  int gguf_get_n_tensors(const struct gguf_context * ctx) {
@@ -20591,10 +21345,10 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf *
20591
21345
  } break;
20592
21346
  case GGUF_TYPE_ARRAY:
20593
21347
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
20594
- };
21348
+ }
20595
21349
  } break;
20596
21350
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
20597
- };
21351
+ }
20598
21352
  }
20599
21353
 
20600
21354
  // write tensor infos