llama_cpp 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3603,6 +3603,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3603
3603
  "SUM_ROWS",
3604
3604
  "MEAN",
3605
3605
  "REPEAT",
3606
+ "REPEAT_BACK",
3606
3607
  "ABS",
3607
3608
  "SGN",
3608
3609
  "NEG",
@@ -3616,6 +3617,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3616
3617
  "RMS_NORM_BACK",
3617
3618
 
3618
3619
  "MUL_MAT",
3620
+ "OUT_PROD",
3619
3621
 
3620
3622
  "SCALE",
3621
3623
  "SET",
@@ -3631,6 +3633,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3631
3633
  "DIAG_MASK_INF",
3632
3634
  "DIAG_MASK_ZERO",
3633
3635
  "SOFT_MAX",
3636
+ "SOFT_MAX_BACK",
3634
3637
  "ROPE",
3635
3638
  "ROPE_BACK",
3636
3639
  "ALIBI",
@@ -3640,13 +3643,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3640
3643
 
3641
3644
  "FLASH_ATTN",
3642
3645
  "FLASH_FF",
3646
+ "FLASH_ATTN_BACK",
3643
3647
 
3644
3648
  "MAP_UNARY",
3645
3649
  "MAP_BINARY",
3646
- };
3647
3650
 
3648
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3651
+ "CROSS_ENTROPY_LOSS",
3652
+ "CROSS_ENTROPY_LOSS_BACK",
3653
+ };
3649
3654
 
3655
+ static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
3650
3656
 
3651
3657
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3652
3658
  "none",
@@ -3665,6 +3671,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3665
3671
  "Σx_k",
3666
3672
  "Σx/n",
3667
3673
  "repeat(x)",
3674
+ "repeat_back(x)",
3668
3675
  "abs(x)",
3669
3676
  "sgn(x)",
3670
3677
  "-x",
@@ -3677,6 +3684,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3677
3684
  "rms_norm(x)",
3678
3685
  "rms_norm_back(x)",
3679
3686
 
3687
+ "X*Y",
3680
3688
  "X*Y",
3681
3689
 
3682
3690
  "x*v",
@@ -3693,6 +3701,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3693
3701
  "diag_mask_inf(x)",
3694
3702
  "diag_mask_zero(x)",
3695
3703
  "soft_max(x)",
3704
+ "soft_max_back(x)",
3696
3705
  "rope(x)",
3697
3706
  "rope_back(x)",
3698
3707
  "alibi(x)",
@@ -3702,12 +3711,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3702
3711
 
3703
3712
  "flash_attn(x)",
3704
3713
  "flash_ff(x)",
3714
+ "flash_attn_back(x)",
3705
3715
 
3706
3716
  "f(x)",
3707
3717
  "f(x,y)",
3718
+
3719
+ "cross_entropy_loss(x,y)",
3720
+ "cross_entropy_loss_back(x,y)",
3708
3721
  };
3709
3722
 
3710
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3723
+ static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
3711
3724
 
3712
3725
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3713
3726
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3870,6 +3883,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3870
3883
  (t0->ne[3] == t1->ne[3]);
3871
3884
  }
3872
3885
 
3886
+ static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3887
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3888
+
3889
+ return
3890
+ (t0->ne[1] == t1->ne[1]) &&
3891
+ (t0->ne[2] == t1->ne[2]) &&
3892
+ (t0->ne[3] == t1->ne[3]);
3893
+ }
3894
+
3873
3895
  bool ggml_is_quantized(enum ggml_type type) {
3874
3896
  return GGML_IS_QUANTIZED[type];
3875
3897
  }
@@ -3917,6 +3939,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3917
3939
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3918
3940
  }
3919
3941
 
3942
+ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
3943
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3944
+
3945
+ return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
3946
+ }
3947
+
3920
3948
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
3921
3949
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3922
3950
 
@@ -4693,7 +4721,7 @@ struct ggml_tensor * ggml_add_impl(
4693
4721
 
4694
4722
  bool is_node = false;
4695
4723
 
4696
- if (!inplace && (a->grad || b->grad)) {
4724
+ if (a->grad || b->grad) {
4697
4725
  is_node = true;
4698
4726
  }
4699
4727
 
@@ -4733,7 +4761,7 @@ struct ggml_tensor * ggml_add1_impl(
4733
4761
 
4734
4762
  bool is_node = false;
4735
4763
 
4736
- if (!inplace && (a->grad || b->grad)) {
4764
+ if (a->grad || b->grad) {
4737
4765
  is_node = true;
4738
4766
  }
4739
4767
 
@@ -5159,6 +5187,34 @@ struct ggml_tensor * ggml_repeat(
5159
5187
  return result;
5160
5188
  }
5161
5189
 
5190
+ // ggml_repeat_back
5191
+
5192
+ struct ggml_tensor * ggml_repeat_back(
5193
+ struct ggml_context * ctx,
5194
+ struct ggml_tensor * a,
5195
+ struct ggml_tensor * b) {
5196
+ GGML_ASSERT(ggml_can_repeat(b, a));
5197
+
5198
+ bool is_node = false;
5199
+
5200
+ if (a->grad) {
5201
+ is_node = true;
5202
+ }
5203
+
5204
+ if (ggml_are_same_shape(a, b) && !is_node) {
5205
+ return a;
5206
+ }
5207
+
5208
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
5209
+
5210
+ result->op = GGML_OP_REPEAT_BACK;
5211
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5212
+ result->src0 = a;
5213
+ result->src1 = b;
5214
+
5215
+ return result;
5216
+ }
5217
+
5162
5218
  // ggml_abs
5163
5219
 
5164
5220
  struct ggml_tensor * ggml_abs_impl(
@@ -5536,6 +5592,32 @@ struct ggml_tensor * ggml_mul_mat(
5536
5592
  return result;
5537
5593
  }
5538
5594
 
5595
+ // ggml_out_prod
5596
+
5597
+ struct ggml_tensor * ggml_out_prod(
5598
+ struct ggml_context * ctx,
5599
+ struct ggml_tensor * a,
5600
+ struct ggml_tensor * b) {
5601
+ GGML_ASSERT(ggml_can_out_prod(a, b));
5602
+ GGML_ASSERT(!ggml_is_transposed(a));
5603
+
5604
+ bool is_node = false;
5605
+
5606
+ if (a->grad || b->grad) {
5607
+ is_node = true;
5608
+ }
5609
+
5610
+ const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
5611
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
5612
+
5613
+ result->op = GGML_OP_OUT_PROD;
5614
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5615
+ result->src0 = a;
5616
+ result->src1 = b;
5617
+
5618
+ return result;
5619
+ }
5620
+
5539
5621
  // ggml_scale
5540
5622
 
5541
5623
  struct ggml_tensor * ggml_scale_impl(
@@ -5548,7 +5630,7 @@ struct ggml_tensor * ggml_scale_impl(
5548
5630
 
5549
5631
  bool is_node = false;
5550
5632
 
5551
- if (!inplace && (a->grad || b->grad)) {
5633
+ if (a->grad || b->grad) {
5552
5634
  is_node = true;
5553
5635
  }
5554
5636
 
@@ -5591,7 +5673,7 @@ struct ggml_tensor * ggml_set_impl(
5591
5673
 
5592
5674
  bool is_node = false;
5593
5675
 
5594
- if (!inplace && (a->grad || b->grad)) {
5676
+ if (a->grad || b->grad) {
5595
5677
  is_node = true;
5596
5678
  }
5597
5679
 
@@ -5913,10 +5995,6 @@ struct ggml_tensor * ggml_view_1d(
5913
5995
  result->src1 = NULL;
5914
5996
  result->opt[0] = offs;
5915
5997
 
5916
- if (is_node) {
5917
- memcpy(result->padding, &offset, sizeof(offset));
5918
- }
5919
-
5920
5998
  return result;
5921
5999
  }
5922
6000
 
@@ -5957,10 +6035,6 @@ struct ggml_tensor * ggml_view_2d(
5957
6035
  result->src1 = NULL;
5958
6036
  result->opt[0] = offs;
5959
6037
 
5960
- if (is_node) {
5961
- memcpy(result->padding, &offset, sizeof(offset));
5962
- }
5963
-
5964
6038
  return result;
5965
6039
  }
5966
6040
 
@@ -6003,10 +6077,6 @@ struct ggml_tensor * ggml_view_3d(
6003
6077
  result->src1 = NULL;
6004
6078
  result->opt[0] = offs;
6005
6079
 
6006
- if (is_node) {
6007
- memcpy(result->padding, &offset, sizeof(offset));
6008
- }
6009
-
6010
6080
  return result;
6011
6081
  }
6012
6082
 
@@ -6051,10 +6121,6 @@ struct ggml_tensor * ggml_view_4d(
6051
6121
  result->src1 = NULL;
6052
6122
  result->opt[0] = offs;
6053
6123
 
6054
- if (is_node) {
6055
- memcpy(result->padding, &offset, sizeof(offset));
6056
- }
6057
-
6058
6124
  return result;
6059
6125
  }
6060
6126
 
@@ -6116,10 +6182,18 @@ struct ggml_tensor * ggml_permute(
6116
6182
  result->src1 = NULL;
6117
6183
 
6118
6184
  if (is_node) {
6119
- result->padding[0] = axis0;
6120
- result->padding[1] = axis1;
6121
- result->padding[2] = axis2;
6122
- result->padding[3] = axis3;
6185
+ ggml_scratch_save(ctx);
6186
+
6187
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6188
+
6189
+ ((int32_t *) b->data)[0] = axis0;
6190
+ ((int32_t *) b->data)[1] = axis1;
6191
+ ((int32_t *) b->data)[2] = axis2;
6192
+ ((int32_t *) b->data)[3] = axis3;
6193
+
6194
+ ggml_scratch_load(ctx);
6195
+
6196
+ result->opt[0] = b;
6123
6197
  }
6124
6198
 
6125
6199
  return result;
@@ -6359,6 +6433,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
6359
6433
  return ggml_soft_max_impl(ctx, a, true);
6360
6434
  }
6361
6435
 
6436
+
6437
+ // ggml_soft_max_back
6438
+
6439
+ struct ggml_tensor * ggml_soft_max_back_impl(
6440
+ struct ggml_context * ctx,
6441
+ struct ggml_tensor * a,
6442
+ struct ggml_tensor * b,
6443
+ bool inplace) {
6444
+ bool is_node = false;
6445
+
6446
+ if (a->grad || b->grad) {
6447
+ is_node = true; // TODO : implement backward pass
6448
+ }
6449
+
6450
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6451
+
6452
+ result->op = GGML_OP_SOFT_MAX_BACK;
6453
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6454
+ result->src0 = a;
6455
+ result->src1 = b;
6456
+
6457
+ return result;
6458
+ }
6459
+
6460
+ struct ggml_tensor * ggml_soft_max_back(
6461
+ struct ggml_context * ctx,
6462
+ struct ggml_tensor * a,
6463
+ struct ggml_tensor * b) {
6464
+ return ggml_soft_max_back_impl(ctx, a, b, false);
6465
+ }
6466
+
6467
+ struct ggml_tensor * ggml_soft_max_back_inplace(
6468
+ struct ggml_context * ctx,
6469
+ struct ggml_tensor * a,
6470
+ struct ggml_tensor * b) {
6471
+ return ggml_soft_max_back_impl(ctx, a, b, true);
6472
+ }
6473
+
6362
6474
  // ggml_rope
6363
6475
 
6364
6476
  struct ggml_tensor * ggml_rope_impl(
@@ -6371,7 +6483,7 @@ struct ggml_tensor * ggml_rope_impl(
6371
6483
  GGML_ASSERT(n_past >= 0);
6372
6484
  bool is_node = false;
6373
6485
 
6374
- if (!inplace && a->grad) {
6486
+ if (a->grad) {
6375
6487
  is_node = true;
6376
6488
  }
6377
6489
 
@@ -6425,8 +6537,7 @@ struct ggml_tensor * ggml_rope_back(
6425
6537
  bool is_node = false;
6426
6538
 
6427
6539
  if (a->grad) {
6428
- GGML_ASSERT(false); // TODO: implement backward
6429
- is_node = true;
6540
+ is_node = false; // TODO: implement backward
6430
6541
  }
6431
6542
 
6432
6543
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6591,7 +6702,6 @@ struct ggml_tensor * ggml_flash_attn(
6591
6702
  bool is_node = false;
6592
6703
 
6593
6704
  if (q->grad || k->grad || v->grad) {
6594
- GGML_ASSERT(false); // TODO: implement backward
6595
6705
  is_node = true;
6596
6706
  }
6597
6707
 
@@ -6623,7 +6733,6 @@ struct ggml_tensor * ggml_flash_ff(
6623
6733
  bool is_node = false;
6624
6734
 
6625
6735
  if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6626
- GGML_ASSERT(false); // TODO: implement backward
6627
6736
  is_node = true;
6628
6737
  }
6629
6738
 
@@ -6641,6 +6750,71 @@ struct ggml_tensor * ggml_flash_ff(
6641
6750
  return result;
6642
6751
  }
6643
6752
 
6753
+ // ggml_flash_attn_back
6754
+
6755
+ struct ggml_tensor * ggml_flash_attn_back(
6756
+ struct ggml_context * ctx,
6757
+ struct ggml_tensor * q,
6758
+ struct ggml_tensor * k,
6759
+ struct ggml_tensor * v,
6760
+ struct ggml_tensor * d,
6761
+ bool masked) {
6762
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6763
+ // TODO: check if vT can be multiplied by (k*qT)
6764
+
6765
+ // d shape [D,N,ne2,ne3]
6766
+ // q shape [D,N,ne2,ne3]
6767
+ // k shape [D,M,ne2,ne3]
6768
+ // v shape [M,D,ne2,ne3]
6769
+
6770
+ const int64_t D = q->ne[0];
6771
+ const int64_t N = q->ne[1];
6772
+ const int64_t M = k->ne[1];
6773
+ const int64_t ne2 = q->ne[2];
6774
+ const int64_t ne3 = q->ne[3];
6775
+
6776
+ GGML_ASSERT(k->ne[0] == D);
6777
+ GGML_ASSERT(v->ne[0] == M);
6778
+ GGML_ASSERT(v->ne[1] == D);
6779
+ GGML_ASSERT(d->ne[0] == D);
6780
+ GGML_ASSERT(d->ne[1] == N);
6781
+ GGML_ASSERT(k->ne[2] == ne2);
6782
+ GGML_ASSERT(k->ne[3] == ne3);
6783
+ GGML_ASSERT(v->ne[2] == ne2);
6784
+ GGML_ASSERT(v->ne[3] == ne3);
6785
+ GGML_ASSERT(d->ne[2] == ne2);
6786
+ GGML_ASSERT(d->ne[3] == ne3);
6787
+
6788
+ bool is_node = false;
6789
+
6790
+ if (q->grad || k->grad || v->grad) {
6791
+ // when using this operation (in backwards pass) these grads are set.
6792
+ // we don't want to create (big) grad of our result, so is_node is false.
6793
+ is_node = false;
6794
+ }
6795
+
6796
+ // store gradients of q, k and v as continuous tensors concatenated in result.
6797
+ // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
6798
+ // gradq->data = result->data
6799
+ // gradk->data = result->data + nb0*D*N*ne2*ne3
6800
+ // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
6801
+ // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
6802
+ int64_t ne[4] = {D,M+N+M,ne2,ne3};
6803
+
6804
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6805
+
6806
+ result->op = GGML_OP_FLASH_ATTN_BACK;
6807
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6808
+ result->src0 = q;
6809
+ result->src1 = k;
6810
+ result->opt[0] = v;
6811
+ result->opt[1] = d;
6812
+ result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
6813
+
6814
+ return result;
6815
+ }
6816
+
6817
+
6644
6818
  // ggml_map_unary
6645
6819
 
6646
6820
  struct ggml_tensor * ggml_map_unary_impl_f32(
@@ -6725,6 +6899,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
6725
6899
  return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
6726
6900
  }
6727
6901
 
6902
+ // ggml_cross_entropy_loss
6903
+
6904
+ struct ggml_tensor * ggml_cross_entropy_loss(
6905
+ struct ggml_context * ctx,
6906
+ struct ggml_tensor * a,
6907
+ struct ggml_tensor * b) {
6908
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6909
+ bool is_node = false;
6910
+
6911
+ if (a->grad || b->grad) {
6912
+ is_node = true;
6913
+ }
6914
+
6915
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
6916
+
6917
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS;
6918
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6919
+ result->src0 = a;
6920
+ result->src1 = b;
6921
+
6922
+ return result;
6923
+ }
6924
+
6925
+ // ggml_cross_entropy_loss_back
6926
+
6927
+ struct ggml_tensor * ggml_cross_entropy_loss_back(
6928
+ struct ggml_context * ctx,
6929
+ struct ggml_tensor * a,
6930
+ struct ggml_tensor * b,
6931
+ struct ggml_tensor * c) {
6932
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6933
+ GGML_ASSERT(ggml_is_scalar(c));
6934
+
6935
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6936
+
6937
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
6938
+ result->grad = NULL;
6939
+ result->src0 = a;
6940
+ result->src1 = b;
6941
+ result->opt[0] = c;
6942
+
6943
+ return result;
6944
+ }
6945
+
6728
6946
  ////////////////////////////////////////////////////////////////////////////////
6729
6947
 
6730
6948
  void ggml_set_param(
@@ -8875,6 +9093,99 @@ static void ggml_compute_forward_repeat(
8875
9093
  }
8876
9094
  }
8877
9095
 
9096
+ // ggml_compute_forward_repeat_back
9097
+
9098
+ static void ggml_compute_forward_repeat_back_f32(
9099
+ const struct ggml_compute_params * params,
9100
+ const struct ggml_tensor * src0,
9101
+ struct ggml_tensor * dst) {
9102
+ GGML_ASSERT(params->ith == 0);
9103
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
9104
+
9105
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9106
+ return;
9107
+ }
9108
+
9109
+ const int64_t ne0 = dst->ne[0];
9110
+ const int64_t ne1 = dst->ne[1];
9111
+ const int64_t ne2 = dst->ne[2];
9112
+ const int64_t ne3 = dst->ne[3];
9113
+
9114
+ const int64_t ne00 = src0->ne[0];
9115
+ const int64_t ne01 = src0->ne[1];
9116
+ const int64_t ne02 = src0->ne[2];
9117
+ const int64_t ne03 = src0->ne[3];
9118
+
9119
+ const size_t nb0 = dst->nb[0];
9120
+ const size_t nb1 = dst->nb[1];
9121
+ const size_t nb2 = dst->nb[2];
9122
+ const size_t nb3 = dst->nb[3];
9123
+
9124
+ const size_t nb00 = src0->nb[0];
9125
+ const size_t nb01 = src0->nb[1];
9126
+ const size_t nb02 = src0->nb[2];
9127
+ const size_t nb03 = src0->nb[3];
9128
+
9129
+ // guaranteed to be an integer due to the check in ggml_can_repeat
9130
+ const int nr0 = (int)(ne00/ne0);
9131
+ const int nr1 = (int)(ne01/ne1);
9132
+ const int nr2 = (int)(ne02/ne2);
9133
+ const int nr3 = (int)(ne03/ne3);
9134
+
9135
+ // TODO: support for transposed / permuted tensors
9136
+ GGML_ASSERT(nb0 == sizeof(float));
9137
+ GGML_ASSERT(nb00 == sizeof(float));
9138
+
9139
+ if (ggml_is_contiguous(dst)) {
9140
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9141
+ } else {
9142
+ for (int k3 = 0; k3 < ne3; k3++) {
9143
+ for (int k2 = 0; k2 < ne2; k2++) {
9144
+ for (int k1 = 0; k1 < ne1; k1++) {
9145
+ ggml_vec_set_f32(ne0,
9146
+ (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
9147
+ 0);
9148
+ }
9149
+ }
9150
+ }
9151
+ }
9152
+
9153
+ // TODO: maybe this is not optimal?
9154
+ for (int i3 = 0; i3 < nr3; i3++) {
9155
+ for (int k3 = 0; k3 < ne3; k3++) {
9156
+ for (int i2 = 0; i2 < nr2; i2++) {
9157
+ for (int k2 = 0; k2 < ne2; k2++) {
9158
+ for (int i1 = 0; i1 < nr1; i1++) {
9159
+ for (int k1 = 0; k1 < ne1; k1++) {
9160
+ for (int i0 = 0; i0 < nr0; i0++) {
9161
+ ggml_vec_acc_f32(ne0,
9162
+ (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
9163
+ (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
9164
+ }
9165
+ }
9166
+ }
9167
+ }
9168
+ }
9169
+ }
9170
+ }
9171
+ }
9172
+
9173
+ static void ggml_compute_forward_repeat_back(
9174
+ const struct ggml_compute_params * params,
9175
+ const struct ggml_tensor * src0,
9176
+ struct ggml_tensor * dst) {
9177
+ switch (src0->type) {
9178
+ case GGML_TYPE_F32:
9179
+ {
9180
+ ggml_compute_forward_repeat_back_f32(params, src0, dst);
9181
+ } break;
9182
+ default:
9183
+ {
9184
+ GGML_ASSERT(false);
9185
+ } break;
9186
+ }
9187
+ }
9188
+
8878
9189
  // ggml_compute_forward_abs
8879
9190
 
8880
9191
  static void ggml_compute_forward_abs_f32(
@@ -10249,18 +10560,188 @@ static void ggml_compute_forward_mul_mat(
10249
10560
  }
10250
10561
  }
10251
10562
 
10252
- // ggml_compute_forward_scale
10563
+ // ggml_compute_forward_out_prod
10253
10564
 
10254
- static void ggml_compute_forward_scale_f32(
10565
+
10566
+ static void ggml_compute_forward_out_prod_f32(
10255
10567
  const struct ggml_compute_params * params,
10256
10568
  const struct ggml_tensor * src0,
10257
10569
  const struct ggml_tensor * src1,
10258
- struct ggml_tensor * dst) {
10259
- GGML_ASSERT(ggml_is_contiguous(src0));
10260
- GGML_ASSERT(ggml_is_contiguous(dst));
10261
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
10262
- GGML_ASSERT(ggml_is_scalar(src1));
10263
-
10570
+ struct ggml_tensor * dst) {
10571
+ int64_t t0 = ggml_perf_time_us();
10572
+ UNUSED(t0);
10573
+
10574
+ const int64_t ne00 = src0->ne[0];
10575
+ const int64_t ne01 = src0->ne[1];
10576
+ const int64_t ne02 = src0->ne[2];
10577
+ const int64_t ne03 = src0->ne[3];
10578
+
10579
+ const int64_t ne10 = src1->ne[0];
10580
+ //const int64_t ne11 = src1->ne[1];
10581
+ const int64_t ne12 = src1->ne[2];
10582
+ const int64_t ne13 = src1->ne[3];
10583
+
10584
+ const int64_t ne0 = dst->ne[0];
10585
+ const int64_t ne1 = dst->ne[1];
10586
+ const int64_t ne2 = dst->ne[2];
10587
+ const int64_t ne3 = dst->ne[3];
10588
+
10589
+ const int nb00 = src0->nb[0];
10590
+ const int nb01 = src0->nb[1];
10591
+ const int nb02 = src0->nb[2];
10592
+ const int nb03 = src0->nb[3];
10593
+
10594
+ const int nb10 = src1->nb[0];
10595
+ const int nb11 = src1->nb[1];
10596
+ const int nb12 = src1->nb[2];
10597
+ const int nb13 = src1->nb[3];
10598
+
10599
+ const int nb0 = dst->nb[0];
10600
+ const int nb1 = dst->nb[1];
10601
+ const int nb2 = dst->nb[2];
10602
+ const int nb3 = dst->nb[3];
10603
+
10604
+ const int ith = params->ith;
10605
+ const int nth = params->nth;
10606
+
10607
+ GGML_ASSERT(ne02 == ne12);
10608
+ GGML_ASSERT(ne03 == ne13);
10609
+ GGML_ASSERT(ne2 == ne12);
10610
+ GGML_ASSERT(ne3 == ne13);
10611
+
10612
+ // we don't support permuted src0 or src1
10613
+ GGML_ASSERT(nb00 == sizeof(float));
10614
+
10615
+ // dst cannot be transposed or permuted
10616
+ GGML_ASSERT(nb0 == sizeof(float));
10617
+ // GGML_ASSERT(nb0 <= nb1);
10618
+ // GGML_ASSERT(nb1 <= nb2);
10619
+ // GGML_ASSERT(nb2 <= nb3);
10620
+
10621
+ GGML_ASSERT(ne0 == ne00);
10622
+ GGML_ASSERT(ne1 == ne10);
10623
+ GGML_ASSERT(ne2 == ne02);
10624
+ GGML_ASSERT(ne3 == ne03);
10625
+
10626
+ // nb01 >= nb00 - src0 is not transposed
10627
+ // compute by src0 rows
10628
+
10629
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
10630
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10631
+
10632
+ if (params->type == GGML_TASK_INIT) {
10633
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
10634
+ return;
10635
+ }
10636
+
10637
+ if (params->type == GGML_TASK_FINALIZE) {
10638
+ return;
10639
+ }
10640
+
10641
+ // parallelize by last three dimensions
10642
+
10643
+ // total rows in dst
10644
+ const int64_t nr = ne1*ne2*ne3;
10645
+
10646
+ // rows per thread
10647
+ const int64_t dr = (nr + nth - 1)/nth;
10648
+
10649
+ // row range for this thread
10650
+ const int64_t ir0 = dr*ith;
10651
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10652
+
10653
+ // dst[:,:,:,:] = 0
10654
+ // for i2,i3:
10655
+ // for i1:
10656
+ // for i01:
10657
+ // for i0:
10658
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
10659
+
10660
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10661
+ // dst indices
10662
+ const int64_t i3 = ir/(ne2*ne1);
10663
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
10664
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
10665
+
10666
+ const int64_t i02 = i2;
10667
+ const int64_t i03 = i3;
10668
+
10669
+ //const int64_t i10 = i1;
10670
+ const int64_t i12 = i2;
10671
+ const int64_t i13 = i3;
10672
+
10673
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
10674
+ const int64_t i11 = i01;
10675
+
10676
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
10677
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
10678
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
10679
+
10680
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
10681
+ // for (int64_t i0 = 0; i0 < ne0; ++i0) {
10682
+ // d[i0] += s0[i0] * s1[i1];
10683
+ // }
10684
+ }
10685
+ }
10686
+
10687
+ //int64_t t1 = ggml_perf_time_us();
10688
+ //static int64_t acc = 0;
10689
+ //acc += t1 - t0;
10690
+ //if (t1 - t0 > 10) {
10691
+ // printf("\n");
10692
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
10693
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
10694
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
10695
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
10696
+
10697
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
10698
+ //}
10699
+ }
10700
+
10701
+ static void ggml_compute_forward_out_prod(
10702
+ const struct ggml_compute_params * params,
10703
+ const struct ggml_tensor * src0,
10704
+ const struct ggml_tensor * src1,
10705
+ struct ggml_tensor * dst) {
10706
+ switch (src0->type) {
10707
+ case GGML_TYPE_Q4_0:
10708
+ case GGML_TYPE_Q4_1:
10709
+ case GGML_TYPE_Q5_0:
10710
+ case GGML_TYPE_Q5_1:
10711
+ case GGML_TYPE_Q8_0:
10712
+ case GGML_TYPE_Q8_1:
10713
+ {
10714
+ GGML_ASSERT(false); // todo
10715
+ // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10716
+ } break;
10717
+ case GGML_TYPE_F16:
10718
+ {
10719
+ GGML_ASSERT(false); // todo
10720
+ // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
10721
+ } break;
10722
+ case GGML_TYPE_F32:
10723
+ {
10724
+ ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
10725
+ } break;
10726
+ default:
10727
+ {
10728
+ GGML_ASSERT(false);
10729
+ } break;
10730
+ }
10731
+ }
10732
+
10733
+ // ggml_compute_forward_scale
10734
+
10735
+ static void ggml_compute_forward_scale_f32(
10736
+ const struct ggml_compute_params * params,
10737
+ const struct ggml_tensor * src0,
10738
+ const struct ggml_tensor * src1,
10739
+ struct ggml_tensor * dst) {
10740
+ GGML_ASSERT(ggml_is_contiguous(src0));
10741
+ GGML_ASSERT(ggml_is_contiguous(dst));
10742
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
10743
+ GGML_ASSERT(ggml_is_scalar(src1));
10744
+
10264
10745
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10265
10746
  return;
10266
10747
  }
@@ -10671,7 +11152,11 @@ static void ggml_compute_forward_get_rows_back_f32(
10671
11152
  GGML_ASSERT(ggml_is_contiguous(opt0));
10672
11153
  GGML_ASSERT(ggml_is_contiguous(dst));
10673
11154
 
10674
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
11155
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
11156
+
11157
+ if (params->type == GGML_TASK_INIT) {
11158
+ memset(dst->data, 0, ggml_nbytes(dst));
11159
+ }
10675
11160
 
10676
11161
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10677
11162
  return;
@@ -10815,8 +11300,8 @@ static void ggml_compute_forward_diag_mask_f32(
10815
11300
  const struct ggml_tensor * src1,
10816
11301
  struct ggml_tensor * dst,
10817
11302
  const float value) {
10818
- assert(src1->type == GGML_TYPE_I32);
10819
- assert(ggml_nelements(src1) == 2);
11303
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11304
+ GGML_ASSERT(ggml_nelements(src1) == 2);
10820
11305
 
10821
11306
  const int ith = params->ith;
10822
11307
  const int nth = params->nth;
@@ -10824,7 +11309,7 @@ static void ggml_compute_forward_diag_mask_f32(
10824
11309
  const int n_past = ((int32_t *) src1->data)[0];
10825
11310
  const bool inplace = (bool)((int32_t *) src1->data)[1];
10826
11311
 
10827
- assert(n_past >= 0);
11312
+ GGML_ASSERT(n_past >= 0);
10828
11313
 
10829
11314
  if (!inplace && (params->type == GGML_TASK_INIT)) {
10830
11315
  // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -10848,8 +11333,8 @@ static void ggml_compute_forward_diag_mask_f32(
10848
11333
  const int nr = src0->ne[1];
10849
11334
  const int nz = n/nr;
10850
11335
 
10851
- assert( dst->nb[0] == sizeof(float));
10852
- assert(src0->nb[0] == sizeof(float));
11336
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
11337
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
10853
11338
 
10854
11339
  for (int k = 0; k < nz; k++) {
10855
11340
  for (int j = ith; j < nr; j += nth) {
@@ -10985,6 +11470,101 @@ static void ggml_compute_forward_soft_max(
10985
11470
  }
10986
11471
  }
10987
11472
 
11473
+ // ggml_compute_forward_soft_max_back
11474
+
11475
+ static void ggml_compute_forward_soft_max_back_f32(
11476
+ const struct ggml_compute_params * params,
11477
+ const struct ggml_tensor * src0,
11478
+ const struct ggml_tensor * src1,
11479
+ struct ggml_tensor * dst) {
11480
+ GGML_ASSERT(ggml_is_contiguous(src0));
11481
+ GGML_ASSERT(ggml_is_contiguous(src1));
11482
+ GGML_ASSERT(ggml_is_contiguous(dst));
11483
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
11484
+ GGML_ASSERT(ggml_are_same_shape(src1, dst));
11485
+
11486
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11487
+ return;
11488
+ }
11489
+
11490
+ // TODO: handle transposed/permuted matrices
11491
+
11492
+ const int ith = params->ith;
11493
+ const int nth = params->nth;
11494
+
11495
+ const int nc = src0->ne[0];
11496
+ const int nr = ggml_nrows(src0);
11497
+
11498
+ // rows per thread
11499
+ const int dr = (nr + nth - 1)/nth;
11500
+
11501
+ // row range for this thread
11502
+ const int ir0 = dr*ith;
11503
+ const int ir1 = MIN(ir0 + dr, nr);
11504
+
11505
+ for (int i1 = ir0; i1 < ir1; i1++) {
11506
+ float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
11507
+ float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
11508
+ float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
11509
+
11510
+ #ifndef NDEBUG
11511
+ for (int i = 0; i < nc; ++i) {
11512
+ //printf("p[%d] = %f\n", i, p[i]);
11513
+ assert(!isnan(dy[i]));
11514
+ assert(!isnan(y[i]));
11515
+ }
11516
+ #endif
11517
+ // Jii = yi - yi*yi
11518
+ // Jij = -yi*yj
11519
+ // J = diag(y)-y.T*y
11520
+ // dx = J * dy
11521
+ // dxk = sum_i(Jki * dyi)
11522
+ // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
11523
+ // dxk = sum_i(-yk*yi * dyi) + yk*dyk
11524
+ // dxk = -yk * sum_i(yi * dyi) + yk*dyk
11525
+ // dxk = -yk * dot(y, dy) + yk*dyk
11526
+ // dxk = yk * (- dot(y, dy) + dyk)
11527
+ // dxk = yk * (dyk - dot(y, dy))
11528
+ //
11529
+ // post-order:
11530
+ // dot_y_dy := dot(y, dy)
11531
+ // dx := dy
11532
+ // dx := dx - dot_y_dy
11533
+ // dx := dx * y
11534
+
11535
+ // linear runtime, no additional memory
11536
+ float dot_y_dy = 0;
11537
+ ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
11538
+ ggml_vec_cpy_f32 (nc, dx, dy);
11539
+ ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
11540
+ ggml_vec_mul_f32 (nc, dx, dx, y);
11541
+
11542
+ #ifndef NDEBUG
11543
+ for (int i = 0; i < nc; ++i) {
11544
+ assert(!isnan(dx[i]));
11545
+ assert(!isinf(dx[i]));
11546
+ }
11547
+ #endif
11548
+ }
11549
+ }
11550
+
11551
+ static void ggml_compute_forward_soft_max_back(
11552
+ const struct ggml_compute_params * params,
11553
+ const struct ggml_tensor * src0,
11554
+ const struct ggml_tensor * src1,
11555
+ struct ggml_tensor * dst) {
11556
+ switch (src0->type) {
11557
+ case GGML_TYPE_F32:
11558
+ {
11559
+ ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
11560
+ } break;
11561
+ default:
11562
+ {
11563
+ GGML_ASSERT(false);
11564
+ } break;
11565
+ }
11566
+ }
11567
+
10988
11568
  // ggml_compute_forward_alibi
10989
11569
 
10990
11570
  static void ggml_compute_forward_alibi_f32(
@@ -12938,42 +13518,616 @@ static void ggml_compute_forward_flash_ff(
12938
13518
  }
12939
13519
  }
12940
13520
 
12941
- // ggml_compute_forward_map_unary
13521
+ // ggml_compute_forward_flash_attn_back
12942
13522
 
12943
- static void ggml_compute_forward_map_unary_f32(
13523
+ static void ggml_compute_forward_flash_attn_back_f32(
12944
13524
  const struct ggml_compute_params * params,
12945
- const struct ggml_tensor * src0,
12946
- struct ggml_tensor * dst,
12947
- const ggml_unary_op_f32_t fun) {
12948
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
13525
+ const struct ggml_tensor * q,
13526
+ const struct ggml_tensor * k,
13527
+ const struct ggml_tensor * v,
13528
+ const struct ggml_tensor * d,
13529
+ const bool masked,
13530
+ struct ggml_tensor * dst) {
13531
+ int64_t t0 = ggml_perf_time_us();
13532
+ UNUSED(t0);
12949
13533
 
12950
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12951
- return;
12952
- }
13534
+ const int64_t neq0 = q->ne[0];
13535
+ const int64_t neq1 = q->ne[1];
13536
+ const int64_t neq2 = q->ne[2];
13537
+ const int64_t neq3 = q->ne[3];
12953
13538
 
12954
- const int n = ggml_nrows(src0);
12955
- const int nc = src0->ne[0];
13539
+ const int64_t nek0 = k->ne[0];
13540
+ const int64_t nek1 = k->ne[1];
13541
+ //const int64_t nek2 = k->ne[2];
13542
+ //const int64_t nek3 = k->ne[3];
12956
13543
 
12957
- assert( dst->nb[0] == sizeof(float));
12958
- assert(src0->nb[0] == sizeof(float));
13544
+ const int64_t nev0 = v->ne[0];
13545
+ const int64_t nev1 = v->ne[1];
13546
+ //const int64_t nev2 = v->ne[2];
13547
+ //const int64_t nev3 = v->ne[3];
12959
13548
 
12960
- for (int i = 0; i < n; i++) {
12961
- fun(nc,
12962
- (float *) ((char *) dst->data + i*( dst->nb[1])),
12963
- (float *) ((char *) src0->data + i*(src0->nb[1])));
12964
- }
12965
- }
13549
+ const int64_t ned0 = d->ne[0];
13550
+ const int64_t ned1 = d->ne[1];
13551
+ //const int64_t ned2 = d->ne[2];
13552
+ //const int64_t ned3 = d->ne[3];
12966
13553
 
13554
+ const int64_t ne0 = dst->ne[0];
13555
+ const int64_t ne1 = dst->ne[1];
13556
+ const int64_t ne2 = dst->ne[2];
13557
+ const int64_t ne3 = dst->ne[3];
12967
13558
 
12968
- static void ggml_compute_forward_map_unary(
12969
- const struct ggml_compute_params * params,
12970
- const struct ggml_tensor * src0,
12971
- struct ggml_tensor * dst,
12972
- const ggml_unary_op_f32_t fun) {
13559
+ const int nbk0 = k->nb[0];
13560
+ const int nbk1 = k->nb[1];
13561
+ const int nbk2 = k->nb[2];
13562
+ const int nbk3 = k->nb[3];
13563
+
13564
+ const int nbq0 = q->nb[0];
13565
+ const int nbq1 = q->nb[1];
13566
+ const int nbq2 = q->nb[2];
13567
+ const int nbq3 = q->nb[3];
13568
+
13569
+ const int nbv0 = v->nb[0];
13570
+ const int nbv1 = v->nb[1];
13571
+ const int nbv2 = v->nb[2];
13572
+ const int nbv3 = v->nb[3];
13573
+
13574
+ const int nbd0 = d->nb[0];
13575
+ const int nbd1 = d->nb[1];
13576
+ const int nbd2 = d->nb[2];
13577
+ const int nbd3 = d->nb[3];
13578
+
13579
+ const int nb0 = dst->nb[0];
13580
+ const int nb1 = dst->nb[1];
13581
+ const int nb2 = dst->nb[2];
13582
+ const int nb3 = dst->nb[3];
13583
+
13584
+ const int ith = params->ith;
13585
+ const int nth = params->nth;
13586
+
13587
+ const int64_t D = neq0;
13588
+ const int64_t N = neq1;
13589
+ const int64_t P = nek1 - N;
13590
+ const int64_t M = P + N;
13591
+
13592
+ const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
13593
+ const int mxDM = MAX(D, Mup);
13594
+
13595
+ // GGML_ASSERT(ne0 == D);
13596
+ // GGML_ASSERT(ne1 == N);
13597
+ GGML_ASSERT(P >= 0);
13598
+
13599
+ GGML_ASSERT(nbq0 == sizeof(float));
13600
+ GGML_ASSERT(nbk0 == sizeof(float));
13601
+ GGML_ASSERT(nbv0 == sizeof(float));
13602
+
13603
+ GGML_ASSERT(neq0 == D);
13604
+ GGML_ASSERT(nek0 == D);
13605
+ GGML_ASSERT(nev1 == D);
13606
+ GGML_ASSERT(ned0 == D);
13607
+
13608
+ GGML_ASSERT(neq1 == N);
13609
+ GGML_ASSERT(nek1 == N + P);
13610
+ GGML_ASSERT(nev1 == D);
13611
+ GGML_ASSERT(ned1 == N);
13612
+
13613
+ // dst cannot be transposed or permuted
13614
+ GGML_ASSERT(nb0 == sizeof(float));
13615
+ GGML_ASSERT(nb0 <= nb1);
13616
+ GGML_ASSERT(nb1 <= nb2);
13617
+ GGML_ASSERT(nb2 <= nb3);
13618
+
13619
+ if (params->type == GGML_TASK_INIT) {
13620
+ if (ith == 0) {
13621
+ memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
13622
+ }
13623
+ return;
13624
+ }
13625
+
13626
+ if (params->type == GGML_TASK_FINALIZE) {
13627
+ return;
13628
+ }
13629
+
13630
+ // parallelize by q rows using ggml_vec_dot_f32
13631
+
13632
+ // total rows in q
13633
+ const int nr = neq2*neq3;
13634
+
13635
+ // rows per thread
13636
+ const int dr = (nr + nth - 1)/nth;
13637
+
13638
+ // row range for this thread
13639
+ const int ir0 = dr*ith;
13640
+ const int ir1 = MIN(ir0 + dr, nr);
13641
+
13642
+ const float scale = 1.0f/sqrtf(D);
13643
+
13644
+ //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
13645
+
13646
+ for (int ir = ir0; ir < ir1; ++ir) {
13647
+ // q indices
13648
+ const int iq3 = ir/(neq2);
13649
+ const int iq2 = ir - iq3*neq2;
13650
+ for ( int iq1 = 0; iq1 < neq1; ++iq1) {
13651
+
13652
+
13653
+ // not sure about CACHE_LINE_SIZE_F32..
13654
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
13655
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
13656
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
13657
+
13658
+ for (int i = M; i < Mup; ++i) {
13659
+ S[i] = -INFINITY;
13660
+ }
13661
+
13662
+ for (int64_t ic = 0; ic < nek1; ++ic) {
13663
+ // k indices
13664
+ const int ik3 = iq3;
13665
+ const int ik2 = iq2;
13666
+ const int ik1 = ic;
13667
+
13668
+ // S indices
13669
+ const int i1 = ik1;
13670
+
13671
+ ggml_vec_dot_f32(neq0,
13672
+ S + i1,
13673
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
13674
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
13675
+ }
13676
+
13677
+ // scale
13678
+ ggml_vec_scale_f32(nek1, S, scale);
13679
+
13680
+ if (masked) {
13681
+ for (int64_t i = P; i < M; i++) {
13682
+ if (i > P + iq1) {
13683
+ S[i] = -INFINITY;
13684
+ }
13685
+ }
13686
+ }
13687
+
13688
+ // softmax
13689
+ {
13690
+ float max = -INFINITY;
13691
+ ggml_vec_max_f32(M, &max, S);
13692
+
13693
+ ggml_float sum = 0.0;
13694
+ {
13695
+ #ifdef GGML_SOFT_MAX_ACCELERATE
13696
+ max = -max;
13697
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
13698
+ vvexpf(SM, SM, &Mup);
13699
+ ggml_vec_sum_f32(Mup, &sum, SM);
13700
+ #else
13701
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL];
13702
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
13703
+
13704
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
13705
+ float * SR = S + i;
13706
+ float * SW = SM + i;
13707
+
13708
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
13709
+ if (SR[j] == -INFINITY) {
13710
+ SW[j] = 0.0f;
13711
+ } else {
13712
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
13713
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
13714
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
13715
+ sump[j] += (ggml_float)val;
13716
+ SW[j] = val;
13717
+ }
13718
+ }
13719
+ }
13720
+
13721
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
13722
+ sum += sump[i];
13723
+ }
13724
+ #endif
13725
+ }
13726
+
13727
+ assert(sum > 0.0);
13728
+
13729
+ sum = 1.0/sum;
13730
+ ggml_vec_scale_f32(M, SM, sum);
13731
+
13732
+ }
13733
+
13734
+ // step-by-step explanation
13735
+ {
13736
+ // forward-process shape grads from backward process
13737
+ // parallel_for iq2,iq3:
13738
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
13739
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
13740
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
13741
+ // for iq1:
13742
+ // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
13743
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
13744
+ // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
13745
+ // S0 = -Inf [D,1,1,1]
13746
+ // ~S1[i] = dot(kcur[:D,i], qcur)
13747
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
13748
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
13749
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13750
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
13751
+ // ~S5[i] = dot(vcur[:,i], S4)
13752
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
13753
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
13754
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
13755
+ // dst backward-/ grad[dst] = d
13756
+ //
13757
+ // output gradients with their dependencies:
13758
+ //
13759
+ // grad[kcur] = grad[S1].T @ qcur
13760
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
13761
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13762
+ // grad[S4] = grad[S5] @ vcur
13763
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
13764
+ // grad[qcur] = grad[S1] @ kcur
13765
+ // grad[vcur] = grad[S5].T @ S4
13766
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
13767
+ //
13768
+ // in post-order:
13769
+ //
13770
+ // S1 = qcur @ kcur.T
13771
+ // S2 = S1 * scale
13772
+ // S3 = diag_mask_inf(S2, P)
13773
+ // S4 = softmax(S3)
13774
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
13775
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13776
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
13777
+ // grad[qcur] = grad[S1] @ kcur
13778
+ // grad[kcur] = grad[S1].T @ qcur
13779
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
13780
+ //
13781
+ // using less variables (SM=S4):
13782
+ //
13783
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
13784
+ // SM = softmax(S)
13785
+ // S = d[:D,iq1,iq2,iq3] @ vcur
13786
+ // dot_SM_gradSM = dot(SM, S)
13787
+ // S = SM * (S - dot(SM, S))
13788
+ // S = diag_mask_zero(S, P) * scale
13789
+ //
13790
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
13791
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
13792
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
13793
+ }
13794
+
13795
+ // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
13796
+ // S = d[:D,iq1,iq2,iq3] @ vcur
13797
+ // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
13798
+ ggml_vec_set_f32(M, S, 0);
13799
+ for (int64_t ic = 0; ic < D; ++ic) {
13800
+ // dst indices
13801
+ const int i1 = iq1;
13802
+ const int i2 = iq2;
13803
+ const int i3 = iq3;
13804
+
13805
+ ggml_vec_mad_f32(M,
13806
+ S,
13807
+ (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
13808
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
13809
+ }
13810
+
13811
+ // S = SM * (S - dot(SM, S))
13812
+ float dot_SM_gradSM = 0;
13813
+ ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
13814
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
13815
+ ggml_vec_mul_f32 (M, S, S, SM);
13816
+
13817
+ // S = diag_mask_zero(S, P) * scale
13818
+ if (masked) {
13819
+ // for (int64_t i = P + iq1 + 1; i < M; i++) {
13820
+ // S[i] = 0;
13821
+ // }
13822
+ for (int64_t i = P; i < M; i++) {
13823
+ if (i > P + iq1) {
13824
+ S[i] = 0;
13825
+ }
13826
+ }
13827
+ }
13828
+ ggml_vec_scale_f32(M, S, scale);
13829
+
13830
+ void * grad_q = (char *) dst->data;
13831
+ void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
13832
+ void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
13833
+
13834
+ const size_t nbgq1 = nb0*neq0;
13835
+ const size_t nbgq2 = nb0*neq0*neq1;
13836
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
13837
+
13838
+ const size_t nbgk1 = nb0*nek0;
13839
+ const size_t nbgk2 = nb0*nek0*nek1;
13840
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
13841
+
13842
+ const size_t nbgv1 = nb0*nev0;
13843
+ const size_t nbgv2 = nb0*nev0*nev1;
13844
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
13845
+
13846
+ // S shape [M,1]
13847
+ // SM shape [M,1]
13848
+ // kcur shape [D,M]
13849
+ // qcur shape [D,1]
13850
+ // vcur shape [M,D]
13851
+ //
13852
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
13853
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
13854
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
13855
+ //
13856
+ //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
13857
+ //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
13858
+ for (int64_t ic = 0; ic < M; ++ic) {
13859
+ // dst indices
13860
+ const int i1 = iq1;
13861
+ const int i2 = iq2;
13862
+ const int i3 = iq3;
13863
+
13864
+ ggml_vec_mad_f32(D,
13865
+ (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
13866
+ (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
13867
+ S[ic]);
13868
+ }
13869
+
13870
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
13871
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
13872
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
13873
+ for (int64_t ic = 0; ic < M; ++ic) {
13874
+ // dst indices
13875
+ const int i1 = iq1;
13876
+ const int i2 = iq2;
13877
+ const int i3 = iq3;
13878
+
13879
+ // ggml_vec_set_f32(D,
13880
+ // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
13881
+ // 0);
13882
+ ggml_vec_mad_f32(D,
13883
+ (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
13884
+ (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
13885
+ S[ic]);
13886
+ }
13887
+
13888
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
13889
+ // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
13890
+ // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
13891
+ for (int64_t ic = 0; ic < D; ++ic) {
13892
+ // dst indices
13893
+ const int i1 = iq1;
13894
+ const int i2 = iq2;
13895
+ const int i3 = iq3;
13896
+
13897
+ // ggml_vec_set_f32(M,
13898
+ // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
13899
+ // 0);
13900
+ ggml_vec_mad_f32(M,
13901
+ (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
13902
+ SM,
13903
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
13904
+ }
13905
+ }
13906
+ }
13907
+ }
13908
+
13909
+ static void ggml_compute_forward_flash_attn_back(
13910
+ const struct ggml_compute_params * params,
13911
+ const struct ggml_tensor * q,
13912
+ const struct ggml_tensor * k,
13913
+ const struct ggml_tensor * v,
13914
+ const struct ggml_tensor * d,
13915
+ const bool masked,
13916
+ struct ggml_tensor * dst) {
13917
+ switch (q->type) {
13918
+ case GGML_TYPE_F32:
13919
+ {
13920
+ ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
13921
+ } break;
13922
+ default:
13923
+ {
13924
+ GGML_ASSERT(false);
13925
+ } break;
13926
+ }
13927
+ }
13928
+
13929
+ // ggml_compute_forward_map_unary
13930
+
13931
+ static void ggml_compute_forward_map_unary_f32(
13932
+ const struct ggml_compute_params * params,
13933
+ const struct ggml_tensor * src0,
13934
+ struct ggml_tensor * dst,
13935
+ const ggml_unary_op_f32_t fun) {
13936
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
13937
+
13938
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13939
+ return;
13940
+ }
13941
+
13942
+ const int n = ggml_nrows(src0);
13943
+ const int nc = src0->ne[0];
13944
+
13945
+ assert( dst->nb[0] == sizeof(float));
13946
+ assert(src0->nb[0] == sizeof(float));
13947
+
13948
+ for (int i = 0; i < n; i++) {
13949
+ fun(nc,
13950
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
13951
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
13952
+ }
13953
+ }
13954
+
13955
+
13956
+ static void ggml_compute_forward_map_unary(
13957
+ const struct ggml_compute_params * params,
13958
+ const struct ggml_tensor * src0,
13959
+ struct ggml_tensor * dst,
13960
+ const ggml_unary_op_f32_t fun) {
13961
+ switch (src0->type) {
13962
+ case GGML_TYPE_F32:
13963
+ {
13964
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
13965
+ } break;
13966
+ default:
13967
+ {
13968
+ GGML_ASSERT(false);
13969
+ } break;
13970
+ }
13971
+ }
13972
+
13973
+ // ggml_compute_forward_map_binary
13974
+
13975
+ static void ggml_compute_forward_map_binary_f32(
13976
+ const struct ggml_compute_params * params,
13977
+ const struct ggml_tensor * src0,
13978
+ const struct ggml_tensor * src1,
13979
+ struct ggml_tensor * dst,
13980
+ const ggml_binary_op_f32_t fun) {
13981
+ assert(params->ith == 0);
13982
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
13983
+
13984
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13985
+ return;
13986
+ }
13987
+
13988
+ const int n = ggml_nrows(src0);
13989
+ const int nc = src0->ne[0];
13990
+
13991
+ assert( dst->nb[0] == sizeof(float));
13992
+ assert(src0->nb[0] == sizeof(float));
13993
+ assert(src1->nb[0] == sizeof(float));
13994
+
13995
+ for (int i = 0; i < n; i++) {
13996
+ fun(nc,
13997
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
13998
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
13999
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
14000
+ }
14001
+ }
14002
+
14003
+
14004
+ static void ggml_compute_forward_map_binary(
14005
+ const struct ggml_compute_params * params,
14006
+ const struct ggml_tensor * src0,
14007
+ const struct ggml_tensor * src1,
14008
+ struct ggml_tensor * dst,
14009
+ const ggml_binary_op_f32_t fun) {
14010
+ switch (src0->type) {
14011
+ case GGML_TYPE_F32:
14012
+ {
14013
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14014
+ } break;
14015
+ default:
14016
+ {
14017
+ GGML_ASSERT(false);
14018
+ } break;
14019
+ }
14020
+ }
14021
+
14022
+ // ggml_compute_forward_cross_entropy_loss
14023
+
14024
+ static void ggml_compute_forward_cross_entropy_loss_f32(
14025
+ const struct ggml_compute_params * params,
14026
+ const struct ggml_tensor * src0,
14027
+ const struct ggml_tensor * src1,
14028
+ struct ggml_tensor * dst) {
14029
+ GGML_ASSERT(ggml_is_contiguous(src0));
14030
+ GGML_ASSERT(ggml_is_contiguous(src1));
14031
+ GGML_ASSERT(ggml_is_scalar(dst));
14032
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
14033
+
14034
+ const int ith = params->ith;
14035
+ const int nth = params->nth;
14036
+
14037
+ float * sums = (float *) params->wdata;
14038
+
14039
+ // TODO: handle transposed/permuted matrices
14040
+ const int nc = src0->ne[0];
14041
+ const int nr = ggml_nrows(src0);
14042
+
14043
+ if (params->type == GGML_TASK_INIT) {
14044
+ if (ith == 0) {
14045
+ memset(sums, 0, sizeof(float) * (nth + nth * nc));
14046
+ }
14047
+ return;
14048
+ }
14049
+
14050
+ if (params->type == GGML_TASK_FINALIZE) {
14051
+ if (ith == 0) {
14052
+ float * dp = (float *) dst->data;
14053
+ ggml_vec_sum_f32(nth, dp, sums);
14054
+ dp[0] *= -1.0f;
14055
+ }
14056
+ return;
14057
+ }
14058
+
14059
+ const double eps = 1e-9;
14060
+
14061
+ // rows per thread
14062
+ const int dr = (nr + nth - 1)/nth;
14063
+
14064
+ // row range for this thread
14065
+ const int ir0 = dr*ith;
14066
+ const int ir1 = MIN(ir0 + dr, nr);
14067
+
14068
+ for (int i1 = ir0; i1 < ir1; i1++) {
14069
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14070
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14071
+ float * st = (float *) params->wdata + nth + ith*nc;
14072
+
14073
+ #ifndef NDEBUG
14074
+ for (int i = 0; i < nc; ++i) {
14075
+ //printf("p[%d] = %f\n", i, p[i]);
14076
+ assert(!isnan(s0[i]));
14077
+ assert(!isnan(s1[i]));
14078
+ }
14079
+ #endif
14080
+ // soft_max
14081
+ ggml_float sum = 0.0;
14082
+ {
14083
+ float max = -INFINITY;
14084
+ ggml_vec_max_f32(nc, &max, s0);
14085
+
14086
+ uint16_t scvt;
14087
+ for (int i = 0; i < nc; i++) {
14088
+ if (s0[i] == -INFINITY) {
14089
+ st[i] = 0.0f;
14090
+ } else {
14091
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14092
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14093
+ memcpy(&scvt, &s, sizeof(scvt));
14094
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14095
+ sum += (ggml_float)val;
14096
+ st[i] = val;
14097
+ }
14098
+ }
14099
+
14100
+ assert(sum > 0.0);
14101
+ // sum = 1.0/sum;
14102
+ }
14103
+ // avoid log(0) by rescaling from [0..1] to [eps..1]
14104
+ sum = (1.0 - eps) / sum;
14105
+ ggml_vec_scale_f32(nc, st, sum);
14106
+ ggml_vec_add1_f32(nc, st, st, eps);
14107
+ ggml_vec_log_f32(nc, st, st);
14108
+ ggml_vec_mul_f32(nc, st, st, s1);
14109
+
14110
+ ggml_vec_sum_f32(nc, sums + ith, st);
14111
+
14112
+ #ifndef NDEBUG
14113
+ for (int i = 0; i < nc; ++i) {
14114
+ assert(!isnan(st[i]));
14115
+ assert(!isinf(st[i]));
14116
+ }
14117
+ #endif
14118
+ }
14119
+
14120
+ }
14121
+
14122
+ static void ggml_compute_forward_cross_entropy_loss(
14123
+ const struct ggml_compute_params * params,
14124
+ const struct ggml_tensor * src0,
14125
+ const struct ggml_tensor * src1,
14126
+ struct ggml_tensor * dst) {
12973
14127
  switch (src0->type) {
12974
14128
  case GGML_TYPE_F32:
12975
14129
  {
12976
- ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
14130
+ ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
12977
14131
  } break;
12978
14132
  default:
12979
14133
  {
@@ -12982,47 +14136,160 @@ static void ggml_compute_forward_map_unary(
12982
14136
  }
12983
14137
  }
12984
14138
 
12985
- // ggml_compute_forward_map_binary
14139
+ // ggml_compute_forward_cross_entropy_loss_back
12986
14140
 
12987
- static void ggml_compute_forward_map_binary_f32(
14141
+ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12988
14142
  const struct ggml_compute_params * params,
12989
14143
  const struct ggml_tensor * src0,
12990
14144
  const struct ggml_tensor * src1,
12991
- struct ggml_tensor * dst,
12992
- const ggml_binary_op_f32_t fun) {
12993
- assert(params->ith == 0);
12994
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14145
+ const struct ggml_tensor * opt0,
14146
+ struct ggml_tensor * dst) {
14147
+ GGML_ASSERT(ggml_is_contiguous(dst));
14148
+ GGML_ASSERT(ggml_is_contiguous(src0));
14149
+ GGML_ASSERT(ggml_is_contiguous(src1));
14150
+ GGML_ASSERT(ggml_is_contiguous(opt0));
14151
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14152
+
14153
+ const int64_t ith = params->ith;
14154
+ const int64_t nth = params->nth;
12995
14155
 
12996
14156
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12997
14157
  return;
12998
14158
  }
12999
14159
 
13000
- const int n = ggml_nrows(src0);
13001
- const int nc = src0->ne[0];
14160
+ const float eps = 1e-9f;
13002
14161
 
13003
- assert( dst->nb[0] == sizeof(float));
13004
- assert(src0->nb[0] == sizeof(float));
13005
- assert(src1->nb[0] == sizeof(float));
14162
+ // TODO: handle transposed/permuted matrices
14163
+ const int64_t nc = src0->ne[0];
14164
+ const int64_t nr = ggml_nrows(src0);
13006
14165
 
13007
- for (int i = 0; i < n; i++) {
13008
- fun(nc,
13009
- (float *) ((char *) dst->data + i*( dst->nb[1])),
13010
- (float *) ((char *) src0->data + i*(src0->nb[1])),
13011
- (float *) ((char *) src1->data + i*(src1->nb[1])));
14166
+ // rows per thread
14167
+ const int64_t dr = (nr + nth - 1)/nth;
14168
+
14169
+ // row range for this thread
14170
+ const int64_t ir0 = dr*ith;
14171
+ const int64_t ir1 = MIN(ir0 + dr, nr);
14172
+
14173
+ float * d = (float *) opt0->data;
14174
+
14175
+ for (int64_t i1 = ir0; i1 < ir1; i1++) {
14176
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
14177
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14178
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14179
+ float * sm = (float *) params->wdata + ith*nc;
14180
+
14181
+ #ifndef NDEBUG
14182
+ for (int i = 0; i < nc; ++i) {
14183
+ //printf("p[%d] = %f\n", i, p[i]);
14184
+ assert(!isnan(s0[i]));
14185
+ assert(!isnan(s1[i]));
14186
+ }
14187
+ #endif
14188
+ // step by step explanation:
14189
+ {
14190
+ //float * sums = (float *) params->wdata;
14191
+
14192
+ // forward pass with annotated gradients from backward pass
14193
+ // (built by going in reverse operation order, adding to gradients of current operation args)
14194
+ // st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
14195
+ // from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14196
+ // ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
14197
+ // ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
14198
+ // ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
14199
+ // ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
14200
+ // ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
14201
+ // ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
14202
+
14203
+ // substitute into grad[st1], because we can reuse softmax_back from this point on
14204
+ // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
14205
+ // postorder:
14206
+ // grad[st1] := softmax(s0)
14207
+ // grad[st1] := grad[st1]*(1.0 - eps)
14208
+ // grad[st1] := grad[st1] + eps
14209
+ // grad[st1] := s1 / grad[st1]
14210
+ // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
14211
+
14212
+ // src0 gradients by going through softmax_back
14213
+ // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14214
+ // from softmax_back:
14215
+ // dxk = yk * (dyk - dot(y, dy))
14216
+ // dot_y_dy := dot(y, dy)
14217
+ // dx := dy
14218
+ // dx := dx - dot_y_dy
14219
+ // dx := dx * y
14220
+ // postorder:
14221
+ // dot_st1_dst1 := dot(st1, grad[st1])
14222
+ // grad[s0] := grad[st1]
14223
+ // grad[s0] := grad[s0] - dot_st1_dst1
14224
+ // grad[s0] := grad[s0] * st1
14225
+
14226
+ // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
14227
+ // sm := softmax(s0)
14228
+ // grad[s0] := sm*(1.0 - eps)
14229
+ // grad[s0] := grad[s0] + eps
14230
+ // grad[s0] := s1 / grad[s0]
14231
+ // grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
14232
+ // dot_st1_dst1 := dot(sm, grad[s0])
14233
+ // grad[s0] := grad[s0] - dot_st1_dst1
14234
+ // grad[s0] := grad[s0] * sm
14235
+ }
14236
+
14237
+ // soft_max
14238
+ ggml_float sum = 0.0;
14239
+ {
14240
+ float max = -INFINITY;
14241
+ ggml_vec_max_f32(nc, &max, s0);
14242
+
14243
+ uint16_t scvt;
14244
+ for (int i = 0; i < nc; i++) {
14245
+ if (s0[i] == -INFINITY) {
14246
+ sm[i] = 0.0f;
14247
+ } else {
14248
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14249
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14250
+ memcpy(&scvt, &s, sizeof(scvt));
14251
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14252
+ sum += (ggml_float)val;
14253
+ sm[i] = val;
14254
+ }
14255
+ }
14256
+
14257
+ assert(sum > 0.0);
14258
+ sum = 1.0/sum;
14259
+ }
14260
+
14261
+ float dot_st1_dst1 = 0;
14262
+ ggml_vec_scale_f32(nc, sm, sum);
14263
+ ggml_vec_cpy_f32 (nc, ds0, sm);
14264
+ ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
14265
+ ggml_vec_add1_f32 (nc, ds0, ds0, eps);
14266
+ ggml_vec_div_f32 (nc, ds0, s1, ds0);
14267
+ ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
14268
+ ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
14269
+ ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
14270
+ ggml_vec_mul_f32 (nc, ds0, ds0, sm);
14271
+
14272
+ #ifndef NDEBUG
14273
+ for (int i = 0; i < nc; ++i) {
14274
+ assert(!isnan(sm[i]));
14275
+ assert(!isinf(sm[i]));
14276
+ assert(!isnan(ds0[i]));
14277
+ assert(!isinf(ds0[i]));
14278
+ }
14279
+ #endif
13012
14280
  }
13013
14281
  }
13014
14282
 
13015
-
13016
- static void ggml_compute_forward_map_binary(
14283
+ static void ggml_compute_forward_cross_entropy_loss_back(
13017
14284
  const struct ggml_compute_params * params,
13018
14285
  const struct ggml_tensor * src0,
13019
14286
  const struct ggml_tensor * src1,
13020
- struct ggml_tensor * dst,
13021
- const ggml_binary_op_f32_t fun) {
14287
+ const struct ggml_tensor * opt0,
14288
+ struct ggml_tensor * dst) {
13022
14289
  switch (src0->type) {
13023
14290
  case GGML_TYPE_F32:
13024
14291
  {
13025
- ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14292
+ ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
13026
14293
  } break;
13027
14294
  default:
13028
14295
  {
@@ -13031,6 +14298,7 @@ static void ggml_compute_forward_map_binary(
13031
14298
  }
13032
14299
  }
13033
14300
 
14301
+
13034
14302
  /////////////////////////////////
13035
14303
 
13036
14304
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -13102,6 +14370,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13102
14370
  {
13103
14371
  ggml_compute_forward_repeat(params, tensor->src0, tensor);
13104
14372
  } break;
14373
+ case GGML_OP_REPEAT_BACK:
14374
+ {
14375
+ ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14376
+ } break;
13105
14377
  case GGML_OP_ABS:
13106
14378
  {
13107
14379
  ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -13150,6 +14422,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13150
14422
  {
13151
14423
  ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
13152
14424
  } break;
14425
+ case GGML_OP_OUT_PROD:
14426
+ {
14427
+ ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
14428
+ } break;
13153
14429
  case GGML_OP_SCALE:
13154
14430
  {
13155
14431
  ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
@@ -13206,6 +14482,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13206
14482
  {
13207
14483
  ggml_compute_forward_soft_max(params, tensor->src0, tensor);
13208
14484
  } break;
14485
+ case GGML_OP_SOFT_MAX_BACK:
14486
+ {
14487
+ ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
14488
+ } break;
13209
14489
  case GGML_OP_ROPE:
13210
14490
  {
13211
14491
  ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
@@ -13241,6 +14521,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13241
14521
  {
13242
14522
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
13243
14523
  } break;
14524
+ case GGML_OP_FLASH_ATTN_BACK:
14525
+ {
14526
+ int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
14527
+ GGML_ASSERT(t == 0 || t == 1);
14528
+ bool masked = t != 0;
14529
+ ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
14530
+ } break;
13244
14531
  case GGML_OP_MAP_UNARY:
13245
14532
  {
13246
14533
  const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
@@ -13253,6 +14540,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13253
14540
  ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
13254
14541
  }
13255
14542
  break;
14543
+ case GGML_OP_CROSS_ENTROPY_LOSS:
14544
+ {
14545
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
14546
+ }
14547
+ break;
14548
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
14549
+ {
14550
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14551
+ }
14552
+ break;
13256
14553
  case GGML_OP_NONE:
13257
14554
  {
13258
14555
  // nop
@@ -13391,11 +14688,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13391
14688
  src0->grad =
13392
14689
  ggml_add_impl(ctx,
13393
14690
  src0->grad,
13394
- ggml_mul(ctx,
13395
- tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
14691
+ ggml_scale(ctx,
13396
14692
  ggml_div(ctx,
13397
- ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
13398
- tensor)),
14693
+ tensor->grad,
14694
+ tensor),
14695
+ ggml_new_f32(ctx, 0.5f)),
13399
14696
  inplace);
13400
14697
  }
13401
14698
  } break;
@@ -13441,43 +14738,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13441
14738
  {
13442
14739
  // necessary for llama
13443
14740
  if (src0->grad) {
13444
- GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
13445
- const int nc = tensor->ne[0];
13446
- const int nr = tensor->ne[1];
13447
- const int nc0 = src0->ne[0];
13448
- const int nr0 = src0->ne[1];
13449
- const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
13450
- const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
13451
- // tensor->grad [nc,nr,1,1]
13452
- // reshape [nc0,nc/nc0,nr0,nr/nr0]
13453
- // permute [nc0,nr0,nc/nc0,nr/nr0]
13454
- // substitute [nc0,nr0,ncr,nrr]
13455
- // reshape [nc0*nr0,ncr*nrr,1,1]
13456
- // transpose [ncr*nrr,nc0*nr0,1,1]
13457
- // sum rows [1,nc0*nr0,1,1]
13458
- // transpose [nc0*nr0,1,1]
13459
- // reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
13460
- // add to src0->grad
13461
-
13462
- int64_t ne[4] = {nc0,ncr,nr0,nrr};
13463
-
13464
- struct ggml_tensor* F00 = tensor->grad;
13465
- struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
13466
- struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
13467
- struct ggml_tensor* F03 = ggml_cont (ctx, F02);
13468
- struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
13469
- struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
13470
- struct ggml_tensor* F06 = ggml_cont (ctx, F05);
13471
- struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
13472
- struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
13473
- struct ggml_tensor* F09 = ggml_cont (ctx, F08);
13474
- struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
13475
-
13476
- src0->grad =
13477
- ggml_add_impl(ctx,
13478
- src0->grad,
13479
- F10,
13480
- inplace);
14741
+ src0->grad = ggml_add_impl(ctx,
14742
+ src0->grad,
14743
+ ggml_repeat_back(ctx, tensor->grad, src0->grad),
14744
+ inplace);
14745
+ }
14746
+ } break;
14747
+ case GGML_OP_REPEAT_BACK:
14748
+ {
14749
+ if (src0->grad) {
14750
+ // TODO: test this
14751
+ src0->grad = ggml_add_impl(ctx,
14752
+ src0->grad,
14753
+ ggml_repeat(ctx, tensor->grad, src0->grad),
14754
+ inplace);
13481
14755
  }
13482
14756
  } break;
13483
14757
  case GGML_OP_ABS:
@@ -13584,38 +14858,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13584
14858
 
13585
14859
  // necessary for llama
13586
14860
  if (src0->grad) {
13587
- // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
13588
14861
  src0->grad =
13589
14862
  ggml_add_impl(ctx,
13590
14863
  src0->grad,
13591
- // ds0 = dt.dot(s1.T)
13592
- // ggml_out_prod(ctx, // [n,m]
13593
- // src1, // [n,p]
13594
- // tensor->grad), // [m,p]
13595
- // for now just using A*B==(B.T*A.T).T
13596
- ggml_cont(ctx, // [n,m]
13597
- ggml_transpose(ctx, // [n,m]
13598
- ggml_mul_mat(ctx, // [m,n]
13599
- ggml_cont(ctx, // [p,m]
13600
- ggml_transpose(ctx, // [p,m]
13601
- tensor->grad)), // [m,p]
13602
- ggml_cont(ctx, // [p,n]
13603
- ggml_transpose(ctx, // [p,n]
13604
- src1))))), // [n,p]
14864
+ ggml_out_prod(ctx, // [n,m]
14865
+ src1, // [n,p]
14866
+ tensor->grad), // [m,p]
13605
14867
  inplace);
13606
14868
  }
13607
14869
  if (src1->grad) {
13608
14870
  src1->grad =
13609
14871
  ggml_add_impl(ctx,
13610
14872
  src1->grad,
13611
- // ds1 = s0.T.dot(dt):
13612
- ggml_mul_mat(ctx, // [n,p]
13613
- ggml_cont(ctx, // [m,n]
13614
- ggml_transpose(ctx, src0)), // [m,n]
13615
- tensor->grad), // [m,p]
14873
+ // ggml_mul_mat(ctx, // [n,p]
14874
+ // ggml_cont(ctx, // [m,n]
14875
+ // ggml_transpose(ctx, src0)), // [m,n]
14876
+ // tensor->grad), // [m,p]
14877
+
14878
+ // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
14879
+ // // avoid transpose of src0, rather transpose smaller tensor->grad
14880
+ // // and then use ggml_out_prod
14881
+ ggml_out_prod(ctx, // [n,p]
14882
+ src0, // [n,m]
14883
+ ggml_transpose(ctx, // [p,m]
14884
+ tensor->grad)), // [m,p]
13616
14885
  inplace);
13617
14886
  }
13618
14887
  } break;
14888
+ case GGML_OP_OUT_PROD:
14889
+ {
14890
+ GGML_ASSERT(false); // TODO: not implemented
14891
+ } break;
13619
14892
  case GGML_OP_SCALE:
13620
14893
  {
13621
14894
  // necessary for llama
@@ -13717,7 +14990,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13717
14990
  // necessary for llama
13718
14991
  if (src0->grad) {
13719
14992
  size_t offset;
13720
- memcpy(&offset, tensor->padding, sizeof(offset));
14993
+
14994
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
14995
+ memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
13721
14996
 
13722
14997
  size_t nb1 = tensor->nb[1];
13723
14998
  size_t nb2 = tensor->nb[2];
@@ -13744,10 +15019,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13744
15019
  {
13745
15020
  // necessary for llama
13746
15021
  if (src0->grad) {
13747
- int axis0 = tensor->padding[0] & 0x3;
13748
- int axis1 = tensor->padding[1] & 0x3;
13749
- int axis2 = tensor->padding[2] & 0x3;
13750
- int axis3 = tensor->padding[3] & 0x3;
15022
+ int32_t * axes = (int32_t *) tensor->opt[0]->data;
15023
+ int axis0 = axes[0] & 0x3;
15024
+ int axis1 = axes[1] & 0x3;
15025
+ int axis2 = axes[2] & 0x3;
15026
+ int axis3 = axes[3] & 0x3;
13751
15027
  int axes_backward[4] = {0,0,0,0};
13752
15028
  axes_backward[axis0] = 0;
13753
15029
  axes_backward[axis1] = 1;
@@ -13831,50 +15107,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13831
15107
  {
13832
15108
  // necessary for llama
13833
15109
  if (src0->grad) {
13834
- // y = softmax(x)
13835
- //
13836
- // Jii = yi - yi*yi
13837
- // Jij = -yi*yj
13838
- // J = diag(y)-y.*y
13839
- // dx = J * dy
13840
- // dxk = sum(Jkj * dyk)
13841
-
13842
- int64_t ne2[4] = {
13843
- tensor->ne[0],
13844
- 1,
13845
- tensor->ne[1]*tensor->ne[2],
13846
- tensor->ne[3]
13847
- };
13848
- struct ggml_tensor * tensor2 = ggml_cont(ctx,
13849
- ggml_reshape_4d(ctx,
13850
- ggml_cont(ctx, tensor),
13851
- ne2[0], ne2[1], ne2[2], ne2[3]));
13852
-
13853
- struct ggml_tensor * grad2 = ggml_cont(ctx,
13854
- ggml_reshape_4d(ctx,
13855
- ggml_cont(ctx, tensor->grad),
13856
- ne2[0], ne2[1], ne2[2], ne2[3]));
13857
-
13858
- struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
13859
- ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
13860
- tensor2, // [ne0,1,ne1*ne2,ne3]
13861
- 1, 0, 2, 3));
13862
-
13863
15110
  src0->grad =
13864
- ggml_add_impl(ctx,
13865
- src0->grad, // [ne0,ne1,ne2,ne3]
13866
- ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
13867
- ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
13868
- ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
13869
- ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
13870
- tensor2), // [ne0,1,ne1*ne2,ne3]
13871
- ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
13872
- tensor2_t, // [1,ne0,ne1*ne2,ne3]
13873
- tensor2_t)), // [1,ne0,ne1*ne2,ne3]
13874
- grad2), // [ne0,1,ne1*ne2,ne3]
13875
- src0->grad),
13876
- inplace);
15111
+ ggml_add_impl(ctx, src0->grad,
15112
+ ggml_soft_max_back(ctx, tensor->grad, tensor),
15113
+ inplace);
13877
15114
  }
15115
+
15116
+ } break;
15117
+ case GGML_OP_SOFT_MAX_BACK:
15118
+ {
15119
+ GGML_ASSERT(false); // TODO: not implemented
13878
15120
  } break;
13879
15121
  case GGML_OP_ROPE:
13880
15122
  {
@@ -13929,17 +15171,190 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13929
15171
  } break;
13930
15172
  case GGML_OP_FLASH_ATTN:
13931
15173
  {
13932
- GGML_ASSERT(false); // not supported
15174
+ struct ggml_tensor * flash_grad = NULL;
15175
+ if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15176
+ int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15177
+ GGML_ASSERT(t == 0 || t == 1);
15178
+ bool masked = t != 0;
15179
+ flash_grad =
15180
+ ggml_flash_attn_back(ctx,
15181
+ src0,
15182
+ src1,
15183
+ tensor->opt[0],
15184
+ tensor->grad,
15185
+ masked);
15186
+ }
15187
+
15188
+ if (src0->grad) {
15189
+ struct ggml_tensor * grad_q = NULL;
15190
+ const size_t nb0 = flash_grad->nb[0];
15191
+ const size_t offset = 0;
15192
+ switch(src0->n_dims) {
15193
+ case 2:
15194
+ {
15195
+ grad_q = ggml_view_2d(ctx,
15196
+ flash_grad,
15197
+ src0->ne[0],
15198
+ src0->ne[1],
15199
+ nb0*src0->ne[0],
15200
+ offset);
15201
+ } break;
15202
+ case 3:
15203
+ {
15204
+ grad_q = ggml_view_3d(ctx,
15205
+ flash_grad,
15206
+ src0->ne[0],
15207
+ src0->ne[1],
15208
+ src0->ne[2],
15209
+ nb0*src0->ne[0],
15210
+ nb0*src0->ne[0]*src0->ne[1],
15211
+ offset);
15212
+ } break;
15213
+ case 4:
15214
+ {
15215
+ grad_q = ggml_view_4d(ctx,
15216
+ flash_grad,
15217
+ src0->ne[0],
15218
+ src0->ne[1],
15219
+ src0->ne[2],
15220
+ src0->ne[3],
15221
+ nb0*src0->ne[0],
15222
+ nb0*src0->ne[0]*src0->ne[1],
15223
+ nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
15224
+ offset);
15225
+ } break;
15226
+ }
15227
+
15228
+ src0->grad = ggml_add_impl(ctx,
15229
+ src0->grad,
15230
+ grad_q,
15231
+ inplace);
15232
+ }
15233
+
15234
+ if (src1->grad) {
15235
+ struct ggml_tensor * grad_k = NULL;
15236
+ const size_t nb0 = flash_grad->nb[0];
15237
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
15238
+ switch(src1->n_dims) {
15239
+ case 2:
15240
+ {
15241
+ grad_k = ggml_view_2d(ctx,
15242
+ flash_grad,
15243
+ src1->ne[0],
15244
+ src1->ne[1],
15245
+ nb0*src1->ne[0],
15246
+ offset);
15247
+ } break;
15248
+ case 3:
15249
+ {
15250
+ grad_k = ggml_view_3d(ctx,
15251
+ flash_grad,
15252
+ src1->ne[0],
15253
+ src1->ne[1],
15254
+ src1->ne[2],
15255
+ nb0*src1->ne[0],
15256
+ nb0*src1->ne[0]*src1->ne[1],
15257
+ offset);
15258
+ } break;
15259
+ case 4:
15260
+ {
15261
+ grad_k = ggml_view_4d(ctx,
15262
+ flash_grad,
15263
+ src1->ne[0],
15264
+ src1->ne[1],
15265
+ src1->ne[2],
15266
+ src1->ne[3],
15267
+ nb0*src1->ne[0],
15268
+ nb0*src1->ne[0]*src1->ne[1],
15269
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
15270
+ offset);
15271
+ } break;
15272
+ }
15273
+
15274
+ src1->grad = ggml_add_impl(ctx,
15275
+ src1->grad,
15276
+ grad_k,
15277
+ inplace);
15278
+ }
15279
+
15280
+ struct ggml_tensor * opt0 = tensor->opt[0];
15281
+
15282
+ if (opt0->grad) {
15283
+ struct ggml_tensor * grad_v = NULL;
15284
+ const size_t nb0 = flash_grad->nb[0];
15285
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
15286
+ + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
15287
+ switch(opt0->n_dims) {
15288
+ case 2:
15289
+ {
15290
+ grad_v = ggml_view_2d(ctx,
15291
+ flash_grad,
15292
+ opt0->ne[0],
15293
+ opt0->ne[1],
15294
+ nb0*opt0->ne[0],
15295
+ offset);
15296
+ } break;
15297
+ case 3:
15298
+ {
15299
+ grad_v = ggml_view_3d(ctx,
15300
+ flash_grad,
15301
+ opt0->ne[0],
15302
+ opt0->ne[1],
15303
+ opt0->ne[2],
15304
+ nb0*opt0->ne[0],
15305
+ nb0*opt0->ne[0]*opt0->ne[1],
15306
+ offset);
15307
+ } break;
15308
+ case 4:
15309
+ {
15310
+ grad_v = ggml_view_4d(ctx,
15311
+ flash_grad,
15312
+ opt0->ne[0],
15313
+ opt0->ne[1],
15314
+ opt0->ne[2],
15315
+ opt0->ne[3],
15316
+ nb0*opt0->ne[0],
15317
+ nb0*opt0->ne[0]*opt0->ne[1],
15318
+ nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
15319
+ offset);
15320
+ } break;
15321
+ }
15322
+
15323
+ opt0->grad = ggml_add_impl(ctx,
15324
+ opt0->grad,
15325
+ grad_v,
15326
+ inplace);
15327
+ }
13933
15328
  } break;
13934
15329
  case GGML_OP_FLASH_FF:
13935
15330
  {
13936
15331
  GGML_ASSERT(false); // not supported
13937
15332
  } break;
15333
+ case GGML_OP_FLASH_ATTN_BACK:
15334
+ {
15335
+ GGML_ASSERT(false); // not supported
15336
+ } break;
13938
15337
  case GGML_OP_MAP_UNARY:
13939
15338
  case GGML_OP_MAP_BINARY:
13940
15339
  {
13941
15340
  GGML_ASSERT(false); // not supported
13942
15341
  } break;
15342
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15343
+ {
15344
+ if (src0->grad) {
15345
+ src0->grad = ggml_add_impl(ctx,
15346
+ src0->grad,
15347
+ ggml_cross_entropy_loss_back(ctx,
15348
+ src0,
15349
+ src1,
15350
+ tensor->grad),
15351
+ inplace);
15352
+ }
15353
+ } break;
15354
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15355
+ {
15356
+ GGML_ASSERT(false); // not supported
15357
+ } break;
13943
15358
  case GGML_OP_NONE:
13944
15359
  {
13945
15360
  // nop
@@ -14316,6 +15731,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14316
15731
  case GGML_OP_SUM_ROWS:
14317
15732
  case GGML_OP_MEAN:
14318
15733
  case GGML_OP_REPEAT:
15734
+ case GGML_OP_REPEAT_BACK:
14319
15735
  case GGML_OP_ABS:
14320
15736
  case GGML_OP_SGN:
14321
15737
  case GGML_OP_NEG:
@@ -14335,6 +15751,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14335
15751
  node->n_tasks = n_threads;
14336
15752
  } break;
14337
15753
  case GGML_OP_MUL_MAT:
15754
+ case GGML_OP_OUT_PROD:
14338
15755
  {
14339
15756
  node->n_tasks = n_threads;
14340
15757
 
@@ -14417,6 +15834,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14417
15834
  } break;
14418
15835
  case GGML_OP_DIAG_MASK_INF:
14419
15836
  case GGML_OP_SOFT_MAX:
15837
+ case GGML_OP_SOFT_MAX_BACK:
14420
15838
  case GGML_OP_ROPE:
14421
15839
  case GGML_OP_ROPE_BACK:
14422
15840
  {
@@ -14496,6 +15914,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14496
15914
  cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
14497
15915
  }
14498
15916
 
15917
+ work_size = MAX(work_size, cur);
15918
+ } break;
15919
+ case GGML_OP_FLASH_ATTN_BACK:
15920
+ {
15921
+ node->n_tasks = n_threads;
15922
+
15923
+ size_t cur = 0;
15924
+
15925
+ const int64_t D = node->src0->ne[0];
15926
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
15927
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
15928
+ if (node->src1->type == GGML_TYPE_F32) {
15929
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
15930
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
15931
+ }
15932
+
15933
+ if (node->src1->type == GGML_TYPE_F16) {
15934
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
15935
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
15936
+ }
15937
+
14499
15938
  work_size = MAX(work_size, cur);
14500
15939
  } break;
14501
15940
  case GGML_OP_MAP_UNARY:
@@ -14503,6 +15942,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14503
15942
  {
14504
15943
  node->n_tasks = 1;
14505
15944
  } break;
15945
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15946
+ {
15947
+ node->n_tasks = n_threads;
15948
+
15949
+ size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
15950
+
15951
+ work_size = MAX(work_size, cur);
15952
+ } break;
15953
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15954
+ {
15955
+ node->n_tasks = n_threads;
15956
+
15957
+ size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
15958
+
15959
+ work_size = MAX(work_size, cur);
15960
+ } break;
14506
15961
  case GGML_OP_NONE:
14507
15962
  {
14508
15963
  node->n_tasks = 1;
@@ -15478,6 +16933,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
15478
16933
 
15479
16934
  static enum ggml_opt_result ggml_opt_adam(
15480
16935
  struct ggml_context * ctx,
16936
+ struct ggml_opt_context * opt,
15481
16937
  struct ggml_opt_params params,
15482
16938
  struct ggml_tensor * f,
15483
16939
  struct ggml_cgraph * gf,
@@ -15503,25 +16959,29 @@ static enum ggml_opt_result ggml_opt_adam(
15503
16959
  }
15504
16960
  }
15505
16961
 
16962
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
16963
+ int iter = opt->iter;
16964
+ ggml_opt_init(opt->ctx, opt, params, nx);
16965
+ opt->iter = iter;
16966
+ }
16967
+
15506
16968
  // constants
15507
- const float alpha = params.adam.alpha;
16969
+ const float sched = params.adam.sched;
16970
+ const float decay = params.adam.decay * sched;
16971
+ const float alpha = params.adam.alpha * sched;
15508
16972
  const float beta1 = params.adam.beta1;
15509
16973
  const float beta2 = params.adam.beta2;
15510
16974
  const float eps = params.adam.eps;
15511
16975
 
15512
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
15513
- float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
15514
- float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
15515
- float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
15516
- float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
15517
- float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
15518
- float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
16976
+ float * x = opt->adam.x->data; // view of the parameters
16977
+ float * g1 = opt->adam.g1->data; // gradient
16978
+ float * g2 = opt->adam.g2->data; // gradient squared
16979
+ float * m = opt->adam.m->data; // first moment
16980
+ float * v = opt->adam.v->data; // second moment
16981
+ float * mh = opt->adam.mh->data; // first moment hat
16982
+ float * vh = opt->adam.vh->data; // second moment hat
15519
16983
 
15520
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
15521
-
15522
- // initialize
15523
- ggml_vec_set_f32(nx, m, 0.0f);
15524
- ggml_vec_set_f32(nx, v, 0.0f);
16984
+ float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
15525
16985
 
15526
16986
  // update view
15527
16987
  ggml_opt_get_params(np, ps, x);
@@ -15531,16 +16991,27 @@ static enum ggml_opt_result ggml_opt_adam(
15531
16991
  ggml_set_f32 (f->grad, 1.0f);
15532
16992
  ggml_graph_compute(ctx, gb);
15533
16993
 
15534
- float fx_prev = ggml_get_f32_1d(f, 0);
16994
+ opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
16995
+ opt->adam.fx_best = opt->adam.fx_prev;
15535
16996
  if (pf) {
15536
- pf[0] = fx_prev;
16997
+ pf[opt->iter % params.past] = opt->adam.fx_prev;
15537
16998
  }
15538
16999
 
15539
- int n_no_improvement = 0;
15540
- float fx_best = fx_prev;
17000
+ // initialize
17001
+ if (opt->just_initialized) {
17002
+ opt->adam.n_no_improvement = 0;
17003
+ opt->just_initialized = false;
17004
+ }
17005
+
17006
+ float * fx_best = &opt->adam.fx_best;
17007
+ float * fx_prev = &opt->adam.fx_prev;
17008
+ int * n_no_improvement = &opt->adam.n_no_improvement;
17009
+
17010
+ int iter0 = opt->iter;
15541
17011
 
15542
17012
  // run the optimizer
15543
17013
  for (int t = 0; t < params.adam.n_iter; ++t) {
17014
+ opt->iter = iter0 + t + 1;
15544
17015
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
15545
17016
 
15546
17017
  GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
@@ -15574,17 +17045,22 @@ static enum ggml_opt_result ggml_opt_adam(
15574
17045
 
15575
17046
  // m^hat = m_t / (1 - beta1^t)
15576
17047
  // v^hat = v_t / (1 - beta2^t)
15577
- // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
17048
+ // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
17049
+ // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
17050
+ // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
17051
+ // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
17052
+ // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
15578
17053
  ggml_vec_cpy_f32 (nx, mh, m);
15579
17054
  ggml_vec_cpy_f32 (nx, vh, v);
15580
17055
 
15581
- ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
15582
- ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1)));
17056
+ ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
17057
+ ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
15583
17058
 
15584
17059
  ggml_vec_sqrt_f32 (nx, vh, vh);
15585
17060
  ggml_vec_acc1_f32 (nx, vh, eps);
15586
17061
 
15587
17062
  ggml_vec_div_f32 (nx, mh, mh, vh);
17063
+ ggml_vec_scale_f32(nx, x, 1.0f - decay);
15588
17064
  ggml_vec_sub_f32 (nx, x, x, mh);
15589
17065
 
15590
17066
  // update the parameters
@@ -15598,7 +17074,7 @@ static enum ggml_opt_result ggml_opt_adam(
15598
17074
  const float fx = ggml_get_f32_1d(f, 0);
15599
17075
 
15600
17076
  // check convergence
15601
- if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
17077
+ if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
15602
17078
  GGML_PRINT_DEBUG("converged\n");
15603
17079
 
15604
17080
  return GGML_OPT_OK;
@@ -15607,32 +17083,32 @@ static enum ggml_opt_result ggml_opt_adam(
15607
17083
  // delta-based convergence test
15608
17084
  if (pf != NULL) {
15609
17085
  // need at least params.past iterations to start checking for convergence
15610
- if (params.past <= t) {
15611
- const float rate = (pf[t%params.past] - fx)/fx;
17086
+ if (params.past <= iter0 + t) {
17087
+ const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
15612
17088
 
15613
17089
  if (fabsf(rate) < params.delta) {
15614
17090
  return GGML_OPT_OK;
15615
17091
  }
15616
17092
  }
15617
17093
 
15618
- pf[t%params.past] = fx;
17094
+ pf[(iter0 + t)%params.past] = fx;
15619
17095
  }
15620
17096
 
15621
17097
  // check for improvement
15622
17098
  if (params.max_no_improvement > 0) {
15623
- if (fx_best > fx) {
15624
- fx_best = fx;
15625
- n_no_improvement = 0;
17099
+ if (fx_best[0] > fx) {
17100
+ fx_best[0] = fx;
17101
+ n_no_improvement[0] = 0;
15626
17102
  } else {
15627
- ++n_no_improvement;
17103
+ ++n_no_improvement[0];
15628
17104
 
15629
- if (n_no_improvement >= params.max_no_improvement) {
17105
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15630
17106
  return GGML_OPT_OK;
15631
17107
  }
15632
17108
  }
15633
17109
  }
15634
17110
 
15635
- fx_prev = fx;
17111
+ fx_prev[0] = fx;
15636
17112
 
15637
17113
  {
15638
17114
  const int64_t t_end_cpu = ggml_cycles();
@@ -15771,6 +17247,7 @@ static enum ggml_opt_result linesearch_backtracking(
15771
17247
 
15772
17248
  static enum ggml_opt_result ggml_opt_lbfgs(
15773
17249
  struct ggml_context * ctx,
17250
+ struct ggml_opt_context * opt,
15774
17251
  struct ggml_opt_params params,
15775
17252
  struct ggml_tensor * f,
15776
17253
  struct ggml_cgraph * gf,
@@ -15803,31 +17280,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15803
17280
  }
15804
17281
  }
15805
17282
 
15806
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
15807
- float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
15808
- float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
15809
- float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
15810
- float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
17283
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
17284
+ int iter = opt->iter;
17285
+ ggml_opt_init(ctx, opt, params, nx);
17286
+ opt->iter = iter;
17287
+ }
17288
+
17289
+ float * x = opt->lbfgs.x->data; // current parameters
17290
+ float * xp = opt->lbfgs.xp->data; // previous parameters
17291
+ float * g = opt->lbfgs.g->data; // current gradient
17292
+ float * gp = opt->lbfgs.gp->data; // previous gradient
17293
+ float * d = opt->lbfgs.d->data; // search direction
15811
17294
 
15812
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
17295
+ float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
15813
17296
 
15814
17297
  float fx = 0.0f; // cost function value
15815
17298
  float xnorm = 0.0f; // ||x||
15816
17299
  float gnorm = 0.0f; // ||g||
15817
- float step = 0.0f;
15818
17300
 
15819
17301
  // initialize x from the graph nodes
15820
17302
  ggml_opt_get_params(np, ps, x);
15821
17303
 
15822
17304
  // the L-BFGS memory
15823
- struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
15824
-
15825
- for (int i = 0; i < m; ++i) {
15826
- lm[i].alpha = 0.0f;
15827
- lm[i].ys = 0.0f;
15828
- lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15829
- lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15830
- }
17305
+ float * lm_alpha = opt->lbfgs.lmal->data;
17306
+ float * lm_ys = opt->lbfgs.lmys->data;
17307
+ float * lm_s = opt->lbfgs.lms->data;
17308
+ float * lm_y = opt->lbfgs.lmy->data;
15831
17309
 
15832
17310
  // evaluate the function value and its gradient
15833
17311
  {
@@ -15842,12 +17320,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15842
17320
  fx = ggml_get_f32_1d(f, 0);
15843
17321
  }
15844
17322
 
15845
- if (pf) {
15846
- pf[0] = fx;
15847
- }
15848
-
15849
- float fx_best = fx;
15850
-
15851
17323
  // search direction = -gradient
15852
17324
  ggml_vec_neg_f32(nx, d, g);
15853
17325
 
@@ -15864,26 +17336,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15864
17336
  return GGML_OPT_OK;
15865
17337
  }
15866
17338
 
15867
- // initial step
15868
- ggml_vec_norm_inv_f32(nx, &step, d);
17339
+ if (opt->just_initialized) {
17340
+ if (pf) {
17341
+ pf[0] = fx;
17342
+ }
17343
+ opt->lbfgs.fx_best = fx;
17344
+
17345
+ // initial step
17346
+ ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
17347
+ opt->lbfgs.j = 0;
17348
+ opt->lbfgs.k = 1;
17349
+ opt->lbfgs.end = 0;
17350
+ opt->lbfgs.n_no_improvement = 0;
17351
+ opt->just_initialized = false;
17352
+ }
17353
+
17354
+ float * fx_best = &opt->lbfgs.fx_best;
17355
+ float * step = &opt->lbfgs.step;
17356
+ int * j = &opt->lbfgs.j;
17357
+ int * k = &opt->lbfgs.k;
17358
+ int * end = &opt->lbfgs.end;
17359
+ int * n_no_improvement = &opt->lbfgs.n_no_improvement;
15869
17360
 
15870
- int j = 0;
15871
- int k = 1;
15872
- int ls = 0;
15873
- int end = 0;
15874
- int bound = 0;
15875
- int n_no_improvement = 0;
17361
+ int ls = 0;
17362
+ int bound = 0;
15876
17363
 
15877
17364
  float ys = 0.0f;
15878
17365
  float yy = 0.0f;
15879
17366
  float beta = 0.0f;
15880
17367
 
17368
+ int it = 0;
17369
+
15881
17370
  while (true) {
15882
17371
  // store the current position and gradient vectors
15883
17372
  ggml_vec_cpy_f32(nx, xp, x);
15884
17373
  ggml_vec_cpy_f32(nx, gp, g);
15885
17374
 
15886
- ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
17375
+ ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
15887
17376
 
15888
17377
  if (ls < 0) {
15889
17378
  // linesearch failed - go back to the previous point and return
@@ -15909,32 +17398,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15909
17398
  // delta-based convergence test
15910
17399
  if (pf != NULL) {
15911
17400
  // need at least params.past iterations to start checking for convergence
15912
- if (params.past <= k) {
15913
- const float rate = (pf[k%params.past] - fx)/fx;
17401
+ if (params.past <= k[0]) {
17402
+ const float rate = (pf[k[0]%params.past] - fx)/fx;
15914
17403
 
15915
17404
  if (fabsf(rate) < params.delta) {
15916
17405
  return GGML_OPT_OK;
15917
17406
  }
15918
17407
  }
15919
17408
 
15920
- pf[k%params.past] = fx;
17409
+ pf[k[0]%params.past] = fx;
15921
17410
  }
15922
17411
 
15923
17412
  // check for improvement
15924
17413
  if (params.max_no_improvement > 0) {
15925
- if (fx < fx_best) {
15926
- fx_best = fx;
15927
- n_no_improvement = 0;
17414
+ if (fx < fx_best[0]) {
17415
+ fx_best[0] = fx;
17416
+ n_no_improvement[0] = 0;
15928
17417
  } else {
15929
- n_no_improvement++;
17418
+ n_no_improvement[0]++;
15930
17419
 
15931
- if (n_no_improvement >= params.max_no_improvement) {
17420
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15932
17421
  return GGML_OPT_OK;
15933
17422
  }
15934
17423
  }
15935
17424
  }
15936
17425
 
15937
- if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
17426
+ if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
15938
17427
  // reached the maximum number of iterations
15939
17428
  return GGML_OPT_DID_NOT_CONVERGE;
15940
17429
  }
@@ -15943,50 +17432,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15943
17432
  // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
15944
17433
  // y_{k+1} = g_{k+1} - g_{k}.
15945
17434
  //
15946
- ggml_vec_sub_f32(nx, lm[end].s, x, xp);
15947
- ggml_vec_sub_f32(nx, lm[end].y, g, gp);
17435
+ ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
17436
+ ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
15948
17437
 
15949
17438
  // compute scalars ys and yy:
15950
17439
  // ys = y^t \cdot s -> 1 / \rho.
15951
17440
  // yy = y^t \cdot y.
15952
17441
  //
15953
- ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
15954
- ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
17442
+ ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
17443
+ ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
15955
17444
 
15956
- lm[end].ys = ys;
17445
+ lm_ys[end[0]] = ys;
15957
17446
 
15958
17447
  // find new search direction
15959
17448
  // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
15960
17449
 
15961
- bound = (m <= k) ? m : k;
15962
- k++;
15963
- end = (end + 1)%m;
17450
+ bound = (m <= k[0]) ? m : k[0];
17451
+ k[0]++;
17452
+ it++;
17453
+ end[0] = (end[0] + 1)%m;
15964
17454
 
15965
17455
  // initialize search direction with -g
15966
17456
  ggml_vec_neg_f32(nx, d, g);
15967
17457
 
15968
- j = end;
17458
+ j[0] = end[0];
15969
17459
  for (int i = 0; i < bound; ++i) {
15970
- j = (j + m - 1) % m;
17460
+ j[0] = (j[0] + m - 1) % m;
15971
17461
  // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
15972
- ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
15973
- lm[j].alpha /= lm[j].ys;
17462
+ ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
17463
+ lm_alpha[j[0]] /= lm_ys[j[0]];
15974
17464
  // q_{i} = q_{i+1} - \alpha_{i} y_{i}
15975
- ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
17465
+ ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
15976
17466
  }
15977
17467
 
15978
17468
  ggml_vec_scale_f32(nx, d, ys/yy);
15979
17469
 
15980
17470
  for (int i = 0; i < bound; ++i) {
15981
17471
  // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
15982
- ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
15983
- beta /= lm[j].ys;
17472
+ ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
17473
+ beta /= lm_ys[j[0]];
15984
17474
  // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
15985
- ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
15986
- j = (j + 1)%m;
17475
+ ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
17476
+ j[0] = (j[0] + 1)%m;
15987
17477
  }
15988
17478
 
15989
- step = 1.0;
17479
+ step[0] = 1.0;
15990
17480
  }
15991
17481
 
15992
17482
  return GGML_OPT_DID_NOT_CONVERGE;
@@ -16011,6 +17501,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
16011
17501
 
16012
17502
  .adam = {
16013
17503
  .n_iter = 10000,
17504
+ .sched = 1.000f,
17505
+ .decay = 0.001f,
16014
17506
  .alpha = 0.001f,
16015
17507
  .beta1 = 0.9f,
16016
17508
  .beta2 = 0.999f,
@@ -16053,6 +17545,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
16053
17545
  return result;
16054
17546
  }
16055
17547
 
17548
+ GGML_API void ggml_opt_init(
17549
+ struct ggml_context * ctx,
17550
+ struct ggml_opt_context * opt,
17551
+ struct ggml_opt_params params,
17552
+ int64_t nx) {
17553
+ opt->ctx = ctx;
17554
+ opt->params = params;
17555
+ opt->iter = 0;
17556
+ opt->nx = nx;
17557
+ opt->just_initialized = true;
17558
+ switch (opt->params.type) {
17559
+ case GGML_OPT_ADAM:
17560
+ {
17561
+ opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17562
+ opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17563
+ opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17564
+ opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17565
+ opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17566
+ opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17567
+ opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17568
+ opt->adam.pf = params.past > 0
17569
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
17570
+ : NULL;
17571
+ ggml_set_zero(opt->adam.x);
17572
+ ggml_set_zero(opt->adam.g1);
17573
+ ggml_set_zero(opt->adam.g2);
17574
+ ggml_set_zero(opt->adam.m);
17575
+ ggml_set_zero(opt->adam.v);
17576
+ ggml_set_zero(opt->adam.mh);
17577
+ ggml_set_zero(opt->adam.vh);
17578
+ if (opt->adam.pf) {
17579
+ ggml_set_zero(opt->adam.pf);
17580
+ }
17581
+ } break;
17582
+ case GGML_OPT_LBFGS:
17583
+ {
17584
+ opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17585
+ opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17586
+ opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17587
+ opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17588
+ opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17589
+ opt->lbfgs.pf = params.past > 0
17590
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
17591
+ : NULL;
17592
+ opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
17593
+ opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
17594
+ opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
17595
+ opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
17596
+ ggml_set_zero(opt->lbfgs.x);
17597
+ ggml_set_zero(opt->lbfgs.xp);
17598
+ ggml_set_zero(opt->lbfgs.g);
17599
+ ggml_set_zero(opt->lbfgs.gp);
17600
+ ggml_set_zero(opt->lbfgs.d);
17601
+ ggml_set_zero(opt->lbfgs.pf);
17602
+ if (opt->lbfgs.pf) {
17603
+ ggml_set_zero(opt->lbfgs.pf);
17604
+ }
17605
+ ggml_set_zero(opt->lbfgs.lmal);
17606
+ ggml_set_zero(opt->lbfgs.lmys);
17607
+ ggml_set_zero(opt->lbfgs.lms);
17608
+ ggml_set_zero(opt->lbfgs.lmy);
17609
+ } break;
17610
+ }
17611
+ }
17612
+
16056
17613
  enum ggml_opt_result ggml_opt(
16057
17614
  struct ggml_context * ctx,
16058
17615
  struct ggml_opt_params params,
@@ -16075,33 +17632,65 @@ enum ggml_opt_result ggml_opt(
16075
17632
 
16076
17633
  enum ggml_opt_result result = GGML_OPT_OK;
16077
17634
 
17635
+ struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
17636
+
17637
+ ggml_opt_init(ctx, opt, params, 0);
17638
+ result = ggml_opt_resume(ctx, opt, f);
17639
+
17640
+ if (free_ctx) {
17641
+ ggml_free(ctx);
17642
+ }
17643
+
17644
+ return result;
17645
+ }
17646
+
17647
+ enum ggml_opt_result ggml_opt_resume(
17648
+ struct ggml_context * ctx,
17649
+ struct ggml_opt_context * opt,
17650
+ struct ggml_tensor * f) {
17651
+
17652
+ // build forward + backward compute graphs
17653
+ struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
17654
+ struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
17655
+
17656
+ struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
17657
+ struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
17658
+
17659
+ *gf = ggml_build_forward (f);
17660
+ *gb = ggml_build_backward(ctx, gf, true);
17661
+
17662
+ return ggml_opt_resume_g(ctx, opt, f, gf, gb);
17663
+ }
17664
+
17665
+ enum ggml_opt_result ggml_opt_resume_g(
17666
+ struct ggml_context * ctx,
17667
+ struct ggml_opt_context * opt,
17668
+ struct ggml_tensor * f,
17669
+ struct ggml_cgraph * gf,
17670
+ struct ggml_cgraph * gb) {
17671
+
16078
17672
  // build forward + backward compute graphs
16079
- struct ggml_cgraph gf = ggml_build_forward (f);
16080
- struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
17673
+ enum ggml_opt_result result = GGML_OPT_OK;
16081
17674
 
16082
- switch (params.type) {
17675
+ switch (opt->params.type) {
16083
17676
  case GGML_OPT_ADAM:
16084
17677
  {
16085
- result = ggml_opt_adam(ctx, params, f, &gf, &gb);
17678
+ result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
16086
17679
  } break;
16087
17680
  case GGML_OPT_LBFGS:
16088
17681
  {
16089
- result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
17682
+ result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
16090
17683
  } break;
16091
17684
  }
16092
17685
 
16093
- if (params.print_forward_graph) {
16094
- ggml_graph_print (&gf);
16095
- ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
16096
- }
16097
-
16098
- if (params.print_backward_graph) {
16099
- ggml_graph_print (&gb);
16100
- ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
17686
+ if (opt->params.print_forward_graph) {
17687
+ ggml_graph_print (gf);
17688
+ ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
16101
17689
  }
16102
17690
 
16103
- if (free_ctx) {
16104
- ggml_free(ctx);
17691
+ if (opt->params.print_backward_graph) {
17692
+ ggml_graph_print (gb);
17693
+ ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
16105
17694
  }
16106
17695
 
16107
17696
  return result;
@@ -16301,6 +17890,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
16301
17890
  result = ggml_quantize_q6_K(src + start, block, n, n, hist);
16302
17891
  } break;
16303
17892
  #endif
17893
+ case GGML_TYPE_F16:
17894
+ {
17895
+ int elemsize = sizeof(ggml_fp16_t);
17896
+ ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
17897
+ result = n * elemsize;
17898
+ } break;
17899
+ case GGML_TYPE_F32:
17900
+ {
17901
+ int elemsize = sizeof(float);
17902
+ result = n * elemsize;
17903
+ memcpy((uint8_t *)dst + start * elemsize, src + start, result);
17904
+ } break;
16304
17905
  default:
16305
17906
  assert(false);
16306
17907
  }