llama_cpp 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/examples/README.md +60 -0
- data/examples/chat.rb +195 -0
- data/ext/llama_cpp/llama_cpp.cpp +52 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +697 -130
- data/ext/llama_cpp/src/ggml-cuda.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +548 -497
- data/ext/llama_cpp/src/ggml-metal.metal +425 -122
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -32
- data/ext/llama_cpp/src/ggml-opencl.h +1 -2
- data/ext/llama_cpp/src/ggml.c +1904 -303
- data/ext/llama_cpp/src/ggml.h +126 -2
- data/ext/llama_cpp/src/llama.cpp +212 -108
- data/ext/llama_cpp/src/llama.h +12 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -0
- metadata +4 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -3603,6 +3603,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3603
3603
|
"SUM_ROWS",
|
3604
3604
|
"MEAN",
|
3605
3605
|
"REPEAT",
|
3606
|
+
"REPEAT_BACK",
|
3606
3607
|
"ABS",
|
3607
3608
|
"SGN",
|
3608
3609
|
"NEG",
|
@@ -3616,6 +3617,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3616
3617
|
"RMS_NORM_BACK",
|
3617
3618
|
|
3618
3619
|
"MUL_MAT",
|
3620
|
+
"OUT_PROD",
|
3619
3621
|
|
3620
3622
|
"SCALE",
|
3621
3623
|
"SET",
|
@@ -3631,6 +3633,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3631
3633
|
"DIAG_MASK_INF",
|
3632
3634
|
"DIAG_MASK_ZERO",
|
3633
3635
|
"SOFT_MAX",
|
3636
|
+
"SOFT_MAX_BACK",
|
3634
3637
|
"ROPE",
|
3635
3638
|
"ROPE_BACK",
|
3636
3639
|
"ALIBI",
|
@@ -3640,13 +3643,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3640
3643
|
|
3641
3644
|
"FLASH_ATTN",
|
3642
3645
|
"FLASH_FF",
|
3646
|
+
"FLASH_ATTN_BACK",
|
3643
3647
|
|
3644
3648
|
"MAP_UNARY",
|
3645
3649
|
"MAP_BINARY",
|
3646
|
-
};
|
3647
3650
|
|
3648
|
-
|
3651
|
+
"CROSS_ENTROPY_LOSS",
|
3652
|
+
"CROSS_ENTROPY_LOSS_BACK",
|
3653
|
+
};
|
3649
3654
|
|
3655
|
+
static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
|
3650
3656
|
|
3651
3657
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
3652
3658
|
"none",
|
@@ -3665,6 +3671,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3665
3671
|
"Σx_k",
|
3666
3672
|
"Σx/n",
|
3667
3673
|
"repeat(x)",
|
3674
|
+
"repeat_back(x)",
|
3668
3675
|
"abs(x)",
|
3669
3676
|
"sgn(x)",
|
3670
3677
|
"-x",
|
@@ -3677,6 +3684,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3677
3684
|
"rms_norm(x)",
|
3678
3685
|
"rms_norm_back(x)",
|
3679
3686
|
|
3687
|
+
"X*Y",
|
3680
3688
|
"X*Y",
|
3681
3689
|
|
3682
3690
|
"x*v",
|
@@ -3693,6 +3701,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3693
3701
|
"diag_mask_inf(x)",
|
3694
3702
|
"diag_mask_zero(x)",
|
3695
3703
|
"soft_max(x)",
|
3704
|
+
"soft_max_back(x)",
|
3696
3705
|
"rope(x)",
|
3697
3706
|
"rope_back(x)",
|
3698
3707
|
"alibi(x)",
|
@@ -3702,12 +3711,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3702
3711
|
|
3703
3712
|
"flash_attn(x)",
|
3704
3713
|
"flash_ff(x)",
|
3714
|
+
"flash_attn_back(x)",
|
3705
3715
|
|
3706
3716
|
"f(x)",
|
3707
3717
|
"f(x,y)",
|
3718
|
+
|
3719
|
+
"cross_entropy_loss(x,y)",
|
3720
|
+
"cross_entropy_loss_back(x,y)",
|
3708
3721
|
};
|
3709
3722
|
|
3710
|
-
static_assert(GGML_OP_COUNT ==
|
3723
|
+
static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
|
3711
3724
|
|
3712
3725
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
3713
3726
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -3870,6 +3883,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
3870
3883
|
(t0->ne[3] == t1->ne[3]);
|
3871
3884
|
}
|
3872
3885
|
|
3886
|
+
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
3887
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3888
|
+
|
3889
|
+
return
|
3890
|
+
(t0->ne[1] == t1->ne[1]) &&
|
3891
|
+
(t0->ne[2] == t1->ne[2]) &&
|
3892
|
+
(t0->ne[3] == t1->ne[3]);
|
3893
|
+
}
|
3894
|
+
|
3873
3895
|
bool ggml_is_quantized(enum ggml_type type) {
|
3874
3896
|
return GGML_IS_QUANTIZED[type];
|
3875
3897
|
}
|
@@ -3917,6 +3939,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
|
3917
3939
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3918
3940
|
}
|
3919
3941
|
|
3942
|
+
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
3943
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3944
|
+
|
3945
|
+
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
|
3946
|
+
}
|
3947
|
+
|
3920
3948
|
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
3921
3949
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3922
3950
|
|
@@ -4693,7 +4721,7 @@ struct ggml_tensor * ggml_add_impl(
|
|
4693
4721
|
|
4694
4722
|
bool is_node = false;
|
4695
4723
|
|
4696
|
-
if (
|
4724
|
+
if (a->grad || b->grad) {
|
4697
4725
|
is_node = true;
|
4698
4726
|
}
|
4699
4727
|
|
@@ -4733,7 +4761,7 @@ struct ggml_tensor * ggml_add1_impl(
|
|
4733
4761
|
|
4734
4762
|
bool is_node = false;
|
4735
4763
|
|
4736
|
-
if (
|
4764
|
+
if (a->grad || b->grad) {
|
4737
4765
|
is_node = true;
|
4738
4766
|
}
|
4739
4767
|
|
@@ -5159,6 +5187,34 @@ struct ggml_tensor * ggml_repeat(
|
|
5159
5187
|
return result;
|
5160
5188
|
}
|
5161
5189
|
|
5190
|
+
// ggml_repeat_back
|
5191
|
+
|
5192
|
+
struct ggml_tensor * ggml_repeat_back(
|
5193
|
+
struct ggml_context * ctx,
|
5194
|
+
struct ggml_tensor * a,
|
5195
|
+
struct ggml_tensor * b) {
|
5196
|
+
GGML_ASSERT(ggml_can_repeat(b, a));
|
5197
|
+
|
5198
|
+
bool is_node = false;
|
5199
|
+
|
5200
|
+
if (a->grad) {
|
5201
|
+
is_node = true;
|
5202
|
+
}
|
5203
|
+
|
5204
|
+
if (ggml_are_same_shape(a, b) && !is_node) {
|
5205
|
+
return a;
|
5206
|
+
}
|
5207
|
+
|
5208
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
|
5209
|
+
|
5210
|
+
result->op = GGML_OP_REPEAT_BACK;
|
5211
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5212
|
+
result->src0 = a;
|
5213
|
+
result->src1 = b;
|
5214
|
+
|
5215
|
+
return result;
|
5216
|
+
}
|
5217
|
+
|
5162
5218
|
// ggml_abs
|
5163
5219
|
|
5164
5220
|
struct ggml_tensor * ggml_abs_impl(
|
@@ -5536,6 +5592,32 @@ struct ggml_tensor * ggml_mul_mat(
|
|
5536
5592
|
return result;
|
5537
5593
|
}
|
5538
5594
|
|
5595
|
+
// ggml_out_prod
|
5596
|
+
|
5597
|
+
struct ggml_tensor * ggml_out_prod(
|
5598
|
+
struct ggml_context * ctx,
|
5599
|
+
struct ggml_tensor * a,
|
5600
|
+
struct ggml_tensor * b) {
|
5601
|
+
GGML_ASSERT(ggml_can_out_prod(a, b));
|
5602
|
+
GGML_ASSERT(!ggml_is_transposed(a));
|
5603
|
+
|
5604
|
+
bool is_node = false;
|
5605
|
+
|
5606
|
+
if (a->grad || b->grad) {
|
5607
|
+
is_node = true;
|
5608
|
+
}
|
5609
|
+
|
5610
|
+
const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
|
5611
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
|
5612
|
+
|
5613
|
+
result->op = GGML_OP_OUT_PROD;
|
5614
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5615
|
+
result->src0 = a;
|
5616
|
+
result->src1 = b;
|
5617
|
+
|
5618
|
+
return result;
|
5619
|
+
}
|
5620
|
+
|
5539
5621
|
// ggml_scale
|
5540
5622
|
|
5541
5623
|
struct ggml_tensor * ggml_scale_impl(
|
@@ -5548,7 +5630,7 @@ struct ggml_tensor * ggml_scale_impl(
|
|
5548
5630
|
|
5549
5631
|
bool is_node = false;
|
5550
5632
|
|
5551
|
-
if (
|
5633
|
+
if (a->grad || b->grad) {
|
5552
5634
|
is_node = true;
|
5553
5635
|
}
|
5554
5636
|
|
@@ -5591,7 +5673,7 @@ struct ggml_tensor * ggml_set_impl(
|
|
5591
5673
|
|
5592
5674
|
bool is_node = false;
|
5593
5675
|
|
5594
|
-
if (
|
5676
|
+
if (a->grad || b->grad) {
|
5595
5677
|
is_node = true;
|
5596
5678
|
}
|
5597
5679
|
|
@@ -5913,10 +5995,6 @@ struct ggml_tensor * ggml_view_1d(
|
|
5913
5995
|
result->src1 = NULL;
|
5914
5996
|
result->opt[0] = offs;
|
5915
5997
|
|
5916
|
-
if (is_node) {
|
5917
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5918
|
-
}
|
5919
|
-
|
5920
5998
|
return result;
|
5921
5999
|
}
|
5922
6000
|
|
@@ -5957,10 +6035,6 @@ struct ggml_tensor * ggml_view_2d(
|
|
5957
6035
|
result->src1 = NULL;
|
5958
6036
|
result->opt[0] = offs;
|
5959
6037
|
|
5960
|
-
if (is_node) {
|
5961
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5962
|
-
}
|
5963
|
-
|
5964
6038
|
return result;
|
5965
6039
|
}
|
5966
6040
|
|
@@ -6003,10 +6077,6 @@ struct ggml_tensor * ggml_view_3d(
|
|
6003
6077
|
result->src1 = NULL;
|
6004
6078
|
result->opt[0] = offs;
|
6005
6079
|
|
6006
|
-
if (is_node) {
|
6007
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
6008
|
-
}
|
6009
|
-
|
6010
6080
|
return result;
|
6011
6081
|
}
|
6012
6082
|
|
@@ -6051,10 +6121,6 @@ struct ggml_tensor * ggml_view_4d(
|
|
6051
6121
|
result->src1 = NULL;
|
6052
6122
|
result->opt[0] = offs;
|
6053
6123
|
|
6054
|
-
if (is_node) {
|
6055
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
6056
|
-
}
|
6057
|
-
|
6058
6124
|
return result;
|
6059
6125
|
}
|
6060
6126
|
|
@@ -6116,10 +6182,18 @@ struct ggml_tensor * ggml_permute(
|
|
6116
6182
|
result->src1 = NULL;
|
6117
6183
|
|
6118
6184
|
if (is_node) {
|
6119
|
-
|
6120
|
-
|
6121
|
-
|
6122
|
-
|
6185
|
+
ggml_scratch_save(ctx);
|
6186
|
+
|
6187
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
|
6188
|
+
|
6189
|
+
((int32_t *) b->data)[0] = axis0;
|
6190
|
+
((int32_t *) b->data)[1] = axis1;
|
6191
|
+
((int32_t *) b->data)[2] = axis2;
|
6192
|
+
((int32_t *) b->data)[3] = axis3;
|
6193
|
+
|
6194
|
+
ggml_scratch_load(ctx);
|
6195
|
+
|
6196
|
+
result->opt[0] = b;
|
6123
6197
|
}
|
6124
6198
|
|
6125
6199
|
return result;
|
@@ -6359,6 +6433,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
|
|
6359
6433
|
return ggml_soft_max_impl(ctx, a, true);
|
6360
6434
|
}
|
6361
6435
|
|
6436
|
+
|
6437
|
+
// ggml_soft_max_back
|
6438
|
+
|
6439
|
+
struct ggml_tensor * ggml_soft_max_back_impl(
|
6440
|
+
struct ggml_context * ctx,
|
6441
|
+
struct ggml_tensor * a,
|
6442
|
+
struct ggml_tensor * b,
|
6443
|
+
bool inplace) {
|
6444
|
+
bool is_node = false;
|
6445
|
+
|
6446
|
+
if (a->grad || b->grad) {
|
6447
|
+
is_node = true; // TODO : implement backward pass
|
6448
|
+
}
|
6449
|
+
|
6450
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6451
|
+
|
6452
|
+
result->op = GGML_OP_SOFT_MAX_BACK;
|
6453
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6454
|
+
result->src0 = a;
|
6455
|
+
result->src1 = b;
|
6456
|
+
|
6457
|
+
return result;
|
6458
|
+
}
|
6459
|
+
|
6460
|
+
struct ggml_tensor * ggml_soft_max_back(
|
6461
|
+
struct ggml_context * ctx,
|
6462
|
+
struct ggml_tensor * a,
|
6463
|
+
struct ggml_tensor * b) {
|
6464
|
+
return ggml_soft_max_back_impl(ctx, a, b, false);
|
6465
|
+
}
|
6466
|
+
|
6467
|
+
struct ggml_tensor * ggml_soft_max_back_inplace(
|
6468
|
+
struct ggml_context * ctx,
|
6469
|
+
struct ggml_tensor * a,
|
6470
|
+
struct ggml_tensor * b) {
|
6471
|
+
return ggml_soft_max_back_impl(ctx, a, b, true);
|
6472
|
+
}
|
6473
|
+
|
6362
6474
|
// ggml_rope
|
6363
6475
|
|
6364
6476
|
struct ggml_tensor * ggml_rope_impl(
|
@@ -6371,7 +6483,7 @@ struct ggml_tensor * ggml_rope_impl(
|
|
6371
6483
|
GGML_ASSERT(n_past >= 0);
|
6372
6484
|
bool is_node = false;
|
6373
6485
|
|
6374
|
-
if (
|
6486
|
+
if (a->grad) {
|
6375
6487
|
is_node = true;
|
6376
6488
|
}
|
6377
6489
|
|
@@ -6425,8 +6537,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6425
6537
|
bool is_node = false;
|
6426
6538
|
|
6427
6539
|
if (a->grad) {
|
6428
|
-
|
6429
|
-
is_node = true;
|
6540
|
+
is_node = false; // TODO: implement backward
|
6430
6541
|
}
|
6431
6542
|
|
6432
6543
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
@@ -6591,7 +6702,6 @@ struct ggml_tensor * ggml_flash_attn(
|
|
6591
6702
|
bool is_node = false;
|
6592
6703
|
|
6593
6704
|
if (q->grad || k->grad || v->grad) {
|
6594
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6595
6705
|
is_node = true;
|
6596
6706
|
}
|
6597
6707
|
|
@@ -6623,7 +6733,6 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6623
6733
|
bool is_node = false;
|
6624
6734
|
|
6625
6735
|
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6626
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6627
6736
|
is_node = true;
|
6628
6737
|
}
|
6629
6738
|
|
@@ -6641,6 +6750,71 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6641
6750
|
return result;
|
6642
6751
|
}
|
6643
6752
|
|
6753
|
+
// ggml_flash_attn_back
|
6754
|
+
|
6755
|
+
struct ggml_tensor * ggml_flash_attn_back(
|
6756
|
+
struct ggml_context * ctx,
|
6757
|
+
struct ggml_tensor * q,
|
6758
|
+
struct ggml_tensor * k,
|
6759
|
+
struct ggml_tensor * v,
|
6760
|
+
struct ggml_tensor * d,
|
6761
|
+
bool masked) {
|
6762
|
+
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6763
|
+
// TODO: check if vT can be multiplied by (k*qT)
|
6764
|
+
|
6765
|
+
// d shape [D,N,ne2,ne3]
|
6766
|
+
// q shape [D,N,ne2,ne3]
|
6767
|
+
// k shape [D,M,ne2,ne3]
|
6768
|
+
// v shape [M,D,ne2,ne3]
|
6769
|
+
|
6770
|
+
const int64_t D = q->ne[0];
|
6771
|
+
const int64_t N = q->ne[1];
|
6772
|
+
const int64_t M = k->ne[1];
|
6773
|
+
const int64_t ne2 = q->ne[2];
|
6774
|
+
const int64_t ne3 = q->ne[3];
|
6775
|
+
|
6776
|
+
GGML_ASSERT(k->ne[0] == D);
|
6777
|
+
GGML_ASSERT(v->ne[0] == M);
|
6778
|
+
GGML_ASSERT(v->ne[1] == D);
|
6779
|
+
GGML_ASSERT(d->ne[0] == D);
|
6780
|
+
GGML_ASSERT(d->ne[1] == N);
|
6781
|
+
GGML_ASSERT(k->ne[2] == ne2);
|
6782
|
+
GGML_ASSERT(k->ne[3] == ne3);
|
6783
|
+
GGML_ASSERT(v->ne[2] == ne2);
|
6784
|
+
GGML_ASSERT(v->ne[3] == ne3);
|
6785
|
+
GGML_ASSERT(d->ne[2] == ne2);
|
6786
|
+
GGML_ASSERT(d->ne[3] == ne3);
|
6787
|
+
|
6788
|
+
bool is_node = false;
|
6789
|
+
|
6790
|
+
if (q->grad || k->grad || v->grad) {
|
6791
|
+
// when using this operation (in backwards pass) these grads are set.
|
6792
|
+
// we don't want to create (big) grad of our result, so is_node is false.
|
6793
|
+
is_node = false;
|
6794
|
+
}
|
6795
|
+
|
6796
|
+
// store gradients of q, k and v as continuous tensors concatenated in result.
|
6797
|
+
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
6798
|
+
// gradq->data = result->data
|
6799
|
+
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
6800
|
+
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
6801
|
+
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
6802
|
+
int64_t ne[4] = {D,M+N+M,ne2,ne3};
|
6803
|
+
|
6804
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6805
|
+
|
6806
|
+
result->op = GGML_OP_FLASH_ATTN_BACK;
|
6807
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6808
|
+
result->src0 = q;
|
6809
|
+
result->src1 = k;
|
6810
|
+
result->opt[0] = v;
|
6811
|
+
result->opt[1] = d;
|
6812
|
+
result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
|
6813
|
+
|
6814
|
+
return result;
|
6815
|
+
}
|
6816
|
+
|
6817
|
+
|
6644
6818
|
// ggml_map_unary
|
6645
6819
|
|
6646
6820
|
struct ggml_tensor * ggml_map_unary_impl_f32(
|
@@ -6725,6 +6899,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
|
|
6725
6899
|
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
6726
6900
|
}
|
6727
6901
|
|
6902
|
+
// ggml_cross_entropy_loss
|
6903
|
+
|
6904
|
+
struct ggml_tensor * ggml_cross_entropy_loss(
|
6905
|
+
struct ggml_context * ctx,
|
6906
|
+
struct ggml_tensor * a,
|
6907
|
+
struct ggml_tensor * b) {
|
6908
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6909
|
+
bool is_node = false;
|
6910
|
+
|
6911
|
+
if (a->grad || b->grad) {
|
6912
|
+
is_node = true;
|
6913
|
+
}
|
6914
|
+
|
6915
|
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
|
6916
|
+
|
6917
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS;
|
6918
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6919
|
+
result->src0 = a;
|
6920
|
+
result->src1 = b;
|
6921
|
+
|
6922
|
+
return result;
|
6923
|
+
}
|
6924
|
+
|
6925
|
+
// ggml_cross_entropy_loss_back
|
6926
|
+
|
6927
|
+
struct ggml_tensor * ggml_cross_entropy_loss_back(
|
6928
|
+
struct ggml_context * ctx,
|
6929
|
+
struct ggml_tensor * a,
|
6930
|
+
struct ggml_tensor * b,
|
6931
|
+
struct ggml_tensor * c) {
|
6932
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6933
|
+
GGML_ASSERT(ggml_is_scalar(c));
|
6934
|
+
|
6935
|
+
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
6936
|
+
|
6937
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
|
6938
|
+
result->grad = NULL;
|
6939
|
+
result->src0 = a;
|
6940
|
+
result->src1 = b;
|
6941
|
+
result->opt[0] = c;
|
6942
|
+
|
6943
|
+
return result;
|
6944
|
+
}
|
6945
|
+
|
6728
6946
|
////////////////////////////////////////////////////////////////////////////////
|
6729
6947
|
|
6730
6948
|
void ggml_set_param(
|
@@ -8875,6 +9093,99 @@ static void ggml_compute_forward_repeat(
|
|
8875
9093
|
}
|
8876
9094
|
}
|
8877
9095
|
|
9096
|
+
// ggml_compute_forward_repeat_back
|
9097
|
+
|
9098
|
+
static void ggml_compute_forward_repeat_back_f32(
|
9099
|
+
const struct ggml_compute_params * params,
|
9100
|
+
const struct ggml_tensor * src0,
|
9101
|
+
struct ggml_tensor * dst) {
|
9102
|
+
GGML_ASSERT(params->ith == 0);
|
9103
|
+
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
9104
|
+
|
9105
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9106
|
+
return;
|
9107
|
+
}
|
9108
|
+
|
9109
|
+
const int64_t ne0 = dst->ne[0];
|
9110
|
+
const int64_t ne1 = dst->ne[1];
|
9111
|
+
const int64_t ne2 = dst->ne[2];
|
9112
|
+
const int64_t ne3 = dst->ne[3];
|
9113
|
+
|
9114
|
+
const int64_t ne00 = src0->ne[0];
|
9115
|
+
const int64_t ne01 = src0->ne[1];
|
9116
|
+
const int64_t ne02 = src0->ne[2];
|
9117
|
+
const int64_t ne03 = src0->ne[3];
|
9118
|
+
|
9119
|
+
const size_t nb0 = dst->nb[0];
|
9120
|
+
const size_t nb1 = dst->nb[1];
|
9121
|
+
const size_t nb2 = dst->nb[2];
|
9122
|
+
const size_t nb3 = dst->nb[3];
|
9123
|
+
|
9124
|
+
const size_t nb00 = src0->nb[0];
|
9125
|
+
const size_t nb01 = src0->nb[1];
|
9126
|
+
const size_t nb02 = src0->nb[2];
|
9127
|
+
const size_t nb03 = src0->nb[3];
|
9128
|
+
|
9129
|
+
// guaranteed to be an integer due to the check in ggml_can_repeat
|
9130
|
+
const int nr0 = (int)(ne00/ne0);
|
9131
|
+
const int nr1 = (int)(ne01/ne1);
|
9132
|
+
const int nr2 = (int)(ne02/ne2);
|
9133
|
+
const int nr3 = (int)(ne03/ne3);
|
9134
|
+
|
9135
|
+
// TODO: support for transposed / permuted tensors
|
9136
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
9137
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
9138
|
+
|
9139
|
+
if (ggml_is_contiguous(dst)) {
|
9140
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
9141
|
+
} else {
|
9142
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9143
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9144
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9145
|
+
ggml_vec_set_f32(ne0,
|
9146
|
+
(float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
|
9147
|
+
0);
|
9148
|
+
}
|
9149
|
+
}
|
9150
|
+
}
|
9151
|
+
}
|
9152
|
+
|
9153
|
+
// TODO: maybe this is not optimal?
|
9154
|
+
for (int i3 = 0; i3 < nr3; i3++) {
|
9155
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9156
|
+
for (int i2 = 0; i2 < nr2; i2++) {
|
9157
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9158
|
+
for (int i1 = 0; i1 < nr1; i1++) {
|
9159
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9160
|
+
for (int i0 = 0; i0 < nr0; i0++) {
|
9161
|
+
ggml_vec_acc_f32(ne0,
|
9162
|
+
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
|
9163
|
+
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
|
9164
|
+
}
|
9165
|
+
}
|
9166
|
+
}
|
9167
|
+
}
|
9168
|
+
}
|
9169
|
+
}
|
9170
|
+
}
|
9171
|
+
}
|
9172
|
+
|
9173
|
+
static void ggml_compute_forward_repeat_back(
|
9174
|
+
const struct ggml_compute_params * params,
|
9175
|
+
const struct ggml_tensor * src0,
|
9176
|
+
struct ggml_tensor * dst) {
|
9177
|
+
switch (src0->type) {
|
9178
|
+
case GGML_TYPE_F32:
|
9179
|
+
{
|
9180
|
+
ggml_compute_forward_repeat_back_f32(params, src0, dst);
|
9181
|
+
} break;
|
9182
|
+
default:
|
9183
|
+
{
|
9184
|
+
GGML_ASSERT(false);
|
9185
|
+
} break;
|
9186
|
+
}
|
9187
|
+
}
|
9188
|
+
|
8878
9189
|
// ggml_compute_forward_abs
|
8879
9190
|
|
8880
9191
|
static void ggml_compute_forward_abs_f32(
|
@@ -10249,18 +10560,188 @@ static void ggml_compute_forward_mul_mat(
|
|
10249
10560
|
}
|
10250
10561
|
}
|
10251
10562
|
|
10252
|
-
//
|
10563
|
+
// ggml_compute_forward_out_prod
|
10253
10564
|
|
10254
|
-
|
10565
|
+
|
10566
|
+
static void ggml_compute_forward_out_prod_f32(
|
10255
10567
|
const struct ggml_compute_params * params,
|
10256
10568
|
const struct ggml_tensor * src0,
|
10257
10569
|
const struct ggml_tensor * src1,
|
10258
|
-
|
10259
|
-
|
10260
|
-
|
10261
|
-
|
10262
|
-
|
10263
|
-
|
10570
|
+
struct ggml_tensor * dst) {
|
10571
|
+
int64_t t0 = ggml_perf_time_us();
|
10572
|
+
UNUSED(t0);
|
10573
|
+
|
10574
|
+
const int64_t ne00 = src0->ne[0];
|
10575
|
+
const int64_t ne01 = src0->ne[1];
|
10576
|
+
const int64_t ne02 = src0->ne[2];
|
10577
|
+
const int64_t ne03 = src0->ne[3];
|
10578
|
+
|
10579
|
+
const int64_t ne10 = src1->ne[0];
|
10580
|
+
//const int64_t ne11 = src1->ne[1];
|
10581
|
+
const int64_t ne12 = src1->ne[2];
|
10582
|
+
const int64_t ne13 = src1->ne[3];
|
10583
|
+
|
10584
|
+
const int64_t ne0 = dst->ne[0];
|
10585
|
+
const int64_t ne1 = dst->ne[1];
|
10586
|
+
const int64_t ne2 = dst->ne[2];
|
10587
|
+
const int64_t ne3 = dst->ne[3];
|
10588
|
+
|
10589
|
+
const int nb00 = src0->nb[0];
|
10590
|
+
const int nb01 = src0->nb[1];
|
10591
|
+
const int nb02 = src0->nb[2];
|
10592
|
+
const int nb03 = src0->nb[3];
|
10593
|
+
|
10594
|
+
const int nb10 = src1->nb[0];
|
10595
|
+
const int nb11 = src1->nb[1];
|
10596
|
+
const int nb12 = src1->nb[2];
|
10597
|
+
const int nb13 = src1->nb[3];
|
10598
|
+
|
10599
|
+
const int nb0 = dst->nb[0];
|
10600
|
+
const int nb1 = dst->nb[1];
|
10601
|
+
const int nb2 = dst->nb[2];
|
10602
|
+
const int nb3 = dst->nb[3];
|
10603
|
+
|
10604
|
+
const int ith = params->ith;
|
10605
|
+
const int nth = params->nth;
|
10606
|
+
|
10607
|
+
GGML_ASSERT(ne02 == ne12);
|
10608
|
+
GGML_ASSERT(ne03 == ne13);
|
10609
|
+
GGML_ASSERT(ne2 == ne12);
|
10610
|
+
GGML_ASSERT(ne3 == ne13);
|
10611
|
+
|
10612
|
+
// we don't support permuted src0 or src1
|
10613
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
10614
|
+
|
10615
|
+
// dst cannot be transposed or permuted
|
10616
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
10617
|
+
// GGML_ASSERT(nb0 <= nb1);
|
10618
|
+
// GGML_ASSERT(nb1 <= nb2);
|
10619
|
+
// GGML_ASSERT(nb2 <= nb3);
|
10620
|
+
|
10621
|
+
GGML_ASSERT(ne0 == ne00);
|
10622
|
+
GGML_ASSERT(ne1 == ne10);
|
10623
|
+
GGML_ASSERT(ne2 == ne02);
|
10624
|
+
GGML_ASSERT(ne3 == ne03);
|
10625
|
+
|
10626
|
+
// nb01 >= nb00 - src0 is not transposed
|
10627
|
+
// compute by src0 rows
|
10628
|
+
|
10629
|
+
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
|
10630
|
+
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10631
|
+
|
10632
|
+
if (params->type == GGML_TASK_INIT) {
|
10633
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
10634
|
+
return;
|
10635
|
+
}
|
10636
|
+
|
10637
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
10638
|
+
return;
|
10639
|
+
}
|
10640
|
+
|
10641
|
+
// parallelize by last three dimensions
|
10642
|
+
|
10643
|
+
// total rows in dst
|
10644
|
+
const int64_t nr = ne1*ne2*ne3;
|
10645
|
+
|
10646
|
+
// rows per thread
|
10647
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
10648
|
+
|
10649
|
+
// row range for this thread
|
10650
|
+
const int64_t ir0 = dr*ith;
|
10651
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
10652
|
+
|
10653
|
+
// dst[:,:,:,:] = 0
|
10654
|
+
// for i2,i3:
|
10655
|
+
// for i1:
|
10656
|
+
// for i01:
|
10657
|
+
// for i0:
|
10658
|
+
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
10659
|
+
|
10660
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
10661
|
+
// dst indices
|
10662
|
+
const int64_t i3 = ir/(ne2*ne1);
|
10663
|
+
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
10664
|
+
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
10665
|
+
|
10666
|
+
const int64_t i02 = i2;
|
10667
|
+
const int64_t i03 = i3;
|
10668
|
+
|
10669
|
+
//const int64_t i10 = i1;
|
10670
|
+
const int64_t i12 = i2;
|
10671
|
+
const int64_t i13 = i3;
|
10672
|
+
|
10673
|
+
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
10674
|
+
const int64_t i11 = i01;
|
10675
|
+
|
10676
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
10677
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
10678
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
10679
|
+
|
10680
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
10681
|
+
// for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
10682
|
+
// d[i0] += s0[i0] * s1[i1];
|
10683
|
+
// }
|
10684
|
+
}
|
10685
|
+
}
|
10686
|
+
|
10687
|
+
//int64_t t1 = ggml_perf_time_us();
|
10688
|
+
//static int64_t acc = 0;
|
10689
|
+
//acc += t1 - t0;
|
10690
|
+
//if (t1 - t0 > 10) {
|
10691
|
+
// printf("\n");
|
10692
|
+
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
|
10693
|
+
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
|
10694
|
+
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
|
10695
|
+
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
|
10696
|
+
|
10697
|
+
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
|
10698
|
+
//}
|
10699
|
+
}
|
10700
|
+
|
10701
|
+
static void ggml_compute_forward_out_prod(
|
10702
|
+
const struct ggml_compute_params * params,
|
10703
|
+
const struct ggml_tensor * src0,
|
10704
|
+
const struct ggml_tensor * src1,
|
10705
|
+
struct ggml_tensor * dst) {
|
10706
|
+
switch (src0->type) {
|
10707
|
+
case GGML_TYPE_Q4_0:
|
10708
|
+
case GGML_TYPE_Q4_1:
|
10709
|
+
case GGML_TYPE_Q5_0:
|
10710
|
+
case GGML_TYPE_Q5_1:
|
10711
|
+
case GGML_TYPE_Q8_0:
|
10712
|
+
case GGML_TYPE_Q8_1:
|
10713
|
+
{
|
10714
|
+
GGML_ASSERT(false); // todo
|
10715
|
+
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
10716
|
+
} break;
|
10717
|
+
case GGML_TYPE_F16:
|
10718
|
+
{
|
10719
|
+
GGML_ASSERT(false); // todo
|
10720
|
+
// ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
|
10721
|
+
} break;
|
10722
|
+
case GGML_TYPE_F32:
|
10723
|
+
{
|
10724
|
+
ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
|
10725
|
+
} break;
|
10726
|
+
default:
|
10727
|
+
{
|
10728
|
+
GGML_ASSERT(false);
|
10729
|
+
} break;
|
10730
|
+
}
|
10731
|
+
}
|
10732
|
+
|
10733
|
+
// ggml_compute_forward_scale
|
10734
|
+
|
10735
|
+
static void ggml_compute_forward_scale_f32(
|
10736
|
+
const struct ggml_compute_params * params,
|
10737
|
+
const struct ggml_tensor * src0,
|
10738
|
+
const struct ggml_tensor * src1,
|
10739
|
+
struct ggml_tensor * dst) {
|
10740
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
10741
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
10742
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
10743
|
+
GGML_ASSERT(ggml_is_scalar(src1));
|
10744
|
+
|
10264
10745
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10265
10746
|
return;
|
10266
10747
|
}
|
@@ -10671,7 +11152,11 @@ static void ggml_compute_forward_get_rows_back_f32(
|
|
10671
11152
|
GGML_ASSERT(ggml_is_contiguous(opt0));
|
10672
11153
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
10673
11154
|
|
10674
|
-
ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11155
|
+
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11156
|
+
|
11157
|
+
if (params->type == GGML_TASK_INIT) {
|
11158
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
11159
|
+
}
|
10675
11160
|
|
10676
11161
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10677
11162
|
return;
|
@@ -10815,8 +11300,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10815
11300
|
const struct ggml_tensor * src1,
|
10816
11301
|
struct ggml_tensor * dst,
|
10817
11302
|
const float value) {
|
10818
|
-
|
10819
|
-
|
11303
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11304
|
+
GGML_ASSERT(ggml_nelements(src1) == 2);
|
10820
11305
|
|
10821
11306
|
const int ith = params->ith;
|
10822
11307
|
const int nth = params->nth;
|
@@ -10824,7 +11309,7 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10824
11309
|
const int n_past = ((int32_t *) src1->data)[0];
|
10825
11310
|
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
10826
11311
|
|
10827
|
-
|
11312
|
+
GGML_ASSERT(n_past >= 0);
|
10828
11313
|
|
10829
11314
|
if (!inplace && (params->type == GGML_TASK_INIT)) {
|
10830
11315
|
// memcpy needs to be synchronized across threads to avoid race conditions.
|
@@ -10848,8 +11333,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10848
11333
|
const int nr = src0->ne[1];
|
10849
11334
|
const int nz = n/nr;
|
10850
11335
|
|
10851
|
-
|
10852
|
-
|
11336
|
+
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
11337
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
10853
11338
|
|
10854
11339
|
for (int k = 0; k < nz; k++) {
|
10855
11340
|
for (int j = ith; j < nr; j += nth) {
|
@@ -10985,6 +11470,101 @@ static void ggml_compute_forward_soft_max(
|
|
10985
11470
|
}
|
10986
11471
|
}
|
10987
11472
|
|
11473
|
+
// ggml_compute_forward_soft_max_back
|
11474
|
+
|
11475
|
+
static void ggml_compute_forward_soft_max_back_f32(
|
11476
|
+
const struct ggml_compute_params * params,
|
11477
|
+
const struct ggml_tensor * src0,
|
11478
|
+
const struct ggml_tensor * src1,
|
11479
|
+
struct ggml_tensor * dst) {
|
11480
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
11481
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
11482
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
11483
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11484
|
+
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
11485
|
+
|
11486
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11487
|
+
return;
|
11488
|
+
}
|
11489
|
+
|
11490
|
+
// TODO: handle transposed/permuted matrices
|
11491
|
+
|
11492
|
+
const int ith = params->ith;
|
11493
|
+
const int nth = params->nth;
|
11494
|
+
|
11495
|
+
const int nc = src0->ne[0];
|
11496
|
+
const int nr = ggml_nrows(src0);
|
11497
|
+
|
11498
|
+
// rows per thread
|
11499
|
+
const int dr = (nr + nth - 1)/nth;
|
11500
|
+
|
11501
|
+
// row range for this thread
|
11502
|
+
const int ir0 = dr*ith;
|
11503
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
11504
|
+
|
11505
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
11506
|
+
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
|
11507
|
+
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
|
11508
|
+
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
|
11509
|
+
|
11510
|
+
#ifndef NDEBUG
|
11511
|
+
for (int i = 0; i < nc; ++i) {
|
11512
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
11513
|
+
assert(!isnan(dy[i]));
|
11514
|
+
assert(!isnan(y[i]));
|
11515
|
+
}
|
11516
|
+
#endif
|
11517
|
+
// Jii = yi - yi*yi
|
11518
|
+
// Jij = -yi*yj
|
11519
|
+
// J = diag(y)-y.T*y
|
11520
|
+
// dx = J * dy
|
11521
|
+
// dxk = sum_i(Jki * dyi)
|
11522
|
+
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
|
11523
|
+
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
|
11524
|
+
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
|
11525
|
+
// dxk = -yk * dot(y, dy) + yk*dyk
|
11526
|
+
// dxk = yk * (- dot(y, dy) + dyk)
|
11527
|
+
// dxk = yk * (dyk - dot(y, dy))
|
11528
|
+
//
|
11529
|
+
// post-order:
|
11530
|
+
// dot_y_dy := dot(y, dy)
|
11531
|
+
// dx := dy
|
11532
|
+
// dx := dx - dot_y_dy
|
11533
|
+
// dx := dx * y
|
11534
|
+
|
11535
|
+
// linear runtime, no additional memory
|
11536
|
+
float dot_y_dy = 0;
|
11537
|
+
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
|
11538
|
+
ggml_vec_cpy_f32 (nc, dx, dy);
|
11539
|
+
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
11540
|
+
ggml_vec_mul_f32 (nc, dx, dx, y);
|
11541
|
+
|
11542
|
+
#ifndef NDEBUG
|
11543
|
+
for (int i = 0; i < nc; ++i) {
|
11544
|
+
assert(!isnan(dx[i]));
|
11545
|
+
assert(!isinf(dx[i]));
|
11546
|
+
}
|
11547
|
+
#endif
|
11548
|
+
}
|
11549
|
+
}
|
11550
|
+
|
11551
|
+
static void ggml_compute_forward_soft_max_back(
|
11552
|
+
const struct ggml_compute_params * params,
|
11553
|
+
const struct ggml_tensor * src0,
|
11554
|
+
const struct ggml_tensor * src1,
|
11555
|
+
struct ggml_tensor * dst) {
|
11556
|
+
switch (src0->type) {
|
11557
|
+
case GGML_TYPE_F32:
|
11558
|
+
{
|
11559
|
+
ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
|
11560
|
+
} break;
|
11561
|
+
default:
|
11562
|
+
{
|
11563
|
+
GGML_ASSERT(false);
|
11564
|
+
} break;
|
11565
|
+
}
|
11566
|
+
}
|
11567
|
+
|
10988
11568
|
// ggml_compute_forward_alibi
|
10989
11569
|
|
10990
11570
|
static void ggml_compute_forward_alibi_f32(
|
@@ -12938,42 +13518,616 @@ static void ggml_compute_forward_flash_ff(
|
|
12938
13518
|
}
|
12939
13519
|
}
|
12940
13520
|
|
12941
|
-
//
|
13521
|
+
// ggml_compute_forward_flash_attn_back
|
12942
13522
|
|
12943
|
-
static void
|
13523
|
+
static void ggml_compute_forward_flash_attn_back_f32(
|
12944
13524
|
const struct ggml_compute_params * params,
|
12945
|
-
const struct ggml_tensor *
|
12946
|
-
struct ggml_tensor *
|
12947
|
-
const
|
12948
|
-
|
13525
|
+
const struct ggml_tensor * q,
|
13526
|
+
const struct ggml_tensor * k,
|
13527
|
+
const struct ggml_tensor * v,
|
13528
|
+
const struct ggml_tensor * d,
|
13529
|
+
const bool masked,
|
13530
|
+
struct ggml_tensor * dst) {
|
13531
|
+
int64_t t0 = ggml_perf_time_us();
|
13532
|
+
UNUSED(t0);
|
12949
13533
|
|
12950
|
-
|
12951
|
-
|
12952
|
-
|
13534
|
+
const int64_t neq0 = q->ne[0];
|
13535
|
+
const int64_t neq1 = q->ne[1];
|
13536
|
+
const int64_t neq2 = q->ne[2];
|
13537
|
+
const int64_t neq3 = q->ne[3];
|
12953
13538
|
|
12954
|
-
const
|
12955
|
-
const
|
13539
|
+
const int64_t nek0 = k->ne[0];
|
13540
|
+
const int64_t nek1 = k->ne[1];
|
13541
|
+
//const int64_t nek2 = k->ne[2];
|
13542
|
+
//const int64_t nek3 = k->ne[3];
|
12956
13543
|
|
12957
|
-
|
12958
|
-
|
13544
|
+
const int64_t nev0 = v->ne[0];
|
13545
|
+
const int64_t nev1 = v->ne[1];
|
13546
|
+
//const int64_t nev2 = v->ne[2];
|
13547
|
+
//const int64_t nev3 = v->ne[3];
|
12959
13548
|
|
12960
|
-
|
12961
|
-
|
12962
|
-
|
12963
|
-
|
12964
|
-
}
|
12965
|
-
}
|
13549
|
+
const int64_t ned0 = d->ne[0];
|
13550
|
+
const int64_t ned1 = d->ne[1];
|
13551
|
+
//const int64_t ned2 = d->ne[2];
|
13552
|
+
//const int64_t ned3 = d->ne[3];
|
12966
13553
|
|
13554
|
+
const int64_t ne0 = dst->ne[0];
|
13555
|
+
const int64_t ne1 = dst->ne[1];
|
13556
|
+
const int64_t ne2 = dst->ne[2];
|
13557
|
+
const int64_t ne3 = dst->ne[3];
|
12967
13558
|
|
12968
|
-
|
12969
|
-
|
12970
|
-
|
12971
|
-
|
12972
|
-
|
13559
|
+
const int nbk0 = k->nb[0];
|
13560
|
+
const int nbk1 = k->nb[1];
|
13561
|
+
const int nbk2 = k->nb[2];
|
13562
|
+
const int nbk3 = k->nb[3];
|
13563
|
+
|
13564
|
+
const int nbq0 = q->nb[0];
|
13565
|
+
const int nbq1 = q->nb[1];
|
13566
|
+
const int nbq2 = q->nb[2];
|
13567
|
+
const int nbq3 = q->nb[3];
|
13568
|
+
|
13569
|
+
const int nbv0 = v->nb[0];
|
13570
|
+
const int nbv1 = v->nb[1];
|
13571
|
+
const int nbv2 = v->nb[2];
|
13572
|
+
const int nbv3 = v->nb[3];
|
13573
|
+
|
13574
|
+
const int nbd0 = d->nb[0];
|
13575
|
+
const int nbd1 = d->nb[1];
|
13576
|
+
const int nbd2 = d->nb[2];
|
13577
|
+
const int nbd3 = d->nb[3];
|
13578
|
+
|
13579
|
+
const int nb0 = dst->nb[0];
|
13580
|
+
const int nb1 = dst->nb[1];
|
13581
|
+
const int nb2 = dst->nb[2];
|
13582
|
+
const int nb3 = dst->nb[3];
|
13583
|
+
|
13584
|
+
const int ith = params->ith;
|
13585
|
+
const int nth = params->nth;
|
13586
|
+
|
13587
|
+
const int64_t D = neq0;
|
13588
|
+
const int64_t N = neq1;
|
13589
|
+
const int64_t P = nek1 - N;
|
13590
|
+
const int64_t M = P + N;
|
13591
|
+
|
13592
|
+
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
13593
|
+
const int mxDM = MAX(D, Mup);
|
13594
|
+
|
13595
|
+
// GGML_ASSERT(ne0 == D);
|
13596
|
+
// GGML_ASSERT(ne1 == N);
|
13597
|
+
GGML_ASSERT(P >= 0);
|
13598
|
+
|
13599
|
+
GGML_ASSERT(nbq0 == sizeof(float));
|
13600
|
+
GGML_ASSERT(nbk0 == sizeof(float));
|
13601
|
+
GGML_ASSERT(nbv0 == sizeof(float));
|
13602
|
+
|
13603
|
+
GGML_ASSERT(neq0 == D);
|
13604
|
+
GGML_ASSERT(nek0 == D);
|
13605
|
+
GGML_ASSERT(nev1 == D);
|
13606
|
+
GGML_ASSERT(ned0 == D);
|
13607
|
+
|
13608
|
+
GGML_ASSERT(neq1 == N);
|
13609
|
+
GGML_ASSERT(nek1 == N + P);
|
13610
|
+
GGML_ASSERT(nev1 == D);
|
13611
|
+
GGML_ASSERT(ned1 == N);
|
13612
|
+
|
13613
|
+
// dst cannot be transposed or permuted
|
13614
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
13615
|
+
GGML_ASSERT(nb0 <= nb1);
|
13616
|
+
GGML_ASSERT(nb1 <= nb2);
|
13617
|
+
GGML_ASSERT(nb2 <= nb3);
|
13618
|
+
|
13619
|
+
if (params->type == GGML_TASK_INIT) {
|
13620
|
+
if (ith == 0) {
|
13621
|
+
memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
|
13622
|
+
}
|
13623
|
+
return;
|
13624
|
+
}
|
13625
|
+
|
13626
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
13627
|
+
return;
|
13628
|
+
}
|
13629
|
+
|
13630
|
+
// parallelize by q rows using ggml_vec_dot_f32
|
13631
|
+
|
13632
|
+
// total rows in q
|
13633
|
+
const int nr = neq2*neq3;
|
13634
|
+
|
13635
|
+
// rows per thread
|
13636
|
+
const int dr = (nr + nth - 1)/nth;
|
13637
|
+
|
13638
|
+
// row range for this thread
|
13639
|
+
const int ir0 = dr*ith;
|
13640
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
13641
|
+
|
13642
|
+
const float scale = 1.0f/sqrtf(D);
|
13643
|
+
|
13644
|
+
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
13645
|
+
|
13646
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
13647
|
+
// q indices
|
13648
|
+
const int iq3 = ir/(neq2);
|
13649
|
+
const int iq2 = ir - iq3*neq2;
|
13650
|
+
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
13651
|
+
|
13652
|
+
|
13653
|
+
// not sure about CACHE_LINE_SIZE_F32..
|
13654
|
+
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
13655
|
+
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
13656
|
+
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
13657
|
+
|
13658
|
+
for (int i = M; i < Mup; ++i) {
|
13659
|
+
S[i] = -INFINITY;
|
13660
|
+
}
|
13661
|
+
|
13662
|
+
for (int64_t ic = 0; ic < nek1; ++ic) {
|
13663
|
+
// k indices
|
13664
|
+
const int ik3 = iq3;
|
13665
|
+
const int ik2 = iq2;
|
13666
|
+
const int ik1 = ic;
|
13667
|
+
|
13668
|
+
// S indices
|
13669
|
+
const int i1 = ik1;
|
13670
|
+
|
13671
|
+
ggml_vec_dot_f32(neq0,
|
13672
|
+
S + i1,
|
13673
|
+
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
13674
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
13675
|
+
}
|
13676
|
+
|
13677
|
+
// scale
|
13678
|
+
ggml_vec_scale_f32(nek1, S, scale);
|
13679
|
+
|
13680
|
+
if (masked) {
|
13681
|
+
for (int64_t i = P; i < M; i++) {
|
13682
|
+
if (i > P + iq1) {
|
13683
|
+
S[i] = -INFINITY;
|
13684
|
+
}
|
13685
|
+
}
|
13686
|
+
}
|
13687
|
+
|
13688
|
+
// softmax
|
13689
|
+
{
|
13690
|
+
float max = -INFINITY;
|
13691
|
+
ggml_vec_max_f32(M, &max, S);
|
13692
|
+
|
13693
|
+
ggml_float sum = 0.0;
|
13694
|
+
{
|
13695
|
+
#ifdef GGML_SOFT_MAX_ACCELERATE
|
13696
|
+
max = -max;
|
13697
|
+
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
13698
|
+
vvexpf(SM, SM, &Mup);
|
13699
|
+
ggml_vec_sum_f32(Mup, &sum, SM);
|
13700
|
+
#else
|
13701
|
+
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
13702
|
+
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
13703
|
+
|
13704
|
+
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
13705
|
+
float * SR = S + i;
|
13706
|
+
float * SW = SM + i;
|
13707
|
+
|
13708
|
+
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
13709
|
+
if (SR[j] == -INFINITY) {
|
13710
|
+
SW[j] = 0.0f;
|
13711
|
+
} else {
|
13712
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
13713
|
+
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
13714
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
13715
|
+
sump[j] += (ggml_float)val;
|
13716
|
+
SW[j] = val;
|
13717
|
+
}
|
13718
|
+
}
|
13719
|
+
}
|
13720
|
+
|
13721
|
+
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
13722
|
+
sum += sump[i];
|
13723
|
+
}
|
13724
|
+
#endif
|
13725
|
+
}
|
13726
|
+
|
13727
|
+
assert(sum > 0.0);
|
13728
|
+
|
13729
|
+
sum = 1.0/sum;
|
13730
|
+
ggml_vec_scale_f32(M, SM, sum);
|
13731
|
+
|
13732
|
+
}
|
13733
|
+
|
13734
|
+
// step-by-step explanation
|
13735
|
+
{
|
13736
|
+
// forward-process shape grads from backward process
|
13737
|
+
// parallel_for iq2,iq3:
|
13738
|
+
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
13739
|
+
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
13740
|
+
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
13741
|
+
// for iq1:
|
13742
|
+
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
13743
|
+
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
13744
|
+
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
13745
|
+
// S0 = -Inf [D,1,1,1]
|
13746
|
+
// ~S1[i] = dot(kcur[:D,i], qcur)
|
13747
|
+
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
13748
|
+
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
13749
|
+
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13750
|
+
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
13751
|
+
// ~S5[i] = dot(vcur[:,i], S4)
|
13752
|
+
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
13753
|
+
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
13754
|
+
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
13755
|
+
// dst backward-/ grad[dst] = d
|
13756
|
+
//
|
13757
|
+
// output gradients with their dependencies:
|
13758
|
+
//
|
13759
|
+
// grad[kcur] = grad[S1].T @ qcur
|
13760
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
13761
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13762
|
+
// grad[S4] = grad[S5] @ vcur
|
13763
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
13764
|
+
// grad[qcur] = grad[S1] @ kcur
|
13765
|
+
// grad[vcur] = grad[S5].T @ S4
|
13766
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
13767
|
+
//
|
13768
|
+
// in post-order:
|
13769
|
+
//
|
13770
|
+
// S1 = qcur @ kcur.T
|
13771
|
+
// S2 = S1 * scale
|
13772
|
+
// S3 = diag_mask_inf(S2, P)
|
13773
|
+
// S4 = softmax(S3)
|
13774
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
13775
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13776
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
13777
|
+
// grad[qcur] = grad[S1] @ kcur
|
13778
|
+
// grad[kcur] = grad[S1].T @ qcur
|
13779
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
13780
|
+
//
|
13781
|
+
// using less variables (SM=S4):
|
13782
|
+
//
|
13783
|
+
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
13784
|
+
// SM = softmax(S)
|
13785
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
13786
|
+
// dot_SM_gradSM = dot(SM, S)
|
13787
|
+
// S = SM * (S - dot(SM, S))
|
13788
|
+
// S = diag_mask_zero(S, P) * scale
|
13789
|
+
//
|
13790
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
13791
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
13792
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
13793
|
+
}
|
13794
|
+
|
13795
|
+
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
13796
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
13797
|
+
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
13798
|
+
ggml_vec_set_f32(M, S, 0);
|
13799
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
13800
|
+
// dst indices
|
13801
|
+
const int i1 = iq1;
|
13802
|
+
const int i2 = iq2;
|
13803
|
+
const int i3 = iq3;
|
13804
|
+
|
13805
|
+
ggml_vec_mad_f32(M,
|
13806
|
+
S,
|
13807
|
+
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
13808
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
13809
|
+
}
|
13810
|
+
|
13811
|
+
// S = SM * (S - dot(SM, S))
|
13812
|
+
float dot_SM_gradSM = 0;
|
13813
|
+
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
13814
|
+
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
13815
|
+
ggml_vec_mul_f32 (M, S, S, SM);
|
13816
|
+
|
13817
|
+
// S = diag_mask_zero(S, P) * scale
|
13818
|
+
if (masked) {
|
13819
|
+
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
13820
|
+
// S[i] = 0;
|
13821
|
+
// }
|
13822
|
+
for (int64_t i = P; i < M; i++) {
|
13823
|
+
if (i > P + iq1) {
|
13824
|
+
S[i] = 0;
|
13825
|
+
}
|
13826
|
+
}
|
13827
|
+
}
|
13828
|
+
ggml_vec_scale_f32(M, S, scale);
|
13829
|
+
|
13830
|
+
void * grad_q = (char *) dst->data;
|
13831
|
+
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
13832
|
+
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
13833
|
+
|
13834
|
+
const size_t nbgq1 = nb0*neq0;
|
13835
|
+
const size_t nbgq2 = nb0*neq0*neq1;
|
13836
|
+
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
13837
|
+
|
13838
|
+
const size_t nbgk1 = nb0*nek0;
|
13839
|
+
const size_t nbgk2 = nb0*nek0*nek1;
|
13840
|
+
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
13841
|
+
|
13842
|
+
const size_t nbgv1 = nb0*nev0;
|
13843
|
+
const size_t nbgv2 = nb0*nev0*nev1;
|
13844
|
+
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
13845
|
+
|
13846
|
+
// S shape [M,1]
|
13847
|
+
// SM shape [M,1]
|
13848
|
+
// kcur shape [D,M]
|
13849
|
+
// qcur shape [D,1]
|
13850
|
+
// vcur shape [M,D]
|
13851
|
+
//
|
13852
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
13853
|
+
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
13854
|
+
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
13855
|
+
//
|
13856
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
13857
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
13858
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
13859
|
+
// dst indices
|
13860
|
+
const int i1 = iq1;
|
13861
|
+
const int i2 = iq2;
|
13862
|
+
const int i3 = iq3;
|
13863
|
+
|
13864
|
+
ggml_vec_mad_f32(D,
|
13865
|
+
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
13866
|
+
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
|
13867
|
+
S[ic]);
|
13868
|
+
}
|
13869
|
+
|
13870
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
13871
|
+
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
13872
|
+
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
13873
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
13874
|
+
// dst indices
|
13875
|
+
const int i1 = iq1;
|
13876
|
+
const int i2 = iq2;
|
13877
|
+
const int i3 = iq3;
|
13878
|
+
|
13879
|
+
// ggml_vec_set_f32(D,
|
13880
|
+
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
13881
|
+
// 0);
|
13882
|
+
ggml_vec_mad_f32(D,
|
13883
|
+
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
13884
|
+
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
|
13885
|
+
S[ic]);
|
13886
|
+
}
|
13887
|
+
|
13888
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
13889
|
+
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
|
13890
|
+
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
|
13891
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
13892
|
+
// dst indices
|
13893
|
+
const int i1 = iq1;
|
13894
|
+
const int i2 = iq2;
|
13895
|
+
const int i3 = iq3;
|
13896
|
+
|
13897
|
+
// ggml_vec_set_f32(M,
|
13898
|
+
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
13899
|
+
// 0);
|
13900
|
+
ggml_vec_mad_f32(M,
|
13901
|
+
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
13902
|
+
SM,
|
13903
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
13904
|
+
}
|
13905
|
+
}
|
13906
|
+
}
|
13907
|
+
}
|
13908
|
+
|
13909
|
+
static void ggml_compute_forward_flash_attn_back(
|
13910
|
+
const struct ggml_compute_params * params,
|
13911
|
+
const struct ggml_tensor * q,
|
13912
|
+
const struct ggml_tensor * k,
|
13913
|
+
const struct ggml_tensor * v,
|
13914
|
+
const struct ggml_tensor * d,
|
13915
|
+
const bool masked,
|
13916
|
+
struct ggml_tensor * dst) {
|
13917
|
+
switch (q->type) {
|
13918
|
+
case GGML_TYPE_F32:
|
13919
|
+
{
|
13920
|
+
ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
|
13921
|
+
} break;
|
13922
|
+
default:
|
13923
|
+
{
|
13924
|
+
GGML_ASSERT(false);
|
13925
|
+
} break;
|
13926
|
+
}
|
13927
|
+
}
|
13928
|
+
|
13929
|
+
// ggml_compute_forward_map_unary
|
13930
|
+
|
13931
|
+
static void ggml_compute_forward_map_unary_f32(
|
13932
|
+
const struct ggml_compute_params * params,
|
13933
|
+
const struct ggml_tensor * src0,
|
13934
|
+
struct ggml_tensor * dst,
|
13935
|
+
const ggml_unary_op_f32_t fun) {
|
13936
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
13937
|
+
|
13938
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
13939
|
+
return;
|
13940
|
+
}
|
13941
|
+
|
13942
|
+
const int n = ggml_nrows(src0);
|
13943
|
+
const int nc = src0->ne[0];
|
13944
|
+
|
13945
|
+
assert( dst->nb[0] == sizeof(float));
|
13946
|
+
assert(src0->nb[0] == sizeof(float));
|
13947
|
+
|
13948
|
+
for (int i = 0; i < n; i++) {
|
13949
|
+
fun(nc,
|
13950
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
13951
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
13952
|
+
}
|
13953
|
+
}
|
13954
|
+
|
13955
|
+
|
13956
|
+
static void ggml_compute_forward_map_unary(
|
13957
|
+
const struct ggml_compute_params * params,
|
13958
|
+
const struct ggml_tensor * src0,
|
13959
|
+
struct ggml_tensor * dst,
|
13960
|
+
const ggml_unary_op_f32_t fun) {
|
13961
|
+
switch (src0->type) {
|
13962
|
+
case GGML_TYPE_F32:
|
13963
|
+
{
|
13964
|
+
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
13965
|
+
} break;
|
13966
|
+
default:
|
13967
|
+
{
|
13968
|
+
GGML_ASSERT(false);
|
13969
|
+
} break;
|
13970
|
+
}
|
13971
|
+
}
|
13972
|
+
|
13973
|
+
// ggml_compute_forward_map_binary
|
13974
|
+
|
13975
|
+
static void ggml_compute_forward_map_binary_f32(
|
13976
|
+
const struct ggml_compute_params * params,
|
13977
|
+
const struct ggml_tensor * src0,
|
13978
|
+
const struct ggml_tensor * src1,
|
13979
|
+
struct ggml_tensor * dst,
|
13980
|
+
const ggml_binary_op_f32_t fun) {
|
13981
|
+
assert(params->ith == 0);
|
13982
|
+
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
13983
|
+
|
13984
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
13985
|
+
return;
|
13986
|
+
}
|
13987
|
+
|
13988
|
+
const int n = ggml_nrows(src0);
|
13989
|
+
const int nc = src0->ne[0];
|
13990
|
+
|
13991
|
+
assert( dst->nb[0] == sizeof(float));
|
13992
|
+
assert(src0->nb[0] == sizeof(float));
|
13993
|
+
assert(src1->nb[0] == sizeof(float));
|
13994
|
+
|
13995
|
+
for (int i = 0; i < n; i++) {
|
13996
|
+
fun(nc,
|
13997
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
13998
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
13999
|
+
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
14000
|
+
}
|
14001
|
+
}
|
14002
|
+
|
14003
|
+
|
14004
|
+
static void ggml_compute_forward_map_binary(
|
14005
|
+
const struct ggml_compute_params * params,
|
14006
|
+
const struct ggml_tensor * src0,
|
14007
|
+
const struct ggml_tensor * src1,
|
14008
|
+
struct ggml_tensor * dst,
|
14009
|
+
const ggml_binary_op_f32_t fun) {
|
14010
|
+
switch (src0->type) {
|
14011
|
+
case GGML_TYPE_F32:
|
14012
|
+
{
|
14013
|
+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
14014
|
+
} break;
|
14015
|
+
default:
|
14016
|
+
{
|
14017
|
+
GGML_ASSERT(false);
|
14018
|
+
} break;
|
14019
|
+
}
|
14020
|
+
}
|
14021
|
+
|
14022
|
+
// ggml_compute_forward_cross_entropy_loss
|
14023
|
+
|
14024
|
+
static void ggml_compute_forward_cross_entropy_loss_f32(
|
14025
|
+
const struct ggml_compute_params * params,
|
14026
|
+
const struct ggml_tensor * src0,
|
14027
|
+
const struct ggml_tensor * src1,
|
14028
|
+
struct ggml_tensor * dst) {
|
14029
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14030
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14031
|
+
GGML_ASSERT(ggml_is_scalar(dst));
|
14032
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
14033
|
+
|
14034
|
+
const int ith = params->ith;
|
14035
|
+
const int nth = params->nth;
|
14036
|
+
|
14037
|
+
float * sums = (float *) params->wdata;
|
14038
|
+
|
14039
|
+
// TODO: handle transposed/permuted matrices
|
14040
|
+
const int nc = src0->ne[0];
|
14041
|
+
const int nr = ggml_nrows(src0);
|
14042
|
+
|
14043
|
+
if (params->type == GGML_TASK_INIT) {
|
14044
|
+
if (ith == 0) {
|
14045
|
+
memset(sums, 0, sizeof(float) * (nth + nth * nc));
|
14046
|
+
}
|
14047
|
+
return;
|
14048
|
+
}
|
14049
|
+
|
14050
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14051
|
+
if (ith == 0) {
|
14052
|
+
float * dp = (float *) dst->data;
|
14053
|
+
ggml_vec_sum_f32(nth, dp, sums);
|
14054
|
+
dp[0] *= -1.0f;
|
14055
|
+
}
|
14056
|
+
return;
|
14057
|
+
}
|
14058
|
+
|
14059
|
+
const double eps = 1e-9;
|
14060
|
+
|
14061
|
+
// rows per thread
|
14062
|
+
const int dr = (nr + nth - 1)/nth;
|
14063
|
+
|
14064
|
+
// row range for this thread
|
14065
|
+
const int ir0 = dr*ith;
|
14066
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
14067
|
+
|
14068
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
14069
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14070
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14071
|
+
float * st = (float *) params->wdata + nth + ith*nc;
|
14072
|
+
|
14073
|
+
#ifndef NDEBUG
|
14074
|
+
for (int i = 0; i < nc; ++i) {
|
14075
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14076
|
+
assert(!isnan(s0[i]));
|
14077
|
+
assert(!isnan(s1[i]));
|
14078
|
+
}
|
14079
|
+
#endif
|
14080
|
+
// soft_max
|
14081
|
+
ggml_float sum = 0.0;
|
14082
|
+
{
|
14083
|
+
float max = -INFINITY;
|
14084
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14085
|
+
|
14086
|
+
uint16_t scvt;
|
14087
|
+
for (int i = 0; i < nc; i++) {
|
14088
|
+
if (s0[i] == -INFINITY) {
|
14089
|
+
st[i] = 0.0f;
|
14090
|
+
} else {
|
14091
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14092
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14093
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14094
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14095
|
+
sum += (ggml_float)val;
|
14096
|
+
st[i] = val;
|
14097
|
+
}
|
14098
|
+
}
|
14099
|
+
|
14100
|
+
assert(sum > 0.0);
|
14101
|
+
// sum = 1.0/sum;
|
14102
|
+
}
|
14103
|
+
// avoid log(0) by rescaling from [0..1] to [eps..1]
|
14104
|
+
sum = (1.0 - eps) / sum;
|
14105
|
+
ggml_vec_scale_f32(nc, st, sum);
|
14106
|
+
ggml_vec_add1_f32(nc, st, st, eps);
|
14107
|
+
ggml_vec_log_f32(nc, st, st);
|
14108
|
+
ggml_vec_mul_f32(nc, st, st, s1);
|
14109
|
+
|
14110
|
+
ggml_vec_sum_f32(nc, sums + ith, st);
|
14111
|
+
|
14112
|
+
#ifndef NDEBUG
|
14113
|
+
for (int i = 0; i < nc; ++i) {
|
14114
|
+
assert(!isnan(st[i]));
|
14115
|
+
assert(!isinf(st[i]));
|
14116
|
+
}
|
14117
|
+
#endif
|
14118
|
+
}
|
14119
|
+
|
14120
|
+
}
|
14121
|
+
|
14122
|
+
static void ggml_compute_forward_cross_entropy_loss(
|
14123
|
+
const struct ggml_compute_params * params,
|
14124
|
+
const struct ggml_tensor * src0,
|
14125
|
+
const struct ggml_tensor * src1,
|
14126
|
+
struct ggml_tensor * dst) {
|
12973
14127
|
switch (src0->type) {
|
12974
14128
|
case GGML_TYPE_F32:
|
12975
14129
|
{
|
12976
|
-
|
14130
|
+
ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
|
12977
14131
|
} break;
|
12978
14132
|
default:
|
12979
14133
|
{
|
@@ -12982,47 +14136,160 @@ static void ggml_compute_forward_map_unary(
|
|
12982
14136
|
}
|
12983
14137
|
}
|
12984
14138
|
|
12985
|
-
//
|
14139
|
+
// ggml_compute_forward_cross_entropy_loss_back
|
12986
14140
|
|
12987
|
-
static void
|
14141
|
+
static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
12988
14142
|
const struct ggml_compute_params * params,
|
12989
14143
|
const struct ggml_tensor * src0,
|
12990
14144
|
const struct ggml_tensor * src1,
|
12991
|
-
struct ggml_tensor *
|
12992
|
-
|
12993
|
-
|
12994
|
-
|
14145
|
+
const struct ggml_tensor * opt0,
|
14146
|
+
struct ggml_tensor * dst) {
|
14147
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
14148
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14149
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14150
|
+
GGML_ASSERT(ggml_is_contiguous(opt0));
|
14151
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14152
|
+
|
14153
|
+
const int64_t ith = params->ith;
|
14154
|
+
const int64_t nth = params->nth;
|
12995
14155
|
|
12996
14156
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
12997
14157
|
return;
|
12998
14158
|
}
|
12999
14159
|
|
13000
|
-
const
|
13001
|
-
const int nc = src0->ne[0];
|
14160
|
+
const float eps = 1e-9f;
|
13002
14161
|
|
13003
|
-
|
13004
|
-
|
13005
|
-
|
14162
|
+
// TODO: handle transposed/permuted matrices
|
14163
|
+
const int64_t nc = src0->ne[0];
|
14164
|
+
const int64_t nr = ggml_nrows(src0);
|
13006
14165
|
|
13007
|
-
|
13008
|
-
|
13009
|
-
|
13010
|
-
|
13011
|
-
|
14166
|
+
// rows per thread
|
14167
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
14168
|
+
|
14169
|
+
// row range for this thread
|
14170
|
+
const int64_t ir0 = dr*ith;
|
14171
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
14172
|
+
|
14173
|
+
float * d = (float *) opt0->data;
|
14174
|
+
|
14175
|
+
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
14176
|
+
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
14177
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14178
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14179
|
+
float * sm = (float *) params->wdata + ith*nc;
|
14180
|
+
|
14181
|
+
#ifndef NDEBUG
|
14182
|
+
for (int i = 0; i < nc; ++i) {
|
14183
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14184
|
+
assert(!isnan(s0[i]));
|
14185
|
+
assert(!isnan(s1[i]));
|
14186
|
+
}
|
14187
|
+
#endif
|
14188
|
+
// step by step explanation:
|
14189
|
+
{
|
14190
|
+
//float * sums = (float *) params->wdata;
|
14191
|
+
|
14192
|
+
// forward pass with annotated gradients from backward pass
|
14193
|
+
// (built by going in reverse operation order, adding to gradients of current operation args)
|
14194
|
+
// st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
|
14195
|
+
// from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14196
|
+
// ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
|
14197
|
+
// ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
|
14198
|
+
// ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
|
14199
|
+
// ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
|
14200
|
+
// ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
|
14201
|
+
// ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
|
14202
|
+
|
14203
|
+
// substitute into grad[st1], because we can reuse softmax_back from this point on
|
14204
|
+
// grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
|
14205
|
+
// postorder:
|
14206
|
+
// grad[st1] := softmax(s0)
|
14207
|
+
// grad[st1] := grad[st1]*(1.0 - eps)
|
14208
|
+
// grad[st1] := grad[st1] + eps
|
14209
|
+
// grad[st1] := s1 / grad[st1]
|
14210
|
+
// grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
|
14211
|
+
|
14212
|
+
// src0 gradients by going through softmax_back
|
14213
|
+
// grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14214
|
+
// from softmax_back:
|
14215
|
+
// dxk = yk * (dyk - dot(y, dy))
|
14216
|
+
// dot_y_dy := dot(y, dy)
|
14217
|
+
// dx := dy
|
14218
|
+
// dx := dx - dot_y_dy
|
14219
|
+
// dx := dx * y
|
14220
|
+
// postorder:
|
14221
|
+
// dot_st1_dst1 := dot(st1, grad[st1])
|
14222
|
+
// grad[s0] := grad[st1]
|
14223
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14224
|
+
// grad[s0] := grad[s0] * st1
|
14225
|
+
|
14226
|
+
// prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
|
14227
|
+
// sm := softmax(s0)
|
14228
|
+
// grad[s0] := sm*(1.0 - eps)
|
14229
|
+
// grad[s0] := grad[s0] + eps
|
14230
|
+
// grad[s0] := s1 / grad[s0]
|
14231
|
+
// grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
|
14232
|
+
// dot_st1_dst1 := dot(sm, grad[s0])
|
14233
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14234
|
+
// grad[s0] := grad[s0] * sm
|
14235
|
+
}
|
14236
|
+
|
14237
|
+
// soft_max
|
14238
|
+
ggml_float sum = 0.0;
|
14239
|
+
{
|
14240
|
+
float max = -INFINITY;
|
14241
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14242
|
+
|
14243
|
+
uint16_t scvt;
|
14244
|
+
for (int i = 0; i < nc; i++) {
|
14245
|
+
if (s0[i] == -INFINITY) {
|
14246
|
+
sm[i] = 0.0f;
|
14247
|
+
} else {
|
14248
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14249
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14250
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14251
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14252
|
+
sum += (ggml_float)val;
|
14253
|
+
sm[i] = val;
|
14254
|
+
}
|
14255
|
+
}
|
14256
|
+
|
14257
|
+
assert(sum > 0.0);
|
14258
|
+
sum = 1.0/sum;
|
14259
|
+
}
|
14260
|
+
|
14261
|
+
float dot_st1_dst1 = 0;
|
14262
|
+
ggml_vec_scale_f32(nc, sm, sum);
|
14263
|
+
ggml_vec_cpy_f32 (nc, ds0, sm);
|
14264
|
+
ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
|
14265
|
+
ggml_vec_add1_f32 (nc, ds0, ds0, eps);
|
14266
|
+
ggml_vec_div_f32 (nc, ds0, s1, ds0);
|
14267
|
+
ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
|
14268
|
+
ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
|
14269
|
+
ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
|
14270
|
+
ggml_vec_mul_f32 (nc, ds0, ds0, sm);
|
14271
|
+
|
14272
|
+
#ifndef NDEBUG
|
14273
|
+
for (int i = 0; i < nc; ++i) {
|
14274
|
+
assert(!isnan(sm[i]));
|
14275
|
+
assert(!isinf(sm[i]));
|
14276
|
+
assert(!isnan(ds0[i]));
|
14277
|
+
assert(!isinf(ds0[i]));
|
14278
|
+
}
|
14279
|
+
#endif
|
13012
14280
|
}
|
13013
14281
|
}
|
13014
14282
|
|
13015
|
-
|
13016
|
-
static void ggml_compute_forward_map_binary(
|
14283
|
+
static void ggml_compute_forward_cross_entropy_loss_back(
|
13017
14284
|
const struct ggml_compute_params * params,
|
13018
14285
|
const struct ggml_tensor * src0,
|
13019
14286
|
const struct ggml_tensor * src1,
|
13020
|
-
struct ggml_tensor *
|
13021
|
-
|
14287
|
+
const struct ggml_tensor * opt0,
|
14288
|
+
struct ggml_tensor * dst) {
|
13022
14289
|
switch (src0->type) {
|
13023
14290
|
case GGML_TYPE_F32:
|
13024
14291
|
{
|
13025
|
-
|
14292
|
+
ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
|
13026
14293
|
} break;
|
13027
14294
|
default:
|
13028
14295
|
{
|
@@ -13031,6 +14298,7 @@ static void ggml_compute_forward_map_binary(
|
|
13031
14298
|
}
|
13032
14299
|
}
|
13033
14300
|
|
14301
|
+
|
13034
14302
|
/////////////////////////////////
|
13035
14303
|
|
13036
14304
|
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
@@ -13102,6 +14370,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13102
14370
|
{
|
13103
14371
|
ggml_compute_forward_repeat(params, tensor->src0, tensor);
|
13104
14372
|
} break;
|
14373
|
+
case GGML_OP_REPEAT_BACK:
|
14374
|
+
{
|
14375
|
+
ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
|
14376
|
+
} break;
|
13105
14377
|
case GGML_OP_ABS:
|
13106
14378
|
{
|
13107
14379
|
ggml_compute_forward_abs(params, tensor->src0, tensor);
|
@@ -13150,6 +14422,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13150
14422
|
{
|
13151
14423
|
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
|
13152
14424
|
} break;
|
14425
|
+
case GGML_OP_OUT_PROD:
|
14426
|
+
{
|
14427
|
+
ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
|
14428
|
+
} break;
|
13153
14429
|
case GGML_OP_SCALE:
|
13154
14430
|
{
|
13155
14431
|
ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
|
@@ -13206,6 +14482,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13206
14482
|
{
|
13207
14483
|
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
13208
14484
|
} break;
|
14485
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14486
|
+
{
|
14487
|
+
ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
|
14488
|
+
} break;
|
13209
14489
|
case GGML_OP_ROPE:
|
13210
14490
|
{
|
13211
14491
|
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
@@ -13241,6 +14521,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13241
14521
|
{
|
13242
14522
|
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
13243
14523
|
} break;
|
14524
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
14525
|
+
{
|
14526
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
|
14527
|
+
GGML_ASSERT(t == 0 || t == 1);
|
14528
|
+
bool masked = t != 0;
|
14529
|
+
ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
|
14530
|
+
} break;
|
13244
14531
|
case GGML_OP_MAP_UNARY:
|
13245
14532
|
{
|
13246
14533
|
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
@@ -13253,6 +14540,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13253
14540
|
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
|
13254
14541
|
}
|
13255
14542
|
break;
|
14543
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
14544
|
+
{
|
14545
|
+
ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
|
14546
|
+
}
|
14547
|
+
break;
|
14548
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
14549
|
+
{
|
14550
|
+
ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
|
14551
|
+
}
|
14552
|
+
break;
|
13256
14553
|
case GGML_OP_NONE:
|
13257
14554
|
{
|
13258
14555
|
// nop
|
@@ -13391,11 +14688,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13391
14688
|
src0->grad =
|
13392
14689
|
ggml_add_impl(ctx,
|
13393
14690
|
src0->grad,
|
13394
|
-
|
13395
|
-
tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
|
14691
|
+
ggml_scale(ctx,
|
13396
14692
|
ggml_div(ctx,
|
13397
|
-
|
13398
|
-
tensor)
|
14693
|
+
tensor->grad,
|
14694
|
+
tensor),
|
14695
|
+
ggml_new_f32(ctx, 0.5f)),
|
13399
14696
|
inplace);
|
13400
14697
|
}
|
13401
14698
|
} break;
|
@@ -13441,43 +14738,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13441
14738
|
{
|
13442
14739
|
// necessary for llama
|
13443
14740
|
if (src0->grad) {
|
13444
|
-
|
13445
|
-
|
13446
|
-
|
13447
|
-
|
13448
|
-
|
13449
|
-
|
13450
|
-
|
13451
|
-
|
13452
|
-
|
13453
|
-
//
|
13454
|
-
|
13455
|
-
|
13456
|
-
|
13457
|
-
|
13458
|
-
// transpose [nc0*nr0,1,1]
|
13459
|
-
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
|
13460
|
-
// add to src0->grad
|
13461
|
-
|
13462
|
-
int64_t ne[4] = {nc0,ncr,nr0,nrr};
|
13463
|
-
|
13464
|
-
struct ggml_tensor* F00 = tensor->grad;
|
13465
|
-
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
|
13466
|
-
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
|
13467
|
-
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
|
13468
|
-
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
|
13469
|
-
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
|
13470
|
-
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
|
13471
|
-
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
|
13472
|
-
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
|
13473
|
-
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
|
13474
|
-
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
|
13475
|
-
|
13476
|
-
src0->grad =
|
13477
|
-
ggml_add_impl(ctx,
|
13478
|
-
src0->grad,
|
13479
|
-
F10,
|
13480
|
-
inplace);
|
14741
|
+
src0->grad = ggml_add_impl(ctx,
|
14742
|
+
src0->grad,
|
14743
|
+
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
14744
|
+
inplace);
|
14745
|
+
}
|
14746
|
+
} break;
|
14747
|
+
case GGML_OP_REPEAT_BACK:
|
14748
|
+
{
|
14749
|
+
if (src0->grad) {
|
14750
|
+
// TODO: test this
|
14751
|
+
src0->grad = ggml_add_impl(ctx,
|
14752
|
+
src0->grad,
|
14753
|
+
ggml_repeat(ctx, tensor->grad, src0->grad),
|
14754
|
+
inplace);
|
13481
14755
|
}
|
13482
14756
|
} break;
|
13483
14757
|
case GGML_OP_ABS:
|
@@ -13584,38 +14858,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13584
14858
|
|
13585
14859
|
// necessary for llama
|
13586
14860
|
if (src0->grad) {
|
13587
|
-
// TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
|
13588
14861
|
src0->grad =
|
13589
14862
|
ggml_add_impl(ctx,
|
13590
14863
|
src0->grad,
|
13591
|
-
//
|
13592
|
-
|
13593
|
-
|
13594
|
-
// tensor->grad), // [m,p]
|
13595
|
-
// for now just using A*B==(B.T*A.T).T
|
13596
|
-
ggml_cont(ctx, // [n,m]
|
13597
|
-
ggml_transpose(ctx, // [n,m]
|
13598
|
-
ggml_mul_mat(ctx, // [m,n]
|
13599
|
-
ggml_cont(ctx, // [p,m]
|
13600
|
-
ggml_transpose(ctx, // [p,m]
|
13601
|
-
tensor->grad)), // [m,p]
|
13602
|
-
ggml_cont(ctx, // [p,n]
|
13603
|
-
ggml_transpose(ctx, // [p,n]
|
13604
|
-
src1))))), // [n,p]
|
14864
|
+
ggml_out_prod(ctx, // [n,m]
|
14865
|
+
src1, // [n,p]
|
14866
|
+
tensor->grad), // [m,p]
|
13605
14867
|
inplace);
|
13606
14868
|
}
|
13607
14869
|
if (src1->grad) {
|
13608
14870
|
src1->grad =
|
13609
14871
|
ggml_add_impl(ctx,
|
13610
14872
|
src1->grad,
|
13611
|
-
//
|
13612
|
-
|
13613
|
-
|
13614
|
-
|
13615
|
-
|
14873
|
+
// ggml_mul_mat(ctx, // [n,p]
|
14874
|
+
// ggml_cont(ctx, // [m,n]
|
14875
|
+
// ggml_transpose(ctx, src0)), // [m,n]
|
14876
|
+
// tensor->grad), // [m,p]
|
14877
|
+
|
14878
|
+
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
14879
|
+
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
14880
|
+
// // and then use ggml_out_prod
|
14881
|
+
ggml_out_prod(ctx, // [n,p]
|
14882
|
+
src0, // [n,m]
|
14883
|
+
ggml_transpose(ctx, // [p,m]
|
14884
|
+
tensor->grad)), // [m,p]
|
13616
14885
|
inplace);
|
13617
14886
|
}
|
13618
14887
|
} break;
|
14888
|
+
case GGML_OP_OUT_PROD:
|
14889
|
+
{
|
14890
|
+
GGML_ASSERT(false); // TODO: not implemented
|
14891
|
+
} break;
|
13619
14892
|
case GGML_OP_SCALE:
|
13620
14893
|
{
|
13621
14894
|
// necessary for llama
|
@@ -13717,7 +14990,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13717
14990
|
// necessary for llama
|
13718
14991
|
if (src0->grad) {
|
13719
14992
|
size_t offset;
|
13720
|
-
|
14993
|
+
|
14994
|
+
GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
|
14995
|
+
memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
|
13721
14996
|
|
13722
14997
|
size_t nb1 = tensor->nb[1];
|
13723
14998
|
size_t nb2 = tensor->nb[2];
|
@@ -13744,10 +15019,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13744
15019
|
{
|
13745
15020
|
// necessary for llama
|
13746
15021
|
if (src0->grad) {
|
13747
|
-
|
13748
|
-
int
|
13749
|
-
int
|
13750
|
-
int
|
15022
|
+
int32_t * axes = (int32_t *) tensor->opt[0]->data;
|
15023
|
+
int axis0 = axes[0] & 0x3;
|
15024
|
+
int axis1 = axes[1] & 0x3;
|
15025
|
+
int axis2 = axes[2] & 0x3;
|
15026
|
+
int axis3 = axes[3] & 0x3;
|
13751
15027
|
int axes_backward[4] = {0,0,0,0};
|
13752
15028
|
axes_backward[axis0] = 0;
|
13753
15029
|
axes_backward[axis1] = 1;
|
@@ -13831,50 +15107,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13831
15107
|
{
|
13832
15108
|
// necessary for llama
|
13833
15109
|
if (src0->grad) {
|
13834
|
-
// y = softmax(x)
|
13835
|
-
//
|
13836
|
-
// Jii = yi - yi*yi
|
13837
|
-
// Jij = -yi*yj
|
13838
|
-
// J = diag(y)-y.*y
|
13839
|
-
// dx = J * dy
|
13840
|
-
// dxk = sum(Jkj * dyk)
|
13841
|
-
|
13842
|
-
int64_t ne2[4] = {
|
13843
|
-
tensor->ne[0],
|
13844
|
-
1,
|
13845
|
-
tensor->ne[1]*tensor->ne[2],
|
13846
|
-
tensor->ne[3]
|
13847
|
-
};
|
13848
|
-
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
13849
|
-
ggml_reshape_4d(ctx,
|
13850
|
-
ggml_cont(ctx, tensor),
|
13851
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13852
|
-
|
13853
|
-
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
13854
|
-
ggml_reshape_4d(ctx,
|
13855
|
-
ggml_cont(ctx, tensor->grad),
|
13856
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13857
|
-
|
13858
|
-
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
13859
|
-
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
13860
|
-
tensor2, // [ne0,1,ne1*ne2,ne3]
|
13861
|
-
1, 0, 2, 3));
|
13862
|
-
|
13863
15110
|
src0->grad =
|
13864
|
-
ggml_add_impl(ctx,
|
13865
|
-
|
13866
|
-
|
13867
|
-
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
13868
|
-
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13869
|
-
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13870
|
-
tensor2), // [ne0,1,ne1*ne2,ne3]
|
13871
|
-
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13872
|
-
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
13873
|
-
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
13874
|
-
grad2), // [ne0,1,ne1*ne2,ne3]
|
13875
|
-
src0->grad),
|
13876
|
-
inplace);
|
15111
|
+
ggml_add_impl(ctx, src0->grad,
|
15112
|
+
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
15113
|
+
inplace);
|
13877
15114
|
}
|
15115
|
+
|
15116
|
+
} break;
|
15117
|
+
case GGML_OP_SOFT_MAX_BACK:
|
15118
|
+
{
|
15119
|
+
GGML_ASSERT(false); // TODO: not implemented
|
13878
15120
|
} break;
|
13879
15121
|
case GGML_OP_ROPE:
|
13880
15122
|
{
|
@@ -13929,17 +15171,190 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13929
15171
|
} break;
|
13930
15172
|
case GGML_OP_FLASH_ATTN:
|
13931
15173
|
{
|
13932
|
-
|
15174
|
+
struct ggml_tensor * flash_grad = NULL;
|
15175
|
+
if (src0->grad || src1->grad || tensor->opt[0]->grad) {
|
15176
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
15177
|
+
GGML_ASSERT(t == 0 || t == 1);
|
15178
|
+
bool masked = t != 0;
|
15179
|
+
flash_grad =
|
15180
|
+
ggml_flash_attn_back(ctx,
|
15181
|
+
src0,
|
15182
|
+
src1,
|
15183
|
+
tensor->opt[0],
|
15184
|
+
tensor->grad,
|
15185
|
+
masked);
|
15186
|
+
}
|
15187
|
+
|
15188
|
+
if (src0->grad) {
|
15189
|
+
struct ggml_tensor * grad_q = NULL;
|
15190
|
+
const size_t nb0 = flash_grad->nb[0];
|
15191
|
+
const size_t offset = 0;
|
15192
|
+
switch(src0->n_dims) {
|
15193
|
+
case 2:
|
15194
|
+
{
|
15195
|
+
grad_q = ggml_view_2d(ctx,
|
15196
|
+
flash_grad,
|
15197
|
+
src0->ne[0],
|
15198
|
+
src0->ne[1],
|
15199
|
+
nb0*src0->ne[0],
|
15200
|
+
offset);
|
15201
|
+
} break;
|
15202
|
+
case 3:
|
15203
|
+
{
|
15204
|
+
grad_q = ggml_view_3d(ctx,
|
15205
|
+
flash_grad,
|
15206
|
+
src0->ne[0],
|
15207
|
+
src0->ne[1],
|
15208
|
+
src0->ne[2],
|
15209
|
+
nb0*src0->ne[0],
|
15210
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15211
|
+
offset);
|
15212
|
+
} break;
|
15213
|
+
case 4:
|
15214
|
+
{
|
15215
|
+
grad_q = ggml_view_4d(ctx,
|
15216
|
+
flash_grad,
|
15217
|
+
src0->ne[0],
|
15218
|
+
src0->ne[1],
|
15219
|
+
src0->ne[2],
|
15220
|
+
src0->ne[3],
|
15221
|
+
nb0*src0->ne[0],
|
15222
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15223
|
+
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
15224
|
+
offset);
|
15225
|
+
} break;
|
15226
|
+
}
|
15227
|
+
|
15228
|
+
src0->grad = ggml_add_impl(ctx,
|
15229
|
+
src0->grad,
|
15230
|
+
grad_q,
|
15231
|
+
inplace);
|
15232
|
+
}
|
15233
|
+
|
15234
|
+
if (src1->grad) {
|
15235
|
+
struct ggml_tensor * grad_k = NULL;
|
15236
|
+
const size_t nb0 = flash_grad->nb[0];
|
15237
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
|
15238
|
+
switch(src1->n_dims) {
|
15239
|
+
case 2:
|
15240
|
+
{
|
15241
|
+
grad_k = ggml_view_2d(ctx,
|
15242
|
+
flash_grad,
|
15243
|
+
src1->ne[0],
|
15244
|
+
src1->ne[1],
|
15245
|
+
nb0*src1->ne[0],
|
15246
|
+
offset);
|
15247
|
+
} break;
|
15248
|
+
case 3:
|
15249
|
+
{
|
15250
|
+
grad_k = ggml_view_3d(ctx,
|
15251
|
+
flash_grad,
|
15252
|
+
src1->ne[0],
|
15253
|
+
src1->ne[1],
|
15254
|
+
src1->ne[2],
|
15255
|
+
nb0*src1->ne[0],
|
15256
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15257
|
+
offset);
|
15258
|
+
} break;
|
15259
|
+
case 4:
|
15260
|
+
{
|
15261
|
+
grad_k = ggml_view_4d(ctx,
|
15262
|
+
flash_grad,
|
15263
|
+
src1->ne[0],
|
15264
|
+
src1->ne[1],
|
15265
|
+
src1->ne[2],
|
15266
|
+
src1->ne[3],
|
15267
|
+
nb0*src1->ne[0],
|
15268
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15269
|
+
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
15270
|
+
offset);
|
15271
|
+
} break;
|
15272
|
+
}
|
15273
|
+
|
15274
|
+
src1->grad = ggml_add_impl(ctx,
|
15275
|
+
src1->grad,
|
15276
|
+
grad_k,
|
15277
|
+
inplace);
|
15278
|
+
}
|
15279
|
+
|
15280
|
+
struct ggml_tensor * opt0 = tensor->opt[0];
|
15281
|
+
|
15282
|
+
if (opt0->grad) {
|
15283
|
+
struct ggml_tensor * grad_v = NULL;
|
15284
|
+
const size_t nb0 = flash_grad->nb[0];
|
15285
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
15286
|
+
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
15287
|
+
switch(opt0->n_dims) {
|
15288
|
+
case 2:
|
15289
|
+
{
|
15290
|
+
grad_v = ggml_view_2d(ctx,
|
15291
|
+
flash_grad,
|
15292
|
+
opt0->ne[0],
|
15293
|
+
opt0->ne[1],
|
15294
|
+
nb0*opt0->ne[0],
|
15295
|
+
offset);
|
15296
|
+
} break;
|
15297
|
+
case 3:
|
15298
|
+
{
|
15299
|
+
grad_v = ggml_view_3d(ctx,
|
15300
|
+
flash_grad,
|
15301
|
+
opt0->ne[0],
|
15302
|
+
opt0->ne[1],
|
15303
|
+
opt0->ne[2],
|
15304
|
+
nb0*opt0->ne[0],
|
15305
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15306
|
+
offset);
|
15307
|
+
} break;
|
15308
|
+
case 4:
|
15309
|
+
{
|
15310
|
+
grad_v = ggml_view_4d(ctx,
|
15311
|
+
flash_grad,
|
15312
|
+
opt0->ne[0],
|
15313
|
+
opt0->ne[1],
|
15314
|
+
opt0->ne[2],
|
15315
|
+
opt0->ne[3],
|
15316
|
+
nb0*opt0->ne[0],
|
15317
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15318
|
+
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
15319
|
+
offset);
|
15320
|
+
} break;
|
15321
|
+
}
|
15322
|
+
|
15323
|
+
opt0->grad = ggml_add_impl(ctx,
|
15324
|
+
opt0->grad,
|
15325
|
+
grad_v,
|
15326
|
+
inplace);
|
15327
|
+
}
|
13933
15328
|
} break;
|
13934
15329
|
case GGML_OP_FLASH_FF:
|
13935
15330
|
{
|
13936
15331
|
GGML_ASSERT(false); // not supported
|
13937
15332
|
} break;
|
15333
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15334
|
+
{
|
15335
|
+
GGML_ASSERT(false); // not supported
|
15336
|
+
} break;
|
13938
15337
|
case GGML_OP_MAP_UNARY:
|
13939
15338
|
case GGML_OP_MAP_BINARY:
|
13940
15339
|
{
|
13941
15340
|
GGML_ASSERT(false); // not supported
|
13942
15341
|
} break;
|
15342
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15343
|
+
{
|
15344
|
+
if (src0->grad) {
|
15345
|
+
src0->grad = ggml_add_impl(ctx,
|
15346
|
+
src0->grad,
|
15347
|
+
ggml_cross_entropy_loss_back(ctx,
|
15348
|
+
src0,
|
15349
|
+
src1,
|
15350
|
+
tensor->grad),
|
15351
|
+
inplace);
|
15352
|
+
}
|
15353
|
+
} break;
|
15354
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15355
|
+
{
|
15356
|
+
GGML_ASSERT(false); // not supported
|
15357
|
+
} break;
|
13943
15358
|
case GGML_OP_NONE:
|
13944
15359
|
{
|
13945
15360
|
// nop
|
@@ -14316,6 +15731,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14316
15731
|
case GGML_OP_SUM_ROWS:
|
14317
15732
|
case GGML_OP_MEAN:
|
14318
15733
|
case GGML_OP_REPEAT:
|
15734
|
+
case GGML_OP_REPEAT_BACK:
|
14319
15735
|
case GGML_OP_ABS:
|
14320
15736
|
case GGML_OP_SGN:
|
14321
15737
|
case GGML_OP_NEG:
|
@@ -14335,6 +15751,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14335
15751
|
node->n_tasks = n_threads;
|
14336
15752
|
} break;
|
14337
15753
|
case GGML_OP_MUL_MAT:
|
15754
|
+
case GGML_OP_OUT_PROD:
|
14338
15755
|
{
|
14339
15756
|
node->n_tasks = n_threads;
|
14340
15757
|
|
@@ -14417,6 +15834,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14417
15834
|
} break;
|
14418
15835
|
case GGML_OP_DIAG_MASK_INF:
|
14419
15836
|
case GGML_OP_SOFT_MAX:
|
15837
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14420
15838
|
case GGML_OP_ROPE:
|
14421
15839
|
case GGML_OP_ROPE_BACK:
|
14422
15840
|
{
|
@@ -14496,6 +15914,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14496
15914
|
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
14497
15915
|
}
|
14498
15916
|
|
15917
|
+
work_size = MAX(work_size, cur);
|
15918
|
+
} break;
|
15919
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15920
|
+
{
|
15921
|
+
node->n_tasks = n_threads;
|
15922
|
+
|
15923
|
+
size_t cur = 0;
|
15924
|
+
|
15925
|
+
const int64_t D = node->src0->ne[0];
|
15926
|
+
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
15927
|
+
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
15928
|
+
if (node->src1->type == GGML_TYPE_F32) {
|
15929
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
15930
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
15931
|
+
}
|
15932
|
+
|
15933
|
+
if (node->src1->type == GGML_TYPE_F16) {
|
15934
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
15935
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
15936
|
+
}
|
15937
|
+
|
14499
15938
|
work_size = MAX(work_size, cur);
|
14500
15939
|
} break;
|
14501
15940
|
case GGML_OP_MAP_UNARY:
|
@@ -14503,6 +15942,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14503
15942
|
{
|
14504
15943
|
node->n_tasks = 1;
|
14505
15944
|
} break;
|
15945
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15946
|
+
{
|
15947
|
+
node->n_tasks = n_threads;
|
15948
|
+
|
15949
|
+
size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
|
15950
|
+
|
15951
|
+
work_size = MAX(work_size, cur);
|
15952
|
+
} break;
|
15953
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15954
|
+
{
|
15955
|
+
node->n_tasks = n_threads;
|
15956
|
+
|
15957
|
+
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
|
15958
|
+
|
15959
|
+
work_size = MAX(work_size, cur);
|
15960
|
+
} break;
|
14506
15961
|
case GGML_OP_NONE:
|
14507
15962
|
{
|
14508
15963
|
node->n_tasks = 1;
|
@@ -15478,6 +16933,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
|
|
15478
16933
|
|
15479
16934
|
static enum ggml_opt_result ggml_opt_adam(
|
15480
16935
|
struct ggml_context * ctx,
|
16936
|
+
struct ggml_opt_context * opt,
|
15481
16937
|
struct ggml_opt_params params,
|
15482
16938
|
struct ggml_tensor * f,
|
15483
16939
|
struct ggml_cgraph * gf,
|
@@ -15503,25 +16959,29 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15503
16959
|
}
|
15504
16960
|
}
|
15505
16961
|
|
16962
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
|
16963
|
+
int iter = opt->iter;
|
16964
|
+
ggml_opt_init(opt->ctx, opt, params, nx);
|
16965
|
+
opt->iter = iter;
|
16966
|
+
}
|
16967
|
+
|
15506
16968
|
// constants
|
15507
|
-
const float
|
16969
|
+
const float sched = params.adam.sched;
|
16970
|
+
const float decay = params.adam.decay * sched;
|
16971
|
+
const float alpha = params.adam.alpha * sched;
|
15508
16972
|
const float beta1 = params.adam.beta1;
|
15509
16973
|
const float beta2 = params.adam.beta2;
|
15510
16974
|
const float eps = params.adam.eps;
|
15511
16975
|
|
15512
|
-
float * x =
|
15513
|
-
float * g1 =
|
15514
|
-
float * g2 =
|
15515
|
-
float * m =
|
15516
|
-
float * v =
|
15517
|
-
float * mh =
|
15518
|
-
float * vh =
|
16976
|
+
float * x = opt->adam.x->data; // view of the parameters
|
16977
|
+
float * g1 = opt->adam.g1->data; // gradient
|
16978
|
+
float * g2 = opt->adam.g2->data; // gradient squared
|
16979
|
+
float * m = opt->adam.m->data; // first moment
|
16980
|
+
float * v = opt->adam.v->data; // second moment
|
16981
|
+
float * mh = opt->adam.mh->data; // first moment hat
|
16982
|
+
float * vh = opt->adam.vh->data; // second moment hat
|
15519
16983
|
|
15520
|
-
float * pf = params.past > 0 ?
|
15521
|
-
|
15522
|
-
// initialize
|
15523
|
-
ggml_vec_set_f32(nx, m, 0.0f);
|
15524
|
-
ggml_vec_set_f32(nx, v, 0.0f);
|
16984
|
+
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
15525
16985
|
|
15526
16986
|
// update view
|
15527
16987
|
ggml_opt_get_params(np, ps, x);
|
@@ -15531,16 +16991,27 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15531
16991
|
ggml_set_f32 (f->grad, 1.0f);
|
15532
16992
|
ggml_graph_compute(ctx, gb);
|
15533
16993
|
|
15534
|
-
|
16994
|
+
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
16995
|
+
opt->adam.fx_best = opt->adam.fx_prev;
|
15535
16996
|
if (pf) {
|
15536
|
-
pf[
|
16997
|
+
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
15537
16998
|
}
|
15538
16999
|
|
15539
|
-
|
15540
|
-
|
17000
|
+
// initialize
|
17001
|
+
if (opt->just_initialized) {
|
17002
|
+
opt->adam.n_no_improvement = 0;
|
17003
|
+
opt->just_initialized = false;
|
17004
|
+
}
|
17005
|
+
|
17006
|
+
float * fx_best = &opt->adam.fx_best;
|
17007
|
+
float * fx_prev = &opt->adam.fx_prev;
|
17008
|
+
int * n_no_improvement = &opt->adam.n_no_improvement;
|
17009
|
+
|
17010
|
+
int iter0 = opt->iter;
|
15541
17011
|
|
15542
17012
|
// run the optimizer
|
15543
17013
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
17014
|
+
opt->iter = iter0 + t + 1;
|
15544
17015
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
15545
17016
|
|
15546
17017
|
GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
|
@@ -15574,17 +17045,22 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15574
17045
|
|
15575
17046
|
// m^hat = m_t / (1 - beta1^t)
|
15576
17047
|
// v^hat = v_t / (1 - beta2^t)
|
15577
|
-
// x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
|
17048
|
+
// x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
|
17049
|
+
// x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
|
17050
|
+
// x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
|
17051
|
+
// x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
|
17052
|
+
// x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
|
15578
17053
|
ggml_vec_cpy_f32 (nx, mh, m);
|
15579
17054
|
ggml_vec_cpy_f32 (nx, vh, v);
|
15580
17055
|
|
15581
|
-
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1,
|
15582
|
-
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2,
|
17056
|
+
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
|
17057
|
+
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
|
15583
17058
|
|
15584
17059
|
ggml_vec_sqrt_f32 (nx, vh, vh);
|
15585
17060
|
ggml_vec_acc1_f32 (nx, vh, eps);
|
15586
17061
|
|
15587
17062
|
ggml_vec_div_f32 (nx, mh, mh, vh);
|
17063
|
+
ggml_vec_scale_f32(nx, x, 1.0f - decay);
|
15588
17064
|
ggml_vec_sub_f32 (nx, x, x, mh);
|
15589
17065
|
|
15590
17066
|
// update the parameters
|
@@ -15598,7 +17074,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15598
17074
|
const float fx = ggml_get_f32_1d(f, 0);
|
15599
17075
|
|
15600
17076
|
// check convergence
|
15601
|
-
if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
|
17077
|
+
if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
|
15602
17078
|
GGML_PRINT_DEBUG("converged\n");
|
15603
17079
|
|
15604
17080
|
return GGML_OPT_OK;
|
@@ -15607,32 +17083,32 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15607
17083
|
// delta-based convergence test
|
15608
17084
|
if (pf != NULL) {
|
15609
17085
|
// need at least params.past iterations to start checking for convergence
|
15610
|
-
if (params.past <= t) {
|
15611
|
-
const float rate = (pf[t%params.past] - fx)/fx;
|
17086
|
+
if (params.past <= iter0 + t) {
|
17087
|
+
const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
|
15612
17088
|
|
15613
17089
|
if (fabsf(rate) < params.delta) {
|
15614
17090
|
return GGML_OPT_OK;
|
15615
17091
|
}
|
15616
17092
|
}
|
15617
17093
|
|
15618
|
-
pf[t%params.past] = fx;
|
17094
|
+
pf[(iter0 + t)%params.past] = fx;
|
15619
17095
|
}
|
15620
17096
|
|
15621
17097
|
// check for improvement
|
15622
17098
|
if (params.max_no_improvement > 0) {
|
15623
|
-
if (fx_best > fx) {
|
15624
|
-
fx_best = fx;
|
15625
|
-
n_no_improvement = 0;
|
17099
|
+
if (fx_best[0] > fx) {
|
17100
|
+
fx_best[0] = fx;
|
17101
|
+
n_no_improvement[0] = 0;
|
15626
17102
|
} else {
|
15627
|
-
++n_no_improvement;
|
17103
|
+
++n_no_improvement[0];
|
15628
17104
|
|
15629
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17105
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15630
17106
|
return GGML_OPT_OK;
|
15631
17107
|
}
|
15632
17108
|
}
|
15633
17109
|
}
|
15634
17110
|
|
15635
|
-
fx_prev = fx;
|
17111
|
+
fx_prev[0] = fx;
|
15636
17112
|
|
15637
17113
|
{
|
15638
17114
|
const int64_t t_end_cpu = ggml_cycles();
|
@@ -15771,6 +17247,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
15771
17247
|
|
15772
17248
|
static enum ggml_opt_result ggml_opt_lbfgs(
|
15773
17249
|
struct ggml_context * ctx,
|
17250
|
+
struct ggml_opt_context * opt,
|
15774
17251
|
struct ggml_opt_params params,
|
15775
17252
|
struct ggml_tensor * f,
|
15776
17253
|
struct ggml_cgraph * gf,
|
@@ -15803,31 +17280,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15803
17280
|
}
|
15804
17281
|
}
|
15805
17282
|
|
15806
|
-
|
15807
|
-
|
15808
|
-
|
15809
|
-
|
15810
|
-
|
17283
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
|
17284
|
+
int iter = opt->iter;
|
17285
|
+
ggml_opt_init(ctx, opt, params, nx);
|
17286
|
+
opt->iter = iter;
|
17287
|
+
}
|
17288
|
+
|
17289
|
+
float * x = opt->lbfgs.x->data; // current parameters
|
17290
|
+
float * xp = opt->lbfgs.xp->data; // previous parameters
|
17291
|
+
float * g = opt->lbfgs.g->data; // current gradient
|
17292
|
+
float * gp = opt->lbfgs.gp->data; // previous gradient
|
17293
|
+
float * d = opt->lbfgs.d->data; // search direction
|
15811
17294
|
|
15812
|
-
float * pf = params.past > 0 ?
|
17295
|
+
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
15813
17296
|
|
15814
17297
|
float fx = 0.0f; // cost function value
|
15815
17298
|
float xnorm = 0.0f; // ||x||
|
15816
17299
|
float gnorm = 0.0f; // ||g||
|
15817
|
-
float step = 0.0f;
|
15818
17300
|
|
15819
17301
|
// initialize x from the graph nodes
|
15820
17302
|
ggml_opt_get_params(np, ps, x);
|
15821
17303
|
|
15822
17304
|
// the L-BFGS memory
|
15823
|
-
|
15824
|
-
|
15825
|
-
|
15826
|
-
|
15827
|
-
lm[i].ys = 0.0f;
|
15828
|
-
lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15829
|
-
lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15830
|
-
}
|
17305
|
+
float * lm_alpha = opt->lbfgs.lmal->data;
|
17306
|
+
float * lm_ys = opt->lbfgs.lmys->data;
|
17307
|
+
float * lm_s = opt->lbfgs.lms->data;
|
17308
|
+
float * lm_y = opt->lbfgs.lmy->data;
|
15831
17309
|
|
15832
17310
|
// evaluate the function value and its gradient
|
15833
17311
|
{
|
@@ -15842,12 +17320,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15842
17320
|
fx = ggml_get_f32_1d(f, 0);
|
15843
17321
|
}
|
15844
17322
|
|
15845
|
-
if (pf) {
|
15846
|
-
pf[0] = fx;
|
15847
|
-
}
|
15848
|
-
|
15849
|
-
float fx_best = fx;
|
15850
|
-
|
15851
17323
|
// search direction = -gradient
|
15852
17324
|
ggml_vec_neg_f32(nx, d, g);
|
15853
17325
|
|
@@ -15864,26 +17336,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15864
17336
|
return GGML_OPT_OK;
|
15865
17337
|
}
|
15866
17338
|
|
15867
|
-
|
15868
|
-
|
17339
|
+
if (opt->just_initialized) {
|
17340
|
+
if (pf) {
|
17341
|
+
pf[0] = fx;
|
17342
|
+
}
|
17343
|
+
opt->lbfgs.fx_best = fx;
|
17344
|
+
|
17345
|
+
// initial step
|
17346
|
+
ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
|
17347
|
+
opt->lbfgs.j = 0;
|
17348
|
+
opt->lbfgs.k = 1;
|
17349
|
+
opt->lbfgs.end = 0;
|
17350
|
+
opt->lbfgs.n_no_improvement = 0;
|
17351
|
+
opt->just_initialized = false;
|
17352
|
+
}
|
17353
|
+
|
17354
|
+
float * fx_best = &opt->lbfgs.fx_best;
|
17355
|
+
float * step = &opt->lbfgs.step;
|
17356
|
+
int * j = &opt->lbfgs.j;
|
17357
|
+
int * k = &opt->lbfgs.k;
|
17358
|
+
int * end = &opt->lbfgs.end;
|
17359
|
+
int * n_no_improvement = &opt->lbfgs.n_no_improvement;
|
15869
17360
|
|
15870
|
-
int
|
15871
|
-
int
|
15872
|
-
int ls = 0;
|
15873
|
-
int end = 0;
|
15874
|
-
int bound = 0;
|
15875
|
-
int n_no_improvement = 0;
|
17361
|
+
int ls = 0;
|
17362
|
+
int bound = 0;
|
15876
17363
|
|
15877
17364
|
float ys = 0.0f;
|
15878
17365
|
float yy = 0.0f;
|
15879
17366
|
float beta = 0.0f;
|
15880
17367
|
|
17368
|
+
int it = 0;
|
17369
|
+
|
15881
17370
|
while (true) {
|
15882
17371
|
// store the current position and gradient vectors
|
15883
17372
|
ggml_vec_cpy_f32(nx, xp, x);
|
15884
17373
|
ggml_vec_cpy_f32(nx, gp, g);
|
15885
17374
|
|
15886
|
-
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d,
|
17375
|
+
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
|
15887
17376
|
|
15888
17377
|
if (ls < 0) {
|
15889
17378
|
// linesearch failed - go back to the previous point and return
|
@@ -15909,32 +17398,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15909
17398
|
// delta-based convergence test
|
15910
17399
|
if (pf != NULL) {
|
15911
17400
|
// need at least params.past iterations to start checking for convergence
|
15912
|
-
if (params.past <= k) {
|
15913
|
-
const float rate = (pf[k%params.past] - fx)/fx;
|
17401
|
+
if (params.past <= k[0]) {
|
17402
|
+
const float rate = (pf[k[0]%params.past] - fx)/fx;
|
15914
17403
|
|
15915
17404
|
if (fabsf(rate) < params.delta) {
|
15916
17405
|
return GGML_OPT_OK;
|
15917
17406
|
}
|
15918
17407
|
}
|
15919
17408
|
|
15920
|
-
pf[k%params.past] = fx;
|
17409
|
+
pf[k[0]%params.past] = fx;
|
15921
17410
|
}
|
15922
17411
|
|
15923
17412
|
// check for improvement
|
15924
17413
|
if (params.max_no_improvement > 0) {
|
15925
|
-
if (fx < fx_best) {
|
15926
|
-
fx_best = fx;
|
15927
|
-
n_no_improvement = 0;
|
17414
|
+
if (fx < fx_best[0]) {
|
17415
|
+
fx_best[0] = fx;
|
17416
|
+
n_no_improvement[0] = 0;
|
15928
17417
|
} else {
|
15929
|
-
n_no_improvement++;
|
17418
|
+
n_no_improvement[0]++;
|
15930
17419
|
|
15931
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17420
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15932
17421
|
return GGML_OPT_OK;
|
15933
17422
|
}
|
15934
17423
|
}
|
15935
17424
|
}
|
15936
17425
|
|
15937
|
-
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter <
|
17426
|
+
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
|
15938
17427
|
// reached the maximum number of iterations
|
15939
17428
|
return GGML_OPT_DID_NOT_CONVERGE;
|
15940
17429
|
}
|
@@ -15943,50 +17432,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15943
17432
|
// s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
|
15944
17433
|
// y_{k+1} = g_{k+1} - g_{k}.
|
15945
17434
|
//
|
15946
|
-
ggml_vec_sub_f32(nx,
|
15947
|
-
ggml_vec_sub_f32(nx,
|
17435
|
+
ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
|
17436
|
+
ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
|
15948
17437
|
|
15949
17438
|
// compute scalars ys and yy:
|
15950
17439
|
// ys = y^t \cdot s -> 1 / \rho.
|
15951
17440
|
// yy = y^t \cdot y.
|
15952
17441
|
//
|
15953
|
-
ggml_vec_dot_f32(nx, &ys,
|
15954
|
-
ggml_vec_dot_f32(nx, &yy,
|
17442
|
+
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
|
17443
|
+
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
|
15955
17444
|
|
15956
|
-
|
17445
|
+
lm_ys[end[0]] = ys;
|
15957
17446
|
|
15958
17447
|
// find new search direction
|
15959
17448
|
// ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
|
15960
17449
|
|
15961
|
-
bound = (m <= k) ? m : k;
|
15962
|
-
k++;
|
15963
|
-
|
17450
|
+
bound = (m <= k[0]) ? m : k[0];
|
17451
|
+
k[0]++;
|
17452
|
+
it++;
|
17453
|
+
end[0] = (end[0] + 1)%m;
|
15964
17454
|
|
15965
17455
|
// initialize search direction with -g
|
15966
17456
|
ggml_vec_neg_f32(nx, d, g);
|
15967
17457
|
|
15968
|
-
j = end;
|
17458
|
+
j[0] = end[0];
|
15969
17459
|
for (int i = 0; i < bound; ++i) {
|
15970
|
-
j = (j + m - 1) % m;
|
17460
|
+
j[0] = (j[0] + m - 1) % m;
|
15971
17461
|
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
|
15972
|
-
ggml_vec_dot_f32(nx, &
|
15973
|
-
|
17462
|
+
ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
|
17463
|
+
lm_alpha[j[0]] /= lm_ys[j[0]];
|
15974
17464
|
// q_{i} = q_{i+1} - \alpha_{i} y_{i}
|
15975
|
-
ggml_vec_mad_f32(nx, d,
|
17465
|
+
ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
|
15976
17466
|
}
|
15977
17467
|
|
15978
17468
|
ggml_vec_scale_f32(nx, d, ys/yy);
|
15979
17469
|
|
15980
17470
|
for (int i = 0; i < bound; ++i) {
|
15981
17471
|
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
|
15982
|
-
ggml_vec_dot_f32(nx, &beta,
|
15983
|
-
beta /=
|
17472
|
+
ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
|
17473
|
+
beta /= lm_ys[j[0]];
|
15984
17474
|
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
|
15985
|
-
ggml_vec_mad_f32(nx, d,
|
15986
|
-
j = (j + 1)%m;
|
17475
|
+
ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
|
17476
|
+
j[0] = (j[0] + 1)%m;
|
15987
17477
|
}
|
15988
17478
|
|
15989
|
-
step = 1.0;
|
17479
|
+
step[0] = 1.0;
|
15990
17480
|
}
|
15991
17481
|
|
15992
17482
|
return GGML_OPT_DID_NOT_CONVERGE;
|
@@ -16011,6 +17501,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
16011
17501
|
|
16012
17502
|
.adam = {
|
16013
17503
|
.n_iter = 10000,
|
17504
|
+
.sched = 1.000f,
|
17505
|
+
.decay = 0.001f,
|
16014
17506
|
.alpha = 0.001f,
|
16015
17507
|
.beta1 = 0.9f,
|
16016
17508
|
.beta2 = 0.999f,
|
@@ -16053,6 +17545,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
16053
17545
|
return result;
|
16054
17546
|
}
|
16055
17547
|
|
17548
|
+
GGML_API void ggml_opt_init(
|
17549
|
+
struct ggml_context * ctx,
|
17550
|
+
struct ggml_opt_context * opt,
|
17551
|
+
struct ggml_opt_params params,
|
17552
|
+
int64_t nx) {
|
17553
|
+
opt->ctx = ctx;
|
17554
|
+
opt->params = params;
|
17555
|
+
opt->iter = 0;
|
17556
|
+
opt->nx = nx;
|
17557
|
+
opt->just_initialized = true;
|
17558
|
+
switch (opt->params.type) {
|
17559
|
+
case GGML_OPT_ADAM:
|
17560
|
+
{
|
17561
|
+
opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17562
|
+
opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17563
|
+
opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17564
|
+
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17565
|
+
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17566
|
+
opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17567
|
+
opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17568
|
+
opt->adam.pf = params.past > 0
|
17569
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
17570
|
+
: NULL;
|
17571
|
+
ggml_set_zero(opt->adam.x);
|
17572
|
+
ggml_set_zero(opt->adam.g1);
|
17573
|
+
ggml_set_zero(opt->adam.g2);
|
17574
|
+
ggml_set_zero(opt->adam.m);
|
17575
|
+
ggml_set_zero(opt->adam.v);
|
17576
|
+
ggml_set_zero(opt->adam.mh);
|
17577
|
+
ggml_set_zero(opt->adam.vh);
|
17578
|
+
if (opt->adam.pf) {
|
17579
|
+
ggml_set_zero(opt->adam.pf);
|
17580
|
+
}
|
17581
|
+
} break;
|
17582
|
+
case GGML_OPT_LBFGS:
|
17583
|
+
{
|
17584
|
+
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17585
|
+
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17586
|
+
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17587
|
+
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17588
|
+
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17589
|
+
opt->lbfgs.pf = params.past > 0
|
17590
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
17591
|
+
: NULL;
|
17592
|
+
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
17593
|
+
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
17594
|
+
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
17595
|
+
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
17596
|
+
ggml_set_zero(opt->lbfgs.x);
|
17597
|
+
ggml_set_zero(opt->lbfgs.xp);
|
17598
|
+
ggml_set_zero(opt->lbfgs.g);
|
17599
|
+
ggml_set_zero(opt->lbfgs.gp);
|
17600
|
+
ggml_set_zero(opt->lbfgs.d);
|
17601
|
+
ggml_set_zero(opt->lbfgs.pf);
|
17602
|
+
if (opt->lbfgs.pf) {
|
17603
|
+
ggml_set_zero(opt->lbfgs.pf);
|
17604
|
+
}
|
17605
|
+
ggml_set_zero(opt->lbfgs.lmal);
|
17606
|
+
ggml_set_zero(opt->lbfgs.lmys);
|
17607
|
+
ggml_set_zero(opt->lbfgs.lms);
|
17608
|
+
ggml_set_zero(opt->lbfgs.lmy);
|
17609
|
+
} break;
|
17610
|
+
}
|
17611
|
+
}
|
17612
|
+
|
16056
17613
|
enum ggml_opt_result ggml_opt(
|
16057
17614
|
struct ggml_context * ctx,
|
16058
17615
|
struct ggml_opt_params params,
|
@@ -16075,33 +17632,65 @@ enum ggml_opt_result ggml_opt(
|
|
16075
17632
|
|
16076
17633
|
enum ggml_opt_result result = GGML_OPT_OK;
|
16077
17634
|
|
17635
|
+
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
|
17636
|
+
|
17637
|
+
ggml_opt_init(ctx, opt, params, 0);
|
17638
|
+
result = ggml_opt_resume(ctx, opt, f);
|
17639
|
+
|
17640
|
+
if (free_ctx) {
|
17641
|
+
ggml_free(ctx);
|
17642
|
+
}
|
17643
|
+
|
17644
|
+
return result;
|
17645
|
+
}
|
17646
|
+
|
17647
|
+
enum ggml_opt_result ggml_opt_resume(
|
17648
|
+
struct ggml_context * ctx,
|
17649
|
+
struct ggml_opt_context * opt,
|
17650
|
+
struct ggml_tensor * f) {
|
17651
|
+
|
17652
|
+
// build forward + backward compute graphs
|
17653
|
+
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
17654
|
+
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
17655
|
+
|
17656
|
+
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
17657
|
+
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
17658
|
+
|
17659
|
+
*gf = ggml_build_forward (f);
|
17660
|
+
*gb = ggml_build_backward(ctx, gf, true);
|
17661
|
+
|
17662
|
+
return ggml_opt_resume_g(ctx, opt, f, gf, gb);
|
17663
|
+
}
|
17664
|
+
|
17665
|
+
enum ggml_opt_result ggml_opt_resume_g(
|
17666
|
+
struct ggml_context * ctx,
|
17667
|
+
struct ggml_opt_context * opt,
|
17668
|
+
struct ggml_tensor * f,
|
17669
|
+
struct ggml_cgraph * gf,
|
17670
|
+
struct ggml_cgraph * gb) {
|
17671
|
+
|
16078
17672
|
// build forward + backward compute graphs
|
16079
|
-
|
16080
|
-
struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
|
17673
|
+
enum ggml_opt_result result = GGML_OPT_OK;
|
16081
17674
|
|
16082
|
-
switch (params.type) {
|
17675
|
+
switch (opt->params.type) {
|
16083
17676
|
case GGML_OPT_ADAM:
|
16084
17677
|
{
|
16085
|
-
result = ggml_opt_adam(ctx, params, f,
|
17678
|
+
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
|
16086
17679
|
} break;
|
16087
17680
|
case GGML_OPT_LBFGS:
|
16088
17681
|
{
|
16089
|
-
result = ggml_opt_lbfgs(ctx, params, f,
|
17682
|
+
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
|
16090
17683
|
} break;
|
16091
17684
|
}
|
16092
17685
|
|
16093
|
-
if (params.print_forward_graph) {
|
16094
|
-
ggml_graph_print (
|
16095
|
-
ggml_graph_dump_dot(
|
16096
|
-
}
|
16097
|
-
|
16098
|
-
if (params.print_backward_graph) {
|
16099
|
-
ggml_graph_print (&gb);
|
16100
|
-
ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
|
17686
|
+
if (opt->params.print_forward_graph) {
|
17687
|
+
ggml_graph_print (gf);
|
17688
|
+
ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
|
16101
17689
|
}
|
16102
17690
|
|
16103
|
-
if (
|
16104
|
-
|
17691
|
+
if (opt->params.print_backward_graph) {
|
17692
|
+
ggml_graph_print (gb);
|
17693
|
+
ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
|
16105
17694
|
}
|
16106
17695
|
|
16107
17696
|
return result;
|
@@ -16301,6 +17890,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
16301
17890
|
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
16302
17891
|
} break;
|
16303
17892
|
#endif
|
17893
|
+
case GGML_TYPE_F16:
|
17894
|
+
{
|
17895
|
+
int elemsize = sizeof(ggml_fp16_t);
|
17896
|
+
ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
|
17897
|
+
result = n * elemsize;
|
17898
|
+
} break;
|
17899
|
+
case GGML_TYPE_F32:
|
17900
|
+
{
|
17901
|
+
int elemsize = sizeof(float);
|
17902
|
+
result = n * elemsize;
|
17903
|
+
memcpy((uint8_t *)dst + start * elemsize, src + start, result);
|
17904
|
+
} break;
|
16304
17905
|
default:
|
16305
17906
|
assert(false);
|
16306
17907
|
}
|