llama_cpp 0.1.4 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,10 @@
3
3
 
4
4
  #include "ggml.h"
5
5
 
6
+ #ifdef GGML_USE_K_QUANTS
7
+ #include "k_quants.h"
8
+ #endif
9
+
6
10
  #if defined(_MSC_VER) || defined(__MINGW32__)
7
11
  #include <malloc.h> // using malloc.h with MSC/MINGW
8
12
  #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -21,6 +25,10 @@
21
25
  #include <float.h>
22
26
  #include <limits.h>
23
27
 
28
+ #ifdef GGML_USE_METAL
29
+ #include <unistd.h>
30
+ #endif
31
+
24
32
  // if C99 - static_assert is noop
25
33
  // ref: https://stackoverflow.com/a/53923785/4039976
26
34
  #ifndef static_assert
@@ -121,7 +129,11 @@ typedef void* thread_ret_t;
121
129
  #else
122
130
  inline static void* ggml_aligned_malloc(size_t size) {
123
131
  void* aligned_memory = NULL;
132
+ #ifdef GGML_USE_METAL
133
+ int result = posix_memalign(&aligned_memory, getpagesize(), size);
134
+ #else
124
135
  int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
136
+ #endif
125
137
  if (result != 0) {
126
138
  // Handle allocation failure
127
139
  return NULL;
@@ -403,21 +415,27 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
403
415
  //
404
416
 
405
417
  #if defined(_MSC_VER) || defined(__MINGW32__)
406
- static int64_t timer_freq;
418
+ static int64_t timer_freq, timer_start;
407
419
  void ggml_time_init(void) {
408
- LARGE_INTEGER frequency;
409
- QueryPerformanceFrequency(&frequency);
410
- timer_freq = frequency.QuadPart;
420
+ LARGE_INTEGER t;
421
+ QueryPerformanceFrequency(&t);
422
+ timer_freq = t.QuadPart;
423
+
424
+ // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
425
+ // and the uptime is high enough.
426
+ // We subtract the program start time to reduce the likelihood of that happening.
427
+ QueryPerformanceCounter(&t);
428
+ timer_start = t.QuadPart;
411
429
  }
412
430
  int64_t ggml_time_ms(void) {
413
431
  LARGE_INTEGER t;
414
432
  QueryPerformanceCounter(&t);
415
- return (t.QuadPart * 1000) / timer_freq;
433
+ return ((t.QuadPart-timer_start) * 1000) / timer_freq;
416
434
  }
417
435
  int64_t ggml_time_us(void) {
418
436
  LARGE_INTEGER t;
419
437
  QueryPerformanceCounter(&t);
420
- return (t.QuadPart * 1000000) / timer_freq;
438
+ return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
421
439
  }
422
440
  #else
423
441
  void ggml_time_init(void) {}
@@ -474,6 +492,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
474
492
  // quantization
475
493
  //
476
494
 
495
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
496
+
477
497
  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
478
498
  // multiply int8_t, add results pairwise twice
479
499
  static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
@@ -533,7 +553,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
533
553
  static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
534
554
  {
535
555
  const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
536
- const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp);
556
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
537
557
  const __m256i lowMask = _mm256_set1_epi8( 0xF );
538
558
  return _mm256_and_si256(lowMask, bytes);
539
559
  }
@@ -606,7 +626,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
606
626
  bytesh = _mm_or_si128(bytesh, bit_mask);
607
627
  bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
608
628
  bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
609
- return _mm256_set_m128i(bytesh, bytesl);
629
+ return MM256_SET_M128I(bytesh, bytesl);
610
630
  }
611
631
 
612
632
  // Unpack 32 4-bit fields into 32 bytes
@@ -619,7 +639,7 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
619
639
  const __m128i lowMask = _mm_set1_epi8(0xF);
620
640
  tmpl = _mm_and_si128(lowMask, tmpl);
621
641
  tmph = _mm_and_si128(lowMask, tmph);
622
- return _mm256_set_m128i(tmph, tmpl);
642
+ return MM256_SET_M128I(tmph, tmpl);
623
643
  }
624
644
 
625
645
  // add int16_t pairwise and return as float vector
@@ -627,7 +647,7 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
627
647
  const __m128i ones = _mm_set1_epi16(1);
628
648
  const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
629
649
  const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
630
- const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
650
+ const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
631
651
  return _mm256_cvtepi32_ps(summed_pairs);
632
652
  }
633
653
 
@@ -1565,6 +1585,48 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1565
1585
  .vec_dot_q = NULL, // TODO
1566
1586
  .vec_dot_type = GGML_TYPE_Q8_1,
1567
1587
  },
1588
+ #ifdef GGML_USE_K_QUANTS
1589
+ [GGML_TYPE_Q2_K] = {
1590
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q2_K,
1591
+ .quantize_row_q = quantize_row_q2_K,
1592
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
1593
+ .quantize_row_q_dot = quantize_row_q8_K,
1594
+ .vec_dot_q = ggml_vec_dot_q2_K_q8_K,
1595
+ .vec_dot_type = GGML_TYPE_Q8_K,
1596
+ },
1597
+ [GGML_TYPE_Q3_K] = {
1598
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_K,
1599
+ .quantize_row_q = quantize_row_q3_K,
1600
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
1601
+ .quantize_row_q_dot = quantize_row_q8_K,
1602
+ .vec_dot_q = ggml_vec_dot_q3_K_q8_K,
1603
+ .vec_dot_type = GGML_TYPE_Q8_K,
1604
+ },
1605
+ [GGML_TYPE_Q4_K] = {
1606
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_K,
1607
+ .quantize_row_q = quantize_row_q4_K,
1608
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
1609
+ .quantize_row_q_dot = quantize_row_q8_K,
1610
+ .vec_dot_q = ggml_vec_dot_q4_K_q8_K,
1611
+ .vec_dot_type = GGML_TYPE_Q8_K,
1612
+ },
1613
+ [GGML_TYPE_Q5_K] = {
1614
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_K,
1615
+ .quantize_row_q = quantize_row_q5_K,
1616
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference,
1617
+ .quantize_row_q_dot = quantize_row_q8_K,
1618
+ .vec_dot_q = ggml_vec_dot_q5_K_q8_K,
1619
+ .vec_dot_type = GGML_TYPE_Q8_K,
1620
+ },
1621
+ [GGML_TYPE_Q6_K] = {
1622
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q6_K,
1623
+ .quantize_row_q = quantize_row_q6_K,
1624
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference,
1625
+ .quantize_row_q_dot = quantize_row_q8_K,
1626
+ .vec_dot_q = ggml_vec_dot_q6_K_q8_K,
1627
+ .vec_dot_type = GGML_TYPE_Q8_K,
1628
+ },
1629
+ #endif
1568
1630
  };
1569
1631
 
1570
1632
  // For internal test use
@@ -2290,7 +2352,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2290
2352
  const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
2291
2353
 
2292
2354
  // Convert int32_t to float
2293
- __m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1));
2355
+ __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
2294
2356
 
2295
2357
  // Apply the scale, and accumulate
2296
2358
  acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
@@ -2766,7 +2828,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2766
2828
  __m128i bxh = _mm256_extractf128_si256(bx, 1);
2767
2829
  bxl = _mm_or_si128(bxl, bxhil);
2768
2830
  bxh = _mm_or_si128(bxh, bxhih);
2769
- bx = _mm256_set_m128i(bxh, bxl);
2831
+ bx = MM256_SET_M128I(bxh, bxl);
2770
2832
 
2771
2833
  const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2772
2834
 
@@ -3022,7 +3084,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
3022
3084
  __m128i bxh = _mm256_extractf128_si256(bx, 1);
3023
3085
  bxl = _mm_or_si128(bxl, bxhil);
3024
3086
  bxh = _mm_or_si128(bxh, bxhih);
3025
- bx = _mm256_set_m128i(bxh, bxl);
3087
+ bx = MM256_SET_M128I(bxh, bxl);
3026
3088
 
3027
3089
  const __m256 dy = _mm256_set1_ps(y[i].d);
3028
3090
  const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
@@ -3444,11 +3506,19 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3444
3506
  [GGML_TYPE_Q5_1] = QK5_1,
3445
3507
  [GGML_TYPE_Q8_0] = QK8_0,
3446
3508
  [GGML_TYPE_Q8_1] = QK8_1,
3509
+ #ifdef GGML_USE_K_QUANTS
3510
+ [GGML_TYPE_Q2_K] = QK_K,
3511
+ [GGML_TYPE_Q3_K] = QK_K,
3512
+ [GGML_TYPE_Q4_K] = QK_K,
3513
+ [GGML_TYPE_Q5_K] = QK_K,
3514
+ [GGML_TYPE_Q6_K] = QK_K,
3515
+ [GGML_TYPE_Q8_K] = QK_K,
3516
+ #endif
3447
3517
  [GGML_TYPE_I8] = 1,
3448
3518
  [GGML_TYPE_I16] = 1,
3449
3519
  [GGML_TYPE_I32] = 1,
3450
3520
  };
3451
- static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
3521
+ static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");
3452
3522
 
3453
3523
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3454
3524
  [GGML_TYPE_F32] = sizeof(float),
@@ -3459,11 +3529,19 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3459
3529
  [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
3460
3530
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3461
3531
  [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
3532
+ #ifdef GGML_USE_K_QUANTS
3533
+ [GGML_TYPE_Q2_K] = sizeof(block_q2_K),
3534
+ [GGML_TYPE_Q3_K] = sizeof(block_q3_K),
3535
+ [GGML_TYPE_Q4_K] = sizeof(block_q4_K),
3536
+ [GGML_TYPE_Q5_K] = sizeof(block_q5_K),
3537
+ [GGML_TYPE_Q6_K] = sizeof(block_q6_K),
3538
+ [GGML_TYPE_Q8_K] = sizeof(block_q8_K),
3539
+ #endif
3462
3540
  [GGML_TYPE_I8] = sizeof(int8_t),
3463
3541
  [GGML_TYPE_I16] = sizeof(int16_t),
3464
3542
  [GGML_TYPE_I32] = sizeof(int32_t),
3465
3543
  };
3466
- static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
3544
+ static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
3467
3545
 
3468
3546
 
3469
3547
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3475,11 +3553,17 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3475
3553
  [GGML_TYPE_Q5_1] = "q5_1",
3476
3554
  [GGML_TYPE_Q8_0] = "q8_0",
3477
3555
  [GGML_TYPE_Q8_1] = "q8_1",
3556
+ [GGML_TYPE_Q2_K] = "q2_K",
3557
+ [GGML_TYPE_Q3_K] = "q3_K",
3558
+ [GGML_TYPE_Q4_K] = "q4_K",
3559
+ [GGML_TYPE_Q5_K] = "q5_K",
3560
+ [GGML_TYPE_Q6_K] = "q6_K",
3561
+ [GGML_TYPE_Q8_K] = "q8_K",
3478
3562
  [GGML_TYPE_I8] = "i8",
3479
3563
  [GGML_TYPE_I16] = "i16",
3480
3564
  [GGML_TYPE_I32] = "i32",
3481
3565
  };
3482
- static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
3566
+ static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");
3483
3567
 
3484
3568
  static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3485
3569
  [GGML_TYPE_F32] = false,
@@ -3490,11 +3574,17 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3490
3574
  [GGML_TYPE_Q5_1] = true,
3491
3575
  [GGML_TYPE_Q8_0] = true,
3492
3576
  [GGML_TYPE_Q8_1] = true,
3577
+ [GGML_TYPE_Q2_K] = true,
3578
+ [GGML_TYPE_Q3_K] = true,
3579
+ [GGML_TYPE_Q4_K] = true,
3580
+ [GGML_TYPE_Q5_K] = true,
3581
+ [GGML_TYPE_Q6_K] = true,
3582
+ [GGML_TYPE_Q8_K] = true,
3493
3583
  [GGML_TYPE_I8] = false,
3494
3584
  [GGML_TYPE_I16] = false,
3495
3585
  [GGML_TYPE_I32] = false,
3496
3586
  };
3497
- static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
3587
+ static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");
3498
3588
 
3499
3589
  static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3500
3590
  "NONE",
@@ -3513,6 +3603,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3513
3603
  "SUM_ROWS",
3514
3604
  "MEAN",
3515
3605
  "REPEAT",
3606
+ "REPEAT_BACK",
3516
3607
  "ABS",
3517
3608
  "SGN",
3518
3609
  "NEG",
@@ -3526,6 +3617,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3526
3617
  "RMS_NORM_BACK",
3527
3618
 
3528
3619
  "MUL_MAT",
3620
+ "OUT_PROD",
3529
3621
 
3530
3622
  "SCALE",
3531
3623
  "SET",
@@ -3541,6 +3633,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3541
3633
  "DIAG_MASK_INF",
3542
3634
  "DIAG_MASK_ZERO",
3543
3635
  "SOFT_MAX",
3636
+ "SOFT_MAX_BACK",
3544
3637
  "ROPE",
3545
3638
  "ROPE_BACK",
3546
3639
  "ALIBI",
@@ -3550,13 +3643,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3550
3643
 
3551
3644
  "FLASH_ATTN",
3552
3645
  "FLASH_FF",
3646
+ "FLASH_ATTN_BACK",
3553
3647
 
3554
3648
  "MAP_UNARY",
3555
3649
  "MAP_BINARY",
3556
- };
3557
3650
 
3558
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3651
+ "CROSS_ENTROPY_LOSS",
3652
+ "CROSS_ENTROPY_LOSS_BACK",
3653
+ };
3559
3654
 
3655
+ static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
3560
3656
 
3561
3657
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3562
3658
  "none",
@@ -3575,6 +3671,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3575
3671
  "Σx_k",
3576
3672
  "Σx/n",
3577
3673
  "repeat(x)",
3674
+ "repeat_back(x)",
3578
3675
  "abs(x)",
3579
3676
  "sgn(x)",
3580
3677
  "-x",
@@ -3587,6 +3684,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3587
3684
  "rms_norm(x)",
3588
3685
  "rms_norm_back(x)",
3589
3686
 
3687
+ "X*Y",
3590
3688
  "X*Y",
3591
3689
 
3592
3690
  "x*v",
@@ -3603,6 +3701,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3603
3701
  "diag_mask_inf(x)",
3604
3702
  "diag_mask_zero(x)",
3605
3703
  "soft_max(x)",
3704
+ "soft_max_back(x)",
3606
3705
  "rope(x)",
3607
3706
  "rope_back(x)",
3608
3707
  "alibi(x)",
@@ -3612,12 +3711,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3612
3711
 
3613
3712
  "flash_attn(x)",
3614
3713
  "flash_ff(x)",
3714
+ "flash_attn_back(x)",
3615
3715
 
3616
3716
  "f(x)",
3617
3717
  "f(x,y)",
3718
+
3719
+ "cross_entropy_loss(x,y)",
3720
+ "cross_entropy_loss_back(x,y)",
3618
3721
  };
3619
3722
 
3620
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3723
+ static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
3621
3724
 
3622
3725
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3623
3726
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3631,6 +3734,7 @@ struct ggml_context {
3631
3734
  void * mem_buffer;
3632
3735
  bool mem_buffer_owned;
3633
3736
  bool no_alloc;
3737
+ bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
3634
3738
 
3635
3739
  int n_objects;
3636
3740
 
@@ -3647,26 +3751,6 @@ struct ggml_context_container {
3647
3751
  struct ggml_context context;
3648
3752
  };
3649
3753
 
3650
- //
3651
- // compute types
3652
- //
3653
-
3654
- enum ggml_task_type {
3655
- GGML_TASK_INIT = 0,
3656
- GGML_TASK_COMPUTE,
3657
- GGML_TASK_FINALIZE,
3658
- };
3659
-
3660
- struct ggml_compute_params {
3661
- enum ggml_task_type type;
3662
-
3663
- int ith, nth;
3664
-
3665
- // work buffer for all threads
3666
- size_t wsize;
3667
- void * wdata;
3668
- };
3669
-
3670
3754
  //
3671
3755
  // ggml state
3672
3756
  //
@@ -3723,7 +3807,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
3723
3807
  return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
3724
3808
  }
3725
3809
 
3726
- int ggml_nrows(const struct ggml_tensor * tensor) {
3810
+ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
3727
3811
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3728
3812
 
3729
3813
  return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@@ -3732,7 +3816,20 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
3732
3816
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
3733
3817
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3734
3818
 
3735
- return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
3819
+ // this should handle cases where the tensor is not contiguous in memory
3820
+ // probaby just:
3821
+ //
3822
+ // return tensor->ne[3]*tensor->nb[3]
3823
+ //
3824
+ // is enough, but just in case, adding the second part
3825
+
3826
+ return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
3827
+ }
3828
+
3829
+ size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
3830
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3831
+
3832
+ return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
3736
3833
  }
3737
3834
 
3738
3835
  int ggml_blck_size(enum ggml_type type) {
@@ -3786,6 +3883,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3786
3883
  (t0->ne[3] == t1->ne[3]);
3787
3884
  }
3788
3885
 
3886
+ static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3887
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3888
+
3889
+ return
3890
+ (t0->ne[1] == t1->ne[1]) &&
3891
+ (t0->ne[2] == t1->ne[2]) &&
3892
+ (t0->ne[3] == t1->ne[3]);
3893
+ }
3894
+
3789
3895
  bool ggml_is_quantized(enum ggml_type type) {
3790
3896
  return GGML_IS_QUANTIZED[type];
3791
3897
  }
@@ -3801,6 +3907,11 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
3801
3907
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
3802
3908
  case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
3803
3909
  case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
3910
+ case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
3911
+ case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
3912
+ case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
3913
+ case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
3914
+ case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
3804
3915
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
3805
3916
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
3806
3917
  }
@@ -3814,11 +3925,11 @@ size_t ggml_tensor_overhead(void) {
3814
3925
  return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
3815
3926
  }
3816
3927
 
3817
- static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3928
+ bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3818
3929
  return tensor->nb[0] > tensor->nb[1];
3819
3930
  }
3820
3931
 
3821
- static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3932
+ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3822
3933
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3823
3934
 
3824
3935
  return
@@ -3828,6 +3939,12 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3828
3939
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3829
3940
  }
3830
3941
 
3942
+ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
3943
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3944
+
3945
+ return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
3946
+ }
3947
+
3831
3948
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
3832
3949
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3833
3950
 
@@ -3967,6 +4084,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3967
4084
  /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
3968
4085
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
3969
4086
  /*.no_alloc =*/ params.no_alloc,
4087
+ /*.no_alloc_save =*/ params.no_alloc,
3970
4088
  /*.n_objects =*/ 0,
3971
4089
  /*.objects_begin =*/ NULL,
3972
4090
  /*.objects_end =*/ NULL,
@@ -4044,11 +4162,18 @@ size_t ggml_get_mem_size(struct ggml_context * ctx) {
4044
4162
  // operators when using scratch buffers
4045
4163
  // TODO: implement a better way
4046
4164
  void ggml_scratch_save(struct ggml_context * ctx) {
4165
+ // this is needed to allow opt tensors to store their data
4166
+ // TODO: again, need to find a better way
4167
+ ctx->no_alloc_save = ctx->no_alloc;
4168
+ ctx->no_alloc = false;
4169
+
4047
4170
  ctx->scratch_save = ctx->scratch;
4048
4171
  ctx->scratch.data = NULL;
4049
4172
  }
4050
4173
 
4051
4174
  void ggml_scratch_load(struct ggml_context * ctx) {
4175
+ ctx->no_alloc = ctx->no_alloc_save;
4176
+
4052
4177
  ctx->scratch = ctx->scratch_save;
4053
4178
  }
4054
4179
 
@@ -4157,6 +4282,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
4157
4282
  /*.perf_time_us =*/ 0,
4158
4283
  /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
4159
4284
  /*.name =*/ { 0 },
4285
+ /*.extra =*/ NULL,
4160
4286
  /*.pad =*/ { 0 },
4161
4287
  };
4162
4288
 
@@ -4595,7 +4721,7 @@ struct ggml_tensor * ggml_add_impl(
4595
4721
 
4596
4722
  bool is_node = false;
4597
4723
 
4598
- if (!inplace && (a->grad || b->grad)) {
4724
+ if (a->grad || b->grad) {
4599
4725
  is_node = true;
4600
4726
  }
4601
4727
 
@@ -4635,7 +4761,7 @@ struct ggml_tensor * ggml_add1_impl(
4635
4761
 
4636
4762
  bool is_node = false;
4637
4763
 
4638
- if (!inplace && (a->grad || b->grad)) {
4764
+ if (a->grad || b->grad) {
4639
4765
  is_node = true;
4640
4766
  }
4641
4767
 
@@ -5061,6 +5187,34 @@ struct ggml_tensor * ggml_repeat(
5061
5187
  return result;
5062
5188
  }
5063
5189
 
5190
+ // ggml_repeat_back
5191
+
5192
+ struct ggml_tensor * ggml_repeat_back(
5193
+ struct ggml_context * ctx,
5194
+ struct ggml_tensor * a,
5195
+ struct ggml_tensor * b) {
5196
+ GGML_ASSERT(ggml_can_repeat(b, a));
5197
+
5198
+ bool is_node = false;
5199
+
5200
+ if (a->grad) {
5201
+ is_node = true;
5202
+ }
5203
+
5204
+ if (ggml_are_same_shape(a, b) && !is_node) {
5205
+ return a;
5206
+ }
5207
+
5208
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
5209
+
5210
+ result->op = GGML_OP_REPEAT_BACK;
5211
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5212
+ result->src0 = a;
5213
+ result->src1 = b;
5214
+
5215
+ return result;
5216
+ }
5217
+
5064
5218
  // ggml_abs
5065
5219
 
5066
5220
  struct ggml_tensor * ggml_abs_impl(
@@ -5438,6 +5592,32 @@ struct ggml_tensor * ggml_mul_mat(
5438
5592
  return result;
5439
5593
  }
5440
5594
 
5595
+ // ggml_out_prod
5596
+
5597
+ struct ggml_tensor * ggml_out_prod(
5598
+ struct ggml_context * ctx,
5599
+ struct ggml_tensor * a,
5600
+ struct ggml_tensor * b) {
5601
+ GGML_ASSERT(ggml_can_out_prod(a, b));
5602
+ GGML_ASSERT(!ggml_is_transposed(a));
5603
+
5604
+ bool is_node = false;
5605
+
5606
+ if (a->grad || b->grad) {
5607
+ is_node = true;
5608
+ }
5609
+
5610
+ const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
5611
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
5612
+
5613
+ result->op = GGML_OP_OUT_PROD;
5614
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5615
+ result->src0 = a;
5616
+ result->src1 = b;
5617
+
5618
+ return result;
5619
+ }
5620
+
5441
5621
  // ggml_scale
5442
5622
 
5443
5623
  struct ggml_tensor * ggml_scale_impl(
@@ -5450,7 +5630,7 @@ struct ggml_tensor * ggml_scale_impl(
5450
5630
 
5451
5631
  bool is_node = false;
5452
5632
 
5453
- if (!inplace && (a->grad || b->grad)) {
5633
+ if (a->grad || b->grad) {
5454
5634
  is_node = true;
5455
5635
  }
5456
5636
 
@@ -5493,7 +5673,7 @@ struct ggml_tensor * ggml_set_impl(
5493
5673
 
5494
5674
  bool is_node = false;
5495
5675
 
5496
- if (!inplace && (a->grad || b->grad)) {
5676
+ if (a->grad || b->grad) {
5497
5677
  is_node = true;
5498
5678
  }
5499
5679
 
@@ -5802,14 +5982,18 @@ struct ggml_tensor * ggml_view_1d(
5802
5982
 
5803
5983
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
5804
5984
 
5985
+ ggml_scratch_save(ctx);
5986
+
5987
+ struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
5988
+ memcpy(offs->data, &offset, 2*sizeof(int32_t));
5989
+
5990
+ ggml_scratch_load(ctx);
5991
+
5805
5992
  result->op = GGML_OP_VIEW;
5806
5993
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5807
5994
  result->src0 = a;
5808
5995
  result->src1 = NULL;
5809
-
5810
- if (is_node) {
5811
- memcpy(result->padding, &offset, sizeof(offset));
5812
- }
5996
+ result->opt[0] = offs;
5813
5997
 
5814
5998
  return result;
5815
5999
  }
@@ -5834,6 +6018,13 @@ struct ggml_tensor * ggml_view_2d(
5834
6018
 
5835
6019
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
5836
6020
 
6021
+ ggml_scratch_save(ctx);
6022
+
6023
+ struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
6024
+ memcpy(offs->data, &offset, 2*sizeof(int32_t));
6025
+
6026
+ ggml_scratch_load(ctx);
6027
+
5837
6028
  result->nb[1] = nb1;
5838
6029
  result->nb[2] = result->nb[1]*ne1;
5839
6030
  result->nb[3] = result->nb[2];
@@ -5842,10 +6033,7 @@ struct ggml_tensor * ggml_view_2d(
5842
6033
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5843
6034
  result->src0 = a;
5844
6035
  result->src1 = NULL;
5845
-
5846
- if (is_node) {
5847
- memcpy(result->padding, &offset, sizeof(offset));
5848
- }
6036
+ result->opt[0] = offs;
5849
6037
 
5850
6038
  return result;
5851
6039
  }
@@ -5872,6 +6060,13 @@ struct ggml_tensor * ggml_view_3d(
5872
6060
 
5873
6061
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
5874
6062
 
6063
+ ggml_scratch_save(ctx);
6064
+
6065
+ struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
6066
+ memcpy(offs->data, &offset, 2*sizeof(int32_t));
6067
+
6068
+ ggml_scratch_load(ctx);
6069
+
5875
6070
  result->nb[1] = nb1;
5876
6071
  result->nb[2] = nb2;
5877
6072
  result->nb[3] = result->nb[2]*ne2;
@@ -5880,10 +6075,7 @@ struct ggml_tensor * ggml_view_3d(
5880
6075
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5881
6076
  result->src0 = a;
5882
6077
  result->src1 = NULL;
5883
-
5884
- if (is_node) {
5885
- memcpy(result->padding, &offset, sizeof(offset));
5886
- }
6078
+ result->opt[0] = offs;
5887
6079
 
5888
6080
  return result;
5889
6081
  }
@@ -5912,6 +6104,13 @@ struct ggml_tensor * ggml_view_4d(
5912
6104
 
5913
6105
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
5914
6106
 
6107
+ ggml_scratch_save(ctx);
6108
+
6109
+ struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
6110
+ memcpy(offs->data, &offset, 2*sizeof(int32_t));
6111
+
6112
+ ggml_scratch_load(ctx);
6113
+
5915
6114
  result->nb[1] = nb1;
5916
6115
  result->nb[2] = nb2;
5917
6116
  result->nb[3] = nb3;
@@ -5920,10 +6119,7 @@ struct ggml_tensor * ggml_view_4d(
5920
6119
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5921
6120
  result->src0 = a;
5922
6121
  result->src1 = NULL;
5923
-
5924
- if (is_node) {
5925
- memcpy(result->padding, &offset, sizeof(offset));
5926
- }
6122
+ result->opt[0] = offs;
5927
6123
 
5928
6124
  return result;
5929
6125
  }
@@ -5986,10 +6182,18 @@ struct ggml_tensor * ggml_permute(
5986
6182
  result->src1 = NULL;
5987
6183
 
5988
6184
  if (is_node) {
5989
- result->padding[0] = axis0;
5990
- result->padding[1] = axis1;
5991
- result->padding[2] = axis2;
5992
- result->padding[3] = axis3;
6185
+ ggml_scratch_save(ctx);
6186
+
6187
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6188
+
6189
+ ((int32_t *) b->data)[0] = axis0;
6190
+ ((int32_t *) b->data)[1] = axis1;
6191
+ ((int32_t *) b->data)[2] = axis2;
6192
+ ((int32_t *) b->data)[3] = axis3;
6193
+
6194
+ ggml_scratch_load(ctx);
6195
+
6196
+ result->opt[0] = b;
5993
6197
  }
5994
6198
 
5995
6199
  return result;
@@ -6229,6 +6433,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
6229
6433
  return ggml_soft_max_impl(ctx, a, true);
6230
6434
  }
6231
6435
 
6436
+
6437
+ // ggml_soft_max_back
6438
+
6439
+ struct ggml_tensor * ggml_soft_max_back_impl(
6440
+ struct ggml_context * ctx,
6441
+ struct ggml_tensor * a,
6442
+ struct ggml_tensor * b,
6443
+ bool inplace) {
6444
+ bool is_node = false;
6445
+
6446
+ if (a->grad || b->grad) {
6447
+ is_node = true; // TODO : implement backward pass
6448
+ }
6449
+
6450
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6451
+
6452
+ result->op = GGML_OP_SOFT_MAX_BACK;
6453
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6454
+ result->src0 = a;
6455
+ result->src1 = b;
6456
+
6457
+ return result;
6458
+ }
6459
+
6460
+ struct ggml_tensor * ggml_soft_max_back(
6461
+ struct ggml_context * ctx,
6462
+ struct ggml_tensor * a,
6463
+ struct ggml_tensor * b) {
6464
+ return ggml_soft_max_back_impl(ctx, a, b, false);
6465
+ }
6466
+
6467
+ struct ggml_tensor * ggml_soft_max_back_inplace(
6468
+ struct ggml_context * ctx,
6469
+ struct ggml_tensor * a,
6470
+ struct ggml_tensor * b) {
6471
+ return ggml_soft_max_back_impl(ctx, a, b, true);
6472
+ }
6473
+
6232
6474
  // ggml_rope
6233
6475
 
6234
6476
  struct ggml_tensor * ggml_rope_impl(
@@ -6241,7 +6483,7 @@ struct ggml_tensor * ggml_rope_impl(
6241
6483
  GGML_ASSERT(n_past >= 0);
6242
6484
  bool is_node = false;
6243
6485
 
6244
- if (!inplace && a->grad) {
6486
+ if (a->grad) {
6245
6487
  is_node = true;
6246
6488
  }
6247
6489
 
@@ -6295,8 +6537,7 @@ struct ggml_tensor * ggml_rope_back(
6295
6537
  bool is_node = false;
6296
6538
 
6297
6539
  if (a->grad) {
6298
- GGML_ASSERT(false); // TODO: implement backward
6299
- is_node = true;
6540
+ is_node = false; // TODO: implement backward
6300
6541
  }
6301
6542
 
6302
6543
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6461,7 +6702,6 @@ struct ggml_tensor * ggml_flash_attn(
6461
6702
  bool is_node = false;
6462
6703
 
6463
6704
  if (q->grad || k->grad || v->grad) {
6464
- GGML_ASSERT(false); // TODO: implement backward
6465
6705
  is_node = true;
6466
6706
  }
6467
6707
 
@@ -6493,7 +6733,6 @@ struct ggml_tensor * ggml_flash_ff(
6493
6733
  bool is_node = false;
6494
6734
 
6495
6735
  if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6496
- GGML_ASSERT(false); // TODO: implement backward
6497
6736
  is_node = true;
6498
6737
  }
6499
6738
 
@@ -6511,6 +6750,71 @@ struct ggml_tensor * ggml_flash_ff(
6511
6750
  return result;
6512
6751
  }
6513
6752
 
6753
+ // ggml_flash_attn_back
6754
+
6755
+ struct ggml_tensor * ggml_flash_attn_back(
6756
+ struct ggml_context * ctx,
6757
+ struct ggml_tensor * q,
6758
+ struct ggml_tensor * k,
6759
+ struct ggml_tensor * v,
6760
+ struct ggml_tensor * d,
6761
+ bool masked) {
6762
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6763
+ // TODO: check if vT can be multiplied by (k*qT)
6764
+
6765
+ // d shape [D,N,ne2,ne3]
6766
+ // q shape [D,N,ne2,ne3]
6767
+ // k shape [D,M,ne2,ne3]
6768
+ // v shape [M,D,ne2,ne3]
6769
+
6770
+ const int64_t D = q->ne[0];
6771
+ const int64_t N = q->ne[1];
6772
+ const int64_t M = k->ne[1];
6773
+ const int64_t ne2 = q->ne[2];
6774
+ const int64_t ne3 = q->ne[3];
6775
+
6776
+ GGML_ASSERT(k->ne[0] == D);
6777
+ GGML_ASSERT(v->ne[0] == M);
6778
+ GGML_ASSERT(v->ne[1] == D);
6779
+ GGML_ASSERT(d->ne[0] == D);
6780
+ GGML_ASSERT(d->ne[1] == N);
6781
+ GGML_ASSERT(k->ne[2] == ne2);
6782
+ GGML_ASSERT(k->ne[3] == ne3);
6783
+ GGML_ASSERT(v->ne[2] == ne2);
6784
+ GGML_ASSERT(v->ne[3] == ne3);
6785
+ GGML_ASSERT(d->ne[2] == ne2);
6786
+ GGML_ASSERT(d->ne[3] == ne3);
6787
+
6788
+ bool is_node = false;
6789
+
6790
+ if (q->grad || k->grad || v->grad) {
6791
+ // when using this operation (in backwards pass) these grads are set.
6792
+ // we don't want to create (big) grad of our result, so is_node is false.
6793
+ is_node = false;
6794
+ }
6795
+
6796
+ // store gradients of q, k and v as continuous tensors concatenated in result.
6797
+ // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
6798
+ // gradq->data = result->data
6799
+ // gradk->data = result->data + nb0*D*N*ne2*ne3
6800
+ // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
6801
+ // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
6802
+ int64_t ne[4] = {D,M+N+M,ne2,ne3};
6803
+
6804
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6805
+
6806
+ result->op = GGML_OP_FLASH_ATTN_BACK;
6807
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6808
+ result->src0 = q;
6809
+ result->src1 = k;
6810
+ result->opt[0] = v;
6811
+ result->opt[1] = d;
6812
+ result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
6813
+
6814
+ return result;
6815
+ }
6816
+
6817
+
6514
6818
  // ggml_map_unary
6515
6819
 
6516
6820
  struct ggml_tensor * ggml_map_unary_impl_f32(
@@ -6595,6 +6899,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
6595
6899
  return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
6596
6900
  }
6597
6901
 
6902
+ // ggml_cross_entropy_loss
6903
+
6904
+ struct ggml_tensor * ggml_cross_entropy_loss(
6905
+ struct ggml_context * ctx,
6906
+ struct ggml_tensor * a,
6907
+ struct ggml_tensor * b) {
6908
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6909
+ bool is_node = false;
6910
+
6911
+ if (a->grad || b->grad) {
6912
+ is_node = true;
6913
+ }
6914
+
6915
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
6916
+
6917
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS;
6918
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6919
+ result->src0 = a;
6920
+ result->src1 = b;
6921
+
6922
+ return result;
6923
+ }
6924
+
6925
+ // ggml_cross_entropy_loss_back
6926
+
6927
+ struct ggml_tensor * ggml_cross_entropy_loss_back(
6928
+ struct ggml_context * ctx,
6929
+ struct ggml_tensor * a,
6930
+ struct ggml_tensor * b,
6931
+ struct ggml_tensor * c) {
6932
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6933
+ GGML_ASSERT(ggml_is_scalar(c));
6934
+
6935
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6936
+
6937
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
6938
+ result->grad = NULL;
6939
+ result->src0 = a;
6940
+ result->src1 = b;
6941
+ result->opt[0] = c;
6942
+
6943
+ return result;
6944
+ }
6945
+
6598
6946
  ////////////////////////////////////////////////////////////////////////////////
6599
6947
 
6600
6948
  void ggml_set_param(
@@ -7584,6 +7932,11 @@ static void ggml_compute_forward_add(
7584
7932
  case GGML_TYPE_Q5_0:
7585
7933
  case GGML_TYPE_Q5_1:
7586
7934
  case GGML_TYPE_Q8_0:
7935
+ case GGML_TYPE_Q2_K:
7936
+ case GGML_TYPE_Q3_K:
7937
+ case GGML_TYPE_Q4_K:
7938
+ case GGML_TYPE_Q5_K:
7939
+ case GGML_TYPE_Q6_K:
7587
7940
  {
7588
7941
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7589
7942
  } break;
@@ -7887,6 +8240,11 @@ static void ggml_compute_forward_add1(
7887
8240
  case GGML_TYPE_Q5_1:
7888
8241
  case GGML_TYPE_Q8_0:
7889
8242
  case GGML_TYPE_Q8_1:
8243
+ case GGML_TYPE_Q2_K:
8244
+ case GGML_TYPE_Q3_K:
8245
+ case GGML_TYPE_Q4_K:
8246
+ case GGML_TYPE_Q5_K:
8247
+ case GGML_TYPE_Q6_K:
7890
8248
  {
7891
8249
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7892
8250
  } break;
@@ -8009,6 +8367,11 @@ static void ggml_compute_forward_acc(
8009
8367
  case GGML_TYPE_Q5_1:
8010
8368
  case GGML_TYPE_Q8_0:
8011
8369
  case GGML_TYPE_Q8_1:
8370
+ case GGML_TYPE_Q2_K:
8371
+ case GGML_TYPE_Q3_K:
8372
+ case GGML_TYPE_Q4_K:
8373
+ case GGML_TYPE_Q5_K:
8374
+ case GGML_TYPE_Q6_K:
8012
8375
  default:
8013
8376
  {
8014
8377
  GGML_ASSERT(false);
@@ -8127,10 +8490,10 @@ static void ggml_compute_forward_mul_f32(
8127
8490
  const int ith = params->ith;
8128
8491
  const int nth = params->nth;
8129
8492
 
8130
- #ifdef GGML_USE_CUBLAS
8131
- if (src1->backend == GGML_BACKEND_CUDA) {
8493
+ #ifdef GGML_USE_CLBLAST
8494
+ if (src1->backend == GGML_BACKEND_GPU) {
8132
8495
  if (ith == 0) {
8133
- ggml_cuda_mul(src0, src1, dst);
8496
+ ggml_cl_mul(src0, src1, dst);
8134
8497
  }
8135
8498
  return;
8136
8499
  }
@@ -8730,29 +9093,122 @@ static void ggml_compute_forward_repeat(
8730
9093
  }
8731
9094
  }
8732
9095
 
8733
- // ggml_compute_forward_abs
9096
+ // ggml_compute_forward_repeat_back
8734
9097
 
8735
- static void ggml_compute_forward_abs_f32(
9098
+ static void ggml_compute_forward_repeat_back_f32(
8736
9099
  const struct ggml_compute_params * params,
8737
9100
  const struct ggml_tensor * src0,
8738
9101
  struct ggml_tensor * dst) {
8739
- assert(params->ith == 0);
8740
- assert(ggml_are_same_shape(src0, dst));
9102
+ GGML_ASSERT(params->ith == 0);
9103
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
8741
9104
 
8742
9105
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8743
9106
  return;
8744
9107
  }
8745
9108
 
8746
- const int n = ggml_nrows(src0);
8747
- const int nc = src0->ne[0];
9109
+ const int64_t ne0 = dst->ne[0];
9110
+ const int64_t ne1 = dst->ne[1];
9111
+ const int64_t ne2 = dst->ne[2];
9112
+ const int64_t ne3 = dst->ne[3];
8748
9113
 
8749
- assert(dst->nb[0] == sizeof(float));
8750
- assert(src0->nb[0] == sizeof(float));
9114
+ const int64_t ne00 = src0->ne[0];
9115
+ const int64_t ne01 = src0->ne[1];
9116
+ const int64_t ne02 = src0->ne[2];
9117
+ const int64_t ne03 = src0->ne[3];
8751
9118
 
8752
- for (int i = 0; i < n; i++) {
8753
- ggml_vec_abs_f32(nc,
8754
- (float *) ((char *) dst->data + i*( dst->nb[1])),
8755
- (float *) ((char *) src0->data + i*(src0->nb[1])));
9119
+ const size_t nb0 = dst->nb[0];
9120
+ const size_t nb1 = dst->nb[1];
9121
+ const size_t nb2 = dst->nb[2];
9122
+ const size_t nb3 = dst->nb[3];
9123
+
9124
+ const size_t nb00 = src0->nb[0];
9125
+ const size_t nb01 = src0->nb[1];
9126
+ const size_t nb02 = src0->nb[2];
9127
+ const size_t nb03 = src0->nb[3];
9128
+
9129
+ // guaranteed to be an integer due to the check in ggml_can_repeat
9130
+ const int nr0 = (int)(ne00/ne0);
9131
+ const int nr1 = (int)(ne01/ne1);
9132
+ const int nr2 = (int)(ne02/ne2);
9133
+ const int nr3 = (int)(ne03/ne3);
9134
+
9135
+ // TODO: support for transposed / permuted tensors
9136
+ GGML_ASSERT(nb0 == sizeof(float));
9137
+ GGML_ASSERT(nb00 == sizeof(float));
9138
+
9139
+ if (ggml_is_contiguous(dst)) {
9140
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9141
+ } else {
9142
+ for (int k3 = 0; k3 < ne3; k3++) {
9143
+ for (int k2 = 0; k2 < ne2; k2++) {
9144
+ for (int k1 = 0; k1 < ne1; k1++) {
9145
+ ggml_vec_set_f32(ne0,
9146
+ (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
9147
+ 0);
9148
+ }
9149
+ }
9150
+ }
9151
+ }
9152
+
9153
+ // TODO: maybe this is not optimal?
9154
+ for (int i3 = 0; i3 < nr3; i3++) {
9155
+ for (int k3 = 0; k3 < ne3; k3++) {
9156
+ for (int i2 = 0; i2 < nr2; i2++) {
9157
+ for (int k2 = 0; k2 < ne2; k2++) {
9158
+ for (int i1 = 0; i1 < nr1; i1++) {
9159
+ for (int k1 = 0; k1 < ne1; k1++) {
9160
+ for (int i0 = 0; i0 < nr0; i0++) {
9161
+ ggml_vec_acc_f32(ne0,
9162
+ (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
9163
+ (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
9164
+ }
9165
+ }
9166
+ }
9167
+ }
9168
+ }
9169
+ }
9170
+ }
9171
+ }
9172
+
9173
+ static void ggml_compute_forward_repeat_back(
9174
+ const struct ggml_compute_params * params,
9175
+ const struct ggml_tensor * src0,
9176
+ struct ggml_tensor * dst) {
9177
+ switch (src0->type) {
9178
+ case GGML_TYPE_F32:
9179
+ {
9180
+ ggml_compute_forward_repeat_back_f32(params, src0, dst);
9181
+ } break;
9182
+ default:
9183
+ {
9184
+ GGML_ASSERT(false);
9185
+ } break;
9186
+ }
9187
+ }
9188
+
9189
+ // ggml_compute_forward_abs
9190
+
9191
+ static void ggml_compute_forward_abs_f32(
9192
+ const struct ggml_compute_params * params,
9193
+ const struct ggml_tensor * src0,
9194
+ struct ggml_tensor * dst) {
9195
+ assert(params->ith == 0);
9196
+ assert(ggml_are_same_shape(src0, dst));
9197
+
9198
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9199
+ return;
9200
+ }
9201
+
9202
+ const int n = ggml_nrows(src0);
9203
+ const int nc = src0->ne[0];
9204
+
9205
+ assert(dst->nb[0] == sizeof(float));
9206
+ assert(src0->nb[0] == sizeof(float));
9207
+
9208
+ for (int i = 0; i < n; i++) {
9209
+ ggml_vec_abs_f32(nc,
9210
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9211
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
8756
9212
  }
8757
9213
  }
8758
9214
 
@@ -9245,7 +9701,7 @@ static void ggml_compute_forward_rms_norm_f32(
9245
9701
  sum += (ggml_float)(x[i00] * x[i00]);
9246
9702
  }
9247
9703
 
9248
- float mean = sum/ne00;
9704
+ const float mean = sum/ne00;
9249
9705
 
9250
9706
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
9251
9707
 
@@ -9568,14 +10024,7 @@ static void ggml_compute_forward_mul_mat_f32(
9568
10024
  // nb01 >= nb00 - src0 is not transposed
9569
10025
  // compute by src0 rows
9570
10026
 
9571
- #if defined(GGML_USE_CUBLAS)
9572
- if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
9573
- if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9574
- ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
9575
- }
9576
- return;
9577
- }
9578
- #elif defined(GGML_USE_CLBLAST)
10027
+ #if defined(GGML_USE_CLBLAST)
9579
10028
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
9580
10029
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9581
10030
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -9740,14 +10189,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
9740
10189
  // nb01 >= nb00 - src0 is not transposed
9741
10190
  // compute by src0 rows
9742
10191
 
9743
- #if defined(GGML_USE_CUBLAS)
9744
- if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
9745
- if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9746
- ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
9747
- }
9748
- return;
9749
- }
9750
- #elif defined(GGML_USE_CLBLAST)
10192
+ #if defined(GGML_USE_CLBLAST)
9751
10193
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
9752
10194
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9753
10195
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -9952,14 +10394,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
9952
10394
  // nb01 >= nb00 - src0 is not transposed
9953
10395
  // compute by src0 rows
9954
10396
 
9955
- #if defined(GGML_USE_CUBLAS)
9956
- if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
9957
- if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9958
- ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
9959
- }
9960
- return;
9961
- }
9962
- #elif defined(GGML_USE_CLBLAST)
10397
+ #if defined(GGML_USE_CLBLAST)
9963
10398
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
9964
10399
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9965
10400
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -10102,6 +10537,11 @@ static void ggml_compute_forward_mul_mat(
10102
10537
  case GGML_TYPE_Q5_1:
10103
10538
  case GGML_TYPE_Q8_0:
10104
10539
  case GGML_TYPE_Q8_1:
10540
+ case GGML_TYPE_Q2_K:
10541
+ case GGML_TYPE_Q3_K:
10542
+ case GGML_TYPE_Q4_K:
10543
+ case GGML_TYPE_Q5_K:
10544
+ case GGML_TYPE_Q6_K:
10105
10545
  {
10106
10546
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
10107
10547
  } break;
@@ -10120,6 +10560,176 @@ static void ggml_compute_forward_mul_mat(
10120
10560
  }
10121
10561
  }
10122
10562
 
10563
+ // ggml_compute_forward_out_prod
10564
+
10565
+
10566
+ static void ggml_compute_forward_out_prod_f32(
10567
+ const struct ggml_compute_params * params,
10568
+ const struct ggml_tensor * src0,
10569
+ const struct ggml_tensor * src1,
10570
+ struct ggml_tensor * dst) {
10571
+ int64_t t0 = ggml_perf_time_us();
10572
+ UNUSED(t0);
10573
+
10574
+ const int64_t ne00 = src0->ne[0];
10575
+ const int64_t ne01 = src0->ne[1];
10576
+ const int64_t ne02 = src0->ne[2];
10577
+ const int64_t ne03 = src0->ne[3];
10578
+
10579
+ const int64_t ne10 = src1->ne[0];
10580
+ //const int64_t ne11 = src1->ne[1];
10581
+ const int64_t ne12 = src1->ne[2];
10582
+ const int64_t ne13 = src1->ne[3];
10583
+
10584
+ const int64_t ne0 = dst->ne[0];
10585
+ const int64_t ne1 = dst->ne[1];
10586
+ const int64_t ne2 = dst->ne[2];
10587
+ const int64_t ne3 = dst->ne[3];
10588
+
10589
+ const int nb00 = src0->nb[0];
10590
+ const int nb01 = src0->nb[1];
10591
+ const int nb02 = src0->nb[2];
10592
+ const int nb03 = src0->nb[3];
10593
+
10594
+ const int nb10 = src1->nb[0];
10595
+ const int nb11 = src1->nb[1];
10596
+ const int nb12 = src1->nb[2];
10597
+ const int nb13 = src1->nb[3];
10598
+
10599
+ const int nb0 = dst->nb[0];
10600
+ const int nb1 = dst->nb[1];
10601
+ const int nb2 = dst->nb[2];
10602
+ const int nb3 = dst->nb[3];
10603
+
10604
+ const int ith = params->ith;
10605
+ const int nth = params->nth;
10606
+
10607
+ GGML_ASSERT(ne02 == ne12);
10608
+ GGML_ASSERT(ne03 == ne13);
10609
+ GGML_ASSERT(ne2 == ne12);
10610
+ GGML_ASSERT(ne3 == ne13);
10611
+
10612
+ // we don't support permuted src0 or src1
10613
+ GGML_ASSERT(nb00 == sizeof(float));
10614
+
10615
+ // dst cannot be transposed or permuted
10616
+ GGML_ASSERT(nb0 == sizeof(float));
10617
+ // GGML_ASSERT(nb0 <= nb1);
10618
+ // GGML_ASSERT(nb1 <= nb2);
10619
+ // GGML_ASSERT(nb2 <= nb3);
10620
+
10621
+ GGML_ASSERT(ne0 == ne00);
10622
+ GGML_ASSERT(ne1 == ne10);
10623
+ GGML_ASSERT(ne2 == ne02);
10624
+ GGML_ASSERT(ne3 == ne03);
10625
+
10626
+ // nb01 >= nb00 - src0 is not transposed
10627
+ // compute by src0 rows
10628
+
10629
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
10630
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10631
+
10632
+ if (params->type == GGML_TASK_INIT) {
10633
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
10634
+ return;
10635
+ }
10636
+
10637
+ if (params->type == GGML_TASK_FINALIZE) {
10638
+ return;
10639
+ }
10640
+
10641
+ // parallelize by last three dimensions
10642
+
10643
+ // total rows in dst
10644
+ const int64_t nr = ne1*ne2*ne3;
10645
+
10646
+ // rows per thread
10647
+ const int64_t dr = (nr + nth - 1)/nth;
10648
+
10649
+ // row range for this thread
10650
+ const int64_t ir0 = dr*ith;
10651
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10652
+
10653
+ // dst[:,:,:,:] = 0
10654
+ // for i2,i3:
10655
+ // for i1:
10656
+ // for i01:
10657
+ // for i0:
10658
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
10659
+
10660
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10661
+ // dst indices
10662
+ const int64_t i3 = ir/(ne2*ne1);
10663
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
10664
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
10665
+
10666
+ const int64_t i02 = i2;
10667
+ const int64_t i03 = i3;
10668
+
10669
+ //const int64_t i10 = i1;
10670
+ const int64_t i12 = i2;
10671
+ const int64_t i13 = i3;
10672
+
10673
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
10674
+ const int64_t i11 = i01;
10675
+
10676
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
10677
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
10678
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
10679
+
10680
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
10681
+ // for (int64_t i0 = 0; i0 < ne0; ++i0) {
10682
+ // d[i0] += s0[i0] * s1[i1];
10683
+ // }
10684
+ }
10685
+ }
10686
+
10687
+ //int64_t t1 = ggml_perf_time_us();
10688
+ //static int64_t acc = 0;
10689
+ //acc += t1 - t0;
10690
+ //if (t1 - t0 > 10) {
10691
+ // printf("\n");
10692
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
10693
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
10694
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
10695
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
10696
+
10697
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
10698
+ //}
10699
+ }
10700
+
10701
+ static void ggml_compute_forward_out_prod(
10702
+ const struct ggml_compute_params * params,
10703
+ const struct ggml_tensor * src0,
10704
+ const struct ggml_tensor * src1,
10705
+ struct ggml_tensor * dst) {
10706
+ switch (src0->type) {
10707
+ case GGML_TYPE_Q4_0:
10708
+ case GGML_TYPE_Q4_1:
10709
+ case GGML_TYPE_Q5_0:
10710
+ case GGML_TYPE_Q5_1:
10711
+ case GGML_TYPE_Q8_0:
10712
+ case GGML_TYPE_Q8_1:
10713
+ {
10714
+ GGML_ASSERT(false); // todo
10715
+ // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10716
+ } break;
10717
+ case GGML_TYPE_F16:
10718
+ {
10719
+ GGML_ASSERT(false); // todo
10720
+ // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
10721
+ } break;
10722
+ case GGML_TYPE_F32:
10723
+ {
10724
+ ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
10725
+ } break;
10726
+ default:
10727
+ {
10728
+ GGML_ASSERT(false);
10729
+ } break;
10730
+ }
10731
+ }
10732
+
10123
10733
  // ggml_compute_forward_scale
10124
10734
 
10125
10735
  static void ggml_compute_forward_scale_f32(
@@ -10285,6 +10895,11 @@ static void ggml_compute_forward_set(
10285
10895
  case GGML_TYPE_Q5_1:
10286
10896
  case GGML_TYPE_Q8_0:
10287
10897
  case GGML_TYPE_Q8_1:
10898
+ case GGML_TYPE_Q2_K:
10899
+ case GGML_TYPE_Q3_K:
10900
+ case GGML_TYPE_Q4_K:
10901
+ case GGML_TYPE_Q5_K:
10902
+ case GGML_TYPE_Q6_K:
10288
10903
  default:
10289
10904
  {
10290
10905
  GGML_ASSERT(false);
@@ -10450,6 +11065,11 @@ static void ggml_compute_forward_get_rows(
10450
11065
  case GGML_TYPE_Q5_1:
10451
11066
  case GGML_TYPE_Q8_0:
10452
11067
  case GGML_TYPE_Q8_1:
11068
+ case GGML_TYPE_Q2_K:
11069
+ case GGML_TYPE_Q3_K:
11070
+ case GGML_TYPE_Q4_K:
11071
+ case GGML_TYPE_Q5_K:
11072
+ case GGML_TYPE_Q6_K:
10453
11073
  {
10454
11074
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
10455
11075
  } break;
@@ -10532,7 +11152,11 @@ static void ggml_compute_forward_get_rows_back_f32(
10532
11152
  GGML_ASSERT(ggml_is_contiguous(opt0));
10533
11153
  GGML_ASSERT(ggml_is_contiguous(dst));
10534
11154
 
10535
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
11155
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
11156
+
11157
+ if (params->type == GGML_TASK_INIT) {
11158
+ memset(dst->data, 0, ggml_nbytes(dst));
11159
+ }
10536
11160
 
10537
11161
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10538
11162
  return;
@@ -10676,8 +11300,8 @@ static void ggml_compute_forward_diag_mask_f32(
10676
11300
  const struct ggml_tensor * src1,
10677
11301
  struct ggml_tensor * dst,
10678
11302
  const float value) {
10679
- assert(src1->type == GGML_TYPE_I32);
10680
- assert(ggml_nelements(src1) == 2);
11303
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11304
+ GGML_ASSERT(ggml_nelements(src1) == 2);
10681
11305
 
10682
11306
  const int ith = params->ith;
10683
11307
  const int nth = params->nth;
@@ -10685,7 +11309,7 @@ static void ggml_compute_forward_diag_mask_f32(
10685
11309
  const int n_past = ((int32_t *) src1->data)[0];
10686
11310
  const bool inplace = (bool)((int32_t *) src1->data)[1];
10687
11311
 
10688
- assert(n_past >= 0);
11312
+ GGML_ASSERT(n_past >= 0);
10689
11313
 
10690
11314
  if (!inplace && (params->type == GGML_TASK_INIT)) {
10691
11315
  // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -10709,8 +11333,8 @@ static void ggml_compute_forward_diag_mask_f32(
10709
11333
  const int nr = src0->ne[1];
10710
11334
  const int nz = n/nr;
10711
11335
 
10712
- assert( dst->nb[0] == sizeof(float));
10713
- assert(src0->nb[0] == sizeof(float));
11336
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
11337
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
10714
11338
 
10715
11339
  for (int k = 0; k < nz; k++) {
10716
11340
  for (int j = ith; j < nr; j += nth) {
@@ -10846,42 +11470,137 @@ static void ggml_compute_forward_soft_max(
10846
11470
  }
10847
11471
  }
10848
11472
 
10849
- // ggml_compute_forward_alibi
11473
+ // ggml_compute_forward_soft_max_back
10850
11474
 
10851
- static void ggml_compute_forward_alibi_f32(
11475
+ static void ggml_compute_forward_soft_max_back_f32(
10852
11476
  const struct ggml_compute_params * params,
10853
11477
  const struct ggml_tensor * src0,
10854
11478
  const struct ggml_tensor * src1,
10855
11479
  struct ggml_tensor * dst) {
10856
- assert(params->ith == 0);
10857
- assert(src1->type == GGML_TYPE_I32);
10858
- assert(ggml_nelements(src1) == 3);
11480
+ GGML_ASSERT(ggml_is_contiguous(src0));
11481
+ GGML_ASSERT(ggml_is_contiguous(src1));
11482
+ GGML_ASSERT(ggml_is_contiguous(dst));
11483
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
11484
+ GGML_ASSERT(ggml_are_same_shape(src1, dst));
10859
11485
 
10860
11486
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10861
11487
  return;
10862
11488
  }
10863
11489
 
10864
- const int n_past = ((int32_t *) src1->data)[0];
10865
- const int n_head = ((int32_t *) src1->data)[1];
10866
- const float max_bias = ((float *) src1->data)[2];
11490
+ // TODO: handle transposed/permuted matrices
10867
11491
 
10868
- assert(n_past >= 0);
11492
+ const int ith = params->ith;
11493
+ const int nth = params->nth;
10869
11494
 
10870
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
10871
- const int ne1 = src0->ne[1]; // seq_len_without_past
10872
- //const int ne2 = src0->ne[2]; // n_head -> this is k
10873
- //const int ne3 = src0->ne[3]; // 1 -> bsz
11495
+ const int nc = src0->ne[0];
11496
+ const int nr = ggml_nrows(src0);
10874
11497
 
10875
- const int n = ggml_nrows(src0);
10876
- const int ne2_ne3 = n/ne1; // ne2*ne3
11498
+ // rows per thread
11499
+ const int dr = (nr + nth - 1)/nth;
10877
11500
 
10878
- const int nb0 = src0->nb[0];
10879
- const int nb1 = src0->nb[1];
10880
- const int nb2 = src0->nb[2];
10881
- //const int nb3 = src0->nb[3];
11501
+ // row range for this thread
11502
+ const int ir0 = dr*ith;
11503
+ const int ir1 = MIN(ir0 + dr, nr);
10882
11504
 
10883
- assert(nb0 == sizeof(float));
10884
- assert(ne1 + n_past == ne0); (void) n_past;
11505
+ for (int i1 = ir0; i1 < ir1; i1++) {
11506
+ float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
11507
+ float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
11508
+ float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
11509
+
11510
+ #ifndef NDEBUG
11511
+ for (int i = 0; i < nc; ++i) {
11512
+ //printf("p[%d] = %f\n", i, p[i]);
11513
+ assert(!isnan(dy[i]));
11514
+ assert(!isnan(y[i]));
11515
+ }
11516
+ #endif
11517
+ // Jii = yi - yi*yi
11518
+ // Jij = -yi*yj
11519
+ // J = diag(y)-y.T*y
11520
+ // dx = J * dy
11521
+ // dxk = sum_i(Jki * dyi)
11522
+ // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
11523
+ // dxk = sum_i(-yk*yi * dyi) + yk*dyk
11524
+ // dxk = -yk * sum_i(yi * dyi) + yk*dyk
11525
+ // dxk = -yk * dot(y, dy) + yk*dyk
11526
+ // dxk = yk * (- dot(y, dy) + dyk)
11527
+ // dxk = yk * (dyk - dot(y, dy))
11528
+ //
11529
+ // post-order:
11530
+ // dot_y_dy := dot(y, dy)
11531
+ // dx := dy
11532
+ // dx := dx - dot_y_dy
11533
+ // dx := dx * y
11534
+
11535
+ // linear runtime, no additional memory
11536
+ float dot_y_dy = 0;
11537
+ ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
11538
+ ggml_vec_cpy_f32 (nc, dx, dy);
11539
+ ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
11540
+ ggml_vec_mul_f32 (nc, dx, dx, y);
11541
+
11542
+ #ifndef NDEBUG
11543
+ for (int i = 0; i < nc; ++i) {
11544
+ assert(!isnan(dx[i]));
11545
+ assert(!isinf(dx[i]));
11546
+ }
11547
+ #endif
11548
+ }
11549
+ }
11550
+
11551
+ static void ggml_compute_forward_soft_max_back(
11552
+ const struct ggml_compute_params * params,
11553
+ const struct ggml_tensor * src0,
11554
+ const struct ggml_tensor * src1,
11555
+ struct ggml_tensor * dst) {
11556
+ switch (src0->type) {
11557
+ case GGML_TYPE_F32:
11558
+ {
11559
+ ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
11560
+ } break;
11561
+ default:
11562
+ {
11563
+ GGML_ASSERT(false);
11564
+ } break;
11565
+ }
11566
+ }
11567
+
11568
+ // ggml_compute_forward_alibi
11569
+
11570
+ static void ggml_compute_forward_alibi_f32(
11571
+ const struct ggml_compute_params * params,
11572
+ const struct ggml_tensor * src0,
11573
+ const struct ggml_tensor * src1,
11574
+ struct ggml_tensor * dst) {
11575
+ assert(params->ith == 0);
11576
+ assert(src1->type == GGML_TYPE_I32);
11577
+ assert(ggml_nelements(src1) == 3);
11578
+
11579
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11580
+ return;
11581
+ }
11582
+
11583
+ const int n_past = ((int32_t *) src1->data)[0];
11584
+ const int n_head = ((int32_t *) src1->data)[1];
11585
+ const float max_bias = ((float *) src1->data)[2];
11586
+
11587
+ assert(n_past >= 0);
11588
+
11589
+ const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11590
+ const int ne1 = src0->ne[1]; // seq_len_without_past
11591
+ //const int ne2 = src0->ne[2]; // n_head -> this is k
11592
+ //const int ne3 = src0->ne[3]; // 1 -> bsz
11593
+
11594
+ const int n = ggml_nrows(src0);
11595
+ const int ne2_ne3 = n/ne1; // ne2*ne3
11596
+
11597
+ const int nb0 = src0->nb[0];
11598
+ const int nb1 = src0->nb[1];
11599
+ const int nb2 = src0->nb[2];
11600
+ //const int nb3 = src0->nb[3];
11601
+
11602
+ assert(nb0 == sizeof(float));
11603
+ assert(ne1 + n_past == ne0); (void) n_past;
10885
11604
 
10886
11605
  // add alibi to src0 (KQ_scaled)
10887
11606
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -10996,6 +11715,12 @@ static void ggml_compute_forward_alibi(
10996
11715
  case GGML_TYPE_Q5_1:
10997
11716
  case GGML_TYPE_Q8_0:
10998
11717
  case GGML_TYPE_Q8_1:
11718
+ case GGML_TYPE_Q2_K:
11719
+ case GGML_TYPE_Q3_K:
11720
+ case GGML_TYPE_Q4_K:
11721
+ case GGML_TYPE_Q5_K:
11722
+ case GGML_TYPE_Q6_K:
11723
+ case GGML_TYPE_Q8_K:
10999
11724
  case GGML_TYPE_I8:
11000
11725
  case GGML_TYPE_I16:
11001
11726
  case GGML_TYPE_I32:
@@ -11067,6 +11792,12 @@ static void ggml_compute_forward_clamp(
11067
11792
  case GGML_TYPE_Q5_1:
11068
11793
  case GGML_TYPE_Q8_0:
11069
11794
  case GGML_TYPE_Q8_1:
11795
+ case GGML_TYPE_Q2_K:
11796
+ case GGML_TYPE_Q3_K:
11797
+ case GGML_TYPE_Q4_K:
11798
+ case GGML_TYPE_Q5_K:
11799
+ case GGML_TYPE_Q6_K:
11800
+ case GGML_TYPE_Q8_K:
11070
11801
  case GGML_TYPE_I8:
11071
11802
  case GGML_TYPE_I16:
11072
11803
  case GGML_TYPE_I32:
@@ -11156,7 +11887,7 @@ static void ggml_compute_forward_rope_f32(
11156
11887
  theta *= theta_scale;
11157
11888
 
11158
11889
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11159
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11890
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11160
11891
 
11161
11892
  const float x0 = src[0];
11162
11893
  const float x1 = src[1];
@@ -11177,7 +11908,7 @@ static void ggml_compute_forward_rope_f32(
11177
11908
  const int64_t i0 = ib*n_dims + ic/2;
11178
11909
 
11179
11910
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11180
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11911
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11181
11912
 
11182
11913
  const float x0 = src[0];
11183
11914
  const float x1 = src[n_dims/2];
@@ -12787,6 +13518,414 @@ static void ggml_compute_forward_flash_ff(
12787
13518
  }
12788
13519
  }
12789
13520
 
13521
+ // ggml_compute_forward_flash_attn_back
13522
+
13523
+ static void ggml_compute_forward_flash_attn_back_f32(
13524
+ const struct ggml_compute_params * params,
13525
+ const struct ggml_tensor * q,
13526
+ const struct ggml_tensor * k,
13527
+ const struct ggml_tensor * v,
13528
+ const struct ggml_tensor * d,
13529
+ const bool masked,
13530
+ struct ggml_tensor * dst) {
13531
+ int64_t t0 = ggml_perf_time_us();
13532
+ UNUSED(t0);
13533
+
13534
+ const int64_t neq0 = q->ne[0];
13535
+ const int64_t neq1 = q->ne[1];
13536
+ const int64_t neq2 = q->ne[2];
13537
+ const int64_t neq3 = q->ne[3];
13538
+
13539
+ const int64_t nek0 = k->ne[0];
13540
+ const int64_t nek1 = k->ne[1];
13541
+ //const int64_t nek2 = k->ne[2];
13542
+ //const int64_t nek3 = k->ne[3];
13543
+
13544
+ const int64_t nev0 = v->ne[0];
13545
+ const int64_t nev1 = v->ne[1];
13546
+ //const int64_t nev2 = v->ne[2];
13547
+ //const int64_t nev3 = v->ne[3];
13548
+
13549
+ const int64_t ned0 = d->ne[0];
13550
+ const int64_t ned1 = d->ne[1];
13551
+ //const int64_t ned2 = d->ne[2];
13552
+ //const int64_t ned3 = d->ne[3];
13553
+
13554
+ const int64_t ne0 = dst->ne[0];
13555
+ const int64_t ne1 = dst->ne[1];
13556
+ const int64_t ne2 = dst->ne[2];
13557
+ const int64_t ne3 = dst->ne[3];
13558
+
13559
+ const int nbk0 = k->nb[0];
13560
+ const int nbk1 = k->nb[1];
13561
+ const int nbk2 = k->nb[2];
13562
+ const int nbk3 = k->nb[3];
13563
+
13564
+ const int nbq0 = q->nb[0];
13565
+ const int nbq1 = q->nb[1];
13566
+ const int nbq2 = q->nb[2];
13567
+ const int nbq3 = q->nb[3];
13568
+
13569
+ const int nbv0 = v->nb[0];
13570
+ const int nbv1 = v->nb[1];
13571
+ const int nbv2 = v->nb[2];
13572
+ const int nbv3 = v->nb[3];
13573
+
13574
+ const int nbd0 = d->nb[0];
13575
+ const int nbd1 = d->nb[1];
13576
+ const int nbd2 = d->nb[2];
13577
+ const int nbd3 = d->nb[3];
13578
+
13579
+ const int nb0 = dst->nb[0];
13580
+ const int nb1 = dst->nb[1];
13581
+ const int nb2 = dst->nb[2];
13582
+ const int nb3 = dst->nb[3];
13583
+
13584
+ const int ith = params->ith;
13585
+ const int nth = params->nth;
13586
+
13587
+ const int64_t D = neq0;
13588
+ const int64_t N = neq1;
13589
+ const int64_t P = nek1 - N;
13590
+ const int64_t M = P + N;
13591
+
13592
+ const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
13593
+ const int mxDM = MAX(D, Mup);
13594
+
13595
+ // GGML_ASSERT(ne0 == D);
13596
+ // GGML_ASSERT(ne1 == N);
13597
+ GGML_ASSERT(P >= 0);
13598
+
13599
+ GGML_ASSERT(nbq0 == sizeof(float));
13600
+ GGML_ASSERT(nbk0 == sizeof(float));
13601
+ GGML_ASSERT(nbv0 == sizeof(float));
13602
+
13603
+ GGML_ASSERT(neq0 == D);
13604
+ GGML_ASSERT(nek0 == D);
13605
+ GGML_ASSERT(nev1 == D);
13606
+ GGML_ASSERT(ned0 == D);
13607
+
13608
+ GGML_ASSERT(neq1 == N);
13609
+ GGML_ASSERT(nek1 == N + P);
13610
+ GGML_ASSERT(nev1 == D);
13611
+ GGML_ASSERT(ned1 == N);
13612
+
13613
+ // dst cannot be transposed or permuted
13614
+ GGML_ASSERT(nb0 == sizeof(float));
13615
+ GGML_ASSERT(nb0 <= nb1);
13616
+ GGML_ASSERT(nb1 <= nb2);
13617
+ GGML_ASSERT(nb2 <= nb3);
13618
+
13619
+ if (params->type == GGML_TASK_INIT) {
13620
+ if (ith == 0) {
13621
+ memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
13622
+ }
13623
+ return;
13624
+ }
13625
+
13626
+ if (params->type == GGML_TASK_FINALIZE) {
13627
+ return;
13628
+ }
13629
+
13630
+ // parallelize by q rows using ggml_vec_dot_f32
13631
+
13632
+ // total rows in q
13633
+ const int nr = neq2*neq3;
13634
+
13635
+ // rows per thread
13636
+ const int dr = (nr + nth - 1)/nth;
13637
+
13638
+ // row range for this thread
13639
+ const int ir0 = dr*ith;
13640
+ const int ir1 = MIN(ir0 + dr, nr);
13641
+
13642
+ const float scale = 1.0f/sqrtf(D);
13643
+
13644
+ //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
13645
+
13646
+ for (int ir = ir0; ir < ir1; ++ir) {
13647
+ // q indices
13648
+ const int iq3 = ir/(neq2);
13649
+ const int iq2 = ir - iq3*neq2;
13650
+ for ( int iq1 = 0; iq1 < neq1; ++iq1) {
13651
+
13652
+
13653
+ // not sure about CACHE_LINE_SIZE_F32..
13654
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
13655
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
13656
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
13657
+
13658
+ for (int i = M; i < Mup; ++i) {
13659
+ S[i] = -INFINITY;
13660
+ }
13661
+
13662
+ for (int64_t ic = 0; ic < nek1; ++ic) {
13663
+ // k indices
13664
+ const int ik3 = iq3;
13665
+ const int ik2 = iq2;
13666
+ const int ik1 = ic;
13667
+
13668
+ // S indices
13669
+ const int i1 = ik1;
13670
+
13671
+ ggml_vec_dot_f32(neq0,
13672
+ S + i1,
13673
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
13674
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
13675
+ }
13676
+
13677
+ // scale
13678
+ ggml_vec_scale_f32(nek1, S, scale);
13679
+
13680
+ if (masked) {
13681
+ for (int64_t i = P; i < M; i++) {
13682
+ if (i > P + iq1) {
13683
+ S[i] = -INFINITY;
13684
+ }
13685
+ }
13686
+ }
13687
+
13688
+ // softmax
13689
+ {
13690
+ float max = -INFINITY;
13691
+ ggml_vec_max_f32(M, &max, S);
13692
+
13693
+ ggml_float sum = 0.0;
13694
+ {
13695
+ #ifdef GGML_SOFT_MAX_ACCELERATE
13696
+ max = -max;
13697
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
13698
+ vvexpf(SM, SM, &Mup);
13699
+ ggml_vec_sum_f32(Mup, &sum, SM);
13700
+ #else
13701
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL];
13702
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
13703
+
13704
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
13705
+ float * SR = S + i;
13706
+ float * SW = SM + i;
13707
+
13708
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
13709
+ if (SR[j] == -INFINITY) {
13710
+ SW[j] = 0.0f;
13711
+ } else {
13712
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
13713
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
13714
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
13715
+ sump[j] += (ggml_float)val;
13716
+ SW[j] = val;
13717
+ }
13718
+ }
13719
+ }
13720
+
13721
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
13722
+ sum += sump[i];
13723
+ }
13724
+ #endif
13725
+ }
13726
+
13727
+ assert(sum > 0.0);
13728
+
13729
+ sum = 1.0/sum;
13730
+ ggml_vec_scale_f32(M, SM, sum);
13731
+
13732
+ }
13733
+
13734
+ // step-by-step explanation
13735
+ {
13736
+ // forward-process shape grads from backward process
13737
+ // parallel_for iq2,iq3:
13738
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
13739
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
13740
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
13741
+ // for iq1:
13742
+ // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
13743
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
13744
+ // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
13745
+ // S0 = -Inf [D,1,1,1]
13746
+ // ~S1[i] = dot(kcur[:D,i], qcur)
13747
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
13748
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
13749
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13750
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
13751
+ // ~S5[i] = dot(vcur[:,i], S4)
13752
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
13753
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
13754
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
13755
+ // dst backward-/ grad[dst] = d
13756
+ //
13757
+ // output gradients with their dependencies:
13758
+ //
13759
+ // grad[kcur] = grad[S1].T @ qcur
13760
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
13761
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13762
+ // grad[S4] = grad[S5] @ vcur
13763
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
13764
+ // grad[qcur] = grad[S1] @ kcur
13765
+ // grad[vcur] = grad[S5].T @ S4
13766
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
13767
+ //
13768
+ // in post-order:
13769
+ //
13770
+ // S1 = qcur @ kcur.T
13771
+ // S2 = S1 * scale
13772
+ // S3 = diag_mask_inf(S2, P)
13773
+ // S4 = softmax(S3)
13774
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
13775
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13776
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
13777
+ // grad[qcur] = grad[S1] @ kcur
13778
+ // grad[kcur] = grad[S1].T @ qcur
13779
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
13780
+ //
13781
+ // using less variables (SM=S4):
13782
+ //
13783
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
13784
+ // SM = softmax(S)
13785
+ // S = d[:D,iq1,iq2,iq3] @ vcur
13786
+ // dot_SM_gradSM = dot(SM, S)
13787
+ // S = SM * (S - dot(SM, S))
13788
+ // S = diag_mask_zero(S, P) * scale
13789
+ //
13790
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
13791
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
13792
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
13793
+ }
13794
+
13795
+ // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
13796
+ // S = d[:D,iq1,iq2,iq3] @ vcur
13797
+ // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
13798
+ ggml_vec_set_f32(M, S, 0);
13799
+ for (int64_t ic = 0; ic < D; ++ic) {
13800
+ // dst indices
13801
+ const int i1 = iq1;
13802
+ const int i2 = iq2;
13803
+ const int i3 = iq3;
13804
+
13805
+ ggml_vec_mad_f32(M,
13806
+ S,
13807
+ (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
13808
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
13809
+ }
13810
+
13811
+ // S = SM * (S - dot(SM, S))
13812
+ float dot_SM_gradSM = 0;
13813
+ ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
13814
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
13815
+ ggml_vec_mul_f32 (M, S, S, SM);
13816
+
13817
+ // S = diag_mask_zero(S, P) * scale
13818
+ if (masked) {
13819
+ // for (int64_t i = P + iq1 + 1; i < M; i++) {
13820
+ // S[i] = 0;
13821
+ // }
13822
+ for (int64_t i = P; i < M; i++) {
13823
+ if (i > P + iq1) {
13824
+ S[i] = 0;
13825
+ }
13826
+ }
13827
+ }
13828
+ ggml_vec_scale_f32(M, S, scale);
13829
+
13830
+ void * grad_q = (char *) dst->data;
13831
+ void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
13832
+ void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
13833
+
13834
+ const size_t nbgq1 = nb0*neq0;
13835
+ const size_t nbgq2 = nb0*neq0*neq1;
13836
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
13837
+
13838
+ const size_t nbgk1 = nb0*nek0;
13839
+ const size_t nbgk2 = nb0*nek0*nek1;
13840
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
13841
+
13842
+ const size_t nbgv1 = nb0*nev0;
13843
+ const size_t nbgv2 = nb0*nev0*nev1;
13844
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
13845
+
13846
+ // S shape [M,1]
13847
+ // SM shape [M,1]
13848
+ // kcur shape [D,M]
13849
+ // qcur shape [D,1]
13850
+ // vcur shape [M,D]
13851
+ //
13852
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
13853
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
13854
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
13855
+ //
13856
+ //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
13857
+ //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
13858
+ for (int64_t ic = 0; ic < M; ++ic) {
13859
+ // dst indices
13860
+ const int i1 = iq1;
13861
+ const int i2 = iq2;
13862
+ const int i3 = iq3;
13863
+
13864
+ ggml_vec_mad_f32(D,
13865
+ (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
13866
+ (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
13867
+ S[ic]);
13868
+ }
13869
+
13870
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
13871
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
13872
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
13873
+ for (int64_t ic = 0; ic < M; ++ic) {
13874
+ // dst indices
13875
+ const int i1 = iq1;
13876
+ const int i2 = iq2;
13877
+ const int i3 = iq3;
13878
+
13879
+ // ggml_vec_set_f32(D,
13880
+ // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
13881
+ // 0);
13882
+ ggml_vec_mad_f32(D,
13883
+ (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
13884
+ (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
13885
+ S[ic]);
13886
+ }
13887
+
13888
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
13889
+ // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
13890
+ // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
13891
+ for (int64_t ic = 0; ic < D; ++ic) {
13892
+ // dst indices
13893
+ const int i1 = iq1;
13894
+ const int i2 = iq2;
13895
+ const int i3 = iq3;
13896
+
13897
+ // ggml_vec_set_f32(M,
13898
+ // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
13899
+ // 0);
13900
+ ggml_vec_mad_f32(M,
13901
+ (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
13902
+ SM,
13903
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
13904
+ }
13905
+ }
13906
+ }
13907
+ }
13908
+
13909
+ static void ggml_compute_forward_flash_attn_back(
13910
+ const struct ggml_compute_params * params,
13911
+ const struct ggml_tensor * q,
13912
+ const struct ggml_tensor * k,
13913
+ const struct ggml_tensor * v,
13914
+ const struct ggml_tensor * d,
13915
+ const bool masked,
13916
+ struct ggml_tensor * dst) {
13917
+ switch (q->type) {
13918
+ case GGML_TYPE_F32:
13919
+ {
13920
+ ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
13921
+ } break;
13922
+ default:
13923
+ {
13924
+ GGML_ASSERT(false);
13925
+ } break;
13926
+ }
13927
+ }
13928
+
12790
13929
  // ggml_compute_forward_map_unary
12791
13930
 
12792
13931
  static void ggml_compute_forward_map_unary_f32(
@@ -12849,29 +13988,308 @@ static void ggml_compute_forward_map_binary_f32(
12849
13988
  const int n = ggml_nrows(src0);
12850
13989
  const int nc = src0->ne[0];
12851
13990
 
12852
- assert( dst->nb[0] == sizeof(float));
12853
- assert(src0->nb[0] == sizeof(float));
12854
- assert(src1->nb[0] == sizeof(float));
13991
+ assert( dst->nb[0] == sizeof(float));
13992
+ assert(src0->nb[0] == sizeof(float));
13993
+ assert(src1->nb[0] == sizeof(float));
13994
+
13995
+ for (int i = 0; i < n; i++) {
13996
+ fun(nc,
13997
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
13998
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
13999
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
14000
+ }
14001
+ }
14002
+
14003
+
14004
+ static void ggml_compute_forward_map_binary(
14005
+ const struct ggml_compute_params * params,
14006
+ const struct ggml_tensor * src0,
14007
+ const struct ggml_tensor * src1,
14008
+ struct ggml_tensor * dst,
14009
+ const ggml_binary_op_f32_t fun) {
14010
+ switch (src0->type) {
14011
+ case GGML_TYPE_F32:
14012
+ {
14013
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14014
+ } break;
14015
+ default:
14016
+ {
14017
+ GGML_ASSERT(false);
14018
+ } break;
14019
+ }
14020
+ }
14021
+
14022
+ // ggml_compute_forward_cross_entropy_loss
14023
+
14024
+ static void ggml_compute_forward_cross_entropy_loss_f32(
14025
+ const struct ggml_compute_params * params,
14026
+ const struct ggml_tensor * src0,
14027
+ const struct ggml_tensor * src1,
14028
+ struct ggml_tensor * dst) {
14029
+ GGML_ASSERT(ggml_is_contiguous(src0));
14030
+ GGML_ASSERT(ggml_is_contiguous(src1));
14031
+ GGML_ASSERT(ggml_is_scalar(dst));
14032
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
14033
+
14034
+ const int ith = params->ith;
14035
+ const int nth = params->nth;
14036
+
14037
+ float * sums = (float *) params->wdata;
14038
+
14039
+ // TODO: handle transposed/permuted matrices
14040
+ const int nc = src0->ne[0];
14041
+ const int nr = ggml_nrows(src0);
14042
+
14043
+ if (params->type == GGML_TASK_INIT) {
14044
+ if (ith == 0) {
14045
+ memset(sums, 0, sizeof(float) * (nth + nth * nc));
14046
+ }
14047
+ return;
14048
+ }
14049
+
14050
+ if (params->type == GGML_TASK_FINALIZE) {
14051
+ if (ith == 0) {
14052
+ float * dp = (float *) dst->data;
14053
+ ggml_vec_sum_f32(nth, dp, sums);
14054
+ dp[0] *= -1.0f;
14055
+ }
14056
+ return;
14057
+ }
14058
+
14059
+ const double eps = 1e-9;
14060
+
14061
+ // rows per thread
14062
+ const int dr = (nr + nth - 1)/nth;
14063
+
14064
+ // row range for this thread
14065
+ const int ir0 = dr*ith;
14066
+ const int ir1 = MIN(ir0 + dr, nr);
14067
+
14068
+ for (int i1 = ir0; i1 < ir1; i1++) {
14069
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14070
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14071
+ float * st = (float *) params->wdata + nth + ith*nc;
14072
+
14073
+ #ifndef NDEBUG
14074
+ for (int i = 0; i < nc; ++i) {
14075
+ //printf("p[%d] = %f\n", i, p[i]);
14076
+ assert(!isnan(s0[i]));
14077
+ assert(!isnan(s1[i]));
14078
+ }
14079
+ #endif
14080
+ // soft_max
14081
+ ggml_float sum = 0.0;
14082
+ {
14083
+ float max = -INFINITY;
14084
+ ggml_vec_max_f32(nc, &max, s0);
14085
+
14086
+ uint16_t scvt;
14087
+ for (int i = 0; i < nc; i++) {
14088
+ if (s0[i] == -INFINITY) {
14089
+ st[i] = 0.0f;
14090
+ } else {
14091
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14092
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14093
+ memcpy(&scvt, &s, sizeof(scvt));
14094
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14095
+ sum += (ggml_float)val;
14096
+ st[i] = val;
14097
+ }
14098
+ }
14099
+
14100
+ assert(sum > 0.0);
14101
+ // sum = 1.0/sum;
14102
+ }
14103
+ // avoid log(0) by rescaling from [0..1] to [eps..1]
14104
+ sum = (1.0 - eps) / sum;
14105
+ ggml_vec_scale_f32(nc, st, sum);
14106
+ ggml_vec_add1_f32(nc, st, st, eps);
14107
+ ggml_vec_log_f32(nc, st, st);
14108
+ ggml_vec_mul_f32(nc, st, st, s1);
14109
+
14110
+ ggml_vec_sum_f32(nc, sums + ith, st);
14111
+
14112
+ #ifndef NDEBUG
14113
+ for (int i = 0; i < nc; ++i) {
14114
+ assert(!isnan(st[i]));
14115
+ assert(!isinf(st[i]));
14116
+ }
14117
+ #endif
14118
+ }
14119
+
14120
+ }
14121
+
14122
+ static void ggml_compute_forward_cross_entropy_loss(
14123
+ const struct ggml_compute_params * params,
14124
+ const struct ggml_tensor * src0,
14125
+ const struct ggml_tensor * src1,
14126
+ struct ggml_tensor * dst) {
14127
+ switch (src0->type) {
14128
+ case GGML_TYPE_F32:
14129
+ {
14130
+ ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
14131
+ } break;
14132
+ default:
14133
+ {
14134
+ GGML_ASSERT(false);
14135
+ } break;
14136
+ }
14137
+ }
14138
+
14139
+ // ggml_compute_forward_cross_entropy_loss_back
14140
+
14141
+ static void ggml_compute_forward_cross_entropy_loss_back_f32(
14142
+ const struct ggml_compute_params * params,
14143
+ const struct ggml_tensor * src0,
14144
+ const struct ggml_tensor * src1,
14145
+ const struct ggml_tensor * opt0,
14146
+ struct ggml_tensor * dst) {
14147
+ GGML_ASSERT(ggml_is_contiguous(dst));
14148
+ GGML_ASSERT(ggml_is_contiguous(src0));
14149
+ GGML_ASSERT(ggml_is_contiguous(src1));
14150
+ GGML_ASSERT(ggml_is_contiguous(opt0));
14151
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14152
+
14153
+ const int64_t ith = params->ith;
14154
+ const int64_t nth = params->nth;
14155
+
14156
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14157
+ return;
14158
+ }
14159
+
14160
+ const float eps = 1e-9f;
14161
+
14162
+ // TODO: handle transposed/permuted matrices
14163
+ const int64_t nc = src0->ne[0];
14164
+ const int64_t nr = ggml_nrows(src0);
14165
+
14166
+ // rows per thread
14167
+ const int64_t dr = (nr + nth - 1)/nth;
14168
+
14169
+ // row range for this thread
14170
+ const int64_t ir0 = dr*ith;
14171
+ const int64_t ir1 = MIN(ir0 + dr, nr);
14172
+
14173
+ float * d = (float *) opt0->data;
14174
+
14175
+ for (int64_t i1 = ir0; i1 < ir1; i1++) {
14176
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
14177
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14178
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14179
+ float * sm = (float *) params->wdata + ith*nc;
14180
+
14181
+ #ifndef NDEBUG
14182
+ for (int i = 0; i < nc; ++i) {
14183
+ //printf("p[%d] = %f\n", i, p[i]);
14184
+ assert(!isnan(s0[i]));
14185
+ assert(!isnan(s1[i]));
14186
+ }
14187
+ #endif
14188
+ // step by step explanation:
14189
+ {
14190
+ //float * sums = (float *) params->wdata;
14191
+
14192
+ // forward pass with annotated gradients from backward pass
14193
+ // (built by going in reverse operation order, adding to gradients of current operation args)
14194
+ // st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
14195
+ // from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14196
+ // ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
14197
+ // ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
14198
+ // ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
14199
+ // ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
14200
+ // ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
14201
+ // ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
14202
+
14203
+ // substitute into grad[st1], because we can reuse softmax_back from this point on
14204
+ // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
14205
+ // postorder:
14206
+ // grad[st1] := softmax(s0)
14207
+ // grad[st1] := grad[st1]*(1.0 - eps)
14208
+ // grad[st1] := grad[st1] + eps
14209
+ // grad[st1] := s1 / grad[st1]
14210
+ // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
14211
+
14212
+ // src0 gradients by going through softmax_back
14213
+ // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14214
+ // from softmax_back:
14215
+ // dxk = yk * (dyk - dot(y, dy))
14216
+ // dot_y_dy := dot(y, dy)
14217
+ // dx := dy
14218
+ // dx := dx - dot_y_dy
14219
+ // dx := dx * y
14220
+ // postorder:
14221
+ // dot_st1_dst1 := dot(st1, grad[st1])
14222
+ // grad[s0] := grad[st1]
14223
+ // grad[s0] := grad[s0] - dot_st1_dst1
14224
+ // grad[s0] := grad[s0] * st1
14225
+
14226
+ // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
14227
+ // sm := softmax(s0)
14228
+ // grad[s0] := sm*(1.0 - eps)
14229
+ // grad[s0] := grad[s0] + eps
14230
+ // grad[s0] := s1 / grad[s0]
14231
+ // grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
14232
+ // dot_st1_dst1 := dot(sm, grad[s0])
14233
+ // grad[s0] := grad[s0] - dot_st1_dst1
14234
+ // grad[s0] := grad[s0] * sm
14235
+ }
14236
+
14237
+ // soft_max
14238
+ ggml_float sum = 0.0;
14239
+ {
14240
+ float max = -INFINITY;
14241
+ ggml_vec_max_f32(nc, &max, s0);
14242
+
14243
+ uint16_t scvt;
14244
+ for (int i = 0; i < nc; i++) {
14245
+ if (s0[i] == -INFINITY) {
14246
+ sm[i] = 0.0f;
14247
+ } else {
14248
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14249
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14250
+ memcpy(&scvt, &s, sizeof(scvt));
14251
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14252
+ sum += (ggml_float)val;
14253
+ sm[i] = val;
14254
+ }
14255
+ }
14256
+
14257
+ assert(sum > 0.0);
14258
+ sum = 1.0/sum;
14259
+ }
12855
14260
 
12856
- for (int i = 0; i < n; i++) {
12857
- fun(nc,
12858
- (float *) ((char *) dst->data + i*( dst->nb[1])),
12859
- (float *) ((char *) src0->data + i*(src0->nb[1])),
12860
- (float *) ((char *) src1->data + i*(src1->nb[1])));
14261
+ float dot_st1_dst1 = 0;
14262
+ ggml_vec_scale_f32(nc, sm, sum);
14263
+ ggml_vec_cpy_f32 (nc, ds0, sm);
14264
+ ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
14265
+ ggml_vec_add1_f32 (nc, ds0, ds0, eps);
14266
+ ggml_vec_div_f32 (nc, ds0, s1, ds0);
14267
+ ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
14268
+ ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
14269
+ ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
14270
+ ggml_vec_mul_f32 (nc, ds0, ds0, sm);
14271
+
14272
+ #ifndef NDEBUG
14273
+ for (int i = 0; i < nc; ++i) {
14274
+ assert(!isnan(sm[i]));
14275
+ assert(!isinf(sm[i]));
14276
+ assert(!isnan(ds0[i]));
14277
+ assert(!isinf(ds0[i]));
14278
+ }
14279
+ #endif
12861
14280
  }
12862
14281
  }
12863
14282
 
12864
-
12865
- static void ggml_compute_forward_map_binary(
14283
+ static void ggml_compute_forward_cross_entropy_loss_back(
12866
14284
  const struct ggml_compute_params * params,
12867
14285
  const struct ggml_tensor * src0,
12868
14286
  const struct ggml_tensor * src1,
12869
- struct ggml_tensor * dst,
12870
- const ggml_binary_op_f32_t fun) {
14287
+ const struct ggml_tensor * opt0,
14288
+ struct ggml_tensor * dst) {
12871
14289
  switch (src0->type) {
12872
14290
  case GGML_TYPE_F32:
12873
14291
  {
12874
- ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14292
+ ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
12875
14293
  } break;
12876
14294
  default:
12877
14295
  {
@@ -12880,11 +14298,21 @@ static void ggml_compute_forward_map_binary(
12880
14298
  }
12881
14299
  }
12882
14300
 
14301
+
12883
14302
  /////////////////////////////////
12884
14303
 
12885
14304
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
12886
14305
  GGML_ASSERT(params);
12887
14306
 
14307
+ #ifdef GGML_USE_CUBLAS
14308
+ bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
14309
+ if (skip_cpu) {
14310
+ return;
14311
+ }
14312
+ GGML_ASSERT(tensor->src0->backend == GGML_BACKEND_CPU);
14313
+ GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
14314
+ #endif // GGML_USE_CUBLAS
14315
+
12888
14316
  switch (tensor->op) {
12889
14317
  case GGML_OP_DUP:
12890
14318
  {
@@ -12942,6 +14370,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12942
14370
  {
12943
14371
  ggml_compute_forward_repeat(params, tensor->src0, tensor);
12944
14372
  } break;
14373
+ case GGML_OP_REPEAT_BACK:
14374
+ {
14375
+ ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14376
+ } break;
12945
14377
  case GGML_OP_ABS:
12946
14378
  {
12947
14379
  ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -12990,6 +14422,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12990
14422
  {
12991
14423
  ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
12992
14424
  } break;
14425
+ case GGML_OP_OUT_PROD:
14426
+ {
14427
+ ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
14428
+ } break;
12993
14429
  case GGML_OP_SCALE:
12994
14430
  {
12995
14431
  ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
@@ -13046,6 +14482,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13046
14482
  {
13047
14483
  ggml_compute_forward_soft_max(params, tensor->src0, tensor);
13048
14484
  } break;
14485
+ case GGML_OP_SOFT_MAX_BACK:
14486
+ {
14487
+ ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
14488
+ } break;
13049
14489
  case GGML_OP_ROPE:
13050
14490
  {
13051
14491
  ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
@@ -13081,6 +14521,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13081
14521
  {
13082
14522
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
13083
14523
  } break;
14524
+ case GGML_OP_FLASH_ATTN_BACK:
14525
+ {
14526
+ int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
14527
+ GGML_ASSERT(t == 0 || t == 1);
14528
+ bool masked = t != 0;
14529
+ ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
14530
+ } break;
13084
14531
  case GGML_OP_MAP_UNARY:
13085
14532
  {
13086
14533
  const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
@@ -13093,6 +14540,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13093
14540
  ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
13094
14541
  }
13095
14542
  break;
14543
+ case GGML_OP_CROSS_ENTROPY_LOSS:
14544
+ {
14545
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
14546
+ }
14547
+ break;
14548
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
14549
+ {
14550
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14551
+ }
14552
+ break;
13096
14553
  case GGML_OP_NONE:
13097
14554
  {
13098
14555
  // nop
@@ -13231,11 +14688,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13231
14688
  src0->grad =
13232
14689
  ggml_add_impl(ctx,
13233
14690
  src0->grad,
13234
- ggml_mul(ctx,
13235
- tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
14691
+ ggml_scale(ctx,
13236
14692
  ggml_div(ctx,
13237
- ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
13238
- tensor)),
14693
+ tensor->grad,
14694
+ tensor),
14695
+ ggml_new_f32(ctx, 0.5f)),
13239
14696
  inplace);
13240
14697
  }
13241
14698
  } break;
@@ -13281,43 +14738,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13281
14738
  {
13282
14739
  // necessary for llama
13283
14740
  if (src0->grad) {
13284
- GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
13285
- const int nc = tensor->ne[0];
13286
- const int nr = tensor->ne[1];
13287
- const int nc0 = src0->ne[0];
13288
- const int nr0 = src0->ne[1];
13289
- const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
13290
- const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
13291
- // tensor->grad [nc,nr,1,1]
13292
- // reshape [nc0,nc/nc0,nr0,nr/nr0]
13293
- // permute [nc0,nr0,nc/nc0,nr/nr0]
13294
- // substitute [nc0,nr0,ncr,nrr]
13295
- // reshape [nc0*nr0,ncr*nrr,1,1]
13296
- // transpose [ncr*nrr,nc0*nr0,1,1]
13297
- // sum rows [1,nc0*nr0,1,1]
13298
- // transpose [nc0*nr0,1,1]
13299
- // reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
13300
- // add to src0->grad
13301
-
13302
- int64_t ne[4] = {nc0,ncr,nr0,nrr};
13303
-
13304
- struct ggml_tensor* F00 = tensor->grad;
13305
- struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
13306
- struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
13307
- struct ggml_tensor* F03 = ggml_cont (ctx, F02);
13308
- struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
13309
- struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
13310
- struct ggml_tensor* F06 = ggml_cont (ctx, F05);
13311
- struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
13312
- struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
13313
- struct ggml_tensor* F09 = ggml_cont (ctx, F08);
13314
- struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
13315
-
13316
- src0->grad =
13317
- ggml_add_impl(ctx,
13318
- src0->grad,
13319
- F10,
13320
- inplace);
14741
+ src0->grad = ggml_add_impl(ctx,
14742
+ src0->grad,
14743
+ ggml_repeat_back(ctx, tensor->grad, src0->grad),
14744
+ inplace);
14745
+ }
14746
+ } break;
14747
+ case GGML_OP_REPEAT_BACK:
14748
+ {
14749
+ if (src0->grad) {
14750
+ // TODO: test this
14751
+ src0->grad = ggml_add_impl(ctx,
14752
+ src0->grad,
14753
+ ggml_repeat(ctx, tensor->grad, src0->grad),
14754
+ inplace);
13321
14755
  }
13322
14756
  } break;
13323
14757
  case GGML_OP_ABS:
@@ -13424,38 +14858,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13424
14858
 
13425
14859
  // necessary for llama
13426
14860
  if (src0->grad) {
13427
- // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
13428
14861
  src0->grad =
13429
14862
  ggml_add_impl(ctx,
13430
14863
  src0->grad,
13431
- // ds0 = dt.dot(s1.T)
13432
- // ggml_out_prod(ctx, // [n,m]
13433
- // src1, // [n,p]
13434
- // tensor->grad), // [m,p]
13435
- // for now just using A*B==(B.T*A.T).T
13436
- ggml_cont(ctx, // [n,m]
13437
- ggml_transpose(ctx, // [n,m]
13438
- ggml_mul_mat(ctx, // [m,n]
13439
- ggml_cont(ctx, // [p,m]
13440
- ggml_transpose(ctx, // [p,m]
13441
- tensor->grad)), // [m,p]
13442
- ggml_cont(ctx, // [p,n]
13443
- ggml_transpose(ctx, // [p,n]
13444
- src1))))), // [n,p]
14864
+ ggml_out_prod(ctx, // [n,m]
14865
+ src1, // [n,p]
14866
+ tensor->grad), // [m,p]
13445
14867
  inplace);
13446
14868
  }
13447
14869
  if (src1->grad) {
13448
14870
  src1->grad =
13449
14871
  ggml_add_impl(ctx,
13450
14872
  src1->grad,
13451
- // ds1 = s0.T.dot(dt):
13452
- ggml_mul_mat(ctx, // [n,p]
13453
- ggml_cont(ctx, // [m,n]
13454
- ggml_transpose(ctx, src0)), // [m,n]
13455
- tensor->grad), // [m,p]
14873
+ // ggml_mul_mat(ctx, // [n,p]
14874
+ // ggml_cont(ctx, // [m,n]
14875
+ // ggml_transpose(ctx, src0)), // [m,n]
14876
+ // tensor->grad), // [m,p]
14877
+
14878
+ // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
14879
+ // // avoid transpose of src0, rather transpose smaller tensor->grad
14880
+ // // and then use ggml_out_prod
14881
+ ggml_out_prod(ctx, // [n,p]
14882
+ src0, // [n,m]
14883
+ ggml_transpose(ctx, // [p,m]
14884
+ tensor->grad)), // [m,p]
13456
14885
  inplace);
13457
14886
  }
13458
14887
  } break;
14888
+ case GGML_OP_OUT_PROD:
14889
+ {
14890
+ GGML_ASSERT(false); // TODO: not implemented
14891
+ } break;
13459
14892
  case GGML_OP_SCALE:
13460
14893
  {
13461
14894
  // necessary for llama
@@ -13557,7 +14990,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13557
14990
  // necessary for llama
13558
14991
  if (src0->grad) {
13559
14992
  size_t offset;
13560
- memcpy(&offset, tensor->padding, sizeof(offset));
14993
+
14994
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
14995
+ memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
13561
14996
 
13562
14997
  size_t nb1 = tensor->nb[1];
13563
14998
  size_t nb2 = tensor->nb[2];
@@ -13584,10 +15019,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13584
15019
  {
13585
15020
  // necessary for llama
13586
15021
  if (src0->grad) {
13587
- int axis0 = tensor->padding[0] & 0x3;
13588
- int axis1 = tensor->padding[1] & 0x3;
13589
- int axis2 = tensor->padding[2] & 0x3;
13590
- int axis3 = tensor->padding[3] & 0x3;
15022
+ int32_t * axes = (int32_t *) tensor->opt[0]->data;
15023
+ int axis0 = axes[0] & 0x3;
15024
+ int axis1 = axes[1] & 0x3;
15025
+ int axis2 = axes[2] & 0x3;
15026
+ int axis3 = axes[3] & 0x3;
13591
15027
  int axes_backward[4] = {0,0,0,0};
13592
15028
  axes_backward[axis0] = 0;
13593
15029
  axes_backward[axis1] = 1;
@@ -13671,50 +15107,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13671
15107
  {
13672
15108
  // necessary for llama
13673
15109
  if (src0->grad) {
13674
- // y = softmax(x)
13675
- //
13676
- // Jii = yi - yi*yi
13677
- // Jij = -yi*yj
13678
- // J = diag(y)-y.*y
13679
- // dx = J * dy
13680
- // dxk = sum(Jkj * dyk)
13681
-
13682
- int64_t ne2[4] = {
13683
- tensor->ne[0],
13684
- 1,
13685
- tensor->ne[1]*tensor->ne[2],
13686
- tensor->ne[3]
13687
- };
13688
- struct ggml_tensor * tensor2 = ggml_cont(ctx,
13689
- ggml_reshape_4d(ctx,
13690
- ggml_cont(ctx, tensor),
13691
- ne2[0], ne2[1], ne2[2], ne2[3]));
13692
-
13693
- struct ggml_tensor * grad2 = ggml_cont(ctx,
13694
- ggml_reshape_4d(ctx,
13695
- ggml_cont(ctx, tensor->grad),
13696
- ne2[0], ne2[1], ne2[2], ne2[3]));
13697
-
13698
- struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
13699
- ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
13700
- tensor2, // [ne0,1,ne1*ne2,ne3]
13701
- 1, 0, 2, 3));
13702
-
13703
15110
  src0->grad =
13704
- ggml_add_impl(ctx,
13705
- src0->grad, // [ne0,ne1,ne2,ne3]
13706
- ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
13707
- ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
13708
- ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
13709
- ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
13710
- tensor2), // [ne0,1,ne1*ne2,ne3]
13711
- ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
13712
- tensor2_t, // [1,ne0,ne1*ne2,ne3]
13713
- tensor2_t)), // [1,ne0,ne1*ne2,ne3]
13714
- grad2), // [ne0,1,ne1*ne2,ne3]
13715
- src0->grad),
13716
- inplace);
15111
+ ggml_add_impl(ctx, src0->grad,
15112
+ ggml_soft_max_back(ctx, tensor->grad, tensor),
15113
+ inplace);
13717
15114
  }
15115
+
15116
+ } break;
15117
+ case GGML_OP_SOFT_MAX_BACK:
15118
+ {
15119
+ GGML_ASSERT(false); // TODO: not implemented
13718
15120
  } break;
13719
15121
  case GGML_OP_ROPE:
13720
15122
  {
@@ -13769,17 +15171,190 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13769
15171
  } break;
13770
15172
  case GGML_OP_FLASH_ATTN:
13771
15173
  {
13772
- GGML_ASSERT(false); // not supported
15174
+ struct ggml_tensor * flash_grad = NULL;
15175
+ if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15176
+ int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15177
+ GGML_ASSERT(t == 0 || t == 1);
15178
+ bool masked = t != 0;
15179
+ flash_grad =
15180
+ ggml_flash_attn_back(ctx,
15181
+ src0,
15182
+ src1,
15183
+ tensor->opt[0],
15184
+ tensor->grad,
15185
+ masked);
15186
+ }
15187
+
15188
+ if (src0->grad) {
15189
+ struct ggml_tensor * grad_q = NULL;
15190
+ const size_t nb0 = flash_grad->nb[0];
15191
+ const size_t offset = 0;
15192
+ switch(src0->n_dims) {
15193
+ case 2:
15194
+ {
15195
+ grad_q = ggml_view_2d(ctx,
15196
+ flash_grad,
15197
+ src0->ne[0],
15198
+ src0->ne[1],
15199
+ nb0*src0->ne[0],
15200
+ offset);
15201
+ } break;
15202
+ case 3:
15203
+ {
15204
+ grad_q = ggml_view_3d(ctx,
15205
+ flash_grad,
15206
+ src0->ne[0],
15207
+ src0->ne[1],
15208
+ src0->ne[2],
15209
+ nb0*src0->ne[0],
15210
+ nb0*src0->ne[0]*src0->ne[1],
15211
+ offset);
15212
+ } break;
15213
+ case 4:
15214
+ {
15215
+ grad_q = ggml_view_4d(ctx,
15216
+ flash_grad,
15217
+ src0->ne[0],
15218
+ src0->ne[1],
15219
+ src0->ne[2],
15220
+ src0->ne[3],
15221
+ nb0*src0->ne[0],
15222
+ nb0*src0->ne[0]*src0->ne[1],
15223
+ nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
15224
+ offset);
15225
+ } break;
15226
+ }
15227
+
15228
+ src0->grad = ggml_add_impl(ctx,
15229
+ src0->grad,
15230
+ grad_q,
15231
+ inplace);
15232
+ }
15233
+
15234
+ if (src1->grad) {
15235
+ struct ggml_tensor * grad_k = NULL;
15236
+ const size_t nb0 = flash_grad->nb[0];
15237
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
15238
+ switch(src1->n_dims) {
15239
+ case 2:
15240
+ {
15241
+ grad_k = ggml_view_2d(ctx,
15242
+ flash_grad,
15243
+ src1->ne[0],
15244
+ src1->ne[1],
15245
+ nb0*src1->ne[0],
15246
+ offset);
15247
+ } break;
15248
+ case 3:
15249
+ {
15250
+ grad_k = ggml_view_3d(ctx,
15251
+ flash_grad,
15252
+ src1->ne[0],
15253
+ src1->ne[1],
15254
+ src1->ne[2],
15255
+ nb0*src1->ne[0],
15256
+ nb0*src1->ne[0]*src1->ne[1],
15257
+ offset);
15258
+ } break;
15259
+ case 4:
15260
+ {
15261
+ grad_k = ggml_view_4d(ctx,
15262
+ flash_grad,
15263
+ src1->ne[0],
15264
+ src1->ne[1],
15265
+ src1->ne[2],
15266
+ src1->ne[3],
15267
+ nb0*src1->ne[0],
15268
+ nb0*src1->ne[0]*src1->ne[1],
15269
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
15270
+ offset);
15271
+ } break;
15272
+ }
15273
+
15274
+ src1->grad = ggml_add_impl(ctx,
15275
+ src1->grad,
15276
+ grad_k,
15277
+ inplace);
15278
+ }
15279
+
15280
+ struct ggml_tensor * opt0 = tensor->opt[0];
15281
+
15282
+ if (opt0->grad) {
15283
+ struct ggml_tensor * grad_v = NULL;
15284
+ const size_t nb0 = flash_grad->nb[0];
15285
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
15286
+ + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
15287
+ switch(opt0->n_dims) {
15288
+ case 2:
15289
+ {
15290
+ grad_v = ggml_view_2d(ctx,
15291
+ flash_grad,
15292
+ opt0->ne[0],
15293
+ opt0->ne[1],
15294
+ nb0*opt0->ne[0],
15295
+ offset);
15296
+ } break;
15297
+ case 3:
15298
+ {
15299
+ grad_v = ggml_view_3d(ctx,
15300
+ flash_grad,
15301
+ opt0->ne[0],
15302
+ opt0->ne[1],
15303
+ opt0->ne[2],
15304
+ nb0*opt0->ne[0],
15305
+ nb0*opt0->ne[0]*opt0->ne[1],
15306
+ offset);
15307
+ } break;
15308
+ case 4:
15309
+ {
15310
+ grad_v = ggml_view_4d(ctx,
15311
+ flash_grad,
15312
+ opt0->ne[0],
15313
+ opt0->ne[1],
15314
+ opt0->ne[2],
15315
+ opt0->ne[3],
15316
+ nb0*opt0->ne[0],
15317
+ nb0*opt0->ne[0]*opt0->ne[1],
15318
+ nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
15319
+ offset);
15320
+ } break;
15321
+ }
15322
+
15323
+ opt0->grad = ggml_add_impl(ctx,
15324
+ opt0->grad,
15325
+ grad_v,
15326
+ inplace);
15327
+ }
13773
15328
  } break;
13774
15329
  case GGML_OP_FLASH_FF:
13775
15330
  {
13776
15331
  GGML_ASSERT(false); // not supported
13777
15332
  } break;
15333
+ case GGML_OP_FLASH_ATTN_BACK:
15334
+ {
15335
+ GGML_ASSERT(false); // not supported
15336
+ } break;
13778
15337
  case GGML_OP_MAP_UNARY:
13779
15338
  case GGML_OP_MAP_BINARY:
13780
15339
  {
13781
15340
  GGML_ASSERT(false); // not supported
13782
15341
  } break;
15342
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15343
+ {
15344
+ if (src0->grad) {
15345
+ src0->grad = ggml_add_impl(ctx,
15346
+ src0->grad,
15347
+ ggml_cross_entropy_loss_back(ctx,
15348
+ src0,
15349
+ src1,
15350
+ tensor->grad),
15351
+ inplace);
15352
+ }
15353
+ } break;
15354
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15355
+ {
15356
+ GGML_ASSERT(false); // not supported
15357
+ } break;
13783
15358
  case GGML_OP_NONE:
13784
15359
  {
13785
15360
  // nop
@@ -14156,6 +15731,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14156
15731
  case GGML_OP_SUM_ROWS:
14157
15732
  case GGML_OP_MEAN:
14158
15733
  case GGML_OP_REPEAT:
15734
+ case GGML_OP_REPEAT_BACK:
14159
15735
  case GGML_OP_ABS:
14160
15736
  case GGML_OP_SGN:
14161
15737
  case GGML_OP_NEG:
@@ -14175,6 +15751,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14175
15751
  node->n_tasks = n_threads;
14176
15752
  } break;
14177
15753
  case GGML_OP_MUL_MAT:
15754
+ case GGML_OP_OUT_PROD:
14178
15755
  {
14179
15756
  node->n_tasks = n_threads;
14180
15757
 
@@ -14191,7 +15768,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14191
15768
  if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
14192
15769
  node->n_tasks = 1; // TODO: this actually is doing nothing
14193
15770
  // the threads are still spinning
14194
- cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
14195
15771
  }
14196
15772
  else
14197
15773
  #elif defined(GGML_USE_CLBLAST)
@@ -14258,6 +15834,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14258
15834
  } break;
14259
15835
  case GGML_OP_DIAG_MASK_INF:
14260
15836
  case GGML_OP_SOFT_MAX:
15837
+ case GGML_OP_SOFT_MAX_BACK:
14261
15838
  case GGML_OP_ROPE:
14262
15839
  case GGML_OP_ROPE_BACK:
14263
15840
  {
@@ -14337,6 +15914,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14337
15914
  cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
14338
15915
  }
14339
15916
 
15917
+ work_size = MAX(work_size, cur);
15918
+ } break;
15919
+ case GGML_OP_FLASH_ATTN_BACK:
15920
+ {
15921
+ node->n_tasks = n_threads;
15922
+
15923
+ size_t cur = 0;
15924
+
15925
+ const int64_t D = node->src0->ne[0];
15926
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
15927
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
15928
+ if (node->src1->type == GGML_TYPE_F32) {
15929
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
15930
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
15931
+ }
15932
+
15933
+ if (node->src1->type == GGML_TYPE_F16) {
15934
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
15935
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
15936
+ }
15937
+
14340
15938
  work_size = MAX(work_size, cur);
14341
15939
  } break;
14342
15940
  case GGML_OP_MAP_UNARY:
@@ -14344,6 +15942,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14344
15942
  {
14345
15943
  node->n_tasks = 1;
14346
15944
  } break;
15945
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15946
+ {
15947
+ node->n_tasks = n_threads;
15948
+
15949
+ size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
15950
+
15951
+ work_size = MAX(work_size, cur);
15952
+ } break;
15953
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15954
+ {
15955
+ node->n_tasks = n_threads;
15956
+
15957
+ size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
15958
+
15959
+ work_size = MAX(work_size, cur);
15960
+ } break;
14347
15961
  case GGML_OP_NONE:
14348
15962
  {
14349
15963
  node->n_tasks = 1;
@@ -14581,7 +16195,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou
14581
16195
  const int64_t * ne = tensor->ne;
14582
16196
  const size_t * nb = tensor->nb;
14583
16197
 
14584
- fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %16s\n",
16198
+ fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
14585
16199
  ggml_type_name(tensor->type),
14586
16200
  ggml_op_name (tensor->op),
14587
16201
  tensor->n_dims,
@@ -14595,7 +16209,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
14595
16209
  const int64_t * ne = tensor->ne;
14596
16210
  const size_t * nb = tensor->nb;
14597
16211
 
14598
- fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %16s\n",
16212
+ fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
14599
16213
  arg,
14600
16214
  ggml_type_name(tensor->type),
14601
16215
  ggml_op_name (tensor->op),
@@ -14608,8 +16222,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
14608
16222
  }
14609
16223
 
14610
16224
  void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
14611
- assert(cgraph->work == NULL);
14612
- assert(cgraph->work_size == 0);
16225
+ //assert(cgraph->work == NULL);
16226
+ //assert(cgraph->work_size == 0);
14613
16227
 
14614
16228
  uint64_t size_eval = 0;
14615
16229
 
@@ -14624,11 +16238,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
14624
16238
  FILE * fout = stdout;
14625
16239
 
14626
16240
  fprintf(fout, "\n");
14627
- fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC);
14628
- fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION);
14629
- fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs);
14630
- fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes);
14631
- fprintf(fout, "%-16s %8llu\n", "eval", size_eval);
16241
+ fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC);
16242
+ fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION);
16243
+ fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs);
16244
+ fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes);
16245
+ fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval);
14632
16246
 
14633
16247
  // header
14634
16248
  fprintf(fout, "\n");
@@ -14830,7 +16444,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
14830
16444
  // read file into data
14831
16445
  {
14832
16446
  FILE * fin = fopen(fname, "rb");
14833
-
14834
16447
  if (!fin) {
14835
16448
  fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
14836
16449
  return result;
@@ -14862,7 +16475,11 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
14862
16475
 
14863
16476
  data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
14864
16477
 
14865
- fread(data->data, sizeof(char), fsize, fin);
16478
+ const size_t ret = fread(data->data, sizeof(char), fsize, fin);
16479
+ if (ret != fsize) {
16480
+ fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
16481
+ return result;
16482
+ }
14866
16483
 
14867
16484
  fclose(fin);
14868
16485
  }
@@ -14970,6 +16587,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
14970
16587
  op = *(const uint32_t *) ptr; ptr += sizeof(op);
14971
16588
  n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
14972
16589
 
16590
+ enum ggml_op eop = (enum ggml_op) op;
16591
+
14973
16592
  int64_t ne[GGML_MAX_DIMS];
14974
16593
  size_t nb[GGML_MAX_DIMS];
14975
16594
 
@@ -14984,42 +16603,77 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
14984
16603
  nb[j] = nb_cur;
14985
16604
  }
14986
16605
 
14987
- struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
14988
-
14989
- tensor->op = (enum ggml_op) op;
16606
+ uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
14990
16607
 
14991
- uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur);
16608
+ const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
14992
16609
 
14993
- memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
16610
+ const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
14994
16611
 
14995
- for (int j = 0; j < GGML_MAX_DIMS; ++j) {
14996
- tensor->nb[j] = nb[j];
14997
- }
16612
+ struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
14998
16613
 
14999
16614
  // parse args
15000
- {
15001
- struct ggml_tensor ** args[2 + GGML_MAX_OPT] = {
15002
- &tensor->src0,
15003
- &tensor->src1,
15004
- };
16615
+ for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
16616
+ const int32_t arg_idx = ptr_arg_idx[j];
15005
16617
 
15006
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
15007
- args[2 + j] = &tensor->opt[j];
16618
+ if (arg_idx == -1) {
16619
+ continue;
15008
16620
  }
15009
16621
 
15010
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
15011
- const int32_t arg_idx = *(const int32_t *) ptr; ptr += sizeof(arg_idx);
16622
+ if (arg_idx < GGML_MAX_NODES) {
16623
+ args[j] = result.leafs[arg_idx];
16624
+ } else {
16625
+ args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
16626
+ }
16627
+ }
15012
16628
 
15013
- if (arg_idx == -1) {
15014
- continue;
15015
- }
16629
+ // create the tensor
16630
+ // "view" operations are handled differently
16631
+ // TODO: handle inplace ops - currently a copy is always made
16632
+
16633
+ struct ggml_tensor * tensor = NULL;
16634
+
16635
+ switch (eop) {
16636
+ // TODO: implement other view ops
16637
+ case GGML_OP_RESHAPE:
16638
+ {
16639
+ tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
16640
+ } break;
16641
+ case GGML_OP_VIEW:
16642
+ {
16643
+ tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
16644
+
16645
+ uint64_t offs;
16646
+ memcpy(&offs, args[2]->data, sizeof(offs));
16647
+
16648
+ tensor->data = ((char *) tensor->data) + offs;
16649
+ } break;
16650
+ case GGML_OP_TRANSPOSE:
16651
+ {
16652
+ tensor = ggml_transpose(*ctx_eval, args[0]);
16653
+ } break;
16654
+ case GGML_OP_PERMUTE:
16655
+ {
16656
+ tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
16657
+ } break;
16658
+ default:
16659
+ {
16660
+ tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
16661
+
16662
+ tensor->op = eop;
16663
+ } break;
16664
+ }
15016
16665
 
15017
- if (arg_idx < GGML_MAX_NODES) {
15018
- *args[j] = result.leafs[arg_idx];
15019
- } else {
15020
- *args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
15021
- }
15022
- }
16666
+ memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
16667
+
16668
+ for (int j = 0; j < GGML_MAX_DIMS; ++j) {
16669
+ tensor->nb[j] = nb[j];
16670
+ }
16671
+
16672
+ tensor->src0 = args[0];
16673
+ tensor->src1 = args[1];
16674
+
16675
+ for (int j = 0; j < GGML_MAX_OPT; ++j) {
16676
+ tensor->opt[j] = args[2 + j];
15023
16677
  }
15024
16678
 
15025
16679
  result.nodes[i] = tensor;
@@ -15279,6 +16933,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
15279
16933
 
15280
16934
  static enum ggml_opt_result ggml_opt_adam(
15281
16935
  struct ggml_context * ctx,
16936
+ struct ggml_opt_context * opt,
15282
16937
  struct ggml_opt_params params,
15283
16938
  struct ggml_tensor * f,
15284
16939
  struct ggml_cgraph * gf,
@@ -15304,25 +16959,29 @@ static enum ggml_opt_result ggml_opt_adam(
15304
16959
  }
15305
16960
  }
15306
16961
 
16962
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
16963
+ int iter = opt->iter;
16964
+ ggml_opt_init(opt->ctx, opt, params, nx);
16965
+ opt->iter = iter;
16966
+ }
16967
+
15307
16968
  // constants
15308
- const float alpha = params.adam.alpha;
16969
+ const float sched = params.adam.sched;
16970
+ const float decay = params.adam.decay * sched;
16971
+ const float alpha = params.adam.alpha * sched;
15309
16972
  const float beta1 = params.adam.beta1;
15310
16973
  const float beta2 = params.adam.beta2;
15311
16974
  const float eps = params.adam.eps;
15312
16975
 
15313
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
15314
- float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
15315
- float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
15316
- float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
15317
- float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
15318
- float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
15319
- float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
15320
-
15321
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
16976
+ float * x = opt->adam.x->data; // view of the parameters
16977
+ float * g1 = opt->adam.g1->data; // gradient
16978
+ float * g2 = opt->adam.g2->data; // gradient squared
16979
+ float * m = opt->adam.m->data; // first moment
16980
+ float * v = opt->adam.v->data; // second moment
16981
+ float * mh = opt->adam.mh->data; // first moment hat
16982
+ float * vh = opt->adam.vh->data; // second moment hat
15322
16983
 
15323
- // initialize
15324
- ggml_vec_set_f32(nx, m, 0.0f);
15325
- ggml_vec_set_f32(nx, v, 0.0f);
16984
+ float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
15326
16985
 
15327
16986
  // update view
15328
16987
  ggml_opt_get_params(np, ps, x);
@@ -15332,16 +16991,27 @@ static enum ggml_opt_result ggml_opt_adam(
15332
16991
  ggml_set_f32 (f->grad, 1.0f);
15333
16992
  ggml_graph_compute(ctx, gb);
15334
16993
 
15335
- float fx_prev = ggml_get_f32_1d(f, 0);
16994
+ opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
16995
+ opt->adam.fx_best = opt->adam.fx_prev;
15336
16996
  if (pf) {
15337
- pf[0] = fx_prev;
16997
+ pf[opt->iter % params.past] = opt->adam.fx_prev;
16998
+ }
16999
+
17000
+ // initialize
17001
+ if (opt->just_initialized) {
17002
+ opt->adam.n_no_improvement = 0;
17003
+ opt->just_initialized = false;
15338
17004
  }
15339
17005
 
15340
- int n_no_improvement = 0;
15341
- float fx_best = fx_prev;
17006
+ float * fx_best = &opt->adam.fx_best;
17007
+ float * fx_prev = &opt->adam.fx_prev;
17008
+ int * n_no_improvement = &opt->adam.n_no_improvement;
17009
+
17010
+ int iter0 = opt->iter;
15342
17011
 
15343
17012
  // run the optimizer
15344
17013
  for (int t = 0; t < params.adam.n_iter; ++t) {
17014
+ opt->iter = iter0 + t + 1;
15345
17015
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
15346
17016
 
15347
17017
  GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
@@ -15375,17 +17045,22 @@ static enum ggml_opt_result ggml_opt_adam(
15375
17045
 
15376
17046
  // m^hat = m_t / (1 - beta1^t)
15377
17047
  // v^hat = v_t / (1 - beta2^t)
15378
- // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
17048
+ // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
17049
+ // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
17050
+ // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
17051
+ // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
17052
+ // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
15379
17053
  ggml_vec_cpy_f32 (nx, mh, m);
15380
17054
  ggml_vec_cpy_f32 (nx, vh, v);
15381
17055
 
15382
- ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
15383
- ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1)));
17056
+ ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
17057
+ ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
15384
17058
 
15385
17059
  ggml_vec_sqrt_f32 (nx, vh, vh);
15386
17060
  ggml_vec_acc1_f32 (nx, vh, eps);
15387
17061
 
15388
17062
  ggml_vec_div_f32 (nx, mh, mh, vh);
17063
+ ggml_vec_scale_f32(nx, x, 1.0f - decay);
15389
17064
  ggml_vec_sub_f32 (nx, x, x, mh);
15390
17065
 
15391
17066
  // update the parameters
@@ -15399,7 +17074,7 @@ static enum ggml_opt_result ggml_opt_adam(
15399
17074
  const float fx = ggml_get_f32_1d(f, 0);
15400
17075
 
15401
17076
  // check convergence
15402
- if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
17077
+ if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
15403
17078
  GGML_PRINT_DEBUG("converged\n");
15404
17079
 
15405
17080
  return GGML_OPT_OK;
@@ -15408,32 +17083,32 @@ static enum ggml_opt_result ggml_opt_adam(
15408
17083
  // delta-based convergence test
15409
17084
  if (pf != NULL) {
15410
17085
  // need at least params.past iterations to start checking for convergence
15411
- if (params.past <= t) {
15412
- const float rate = (pf[t%params.past] - fx)/fx;
17086
+ if (params.past <= iter0 + t) {
17087
+ const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
15413
17088
 
15414
17089
  if (fabsf(rate) < params.delta) {
15415
17090
  return GGML_OPT_OK;
15416
17091
  }
15417
17092
  }
15418
17093
 
15419
- pf[t%params.past] = fx;
17094
+ pf[(iter0 + t)%params.past] = fx;
15420
17095
  }
15421
17096
 
15422
17097
  // check for improvement
15423
17098
  if (params.max_no_improvement > 0) {
15424
- if (fx_best > fx) {
15425
- fx_best = fx;
15426
- n_no_improvement = 0;
17099
+ if (fx_best[0] > fx) {
17100
+ fx_best[0] = fx;
17101
+ n_no_improvement[0] = 0;
15427
17102
  } else {
15428
- ++n_no_improvement;
17103
+ ++n_no_improvement[0];
15429
17104
 
15430
- if (n_no_improvement >= params.max_no_improvement) {
17105
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15431
17106
  return GGML_OPT_OK;
15432
17107
  }
15433
17108
  }
15434
17109
  }
15435
17110
 
15436
- fx_prev = fx;
17111
+ fx_prev[0] = fx;
15437
17112
 
15438
17113
  {
15439
17114
  const int64_t t_end_cpu = ggml_cycles();
@@ -15572,6 +17247,7 @@ static enum ggml_opt_result linesearch_backtracking(
15572
17247
 
15573
17248
  static enum ggml_opt_result ggml_opt_lbfgs(
15574
17249
  struct ggml_context * ctx,
17250
+ struct ggml_opt_context * opt,
15575
17251
  struct ggml_opt_params params,
15576
17252
  struct ggml_tensor * f,
15577
17253
  struct ggml_cgraph * gf,
@@ -15604,31 +17280,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15604
17280
  }
15605
17281
  }
15606
17282
 
15607
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
15608
- float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
15609
- float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
15610
- float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
15611
- float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
17283
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
17284
+ int iter = opt->iter;
17285
+ ggml_opt_init(ctx, opt, params, nx);
17286
+ opt->iter = iter;
17287
+ }
17288
+
17289
+ float * x = opt->lbfgs.x->data; // current parameters
17290
+ float * xp = opt->lbfgs.xp->data; // previous parameters
17291
+ float * g = opt->lbfgs.g->data; // current gradient
17292
+ float * gp = opt->lbfgs.gp->data; // previous gradient
17293
+ float * d = opt->lbfgs.d->data; // search direction
15612
17294
 
15613
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
17295
+ float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
15614
17296
 
15615
17297
  float fx = 0.0f; // cost function value
15616
17298
  float xnorm = 0.0f; // ||x||
15617
17299
  float gnorm = 0.0f; // ||g||
15618
- float step = 0.0f;
15619
17300
 
15620
17301
  // initialize x from the graph nodes
15621
17302
  ggml_opt_get_params(np, ps, x);
15622
17303
 
15623
17304
  // the L-BFGS memory
15624
- struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
15625
-
15626
- for (int i = 0; i < m; ++i) {
15627
- lm[i].alpha = 0.0f;
15628
- lm[i].ys = 0.0f;
15629
- lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15630
- lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15631
- }
17305
+ float * lm_alpha = opt->lbfgs.lmal->data;
17306
+ float * lm_ys = opt->lbfgs.lmys->data;
17307
+ float * lm_s = opt->lbfgs.lms->data;
17308
+ float * lm_y = opt->lbfgs.lmy->data;
15632
17309
 
15633
17310
  // evaluate the function value and its gradient
15634
17311
  {
@@ -15643,12 +17320,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15643
17320
  fx = ggml_get_f32_1d(f, 0);
15644
17321
  }
15645
17322
 
15646
- if (pf) {
15647
- pf[0] = fx;
15648
- }
15649
-
15650
- float fx_best = fx;
15651
-
15652
17323
  // search direction = -gradient
15653
17324
  ggml_vec_neg_f32(nx, d, g);
15654
17325
 
@@ -15665,26 +17336,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15665
17336
  return GGML_OPT_OK;
15666
17337
  }
15667
17338
 
15668
- // initial step
15669
- ggml_vec_norm_inv_f32(nx, &step, d);
17339
+ if (opt->just_initialized) {
17340
+ if (pf) {
17341
+ pf[0] = fx;
17342
+ }
17343
+ opt->lbfgs.fx_best = fx;
17344
+
17345
+ // initial step
17346
+ ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
17347
+ opt->lbfgs.j = 0;
17348
+ opt->lbfgs.k = 1;
17349
+ opt->lbfgs.end = 0;
17350
+ opt->lbfgs.n_no_improvement = 0;
17351
+ opt->just_initialized = false;
17352
+ }
17353
+
17354
+ float * fx_best = &opt->lbfgs.fx_best;
17355
+ float * step = &opt->lbfgs.step;
17356
+ int * j = &opt->lbfgs.j;
17357
+ int * k = &opt->lbfgs.k;
17358
+ int * end = &opt->lbfgs.end;
17359
+ int * n_no_improvement = &opt->lbfgs.n_no_improvement;
15670
17360
 
15671
- int j = 0;
15672
- int k = 1;
15673
- int ls = 0;
15674
- int end = 0;
15675
- int bound = 0;
15676
- int n_no_improvement = 0;
17361
+ int ls = 0;
17362
+ int bound = 0;
15677
17363
 
15678
17364
  float ys = 0.0f;
15679
17365
  float yy = 0.0f;
15680
17366
  float beta = 0.0f;
15681
17367
 
17368
+ int it = 0;
17369
+
15682
17370
  while (true) {
15683
17371
  // store the current position and gradient vectors
15684
17372
  ggml_vec_cpy_f32(nx, xp, x);
15685
17373
  ggml_vec_cpy_f32(nx, gp, g);
15686
17374
 
15687
- ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
17375
+ ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
15688
17376
 
15689
17377
  if (ls < 0) {
15690
17378
  // linesearch failed - go back to the previous point and return
@@ -15710,32 +17398,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15710
17398
  // delta-based convergence test
15711
17399
  if (pf != NULL) {
15712
17400
  // need at least params.past iterations to start checking for convergence
15713
- if (params.past <= k) {
15714
- const float rate = (pf[k%params.past] - fx)/fx;
17401
+ if (params.past <= k[0]) {
17402
+ const float rate = (pf[k[0]%params.past] - fx)/fx;
15715
17403
 
15716
17404
  if (fabsf(rate) < params.delta) {
15717
17405
  return GGML_OPT_OK;
15718
17406
  }
15719
17407
  }
15720
17408
 
15721
- pf[k%params.past] = fx;
17409
+ pf[k[0]%params.past] = fx;
15722
17410
  }
15723
17411
 
15724
17412
  // check for improvement
15725
17413
  if (params.max_no_improvement > 0) {
15726
- if (fx < fx_best) {
15727
- fx_best = fx;
15728
- n_no_improvement = 0;
17414
+ if (fx < fx_best[0]) {
17415
+ fx_best[0] = fx;
17416
+ n_no_improvement[0] = 0;
15729
17417
  } else {
15730
- n_no_improvement++;
17418
+ n_no_improvement[0]++;
15731
17419
 
15732
- if (n_no_improvement >= params.max_no_improvement) {
17420
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15733
17421
  return GGML_OPT_OK;
15734
17422
  }
15735
17423
  }
15736
17424
  }
15737
17425
 
15738
- if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
17426
+ if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
15739
17427
  // reached the maximum number of iterations
15740
17428
  return GGML_OPT_DID_NOT_CONVERGE;
15741
17429
  }
@@ -15744,50 +17432,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15744
17432
  // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
15745
17433
  // y_{k+1} = g_{k+1} - g_{k}.
15746
17434
  //
15747
- ggml_vec_sub_f32(nx, lm[end].s, x, xp);
15748
- ggml_vec_sub_f32(nx, lm[end].y, g, gp);
17435
+ ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
17436
+ ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
15749
17437
 
15750
17438
  // compute scalars ys and yy:
15751
17439
  // ys = y^t \cdot s -> 1 / \rho.
15752
17440
  // yy = y^t \cdot y.
15753
17441
  //
15754
- ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
15755
- ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
17442
+ ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
17443
+ ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
15756
17444
 
15757
- lm[end].ys = ys;
17445
+ lm_ys[end[0]] = ys;
15758
17446
 
15759
17447
  // find new search direction
15760
17448
  // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
15761
17449
 
15762
- bound = (m <= k) ? m : k;
15763
- k++;
15764
- end = (end + 1)%m;
17450
+ bound = (m <= k[0]) ? m : k[0];
17451
+ k[0]++;
17452
+ it++;
17453
+ end[0] = (end[0] + 1)%m;
15765
17454
 
15766
17455
  // initialize search direction with -g
15767
17456
  ggml_vec_neg_f32(nx, d, g);
15768
17457
 
15769
- j = end;
17458
+ j[0] = end[0];
15770
17459
  for (int i = 0; i < bound; ++i) {
15771
- j = (j + m - 1) % m;
17460
+ j[0] = (j[0] + m - 1) % m;
15772
17461
  // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
15773
- ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
15774
- lm[j].alpha /= lm[j].ys;
17462
+ ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
17463
+ lm_alpha[j[0]] /= lm_ys[j[0]];
15775
17464
  // q_{i} = q_{i+1} - \alpha_{i} y_{i}
15776
- ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
17465
+ ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
15777
17466
  }
15778
17467
 
15779
17468
  ggml_vec_scale_f32(nx, d, ys/yy);
15780
17469
 
15781
17470
  for (int i = 0; i < bound; ++i) {
15782
17471
  // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
15783
- ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
15784
- beta /= lm[j].ys;
17472
+ ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
17473
+ beta /= lm_ys[j[0]];
15785
17474
  // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
15786
- ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
15787
- j = (j + 1)%m;
17475
+ ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
17476
+ j[0] = (j[0] + 1)%m;
15788
17477
  }
15789
17478
 
15790
- step = 1.0;
17479
+ step[0] = 1.0;
15791
17480
  }
15792
17481
 
15793
17482
  return GGML_OPT_DID_NOT_CONVERGE;
@@ -15812,6 +17501,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
15812
17501
 
15813
17502
  .adam = {
15814
17503
  .n_iter = 10000,
17504
+ .sched = 1.000f,
17505
+ .decay = 0.001f,
15815
17506
  .alpha = 0.001f,
15816
17507
  .beta1 = 0.9f,
15817
17508
  .beta2 = 0.999f,
@@ -15854,6 +17545,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
15854
17545
  return result;
15855
17546
  }
15856
17547
 
17548
+ GGML_API void ggml_opt_init(
17549
+ struct ggml_context * ctx,
17550
+ struct ggml_opt_context * opt,
17551
+ struct ggml_opt_params params,
17552
+ int64_t nx) {
17553
+ opt->ctx = ctx;
17554
+ opt->params = params;
17555
+ opt->iter = 0;
17556
+ opt->nx = nx;
17557
+ opt->just_initialized = true;
17558
+ switch (opt->params.type) {
17559
+ case GGML_OPT_ADAM:
17560
+ {
17561
+ opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17562
+ opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17563
+ opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17564
+ opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17565
+ opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17566
+ opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17567
+ opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17568
+ opt->adam.pf = params.past > 0
17569
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
17570
+ : NULL;
17571
+ ggml_set_zero(opt->adam.x);
17572
+ ggml_set_zero(opt->adam.g1);
17573
+ ggml_set_zero(opt->adam.g2);
17574
+ ggml_set_zero(opt->adam.m);
17575
+ ggml_set_zero(opt->adam.v);
17576
+ ggml_set_zero(opt->adam.mh);
17577
+ ggml_set_zero(opt->adam.vh);
17578
+ if (opt->adam.pf) {
17579
+ ggml_set_zero(opt->adam.pf);
17580
+ }
17581
+ } break;
17582
+ case GGML_OPT_LBFGS:
17583
+ {
17584
+ opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17585
+ opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17586
+ opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17587
+ opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17588
+ opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17589
+ opt->lbfgs.pf = params.past > 0
17590
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
17591
+ : NULL;
17592
+ opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
17593
+ opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
17594
+ opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
17595
+ opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
17596
+ ggml_set_zero(opt->lbfgs.x);
17597
+ ggml_set_zero(opt->lbfgs.xp);
17598
+ ggml_set_zero(opt->lbfgs.g);
17599
+ ggml_set_zero(opt->lbfgs.gp);
17600
+ ggml_set_zero(opt->lbfgs.d);
17601
+ ggml_set_zero(opt->lbfgs.pf);
17602
+ if (opt->lbfgs.pf) {
17603
+ ggml_set_zero(opt->lbfgs.pf);
17604
+ }
17605
+ ggml_set_zero(opt->lbfgs.lmal);
17606
+ ggml_set_zero(opt->lbfgs.lmys);
17607
+ ggml_set_zero(opt->lbfgs.lms);
17608
+ ggml_set_zero(opt->lbfgs.lmy);
17609
+ } break;
17610
+ }
17611
+ }
17612
+
15857
17613
  enum ggml_opt_result ggml_opt(
15858
17614
  struct ggml_context * ctx,
15859
17615
  struct ggml_opt_params params,
@@ -15876,33 +17632,65 @@ enum ggml_opt_result ggml_opt(
15876
17632
 
15877
17633
  enum ggml_opt_result result = GGML_OPT_OK;
15878
17634
 
17635
+ struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
17636
+
17637
+ ggml_opt_init(ctx, opt, params, 0);
17638
+ result = ggml_opt_resume(ctx, opt, f);
17639
+
17640
+ if (free_ctx) {
17641
+ ggml_free(ctx);
17642
+ }
17643
+
17644
+ return result;
17645
+ }
17646
+
17647
+ enum ggml_opt_result ggml_opt_resume(
17648
+ struct ggml_context * ctx,
17649
+ struct ggml_opt_context * opt,
17650
+ struct ggml_tensor * f) {
17651
+
17652
+ // build forward + backward compute graphs
17653
+ struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
17654
+ struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
17655
+
17656
+ struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
17657
+ struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
17658
+
17659
+ *gf = ggml_build_forward (f);
17660
+ *gb = ggml_build_backward(ctx, gf, true);
17661
+
17662
+ return ggml_opt_resume_g(ctx, opt, f, gf, gb);
17663
+ }
17664
+
17665
+ enum ggml_opt_result ggml_opt_resume_g(
17666
+ struct ggml_context * ctx,
17667
+ struct ggml_opt_context * opt,
17668
+ struct ggml_tensor * f,
17669
+ struct ggml_cgraph * gf,
17670
+ struct ggml_cgraph * gb) {
17671
+
15879
17672
  // build forward + backward compute graphs
15880
- struct ggml_cgraph gf = ggml_build_forward (f);
15881
- struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
17673
+ enum ggml_opt_result result = GGML_OPT_OK;
15882
17674
 
15883
- switch (params.type) {
17675
+ switch (opt->params.type) {
15884
17676
  case GGML_OPT_ADAM:
15885
17677
  {
15886
- result = ggml_opt_adam(ctx, params, f, &gf, &gb);
17678
+ result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
15887
17679
  } break;
15888
17680
  case GGML_OPT_LBFGS:
15889
17681
  {
15890
- result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
17682
+ result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
15891
17683
  } break;
15892
17684
  }
15893
17685
 
15894
- if (params.print_forward_graph) {
15895
- ggml_graph_print (&gf);
15896
- ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
15897
- }
15898
-
15899
- if (params.print_backward_graph) {
15900
- ggml_graph_print (&gb);
15901
- ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
17686
+ if (opt->params.print_forward_graph) {
17687
+ ggml_graph_print (gf);
17688
+ ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
15902
17689
  }
15903
17690
 
15904
- if (free_ctx) {
15905
- ggml_free(ctx);
17691
+ if (opt->params.print_backward_graph) {
17692
+ ggml_graph_print (gb);
17693
+ ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
15906
17694
  }
15907
17695
 
15908
17696
  return result;
@@ -16070,6 +17858,50 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
16070
17858
  block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
16071
17859
  result = ggml_quantize_q8_0(src + start, block, n, n, hist);
16072
17860
  } break;
17861
+ #ifdef GGML_USE_K_QUANTS
17862
+ case GGML_TYPE_Q2_K:
17863
+ {
17864
+ GGML_ASSERT(start % QK_K == 0);
17865
+ block_q2_K * block = (block_q2_K*)dst + start / QK_K;
17866
+ result = ggml_quantize_q2_K(src + start, block, n, n, hist);
17867
+ } break;
17868
+ case GGML_TYPE_Q3_K:
17869
+ {
17870
+ GGML_ASSERT(start % QK_K == 0);
17871
+ block_q3_K * block = (block_q3_K*)dst + start / QK_K;
17872
+ result = ggml_quantize_q3_K(src + start, block, n, n, hist);
17873
+ } break;
17874
+ case GGML_TYPE_Q4_K:
17875
+ {
17876
+ GGML_ASSERT(start % QK_K == 0);
17877
+ block_q4_K * block = (block_q4_K*)dst + start / QK_K;
17878
+ result = ggml_quantize_q4_K(src + start, block, n, n, hist);
17879
+ } break;
17880
+ case GGML_TYPE_Q5_K:
17881
+ {
17882
+ GGML_ASSERT(start % QK_K == 0);
17883
+ block_q5_K * block = (block_q5_K*)dst + start / QK_K;
17884
+ result = ggml_quantize_q5_K(src + start, block, n, n, hist);
17885
+ } break;
17886
+ case GGML_TYPE_Q6_K:
17887
+ {
17888
+ GGML_ASSERT(start % QK_K == 0);
17889
+ block_q6_K * block = (block_q6_K*)dst + start / QK_K;
17890
+ result = ggml_quantize_q6_K(src + start, block, n, n, hist);
17891
+ } break;
17892
+ #endif
17893
+ case GGML_TYPE_F16:
17894
+ {
17895
+ int elemsize = sizeof(ggml_fp16_t);
17896
+ ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
17897
+ result = n * elemsize;
17898
+ } break;
17899
+ case GGML_TYPE_F32:
17900
+ {
17901
+ int elemsize = sizeof(float);
17902
+ result = n * elemsize;
17903
+ memcpy((uint8_t *)dst + start * elemsize, src + start, result);
17904
+ } break;
16073
17905
  default:
16074
17906
  assert(false);
16075
17907
  }