llama_cpp 0.0.1 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,7 @@
16
16
  #include <stdlib.h>
17
17
  #include <string.h>
18
18
  #include <stdint.h>
19
+ #include <inttypes.h>
19
20
  #include <stdio.h>
20
21
  #include <float.h>
21
22
 
@@ -79,6 +80,19 @@ static int sched_yield (void) {
79
80
  typedef void* thread_ret_t;
80
81
  #endif
81
82
 
83
+ // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
84
+ #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
85
+ #ifndef __FMA__
86
+ #define __FMA__
87
+ #endif
88
+ #ifndef __F16C__
89
+ #define __F16C__
90
+ #endif
91
+ #ifndef __SSE3__
92
+ #define __SSE3__
93
+ #endif
94
+ #endif
95
+
82
96
  #ifdef __HAIKU__
83
97
  #define static_assert(cond, msg) _Static_assert(cond, msg)
84
98
  #endif
@@ -172,8 +186,13 @@ typedef double ggml_float;
172
186
 
173
187
  #ifdef __F16C__
174
188
 
189
+ #ifdef _MSC_VER
190
+ #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
191
+ #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
192
+ #else
175
193
  #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
176
194
  #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
195
+ #endif
177
196
 
178
197
  #elif defined(__POWER9_VECTOR__)
179
198
 
@@ -443,6 +462,39 @@ static inline __m128i packNibbles( __m256i bytes )
443
462
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
444
463
  return _mm_packus_epi16( r0, r1 );
445
464
  }
465
+ #elif __AVX__
466
+ static inline __m128i bytesFromNibbles( const uint8_t* rsi )
467
+ {
468
+ // Load 8 bytes from memory
469
+ __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
470
+
471
+ // Expand bytes into uint16_t values
472
+ __m128i bytes = _mm_cvtepu8_epi16( tmp );
473
+
474
+ // Unpack values into individual bytes
475
+ const __m128i lowMask = _mm_set1_epi8( 0xF );
476
+ __m128i high = _mm_andnot_si128( lowMask, bytes );
477
+ __m128i low = _mm_and_si128( lowMask, bytes );
478
+ high = _mm_slli_epi16( high, 4 );
479
+ bytes = _mm_or_si128( low, high );
480
+ return bytes;
481
+ }
482
+
483
+ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
484
+ {
485
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
486
+ const __m128i lowByte = _mm_set1_epi16( 0xFF );
487
+ __m128i high = _mm_andnot_si128( lowByte, bytes1 );
488
+ __m128i low = _mm_and_si128( lowByte, bytes1 );
489
+ high = _mm_srli_epi16( high, 4 );
490
+ bytes1 = _mm_or_si128( low, high );
491
+ high = _mm_andnot_si128( lowByte, bytes2 );
492
+ low = _mm_and_si128( lowByte, bytes2 );
493
+ high = _mm_srli_epi16( high, 4 );
494
+ bytes2 = _mm_or_si128( low, high );
495
+
496
+ return _mm_packus_epi16( bytes1, bytes2);
497
+ }
446
498
  #endif
447
499
 
448
500
  // method 5
@@ -491,8 +543,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
491
543
  const uint8_t vi0 = (int8_t)roundf(v0) + 8;
492
544
  const uint8_t vi1 = (int8_t)roundf(v1) + 8;
493
545
 
494
- assert(vi0 >= 0 && vi0 < 16);
495
- assert(vi1 >= 0 && vi1 < 16);
546
+ assert(vi0 < 16);
547
+ assert(vi1 < 16);
496
548
 
497
549
  pp[l/2] = vi0 | (vi1 << 4);
498
550
  }
@@ -546,10 +598,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
546
598
  }
547
599
  }
548
600
  #elif __ARM_NEON
549
- uint8_t pp[QK/2];
550
601
  for (int i = 0; i < nb; i++) {
551
- float amax = 0.0f; // absolute max
552
-
553
602
  float32x4_t srcv [8];
554
603
  float32x4_t asrcv[8];
555
604
  float32x4_t amaxv[8];
@@ -561,7 +610,8 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
561
610
  for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
562
611
  for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
563
612
 
564
- amax = MAX(
613
+ // absolute max
614
+ const float amax = MAX(
565
615
  MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
566
616
  MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
567
617
 
@@ -575,11 +625,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
575
625
  const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
576
626
  const int32x4_t vi = vcvtq_s32_f32(vf);
577
627
 
578
- pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
579
- pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
628
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
629
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
580
630
  }
581
-
582
- memcpy(y[i].qs, pp, sizeof(pp));
583
631
  }
584
632
  #elif defined(__AVX2__)
585
633
  for (int i = 0; i < nb; i++) {
@@ -646,8 +694,81 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
646
694
  __m128i res = packNibbles( i0 );
647
695
  _mm_storeu_si128( ( __m128i* )y[i].qs, res );
648
696
  }
697
+ #elif defined(__AVX__)
698
+ for (int i = 0; i < nb; i++) {
699
+ // Load elements into 4 AVX vectors
700
+ __m256 v0 = _mm256_loadu_ps( x );
701
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
702
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
703
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
704
+ x += 32;
705
+
706
+ // Compute max(abs(e)) for the block
707
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
708
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
709
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
710
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
711
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
712
+
713
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
714
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
715
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
716
+ const float maxScalar = _mm_cvtss_f32( max4 );
717
+
718
+ // Quantize these floats
719
+ const float d = maxScalar / 7.0f;
720
+ y[i].d = d;
721
+ const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
722
+ const __m256 mul = _mm256_set1_ps( id );
723
+
724
+ // Apply the multiplier
725
+ v0 = _mm256_mul_ps( v0, mul );
726
+ v1 = _mm256_mul_ps( v1, mul );
727
+ v2 = _mm256_mul_ps( v2, mul );
728
+ v3 = _mm256_mul_ps( v3, mul );
729
+
730
+ // Round to nearest integer
731
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
732
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
733
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
734
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
735
+
736
+ // Convert floats to integers
737
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
738
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
739
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
740
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
741
+
742
+ // Since we don't have in AVX some necessary functions,
743
+ // we split the registers in half and call AVX2 analogs from SSE
744
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
745
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
746
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
747
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
748
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
749
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
750
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
751
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
752
+
753
+ // Convert int32 to int16
754
+ ni0 = _mm_packs_epi32( ni0, ni1 );
755
+ ni2 = _mm_packs_epi32( ni2, ni3 );
756
+ ni4 = _mm_packs_epi32( ni4, ni5 );
757
+ ni6 = _mm_packs_epi32( ni6, ni7 );
758
+ // Convert int16 to int8
759
+ ni0 = _mm_packs_epi16( ni0, ni2 );
760
+ ni4 = _mm_packs_epi16( ni4, ni6 );
761
+
762
+ // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
763
+ const __m128i off = _mm_set1_epi8( 8);
764
+ ni0 = _mm_add_epi8( ni0, off );
765
+ ni4 = _mm_add_epi8( ni4, off );
766
+
767
+ // Compress the vector into 4 bit/value, and store
768
+ __m128i res = packNibbles( ni0, ni4 );
769
+ _mm_storeu_si128( ( __m128i* )y[i].qs, res );
770
+ }
649
771
  #elif defined(__wasm_simd128__)
650
- uint8_t pp[QK/2];
651
772
  for (int i = 0; i < nb; i++) {
652
773
  float amax = 0.0f; // absolute max
653
774
 
@@ -676,11 +797,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
676
797
  const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
677
798
  const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
678
799
 
679
- pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
680
- pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
800
+ y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
801
+ y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
681
802
  }
682
-
683
- memcpy(y[i].qs, pp, sizeof(pp));
684
803
  }
685
804
  #else
686
805
  // scalar
@@ -719,8 +838,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
719
838
  const uint8_t vi0 = roundf(v0);
720
839
  const uint8_t vi1 = roundf(v1);
721
840
 
722
- assert(vi0 >= 0 && vi0 < 16);
723
- assert(vi1 >= 0 && vi1 < 16);
841
+ assert(vi0 < 16);
842
+ assert(vi1 < 16);
724
843
 
725
844
  pp[l/2] = vi0 | (vi1 << 4);
726
845
  }
@@ -732,11 +851,11 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
732
851
  static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
733
852
  assert(k % QK == 0);
734
853
 
735
- #if defined(__AVX2__)
736
854
  const int nb = k / QK;
737
855
 
738
856
  block_q4_1 * restrict y = vy;
739
857
 
858
+ #if defined(__AVX2__)
740
859
  for (int i = 0; i < nb; i++) {
741
860
  // Load elements into 4 AVX vectors
742
861
  __m256 v0 = _mm256_loadu_ps( x );
@@ -810,6 +929,41 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
810
929
  __m128i res = packNibbles( i0 );
811
930
  _mm_storeu_si128( ( __m128i* )y[i].qs, res );
812
931
  }
932
+ #elif __ARM_NEON
933
+ for (int i = 0; i < nb; i++) {
934
+ float32x4_t srcv[8];
935
+ float32x4_t minv[8];
936
+ float32x4_t maxv[8];
937
+
938
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
939
+
940
+ for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
941
+ for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
942
+ for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]);
943
+
944
+ for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]);
945
+ for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]);
946
+ for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]);
947
+
948
+ const float min = vminvq_f32(minv[0]);
949
+ const float max = vmaxvq_f32(maxv[0]);
950
+
951
+ const float d = (max - min) / ((1 << 4) - 1);
952
+ const float id = d ? 1.0f/d : 0.0f;
953
+
954
+ y[i].d = d;
955
+ y[i].m = min;
956
+
957
+ const float32x4_t minv0 = vdupq_n_f32(min);
958
+
959
+ for (int l = 0; l < 8; l++) {
960
+ const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
961
+ const int32x4_t vi = vcvtq_s32_f32(v);
962
+
963
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
964
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
965
+ }
966
+ }
813
967
  #else
814
968
  // scalar
815
969
  quantize_row_q4_1_reference(x, vy, k);
@@ -970,6 +1124,50 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
970
1124
  }
971
1125
  }
972
1126
  }
1127
+ #elif defined(__ARM_NEON)
1128
+ for (int i = 0; i < nb; i++) {
1129
+ const float32x4_t vd = vdupq_n_f32(x[i].d);
1130
+ const float32x4_t vm = vdupq_n_f32(x[i].m);
1131
+
1132
+ const uint8_t * restrict pp = x[i].qs;
1133
+
1134
+ for (int l = 0; l < QK; l += 16) {
1135
+ // Load 16x4-bit integers into 8x8-bit integers
1136
+ const uint8x8_t v8 = vld1_u8(pp + l/2);
1137
+
1138
+ // Expand 4-bit qs to 8-bit bytes
1139
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1140
+ const uint8x8_t v1 = vshr_n_u8(v8, 4);
1141
+
1142
+ // Interleave and combine
1143
+ const uint8x8_t vx_0 = vzip1_u8(v0, v1);
1144
+ const uint8x8_t vx_1 = vzip2_u8(v0, v1);
1145
+
1146
+ const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
1147
+
1148
+ // convert to 2x uint16x8_t
1149
+ const uint16x8_t vi_0 = vmovl_u8(vget_low_u8 (vq));
1150
+ const uint16x8_t vi_1 = vmovl_u8(vget_high_u8(vq));
1151
+
1152
+ // convert to 4x float32x4_t
1153
+ const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
1154
+ const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0)));
1155
+ const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1)));
1156
+ const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1)));
1157
+
1158
+ // multiply by d and add m
1159
+ const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd);
1160
+ const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd);
1161
+ const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd);
1162
+ const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
1163
+
1164
+ // Store
1165
+ vst1q_f32(y + i*QK + l + 0, r0);
1166
+ vst1q_f32(y + i*QK + l + 4, r1);
1167
+ vst1q_f32(y + i*QK + l + 8, r2);
1168
+ vst1q_f32(y + i*QK + l + 12, r3);
1169
+ }
1170
+ }
973
1171
  #else
974
1172
  for (int i = 0; i < nb; i++) {
975
1173
  const float d = x[i].d;
@@ -1207,7 +1405,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1207
1405
  _mm256_storeu_ps(arr, y);
1208
1406
 
1209
1407
  for (int i = 0; i < 8; i++)
1210
- x[i] = GGML_FP16_TO_FP32(arr[i]);
1408
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1211
1409
  }
1212
1410
  #define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
1213
1411
  #define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
@@ -1636,7 +1834,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1636
1834
  const block_q4_0 * restrict x = vx;
1637
1835
  const block_q4_0 * restrict y = vy;
1638
1836
 
1639
- ggml_float sumf = 0.0;
1837
+ float sumf = 0.0;
1640
1838
 
1641
1839
  #if defined(__ARM_NEON)
1642
1840
  float sum0 = 0.0f;
@@ -1731,7 +1929,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1731
1929
  #endif
1732
1930
  }
1733
1931
 
1734
- sumf = (ggml_float)(sum0 + sum1);
1932
+ sumf = sum0 + sum1;
1735
1933
  #elif defined(__AVX512F__)
1736
1934
  // Initialize accumulator with zeros
1737
1935
  __m512 acc0 = _mm512_setzero_ps();
@@ -1739,7 +1937,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1739
1937
 
1740
1938
  const int superblock_size = 8;
1741
1939
  const int superblock_count = nb / superblock_size;
1742
- const int remainder = nb % superblock_size;
1743
1940
 
1744
1941
  for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
1745
1942
  int i = superblock_ix * superblock_size;
@@ -1765,36 +1962,116 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1765
1962
  // Initialize accumulator with zeros
1766
1963
  __m256 acc = _mm256_setzero_ps();
1767
1964
 
1965
+ /* Prepare the constants we will need during execution */
1966
+ const __m256i lowMask = _mm256_set1_epi8( 0xF );
1967
+ const __m256i offset_8 = _mm256_set1_epi16( 8 );
1968
+
1969
+ #define UNROLL_COUNT 8
1970
+ // make sure we only unroll multiples of the block count
1971
+ assert(nb % UNROLL_COUNT == 0);
1972
+
1973
+ // Main loop
1974
+ for (int i = 0; i < nb; i+=UNROLL_COUNT) {
1975
+
1976
+ // This loop will be unrolled by the compiler
1977
+ for (int u=0;u<UNROLL_COUNT;u++) {
1978
+ /* Compute combined scale for the block */
1979
+ const __m256 scale = _mm256_mul_ps(
1980
+ _mm256_broadcast_ss( &x[i+u].d ),
1981
+ _mm256_broadcast_ss( &y[i+u].d ) );
1982
+
1983
+ /* get input from x
1984
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
1985
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1986
+
1987
+ /* Load 16 bytes from memory */
1988
+ const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
1989
+ /* Expand bytes into uint16_t values */
1990
+ const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
1991
+ /* Unpack values into individual bytes */
1992
+ __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
1993
+ const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
1994
+ __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
1995
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1996
+ x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
1997
+ x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
1998
+
1999
+ /* get input from y
2000
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
2001
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2002
+
2003
+ /* Load 16 bytes from memory */
2004
+ const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2005
+ /* Expand bytes into uint16_t values */
2006
+ const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2007
+ /* Unpack values into individual bytes */
2008
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2009
+ __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2010
+ __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2011
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2012
+ y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2013
+ y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2014
+
2015
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
2016
+ __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2017
+ __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2018
+
2019
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2020
+ __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2021
+
2022
+ /* Convert to vectore of 8 int32_t to 8 floats */
2023
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2024
+
2025
+ /* Multiply q with scale and accumulate */
2026
+ acc = _mm256_fmadd_ps( scale, q, acc );
2027
+ }
2028
+
2029
+ }
2030
+
2031
+ // Return horizontal sum of the acc vector
2032
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2033
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2034
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2035
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2036
+
2037
+ sumf = _mm_cvtss_f32( res );
2038
+ #elif defined(__AVX__)
2039
+ // Initialize accumulator with zeros
2040
+ __m256 acc = _mm256_setzero_ps();
2041
+
1768
2042
  // Main loop
1769
2043
  for (int i = 0; i < nb; ++i) {
1770
2044
  // Compute combined scale for the block
1771
2045
  const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
1772
2046
 
1773
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1774
- __m256i bx = bytesFromNibbles( x[i].qs );
1775
- __m256i by = bytesFromNibbles( y[i].qs );
2047
+ __m128i i32[2];
2048
+ for (int j = 0; j < 2; ++j) {
2049
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2050
+ __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2051
+ __m128i by = bytesFromNibbles( y[i].qs + 8*j );
1776
2052
 
1777
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1778
- const __m256i off = _mm256_set1_epi8( 8 );
1779
- bx = _mm256_sub_epi8( bx, off );
1780
- by = _mm256_sub_epi8( by, off );
2053
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2054
+ const __m128i off = _mm_set1_epi8( 8 );
2055
+ bx = _mm_sub_epi8( bx, off );
2056
+ by = _mm_sub_epi8( by, off );
1781
2057
 
1782
- // Sign-extend first 16 signed bytes into int16_t
1783
- __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
1784
- __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
1785
- // Compute products of int16_t integers, add pairwise
1786
- __m256i i32 = _mm256_madd_epi16( x16, y16 );
2058
+ // Get absolute values of x vectors
2059
+ const __m128i ax = _mm_sign_epi8(bx, bx);
1787
2060
 
1788
- // Sign-extend last 16 signed bytes into int16_t vectors
1789
- x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
1790
- y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
1791
- // Accumulate products of int16_t integers
1792
- i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
2061
+ // Sign the values of the y vectors
2062
+ const __m128i sy = _mm_sign_epi8(by, bx);
2063
+
2064
+ // Perform multiplication and create 16-bit values
2065
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
2066
+
2067
+ const __m128i ones = _mm_set1_epi16(1);
2068
+ i32[j] = _mm_madd_epi16(ones, dot);
2069
+ }
1793
2070
 
1794
2071
  // Convert int32_t to float
1795
- __m256 p = _mm256_cvtepi32_ps( i32 );
2072
+ __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
1796
2073
  // Apply the scale, and accumulate
1797
- acc = _mm256_fmadd_ps( d, p, acc );
2074
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
1798
2075
  }
1799
2076
 
1800
2077
  // Return horizontal sum of the acc vector
@@ -1944,7 +2221,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
1944
2221
  // Compute cross scales for the block
1945
2222
  const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
1946
2223
  const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
1947
- const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 );
2224
+ const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
1948
2225
 
1949
2226
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1950
2227
  __m256i bx = bytesFromNibbles( x[i].qs );
@@ -1990,6 +2267,45 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
1990
2267
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
1991
2268
 
1992
2269
  sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
2270
+ #elif defined(__ARM_NEON)
2271
+ float sum00 = 0.0f;
2272
+ float sum01 = 0.0f;
2273
+ float sum10 = 0.0f;
2274
+ float sum11 = 0.0f;
2275
+
2276
+ for (int i = 0; i < nb; ++i) {
2277
+ const block_q4_1 * restrict x0 = &x[i + 0];
2278
+ const block_q4_1 * restrict y0 = &y[i + 0];
2279
+
2280
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2281
+
2282
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2283
+ const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2284
+
2285
+ // and with 0xf
2286
+ const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2287
+ const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2288
+
2289
+ const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2290
+ const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2291
+
2292
+ // dot product into uint16x8_t
2293
+ const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2294
+ const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2295
+
2296
+ const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2297
+ const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2298
+
2299
+ const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2300
+ const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
2301
+
2302
+ sum00 += x0->m*y0->m;
2303
+ sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2304
+ sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2305
+ sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
2306
+ }
2307
+
2308
+ sumf = QK*sum00 + sum01 + sum10 + sum11;
1993
2309
  #else
1994
2310
  // scalar
1995
2311
  for (int i = 0; i < nb; i++) {
@@ -2401,8 +2717,9 @@ struct ggml_context {
2401
2717
  void * mem_buffer;
2402
2718
  bool mem_buffer_owned;
2403
2719
  bool mem_buffer_mlocked;
2720
+ bool no_alloc;
2404
2721
 
2405
- int n_objects;
2722
+ int n_objects;
2406
2723
 
2407
2724
  struct ggml_object * objects_begin;
2408
2725
  struct ggml_object * objects_end;
@@ -2487,7 +2804,7 @@ void ggml_print_objects(const struct ggml_context * ctx) {
2487
2804
  GGML_PRINT("%s: --- end ---\n", __func__);
2488
2805
  }
2489
2806
 
2490
- int ggml_nelements(const struct ggml_tensor * tensor) {
2807
+ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
2491
2808
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
2492
2809
 
2493
2810
  return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@@ -2619,6 +2936,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2619
2936
  static bool is_first_call = true;
2620
2937
 
2621
2938
  if (is_first_call) {
2939
+ // initialize time system (required on Windows)
2940
+ ggml_time_init();
2941
+
2622
2942
  // initialize GELU, SILU and EXP F32 tables
2623
2943
  {
2624
2944
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
@@ -2684,6 +3004,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2684
3004
  /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
2685
3005
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
2686
3006
  /*.mem_buffer_mlocked =*/ false,
3007
+ /*.no_alloc =*/ params.no_alloc,
2687
3008
  /*.n_objects =*/ 0,
2688
3009
  /*.objects_begin =*/ NULL,
2689
3010
  /*.objects_end =*/ NULL,
@@ -2751,36 +3072,47 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
2751
3072
  return result;
2752
3073
  }
2753
3074
 
3075
+ #ifdef __APPLE__
3076
+ #define MLOCK_SUGGESTION \
3077
+ "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
3078
+ "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
3079
+ #else
3080
+ #define MLOCK_SUGGESTION \
3081
+ "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
3082
+ #endif
3083
+
2754
3084
  bool ggml_mlock_supported(void) {
2755
3085
  return GGML_MLOCK_SUPPORT;
2756
3086
  }
2757
3087
 
3088
+ bool ggml_mlock(
3089
+ struct ggml_context * ctx,
3090
+ const void *opt_extra_addr,
3091
+ size_t opt_extra_len,
3092
+ char **err_p) {
3093
+ // TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
2758
3094
  #if GGML_MLOCK_SUPPORT
2759
- #ifdef __APPLE__
2760
- #define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \
2761
- "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l)."
2762
- #else
2763
- #define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)."
2764
- #endif
2765
- bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
2766
3095
  if (ctx->mem_buffer_mlocked) {
2767
3096
  return true;
2768
3097
  }
2769
- if (mlock(ctx->mem_buffer, ctx->mem_size)) {
2770
- int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
2771
- ctx->mem_size, strerror(errno));
2772
- GGML_ASSERT(ret >= 0);
3098
+ if (mlock(ctx->mem_buffer, ctx->mem_size) ||
3099
+ (opt_extra_len &&
3100
+ mlock(opt_extra_addr, opt_extra_len))) {
3101
+ if ((*err_p = malloc(1024))) {
3102
+ snprintf(*err_p, 1024,
3103
+ "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
3104
+ ctx->mem_size + opt_extra_len,
3105
+ strerror(errno));
3106
+ }
2773
3107
  return false;
2774
3108
  }
2775
3109
  ctx->mem_buffer_mlocked = true;
2776
3110
  return true;
2777
- }
2778
3111
  #else // GGML_MLOCK_SUPPORT
2779
- bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
2780
3112
  *err_p = strdup("can't mlock because it's not supported on this system");
2781
3113
  return false;
2782
- }
2783
3114
  #endif // GGML_MLOCK_SUPPORT
3115
+ }
2784
3116
 
2785
3117
  ////////////////////////////////////////////////////////////////////////////////
2786
3118
 
@@ -2788,7 +3120,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
2788
3120
  struct ggml_context * ctx,
2789
3121
  enum ggml_type type,
2790
3122
  int n_dims,
2791
- const int* ne,
3123
+ const int64_t* ne,
2792
3124
  void* data) {
2793
3125
  // always insert objects at the end of the context's memory pool
2794
3126
  struct ggml_object * obj_cur = ctx->objects_end;
@@ -2799,7 +3131,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
2799
3131
 
2800
3132
  size_t size_needed = 0;
2801
3133
 
2802
- if (data == NULL) {
3134
+ if (data == NULL && !ctx->no_alloc) {
2803
3135
  size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
2804
3136
  for (int i = 1; i < n_dims; i++) {
2805
3137
  size_needed *= ne[i];
@@ -2883,11 +3215,12 @@ struct ggml_tensor * ggml_new_tensor_impl(
2883
3215
  /*.perf_runs =*/ 0,
2884
3216
  /*.perf_cycles =*/ 0,
2885
3217
  /*.perf_time_us =*/ 0,
2886
- /*.data =*/ data == NULL ? (void *)(result + 1) : data,
3218
+ /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
2887
3219
  /*.pad =*/ { 0 },
2888
3220
  };
2889
3221
 
2890
- ggml_assert_aligned(result->data);
3222
+ // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
3223
+ //ggml_assert_aligned(result->data);
2891
3224
 
2892
3225
  for (int i = 0; i < n_dims; i++) {
2893
3226
  result->ne[i] = ne[i];
@@ -2908,44 +3241,44 @@ struct ggml_tensor * ggml_new_tensor(
2908
3241
  struct ggml_context * ctx,
2909
3242
  enum ggml_type type,
2910
3243
  int n_dims,
2911
- const int * ne) {
3244
+ const int64_t * ne) {
2912
3245
  return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
2913
3246
  }
2914
3247
 
2915
3248
  struct ggml_tensor * ggml_new_tensor_1d(
2916
3249
  struct ggml_context * ctx,
2917
3250
  enum ggml_type type,
2918
- int ne0) {
3251
+ int64_t ne0) {
2919
3252
  return ggml_new_tensor(ctx, type, 1, &ne0);
2920
3253
  }
2921
3254
 
2922
3255
  struct ggml_tensor * ggml_new_tensor_2d(
2923
3256
  struct ggml_context * ctx,
2924
3257
  enum ggml_type type,
2925
- int ne0,
2926
- int ne1) {
2927
- const int ne[2] = { ne0, ne1 };
3258
+ int64_t ne0,
3259
+ int64_t ne1) {
3260
+ const int64_t ne[2] = { ne0, ne1 };
2928
3261
  return ggml_new_tensor(ctx, type, 2, ne);
2929
3262
  }
2930
3263
 
2931
3264
  struct ggml_tensor * ggml_new_tensor_3d(
2932
3265
  struct ggml_context * ctx,
2933
3266
  enum ggml_type type,
2934
- int ne0,
2935
- int ne1,
2936
- int ne2) {
2937
- const int ne[3] = { ne0, ne1, ne2 };
3267
+ int64_t ne0,
3268
+ int64_t ne1,
3269
+ int64_t ne2) {
3270
+ const int64_t ne[3] = { ne0, ne1, ne2 };
2938
3271
  return ggml_new_tensor(ctx, type, 3, ne);
2939
3272
  }
2940
3273
 
2941
3274
  struct ggml_tensor * ggml_new_tensor_4d(
2942
3275
  struct ggml_context * ctx,
2943
3276
  enum ggml_type type,
2944
- int ne0,
2945
- int ne1,
2946
- int ne2,
2947
- int ne3) {
2948
- const int ne[4] = { ne0, ne1, ne2, ne3 };
3277
+ int64_t ne0,
3278
+ int64_t ne1,
3279
+ int64_t ne2,
3280
+ int64_t ne3) {
3281
+ const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
2949
3282
  return ggml_new_tensor(ctx, type, 4, ne);
2950
3283
  }
2951
3284
 
@@ -3288,7 +3621,14 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
3288
3621
  struct ggml_tensor * ggml_view_tensor(
3289
3622
  struct ggml_context * ctx,
3290
3623
  const struct ggml_tensor * src) {
3291
- return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
3624
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
3625
+
3626
+ result->nb[0] = src->nb[0];
3627
+ result->nb[1] = src->nb[1];
3628
+ result->nb[2] = src->nb[2];
3629
+ result->nb[3] = src->nb[3];
3630
+
3631
+ return result;
3292
3632
  }
3293
3633
 
3294
3634
  ////////////////////////////////////////////////////////////////////////////////
@@ -3592,7 +3932,7 @@ struct ggml_tensor * ggml_mean(
3592
3932
  is_node = true;
3593
3933
  }
3594
3934
 
3595
- int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
3935
+ int64_t ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
3596
3936
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne);
3597
3937
 
3598
3938
  result->op = GGML_OP_MEAN;
@@ -3953,7 +4293,7 @@ struct ggml_tensor * ggml_mul_mat(
3953
4293
  is_node = true;
3954
4294
  }
3955
4295
 
3956
- const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
4296
+ const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
3957
4297
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
3958
4298
 
3959
4299
  result->op = GGML_OP_MUL_MAT;
@@ -4078,8 +4418,8 @@ struct ggml_tensor * ggml_reshape(
4078
4418
  struct ggml_tensor * ggml_reshape_2d(
4079
4419
  struct ggml_context * ctx,
4080
4420
  struct ggml_tensor * a,
4081
- int ne0,
4082
- int ne1) {
4421
+ int64_t ne0,
4422
+ int64_t ne1) {
4083
4423
  GGML_ASSERT(ggml_is_contiguous(a));
4084
4424
  GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
4085
4425
 
@@ -4090,7 +4430,7 @@ struct ggml_tensor * ggml_reshape_2d(
4090
4430
  is_node = true;
4091
4431
  }
4092
4432
 
4093
- const int ne[2] = { ne0, ne1 };
4433
+ const int64_t ne[2] = { ne0, ne1 };
4094
4434
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
4095
4435
 
4096
4436
  result->op = GGML_OP_RESHAPE;
@@ -4104,9 +4444,9 @@ struct ggml_tensor * ggml_reshape_2d(
4104
4444
  struct ggml_tensor * ggml_reshape_3d(
4105
4445
  struct ggml_context * ctx,
4106
4446
  struct ggml_tensor * a,
4107
- int ne0,
4108
- int ne1,
4109
- int ne2) {
4447
+ int64_t ne0,
4448
+ int64_t ne1,
4449
+ int64_t ne2) {
4110
4450
  GGML_ASSERT(ggml_is_contiguous(a));
4111
4451
  GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
4112
4452
 
@@ -4117,7 +4457,7 @@ struct ggml_tensor * ggml_reshape_3d(
4117
4457
  is_node = true;
4118
4458
  }
4119
4459
 
4120
- const int ne[3] = { ne0, ne1, ne2 };
4460
+ const int64_t ne[3] = { ne0, ne1, ne2 };
4121
4461
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
4122
4462
 
4123
4463
  result->op = GGML_OP_RESHAPE;
@@ -4133,7 +4473,7 @@ struct ggml_tensor * ggml_reshape_3d(
4133
4473
  struct ggml_tensor * ggml_view_1d(
4134
4474
  struct ggml_context * ctx,
4135
4475
  struct ggml_tensor * a,
4136
- int ne0,
4476
+ int64_t ne0,
4137
4477
  size_t offset) {
4138
4478
  if (a->grad) {
4139
4479
  GGML_ASSERT(false); // gradient propagation is not supported
@@ -4154,15 +4494,15 @@ struct ggml_tensor * ggml_view_1d(
4154
4494
  struct ggml_tensor * ggml_view_2d(
4155
4495
  struct ggml_context * ctx,
4156
4496
  struct ggml_tensor * a,
4157
- int ne0,
4158
- int ne1,
4497
+ int64_t ne0,
4498
+ int64_t ne1,
4159
4499
  size_t nb1,
4160
4500
  size_t offset) {
4161
4501
  if (a->grad) {
4162
4502
  GGML_ASSERT(false); // gradient propagation is not supported
4163
4503
  }
4164
4504
 
4165
- const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
4505
+ const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
4166
4506
 
4167
4507
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
4168
4508
 
@@ -4178,6 +4518,37 @@ struct ggml_tensor * ggml_view_2d(
4178
4518
  return result;
4179
4519
  }
4180
4520
 
4521
+ // ggml_view_3d
4522
+
4523
+ struct ggml_tensor * ggml_view_3d(
4524
+ struct ggml_context * ctx,
4525
+ struct ggml_tensor * a,
4526
+ int64_t ne0,
4527
+ int64_t ne1,
4528
+ int64_t ne2,
4529
+ size_t nb1,
4530
+ size_t nb2,
4531
+ size_t offset) {
4532
+ if (a->grad) {
4533
+ GGML_ASSERT(false); // gradient propagation is not supported
4534
+ }
4535
+
4536
+ const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
4537
+
4538
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
4539
+
4540
+ result->nb[1] = nb1;
4541
+ result->nb[2] = nb2;
4542
+ result->nb[3] = result->nb[2]*ne2;
4543
+
4544
+ result->op = GGML_OP_VIEW;
4545
+ result->grad = NULL;
4546
+ result->src0 = a;
4547
+ result->src1 = NULL; // TODO: maybe store the offset here?
4548
+
4549
+ return result;
4550
+ }
4551
+
4181
4552
  // ggml_permute
4182
4553
 
4183
4554
  struct ggml_tensor * ggml_permute(
@@ -4393,7 +4764,7 @@ struct ggml_tensor * ggml_conv_1d_1s(
4393
4764
  is_node = true;
4394
4765
  }
4395
4766
 
4396
- const int ne[4] = { b->ne[0], a->ne[2], 1, 1, };
4767
+ const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
4397
4768
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
4398
4769
 
4399
4770
  result->op = GGML_OP_CONV_1D_1S;
@@ -4420,7 +4791,7 @@ struct ggml_tensor * ggml_conv_1d_2s(
4420
4791
  is_node = true;
4421
4792
  }
4422
4793
 
4423
- const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
4794
+ const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
4424
4795
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
4425
4796
 
4426
4797
  result->op = GGML_OP_CONV_1D_2S;
@@ -4513,102 +4884,112 @@ static void ggml_compute_forward_dup_f16(
4513
4884
  const struct ggml_tensor * src0,
4514
4885
  struct ggml_tensor * dst) {
4515
4886
  GGML_ASSERT(params->ith == 0);
4516
- GGML_ASSERT(ggml_is_contiguous(dst));
4517
4887
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
4518
4888
 
4519
4889
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
4520
4890
  return;
4521
4891
  }
4522
4892
 
4523
- const int ne00 = src0->ne[0];
4524
- const int ne01 = src0->ne[1];
4525
- const int ne02 = src0->ne[2];
4526
- const int ne03 = src0->ne[3];
4893
+ const int64_t ne00 = src0->ne[0];
4894
+ const int64_t ne01 = src0->ne[1];
4895
+ const int64_t ne02 = src0->ne[2];
4896
+ const int64_t ne03 = src0->ne[3];
4527
4897
 
4528
4898
  const size_t nb00 = src0->nb[0];
4529
4899
  const size_t nb01 = src0->nb[1];
4530
4900
  const size_t nb02 = src0->nb[2];
4531
4901
  const size_t nb03 = src0->nb[3];
4532
4902
 
4533
- if (ggml_is_contiguous(src0) && src0->type == dst->type) {
4903
+ const size_t nb0 = dst->nb[0];
4904
+ const size_t nb1 = dst->nb[1];
4905
+ const size_t nb2 = dst->nb[2];
4906
+ const size_t nb3 = dst->nb[3];
4907
+
4908
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
4534
4909
  memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
4535
4910
  return;
4536
4911
  }
4537
4912
 
4538
- if (src0->nb[0] == sizeof(ggml_fp16_t)) {
4539
- if (dst->type == GGML_TYPE_F16) {
4540
- size_t id = 0;
4541
- const size_t rs = ne00*nb00;
4542
-
4543
- for (int i03 = 0; i03 < ne03; i03++) {
4544
- for (int i02 = 0; i02 < ne02; i02++) {
4545
- for (int i01 = 0; i01 < ne01; i01++) {
4546
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
4547
- char * dst_ptr = (char *) dst->data + id*rs;
4548
-
4549
- memcpy(dst_ptr, src0_ptr, rs);
4550
-
4551
- id++;
4552
- }
4913
+ if (src0->type == dst->type &&
4914
+ src0->ne[0] == dst->ne[0] &&
4915
+ src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
4916
+ // copy by rows
4917
+ const size_t rs = ne00*nb00;
4918
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
4919
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
4920
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
4921
+ memcpy(
4922
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
4923
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
4924
+ rs);
4553
4925
  }
4554
4926
  }
4555
- } else if (dst->type == GGML_TYPE_F32) {
4556
- size_t id = 0;
4557
- float * dst_ptr = (float *) dst->data;
4558
-
4559
- for (int i03 = 0; i03 < ne03; i03++) {
4560
- for (int i02 = 0; i02 < ne02; i02++) {
4561
- for (int i01 = 0; i01 < ne01; i01++) {
4562
- for (int i00 = 0; i00 < ne00; i00++) {
4563
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4564
-
4565
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4566
- id++;
4567
- }
4568
- }
4569
- }
4570
- }
4571
- } else {
4572
- GGML_ASSERT(false); // TODO: implement
4573
4927
  }
4574
- } else {
4575
- //printf("%s: this is not optimal - fix me\n", __func__);
4576
-
4577
- if (dst->type == GGML_TYPE_F32) {
4578
- size_t id = 0;
4579
- float * dst_ptr = (float *) dst->data;
4580
-
4581
- for (int i03 = 0; i03 < ne03; i03++) {
4582
- for (int i02 = 0; i02 < ne02; i02++) {
4583
- for (int i01 = 0; i01 < ne01; i01++) {
4584
- for (int i00 = 0; i00 < ne00; i00++) {
4585
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4928
+ return;
4929
+ }
4586
4930
 
4587
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4588
- id++;
4931
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
4932
+
4933
+ // dst counters
4934
+ int64_t i10 = 0;
4935
+ int64_t i11 = 0;
4936
+ int64_t i12 = 0;
4937
+ int64_t i13 = 0;
4938
+
4939
+ if (dst->type == GGML_TYPE_F16) {
4940
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
4941
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
4942
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
4943
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
4944
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4945
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4946
+
4947
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
4948
+
4949
+ if (++i10 == ne00) {
4950
+ i10 = 0;
4951
+ if (++i11 == ne01) {
4952
+ i11 = 0;
4953
+ if (++i12 == ne02) {
4954
+ i12 = 0;
4955
+ if (++i13 == ne03) {
4956
+ i13 = 0;
4957
+ }
4958
+ }
4959
+ }
4589
4960
  }
4590
4961
  }
4591
4962
  }
4592
4963
  }
4593
- } else if (dst->type == GGML_TYPE_F16) {
4594
- size_t id = 0;
4595
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4596
-
4597
- for (int i03 = 0; i03 < ne03; i03++) {
4598
- for (int i02 = 0; i02 < ne02; i02++) {
4599
- for (int i01 = 0; i01 < ne01; i01++) {
4600
- for (int i00 = 0; i00 < ne00; i00++) {
4601
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4602
-
4603
- dst_ptr[id] = *src0_ptr;
4604
- id++;
4964
+ }
4965
+ } else if (dst->type == GGML_TYPE_F32) {
4966
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
4967
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
4968
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
4969
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
4970
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4971
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4972
+
4973
+ *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
4974
+
4975
+ if (++i10 == ne00) {
4976
+ i10 = 0;
4977
+ if (++i11 == ne01) {
4978
+ i11 = 0;
4979
+ if (++i12 == ne02) {
4980
+ i12 = 0;
4981
+ if (++i13 == ne03) {
4982
+ i13 = 0;
4983
+ }
4984
+ }
4985
+ }
4605
4986
  }
4606
4987
  }
4607
4988
  }
4608
4989
  }
4609
- } else {
4610
- GGML_ASSERT(false); // TODO: implement
4611
4990
  }
4991
+ } else {
4992
+ GGML_ASSERT(false); // TODO: implement
4612
4993
  }
4613
4994
  }
4614
4995
 
@@ -4617,102 +4998,92 @@ static void ggml_compute_forward_dup_f32(
4617
4998
  const struct ggml_tensor * src0,
4618
4999
  struct ggml_tensor * dst) {
4619
5000
  GGML_ASSERT(params->ith == 0);
4620
- GGML_ASSERT(ggml_is_contiguous(dst));
4621
5001
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
4622
5002
 
4623
5003
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
4624
5004
  return;
4625
5005
  }
4626
5006
 
4627
- const int ne00 = src0->ne[0];
4628
- const int ne01 = src0->ne[1];
4629
- const int ne02 = src0->ne[2];
4630
- const int ne03 = src0->ne[3];
5007
+ const int64_t ne00 = src0->ne[0];
5008
+ const int64_t ne01 = src0->ne[1];
5009
+ const int64_t ne02 = src0->ne[2];
5010
+ const int64_t ne03 = src0->ne[3];
4631
5011
 
4632
5012
  const size_t nb00 = src0->nb[0];
4633
5013
  const size_t nb01 = src0->nb[1];
4634
5014
  const size_t nb02 = src0->nb[2];
4635
5015
  const size_t nb03 = src0->nb[3];
4636
5016
 
4637
- if (ggml_is_contiguous(src0) && src0->type == dst->type) {
5017
+ const size_t nb0 = dst->nb[0];
5018
+ const size_t nb1 = dst->nb[1];
5019
+ const size_t nb2 = dst->nb[2];
5020
+ const size_t nb3 = dst->nb[3];
5021
+
5022
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
4638
5023
  memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
4639
5024
  return;
4640
5025
  }
4641
5026
 
4642
- if (src0->nb[0] == sizeof(float)) {
4643
- if (dst->type == GGML_TYPE_F32) {
4644
- size_t id = 0;
4645
- const size_t rs = ne00*nb00;
4646
-
4647
- for (int i03 = 0; i03 < ne03; i03++) {
4648
- for (int i02 = 0; i02 < ne02; i02++) {
4649
- for (int i01 = 0; i01 < ne01; i01++) {
4650
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
4651
- char * dst_ptr = (char *) dst->data + id*rs;
4652
-
4653
- memcpy(dst_ptr, src0_ptr, rs);
4654
-
4655
- id++;
4656
- }
4657
- }
4658
- }
4659
- } else if (dst->type == GGML_TYPE_F16) {
4660
- size_t id = 0;
4661
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4662
-
4663
- for (int i03 = 0; i03 < ne03; i03++) {
4664
- for (int i02 = 0; i02 < ne02; i02++) {
4665
- for (int i01 = 0; i01 < ne01; i01++) {
4666
- for (int i00 = 0; i00 < ne00; i00++) {
4667
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4668
-
4669
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
4670
- id++;
5027
+ // dst counters
5028
+ int64_t i10 = 0;
5029
+ int64_t i11 = 0;
5030
+ int64_t i12 = 0;
5031
+ int64_t i13 = 0;
5032
+
5033
+ if (dst->type == GGML_TYPE_F32) {
5034
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5035
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5036
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5037
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5038
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5039
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5040
+
5041
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
5042
+
5043
+ if (++i10 == dst->ne[0]) {
5044
+ i10 = 0;
5045
+ if (++i11 == dst->ne[1]) {
5046
+ i11 = 0;
5047
+ if (++i12 == dst->ne[2]) {
5048
+ i12 = 0;
5049
+ if (++i13 == dst->ne[3]) {
5050
+ i13 = 0;
5051
+ }
5052
+ }
5053
+ }
4671
5054
  }
4672
5055
  }
4673
5056
  }
4674
5057
  }
4675
- } else {
4676
- GGML_ASSERT(false); // TODO: implement
4677
5058
  }
4678
- } else {
4679
- //printf("%s: this is not optimal - fix me\n", __func__);
4680
-
4681
- if (dst->type == GGML_TYPE_F32) {
4682
- size_t id = 0;
4683
- float * dst_ptr = (float *) dst->data;
4684
-
4685
- for (int i03 = 0; i03 < ne03; i03++) {
4686
- for (int i02 = 0; i02 < ne02; i02++) {
4687
- for (int i01 = 0; i01 < ne01; i01++) {
4688
- for (int i00 = 0; i00 < ne00; i00++) {
4689
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4690
-
4691
- dst_ptr[id] = *src0_ptr;
4692
- id++;
4693
- }
4694
- }
4695
- }
4696
- }
4697
- } else if (dst->type == GGML_TYPE_F16) {
4698
- size_t id = 0;
4699
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4700
-
4701
- for (int i03 = 0; i03 < ne03; i03++) {
4702
- for (int i02 = 0; i02 < ne02; i02++) {
4703
- for (int i01 = 0; i01 < ne01; i01++) {
4704
- for (int i00 = 0; i00 < ne00; i00++) {
4705
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4706
-
4707
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
4708
- id++;
5059
+ } else if (dst->type == GGML_TYPE_F16) {
5060
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5061
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5062
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5063
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5064
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5065
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5066
+
5067
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5068
+
5069
+ if (++i10 == dst->ne[0]) {
5070
+ i10 = 0;
5071
+ if (++i11 == dst->ne[1]) {
5072
+ i11 = 0;
5073
+ if (++i12 == dst->ne[2]) {
5074
+ i12 = 0;
5075
+ if (++i13 == dst->ne[3]) {
5076
+ i13 = 0;
5077
+ }
5078
+ }
5079
+ }
4709
5080
  }
4710
5081
  }
4711
5082
  }
4712
5083
  }
4713
- } else {
4714
- GGML_ASSERT(false); // TODO: implement
4715
5084
  }
5085
+ } else {
5086
+ GGML_ASSERT(false); // TODO: implement
4716
5087
  }
4717
5088
  }
4718
5089
 
@@ -5087,18 +5458,18 @@ static void ggml_compute_forward_sum_f32(
5087
5458
  assert(ggml_is_scalar(dst));
5088
5459
  assert(src0->nb[0] == sizeof(float));
5089
5460
 
5090
- const int ne00 = src0->ne[0];
5091
- const int ne01 = src0->ne[1];
5092
- const int ne02 = src0->ne[2];
5093
- const int ne03 = src0->ne[3];
5461
+ const int64_t ne00 = src0->ne[0];
5462
+ const int64_t ne01 = src0->ne[1];
5463
+ const int64_t ne02 = src0->ne[2];
5464
+ const int64_t ne03 = src0->ne[3];
5094
5465
 
5095
5466
  const size_t nb01 = src0->nb[1];
5096
5467
  const size_t nb02 = src0->nb[2];
5097
5468
  const size_t nb03 = src0->nb[3];
5098
5469
 
5099
- for (int i03 = 0; i03 < ne03; i03++) {
5100
- for (int i02 = 0; i02 < ne02; i02++) {
5101
- for (int i01 = 0; i01 < ne01; i01++) {
5470
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5471
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5472
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5102
5473
  ggml_vec_sum_f32(ne00,
5103
5474
  (float *) (dst->data),
5104
5475
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@@ -5143,19 +5514,19 @@ static void ggml_compute_forward_mean_f32(
5143
5514
 
5144
5515
  assert(src0->nb[0] == sizeof(float));
5145
5516
 
5146
- const int ne00 = src0->ne[0];
5147
- const int ne01 = src0->ne[1];
5148
- const int ne02 = src0->ne[2];
5149
- const int ne03 = src0->ne[3];
5517
+ const int64_t ne00 = src0->ne[0];
5518
+ const int64_t ne01 = src0->ne[1];
5519
+ const int64_t ne02 = src0->ne[2];
5520
+ const int64_t ne03 = src0->ne[3];
5150
5521
 
5151
5522
  const size_t nb01 = src0->nb[1];
5152
5523
  const size_t nb02 = src0->nb[2];
5153
5524
  const size_t nb03 = src0->nb[3];
5154
5525
 
5155
- const int ne0 = dst->ne[0];
5156
- const int ne1 = dst->ne[1];
5157
- const int ne2 = dst->ne[2];
5158
- const int ne3 = dst->ne[3];
5526
+ const int64_t ne0 = dst->ne[0];
5527
+ const int64_t ne1 = dst->ne[1];
5528
+ const int64_t ne2 = dst->ne[2];
5529
+ const int64_t ne3 = dst->ne[3];
5159
5530
 
5160
5531
  assert(ne0 == 1);
5161
5532
  assert(ne1 == ne01);
@@ -5171,9 +5542,9 @@ static void ggml_compute_forward_mean_f32(
5171
5542
  const size_t nb2 = dst->nb[2];
5172
5543
  const size_t nb3 = dst->nb[3];
5173
5544
 
5174
- for (int i03 = 0; i03 < ne03; i03++) {
5175
- for (int i02 = 0; i02 < ne02; i02++) {
5176
- for (int i01 = 0; i01 < ne01; i01++) {
5545
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5546
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5547
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5177
5548
  ggml_vec_sum_f32(ne00,
5178
5549
  (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5179
5550
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@@ -5660,10 +6031,10 @@ static void ggml_compute_forward_norm_f32(
5660
6031
  const int ith = params->ith;
5661
6032
  const int nth = params->nth;
5662
6033
 
5663
- const int ne00 = src0->ne[0];
5664
- const int ne01 = src0->ne[1];
5665
- const int ne02 = src0->ne[2];
5666
- const int ne03 = src0->ne[3];
6034
+ const int64_t ne00 = src0->ne[0];
6035
+ const int64_t ne01 = src0->ne[1];
6036
+ const int64_t ne02 = src0->ne[2];
6037
+ const int64_t ne03 = src0->ne[3];
5667
6038
 
5668
6039
  const size_t nb01 = src0->nb[1];
5669
6040
  const size_t nb02 = src0->nb[2];
@@ -5676,13 +6047,13 @@ static void ggml_compute_forward_norm_f32(
5676
6047
  const float eps = 1e-5f; // TODO: make this a parameter
5677
6048
 
5678
6049
  // TODO: optimize
5679
- for (int i03 = 0; i03 < ne03; i03++) {
5680
- for (int i02 = 0; i02 < ne02; i02++) {
5681
- for (int i01 = ith; i01 < ne01; i01 += nth) {
6050
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6051
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6052
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5682
6053
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5683
6054
 
5684
6055
  ggml_float sum = 0.0;
5685
- for (int i00 = 0; i00 < ne00; i00++) {
6056
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5686
6057
  sum += (ggml_float)x[i00];
5687
6058
  }
5688
6059
 
@@ -5691,7 +6062,7 @@ static void ggml_compute_forward_norm_f32(
5691
6062
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5692
6063
 
5693
6064
  ggml_float sum2 = 0.0;
5694
- for (int i00 = 0; i00 < ne00; i00++) {
6065
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5695
6066
  float v = x[i00] - mean;
5696
6067
  y[i00] = v;
5697
6068
  sum2 += (ggml_float)(v*v);
@@ -5743,10 +6114,10 @@ static void ggml_compute_forward_rms_norm_f32(
5743
6114
  const int ith = params->ith;
5744
6115
  const int nth = params->nth;
5745
6116
 
5746
- const int ne00 = src0->ne[0];
5747
- const int ne01 = src0->ne[1];
5748
- const int ne02 = src0->ne[2];
5749
- const int ne03 = src0->ne[3];
6117
+ const int64_t ne00 = src0->ne[0];
6118
+ const int64_t ne01 = src0->ne[1];
6119
+ const int64_t ne02 = src0->ne[2];
6120
+ const int64_t ne03 = src0->ne[3];
5750
6121
 
5751
6122
  const size_t nb01 = src0->nb[1];
5752
6123
  const size_t nb02 = src0->nb[2];
@@ -5759,13 +6130,13 @@ static void ggml_compute_forward_rms_norm_f32(
5759
6130
  const float eps = 1e-6f; // TODO: make this a parameter
5760
6131
 
5761
6132
  // TODO: optimize
5762
- for (int i03 = 0; i03 < ne03; i03++) {
5763
- for (int i02 = 0; i02 < ne02; i02++) {
5764
- for (int i01 = ith; i01 < ne01; i01 += nth) {
6133
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6134
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6135
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5765
6136
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5766
6137
 
5767
6138
  ggml_float sum = 0.0;
5768
- for (int i00 = 0; i00 < ne00; i00++) {
6139
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5769
6140
  sum += (ggml_float)(x[i00] * x[i00]);
5770
6141
  }
5771
6142
 
@@ -5818,13 +6189,13 @@ static bool ggml_compute_forward_mul_mat_use_blas(
5818
6189
  const struct ggml_tensor * src0,
5819
6190
  const struct ggml_tensor * src1,
5820
6191
  struct ggml_tensor * dst) {
5821
- //const int ne00 = src0->ne[0];
5822
- //const int ne01 = src0->ne[1];
6192
+ //const int64_t ne00 = src0->ne[0];
6193
+ //const int64_t ne01 = src0->ne[1];
5823
6194
 
5824
- const int ne10 = src1->ne[0];
6195
+ const int64_t ne10 = src1->ne[0];
5825
6196
 
5826
- const int ne0 = dst->ne[0];
5827
- const int ne1 = dst->ne[1];
6197
+ const int64_t ne0 = dst->ne[0];
6198
+ const int64_t ne1 = dst->ne[1];
5828
6199
 
5829
6200
  // TODO: find the optimal values for these
5830
6201
  if (ggml_is_contiguous(src0) &&
@@ -5846,23 +6217,23 @@ static void ggml_compute_forward_mul_mat_f32(
5846
6217
  int64_t t0 = ggml_perf_time_us();
5847
6218
  UNUSED(t0);
5848
6219
 
5849
- const int ne00 = src0->ne[0];
5850
- const int ne01 = src0->ne[1];
5851
- const int ne02 = src0->ne[2];
5852
- const int ne03 = src0->ne[3];
6220
+ const int64_t ne00 = src0->ne[0];
6221
+ const int64_t ne01 = src0->ne[1];
6222
+ const int64_t ne02 = src0->ne[2];
6223
+ const int64_t ne03 = src0->ne[3];
5853
6224
 
5854
6225
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
5855
- const int ne10 = src1->ne[0];
6226
+ const int64_t ne10 = src1->ne[0];
5856
6227
  #endif
5857
- const int ne11 = src1->ne[1];
6228
+ const int64_t ne11 = src1->ne[1];
5858
6229
  #ifndef NDEBUG
5859
- const int ne12 = src1->ne[2];
5860
- const int ne13 = src1->ne[3];
6230
+ const int64_t ne12 = src1->ne[2];
6231
+ const int64_t ne13 = src1->ne[3];
5861
6232
 
5862
- const int ne0 = dst->ne[0];
5863
- const int ne1 = dst->ne[1];
5864
- const int ne2 = dst->ne[2];
5865
- const int ne3 = dst->ne[3];
6233
+ const int64_t ne0 = dst->ne[0];
6234
+ const int64_t ne1 = dst->ne[1];
6235
+ const int64_t ne2 = dst->ne[2];
6236
+ const int64_t ne3 = dst->ne[3];
5866
6237
 
5867
6238
  const int nb00 = src0->nb[0];
5868
6239
  #endif
@@ -5922,8 +6293,8 @@ static void ggml_compute_forward_mul_mat_f32(
5922
6293
  return;
5923
6294
  }
5924
6295
 
5925
- for (int i03 = 0; i03 < ne03; i03++) {
5926
- for (int i02 = 0; i02 < ne02; i02++) {
6296
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6297
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5927
6298
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
5928
6299
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
5929
6300
 
@@ -5970,7 +6341,7 @@ static void ggml_compute_forward_mul_mat_f32(
5970
6341
  const int i02 = (ir - i03*ne02*ne01)/ne01;
5971
6342
  const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
5972
6343
 
5973
- for (int ic = 0; ic < ne11; ++ic) {
6344
+ for (int64_t ic = 0; ic < ne11; ++ic) {
5974
6345
  // src1 indices
5975
6346
  const int i13 = i03;
5976
6347
  const int i12 = i02;
@@ -6011,21 +6382,21 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6011
6382
  int64_t t0 = ggml_perf_time_us();
6012
6383
  UNUSED(t0);
6013
6384
 
6014
- const int ne00 = src0->ne[0];
6015
- const int ne01 = src0->ne[1];
6016
- const int ne02 = src0->ne[2];
6017
- const int ne03 = src0->ne[3];
6385
+ const int64_t ne00 = src0->ne[0];
6386
+ const int64_t ne01 = src0->ne[1];
6387
+ const int64_t ne02 = src0->ne[2];
6388
+ const int64_t ne03 = src0->ne[3];
6018
6389
 
6019
- const int ne10 = src1->ne[0];
6020
- const int ne11 = src1->ne[1];
6021
- const int ne12 = src1->ne[2];
6022
- const int ne13 = src1->ne[3];
6390
+ const int64_t ne10 = src1->ne[0];
6391
+ const int64_t ne11 = src1->ne[1];
6392
+ const int64_t ne12 = src1->ne[2];
6393
+ const int64_t ne13 = src1->ne[3];
6023
6394
 
6024
- const int ne0 = dst->ne[0];
6025
- const int ne1 = dst->ne[1];
6026
- const int ne2 = dst->ne[2];
6027
- const int ne3 = dst->ne[3];
6028
- //const int ne = ne0*ne1*ne2*ne3;
6395
+ const int64_t ne0 = dst->ne[0];
6396
+ const int64_t ne1 = dst->ne[1];
6397
+ const int64_t ne2 = dst->ne[2];
6398
+ const int64_t ne3 = dst->ne[3];
6399
+ //const int64_t ne = ne0*ne1*ne2*ne3;
6029
6400
 
6030
6401
  const int nb00 = src0->nb[0];
6031
6402
  const int nb01 = src0->nb[1];
@@ -6085,12 +6456,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6085
6456
 
6086
6457
  float * const wdata = params->wdata;
6087
6458
 
6088
- for (int i03 = 0; i03 < ne03; i03++) {
6089
- for (int i02 = 0; i02 < ne02; i02++) {
6459
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6460
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6090
6461
  {
6091
6462
  size_t id = 0;
6092
- for (int i01 = 0; i01 < ne01; ++i01) {
6093
- for (int i00 = 0; i00 < ne00; ++i00) {
6463
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
6464
+ for (int64_t i00 = 0; i00 < ne00; ++i00) {
6094
6465
  wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
6095
6466
  }
6096
6467
  }
@@ -6120,10 +6491,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6120
6491
  ggml_fp16_t * const wdata = params->wdata;
6121
6492
 
6122
6493
  size_t id = 0;
6123
- for (int i13 = 0; i13 < ne13; ++i13) {
6124
- for (int i12 = 0; i12 < ne12; ++i12) {
6125
- for (int i11 = 0; i11 < ne11; ++i11) {
6126
- for (int i10 = 0; i10 < ne10; ++i10) {
6494
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
6495
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
6496
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
6497
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
6127
6498
  wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
6128
6499
  }
6129
6500
  }
@@ -6175,7 +6546,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6175
6546
 
6176
6547
  float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
6177
6548
 
6178
- for (int ic = 0; ic < ne11; ++ic) {
6549
+ for (int64_t ic = 0; ic < ne11; ++ic) {
6179
6550
  ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
6180
6551
  }
6181
6552
  }
@@ -6224,20 +6595,20 @@ static void ggml_compute_forward_mul_mat_q_f32(
6224
6595
  int64_t t0 = ggml_perf_time_us();
6225
6596
  UNUSED(t0);
6226
6597
 
6227
- const int ne00 = src0->ne[0];
6228
- const int ne01 = src0->ne[1];
6229
- const int ne02 = src0->ne[2];
6230
- const int ne03 = src0->ne[3];
6598
+ const int64_t ne00 = src0->ne[0];
6599
+ const int64_t ne01 = src0->ne[1];
6600
+ const int64_t ne02 = src0->ne[2];
6601
+ const int64_t ne03 = src0->ne[3];
6231
6602
 
6232
- const int ne10 = src1->ne[0];
6233
- const int ne11 = src1->ne[1];
6234
- const int ne12 = src1->ne[2];
6235
- const int ne13 = src1->ne[3];
6603
+ const int64_t ne10 = src1->ne[0];
6604
+ const int64_t ne11 = src1->ne[1];
6605
+ const int64_t ne12 = src1->ne[2];
6606
+ const int64_t ne13 = src1->ne[3];
6236
6607
 
6237
- const int ne0 = dst->ne[0];
6238
- const int ne1 = dst->ne[1];
6239
- const int ne2 = dst->ne[2];
6240
- const int ne3 = dst->ne[3];
6608
+ const int64_t ne0 = dst->ne[0];
6609
+ const int64_t ne1 = dst->ne[1];
6610
+ const int64_t ne2 = dst->ne[2];
6611
+ const int64_t ne3 = dst->ne[3];
6241
6612
 
6242
6613
  const int nb00 = src0->nb[0];
6243
6614
  const int nb01 = src0->nb[1];
@@ -6301,11 +6672,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
6301
6672
  float * const wdata = params->wdata;
6302
6673
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
6303
6674
 
6304
- for (int i03 = 0; i03 < ne03; i03++) {
6305
- for (int i02 = 0; i02 < ne02; i02++) {
6675
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6676
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6306
6677
  {
6307
6678
  size_t id = 0;
6308
- for (int i01 = 0; i01 < ne01; ++i01) {
6679
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
6309
6680
  dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
6310
6681
  id += ne00;
6311
6682
  }
@@ -6335,9 +6706,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
6335
6706
  char * wdata = params->wdata;
6336
6707
  const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
6337
6708
 
6338
- for (int i13 = 0; i13 < ne13; ++i13) {
6339
- for (int i12 = 0; i12 < ne12; ++i12) {
6340
- for (int i11 = 0; i11 < ne11; ++i11) {
6709
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
6710
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
6711
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
6341
6712
  quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
6342
6713
  wdata += row_size;
6343
6714
  }
@@ -6386,7 +6757,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6386
6757
 
6387
6758
  assert(ne00 % 32 == 0);
6388
6759
 
6389
- for (int ic = 0; ic < ne11; ++ic) {
6760
+ for (int64_t ic = 0; ic < ne11; ++ic) {
6390
6761
  vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
6391
6762
  }
6392
6763
  }
@@ -6867,7 +7238,6 @@ static void ggml_compute_forward_rope_f32(
6867
7238
  const struct ggml_tensor * src0,
6868
7239
  const struct ggml_tensor * src1,
6869
7240
  struct ggml_tensor * dst) {
6870
- assert(params->ith == 0);
6871
7241
  assert(src1->type == GGML_TYPE_I32);
6872
7242
  assert(ggml_nelements(src1) == 3);
6873
7243
 
@@ -6879,10 +7249,10 @@ static void ggml_compute_forward_rope_f32(
6879
7249
  const int n_dims = ((int32_t *) src1->data)[1];
6880
7250
  const int mode = ((int32_t *) src1->data)[2];
6881
7251
 
6882
- //const int ne0 = src0->ne[0];
6883
- const int ne1 = src0->ne[1];
6884
- const int ne2 = src0->ne[2];
6885
- const int ne3 = src0->ne[3];
7252
+ //const int64_t ne0 = src0->ne[0];
7253
+ const int64_t ne1 = src0->ne[1];
7254
+ const int64_t ne2 = src0->ne[2];
7255
+ const int64_t ne3 = src0->ne[3];
6886
7256
 
6887
7257
  const int nb0 = src0->nb[0];
6888
7258
  const int nb1 = src0->nb[1];
@@ -6894,11 +7264,28 @@ static void ggml_compute_forward_rope_f32(
6894
7264
 
6895
7265
  assert(nb0 == sizeof(float));
6896
7266
 
6897
- // TODO: optimize
6898
- for (int i3 = 0; i3 < ne3; i3++) {
6899
- for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7267
+ const int ith = params->ith;
7268
+ const int nth = params->nth;
7269
+
7270
+ const int nr = ggml_nrows(src0);
7271
+
7272
+ // rows per thread
7273
+ const int dr = (nr + nth - 1)/nth;
7274
+
7275
+ // row range for this thread
7276
+ const int ir0 = dr*ith;
7277
+ const int ir1 = MIN(ir0 + dr, nr);
7278
+
7279
+ // row index used to determine which thread to use
7280
+ int ir = 0;
7281
+
7282
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7283
+ for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
6900
7284
  const int p = (mode == 0 ? n_past + i2 : i2);
6901
- for (int i1 = 0; i1 < ne1; i1++) {
7285
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7286
+ if (ir++ < ir0) continue;
7287
+ if (ir > ir1) break;
7288
+
6902
7289
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
6903
7290
  const float theta = powf(10000.0, ((float)-i0)/n_dims);
6904
7291
 
@@ -6924,7 +7311,6 @@ static void ggml_compute_forward_rope_f16(
6924
7311
  const struct ggml_tensor * src0,
6925
7312
  const struct ggml_tensor * src1,
6926
7313
  struct ggml_tensor * dst) {
6927
- assert(params->ith == 0);
6928
7314
  assert(src1->type == GGML_TYPE_I32);
6929
7315
  assert(ggml_nelements(src1) == 3);
6930
7316
 
@@ -6936,10 +7322,10 @@ static void ggml_compute_forward_rope_f16(
6936
7322
  const int n_dims = ((int32_t *) src1->data)[1];
6937
7323
  const int mode = ((int32_t *) src1->data)[2];
6938
7324
 
6939
- //const int ne0 = src0->ne[0];
6940
- const int ne1 = src0->ne[1];
6941
- const int ne2 = src0->ne[2];
6942
- const int ne3 = src0->ne[3];
7325
+ //const int64_t ne0 = src0->ne[0];
7326
+ const int64_t ne1 = src0->ne[1];
7327
+ const int64_t ne2 = src0->ne[2];
7328
+ const int64_t ne3 = src0->ne[3];
6943
7329
 
6944
7330
  const int nb0 = src0->nb[0];
6945
7331
  const int nb1 = src0->nb[1];
@@ -6951,10 +7337,28 @@ static void ggml_compute_forward_rope_f16(
6951
7337
 
6952
7338
  assert(nb0 == sizeof(ggml_fp16_t));
6953
7339
 
6954
- for (int i3 = 0; i3 < ne3; i3++) {
6955
- for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7340
+ const int ith = params->ith;
7341
+ const int nth = params->nth;
7342
+
7343
+ const int nr = ggml_nrows(src0);
7344
+
7345
+ // rows per thread
7346
+ const int dr = (nr + nth - 1)/nth;
7347
+
7348
+ // row range for this thread
7349
+ const int ir0 = dr*ith;
7350
+ const int ir1 = MIN(ir0 + dr, nr);
7351
+
7352
+ // row index used to determine which thread to use
7353
+ int ir = 0;
7354
+
7355
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7356
+ for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
6956
7357
  const int p = (mode == 0 ? n_past + i2 : i2);
6957
- for (int i1 = 0; i1 < ne1; i1++) {
7358
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7359
+ if (ir++ < ir0) continue;
7360
+ if (ir > ir1) break;
7361
+
6958
7362
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
6959
7363
  const float theta = powf(10000.0, ((float)-i0)/n_dims);
6960
7364
 
@@ -7015,21 +7419,21 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7015
7419
  int64_t t0 = ggml_perf_time_us();
7016
7420
  UNUSED(t0);
7017
7421
 
7018
- const int ne00 = src0->ne[0];
7019
- const int ne01 = src0->ne[1];
7020
- const int ne02 = src0->ne[2];
7021
- //const int ne03 = src0->ne[3];
7422
+ const int64_t ne00 = src0->ne[0];
7423
+ const int64_t ne01 = src0->ne[1];
7424
+ const int64_t ne02 = src0->ne[2];
7425
+ //const int64_t ne03 = src0->ne[3];
7022
7426
 
7023
- const int ne10 = src1->ne[0];
7024
- const int ne11 = src1->ne[1];
7025
- //const int ne12 = src1->ne[2];
7026
- //const int ne13 = src1->ne[3];
7427
+ const int64_t ne10 = src1->ne[0];
7428
+ const int64_t ne11 = src1->ne[1];
7429
+ //const int64_t ne12 = src1->ne[2];
7430
+ //const int64_t ne13 = src1->ne[3];
7027
7431
 
7028
- //const int ne0 = dst->ne[0];
7029
- //const int ne1 = dst->ne[1];
7030
- //const int ne2 = dst->ne[2];
7031
- //const int ne3 = dst->ne[3];
7032
- //const int ne = ne0*ne1*ne2*ne3;
7432
+ //const int64_t ne0 = dst->ne[0];
7433
+ //const int64_t ne1 = dst->ne[1];
7434
+ //const int64_t ne2 = dst->ne[2];
7435
+ //const int64_t ne3 = dst->ne[3];
7436
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7033
7437
 
7034
7438
  const int nb00 = src0->nb[0];
7035
7439
  const int nb01 = src0->nb[1];
@@ -7066,11 +7470,11 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7066
7470
  {
7067
7471
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
7068
7472
 
7069
- for (int i02 = 0; i02 < ne02; i02++) {
7070
- for (int i01 = 0; i01 < ne01; i01++) {
7473
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7474
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7071
7475
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
7072
7476
  ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
7073
- for (int i00 = 0; i00 < ne00; i00++) {
7477
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7074
7478
  dst_data[i00*ew0 + i01] = src[i00];
7075
7479
  }
7076
7480
  }
@@ -7081,10 +7485,10 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7081
7485
  {
7082
7486
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
7083
7487
 
7084
- for (int i11 = 0; i11 < ne11; i11++) {
7488
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7085
7489
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7086
7490
  ggml_fp16_t * dst_data = wdata;
7087
- for (int i10 = 0; i10 < ne10; i10++) {
7491
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7088
7492
  dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
7089
7493
  }
7090
7494
  }
@@ -7109,7 +7513,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7109
7513
 
7110
7514
  for (int i1 = ir0; i1 < ir1; i1++) {
7111
7515
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7112
- for (int i0 = 0; i0 < ne10; ++i0) {
7516
+ for (int64_t i0 = 0; i0 < ne10; ++i0) {
7113
7517
  dst_data[i0] = 0;
7114
7518
  for (int k = -nh; k <= nh; k++) {
7115
7519
  float v = 0.0f;
@@ -7135,21 +7539,21 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7135
7539
  int64_t t0 = ggml_perf_time_us();
7136
7540
  UNUSED(t0);
7137
7541
 
7138
- const int ne00 = src0->ne[0];
7139
- const int ne01 = src0->ne[1];
7140
- const int ne02 = src0->ne[2];
7141
- //const int ne03 = src0->ne[3];
7542
+ const int64_t ne00 = src0->ne[0];
7543
+ const int64_t ne01 = src0->ne[1];
7544
+ const int64_t ne02 = src0->ne[2];
7545
+ //const int64_t ne03 = src0->ne[3];
7142
7546
 
7143
- const int ne10 = src1->ne[0];
7144
- const int ne11 = src1->ne[1];
7145
- //const int ne12 = src1->ne[2];
7146
- //const int ne13 = src1->ne[3];
7547
+ const int64_t ne10 = src1->ne[0];
7548
+ const int64_t ne11 = src1->ne[1];
7549
+ //const int64_t ne12 = src1->ne[2];
7550
+ //const int64_t ne13 = src1->ne[3];
7147
7551
 
7148
- //const int ne0 = dst->ne[0];
7149
- //const int ne1 = dst->ne[1];
7150
- //const int ne2 = dst->ne[2];
7151
- //const int ne3 = dst->ne[3];
7152
- //const int ne = ne0*ne1*ne2*ne3;
7552
+ //const int64_t ne0 = dst->ne[0];
7553
+ //const int64_t ne1 = dst->ne[1];
7554
+ //const int64_t ne2 = dst->ne[2];
7555
+ //const int64_t ne3 = dst->ne[3];
7556
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7153
7557
 
7154
7558
  const int nb00 = src0->nb[0];
7155
7559
  const int nb01 = src0->nb[1];
@@ -7186,11 +7590,11 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7186
7590
  {
7187
7591
  float * const wdata = (float *) params->wdata + 0;
7188
7592
 
7189
- for (int i02 = 0; i02 < ne02; i02++) {
7190
- for (int i01 = 0; i01 < ne01; i01++) {
7593
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7594
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7191
7595
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
7192
7596
  float * dst_data = wdata + i02*ew0*ne00;
7193
- for (int i00 = 0; i00 < ne00; i00++) {
7597
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7194
7598
  dst_data[i00*ew0 + i01] = src[i00];
7195
7599
  }
7196
7600
  }
@@ -7201,10 +7605,10 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7201
7605
  {
7202
7606
  float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
7203
7607
 
7204
- for (int i11 = 0; i11 < ne11; i11++) {
7608
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7205
7609
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7206
7610
  float * dst_data = wdata;
7207
- for (int i10 = 0; i10 < ne10; i10++) {
7611
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7208
7612
  dst_data[(i10 + nh)*ew0 + i11] = src[i10];
7209
7613
  }
7210
7614
  }
@@ -7229,7 +7633,7 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7229
7633
 
7230
7634
  for (int i1 = ir0; i1 < ir1; i1++) {
7231
7635
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7232
- for (int i0 = 0; i0 < ne10; ++i0) {
7636
+ for (int64_t i0 = 0; i0 < ne10; ++i0) {
7233
7637
  dst_data[i0] = 0;
7234
7638
  for (int k = -nh; k <= nh; k++) {
7235
7639
  float v = 0.0f;
@@ -7283,21 +7687,21 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7283
7687
  int64_t t0 = ggml_perf_time_us();
7284
7688
  UNUSED(t0);
7285
7689
 
7286
- const int ne00 = src0->ne[0];
7287
- const int ne01 = src0->ne[1];
7288
- const int ne02 = src0->ne[2];
7289
- //const int ne03 = src0->ne[3];
7690
+ const int64_t ne00 = src0->ne[0];
7691
+ const int64_t ne01 = src0->ne[1];
7692
+ const int64_t ne02 = src0->ne[2];
7693
+ //const int64_t ne03 = src0->ne[3];
7290
7694
 
7291
- const int ne10 = src1->ne[0];
7292
- const int ne11 = src1->ne[1];
7293
- //const int ne12 = src1->ne[2];
7294
- //const int ne13 = src1->ne[3];
7695
+ const int64_t ne10 = src1->ne[0];
7696
+ const int64_t ne11 = src1->ne[1];
7697
+ //const int64_t ne12 = src1->ne[2];
7698
+ //const int64_t ne13 = src1->ne[3];
7295
7699
 
7296
- //const int ne0 = dst->ne[0];
7297
- //const int ne1 = dst->ne[1];
7298
- //const int ne2 = dst->ne[2];
7299
- //const int ne3 = dst->ne[3];
7300
- //const int ne = ne0*ne1*ne2*ne3;
7700
+ //const int64_t ne0 = dst->ne[0];
7701
+ //const int64_t ne1 = dst->ne[1];
7702
+ //const int64_t ne2 = dst->ne[2];
7703
+ //const int64_t ne3 = dst->ne[3];
7704
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7301
7705
 
7302
7706
  const int nb00 = src0->nb[0];
7303
7707
  const int nb01 = src0->nb[1];
@@ -7334,11 +7738,11 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7334
7738
  {
7335
7739
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
7336
7740
 
7337
- for (int i02 = 0; i02 < ne02; i02++) {
7338
- for (int i01 = 0; i01 < ne01; i01++) {
7741
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7742
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7339
7743
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
7340
7744
  ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
7341
- for (int i00 = 0; i00 < ne00; i00++) {
7745
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7342
7746
  dst_data[i00*ew0 + i01] = src[i00];
7343
7747
  }
7344
7748
  }
@@ -7349,10 +7753,10 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7349
7753
  {
7350
7754
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
7351
7755
 
7352
- for (int i11 = 0; i11 < ne11; i11++) {
7756
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7353
7757
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7354
7758
  ggml_fp16_t * dst_data = wdata;
7355
- for (int i10 = 0; i10 < ne10; i10++) {
7759
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7356
7760
  dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
7357
7761
  }
7358
7762
  }
@@ -7377,7 +7781,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7377
7781
 
7378
7782
  for (int i1 = ir0; i1 < ir1; i1++) {
7379
7783
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7380
- for (int i0 = 0; i0 < ne10; i0 += 2) {
7784
+ for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
7381
7785
  dst_data[i0/2] = 0;
7382
7786
  for (int k = -nh; k <= nh; k++) {
7383
7787
  float v = 0.0f;
@@ -7403,21 +7807,21 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7403
7807
  int64_t t0 = ggml_perf_time_us();
7404
7808
  UNUSED(t0);
7405
7809
 
7406
- const int ne00 = src0->ne[0];
7407
- const int ne01 = src0->ne[1];
7408
- const int ne02 = src0->ne[2];
7409
- //const int ne03 = src0->ne[3];
7810
+ const int64_t ne00 = src0->ne[0];
7811
+ const int64_t ne01 = src0->ne[1];
7812
+ const int64_t ne02 = src0->ne[2];
7813
+ //const int64_t ne03 = src0->ne[3];
7410
7814
 
7411
- const int ne10 = src1->ne[0];
7412
- const int ne11 = src1->ne[1];
7413
- //const int ne12 = src1->ne[2];
7414
- //const int ne13 = src1->ne[3];
7815
+ const int64_t ne10 = src1->ne[0];
7816
+ const int64_t ne11 = src1->ne[1];
7817
+ //const int64_t ne12 = src1->ne[2];
7818
+ //const int64_t ne13 = src1->ne[3];
7415
7819
 
7416
- //const int ne0 = dst->ne[0];
7417
- //const int ne1 = dst->ne[1];
7418
- //const int ne2 = dst->ne[2];
7419
- //const int ne3 = dst->ne[3];
7420
- //const int ne = ne0*ne1*ne2*ne3;
7820
+ //const int64_t ne0 = dst->ne[0];
7821
+ //const int64_t ne1 = dst->ne[1];
7822
+ //const int64_t ne2 = dst->ne[2];
7823
+ //const int64_t ne3 = dst->ne[3];
7824
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7421
7825
 
7422
7826
  const int nb00 = src0->nb[0];
7423
7827
  const int nb01 = src0->nb[1];
@@ -7454,11 +7858,11 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7454
7858
  {
7455
7859
  float * const wdata = (float *) params->wdata + 0;
7456
7860
 
7457
- for (int i02 = 0; i02 < ne02; i02++) {
7458
- for (int i01 = 0; i01 < ne01; i01++) {
7861
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7862
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7459
7863
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
7460
7864
  float * dst_data = wdata + i02*ew0*ne00;
7461
- for (int i00 = 0; i00 < ne00; i00++) {
7865
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7462
7866
  dst_data[i00*ew0 + i01] = src[i00];
7463
7867
  }
7464
7868
  }
@@ -7469,10 +7873,10 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7469
7873
  {
7470
7874
  float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
7471
7875
 
7472
- for (int i11 = 0; i11 < ne11; i11++) {
7876
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7473
7877
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7474
7878
  float * dst_data = wdata;
7475
- for (int i10 = 0; i10 < ne10; i10++) {
7879
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7476
7880
  dst_data[(i10 + nh)*ew0 + i11] = src[i10];
7477
7881
  }
7478
7882
  }
@@ -7497,7 +7901,7 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7497
7901
 
7498
7902
  for (int i1 = ir0; i1 < ir1; i1++) {
7499
7903
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7500
- for (int i0 = 0; i0 < ne10; i0 += 2) {
7904
+ for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
7501
7905
  dst_data[i0/2] = 0;
7502
7906
  for (int k = -nh; k <= nh; k++) {
7503
7907
  float v = 0.0f;
@@ -7549,25 +7953,25 @@ static void ggml_compute_forward_flash_attn_f32(
7549
7953
  int64_t t0 = ggml_perf_time_us();
7550
7954
  UNUSED(t0);
7551
7955
 
7552
- const int neq0 = q->ne[0];
7553
- const int neq1 = q->ne[1];
7554
- const int neq2 = q->ne[2];
7555
- const int neq3 = q->ne[3];
7956
+ const int64_t neq0 = q->ne[0];
7957
+ const int64_t neq1 = q->ne[1];
7958
+ const int64_t neq2 = q->ne[2];
7959
+ const int64_t neq3 = q->ne[3];
7556
7960
 
7557
- const int nek0 = k->ne[0];
7558
- const int nek1 = k->ne[1];
7559
- //const int nek2 = k->ne[2];
7560
- //const int nek3 = k->ne[3];
7961
+ const int64_t nek0 = k->ne[0];
7962
+ const int64_t nek1 = k->ne[1];
7963
+ //const int64_t nek2 = k->ne[2];
7964
+ //const int64_t nek3 = k->ne[3];
7561
7965
 
7562
- //const int nev0 = v->ne[0];
7563
- const int nev1 = v->ne[1];
7564
- //const int nev2 = v->ne[2];
7565
- //const int nev3 = v->ne[3];
7966
+ //const int64_t nev0 = v->ne[0];
7967
+ const int64_t nev1 = v->ne[1];
7968
+ //const int64_t nev2 = v->ne[2];
7969
+ //const int64_t nev3 = v->ne[3];
7566
7970
 
7567
- const int ne0 = dst->ne[0];
7568
- const int ne1 = dst->ne[1];
7569
- //const int ne2 = dst->ne[2];
7570
- //const int ne3 = dst->ne[3];
7971
+ const int64_t ne0 = dst->ne[0];
7972
+ const int64_t ne1 = dst->ne[1];
7973
+ //const int64_t ne2 = dst->ne[2];
7974
+ //const int64_t ne3 = dst->ne[3];
7571
7975
 
7572
7976
  const int nbk0 = k->nb[0];
7573
7977
  const int nbk1 = k->nb[1];
@@ -7592,10 +7996,10 @@ static void ggml_compute_forward_flash_attn_f32(
7592
7996
  const int ith = params->ith;
7593
7997
  const int nth = params->nth;
7594
7998
 
7595
- const int D = neq0;
7596
- const int N = neq1;
7597
- const int P = nek1 - N;
7598
- const int M = P + N;
7999
+ const int64_t D = neq0;
8000
+ const int64_t N = neq1;
8001
+ const int64_t P = nek1 - N;
8002
+ const int64_t M = P + N;
7599
8003
 
7600
8004
  const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
7601
8005
 
@@ -7657,7 +8061,7 @@ static void ggml_compute_forward_flash_attn_f32(
7657
8061
  S[i] = -INFINITY;
7658
8062
  }
7659
8063
 
7660
- for (int ic = 0; ic < nek1; ++ic) {
8064
+ for (int64_t ic = 0; ic < nek1; ++ic) {
7661
8065
  // k indices
7662
8066
  const int ik3 = iq3;
7663
8067
  const int ik2 = iq2;
@@ -7676,7 +8080,7 @@ static void ggml_compute_forward_flash_attn_f32(
7676
8080
  ggml_vec_scale_f32(nek1, S, scale);
7677
8081
 
7678
8082
  if (masked) {
7679
- for (int i = P; i < M; i++) {
8083
+ for (int64_t i = P; i < M; i++) {
7680
8084
  if (i > P + iq1) {
7681
8085
  S[i] = -INFINITY;
7682
8086
  }
@@ -7734,7 +8138,7 @@ static void ggml_compute_forward_flash_attn_f32(
7734
8138
  #endif
7735
8139
  }
7736
8140
 
7737
- for (int ic = 0; ic < nev1; ++ic) {
8141
+ for (int64_t ic = 0; ic < nev1; ++ic) {
7738
8142
  // dst indices
7739
8143
  const int i1 = iq1;
7740
8144
  const int i2 = iq2;
@@ -7758,25 +8162,25 @@ static void ggml_compute_forward_flash_attn_f16(
7758
8162
  int64_t t0 = ggml_perf_time_us();
7759
8163
  UNUSED(t0);
7760
8164
 
7761
- const int neq0 = q->ne[0];
7762
- const int neq1 = q->ne[1];
7763
- const int neq2 = q->ne[2];
7764
- const int neq3 = q->ne[3];
8165
+ const int64_t neq0 = q->ne[0];
8166
+ const int64_t neq1 = q->ne[1];
8167
+ const int64_t neq2 = q->ne[2];
8168
+ const int64_t neq3 = q->ne[3];
7765
8169
 
7766
- const int nek0 = k->ne[0];
7767
- const int nek1 = k->ne[1];
7768
- //const int nek2 = k->ne[2];
7769
- //const int nek3 = k->ne[3];
8170
+ const int64_t nek0 = k->ne[0];
8171
+ const int64_t nek1 = k->ne[1];
8172
+ //const int64_t nek2 = k->ne[2];
8173
+ //const int64_t nek3 = k->ne[3];
7770
8174
 
7771
- //const int nev0 = v->ne[0];
7772
- const int nev1 = v->ne[1];
7773
- //const int nev2 = v->ne[2];
7774
- //const int nev3 = v->ne[3];
8175
+ //const int64_t nev0 = v->ne[0];
8176
+ const int64_t nev1 = v->ne[1];
8177
+ //const int64_t nev2 = v->ne[2];
8178
+ //const int64_t nev3 = v->ne[3];
7775
8179
 
7776
- const int ne0 = dst->ne[0];
7777
- const int ne1 = dst->ne[1];
7778
- //const int ne2 = dst->ne[2];
7779
- //const int ne3 = dst->ne[3];
8180
+ const int64_t ne0 = dst->ne[0];
8181
+ const int64_t ne1 = dst->ne[1];
8182
+ //const int64_t ne2 = dst->ne[2];
8183
+ //const int64_t ne3 = dst->ne[3];
7780
8184
 
7781
8185
  const int nbk0 = k->nb[0];
7782
8186
  const int nbk1 = k->nb[1];
@@ -7801,10 +8205,10 @@ static void ggml_compute_forward_flash_attn_f16(
7801
8205
  const int ith = params->ith;
7802
8206
  const int nth = params->nth;
7803
8207
 
7804
- const int D = neq0;
7805
- const int N = neq1;
7806
- const int P = nek1 - N;
7807
- const int M = P + N;
8208
+ const int64_t D = neq0;
8209
+ const int64_t N = neq1;
8210
+ const int64_t P = nek1 - N;
8211
+ const int64_t M = P + N;
7808
8212
 
7809
8213
  const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
7810
8214
 
@@ -7867,7 +8271,7 @@ static void ggml_compute_forward_flash_attn_f16(
7867
8271
  }
7868
8272
 
7869
8273
  if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
7870
- for (int ic = 0; ic < nek1; ++ic) {
8274
+ for (int64_t ic = 0; ic < nek1; ++ic) {
7871
8275
  // k indices
7872
8276
  const int ik3 = iq3;
7873
8277
  const int ik2 = iq2;
@@ -7882,7 +8286,7 @@ static void ggml_compute_forward_flash_attn_f16(
7882
8286
  (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
7883
8287
  }
7884
8288
  } else {
7885
- for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
8289
+ for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
7886
8290
  // k indices
7887
8291
  const int ik3 = iq3;
7888
8292
  const int ik2 = iq2;
@@ -7902,7 +8306,7 @@ static void ggml_compute_forward_flash_attn_f16(
7902
8306
  ggml_vec_scale_f32(nek1, S, scale);
7903
8307
 
7904
8308
  if (masked) {
7905
- for (int i = P; i < M; i++) {
8309
+ for (int64_t i = P; i < M; i++) {
7906
8310
  if (i > P + iq1) {
7907
8311
  S[i] = -INFINITY;
7908
8312
  }
@@ -7962,12 +8366,12 @@ static void ggml_compute_forward_flash_attn_f16(
7962
8366
 
7963
8367
  ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
7964
8368
 
7965
- for (int i = 0; i < M; i++) {
8369
+ for (int64_t i = 0; i < M; i++) {
7966
8370
  S16[i] = GGML_FP32_TO_FP16(S[i]);
7967
8371
  }
7968
8372
 
7969
8373
  if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
7970
- for (int ic = 0; ic < nev1; ++ic) {
8374
+ for (int64_t ic = 0; ic < nev1; ++ic) {
7971
8375
  // dst indices
7972
8376
  const int i1 = iq1;
7973
8377
  const int i2 = iq2;
@@ -7979,7 +8383,7 @@ static void ggml_compute_forward_flash_attn_f16(
7979
8383
  S16);
7980
8384
  }
7981
8385
  } else {
7982
- for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
8386
+ for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
7983
8387
  // dst indices
7984
8388
  const int i1 = iq1;
7985
8389
  const int i2 = iq2;
@@ -8035,35 +8439,35 @@ static void ggml_compute_forward_flash_ff_f16(
8035
8439
  int64_t t0 = ggml_perf_time_us();
8036
8440
  UNUSED(t0);
8037
8441
 
8038
- const int nea0 = a->ne[0];
8039
- const int nea1 = a->ne[1];
8040
- const int nea2 = a->ne[2];
8041
- const int nea3 = a->ne[3];
8442
+ const int64_t nea0 = a->ne[0];
8443
+ const int64_t nea1 = a->ne[1];
8444
+ const int64_t nea2 = a->ne[2];
8445
+ const int64_t nea3 = a->ne[3];
8042
8446
 
8043
- const int neb00 = b0->ne[0];
8044
- const int neb01 = b0->ne[1];
8045
- //const int neb02 = b0->ne[2];
8046
- //const int neb03 = b0->ne[3];
8447
+ const int64_t neb00 = b0->ne[0];
8448
+ const int64_t neb01 = b0->ne[1];
8449
+ //const int64_t neb02 = b0->ne[2];
8450
+ //const int64_t neb03 = b0->ne[3];
8047
8451
 
8048
- const int neb10 = b1->ne[0];
8049
- const int neb11 = b1->ne[1];
8050
- //const int neb12 = b1->ne[2];
8051
- //const int neb13 = b1->ne[3];
8452
+ const int64_t neb10 = b1->ne[0];
8453
+ const int64_t neb11 = b1->ne[1];
8454
+ //const int64_t neb12 = b1->ne[2];
8455
+ //const int64_t neb13 = b1->ne[3];
8052
8456
 
8053
- const int nec00 = c0->ne[0];
8054
- const int nec01 = c0->ne[1];
8055
- //const int nec02 = c0->ne[2];
8056
- //const int nec03 = c0->ne[3];
8457
+ const int64_t nec00 = c0->ne[0];
8458
+ const int64_t nec01 = c0->ne[1];
8459
+ //const int64_t nec02 = c0->ne[2];
8460
+ //const int64_t nec03 = c0->ne[3];
8057
8461
 
8058
- const int nec10 = c1->ne[0];
8059
- const int nec11 = c1->ne[1];
8060
- //const int nec12 = c1->ne[2];
8061
- //const int nec13 = c1->ne[3];
8462
+ const int64_t nec10 = c1->ne[0];
8463
+ const int64_t nec11 = c1->ne[1];
8464
+ //const int64_t nec12 = c1->ne[2];
8465
+ //const int64_t nec13 = c1->ne[3];
8062
8466
 
8063
- const int ne0 = dst->ne[0];
8064
- const int ne1 = dst->ne[1];
8065
- const int ne2 = dst->ne[2];
8066
- //const int ne3 = dst->ne[3];
8467
+ const int64_t ne0 = dst->ne[0];
8468
+ const int64_t ne1 = dst->ne[1];
8469
+ const int64_t ne2 = dst->ne[2];
8470
+ //const int64_t ne3 = dst->ne[3];
8067
8471
 
8068
8472
  const int nba0 = a->nb[0];
8069
8473
  const int nba1 = a->nb[1];
@@ -8098,9 +8502,9 @@ static void ggml_compute_forward_flash_ff_f16(
8098
8502
  const int ith = params->ith;
8099
8503
  const int nth = params->nth;
8100
8504
 
8101
- const int D = nea0;
8102
- //const int N = nea1;
8103
- const int M = neb01;
8505
+ const int64_t D = nea0;
8506
+ //const int64_t N = nea1;
8507
+ const int64_t M = neb01;
8104
8508
 
8105
8509
  GGML_ASSERT(ne0 == nea0);
8106
8510
  GGML_ASSERT(ne1 == nea1);
@@ -8156,7 +8560,7 @@ static void ggml_compute_forward_flash_ff_f16(
8156
8560
 
8157
8561
  float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
8158
8562
 
8159
- for (int ic = 0; ic < neb01; ++ic) {
8563
+ for (int64_t ic = 0; ic < neb01; ++ic) {
8160
8564
  // b0 indices
8161
8565
  const int ib03 = ia3;
8162
8566
  const int ib02 = ia2;
@@ -8176,7 +8580,7 @@ static void ggml_compute_forward_flash_ff_f16(
8176
8580
 
8177
8581
  ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
8178
8582
 
8179
- for (int i = 0; i < M; i++) {
8583
+ for (int64_t i = 0; i < M; i++) {
8180
8584
  S16[i] = GGML_FP32_TO_FP16(S[i]);
8181
8585
  }
8182
8586
 
@@ -8188,7 +8592,7 @@ static void ggml_compute_forward_flash_ff_f16(
8188
8592
  const int i2 = ia2;
8189
8593
  const int i3 = ia3;
8190
8594
 
8191
- for (int ic = 0; ic < nec01; ++ic) {
8595
+ for (int64_t ic = 0; ic < nec01; ++ic) {
8192
8596
 
8193
8597
  ggml_vec_dot_f16(neb01,
8194
8598
  (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
@@ -9053,7 +9457,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9053
9457
  } break;
9054
9458
  case GGML_OP_ROPE:
9055
9459
  {
9056
- node->n_tasks = 1;
9460
+ node->n_tasks = n_threads;
9057
9461
  } break;
9058
9462
  case GGML_OP_CONV_1D_1S:
9059
9463
  case GGML_OP_CONV_1D_2S:
@@ -9091,7 +9495,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9091
9495
 
9092
9496
  size_t cur = 0;
9093
9497
 
9094
- const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
9498
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
9095
9499
 
9096
9500
  if (node->src1->type == GGML_TYPE_F32) {
9097
9501
  cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
@@ -9350,7 +9754,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9350
9754
 
9351
9755
  perf_total_per_op_us[node->op] += node->perf_time_us;
9352
9756
 
9353
- GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
9757
+ GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
9354
9758
  i,
9355
9759
  node->ne[0], node->ne[1], node->ne[2],
9356
9760
  GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -9364,7 +9768,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9364
9768
  for (int i = 0; i < cgraph->n_leafs; i++) {
9365
9769
  struct ggml_tensor * node = cgraph->leafs[i];
9366
9770
 
9367
- GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n",
9771
+ GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
9368
9772
  i,
9369
9773
  node->ne[0], node->ne[1],
9370
9774
  GGML_OP_LABEL[node->op]);
@@ -9435,7 +9839,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
9435
9839
 
9436
9840
  fprintf(fp, " \"%p\" [ \
9437
9841
  style = filled; fillcolor = %s; shape = record; \
9438
- label=\"%d [%d, %d] | <x>%s",
9842
+ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
9439
9843
  (void *) node, color,
9440
9844
  i, node->ne[0], node->ne[1],
9441
9845
  GGML_OP_SYMBOL[node->op]);
@@ -9460,7 +9864,7 @@ label=\"<x>%.1e\"; ]\n",
9460
9864
  } else {
9461
9865
  fprintf(fp, " \"%p\" [ \
9462
9866
  style = filled; fillcolor = %s; shape = record; \
9463
- label=\"<x>CONST %d [%d, %d]\"; ]\n",
9867
+ label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
9464
9868
  (void *) node, color,
9465
9869
  i, node->ne[0], node->ne[1]);
9466
9870
  }
@@ -9524,9 +9928,9 @@ label=\"<x>CONST %d [%d, %d]\"; ]\n",
9524
9928
  static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
9525
9929
  int i = 0;
9526
9930
  for (int p = 0; p < np; ++p) {
9527
- const int ne = ggml_nelements(ps[p]) ;
9931
+ const int64_t ne = ggml_nelements(ps[p]) ;
9528
9932
  // TODO: add function to set tensor from array
9529
- for (int j = 0; j < ne; ++j) {
9933
+ for (int64_t j = 0; j < ne; ++j) {
9530
9934
  ggml_set_f32_1d(ps[p], j, x[i++]);
9531
9935
  }
9532
9936
  }
@@ -9535,9 +9939,9 @@ static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const f
9535
9939
  static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
9536
9940
  int i = 0;
9537
9941
  for (int p = 0; p < np; ++p) {
9538
- const int ne = ggml_nelements(ps[p]) ;
9942
+ const int64_t ne = ggml_nelements(ps[p]) ;
9539
9943
  // TODO: add function to get all elements at once
9540
- for (int j = 0; j < ne; ++j) {
9944
+ for (int64_t j = 0; j < ne; ++j) {
9541
9945
  x[i++] = ggml_get_f32_1d(ps[p], j);
9542
9946
  }
9543
9947
  }
@@ -9546,9 +9950,9 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
9546
9950
  static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
9547
9951
  int i = 0;
9548
9952
  for (int p = 0; p < np; ++p) {
9549
- const int ne = ggml_nelements(ps[p]) ;
9953
+ const int64_t ne = ggml_nelements(ps[p]) ;
9550
9954
  // TODO: add function to get all elements at once
9551
- for (int j = 0; j < ne; ++j) {
9955
+ for (int64_t j = 0; j < ne; ++j) {
9552
9956
  g[i++] = ggml_get_f32_1d(ps[p]->grad, j);
9553
9957
  }
9554
9958
  }
@@ -10146,6 +10550,7 @@ enum ggml_opt_result ggml_opt(
10146
10550
  struct ggml_init_params params_ctx = {
10147
10551
  .mem_size = 16*1024*1024,
10148
10552
  .mem_buffer = NULL,
10553
+ .no_alloc = false,
10149
10554
  };
10150
10555
 
10151
10556
  ctx = ggml_init(params_ctx);