llama_cpp 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -3603,6 +3603,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3603
3603
  "SUM_ROWS",
3604
3604
  "MEAN",
3605
3605
  "REPEAT",
3606
+ "REPEAT_BACK",
3606
3607
  "ABS",
3607
3608
  "SGN",
3608
3609
  "NEG",
@@ -3616,6 +3617,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3616
3617
  "RMS_NORM_BACK",
3617
3618
 
3618
3619
  "MUL_MAT",
3620
+ "OUT_PROD",
3619
3621
 
3620
3622
  "SCALE",
3621
3623
  "SET",
@@ -3631,6 +3633,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3631
3633
  "DIAG_MASK_INF",
3632
3634
  "DIAG_MASK_ZERO",
3633
3635
  "SOFT_MAX",
3636
+ "SOFT_MAX_BACK",
3634
3637
  "ROPE",
3635
3638
  "ROPE_BACK",
3636
3639
  "ALIBI",
@@ -3640,13 +3643,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3640
3643
 
3641
3644
  "FLASH_ATTN",
3642
3645
  "FLASH_FF",
3646
+ "FLASH_ATTN_BACK",
3643
3647
 
3644
3648
  "MAP_UNARY",
3645
3649
  "MAP_BINARY",
3646
- };
3647
3650
 
3648
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3651
+ "CROSS_ENTROPY_LOSS",
3652
+ "CROSS_ENTROPY_LOSS_BACK",
3653
+ };
3649
3654
 
3655
+ static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
3650
3656
 
3651
3657
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3652
3658
  "none",
@@ -3665,6 +3671,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3665
3671
  "Σx_k",
3666
3672
  "Σx/n",
3667
3673
  "repeat(x)",
3674
+ "repeat_back(x)",
3668
3675
  "abs(x)",
3669
3676
  "sgn(x)",
3670
3677
  "-x",
@@ -3677,6 +3684,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3677
3684
  "rms_norm(x)",
3678
3685
  "rms_norm_back(x)",
3679
3686
 
3687
+ "X*Y",
3680
3688
  "X*Y",
3681
3689
 
3682
3690
  "x*v",
@@ -3693,6 +3701,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3693
3701
  "diag_mask_inf(x)",
3694
3702
  "diag_mask_zero(x)",
3695
3703
  "soft_max(x)",
3704
+ "soft_max_back(x)",
3696
3705
  "rope(x)",
3697
3706
  "rope_back(x)",
3698
3707
  "alibi(x)",
@@ -3702,12 +3711,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3702
3711
 
3703
3712
  "flash_attn(x)",
3704
3713
  "flash_ff(x)",
3714
+ "flash_attn_back(x)",
3705
3715
 
3706
3716
  "f(x)",
3707
3717
  "f(x,y)",
3718
+
3719
+ "cross_entropy_loss(x,y)",
3720
+ "cross_entropy_loss_back(x,y)",
3708
3721
  };
3709
3722
 
3710
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3723
+ static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
3711
3724
 
3712
3725
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3713
3726
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3870,6 +3883,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3870
3883
  (t0->ne[3] == t1->ne[3]);
3871
3884
  }
3872
3885
 
3886
+ static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3887
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3888
+
3889
+ return
3890
+ (t0->ne[1] == t1->ne[1]) &&
3891
+ (t0->ne[2] == t1->ne[2]) &&
3892
+ (t0->ne[3] == t1->ne[3]);
3893
+ }
3894
+
3873
3895
  bool ggml_is_quantized(enum ggml_type type) {
3874
3896
  return GGML_IS_QUANTIZED[type];
3875
3897
  }
@@ -3917,6 +3939,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3917
3939
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3918
3940
  }
3919
3941
 
3942
+ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
3943
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3944
+
3945
+ return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
3946
+ }
3947
+
3920
3948
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
3921
3949
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3922
3950
 
@@ -4693,7 +4721,7 @@ struct ggml_tensor * ggml_add_impl(
4693
4721
 
4694
4722
  bool is_node = false;
4695
4723
 
4696
- if (!inplace && (a->grad || b->grad)) {
4724
+ if (a->grad || b->grad) {
4697
4725
  is_node = true;
4698
4726
  }
4699
4727
 
@@ -4733,7 +4761,7 @@ struct ggml_tensor * ggml_add1_impl(
4733
4761
 
4734
4762
  bool is_node = false;
4735
4763
 
4736
- if (!inplace && (a->grad || b->grad)) {
4764
+ if (a->grad || b->grad) {
4737
4765
  is_node = true;
4738
4766
  }
4739
4767
 
@@ -5159,6 +5187,34 @@ struct ggml_tensor * ggml_repeat(
5159
5187
  return result;
5160
5188
  }
5161
5189
 
5190
+ // ggml_repeat_back
5191
+
5192
+ struct ggml_tensor * ggml_repeat_back(
5193
+ struct ggml_context * ctx,
5194
+ struct ggml_tensor * a,
5195
+ struct ggml_tensor * b) {
5196
+ GGML_ASSERT(ggml_can_repeat(b, a));
5197
+
5198
+ bool is_node = false;
5199
+
5200
+ if (a->grad) {
5201
+ is_node = true;
5202
+ }
5203
+
5204
+ if (ggml_are_same_shape(a, b) && !is_node) {
5205
+ return a;
5206
+ }
5207
+
5208
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
5209
+
5210
+ result->op = GGML_OP_REPEAT_BACK;
5211
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5212
+ result->src0 = a;
5213
+ result->src1 = b;
5214
+
5215
+ return result;
5216
+ }
5217
+
5162
5218
  // ggml_abs
5163
5219
 
5164
5220
  struct ggml_tensor * ggml_abs_impl(
@@ -5536,6 +5592,32 @@ struct ggml_tensor * ggml_mul_mat(
5536
5592
  return result;
5537
5593
  }
5538
5594
 
5595
+ // ggml_out_prod
5596
+
5597
+ struct ggml_tensor * ggml_out_prod(
5598
+ struct ggml_context * ctx,
5599
+ struct ggml_tensor * a,
5600
+ struct ggml_tensor * b) {
5601
+ GGML_ASSERT(ggml_can_out_prod(a, b));
5602
+ GGML_ASSERT(!ggml_is_transposed(a));
5603
+
5604
+ bool is_node = false;
5605
+
5606
+ if (a->grad || b->grad) {
5607
+ is_node = true;
5608
+ }
5609
+
5610
+ const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
5611
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
5612
+
5613
+ result->op = GGML_OP_OUT_PROD;
5614
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5615
+ result->src0 = a;
5616
+ result->src1 = b;
5617
+
5618
+ return result;
5619
+ }
5620
+
5539
5621
  // ggml_scale
5540
5622
 
5541
5623
  struct ggml_tensor * ggml_scale_impl(
@@ -5548,7 +5630,7 @@ struct ggml_tensor * ggml_scale_impl(
5548
5630
 
5549
5631
  bool is_node = false;
5550
5632
 
5551
- if (!inplace && (a->grad || b->grad)) {
5633
+ if (a->grad || b->grad) {
5552
5634
  is_node = true;
5553
5635
  }
5554
5636
 
@@ -5591,7 +5673,7 @@ struct ggml_tensor * ggml_set_impl(
5591
5673
 
5592
5674
  bool is_node = false;
5593
5675
 
5594
- if (!inplace && (a->grad || b->grad)) {
5676
+ if (a->grad || b->grad) {
5595
5677
  is_node = true;
5596
5678
  }
5597
5679
 
@@ -5913,10 +5995,6 @@ struct ggml_tensor * ggml_view_1d(
5913
5995
  result->src1 = NULL;
5914
5996
  result->opt[0] = offs;
5915
5997
 
5916
- if (is_node) {
5917
- memcpy(result->padding, &offset, sizeof(offset));
5918
- }
5919
-
5920
5998
  return result;
5921
5999
  }
5922
6000
 
@@ -5957,10 +6035,6 @@ struct ggml_tensor * ggml_view_2d(
5957
6035
  result->src1 = NULL;
5958
6036
  result->opt[0] = offs;
5959
6037
 
5960
- if (is_node) {
5961
- memcpy(result->padding, &offset, sizeof(offset));
5962
- }
5963
-
5964
6038
  return result;
5965
6039
  }
5966
6040
 
@@ -6003,10 +6077,6 @@ struct ggml_tensor * ggml_view_3d(
6003
6077
  result->src1 = NULL;
6004
6078
  result->opt[0] = offs;
6005
6079
 
6006
- if (is_node) {
6007
- memcpy(result->padding, &offset, sizeof(offset));
6008
- }
6009
-
6010
6080
  return result;
6011
6081
  }
6012
6082
 
@@ -6051,10 +6121,6 @@ struct ggml_tensor * ggml_view_4d(
6051
6121
  result->src1 = NULL;
6052
6122
  result->opt[0] = offs;
6053
6123
 
6054
- if (is_node) {
6055
- memcpy(result->padding, &offset, sizeof(offset));
6056
- }
6057
-
6058
6124
  return result;
6059
6125
  }
6060
6126
 
@@ -6116,10 +6182,18 @@ struct ggml_tensor * ggml_permute(
6116
6182
  result->src1 = NULL;
6117
6183
 
6118
6184
  if (is_node) {
6119
- result->padding[0] = axis0;
6120
- result->padding[1] = axis1;
6121
- result->padding[2] = axis2;
6122
- result->padding[3] = axis3;
6185
+ ggml_scratch_save(ctx);
6186
+
6187
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6188
+
6189
+ ((int32_t *) b->data)[0] = axis0;
6190
+ ((int32_t *) b->data)[1] = axis1;
6191
+ ((int32_t *) b->data)[2] = axis2;
6192
+ ((int32_t *) b->data)[3] = axis3;
6193
+
6194
+ ggml_scratch_load(ctx);
6195
+
6196
+ result->opt[0] = b;
6123
6197
  }
6124
6198
 
6125
6199
  return result;
@@ -6359,6 +6433,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
6359
6433
  return ggml_soft_max_impl(ctx, a, true);
6360
6434
  }
6361
6435
 
6436
+
6437
+ // ggml_soft_max_back
6438
+
6439
+ struct ggml_tensor * ggml_soft_max_back_impl(
6440
+ struct ggml_context * ctx,
6441
+ struct ggml_tensor * a,
6442
+ struct ggml_tensor * b,
6443
+ bool inplace) {
6444
+ bool is_node = false;
6445
+
6446
+ if (a->grad || b->grad) {
6447
+ is_node = true; // TODO : implement backward pass
6448
+ }
6449
+
6450
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6451
+
6452
+ result->op = GGML_OP_SOFT_MAX_BACK;
6453
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6454
+ result->src0 = a;
6455
+ result->src1 = b;
6456
+
6457
+ return result;
6458
+ }
6459
+
6460
+ struct ggml_tensor * ggml_soft_max_back(
6461
+ struct ggml_context * ctx,
6462
+ struct ggml_tensor * a,
6463
+ struct ggml_tensor * b) {
6464
+ return ggml_soft_max_back_impl(ctx, a, b, false);
6465
+ }
6466
+
6467
+ struct ggml_tensor * ggml_soft_max_back_inplace(
6468
+ struct ggml_context * ctx,
6469
+ struct ggml_tensor * a,
6470
+ struct ggml_tensor * b) {
6471
+ return ggml_soft_max_back_impl(ctx, a, b, true);
6472
+ }
6473
+
6362
6474
  // ggml_rope
6363
6475
 
6364
6476
  struct ggml_tensor * ggml_rope_impl(
@@ -6371,7 +6483,7 @@ struct ggml_tensor * ggml_rope_impl(
6371
6483
  GGML_ASSERT(n_past >= 0);
6372
6484
  bool is_node = false;
6373
6485
 
6374
- if (!inplace && a->grad) {
6486
+ if (a->grad) {
6375
6487
  is_node = true;
6376
6488
  }
6377
6489
 
@@ -6425,8 +6537,7 @@ struct ggml_tensor * ggml_rope_back(
6425
6537
  bool is_node = false;
6426
6538
 
6427
6539
  if (a->grad) {
6428
- GGML_ASSERT(false); // TODO: implement backward
6429
- is_node = true;
6540
+ is_node = false; // TODO: implement backward
6430
6541
  }
6431
6542
 
6432
6543
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6591,7 +6702,6 @@ struct ggml_tensor * ggml_flash_attn(
6591
6702
  bool is_node = false;
6592
6703
 
6593
6704
  if (q->grad || k->grad || v->grad) {
6594
- GGML_ASSERT(false); // TODO: implement backward
6595
6705
  is_node = true;
6596
6706
  }
6597
6707
 
@@ -6623,7 +6733,6 @@ struct ggml_tensor * ggml_flash_ff(
6623
6733
  bool is_node = false;
6624
6734
 
6625
6735
  if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6626
- GGML_ASSERT(false); // TODO: implement backward
6627
6736
  is_node = true;
6628
6737
  }
6629
6738
 
@@ -6641,6 +6750,71 @@ struct ggml_tensor * ggml_flash_ff(
6641
6750
  return result;
6642
6751
  }
6643
6752
 
6753
+ // ggml_flash_attn_back
6754
+
6755
+ struct ggml_tensor * ggml_flash_attn_back(
6756
+ struct ggml_context * ctx,
6757
+ struct ggml_tensor * q,
6758
+ struct ggml_tensor * k,
6759
+ struct ggml_tensor * v,
6760
+ struct ggml_tensor * d,
6761
+ bool masked) {
6762
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6763
+ // TODO: check if vT can be multiplied by (k*qT)
6764
+
6765
+ // d shape [D,N,ne2,ne3]
6766
+ // q shape [D,N,ne2,ne3]
6767
+ // k shape [D,M,ne2,ne3]
6768
+ // v shape [M,D,ne2,ne3]
6769
+
6770
+ const int64_t D = q->ne[0];
6771
+ const int64_t N = q->ne[1];
6772
+ const int64_t M = k->ne[1];
6773
+ const int64_t ne2 = q->ne[2];
6774
+ const int64_t ne3 = q->ne[3];
6775
+
6776
+ GGML_ASSERT(k->ne[0] == D);
6777
+ GGML_ASSERT(v->ne[0] == M);
6778
+ GGML_ASSERT(v->ne[1] == D);
6779
+ GGML_ASSERT(d->ne[0] == D);
6780
+ GGML_ASSERT(d->ne[1] == N);
6781
+ GGML_ASSERT(k->ne[2] == ne2);
6782
+ GGML_ASSERT(k->ne[3] == ne3);
6783
+ GGML_ASSERT(v->ne[2] == ne2);
6784
+ GGML_ASSERT(v->ne[3] == ne3);
6785
+ GGML_ASSERT(d->ne[2] == ne2);
6786
+ GGML_ASSERT(d->ne[3] == ne3);
6787
+
6788
+ bool is_node = false;
6789
+
6790
+ if (q->grad || k->grad || v->grad) {
6791
+ // when using this operation (in backwards pass) these grads are set.
6792
+ // we don't want to create (big) grad of our result, so is_node is false.
6793
+ is_node = false;
6794
+ }
6795
+
6796
+ // store gradients of q, k and v as continuous tensors concatenated in result.
6797
+ // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
6798
+ // gradq->data = result->data
6799
+ // gradk->data = result->data + nb0*D*N*ne2*ne3
6800
+ // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
6801
+ // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
6802
+ int64_t ne[4] = {D,M+N+M,ne2,ne3};
6803
+
6804
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6805
+
6806
+ result->op = GGML_OP_FLASH_ATTN_BACK;
6807
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6808
+ result->src0 = q;
6809
+ result->src1 = k;
6810
+ result->opt[0] = v;
6811
+ result->opt[1] = d;
6812
+ result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
6813
+
6814
+ return result;
6815
+ }
6816
+
6817
+
6644
6818
  // ggml_map_unary
6645
6819
 
6646
6820
  struct ggml_tensor * ggml_map_unary_impl_f32(
@@ -6725,6 +6899,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
6725
6899
  return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
6726
6900
  }
6727
6901
 
6902
+ // ggml_cross_entropy_loss
6903
+
6904
+ struct ggml_tensor * ggml_cross_entropy_loss(
6905
+ struct ggml_context * ctx,
6906
+ struct ggml_tensor * a,
6907
+ struct ggml_tensor * b) {
6908
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6909
+ bool is_node = false;
6910
+
6911
+ if (a->grad || b->grad) {
6912
+ is_node = true;
6913
+ }
6914
+
6915
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
6916
+
6917
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS;
6918
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6919
+ result->src0 = a;
6920
+ result->src1 = b;
6921
+
6922
+ return result;
6923
+ }
6924
+
6925
+ // ggml_cross_entropy_loss_back
6926
+
6927
+ struct ggml_tensor * ggml_cross_entropy_loss_back(
6928
+ struct ggml_context * ctx,
6929
+ struct ggml_tensor * a,
6930
+ struct ggml_tensor * b,
6931
+ struct ggml_tensor * c) {
6932
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6933
+ GGML_ASSERT(ggml_is_scalar(c));
6934
+
6935
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6936
+
6937
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
6938
+ result->grad = NULL;
6939
+ result->src0 = a;
6940
+ result->src1 = b;
6941
+ result->opt[0] = c;
6942
+
6943
+ return result;
6944
+ }
6945
+
6728
6946
  ////////////////////////////////////////////////////////////////////////////////
6729
6947
 
6730
6948
  void ggml_set_param(
@@ -8875,6 +9093,99 @@ static void ggml_compute_forward_repeat(
8875
9093
  }
8876
9094
  }
8877
9095
 
9096
+ // ggml_compute_forward_repeat_back
9097
+
9098
+ static void ggml_compute_forward_repeat_back_f32(
9099
+ const struct ggml_compute_params * params,
9100
+ const struct ggml_tensor * src0,
9101
+ struct ggml_tensor * dst) {
9102
+ GGML_ASSERT(params->ith == 0);
9103
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
9104
+
9105
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9106
+ return;
9107
+ }
9108
+
9109
+ const int64_t ne0 = dst->ne[0];
9110
+ const int64_t ne1 = dst->ne[1];
9111
+ const int64_t ne2 = dst->ne[2];
9112
+ const int64_t ne3 = dst->ne[3];
9113
+
9114
+ const int64_t ne00 = src0->ne[0];
9115
+ const int64_t ne01 = src0->ne[1];
9116
+ const int64_t ne02 = src0->ne[2];
9117
+ const int64_t ne03 = src0->ne[3];
9118
+
9119
+ const size_t nb0 = dst->nb[0];
9120
+ const size_t nb1 = dst->nb[1];
9121
+ const size_t nb2 = dst->nb[2];
9122
+ const size_t nb3 = dst->nb[3];
9123
+
9124
+ const size_t nb00 = src0->nb[0];
9125
+ const size_t nb01 = src0->nb[1];
9126
+ const size_t nb02 = src0->nb[2];
9127
+ const size_t nb03 = src0->nb[3];
9128
+
9129
+ // guaranteed to be an integer due to the check in ggml_can_repeat
9130
+ const int nr0 = (int)(ne00/ne0);
9131
+ const int nr1 = (int)(ne01/ne1);
9132
+ const int nr2 = (int)(ne02/ne2);
9133
+ const int nr3 = (int)(ne03/ne3);
9134
+
9135
+ // TODO: support for transposed / permuted tensors
9136
+ GGML_ASSERT(nb0 == sizeof(float));
9137
+ GGML_ASSERT(nb00 == sizeof(float));
9138
+
9139
+ if (ggml_is_contiguous(dst)) {
9140
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9141
+ } else {
9142
+ for (int k3 = 0; k3 < ne3; k3++) {
9143
+ for (int k2 = 0; k2 < ne2; k2++) {
9144
+ for (int k1 = 0; k1 < ne1; k1++) {
9145
+ ggml_vec_set_f32(ne0,
9146
+ (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
9147
+ 0);
9148
+ }
9149
+ }
9150
+ }
9151
+ }
9152
+
9153
+ // TODO: maybe this is not optimal?
9154
+ for (int i3 = 0; i3 < nr3; i3++) {
9155
+ for (int k3 = 0; k3 < ne3; k3++) {
9156
+ for (int i2 = 0; i2 < nr2; i2++) {
9157
+ for (int k2 = 0; k2 < ne2; k2++) {
9158
+ for (int i1 = 0; i1 < nr1; i1++) {
9159
+ for (int k1 = 0; k1 < ne1; k1++) {
9160
+ for (int i0 = 0; i0 < nr0; i0++) {
9161
+ ggml_vec_acc_f32(ne0,
9162
+ (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
9163
+ (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
9164
+ }
9165
+ }
9166
+ }
9167
+ }
9168
+ }
9169
+ }
9170
+ }
9171
+ }
9172
+
9173
+ static void ggml_compute_forward_repeat_back(
9174
+ const struct ggml_compute_params * params,
9175
+ const struct ggml_tensor * src0,
9176
+ struct ggml_tensor * dst) {
9177
+ switch (src0->type) {
9178
+ case GGML_TYPE_F32:
9179
+ {
9180
+ ggml_compute_forward_repeat_back_f32(params, src0, dst);
9181
+ } break;
9182
+ default:
9183
+ {
9184
+ GGML_ASSERT(false);
9185
+ } break;
9186
+ }
9187
+ }
9188
+
8878
9189
  // ggml_compute_forward_abs
8879
9190
 
8880
9191
  static void ggml_compute_forward_abs_f32(
@@ -10249,18 +10560,188 @@ static void ggml_compute_forward_mul_mat(
10249
10560
  }
10250
10561
  }
10251
10562
 
10252
- // ggml_compute_forward_scale
10563
+ // ggml_compute_forward_out_prod
10253
10564
 
10254
- static void ggml_compute_forward_scale_f32(
10565
+
10566
+ static void ggml_compute_forward_out_prod_f32(
10255
10567
  const struct ggml_compute_params * params,
10256
10568
  const struct ggml_tensor * src0,
10257
10569
  const struct ggml_tensor * src1,
10258
- struct ggml_tensor * dst) {
10259
- GGML_ASSERT(ggml_is_contiguous(src0));
10260
- GGML_ASSERT(ggml_is_contiguous(dst));
10261
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
10262
- GGML_ASSERT(ggml_is_scalar(src1));
10263
-
10570
+ struct ggml_tensor * dst) {
10571
+ int64_t t0 = ggml_perf_time_us();
10572
+ UNUSED(t0);
10573
+
10574
+ const int64_t ne00 = src0->ne[0];
10575
+ const int64_t ne01 = src0->ne[1];
10576
+ const int64_t ne02 = src0->ne[2];
10577
+ const int64_t ne03 = src0->ne[3];
10578
+
10579
+ const int64_t ne10 = src1->ne[0];
10580
+ //const int64_t ne11 = src1->ne[1];
10581
+ const int64_t ne12 = src1->ne[2];
10582
+ const int64_t ne13 = src1->ne[3];
10583
+
10584
+ const int64_t ne0 = dst->ne[0];
10585
+ const int64_t ne1 = dst->ne[1];
10586
+ const int64_t ne2 = dst->ne[2];
10587
+ const int64_t ne3 = dst->ne[3];
10588
+
10589
+ const int nb00 = src0->nb[0];
10590
+ const int nb01 = src0->nb[1];
10591
+ const int nb02 = src0->nb[2];
10592
+ const int nb03 = src0->nb[3];
10593
+
10594
+ const int nb10 = src1->nb[0];
10595
+ const int nb11 = src1->nb[1];
10596
+ const int nb12 = src1->nb[2];
10597
+ const int nb13 = src1->nb[3];
10598
+
10599
+ const int nb0 = dst->nb[0];
10600
+ const int nb1 = dst->nb[1];
10601
+ const int nb2 = dst->nb[2];
10602
+ const int nb3 = dst->nb[3];
10603
+
10604
+ const int ith = params->ith;
10605
+ const int nth = params->nth;
10606
+
10607
+ GGML_ASSERT(ne02 == ne12);
10608
+ GGML_ASSERT(ne03 == ne13);
10609
+ GGML_ASSERT(ne2 == ne12);
10610
+ GGML_ASSERT(ne3 == ne13);
10611
+
10612
+ // we don't support permuted src0 or src1
10613
+ GGML_ASSERT(nb00 == sizeof(float));
10614
+
10615
+ // dst cannot be transposed or permuted
10616
+ GGML_ASSERT(nb0 == sizeof(float));
10617
+ // GGML_ASSERT(nb0 <= nb1);
10618
+ // GGML_ASSERT(nb1 <= nb2);
10619
+ // GGML_ASSERT(nb2 <= nb3);
10620
+
10621
+ GGML_ASSERT(ne0 == ne00);
10622
+ GGML_ASSERT(ne1 == ne10);
10623
+ GGML_ASSERT(ne2 == ne02);
10624
+ GGML_ASSERT(ne3 == ne03);
10625
+
10626
+ // nb01 >= nb00 - src0 is not transposed
10627
+ // compute by src0 rows
10628
+
10629
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
10630
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10631
+
10632
+ if (params->type == GGML_TASK_INIT) {
10633
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
10634
+ return;
10635
+ }
10636
+
10637
+ if (params->type == GGML_TASK_FINALIZE) {
10638
+ return;
10639
+ }
10640
+
10641
+ // parallelize by last three dimensions
10642
+
10643
+ // total rows in dst
10644
+ const int64_t nr = ne1*ne2*ne3;
10645
+
10646
+ // rows per thread
10647
+ const int64_t dr = (nr + nth - 1)/nth;
10648
+
10649
+ // row range for this thread
10650
+ const int64_t ir0 = dr*ith;
10651
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10652
+
10653
+ // dst[:,:,:,:] = 0
10654
+ // for i2,i3:
10655
+ // for i1:
10656
+ // for i01:
10657
+ // for i0:
10658
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
10659
+
10660
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10661
+ // dst indices
10662
+ const int64_t i3 = ir/(ne2*ne1);
10663
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
10664
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
10665
+
10666
+ const int64_t i02 = i2;
10667
+ const int64_t i03 = i3;
10668
+
10669
+ //const int64_t i10 = i1;
10670
+ const int64_t i12 = i2;
10671
+ const int64_t i13 = i3;
10672
+
10673
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
10674
+ const int64_t i11 = i01;
10675
+
10676
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
10677
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
10678
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
10679
+
10680
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
10681
+ // for (int64_t i0 = 0; i0 < ne0; ++i0) {
10682
+ // d[i0] += s0[i0] * s1[i1];
10683
+ // }
10684
+ }
10685
+ }
10686
+
10687
+ //int64_t t1 = ggml_perf_time_us();
10688
+ //static int64_t acc = 0;
10689
+ //acc += t1 - t0;
10690
+ //if (t1 - t0 > 10) {
10691
+ // printf("\n");
10692
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
10693
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
10694
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
10695
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
10696
+
10697
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
10698
+ //}
10699
+ }
10700
+
10701
+ static void ggml_compute_forward_out_prod(
10702
+ const struct ggml_compute_params * params,
10703
+ const struct ggml_tensor * src0,
10704
+ const struct ggml_tensor * src1,
10705
+ struct ggml_tensor * dst) {
10706
+ switch (src0->type) {
10707
+ case GGML_TYPE_Q4_0:
10708
+ case GGML_TYPE_Q4_1:
10709
+ case GGML_TYPE_Q5_0:
10710
+ case GGML_TYPE_Q5_1:
10711
+ case GGML_TYPE_Q8_0:
10712
+ case GGML_TYPE_Q8_1:
10713
+ {
10714
+ GGML_ASSERT(false); // todo
10715
+ // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10716
+ } break;
10717
+ case GGML_TYPE_F16:
10718
+ {
10719
+ GGML_ASSERT(false); // todo
10720
+ // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
10721
+ } break;
10722
+ case GGML_TYPE_F32:
10723
+ {
10724
+ ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
10725
+ } break;
10726
+ default:
10727
+ {
10728
+ GGML_ASSERT(false);
10729
+ } break;
10730
+ }
10731
+ }
10732
+
10733
+ // ggml_compute_forward_scale
10734
+
10735
+ static void ggml_compute_forward_scale_f32(
10736
+ const struct ggml_compute_params * params,
10737
+ const struct ggml_tensor * src0,
10738
+ const struct ggml_tensor * src1,
10739
+ struct ggml_tensor * dst) {
10740
+ GGML_ASSERT(ggml_is_contiguous(src0));
10741
+ GGML_ASSERT(ggml_is_contiguous(dst));
10742
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
10743
+ GGML_ASSERT(ggml_is_scalar(src1));
10744
+
10264
10745
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10265
10746
  return;
10266
10747
  }
@@ -10671,7 +11152,11 @@ static void ggml_compute_forward_get_rows_back_f32(
10671
11152
  GGML_ASSERT(ggml_is_contiguous(opt0));
10672
11153
  GGML_ASSERT(ggml_is_contiguous(dst));
10673
11154
 
10674
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
11155
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
11156
+
11157
+ if (params->type == GGML_TASK_INIT) {
11158
+ memset(dst->data, 0, ggml_nbytes(dst));
11159
+ }
10675
11160
 
10676
11161
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10677
11162
  return;
@@ -10815,8 +11300,8 @@ static void ggml_compute_forward_diag_mask_f32(
10815
11300
  const struct ggml_tensor * src1,
10816
11301
  struct ggml_tensor * dst,
10817
11302
  const float value) {
10818
- assert(src1->type == GGML_TYPE_I32);
10819
- assert(ggml_nelements(src1) == 2);
11303
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11304
+ GGML_ASSERT(ggml_nelements(src1) == 2);
10820
11305
 
10821
11306
  const int ith = params->ith;
10822
11307
  const int nth = params->nth;
@@ -10824,7 +11309,7 @@ static void ggml_compute_forward_diag_mask_f32(
10824
11309
  const int n_past = ((int32_t *) src1->data)[0];
10825
11310
  const bool inplace = (bool)((int32_t *) src1->data)[1];
10826
11311
 
10827
- assert(n_past >= 0);
11312
+ GGML_ASSERT(n_past >= 0);
10828
11313
 
10829
11314
  if (!inplace && (params->type == GGML_TASK_INIT)) {
10830
11315
  // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -10848,8 +11333,8 @@ static void ggml_compute_forward_diag_mask_f32(
10848
11333
  const int nr = src0->ne[1];
10849
11334
  const int nz = n/nr;
10850
11335
 
10851
- assert( dst->nb[0] == sizeof(float));
10852
- assert(src0->nb[0] == sizeof(float));
11336
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
11337
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
10853
11338
 
10854
11339
  for (int k = 0; k < nz; k++) {
10855
11340
  for (int j = ith; j < nr; j += nth) {
@@ -10985,6 +11470,101 @@ static void ggml_compute_forward_soft_max(
10985
11470
  }
10986
11471
  }
10987
11472
 
11473
+ // ggml_compute_forward_soft_max_back
11474
+
11475
+ static void ggml_compute_forward_soft_max_back_f32(
11476
+ const struct ggml_compute_params * params,
11477
+ const struct ggml_tensor * src0,
11478
+ const struct ggml_tensor * src1,
11479
+ struct ggml_tensor * dst) {
11480
+ GGML_ASSERT(ggml_is_contiguous(src0));
11481
+ GGML_ASSERT(ggml_is_contiguous(src1));
11482
+ GGML_ASSERT(ggml_is_contiguous(dst));
11483
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
11484
+ GGML_ASSERT(ggml_are_same_shape(src1, dst));
11485
+
11486
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11487
+ return;
11488
+ }
11489
+
11490
+ // TODO: handle transposed/permuted matrices
11491
+
11492
+ const int ith = params->ith;
11493
+ const int nth = params->nth;
11494
+
11495
+ const int nc = src0->ne[0];
11496
+ const int nr = ggml_nrows(src0);
11497
+
11498
+ // rows per thread
11499
+ const int dr = (nr + nth - 1)/nth;
11500
+
11501
+ // row range for this thread
11502
+ const int ir0 = dr*ith;
11503
+ const int ir1 = MIN(ir0 + dr, nr);
11504
+
11505
+ for (int i1 = ir0; i1 < ir1; i1++) {
11506
+ float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
11507
+ float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
11508
+ float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
11509
+
11510
+ #ifndef NDEBUG
11511
+ for (int i = 0; i < nc; ++i) {
11512
+ //printf("p[%d] = %f\n", i, p[i]);
11513
+ assert(!isnan(dy[i]));
11514
+ assert(!isnan(y[i]));
11515
+ }
11516
+ #endif
11517
+ // Jii = yi - yi*yi
11518
+ // Jij = -yi*yj
11519
+ // J = diag(y)-y.T*y
11520
+ // dx = J * dy
11521
+ // dxk = sum_i(Jki * dyi)
11522
+ // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
11523
+ // dxk = sum_i(-yk*yi * dyi) + yk*dyk
11524
+ // dxk = -yk * sum_i(yi * dyi) + yk*dyk
11525
+ // dxk = -yk * dot(y, dy) + yk*dyk
11526
+ // dxk = yk * (- dot(y, dy) + dyk)
11527
+ // dxk = yk * (dyk - dot(y, dy))
11528
+ //
11529
+ // post-order:
11530
+ // dot_y_dy := dot(y, dy)
11531
+ // dx := dy
11532
+ // dx := dx - dot_y_dy
11533
+ // dx := dx * y
11534
+
11535
+ // linear runtime, no additional memory
11536
+ float dot_y_dy = 0;
11537
+ ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
11538
+ ggml_vec_cpy_f32 (nc, dx, dy);
11539
+ ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
11540
+ ggml_vec_mul_f32 (nc, dx, dx, y);
11541
+
11542
+ #ifndef NDEBUG
11543
+ for (int i = 0; i < nc; ++i) {
11544
+ assert(!isnan(dx[i]));
11545
+ assert(!isinf(dx[i]));
11546
+ }
11547
+ #endif
11548
+ }
11549
+ }
11550
+
11551
+ static void ggml_compute_forward_soft_max_back(
11552
+ const struct ggml_compute_params * params,
11553
+ const struct ggml_tensor * src0,
11554
+ const struct ggml_tensor * src1,
11555
+ struct ggml_tensor * dst) {
11556
+ switch (src0->type) {
11557
+ case GGML_TYPE_F32:
11558
+ {
11559
+ ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
11560
+ } break;
11561
+ default:
11562
+ {
11563
+ GGML_ASSERT(false);
11564
+ } break;
11565
+ }
11566
+ }
11567
+
10988
11568
  // ggml_compute_forward_alibi
10989
11569
 
10990
11570
  static void ggml_compute_forward_alibi_f32(
@@ -12938,42 +13518,616 @@ static void ggml_compute_forward_flash_ff(
12938
13518
  }
12939
13519
  }
12940
13520
 
12941
- // ggml_compute_forward_map_unary
13521
+ // ggml_compute_forward_flash_attn_back
12942
13522
 
12943
- static void ggml_compute_forward_map_unary_f32(
13523
+ static void ggml_compute_forward_flash_attn_back_f32(
12944
13524
  const struct ggml_compute_params * params,
12945
- const struct ggml_tensor * src0,
12946
- struct ggml_tensor * dst,
12947
- const ggml_unary_op_f32_t fun) {
12948
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
13525
+ const struct ggml_tensor * q,
13526
+ const struct ggml_tensor * k,
13527
+ const struct ggml_tensor * v,
13528
+ const struct ggml_tensor * d,
13529
+ const bool masked,
13530
+ struct ggml_tensor * dst) {
13531
+ int64_t t0 = ggml_perf_time_us();
13532
+ UNUSED(t0);
12949
13533
 
12950
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12951
- return;
12952
- }
13534
+ const int64_t neq0 = q->ne[0];
13535
+ const int64_t neq1 = q->ne[1];
13536
+ const int64_t neq2 = q->ne[2];
13537
+ const int64_t neq3 = q->ne[3];
12953
13538
 
12954
- const int n = ggml_nrows(src0);
12955
- const int nc = src0->ne[0];
13539
+ const int64_t nek0 = k->ne[0];
13540
+ const int64_t nek1 = k->ne[1];
13541
+ //const int64_t nek2 = k->ne[2];
13542
+ //const int64_t nek3 = k->ne[3];
12956
13543
 
12957
- assert( dst->nb[0] == sizeof(float));
12958
- assert(src0->nb[0] == sizeof(float));
13544
+ const int64_t nev0 = v->ne[0];
13545
+ const int64_t nev1 = v->ne[1];
13546
+ //const int64_t nev2 = v->ne[2];
13547
+ //const int64_t nev3 = v->ne[3];
12959
13548
 
12960
- for (int i = 0; i < n; i++) {
12961
- fun(nc,
12962
- (float *) ((char *) dst->data + i*( dst->nb[1])),
12963
- (float *) ((char *) src0->data + i*(src0->nb[1])));
12964
- }
12965
- }
13549
+ const int64_t ned0 = d->ne[0];
13550
+ const int64_t ned1 = d->ne[1];
13551
+ //const int64_t ned2 = d->ne[2];
13552
+ //const int64_t ned3 = d->ne[3];
12966
13553
 
13554
+ const int64_t ne0 = dst->ne[0];
13555
+ const int64_t ne1 = dst->ne[1];
13556
+ const int64_t ne2 = dst->ne[2];
13557
+ const int64_t ne3 = dst->ne[3];
12967
13558
 
12968
- static void ggml_compute_forward_map_unary(
12969
- const struct ggml_compute_params * params,
12970
- const struct ggml_tensor * src0,
12971
- struct ggml_tensor * dst,
12972
- const ggml_unary_op_f32_t fun) {
13559
+ const int nbk0 = k->nb[0];
13560
+ const int nbk1 = k->nb[1];
13561
+ const int nbk2 = k->nb[2];
13562
+ const int nbk3 = k->nb[3];
13563
+
13564
+ const int nbq0 = q->nb[0];
13565
+ const int nbq1 = q->nb[1];
13566
+ const int nbq2 = q->nb[2];
13567
+ const int nbq3 = q->nb[3];
13568
+
13569
+ const int nbv0 = v->nb[0];
13570
+ const int nbv1 = v->nb[1];
13571
+ const int nbv2 = v->nb[2];
13572
+ const int nbv3 = v->nb[3];
13573
+
13574
+ const int nbd0 = d->nb[0];
13575
+ const int nbd1 = d->nb[1];
13576
+ const int nbd2 = d->nb[2];
13577
+ const int nbd3 = d->nb[3];
13578
+
13579
+ const int nb0 = dst->nb[0];
13580
+ const int nb1 = dst->nb[1];
13581
+ const int nb2 = dst->nb[2];
13582
+ const int nb3 = dst->nb[3];
13583
+
13584
+ const int ith = params->ith;
13585
+ const int nth = params->nth;
13586
+
13587
+ const int64_t D = neq0;
13588
+ const int64_t N = neq1;
13589
+ const int64_t P = nek1 - N;
13590
+ const int64_t M = P + N;
13591
+
13592
+ const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
13593
+ const int mxDM = MAX(D, Mup);
13594
+
13595
+ // GGML_ASSERT(ne0 == D);
13596
+ // GGML_ASSERT(ne1 == N);
13597
+ GGML_ASSERT(P >= 0);
13598
+
13599
+ GGML_ASSERT(nbq0 == sizeof(float));
13600
+ GGML_ASSERT(nbk0 == sizeof(float));
13601
+ GGML_ASSERT(nbv0 == sizeof(float));
13602
+
13603
+ GGML_ASSERT(neq0 == D);
13604
+ GGML_ASSERT(nek0 == D);
13605
+ GGML_ASSERT(nev1 == D);
13606
+ GGML_ASSERT(ned0 == D);
13607
+
13608
+ GGML_ASSERT(neq1 == N);
13609
+ GGML_ASSERT(nek1 == N + P);
13610
+ GGML_ASSERT(nev1 == D);
13611
+ GGML_ASSERT(ned1 == N);
13612
+
13613
+ // dst cannot be transposed or permuted
13614
+ GGML_ASSERT(nb0 == sizeof(float));
13615
+ GGML_ASSERT(nb0 <= nb1);
13616
+ GGML_ASSERT(nb1 <= nb2);
13617
+ GGML_ASSERT(nb2 <= nb3);
13618
+
13619
+ if (params->type == GGML_TASK_INIT) {
13620
+ if (ith == 0) {
13621
+ memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
13622
+ }
13623
+ return;
13624
+ }
13625
+
13626
+ if (params->type == GGML_TASK_FINALIZE) {
13627
+ return;
13628
+ }
13629
+
13630
+ // parallelize by q rows using ggml_vec_dot_f32
13631
+
13632
+ // total rows in q
13633
+ const int nr = neq2*neq3;
13634
+
13635
+ // rows per thread
13636
+ const int dr = (nr + nth - 1)/nth;
13637
+
13638
+ // row range for this thread
13639
+ const int ir0 = dr*ith;
13640
+ const int ir1 = MIN(ir0 + dr, nr);
13641
+
13642
+ const float scale = 1.0f/sqrtf(D);
13643
+
13644
+ //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
13645
+
13646
+ for (int ir = ir0; ir < ir1; ++ir) {
13647
+ // q indices
13648
+ const int iq3 = ir/(neq2);
13649
+ const int iq2 = ir - iq3*neq2;
13650
+ for ( int iq1 = 0; iq1 < neq1; ++iq1) {
13651
+
13652
+
13653
+ // not sure about CACHE_LINE_SIZE_F32..
13654
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
13655
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
13656
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
13657
+
13658
+ for (int i = M; i < Mup; ++i) {
13659
+ S[i] = -INFINITY;
13660
+ }
13661
+
13662
+ for (int64_t ic = 0; ic < nek1; ++ic) {
13663
+ // k indices
13664
+ const int ik3 = iq3;
13665
+ const int ik2 = iq2;
13666
+ const int ik1 = ic;
13667
+
13668
+ // S indices
13669
+ const int i1 = ik1;
13670
+
13671
+ ggml_vec_dot_f32(neq0,
13672
+ S + i1,
13673
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
13674
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
13675
+ }
13676
+
13677
+ // scale
13678
+ ggml_vec_scale_f32(nek1, S, scale);
13679
+
13680
+ if (masked) {
13681
+ for (int64_t i = P; i < M; i++) {
13682
+ if (i > P + iq1) {
13683
+ S[i] = -INFINITY;
13684
+ }
13685
+ }
13686
+ }
13687
+
13688
+ // softmax
13689
+ {
13690
+ float max = -INFINITY;
13691
+ ggml_vec_max_f32(M, &max, S);
13692
+
13693
+ ggml_float sum = 0.0;
13694
+ {
13695
+ #ifdef GGML_SOFT_MAX_ACCELERATE
13696
+ max = -max;
13697
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
13698
+ vvexpf(SM, SM, &Mup);
13699
+ ggml_vec_sum_f32(Mup, &sum, SM);
13700
+ #else
13701
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL];
13702
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
13703
+
13704
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
13705
+ float * SR = S + i;
13706
+ float * SW = SM + i;
13707
+
13708
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
13709
+ if (SR[j] == -INFINITY) {
13710
+ SW[j] = 0.0f;
13711
+ } else {
13712
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
13713
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
13714
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
13715
+ sump[j] += (ggml_float)val;
13716
+ SW[j] = val;
13717
+ }
13718
+ }
13719
+ }
13720
+
13721
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
13722
+ sum += sump[i];
13723
+ }
13724
+ #endif
13725
+ }
13726
+
13727
+ assert(sum > 0.0);
13728
+
13729
+ sum = 1.0/sum;
13730
+ ggml_vec_scale_f32(M, SM, sum);
13731
+
13732
+ }
13733
+
13734
+ // step-by-step explanation
13735
+ {
13736
+ // forward-process shape grads from backward process
13737
+ // parallel_for iq2,iq3:
13738
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
13739
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
13740
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
13741
+ // for iq1:
13742
+ // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
13743
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
13744
+ // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
13745
+ // S0 = -Inf [D,1,1,1]
13746
+ // ~S1[i] = dot(kcur[:D,i], qcur)
13747
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
13748
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
13749
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13750
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
13751
+ // ~S5[i] = dot(vcur[:,i], S4)
13752
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
13753
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
13754
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
13755
+ // dst backward-/ grad[dst] = d
13756
+ //
13757
+ // output gradients with their dependencies:
13758
+ //
13759
+ // grad[kcur] = grad[S1].T @ qcur
13760
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
13761
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13762
+ // grad[S4] = grad[S5] @ vcur
13763
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
13764
+ // grad[qcur] = grad[S1] @ kcur
13765
+ // grad[vcur] = grad[S5].T @ S4
13766
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
13767
+ //
13768
+ // in post-order:
13769
+ //
13770
+ // S1 = qcur @ kcur.T
13771
+ // S2 = S1 * scale
13772
+ // S3 = diag_mask_inf(S2, P)
13773
+ // S4 = softmax(S3)
13774
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
13775
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13776
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
13777
+ // grad[qcur] = grad[S1] @ kcur
13778
+ // grad[kcur] = grad[S1].T @ qcur
13779
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
13780
+ //
13781
+ // using less variables (SM=S4):
13782
+ //
13783
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
13784
+ // SM = softmax(S)
13785
+ // S = d[:D,iq1,iq2,iq3] @ vcur
13786
+ // dot_SM_gradSM = dot(SM, S)
13787
+ // S = SM * (S - dot(SM, S))
13788
+ // S = diag_mask_zero(S, P) * scale
13789
+ //
13790
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
13791
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
13792
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
13793
+ }
13794
+
13795
+ // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
13796
+ // S = d[:D,iq1,iq2,iq3] @ vcur
13797
+ // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
13798
+ ggml_vec_set_f32(M, S, 0);
13799
+ for (int64_t ic = 0; ic < D; ++ic) {
13800
+ // dst indices
13801
+ const int i1 = iq1;
13802
+ const int i2 = iq2;
13803
+ const int i3 = iq3;
13804
+
13805
+ ggml_vec_mad_f32(M,
13806
+ S,
13807
+ (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
13808
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
13809
+ }
13810
+
13811
+ // S = SM * (S - dot(SM, S))
13812
+ float dot_SM_gradSM = 0;
13813
+ ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
13814
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
13815
+ ggml_vec_mul_f32 (M, S, S, SM);
13816
+
13817
+ // S = diag_mask_zero(S, P) * scale
13818
+ if (masked) {
13819
+ // for (int64_t i = P + iq1 + 1; i < M; i++) {
13820
+ // S[i] = 0;
13821
+ // }
13822
+ for (int64_t i = P; i < M; i++) {
13823
+ if (i > P + iq1) {
13824
+ S[i] = 0;
13825
+ }
13826
+ }
13827
+ }
13828
+ ggml_vec_scale_f32(M, S, scale);
13829
+
13830
+ void * grad_q = (char *) dst->data;
13831
+ void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
13832
+ void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
13833
+
13834
+ const size_t nbgq1 = nb0*neq0;
13835
+ const size_t nbgq2 = nb0*neq0*neq1;
13836
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
13837
+
13838
+ const size_t nbgk1 = nb0*nek0;
13839
+ const size_t nbgk2 = nb0*nek0*nek1;
13840
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
13841
+
13842
+ const size_t nbgv1 = nb0*nev0;
13843
+ const size_t nbgv2 = nb0*nev0*nev1;
13844
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
13845
+
13846
+ // S shape [M,1]
13847
+ // SM shape [M,1]
13848
+ // kcur shape [D,M]
13849
+ // qcur shape [D,1]
13850
+ // vcur shape [M,D]
13851
+ //
13852
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
13853
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
13854
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
13855
+ //
13856
+ //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
13857
+ //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
13858
+ for (int64_t ic = 0; ic < M; ++ic) {
13859
+ // dst indices
13860
+ const int i1 = iq1;
13861
+ const int i2 = iq2;
13862
+ const int i3 = iq3;
13863
+
13864
+ ggml_vec_mad_f32(D,
13865
+ (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
13866
+ (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
13867
+ S[ic]);
13868
+ }
13869
+
13870
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
13871
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
13872
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
13873
+ for (int64_t ic = 0; ic < M; ++ic) {
13874
+ // dst indices
13875
+ const int i1 = iq1;
13876
+ const int i2 = iq2;
13877
+ const int i3 = iq3;
13878
+
13879
+ // ggml_vec_set_f32(D,
13880
+ // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
13881
+ // 0);
13882
+ ggml_vec_mad_f32(D,
13883
+ (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
13884
+ (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
13885
+ S[ic]);
13886
+ }
13887
+
13888
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
13889
+ // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
13890
+ // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
13891
+ for (int64_t ic = 0; ic < D; ++ic) {
13892
+ // dst indices
13893
+ const int i1 = iq1;
13894
+ const int i2 = iq2;
13895
+ const int i3 = iq3;
13896
+
13897
+ // ggml_vec_set_f32(M,
13898
+ // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
13899
+ // 0);
13900
+ ggml_vec_mad_f32(M,
13901
+ (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
13902
+ SM,
13903
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
13904
+ }
13905
+ }
13906
+ }
13907
+ }
13908
+
13909
+ static void ggml_compute_forward_flash_attn_back(
13910
+ const struct ggml_compute_params * params,
13911
+ const struct ggml_tensor * q,
13912
+ const struct ggml_tensor * k,
13913
+ const struct ggml_tensor * v,
13914
+ const struct ggml_tensor * d,
13915
+ const bool masked,
13916
+ struct ggml_tensor * dst) {
13917
+ switch (q->type) {
13918
+ case GGML_TYPE_F32:
13919
+ {
13920
+ ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
13921
+ } break;
13922
+ default:
13923
+ {
13924
+ GGML_ASSERT(false);
13925
+ } break;
13926
+ }
13927
+ }
13928
+
13929
+ // ggml_compute_forward_map_unary
13930
+
13931
+ static void ggml_compute_forward_map_unary_f32(
13932
+ const struct ggml_compute_params * params,
13933
+ const struct ggml_tensor * src0,
13934
+ struct ggml_tensor * dst,
13935
+ const ggml_unary_op_f32_t fun) {
13936
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
13937
+
13938
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13939
+ return;
13940
+ }
13941
+
13942
+ const int n = ggml_nrows(src0);
13943
+ const int nc = src0->ne[0];
13944
+
13945
+ assert( dst->nb[0] == sizeof(float));
13946
+ assert(src0->nb[0] == sizeof(float));
13947
+
13948
+ for (int i = 0; i < n; i++) {
13949
+ fun(nc,
13950
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
13951
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
13952
+ }
13953
+ }
13954
+
13955
+
13956
+ static void ggml_compute_forward_map_unary(
13957
+ const struct ggml_compute_params * params,
13958
+ const struct ggml_tensor * src0,
13959
+ struct ggml_tensor * dst,
13960
+ const ggml_unary_op_f32_t fun) {
13961
+ switch (src0->type) {
13962
+ case GGML_TYPE_F32:
13963
+ {
13964
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
13965
+ } break;
13966
+ default:
13967
+ {
13968
+ GGML_ASSERT(false);
13969
+ } break;
13970
+ }
13971
+ }
13972
+
13973
+ // ggml_compute_forward_map_binary
13974
+
13975
+ static void ggml_compute_forward_map_binary_f32(
13976
+ const struct ggml_compute_params * params,
13977
+ const struct ggml_tensor * src0,
13978
+ const struct ggml_tensor * src1,
13979
+ struct ggml_tensor * dst,
13980
+ const ggml_binary_op_f32_t fun) {
13981
+ assert(params->ith == 0);
13982
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
13983
+
13984
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13985
+ return;
13986
+ }
13987
+
13988
+ const int n = ggml_nrows(src0);
13989
+ const int nc = src0->ne[0];
13990
+
13991
+ assert( dst->nb[0] == sizeof(float));
13992
+ assert(src0->nb[0] == sizeof(float));
13993
+ assert(src1->nb[0] == sizeof(float));
13994
+
13995
+ for (int i = 0; i < n; i++) {
13996
+ fun(nc,
13997
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
13998
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
13999
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
14000
+ }
14001
+ }
14002
+
14003
+
14004
+ static void ggml_compute_forward_map_binary(
14005
+ const struct ggml_compute_params * params,
14006
+ const struct ggml_tensor * src0,
14007
+ const struct ggml_tensor * src1,
14008
+ struct ggml_tensor * dst,
14009
+ const ggml_binary_op_f32_t fun) {
14010
+ switch (src0->type) {
14011
+ case GGML_TYPE_F32:
14012
+ {
14013
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14014
+ } break;
14015
+ default:
14016
+ {
14017
+ GGML_ASSERT(false);
14018
+ } break;
14019
+ }
14020
+ }
14021
+
14022
+ // ggml_compute_forward_cross_entropy_loss
14023
+
14024
+ static void ggml_compute_forward_cross_entropy_loss_f32(
14025
+ const struct ggml_compute_params * params,
14026
+ const struct ggml_tensor * src0,
14027
+ const struct ggml_tensor * src1,
14028
+ struct ggml_tensor * dst) {
14029
+ GGML_ASSERT(ggml_is_contiguous(src0));
14030
+ GGML_ASSERT(ggml_is_contiguous(src1));
14031
+ GGML_ASSERT(ggml_is_scalar(dst));
14032
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
14033
+
14034
+ const int ith = params->ith;
14035
+ const int nth = params->nth;
14036
+
14037
+ float * sums = (float *) params->wdata;
14038
+
14039
+ // TODO: handle transposed/permuted matrices
14040
+ const int nc = src0->ne[0];
14041
+ const int nr = ggml_nrows(src0);
14042
+
14043
+ if (params->type == GGML_TASK_INIT) {
14044
+ if (ith == 0) {
14045
+ memset(sums, 0, sizeof(float) * (nth + nth * nc));
14046
+ }
14047
+ return;
14048
+ }
14049
+
14050
+ if (params->type == GGML_TASK_FINALIZE) {
14051
+ if (ith == 0) {
14052
+ float * dp = (float *) dst->data;
14053
+ ggml_vec_sum_f32(nth, dp, sums);
14054
+ dp[0] *= -1.0f;
14055
+ }
14056
+ return;
14057
+ }
14058
+
14059
+ const double eps = 1e-9;
14060
+
14061
+ // rows per thread
14062
+ const int dr = (nr + nth - 1)/nth;
14063
+
14064
+ // row range for this thread
14065
+ const int ir0 = dr*ith;
14066
+ const int ir1 = MIN(ir0 + dr, nr);
14067
+
14068
+ for (int i1 = ir0; i1 < ir1; i1++) {
14069
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14070
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14071
+ float * st = (float *) params->wdata + nth + ith*nc;
14072
+
14073
+ #ifndef NDEBUG
14074
+ for (int i = 0; i < nc; ++i) {
14075
+ //printf("p[%d] = %f\n", i, p[i]);
14076
+ assert(!isnan(s0[i]));
14077
+ assert(!isnan(s1[i]));
14078
+ }
14079
+ #endif
14080
+ // soft_max
14081
+ ggml_float sum = 0.0;
14082
+ {
14083
+ float max = -INFINITY;
14084
+ ggml_vec_max_f32(nc, &max, s0);
14085
+
14086
+ uint16_t scvt;
14087
+ for (int i = 0; i < nc; i++) {
14088
+ if (s0[i] == -INFINITY) {
14089
+ st[i] = 0.0f;
14090
+ } else {
14091
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14092
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14093
+ memcpy(&scvt, &s, sizeof(scvt));
14094
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14095
+ sum += (ggml_float)val;
14096
+ st[i] = val;
14097
+ }
14098
+ }
14099
+
14100
+ assert(sum > 0.0);
14101
+ // sum = 1.0/sum;
14102
+ }
14103
+ // avoid log(0) by rescaling from [0..1] to [eps..1]
14104
+ sum = (1.0 - eps) / sum;
14105
+ ggml_vec_scale_f32(nc, st, sum);
14106
+ ggml_vec_add1_f32(nc, st, st, eps);
14107
+ ggml_vec_log_f32(nc, st, st);
14108
+ ggml_vec_mul_f32(nc, st, st, s1);
14109
+
14110
+ ggml_vec_sum_f32(nc, sums + ith, st);
14111
+
14112
+ #ifndef NDEBUG
14113
+ for (int i = 0; i < nc; ++i) {
14114
+ assert(!isnan(st[i]));
14115
+ assert(!isinf(st[i]));
14116
+ }
14117
+ #endif
14118
+ }
14119
+
14120
+ }
14121
+
14122
+ static void ggml_compute_forward_cross_entropy_loss(
14123
+ const struct ggml_compute_params * params,
14124
+ const struct ggml_tensor * src0,
14125
+ const struct ggml_tensor * src1,
14126
+ struct ggml_tensor * dst) {
12973
14127
  switch (src0->type) {
12974
14128
  case GGML_TYPE_F32:
12975
14129
  {
12976
- ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
14130
+ ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
12977
14131
  } break;
12978
14132
  default:
12979
14133
  {
@@ -12982,47 +14136,160 @@ static void ggml_compute_forward_map_unary(
12982
14136
  }
12983
14137
  }
12984
14138
 
12985
- // ggml_compute_forward_map_binary
14139
+ // ggml_compute_forward_cross_entropy_loss_back
12986
14140
 
12987
- static void ggml_compute_forward_map_binary_f32(
14141
+ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12988
14142
  const struct ggml_compute_params * params,
12989
14143
  const struct ggml_tensor * src0,
12990
14144
  const struct ggml_tensor * src1,
12991
- struct ggml_tensor * dst,
12992
- const ggml_binary_op_f32_t fun) {
12993
- assert(params->ith == 0);
12994
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14145
+ const struct ggml_tensor * opt0,
14146
+ struct ggml_tensor * dst) {
14147
+ GGML_ASSERT(ggml_is_contiguous(dst));
14148
+ GGML_ASSERT(ggml_is_contiguous(src0));
14149
+ GGML_ASSERT(ggml_is_contiguous(src1));
14150
+ GGML_ASSERT(ggml_is_contiguous(opt0));
14151
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14152
+
14153
+ const int64_t ith = params->ith;
14154
+ const int64_t nth = params->nth;
12995
14155
 
12996
14156
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12997
14157
  return;
12998
14158
  }
12999
14159
 
13000
- const int n = ggml_nrows(src0);
13001
- const int nc = src0->ne[0];
14160
+ const float eps = 1e-9f;
13002
14161
 
13003
- assert( dst->nb[0] == sizeof(float));
13004
- assert(src0->nb[0] == sizeof(float));
13005
- assert(src1->nb[0] == sizeof(float));
14162
+ // TODO: handle transposed/permuted matrices
14163
+ const int64_t nc = src0->ne[0];
14164
+ const int64_t nr = ggml_nrows(src0);
13006
14165
 
13007
- for (int i = 0; i < n; i++) {
13008
- fun(nc,
13009
- (float *) ((char *) dst->data + i*( dst->nb[1])),
13010
- (float *) ((char *) src0->data + i*(src0->nb[1])),
13011
- (float *) ((char *) src1->data + i*(src1->nb[1])));
14166
+ // rows per thread
14167
+ const int64_t dr = (nr + nth - 1)/nth;
14168
+
14169
+ // row range for this thread
14170
+ const int64_t ir0 = dr*ith;
14171
+ const int64_t ir1 = MIN(ir0 + dr, nr);
14172
+
14173
+ float * d = (float *) opt0->data;
14174
+
14175
+ for (int64_t i1 = ir0; i1 < ir1; i1++) {
14176
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
14177
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14178
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14179
+ float * sm = (float *) params->wdata + ith*nc;
14180
+
14181
+ #ifndef NDEBUG
14182
+ for (int i = 0; i < nc; ++i) {
14183
+ //printf("p[%d] = %f\n", i, p[i]);
14184
+ assert(!isnan(s0[i]));
14185
+ assert(!isnan(s1[i]));
14186
+ }
14187
+ #endif
14188
+ // step by step explanation:
14189
+ {
14190
+ //float * sums = (float *) params->wdata;
14191
+
14192
+ // forward pass with annotated gradients from backward pass
14193
+ // (built by going in reverse operation order, adding to gradients of current operation args)
14194
+ // st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
14195
+ // from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14196
+ // ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
14197
+ // ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
14198
+ // ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
14199
+ // ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
14200
+ // ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
14201
+ // ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
14202
+
14203
+ // substitute into grad[st1], because we can reuse softmax_back from this point on
14204
+ // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
14205
+ // postorder:
14206
+ // grad[st1] := softmax(s0)
14207
+ // grad[st1] := grad[st1]*(1.0 - eps)
14208
+ // grad[st1] := grad[st1] + eps
14209
+ // grad[st1] := s1 / grad[st1]
14210
+ // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
14211
+
14212
+ // src0 gradients by going through softmax_back
14213
+ // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14214
+ // from softmax_back:
14215
+ // dxk = yk * (dyk - dot(y, dy))
14216
+ // dot_y_dy := dot(y, dy)
14217
+ // dx := dy
14218
+ // dx := dx - dot_y_dy
14219
+ // dx := dx * y
14220
+ // postorder:
14221
+ // dot_st1_dst1 := dot(st1, grad[st1])
14222
+ // grad[s0] := grad[st1]
14223
+ // grad[s0] := grad[s0] - dot_st1_dst1
14224
+ // grad[s0] := grad[s0] * st1
14225
+
14226
+ // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
14227
+ // sm := softmax(s0)
14228
+ // grad[s0] := sm*(1.0 - eps)
14229
+ // grad[s0] := grad[s0] + eps
14230
+ // grad[s0] := s1 / grad[s0]
14231
+ // grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
14232
+ // dot_st1_dst1 := dot(sm, grad[s0])
14233
+ // grad[s0] := grad[s0] - dot_st1_dst1
14234
+ // grad[s0] := grad[s0] * sm
14235
+ }
14236
+
14237
+ // soft_max
14238
+ ggml_float sum = 0.0;
14239
+ {
14240
+ float max = -INFINITY;
14241
+ ggml_vec_max_f32(nc, &max, s0);
14242
+
14243
+ uint16_t scvt;
14244
+ for (int i = 0; i < nc; i++) {
14245
+ if (s0[i] == -INFINITY) {
14246
+ sm[i] = 0.0f;
14247
+ } else {
14248
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14249
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14250
+ memcpy(&scvt, &s, sizeof(scvt));
14251
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14252
+ sum += (ggml_float)val;
14253
+ sm[i] = val;
14254
+ }
14255
+ }
14256
+
14257
+ assert(sum > 0.0);
14258
+ sum = 1.0/sum;
14259
+ }
14260
+
14261
+ float dot_st1_dst1 = 0;
14262
+ ggml_vec_scale_f32(nc, sm, sum);
14263
+ ggml_vec_cpy_f32 (nc, ds0, sm);
14264
+ ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
14265
+ ggml_vec_add1_f32 (nc, ds0, ds0, eps);
14266
+ ggml_vec_div_f32 (nc, ds0, s1, ds0);
14267
+ ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
14268
+ ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
14269
+ ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
14270
+ ggml_vec_mul_f32 (nc, ds0, ds0, sm);
14271
+
14272
+ #ifndef NDEBUG
14273
+ for (int i = 0; i < nc; ++i) {
14274
+ assert(!isnan(sm[i]));
14275
+ assert(!isinf(sm[i]));
14276
+ assert(!isnan(ds0[i]));
14277
+ assert(!isinf(ds0[i]));
14278
+ }
14279
+ #endif
13012
14280
  }
13013
14281
  }
13014
14282
 
13015
-
13016
- static void ggml_compute_forward_map_binary(
14283
+ static void ggml_compute_forward_cross_entropy_loss_back(
13017
14284
  const struct ggml_compute_params * params,
13018
14285
  const struct ggml_tensor * src0,
13019
14286
  const struct ggml_tensor * src1,
13020
- struct ggml_tensor * dst,
13021
- const ggml_binary_op_f32_t fun) {
14287
+ const struct ggml_tensor * opt0,
14288
+ struct ggml_tensor * dst) {
13022
14289
  switch (src0->type) {
13023
14290
  case GGML_TYPE_F32:
13024
14291
  {
13025
- ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14292
+ ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
13026
14293
  } break;
13027
14294
  default:
13028
14295
  {
@@ -13031,6 +14298,7 @@ static void ggml_compute_forward_map_binary(
13031
14298
  }
13032
14299
  }
13033
14300
 
14301
+
13034
14302
  /////////////////////////////////
13035
14303
 
13036
14304
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -13102,6 +14370,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13102
14370
  {
13103
14371
  ggml_compute_forward_repeat(params, tensor->src0, tensor);
13104
14372
  } break;
14373
+ case GGML_OP_REPEAT_BACK:
14374
+ {
14375
+ ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14376
+ } break;
13105
14377
  case GGML_OP_ABS:
13106
14378
  {
13107
14379
  ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -13150,6 +14422,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13150
14422
  {
13151
14423
  ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
13152
14424
  } break;
14425
+ case GGML_OP_OUT_PROD:
14426
+ {
14427
+ ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
14428
+ } break;
13153
14429
  case GGML_OP_SCALE:
13154
14430
  {
13155
14431
  ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
@@ -13206,6 +14482,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13206
14482
  {
13207
14483
  ggml_compute_forward_soft_max(params, tensor->src0, tensor);
13208
14484
  } break;
14485
+ case GGML_OP_SOFT_MAX_BACK:
14486
+ {
14487
+ ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
14488
+ } break;
13209
14489
  case GGML_OP_ROPE:
13210
14490
  {
13211
14491
  ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
@@ -13241,6 +14521,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13241
14521
  {
13242
14522
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
13243
14523
  } break;
14524
+ case GGML_OP_FLASH_ATTN_BACK:
14525
+ {
14526
+ int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
14527
+ GGML_ASSERT(t == 0 || t == 1);
14528
+ bool masked = t != 0;
14529
+ ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
14530
+ } break;
13244
14531
  case GGML_OP_MAP_UNARY:
13245
14532
  {
13246
14533
  const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
@@ -13253,6 +14540,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13253
14540
  ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
13254
14541
  }
13255
14542
  break;
14543
+ case GGML_OP_CROSS_ENTROPY_LOSS:
14544
+ {
14545
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
14546
+ }
14547
+ break;
14548
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
14549
+ {
14550
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14551
+ }
14552
+ break;
13256
14553
  case GGML_OP_NONE:
13257
14554
  {
13258
14555
  // nop
@@ -13391,11 +14688,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13391
14688
  src0->grad =
13392
14689
  ggml_add_impl(ctx,
13393
14690
  src0->grad,
13394
- ggml_mul(ctx,
13395
- tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
14691
+ ggml_scale(ctx,
13396
14692
  ggml_div(ctx,
13397
- ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
13398
- tensor)),
14693
+ tensor->grad,
14694
+ tensor),
14695
+ ggml_new_f32(ctx, 0.5f)),
13399
14696
  inplace);
13400
14697
  }
13401
14698
  } break;
@@ -13441,43 +14738,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13441
14738
  {
13442
14739
  // necessary for llama
13443
14740
  if (src0->grad) {
13444
- GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
13445
- const int nc = tensor->ne[0];
13446
- const int nr = tensor->ne[1];
13447
- const int nc0 = src0->ne[0];
13448
- const int nr0 = src0->ne[1];
13449
- const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
13450
- const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
13451
- // tensor->grad [nc,nr,1,1]
13452
- // reshape [nc0,nc/nc0,nr0,nr/nr0]
13453
- // permute [nc0,nr0,nc/nc0,nr/nr0]
13454
- // substitute [nc0,nr0,ncr,nrr]
13455
- // reshape [nc0*nr0,ncr*nrr,1,1]
13456
- // transpose [ncr*nrr,nc0*nr0,1,1]
13457
- // sum rows [1,nc0*nr0,1,1]
13458
- // transpose [nc0*nr0,1,1]
13459
- // reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
13460
- // add to src0->grad
13461
-
13462
- int64_t ne[4] = {nc0,ncr,nr0,nrr};
13463
-
13464
- struct ggml_tensor* F00 = tensor->grad;
13465
- struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
13466
- struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
13467
- struct ggml_tensor* F03 = ggml_cont (ctx, F02);
13468
- struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
13469
- struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
13470
- struct ggml_tensor* F06 = ggml_cont (ctx, F05);
13471
- struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
13472
- struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
13473
- struct ggml_tensor* F09 = ggml_cont (ctx, F08);
13474
- struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
13475
-
13476
- src0->grad =
13477
- ggml_add_impl(ctx,
13478
- src0->grad,
13479
- F10,
13480
- inplace);
14741
+ src0->grad = ggml_add_impl(ctx,
14742
+ src0->grad,
14743
+ ggml_repeat_back(ctx, tensor->grad, src0->grad),
14744
+ inplace);
14745
+ }
14746
+ } break;
14747
+ case GGML_OP_REPEAT_BACK:
14748
+ {
14749
+ if (src0->grad) {
14750
+ // TODO: test this
14751
+ src0->grad = ggml_add_impl(ctx,
14752
+ src0->grad,
14753
+ ggml_repeat(ctx, tensor->grad, src0->grad),
14754
+ inplace);
13481
14755
  }
13482
14756
  } break;
13483
14757
  case GGML_OP_ABS:
@@ -13584,38 +14858,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13584
14858
 
13585
14859
  // necessary for llama
13586
14860
  if (src0->grad) {
13587
- // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
13588
14861
  src0->grad =
13589
14862
  ggml_add_impl(ctx,
13590
14863
  src0->grad,
13591
- // ds0 = dt.dot(s1.T)
13592
- // ggml_out_prod(ctx, // [n,m]
13593
- // src1, // [n,p]
13594
- // tensor->grad), // [m,p]
13595
- // for now just using A*B==(B.T*A.T).T
13596
- ggml_cont(ctx, // [n,m]
13597
- ggml_transpose(ctx, // [n,m]
13598
- ggml_mul_mat(ctx, // [m,n]
13599
- ggml_cont(ctx, // [p,m]
13600
- ggml_transpose(ctx, // [p,m]
13601
- tensor->grad)), // [m,p]
13602
- ggml_cont(ctx, // [p,n]
13603
- ggml_transpose(ctx, // [p,n]
13604
- src1))))), // [n,p]
14864
+ ggml_out_prod(ctx, // [n,m]
14865
+ src1, // [n,p]
14866
+ tensor->grad), // [m,p]
13605
14867
  inplace);
13606
14868
  }
13607
14869
  if (src1->grad) {
13608
14870
  src1->grad =
13609
14871
  ggml_add_impl(ctx,
13610
14872
  src1->grad,
13611
- // ds1 = s0.T.dot(dt):
13612
- ggml_mul_mat(ctx, // [n,p]
13613
- ggml_cont(ctx, // [m,n]
13614
- ggml_transpose(ctx, src0)), // [m,n]
13615
- tensor->grad), // [m,p]
14873
+ // ggml_mul_mat(ctx, // [n,p]
14874
+ // ggml_cont(ctx, // [m,n]
14875
+ // ggml_transpose(ctx, src0)), // [m,n]
14876
+ // tensor->grad), // [m,p]
14877
+
14878
+ // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
14879
+ // // avoid transpose of src0, rather transpose smaller tensor->grad
14880
+ // // and then use ggml_out_prod
14881
+ ggml_out_prod(ctx, // [n,p]
14882
+ src0, // [n,m]
14883
+ ggml_transpose(ctx, // [p,m]
14884
+ tensor->grad)), // [m,p]
13616
14885
  inplace);
13617
14886
  }
13618
14887
  } break;
14888
+ case GGML_OP_OUT_PROD:
14889
+ {
14890
+ GGML_ASSERT(false); // TODO: not implemented
14891
+ } break;
13619
14892
  case GGML_OP_SCALE:
13620
14893
  {
13621
14894
  // necessary for llama
@@ -13717,7 +14990,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13717
14990
  // necessary for llama
13718
14991
  if (src0->grad) {
13719
14992
  size_t offset;
13720
- memcpy(&offset, tensor->padding, sizeof(offset));
14993
+
14994
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
14995
+ memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
13721
14996
 
13722
14997
  size_t nb1 = tensor->nb[1];
13723
14998
  size_t nb2 = tensor->nb[2];
@@ -13744,10 +15019,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13744
15019
  {
13745
15020
  // necessary for llama
13746
15021
  if (src0->grad) {
13747
- int axis0 = tensor->padding[0] & 0x3;
13748
- int axis1 = tensor->padding[1] & 0x3;
13749
- int axis2 = tensor->padding[2] & 0x3;
13750
- int axis3 = tensor->padding[3] & 0x3;
15022
+ int32_t * axes = (int32_t *) tensor->opt[0]->data;
15023
+ int axis0 = axes[0] & 0x3;
15024
+ int axis1 = axes[1] & 0x3;
15025
+ int axis2 = axes[2] & 0x3;
15026
+ int axis3 = axes[3] & 0x3;
13751
15027
  int axes_backward[4] = {0,0,0,0};
13752
15028
  axes_backward[axis0] = 0;
13753
15029
  axes_backward[axis1] = 1;
@@ -13831,50 +15107,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13831
15107
  {
13832
15108
  // necessary for llama
13833
15109
  if (src0->grad) {
13834
- // y = softmax(x)
13835
- //
13836
- // Jii = yi - yi*yi
13837
- // Jij = -yi*yj
13838
- // J = diag(y)-y.*y
13839
- // dx = J * dy
13840
- // dxk = sum(Jkj * dyk)
13841
-
13842
- int64_t ne2[4] = {
13843
- tensor->ne[0],
13844
- 1,
13845
- tensor->ne[1]*tensor->ne[2],
13846
- tensor->ne[3]
13847
- };
13848
- struct ggml_tensor * tensor2 = ggml_cont(ctx,
13849
- ggml_reshape_4d(ctx,
13850
- ggml_cont(ctx, tensor),
13851
- ne2[0], ne2[1], ne2[2], ne2[3]));
13852
-
13853
- struct ggml_tensor * grad2 = ggml_cont(ctx,
13854
- ggml_reshape_4d(ctx,
13855
- ggml_cont(ctx, tensor->grad),
13856
- ne2[0], ne2[1], ne2[2], ne2[3]));
13857
-
13858
- struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
13859
- ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
13860
- tensor2, // [ne0,1,ne1*ne2,ne3]
13861
- 1, 0, 2, 3));
13862
-
13863
15110
  src0->grad =
13864
- ggml_add_impl(ctx,
13865
- src0->grad, // [ne0,ne1,ne2,ne3]
13866
- ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
13867
- ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
13868
- ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
13869
- ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
13870
- tensor2), // [ne0,1,ne1*ne2,ne3]
13871
- ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
13872
- tensor2_t, // [1,ne0,ne1*ne2,ne3]
13873
- tensor2_t)), // [1,ne0,ne1*ne2,ne3]
13874
- grad2), // [ne0,1,ne1*ne2,ne3]
13875
- src0->grad),
13876
- inplace);
15111
+ ggml_add_impl(ctx, src0->grad,
15112
+ ggml_soft_max_back(ctx, tensor->grad, tensor),
15113
+ inplace);
13877
15114
  }
15115
+
15116
+ } break;
15117
+ case GGML_OP_SOFT_MAX_BACK:
15118
+ {
15119
+ GGML_ASSERT(false); // TODO: not implemented
13878
15120
  } break;
13879
15121
  case GGML_OP_ROPE:
13880
15122
  {
@@ -13929,17 +15171,190 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13929
15171
  } break;
13930
15172
  case GGML_OP_FLASH_ATTN:
13931
15173
  {
13932
- GGML_ASSERT(false); // not supported
15174
+ struct ggml_tensor * flash_grad = NULL;
15175
+ if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15176
+ int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15177
+ GGML_ASSERT(t == 0 || t == 1);
15178
+ bool masked = t != 0;
15179
+ flash_grad =
15180
+ ggml_flash_attn_back(ctx,
15181
+ src0,
15182
+ src1,
15183
+ tensor->opt[0],
15184
+ tensor->grad,
15185
+ masked);
15186
+ }
15187
+
15188
+ if (src0->grad) {
15189
+ struct ggml_tensor * grad_q = NULL;
15190
+ const size_t nb0 = flash_grad->nb[0];
15191
+ const size_t offset = 0;
15192
+ switch(src0->n_dims) {
15193
+ case 2:
15194
+ {
15195
+ grad_q = ggml_view_2d(ctx,
15196
+ flash_grad,
15197
+ src0->ne[0],
15198
+ src0->ne[1],
15199
+ nb0*src0->ne[0],
15200
+ offset);
15201
+ } break;
15202
+ case 3:
15203
+ {
15204
+ grad_q = ggml_view_3d(ctx,
15205
+ flash_grad,
15206
+ src0->ne[0],
15207
+ src0->ne[1],
15208
+ src0->ne[2],
15209
+ nb0*src0->ne[0],
15210
+ nb0*src0->ne[0]*src0->ne[1],
15211
+ offset);
15212
+ } break;
15213
+ case 4:
15214
+ {
15215
+ grad_q = ggml_view_4d(ctx,
15216
+ flash_grad,
15217
+ src0->ne[0],
15218
+ src0->ne[1],
15219
+ src0->ne[2],
15220
+ src0->ne[3],
15221
+ nb0*src0->ne[0],
15222
+ nb0*src0->ne[0]*src0->ne[1],
15223
+ nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
15224
+ offset);
15225
+ } break;
15226
+ }
15227
+
15228
+ src0->grad = ggml_add_impl(ctx,
15229
+ src0->grad,
15230
+ grad_q,
15231
+ inplace);
15232
+ }
15233
+
15234
+ if (src1->grad) {
15235
+ struct ggml_tensor * grad_k = NULL;
15236
+ const size_t nb0 = flash_grad->nb[0];
15237
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
15238
+ switch(src1->n_dims) {
15239
+ case 2:
15240
+ {
15241
+ grad_k = ggml_view_2d(ctx,
15242
+ flash_grad,
15243
+ src1->ne[0],
15244
+ src1->ne[1],
15245
+ nb0*src1->ne[0],
15246
+ offset);
15247
+ } break;
15248
+ case 3:
15249
+ {
15250
+ grad_k = ggml_view_3d(ctx,
15251
+ flash_grad,
15252
+ src1->ne[0],
15253
+ src1->ne[1],
15254
+ src1->ne[2],
15255
+ nb0*src1->ne[0],
15256
+ nb0*src1->ne[0]*src1->ne[1],
15257
+ offset);
15258
+ } break;
15259
+ case 4:
15260
+ {
15261
+ grad_k = ggml_view_4d(ctx,
15262
+ flash_grad,
15263
+ src1->ne[0],
15264
+ src1->ne[1],
15265
+ src1->ne[2],
15266
+ src1->ne[3],
15267
+ nb0*src1->ne[0],
15268
+ nb0*src1->ne[0]*src1->ne[1],
15269
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
15270
+ offset);
15271
+ } break;
15272
+ }
15273
+
15274
+ src1->grad = ggml_add_impl(ctx,
15275
+ src1->grad,
15276
+ grad_k,
15277
+ inplace);
15278
+ }
15279
+
15280
+ struct ggml_tensor * opt0 = tensor->opt[0];
15281
+
15282
+ if (opt0->grad) {
15283
+ struct ggml_tensor * grad_v = NULL;
15284
+ const size_t nb0 = flash_grad->nb[0];
15285
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
15286
+ + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
15287
+ switch(opt0->n_dims) {
15288
+ case 2:
15289
+ {
15290
+ grad_v = ggml_view_2d(ctx,
15291
+ flash_grad,
15292
+ opt0->ne[0],
15293
+ opt0->ne[1],
15294
+ nb0*opt0->ne[0],
15295
+ offset);
15296
+ } break;
15297
+ case 3:
15298
+ {
15299
+ grad_v = ggml_view_3d(ctx,
15300
+ flash_grad,
15301
+ opt0->ne[0],
15302
+ opt0->ne[1],
15303
+ opt0->ne[2],
15304
+ nb0*opt0->ne[0],
15305
+ nb0*opt0->ne[0]*opt0->ne[1],
15306
+ offset);
15307
+ } break;
15308
+ case 4:
15309
+ {
15310
+ grad_v = ggml_view_4d(ctx,
15311
+ flash_grad,
15312
+ opt0->ne[0],
15313
+ opt0->ne[1],
15314
+ opt0->ne[2],
15315
+ opt0->ne[3],
15316
+ nb0*opt0->ne[0],
15317
+ nb0*opt0->ne[0]*opt0->ne[1],
15318
+ nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
15319
+ offset);
15320
+ } break;
15321
+ }
15322
+
15323
+ opt0->grad = ggml_add_impl(ctx,
15324
+ opt0->grad,
15325
+ grad_v,
15326
+ inplace);
15327
+ }
13933
15328
  } break;
13934
15329
  case GGML_OP_FLASH_FF:
13935
15330
  {
13936
15331
  GGML_ASSERT(false); // not supported
13937
15332
  } break;
15333
+ case GGML_OP_FLASH_ATTN_BACK:
15334
+ {
15335
+ GGML_ASSERT(false); // not supported
15336
+ } break;
13938
15337
  case GGML_OP_MAP_UNARY:
13939
15338
  case GGML_OP_MAP_BINARY:
13940
15339
  {
13941
15340
  GGML_ASSERT(false); // not supported
13942
15341
  } break;
15342
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15343
+ {
15344
+ if (src0->grad) {
15345
+ src0->grad = ggml_add_impl(ctx,
15346
+ src0->grad,
15347
+ ggml_cross_entropy_loss_back(ctx,
15348
+ src0,
15349
+ src1,
15350
+ tensor->grad),
15351
+ inplace);
15352
+ }
15353
+ } break;
15354
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15355
+ {
15356
+ GGML_ASSERT(false); // not supported
15357
+ } break;
13943
15358
  case GGML_OP_NONE:
13944
15359
  {
13945
15360
  // nop
@@ -14316,6 +15731,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14316
15731
  case GGML_OP_SUM_ROWS:
14317
15732
  case GGML_OP_MEAN:
14318
15733
  case GGML_OP_REPEAT:
15734
+ case GGML_OP_REPEAT_BACK:
14319
15735
  case GGML_OP_ABS:
14320
15736
  case GGML_OP_SGN:
14321
15737
  case GGML_OP_NEG:
@@ -14335,6 +15751,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14335
15751
  node->n_tasks = n_threads;
14336
15752
  } break;
14337
15753
  case GGML_OP_MUL_MAT:
15754
+ case GGML_OP_OUT_PROD:
14338
15755
  {
14339
15756
  node->n_tasks = n_threads;
14340
15757
 
@@ -14417,6 +15834,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14417
15834
  } break;
14418
15835
  case GGML_OP_DIAG_MASK_INF:
14419
15836
  case GGML_OP_SOFT_MAX:
15837
+ case GGML_OP_SOFT_MAX_BACK:
14420
15838
  case GGML_OP_ROPE:
14421
15839
  case GGML_OP_ROPE_BACK:
14422
15840
  {
@@ -14496,6 +15914,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14496
15914
  cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
14497
15915
  }
14498
15916
 
15917
+ work_size = MAX(work_size, cur);
15918
+ } break;
15919
+ case GGML_OP_FLASH_ATTN_BACK:
15920
+ {
15921
+ node->n_tasks = n_threads;
15922
+
15923
+ size_t cur = 0;
15924
+
15925
+ const int64_t D = node->src0->ne[0];
15926
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
15927
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
15928
+ if (node->src1->type == GGML_TYPE_F32) {
15929
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
15930
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
15931
+ }
15932
+
15933
+ if (node->src1->type == GGML_TYPE_F16) {
15934
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
15935
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
15936
+ }
15937
+
14499
15938
  work_size = MAX(work_size, cur);
14500
15939
  } break;
14501
15940
  case GGML_OP_MAP_UNARY:
@@ -14503,6 +15942,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14503
15942
  {
14504
15943
  node->n_tasks = 1;
14505
15944
  } break;
15945
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15946
+ {
15947
+ node->n_tasks = n_threads;
15948
+
15949
+ size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
15950
+
15951
+ work_size = MAX(work_size, cur);
15952
+ } break;
15953
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15954
+ {
15955
+ node->n_tasks = n_threads;
15956
+
15957
+ size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
15958
+
15959
+ work_size = MAX(work_size, cur);
15960
+ } break;
14506
15961
  case GGML_OP_NONE:
14507
15962
  {
14508
15963
  node->n_tasks = 1;
@@ -15478,6 +16933,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
15478
16933
 
15479
16934
  static enum ggml_opt_result ggml_opt_adam(
15480
16935
  struct ggml_context * ctx,
16936
+ struct ggml_opt_context * opt,
15481
16937
  struct ggml_opt_params params,
15482
16938
  struct ggml_tensor * f,
15483
16939
  struct ggml_cgraph * gf,
@@ -15503,25 +16959,29 @@ static enum ggml_opt_result ggml_opt_adam(
15503
16959
  }
15504
16960
  }
15505
16961
 
16962
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
16963
+ int iter = opt->iter;
16964
+ ggml_opt_init(opt->ctx, opt, params, nx);
16965
+ opt->iter = iter;
16966
+ }
16967
+
15506
16968
  // constants
15507
- const float alpha = params.adam.alpha;
16969
+ const float sched = params.adam.sched;
16970
+ const float decay = params.adam.decay * sched;
16971
+ const float alpha = params.adam.alpha * sched;
15508
16972
  const float beta1 = params.adam.beta1;
15509
16973
  const float beta2 = params.adam.beta2;
15510
16974
  const float eps = params.adam.eps;
15511
16975
 
15512
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
15513
- float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
15514
- float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
15515
- float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
15516
- float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
15517
- float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
15518
- float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
16976
+ float * x = opt->adam.x->data; // view of the parameters
16977
+ float * g1 = opt->adam.g1->data; // gradient
16978
+ float * g2 = opt->adam.g2->data; // gradient squared
16979
+ float * m = opt->adam.m->data; // first moment
16980
+ float * v = opt->adam.v->data; // second moment
16981
+ float * mh = opt->adam.mh->data; // first moment hat
16982
+ float * vh = opt->adam.vh->data; // second moment hat
15519
16983
 
15520
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
15521
-
15522
- // initialize
15523
- ggml_vec_set_f32(nx, m, 0.0f);
15524
- ggml_vec_set_f32(nx, v, 0.0f);
16984
+ float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
15525
16985
 
15526
16986
  // update view
15527
16987
  ggml_opt_get_params(np, ps, x);
@@ -15531,16 +16991,27 @@ static enum ggml_opt_result ggml_opt_adam(
15531
16991
  ggml_set_f32 (f->grad, 1.0f);
15532
16992
  ggml_graph_compute(ctx, gb);
15533
16993
 
15534
- float fx_prev = ggml_get_f32_1d(f, 0);
16994
+ opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
16995
+ opt->adam.fx_best = opt->adam.fx_prev;
15535
16996
  if (pf) {
15536
- pf[0] = fx_prev;
16997
+ pf[opt->iter % params.past] = opt->adam.fx_prev;
15537
16998
  }
15538
16999
 
15539
- int n_no_improvement = 0;
15540
- float fx_best = fx_prev;
17000
+ // initialize
17001
+ if (opt->just_initialized) {
17002
+ opt->adam.n_no_improvement = 0;
17003
+ opt->just_initialized = false;
17004
+ }
17005
+
17006
+ float * fx_best = &opt->adam.fx_best;
17007
+ float * fx_prev = &opt->adam.fx_prev;
17008
+ int * n_no_improvement = &opt->adam.n_no_improvement;
17009
+
17010
+ int iter0 = opt->iter;
15541
17011
 
15542
17012
  // run the optimizer
15543
17013
  for (int t = 0; t < params.adam.n_iter; ++t) {
17014
+ opt->iter = iter0 + t + 1;
15544
17015
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
15545
17016
 
15546
17017
  GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
@@ -15574,17 +17045,22 @@ static enum ggml_opt_result ggml_opt_adam(
15574
17045
 
15575
17046
  // m^hat = m_t / (1 - beta1^t)
15576
17047
  // v^hat = v_t / (1 - beta2^t)
15577
- // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
17048
+ // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
17049
+ // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
17050
+ // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
17051
+ // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
17052
+ // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
15578
17053
  ggml_vec_cpy_f32 (nx, mh, m);
15579
17054
  ggml_vec_cpy_f32 (nx, vh, v);
15580
17055
 
15581
- ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
15582
- ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1)));
17056
+ ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
17057
+ ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
15583
17058
 
15584
17059
  ggml_vec_sqrt_f32 (nx, vh, vh);
15585
17060
  ggml_vec_acc1_f32 (nx, vh, eps);
15586
17061
 
15587
17062
  ggml_vec_div_f32 (nx, mh, mh, vh);
17063
+ ggml_vec_scale_f32(nx, x, 1.0f - decay);
15588
17064
  ggml_vec_sub_f32 (nx, x, x, mh);
15589
17065
 
15590
17066
  // update the parameters
@@ -15598,7 +17074,7 @@ static enum ggml_opt_result ggml_opt_adam(
15598
17074
  const float fx = ggml_get_f32_1d(f, 0);
15599
17075
 
15600
17076
  // check convergence
15601
- if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
17077
+ if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
15602
17078
  GGML_PRINT_DEBUG("converged\n");
15603
17079
 
15604
17080
  return GGML_OPT_OK;
@@ -15607,32 +17083,32 @@ static enum ggml_opt_result ggml_opt_adam(
15607
17083
  // delta-based convergence test
15608
17084
  if (pf != NULL) {
15609
17085
  // need at least params.past iterations to start checking for convergence
15610
- if (params.past <= t) {
15611
- const float rate = (pf[t%params.past] - fx)/fx;
17086
+ if (params.past <= iter0 + t) {
17087
+ const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
15612
17088
 
15613
17089
  if (fabsf(rate) < params.delta) {
15614
17090
  return GGML_OPT_OK;
15615
17091
  }
15616
17092
  }
15617
17093
 
15618
- pf[t%params.past] = fx;
17094
+ pf[(iter0 + t)%params.past] = fx;
15619
17095
  }
15620
17096
 
15621
17097
  // check for improvement
15622
17098
  if (params.max_no_improvement > 0) {
15623
- if (fx_best > fx) {
15624
- fx_best = fx;
15625
- n_no_improvement = 0;
17099
+ if (fx_best[0] > fx) {
17100
+ fx_best[0] = fx;
17101
+ n_no_improvement[0] = 0;
15626
17102
  } else {
15627
- ++n_no_improvement;
17103
+ ++n_no_improvement[0];
15628
17104
 
15629
- if (n_no_improvement >= params.max_no_improvement) {
17105
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15630
17106
  return GGML_OPT_OK;
15631
17107
  }
15632
17108
  }
15633
17109
  }
15634
17110
 
15635
- fx_prev = fx;
17111
+ fx_prev[0] = fx;
15636
17112
 
15637
17113
  {
15638
17114
  const int64_t t_end_cpu = ggml_cycles();
@@ -15771,6 +17247,7 @@ static enum ggml_opt_result linesearch_backtracking(
15771
17247
 
15772
17248
  static enum ggml_opt_result ggml_opt_lbfgs(
15773
17249
  struct ggml_context * ctx,
17250
+ struct ggml_opt_context * opt,
15774
17251
  struct ggml_opt_params params,
15775
17252
  struct ggml_tensor * f,
15776
17253
  struct ggml_cgraph * gf,
@@ -15803,31 +17280,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15803
17280
  }
15804
17281
  }
15805
17282
 
15806
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
15807
- float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
15808
- float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
15809
- float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
15810
- float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
17283
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
17284
+ int iter = opt->iter;
17285
+ ggml_opt_init(ctx, opt, params, nx);
17286
+ opt->iter = iter;
17287
+ }
17288
+
17289
+ float * x = opt->lbfgs.x->data; // current parameters
17290
+ float * xp = opt->lbfgs.xp->data; // previous parameters
17291
+ float * g = opt->lbfgs.g->data; // current gradient
17292
+ float * gp = opt->lbfgs.gp->data; // previous gradient
17293
+ float * d = opt->lbfgs.d->data; // search direction
15811
17294
 
15812
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
17295
+ float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
15813
17296
 
15814
17297
  float fx = 0.0f; // cost function value
15815
17298
  float xnorm = 0.0f; // ||x||
15816
17299
  float gnorm = 0.0f; // ||g||
15817
- float step = 0.0f;
15818
17300
 
15819
17301
  // initialize x from the graph nodes
15820
17302
  ggml_opt_get_params(np, ps, x);
15821
17303
 
15822
17304
  // the L-BFGS memory
15823
- struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
15824
-
15825
- for (int i = 0; i < m; ++i) {
15826
- lm[i].alpha = 0.0f;
15827
- lm[i].ys = 0.0f;
15828
- lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15829
- lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15830
- }
17305
+ float * lm_alpha = opt->lbfgs.lmal->data;
17306
+ float * lm_ys = opt->lbfgs.lmys->data;
17307
+ float * lm_s = opt->lbfgs.lms->data;
17308
+ float * lm_y = opt->lbfgs.lmy->data;
15831
17309
 
15832
17310
  // evaluate the function value and its gradient
15833
17311
  {
@@ -15842,12 +17320,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15842
17320
  fx = ggml_get_f32_1d(f, 0);
15843
17321
  }
15844
17322
 
15845
- if (pf) {
15846
- pf[0] = fx;
15847
- }
15848
-
15849
- float fx_best = fx;
15850
-
15851
17323
  // search direction = -gradient
15852
17324
  ggml_vec_neg_f32(nx, d, g);
15853
17325
 
@@ -15864,26 +17336,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15864
17336
  return GGML_OPT_OK;
15865
17337
  }
15866
17338
 
15867
- // initial step
15868
- ggml_vec_norm_inv_f32(nx, &step, d);
17339
+ if (opt->just_initialized) {
17340
+ if (pf) {
17341
+ pf[0] = fx;
17342
+ }
17343
+ opt->lbfgs.fx_best = fx;
17344
+
17345
+ // initial step
17346
+ ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
17347
+ opt->lbfgs.j = 0;
17348
+ opt->lbfgs.k = 1;
17349
+ opt->lbfgs.end = 0;
17350
+ opt->lbfgs.n_no_improvement = 0;
17351
+ opt->just_initialized = false;
17352
+ }
17353
+
17354
+ float * fx_best = &opt->lbfgs.fx_best;
17355
+ float * step = &opt->lbfgs.step;
17356
+ int * j = &opt->lbfgs.j;
17357
+ int * k = &opt->lbfgs.k;
17358
+ int * end = &opt->lbfgs.end;
17359
+ int * n_no_improvement = &opt->lbfgs.n_no_improvement;
15869
17360
 
15870
- int j = 0;
15871
- int k = 1;
15872
- int ls = 0;
15873
- int end = 0;
15874
- int bound = 0;
15875
- int n_no_improvement = 0;
17361
+ int ls = 0;
17362
+ int bound = 0;
15876
17363
 
15877
17364
  float ys = 0.0f;
15878
17365
  float yy = 0.0f;
15879
17366
  float beta = 0.0f;
15880
17367
 
17368
+ int it = 0;
17369
+
15881
17370
  while (true) {
15882
17371
  // store the current position and gradient vectors
15883
17372
  ggml_vec_cpy_f32(nx, xp, x);
15884
17373
  ggml_vec_cpy_f32(nx, gp, g);
15885
17374
 
15886
- ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
17375
+ ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
15887
17376
 
15888
17377
  if (ls < 0) {
15889
17378
  // linesearch failed - go back to the previous point and return
@@ -15909,32 +17398,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15909
17398
  // delta-based convergence test
15910
17399
  if (pf != NULL) {
15911
17400
  // need at least params.past iterations to start checking for convergence
15912
- if (params.past <= k) {
15913
- const float rate = (pf[k%params.past] - fx)/fx;
17401
+ if (params.past <= k[0]) {
17402
+ const float rate = (pf[k[0]%params.past] - fx)/fx;
15914
17403
 
15915
17404
  if (fabsf(rate) < params.delta) {
15916
17405
  return GGML_OPT_OK;
15917
17406
  }
15918
17407
  }
15919
17408
 
15920
- pf[k%params.past] = fx;
17409
+ pf[k[0]%params.past] = fx;
15921
17410
  }
15922
17411
 
15923
17412
  // check for improvement
15924
17413
  if (params.max_no_improvement > 0) {
15925
- if (fx < fx_best) {
15926
- fx_best = fx;
15927
- n_no_improvement = 0;
17414
+ if (fx < fx_best[0]) {
17415
+ fx_best[0] = fx;
17416
+ n_no_improvement[0] = 0;
15928
17417
  } else {
15929
- n_no_improvement++;
17418
+ n_no_improvement[0]++;
15930
17419
 
15931
- if (n_no_improvement >= params.max_no_improvement) {
17420
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15932
17421
  return GGML_OPT_OK;
15933
17422
  }
15934
17423
  }
15935
17424
  }
15936
17425
 
15937
- if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
17426
+ if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
15938
17427
  // reached the maximum number of iterations
15939
17428
  return GGML_OPT_DID_NOT_CONVERGE;
15940
17429
  }
@@ -15943,50 +17432,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15943
17432
  // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
15944
17433
  // y_{k+1} = g_{k+1} - g_{k}.
15945
17434
  //
15946
- ggml_vec_sub_f32(nx, lm[end].s, x, xp);
15947
- ggml_vec_sub_f32(nx, lm[end].y, g, gp);
17435
+ ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
17436
+ ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
15948
17437
 
15949
17438
  // compute scalars ys and yy:
15950
17439
  // ys = y^t \cdot s -> 1 / \rho.
15951
17440
  // yy = y^t \cdot y.
15952
17441
  //
15953
- ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
15954
- ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
17442
+ ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
17443
+ ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
15955
17444
 
15956
- lm[end].ys = ys;
17445
+ lm_ys[end[0]] = ys;
15957
17446
 
15958
17447
  // find new search direction
15959
17448
  // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
15960
17449
 
15961
- bound = (m <= k) ? m : k;
15962
- k++;
15963
- end = (end + 1)%m;
17450
+ bound = (m <= k[0]) ? m : k[0];
17451
+ k[0]++;
17452
+ it++;
17453
+ end[0] = (end[0] + 1)%m;
15964
17454
 
15965
17455
  // initialize search direction with -g
15966
17456
  ggml_vec_neg_f32(nx, d, g);
15967
17457
 
15968
- j = end;
17458
+ j[0] = end[0];
15969
17459
  for (int i = 0; i < bound; ++i) {
15970
- j = (j + m - 1) % m;
17460
+ j[0] = (j[0] + m - 1) % m;
15971
17461
  // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
15972
- ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
15973
- lm[j].alpha /= lm[j].ys;
17462
+ ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
17463
+ lm_alpha[j[0]] /= lm_ys[j[0]];
15974
17464
  // q_{i} = q_{i+1} - \alpha_{i} y_{i}
15975
- ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
17465
+ ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
15976
17466
  }
15977
17467
 
15978
17468
  ggml_vec_scale_f32(nx, d, ys/yy);
15979
17469
 
15980
17470
  for (int i = 0; i < bound; ++i) {
15981
17471
  // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
15982
- ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
15983
- beta /= lm[j].ys;
17472
+ ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
17473
+ beta /= lm_ys[j[0]];
15984
17474
  // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
15985
- ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
15986
- j = (j + 1)%m;
17475
+ ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
17476
+ j[0] = (j[0] + 1)%m;
15987
17477
  }
15988
17478
 
15989
- step = 1.0;
17479
+ step[0] = 1.0;
15990
17480
  }
15991
17481
 
15992
17482
  return GGML_OPT_DID_NOT_CONVERGE;
@@ -16011,6 +17501,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
16011
17501
 
16012
17502
  .adam = {
16013
17503
  .n_iter = 10000,
17504
+ .sched = 1.000f,
17505
+ .decay = 0.001f,
16014
17506
  .alpha = 0.001f,
16015
17507
  .beta1 = 0.9f,
16016
17508
  .beta2 = 0.999f,
@@ -16053,6 +17545,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
16053
17545
  return result;
16054
17546
  }
16055
17547
 
17548
+ GGML_API void ggml_opt_init(
17549
+ struct ggml_context * ctx,
17550
+ struct ggml_opt_context * opt,
17551
+ struct ggml_opt_params params,
17552
+ int64_t nx) {
17553
+ opt->ctx = ctx;
17554
+ opt->params = params;
17555
+ opt->iter = 0;
17556
+ opt->nx = nx;
17557
+ opt->just_initialized = true;
17558
+ switch (opt->params.type) {
17559
+ case GGML_OPT_ADAM:
17560
+ {
17561
+ opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17562
+ opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17563
+ opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17564
+ opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17565
+ opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17566
+ opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17567
+ opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17568
+ opt->adam.pf = params.past > 0
17569
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
17570
+ : NULL;
17571
+ ggml_set_zero(opt->adam.x);
17572
+ ggml_set_zero(opt->adam.g1);
17573
+ ggml_set_zero(opt->adam.g2);
17574
+ ggml_set_zero(opt->adam.m);
17575
+ ggml_set_zero(opt->adam.v);
17576
+ ggml_set_zero(opt->adam.mh);
17577
+ ggml_set_zero(opt->adam.vh);
17578
+ if (opt->adam.pf) {
17579
+ ggml_set_zero(opt->adam.pf);
17580
+ }
17581
+ } break;
17582
+ case GGML_OPT_LBFGS:
17583
+ {
17584
+ opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17585
+ opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17586
+ opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17587
+ opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17588
+ opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17589
+ opt->lbfgs.pf = params.past > 0
17590
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
17591
+ : NULL;
17592
+ opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
17593
+ opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
17594
+ opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
17595
+ opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
17596
+ ggml_set_zero(opt->lbfgs.x);
17597
+ ggml_set_zero(opt->lbfgs.xp);
17598
+ ggml_set_zero(opt->lbfgs.g);
17599
+ ggml_set_zero(opt->lbfgs.gp);
17600
+ ggml_set_zero(opt->lbfgs.d);
17601
+ ggml_set_zero(opt->lbfgs.pf);
17602
+ if (opt->lbfgs.pf) {
17603
+ ggml_set_zero(opt->lbfgs.pf);
17604
+ }
17605
+ ggml_set_zero(opt->lbfgs.lmal);
17606
+ ggml_set_zero(opt->lbfgs.lmys);
17607
+ ggml_set_zero(opt->lbfgs.lms);
17608
+ ggml_set_zero(opt->lbfgs.lmy);
17609
+ } break;
17610
+ }
17611
+ }
17612
+
16056
17613
  enum ggml_opt_result ggml_opt(
16057
17614
  struct ggml_context * ctx,
16058
17615
  struct ggml_opt_params params,
@@ -16075,33 +17632,65 @@ enum ggml_opt_result ggml_opt(
16075
17632
 
16076
17633
  enum ggml_opt_result result = GGML_OPT_OK;
16077
17634
 
17635
+ struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
17636
+
17637
+ ggml_opt_init(ctx, opt, params, 0);
17638
+ result = ggml_opt_resume(ctx, opt, f);
17639
+
17640
+ if (free_ctx) {
17641
+ ggml_free(ctx);
17642
+ }
17643
+
17644
+ return result;
17645
+ }
17646
+
17647
+ enum ggml_opt_result ggml_opt_resume(
17648
+ struct ggml_context * ctx,
17649
+ struct ggml_opt_context * opt,
17650
+ struct ggml_tensor * f) {
17651
+
17652
+ // build forward + backward compute graphs
17653
+ struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
17654
+ struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
17655
+
17656
+ struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
17657
+ struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
17658
+
17659
+ *gf = ggml_build_forward (f);
17660
+ *gb = ggml_build_backward(ctx, gf, true);
17661
+
17662
+ return ggml_opt_resume_g(ctx, opt, f, gf, gb);
17663
+ }
17664
+
17665
+ enum ggml_opt_result ggml_opt_resume_g(
17666
+ struct ggml_context * ctx,
17667
+ struct ggml_opt_context * opt,
17668
+ struct ggml_tensor * f,
17669
+ struct ggml_cgraph * gf,
17670
+ struct ggml_cgraph * gb) {
17671
+
16078
17672
  // build forward + backward compute graphs
16079
- struct ggml_cgraph gf = ggml_build_forward (f);
16080
- struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
17673
+ enum ggml_opt_result result = GGML_OPT_OK;
16081
17674
 
16082
- switch (params.type) {
17675
+ switch (opt->params.type) {
16083
17676
  case GGML_OPT_ADAM:
16084
17677
  {
16085
- result = ggml_opt_adam(ctx, params, f, &gf, &gb);
17678
+ result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
16086
17679
  } break;
16087
17680
  case GGML_OPT_LBFGS:
16088
17681
  {
16089
- result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
17682
+ result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
16090
17683
  } break;
16091
17684
  }
16092
17685
 
16093
- if (params.print_forward_graph) {
16094
- ggml_graph_print (&gf);
16095
- ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
16096
- }
16097
-
16098
- if (params.print_backward_graph) {
16099
- ggml_graph_print (&gb);
16100
- ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
17686
+ if (opt->params.print_forward_graph) {
17687
+ ggml_graph_print (gf);
17688
+ ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
16101
17689
  }
16102
17690
 
16103
- if (free_ctx) {
16104
- ggml_free(ctx);
17691
+ if (opt->params.print_backward_graph) {
17692
+ ggml_graph_print (gb);
17693
+ ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
16105
17694
  }
16106
17695
 
16107
17696
  return result;
@@ -16301,6 +17890,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
16301
17890
  result = ggml_quantize_q6_K(src + start, block, n, n, hist);
16302
17891
  } break;
16303
17892
  #endif
17893
+ case GGML_TYPE_F16:
17894
+ {
17895
+ int elemsize = sizeof(ggml_fp16_t);
17896
+ ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
17897
+ result = n * elemsize;
17898
+ } break;
17899
+ case GGML_TYPE_F32:
17900
+ {
17901
+ int elemsize = sizeof(float);
17902
+ result = n * elemsize;
17903
+ memcpy((uint8_t *)dst + start * elemsize, src + start, result);
17904
+ } break;
16304
17905
  default:
16305
17906
  assert(false);
16306
17907
  }