llama_cpp 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -89,7 +89,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo
89
89
 
90
90
  static int pthread_join(pthread_t thread, void * unused) {
91
91
  (void) unused;
92
- return (int) WaitForSingleObject(thread, INFINITE);
92
+ int ret = (int) WaitForSingleObject(thread, INFINITE);
93
+ CloseHandle(thread);
94
+ return ret;
93
95
  }
94
96
 
95
97
  static int sched_yield (void) {
@@ -134,6 +136,7 @@ typedef void * thread_ret_t;
134
136
 
135
137
  #define GGML_SOFT_MAX_UNROLL 4
136
138
  #define GGML_VEC_DOT_UNROLL 2
139
+ #define GGML_VEC_MAD_UNROLL 32
137
140
 
138
141
  //
139
142
  // logging
@@ -242,18 +245,18 @@ inline static void * ggml_aligned_malloc(size_t size) {
242
245
  //
243
246
 
244
247
  #define GGML_TENSOR_UNARY_OP_LOCALS \
245
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
246
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \
247
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \
248
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
248
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
249
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
250
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
251
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
249
252
 
250
253
  #define GGML_TENSOR_BINARY_OP_LOCALS \
251
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
252
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \
253
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \
254
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \
255
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \
256
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
254
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
255
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
256
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
257
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
258
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
259
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
257
260
 
258
261
  #if defined(GGML_USE_ACCELERATE)
259
262
  #include <Accelerate/Accelerate.h>
@@ -1863,7 +1866,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1863
1866
  #define GGML_F16x8_ADD vaddq_f16
1864
1867
  #define GGML_F16x8_MUL vmulq_f16
1865
1868
  #define GGML_F16x8_REDUCE(res, x) \
1866
- { \
1869
+ do { \
1867
1870
  int offset = GGML_F16_ARR >> 1; \
1868
1871
  for (int i = 0; i < offset; ++i) { \
1869
1872
  x[i] = vaddq_f16(x[i], x[offset+i]); \
@@ -1879,7 +1882,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1879
1882
  const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1880
1883
  const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
1881
1884
  res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1882
- }
1885
+ } while (0)
1883
1886
 
1884
1887
  #define GGML_F16_VEC GGML_F16x8
1885
1888
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
@@ -1940,7 +1943,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1940
1943
  #define GGML_F32x8_ADD _mm256_add_ps
1941
1944
  #define GGML_F32x8_MUL _mm256_mul_ps
1942
1945
  #define GGML_F32x8_REDUCE(res, x) \
1943
- { \
1946
+ do { \
1944
1947
  int offset = GGML_F32_ARR >> 1; \
1945
1948
  for (int i = 0; i < offset; ++i) { \
1946
1949
  x[i] = _mm256_add_ps(x[i], x[offset+i]); \
@@ -1957,7 +1960,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1957
1960
  _mm256_extractf128_ps(x[0], 1)); \
1958
1961
  const __m128 t1 = _mm_hadd_ps(t0, t0); \
1959
1962
  res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
1960
- }
1963
+ } while (0)
1961
1964
  // TODO: is this optimal ?
1962
1965
 
1963
1966
  #define GGML_F32_VEC GGML_F32x8
@@ -3707,6 +3710,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
3707
3710
  #endif
3708
3711
  }
3709
3712
 
3713
+ // xs and vs are byte strides of x and v
3714
+ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
3715
+
3716
+ const float * restrict x[GGML_VEC_MAD_UNROLL];
3717
+ const float * restrict v[GGML_VEC_MAD_UNROLL];
3718
+
3719
+ for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
3720
+ x[i] = (const float *) ((const char *) xv + i*xs);
3721
+ v[i] = (const float *) ((const char *) vv + i*vs);
3722
+ }
3723
+
3724
+ #if defined(GGML_SIMD)
3725
+ const int np = (n & ~(GGML_F32_STEP - 1));
3726
+
3727
+ GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
3728
+
3729
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3730
+ vx[k] = GGML_F32_VEC_SET1(v[k][0]);
3731
+ }
3732
+
3733
+ GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
3734
+ GGML_F32_VEC ay[GGML_F32_ARR];
3735
+
3736
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
3737
+ for (int j = 0; j < GGML_F32_ARR; j++) {
3738
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
3739
+
3740
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3741
+ ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
3742
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
3743
+ }
3744
+
3745
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
3746
+ }
3747
+ }
3748
+
3749
+ // leftovers
3750
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3751
+ for (int i = np; i < n; ++i) {
3752
+ y[i] += x[k][i]*v[k][0];
3753
+ }
3754
+ }
3755
+ #else
3756
+ // scalar
3757
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
3758
+ for (int i = 0; i < n; ++i) {
3759
+ y[i] += x[k][i]*v[k][0];
3760
+ }
3761
+ }
3762
+ #endif
3763
+ }
3764
+
3710
3765
  //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
3711
3766
  inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
3712
3767
  #if defined(GGML_USE_ACCELERATE)
@@ -4303,10 +4358,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
4303
4358
  }
4304
4359
 
4305
4360
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
4306
- size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
4307
- for (int i = 1; i < GGML_MAX_DIMS; ++i) {
4308
- nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
4361
+ size_t nbytes;
4362
+ size_t blck_size = ggml_blck_size(tensor->type);
4363
+ if (blck_size == 1) {
4364
+ nbytes = ggml_type_size(tensor->type);
4365
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
4366
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
4367
+ }
4309
4368
  }
4369
+ else {
4370
+ nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
4371
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
4372
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
4373
+ }
4374
+ }
4375
+
4310
4376
  return nbytes;
4311
4377
  }
4312
4378
 
@@ -4381,10 +4447,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
4381
4447
  static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
4382
4448
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4383
4449
 
4384
- return
4385
- (t0->ne[1] == t1->ne[1]) &&
4386
- (t0->ne[2] == t1->ne[2]) &&
4387
- (t0->ne[3] == t1->ne[3]);
4450
+ return (t0->ne[1] == t1->ne[1]) &&
4451
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
4452
+ (t1->ne[3]%t0->ne[3] == 0);
4388
4453
  }
4389
4454
 
4390
4455
  enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
@@ -5054,43 +5119,78 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
5054
5119
  return tensor;
5055
5120
  }
5056
5121
 
5122
+ void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
5123
+ const int64_t ne2 = tensor->ne[2];
5124
+ const int64_t ne1 = tensor->ne[1];
5125
+ const int64_t ne0 = tensor->ne[0];
5126
+
5127
+ const int64_t i3_ = (i/(ne2*ne1*ne0));
5128
+ const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
5129
+ const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
5130
+ const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
5131
+
5132
+ if (i0) {
5133
+ * i0 = i0_;
5134
+ }
5135
+ if (i1) {
5136
+ * i1 = i1_;
5137
+ }
5138
+ if (i2) {
5139
+ * i2 = i2_;
5140
+ }
5141
+ if (i3) {
5142
+ * i3 = i3_;
5143
+ }
5144
+ }
5145
+
5057
5146
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
5147
+ if (!ggml_is_contiguous(tensor)) {
5148
+ int64_t id[4] = { 0, 0, 0, 0 };
5149
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5150
+ return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
5151
+ }
5058
5152
  switch (tensor->type) {
5059
5153
  case GGML_TYPE_I8:
5060
5154
  {
5061
5155
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
5062
5156
  return ((int8_t *)(tensor->data))[i];
5063
- } break;
5157
+ }
5064
5158
  case GGML_TYPE_I16:
5065
5159
  {
5066
5160
  GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
5067
5161
  return ((int16_t *)(tensor->data))[i];
5068
- } break;
5162
+ }
5069
5163
  case GGML_TYPE_I32:
5070
5164
  {
5071
5165
  GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
5072
5166
  return ((int32_t *)(tensor->data))[i];
5073
- } break;
5167
+ }
5074
5168
  case GGML_TYPE_F16:
5075
5169
  {
5076
5170
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
5077
5171
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
5078
- } break;
5172
+ }
5079
5173
  case GGML_TYPE_F32:
5080
5174
  {
5081
5175
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
5082
5176
  return ((float *)(tensor->data))[i];
5083
- } break;
5177
+ }
5084
5178
  default:
5085
5179
  {
5086
5180
  GGML_ASSERT(false);
5087
- } break;
5181
+ }
5088
5182
  }
5089
5183
 
5090
5184
  return 0.0f;
5091
5185
  }
5092
5186
 
5093
5187
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
5188
+ if (!ggml_is_contiguous(tensor)) {
5189
+ int64_t id[4] = { 0, 0, 0, 0 };
5190
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5191
+ ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
5192
+ return;
5193
+ }
5094
5194
  switch (tensor->type) {
5095
5195
  case GGML_TYPE_I8:
5096
5196
  {
@@ -5124,43 +5224,104 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
5124
5224
  }
5125
5225
  }
5126
5226
 
5227
+ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
5228
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5229
+ switch (tensor->type) {
5230
+ case GGML_TYPE_I8:
5231
+ return ((int8_t *) data)[0];
5232
+ case GGML_TYPE_I16:
5233
+ return ((int16_t *) data)[0];
5234
+ case GGML_TYPE_I32:
5235
+ return ((int32_t *) data)[0];
5236
+ case GGML_TYPE_F16:
5237
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
5238
+ case GGML_TYPE_F32:
5239
+ return ((float *) data)[0];
5240
+ default:
5241
+ GGML_ASSERT(false);
5242
+ }
5243
+
5244
+ return 0.0f;
5245
+ }
5246
+
5247
+ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
5248
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5249
+ switch (tensor->type) {
5250
+ case GGML_TYPE_I8:
5251
+ {
5252
+ ((int8_t *)(data))[0] = value;
5253
+ } break;
5254
+ case GGML_TYPE_I16:
5255
+ {
5256
+ ((int16_t *)(data))[0] = value;
5257
+ } break;
5258
+ case GGML_TYPE_I32:
5259
+ {
5260
+ ((int32_t *)(data))[0] = value;
5261
+ } break;
5262
+ case GGML_TYPE_F16:
5263
+ {
5264
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
5265
+ } break;
5266
+ case GGML_TYPE_F32:
5267
+ {
5268
+ ((float *)(data))[0] = value;
5269
+ } break;
5270
+ default:
5271
+ {
5272
+ GGML_ASSERT(false);
5273
+ } break;
5274
+ }
5275
+ }
5276
+
5127
5277
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
5278
+ if (!ggml_is_contiguous(tensor)) {
5279
+ int64_t id[4] = { 0, 0, 0, 0 };
5280
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5281
+ return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
5282
+ }
5128
5283
  switch (tensor->type) {
5129
5284
  case GGML_TYPE_I8:
5130
5285
  {
5131
5286
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
5132
5287
  return ((int8_t *)(tensor->data))[i];
5133
- } break;
5288
+ }
5134
5289
  case GGML_TYPE_I16:
5135
5290
  {
5136
5291
  GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
5137
5292
  return ((int16_t *)(tensor->data))[i];
5138
- } break;
5293
+ }
5139
5294
  case GGML_TYPE_I32:
5140
5295
  {
5141
5296
  GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
5142
5297
  return ((int32_t *)(tensor->data))[i];
5143
- } break;
5298
+ }
5144
5299
  case GGML_TYPE_F16:
5145
5300
  {
5146
5301
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
5147
5302
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
5148
- } break;
5303
+ }
5149
5304
  case GGML_TYPE_F32:
5150
5305
  {
5151
5306
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
5152
5307
  return ((float *)(tensor->data))[i];
5153
- } break;
5308
+ }
5154
5309
  default:
5155
5310
  {
5156
5311
  GGML_ASSERT(false);
5157
- } break;
5312
+ }
5158
5313
  }
5159
5314
 
5160
5315
  return 0.0f;
5161
5316
  }
5162
5317
 
5163
5318
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
5319
+ if (!ggml_is_contiguous(tensor)) {
5320
+ int64_t id[4] = { 0, 0, 0, 0 };
5321
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
5322
+ ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
5323
+ return;
5324
+ }
5164
5325
  switch (tensor->type) {
5165
5326
  case GGML_TYPE_I8:
5166
5327
  {
@@ -5194,6 +5355,56 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
5194
5355
  }
5195
5356
  }
5196
5357
 
5358
+ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
5359
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5360
+ switch (tensor->type) {
5361
+ case GGML_TYPE_I8:
5362
+ return ((int8_t *) data)[0];
5363
+ case GGML_TYPE_I16:
5364
+ return ((int16_t *) data)[0];
5365
+ case GGML_TYPE_I32:
5366
+ return ((int32_t *) data)[0];
5367
+ case GGML_TYPE_F16:
5368
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
5369
+ case GGML_TYPE_F32:
5370
+ return ((float *) data)[0];
5371
+ default:
5372
+ GGML_ASSERT(false);
5373
+ }
5374
+
5375
+ return 0.0f;
5376
+ }
5377
+
5378
+ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
5379
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
5380
+ switch (tensor->type) {
5381
+ case GGML_TYPE_I8:
5382
+ {
5383
+ ((int8_t *)(data))[0] = value;
5384
+ } break;
5385
+ case GGML_TYPE_I16:
5386
+ {
5387
+ ((int16_t *)(data))[0] = value;
5388
+ } break;
5389
+ case GGML_TYPE_I32:
5390
+ {
5391
+ ((int32_t *)(data))[0] = value;
5392
+ } break;
5393
+ case GGML_TYPE_F16:
5394
+ {
5395
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
5396
+ } break;
5397
+ case GGML_TYPE_F32:
5398
+ {
5399
+ ((float *)(data))[0] = value;
5400
+ } break;
5401
+ default:
5402
+ {
5403
+ GGML_ASSERT(false);
5404
+ } break;
5405
+ }
5406
+ }
5407
+
5197
5408
  void * ggml_get_data(const struct ggml_tensor * tensor) {
5198
5409
  return tensor->data;
5199
5410
  }
@@ -5336,6 +5547,44 @@ struct ggml_tensor * ggml_add_inplace(
5336
5547
  return ggml_add_impl(ctx, a, b, true);
5337
5548
  }
5338
5549
 
5550
+ // ggml_add_cast
5551
+
5552
+ static struct ggml_tensor * ggml_add_cast_impl(
5553
+ struct ggml_context * ctx,
5554
+ struct ggml_tensor * a,
5555
+ struct ggml_tensor * b,
5556
+ enum ggml_type type) {
5557
+ // TODO: support less-strict constraint
5558
+ // GGML_ASSERT(ggml_can_repeat(b, a));
5559
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
5560
+ GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
5561
+
5562
+ bool is_node = false;
5563
+
5564
+ if (a->grad || b->grad) {
5565
+ // TODO: support backward pass for broadcasting
5566
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5567
+ is_node = true;
5568
+ }
5569
+
5570
+ struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
5571
+
5572
+ result->op = GGML_OP_ADD;
5573
+ result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
5574
+ result->src[0] = a;
5575
+ result->src[1] = b;
5576
+
5577
+ return result;
5578
+ }
5579
+
5580
+ struct ggml_tensor * ggml_add_cast(
5581
+ struct ggml_context * ctx,
5582
+ struct ggml_tensor * a,
5583
+ struct ggml_tensor * b,
5584
+ enum ggml_type type) {
5585
+ return ggml_add_cast_impl(ctx, a, b, type);
5586
+ }
5587
+
5339
5588
  // ggml_add1
5340
5589
 
5341
5590
  static struct ggml_tensor * ggml_add1_impl(
@@ -5772,7 +6021,6 @@ struct ggml_tensor * ggml_repeat(
5772
6021
  result->op = GGML_OP_REPEAT;
5773
6022
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5774
6023
  result->src[0] = a;
5775
- result->src[1] = b;
5776
6024
 
5777
6025
  return result;
5778
6026
  }
@@ -5800,7 +6048,6 @@ struct ggml_tensor * ggml_repeat_back(
5800
6048
  result->op = GGML_OP_REPEAT_BACK;
5801
6049
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5802
6050
  result->src[0] = a;
5803
- result->src[1] = b;
5804
6051
 
5805
6052
  return result;
5806
6053
  }
@@ -6175,8 +6422,9 @@ struct ggml_tensor * ggml_out_prod(
6175
6422
  is_node = true;
6176
6423
  }
6177
6424
 
6178
- const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
6179
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
6425
+ // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
6426
+ const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
6427
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
6180
6428
 
6181
6429
  result->op = GGML_OP_OUT_PROD;
6182
6430
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6395,6 +6643,54 @@ struct ggml_tensor * ggml_cont_inplace(
6395
6643
  return ggml_cont_impl(ctx, a, true);
6396
6644
  }
6397
6645
 
6646
+
6647
+ // make contiguous, with new shape
6648
+ GGML_API struct ggml_tensor * ggml_cont_1d(
6649
+ struct ggml_context * ctx,
6650
+ struct ggml_tensor * a,
6651
+ int64_t ne0) {
6652
+ return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
6653
+ }
6654
+
6655
+ GGML_API struct ggml_tensor * ggml_cont_2d(
6656
+ struct ggml_context * ctx,
6657
+ struct ggml_tensor * a,
6658
+ int64_t ne0,
6659
+ int64_t ne1) {
6660
+ return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
6661
+ }
6662
+
6663
+ GGML_API struct ggml_tensor * ggml_cont_3d(
6664
+ struct ggml_context * ctx,
6665
+ struct ggml_tensor * a,
6666
+ int64_t ne0,
6667
+ int64_t ne1,
6668
+ int64_t ne2) {
6669
+ return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
6670
+ }
6671
+
6672
+ struct ggml_tensor * ggml_cont_4d(
6673
+ struct ggml_context * ctx,
6674
+ struct ggml_tensor * a,
6675
+ int64_t ne0,
6676
+ int64_t ne1,
6677
+ int64_t ne2,
6678
+ int64_t ne3) {
6679
+ GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
6680
+
6681
+ bool is_node = false;
6682
+
6683
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
6684
+ ggml_format_name(result, "%s (cont)", a->name);
6685
+
6686
+ result->op = GGML_OP_CONT;
6687
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6688
+ result->src[0] = a;
6689
+
6690
+ return result;
6691
+ }
6692
+
6693
+
6398
6694
  // ggml_reshape
6399
6695
 
6400
6696
  struct ggml_tensor * ggml_reshape(
@@ -6402,7 +6698,7 @@ struct ggml_tensor * ggml_reshape(
6402
6698
  struct ggml_tensor * a,
6403
6699
  struct ggml_tensor * b) {
6404
6700
  GGML_ASSERT(ggml_is_contiguous(a));
6405
- GGML_ASSERT(ggml_is_contiguous(b));
6701
+ // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
6406
6702
  GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
6407
6703
 
6408
6704
  bool is_node = false;
@@ -6775,7 +7071,6 @@ struct ggml_tensor * ggml_get_rows_back(
6775
7071
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6776
7072
  result->src[0] = a;
6777
7073
  result->src[1] = b;
6778
- result->src[2] = c;
6779
7074
 
6780
7075
  return result;
6781
7076
  }
@@ -6957,7 +7252,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
6957
7252
  static struct ggml_tensor * ggml_rope_impl(
6958
7253
  struct ggml_context * ctx,
6959
7254
  struct ggml_tensor * a,
6960
- int n_past,
7255
+ struct ggml_tensor * b,
6961
7256
  int n_dims,
6962
7257
  int mode,
6963
7258
  int n_ctx,
@@ -6966,7 +7261,10 @@ static struct ggml_tensor * ggml_rope_impl(
6966
7261
  float xpos_base,
6967
7262
  bool xpos_down,
6968
7263
  bool inplace) {
6969
- GGML_ASSERT(n_past >= 0);
7264
+ GGML_ASSERT(ggml_is_vector(b));
7265
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
7266
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
7267
+
6970
7268
  bool is_node = false;
6971
7269
 
6972
7270
  if (a->grad) {
@@ -6975,7 +7273,7 @@ static struct ggml_tensor * ggml_rope_impl(
6975
7273
 
6976
7274
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6977
7275
 
6978
- int32_t params[8] = { n_past, n_dims, mode, n_ctx };
7276
+ int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
6979
7277
  memcpy(params + 4, &freq_base, sizeof(float));
6980
7278
  memcpy(params + 5, &freq_scale, sizeof(float));
6981
7279
  memcpy(params + 6, &xpos_base, sizeof(float));
@@ -6985,6 +7283,7 @@ static struct ggml_tensor * ggml_rope_impl(
6985
7283
  result->op = GGML_OP_ROPE;
6986
7284
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6987
7285
  result->src[0] = a;
7286
+ result->src[1] = b;
6988
7287
 
6989
7288
  return result;
6990
7289
  }
@@ -6992,55 +7291,55 @@ static struct ggml_tensor * ggml_rope_impl(
6992
7291
  struct ggml_tensor * ggml_rope(
6993
7292
  struct ggml_context * ctx,
6994
7293
  struct ggml_tensor * a,
6995
- int n_past,
7294
+ struct ggml_tensor * b,
6996
7295
  int n_dims,
6997
7296
  int mode,
6998
7297
  int n_ctx) {
6999
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
7298
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
7000
7299
  }
7001
7300
 
7002
7301
  struct ggml_tensor * ggml_rope_inplace(
7003
7302
  struct ggml_context * ctx,
7004
7303
  struct ggml_tensor * a,
7005
- int n_past,
7304
+ struct ggml_tensor * b,
7006
7305
  int n_dims,
7007
7306
  int mode,
7008
7307
  int n_ctx) {
7009
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
7308
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
7010
7309
  }
7011
7310
 
7012
7311
  struct ggml_tensor * ggml_rope_custom(
7013
7312
  struct ggml_context * ctx,
7014
7313
  struct ggml_tensor * a,
7015
- int n_past,
7314
+ struct ggml_tensor * b,
7016
7315
  int n_dims,
7017
7316
  int mode,
7018
7317
  int n_ctx,
7019
7318
  float freq_base,
7020
7319
  float freq_scale) {
7021
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
7320
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
7022
7321
  }
7023
7322
 
7024
7323
  struct ggml_tensor * ggml_rope_custom_inplace(
7025
7324
  struct ggml_context * ctx,
7026
7325
  struct ggml_tensor * a,
7027
- int n_past,
7326
+ struct ggml_tensor * b,
7028
7327
  int n_dims,
7029
7328
  int mode,
7030
7329
  int n_ctx,
7031
7330
  float freq_base,
7032
7331
  float freq_scale) {
7033
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
7332
+ return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
7034
7333
  }
7035
7334
 
7036
7335
  struct ggml_tensor * ggml_rope_xpos_inplace(
7037
7336
  struct ggml_context * ctx,
7038
7337
  struct ggml_tensor * a,
7039
- int n_past,
7338
+ struct ggml_tensor * b,
7040
7339
  int n_dims,
7041
7340
  float base,
7042
7341
  bool down) {
7043
- return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
7342
+ return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
7044
7343
  }
7045
7344
 
7046
7345
  // ggml_rope_back
@@ -7048,7 +7347,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
7048
7347
  struct ggml_tensor * ggml_rope_back(
7049
7348
  struct ggml_context * ctx,
7050
7349
  struct ggml_tensor * a,
7051
- int n_past,
7350
+ struct ggml_tensor * b,
7052
7351
  int n_dims,
7053
7352
  int mode,
7054
7353
  int n_ctx,
@@ -7056,7 +7355,10 @@ struct ggml_tensor * ggml_rope_back(
7056
7355
  float freq_scale,
7057
7356
  float xpos_base,
7058
7357
  bool xpos_down) {
7059
- GGML_ASSERT(n_past >= 0);
7358
+ GGML_ASSERT(ggml_is_vector(b));
7359
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
7360
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
7361
+
7060
7362
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
7061
7363
 
7062
7364
  bool is_node = false;
@@ -7067,7 +7369,7 @@ struct ggml_tensor * ggml_rope_back(
7067
7369
 
7068
7370
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
7069
7371
 
7070
- int32_t params[8] = { n_past, n_dims, mode, n_ctx };
7372
+ int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
7071
7373
  memcpy(params + 4, &freq_base, sizeof(float));
7072
7374
  memcpy(params + 5, &freq_scale, sizeof(float));
7073
7375
  memcpy(params + 6, &xpos_base, sizeof(float));
@@ -7077,6 +7379,7 @@ struct ggml_tensor * ggml_rope_back(
7077
7379
  result->op = GGML_OP_ROPE_BACK;
7078
7380
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7079
7381
  result->src[0] = a;
7382
+ result->src[1] = b;
7080
7383
 
7081
7384
  return result;
7082
7385
  }
@@ -7473,27 +7776,30 @@ struct ggml_tensor * ggml_flash_attn_back(
7473
7776
 
7474
7777
  // d shape [D,N,ne2,ne3]
7475
7778
  // q shape [D,N,ne2,ne3]
7476
- // k shape [D,M,ne2,ne3]
7477
- // v shape [M,D,ne2,ne3]
7779
+ // k shape [D,M,kvne2,ne3]
7780
+ // v shape [M,D,kvne2,ne3]
7478
7781
 
7479
- const int64_t D = q->ne[0];
7480
- const int64_t N = q->ne[1];
7481
- const int64_t M = k->ne[1];
7482
- const int64_t ne2 = q->ne[2];
7483
- const int64_t ne3 = q->ne[3];
7782
+ const int64_t D = q->ne[0];
7783
+ const int64_t N = q->ne[1];
7784
+ const int64_t M = k->ne[1];
7785
+ const int64_t ne2 = q->ne[2];
7786
+ const int64_t ne3 = q->ne[3];
7787
+ const int64_t kvne2 = k->ne[2];
7484
7788
 
7485
7789
  GGML_ASSERT(k->ne[0] == D);
7486
7790
  GGML_ASSERT(v->ne[0] == M);
7487
7791
  GGML_ASSERT(v->ne[1] == D);
7488
7792
  GGML_ASSERT(d->ne[0] == D);
7489
7793
  GGML_ASSERT(d->ne[1] == N);
7490
- GGML_ASSERT(k->ne[2] == ne2);
7794
+ GGML_ASSERT(k->ne[2] == kvne2);
7491
7795
  GGML_ASSERT(k->ne[3] == ne3);
7492
- GGML_ASSERT(v->ne[2] == ne2);
7796
+ GGML_ASSERT(v->ne[2] == kvne2);
7493
7797
  GGML_ASSERT(v->ne[3] == ne3);
7494
7798
  GGML_ASSERT(d->ne[2] == ne2);
7495
7799
  GGML_ASSERT(d->ne[3] == ne3);
7496
7800
 
7801
+ GGML_ASSERT(ne2 % kvne2 == 0);
7802
+
7497
7803
  bool is_node = false;
7498
7804
 
7499
7805
  if (q->grad || k->grad || v->grad) {
@@ -7503,14 +7809,23 @@ struct ggml_tensor * ggml_flash_attn_back(
7503
7809
  }
7504
7810
 
7505
7811
  // store gradients of q, k and v as continuous tensors concatenated in result.
7506
- // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
7507
- // gradq->data = result->data
7508
- // gradk->data = result->data + nb0*D*N*ne2*ne3
7509
- // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
7510
7812
  // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
7511
- int64_t ne[4] = {D,M+N+M,ne2,ne3};
7813
+ const int64_t elem_q = ggml_nelements(q);
7814
+ const int64_t elem_k = ggml_nelements(k);
7815
+ const int64_t elem_v = ggml_nelements(v);
7512
7816
 
7513
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7817
+ enum ggml_type result_type = GGML_TYPE_F32;
7818
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
7819
+ const size_t tsize = ggml_type_size(result_type);
7820
+
7821
+ const size_t offs_q = 0;
7822
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
7823
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
7824
+ const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
7825
+
7826
+ const size_t nelements = (end + tsize - 1)/tsize;
7827
+
7828
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
7514
7829
 
7515
7830
  int32_t masked_i = masked ? 1 : 0;
7516
7831
  ggml_set_op_params(result, &masked_i, sizeof(masked_i));
@@ -8203,7 +8518,7 @@ static void ggml_compute_forward_dup_f16(
8203
8518
  return;
8204
8519
  }
8205
8520
 
8206
- GGML_TENSOR_UNARY_OP_LOCALS;
8521
+ GGML_TENSOR_UNARY_OP_LOCALS
8207
8522
 
8208
8523
  const int ith = params->ith; // thread index
8209
8524
  const int nth = params->nth; // number of threads
@@ -8474,7 +8789,7 @@ static void ggml_compute_forward_dup_f32(
8474
8789
  return;
8475
8790
  }
8476
8791
 
8477
- GGML_TENSOR_UNARY_OP_LOCALS;
8792
+ GGML_TENSOR_UNARY_OP_LOCALS
8478
8793
 
8479
8794
  const int ith = params->ith; // thread index
8480
8795
  const int nth = params->nth; // number of threads
@@ -8755,7 +9070,7 @@ static void ggml_compute_forward_add_f32(
8755
9070
 
8756
9071
  const int nr = ggml_nrows(src0);
8757
9072
 
8758
- GGML_TENSOR_BINARY_OP_LOCALS;
9073
+ GGML_TENSOR_BINARY_OP_LOCALS
8759
9074
 
8760
9075
  GGML_ASSERT( nb0 == sizeof(float));
8761
9076
  GGML_ASSERT(nb00 == sizeof(float));
@@ -8787,8 +9102,6 @@ static void ggml_compute_forward_add_f32(
8787
9102
  #else
8788
9103
  ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
8789
9104
  #endif
8790
- // }
8791
- // }
8792
9105
  }
8793
9106
  } else {
8794
9107
  // src1 is not contiguous
@@ -8830,7 +9143,7 @@ static void ggml_compute_forward_add_f16_f32(
8830
9143
 
8831
9144
  const int nr = ggml_nrows(src0);
8832
9145
 
8833
- GGML_TENSOR_BINARY_OP_LOCALS;
9146
+ GGML_TENSOR_BINARY_OP_LOCALS
8834
9147
 
8835
9148
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
8836
9149
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -8884,7 +9197,7 @@ static void ggml_compute_forward_add_f16_f16(
8884
9197
 
8885
9198
  const int nr = ggml_nrows(src0);
8886
9199
 
8887
- GGML_TENSOR_BINARY_OP_LOCALS;
9200
+ GGML_TENSOR_BINARY_OP_LOCALS
8888
9201
 
8889
9202
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
8890
9203
  GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -8935,14 +9248,15 @@ static void ggml_compute_forward_add_q_f32(
8935
9248
 
8936
9249
  const int nr = ggml_nrows(src0);
8937
9250
 
8938
- GGML_TENSOR_BINARY_OP_LOCALS;
9251
+ GGML_TENSOR_BINARY_OP_LOCALS
8939
9252
 
8940
9253
  const int ith = params->ith;
8941
9254
  const int nth = params->nth;
8942
9255
 
8943
9256
  const enum ggml_type type = src0->type;
9257
+ const enum ggml_type dtype = dst->type;
8944
9258
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
8945
- ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
9259
+ ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
8946
9260
 
8947
9261
  // we don't support permuted src0 or src1
8948
9262
  GGML_ASSERT(nb00 == ggml_type_size(type));
@@ -8954,7 +9268,6 @@ static void ggml_compute_forward_add_q_f32(
8954
9268
  GGML_ASSERT(nb2 <= nb3);
8955
9269
 
8956
9270
  GGML_ASSERT(ggml_is_quantized(src0->type));
8957
- GGML_ASSERT(dst->type == src0->type);
8958
9271
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
8959
9272
 
8960
9273
  // rows per thread
@@ -8992,7 +9305,11 @@ static void ggml_compute_forward_add_q_f32(
8992
9305
  // add src1
8993
9306
  ggml_vec_acc_f32(ne00, wdata, src1_row);
8994
9307
  // quantize row to dst
8995
- quantize_row_q(wdata, dst_row, ne00);
9308
+ if (quantize_row_q != NULL) {
9309
+ quantize_row_q(wdata, dst_row, ne00);
9310
+ } else {
9311
+ memcpy(dst_row, wdata, ne0*nb0);
9312
+ }
8996
9313
  }
8997
9314
  }
8998
9315
 
@@ -9057,7 +9374,7 @@ static void ggml_compute_forward_add1_f32(
9057
9374
 
9058
9375
  const int nr = ggml_nrows(src0);
9059
9376
 
9060
- GGML_TENSOR_UNARY_OP_LOCALS;
9377
+ GGML_TENSOR_UNARY_OP_LOCALS
9061
9378
 
9062
9379
  GGML_ASSERT( nb0 == sizeof(float));
9063
9380
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9112,7 +9429,7 @@ static void ggml_compute_forward_add1_f16_f32(
9112
9429
 
9113
9430
  const int nr = ggml_nrows(src0);
9114
9431
 
9115
- GGML_TENSOR_UNARY_OP_LOCALS;
9432
+ GGML_TENSOR_UNARY_OP_LOCALS
9116
9433
 
9117
9434
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9118
9435
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -9162,7 +9479,7 @@ static void ggml_compute_forward_add1_f16_f16(
9162
9479
 
9163
9480
  const int nr = ggml_nrows(src0);
9164
9481
 
9165
- GGML_TENSOR_UNARY_OP_LOCALS;
9482
+ GGML_TENSOR_UNARY_OP_LOCALS
9166
9483
 
9167
9484
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9168
9485
  GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -9212,7 +9529,7 @@ static void ggml_compute_forward_add1_q_f32(
9212
9529
 
9213
9530
  const int nr = ggml_nrows(src0);
9214
9531
 
9215
- GGML_TENSOR_UNARY_OP_LOCALS;
9532
+ GGML_TENSOR_UNARY_OP_LOCALS
9216
9533
 
9217
9534
  const enum ggml_type type = src0->type;
9218
9535
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
@@ -9340,8 +9657,8 @@ static void ggml_compute_forward_acc_f32(
9340
9657
  const int nr = ggml_nrows(src1);
9341
9658
  const int nc = src1->ne[0];
9342
9659
 
9343
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
9344
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
9660
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
9661
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
9345
9662
 
9346
9663
  // src0 and dst as viewed during acc
9347
9664
  const size_t nb0 = ggml_element_size(src0);
@@ -9430,7 +9747,7 @@ static void ggml_compute_forward_sub_f32(
9430
9747
 
9431
9748
  const int nr = ggml_nrows(src0);
9432
9749
 
9433
- GGML_TENSOR_BINARY_OP_LOCALS;
9750
+ GGML_TENSOR_BINARY_OP_LOCALS
9434
9751
 
9435
9752
  GGML_ASSERT( nb0 == sizeof(float));
9436
9753
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9520,7 +9837,7 @@ static void ggml_compute_forward_mul_f32(
9520
9837
 
9521
9838
  const int64_t nr = ggml_nrows(src0);
9522
9839
 
9523
- GGML_TENSOR_BINARY_OP_LOCALS;
9840
+ GGML_TENSOR_BINARY_OP_LOCALS
9524
9841
 
9525
9842
  GGML_ASSERT( nb0 == sizeof(float));
9526
9843
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9611,7 +9928,7 @@ static void ggml_compute_forward_div_f32(
9611
9928
 
9612
9929
  const int nr = ggml_nrows(src0);
9613
9930
 
9614
- GGML_TENSOR_BINARY_OP_LOCALS;
9931
+ GGML_TENSOR_BINARY_OP_LOCALS
9615
9932
 
9616
9933
  GGML_ASSERT( nb0 == sizeof(float));
9617
9934
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9820,8 +10137,8 @@ static void ggml_compute_forward_sum_f32(
9820
10137
  assert(ggml_is_scalar(dst));
9821
10138
  assert(src0->nb[0] == sizeof(float));
9822
10139
 
9823
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
9824
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb);
10140
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10141
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
9825
10142
 
9826
10143
  ggml_float sum = 0;
9827
10144
  ggml_float row_sum = 0;
@@ -9852,8 +10169,8 @@ static void ggml_compute_forward_sum_f16(
9852
10169
 
9853
10170
  assert(src0->nb[0] == sizeof(ggml_fp16_t));
9854
10171
 
9855
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
9856
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb);
10172
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10173
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
9857
10174
 
9858
10175
  float sum = 0;
9859
10176
  float row_sum = 0;
@@ -9906,7 +10223,7 @@ static void ggml_compute_forward_sum_rows_f32(
9906
10223
  GGML_ASSERT(src0->nb[0] == sizeof(float));
9907
10224
  GGML_ASSERT(dst->nb[0] == sizeof(float));
9908
10225
 
9909
- GGML_TENSOR_UNARY_OP_LOCALS;
10226
+ GGML_TENSOR_UNARY_OP_LOCALS
9910
10227
 
9911
10228
  GGML_ASSERT(ne0 == 1);
9912
10229
  GGML_ASSERT(ne1 == ne01);
@@ -9956,7 +10273,7 @@ static void ggml_compute_forward_mean_f32(
9956
10273
 
9957
10274
  assert(src0->nb[0] == sizeof(float));
9958
10275
 
9959
- GGML_TENSOR_UNARY_OP_LOCALS;
10276
+ GGML_TENSOR_UNARY_OP_LOCALS
9960
10277
 
9961
10278
  assert(ne0 == 1);
9962
10279
  assert(ne1 == ne01);
@@ -10056,7 +10373,7 @@ static void ggml_compute_forward_repeat_f32(
10056
10373
  return;
10057
10374
  }
10058
10375
 
10059
- GGML_TENSOR_UNARY_OP_LOCALS;
10376
+ GGML_TENSOR_UNARY_OP_LOCALS
10060
10377
 
10061
10378
  // guaranteed to be an integer due to the check in ggml_can_repeat
10062
10379
  const int nr0 = (int)(ne0/ne00);
@@ -10088,11 +10405,61 @@ static void ggml_compute_forward_repeat_f32(
10088
10405
  }
10089
10406
  }
10090
10407
 
10408
+ static void ggml_compute_forward_repeat_f16(
10409
+ const struct ggml_compute_params * params,
10410
+ const struct ggml_tensor * src0,
10411
+ struct ggml_tensor * dst) {
10412
+ GGML_ASSERT(params->ith == 0);
10413
+ GGML_ASSERT(ggml_can_repeat(src0, dst));
10414
+
10415
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10416
+ return;
10417
+ }
10418
+
10419
+ GGML_TENSOR_UNARY_OP_LOCALS;
10420
+
10421
+ // guaranteed to be an integer due to the check in ggml_can_repeat
10422
+ const int nr0 = (int)(ne0/ne00);
10423
+ const int nr1 = (int)(ne1/ne01);
10424
+ const int nr2 = (int)(ne2/ne02);
10425
+ const int nr3 = (int)(ne3/ne03);
10426
+
10427
+ // TODO: support for transposed / permuted tensors
10428
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
10429
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
10430
+
10431
+ // TODO: maybe this is not optimal?
10432
+ for (int i3 = 0; i3 < nr3; i3++) {
10433
+ for (int k3 = 0; k3 < ne03; k3++) {
10434
+ for (int i2 = 0; i2 < nr2; i2++) {
10435
+ for (int k2 = 0; k2 < ne02; k2++) {
10436
+ for (int i1 = 0; i1 < nr1; i1++) {
10437
+ for (int k1 = 0; k1 < ne01; k1++) {
10438
+ for (int i0 = 0; i0 < nr0; i0++) {
10439
+ ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
10440
+ ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
10441
+ // ggml_vec_cpy_f16(ne00, y, x)
10442
+ for (int i = 0; i < ne00; ++i) {
10443
+ y[i] = x[i];
10444
+ }
10445
+ }
10446
+ }
10447
+ }
10448
+ }
10449
+ }
10450
+ }
10451
+ }
10452
+ }
10453
+
10091
10454
  static void ggml_compute_forward_repeat(
10092
10455
  const struct ggml_compute_params * params,
10093
10456
  const struct ggml_tensor * src0,
10094
10457
  struct ggml_tensor * dst) {
10095
10458
  switch (src0->type) {
10459
+ case GGML_TYPE_F16:
10460
+ {
10461
+ ggml_compute_forward_repeat_f16(params, src0, dst);
10462
+ } break;
10096
10463
  case GGML_TYPE_F32:
10097
10464
  {
10098
10465
  ggml_compute_forward_repeat_f32(params, src0, dst);
@@ -10117,7 +10484,7 @@ static void ggml_compute_forward_repeat_back_f32(
10117
10484
  return;
10118
10485
  }
10119
10486
 
10120
- GGML_TENSOR_UNARY_OP_LOCALS;
10487
+ GGML_TENSOR_UNARY_OP_LOCALS
10121
10488
 
10122
10489
  // guaranteed to be an integer due to the check in ggml_can_repeat
10123
10490
  const int nr0 = (int)(ne00/ne0);
@@ -10195,7 +10562,7 @@ static void ggml_compute_forward_concat_f32(
10195
10562
 
10196
10563
  const int ith = params->ith;
10197
10564
 
10198
- GGML_TENSOR_BINARY_OP_LOCALS;
10565
+ GGML_TENSOR_BINARY_OP_LOCALS
10199
10566
 
10200
10567
  // TODO: support for transposed / permuted tensors
10201
10568
  GGML_ASSERT(nb0 == sizeof(float));
@@ -10797,7 +11164,7 @@ static void ggml_compute_forward_norm_f32(
10797
11164
  const int ith = params->ith;
10798
11165
  const int nth = params->nth;
10799
11166
 
10800
- GGML_TENSOR_UNARY_OP_LOCALS;
11167
+ GGML_TENSOR_UNARY_OP_LOCALS
10801
11168
 
10802
11169
  float eps;
10803
11170
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -10866,7 +11233,7 @@ static void ggml_compute_forward_rms_norm_f32(
10866
11233
  const int ith = params->ith;
10867
11234
  const int nth = params->nth;
10868
11235
 
10869
- GGML_TENSOR_UNARY_OP_LOCALS;
11236
+ GGML_TENSOR_UNARY_OP_LOCALS
10870
11237
 
10871
11238
  float eps;
10872
11239
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -10931,7 +11298,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
10931
11298
  const int ith = params->ith;
10932
11299
  const int nth = params->nth;
10933
11300
 
10934
- GGML_TENSOR_BINARY_OP_LOCALS;
11301
+ GGML_TENSOR_BINARY_OP_LOCALS
10935
11302
 
10936
11303
  float eps;
10937
11304
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -11106,7 +11473,7 @@ static void ggml_compute_forward_group_norm_f32(
11106
11473
  const int ith = params->ith;
11107
11474
  const int nth = params->nth;
11108
11475
 
11109
- GGML_TENSOR_UNARY_OP_LOCALS;
11476
+ GGML_TENSOR_UNARY_OP_LOCALS
11110
11477
 
11111
11478
  const float eps = 1e-6f; // TODO: make this a parameter
11112
11479
 
@@ -11217,7 +11584,7 @@ static void ggml_compute_forward_mul_mat(
11217
11584
  int64_t t0 = ggml_perf_time_us();
11218
11585
  UNUSED(t0);
11219
11586
 
11220
- GGML_TENSOR_BINARY_OP_LOCALS;
11587
+ GGML_TENSOR_BINARY_OP_LOCALS
11221
11588
 
11222
11589
  const int ith = params->ith;
11223
11590
  const int nth = params->nth;
@@ -11432,10 +11799,10 @@ static void ggml_compute_forward_out_prod_f32(
11432
11799
  const struct ggml_tensor * src0,
11433
11800
  const struct ggml_tensor * src1,
11434
11801
  struct ggml_tensor * dst) {
11435
- int64_t t0 = ggml_perf_time_us();
11436
- UNUSED(t0);
11802
+ // int64_t t0 = ggml_perf_time_us();
11803
+ // UNUSED(t0);
11437
11804
 
11438
- GGML_TENSOR_BINARY_OP_LOCALS;
11805
+ GGML_TENSOR_BINARY_OP_LOCALS
11439
11806
 
11440
11807
  const int ith = params->ith;
11441
11808
  const int nth = params->nth;
@@ -11474,6 +11841,146 @@ static void ggml_compute_forward_out_prod_f32(
11474
11841
  return;
11475
11842
  }
11476
11843
 
11844
+ // dst[:,:,:,:] = 0
11845
+ // for i2,i3:
11846
+ // for i1:
11847
+ // for i01:
11848
+ // for i0:
11849
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
11850
+
11851
+ // parallelize by last three dimensions
11852
+
11853
+ // total rows in dst
11854
+ const int64_t nr = ne1*ne2*ne3;
11855
+
11856
+ // rows per thread
11857
+ const int64_t dr = (nr + nth - 1)/nth;
11858
+
11859
+ // row range for this thread
11860
+ const int64_t ir0 = dr*ith;
11861
+ const int64_t ir1 = MIN(ir0 + dr, nr);
11862
+
11863
+ // block-tiling attempt
11864
+ const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
11865
+ const int64_t blck_1 = 16;
11866
+
11867
+ for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
11868
+ const int64_t bir1 = MIN(bir + blck_1, ir1);
11869
+ for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
11870
+ const int64_t bne01 = MIN(bi01 + blck_0, ne01);
11871
+ for (int64_t ir = bir; ir < bir1; ++ir) {
11872
+ // dst indices
11873
+ const int64_t i3 = ir/(ne2*ne1);
11874
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
11875
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
11876
+
11877
+ const int64_t i02 = i2;
11878
+ const int64_t i03 = i3;
11879
+
11880
+ //const int64_t i10 = i1;
11881
+ const int64_t i12 = i2;
11882
+ const int64_t i13 = i3;
11883
+
11884
+ #if GGML_VEC_MAD_UNROLL > 2
11885
+ const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
11886
+ for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
11887
+ const int64_t i11 = i01;
11888
+
11889
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11890
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11891
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11892
+
11893
+ ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
11894
+ }
11895
+ for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
11896
+ const int64_t i11 = i01;
11897
+
11898
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11899
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11900
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11901
+
11902
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
11903
+ }
11904
+ #else
11905
+ for (int64_t i01 = bi01; i01 < bne01; ++i01) {
11906
+ const int64_t i11 = i01;
11907
+
11908
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
11909
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11910
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11911
+
11912
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
11913
+ }
11914
+ #endif
11915
+ }
11916
+ }
11917
+ }
11918
+
11919
+
11920
+ //int64_t t1 = ggml_perf_time_us();
11921
+ //static int64_t acc = 0;
11922
+ //acc += t1 - t0;
11923
+ //if (t1 - t0 > 10) {
11924
+ // printf("\n");
11925
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
11926
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
11927
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
11928
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
11929
+
11930
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
11931
+ //}
11932
+ }
11933
+
11934
+ static void ggml_compute_forward_out_prod_q_f32(
11935
+ const struct ggml_compute_params * params,
11936
+ const struct ggml_tensor * src0,
11937
+ const struct ggml_tensor * src1,
11938
+ struct ggml_tensor * dst) {
11939
+ // int64_t t0 = ggml_perf_time_us();
11940
+ // UNUSED(t0);
11941
+
11942
+ GGML_TENSOR_BINARY_OP_LOCALS;
11943
+
11944
+ const int ith = params->ith;
11945
+ const int nth = params->nth;
11946
+
11947
+ const enum ggml_type type = src0->type;
11948
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
11949
+
11950
+ GGML_ASSERT(ne02 == ne12);
11951
+ GGML_ASSERT(ne03 == ne13);
11952
+ GGML_ASSERT(ne2 == ne12);
11953
+ GGML_ASSERT(ne3 == ne13);
11954
+
11955
+ // we don't support permuted src0 dim0
11956
+ GGML_ASSERT(nb00 == ggml_type_size(type));
11957
+
11958
+ // dst dim0 cannot be transposed or permuted
11959
+ GGML_ASSERT(nb0 == sizeof(float));
11960
+ // GGML_ASSERT(nb0 <= nb1);
11961
+ // GGML_ASSERT(nb1 <= nb2);
11962
+ // GGML_ASSERT(nb2 <= nb3);
11963
+
11964
+ GGML_ASSERT(ne0 == ne00);
11965
+ GGML_ASSERT(ne1 == ne10);
11966
+ GGML_ASSERT(ne2 == ne02);
11967
+ GGML_ASSERT(ne3 == ne03);
11968
+
11969
+ // nb01 >= nb00 - src0 is not transposed
11970
+ // compute by src0 rows
11971
+
11972
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
11973
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
11974
+
11975
+ if (params->type == GGML_TASK_INIT) {
11976
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
11977
+ return;
11978
+ }
11979
+
11980
+ if (params->type == GGML_TASK_FINALIZE) {
11981
+ return;
11982
+ }
11983
+
11477
11984
  // parallelize by last three dimensions
11478
11985
 
11479
11986
  // total rows in dst
@@ -11493,6 +12000,8 @@ static void ggml_compute_forward_out_prod_f32(
11493
12000
  // for i0:
11494
12001
  // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
11495
12002
 
12003
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
12004
+
11496
12005
  for (int64_t ir = ir0; ir < ir1; ++ir) {
11497
12006
  // dst indices
11498
12007
  const int64_t i3 = ir/(ne2*ne1);
@@ -11513,10 +12022,8 @@ static void ggml_compute_forward_out_prod_f32(
11513
12022
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
11514
12023
  float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
11515
12024
 
11516
- ggml_vec_mad_f32(ne0, d, s0, *s1);
11517
- // for (int64_t i0 = 0; i0 < ne0; ++i0) {
11518
- // d[i0] += s0[i0] * s1[i1];
11519
- // }
12025
+ dequantize_row_q(s0, wdata, ne0);
12026
+ ggml_vec_mad_f32(ne0, d, wdata, *s1);
11520
12027
  }
11521
12028
  }
11522
12029
 
@@ -11545,10 +12052,13 @@ static void ggml_compute_forward_out_prod(
11545
12052
  case GGML_TYPE_Q5_0:
11546
12053
  case GGML_TYPE_Q5_1:
11547
12054
  case GGML_TYPE_Q8_0:
11548
- case GGML_TYPE_Q8_1:
12055
+ case GGML_TYPE_Q2_K:
12056
+ case GGML_TYPE_Q3_K:
12057
+ case GGML_TYPE_Q4_K:
12058
+ case GGML_TYPE_Q5_K:
12059
+ case GGML_TYPE_Q6_K:
11549
12060
  {
11550
- GGML_ASSERT(false); // todo
11551
- // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
12061
+ ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
11552
12062
  } break;
11553
12063
  case GGML_TYPE_F16:
11554
12064
  {
@@ -11666,8 +12176,8 @@ static void ggml_compute_forward_set_f32(
11666
12176
  const int nr = ggml_nrows(src1);
11667
12177
  const int nc = src1->ne[0];
11668
12178
 
11669
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
11670
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
12179
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
12180
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
11671
12181
 
11672
12182
  // src0 and dst as viewed during set
11673
12183
  const size_t nb0 = ggml_element_size(src0);
@@ -11936,14 +12446,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
11936
12446
  const struct ggml_compute_params * params,
11937
12447
  const struct ggml_tensor * src0,
11938
12448
  const struct ggml_tensor * src1,
11939
- const struct ggml_tensor * opt0,
11940
12449
  struct ggml_tensor * dst) {
11941
12450
  GGML_ASSERT(params->ith == 0);
11942
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
11943
- GGML_ASSERT(ggml_is_contiguous(opt0));
11944
12451
  GGML_ASSERT(ggml_is_contiguous(dst));
11945
12452
 
11946
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
12453
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
12454
+
12455
+ if (params->type == GGML_TASK_INIT) {
12456
+ memset(dst->data, 0, ggml_nbytes(dst));
12457
+ }
11947
12458
 
11948
12459
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11949
12460
  return;
@@ -11969,11 +12480,8 @@ static void ggml_compute_forward_get_rows_back_f32(
11969
12480
  const struct ggml_compute_params * params,
11970
12481
  const struct ggml_tensor * src0,
11971
12482
  const struct ggml_tensor * src1,
11972
- const struct ggml_tensor * opt0,
11973
12483
  struct ggml_tensor * dst) {
11974
12484
  GGML_ASSERT(params->ith == 0);
11975
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
11976
- GGML_ASSERT(ggml_is_contiguous(opt0));
11977
12485
  GGML_ASSERT(ggml_is_contiguous(dst));
11978
12486
 
11979
12487
  // ggml_compute_forward_dup_same_cont(params, opt0, dst);
@@ -12007,16 +12515,15 @@ static void ggml_compute_forward_get_rows_back(
12007
12515
  const struct ggml_compute_params * params,
12008
12516
  const struct ggml_tensor * src0,
12009
12517
  const struct ggml_tensor * src1,
12010
- const struct ggml_tensor * opt0,
12011
12518
  struct ggml_tensor * dst) {
12012
12519
  switch (src0->type) {
12013
12520
  case GGML_TYPE_F16:
12014
12521
  {
12015
- ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst);
12522
+ ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
12016
12523
  } break;
12017
12524
  case GGML_TYPE_F32:
12018
12525
  {
12019
- ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst);
12526
+ ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
12020
12527
  } break;
12021
12528
  default:
12022
12529
  {
@@ -12057,7 +12564,7 @@ static void ggml_compute_forward_diag_f32(
12057
12564
 
12058
12565
  // TODO: handle transposed/permuted matrices
12059
12566
 
12060
- GGML_TENSOR_UNARY_OP_LOCALS;
12567
+ GGML_TENSOR_UNARY_OP_LOCALS
12061
12568
 
12062
12569
  GGML_ASSERT(ne00 == ne0);
12063
12570
  GGML_ASSERT(ne00 == ne1);
@@ -12445,13 +12952,11 @@ static void ggml_compute_forward_alibi_f16(
12445
12952
  return;
12446
12953
  }
12447
12954
 
12448
- const int n_past = ((int32_t *) dst->op_params)[0];
12955
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12449
12956
  const int n_head = ((int32_t *) dst->op_params)[1];
12450
12957
  float max_bias;
12451
12958
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12452
12959
 
12453
- assert(n_past >= 0);
12454
-
12455
12960
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12456
12961
  const int ne1 = src0->ne[1]; // seq_len_without_past
12457
12962
  const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -12466,7 +12971,7 @@ static void ggml_compute_forward_alibi_f16(
12466
12971
  //const int nb3 = src0->nb[3];
12467
12972
 
12468
12973
  GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
12469
- GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12974
+ //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12470
12975
  GGML_ASSERT(n_head == ne2);
12471
12976
 
12472
12977
  // add alibi to src0 (KQ_scaled)
@@ -12612,8 +13117,8 @@ static void ggml_compute_forward_clamp(
12612
13117
  static void ggml_compute_forward_rope_f32(
12613
13118
  const struct ggml_compute_params * params,
12614
13119
  const struct ggml_tensor * src0,
13120
+ const struct ggml_tensor * src1,
12615
13121
  struct ggml_tensor * dst) {
12616
-
12617
13122
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12618
13123
  return;
12619
13124
  }
@@ -12623,9 +13128,9 @@ static void ggml_compute_forward_rope_f32(
12623
13128
 
12624
13129
  // these two only relevant for xPos RoPE:
12625
13130
  float xpos_base;
12626
- bool xpos_down;
13131
+ bool xpos_down;
12627
13132
 
12628
- const int n_past = ((int32_t *) dst->op_params)[0];
13133
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12629
13134
  const int n_dims = ((int32_t *) dst->op_params)[1];
12630
13135
  const int mode = ((int32_t *) dst->op_params)[2];
12631
13136
  const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -12634,9 +13139,7 @@ static void ggml_compute_forward_rope_f32(
12634
13139
  memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
12635
13140
  memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
12636
13141
 
12637
- assert(n_past >= 0);
12638
-
12639
- GGML_TENSOR_UNARY_OP_LOCALS;
13142
+ GGML_TENSOR_UNARY_OP_LOCALS
12640
13143
 
12641
13144
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12642
13145
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12666,9 +13169,11 @@ static void ggml_compute_forward_rope_f32(
12666
13169
  const bool is_neox = mode & 2;
12667
13170
  const bool is_glm = mode & 4;
12668
13171
 
13172
+ const int32_t * pos = (const int32_t *) src1->data;
13173
+
12669
13174
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12670
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12671
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13175
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13176
+ const int64_t p = pos[i2];
12672
13177
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12673
13178
  if (ir++ < ir0) continue;
12674
13179
  if (ir > ir1) break;
@@ -12705,7 +13210,7 @@ static void ggml_compute_forward_rope_f32(
12705
13210
  const float cos_theta = cosf(theta);
12706
13211
  const float sin_theta = sinf(theta);
12707
13212
  // zeta scaling for xPos only:
12708
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
13213
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
12709
13214
  if (xpos_down) zeta = 1.0f / zeta;
12710
13215
 
12711
13216
  theta *= theta_scale;
@@ -12750,8 +13255,8 @@ static void ggml_compute_forward_rope_f32(
12750
13255
  static void ggml_compute_forward_rope_f16(
12751
13256
  const struct ggml_compute_params * params,
12752
13257
  const struct ggml_tensor * src0,
13258
+ const struct ggml_tensor * src1,
12753
13259
  struct ggml_tensor * dst) {
12754
-
12755
13260
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12756
13261
  return;
12757
13262
  }
@@ -12759,16 +13264,14 @@ static void ggml_compute_forward_rope_f16(
12759
13264
  float freq_base;
12760
13265
  float freq_scale;
12761
13266
 
12762
- const int n_past = ((int32_t *) dst->op_params)[0];
13267
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12763
13268
  const int n_dims = ((int32_t *) dst->op_params)[1];
12764
13269
  const int mode = ((int32_t *) dst->op_params)[2];
12765
13270
  const int n_ctx = ((int32_t *) dst->op_params)[3];
12766
13271
  memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
12767
13272
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
12768
13273
 
12769
- assert(n_past >= 0);
12770
-
12771
- GGML_TENSOR_UNARY_OP_LOCALS;
13274
+ GGML_TENSOR_UNARY_OP_LOCALS
12772
13275
 
12773
13276
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12774
13277
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12798,9 +13301,11 @@ static void ggml_compute_forward_rope_f16(
12798
13301
  const bool is_neox = mode & 2;
12799
13302
  const bool is_glm = mode & 4;
12800
13303
 
13304
+ const int32_t * pos = (const int32_t *) src1->data;
13305
+
12801
13306
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12802
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12803
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13307
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13308
+ const int64_t p = pos[i2];
12804
13309
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12805
13310
  if (ir++ < ir0) continue;
12806
13311
  if (ir > ir1) break;
@@ -12879,15 +13384,16 @@ static void ggml_compute_forward_rope_f16(
12879
13384
  static void ggml_compute_forward_rope(
12880
13385
  const struct ggml_compute_params * params,
12881
13386
  const struct ggml_tensor * src0,
13387
+ const struct ggml_tensor * src1,
12882
13388
  struct ggml_tensor * dst) {
12883
13389
  switch (src0->type) {
12884
13390
  case GGML_TYPE_F16:
12885
13391
  {
12886
- ggml_compute_forward_rope_f16(params, src0, dst);
13392
+ ggml_compute_forward_rope_f16(params, src0, src1, dst);
12887
13393
  } break;
12888
13394
  case GGML_TYPE_F32:
12889
13395
  {
12890
- ggml_compute_forward_rope_f32(params, src0, dst);
13396
+ ggml_compute_forward_rope_f32(params, src0, src1, dst);
12891
13397
  } break;
12892
13398
  default:
12893
13399
  {
@@ -12901,6 +13407,7 @@ static void ggml_compute_forward_rope(
12901
13407
  static void ggml_compute_forward_rope_back_f32(
12902
13408
  const struct ggml_compute_params * params,
12903
13409
  const struct ggml_tensor * src0,
13410
+ const struct ggml_tensor * src1,
12904
13411
  struct ggml_tensor * dst) {
12905
13412
 
12906
13413
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -12918,7 +13425,7 @@ static void ggml_compute_forward_rope_back_f32(
12918
13425
  float xpos_base;
12919
13426
  bool xpos_down;
12920
13427
 
12921
- const int n_past = ((int32_t *) dst->op_params)[0];
13428
+ //const int n_past = ((int32_t *) dst->op_params)[0];
12922
13429
  const int n_dims = ((int32_t *) dst->op_params)[1];
12923
13430
  const int mode = ((int32_t *) dst->op_params)[2];
12924
13431
  const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
@@ -12927,9 +13434,7 @@ static void ggml_compute_forward_rope_back_f32(
12927
13434
  memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
12928
13435
  memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
12929
13436
 
12930
- assert(n_past >= 0);
12931
-
12932
- GGML_TENSOR_UNARY_OP_LOCALS;
13437
+ GGML_TENSOR_UNARY_OP_LOCALS
12933
13438
 
12934
13439
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
12935
13440
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12955,9 +13460,11 @@ static void ggml_compute_forward_rope_back_f32(
12955
13460
 
12956
13461
  const bool is_neox = mode & 2;
12957
13462
 
13463
+ const int32_t * pos = (const int32_t *) src1->data;
13464
+
12958
13465
  for (int64_t i3 = 0; i3 < ne3; i3++) {
12959
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
12960
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13466
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13467
+ const int64_t p = pos[i2];
12961
13468
  for (int64_t i1 = 0; i1 < ne1; i1++) {
12962
13469
  if (ir++ < ir0) continue;
12963
13470
  if (ir > ir1) break;
@@ -12969,7 +13476,7 @@ static void ggml_compute_forward_rope_back_f32(
12969
13476
  const float cos_theta = cosf(theta);
12970
13477
  const float sin_theta = sinf(theta);
12971
13478
  // zeta scaling for xPos only:
12972
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
13479
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
12973
13480
  if (xpos_down) zeta = 1.0f / zeta;
12974
13481
 
12975
13482
  theta *= theta_scale;
@@ -13012,6 +13519,7 @@ static void ggml_compute_forward_rope_back_f32(
13012
13519
  static void ggml_compute_forward_rope_back_f16(
13013
13520
  const struct ggml_compute_params * params,
13014
13521
  const struct ggml_tensor * src0,
13522
+ const struct ggml_tensor * src1,
13015
13523
  struct ggml_tensor * dst) {
13016
13524
 
13017
13525
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -13022,13 +13530,11 @@ static void ggml_compute_forward_rope_back_f16(
13022
13530
  // dx = rope_back(dy, src1)
13023
13531
  // src0 is dy, src1 contains options
13024
13532
 
13025
- const int n_past = ((int32_t *) dst->op_params)[0];
13533
+ //const int n_past = ((int32_t *) dst->op_params)[0];
13026
13534
  const int n_dims = ((int32_t *) dst->op_params)[1];
13027
13535
  const int mode = ((int32_t *) dst->op_params)[2];
13028
13536
 
13029
- assert(n_past >= 0);
13030
-
13031
- GGML_TENSOR_UNARY_OP_LOCALS;
13537
+ GGML_TENSOR_UNARY_OP_LOCALS
13032
13538
 
13033
13539
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
13034
13540
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -13054,9 +13560,11 @@ static void ggml_compute_forward_rope_back_f16(
13054
13560
 
13055
13561
  const bool is_neox = mode & 2;
13056
13562
 
13563
+ const int32_t * pos = (const int32_t *) src1->data;
13564
+
13057
13565
  for (int64_t i3 = 0; i3 < ne3; i3++) {
13058
- for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
13059
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
13566
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
13567
+ const int64_t p = pos[i2];
13060
13568
  for (int64_t i1 = 0; i1 < ne1; i1++) {
13061
13569
  if (ir++ < ir0) continue;
13062
13570
  if (ir > ir1) break;
@@ -13108,15 +13616,16 @@ static void ggml_compute_forward_rope_back_f16(
13108
13616
  static void ggml_compute_forward_rope_back(
13109
13617
  const struct ggml_compute_params * params,
13110
13618
  const struct ggml_tensor * src0,
13619
+ const struct ggml_tensor * src1,
13111
13620
  struct ggml_tensor * dst) {
13112
13621
  switch (src0->type) {
13113
13622
  case GGML_TYPE_F16:
13114
13623
  {
13115
- ggml_compute_forward_rope_back_f16(params, src0, dst);
13624
+ ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
13116
13625
  } break;
13117
13626
  case GGML_TYPE_F32:
13118
13627
  {
13119
- ggml_compute_forward_rope_back_f32(params, src0, dst);
13628
+ ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
13120
13629
  } break;
13121
13630
  default:
13122
13631
  {
@@ -13139,7 +13648,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
13139
13648
  int64_t t0 = ggml_perf_time_us();
13140
13649
  UNUSED(t0);
13141
13650
 
13142
- GGML_TENSOR_BINARY_OP_LOCALS;
13651
+ GGML_TENSOR_BINARY_OP_LOCALS
13143
13652
 
13144
13653
  const int ith = params->ith;
13145
13654
  const int nth = params->nth;
@@ -13230,7 +13739,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
13230
13739
  int64_t t0 = ggml_perf_time_us();
13231
13740
  UNUSED(t0);
13232
13741
 
13233
- GGML_TENSOR_BINARY_OP_LOCALS;
13742
+ GGML_TENSOR_BINARY_OP_LOCALS
13234
13743
 
13235
13744
  const int ith = params->ith;
13236
13745
  const int nth = params->nth;
@@ -13342,7 +13851,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
13342
13851
  int64_t t0 = ggml_perf_time_us();
13343
13852
  UNUSED(t0);
13344
13853
 
13345
- GGML_TENSOR_BINARY_OP_LOCALS;
13854
+ GGML_TENSOR_BINARY_OP_LOCALS
13346
13855
 
13347
13856
  const int ith = params->ith;
13348
13857
  const int nth = params->nth;
@@ -13433,7 +13942,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
13433
13942
  int64_t t0 = ggml_perf_time_us();
13434
13943
  UNUSED(t0);
13435
13944
 
13436
- GGML_TENSOR_BINARY_OP_LOCALS;
13945
+ GGML_TENSOR_BINARY_OP_LOCALS
13437
13946
 
13438
13947
  const int ith = params->ith;
13439
13948
  const int nth = params->nth;
@@ -13551,7 +14060,7 @@ static void ggml_compute_forward_conv_1d(
13551
14060
  ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
13552
14061
  } else {
13553
14062
  GGML_ASSERT(false); // only stride 1 and 2 supported
13554
- };
14063
+ }
13555
14064
  }
13556
14065
 
13557
14066
  // ggml_compute_forward_conv_2d
@@ -13568,7 +14077,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
13568
14077
  int64_t t0 = ggml_perf_time_us();
13569
14078
  UNUSED(t0);
13570
14079
 
13571
- GGML_TENSOR_BINARY_OP_LOCALS;
14080
+ GGML_TENSOR_BINARY_OP_LOCALS
13572
14081
 
13573
14082
  const int ith = params->ith;
13574
14083
  const int nth = params->nth;
@@ -13688,7 +14197,7 @@ static void ggml_compute_forward_conv_transpose_2d(
13688
14197
  int64_t t0 = ggml_perf_time_us();
13689
14198
  UNUSED(t0);
13690
14199
 
13691
- GGML_TENSOR_BINARY_OP_LOCALS;
14200
+ GGML_TENSOR_BINARY_OP_LOCALS
13692
14201
 
13693
14202
  const int ith = params->ith;
13694
14203
  const int nth = params->nth;
@@ -13947,7 +14456,7 @@ static void ggml_compute_forward_upscale_f32(
13947
14456
 
13948
14457
  const int ith = params->ith;
13949
14458
 
13950
- GGML_TENSOR_UNARY_OP_LOCALS;
14459
+ GGML_TENSOR_UNARY_OP_LOCALS
13951
14460
 
13952
14461
  const int scale_factor = dst->op_params[0];
13953
14462
 
@@ -13999,14 +14508,14 @@ static void ggml_compute_forward_flash_attn_f32(
13999
14508
  int64_t t0 = ggml_perf_time_us();
14000
14509
  UNUSED(t0);
14001
14510
 
14002
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14003
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14004
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14005
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14006
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14007
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14008
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14009
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14511
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14512
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14513
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14514
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14515
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14516
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14517
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14518
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14010
14519
 
14011
14520
  const int ith = params->ith;
14012
14521
  const int nth = params->nth;
@@ -14076,10 +14585,11 @@ static void ggml_compute_forward_flash_attn_f32(
14076
14585
  S[i] = -INFINITY;
14077
14586
  }
14078
14587
 
14079
- for (int64_t ic = 0; ic < nek1; ++ic) {
14588
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
14589
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
14080
14590
  // k indices
14081
14591
  const int ik3 = iq3;
14082
- const int ik2 = iq2;
14592
+ const int ik2 = iq2 % nek2;
14083
14593
  const int ik1 = ic;
14084
14594
 
14085
14595
  // S indices
@@ -14092,20 +14602,18 @@ static void ggml_compute_forward_flash_attn_f32(
14092
14602
  }
14093
14603
 
14094
14604
  // scale
14095
- ggml_vec_scale_f32(nek1, S, scale);
14605
+ ggml_vec_scale_f32(masked_begin, S, scale);
14096
14606
 
14097
- if (masked) {
14098
- for (int64_t i = P; i < M; i++) {
14099
- if (i > P + iq1) {
14100
- S[i] = -INFINITY;
14101
- }
14102
- }
14607
+ for (int64_t i = masked_begin; i < M; i++) {
14608
+ S[i] = -INFINITY;
14103
14609
  }
14104
14610
 
14105
14611
  // softmax
14612
+ // exclude known -INF S[..] values from max and loop
14613
+ // dont forget to set their SW values to zero
14106
14614
  {
14107
14615
  float max = -INFINITY;
14108
- ggml_vec_max_f32(M, &max, S);
14616
+ ggml_vec_max_f32(masked_begin, &max, S);
14109
14617
 
14110
14618
  ggml_float sum = 0.0;
14111
14619
  {
@@ -14119,10 +14627,15 @@ static void ggml_compute_forward_flash_attn_f32(
14119
14627
  ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14120
14628
 
14121
14629
  for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14630
+ if (i >= masked_begin) {
14631
+ break;
14632
+ }
14122
14633
  float * SS = S + i;
14123
14634
 
14124
14635
  for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14125
- if (SS[j] == -INFINITY) {
14636
+ if (i + j >= masked_begin) {
14637
+ break;
14638
+ } else if (SS[j] == -INFINITY) {
14126
14639
  SS[j] = 0.0f;
14127
14640
  } else {
14128
14641
  #ifndef GGML_FLASH_ATTN_EXP_FP16
@@ -14147,10 +14660,10 @@ static void ggml_compute_forward_flash_attn_f32(
14147
14660
  assert(sum > 0.0);
14148
14661
 
14149
14662
  sum = 1.0/sum;
14150
- ggml_vec_scale_f32(M, S, sum);
14663
+ ggml_vec_scale_f32(masked_begin, S, sum);
14151
14664
 
14152
14665
  #ifndef NDEBUG
14153
- for (int i = 0; i < M; ++i) {
14666
+ for (int i = 0; i < masked_begin; ++i) {
14154
14667
  assert(!isnan(S[i]));
14155
14668
  assert(!isinf(S[i]));
14156
14669
  }
@@ -14163,9 +14676,13 @@ static void ggml_compute_forward_flash_attn_f32(
14163
14676
  const int i2 = iq2;
14164
14677
  const int i3 = iq3;
14165
14678
 
14166
- ggml_vec_dot_f32(nek1,
14167
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14168
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14679
+ // v indices
14680
+ const int iv2 = iq2 % nev2;
14681
+ const int iv3 = iq3;
14682
+
14683
+ ggml_vec_dot_f32(masked_begin,
14684
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14685
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14169
14686
  S);
14170
14687
  }
14171
14688
  }
@@ -14181,14 +14698,14 @@ static void ggml_compute_forward_flash_attn_f16(
14181
14698
  int64_t t0 = ggml_perf_time_us();
14182
14699
  UNUSED(t0);
14183
14700
 
14184
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14185
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14186
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14187
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14188
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14189
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14190
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14191
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14701
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14702
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14703
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14704
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14705
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14706
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14707
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14708
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14192
14709
 
14193
14710
  const int ith = params->ith;
14194
14711
  const int nth = params->nth;
@@ -14262,7 +14779,7 @@ static void ggml_compute_forward_flash_attn_f16(
14262
14779
  for (int64_t ic = 0; ic < nek1; ++ic) {
14263
14780
  // k indices
14264
14781
  const int ik3 = iq3;
14265
- const int ik2 = iq2;
14782
+ const int ik2 = iq2 % nek2;
14266
14783
  const int ik1 = ic;
14267
14784
 
14268
14785
  // S indices
@@ -14277,7 +14794,7 @@ static void ggml_compute_forward_flash_attn_f16(
14277
14794
  for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
14278
14795
  // k indices
14279
14796
  const int ik3 = iq3;
14280
- const int ik2 = iq2;
14797
+ const int ik2 = iq2 % nek2;
14281
14798
  const int ik1 = ic;
14282
14799
 
14283
14800
  // S indices
@@ -14302,6 +14819,8 @@ static void ggml_compute_forward_flash_attn_f16(
14302
14819
  }
14303
14820
 
14304
14821
  // softmax
14822
+ // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
14823
+ // dont forget to set their S values to zero
14305
14824
  {
14306
14825
  float max = -INFINITY;
14307
14826
  ggml_vec_max_f32(M, &max, S);
@@ -14358,6 +14877,7 @@ static void ggml_compute_forward_flash_attn_f16(
14358
14877
  S16[i] = GGML_FP32_TO_FP16(S[i]);
14359
14878
  }
14360
14879
 
14880
+ // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
14361
14881
  if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
14362
14882
  for (int64_t ic = 0; ic < nev1; ++ic) {
14363
14883
  // dst indices
@@ -14365,9 +14885,13 @@ static void ggml_compute_forward_flash_attn_f16(
14365
14885
  const int i2 = iq2;
14366
14886
  const int i3 = iq3;
14367
14887
 
14368
- ggml_vec_dot_f16(nek1,
14369
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14370
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14888
+ // v indices
14889
+ const int iv2 = iq2 % nev2;
14890
+ const int iv3 = iq3;
14891
+
14892
+ ggml_vec_dot_f16(nev0,
14893
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14894
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14371
14895
  S16);
14372
14896
  }
14373
14897
  } else {
@@ -14377,9 +14901,13 @@ static void ggml_compute_forward_flash_attn_f16(
14377
14901
  const int i2 = iq2;
14378
14902
  const int i3 = iq3;
14379
14903
 
14380
- ggml_vec_dot_f16_unroll(nek1, nbv1,
14381
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14382
- ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14904
+ // v indices
14905
+ const int iv2 = iq2 % nev2;
14906
+ const int iv3 = iq3;
14907
+
14908
+ ggml_vec_dot_f16_unroll(nev0, nbv1,
14909
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
14910
+ ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
14383
14911
  S16);
14384
14912
  }
14385
14913
  }
@@ -14422,18 +14950,18 @@ static void ggml_compute_forward_flash_ff_f16(
14422
14950
  int64_t t0 = ggml_perf_time_us();
14423
14951
  UNUSED(t0);
14424
14952
 
14425
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne);
14426
- GGML_TENSOR_LOCALS(size_t, nba, a, nb);
14427
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne);
14428
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb);
14429
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne);
14430
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb);
14431
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne);
14432
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb);
14433
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne);
14434
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb);
14435
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14436
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
14953
+ GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
14954
+ GGML_TENSOR_LOCALS(size_t, nba, a, nb)
14955
+ GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
14956
+ GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
14957
+ GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
14958
+ GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
14959
+ GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
14960
+ GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
14961
+ GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
14962
+ GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
14963
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14964
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14437
14965
 
14438
14966
  const int ith = params->ith;
14439
14967
  const int nth = params->nth;
@@ -14581,16 +15109,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
14581
15109
  int64_t t0 = ggml_perf_time_us();
14582
15110
  UNUSED(t0);
14583
15111
 
14584
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne);
14585
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb);
14586
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne);
14587
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb);
14588
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne);
14589
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb);
14590
- GGML_TENSOR_LOCALS(int64_t, ned, d, ne);
14591
- GGML_TENSOR_LOCALS(size_t, nbd, d, nb);
14592
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
14593
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
15112
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15113
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15114
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15115
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15116
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15117
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15118
+ GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
15119
+ GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
15120
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15121
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14594
15122
 
14595
15123
  const int ith = params->ith;
14596
15124
  const int nth = params->nth;
@@ -14638,10 +15166,37 @@ static void ggml_compute_forward_flash_attn_back_f32(
14638
15166
  return;
14639
15167
  }
14640
15168
 
14641
- // parallelize by q rows using ggml_vec_dot_f32
15169
+ const int64_t elem_q = ggml_nelements(q);
15170
+ const int64_t elem_k = ggml_nelements(k);
14642
15171
 
14643
- // total rows in q
14644
- const int nr = neq2*neq3;
15172
+ enum ggml_type result_type = dst->type;
15173
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
15174
+ const size_t tsize = ggml_type_size(result_type);
15175
+
15176
+ const size_t offs_q = 0;
15177
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
15178
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
15179
+
15180
+ void * grad_q = (char *) dst->data;
15181
+ void * grad_k = (char *) dst->data + offs_k;
15182
+ void * grad_v = (char *) dst->data + offs_v;
15183
+
15184
+ const size_t nbgq1 = nb0*neq0;
15185
+ const size_t nbgq2 = nb0*neq0*neq1;
15186
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
15187
+
15188
+ const size_t nbgk1 = nb0*nek0;
15189
+ const size_t nbgk2 = nb0*nek0*nek1;
15190
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
15191
+
15192
+ const size_t nbgv1 = nb0*nev0;
15193
+ const size_t nbgv2 = nb0*nev0*nev1;
15194
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
15195
+
15196
+ // parallelize by k rows using ggml_vec_dot_f32
15197
+
15198
+ // total rows in k
15199
+ const int nr = nek2*nek3;
14645
15200
 
14646
15201
  // rows per thread
14647
15202
  const int dr = (nr + nth - 1)/nth;
@@ -14654,268 +15209,243 @@ static void ggml_compute_forward_flash_attn_back_f32(
14654
15209
 
14655
15210
  //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
14656
15211
 
15212
+ // how often k2 (and v2) is repeated in q2
15213
+ int nrep = neq2/nek2;
15214
+
14657
15215
  for (int ir = ir0; ir < ir1; ++ir) {
14658
15216
  // q indices
14659
- const int iq3 = ir/(neq2);
14660
- const int iq2 = ir - iq3*neq2;
14661
- for ( int iq1 = 0; iq1 < neq1; ++iq1) {
15217
+ const int ik3 = ir/(nek2);
15218
+ const int ik2 = ir - ik3*nek2;
14662
15219
 
15220
+ const int iq3 = ik3;
15221
+ const int id3 = ik3;
15222
+ const int iv3 = ik3;
15223
+ const int iv2 = ik2;
14663
15224
 
14664
- // not sure about CACHE_LINE_SIZE_F32..
14665
- // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
14666
- float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
14667
- float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
15225
+ for (int irep = 0; irep < nrep; ++irep) {
15226
+ const int iq2 = ik2 + irep*nek2;
15227
+ const int id2 = iq2;
14668
15228
 
14669
- for (int i = M; i < Mup; ++i) {
14670
- S[i] = -INFINITY;
14671
- }
15229
+ // (ik2 + irep*nek2) % nek2 == ik2
15230
+ for (int iq1 = 0; iq1 < neq1; ++iq1) {
15231
+ const int id1 = iq1;
14672
15232
 
14673
- for (int64_t ic = 0; ic < nek1; ++ic) {
14674
- // k indices
14675
- const int ik3 = iq3;
14676
- const int ik2 = iq2;
14677
- const int ik1 = ic;
15233
+ // not sure about CACHE_LINE_SIZE_F32..
15234
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
15235
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
15236
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
14678
15237
 
14679
- // S indices
14680
- const int i1 = ik1;
15238
+ for (int i = M; i < Mup; ++i) {
15239
+ S[i] = -INFINITY;
15240
+ }
14681
15241
 
14682
- ggml_vec_dot_f32(neq0,
14683
- S + i1,
14684
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
14685
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14686
- }
15242
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15243
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15244
+ // k indices
15245
+ const int ik1 = ic;
14687
15246
 
14688
- // scale
14689
- ggml_vec_scale_f32(nek1, S, scale);
15247
+ // S indices
15248
+ const int i1 = ik1;
14690
15249
 
14691
- if (masked) {
14692
- for (int64_t i = P; i < M; i++) {
14693
- if (i > P + iq1) {
14694
- S[i] = -INFINITY;
14695
- }
15250
+ ggml_vec_dot_f32(neq0,
15251
+ S + i1,
15252
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15253
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14696
15254
  }
14697
- }
14698
15255
 
14699
- // softmax
14700
- {
14701
- float max = -INFINITY;
14702
- ggml_vec_max_f32(M, &max, S);
15256
+ // scale
15257
+ ggml_vec_scale_f32(masked_begin, S, scale);
14703
15258
 
14704
- ggml_float sum = 0.0;
15259
+ for (int64_t i = masked_begin; i < M; i++) {
15260
+ S[i] = -INFINITY;
15261
+ }
15262
+
15263
+ // softmax
15264
+ // exclude known -INF S[..] values from max and loop
15265
+ // dont forget to set their SM values to zero
14705
15266
  {
15267
+ float max = -INFINITY;
15268
+ ggml_vec_max_f32(masked_begin, &max, S);
15269
+
15270
+ ggml_float sum = 0.0;
15271
+ {
14706
15272
  #ifdef GGML_SOFT_MAX_ACCELERATE
14707
- max = -max;
14708
- vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
14709
- vvexpf(SM, SM, &Mup);
14710
- ggml_vec_sum_f32(Mup, &sum, SM);
15273
+ max = -max;
15274
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
15275
+ vvexpf(SM, SM, &Mup);
15276
+ ggml_vec_sum_f32(Mup, &sum, SM);
14711
15277
  #else
14712
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
14713
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15278
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15279
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14714
15280
 
14715
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14716
- float * SR = S + i;
14717
- float * SW = SM + i;
14718
-
14719
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14720
- if (SR[j] == -INFINITY) {
14721
- SW[j] = 0.0f;
14722
- } else {
15281
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15282
+ if (i >= masked_begin) {
15283
+ break;
15284
+ }
15285
+ float * SR = S + i;
15286
+ float * SW = SM + i;
15287
+
15288
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15289
+ if (i + j >= masked_begin) {
15290
+ break;
15291
+ } else if (SR[j] == -INFINITY) {
15292
+ SW[j] = 0.0f;
15293
+ } else {
14723
15294
  #ifndef GGML_FLASH_ATTN_EXP_FP16
14724
- const float val = expf(SR[j] - max);
15295
+ const float val = expf(SR[j] - max);
14725
15296
  #else
14726
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
14727
- memcpy(&scvt[j], &s, sizeof(uint16_t));
14728
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
15297
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
15298
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
15299
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
14729
15300
  #endif
14730
- sump[j] += (ggml_float)val;
14731
- SW[j] = val;
15301
+ sump[j] += (ggml_float)val;
15302
+ SW[j] = val;
15303
+ }
14732
15304
  }
14733
15305
  }
14734
- }
14735
15306
 
14736
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14737
- sum += sump[i];
14738
- }
15307
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15308
+ sum += sump[i];
15309
+ }
14739
15310
  #endif
14740
- }
14741
-
14742
- assert(sum > 0.0);
14743
-
14744
- sum = 1.0/sum;
14745
- ggml_vec_scale_f32(M, SM, sum);
14746
-
14747
- }
14748
-
14749
- // step-by-step explanation
14750
- {
14751
- // forward-process shape grads from backward process
14752
- // parallel_for iq2,iq3:
14753
- // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
14754
- // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
14755
- // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
14756
- // for iq1:
14757
- // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
14758
- // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
14759
- // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
14760
- // S0 = -Inf [D,1,1,1]
14761
- // ~S1[i] = dot(kcur[:D,i], qcur)
14762
- // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
14763
- // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
14764
- // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14765
- // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
14766
- // ~S5[i] = dot(vcur[:,i], S4)
14767
- // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
14768
- // ~dst[i,iq1,iq2,iq3] = S5[i] ^
14769
- // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
14770
- // dst backward-/ grad[dst] = d
14771
- //
14772
- // output gradients with their dependencies:
14773
- //
14774
- // grad[kcur] = grad[S1].T @ qcur
14775
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14776
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14777
- // grad[S4] = grad[S5] @ vcur
14778
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14779
- // grad[qcur] = grad[S1] @ kcur
14780
- // grad[vcur] = grad[S5].T @ S4
14781
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14782
- //
14783
- // in post-order:
14784
- //
14785
- // S1 = qcur @ kcur.T
14786
- // S2 = S1 * scale
14787
- // S3 = diag_mask_inf(S2, P)
14788
- // S4 = softmax(S3)
14789
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14790
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14791
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14792
- // grad[qcur] = grad[S1] @ kcur
14793
- // grad[kcur] = grad[S1].T @ qcur
14794
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14795
- //
14796
- // using less variables (SM=S4):
14797
- //
14798
- // S = diag_mask_inf(qcur @ kcur.T * scale, P)
14799
- // SM = softmax(S)
14800
- // S = d[:D,iq1,iq2,iq3] @ vcur
14801
- // dot_SM_gradSM = dot(SM, S)
14802
- // S = SM * (S - dot(SM, S))
14803
- // S = diag_mask_zero(S, P) * scale
14804
- //
14805
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14806
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14807
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14808
- }
14809
-
14810
- // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
14811
- // S = d[:D,iq1,iq2,iq3] @ vcur
14812
- // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
14813
- ggml_vec_set_f32(M, S, 0);
14814
- for (int64_t ic = 0; ic < D; ++ic) {
14815
- // dst indices
14816
- const int i1 = iq1;
14817
- const int i2 = iq2;
14818
- const int i3 = iq3;
15311
+ }
14819
15312
 
14820
- ggml_vec_mad_f32(M,
14821
- S,
14822
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14823
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
14824
- }
15313
+ assert(sum > 0.0);
14825
15314
 
14826
- // S = SM * (S - dot(SM, S))
14827
- float dot_SM_gradSM = 0;
14828
- ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
14829
- ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
14830
- ggml_vec_mul_f32 (M, S, S, SM);
15315
+ sum = 1.0/sum;
15316
+ ggml_vec_scale_f32(masked_begin, SM, sum);
14831
15317
 
14832
- // S = diag_mask_zero(S, P) * scale
14833
- if (masked) {
14834
- // for (int64_t i = P + iq1 + 1; i < M; i++) {
14835
- // S[i] = 0;
14836
- // }
14837
- for (int64_t i = P; i < M; i++) {
14838
- if (i > P + iq1) {
14839
- S[i] = 0;
14840
- }
14841
15318
  }
14842
- }
14843
- ggml_vec_scale_f32(M, S, scale);
14844
-
14845
- void * grad_q = (char *) dst->data;
14846
- void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
14847
- void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
14848
-
14849
- const size_t nbgq1 = nb0*neq0;
14850
- const size_t nbgq2 = nb0*neq0*neq1;
14851
- const size_t nbgq3 = nb0*neq0*neq1*neq2;
14852
-
14853
- const size_t nbgk1 = nb0*nek0;
14854
- const size_t nbgk2 = nb0*nek0*nek1;
14855
- const size_t nbgk3 = nb0*nek0*nek1*neq2;
14856
-
14857
- const size_t nbgv1 = nb0*nev0;
14858
- const size_t nbgv2 = nb0*nev0*nev1;
14859
- const size_t nbgv3 = nb0*nev0*nev1*neq2;
14860
-
14861
- // S shape [M,1]
14862
- // SM shape [M,1]
14863
- // kcur shape [D,M]
14864
- // qcur shape [D,1]
14865
- // vcur shape [M,D]
14866
- //
14867
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14868
- // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
14869
- // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
14870
- //
14871
- //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
14872
- //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
14873
- for (int64_t ic = 0; ic < M; ++ic) {
14874
- // dst indices
14875
- const int i1 = iq1;
14876
- const int i2 = iq2;
14877
- const int i3 = iq3;
14878
15319
 
14879
- ggml_vec_mad_f32(D,
14880
- (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
14881
- (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
14882
- S[ic]);
14883
- }
15320
+ // step-by-step explanation
15321
+ {
15322
+ // forward-process shape grads from backward process
15323
+ // parallel_for ik2,ik3:
15324
+ // for irep:
15325
+ // iq2 = ik2 + irep*nek2
15326
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
15327
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
15328
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
15329
+ // for iq1:
15330
+ // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
15331
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
15332
+ // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
15333
+ // S0 = -Inf [D,1,1,1]
15334
+ // ~S1[i] = dot(kcur[:D,i], qcur)
15335
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
15336
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
15337
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15338
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
15339
+ // ~S5[i] = dot(vcur[:,i], S4)
15340
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
15341
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
15342
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
15343
+ // dst backward-/ grad[dst] = d
15344
+ //
15345
+ // output gradients with their dependencies:
15346
+ //
15347
+ // grad[kcur] = grad[S1].T @ qcur
15348
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
15349
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15350
+ // grad[S4] = grad[S5] @ vcur
15351
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
15352
+ // grad[qcur] = grad[S1] @ kcur
15353
+ // grad[vcur] = grad[S5].T @ S4
15354
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
15355
+ //
15356
+ // in post-order:
15357
+ //
15358
+ // S1 = qcur @ kcur.T
15359
+ // S2 = S1 * scale
15360
+ // S3 = diag_mask_inf(S2, P)
15361
+ // S4 = softmax(S3)
15362
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
15363
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
15364
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
15365
+ // grad[qcur] = grad[S1] @ kcur
15366
+ // grad[kcur] = grad[S1].T @ qcur
15367
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
15368
+ //
15369
+ // using less variables (SM=S4):
15370
+ //
15371
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
15372
+ // SM = softmax(S)
15373
+ // S = d[:D,iq1,iq2,iq3] @ vcur
15374
+ // dot_SM_gradSM = dot(SM, S)
15375
+ // S = SM * (S - dot(SM, S))
15376
+ // S = diag_mask_zero(S, P) * scale
15377
+ //
15378
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
15379
+ // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
15380
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
15381
+ }
14884
15382
 
14885
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14886
- // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
14887
- // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
14888
- for (int64_t ic = 0; ic < M; ++ic) {
14889
- // dst indices
14890
- const int i1 = iq1;
14891
- const int i2 = iq2;
14892
- const int i3 = iq3;
15383
+ // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
15384
+ // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
15385
+ // for ic:
15386
+ // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
15387
+ // exclude known future zero S[..] values from operation
15388
+ ggml_vec_set_f32(masked_begin, S, 0);
15389
+ for (int64_t ic = 0; ic < D; ++ic) {
15390
+ ggml_vec_mad_f32(masked_begin,
15391
+ S,
15392
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15393
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
15394
+ }
14893
15395
 
14894
- // ggml_vec_set_f32(D,
14895
- // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14896
- // 0);
14897
- ggml_vec_mad_f32(D,
14898
- (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14899
- (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
14900
- S[ic]);
14901
- }
15396
+ // S = SM * (S - dot(SM, S))
15397
+ float dot_SM_gradSM = 0;
15398
+ ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
15399
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
15400
+ ggml_vec_mul_f32 (masked_begin, S, S, SM);
15401
+
15402
+ // S = diag_mask_zero(S, P) * scale
15403
+ // already done by above ggml_vec_set_f32
15404
+
15405
+ // exclude known zero S[..] values from operation
15406
+ ggml_vec_scale_f32(masked_begin, S, scale);
15407
+
15408
+ // S shape [M,1]
15409
+ // SM shape [M,1]
15410
+ // kcur shape [D,M]
15411
+ // qcur shape [D,1]
15412
+ // vcur shape [M,D]
15413
+
15414
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
15415
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
15416
+ // for ic:
15417
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
15418
+ // exclude known zero S[..] values from loop
15419
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15420
+ ggml_vec_mad_f32(D,
15421
+ (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
15422
+ (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
15423
+ S[ic]);
15424
+ }
14902
15425
 
14903
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14904
- // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
14905
- // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
14906
- for (int64_t ic = 0; ic < D; ++ic) {
14907
- // dst indices
14908
- const int i1 = iq1;
14909
- const int i2 = iq2;
14910
- const int i3 = iq3;
15426
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
15427
+ // for ic:
15428
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
15429
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
15430
+ // exclude known zero S[..] values from loop
15431
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
15432
+ ggml_vec_mad_f32(D,
15433
+ (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
15434
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
15435
+ S[ic]);
15436
+ }
14911
15437
 
14912
- // ggml_vec_set_f32(M,
14913
- // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14914
- // 0);
14915
- ggml_vec_mad_f32(M,
14916
- (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14917
- SM,
14918
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
15438
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
15439
+ // for ic:
15440
+ // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
15441
+ // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
15442
+ // exclude known zero SM[..] values from mad
15443
+ for (int64_t ic = 0; ic < D; ++ic) {
15444
+ ggml_vec_mad_f32(masked_begin,
15445
+ (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
15446
+ SM,
15447
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
15448
+ }
14919
15449
  }
14920
15450
  }
14921
15451
  }
@@ -14951,8 +15481,8 @@ static void ggml_compute_forward_win_part_f32(
14951
15481
  return;
14952
15482
  }
14953
15483
 
14954
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
14955
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
15484
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
15485
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14956
15486
 
14957
15487
  const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
14958
15488
  const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
@@ -15013,8 +15543,8 @@ static void ggml_compute_forward_win_unpart_f32(
15013
15543
  return;
15014
15544
  }
15015
15545
 
15016
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
15017
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
15546
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
15547
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15018
15548
 
15019
15549
  const int32_t w = ((const int32_t *)(dst->op_params))[0];
15020
15550
 
@@ -15131,7 +15661,7 @@ static void ggml_compute_forward_get_rel_pos_f16(
15131
15661
 
15132
15662
  // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
15133
15663
 
15134
- GGML_TENSOR_UNARY_OP_LOCALS;
15664
+ GGML_TENSOR_UNARY_OP_LOCALS
15135
15665
 
15136
15666
  const int64_t w = ne1;
15137
15667
 
@@ -15829,7 +16359,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15829
16359
  } break;
15830
16360
  case GGML_OP_GET_ROWS_BACK:
15831
16361
  {
15832
- ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
16362
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
15833
16363
  } break;
15834
16364
  case GGML_OP_DIAG:
15835
16365
  {
@@ -15853,11 +16383,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15853
16383
  } break;
15854
16384
  case GGML_OP_ROPE:
15855
16385
  {
15856
- ggml_compute_forward_rope(params, tensor->src[0], tensor);
16386
+ ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
15857
16387
  } break;
15858
16388
  case GGML_OP_ROPE_BACK:
15859
16389
  {
15860
- ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
16390
+ ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
15861
16391
  } break;
15862
16392
  case GGML_OP_ALIBI:
15863
16393
  {
@@ -16002,7 +16532,218 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16002
16532
 
16003
16533
  ////////////////////////////////////////////////////////////////////////////////
16004
16534
 
16005
- static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
16535
+ static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
16536
+
16537
+ static size_t hash(void * p) {
16538
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
16539
+ }
16540
+
16541
+ static size_t hash_find(void * hash_table[], void * p) {
16542
+ size_t h = hash(p);
16543
+
16544
+ // linear probing
16545
+ size_t i = h;
16546
+ while (hash_table[i] != NULL && hash_table[i] != p) {
16547
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
16548
+ if (i == h) {
16549
+ // visited all hash table entries -> not found
16550
+ return GGML_GRAPH_HASHTABLE_SIZE;
16551
+ }
16552
+ }
16553
+ return i;
16554
+ }
16555
+
16556
+ static bool hash_insert(void * hash_table[], void * p) {
16557
+ size_t i = hash_find(hash_table, p);
16558
+
16559
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16560
+
16561
+ if (hash_table[i] == p) {
16562
+ return true;
16563
+ }
16564
+
16565
+ // insert
16566
+ GGML_ASSERT(hash_table[i] == NULL);
16567
+ hash_table[i] = p;
16568
+ return false;
16569
+ }
16570
+
16571
+ static bool hash_contains(void * hash_table[], void * p) {
16572
+ size_t i = hash_find(hash_table, p);
16573
+ return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
16574
+ }
16575
+
16576
+ struct hash_map {
16577
+ void * keys[GGML_GRAPH_HASHTABLE_SIZE];
16578
+ void * vals[GGML_GRAPH_HASHTABLE_SIZE];
16579
+ };
16580
+
16581
+ static struct hash_map * new_hash_map(void) {
16582
+ struct hash_map * result = malloc(sizeof(struct hash_map));
16583
+ for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
16584
+ result->keys[i] = NULL;
16585
+ result->vals[i] = NULL;
16586
+ }
16587
+ return result;
16588
+ }
16589
+
16590
+ static void free_hash_map(struct hash_map * map) {
16591
+ free(map);
16592
+ }
16593
+
16594
+ // gradient checkpointing
16595
+
16596
+ static struct ggml_tensor * ggml_recompute_graph_node(
16597
+ struct ggml_context * ctx,
16598
+ struct ggml_cgraph * graph,
16599
+ struct hash_map * replacements,
16600
+ struct ggml_tensor * node) {
16601
+
16602
+ if (node == NULL) {
16603
+ return NULL;
16604
+ }
16605
+
16606
+ if (node->is_param) {
16607
+ return node;
16608
+ }
16609
+
16610
+ if (!hash_contains(graph->visited_hash_table, node)) {
16611
+ return node;
16612
+ }
16613
+
16614
+ int count_children = 0;
16615
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16616
+ if (node->src[k]) {
16617
+ ++count_children;
16618
+ }
16619
+ }
16620
+
16621
+ if (count_children == 0) {
16622
+ return node;
16623
+ }
16624
+
16625
+ size_t i = hash_find(replacements->keys, node);
16626
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16627
+ if (replacements->keys[i] == node) {
16628
+ return (struct ggml_tensor *) replacements->vals[i];
16629
+ }
16630
+
16631
+ struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
16632
+
16633
+ // insert clone into replacements
16634
+ GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
16635
+ replacements->keys[i] = node;
16636
+ replacements->vals[i] = clone;
16637
+
16638
+ clone->op = node->op;
16639
+ clone->grad = node->grad;
16640
+ clone->is_param = node->is_param;
16641
+ clone->extra = node->extra;
16642
+ for (int k = 0; k < GGML_MAX_DIMS; ++k) {
16643
+ clone->nb[k] = node->nb[k];
16644
+ }
16645
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16646
+ clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
16647
+ }
16648
+ if (node->view_src != NULL) {
16649
+ clone->data = (node->view_src->data == NULL)
16650
+ ? NULL // view_src not yet allocated
16651
+ : (char *) node->view_src->data // view_src already allocated
16652
+ + node->view_offs;
16653
+ clone->view_src = node->view_src;
16654
+ clone->view_offs = node->view_offs;
16655
+ }
16656
+
16657
+ GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
16658
+ GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
16659
+ memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
16660
+ ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
16661
+
16662
+ return clone;
16663
+ }
16664
+
16665
+ void ggml_build_backward_gradient_checkpointing(
16666
+ struct ggml_context * ctx,
16667
+ struct ggml_cgraph * gf,
16668
+ struct ggml_cgraph * gb,
16669
+ struct ggml_cgraph * gb_tmp,
16670
+ struct ggml_tensor * * checkpoints,
16671
+ int n_checkpoints) {
16672
+ *gb_tmp = *gf;
16673
+ ggml_build_backward_expand(ctx, gf, gb_tmp, true);
16674
+
16675
+ if (n_checkpoints <= 0) {
16676
+ *gb = *gb_tmp;
16677
+ return;
16678
+ }
16679
+
16680
+ struct hash_map * replacements = new_hash_map();
16681
+
16682
+ // insert checkpoints in replacements
16683
+ for (int i = 0; i < n_checkpoints; ++i) {
16684
+ size_t k = hash_find(replacements->keys, checkpoints[i]);
16685
+ GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
16686
+ GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
16687
+ replacements->keys[k] = checkpoints[i];
16688
+ replacements->vals[k] = checkpoints[i];
16689
+ }
16690
+
16691
+ *gb = *gf;
16692
+ // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
16693
+ // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
16694
+ // by recomputing them from checkpoints
16695
+ for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
16696
+ struct ggml_tensor * node = gb_tmp->nodes[i];
16697
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
16698
+ // insert new tensors recomputing src, reusing already made replacements,
16699
+ // remember replacements: remember new tensors with mapping from corresponding gf nodes
16700
+ // recurse for input tensors,
16701
+ // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
16702
+ node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
16703
+ }
16704
+ // insert rewritten backward node with replacements made into resulting backward graph gb
16705
+ ggml_build_forward_expand(gb, node);
16706
+ }
16707
+
16708
+ free_hash_map(replacements);
16709
+ }
16710
+
16711
+ // functions to change gradients considering the case that input a might be initial gradient with zero value
16712
+
16713
+ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16714
+ if (hash_contains(zero_table, a)) {
16715
+ return b;
16716
+ } else {
16717
+ return ggml_add_impl(ctx, a, b, false);
16718
+ }
16719
+ }
16720
+
16721
+ static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
16722
+ if (hash_contains(zero_table, a)) {
16723
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
16724
+ return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
16725
+ } else {
16726
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
16727
+ }
16728
+ }
16729
+
16730
+ static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16731
+ if (hash_contains(zero_table, a)) {
16732
+ return ggml_repeat(ctx, b, a);
16733
+ } else {
16734
+ return ggml_add1_impl(ctx, a, b, false);
16735
+ }
16736
+ }
16737
+
16738
+ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
16739
+ if (hash_contains(zero_table, a)) {
16740
+ return ggml_neg(ctx, b);
16741
+ } else {
16742
+ return ggml_sub_impl(ctx, a, b, false);
16743
+ }
16744
+ }
16745
+
16746
+ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
16006
16747
  struct ggml_tensor * src0 = tensor->src[0];
16007
16748
  struct ggml_tensor * src1 = tensor->src[1];
16008
16749
 
@@ -16010,34 +16751,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16010
16751
  case GGML_OP_DUP:
16011
16752
  {
16012
16753
  if (src0->grad) {
16013
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16754
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16014
16755
  }
16015
16756
  } break;
16016
16757
  case GGML_OP_ADD:
16017
16758
  {
16018
16759
  if (src0->grad) {
16019
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16760
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16020
16761
  }
16021
16762
  if (src1->grad) {
16022
- src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
16763
+ src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
16023
16764
  }
16024
16765
  } break;
16025
16766
  case GGML_OP_ADD1:
16026
16767
  {
16027
16768
  if (src0->grad) {
16028
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16769
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16029
16770
  }
16030
16771
  if (src1->grad) {
16031
- src1->grad = ggml_add_impl(ctx,
16772
+ src1->grad = ggml_add_or_set(ctx,
16032
16773
  src1->grad,
16033
16774
  ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
16034
- inplace);
16775
+ zero_table);
16035
16776
  }
16036
16777
  } break;
16037
16778
  case GGML_OP_ACC:
16038
16779
  {
16039
16780
  if (src0->grad) {
16040
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16781
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16041
16782
  }
16042
16783
  if (src1->grad) {
16043
16784
  const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -16054,117 +16795,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16054
16795
  nb1, nb2, nb3, offset);
16055
16796
 
16056
16797
  src1->grad =
16057
- ggml_add_impl(ctx,
16798
+ ggml_add_or_set(ctx,
16058
16799
  src1->grad,
16059
16800
  ggml_reshape(ctx,
16060
16801
  ggml_cont(ctx, tensor_grad_view),
16061
16802
  src1->grad),
16062
- inplace);
16803
+ zero_table);
16063
16804
  }
16064
16805
  } break;
16065
16806
  case GGML_OP_SUB:
16066
16807
  {
16067
16808
  if (src0->grad) {
16068
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
16809
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16069
16810
  }
16070
16811
  if (src1->grad) {
16071
- src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
16812
+ src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
16072
16813
  }
16073
16814
  } break;
16074
16815
  case GGML_OP_MUL:
16075
16816
  {
16076
16817
  if (src0->grad) {
16077
16818
  src0->grad =
16078
- ggml_add_impl(ctx,
16819
+ ggml_add_or_set(ctx,
16079
16820
  src0->grad,
16080
16821
  ggml_mul(ctx, src1, tensor->grad),
16081
- inplace);
16822
+ zero_table);
16082
16823
  }
16083
16824
  if (src1->grad) {
16084
16825
  src1->grad =
16085
- ggml_add_impl(ctx,
16826
+ ggml_add_or_set(ctx,
16086
16827
  src1->grad,
16087
16828
  ggml_mul(ctx, src0, tensor->grad),
16088
- inplace);
16829
+ zero_table);
16089
16830
  }
16090
16831
  } break;
16091
16832
  case GGML_OP_DIV:
16092
16833
  {
16093
16834
  if (src0->grad) {
16094
16835
  src0->grad =
16095
- ggml_add_impl(ctx,
16836
+ ggml_add_or_set(ctx,
16096
16837
  src0->grad,
16097
16838
  ggml_div(ctx, tensor->grad, src1),
16098
- inplace);
16839
+ zero_table);
16099
16840
  }
16100
16841
  if (src1->grad) {
16101
16842
  src1->grad =
16102
- ggml_sub_impl(ctx,
16843
+ ggml_sub_or_set(ctx,
16103
16844
  src1->grad,
16104
16845
  ggml_mul(ctx,
16105
16846
  tensor->grad,
16106
16847
  ggml_div(ctx, tensor, src1)),
16107
- inplace);
16848
+ zero_table);
16108
16849
  }
16109
16850
  } break;
16110
16851
  case GGML_OP_SQR:
16111
16852
  {
16112
16853
  if (src0->grad) {
16113
16854
  src0->grad =
16114
- ggml_add_impl(ctx,
16855
+ ggml_add_or_set(ctx,
16115
16856
  src0->grad,
16116
16857
  ggml_scale(ctx,
16117
16858
  ggml_mul(ctx, src0, tensor->grad),
16118
16859
  ggml_new_f32(ctx, 2.0f)),
16119
- inplace);
16860
+ zero_table);
16120
16861
  }
16121
16862
  } break;
16122
16863
  case GGML_OP_SQRT:
16123
16864
  {
16124
16865
  if (src0->grad) {
16125
16866
  src0->grad =
16126
- ggml_add_impl(ctx,
16867
+ ggml_add_or_set(ctx,
16127
16868
  src0->grad,
16128
16869
  ggml_scale(ctx,
16129
16870
  ggml_div(ctx,
16130
16871
  tensor->grad,
16131
16872
  tensor),
16132
16873
  ggml_new_f32(ctx, 0.5f)),
16133
- inplace);
16874
+ zero_table);
16134
16875
  }
16135
16876
  } break;
16136
16877
  case GGML_OP_LOG:
16137
16878
  {
16138
16879
  if (src0->grad) {
16139
16880
  src0->grad =
16140
- ggml_add_impl(ctx,
16881
+ ggml_add_or_set(ctx,
16141
16882
  src0->grad,
16142
16883
  ggml_div(ctx,
16143
16884
  tensor->grad,
16144
16885
  src0),
16145
- inplace);
16886
+ zero_table);
16146
16887
  }
16147
16888
  } break;
16148
16889
  case GGML_OP_SUM:
16149
16890
  {
16150
16891
  if (src0->grad) {
16151
16892
  src0->grad =
16152
- ggml_add1_impl(ctx,
16893
+ ggml_add1_or_set(ctx,
16153
16894
  src0->grad,
16154
16895
  tensor->grad,
16155
- inplace);
16896
+ zero_table);
16156
16897
  }
16157
16898
  } break;
16158
16899
  case GGML_OP_SUM_ROWS:
16159
16900
  {
16160
16901
  if (src0->grad) {
16161
16902
  src0->grad =
16162
- ggml_add_impl(ctx,
16903
+ ggml_add_or_set(ctx,
16163
16904
  src0->grad,
16164
16905
  ggml_repeat(ctx,
16165
16906
  tensor->grad,
16166
16907
  src0->grad),
16167
- inplace);
16908
+ zero_table);
16168
16909
  }
16169
16910
  } break;
16170
16911
  case GGML_OP_MEAN:
@@ -16176,20 +16917,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16176
16917
  {
16177
16918
  // necessary for llama
16178
16919
  if (src0->grad) {
16179
- src0->grad = ggml_add_impl(ctx,
16920
+ src0->grad = ggml_add_or_set(ctx,
16180
16921
  src0->grad,
16181
16922
  ggml_repeat_back(ctx, tensor->grad, src0->grad),
16182
- inplace);
16923
+ zero_table);
16183
16924
  }
16184
16925
  } break;
16185
16926
  case GGML_OP_REPEAT_BACK:
16186
16927
  {
16187
16928
  if (src0->grad) {
16188
16929
  // TODO: test this
16189
- src0->grad = ggml_add_impl(ctx,
16930
+ src0->grad = ggml_add_or_set(ctx,
16190
16931
  src0->grad,
16191
16932
  ggml_repeat(ctx, tensor->grad, src0->grad),
16192
- inplace);
16933
+ zero_table);
16193
16934
  }
16194
16935
  } break;
16195
16936
  case GGML_OP_CONCAT:
@@ -16211,10 +16952,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16211
16952
  float eps;
16212
16953
  memcpy(&eps, tensor->op_params, sizeof(float));
16213
16954
 
16214
- src0->grad = ggml_add_impl(ctx,
16955
+ src0->grad = ggml_add_or_set(ctx,
16215
16956
  src0->grad,
16216
16957
  ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
16217
- inplace);
16958
+ zero_table);
16218
16959
  }
16219
16960
  } break;
16220
16961
  case GGML_OP_RMS_NORM_BACK:
@@ -16238,37 +16979,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16238
16979
  // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
16239
16980
  // ds1 = t.T.dot(dt)
16240
16981
 
16241
- // tensor.shape [m,p]
16242
- // src0.shape [n,m]
16243
- // src1.shape [n,p]
16982
+ // tensor.shape [m,p,qq,rr]
16983
+ // src0.shape [n,m,q1,r1]
16984
+ // src1.shape [n,p,qq,rr]
16244
16985
 
16245
16986
  // necessary for llama
16246
16987
  if (src0->grad) {
16988
+ struct ggml_tensor * s1_tg =
16989
+ ggml_out_prod(ctx, // [n,m,qq,rr]
16990
+ src1, // [n,p,qq,rr]
16991
+ tensor->grad); // [m,p,qq,rr]
16992
+ const int64_t qq = s1_tg->ne[2];
16993
+ const int64_t rr = s1_tg->ne[3];
16994
+ const int64_t q1 = src0->ne[2];
16995
+ const int64_t r1 = src0->ne[3];
16996
+ const bool ne2_broadcasted = qq > q1;
16997
+ const bool ne3_broadcasted = rr > r1;
16998
+ if (ne2_broadcasted || ne3_broadcasted) {
16999
+ // sum broadcast repetitions of s1_tg into shape of src0
17000
+ s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
17001
+ }
16247
17002
  src0->grad =
16248
- ggml_add_impl(ctx,
16249
- src0->grad,
16250
- ggml_out_prod(ctx, // [n,m]
16251
- src1, // [n,p]
16252
- tensor->grad), // [m,p]
16253
- inplace);
17003
+ ggml_add_or_set(ctx,
17004
+ src0->grad, // [n,m,q1,r1]
17005
+ s1_tg, // [n,m,q1,r1]
17006
+ zero_table);
16254
17007
  }
16255
17008
  if (src1->grad) {
16256
17009
  src1->grad =
16257
- ggml_add_impl(ctx,
16258
- src1->grad,
16259
- // ggml_mul_mat(ctx, // [n,p]
16260
- // ggml_cont(ctx, // [m,n]
16261
- // ggml_transpose(ctx, src0)), // [m,n]
16262
- // tensor->grad), // [m,p]
17010
+ ggml_add_or_set(ctx,
17011
+ src1->grad, // [n,p,qq,rr]
17012
+ // ggml_mul_mat(ctx, // [n,p,qq,rr]
17013
+ // ggml_cont(ctx, // [m,n,q1,r1]
17014
+ // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
17015
+ // tensor->grad), // [m,p,qq,rr]
16263
17016
 
16264
17017
  // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
16265
17018
  // // avoid transpose of src0, rather transpose smaller tensor->grad
16266
17019
  // // and then use ggml_out_prod
16267
- ggml_out_prod(ctx, // [n,p]
16268
- src0, // [n,m]
16269
- ggml_transpose(ctx, // [p,m]
16270
- tensor->grad)), // [m,p]
16271
- inplace);
17020
+ ggml_out_prod(ctx, // [n,p,qq,rr]
17021
+ src0, // [n,m,q1,r1]
17022
+ ggml_transpose(ctx, // [p,m,qq,rr]
17023
+ tensor->grad)), // [m,p,qq,rr]
17024
+ zero_table);
16272
17025
  }
16273
17026
  } break;
16274
17027
  case GGML_OP_OUT_PROD:
@@ -16280,17 +17033,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16280
17033
  // necessary for llama
16281
17034
  if (src0->grad) {
16282
17035
  src0->grad =
16283
- ggml_add_impl(ctx,
17036
+ ggml_add_or_set(ctx,
16284
17037
  src0->grad,
16285
17038
  ggml_scale_impl(ctx, tensor->grad, src1, false),
16286
- inplace);
17039
+ zero_table);
16287
17040
  }
16288
17041
  if (src1->grad) {
16289
17042
  src1->grad =
16290
- ggml_add_impl(ctx,
17043
+ ggml_add_or_set(ctx,
16291
17044
  src1->grad,
16292
17045
  ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
16293
- inplace);
17046
+ zero_table);
16294
17047
  }
16295
17048
  } break;
16296
17049
  case GGML_OP_SET:
@@ -16317,23 +17070,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16317
17070
  }
16318
17071
 
16319
17072
  if (src0->grad) {
16320
- src0->grad = ggml_add_impl(ctx,
17073
+ src0->grad = ggml_add_or_set(ctx,
16321
17074
  src0->grad,
16322
17075
  ggml_acc_impl(ctx,
16323
17076
  tensor->grad,
16324
17077
  ggml_neg(ctx, tensor_grad_view),
16325
17078
  nb1, nb2, nb3, offset, false),
16326
- inplace);
17079
+ zero_table);
16327
17080
  }
16328
17081
 
16329
17082
  if (src1->grad) {
16330
17083
  src1->grad =
16331
- ggml_add_impl(ctx,
17084
+ ggml_add_or_set(ctx,
16332
17085
  src1->grad,
16333
17086
  ggml_reshape(ctx,
16334
17087
  ggml_cont(ctx, tensor_grad_view),
16335
17088
  src1->grad),
16336
- inplace);
17089
+ zero_table);
16337
17090
  }
16338
17091
  } break;
16339
17092
  case GGML_OP_CPY:
@@ -16344,7 +17097,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16344
17097
  // tensor = src0 * 1 + src1 * 0
16345
17098
  if (src0->grad) {
16346
17099
  // dsrc0 = dtensor * 1
16347
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
17100
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16348
17101
  }
16349
17102
  if (src1->grad) {
16350
17103
  // dsrc1 = dtensor * 0 -> noop
@@ -16356,7 +17109,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16356
17109
  if (src0->grad) {
16357
17110
  GGML_ASSERT(ggml_is_contiguous(src0->grad));
16358
17111
  GGML_ASSERT(ggml_is_contiguous(tensor->grad));
16359
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
17112
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
16360
17113
  }
16361
17114
  } break;
16362
17115
  case GGML_OP_RESHAPE:
@@ -16364,9 +17117,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16364
17117
  // necessary for llama
16365
17118
  if (src0->grad) {
16366
17119
  src0->grad =
16367
- ggml_add_impl(ctx, src0->grad,
16368
- ggml_reshape(ctx, tensor->grad, src0->grad),
16369
- inplace);
17120
+ ggml_add_or_set(ctx, src0->grad,
17121
+ ggml_reshape(ctx,
17122
+ ggml_is_contiguous(tensor->grad)
17123
+ ? tensor->grad
17124
+ : ggml_cont(ctx, tensor->grad),
17125
+ src0->grad),
17126
+ zero_table);
16370
17127
  }
16371
17128
  } break;
16372
17129
  case GGML_OP_VIEW:
@@ -16395,7 +17152,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16395
17152
  nb3 = (nb3 / n0) * ng;
16396
17153
  }
16397
17154
 
16398
- src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
17155
+ src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
16399
17156
  }
16400
17157
  } break;
16401
17158
  case GGML_OP_PERMUTE:
@@ -16413,14 +17170,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16413
17170
  axes_backward[axis2] = 2;
16414
17171
  axes_backward[axis3] = 3;
16415
17172
  src0->grad =
16416
- ggml_add_impl(ctx, src0->grad,
17173
+ ggml_add_or_set(ctx, src0->grad,
16417
17174
  ggml_permute(ctx,
16418
17175
  tensor->grad,
16419
17176
  axes_backward[0],
16420
17177
  axes_backward[1],
16421
17178
  axes_backward[2],
16422
17179
  axes_backward[3]),
16423
- inplace);
17180
+ zero_table);
16424
17181
  }
16425
17182
  } break;
16426
17183
  case GGML_OP_TRANSPOSE:
@@ -16428,9 +17185,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16428
17185
  // necessary for llama
16429
17186
  if (src0->grad) {
16430
17187
  src0->grad =
16431
- ggml_add_impl(ctx, src0->grad,
17188
+ ggml_add_or_set(ctx, src0->grad,
16432
17189
  ggml_transpose(ctx, tensor->grad),
16433
- inplace);
17190
+ zero_table);
16434
17191
  }
16435
17192
  } break;
16436
17193
  case GGML_OP_GET_ROWS:
@@ -16438,9 +17195,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16438
17195
  // necessary for llama (only for tokenizer)
16439
17196
  if (src0->grad) {
16440
17197
  src0->grad =
16441
- ggml_add_impl(ctx, src0->grad,
17198
+ ggml_add_or_set(ctx, src0->grad,
17199
+ // last ggml_get_rows_back argument src0->grad is only
17200
+ // necessary to setup correct output shape
16442
17201
  ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
16443
- inplace);
17202
+ zero_table);
16444
17203
  }
16445
17204
  if (src1->grad) {
16446
17205
  // noop
@@ -16460,9 +17219,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16460
17219
  if (src0->grad) {
16461
17220
  const int n_past = ((int32_t *) tensor->op_params)[0];
16462
17221
  src0->grad =
16463
- ggml_add_impl(ctx, src0->grad,
17222
+ ggml_add_or_set(ctx, src0->grad,
16464
17223
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
16465
- inplace);
17224
+ zero_table);
16466
17225
  }
16467
17226
  } break;
16468
17227
  case GGML_OP_DIAG_MASK_ZERO:
@@ -16471,9 +17230,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16471
17230
  if (src0->grad) {
16472
17231
  const int n_past = ((int32_t *) tensor->op_params)[0];
16473
17232
  src0->grad =
16474
- ggml_add_impl(ctx, src0->grad,
17233
+ ggml_add_or_set(ctx, src0->grad,
16475
17234
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
16476
- inplace);
17235
+ zero_table);
16477
17236
  }
16478
17237
  } break;
16479
17238
  case GGML_OP_SOFT_MAX:
@@ -16481,9 +17240,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16481
17240
  // necessary for llama
16482
17241
  if (src0->grad) {
16483
17242
  src0->grad =
16484
- ggml_add_impl(ctx, src0->grad,
17243
+ ggml_add_or_set(ctx, src0->grad,
16485
17244
  ggml_soft_max_back(ctx, tensor->grad, tensor),
16486
- inplace);
17245
+ zero_table);
16487
17246
  }
16488
17247
 
16489
17248
  } break;
@@ -16495,7 +17254,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16495
17254
  {
16496
17255
  // necessary for llama
16497
17256
  if (src0->grad) {
16498
- const int n_past = ((int32_t *) tensor->op_params)[0];
17257
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
16499
17258
  const int n_dims = ((int32_t *) tensor->op_params)[1];
16500
17259
  const int mode = ((int32_t *) tensor->op_params)[2];
16501
17260
  const int n_ctx = ((int32_t *) tensor->op_params)[3];
@@ -16508,11 +17267,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16508
17267
  memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
16509
17268
  memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
16510
17269
 
16511
- src0->grad = ggml_add_impl(ctx,
17270
+ src0->grad = ggml_add_or_set(ctx,
16512
17271
  src0->grad,
16513
17272
  ggml_rope_back(ctx,
16514
17273
  tensor->grad,
16515
- n_past,
17274
+ src1,
16516
17275
  n_dims,
16517
17276
  mode,
16518
17277
  n_ctx,
@@ -16520,13 +17279,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16520
17279
  freq_scale,
16521
17280
  xpos_base,
16522
17281
  xpos_down),
16523
- inplace);
17282
+ zero_table);
16524
17283
  }
16525
17284
  } break;
16526
17285
  case GGML_OP_ROPE_BACK:
16527
17286
  {
16528
17287
  if (src0->grad) {
16529
- const int n_past = ((int32_t *) tensor->op_params)[0];
17288
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
16530
17289
  const int n_dims = ((int32_t *) tensor->op_params)[1];
16531
17290
  const int mode = ((int32_t *) tensor->op_params)[2];
16532
17291
  const int n_ctx = ((int32_t *) tensor->op_params)[3];
@@ -16539,11 +17298,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16539
17298
  memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
16540
17299
  memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
16541
17300
 
16542
- src0->grad = ggml_add_impl(ctx,
17301
+ src0->grad = ggml_add_or_set(ctx,
16543
17302
  src0->grad,
16544
17303
  ggml_rope_impl(ctx,
16545
17304
  tensor->grad,
16546
- n_past,
17305
+ src1,
16547
17306
  n_dims,
16548
17307
  mode,
16549
17308
  n_ctx,
@@ -16552,7 +17311,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16552
17311
  xpos_base,
16553
17312
  xpos_down,
16554
17313
  false),
16555
- inplace);
17314
+ zero_table);
16556
17315
  }
16557
17316
  } break;
16558
17317
  case GGML_OP_ALIBI:
@@ -16603,145 +17362,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16603
17362
  masked);
16604
17363
  }
16605
17364
 
16606
- if (src0->grad) {
16607
- struct ggml_tensor * grad_q = NULL;
16608
- const size_t nb0 = flash_grad->nb[0];
16609
- const size_t offset = 0;
16610
- switch(src0->n_dims) {
16611
- case 2:
16612
- {
16613
- grad_q = ggml_view_2d(ctx,
16614
- flash_grad,
16615
- src0->ne[0],
16616
- src0->ne[1],
16617
- nb0*src0->ne[0],
16618
- offset);
16619
- } break;
16620
- case 3:
16621
- {
16622
- grad_q = ggml_view_3d(ctx,
16623
- flash_grad,
16624
- src0->ne[0],
16625
- src0->ne[1],
16626
- src0->ne[2],
16627
- nb0*src0->ne[0],
16628
- nb0*src0->ne[0]*src0->ne[1],
16629
- offset);
16630
- } break;
16631
- case 4:
16632
- {
16633
- grad_q = ggml_view_4d(ctx,
16634
- flash_grad,
16635
- src0->ne[0],
16636
- src0->ne[1],
16637
- src0->ne[2],
16638
- src0->ne[3],
16639
- nb0*src0->ne[0],
16640
- nb0*src0->ne[0]*src0->ne[1],
16641
- nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
16642
- offset);
16643
- } break;
16644
- }
17365
+ struct ggml_tensor * src2 = tensor->src[2];
17366
+ const int64_t elem_q = ggml_nelements(src0);
17367
+ const int64_t elem_k = ggml_nelements(src1);
17368
+ const int64_t elem_v = ggml_nelements(src2);
17369
+
17370
+ enum ggml_type result_type = flash_grad->type;
17371
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
17372
+ const size_t tsize = ggml_type_size(result_type);
17373
+
17374
+ const size_t offs_q = 0;
17375
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
17376
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
16645
17377
 
16646
- src0->grad = ggml_add_impl(ctx,
17378
+ if (src0->grad) {
17379
+ struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
17380
+ struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
17381
+ src0->grad = ggml_add_or_set(ctx,
16647
17382
  src0->grad,
16648
17383
  grad_q,
16649
- inplace);
17384
+ zero_table);
16650
17385
  }
16651
-
16652
17386
  if (src1->grad) {
16653
- struct ggml_tensor * grad_k = NULL;
16654
- const size_t nb0 = flash_grad->nb[0];
16655
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
16656
- switch(src1->n_dims) {
16657
- case 2:
16658
- {
16659
- grad_k = ggml_view_2d(ctx,
16660
- flash_grad,
16661
- src1->ne[0],
16662
- src1->ne[1],
16663
- nb0*src1->ne[0],
16664
- offset);
16665
- } break;
16666
- case 3:
16667
- {
16668
- grad_k = ggml_view_3d(ctx,
16669
- flash_grad,
16670
- src1->ne[0],
16671
- src1->ne[1],
16672
- src1->ne[2],
16673
- nb0*src1->ne[0],
16674
- nb0*src1->ne[0]*src1->ne[1],
16675
- offset);
16676
- } break;
16677
- case 4:
16678
- {
16679
- grad_k = ggml_view_4d(ctx,
16680
- flash_grad,
16681
- src1->ne[0],
16682
- src1->ne[1],
16683
- src1->ne[2],
16684
- src1->ne[3],
16685
- nb0*src1->ne[0],
16686
- nb0*src1->ne[0]*src1->ne[1],
16687
- nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
16688
- offset);
16689
- } break;
16690
- }
16691
-
16692
- src1->grad = ggml_add_impl(ctx,
17387
+ struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
17388
+ struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
17389
+ src1->grad = ggml_add_or_set(ctx,
16693
17390
  src1->grad,
16694
17391
  grad_k,
16695
- inplace);
17392
+ zero_table);
16696
17393
  }
16697
-
16698
- struct ggml_tensor * opt0 = tensor->src[2];
16699
-
16700
- if (opt0->grad) {
16701
- struct ggml_tensor * grad_v = NULL;
16702
- const size_t nb0 = flash_grad->nb[0];
16703
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
16704
- + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
16705
- switch(opt0->n_dims) {
16706
- case 2:
16707
- {
16708
- grad_v = ggml_view_2d(ctx,
16709
- flash_grad,
16710
- opt0->ne[0],
16711
- opt0->ne[1],
16712
- nb0*opt0->ne[0],
16713
- offset);
16714
- } break;
16715
- case 3:
16716
- {
16717
- grad_v = ggml_view_3d(ctx,
16718
- flash_grad,
16719
- opt0->ne[0],
16720
- opt0->ne[1],
16721
- opt0->ne[2],
16722
- nb0*opt0->ne[0],
16723
- nb0*opt0->ne[0]*opt0->ne[1],
16724
- offset);
16725
- } break;
16726
- case 4:
16727
- {
16728
- grad_v = ggml_view_4d(ctx,
16729
- flash_grad,
16730
- opt0->ne[0],
16731
- opt0->ne[1],
16732
- opt0->ne[2],
16733
- opt0->ne[3],
16734
- nb0*opt0->ne[0],
16735
- nb0*opt0->ne[0]*opt0->ne[1],
16736
- nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
16737
- offset);
16738
- } break;
16739
- }
16740
-
16741
- opt0->grad = ggml_add_impl(ctx,
16742
- opt0->grad,
17394
+ if (src2->grad) {
17395
+ struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
17396
+ struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
17397
+ src2->grad = ggml_add_or_set(ctx,
17398
+ src2->grad,
16743
17399
  grad_v,
16744
- inplace);
17400
+ zero_table);
16745
17401
  }
16746
17402
  } break;
16747
17403
  case GGML_OP_FLASH_FF:
@@ -16761,12 +17417,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16761
17417
  {
16762
17418
  if (src0->grad) {
16763
17419
  src0->grad =
16764
- ggml_add_impl(ctx,
17420
+ ggml_add_or_set(ctx,
16765
17421
  src0->grad,
16766
17422
  ggml_mul(ctx,
16767
17423
  ggml_sgn(ctx, src0),
16768
17424
  tensor->grad),
16769
- inplace);
17425
+ zero_table);
16770
17426
  }
16771
17427
  } break;
16772
17428
  case GGML_UNARY_OP_SGN:
@@ -16778,7 +17434,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16778
17434
  case GGML_UNARY_OP_NEG:
16779
17435
  {
16780
17436
  if (src0->grad) {
16781
- src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
17437
+ src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
16782
17438
  }
16783
17439
  } break;
16784
17440
  case GGML_UNARY_OP_STEP:
@@ -16798,12 +17454,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16798
17454
  case GGML_UNARY_OP_RELU:
16799
17455
  {
16800
17456
  if (src0->grad) {
16801
- src0->grad = ggml_add_impl(ctx,
17457
+ src0->grad = ggml_add_or_set(ctx,
16802
17458
  src0->grad,
16803
17459
  ggml_mul(ctx,
16804
17460
  ggml_step(ctx, src0),
16805
17461
  tensor->grad),
16806
- inplace);
17462
+ zero_table);
16807
17463
  }
16808
17464
  } break;
16809
17465
  case GGML_UNARY_OP_GELU:
@@ -16818,10 +17474,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16818
17474
  {
16819
17475
  // necessary for llama
16820
17476
  if (src0->grad) {
16821
- src0->grad = ggml_add_impl(ctx,
17477
+ src0->grad = ggml_add_or_set(ctx,
16822
17478
  src0->grad,
16823
17479
  ggml_silu_back(ctx, src0, tensor->grad),
16824
- inplace);
17480
+ zero_table);
16825
17481
  }
16826
17482
  } break;
16827
17483
  default:
@@ -16844,13 +17500,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16844
17500
  case GGML_OP_CROSS_ENTROPY_LOSS:
16845
17501
  {
16846
17502
  if (src0->grad) {
16847
- src0->grad = ggml_add_impl(ctx,
17503
+ src0->grad = ggml_add_or_set(ctx,
16848
17504
  src0->grad,
16849
17505
  ggml_cross_entropy_loss_back(ctx,
16850
17506
  src0,
16851
17507
  src1,
16852
17508
  tensor->grad),
16853
- inplace);
17509
+ zero_table);
16854
17510
  }
16855
17511
  } break;
16856
17512
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@@ -16866,34 +17522,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16866
17522
  GGML_ASSERT(false);
16867
17523
  } break;
16868
17524
  }
16869
- }
16870
17525
 
16871
- static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
16872
-
16873
- static size_t hash(void * p) {
16874
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
16875
- }
16876
-
16877
- static bool hash_insert(void * hash_table[], void * p) {
16878
- size_t h = hash(p);
16879
-
16880
- // linear probing
16881
- size_t i = h;
16882
- while (hash_table[i] != NULL && hash_table[i] != p) {
16883
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
16884
- if (i == h) {
16885
- // hash table is full
16886
- GGML_ASSERT(false);
17526
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
17527
+ if (tensor->src[i] && tensor->src[i]->grad) {
17528
+ GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
16887
17529
  }
16888
17530
  }
16889
-
16890
- if (hash_table[i] == p) {
16891
- return true;
16892
- }
16893
-
16894
- // insert
16895
- hash_table[i] = p;
16896
- return false;
16897
17531
  }
16898
17532
 
16899
17533
  static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
@@ -16911,8 +17545,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
16911
17545
  }
16912
17546
 
16913
17547
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
16914
- if (node->src[i]) {
16915
- ggml_visit_parents(cgraph, node->src[i]);
17548
+ const int k =
17549
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
17550
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
17551
+ /* unknown order, just fall back to using i*/ i;
17552
+ if (node->src[k]) {
17553
+ ggml_visit_parents(cgraph, node->src[k]);
16916
17554
  }
16917
17555
  }
16918
17556
 
@@ -16971,6 +17609,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
16971
17609
  /*.grads =*/ { NULL },
16972
17610
  /*.leafs =*/ { NULL },
16973
17611
  /*.hash_table =*/ { NULL },
17612
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
16974
17613
  /*.perf_runs =*/ 0,
16975
17614
  /*.perf_cycles =*/ 0,
16976
17615
  /*.perf_time_us =*/ 0,
@@ -16996,12 +17635,22 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
16996
17635
  }
16997
17636
  }
16998
17637
 
17638
+ // remember original gradients which start with zero values
17639
+ void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
17640
+ memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
17641
+ for (int i = 0; i < gf->n_nodes; i++) {
17642
+ if (gf->grads[i]) {
17643
+ hash_insert(zero_table, gf->grads[i]);
17644
+ }
17645
+ }
17646
+
16999
17647
  for (int i = gf->n_nodes - 1; i >= 0; i--) {
17000
17648
  struct ggml_tensor * node = gf->nodes[i];
17001
17649
 
17002
- // because we detached the grad nodes from the original graph, we can afford inplace operations
17650
+ // inplace operations to add gradients are not created by ggml_compute_backward
17651
+ // use allocator to automatically make inplace operations
17003
17652
  if (node->grad) {
17004
- ggml_compute_backward(ctx, node, keep);
17653
+ ggml_compute_backward(ctx, node, zero_table);
17005
17654
  }
17006
17655
  }
17007
17656
 
@@ -17013,6 +17662,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
17013
17662
  ggml_build_forward_expand(gb, node->grad);
17014
17663
  }
17015
17664
  }
17665
+
17666
+ free(zero_table);
17016
17667
  }
17017
17668
 
17018
17669
  struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
@@ -17032,6 +17683,7 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
17032
17683
  /*.grads =*/ { NULL },
17033
17684
  /*.leafs =*/ { NULL },
17034
17685
  /*.hash_table =*/ { NULL },
17686
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
17035
17687
  /*.perf_runs =*/ 0,
17036
17688
  /*.perf_cycles =*/ 0,
17037
17689
  /*.perf_time_us =*/ 0,
@@ -17283,10 +17935,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
17283
17935
  } else {
17284
17936
  // wait for other threads to finish
17285
17937
  const int last = node_n;
17286
- do {
17287
- //sched_yield();
17938
+ while (true) {
17939
+ // TODO: this sched_yield can have significant impact on the performance - either positive or negative
17940
+ // depending on the workload and the operating system.
17941
+ // since it is not clear what is the best approach, it should potentially become user-configurable
17942
+ // ref: https://github.com/ggerganov/ggml/issues/291
17943
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
17944
+ sched_yield();
17945
+ #endif
17946
+
17288
17947
  node_n = atomic_load(&state->shared->node_n);
17289
- } while (node_n == last);
17948
+ if (node_n != last) break;
17949
+ };
17290
17950
  }
17291
17951
 
17292
17952
  // check if we should stop
@@ -17414,7 +18074,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
17414
18074
  } break;
17415
18075
  case GGML_OP_CONCAT:
17416
18076
  case GGML_OP_MUL_MAT:
17417
- case GGML_OP_OUT_PROD:
17418
18077
  {
17419
18078
  n_tasks = n_threads;
17420
18079
 
@@ -17456,6 +18115,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
17456
18115
  cur = 0;
17457
18116
  }
17458
18117
 
18118
+ work_size = MAX(work_size, cur);
18119
+ } break;
18120
+ case GGML_OP_OUT_PROD:
18121
+ {
18122
+ n_tasks = n_threads;
18123
+
18124
+ size_t cur = 0;
18125
+
18126
+ if (ggml_is_quantized(node->src[0]->type)) {
18127
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
18128
+ }
18129
+
17459
18130
  work_size = MAX(work_size, cur);
17460
18131
  } break;
17461
18132
  case GGML_OP_SCALE:
@@ -18337,10 +19008,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
18337
19008
  for (int i = 0; i < cgraph->n_leafs; i++) {
18338
19009
  struct ggml_tensor * node = cgraph->leafs[i];
18339
19010
 
18340
- GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
19011
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
18341
19012
  i,
18342
19013
  node->ne[0], node->ne[1],
18343
- ggml_op_name(node->op));
19014
+ ggml_op_name(node->op),
19015
+ ggml_get_name(node));
18344
19016
  }
18345
19017
 
18346
19018
  for (int i = 0; i < GGML_OP_COUNT; i++) {
@@ -18548,7 +19220,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
18548
19220
  }
18549
19221
 
18550
19222
  static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
18551
- int i = 0;
19223
+ int64_t i = 0;
18552
19224
  for (int p = 0; p < np; ++p) {
18553
19225
  const int64_t ne = ggml_nelements(ps[p]) ;
18554
19226
  // TODO: add function to get all elements at once
@@ -18558,6 +19230,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
18558
19230
  }
18559
19231
  }
18560
19232
 
19233
+ static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
19234
+ int64_t i = 0;
19235
+ for (int p = 0; p < np; ++p) {
19236
+ const int64_t ne = ggml_nelements(ps[p]) ;
19237
+ // TODO: add function to get all elements at once
19238
+ for (int64_t j = 0; j < ne; ++j) {
19239
+ g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
19240
+ }
19241
+ }
19242
+ }
19243
+
18561
19244
  //
18562
19245
  // ADAM
18563
19246
  //
@@ -18606,26 +19289,43 @@ static enum ggml_opt_result ggml_opt_adam(
18606
19289
  const float eps = params.adam.eps;
18607
19290
  const float gclip = params.adam.gclip;
18608
19291
  const int decay_min_ndim = params.adam.decay_min_ndim;
19292
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
19293
+ const float accum_norm = 1.0f / (float) n_accum;
18609
19294
 
19295
+ float * g = opt->adam.g->data; // gradients
18610
19296
  float * m = opt->adam.m->data; // first moment
18611
19297
  float * v = opt->adam.v->data; // second moment
18612
19298
 
18613
19299
  float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
18614
19300
 
18615
- if (callback) {
18616
- callback(callback_data, &sched);
18617
- }
18618
-
18619
- // compute the function value
18620
- ggml_graph_reset (gf);
18621
- ggml_set_f32 (f->grad, 1.0f);
18622
-
18623
19301
  struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
18624
19302
  struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
18625
19303
  cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
18626
- ggml_graph_compute(gb, &cplan);
18627
19304
 
18628
- opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
19305
+ bool cancel = false;
19306
+
19307
+ // compute the function value
19308
+ float fx = 0;
19309
+ ggml_set_zero(opt->adam.g);
19310
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19311
+ if (callback) {
19312
+ callback(callback_data, accum_step, &sched, &cancel);
19313
+ if (cancel) {
19314
+ break;
19315
+ }
19316
+ }
19317
+ // ggml_graph_reset (gf);
19318
+ ggml_set_f32 (f->grad, 1.0f);
19319
+ ggml_graph_compute(gb, &cplan);
19320
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19321
+ fx += ggml_get_f32_1d(f, 0);
19322
+ }
19323
+ if (cancel) {
19324
+ return GGML_OPT_DID_NOT_CONVERGE;
19325
+ }
19326
+ fx *= accum_norm;
19327
+
19328
+ opt->adam.fx_prev = fx;
18629
19329
  opt->adam.fx_best = opt->adam.fx_prev;
18630
19330
  if (pf) {
18631
19331
  pf[opt->iter % params.past] = opt->adam.fx_prev;
@@ -18648,6 +19348,9 @@ static enum ggml_opt_result ggml_opt_adam(
18648
19348
 
18649
19349
  // run the optimizer
18650
19350
  for (int t = 0; t < params.adam.n_iter; ++t) {
19351
+ if (cancel) {
19352
+ break;
19353
+ }
18651
19354
  opt->iter = iter0 + t + 1;
18652
19355
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
18653
19356
 
@@ -18670,12 +19373,8 @@ static enum ggml_opt_result ggml_opt_adam(
18670
19373
  if (gclip > 0.0f) {
18671
19374
  // gradient clipping
18672
19375
  ggml_float sum = 0.0;
18673
- for (int p = 0; p < np; ++p) {
18674
- const int64_t ne = ggml_nelements(ps[p]);
18675
- for (int64_t j = 0; j < ne; ++j) {
18676
- float g = ggml_get_f32_1d(ps[p]->grad, j);
18677
- sum += (ggml_float)(g*g);
18678
- }
19376
+ for (int64_t i = 0; i < nx; ++i) {
19377
+ sum += (ggml_float)(g[i]*g[i]);
18679
19378
  }
18680
19379
  ggml_float norm = sqrt(sum);
18681
19380
  if (norm > (ggml_float) gclip) {
@@ -18689,10 +19388,10 @@ static enum ggml_opt_result ggml_opt_adam(
18689
19388
  const int64_t ne = ggml_nelements(ps[p]);
18690
19389
  const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
18691
19390
  for (int64_t j = 0; j < ne; ++j) {
18692
- float x = ggml_get_f32_1d(ps[p], j);
18693
- float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
18694
- m[i] = m[i]*beta1 + g*(1.0f - beta1);
18695
- v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
19391
+ float x = ggml_get_f32_1d(ps[p], j);
19392
+ float g_ = g[i]*gnorm;
19393
+ m[i] = m[i]*beta1 + g_*(1.0f - beta1);
19394
+ v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
18696
19395
  float mh = m[i]*beta1h;
18697
19396
  float vh = v[i]*beta2h;
18698
19397
  vh = sqrtf(vh) + eps;
@@ -18703,16 +19402,26 @@ static enum ggml_opt_result ggml_opt_adam(
18703
19402
  }
18704
19403
  }
18705
19404
 
18706
- if (callback) {
18707
- callback(callback_data, &sched);
19405
+ fx = 0;
19406
+ ggml_set_zero(opt->adam.g);
19407
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19408
+ if (callback) {
19409
+ callback(callback_data, accum_step, &sched, &cancel);
19410
+ if (cancel) {
19411
+ break;
19412
+ }
19413
+ }
19414
+ // ggml_graph_reset (gf);
19415
+ ggml_set_f32 (f->grad, 1.0f);
19416
+ ggml_graph_compute(gb, &cplan);
19417
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19418
+ fx += ggml_get_f32_1d(f, 0);
18708
19419
  }
19420
+ if (cancel) {
19421
+ break;
19422
+ }
19423
+ fx *= accum_norm;
18709
19424
 
18710
- ggml_graph_reset (gf);
18711
- ggml_set_f32 (f->grad, 1.0f);
18712
-
18713
- ggml_graph_compute(gb, &cplan);
18714
-
18715
- const float fx = ggml_get_f32_1d(f, 0);
18716
19425
  opt->loss_after = fx;
18717
19426
 
18718
19427
 
@@ -18792,11 +19501,11 @@ static enum ggml_opt_result linesearch_backtracking(
18792
19501
  float * step,
18793
19502
  const float * xp,
18794
19503
  struct ggml_tensor * f,
18795
- struct ggml_cgraph * gf,
18796
19504
  struct ggml_cgraph * gb,
18797
19505
  struct ggml_cplan * cplan,
18798
19506
  const int np,
18799
19507
  struct ggml_tensor * ps[],
19508
+ bool * cancel,
18800
19509
  ggml_opt_callback callback,
18801
19510
  void * callback_data) {
18802
19511
  int count = 0;
@@ -18810,6 +19519,9 @@ static enum ggml_opt_result linesearch_backtracking(
18810
19519
  const float dec = 0.5f;
18811
19520
  const float inc = 2.1f;
18812
19521
 
19522
+ const int n_accum = MAX(1, params->n_gradient_accumulation);
19523
+ const float accum_norm = 1.0f / (float) n_accum;
19524
+
18813
19525
  if (*step <= 0.f) {
18814
19526
  return GGML_LINESEARCH_INVALID_PARAMETERS;
18815
19527
  }
@@ -18826,13 +19538,7 @@ static enum ggml_opt_result linesearch_backtracking(
18826
19538
  finit = *fx;
18827
19539
  dgtest = params->lbfgs.ftol*dginit;
18828
19540
 
18829
- while (true) {
18830
- if (callback) {
18831
- // LBFG-S does not support learning rate -> ignore learning schedule
18832
- float sched = 0;
18833
- callback(callback_data, &sched);
18834
- }
18835
-
19541
+ while (!*cancel) {
18836
19542
  ggml_vec_cpy_f32(nx, x, xp);
18837
19543
  ggml_vec_mad_f32(nx, x, d, *step);
18838
19544
 
@@ -18840,14 +19546,28 @@ static enum ggml_opt_result linesearch_backtracking(
18840
19546
  {
18841
19547
  ggml_opt_set_params(np, ps, x);
18842
19548
 
18843
- ggml_graph_reset (gf);
18844
- ggml_set_f32 (f->grad, 1.0f);
18845
-
18846
- ggml_graph_compute(gb, cplan);
18847
-
18848
- ggml_opt_get_grad(np, ps, g);
19549
+ *fx = 0;
19550
+ memset(g, 0, sizeof(float)*nx);
19551
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19552
+ if (callback) {
19553
+ // LBFG-S does not support learning rate -> ignore learning schedule
19554
+ float sched = 0;
19555
+ callback(callback_data, accum_step, &sched, cancel);
19556
+ if (*cancel) {
19557
+ break;
19558
+ }
19559
+ }
19560
+ // ggml_graph_reset (gf);
19561
+ ggml_set_f32 (f->grad, 1.0f);
19562
+ ggml_graph_compute(gb, cplan);
19563
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19564
+ *fx += ggml_get_f32_1d(f, 0);
19565
+ }
19566
+ if (*cancel) {
19567
+ break;
19568
+ }
19569
+ *fx *= accum_norm;
18849
19570
 
18850
- *fx = ggml_get_f32_1d(f, 0);
18851
19571
  }
18852
19572
 
18853
19573
  ++count;
@@ -18893,7 +19613,7 @@ static enum ggml_opt_result linesearch_backtracking(
18893
19613
  (*step) *= width;
18894
19614
  }
18895
19615
 
18896
- return GGML_LINESEARCH_FAIL;
19616
+ GGML_UNREACHABLE();
18897
19617
  }
18898
19618
 
18899
19619
  static enum ggml_opt_result ggml_opt_lbfgs(
@@ -18948,6 +19668,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18948
19668
 
18949
19669
  float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
18950
19670
 
19671
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
19672
+ const float accum_norm = 1.0f / (float) n_accum;
19673
+
18951
19674
  float fx = 0.0f; // cost function value
18952
19675
  float xnorm = 0.0f; // ||x||
18953
19676
  float gnorm = 0.0f; // ||g||
@@ -18961,24 +19684,33 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18961
19684
  float * lm_s = opt->lbfgs.lms->data;
18962
19685
  float * lm_y = opt->lbfgs.lmy->data;
18963
19686
 
18964
- if (callback) {
18965
- // LBFG-S does not support learning rate -> ignore learning schedule
18966
- float sched = 0;
18967
- callback(callback_data, &sched);
18968
- }
19687
+ bool cancel = false;
18969
19688
 
18970
19689
  // evaluate the function value and its gradient
18971
19690
  {
18972
19691
  ggml_opt_set_params(np, ps, x);
18973
19692
 
18974
- ggml_graph_reset (gf);
18975
- ggml_set_f32 (f->grad, 1.0f);
18976
-
18977
- ggml_graph_compute(gb, &cplan);
18978
-
18979
- ggml_opt_get_grad(np, ps, g);
18980
-
18981
- fx = ggml_get_f32_1d(f, 0);
19693
+ fx = 0;
19694
+ memset(g, 0, sizeof(float)*nx);
19695
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
19696
+ if (callback) {
19697
+ // LBFG-S does not support learning rate -> ignore learning schedule
19698
+ float sched = 0;
19699
+ callback(callback_data, accum_step, &sched, &cancel);
19700
+ if (cancel) {
19701
+ break;
19702
+ }
19703
+ }
19704
+ // ggml_graph_reset (gf);
19705
+ ggml_set_f32 (f->grad, 1.0f);
19706
+ ggml_graph_compute(gb, &cplan);
19707
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
19708
+ fx += ggml_get_f32_1d(f, 0);
19709
+ }
19710
+ if (cancel) {
19711
+ return GGML_OPT_DID_NOT_CONVERGE;
19712
+ }
19713
+ fx *= accum_norm;
18982
19714
 
18983
19715
  opt->loss_before = fx;
18984
19716
  opt->loss_after = fx;
@@ -19036,7 +19768,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19036
19768
  ggml_vec_cpy_f32(nx, xp, x);
19037
19769
  ggml_vec_cpy_f32(nx, gp, g);
19038
19770
 
19039
- ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data);
19771
+ ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
19772
+ if (!cancel) {
19773
+ break;
19774
+ }
19040
19775
 
19041
19776
  if (ls < 0) {
19042
19777
  // linesearch failed - go back to the previous point and return
@@ -19145,7 +19880,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
19145
19880
  step[0] = 1.0;
19146
19881
  }
19147
19882
 
19148
- return GGML_OPT_DID_NOT_CONVERGE;
19883
+ GGML_UNREACHABLE();
19149
19884
  }
19150
19885
 
19151
19886
  struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
@@ -19165,6 +19900,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
19165
19900
  .print_forward_graph = true,
19166
19901
  .print_backward_graph = true,
19167
19902
 
19903
+ .n_gradient_accumulation = 1,
19904
+
19168
19905
  .adam = {
19169
19906
  .n_iter = 10000,
19170
19907
  .sched = 1.000f,
@@ -19193,6 +19930,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
19193
19930
  .print_forward_graph = true,
19194
19931
  .print_backward_graph = true,
19195
19932
 
19933
+ .n_gradient_accumulation = 1,
19934
+
19196
19935
  .lbfgs = {
19197
19936
  .m = 6,
19198
19937
  .n_iter = 100,
@@ -19223,13 +19962,32 @@ GGML_API void ggml_opt_init(
19223
19962
  opt->iter = 0;
19224
19963
  opt->nx = nx;
19225
19964
  opt->just_initialized = true;
19965
+ if (opt->ctx == NULL) {
19966
+ struct ggml_init_params ctx_opt_params;
19967
+ if (opt->params.type == GGML_OPT_ADAM) {
19968
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
19969
+ if (opt->params.past > 0) {
19970
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
19971
+ }
19972
+ } else if (opt->params.type == GGML_OPT_LBFGS) {
19973
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
19974
+ if (opt->params.past > 0) {
19975
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
19976
+ }
19977
+ }
19978
+ ctx_opt_params.mem_buffer = NULL;
19979
+ ctx_opt_params.no_alloc = false;
19980
+
19981
+ opt->ctx = ggml_init(ctx_opt_params);
19982
+ }
19226
19983
  switch (opt->params.type) {
19227
19984
  case GGML_OPT_ADAM:
19228
19985
  {
19229
- opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19230
- opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19986
+ opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19987
+ opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19988
+ opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19231
19989
  opt->adam.pf = params.past > 0
19232
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
19990
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
19233
19991
  : NULL;
19234
19992
  ggml_set_zero(opt->adam.m);
19235
19993
  ggml_set_zero(opt->adam.v);
@@ -19239,18 +19997,18 @@ GGML_API void ggml_opt_init(
19239
19997
  } break;
19240
19998
  case GGML_OPT_LBFGS:
19241
19999
  {
19242
- opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19243
- opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19244
- opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19245
- opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
19246
- opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
20000
+ opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20001
+ opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20002
+ opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20003
+ opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
20004
+ opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
19247
20005
  opt->lbfgs.pf = params.past > 0
19248
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
20006
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
19249
20007
  : NULL;
19250
- opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
19251
- opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
19252
- opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
19253
- opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
20008
+ opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
20009
+ opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
20010
+ opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
20011
+ opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
19254
20012
  ggml_set_zero(opt->lbfgs.x);
19255
20013
  ggml_set_zero(opt->lbfgs.xp);
19256
20014
  ggml_set_zero(opt->lbfgs.g);
@@ -19856,10 +20614,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
19856
20614
  } break;
19857
20615
  case GGUF_TYPE_ARRAY:
19858
20616
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
19859
- };
20617
+ }
19860
20618
  } break;
19861
20619
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
19862
- };
20620
+ }
19863
20621
 
19864
20622
  if (!ok) {
19865
20623
  break;
@@ -20099,27 +20857,27 @@ const char * gguf_type_name(enum gguf_type type) {
20099
20857
  return GGUF_TYPE_NAME[type];
20100
20858
  }
20101
20859
 
20102
- int gguf_get_version(struct gguf_context * ctx) {
20860
+ int gguf_get_version(const struct gguf_context * ctx) {
20103
20861
  return ctx->header.version;
20104
20862
  }
20105
20863
 
20106
- size_t gguf_get_alignment(struct gguf_context * ctx) {
20864
+ size_t gguf_get_alignment(const struct gguf_context * ctx) {
20107
20865
  return ctx->alignment;
20108
20866
  }
20109
20867
 
20110
- size_t gguf_get_data_offset(struct gguf_context * ctx) {
20868
+ size_t gguf_get_data_offset(const struct gguf_context * ctx) {
20111
20869
  return ctx->offset;
20112
20870
  }
20113
20871
 
20114
- void * gguf_get_data(struct gguf_context * ctx) {
20872
+ void * gguf_get_data(const struct gguf_context * ctx) {
20115
20873
  return ctx->data;
20116
20874
  }
20117
20875
 
20118
- int gguf_get_n_kv(struct gguf_context * ctx) {
20876
+ int gguf_get_n_kv(const struct gguf_context * ctx) {
20119
20877
  return ctx->header.n_kv;
20120
20878
  }
20121
20879
 
20122
- int gguf_find_key(struct gguf_context * ctx, const char * key) {
20880
+ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
20123
20881
  // return -1 if key not found
20124
20882
  int keyfound = -1;
20125
20883
 
@@ -20135,85 +20893,101 @@ int gguf_find_key(struct gguf_context * ctx, const char * key) {
20135
20893
  return keyfound;
20136
20894
  }
20137
20895
 
20138
- const char * gguf_get_key(struct gguf_context * ctx, int i) {
20139
- return ctx->kv[i].key.data;
20896
+ const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
20897
+ return ctx->kv[key_id].key.data;
20140
20898
  }
20141
20899
 
20142
- enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) {
20143
- return ctx->kv[i].type;
20900
+ enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
20901
+ return ctx->kv[key_id].type;
20144
20902
  }
20145
20903
 
20146
- enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) {
20147
- return ctx->kv[i].value.arr.type;
20904
+ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
20905
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20906
+ return ctx->kv[key_id].value.arr.type;
20148
20907
  }
20149
20908
 
20150
- const void * gguf_get_arr_data(struct gguf_context * ctx, int i) {
20151
- return ctx->kv[i].value.arr.data;
20909
+ const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
20910
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20911
+ return ctx->kv[key_id].value.arr.data;
20152
20912
  }
20153
20913
 
20154
- const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
20914
+ const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
20915
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20155
20916
  struct gguf_kv * kv = &ctx->kv[key_id];
20156
20917
  struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
20157
20918
  return str->data;
20158
20919
  }
20159
20920
 
20160
- int gguf_get_arr_n(struct gguf_context * ctx, int i) {
20161
- return ctx->kv[i].value.arr.n;
20921
+ int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
20922
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
20923
+ return ctx->kv[key_id].value.arr.n;
20162
20924
  }
20163
20925
 
20164
- uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) {
20165
- return ctx->kv[i].value.uint8;
20926
+ uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
20927
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
20928
+ return ctx->kv[key_id].value.uint8;
20166
20929
  }
20167
20930
 
20168
- int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) {
20169
- return ctx->kv[i].value.int8;
20931
+ int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
20932
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
20933
+ return ctx->kv[key_id].value.int8;
20170
20934
  }
20171
20935
 
20172
- uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) {
20173
- return ctx->kv[i].value.uint16;
20936
+ uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
20937
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
20938
+ return ctx->kv[key_id].value.uint16;
20174
20939
  }
20175
20940
 
20176
- int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) {
20177
- return ctx->kv[i].value.int16;
20941
+ int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
20942
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
20943
+ return ctx->kv[key_id].value.int16;
20178
20944
  }
20179
20945
 
20180
- uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) {
20181
- return ctx->kv[i].value.uint32;
20946
+ uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
20947
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
20948
+ return ctx->kv[key_id].value.uint32;
20182
20949
  }
20183
20950
 
20184
- int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) {
20185
- return ctx->kv[i].value.int32;
20951
+ int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
20952
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
20953
+ return ctx->kv[key_id].value.int32;
20186
20954
  }
20187
20955
 
20188
- float gguf_get_val_f32(struct gguf_context * ctx, int i) {
20189
- return ctx->kv[i].value.float32;
20956
+ float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
20957
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
20958
+ return ctx->kv[key_id].value.float32;
20190
20959
  }
20191
20960
 
20192
- uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) {
20193
- return ctx->kv[i].value.uint64;
20961
+ uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
20962
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
20963
+ return ctx->kv[key_id].value.uint64;
20194
20964
  }
20195
20965
 
20196
- int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) {
20197
- return ctx->kv[i].value.int64;
20966
+ int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
20967
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
20968
+ return ctx->kv[key_id].value.int64;
20198
20969
  }
20199
20970
 
20200
- double gguf_get_val_f64(struct gguf_context * ctx, int i) {
20201
- return ctx->kv[i].value.float64;
20971
+ double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
20972
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
20973
+ return ctx->kv[key_id].value.float64;
20202
20974
  }
20203
20975
 
20204
- bool gguf_get_val_bool(struct gguf_context * ctx, int i) {
20205
- return ctx->kv[i].value.bool_;
20976
+ bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
20977
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
20978
+ return ctx->kv[key_id].value.bool_;
20206
20979
  }
20207
20980
 
20208
- const char * gguf_get_val_str (struct gguf_context * ctx, int i) {
20209
- return ctx->kv[i].value.str.data;
20981
+ const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
20982
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
20983
+ return ctx->kv[key_id].value.str.data;
20210
20984
  }
20211
20985
 
20212
- int gguf_get_n_tensors(struct gguf_context * ctx) {
20986
+ int gguf_get_n_tensors(const struct gguf_context * ctx) {
20213
20987
  return ctx->header.n_tensors;
20214
20988
  }
20215
20989
 
20216
- int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
20990
+ int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
20217
20991
  // return -1 if tensor not found
20218
20992
  int tensorfound = -1;
20219
20993
 
@@ -20229,11 +21003,11 @@ int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
20229
21003
  return tensorfound;
20230
21004
  }
20231
21005
 
20232
- size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i) {
21006
+ size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
20233
21007
  return ctx->infos[i].offset;
20234
21008
  }
20235
21009
 
20236
- char * gguf_get_tensor_name(struct gguf_context * ctx, int i) {
21010
+ char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
20237
21011
  return ctx->infos[i].name.data;
20238
21012
  }
20239
21013
 
@@ -20516,7 +21290,7 @@ static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_si
20516
21290
  buf->offset += el_size;
20517
21291
  }
20518
21292
 
20519
- static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
21293
+ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
20520
21294
  // write header
20521
21295
  gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
20522
21296
  gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
@@ -20571,10 +21345,10 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
20571
21345
  } break;
20572
21346
  case GGUF_TYPE_ARRAY:
20573
21347
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
20574
- };
21348
+ }
20575
21349
  } break;
20576
21350
  case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
20577
- };
21351
+ }
20578
21352
  }
20579
21353
 
20580
21354
  // write tensor infos
@@ -20631,7 +21405,7 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
20631
21405
  }
20632
21406
  }
20633
21407
 
20634
- void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta) {
21408
+ void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
20635
21409
  FILE * file = fopen(fname, "wb");
20636
21410
  if (!file) {
20637
21411
  GGML_ASSERT(false && "failed to open file for writing");
@@ -20648,7 +21422,7 @@ void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only
20648
21422
  fclose(file);
20649
21423
  }
20650
21424
 
20651
- size_t gguf_get_meta_size(struct gguf_context * ctx) {
21425
+ size_t gguf_get_meta_size(const struct gguf_context * ctx) {
20652
21426
  // no allocs - only compute size
20653
21427
  struct gguf_buf buf = gguf_buf_init(0);
20654
21428
 
@@ -20657,7 +21431,7 @@ size_t gguf_get_meta_size(struct gguf_context * ctx) {
20657
21431
  return buf.offset;
20658
21432
  }
20659
21433
 
20660
- void gguf_get_meta_data(struct gguf_context * ctx, void * data) {
21434
+ void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
20661
21435
  struct gguf_buf buf = gguf_buf_init(16*1024);
20662
21436
 
20663
21437
  gguf_write_to_buf(ctx, &buf, true);
@@ -20733,6 +21507,14 @@ int ggml_cpu_has_arm_fma(void) {
20733
21507
  #endif
20734
21508
  }
20735
21509
 
21510
+ int ggml_cpu_has_metal(void) {
21511
+ #if defined(GGML_USE_METAL)
21512
+ return 1;
21513
+ #else
21514
+ return 0;
21515
+ #endif
21516
+ }
21517
+
20736
21518
  int ggml_cpu_has_f16c(void) {
20737
21519
  #if defined(__F16C__)
20738
21520
  return 1;