llama_cpp 0.2.0 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/examples/README.md +60 -0
- data/examples/chat.rb +195 -0
- data/ext/llama_cpp/llama_cpp.cpp +52 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +697 -130
- data/ext/llama_cpp/src/ggml-cuda.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +548 -497
- data/ext/llama_cpp/src/ggml-metal.metal +425 -122
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -32
- data/ext/llama_cpp/src/ggml-opencl.h +1 -2
- data/ext/llama_cpp/src/ggml.c +1904 -303
- data/ext/llama_cpp/src/ggml.h +126 -2
- data/ext/llama_cpp/src/llama.cpp +212 -108
- data/ext/llama_cpp/src/llama.h +12 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -0
- metadata +4 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -3603,6 +3603,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3603
3603
|
"SUM_ROWS",
|
3604
3604
|
"MEAN",
|
3605
3605
|
"REPEAT",
|
3606
|
+
"REPEAT_BACK",
|
3606
3607
|
"ABS",
|
3607
3608
|
"SGN",
|
3608
3609
|
"NEG",
|
@@ -3616,6 +3617,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3616
3617
|
"RMS_NORM_BACK",
|
3617
3618
|
|
3618
3619
|
"MUL_MAT",
|
3620
|
+
"OUT_PROD",
|
3619
3621
|
|
3620
3622
|
"SCALE",
|
3621
3623
|
"SET",
|
@@ -3631,6 +3633,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3631
3633
|
"DIAG_MASK_INF",
|
3632
3634
|
"DIAG_MASK_ZERO",
|
3633
3635
|
"SOFT_MAX",
|
3636
|
+
"SOFT_MAX_BACK",
|
3634
3637
|
"ROPE",
|
3635
3638
|
"ROPE_BACK",
|
3636
3639
|
"ALIBI",
|
@@ -3640,13 +3643,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3640
3643
|
|
3641
3644
|
"FLASH_ATTN",
|
3642
3645
|
"FLASH_FF",
|
3646
|
+
"FLASH_ATTN_BACK",
|
3643
3647
|
|
3644
3648
|
"MAP_UNARY",
|
3645
3649
|
"MAP_BINARY",
|
3646
|
-
};
|
3647
3650
|
|
3648
|
-
|
3651
|
+
"CROSS_ENTROPY_LOSS",
|
3652
|
+
"CROSS_ENTROPY_LOSS_BACK",
|
3653
|
+
};
|
3649
3654
|
|
3655
|
+
static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
|
3650
3656
|
|
3651
3657
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
3652
3658
|
"none",
|
@@ -3665,6 +3671,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3665
3671
|
"Σx_k",
|
3666
3672
|
"Σx/n",
|
3667
3673
|
"repeat(x)",
|
3674
|
+
"repeat_back(x)",
|
3668
3675
|
"abs(x)",
|
3669
3676
|
"sgn(x)",
|
3670
3677
|
"-x",
|
@@ -3677,6 +3684,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3677
3684
|
"rms_norm(x)",
|
3678
3685
|
"rms_norm_back(x)",
|
3679
3686
|
|
3687
|
+
"X*Y",
|
3680
3688
|
"X*Y",
|
3681
3689
|
|
3682
3690
|
"x*v",
|
@@ -3693,6 +3701,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3693
3701
|
"diag_mask_inf(x)",
|
3694
3702
|
"diag_mask_zero(x)",
|
3695
3703
|
"soft_max(x)",
|
3704
|
+
"soft_max_back(x)",
|
3696
3705
|
"rope(x)",
|
3697
3706
|
"rope_back(x)",
|
3698
3707
|
"alibi(x)",
|
@@ -3702,12 +3711,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3702
3711
|
|
3703
3712
|
"flash_attn(x)",
|
3704
3713
|
"flash_ff(x)",
|
3714
|
+
"flash_attn_back(x)",
|
3705
3715
|
|
3706
3716
|
"f(x)",
|
3707
3717
|
"f(x,y)",
|
3718
|
+
|
3719
|
+
"cross_entropy_loss(x,y)",
|
3720
|
+
"cross_entropy_loss_back(x,y)",
|
3708
3721
|
};
|
3709
3722
|
|
3710
|
-
static_assert(GGML_OP_COUNT ==
|
3723
|
+
static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
|
3711
3724
|
|
3712
3725
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
3713
3726
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -3870,6 +3883,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
3870
3883
|
(t0->ne[3] == t1->ne[3]);
|
3871
3884
|
}
|
3872
3885
|
|
3886
|
+
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
3887
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3888
|
+
|
3889
|
+
return
|
3890
|
+
(t0->ne[1] == t1->ne[1]) &&
|
3891
|
+
(t0->ne[2] == t1->ne[2]) &&
|
3892
|
+
(t0->ne[3] == t1->ne[3]);
|
3893
|
+
}
|
3894
|
+
|
3873
3895
|
bool ggml_is_quantized(enum ggml_type type) {
|
3874
3896
|
return GGML_IS_QUANTIZED[type];
|
3875
3897
|
}
|
@@ -3917,6 +3939,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
|
3917
3939
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3918
3940
|
}
|
3919
3941
|
|
3942
|
+
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
3943
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3944
|
+
|
3945
|
+
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
|
3946
|
+
}
|
3947
|
+
|
3920
3948
|
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
3921
3949
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3922
3950
|
|
@@ -4693,7 +4721,7 @@ struct ggml_tensor * ggml_add_impl(
|
|
4693
4721
|
|
4694
4722
|
bool is_node = false;
|
4695
4723
|
|
4696
|
-
if (
|
4724
|
+
if (a->grad || b->grad) {
|
4697
4725
|
is_node = true;
|
4698
4726
|
}
|
4699
4727
|
|
@@ -4733,7 +4761,7 @@ struct ggml_tensor * ggml_add1_impl(
|
|
4733
4761
|
|
4734
4762
|
bool is_node = false;
|
4735
4763
|
|
4736
|
-
if (
|
4764
|
+
if (a->grad || b->grad) {
|
4737
4765
|
is_node = true;
|
4738
4766
|
}
|
4739
4767
|
|
@@ -5159,6 +5187,34 @@ struct ggml_tensor * ggml_repeat(
|
|
5159
5187
|
return result;
|
5160
5188
|
}
|
5161
5189
|
|
5190
|
+
// ggml_repeat_back
|
5191
|
+
|
5192
|
+
struct ggml_tensor * ggml_repeat_back(
|
5193
|
+
struct ggml_context * ctx,
|
5194
|
+
struct ggml_tensor * a,
|
5195
|
+
struct ggml_tensor * b) {
|
5196
|
+
GGML_ASSERT(ggml_can_repeat(b, a));
|
5197
|
+
|
5198
|
+
bool is_node = false;
|
5199
|
+
|
5200
|
+
if (a->grad) {
|
5201
|
+
is_node = true;
|
5202
|
+
}
|
5203
|
+
|
5204
|
+
if (ggml_are_same_shape(a, b) && !is_node) {
|
5205
|
+
return a;
|
5206
|
+
}
|
5207
|
+
|
5208
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
|
5209
|
+
|
5210
|
+
result->op = GGML_OP_REPEAT_BACK;
|
5211
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5212
|
+
result->src0 = a;
|
5213
|
+
result->src1 = b;
|
5214
|
+
|
5215
|
+
return result;
|
5216
|
+
}
|
5217
|
+
|
5162
5218
|
// ggml_abs
|
5163
5219
|
|
5164
5220
|
struct ggml_tensor * ggml_abs_impl(
|
@@ -5536,6 +5592,32 @@ struct ggml_tensor * ggml_mul_mat(
|
|
5536
5592
|
return result;
|
5537
5593
|
}
|
5538
5594
|
|
5595
|
+
// ggml_out_prod
|
5596
|
+
|
5597
|
+
struct ggml_tensor * ggml_out_prod(
|
5598
|
+
struct ggml_context * ctx,
|
5599
|
+
struct ggml_tensor * a,
|
5600
|
+
struct ggml_tensor * b) {
|
5601
|
+
GGML_ASSERT(ggml_can_out_prod(a, b));
|
5602
|
+
GGML_ASSERT(!ggml_is_transposed(a));
|
5603
|
+
|
5604
|
+
bool is_node = false;
|
5605
|
+
|
5606
|
+
if (a->grad || b->grad) {
|
5607
|
+
is_node = true;
|
5608
|
+
}
|
5609
|
+
|
5610
|
+
const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
|
5611
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
|
5612
|
+
|
5613
|
+
result->op = GGML_OP_OUT_PROD;
|
5614
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5615
|
+
result->src0 = a;
|
5616
|
+
result->src1 = b;
|
5617
|
+
|
5618
|
+
return result;
|
5619
|
+
}
|
5620
|
+
|
5539
5621
|
// ggml_scale
|
5540
5622
|
|
5541
5623
|
struct ggml_tensor * ggml_scale_impl(
|
@@ -5548,7 +5630,7 @@ struct ggml_tensor * ggml_scale_impl(
|
|
5548
5630
|
|
5549
5631
|
bool is_node = false;
|
5550
5632
|
|
5551
|
-
if (
|
5633
|
+
if (a->grad || b->grad) {
|
5552
5634
|
is_node = true;
|
5553
5635
|
}
|
5554
5636
|
|
@@ -5591,7 +5673,7 @@ struct ggml_tensor * ggml_set_impl(
|
|
5591
5673
|
|
5592
5674
|
bool is_node = false;
|
5593
5675
|
|
5594
|
-
if (
|
5676
|
+
if (a->grad || b->grad) {
|
5595
5677
|
is_node = true;
|
5596
5678
|
}
|
5597
5679
|
|
@@ -5913,10 +5995,6 @@ struct ggml_tensor * ggml_view_1d(
|
|
5913
5995
|
result->src1 = NULL;
|
5914
5996
|
result->opt[0] = offs;
|
5915
5997
|
|
5916
|
-
if (is_node) {
|
5917
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5918
|
-
}
|
5919
|
-
|
5920
5998
|
return result;
|
5921
5999
|
}
|
5922
6000
|
|
@@ -5957,10 +6035,6 @@ struct ggml_tensor * ggml_view_2d(
|
|
5957
6035
|
result->src1 = NULL;
|
5958
6036
|
result->opt[0] = offs;
|
5959
6037
|
|
5960
|
-
if (is_node) {
|
5961
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5962
|
-
}
|
5963
|
-
|
5964
6038
|
return result;
|
5965
6039
|
}
|
5966
6040
|
|
@@ -6003,10 +6077,6 @@ struct ggml_tensor * ggml_view_3d(
|
|
6003
6077
|
result->src1 = NULL;
|
6004
6078
|
result->opt[0] = offs;
|
6005
6079
|
|
6006
|
-
if (is_node) {
|
6007
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
6008
|
-
}
|
6009
|
-
|
6010
6080
|
return result;
|
6011
6081
|
}
|
6012
6082
|
|
@@ -6051,10 +6121,6 @@ struct ggml_tensor * ggml_view_4d(
|
|
6051
6121
|
result->src1 = NULL;
|
6052
6122
|
result->opt[0] = offs;
|
6053
6123
|
|
6054
|
-
if (is_node) {
|
6055
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
6056
|
-
}
|
6057
|
-
|
6058
6124
|
return result;
|
6059
6125
|
}
|
6060
6126
|
|
@@ -6116,10 +6182,18 @@ struct ggml_tensor * ggml_permute(
|
|
6116
6182
|
result->src1 = NULL;
|
6117
6183
|
|
6118
6184
|
if (is_node) {
|
6119
|
-
|
6120
|
-
|
6121
|
-
|
6122
|
-
|
6185
|
+
ggml_scratch_save(ctx);
|
6186
|
+
|
6187
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
|
6188
|
+
|
6189
|
+
((int32_t *) b->data)[0] = axis0;
|
6190
|
+
((int32_t *) b->data)[1] = axis1;
|
6191
|
+
((int32_t *) b->data)[2] = axis2;
|
6192
|
+
((int32_t *) b->data)[3] = axis3;
|
6193
|
+
|
6194
|
+
ggml_scratch_load(ctx);
|
6195
|
+
|
6196
|
+
result->opt[0] = b;
|
6123
6197
|
}
|
6124
6198
|
|
6125
6199
|
return result;
|
@@ -6359,6 +6433,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
|
|
6359
6433
|
return ggml_soft_max_impl(ctx, a, true);
|
6360
6434
|
}
|
6361
6435
|
|
6436
|
+
|
6437
|
+
// ggml_soft_max_back
|
6438
|
+
|
6439
|
+
struct ggml_tensor * ggml_soft_max_back_impl(
|
6440
|
+
struct ggml_context * ctx,
|
6441
|
+
struct ggml_tensor * a,
|
6442
|
+
struct ggml_tensor * b,
|
6443
|
+
bool inplace) {
|
6444
|
+
bool is_node = false;
|
6445
|
+
|
6446
|
+
if (a->grad || b->grad) {
|
6447
|
+
is_node = true; // TODO : implement backward pass
|
6448
|
+
}
|
6449
|
+
|
6450
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6451
|
+
|
6452
|
+
result->op = GGML_OP_SOFT_MAX_BACK;
|
6453
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6454
|
+
result->src0 = a;
|
6455
|
+
result->src1 = b;
|
6456
|
+
|
6457
|
+
return result;
|
6458
|
+
}
|
6459
|
+
|
6460
|
+
struct ggml_tensor * ggml_soft_max_back(
|
6461
|
+
struct ggml_context * ctx,
|
6462
|
+
struct ggml_tensor * a,
|
6463
|
+
struct ggml_tensor * b) {
|
6464
|
+
return ggml_soft_max_back_impl(ctx, a, b, false);
|
6465
|
+
}
|
6466
|
+
|
6467
|
+
struct ggml_tensor * ggml_soft_max_back_inplace(
|
6468
|
+
struct ggml_context * ctx,
|
6469
|
+
struct ggml_tensor * a,
|
6470
|
+
struct ggml_tensor * b) {
|
6471
|
+
return ggml_soft_max_back_impl(ctx, a, b, true);
|
6472
|
+
}
|
6473
|
+
|
6362
6474
|
// ggml_rope
|
6363
6475
|
|
6364
6476
|
struct ggml_tensor * ggml_rope_impl(
|
@@ -6371,7 +6483,7 @@ struct ggml_tensor * ggml_rope_impl(
|
|
6371
6483
|
GGML_ASSERT(n_past >= 0);
|
6372
6484
|
bool is_node = false;
|
6373
6485
|
|
6374
|
-
if (
|
6486
|
+
if (a->grad) {
|
6375
6487
|
is_node = true;
|
6376
6488
|
}
|
6377
6489
|
|
@@ -6425,8 +6537,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6425
6537
|
bool is_node = false;
|
6426
6538
|
|
6427
6539
|
if (a->grad) {
|
6428
|
-
|
6429
|
-
is_node = true;
|
6540
|
+
is_node = false; // TODO: implement backward
|
6430
6541
|
}
|
6431
6542
|
|
6432
6543
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
@@ -6591,7 +6702,6 @@ struct ggml_tensor * ggml_flash_attn(
|
|
6591
6702
|
bool is_node = false;
|
6592
6703
|
|
6593
6704
|
if (q->grad || k->grad || v->grad) {
|
6594
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6595
6705
|
is_node = true;
|
6596
6706
|
}
|
6597
6707
|
|
@@ -6623,7 +6733,6 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6623
6733
|
bool is_node = false;
|
6624
6734
|
|
6625
6735
|
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6626
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6627
6736
|
is_node = true;
|
6628
6737
|
}
|
6629
6738
|
|
@@ -6641,6 +6750,71 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6641
6750
|
return result;
|
6642
6751
|
}
|
6643
6752
|
|
6753
|
+
// ggml_flash_attn_back
|
6754
|
+
|
6755
|
+
struct ggml_tensor * ggml_flash_attn_back(
|
6756
|
+
struct ggml_context * ctx,
|
6757
|
+
struct ggml_tensor * q,
|
6758
|
+
struct ggml_tensor * k,
|
6759
|
+
struct ggml_tensor * v,
|
6760
|
+
struct ggml_tensor * d,
|
6761
|
+
bool masked) {
|
6762
|
+
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6763
|
+
// TODO: check if vT can be multiplied by (k*qT)
|
6764
|
+
|
6765
|
+
// d shape [D,N,ne2,ne3]
|
6766
|
+
// q shape [D,N,ne2,ne3]
|
6767
|
+
// k shape [D,M,ne2,ne3]
|
6768
|
+
// v shape [M,D,ne2,ne3]
|
6769
|
+
|
6770
|
+
const int64_t D = q->ne[0];
|
6771
|
+
const int64_t N = q->ne[1];
|
6772
|
+
const int64_t M = k->ne[1];
|
6773
|
+
const int64_t ne2 = q->ne[2];
|
6774
|
+
const int64_t ne3 = q->ne[3];
|
6775
|
+
|
6776
|
+
GGML_ASSERT(k->ne[0] == D);
|
6777
|
+
GGML_ASSERT(v->ne[0] == M);
|
6778
|
+
GGML_ASSERT(v->ne[1] == D);
|
6779
|
+
GGML_ASSERT(d->ne[0] == D);
|
6780
|
+
GGML_ASSERT(d->ne[1] == N);
|
6781
|
+
GGML_ASSERT(k->ne[2] == ne2);
|
6782
|
+
GGML_ASSERT(k->ne[3] == ne3);
|
6783
|
+
GGML_ASSERT(v->ne[2] == ne2);
|
6784
|
+
GGML_ASSERT(v->ne[3] == ne3);
|
6785
|
+
GGML_ASSERT(d->ne[2] == ne2);
|
6786
|
+
GGML_ASSERT(d->ne[3] == ne3);
|
6787
|
+
|
6788
|
+
bool is_node = false;
|
6789
|
+
|
6790
|
+
if (q->grad || k->grad || v->grad) {
|
6791
|
+
// when using this operation (in backwards pass) these grads are set.
|
6792
|
+
// we don't want to create (big) grad of our result, so is_node is false.
|
6793
|
+
is_node = false;
|
6794
|
+
}
|
6795
|
+
|
6796
|
+
// store gradients of q, k and v as continuous tensors concatenated in result.
|
6797
|
+
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
6798
|
+
// gradq->data = result->data
|
6799
|
+
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
6800
|
+
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
6801
|
+
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
6802
|
+
int64_t ne[4] = {D,M+N+M,ne2,ne3};
|
6803
|
+
|
6804
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6805
|
+
|
6806
|
+
result->op = GGML_OP_FLASH_ATTN_BACK;
|
6807
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6808
|
+
result->src0 = q;
|
6809
|
+
result->src1 = k;
|
6810
|
+
result->opt[0] = v;
|
6811
|
+
result->opt[1] = d;
|
6812
|
+
result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
|
6813
|
+
|
6814
|
+
return result;
|
6815
|
+
}
|
6816
|
+
|
6817
|
+
|
6644
6818
|
// ggml_map_unary
|
6645
6819
|
|
6646
6820
|
struct ggml_tensor * ggml_map_unary_impl_f32(
|
@@ -6725,6 +6899,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
|
|
6725
6899
|
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
6726
6900
|
}
|
6727
6901
|
|
6902
|
+
// ggml_cross_entropy_loss
|
6903
|
+
|
6904
|
+
struct ggml_tensor * ggml_cross_entropy_loss(
|
6905
|
+
struct ggml_context * ctx,
|
6906
|
+
struct ggml_tensor * a,
|
6907
|
+
struct ggml_tensor * b) {
|
6908
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6909
|
+
bool is_node = false;
|
6910
|
+
|
6911
|
+
if (a->grad || b->grad) {
|
6912
|
+
is_node = true;
|
6913
|
+
}
|
6914
|
+
|
6915
|
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
|
6916
|
+
|
6917
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS;
|
6918
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6919
|
+
result->src0 = a;
|
6920
|
+
result->src1 = b;
|
6921
|
+
|
6922
|
+
return result;
|
6923
|
+
}
|
6924
|
+
|
6925
|
+
// ggml_cross_entropy_loss_back
|
6926
|
+
|
6927
|
+
struct ggml_tensor * ggml_cross_entropy_loss_back(
|
6928
|
+
struct ggml_context * ctx,
|
6929
|
+
struct ggml_tensor * a,
|
6930
|
+
struct ggml_tensor * b,
|
6931
|
+
struct ggml_tensor * c) {
|
6932
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6933
|
+
GGML_ASSERT(ggml_is_scalar(c));
|
6934
|
+
|
6935
|
+
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
6936
|
+
|
6937
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
|
6938
|
+
result->grad = NULL;
|
6939
|
+
result->src0 = a;
|
6940
|
+
result->src1 = b;
|
6941
|
+
result->opt[0] = c;
|
6942
|
+
|
6943
|
+
return result;
|
6944
|
+
}
|
6945
|
+
|
6728
6946
|
////////////////////////////////////////////////////////////////////////////////
|
6729
6947
|
|
6730
6948
|
void ggml_set_param(
|
@@ -8875,6 +9093,99 @@ static void ggml_compute_forward_repeat(
|
|
8875
9093
|
}
|
8876
9094
|
}
|
8877
9095
|
|
9096
|
+
// ggml_compute_forward_repeat_back
|
9097
|
+
|
9098
|
+
static void ggml_compute_forward_repeat_back_f32(
|
9099
|
+
const struct ggml_compute_params * params,
|
9100
|
+
const struct ggml_tensor * src0,
|
9101
|
+
struct ggml_tensor * dst) {
|
9102
|
+
GGML_ASSERT(params->ith == 0);
|
9103
|
+
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
9104
|
+
|
9105
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9106
|
+
return;
|
9107
|
+
}
|
9108
|
+
|
9109
|
+
const int64_t ne0 = dst->ne[0];
|
9110
|
+
const int64_t ne1 = dst->ne[1];
|
9111
|
+
const int64_t ne2 = dst->ne[2];
|
9112
|
+
const int64_t ne3 = dst->ne[3];
|
9113
|
+
|
9114
|
+
const int64_t ne00 = src0->ne[0];
|
9115
|
+
const int64_t ne01 = src0->ne[1];
|
9116
|
+
const int64_t ne02 = src0->ne[2];
|
9117
|
+
const int64_t ne03 = src0->ne[3];
|
9118
|
+
|
9119
|
+
const size_t nb0 = dst->nb[0];
|
9120
|
+
const size_t nb1 = dst->nb[1];
|
9121
|
+
const size_t nb2 = dst->nb[2];
|
9122
|
+
const size_t nb3 = dst->nb[3];
|
9123
|
+
|
9124
|
+
const size_t nb00 = src0->nb[0];
|
9125
|
+
const size_t nb01 = src0->nb[1];
|
9126
|
+
const size_t nb02 = src0->nb[2];
|
9127
|
+
const size_t nb03 = src0->nb[3];
|
9128
|
+
|
9129
|
+
// guaranteed to be an integer due to the check in ggml_can_repeat
|
9130
|
+
const int nr0 = (int)(ne00/ne0);
|
9131
|
+
const int nr1 = (int)(ne01/ne1);
|
9132
|
+
const int nr2 = (int)(ne02/ne2);
|
9133
|
+
const int nr3 = (int)(ne03/ne3);
|
9134
|
+
|
9135
|
+
// TODO: support for transposed / permuted tensors
|
9136
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
9137
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
9138
|
+
|
9139
|
+
if (ggml_is_contiguous(dst)) {
|
9140
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
9141
|
+
} else {
|
9142
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9143
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9144
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9145
|
+
ggml_vec_set_f32(ne0,
|
9146
|
+
(float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
|
9147
|
+
0);
|
9148
|
+
}
|
9149
|
+
}
|
9150
|
+
}
|
9151
|
+
}
|
9152
|
+
|
9153
|
+
// TODO: maybe this is not optimal?
|
9154
|
+
for (int i3 = 0; i3 < nr3; i3++) {
|
9155
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9156
|
+
for (int i2 = 0; i2 < nr2; i2++) {
|
9157
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9158
|
+
for (int i1 = 0; i1 < nr1; i1++) {
|
9159
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9160
|
+
for (int i0 = 0; i0 < nr0; i0++) {
|
9161
|
+
ggml_vec_acc_f32(ne0,
|
9162
|
+
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
|
9163
|
+
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
|
9164
|
+
}
|
9165
|
+
}
|
9166
|
+
}
|
9167
|
+
}
|
9168
|
+
}
|
9169
|
+
}
|
9170
|
+
}
|
9171
|
+
}
|
9172
|
+
|
9173
|
+
static void ggml_compute_forward_repeat_back(
|
9174
|
+
const struct ggml_compute_params * params,
|
9175
|
+
const struct ggml_tensor * src0,
|
9176
|
+
struct ggml_tensor * dst) {
|
9177
|
+
switch (src0->type) {
|
9178
|
+
case GGML_TYPE_F32:
|
9179
|
+
{
|
9180
|
+
ggml_compute_forward_repeat_back_f32(params, src0, dst);
|
9181
|
+
} break;
|
9182
|
+
default:
|
9183
|
+
{
|
9184
|
+
GGML_ASSERT(false);
|
9185
|
+
} break;
|
9186
|
+
}
|
9187
|
+
}
|
9188
|
+
|
8878
9189
|
// ggml_compute_forward_abs
|
8879
9190
|
|
8880
9191
|
static void ggml_compute_forward_abs_f32(
|
@@ -10249,18 +10560,188 @@ static void ggml_compute_forward_mul_mat(
|
|
10249
10560
|
}
|
10250
10561
|
}
|
10251
10562
|
|
10252
|
-
//
|
10563
|
+
// ggml_compute_forward_out_prod
|
10253
10564
|
|
10254
|
-
|
10565
|
+
|
10566
|
+
static void ggml_compute_forward_out_prod_f32(
|
10255
10567
|
const struct ggml_compute_params * params,
|
10256
10568
|
const struct ggml_tensor * src0,
|
10257
10569
|
const struct ggml_tensor * src1,
|
10258
|
-
|
10259
|
-
|
10260
|
-
|
10261
|
-
|
10262
|
-
|
10263
|
-
|
10570
|
+
struct ggml_tensor * dst) {
|
10571
|
+
int64_t t0 = ggml_perf_time_us();
|
10572
|
+
UNUSED(t0);
|
10573
|
+
|
10574
|
+
const int64_t ne00 = src0->ne[0];
|
10575
|
+
const int64_t ne01 = src0->ne[1];
|
10576
|
+
const int64_t ne02 = src0->ne[2];
|
10577
|
+
const int64_t ne03 = src0->ne[3];
|
10578
|
+
|
10579
|
+
const int64_t ne10 = src1->ne[0];
|
10580
|
+
//const int64_t ne11 = src1->ne[1];
|
10581
|
+
const int64_t ne12 = src1->ne[2];
|
10582
|
+
const int64_t ne13 = src1->ne[3];
|
10583
|
+
|
10584
|
+
const int64_t ne0 = dst->ne[0];
|
10585
|
+
const int64_t ne1 = dst->ne[1];
|
10586
|
+
const int64_t ne2 = dst->ne[2];
|
10587
|
+
const int64_t ne3 = dst->ne[3];
|
10588
|
+
|
10589
|
+
const int nb00 = src0->nb[0];
|
10590
|
+
const int nb01 = src0->nb[1];
|
10591
|
+
const int nb02 = src0->nb[2];
|
10592
|
+
const int nb03 = src0->nb[3];
|
10593
|
+
|
10594
|
+
const int nb10 = src1->nb[0];
|
10595
|
+
const int nb11 = src1->nb[1];
|
10596
|
+
const int nb12 = src1->nb[2];
|
10597
|
+
const int nb13 = src1->nb[3];
|
10598
|
+
|
10599
|
+
const int nb0 = dst->nb[0];
|
10600
|
+
const int nb1 = dst->nb[1];
|
10601
|
+
const int nb2 = dst->nb[2];
|
10602
|
+
const int nb3 = dst->nb[3];
|
10603
|
+
|
10604
|
+
const int ith = params->ith;
|
10605
|
+
const int nth = params->nth;
|
10606
|
+
|
10607
|
+
GGML_ASSERT(ne02 == ne12);
|
10608
|
+
GGML_ASSERT(ne03 == ne13);
|
10609
|
+
GGML_ASSERT(ne2 == ne12);
|
10610
|
+
GGML_ASSERT(ne3 == ne13);
|
10611
|
+
|
10612
|
+
// we don't support permuted src0 or src1
|
10613
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
10614
|
+
|
10615
|
+
// dst cannot be transposed or permuted
|
10616
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
10617
|
+
// GGML_ASSERT(nb0 <= nb1);
|
10618
|
+
// GGML_ASSERT(nb1 <= nb2);
|
10619
|
+
// GGML_ASSERT(nb2 <= nb3);
|
10620
|
+
|
10621
|
+
GGML_ASSERT(ne0 == ne00);
|
10622
|
+
GGML_ASSERT(ne1 == ne10);
|
10623
|
+
GGML_ASSERT(ne2 == ne02);
|
10624
|
+
GGML_ASSERT(ne3 == ne03);
|
10625
|
+
|
10626
|
+
// nb01 >= nb00 - src0 is not transposed
|
10627
|
+
// compute by src0 rows
|
10628
|
+
|
10629
|
+
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
|
10630
|
+
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10631
|
+
|
10632
|
+
if (params->type == GGML_TASK_INIT) {
|
10633
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
10634
|
+
return;
|
10635
|
+
}
|
10636
|
+
|
10637
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
10638
|
+
return;
|
10639
|
+
}
|
10640
|
+
|
10641
|
+
// parallelize by last three dimensions
|
10642
|
+
|
10643
|
+
// total rows in dst
|
10644
|
+
const int64_t nr = ne1*ne2*ne3;
|
10645
|
+
|
10646
|
+
// rows per thread
|
10647
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
10648
|
+
|
10649
|
+
// row range for this thread
|
10650
|
+
const int64_t ir0 = dr*ith;
|
10651
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
10652
|
+
|
10653
|
+
// dst[:,:,:,:] = 0
|
10654
|
+
// for i2,i3:
|
10655
|
+
// for i1:
|
10656
|
+
// for i01:
|
10657
|
+
// for i0:
|
10658
|
+
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
10659
|
+
|
10660
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
10661
|
+
// dst indices
|
10662
|
+
const int64_t i3 = ir/(ne2*ne1);
|
10663
|
+
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
10664
|
+
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
10665
|
+
|
10666
|
+
const int64_t i02 = i2;
|
10667
|
+
const int64_t i03 = i3;
|
10668
|
+
|
10669
|
+
//const int64_t i10 = i1;
|
10670
|
+
const int64_t i12 = i2;
|
10671
|
+
const int64_t i13 = i3;
|
10672
|
+
|
10673
|
+
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
10674
|
+
const int64_t i11 = i01;
|
10675
|
+
|
10676
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
10677
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
10678
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
10679
|
+
|
10680
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
10681
|
+
// for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
10682
|
+
// d[i0] += s0[i0] * s1[i1];
|
10683
|
+
// }
|
10684
|
+
}
|
10685
|
+
}
|
10686
|
+
|
10687
|
+
//int64_t t1 = ggml_perf_time_us();
|
10688
|
+
//static int64_t acc = 0;
|
10689
|
+
//acc += t1 - t0;
|
10690
|
+
//if (t1 - t0 > 10) {
|
10691
|
+
// printf("\n");
|
10692
|
+
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
|
10693
|
+
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
|
10694
|
+
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
|
10695
|
+
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
|
10696
|
+
|
10697
|
+
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
|
10698
|
+
//}
|
10699
|
+
}
|
10700
|
+
|
10701
|
+
static void ggml_compute_forward_out_prod(
|
10702
|
+
const struct ggml_compute_params * params,
|
10703
|
+
const struct ggml_tensor * src0,
|
10704
|
+
const struct ggml_tensor * src1,
|
10705
|
+
struct ggml_tensor * dst) {
|
10706
|
+
switch (src0->type) {
|
10707
|
+
case GGML_TYPE_Q4_0:
|
10708
|
+
case GGML_TYPE_Q4_1:
|
10709
|
+
case GGML_TYPE_Q5_0:
|
10710
|
+
case GGML_TYPE_Q5_1:
|
10711
|
+
case GGML_TYPE_Q8_0:
|
10712
|
+
case GGML_TYPE_Q8_1:
|
10713
|
+
{
|
10714
|
+
GGML_ASSERT(false); // todo
|
10715
|
+
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
10716
|
+
} break;
|
10717
|
+
case GGML_TYPE_F16:
|
10718
|
+
{
|
10719
|
+
GGML_ASSERT(false); // todo
|
10720
|
+
// ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
|
10721
|
+
} break;
|
10722
|
+
case GGML_TYPE_F32:
|
10723
|
+
{
|
10724
|
+
ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
|
10725
|
+
} break;
|
10726
|
+
default:
|
10727
|
+
{
|
10728
|
+
GGML_ASSERT(false);
|
10729
|
+
} break;
|
10730
|
+
}
|
10731
|
+
}
|
10732
|
+
|
10733
|
+
// ggml_compute_forward_scale
|
10734
|
+
|
10735
|
+
static void ggml_compute_forward_scale_f32(
|
10736
|
+
const struct ggml_compute_params * params,
|
10737
|
+
const struct ggml_tensor * src0,
|
10738
|
+
const struct ggml_tensor * src1,
|
10739
|
+
struct ggml_tensor * dst) {
|
10740
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
10741
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
10742
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
10743
|
+
GGML_ASSERT(ggml_is_scalar(src1));
|
10744
|
+
|
10264
10745
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10265
10746
|
return;
|
10266
10747
|
}
|
@@ -10671,7 +11152,11 @@ static void ggml_compute_forward_get_rows_back_f32(
|
|
10671
11152
|
GGML_ASSERT(ggml_is_contiguous(opt0));
|
10672
11153
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
10673
11154
|
|
10674
|
-
ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11155
|
+
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11156
|
+
|
11157
|
+
if (params->type == GGML_TASK_INIT) {
|
11158
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
11159
|
+
}
|
10675
11160
|
|
10676
11161
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10677
11162
|
return;
|
@@ -10815,8 +11300,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10815
11300
|
const struct ggml_tensor * src1,
|
10816
11301
|
struct ggml_tensor * dst,
|
10817
11302
|
const float value) {
|
10818
|
-
|
10819
|
-
|
11303
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11304
|
+
GGML_ASSERT(ggml_nelements(src1) == 2);
|
10820
11305
|
|
10821
11306
|
const int ith = params->ith;
|
10822
11307
|
const int nth = params->nth;
|
@@ -10824,7 +11309,7 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10824
11309
|
const int n_past = ((int32_t *) src1->data)[0];
|
10825
11310
|
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
10826
11311
|
|
10827
|
-
|
11312
|
+
GGML_ASSERT(n_past >= 0);
|
10828
11313
|
|
10829
11314
|
if (!inplace && (params->type == GGML_TASK_INIT)) {
|
10830
11315
|
// memcpy needs to be synchronized across threads to avoid race conditions.
|
@@ -10848,8 +11333,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10848
11333
|
const int nr = src0->ne[1];
|
10849
11334
|
const int nz = n/nr;
|
10850
11335
|
|
10851
|
-
|
10852
|
-
|
11336
|
+
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
11337
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
10853
11338
|
|
10854
11339
|
for (int k = 0; k < nz; k++) {
|
10855
11340
|
for (int j = ith; j < nr; j += nth) {
|
@@ -10985,6 +11470,101 @@ static void ggml_compute_forward_soft_max(
|
|
10985
11470
|
}
|
10986
11471
|
}
|
10987
11472
|
|
11473
|
+
// ggml_compute_forward_soft_max_back
|
11474
|
+
|
11475
|
+
static void ggml_compute_forward_soft_max_back_f32(
|
11476
|
+
const struct ggml_compute_params * params,
|
11477
|
+
const struct ggml_tensor * src0,
|
11478
|
+
const struct ggml_tensor * src1,
|
11479
|
+
struct ggml_tensor * dst) {
|
11480
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
11481
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
11482
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
11483
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11484
|
+
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
11485
|
+
|
11486
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11487
|
+
return;
|
11488
|
+
}
|
11489
|
+
|
11490
|
+
// TODO: handle transposed/permuted matrices
|
11491
|
+
|
11492
|
+
const int ith = params->ith;
|
11493
|
+
const int nth = params->nth;
|
11494
|
+
|
11495
|
+
const int nc = src0->ne[0];
|
11496
|
+
const int nr = ggml_nrows(src0);
|
11497
|
+
|
11498
|
+
// rows per thread
|
11499
|
+
const int dr = (nr + nth - 1)/nth;
|
11500
|
+
|
11501
|
+
// row range for this thread
|
11502
|
+
const int ir0 = dr*ith;
|
11503
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
11504
|
+
|
11505
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
11506
|
+
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
|
11507
|
+
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
|
11508
|
+
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
|
11509
|
+
|
11510
|
+
#ifndef NDEBUG
|
11511
|
+
for (int i = 0; i < nc; ++i) {
|
11512
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
11513
|
+
assert(!isnan(dy[i]));
|
11514
|
+
assert(!isnan(y[i]));
|
11515
|
+
}
|
11516
|
+
#endif
|
11517
|
+
// Jii = yi - yi*yi
|
11518
|
+
// Jij = -yi*yj
|
11519
|
+
// J = diag(y)-y.T*y
|
11520
|
+
// dx = J * dy
|
11521
|
+
// dxk = sum_i(Jki * dyi)
|
11522
|
+
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
|
11523
|
+
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
|
11524
|
+
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
|
11525
|
+
// dxk = -yk * dot(y, dy) + yk*dyk
|
11526
|
+
// dxk = yk * (- dot(y, dy) + dyk)
|
11527
|
+
// dxk = yk * (dyk - dot(y, dy))
|
11528
|
+
//
|
11529
|
+
// post-order:
|
11530
|
+
// dot_y_dy := dot(y, dy)
|
11531
|
+
// dx := dy
|
11532
|
+
// dx := dx - dot_y_dy
|
11533
|
+
// dx := dx * y
|
11534
|
+
|
11535
|
+
// linear runtime, no additional memory
|
11536
|
+
float dot_y_dy = 0;
|
11537
|
+
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
|
11538
|
+
ggml_vec_cpy_f32 (nc, dx, dy);
|
11539
|
+
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
11540
|
+
ggml_vec_mul_f32 (nc, dx, dx, y);
|
11541
|
+
|
11542
|
+
#ifndef NDEBUG
|
11543
|
+
for (int i = 0; i < nc; ++i) {
|
11544
|
+
assert(!isnan(dx[i]));
|
11545
|
+
assert(!isinf(dx[i]));
|
11546
|
+
}
|
11547
|
+
#endif
|
11548
|
+
}
|
11549
|
+
}
|
11550
|
+
|
11551
|
+
static void ggml_compute_forward_soft_max_back(
|
11552
|
+
const struct ggml_compute_params * params,
|
11553
|
+
const struct ggml_tensor * src0,
|
11554
|
+
const struct ggml_tensor * src1,
|
11555
|
+
struct ggml_tensor * dst) {
|
11556
|
+
switch (src0->type) {
|
11557
|
+
case GGML_TYPE_F32:
|
11558
|
+
{
|
11559
|
+
ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
|
11560
|
+
} break;
|
11561
|
+
default:
|
11562
|
+
{
|
11563
|
+
GGML_ASSERT(false);
|
11564
|
+
} break;
|
11565
|
+
}
|
11566
|
+
}
|
11567
|
+
|
10988
11568
|
// ggml_compute_forward_alibi
|
10989
11569
|
|
10990
11570
|
static void ggml_compute_forward_alibi_f32(
|
@@ -12938,42 +13518,616 @@ static void ggml_compute_forward_flash_ff(
|
|
12938
13518
|
}
|
12939
13519
|
}
|
12940
13520
|
|
12941
|
-
//
|
13521
|
+
// ggml_compute_forward_flash_attn_back
|
12942
13522
|
|
12943
|
-
static void
|
13523
|
+
static void ggml_compute_forward_flash_attn_back_f32(
|
12944
13524
|
const struct ggml_compute_params * params,
|
12945
|
-
const struct ggml_tensor *
|
12946
|
-
struct ggml_tensor *
|
12947
|
-
const
|
12948
|
-
|
13525
|
+
const struct ggml_tensor * q,
|
13526
|
+
const struct ggml_tensor * k,
|
13527
|
+
const struct ggml_tensor * v,
|
13528
|
+
const struct ggml_tensor * d,
|
13529
|
+
const bool masked,
|
13530
|
+
struct ggml_tensor * dst) {
|
13531
|
+
int64_t t0 = ggml_perf_time_us();
|
13532
|
+
UNUSED(t0);
|
12949
13533
|
|
12950
|
-
|
12951
|
-
|
12952
|
-
|
13534
|
+
const int64_t neq0 = q->ne[0];
|
13535
|
+
const int64_t neq1 = q->ne[1];
|
13536
|
+
const int64_t neq2 = q->ne[2];
|
13537
|
+
const int64_t neq3 = q->ne[3];
|
12953
13538
|
|
12954
|
-
const
|
12955
|
-
const
|
13539
|
+
const int64_t nek0 = k->ne[0];
|
13540
|
+
const int64_t nek1 = k->ne[1];
|
13541
|
+
//const int64_t nek2 = k->ne[2];
|
13542
|
+
//const int64_t nek3 = k->ne[3];
|
12956
13543
|
|
12957
|
-
|
12958
|
-
|
13544
|
+
const int64_t nev0 = v->ne[0];
|
13545
|
+
const int64_t nev1 = v->ne[1];
|
13546
|
+
//const int64_t nev2 = v->ne[2];
|
13547
|
+
//const int64_t nev3 = v->ne[3];
|
12959
13548
|
|
12960
|
-
|
12961
|
-
|
12962
|
-
|
12963
|
-
|
12964
|
-
}
|
12965
|
-
}
|
13549
|
+
const int64_t ned0 = d->ne[0];
|
13550
|
+
const int64_t ned1 = d->ne[1];
|
13551
|
+
//const int64_t ned2 = d->ne[2];
|
13552
|
+
//const int64_t ned3 = d->ne[3];
|
12966
13553
|
|
13554
|
+
const int64_t ne0 = dst->ne[0];
|
13555
|
+
const int64_t ne1 = dst->ne[1];
|
13556
|
+
const int64_t ne2 = dst->ne[2];
|
13557
|
+
const int64_t ne3 = dst->ne[3];
|
12967
13558
|
|
12968
|
-
|
12969
|
-
|
12970
|
-
|
12971
|
-
|
12972
|
-
|
13559
|
+
const int nbk0 = k->nb[0];
|
13560
|
+
const int nbk1 = k->nb[1];
|
13561
|
+
const int nbk2 = k->nb[2];
|
13562
|
+
const int nbk3 = k->nb[3];
|
13563
|
+
|
13564
|
+
const int nbq0 = q->nb[0];
|
13565
|
+
const int nbq1 = q->nb[1];
|
13566
|
+
const int nbq2 = q->nb[2];
|
13567
|
+
const int nbq3 = q->nb[3];
|
13568
|
+
|
13569
|
+
const int nbv0 = v->nb[0];
|
13570
|
+
const int nbv1 = v->nb[1];
|
13571
|
+
const int nbv2 = v->nb[2];
|
13572
|
+
const int nbv3 = v->nb[3];
|
13573
|
+
|
13574
|
+
const int nbd0 = d->nb[0];
|
13575
|
+
const int nbd1 = d->nb[1];
|
13576
|
+
const int nbd2 = d->nb[2];
|
13577
|
+
const int nbd3 = d->nb[3];
|
13578
|
+
|
13579
|
+
const int nb0 = dst->nb[0];
|
13580
|
+
const int nb1 = dst->nb[1];
|
13581
|
+
const int nb2 = dst->nb[2];
|
13582
|
+
const int nb3 = dst->nb[3];
|
13583
|
+
|
13584
|
+
const int ith = params->ith;
|
13585
|
+
const int nth = params->nth;
|
13586
|
+
|
13587
|
+
const int64_t D = neq0;
|
13588
|
+
const int64_t N = neq1;
|
13589
|
+
const int64_t P = nek1 - N;
|
13590
|
+
const int64_t M = P + N;
|
13591
|
+
|
13592
|
+
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
13593
|
+
const int mxDM = MAX(D, Mup);
|
13594
|
+
|
13595
|
+
// GGML_ASSERT(ne0 == D);
|
13596
|
+
// GGML_ASSERT(ne1 == N);
|
13597
|
+
GGML_ASSERT(P >= 0);
|
13598
|
+
|
13599
|
+
GGML_ASSERT(nbq0 == sizeof(float));
|
13600
|
+
GGML_ASSERT(nbk0 == sizeof(float));
|
13601
|
+
GGML_ASSERT(nbv0 == sizeof(float));
|
13602
|
+
|
13603
|
+
GGML_ASSERT(neq0 == D);
|
13604
|
+
GGML_ASSERT(nek0 == D);
|
13605
|
+
GGML_ASSERT(nev1 == D);
|
13606
|
+
GGML_ASSERT(ned0 == D);
|
13607
|
+
|
13608
|
+
GGML_ASSERT(neq1 == N);
|
13609
|
+
GGML_ASSERT(nek1 == N + P);
|
13610
|
+
GGML_ASSERT(nev1 == D);
|
13611
|
+
GGML_ASSERT(ned1 == N);
|
13612
|
+
|
13613
|
+
// dst cannot be transposed or permuted
|
13614
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
13615
|
+
GGML_ASSERT(nb0 <= nb1);
|
13616
|
+
GGML_ASSERT(nb1 <= nb2);
|
13617
|
+
GGML_ASSERT(nb2 <= nb3);
|
13618
|
+
|
13619
|
+
if (params->type == GGML_TASK_INIT) {
|
13620
|
+
if (ith == 0) {
|
13621
|
+
memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
|
13622
|
+
}
|
13623
|
+
return;
|
13624
|
+
}
|
13625
|
+
|
13626
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
13627
|
+
return;
|
13628
|
+
}
|
13629
|
+
|
13630
|
+
// parallelize by q rows using ggml_vec_dot_f32
|
13631
|
+
|
13632
|
+
// total rows in q
|
13633
|
+
const int nr = neq2*neq3;
|
13634
|
+
|
13635
|
+
// rows per thread
|
13636
|
+
const int dr = (nr + nth - 1)/nth;
|
13637
|
+
|
13638
|
+
// row range for this thread
|
13639
|
+
const int ir0 = dr*ith;
|
13640
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
13641
|
+
|
13642
|
+
const float scale = 1.0f/sqrtf(D);
|
13643
|
+
|
13644
|
+
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
13645
|
+
|
13646
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
13647
|
+
// q indices
|
13648
|
+
const int iq3 = ir/(neq2);
|
13649
|
+
const int iq2 = ir - iq3*neq2;
|
13650
|
+
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
13651
|
+
|
13652
|
+
|
13653
|
+
// not sure about CACHE_LINE_SIZE_F32..
|
13654
|
+
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
13655
|
+
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
13656
|
+
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
13657
|
+
|
13658
|
+
for (int i = M; i < Mup; ++i) {
|
13659
|
+
S[i] = -INFINITY;
|
13660
|
+
}
|
13661
|
+
|
13662
|
+
for (int64_t ic = 0; ic < nek1; ++ic) {
|
13663
|
+
// k indices
|
13664
|
+
const int ik3 = iq3;
|
13665
|
+
const int ik2 = iq2;
|
13666
|
+
const int ik1 = ic;
|
13667
|
+
|
13668
|
+
// S indices
|
13669
|
+
const int i1 = ik1;
|
13670
|
+
|
13671
|
+
ggml_vec_dot_f32(neq0,
|
13672
|
+
S + i1,
|
13673
|
+
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
13674
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
13675
|
+
}
|
13676
|
+
|
13677
|
+
// scale
|
13678
|
+
ggml_vec_scale_f32(nek1, S, scale);
|
13679
|
+
|
13680
|
+
if (masked) {
|
13681
|
+
for (int64_t i = P; i < M; i++) {
|
13682
|
+
if (i > P + iq1) {
|
13683
|
+
S[i] = -INFINITY;
|
13684
|
+
}
|
13685
|
+
}
|
13686
|
+
}
|
13687
|
+
|
13688
|
+
// softmax
|
13689
|
+
{
|
13690
|
+
float max = -INFINITY;
|
13691
|
+
ggml_vec_max_f32(M, &max, S);
|
13692
|
+
|
13693
|
+
ggml_float sum = 0.0;
|
13694
|
+
{
|
13695
|
+
#ifdef GGML_SOFT_MAX_ACCELERATE
|
13696
|
+
max = -max;
|
13697
|
+
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
13698
|
+
vvexpf(SM, SM, &Mup);
|
13699
|
+
ggml_vec_sum_f32(Mup, &sum, SM);
|
13700
|
+
#else
|
13701
|
+
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
13702
|
+
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
13703
|
+
|
13704
|
+
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
13705
|
+
float * SR = S + i;
|
13706
|
+
float * SW = SM + i;
|
13707
|
+
|
13708
|
+
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
13709
|
+
if (SR[j] == -INFINITY) {
|
13710
|
+
SW[j] = 0.0f;
|
13711
|
+
} else {
|
13712
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
13713
|
+
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
13714
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
13715
|
+
sump[j] += (ggml_float)val;
|
13716
|
+
SW[j] = val;
|
13717
|
+
}
|
13718
|
+
}
|
13719
|
+
}
|
13720
|
+
|
13721
|
+
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
13722
|
+
sum += sump[i];
|
13723
|
+
}
|
13724
|
+
#endif
|
13725
|
+
}
|
13726
|
+
|
13727
|
+
assert(sum > 0.0);
|
13728
|
+
|
13729
|
+
sum = 1.0/sum;
|
13730
|
+
ggml_vec_scale_f32(M, SM, sum);
|
13731
|
+
|
13732
|
+
}
|
13733
|
+
|
13734
|
+
// step-by-step explanation
|
13735
|
+
{
|
13736
|
+
// forward-process shape grads from backward process
|
13737
|
+
// parallel_for iq2,iq3:
|
13738
|
+
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
13739
|
+
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
13740
|
+
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
13741
|
+
// for iq1:
|
13742
|
+
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
13743
|
+
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
13744
|
+
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
13745
|
+
// S0 = -Inf [D,1,1,1]
|
13746
|
+
// ~S1[i] = dot(kcur[:D,i], qcur)
|
13747
|
+
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
13748
|
+
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
13749
|
+
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13750
|
+
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
13751
|
+
// ~S5[i] = dot(vcur[:,i], S4)
|
13752
|
+
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
13753
|
+
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
13754
|
+
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
13755
|
+
// dst backward-/ grad[dst] = d
|
13756
|
+
//
|
13757
|
+
// output gradients with their dependencies:
|
13758
|
+
//
|
13759
|
+
// grad[kcur] = grad[S1].T @ qcur
|
13760
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
13761
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13762
|
+
// grad[S4] = grad[S5] @ vcur
|
13763
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
13764
|
+
// grad[qcur] = grad[S1] @ kcur
|
13765
|
+
// grad[vcur] = grad[S5].T @ S4
|
13766
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
13767
|
+
//
|
13768
|
+
// in post-order:
|
13769
|
+
//
|
13770
|
+
// S1 = qcur @ kcur.T
|
13771
|
+
// S2 = S1 * scale
|
13772
|
+
// S3 = diag_mask_inf(S2, P)
|
13773
|
+
// S4 = softmax(S3)
|
13774
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
13775
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13776
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
13777
|
+
// grad[qcur] = grad[S1] @ kcur
|
13778
|
+
// grad[kcur] = grad[S1].T @ qcur
|
13779
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
13780
|
+
//
|
13781
|
+
// using less variables (SM=S4):
|
13782
|
+
//
|
13783
|
+
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
13784
|
+
// SM = softmax(S)
|
13785
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
13786
|
+
// dot_SM_gradSM = dot(SM, S)
|
13787
|
+
// S = SM * (S - dot(SM, S))
|
13788
|
+
// S = diag_mask_zero(S, P) * scale
|
13789
|
+
//
|
13790
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
13791
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
13792
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
13793
|
+
}
|
13794
|
+
|
13795
|
+
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
13796
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
13797
|
+
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
13798
|
+
ggml_vec_set_f32(M, S, 0);
|
13799
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
13800
|
+
// dst indices
|
13801
|
+
const int i1 = iq1;
|
13802
|
+
const int i2 = iq2;
|
13803
|
+
const int i3 = iq3;
|
13804
|
+
|
13805
|
+
ggml_vec_mad_f32(M,
|
13806
|
+
S,
|
13807
|
+
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
13808
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
13809
|
+
}
|
13810
|
+
|
13811
|
+
// S = SM * (S - dot(SM, S))
|
13812
|
+
float dot_SM_gradSM = 0;
|
13813
|
+
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
13814
|
+
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
13815
|
+
ggml_vec_mul_f32 (M, S, S, SM);
|
13816
|
+
|
13817
|
+
// S = diag_mask_zero(S, P) * scale
|
13818
|
+
if (masked) {
|
13819
|
+
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
13820
|
+
// S[i] = 0;
|
13821
|
+
// }
|
13822
|
+
for (int64_t i = P; i < M; i++) {
|
13823
|
+
if (i > P + iq1) {
|
13824
|
+
S[i] = 0;
|
13825
|
+
}
|
13826
|
+
}
|
13827
|
+
}
|
13828
|
+
ggml_vec_scale_f32(M, S, scale);
|
13829
|
+
|
13830
|
+
void * grad_q = (char *) dst->data;
|
13831
|
+
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
13832
|
+
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
13833
|
+
|
13834
|
+
const size_t nbgq1 = nb0*neq0;
|
13835
|
+
const size_t nbgq2 = nb0*neq0*neq1;
|
13836
|
+
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
13837
|
+
|
13838
|
+
const size_t nbgk1 = nb0*nek0;
|
13839
|
+
const size_t nbgk2 = nb0*nek0*nek1;
|
13840
|
+
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
13841
|
+
|
13842
|
+
const size_t nbgv1 = nb0*nev0;
|
13843
|
+
const size_t nbgv2 = nb0*nev0*nev1;
|
13844
|
+
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
13845
|
+
|
13846
|
+
// S shape [M,1]
|
13847
|
+
// SM shape [M,1]
|
13848
|
+
// kcur shape [D,M]
|
13849
|
+
// qcur shape [D,1]
|
13850
|
+
// vcur shape [M,D]
|
13851
|
+
//
|
13852
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
13853
|
+
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
13854
|
+
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
13855
|
+
//
|
13856
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
13857
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
13858
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
13859
|
+
// dst indices
|
13860
|
+
const int i1 = iq1;
|
13861
|
+
const int i2 = iq2;
|
13862
|
+
const int i3 = iq3;
|
13863
|
+
|
13864
|
+
ggml_vec_mad_f32(D,
|
13865
|
+
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
13866
|
+
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
|
13867
|
+
S[ic]);
|
13868
|
+
}
|
13869
|
+
|
13870
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
13871
|
+
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
13872
|
+
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
13873
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
13874
|
+
// dst indices
|
13875
|
+
const int i1 = iq1;
|
13876
|
+
const int i2 = iq2;
|
13877
|
+
const int i3 = iq3;
|
13878
|
+
|
13879
|
+
// ggml_vec_set_f32(D,
|
13880
|
+
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
13881
|
+
// 0);
|
13882
|
+
ggml_vec_mad_f32(D,
|
13883
|
+
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
13884
|
+
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
|
13885
|
+
S[ic]);
|
13886
|
+
}
|
13887
|
+
|
13888
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
13889
|
+
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
|
13890
|
+
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
|
13891
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
13892
|
+
// dst indices
|
13893
|
+
const int i1 = iq1;
|
13894
|
+
const int i2 = iq2;
|
13895
|
+
const int i3 = iq3;
|
13896
|
+
|
13897
|
+
// ggml_vec_set_f32(M,
|
13898
|
+
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
13899
|
+
// 0);
|
13900
|
+
ggml_vec_mad_f32(M,
|
13901
|
+
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
13902
|
+
SM,
|
13903
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
13904
|
+
}
|
13905
|
+
}
|
13906
|
+
}
|
13907
|
+
}
|
13908
|
+
|
13909
|
+
static void ggml_compute_forward_flash_attn_back(
|
13910
|
+
const struct ggml_compute_params * params,
|
13911
|
+
const struct ggml_tensor * q,
|
13912
|
+
const struct ggml_tensor * k,
|
13913
|
+
const struct ggml_tensor * v,
|
13914
|
+
const struct ggml_tensor * d,
|
13915
|
+
const bool masked,
|
13916
|
+
struct ggml_tensor * dst) {
|
13917
|
+
switch (q->type) {
|
13918
|
+
case GGML_TYPE_F32:
|
13919
|
+
{
|
13920
|
+
ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
|
13921
|
+
} break;
|
13922
|
+
default:
|
13923
|
+
{
|
13924
|
+
GGML_ASSERT(false);
|
13925
|
+
} break;
|
13926
|
+
}
|
13927
|
+
}
|
13928
|
+
|
13929
|
+
// ggml_compute_forward_map_unary
|
13930
|
+
|
13931
|
+
static void ggml_compute_forward_map_unary_f32(
|
13932
|
+
const struct ggml_compute_params * params,
|
13933
|
+
const struct ggml_tensor * src0,
|
13934
|
+
struct ggml_tensor * dst,
|
13935
|
+
const ggml_unary_op_f32_t fun) {
|
13936
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
13937
|
+
|
13938
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
13939
|
+
return;
|
13940
|
+
}
|
13941
|
+
|
13942
|
+
const int n = ggml_nrows(src0);
|
13943
|
+
const int nc = src0->ne[0];
|
13944
|
+
|
13945
|
+
assert( dst->nb[0] == sizeof(float));
|
13946
|
+
assert(src0->nb[0] == sizeof(float));
|
13947
|
+
|
13948
|
+
for (int i = 0; i < n; i++) {
|
13949
|
+
fun(nc,
|
13950
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
13951
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
13952
|
+
}
|
13953
|
+
}
|
13954
|
+
|
13955
|
+
|
13956
|
+
static void ggml_compute_forward_map_unary(
|
13957
|
+
const struct ggml_compute_params * params,
|
13958
|
+
const struct ggml_tensor * src0,
|
13959
|
+
struct ggml_tensor * dst,
|
13960
|
+
const ggml_unary_op_f32_t fun) {
|
13961
|
+
switch (src0->type) {
|
13962
|
+
case GGML_TYPE_F32:
|
13963
|
+
{
|
13964
|
+
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
13965
|
+
} break;
|
13966
|
+
default:
|
13967
|
+
{
|
13968
|
+
GGML_ASSERT(false);
|
13969
|
+
} break;
|
13970
|
+
}
|
13971
|
+
}
|
13972
|
+
|
13973
|
+
// ggml_compute_forward_map_binary
|
13974
|
+
|
13975
|
+
static void ggml_compute_forward_map_binary_f32(
|
13976
|
+
const struct ggml_compute_params * params,
|
13977
|
+
const struct ggml_tensor * src0,
|
13978
|
+
const struct ggml_tensor * src1,
|
13979
|
+
struct ggml_tensor * dst,
|
13980
|
+
const ggml_binary_op_f32_t fun) {
|
13981
|
+
assert(params->ith == 0);
|
13982
|
+
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
13983
|
+
|
13984
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
13985
|
+
return;
|
13986
|
+
}
|
13987
|
+
|
13988
|
+
const int n = ggml_nrows(src0);
|
13989
|
+
const int nc = src0->ne[0];
|
13990
|
+
|
13991
|
+
assert( dst->nb[0] == sizeof(float));
|
13992
|
+
assert(src0->nb[0] == sizeof(float));
|
13993
|
+
assert(src1->nb[0] == sizeof(float));
|
13994
|
+
|
13995
|
+
for (int i = 0; i < n; i++) {
|
13996
|
+
fun(nc,
|
13997
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
13998
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
13999
|
+
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
14000
|
+
}
|
14001
|
+
}
|
14002
|
+
|
14003
|
+
|
14004
|
+
static void ggml_compute_forward_map_binary(
|
14005
|
+
const struct ggml_compute_params * params,
|
14006
|
+
const struct ggml_tensor * src0,
|
14007
|
+
const struct ggml_tensor * src1,
|
14008
|
+
struct ggml_tensor * dst,
|
14009
|
+
const ggml_binary_op_f32_t fun) {
|
14010
|
+
switch (src0->type) {
|
14011
|
+
case GGML_TYPE_F32:
|
14012
|
+
{
|
14013
|
+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
14014
|
+
} break;
|
14015
|
+
default:
|
14016
|
+
{
|
14017
|
+
GGML_ASSERT(false);
|
14018
|
+
} break;
|
14019
|
+
}
|
14020
|
+
}
|
14021
|
+
|
14022
|
+
// ggml_compute_forward_cross_entropy_loss
|
14023
|
+
|
14024
|
+
static void ggml_compute_forward_cross_entropy_loss_f32(
|
14025
|
+
const struct ggml_compute_params * params,
|
14026
|
+
const struct ggml_tensor * src0,
|
14027
|
+
const struct ggml_tensor * src1,
|
14028
|
+
struct ggml_tensor * dst) {
|
14029
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14030
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14031
|
+
GGML_ASSERT(ggml_is_scalar(dst));
|
14032
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
14033
|
+
|
14034
|
+
const int ith = params->ith;
|
14035
|
+
const int nth = params->nth;
|
14036
|
+
|
14037
|
+
float * sums = (float *) params->wdata;
|
14038
|
+
|
14039
|
+
// TODO: handle transposed/permuted matrices
|
14040
|
+
const int nc = src0->ne[0];
|
14041
|
+
const int nr = ggml_nrows(src0);
|
14042
|
+
|
14043
|
+
if (params->type == GGML_TASK_INIT) {
|
14044
|
+
if (ith == 0) {
|
14045
|
+
memset(sums, 0, sizeof(float) * (nth + nth * nc));
|
14046
|
+
}
|
14047
|
+
return;
|
14048
|
+
}
|
14049
|
+
|
14050
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14051
|
+
if (ith == 0) {
|
14052
|
+
float * dp = (float *) dst->data;
|
14053
|
+
ggml_vec_sum_f32(nth, dp, sums);
|
14054
|
+
dp[0] *= -1.0f;
|
14055
|
+
}
|
14056
|
+
return;
|
14057
|
+
}
|
14058
|
+
|
14059
|
+
const double eps = 1e-9;
|
14060
|
+
|
14061
|
+
// rows per thread
|
14062
|
+
const int dr = (nr + nth - 1)/nth;
|
14063
|
+
|
14064
|
+
// row range for this thread
|
14065
|
+
const int ir0 = dr*ith;
|
14066
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
14067
|
+
|
14068
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
14069
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14070
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14071
|
+
float * st = (float *) params->wdata + nth + ith*nc;
|
14072
|
+
|
14073
|
+
#ifndef NDEBUG
|
14074
|
+
for (int i = 0; i < nc; ++i) {
|
14075
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14076
|
+
assert(!isnan(s0[i]));
|
14077
|
+
assert(!isnan(s1[i]));
|
14078
|
+
}
|
14079
|
+
#endif
|
14080
|
+
// soft_max
|
14081
|
+
ggml_float sum = 0.0;
|
14082
|
+
{
|
14083
|
+
float max = -INFINITY;
|
14084
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14085
|
+
|
14086
|
+
uint16_t scvt;
|
14087
|
+
for (int i = 0; i < nc; i++) {
|
14088
|
+
if (s0[i] == -INFINITY) {
|
14089
|
+
st[i] = 0.0f;
|
14090
|
+
} else {
|
14091
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14092
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14093
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14094
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14095
|
+
sum += (ggml_float)val;
|
14096
|
+
st[i] = val;
|
14097
|
+
}
|
14098
|
+
}
|
14099
|
+
|
14100
|
+
assert(sum > 0.0);
|
14101
|
+
// sum = 1.0/sum;
|
14102
|
+
}
|
14103
|
+
// avoid log(0) by rescaling from [0..1] to [eps..1]
|
14104
|
+
sum = (1.0 - eps) / sum;
|
14105
|
+
ggml_vec_scale_f32(nc, st, sum);
|
14106
|
+
ggml_vec_add1_f32(nc, st, st, eps);
|
14107
|
+
ggml_vec_log_f32(nc, st, st);
|
14108
|
+
ggml_vec_mul_f32(nc, st, st, s1);
|
14109
|
+
|
14110
|
+
ggml_vec_sum_f32(nc, sums + ith, st);
|
14111
|
+
|
14112
|
+
#ifndef NDEBUG
|
14113
|
+
for (int i = 0; i < nc; ++i) {
|
14114
|
+
assert(!isnan(st[i]));
|
14115
|
+
assert(!isinf(st[i]));
|
14116
|
+
}
|
14117
|
+
#endif
|
14118
|
+
}
|
14119
|
+
|
14120
|
+
}
|
14121
|
+
|
14122
|
+
static void ggml_compute_forward_cross_entropy_loss(
|
14123
|
+
const struct ggml_compute_params * params,
|
14124
|
+
const struct ggml_tensor * src0,
|
14125
|
+
const struct ggml_tensor * src1,
|
14126
|
+
struct ggml_tensor * dst) {
|
12973
14127
|
switch (src0->type) {
|
12974
14128
|
case GGML_TYPE_F32:
|
12975
14129
|
{
|
12976
|
-
|
14130
|
+
ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
|
12977
14131
|
} break;
|
12978
14132
|
default:
|
12979
14133
|
{
|
@@ -12982,47 +14136,160 @@ static void ggml_compute_forward_map_unary(
|
|
12982
14136
|
}
|
12983
14137
|
}
|
12984
14138
|
|
12985
|
-
//
|
14139
|
+
// ggml_compute_forward_cross_entropy_loss_back
|
12986
14140
|
|
12987
|
-
static void
|
14141
|
+
static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
12988
14142
|
const struct ggml_compute_params * params,
|
12989
14143
|
const struct ggml_tensor * src0,
|
12990
14144
|
const struct ggml_tensor * src1,
|
12991
|
-
struct ggml_tensor *
|
12992
|
-
|
12993
|
-
|
12994
|
-
|
14145
|
+
const struct ggml_tensor * opt0,
|
14146
|
+
struct ggml_tensor * dst) {
|
14147
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
14148
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14149
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14150
|
+
GGML_ASSERT(ggml_is_contiguous(opt0));
|
14151
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14152
|
+
|
14153
|
+
const int64_t ith = params->ith;
|
14154
|
+
const int64_t nth = params->nth;
|
12995
14155
|
|
12996
14156
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
12997
14157
|
return;
|
12998
14158
|
}
|
12999
14159
|
|
13000
|
-
const
|
13001
|
-
const int nc = src0->ne[0];
|
14160
|
+
const float eps = 1e-9f;
|
13002
14161
|
|
13003
|
-
|
13004
|
-
|
13005
|
-
|
14162
|
+
// TODO: handle transposed/permuted matrices
|
14163
|
+
const int64_t nc = src0->ne[0];
|
14164
|
+
const int64_t nr = ggml_nrows(src0);
|
13006
14165
|
|
13007
|
-
|
13008
|
-
|
13009
|
-
|
13010
|
-
|
13011
|
-
|
14166
|
+
// rows per thread
|
14167
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
14168
|
+
|
14169
|
+
// row range for this thread
|
14170
|
+
const int64_t ir0 = dr*ith;
|
14171
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
14172
|
+
|
14173
|
+
float * d = (float *) opt0->data;
|
14174
|
+
|
14175
|
+
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
14176
|
+
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
14177
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14178
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14179
|
+
float * sm = (float *) params->wdata + ith*nc;
|
14180
|
+
|
14181
|
+
#ifndef NDEBUG
|
14182
|
+
for (int i = 0; i < nc; ++i) {
|
14183
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14184
|
+
assert(!isnan(s0[i]));
|
14185
|
+
assert(!isnan(s1[i]));
|
14186
|
+
}
|
14187
|
+
#endif
|
14188
|
+
// step by step explanation:
|
14189
|
+
{
|
14190
|
+
//float * sums = (float *) params->wdata;
|
14191
|
+
|
14192
|
+
// forward pass with annotated gradients from backward pass
|
14193
|
+
// (built by going in reverse operation order, adding to gradients of current operation args)
|
14194
|
+
// st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
|
14195
|
+
// from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14196
|
+
// ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
|
14197
|
+
// ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
|
14198
|
+
// ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
|
14199
|
+
// ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
|
14200
|
+
// ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
|
14201
|
+
// ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
|
14202
|
+
|
14203
|
+
// substitute into grad[st1], because we can reuse softmax_back from this point on
|
14204
|
+
// grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
|
14205
|
+
// postorder:
|
14206
|
+
// grad[st1] := softmax(s0)
|
14207
|
+
// grad[st1] := grad[st1]*(1.0 - eps)
|
14208
|
+
// grad[st1] := grad[st1] + eps
|
14209
|
+
// grad[st1] := s1 / grad[st1]
|
14210
|
+
// grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
|
14211
|
+
|
14212
|
+
// src0 gradients by going through softmax_back
|
14213
|
+
// grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14214
|
+
// from softmax_back:
|
14215
|
+
// dxk = yk * (dyk - dot(y, dy))
|
14216
|
+
// dot_y_dy := dot(y, dy)
|
14217
|
+
// dx := dy
|
14218
|
+
// dx := dx - dot_y_dy
|
14219
|
+
// dx := dx * y
|
14220
|
+
// postorder:
|
14221
|
+
// dot_st1_dst1 := dot(st1, grad[st1])
|
14222
|
+
// grad[s0] := grad[st1]
|
14223
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14224
|
+
// grad[s0] := grad[s0] * st1
|
14225
|
+
|
14226
|
+
// prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
|
14227
|
+
// sm := softmax(s0)
|
14228
|
+
// grad[s0] := sm*(1.0 - eps)
|
14229
|
+
// grad[s0] := grad[s0] + eps
|
14230
|
+
// grad[s0] := s1 / grad[s0]
|
14231
|
+
// grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
|
14232
|
+
// dot_st1_dst1 := dot(sm, grad[s0])
|
14233
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14234
|
+
// grad[s0] := grad[s0] * sm
|
14235
|
+
}
|
14236
|
+
|
14237
|
+
// soft_max
|
14238
|
+
ggml_float sum = 0.0;
|
14239
|
+
{
|
14240
|
+
float max = -INFINITY;
|
14241
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14242
|
+
|
14243
|
+
uint16_t scvt;
|
14244
|
+
for (int i = 0; i < nc; i++) {
|
14245
|
+
if (s0[i] == -INFINITY) {
|
14246
|
+
sm[i] = 0.0f;
|
14247
|
+
} else {
|
14248
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14249
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14250
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14251
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14252
|
+
sum += (ggml_float)val;
|
14253
|
+
sm[i] = val;
|
14254
|
+
}
|
14255
|
+
}
|
14256
|
+
|
14257
|
+
assert(sum > 0.0);
|
14258
|
+
sum = 1.0/sum;
|
14259
|
+
}
|
14260
|
+
|
14261
|
+
float dot_st1_dst1 = 0;
|
14262
|
+
ggml_vec_scale_f32(nc, sm, sum);
|
14263
|
+
ggml_vec_cpy_f32 (nc, ds0, sm);
|
14264
|
+
ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
|
14265
|
+
ggml_vec_add1_f32 (nc, ds0, ds0, eps);
|
14266
|
+
ggml_vec_div_f32 (nc, ds0, s1, ds0);
|
14267
|
+
ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
|
14268
|
+
ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
|
14269
|
+
ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
|
14270
|
+
ggml_vec_mul_f32 (nc, ds0, ds0, sm);
|
14271
|
+
|
14272
|
+
#ifndef NDEBUG
|
14273
|
+
for (int i = 0; i < nc; ++i) {
|
14274
|
+
assert(!isnan(sm[i]));
|
14275
|
+
assert(!isinf(sm[i]));
|
14276
|
+
assert(!isnan(ds0[i]));
|
14277
|
+
assert(!isinf(ds0[i]));
|
14278
|
+
}
|
14279
|
+
#endif
|
13012
14280
|
}
|
13013
14281
|
}
|
13014
14282
|
|
13015
|
-
|
13016
|
-
static void ggml_compute_forward_map_binary(
|
14283
|
+
static void ggml_compute_forward_cross_entropy_loss_back(
|
13017
14284
|
const struct ggml_compute_params * params,
|
13018
14285
|
const struct ggml_tensor * src0,
|
13019
14286
|
const struct ggml_tensor * src1,
|
13020
|
-
struct ggml_tensor *
|
13021
|
-
|
14287
|
+
const struct ggml_tensor * opt0,
|
14288
|
+
struct ggml_tensor * dst) {
|
13022
14289
|
switch (src0->type) {
|
13023
14290
|
case GGML_TYPE_F32:
|
13024
14291
|
{
|
13025
|
-
|
14292
|
+
ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
|
13026
14293
|
} break;
|
13027
14294
|
default:
|
13028
14295
|
{
|
@@ -13031,6 +14298,7 @@ static void ggml_compute_forward_map_binary(
|
|
13031
14298
|
}
|
13032
14299
|
}
|
13033
14300
|
|
14301
|
+
|
13034
14302
|
/////////////////////////////////
|
13035
14303
|
|
13036
14304
|
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
@@ -13102,6 +14370,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13102
14370
|
{
|
13103
14371
|
ggml_compute_forward_repeat(params, tensor->src0, tensor);
|
13104
14372
|
} break;
|
14373
|
+
case GGML_OP_REPEAT_BACK:
|
14374
|
+
{
|
14375
|
+
ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
|
14376
|
+
} break;
|
13105
14377
|
case GGML_OP_ABS:
|
13106
14378
|
{
|
13107
14379
|
ggml_compute_forward_abs(params, tensor->src0, tensor);
|
@@ -13150,6 +14422,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13150
14422
|
{
|
13151
14423
|
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
|
13152
14424
|
} break;
|
14425
|
+
case GGML_OP_OUT_PROD:
|
14426
|
+
{
|
14427
|
+
ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
|
14428
|
+
} break;
|
13153
14429
|
case GGML_OP_SCALE:
|
13154
14430
|
{
|
13155
14431
|
ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
|
@@ -13206,6 +14482,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13206
14482
|
{
|
13207
14483
|
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
13208
14484
|
} break;
|
14485
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14486
|
+
{
|
14487
|
+
ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
|
14488
|
+
} break;
|
13209
14489
|
case GGML_OP_ROPE:
|
13210
14490
|
{
|
13211
14491
|
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
@@ -13241,6 +14521,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13241
14521
|
{
|
13242
14522
|
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
13243
14523
|
} break;
|
14524
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
14525
|
+
{
|
14526
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
|
14527
|
+
GGML_ASSERT(t == 0 || t == 1);
|
14528
|
+
bool masked = t != 0;
|
14529
|
+
ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
|
14530
|
+
} break;
|
13244
14531
|
case GGML_OP_MAP_UNARY:
|
13245
14532
|
{
|
13246
14533
|
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
@@ -13253,6 +14540,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13253
14540
|
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
|
13254
14541
|
}
|
13255
14542
|
break;
|
14543
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
14544
|
+
{
|
14545
|
+
ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
|
14546
|
+
}
|
14547
|
+
break;
|
14548
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
14549
|
+
{
|
14550
|
+
ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
|
14551
|
+
}
|
14552
|
+
break;
|
13256
14553
|
case GGML_OP_NONE:
|
13257
14554
|
{
|
13258
14555
|
// nop
|
@@ -13391,11 +14688,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13391
14688
|
src0->grad =
|
13392
14689
|
ggml_add_impl(ctx,
|
13393
14690
|
src0->grad,
|
13394
|
-
|
13395
|
-
tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
|
14691
|
+
ggml_scale(ctx,
|
13396
14692
|
ggml_div(ctx,
|
13397
|
-
|
13398
|
-
tensor)
|
14693
|
+
tensor->grad,
|
14694
|
+
tensor),
|
14695
|
+
ggml_new_f32(ctx, 0.5f)),
|
13399
14696
|
inplace);
|
13400
14697
|
}
|
13401
14698
|
} break;
|
@@ -13441,43 +14738,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13441
14738
|
{
|
13442
14739
|
// necessary for llama
|
13443
14740
|
if (src0->grad) {
|
13444
|
-
|
13445
|
-
|
13446
|
-
|
13447
|
-
|
13448
|
-
|
13449
|
-
|
13450
|
-
|
13451
|
-
|
13452
|
-
|
13453
|
-
//
|
13454
|
-
|
13455
|
-
|
13456
|
-
|
13457
|
-
|
13458
|
-
// transpose [nc0*nr0,1,1]
|
13459
|
-
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
|
13460
|
-
// add to src0->grad
|
13461
|
-
|
13462
|
-
int64_t ne[4] = {nc0,ncr,nr0,nrr};
|
13463
|
-
|
13464
|
-
struct ggml_tensor* F00 = tensor->grad;
|
13465
|
-
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
|
13466
|
-
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
|
13467
|
-
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
|
13468
|
-
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
|
13469
|
-
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
|
13470
|
-
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
|
13471
|
-
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
|
13472
|
-
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
|
13473
|
-
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
|
13474
|
-
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
|
13475
|
-
|
13476
|
-
src0->grad =
|
13477
|
-
ggml_add_impl(ctx,
|
13478
|
-
src0->grad,
|
13479
|
-
F10,
|
13480
|
-
inplace);
|
14741
|
+
src0->grad = ggml_add_impl(ctx,
|
14742
|
+
src0->grad,
|
14743
|
+
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
14744
|
+
inplace);
|
14745
|
+
}
|
14746
|
+
} break;
|
14747
|
+
case GGML_OP_REPEAT_BACK:
|
14748
|
+
{
|
14749
|
+
if (src0->grad) {
|
14750
|
+
// TODO: test this
|
14751
|
+
src0->grad = ggml_add_impl(ctx,
|
14752
|
+
src0->grad,
|
14753
|
+
ggml_repeat(ctx, tensor->grad, src0->grad),
|
14754
|
+
inplace);
|
13481
14755
|
}
|
13482
14756
|
} break;
|
13483
14757
|
case GGML_OP_ABS:
|
@@ -13584,38 +14858,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13584
14858
|
|
13585
14859
|
// necessary for llama
|
13586
14860
|
if (src0->grad) {
|
13587
|
-
// TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
|
13588
14861
|
src0->grad =
|
13589
14862
|
ggml_add_impl(ctx,
|
13590
14863
|
src0->grad,
|
13591
|
-
//
|
13592
|
-
|
13593
|
-
|
13594
|
-
// tensor->grad), // [m,p]
|
13595
|
-
// for now just using A*B==(B.T*A.T).T
|
13596
|
-
ggml_cont(ctx, // [n,m]
|
13597
|
-
ggml_transpose(ctx, // [n,m]
|
13598
|
-
ggml_mul_mat(ctx, // [m,n]
|
13599
|
-
ggml_cont(ctx, // [p,m]
|
13600
|
-
ggml_transpose(ctx, // [p,m]
|
13601
|
-
tensor->grad)), // [m,p]
|
13602
|
-
ggml_cont(ctx, // [p,n]
|
13603
|
-
ggml_transpose(ctx, // [p,n]
|
13604
|
-
src1))))), // [n,p]
|
14864
|
+
ggml_out_prod(ctx, // [n,m]
|
14865
|
+
src1, // [n,p]
|
14866
|
+
tensor->grad), // [m,p]
|
13605
14867
|
inplace);
|
13606
14868
|
}
|
13607
14869
|
if (src1->grad) {
|
13608
14870
|
src1->grad =
|
13609
14871
|
ggml_add_impl(ctx,
|
13610
14872
|
src1->grad,
|
13611
|
-
//
|
13612
|
-
|
13613
|
-
|
13614
|
-
|
13615
|
-
|
14873
|
+
// ggml_mul_mat(ctx, // [n,p]
|
14874
|
+
// ggml_cont(ctx, // [m,n]
|
14875
|
+
// ggml_transpose(ctx, src0)), // [m,n]
|
14876
|
+
// tensor->grad), // [m,p]
|
14877
|
+
|
14878
|
+
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
14879
|
+
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
14880
|
+
// // and then use ggml_out_prod
|
14881
|
+
ggml_out_prod(ctx, // [n,p]
|
14882
|
+
src0, // [n,m]
|
14883
|
+
ggml_transpose(ctx, // [p,m]
|
14884
|
+
tensor->grad)), // [m,p]
|
13616
14885
|
inplace);
|
13617
14886
|
}
|
13618
14887
|
} break;
|
14888
|
+
case GGML_OP_OUT_PROD:
|
14889
|
+
{
|
14890
|
+
GGML_ASSERT(false); // TODO: not implemented
|
14891
|
+
} break;
|
13619
14892
|
case GGML_OP_SCALE:
|
13620
14893
|
{
|
13621
14894
|
// necessary for llama
|
@@ -13717,7 +14990,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13717
14990
|
// necessary for llama
|
13718
14991
|
if (src0->grad) {
|
13719
14992
|
size_t offset;
|
13720
|
-
|
14993
|
+
|
14994
|
+
GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
|
14995
|
+
memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
|
13721
14996
|
|
13722
14997
|
size_t nb1 = tensor->nb[1];
|
13723
14998
|
size_t nb2 = tensor->nb[2];
|
@@ -13744,10 +15019,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13744
15019
|
{
|
13745
15020
|
// necessary for llama
|
13746
15021
|
if (src0->grad) {
|
13747
|
-
|
13748
|
-
int
|
13749
|
-
int
|
13750
|
-
int
|
15022
|
+
int32_t * axes = (int32_t *) tensor->opt[0]->data;
|
15023
|
+
int axis0 = axes[0] & 0x3;
|
15024
|
+
int axis1 = axes[1] & 0x3;
|
15025
|
+
int axis2 = axes[2] & 0x3;
|
15026
|
+
int axis3 = axes[3] & 0x3;
|
13751
15027
|
int axes_backward[4] = {0,0,0,0};
|
13752
15028
|
axes_backward[axis0] = 0;
|
13753
15029
|
axes_backward[axis1] = 1;
|
@@ -13831,50 +15107,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13831
15107
|
{
|
13832
15108
|
// necessary for llama
|
13833
15109
|
if (src0->grad) {
|
13834
|
-
// y = softmax(x)
|
13835
|
-
//
|
13836
|
-
// Jii = yi - yi*yi
|
13837
|
-
// Jij = -yi*yj
|
13838
|
-
// J = diag(y)-y.*y
|
13839
|
-
// dx = J * dy
|
13840
|
-
// dxk = sum(Jkj * dyk)
|
13841
|
-
|
13842
|
-
int64_t ne2[4] = {
|
13843
|
-
tensor->ne[0],
|
13844
|
-
1,
|
13845
|
-
tensor->ne[1]*tensor->ne[2],
|
13846
|
-
tensor->ne[3]
|
13847
|
-
};
|
13848
|
-
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
13849
|
-
ggml_reshape_4d(ctx,
|
13850
|
-
ggml_cont(ctx, tensor),
|
13851
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13852
|
-
|
13853
|
-
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
13854
|
-
ggml_reshape_4d(ctx,
|
13855
|
-
ggml_cont(ctx, tensor->grad),
|
13856
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13857
|
-
|
13858
|
-
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
13859
|
-
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
13860
|
-
tensor2, // [ne0,1,ne1*ne2,ne3]
|
13861
|
-
1, 0, 2, 3));
|
13862
|
-
|
13863
15110
|
src0->grad =
|
13864
|
-
ggml_add_impl(ctx,
|
13865
|
-
|
13866
|
-
|
13867
|
-
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
13868
|
-
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13869
|
-
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13870
|
-
tensor2), // [ne0,1,ne1*ne2,ne3]
|
13871
|
-
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13872
|
-
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
13873
|
-
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
13874
|
-
grad2), // [ne0,1,ne1*ne2,ne3]
|
13875
|
-
src0->grad),
|
13876
|
-
inplace);
|
15111
|
+
ggml_add_impl(ctx, src0->grad,
|
15112
|
+
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
15113
|
+
inplace);
|
13877
15114
|
}
|
15115
|
+
|
15116
|
+
} break;
|
15117
|
+
case GGML_OP_SOFT_MAX_BACK:
|
15118
|
+
{
|
15119
|
+
GGML_ASSERT(false); // TODO: not implemented
|
13878
15120
|
} break;
|
13879
15121
|
case GGML_OP_ROPE:
|
13880
15122
|
{
|
@@ -13929,17 +15171,190 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13929
15171
|
} break;
|
13930
15172
|
case GGML_OP_FLASH_ATTN:
|
13931
15173
|
{
|
13932
|
-
|
15174
|
+
struct ggml_tensor * flash_grad = NULL;
|
15175
|
+
if (src0->grad || src1->grad || tensor->opt[0]->grad) {
|
15176
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
15177
|
+
GGML_ASSERT(t == 0 || t == 1);
|
15178
|
+
bool masked = t != 0;
|
15179
|
+
flash_grad =
|
15180
|
+
ggml_flash_attn_back(ctx,
|
15181
|
+
src0,
|
15182
|
+
src1,
|
15183
|
+
tensor->opt[0],
|
15184
|
+
tensor->grad,
|
15185
|
+
masked);
|
15186
|
+
}
|
15187
|
+
|
15188
|
+
if (src0->grad) {
|
15189
|
+
struct ggml_tensor * grad_q = NULL;
|
15190
|
+
const size_t nb0 = flash_grad->nb[0];
|
15191
|
+
const size_t offset = 0;
|
15192
|
+
switch(src0->n_dims) {
|
15193
|
+
case 2:
|
15194
|
+
{
|
15195
|
+
grad_q = ggml_view_2d(ctx,
|
15196
|
+
flash_grad,
|
15197
|
+
src0->ne[0],
|
15198
|
+
src0->ne[1],
|
15199
|
+
nb0*src0->ne[0],
|
15200
|
+
offset);
|
15201
|
+
} break;
|
15202
|
+
case 3:
|
15203
|
+
{
|
15204
|
+
grad_q = ggml_view_3d(ctx,
|
15205
|
+
flash_grad,
|
15206
|
+
src0->ne[0],
|
15207
|
+
src0->ne[1],
|
15208
|
+
src0->ne[2],
|
15209
|
+
nb0*src0->ne[0],
|
15210
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15211
|
+
offset);
|
15212
|
+
} break;
|
15213
|
+
case 4:
|
15214
|
+
{
|
15215
|
+
grad_q = ggml_view_4d(ctx,
|
15216
|
+
flash_grad,
|
15217
|
+
src0->ne[0],
|
15218
|
+
src0->ne[1],
|
15219
|
+
src0->ne[2],
|
15220
|
+
src0->ne[3],
|
15221
|
+
nb0*src0->ne[0],
|
15222
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15223
|
+
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
15224
|
+
offset);
|
15225
|
+
} break;
|
15226
|
+
}
|
15227
|
+
|
15228
|
+
src0->grad = ggml_add_impl(ctx,
|
15229
|
+
src0->grad,
|
15230
|
+
grad_q,
|
15231
|
+
inplace);
|
15232
|
+
}
|
15233
|
+
|
15234
|
+
if (src1->grad) {
|
15235
|
+
struct ggml_tensor * grad_k = NULL;
|
15236
|
+
const size_t nb0 = flash_grad->nb[0];
|
15237
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
|
15238
|
+
switch(src1->n_dims) {
|
15239
|
+
case 2:
|
15240
|
+
{
|
15241
|
+
grad_k = ggml_view_2d(ctx,
|
15242
|
+
flash_grad,
|
15243
|
+
src1->ne[0],
|
15244
|
+
src1->ne[1],
|
15245
|
+
nb0*src1->ne[0],
|
15246
|
+
offset);
|
15247
|
+
} break;
|
15248
|
+
case 3:
|
15249
|
+
{
|
15250
|
+
grad_k = ggml_view_3d(ctx,
|
15251
|
+
flash_grad,
|
15252
|
+
src1->ne[0],
|
15253
|
+
src1->ne[1],
|
15254
|
+
src1->ne[2],
|
15255
|
+
nb0*src1->ne[0],
|
15256
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15257
|
+
offset);
|
15258
|
+
} break;
|
15259
|
+
case 4:
|
15260
|
+
{
|
15261
|
+
grad_k = ggml_view_4d(ctx,
|
15262
|
+
flash_grad,
|
15263
|
+
src1->ne[0],
|
15264
|
+
src1->ne[1],
|
15265
|
+
src1->ne[2],
|
15266
|
+
src1->ne[3],
|
15267
|
+
nb0*src1->ne[0],
|
15268
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15269
|
+
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
15270
|
+
offset);
|
15271
|
+
} break;
|
15272
|
+
}
|
15273
|
+
|
15274
|
+
src1->grad = ggml_add_impl(ctx,
|
15275
|
+
src1->grad,
|
15276
|
+
grad_k,
|
15277
|
+
inplace);
|
15278
|
+
}
|
15279
|
+
|
15280
|
+
struct ggml_tensor * opt0 = tensor->opt[0];
|
15281
|
+
|
15282
|
+
if (opt0->grad) {
|
15283
|
+
struct ggml_tensor * grad_v = NULL;
|
15284
|
+
const size_t nb0 = flash_grad->nb[0];
|
15285
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
15286
|
+
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
15287
|
+
switch(opt0->n_dims) {
|
15288
|
+
case 2:
|
15289
|
+
{
|
15290
|
+
grad_v = ggml_view_2d(ctx,
|
15291
|
+
flash_grad,
|
15292
|
+
opt0->ne[0],
|
15293
|
+
opt0->ne[1],
|
15294
|
+
nb0*opt0->ne[0],
|
15295
|
+
offset);
|
15296
|
+
} break;
|
15297
|
+
case 3:
|
15298
|
+
{
|
15299
|
+
grad_v = ggml_view_3d(ctx,
|
15300
|
+
flash_grad,
|
15301
|
+
opt0->ne[0],
|
15302
|
+
opt0->ne[1],
|
15303
|
+
opt0->ne[2],
|
15304
|
+
nb0*opt0->ne[0],
|
15305
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15306
|
+
offset);
|
15307
|
+
} break;
|
15308
|
+
case 4:
|
15309
|
+
{
|
15310
|
+
grad_v = ggml_view_4d(ctx,
|
15311
|
+
flash_grad,
|
15312
|
+
opt0->ne[0],
|
15313
|
+
opt0->ne[1],
|
15314
|
+
opt0->ne[2],
|
15315
|
+
opt0->ne[3],
|
15316
|
+
nb0*opt0->ne[0],
|
15317
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15318
|
+
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
15319
|
+
offset);
|
15320
|
+
} break;
|
15321
|
+
}
|
15322
|
+
|
15323
|
+
opt0->grad = ggml_add_impl(ctx,
|
15324
|
+
opt0->grad,
|
15325
|
+
grad_v,
|
15326
|
+
inplace);
|
15327
|
+
}
|
13933
15328
|
} break;
|
13934
15329
|
case GGML_OP_FLASH_FF:
|
13935
15330
|
{
|
13936
15331
|
GGML_ASSERT(false); // not supported
|
13937
15332
|
} break;
|
15333
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15334
|
+
{
|
15335
|
+
GGML_ASSERT(false); // not supported
|
15336
|
+
} break;
|
13938
15337
|
case GGML_OP_MAP_UNARY:
|
13939
15338
|
case GGML_OP_MAP_BINARY:
|
13940
15339
|
{
|
13941
15340
|
GGML_ASSERT(false); // not supported
|
13942
15341
|
} break;
|
15342
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15343
|
+
{
|
15344
|
+
if (src0->grad) {
|
15345
|
+
src0->grad = ggml_add_impl(ctx,
|
15346
|
+
src0->grad,
|
15347
|
+
ggml_cross_entropy_loss_back(ctx,
|
15348
|
+
src0,
|
15349
|
+
src1,
|
15350
|
+
tensor->grad),
|
15351
|
+
inplace);
|
15352
|
+
}
|
15353
|
+
} break;
|
15354
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15355
|
+
{
|
15356
|
+
GGML_ASSERT(false); // not supported
|
15357
|
+
} break;
|
13943
15358
|
case GGML_OP_NONE:
|
13944
15359
|
{
|
13945
15360
|
// nop
|
@@ -14316,6 +15731,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14316
15731
|
case GGML_OP_SUM_ROWS:
|
14317
15732
|
case GGML_OP_MEAN:
|
14318
15733
|
case GGML_OP_REPEAT:
|
15734
|
+
case GGML_OP_REPEAT_BACK:
|
14319
15735
|
case GGML_OP_ABS:
|
14320
15736
|
case GGML_OP_SGN:
|
14321
15737
|
case GGML_OP_NEG:
|
@@ -14335,6 +15751,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14335
15751
|
node->n_tasks = n_threads;
|
14336
15752
|
} break;
|
14337
15753
|
case GGML_OP_MUL_MAT:
|
15754
|
+
case GGML_OP_OUT_PROD:
|
14338
15755
|
{
|
14339
15756
|
node->n_tasks = n_threads;
|
14340
15757
|
|
@@ -14417,6 +15834,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14417
15834
|
} break;
|
14418
15835
|
case GGML_OP_DIAG_MASK_INF:
|
14419
15836
|
case GGML_OP_SOFT_MAX:
|
15837
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14420
15838
|
case GGML_OP_ROPE:
|
14421
15839
|
case GGML_OP_ROPE_BACK:
|
14422
15840
|
{
|
@@ -14496,6 +15914,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14496
15914
|
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
14497
15915
|
}
|
14498
15916
|
|
15917
|
+
work_size = MAX(work_size, cur);
|
15918
|
+
} break;
|
15919
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15920
|
+
{
|
15921
|
+
node->n_tasks = n_threads;
|
15922
|
+
|
15923
|
+
size_t cur = 0;
|
15924
|
+
|
15925
|
+
const int64_t D = node->src0->ne[0];
|
15926
|
+
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
15927
|
+
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
15928
|
+
if (node->src1->type == GGML_TYPE_F32) {
|
15929
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
15930
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
15931
|
+
}
|
15932
|
+
|
15933
|
+
if (node->src1->type == GGML_TYPE_F16) {
|
15934
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
15935
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
15936
|
+
}
|
15937
|
+
|
14499
15938
|
work_size = MAX(work_size, cur);
|
14500
15939
|
} break;
|
14501
15940
|
case GGML_OP_MAP_UNARY:
|
@@ -14503,6 +15942,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14503
15942
|
{
|
14504
15943
|
node->n_tasks = 1;
|
14505
15944
|
} break;
|
15945
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15946
|
+
{
|
15947
|
+
node->n_tasks = n_threads;
|
15948
|
+
|
15949
|
+
size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
|
15950
|
+
|
15951
|
+
work_size = MAX(work_size, cur);
|
15952
|
+
} break;
|
15953
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15954
|
+
{
|
15955
|
+
node->n_tasks = n_threads;
|
15956
|
+
|
15957
|
+
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
|
15958
|
+
|
15959
|
+
work_size = MAX(work_size, cur);
|
15960
|
+
} break;
|
14506
15961
|
case GGML_OP_NONE:
|
14507
15962
|
{
|
14508
15963
|
node->n_tasks = 1;
|
@@ -15478,6 +16933,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
|
|
15478
16933
|
|
15479
16934
|
static enum ggml_opt_result ggml_opt_adam(
|
15480
16935
|
struct ggml_context * ctx,
|
16936
|
+
struct ggml_opt_context * opt,
|
15481
16937
|
struct ggml_opt_params params,
|
15482
16938
|
struct ggml_tensor * f,
|
15483
16939
|
struct ggml_cgraph * gf,
|
@@ -15503,25 +16959,29 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15503
16959
|
}
|
15504
16960
|
}
|
15505
16961
|
|
16962
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
|
16963
|
+
int iter = opt->iter;
|
16964
|
+
ggml_opt_init(opt->ctx, opt, params, nx);
|
16965
|
+
opt->iter = iter;
|
16966
|
+
}
|
16967
|
+
|
15506
16968
|
// constants
|
15507
|
-
const float
|
16969
|
+
const float sched = params.adam.sched;
|
16970
|
+
const float decay = params.adam.decay * sched;
|
16971
|
+
const float alpha = params.adam.alpha * sched;
|
15508
16972
|
const float beta1 = params.adam.beta1;
|
15509
16973
|
const float beta2 = params.adam.beta2;
|
15510
16974
|
const float eps = params.adam.eps;
|
15511
16975
|
|
15512
|
-
float * x =
|
15513
|
-
float * g1 =
|
15514
|
-
float * g2 =
|
15515
|
-
float * m =
|
15516
|
-
float * v =
|
15517
|
-
float * mh =
|
15518
|
-
float * vh =
|
16976
|
+
float * x = opt->adam.x->data; // view of the parameters
|
16977
|
+
float * g1 = opt->adam.g1->data; // gradient
|
16978
|
+
float * g2 = opt->adam.g2->data; // gradient squared
|
16979
|
+
float * m = opt->adam.m->data; // first moment
|
16980
|
+
float * v = opt->adam.v->data; // second moment
|
16981
|
+
float * mh = opt->adam.mh->data; // first moment hat
|
16982
|
+
float * vh = opt->adam.vh->data; // second moment hat
|
15519
16983
|
|
15520
|
-
float * pf = params.past > 0 ?
|
15521
|
-
|
15522
|
-
// initialize
|
15523
|
-
ggml_vec_set_f32(nx, m, 0.0f);
|
15524
|
-
ggml_vec_set_f32(nx, v, 0.0f);
|
16984
|
+
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
15525
16985
|
|
15526
16986
|
// update view
|
15527
16987
|
ggml_opt_get_params(np, ps, x);
|
@@ -15531,16 +16991,27 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15531
16991
|
ggml_set_f32 (f->grad, 1.0f);
|
15532
16992
|
ggml_graph_compute(ctx, gb);
|
15533
16993
|
|
15534
|
-
|
16994
|
+
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
16995
|
+
opt->adam.fx_best = opt->adam.fx_prev;
|
15535
16996
|
if (pf) {
|
15536
|
-
pf[
|
16997
|
+
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
15537
16998
|
}
|
15538
16999
|
|
15539
|
-
|
15540
|
-
|
17000
|
+
// initialize
|
17001
|
+
if (opt->just_initialized) {
|
17002
|
+
opt->adam.n_no_improvement = 0;
|
17003
|
+
opt->just_initialized = false;
|
17004
|
+
}
|
17005
|
+
|
17006
|
+
float * fx_best = &opt->adam.fx_best;
|
17007
|
+
float * fx_prev = &opt->adam.fx_prev;
|
17008
|
+
int * n_no_improvement = &opt->adam.n_no_improvement;
|
17009
|
+
|
17010
|
+
int iter0 = opt->iter;
|
15541
17011
|
|
15542
17012
|
// run the optimizer
|
15543
17013
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
17014
|
+
opt->iter = iter0 + t + 1;
|
15544
17015
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
15545
17016
|
|
15546
17017
|
GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
|
@@ -15574,17 +17045,22 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15574
17045
|
|
15575
17046
|
// m^hat = m_t / (1 - beta1^t)
|
15576
17047
|
// v^hat = v_t / (1 - beta2^t)
|
15577
|
-
// x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
|
17048
|
+
// x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
|
17049
|
+
// x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
|
17050
|
+
// x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
|
17051
|
+
// x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
|
17052
|
+
// x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
|
15578
17053
|
ggml_vec_cpy_f32 (nx, mh, m);
|
15579
17054
|
ggml_vec_cpy_f32 (nx, vh, v);
|
15580
17055
|
|
15581
|
-
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1,
|
15582
|
-
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2,
|
17056
|
+
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
|
17057
|
+
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
|
15583
17058
|
|
15584
17059
|
ggml_vec_sqrt_f32 (nx, vh, vh);
|
15585
17060
|
ggml_vec_acc1_f32 (nx, vh, eps);
|
15586
17061
|
|
15587
17062
|
ggml_vec_div_f32 (nx, mh, mh, vh);
|
17063
|
+
ggml_vec_scale_f32(nx, x, 1.0f - decay);
|
15588
17064
|
ggml_vec_sub_f32 (nx, x, x, mh);
|
15589
17065
|
|
15590
17066
|
// update the parameters
|
@@ -15598,7 +17074,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15598
17074
|
const float fx = ggml_get_f32_1d(f, 0);
|
15599
17075
|
|
15600
17076
|
// check convergence
|
15601
|
-
if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
|
17077
|
+
if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
|
15602
17078
|
GGML_PRINT_DEBUG("converged\n");
|
15603
17079
|
|
15604
17080
|
return GGML_OPT_OK;
|
@@ -15607,32 +17083,32 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15607
17083
|
// delta-based convergence test
|
15608
17084
|
if (pf != NULL) {
|
15609
17085
|
// need at least params.past iterations to start checking for convergence
|
15610
|
-
if (params.past <= t) {
|
15611
|
-
const float rate = (pf[t%params.past] - fx)/fx;
|
17086
|
+
if (params.past <= iter0 + t) {
|
17087
|
+
const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
|
15612
17088
|
|
15613
17089
|
if (fabsf(rate) < params.delta) {
|
15614
17090
|
return GGML_OPT_OK;
|
15615
17091
|
}
|
15616
17092
|
}
|
15617
17093
|
|
15618
|
-
pf[t%params.past] = fx;
|
17094
|
+
pf[(iter0 + t)%params.past] = fx;
|
15619
17095
|
}
|
15620
17096
|
|
15621
17097
|
// check for improvement
|
15622
17098
|
if (params.max_no_improvement > 0) {
|
15623
|
-
if (fx_best > fx) {
|
15624
|
-
fx_best = fx;
|
15625
|
-
n_no_improvement = 0;
|
17099
|
+
if (fx_best[0] > fx) {
|
17100
|
+
fx_best[0] = fx;
|
17101
|
+
n_no_improvement[0] = 0;
|
15626
17102
|
} else {
|
15627
|
-
++n_no_improvement;
|
17103
|
+
++n_no_improvement[0];
|
15628
17104
|
|
15629
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17105
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15630
17106
|
return GGML_OPT_OK;
|
15631
17107
|
}
|
15632
17108
|
}
|
15633
17109
|
}
|
15634
17110
|
|
15635
|
-
fx_prev = fx;
|
17111
|
+
fx_prev[0] = fx;
|
15636
17112
|
|
15637
17113
|
{
|
15638
17114
|
const int64_t t_end_cpu = ggml_cycles();
|
@@ -15771,6 +17247,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
15771
17247
|
|
15772
17248
|
static enum ggml_opt_result ggml_opt_lbfgs(
|
15773
17249
|
struct ggml_context * ctx,
|
17250
|
+
struct ggml_opt_context * opt,
|
15774
17251
|
struct ggml_opt_params params,
|
15775
17252
|
struct ggml_tensor * f,
|
15776
17253
|
struct ggml_cgraph * gf,
|
@@ -15803,31 +17280,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15803
17280
|
}
|
15804
17281
|
}
|
15805
17282
|
|
15806
|
-
|
15807
|
-
|
15808
|
-
|
15809
|
-
|
15810
|
-
|
17283
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
|
17284
|
+
int iter = opt->iter;
|
17285
|
+
ggml_opt_init(ctx, opt, params, nx);
|
17286
|
+
opt->iter = iter;
|
17287
|
+
}
|
17288
|
+
|
17289
|
+
float * x = opt->lbfgs.x->data; // current parameters
|
17290
|
+
float * xp = opt->lbfgs.xp->data; // previous parameters
|
17291
|
+
float * g = opt->lbfgs.g->data; // current gradient
|
17292
|
+
float * gp = opt->lbfgs.gp->data; // previous gradient
|
17293
|
+
float * d = opt->lbfgs.d->data; // search direction
|
15811
17294
|
|
15812
|
-
float * pf = params.past > 0 ?
|
17295
|
+
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
15813
17296
|
|
15814
17297
|
float fx = 0.0f; // cost function value
|
15815
17298
|
float xnorm = 0.0f; // ||x||
|
15816
17299
|
float gnorm = 0.0f; // ||g||
|
15817
|
-
float step = 0.0f;
|
15818
17300
|
|
15819
17301
|
// initialize x from the graph nodes
|
15820
17302
|
ggml_opt_get_params(np, ps, x);
|
15821
17303
|
|
15822
17304
|
// the L-BFGS memory
|
15823
|
-
|
15824
|
-
|
15825
|
-
|
15826
|
-
|
15827
|
-
lm[i].ys = 0.0f;
|
15828
|
-
lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15829
|
-
lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15830
|
-
}
|
17305
|
+
float * lm_alpha = opt->lbfgs.lmal->data;
|
17306
|
+
float * lm_ys = opt->lbfgs.lmys->data;
|
17307
|
+
float * lm_s = opt->lbfgs.lms->data;
|
17308
|
+
float * lm_y = opt->lbfgs.lmy->data;
|
15831
17309
|
|
15832
17310
|
// evaluate the function value and its gradient
|
15833
17311
|
{
|
@@ -15842,12 +17320,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15842
17320
|
fx = ggml_get_f32_1d(f, 0);
|
15843
17321
|
}
|
15844
17322
|
|
15845
|
-
if (pf) {
|
15846
|
-
pf[0] = fx;
|
15847
|
-
}
|
15848
|
-
|
15849
|
-
float fx_best = fx;
|
15850
|
-
|
15851
17323
|
// search direction = -gradient
|
15852
17324
|
ggml_vec_neg_f32(nx, d, g);
|
15853
17325
|
|
@@ -15864,26 +17336,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15864
17336
|
return GGML_OPT_OK;
|
15865
17337
|
}
|
15866
17338
|
|
15867
|
-
|
15868
|
-
|
17339
|
+
if (opt->just_initialized) {
|
17340
|
+
if (pf) {
|
17341
|
+
pf[0] = fx;
|
17342
|
+
}
|
17343
|
+
opt->lbfgs.fx_best = fx;
|
17344
|
+
|
17345
|
+
// initial step
|
17346
|
+
ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
|
17347
|
+
opt->lbfgs.j = 0;
|
17348
|
+
opt->lbfgs.k = 1;
|
17349
|
+
opt->lbfgs.end = 0;
|
17350
|
+
opt->lbfgs.n_no_improvement = 0;
|
17351
|
+
opt->just_initialized = false;
|
17352
|
+
}
|
17353
|
+
|
17354
|
+
float * fx_best = &opt->lbfgs.fx_best;
|
17355
|
+
float * step = &opt->lbfgs.step;
|
17356
|
+
int * j = &opt->lbfgs.j;
|
17357
|
+
int * k = &opt->lbfgs.k;
|
17358
|
+
int * end = &opt->lbfgs.end;
|
17359
|
+
int * n_no_improvement = &opt->lbfgs.n_no_improvement;
|
15869
17360
|
|
15870
|
-
int
|
15871
|
-
int
|
15872
|
-
int ls = 0;
|
15873
|
-
int end = 0;
|
15874
|
-
int bound = 0;
|
15875
|
-
int n_no_improvement = 0;
|
17361
|
+
int ls = 0;
|
17362
|
+
int bound = 0;
|
15876
17363
|
|
15877
17364
|
float ys = 0.0f;
|
15878
17365
|
float yy = 0.0f;
|
15879
17366
|
float beta = 0.0f;
|
15880
17367
|
|
17368
|
+
int it = 0;
|
17369
|
+
|
15881
17370
|
while (true) {
|
15882
17371
|
// store the current position and gradient vectors
|
15883
17372
|
ggml_vec_cpy_f32(nx, xp, x);
|
15884
17373
|
ggml_vec_cpy_f32(nx, gp, g);
|
15885
17374
|
|
15886
|
-
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d,
|
17375
|
+
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
|
15887
17376
|
|
15888
17377
|
if (ls < 0) {
|
15889
17378
|
// linesearch failed - go back to the previous point and return
|
@@ -15909,32 +17398,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15909
17398
|
// delta-based convergence test
|
15910
17399
|
if (pf != NULL) {
|
15911
17400
|
// need at least params.past iterations to start checking for convergence
|
15912
|
-
if (params.past <= k) {
|
15913
|
-
const float rate = (pf[k%params.past] - fx)/fx;
|
17401
|
+
if (params.past <= k[0]) {
|
17402
|
+
const float rate = (pf[k[0]%params.past] - fx)/fx;
|
15914
17403
|
|
15915
17404
|
if (fabsf(rate) < params.delta) {
|
15916
17405
|
return GGML_OPT_OK;
|
15917
17406
|
}
|
15918
17407
|
}
|
15919
17408
|
|
15920
|
-
pf[k%params.past] = fx;
|
17409
|
+
pf[k[0]%params.past] = fx;
|
15921
17410
|
}
|
15922
17411
|
|
15923
17412
|
// check for improvement
|
15924
17413
|
if (params.max_no_improvement > 0) {
|
15925
|
-
if (fx < fx_best) {
|
15926
|
-
fx_best = fx;
|
15927
|
-
n_no_improvement = 0;
|
17414
|
+
if (fx < fx_best[0]) {
|
17415
|
+
fx_best[0] = fx;
|
17416
|
+
n_no_improvement[0] = 0;
|
15928
17417
|
} else {
|
15929
|
-
n_no_improvement++;
|
17418
|
+
n_no_improvement[0]++;
|
15930
17419
|
|
15931
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17420
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15932
17421
|
return GGML_OPT_OK;
|
15933
17422
|
}
|
15934
17423
|
}
|
15935
17424
|
}
|
15936
17425
|
|
15937
|
-
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter <
|
17426
|
+
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
|
15938
17427
|
// reached the maximum number of iterations
|
15939
17428
|
return GGML_OPT_DID_NOT_CONVERGE;
|
15940
17429
|
}
|
@@ -15943,50 +17432,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15943
17432
|
// s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
|
15944
17433
|
// y_{k+1} = g_{k+1} - g_{k}.
|
15945
17434
|
//
|
15946
|
-
ggml_vec_sub_f32(nx,
|
15947
|
-
ggml_vec_sub_f32(nx,
|
17435
|
+
ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
|
17436
|
+
ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
|
15948
17437
|
|
15949
17438
|
// compute scalars ys and yy:
|
15950
17439
|
// ys = y^t \cdot s -> 1 / \rho.
|
15951
17440
|
// yy = y^t \cdot y.
|
15952
17441
|
//
|
15953
|
-
ggml_vec_dot_f32(nx, &ys,
|
15954
|
-
ggml_vec_dot_f32(nx, &yy,
|
17442
|
+
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
|
17443
|
+
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
|
15955
17444
|
|
15956
|
-
|
17445
|
+
lm_ys[end[0]] = ys;
|
15957
17446
|
|
15958
17447
|
// find new search direction
|
15959
17448
|
// ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
|
15960
17449
|
|
15961
|
-
bound = (m <= k) ? m : k;
|
15962
|
-
k++;
|
15963
|
-
|
17450
|
+
bound = (m <= k[0]) ? m : k[0];
|
17451
|
+
k[0]++;
|
17452
|
+
it++;
|
17453
|
+
end[0] = (end[0] + 1)%m;
|
15964
17454
|
|
15965
17455
|
// initialize search direction with -g
|
15966
17456
|
ggml_vec_neg_f32(nx, d, g);
|
15967
17457
|
|
15968
|
-
j = end;
|
17458
|
+
j[0] = end[0];
|
15969
17459
|
for (int i = 0; i < bound; ++i) {
|
15970
|
-
j = (j + m - 1) % m;
|
17460
|
+
j[0] = (j[0] + m - 1) % m;
|
15971
17461
|
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
|
15972
|
-
ggml_vec_dot_f32(nx, &
|
15973
|
-
|
17462
|
+
ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
|
17463
|
+
lm_alpha[j[0]] /= lm_ys[j[0]];
|
15974
17464
|
// q_{i} = q_{i+1} - \alpha_{i} y_{i}
|
15975
|
-
ggml_vec_mad_f32(nx, d,
|
17465
|
+
ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
|
15976
17466
|
}
|
15977
17467
|
|
15978
17468
|
ggml_vec_scale_f32(nx, d, ys/yy);
|
15979
17469
|
|
15980
17470
|
for (int i = 0; i < bound; ++i) {
|
15981
17471
|
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
|
15982
|
-
ggml_vec_dot_f32(nx, &beta,
|
15983
|
-
beta /=
|
17472
|
+
ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
|
17473
|
+
beta /= lm_ys[j[0]];
|
15984
17474
|
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
|
15985
|
-
ggml_vec_mad_f32(nx, d,
|
15986
|
-
j = (j + 1)%m;
|
17475
|
+
ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
|
17476
|
+
j[0] = (j[0] + 1)%m;
|
15987
17477
|
}
|
15988
17478
|
|
15989
|
-
step = 1.0;
|
17479
|
+
step[0] = 1.0;
|
15990
17480
|
}
|
15991
17481
|
|
15992
17482
|
return GGML_OPT_DID_NOT_CONVERGE;
|
@@ -16011,6 +17501,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
16011
17501
|
|
16012
17502
|
.adam = {
|
16013
17503
|
.n_iter = 10000,
|
17504
|
+
.sched = 1.000f,
|
17505
|
+
.decay = 0.001f,
|
16014
17506
|
.alpha = 0.001f,
|
16015
17507
|
.beta1 = 0.9f,
|
16016
17508
|
.beta2 = 0.999f,
|
@@ -16053,6 +17545,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
16053
17545
|
return result;
|
16054
17546
|
}
|
16055
17547
|
|
17548
|
+
GGML_API void ggml_opt_init(
|
17549
|
+
struct ggml_context * ctx,
|
17550
|
+
struct ggml_opt_context * opt,
|
17551
|
+
struct ggml_opt_params params,
|
17552
|
+
int64_t nx) {
|
17553
|
+
opt->ctx = ctx;
|
17554
|
+
opt->params = params;
|
17555
|
+
opt->iter = 0;
|
17556
|
+
opt->nx = nx;
|
17557
|
+
opt->just_initialized = true;
|
17558
|
+
switch (opt->params.type) {
|
17559
|
+
case GGML_OPT_ADAM:
|
17560
|
+
{
|
17561
|
+
opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17562
|
+
opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17563
|
+
opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17564
|
+
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17565
|
+
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17566
|
+
opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17567
|
+
opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17568
|
+
opt->adam.pf = params.past > 0
|
17569
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
17570
|
+
: NULL;
|
17571
|
+
ggml_set_zero(opt->adam.x);
|
17572
|
+
ggml_set_zero(opt->adam.g1);
|
17573
|
+
ggml_set_zero(opt->adam.g2);
|
17574
|
+
ggml_set_zero(opt->adam.m);
|
17575
|
+
ggml_set_zero(opt->adam.v);
|
17576
|
+
ggml_set_zero(opt->adam.mh);
|
17577
|
+
ggml_set_zero(opt->adam.vh);
|
17578
|
+
if (opt->adam.pf) {
|
17579
|
+
ggml_set_zero(opt->adam.pf);
|
17580
|
+
}
|
17581
|
+
} break;
|
17582
|
+
case GGML_OPT_LBFGS:
|
17583
|
+
{
|
17584
|
+
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17585
|
+
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17586
|
+
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17587
|
+
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17588
|
+
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17589
|
+
opt->lbfgs.pf = params.past > 0
|
17590
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
17591
|
+
: NULL;
|
17592
|
+
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
17593
|
+
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
17594
|
+
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
17595
|
+
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
17596
|
+
ggml_set_zero(opt->lbfgs.x);
|
17597
|
+
ggml_set_zero(opt->lbfgs.xp);
|
17598
|
+
ggml_set_zero(opt->lbfgs.g);
|
17599
|
+
ggml_set_zero(opt->lbfgs.gp);
|
17600
|
+
ggml_set_zero(opt->lbfgs.d);
|
17601
|
+
ggml_set_zero(opt->lbfgs.pf);
|
17602
|
+
if (opt->lbfgs.pf) {
|
17603
|
+
ggml_set_zero(opt->lbfgs.pf);
|
17604
|
+
}
|
17605
|
+
ggml_set_zero(opt->lbfgs.lmal);
|
17606
|
+
ggml_set_zero(opt->lbfgs.lmys);
|
17607
|
+
ggml_set_zero(opt->lbfgs.lms);
|
17608
|
+
ggml_set_zero(opt->lbfgs.lmy);
|
17609
|
+
} break;
|
17610
|
+
}
|
17611
|
+
}
|
17612
|
+
|
16056
17613
|
enum ggml_opt_result ggml_opt(
|
16057
17614
|
struct ggml_context * ctx,
|
16058
17615
|
struct ggml_opt_params params,
|
@@ -16075,33 +17632,65 @@ enum ggml_opt_result ggml_opt(
|
|
16075
17632
|
|
16076
17633
|
enum ggml_opt_result result = GGML_OPT_OK;
|
16077
17634
|
|
17635
|
+
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
|
17636
|
+
|
17637
|
+
ggml_opt_init(ctx, opt, params, 0);
|
17638
|
+
result = ggml_opt_resume(ctx, opt, f);
|
17639
|
+
|
17640
|
+
if (free_ctx) {
|
17641
|
+
ggml_free(ctx);
|
17642
|
+
}
|
17643
|
+
|
17644
|
+
return result;
|
17645
|
+
}
|
17646
|
+
|
17647
|
+
enum ggml_opt_result ggml_opt_resume(
|
17648
|
+
struct ggml_context * ctx,
|
17649
|
+
struct ggml_opt_context * opt,
|
17650
|
+
struct ggml_tensor * f) {
|
17651
|
+
|
17652
|
+
// build forward + backward compute graphs
|
17653
|
+
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
17654
|
+
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
17655
|
+
|
17656
|
+
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
17657
|
+
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
17658
|
+
|
17659
|
+
*gf = ggml_build_forward (f);
|
17660
|
+
*gb = ggml_build_backward(ctx, gf, true);
|
17661
|
+
|
17662
|
+
return ggml_opt_resume_g(ctx, opt, f, gf, gb);
|
17663
|
+
}
|
17664
|
+
|
17665
|
+
enum ggml_opt_result ggml_opt_resume_g(
|
17666
|
+
struct ggml_context * ctx,
|
17667
|
+
struct ggml_opt_context * opt,
|
17668
|
+
struct ggml_tensor * f,
|
17669
|
+
struct ggml_cgraph * gf,
|
17670
|
+
struct ggml_cgraph * gb) {
|
17671
|
+
|
16078
17672
|
// build forward + backward compute graphs
|
16079
|
-
|
16080
|
-
struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
|
17673
|
+
enum ggml_opt_result result = GGML_OPT_OK;
|
16081
17674
|
|
16082
|
-
switch (params.type) {
|
17675
|
+
switch (opt->params.type) {
|
16083
17676
|
case GGML_OPT_ADAM:
|
16084
17677
|
{
|
16085
|
-
result = ggml_opt_adam(ctx, params, f,
|
17678
|
+
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
|
16086
17679
|
} break;
|
16087
17680
|
case GGML_OPT_LBFGS:
|
16088
17681
|
{
|
16089
|
-
result = ggml_opt_lbfgs(ctx, params, f,
|
17682
|
+
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
|
16090
17683
|
} break;
|
16091
17684
|
}
|
16092
17685
|
|
16093
|
-
if (params.print_forward_graph) {
|
16094
|
-
ggml_graph_print (
|
16095
|
-
ggml_graph_dump_dot(
|
16096
|
-
}
|
16097
|
-
|
16098
|
-
if (params.print_backward_graph) {
|
16099
|
-
ggml_graph_print (&gb);
|
16100
|
-
ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
|
17686
|
+
if (opt->params.print_forward_graph) {
|
17687
|
+
ggml_graph_print (gf);
|
17688
|
+
ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
|
16101
17689
|
}
|
16102
17690
|
|
16103
|
-
if (
|
16104
|
-
|
17691
|
+
if (opt->params.print_backward_graph) {
|
17692
|
+
ggml_graph_print (gb);
|
17693
|
+
ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
|
16105
17694
|
}
|
16106
17695
|
|
16107
17696
|
return result;
|
@@ -16301,6 +17890,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
16301
17890
|
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
16302
17891
|
} break;
|
16303
17892
|
#endif
|
17893
|
+
case GGML_TYPE_F16:
|
17894
|
+
{
|
17895
|
+
int elemsize = sizeof(ggml_fp16_t);
|
17896
|
+
ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
|
17897
|
+
result = n * elemsize;
|
17898
|
+
} break;
|
17899
|
+
case GGML_TYPE_F32:
|
17900
|
+
{
|
17901
|
+
int elemsize = sizeof(float);
|
17902
|
+
result = n * elemsize;
|
17903
|
+
memcpy((uint8_t *)dst + start * elemsize, src + start, result);
|
17904
|
+
} break;
|
16304
17905
|
default:
|
16305
17906
|
assert(false);
|
16306
17907
|
}
|