llama_cpp 0.0.1 → 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -16,6 +16,7 @@
16
16
  #include <stdlib.h>
17
17
  #include <string.h>
18
18
  #include <stdint.h>
19
+ #include <inttypes.h>
19
20
  #include <stdio.h>
20
21
  #include <float.h>
21
22
 
@@ -79,6 +80,19 @@ static int sched_yield (void) {
79
80
  typedef void* thread_ret_t;
80
81
  #endif
81
82
 
83
+ // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
84
+ #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
85
+ #ifndef __FMA__
86
+ #define __FMA__
87
+ #endif
88
+ #ifndef __F16C__
89
+ #define __F16C__
90
+ #endif
91
+ #ifndef __SSE3__
92
+ #define __SSE3__
93
+ #endif
94
+ #endif
95
+
82
96
  #ifdef __HAIKU__
83
97
  #define static_assert(cond, msg) _Static_assert(cond, msg)
84
98
  #endif
@@ -172,8 +186,13 @@ typedef double ggml_float;
172
186
 
173
187
  #ifdef __F16C__
174
188
 
189
+ #ifdef _MSC_VER
190
+ #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
191
+ #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
192
+ #else
175
193
  #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
176
194
  #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
195
+ #endif
177
196
 
178
197
  #elif defined(__POWER9_VECTOR__)
179
198
 
@@ -443,6 +462,39 @@ static inline __m128i packNibbles( __m256i bytes )
443
462
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
444
463
  return _mm_packus_epi16( r0, r1 );
445
464
  }
465
+ #elif __AVX__
466
+ static inline __m128i bytesFromNibbles( const uint8_t* rsi )
467
+ {
468
+ // Load 8 bytes from memory
469
+ __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
470
+
471
+ // Expand bytes into uint16_t values
472
+ __m128i bytes = _mm_cvtepu8_epi16( tmp );
473
+
474
+ // Unpack values into individual bytes
475
+ const __m128i lowMask = _mm_set1_epi8( 0xF );
476
+ __m128i high = _mm_andnot_si128( lowMask, bytes );
477
+ __m128i low = _mm_and_si128( lowMask, bytes );
478
+ high = _mm_slli_epi16( high, 4 );
479
+ bytes = _mm_or_si128( low, high );
480
+ return bytes;
481
+ }
482
+
483
+ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
484
+ {
485
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
486
+ const __m128i lowByte = _mm_set1_epi16( 0xFF );
487
+ __m128i high = _mm_andnot_si128( lowByte, bytes1 );
488
+ __m128i low = _mm_and_si128( lowByte, bytes1 );
489
+ high = _mm_srli_epi16( high, 4 );
490
+ bytes1 = _mm_or_si128( low, high );
491
+ high = _mm_andnot_si128( lowByte, bytes2 );
492
+ low = _mm_and_si128( lowByte, bytes2 );
493
+ high = _mm_srli_epi16( high, 4 );
494
+ bytes2 = _mm_or_si128( low, high );
495
+
496
+ return _mm_packus_epi16( bytes1, bytes2);
497
+ }
446
498
  #endif
447
499
 
448
500
  // method 5
@@ -491,8 +543,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
491
543
  const uint8_t vi0 = (int8_t)roundf(v0) + 8;
492
544
  const uint8_t vi1 = (int8_t)roundf(v1) + 8;
493
545
 
494
- assert(vi0 >= 0 && vi0 < 16);
495
- assert(vi1 >= 0 && vi1 < 16);
546
+ assert(vi0 < 16);
547
+ assert(vi1 < 16);
496
548
 
497
549
  pp[l/2] = vi0 | (vi1 << 4);
498
550
  }
@@ -546,10 +598,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
546
598
  }
547
599
  }
548
600
  #elif __ARM_NEON
549
- uint8_t pp[QK/2];
550
601
  for (int i = 0; i < nb; i++) {
551
- float amax = 0.0f; // absolute max
552
-
553
602
  float32x4_t srcv [8];
554
603
  float32x4_t asrcv[8];
555
604
  float32x4_t amaxv[8];
@@ -561,7 +610,8 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
561
610
  for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
562
611
  for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
563
612
 
564
- amax = MAX(
613
+ // absolute max
614
+ const float amax = MAX(
565
615
  MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
566
616
  MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
567
617
 
@@ -575,11 +625,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
575
625
  const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
576
626
  const int32x4_t vi = vcvtq_s32_f32(vf);
577
627
 
578
- pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
579
- pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
628
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
629
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
580
630
  }
581
-
582
- memcpy(y[i].qs, pp, sizeof(pp));
583
631
  }
584
632
  #elif defined(__AVX2__)
585
633
  for (int i = 0; i < nb; i++) {
@@ -646,8 +694,81 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
646
694
  __m128i res = packNibbles( i0 );
647
695
  _mm_storeu_si128( ( __m128i* )y[i].qs, res );
648
696
  }
697
+ #elif defined(__AVX__)
698
+ for (int i = 0; i < nb; i++) {
699
+ // Load elements into 4 AVX vectors
700
+ __m256 v0 = _mm256_loadu_ps( x );
701
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
702
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
703
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
704
+ x += 32;
705
+
706
+ // Compute max(abs(e)) for the block
707
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
708
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
709
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
710
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
711
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
712
+
713
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
714
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
715
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
716
+ const float maxScalar = _mm_cvtss_f32( max4 );
717
+
718
+ // Quantize these floats
719
+ const float d = maxScalar / 7.0f;
720
+ y[i].d = d;
721
+ const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
722
+ const __m256 mul = _mm256_set1_ps( id );
723
+
724
+ // Apply the multiplier
725
+ v0 = _mm256_mul_ps( v0, mul );
726
+ v1 = _mm256_mul_ps( v1, mul );
727
+ v2 = _mm256_mul_ps( v2, mul );
728
+ v3 = _mm256_mul_ps( v3, mul );
729
+
730
+ // Round to nearest integer
731
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
732
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
733
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
734
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
735
+
736
+ // Convert floats to integers
737
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
738
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
739
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
740
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
741
+
742
+ // Since we don't have in AVX some necessary functions,
743
+ // we split the registers in half and call AVX2 analogs from SSE
744
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
745
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
746
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
747
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
748
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
749
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
750
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
751
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
752
+
753
+ // Convert int32 to int16
754
+ ni0 = _mm_packs_epi32( ni0, ni1 );
755
+ ni2 = _mm_packs_epi32( ni2, ni3 );
756
+ ni4 = _mm_packs_epi32( ni4, ni5 );
757
+ ni6 = _mm_packs_epi32( ni6, ni7 );
758
+ // Convert int16 to int8
759
+ ni0 = _mm_packs_epi16( ni0, ni2 );
760
+ ni4 = _mm_packs_epi16( ni4, ni6 );
761
+
762
+ // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
763
+ const __m128i off = _mm_set1_epi8( 8);
764
+ ni0 = _mm_add_epi8( ni0, off );
765
+ ni4 = _mm_add_epi8( ni4, off );
766
+
767
+ // Compress the vector into 4 bit/value, and store
768
+ __m128i res = packNibbles( ni0, ni4 );
769
+ _mm_storeu_si128( ( __m128i* )y[i].qs, res );
770
+ }
649
771
  #elif defined(__wasm_simd128__)
650
- uint8_t pp[QK/2];
651
772
  for (int i = 0; i < nb; i++) {
652
773
  float amax = 0.0f; // absolute max
653
774
 
@@ -676,11 +797,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
676
797
  const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
677
798
  const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
678
799
 
679
- pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
680
- pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
800
+ y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
801
+ y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
681
802
  }
682
-
683
- memcpy(y[i].qs, pp, sizeof(pp));
684
803
  }
685
804
  #else
686
805
  // scalar
@@ -719,8 +838,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
719
838
  const uint8_t vi0 = roundf(v0);
720
839
  const uint8_t vi1 = roundf(v1);
721
840
 
722
- assert(vi0 >= 0 && vi0 < 16);
723
- assert(vi1 >= 0 && vi1 < 16);
841
+ assert(vi0 < 16);
842
+ assert(vi1 < 16);
724
843
 
725
844
  pp[l/2] = vi0 | (vi1 << 4);
726
845
  }
@@ -732,11 +851,11 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
732
851
  static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
733
852
  assert(k % QK == 0);
734
853
 
735
- #if defined(__AVX2__)
736
854
  const int nb = k / QK;
737
855
 
738
856
  block_q4_1 * restrict y = vy;
739
857
 
858
+ #if defined(__AVX2__)
740
859
  for (int i = 0; i < nb; i++) {
741
860
  // Load elements into 4 AVX vectors
742
861
  __m256 v0 = _mm256_loadu_ps( x );
@@ -810,6 +929,41 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
810
929
  __m128i res = packNibbles( i0 );
811
930
  _mm_storeu_si128( ( __m128i* )y[i].qs, res );
812
931
  }
932
+ #elif __ARM_NEON
933
+ for (int i = 0; i < nb; i++) {
934
+ float32x4_t srcv[8];
935
+ float32x4_t minv[8];
936
+ float32x4_t maxv[8];
937
+
938
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
939
+
940
+ for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
941
+ for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
942
+ for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]);
943
+
944
+ for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]);
945
+ for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]);
946
+ for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]);
947
+
948
+ const float min = vminvq_f32(minv[0]);
949
+ const float max = vmaxvq_f32(maxv[0]);
950
+
951
+ const float d = (max - min) / ((1 << 4) - 1);
952
+ const float id = d ? 1.0f/d : 0.0f;
953
+
954
+ y[i].d = d;
955
+ y[i].m = min;
956
+
957
+ const float32x4_t minv0 = vdupq_n_f32(min);
958
+
959
+ for (int l = 0; l < 8; l++) {
960
+ const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
961
+ const int32x4_t vi = vcvtq_s32_f32(v);
962
+
963
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
964
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
965
+ }
966
+ }
813
967
  #else
814
968
  // scalar
815
969
  quantize_row_q4_1_reference(x, vy, k);
@@ -970,6 +1124,50 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
970
1124
  }
971
1125
  }
972
1126
  }
1127
+ #elif defined(__ARM_NEON)
1128
+ for (int i = 0; i < nb; i++) {
1129
+ const float32x4_t vd = vdupq_n_f32(x[i].d);
1130
+ const float32x4_t vm = vdupq_n_f32(x[i].m);
1131
+
1132
+ const uint8_t * restrict pp = x[i].qs;
1133
+
1134
+ for (int l = 0; l < QK; l += 16) {
1135
+ // Load 16x4-bit integers into 8x8-bit integers
1136
+ const uint8x8_t v8 = vld1_u8(pp + l/2);
1137
+
1138
+ // Expand 4-bit qs to 8-bit bytes
1139
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1140
+ const uint8x8_t v1 = vshr_n_u8(v8, 4);
1141
+
1142
+ // Interleave and combine
1143
+ const uint8x8_t vx_0 = vzip1_u8(v0, v1);
1144
+ const uint8x8_t vx_1 = vzip2_u8(v0, v1);
1145
+
1146
+ const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
1147
+
1148
+ // convert to 2x uint16x8_t
1149
+ const uint16x8_t vi_0 = vmovl_u8(vget_low_u8 (vq));
1150
+ const uint16x8_t vi_1 = vmovl_u8(vget_high_u8(vq));
1151
+
1152
+ // convert to 4x float32x4_t
1153
+ const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
1154
+ const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0)));
1155
+ const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1)));
1156
+ const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1)));
1157
+
1158
+ // multiply by d and add m
1159
+ const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd);
1160
+ const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd);
1161
+ const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd);
1162
+ const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
1163
+
1164
+ // Store
1165
+ vst1q_f32(y + i*QK + l + 0, r0);
1166
+ vst1q_f32(y + i*QK + l + 4, r1);
1167
+ vst1q_f32(y + i*QK + l + 8, r2);
1168
+ vst1q_f32(y + i*QK + l + 12, r3);
1169
+ }
1170
+ }
973
1171
  #else
974
1172
  for (int i = 0; i < nb; i++) {
975
1173
  const float d = x[i].d;
@@ -1207,7 +1405,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1207
1405
  _mm256_storeu_ps(arr, y);
1208
1406
 
1209
1407
  for (int i = 0; i < 8; i++)
1210
- x[i] = GGML_FP16_TO_FP32(arr[i]);
1408
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1211
1409
  }
1212
1410
  #define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
1213
1411
  #define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
@@ -1636,7 +1834,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1636
1834
  const block_q4_0 * restrict x = vx;
1637
1835
  const block_q4_0 * restrict y = vy;
1638
1836
 
1639
- ggml_float sumf = 0.0;
1837
+ float sumf = 0.0;
1640
1838
 
1641
1839
  #if defined(__ARM_NEON)
1642
1840
  float sum0 = 0.0f;
@@ -1731,7 +1929,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1731
1929
  #endif
1732
1930
  }
1733
1931
 
1734
- sumf = (ggml_float)(sum0 + sum1);
1932
+ sumf = sum0 + sum1;
1735
1933
  #elif defined(__AVX512F__)
1736
1934
  // Initialize accumulator with zeros
1737
1935
  __m512 acc0 = _mm512_setzero_ps();
@@ -1739,7 +1937,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1739
1937
 
1740
1938
  const int superblock_size = 8;
1741
1939
  const int superblock_count = nb / superblock_size;
1742
- const int remainder = nb % superblock_size;
1743
1940
 
1744
1941
  for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
1745
1942
  int i = superblock_ix * superblock_size;
@@ -1765,36 +1962,116 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1765
1962
  // Initialize accumulator with zeros
1766
1963
  __m256 acc = _mm256_setzero_ps();
1767
1964
 
1965
+ /* Prepare the constants we will need during execution */
1966
+ const __m256i lowMask = _mm256_set1_epi8( 0xF );
1967
+ const __m256i offset_8 = _mm256_set1_epi16( 8 );
1968
+
1969
+ #define UNROLL_COUNT 8
1970
+ // make sure we only unroll multiples of the block count
1971
+ assert(nb % UNROLL_COUNT == 0);
1972
+
1973
+ // Main loop
1974
+ for (int i = 0; i < nb; i+=UNROLL_COUNT) {
1975
+
1976
+ // This loop will be unrolled by the compiler
1977
+ for (int u=0;u<UNROLL_COUNT;u++) {
1978
+ /* Compute combined scale for the block */
1979
+ const __m256 scale = _mm256_mul_ps(
1980
+ _mm256_broadcast_ss( &x[i+u].d ),
1981
+ _mm256_broadcast_ss( &y[i+u].d ) );
1982
+
1983
+ /* get input from x
1984
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
1985
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1986
+
1987
+ /* Load 16 bytes from memory */
1988
+ const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
1989
+ /* Expand bytes into uint16_t values */
1990
+ const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
1991
+ /* Unpack values into individual bytes */
1992
+ __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
1993
+ const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
1994
+ __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
1995
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1996
+ x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
1997
+ x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
1998
+
1999
+ /* get input from y
2000
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
2001
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2002
+
2003
+ /* Load 16 bytes from memory */
2004
+ const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2005
+ /* Expand bytes into uint16_t values */
2006
+ const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2007
+ /* Unpack values into individual bytes */
2008
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2009
+ __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2010
+ __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2011
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2012
+ y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2013
+ y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2014
+
2015
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
2016
+ __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2017
+ __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2018
+
2019
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2020
+ __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2021
+
2022
+ /* Convert to vectore of 8 int32_t to 8 floats */
2023
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2024
+
2025
+ /* Multiply q with scale and accumulate */
2026
+ acc = _mm256_fmadd_ps( scale, q, acc );
2027
+ }
2028
+
2029
+ }
2030
+
2031
+ // Return horizontal sum of the acc vector
2032
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2033
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2034
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2035
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2036
+
2037
+ sumf = _mm_cvtss_f32( res );
2038
+ #elif defined(__AVX__)
2039
+ // Initialize accumulator with zeros
2040
+ __m256 acc = _mm256_setzero_ps();
2041
+
1768
2042
  // Main loop
1769
2043
  for (int i = 0; i < nb; ++i) {
1770
2044
  // Compute combined scale for the block
1771
2045
  const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
1772
2046
 
1773
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1774
- __m256i bx = bytesFromNibbles( x[i].qs );
1775
- __m256i by = bytesFromNibbles( y[i].qs );
2047
+ __m128i i32[2];
2048
+ for (int j = 0; j < 2; ++j) {
2049
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2050
+ __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2051
+ __m128i by = bytesFromNibbles( y[i].qs + 8*j );
1776
2052
 
1777
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1778
- const __m256i off = _mm256_set1_epi8( 8 );
1779
- bx = _mm256_sub_epi8( bx, off );
1780
- by = _mm256_sub_epi8( by, off );
2053
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2054
+ const __m128i off = _mm_set1_epi8( 8 );
2055
+ bx = _mm_sub_epi8( bx, off );
2056
+ by = _mm_sub_epi8( by, off );
1781
2057
 
1782
- // Sign-extend first 16 signed bytes into int16_t
1783
- __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
1784
- __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
1785
- // Compute products of int16_t integers, add pairwise
1786
- __m256i i32 = _mm256_madd_epi16( x16, y16 );
2058
+ // Get absolute values of x vectors
2059
+ const __m128i ax = _mm_sign_epi8(bx, bx);
1787
2060
 
1788
- // Sign-extend last 16 signed bytes into int16_t vectors
1789
- x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
1790
- y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
1791
- // Accumulate products of int16_t integers
1792
- i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
2061
+ // Sign the values of the y vectors
2062
+ const __m128i sy = _mm_sign_epi8(by, bx);
2063
+
2064
+ // Perform multiplication and create 16-bit values
2065
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
2066
+
2067
+ const __m128i ones = _mm_set1_epi16(1);
2068
+ i32[j] = _mm_madd_epi16(ones, dot);
2069
+ }
1793
2070
 
1794
2071
  // Convert int32_t to float
1795
- __m256 p = _mm256_cvtepi32_ps( i32 );
2072
+ __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
1796
2073
  // Apply the scale, and accumulate
1797
- acc = _mm256_fmadd_ps( d, p, acc );
2074
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
1798
2075
  }
1799
2076
 
1800
2077
  // Return horizontal sum of the acc vector
@@ -1944,7 +2221,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
1944
2221
  // Compute cross scales for the block
1945
2222
  const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
1946
2223
  const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
1947
- const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 );
2224
+ const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
1948
2225
 
1949
2226
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1950
2227
  __m256i bx = bytesFromNibbles( x[i].qs );
@@ -1990,6 +2267,45 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
1990
2267
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
1991
2268
 
1992
2269
  sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
2270
+ #elif defined(__ARM_NEON)
2271
+ float sum00 = 0.0f;
2272
+ float sum01 = 0.0f;
2273
+ float sum10 = 0.0f;
2274
+ float sum11 = 0.0f;
2275
+
2276
+ for (int i = 0; i < nb; ++i) {
2277
+ const block_q4_1 * restrict x0 = &x[i + 0];
2278
+ const block_q4_1 * restrict y0 = &y[i + 0];
2279
+
2280
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2281
+
2282
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2283
+ const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2284
+
2285
+ // and with 0xf
2286
+ const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2287
+ const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2288
+
2289
+ const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2290
+ const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2291
+
2292
+ // dot product into uint16x8_t
2293
+ const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2294
+ const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2295
+
2296
+ const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2297
+ const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2298
+
2299
+ const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2300
+ const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
2301
+
2302
+ sum00 += x0->m*y0->m;
2303
+ sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2304
+ sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2305
+ sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
2306
+ }
2307
+
2308
+ sumf = QK*sum00 + sum01 + sum10 + sum11;
1993
2309
  #else
1994
2310
  // scalar
1995
2311
  for (int i = 0; i < nb; i++) {
@@ -2401,8 +2717,9 @@ struct ggml_context {
2401
2717
  void * mem_buffer;
2402
2718
  bool mem_buffer_owned;
2403
2719
  bool mem_buffer_mlocked;
2720
+ bool no_alloc;
2404
2721
 
2405
- int n_objects;
2722
+ int n_objects;
2406
2723
 
2407
2724
  struct ggml_object * objects_begin;
2408
2725
  struct ggml_object * objects_end;
@@ -2487,7 +2804,7 @@ void ggml_print_objects(const struct ggml_context * ctx) {
2487
2804
  GGML_PRINT("%s: --- end ---\n", __func__);
2488
2805
  }
2489
2806
 
2490
- int ggml_nelements(const struct ggml_tensor * tensor) {
2807
+ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
2491
2808
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
2492
2809
 
2493
2810
  return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@@ -2619,6 +2936,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2619
2936
  static bool is_first_call = true;
2620
2937
 
2621
2938
  if (is_first_call) {
2939
+ // initialize time system (required on Windows)
2940
+ ggml_time_init();
2941
+
2622
2942
  // initialize GELU, SILU and EXP F32 tables
2623
2943
  {
2624
2944
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
@@ -2684,6 +3004,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2684
3004
  /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
2685
3005
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
2686
3006
  /*.mem_buffer_mlocked =*/ false,
3007
+ /*.no_alloc =*/ params.no_alloc,
2687
3008
  /*.n_objects =*/ 0,
2688
3009
  /*.objects_begin =*/ NULL,
2689
3010
  /*.objects_end =*/ NULL,
@@ -2751,36 +3072,47 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
2751
3072
  return result;
2752
3073
  }
2753
3074
 
3075
+ #ifdef __APPLE__
3076
+ #define MLOCK_SUGGESTION \
3077
+ "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
3078
+ "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
3079
+ #else
3080
+ #define MLOCK_SUGGESTION \
3081
+ "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
3082
+ #endif
3083
+
2754
3084
  bool ggml_mlock_supported(void) {
2755
3085
  return GGML_MLOCK_SUPPORT;
2756
3086
  }
2757
3087
 
3088
+ bool ggml_mlock(
3089
+ struct ggml_context * ctx,
3090
+ const void *opt_extra_addr,
3091
+ size_t opt_extra_len,
3092
+ char **err_p) {
3093
+ // TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
2758
3094
  #if GGML_MLOCK_SUPPORT
2759
- #ifdef __APPLE__
2760
- #define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \
2761
- "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l)."
2762
- #else
2763
- #define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)."
2764
- #endif
2765
- bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
2766
3095
  if (ctx->mem_buffer_mlocked) {
2767
3096
  return true;
2768
3097
  }
2769
- if (mlock(ctx->mem_buffer, ctx->mem_size)) {
2770
- int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
2771
- ctx->mem_size, strerror(errno));
2772
- GGML_ASSERT(ret >= 0);
3098
+ if (mlock(ctx->mem_buffer, ctx->mem_size) ||
3099
+ (opt_extra_len &&
3100
+ mlock(opt_extra_addr, opt_extra_len))) {
3101
+ if ((*err_p = malloc(1024))) {
3102
+ snprintf(*err_p, 1024,
3103
+ "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
3104
+ ctx->mem_size + opt_extra_len,
3105
+ strerror(errno));
3106
+ }
2773
3107
  return false;
2774
3108
  }
2775
3109
  ctx->mem_buffer_mlocked = true;
2776
3110
  return true;
2777
- }
2778
3111
  #else // GGML_MLOCK_SUPPORT
2779
- bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
2780
3112
  *err_p = strdup("can't mlock because it's not supported on this system");
2781
3113
  return false;
2782
- }
2783
3114
  #endif // GGML_MLOCK_SUPPORT
3115
+ }
2784
3116
 
2785
3117
  ////////////////////////////////////////////////////////////////////////////////
2786
3118
 
@@ -2788,7 +3120,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
2788
3120
  struct ggml_context * ctx,
2789
3121
  enum ggml_type type,
2790
3122
  int n_dims,
2791
- const int* ne,
3123
+ const int64_t* ne,
2792
3124
  void* data) {
2793
3125
  // always insert objects at the end of the context's memory pool
2794
3126
  struct ggml_object * obj_cur = ctx->objects_end;
@@ -2799,7 +3131,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
2799
3131
 
2800
3132
  size_t size_needed = 0;
2801
3133
 
2802
- if (data == NULL) {
3134
+ if (data == NULL && !ctx->no_alloc) {
2803
3135
  size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
2804
3136
  for (int i = 1; i < n_dims; i++) {
2805
3137
  size_needed *= ne[i];
@@ -2883,11 +3215,12 @@ struct ggml_tensor * ggml_new_tensor_impl(
2883
3215
  /*.perf_runs =*/ 0,
2884
3216
  /*.perf_cycles =*/ 0,
2885
3217
  /*.perf_time_us =*/ 0,
2886
- /*.data =*/ data == NULL ? (void *)(result + 1) : data,
3218
+ /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
2887
3219
  /*.pad =*/ { 0 },
2888
3220
  };
2889
3221
 
2890
- ggml_assert_aligned(result->data);
3222
+ // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
3223
+ //ggml_assert_aligned(result->data);
2891
3224
 
2892
3225
  for (int i = 0; i < n_dims; i++) {
2893
3226
  result->ne[i] = ne[i];
@@ -2908,44 +3241,44 @@ struct ggml_tensor * ggml_new_tensor(
2908
3241
  struct ggml_context * ctx,
2909
3242
  enum ggml_type type,
2910
3243
  int n_dims,
2911
- const int * ne) {
3244
+ const int64_t * ne) {
2912
3245
  return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
2913
3246
  }
2914
3247
 
2915
3248
  struct ggml_tensor * ggml_new_tensor_1d(
2916
3249
  struct ggml_context * ctx,
2917
3250
  enum ggml_type type,
2918
- int ne0) {
3251
+ int64_t ne0) {
2919
3252
  return ggml_new_tensor(ctx, type, 1, &ne0);
2920
3253
  }
2921
3254
 
2922
3255
  struct ggml_tensor * ggml_new_tensor_2d(
2923
3256
  struct ggml_context * ctx,
2924
3257
  enum ggml_type type,
2925
- int ne0,
2926
- int ne1) {
2927
- const int ne[2] = { ne0, ne1 };
3258
+ int64_t ne0,
3259
+ int64_t ne1) {
3260
+ const int64_t ne[2] = { ne0, ne1 };
2928
3261
  return ggml_new_tensor(ctx, type, 2, ne);
2929
3262
  }
2930
3263
 
2931
3264
  struct ggml_tensor * ggml_new_tensor_3d(
2932
3265
  struct ggml_context * ctx,
2933
3266
  enum ggml_type type,
2934
- int ne0,
2935
- int ne1,
2936
- int ne2) {
2937
- const int ne[3] = { ne0, ne1, ne2 };
3267
+ int64_t ne0,
3268
+ int64_t ne1,
3269
+ int64_t ne2) {
3270
+ const int64_t ne[3] = { ne0, ne1, ne2 };
2938
3271
  return ggml_new_tensor(ctx, type, 3, ne);
2939
3272
  }
2940
3273
 
2941
3274
  struct ggml_tensor * ggml_new_tensor_4d(
2942
3275
  struct ggml_context * ctx,
2943
3276
  enum ggml_type type,
2944
- int ne0,
2945
- int ne1,
2946
- int ne2,
2947
- int ne3) {
2948
- const int ne[4] = { ne0, ne1, ne2, ne3 };
3277
+ int64_t ne0,
3278
+ int64_t ne1,
3279
+ int64_t ne2,
3280
+ int64_t ne3) {
3281
+ const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
2949
3282
  return ggml_new_tensor(ctx, type, 4, ne);
2950
3283
  }
2951
3284
 
@@ -3288,7 +3621,14 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
3288
3621
  struct ggml_tensor * ggml_view_tensor(
3289
3622
  struct ggml_context * ctx,
3290
3623
  const struct ggml_tensor * src) {
3291
- return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
3624
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
3625
+
3626
+ result->nb[0] = src->nb[0];
3627
+ result->nb[1] = src->nb[1];
3628
+ result->nb[2] = src->nb[2];
3629
+ result->nb[3] = src->nb[3];
3630
+
3631
+ return result;
3292
3632
  }
3293
3633
 
3294
3634
  ////////////////////////////////////////////////////////////////////////////////
@@ -3592,7 +3932,7 @@ struct ggml_tensor * ggml_mean(
3592
3932
  is_node = true;
3593
3933
  }
3594
3934
 
3595
- int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
3935
+ int64_t ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
3596
3936
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne);
3597
3937
 
3598
3938
  result->op = GGML_OP_MEAN;
@@ -3953,7 +4293,7 @@ struct ggml_tensor * ggml_mul_mat(
3953
4293
  is_node = true;
3954
4294
  }
3955
4295
 
3956
- const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
4296
+ const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
3957
4297
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
3958
4298
 
3959
4299
  result->op = GGML_OP_MUL_MAT;
@@ -4078,8 +4418,8 @@ struct ggml_tensor * ggml_reshape(
4078
4418
  struct ggml_tensor * ggml_reshape_2d(
4079
4419
  struct ggml_context * ctx,
4080
4420
  struct ggml_tensor * a,
4081
- int ne0,
4082
- int ne1) {
4421
+ int64_t ne0,
4422
+ int64_t ne1) {
4083
4423
  GGML_ASSERT(ggml_is_contiguous(a));
4084
4424
  GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
4085
4425
 
@@ -4090,7 +4430,7 @@ struct ggml_tensor * ggml_reshape_2d(
4090
4430
  is_node = true;
4091
4431
  }
4092
4432
 
4093
- const int ne[2] = { ne0, ne1 };
4433
+ const int64_t ne[2] = { ne0, ne1 };
4094
4434
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
4095
4435
 
4096
4436
  result->op = GGML_OP_RESHAPE;
@@ -4104,9 +4444,9 @@ struct ggml_tensor * ggml_reshape_2d(
4104
4444
  struct ggml_tensor * ggml_reshape_3d(
4105
4445
  struct ggml_context * ctx,
4106
4446
  struct ggml_tensor * a,
4107
- int ne0,
4108
- int ne1,
4109
- int ne2) {
4447
+ int64_t ne0,
4448
+ int64_t ne1,
4449
+ int64_t ne2) {
4110
4450
  GGML_ASSERT(ggml_is_contiguous(a));
4111
4451
  GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
4112
4452
 
@@ -4117,7 +4457,7 @@ struct ggml_tensor * ggml_reshape_3d(
4117
4457
  is_node = true;
4118
4458
  }
4119
4459
 
4120
- const int ne[3] = { ne0, ne1, ne2 };
4460
+ const int64_t ne[3] = { ne0, ne1, ne2 };
4121
4461
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
4122
4462
 
4123
4463
  result->op = GGML_OP_RESHAPE;
@@ -4133,7 +4473,7 @@ struct ggml_tensor * ggml_reshape_3d(
4133
4473
  struct ggml_tensor * ggml_view_1d(
4134
4474
  struct ggml_context * ctx,
4135
4475
  struct ggml_tensor * a,
4136
- int ne0,
4476
+ int64_t ne0,
4137
4477
  size_t offset) {
4138
4478
  if (a->grad) {
4139
4479
  GGML_ASSERT(false); // gradient propagation is not supported
@@ -4154,15 +4494,15 @@ struct ggml_tensor * ggml_view_1d(
4154
4494
  struct ggml_tensor * ggml_view_2d(
4155
4495
  struct ggml_context * ctx,
4156
4496
  struct ggml_tensor * a,
4157
- int ne0,
4158
- int ne1,
4497
+ int64_t ne0,
4498
+ int64_t ne1,
4159
4499
  size_t nb1,
4160
4500
  size_t offset) {
4161
4501
  if (a->grad) {
4162
4502
  GGML_ASSERT(false); // gradient propagation is not supported
4163
4503
  }
4164
4504
 
4165
- const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
4505
+ const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
4166
4506
 
4167
4507
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
4168
4508
 
@@ -4178,6 +4518,37 @@ struct ggml_tensor * ggml_view_2d(
4178
4518
  return result;
4179
4519
  }
4180
4520
 
4521
+ // ggml_view_3d
4522
+
4523
+ struct ggml_tensor * ggml_view_3d(
4524
+ struct ggml_context * ctx,
4525
+ struct ggml_tensor * a,
4526
+ int64_t ne0,
4527
+ int64_t ne1,
4528
+ int64_t ne2,
4529
+ size_t nb1,
4530
+ size_t nb2,
4531
+ size_t offset) {
4532
+ if (a->grad) {
4533
+ GGML_ASSERT(false); // gradient propagation is not supported
4534
+ }
4535
+
4536
+ const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
4537
+
4538
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
4539
+
4540
+ result->nb[1] = nb1;
4541
+ result->nb[2] = nb2;
4542
+ result->nb[3] = result->nb[2]*ne2;
4543
+
4544
+ result->op = GGML_OP_VIEW;
4545
+ result->grad = NULL;
4546
+ result->src0 = a;
4547
+ result->src1 = NULL; // TODO: maybe store the offset here?
4548
+
4549
+ return result;
4550
+ }
4551
+
4181
4552
  // ggml_permute
4182
4553
 
4183
4554
  struct ggml_tensor * ggml_permute(
@@ -4393,7 +4764,7 @@ struct ggml_tensor * ggml_conv_1d_1s(
4393
4764
  is_node = true;
4394
4765
  }
4395
4766
 
4396
- const int ne[4] = { b->ne[0], a->ne[2], 1, 1, };
4767
+ const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
4397
4768
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
4398
4769
 
4399
4770
  result->op = GGML_OP_CONV_1D_1S;
@@ -4420,7 +4791,7 @@ struct ggml_tensor * ggml_conv_1d_2s(
4420
4791
  is_node = true;
4421
4792
  }
4422
4793
 
4423
- const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
4794
+ const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
4424
4795
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
4425
4796
 
4426
4797
  result->op = GGML_OP_CONV_1D_2S;
@@ -4513,102 +4884,112 @@ static void ggml_compute_forward_dup_f16(
4513
4884
  const struct ggml_tensor * src0,
4514
4885
  struct ggml_tensor * dst) {
4515
4886
  GGML_ASSERT(params->ith == 0);
4516
- GGML_ASSERT(ggml_is_contiguous(dst));
4517
4887
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
4518
4888
 
4519
4889
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
4520
4890
  return;
4521
4891
  }
4522
4892
 
4523
- const int ne00 = src0->ne[0];
4524
- const int ne01 = src0->ne[1];
4525
- const int ne02 = src0->ne[2];
4526
- const int ne03 = src0->ne[3];
4893
+ const int64_t ne00 = src0->ne[0];
4894
+ const int64_t ne01 = src0->ne[1];
4895
+ const int64_t ne02 = src0->ne[2];
4896
+ const int64_t ne03 = src0->ne[3];
4527
4897
 
4528
4898
  const size_t nb00 = src0->nb[0];
4529
4899
  const size_t nb01 = src0->nb[1];
4530
4900
  const size_t nb02 = src0->nb[2];
4531
4901
  const size_t nb03 = src0->nb[3];
4532
4902
 
4533
- if (ggml_is_contiguous(src0) && src0->type == dst->type) {
4903
+ const size_t nb0 = dst->nb[0];
4904
+ const size_t nb1 = dst->nb[1];
4905
+ const size_t nb2 = dst->nb[2];
4906
+ const size_t nb3 = dst->nb[3];
4907
+
4908
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
4534
4909
  memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
4535
4910
  return;
4536
4911
  }
4537
4912
 
4538
- if (src0->nb[0] == sizeof(ggml_fp16_t)) {
4539
- if (dst->type == GGML_TYPE_F16) {
4540
- size_t id = 0;
4541
- const size_t rs = ne00*nb00;
4542
-
4543
- for (int i03 = 0; i03 < ne03; i03++) {
4544
- for (int i02 = 0; i02 < ne02; i02++) {
4545
- for (int i01 = 0; i01 < ne01; i01++) {
4546
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
4547
- char * dst_ptr = (char *) dst->data + id*rs;
4548
-
4549
- memcpy(dst_ptr, src0_ptr, rs);
4550
-
4551
- id++;
4552
- }
4913
+ if (src0->type == dst->type &&
4914
+ src0->ne[0] == dst->ne[0] &&
4915
+ src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
4916
+ // copy by rows
4917
+ const size_t rs = ne00*nb00;
4918
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
4919
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
4920
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
4921
+ memcpy(
4922
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
4923
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
4924
+ rs);
4553
4925
  }
4554
4926
  }
4555
- } else if (dst->type == GGML_TYPE_F32) {
4556
- size_t id = 0;
4557
- float * dst_ptr = (float *) dst->data;
4558
-
4559
- for (int i03 = 0; i03 < ne03; i03++) {
4560
- for (int i02 = 0; i02 < ne02; i02++) {
4561
- for (int i01 = 0; i01 < ne01; i01++) {
4562
- for (int i00 = 0; i00 < ne00; i00++) {
4563
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4564
-
4565
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4566
- id++;
4567
- }
4568
- }
4569
- }
4570
- }
4571
- } else {
4572
- GGML_ASSERT(false); // TODO: implement
4573
4927
  }
4574
- } else {
4575
- //printf("%s: this is not optimal - fix me\n", __func__);
4576
-
4577
- if (dst->type == GGML_TYPE_F32) {
4578
- size_t id = 0;
4579
- float * dst_ptr = (float *) dst->data;
4580
-
4581
- for (int i03 = 0; i03 < ne03; i03++) {
4582
- for (int i02 = 0; i02 < ne02; i02++) {
4583
- for (int i01 = 0; i01 < ne01; i01++) {
4584
- for (int i00 = 0; i00 < ne00; i00++) {
4585
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4928
+ return;
4929
+ }
4586
4930
 
4587
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4588
- id++;
4931
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
4932
+
4933
+ // dst counters
4934
+ int64_t i10 = 0;
4935
+ int64_t i11 = 0;
4936
+ int64_t i12 = 0;
4937
+ int64_t i13 = 0;
4938
+
4939
+ if (dst->type == GGML_TYPE_F16) {
4940
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
4941
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
4942
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
4943
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
4944
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4945
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4946
+
4947
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
4948
+
4949
+ if (++i10 == ne00) {
4950
+ i10 = 0;
4951
+ if (++i11 == ne01) {
4952
+ i11 = 0;
4953
+ if (++i12 == ne02) {
4954
+ i12 = 0;
4955
+ if (++i13 == ne03) {
4956
+ i13 = 0;
4957
+ }
4958
+ }
4959
+ }
4589
4960
  }
4590
4961
  }
4591
4962
  }
4592
4963
  }
4593
- } else if (dst->type == GGML_TYPE_F16) {
4594
- size_t id = 0;
4595
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4596
-
4597
- for (int i03 = 0; i03 < ne03; i03++) {
4598
- for (int i02 = 0; i02 < ne02; i02++) {
4599
- for (int i01 = 0; i01 < ne01; i01++) {
4600
- for (int i00 = 0; i00 < ne00; i00++) {
4601
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4602
-
4603
- dst_ptr[id] = *src0_ptr;
4604
- id++;
4964
+ }
4965
+ } else if (dst->type == GGML_TYPE_F32) {
4966
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
4967
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
4968
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
4969
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
4970
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4971
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4972
+
4973
+ *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
4974
+
4975
+ if (++i10 == ne00) {
4976
+ i10 = 0;
4977
+ if (++i11 == ne01) {
4978
+ i11 = 0;
4979
+ if (++i12 == ne02) {
4980
+ i12 = 0;
4981
+ if (++i13 == ne03) {
4982
+ i13 = 0;
4983
+ }
4984
+ }
4985
+ }
4605
4986
  }
4606
4987
  }
4607
4988
  }
4608
4989
  }
4609
- } else {
4610
- GGML_ASSERT(false); // TODO: implement
4611
4990
  }
4991
+ } else {
4992
+ GGML_ASSERT(false); // TODO: implement
4612
4993
  }
4613
4994
  }
4614
4995
 
@@ -4617,102 +4998,92 @@ static void ggml_compute_forward_dup_f32(
4617
4998
  const struct ggml_tensor * src0,
4618
4999
  struct ggml_tensor * dst) {
4619
5000
  GGML_ASSERT(params->ith == 0);
4620
- GGML_ASSERT(ggml_is_contiguous(dst));
4621
5001
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
4622
5002
 
4623
5003
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
4624
5004
  return;
4625
5005
  }
4626
5006
 
4627
- const int ne00 = src0->ne[0];
4628
- const int ne01 = src0->ne[1];
4629
- const int ne02 = src0->ne[2];
4630
- const int ne03 = src0->ne[3];
5007
+ const int64_t ne00 = src0->ne[0];
5008
+ const int64_t ne01 = src0->ne[1];
5009
+ const int64_t ne02 = src0->ne[2];
5010
+ const int64_t ne03 = src0->ne[3];
4631
5011
 
4632
5012
  const size_t nb00 = src0->nb[0];
4633
5013
  const size_t nb01 = src0->nb[1];
4634
5014
  const size_t nb02 = src0->nb[2];
4635
5015
  const size_t nb03 = src0->nb[3];
4636
5016
 
4637
- if (ggml_is_contiguous(src0) && src0->type == dst->type) {
5017
+ const size_t nb0 = dst->nb[0];
5018
+ const size_t nb1 = dst->nb[1];
5019
+ const size_t nb2 = dst->nb[2];
5020
+ const size_t nb3 = dst->nb[3];
5021
+
5022
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
4638
5023
  memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
4639
5024
  return;
4640
5025
  }
4641
5026
 
4642
- if (src0->nb[0] == sizeof(float)) {
4643
- if (dst->type == GGML_TYPE_F32) {
4644
- size_t id = 0;
4645
- const size_t rs = ne00*nb00;
4646
-
4647
- for (int i03 = 0; i03 < ne03; i03++) {
4648
- for (int i02 = 0; i02 < ne02; i02++) {
4649
- for (int i01 = 0; i01 < ne01; i01++) {
4650
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
4651
- char * dst_ptr = (char *) dst->data + id*rs;
4652
-
4653
- memcpy(dst_ptr, src0_ptr, rs);
4654
-
4655
- id++;
4656
- }
4657
- }
4658
- }
4659
- } else if (dst->type == GGML_TYPE_F16) {
4660
- size_t id = 0;
4661
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4662
-
4663
- for (int i03 = 0; i03 < ne03; i03++) {
4664
- for (int i02 = 0; i02 < ne02; i02++) {
4665
- for (int i01 = 0; i01 < ne01; i01++) {
4666
- for (int i00 = 0; i00 < ne00; i00++) {
4667
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4668
-
4669
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
4670
- id++;
5027
+ // dst counters
5028
+ int64_t i10 = 0;
5029
+ int64_t i11 = 0;
5030
+ int64_t i12 = 0;
5031
+ int64_t i13 = 0;
5032
+
5033
+ if (dst->type == GGML_TYPE_F32) {
5034
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5035
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5036
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5037
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5038
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5039
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5040
+
5041
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
5042
+
5043
+ if (++i10 == dst->ne[0]) {
5044
+ i10 = 0;
5045
+ if (++i11 == dst->ne[1]) {
5046
+ i11 = 0;
5047
+ if (++i12 == dst->ne[2]) {
5048
+ i12 = 0;
5049
+ if (++i13 == dst->ne[3]) {
5050
+ i13 = 0;
5051
+ }
5052
+ }
5053
+ }
4671
5054
  }
4672
5055
  }
4673
5056
  }
4674
5057
  }
4675
- } else {
4676
- GGML_ASSERT(false); // TODO: implement
4677
5058
  }
4678
- } else {
4679
- //printf("%s: this is not optimal - fix me\n", __func__);
4680
-
4681
- if (dst->type == GGML_TYPE_F32) {
4682
- size_t id = 0;
4683
- float * dst_ptr = (float *) dst->data;
4684
-
4685
- for (int i03 = 0; i03 < ne03; i03++) {
4686
- for (int i02 = 0; i02 < ne02; i02++) {
4687
- for (int i01 = 0; i01 < ne01; i01++) {
4688
- for (int i00 = 0; i00 < ne00; i00++) {
4689
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4690
-
4691
- dst_ptr[id] = *src0_ptr;
4692
- id++;
4693
- }
4694
- }
4695
- }
4696
- }
4697
- } else if (dst->type == GGML_TYPE_F16) {
4698
- size_t id = 0;
4699
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4700
-
4701
- for (int i03 = 0; i03 < ne03; i03++) {
4702
- for (int i02 = 0; i02 < ne02; i02++) {
4703
- for (int i01 = 0; i01 < ne01; i01++) {
4704
- for (int i00 = 0; i00 < ne00; i00++) {
4705
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4706
-
4707
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
4708
- id++;
5059
+ } else if (dst->type == GGML_TYPE_F16) {
5060
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5061
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5062
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5063
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5064
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5065
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5066
+
5067
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5068
+
5069
+ if (++i10 == dst->ne[0]) {
5070
+ i10 = 0;
5071
+ if (++i11 == dst->ne[1]) {
5072
+ i11 = 0;
5073
+ if (++i12 == dst->ne[2]) {
5074
+ i12 = 0;
5075
+ if (++i13 == dst->ne[3]) {
5076
+ i13 = 0;
5077
+ }
5078
+ }
5079
+ }
4709
5080
  }
4710
5081
  }
4711
5082
  }
4712
5083
  }
4713
- } else {
4714
- GGML_ASSERT(false); // TODO: implement
4715
5084
  }
5085
+ } else {
5086
+ GGML_ASSERT(false); // TODO: implement
4716
5087
  }
4717
5088
  }
4718
5089
 
@@ -5087,18 +5458,18 @@ static void ggml_compute_forward_sum_f32(
5087
5458
  assert(ggml_is_scalar(dst));
5088
5459
  assert(src0->nb[0] == sizeof(float));
5089
5460
 
5090
- const int ne00 = src0->ne[0];
5091
- const int ne01 = src0->ne[1];
5092
- const int ne02 = src0->ne[2];
5093
- const int ne03 = src0->ne[3];
5461
+ const int64_t ne00 = src0->ne[0];
5462
+ const int64_t ne01 = src0->ne[1];
5463
+ const int64_t ne02 = src0->ne[2];
5464
+ const int64_t ne03 = src0->ne[3];
5094
5465
 
5095
5466
  const size_t nb01 = src0->nb[1];
5096
5467
  const size_t nb02 = src0->nb[2];
5097
5468
  const size_t nb03 = src0->nb[3];
5098
5469
 
5099
- for (int i03 = 0; i03 < ne03; i03++) {
5100
- for (int i02 = 0; i02 < ne02; i02++) {
5101
- for (int i01 = 0; i01 < ne01; i01++) {
5470
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5471
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5472
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5102
5473
  ggml_vec_sum_f32(ne00,
5103
5474
  (float *) (dst->data),
5104
5475
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@@ -5143,19 +5514,19 @@ static void ggml_compute_forward_mean_f32(
5143
5514
 
5144
5515
  assert(src0->nb[0] == sizeof(float));
5145
5516
 
5146
- const int ne00 = src0->ne[0];
5147
- const int ne01 = src0->ne[1];
5148
- const int ne02 = src0->ne[2];
5149
- const int ne03 = src0->ne[3];
5517
+ const int64_t ne00 = src0->ne[0];
5518
+ const int64_t ne01 = src0->ne[1];
5519
+ const int64_t ne02 = src0->ne[2];
5520
+ const int64_t ne03 = src0->ne[3];
5150
5521
 
5151
5522
  const size_t nb01 = src0->nb[1];
5152
5523
  const size_t nb02 = src0->nb[2];
5153
5524
  const size_t nb03 = src0->nb[3];
5154
5525
 
5155
- const int ne0 = dst->ne[0];
5156
- const int ne1 = dst->ne[1];
5157
- const int ne2 = dst->ne[2];
5158
- const int ne3 = dst->ne[3];
5526
+ const int64_t ne0 = dst->ne[0];
5527
+ const int64_t ne1 = dst->ne[1];
5528
+ const int64_t ne2 = dst->ne[2];
5529
+ const int64_t ne3 = dst->ne[3];
5159
5530
 
5160
5531
  assert(ne0 == 1);
5161
5532
  assert(ne1 == ne01);
@@ -5171,9 +5542,9 @@ static void ggml_compute_forward_mean_f32(
5171
5542
  const size_t nb2 = dst->nb[2];
5172
5543
  const size_t nb3 = dst->nb[3];
5173
5544
 
5174
- for (int i03 = 0; i03 < ne03; i03++) {
5175
- for (int i02 = 0; i02 < ne02; i02++) {
5176
- for (int i01 = 0; i01 < ne01; i01++) {
5545
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5546
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5547
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5177
5548
  ggml_vec_sum_f32(ne00,
5178
5549
  (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5179
5550
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@@ -5660,10 +6031,10 @@ static void ggml_compute_forward_norm_f32(
5660
6031
  const int ith = params->ith;
5661
6032
  const int nth = params->nth;
5662
6033
 
5663
- const int ne00 = src0->ne[0];
5664
- const int ne01 = src0->ne[1];
5665
- const int ne02 = src0->ne[2];
5666
- const int ne03 = src0->ne[3];
6034
+ const int64_t ne00 = src0->ne[0];
6035
+ const int64_t ne01 = src0->ne[1];
6036
+ const int64_t ne02 = src0->ne[2];
6037
+ const int64_t ne03 = src0->ne[3];
5667
6038
 
5668
6039
  const size_t nb01 = src0->nb[1];
5669
6040
  const size_t nb02 = src0->nb[2];
@@ -5676,13 +6047,13 @@ static void ggml_compute_forward_norm_f32(
5676
6047
  const float eps = 1e-5f; // TODO: make this a parameter
5677
6048
 
5678
6049
  // TODO: optimize
5679
- for (int i03 = 0; i03 < ne03; i03++) {
5680
- for (int i02 = 0; i02 < ne02; i02++) {
5681
- for (int i01 = ith; i01 < ne01; i01 += nth) {
6050
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6051
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6052
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5682
6053
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5683
6054
 
5684
6055
  ggml_float sum = 0.0;
5685
- for (int i00 = 0; i00 < ne00; i00++) {
6056
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5686
6057
  sum += (ggml_float)x[i00];
5687
6058
  }
5688
6059
 
@@ -5691,7 +6062,7 @@ static void ggml_compute_forward_norm_f32(
5691
6062
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5692
6063
 
5693
6064
  ggml_float sum2 = 0.0;
5694
- for (int i00 = 0; i00 < ne00; i00++) {
6065
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5695
6066
  float v = x[i00] - mean;
5696
6067
  y[i00] = v;
5697
6068
  sum2 += (ggml_float)(v*v);
@@ -5743,10 +6114,10 @@ static void ggml_compute_forward_rms_norm_f32(
5743
6114
  const int ith = params->ith;
5744
6115
  const int nth = params->nth;
5745
6116
 
5746
- const int ne00 = src0->ne[0];
5747
- const int ne01 = src0->ne[1];
5748
- const int ne02 = src0->ne[2];
5749
- const int ne03 = src0->ne[3];
6117
+ const int64_t ne00 = src0->ne[0];
6118
+ const int64_t ne01 = src0->ne[1];
6119
+ const int64_t ne02 = src0->ne[2];
6120
+ const int64_t ne03 = src0->ne[3];
5750
6121
 
5751
6122
  const size_t nb01 = src0->nb[1];
5752
6123
  const size_t nb02 = src0->nb[2];
@@ -5759,13 +6130,13 @@ static void ggml_compute_forward_rms_norm_f32(
5759
6130
  const float eps = 1e-6f; // TODO: make this a parameter
5760
6131
 
5761
6132
  // TODO: optimize
5762
- for (int i03 = 0; i03 < ne03; i03++) {
5763
- for (int i02 = 0; i02 < ne02; i02++) {
5764
- for (int i01 = ith; i01 < ne01; i01 += nth) {
6133
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6134
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6135
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5765
6136
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5766
6137
 
5767
6138
  ggml_float sum = 0.0;
5768
- for (int i00 = 0; i00 < ne00; i00++) {
6139
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5769
6140
  sum += (ggml_float)(x[i00] * x[i00]);
5770
6141
  }
5771
6142
 
@@ -5818,13 +6189,13 @@ static bool ggml_compute_forward_mul_mat_use_blas(
5818
6189
  const struct ggml_tensor * src0,
5819
6190
  const struct ggml_tensor * src1,
5820
6191
  struct ggml_tensor * dst) {
5821
- //const int ne00 = src0->ne[0];
5822
- //const int ne01 = src0->ne[1];
6192
+ //const int64_t ne00 = src0->ne[0];
6193
+ //const int64_t ne01 = src0->ne[1];
5823
6194
 
5824
- const int ne10 = src1->ne[0];
6195
+ const int64_t ne10 = src1->ne[0];
5825
6196
 
5826
- const int ne0 = dst->ne[0];
5827
- const int ne1 = dst->ne[1];
6197
+ const int64_t ne0 = dst->ne[0];
6198
+ const int64_t ne1 = dst->ne[1];
5828
6199
 
5829
6200
  // TODO: find the optimal values for these
5830
6201
  if (ggml_is_contiguous(src0) &&
@@ -5846,23 +6217,23 @@ static void ggml_compute_forward_mul_mat_f32(
5846
6217
  int64_t t0 = ggml_perf_time_us();
5847
6218
  UNUSED(t0);
5848
6219
 
5849
- const int ne00 = src0->ne[0];
5850
- const int ne01 = src0->ne[1];
5851
- const int ne02 = src0->ne[2];
5852
- const int ne03 = src0->ne[3];
6220
+ const int64_t ne00 = src0->ne[0];
6221
+ const int64_t ne01 = src0->ne[1];
6222
+ const int64_t ne02 = src0->ne[2];
6223
+ const int64_t ne03 = src0->ne[3];
5853
6224
 
5854
6225
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
5855
- const int ne10 = src1->ne[0];
6226
+ const int64_t ne10 = src1->ne[0];
5856
6227
  #endif
5857
- const int ne11 = src1->ne[1];
6228
+ const int64_t ne11 = src1->ne[1];
5858
6229
  #ifndef NDEBUG
5859
- const int ne12 = src1->ne[2];
5860
- const int ne13 = src1->ne[3];
6230
+ const int64_t ne12 = src1->ne[2];
6231
+ const int64_t ne13 = src1->ne[3];
5861
6232
 
5862
- const int ne0 = dst->ne[0];
5863
- const int ne1 = dst->ne[1];
5864
- const int ne2 = dst->ne[2];
5865
- const int ne3 = dst->ne[3];
6233
+ const int64_t ne0 = dst->ne[0];
6234
+ const int64_t ne1 = dst->ne[1];
6235
+ const int64_t ne2 = dst->ne[2];
6236
+ const int64_t ne3 = dst->ne[3];
5866
6237
 
5867
6238
  const int nb00 = src0->nb[0];
5868
6239
  #endif
@@ -5922,8 +6293,8 @@ static void ggml_compute_forward_mul_mat_f32(
5922
6293
  return;
5923
6294
  }
5924
6295
 
5925
- for (int i03 = 0; i03 < ne03; i03++) {
5926
- for (int i02 = 0; i02 < ne02; i02++) {
6296
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6297
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5927
6298
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
5928
6299
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
5929
6300
 
@@ -5970,7 +6341,7 @@ static void ggml_compute_forward_mul_mat_f32(
5970
6341
  const int i02 = (ir - i03*ne02*ne01)/ne01;
5971
6342
  const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
5972
6343
 
5973
- for (int ic = 0; ic < ne11; ++ic) {
6344
+ for (int64_t ic = 0; ic < ne11; ++ic) {
5974
6345
  // src1 indices
5975
6346
  const int i13 = i03;
5976
6347
  const int i12 = i02;
@@ -6011,21 +6382,21 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6011
6382
  int64_t t0 = ggml_perf_time_us();
6012
6383
  UNUSED(t0);
6013
6384
 
6014
- const int ne00 = src0->ne[0];
6015
- const int ne01 = src0->ne[1];
6016
- const int ne02 = src0->ne[2];
6017
- const int ne03 = src0->ne[3];
6385
+ const int64_t ne00 = src0->ne[0];
6386
+ const int64_t ne01 = src0->ne[1];
6387
+ const int64_t ne02 = src0->ne[2];
6388
+ const int64_t ne03 = src0->ne[3];
6018
6389
 
6019
- const int ne10 = src1->ne[0];
6020
- const int ne11 = src1->ne[1];
6021
- const int ne12 = src1->ne[2];
6022
- const int ne13 = src1->ne[3];
6390
+ const int64_t ne10 = src1->ne[0];
6391
+ const int64_t ne11 = src1->ne[1];
6392
+ const int64_t ne12 = src1->ne[2];
6393
+ const int64_t ne13 = src1->ne[3];
6023
6394
 
6024
- const int ne0 = dst->ne[0];
6025
- const int ne1 = dst->ne[1];
6026
- const int ne2 = dst->ne[2];
6027
- const int ne3 = dst->ne[3];
6028
- //const int ne = ne0*ne1*ne2*ne3;
6395
+ const int64_t ne0 = dst->ne[0];
6396
+ const int64_t ne1 = dst->ne[1];
6397
+ const int64_t ne2 = dst->ne[2];
6398
+ const int64_t ne3 = dst->ne[3];
6399
+ //const int64_t ne = ne0*ne1*ne2*ne3;
6029
6400
 
6030
6401
  const int nb00 = src0->nb[0];
6031
6402
  const int nb01 = src0->nb[1];
@@ -6085,12 +6456,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6085
6456
 
6086
6457
  float * const wdata = params->wdata;
6087
6458
 
6088
- for (int i03 = 0; i03 < ne03; i03++) {
6089
- for (int i02 = 0; i02 < ne02; i02++) {
6459
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6460
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6090
6461
  {
6091
6462
  size_t id = 0;
6092
- for (int i01 = 0; i01 < ne01; ++i01) {
6093
- for (int i00 = 0; i00 < ne00; ++i00) {
6463
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
6464
+ for (int64_t i00 = 0; i00 < ne00; ++i00) {
6094
6465
  wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
6095
6466
  }
6096
6467
  }
@@ -6120,10 +6491,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6120
6491
  ggml_fp16_t * const wdata = params->wdata;
6121
6492
 
6122
6493
  size_t id = 0;
6123
- for (int i13 = 0; i13 < ne13; ++i13) {
6124
- for (int i12 = 0; i12 < ne12; ++i12) {
6125
- for (int i11 = 0; i11 < ne11; ++i11) {
6126
- for (int i10 = 0; i10 < ne10; ++i10) {
6494
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
6495
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
6496
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
6497
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
6127
6498
  wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
6128
6499
  }
6129
6500
  }
@@ -6175,7 +6546,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6175
6546
 
6176
6547
  float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
6177
6548
 
6178
- for (int ic = 0; ic < ne11; ++ic) {
6549
+ for (int64_t ic = 0; ic < ne11; ++ic) {
6179
6550
  ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
6180
6551
  }
6181
6552
  }
@@ -6224,20 +6595,20 @@ static void ggml_compute_forward_mul_mat_q_f32(
6224
6595
  int64_t t0 = ggml_perf_time_us();
6225
6596
  UNUSED(t0);
6226
6597
 
6227
- const int ne00 = src0->ne[0];
6228
- const int ne01 = src0->ne[1];
6229
- const int ne02 = src0->ne[2];
6230
- const int ne03 = src0->ne[3];
6598
+ const int64_t ne00 = src0->ne[0];
6599
+ const int64_t ne01 = src0->ne[1];
6600
+ const int64_t ne02 = src0->ne[2];
6601
+ const int64_t ne03 = src0->ne[3];
6231
6602
 
6232
- const int ne10 = src1->ne[0];
6233
- const int ne11 = src1->ne[1];
6234
- const int ne12 = src1->ne[2];
6235
- const int ne13 = src1->ne[3];
6603
+ const int64_t ne10 = src1->ne[0];
6604
+ const int64_t ne11 = src1->ne[1];
6605
+ const int64_t ne12 = src1->ne[2];
6606
+ const int64_t ne13 = src1->ne[3];
6236
6607
 
6237
- const int ne0 = dst->ne[0];
6238
- const int ne1 = dst->ne[1];
6239
- const int ne2 = dst->ne[2];
6240
- const int ne3 = dst->ne[3];
6608
+ const int64_t ne0 = dst->ne[0];
6609
+ const int64_t ne1 = dst->ne[1];
6610
+ const int64_t ne2 = dst->ne[2];
6611
+ const int64_t ne3 = dst->ne[3];
6241
6612
 
6242
6613
  const int nb00 = src0->nb[0];
6243
6614
  const int nb01 = src0->nb[1];
@@ -6301,11 +6672,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
6301
6672
  float * const wdata = params->wdata;
6302
6673
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
6303
6674
 
6304
- for (int i03 = 0; i03 < ne03; i03++) {
6305
- for (int i02 = 0; i02 < ne02; i02++) {
6675
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6676
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6306
6677
  {
6307
6678
  size_t id = 0;
6308
- for (int i01 = 0; i01 < ne01; ++i01) {
6679
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
6309
6680
  dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
6310
6681
  id += ne00;
6311
6682
  }
@@ -6335,9 +6706,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
6335
6706
  char * wdata = params->wdata;
6336
6707
  const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
6337
6708
 
6338
- for (int i13 = 0; i13 < ne13; ++i13) {
6339
- for (int i12 = 0; i12 < ne12; ++i12) {
6340
- for (int i11 = 0; i11 < ne11; ++i11) {
6709
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
6710
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
6711
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
6341
6712
  quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
6342
6713
  wdata += row_size;
6343
6714
  }
@@ -6386,7 +6757,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6386
6757
 
6387
6758
  assert(ne00 % 32 == 0);
6388
6759
 
6389
- for (int ic = 0; ic < ne11; ++ic) {
6760
+ for (int64_t ic = 0; ic < ne11; ++ic) {
6390
6761
  vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
6391
6762
  }
6392
6763
  }
@@ -6867,7 +7238,6 @@ static void ggml_compute_forward_rope_f32(
6867
7238
  const struct ggml_tensor * src0,
6868
7239
  const struct ggml_tensor * src1,
6869
7240
  struct ggml_tensor * dst) {
6870
- assert(params->ith == 0);
6871
7241
  assert(src1->type == GGML_TYPE_I32);
6872
7242
  assert(ggml_nelements(src1) == 3);
6873
7243
 
@@ -6879,10 +7249,10 @@ static void ggml_compute_forward_rope_f32(
6879
7249
  const int n_dims = ((int32_t *) src1->data)[1];
6880
7250
  const int mode = ((int32_t *) src1->data)[2];
6881
7251
 
6882
- //const int ne0 = src0->ne[0];
6883
- const int ne1 = src0->ne[1];
6884
- const int ne2 = src0->ne[2];
6885
- const int ne3 = src0->ne[3];
7252
+ //const int64_t ne0 = src0->ne[0];
7253
+ const int64_t ne1 = src0->ne[1];
7254
+ const int64_t ne2 = src0->ne[2];
7255
+ const int64_t ne3 = src0->ne[3];
6886
7256
 
6887
7257
  const int nb0 = src0->nb[0];
6888
7258
  const int nb1 = src0->nb[1];
@@ -6894,11 +7264,28 @@ static void ggml_compute_forward_rope_f32(
6894
7264
 
6895
7265
  assert(nb0 == sizeof(float));
6896
7266
 
6897
- // TODO: optimize
6898
- for (int i3 = 0; i3 < ne3; i3++) {
6899
- for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7267
+ const int ith = params->ith;
7268
+ const int nth = params->nth;
7269
+
7270
+ const int nr = ggml_nrows(src0);
7271
+
7272
+ // rows per thread
7273
+ const int dr = (nr + nth - 1)/nth;
7274
+
7275
+ // row range for this thread
7276
+ const int ir0 = dr*ith;
7277
+ const int ir1 = MIN(ir0 + dr, nr);
7278
+
7279
+ // row index used to determine which thread to use
7280
+ int ir = 0;
7281
+
7282
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7283
+ for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
6900
7284
  const int p = (mode == 0 ? n_past + i2 : i2);
6901
- for (int i1 = 0; i1 < ne1; i1++) {
7285
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7286
+ if (ir++ < ir0) continue;
7287
+ if (ir > ir1) break;
7288
+
6902
7289
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
6903
7290
  const float theta = powf(10000.0, ((float)-i0)/n_dims);
6904
7291
 
@@ -6924,7 +7311,6 @@ static void ggml_compute_forward_rope_f16(
6924
7311
  const struct ggml_tensor * src0,
6925
7312
  const struct ggml_tensor * src1,
6926
7313
  struct ggml_tensor * dst) {
6927
- assert(params->ith == 0);
6928
7314
  assert(src1->type == GGML_TYPE_I32);
6929
7315
  assert(ggml_nelements(src1) == 3);
6930
7316
 
@@ -6936,10 +7322,10 @@ static void ggml_compute_forward_rope_f16(
6936
7322
  const int n_dims = ((int32_t *) src1->data)[1];
6937
7323
  const int mode = ((int32_t *) src1->data)[2];
6938
7324
 
6939
- //const int ne0 = src0->ne[0];
6940
- const int ne1 = src0->ne[1];
6941
- const int ne2 = src0->ne[2];
6942
- const int ne3 = src0->ne[3];
7325
+ //const int64_t ne0 = src0->ne[0];
7326
+ const int64_t ne1 = src0->ne[1];
7327
+ const int64_t ne2 = src0->ne[2];
7328
+ const int64_t ne3 = src0->ne[3];
6943
7329
 
6944
7330
  const int nb0 = src0->nb[0];
6945
7331
  const int nb1 = src0->nb[1];
@@ -6951,10 +7337,28 @@ static void ggml_compute_forward_rope_f16(
6951
7337
 
6952
7338
  assert(nb0 == sizeof(ggml_fp16_t));
6953
7339
 
6954
- for (int i3 = 0; i3 < ne3; i3++) {
6955
- for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7340
+ const int ith = params->ith;
7341
+ const int nth = params->nth;
7342
+
7343
+ const int nr = ggml_nrows(src0);
7344
+
7345
+ // rows per thread
7346
+ const int dr = (nr + nth - 1)/nth;
7347
+
7348
+ // row range for this thread
7349
+ const int ir0 = dr*ith;
7350
+ const int ir1 = MIN(ir0 + dr, nr);
7351
+
7352
+ // row index used to determine which thread to use
7353
+ int ir = 0;
7354
+
7355
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7356
+ for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
6956
7357
  const int p = (mode == 0 ? n_past + i2 : i2);
6957
- for (int i1 = 0; i1 < ne1; i1++) {
7358
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7359
+ if (ir++ < ir0) continue;
7360
+ if (ir > ir1) break;
7361
+
6958
7362
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
6959
7363
  const float theta = powf(10000.0, ((float)-i0)/n_dims);
6960
7364
 
@@ -7015,21 +7419,21 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7015
7419
  int64_t t0 = ggml_perf_time_us();
7016
7420
  UNUSED(t0);
7017
7421
 
7018
- const int ne00 = src0->ne[0];
7019
- const int ne01 = src0->ne[1];
7020
- const int ne02 = src0->ne[2];
7021
- //const int ne03 = src0->ne[3];
7422
+ const int64_t ne00 = src0->ne[0];
7423
+ const int64_t ne01 = src0->ne[1];
7424
+ const int64_t ne02 = src0->ne[2];
7425
+ //const int64_t ne03 = src0->ne[3];
7022
7426
 
7023
- const int ne10 = src1->ne[0];
7024
- const int ne11 = src1->ne[1];
7025
- //const int ne12 = src1->ne[2];
7026
- //const int ne13 = src1->ne[3];
7427
+ const int64_t ne10 = src1->ne[0];
7428
+ const int64_t ne11 = src1->ne[1];
7429
+ //const int64_t ne12 = src1->ne[2];
7430
+ //const int64_t ne13 = src1->ne[3];
7027
7431
 
7028
- //const int ne0 = dst->ne[0];
7029
- //const int ne1 = dst->ne[1];
7030
- //const int ne2 = dst->ne[2];
7031
- //const int ne3 = dst->ne[3];
7032
- //const int ne = ne0*ne1*ne2*ne3;
7432
+ //const int64_t ne0 = dst->ne[0];
7433
+ //const int64_t ne1 = dst->ne[1];
7434
+ //const int64_t ne2 = dst->ne[2];
7435
+ //const int64_t ne3 = dst->ne[3];
7436
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7033
7437
 
7034
7438
  const int nb00 = src0->nb[0];
7035
7439
  const int nb01 = src0->nb[1];
@@ -7066,11 +7470,11 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7066
7470
  {
7067
7471
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
7068
7472
 
7069
- for (int i02 = 0; i02 < ne02; i02++) {
7070
- for (int i01 = 0; i01 < ne01; i01++) {
7473
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7474
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7071
7475
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
7072
7476
  ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
7073
- for (int i00 = 0; i00 < ne00; i00++) {
7477
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7074
7478
  dst_data[i00*ew0 + i01] = src[i00];
7075
7479
  }
7076
7480
  }
@@ -7081,10 +7485,10 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7081
7485
  {
7082
7486
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
7083
7487
 
7084
- for (int i11 = 0; i11 < ne11; i11++) {
7488
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7085
7489
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7086
7490
  ggml_fp16_t * dst_data = wdata;
7087
- for (int i10 = 0; i10 < ne10; i10++) {
7491
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7088
7492
  dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
7089
7493
  }
7090
7494
  }
@@ -7109,7 +7513,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7109
7513
 
7110
7514
  for (int i1 = ir0; i1 < ir1; i1++) {
7111
7515
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7112
- for (int i0 = 0; i0 < ne10; ++i0) {
7516
+ for (int64_t i0 = 0; i0 < ne10; ++i0) {
7113
7517
  dst_data[i0] = 0;
7114
7518
  for (int k = -nh; k <= nh; k++) {
7115
7519
  float v = 0.0f;
@@ -7135,21 +7539,21 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7135
7539
  int64_t t0 = ggml_perf_time_us();
7136
7540
  UNUSED(t0);
7137
7541
 
7138
- const int ne00 = src0->ne[0];
7139
- const int ne01 = src0->ne[1];
7140
- const int ne02 = src0->ne[2];
7141
- //const int ne03 = src0->ne[3];
7542
+ const int64_t ne00 = src0->ne[0];
7543
+ const int64_t ne01 = src0->ne[1];
7544
+ const int64_t ne02 = src0->ne[2];
7545
+ //const int64_t ne03 = src0->ne[3];
7142
7546
 
7143
- const int ne10 = src1->ne[0];
7144
- const int ne11 = src1->ne[1];
7145
- //const int ne12 = src1->ne[2];
7146
- //const int ne13 = src1->ne[3];
7547
+ const int64_t ne10 = src1->ne[0];
7548
+ const int64_t ne11 = src1->ne[1];
7549
+ //const int64_t ne12 = src1->ne[2];
7550
+ //const int64_t ne13 = src1->ne[3];
7147
7551
 
7148
- //const int ne0 = dst->ne[0];
7149
- //const int ne1 = dst->ne[1];
7150
- //const int ne2 = dst->ne[2];
7151
- //const int ne3 = dst->ne[3];
7152
- //const int ne = ne0*ne1*ne2*ne3;
7552
+ //const int64_t ne0 = dst->ne[0];
7553
+ //const int64_t ne1 = dst->ne[1];
7554
+ //const int64_t ne2 = dst->ne[2];
7555
+ //const int64_t ne3 = dst->ne[3];
7556
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7153
7557
 
7154
7558
  const int nb00 = src0->nb[0];
7155
7559
  const int nb01 = src0->nb[1];
@@ -7186,11 +7590,11 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7186
7590
  {
7187
7591
  float * const wdata = (float *) params->wdata + 0;
7188
7592
 
7189
- for (int i02 = 0; i02 < ne02; i02++) {
7190
- for (int i01 = 0; i01 < ne01; i01++) {
7593
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7594
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7191
7595
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
7192
7596
  float * dst_data = wdata + i02*ew0*ne00;
7193
- for (int i00 = 0; i00 < ne00; i00++) {
7597
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7194
7598
  dst_data[i00*ew0 + i01] = src[i00];
7195
7599
  }
7196
7600
  }
@@ -7201,10 +7605,10 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7201
7605
  {
7202
7606
  float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
7203
7607
 
7204
- for (int i11 = 0; i11 < ne11; i11++) {
7608
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7205
7609
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7206
7610
  float * dst_data = wdata;
7207
- for (int i10 = 0; i10 < ne10; i10++) {
7611
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7208
7612
  dst_data[(i10 + nh)*ew0 + i11] = src[i10];
7209
7613
  }
7210
7614
  }
@@ -7229,7 +7633,7 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7229
7633
 
7230
7634
  for (int i1 = ir0; i1 < ir1; i1++) {
7231
7635
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7232
- for (int i0 = 0; i0 < ne10; ++i0) {
7636
+ for (int64_t i0 = 0; i0 < ne10; ++i0) {
7233
7637
  dst_data[i0] = 0;
7234
7638
  for (int k = -nh; k <= nh; k++) {
7235
7639
  float v = 0.0f;
@@ -7283,21 +7687,21 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7283
7687
  int64_t t0 = ggml_perf_time_us();
7284
7688
  UNUSED(t0);
7285
7689
 
7286
- const int ne00 = src0->ne[0];
7287
- const int ne01 = src0->ne[1];
7288
- const int ne02 = src0->ne[2];
7289
- //const int ne03 = src0->ne[3];
7690
+ const int64_t ne00 = src0->ne[0];
7691
+ const int64_t ne01 = src0->ne[1];
7692
+ const int64_t ne02 = src0->ne[2];
7693
+ //const int64_t ne03 = src0->ne[3];
7290
7694
 
7291
- const int ne10 = src1->ne[0];
7292
- const int ne11 = src1->ne[1];
7293
- //const int ne12 = src1->ne[2];
7294
- //const int ne13 = src1->ne[3];
7695
+ const int64_t ne10 = src1->ne[0];
7696
+ const int64_t ne11 = src1->ne[1];
7697
+ //const int64_t ne12 = src1->ne[2];
7698
+ //const int64_t ne13 = src1->ne[3];
7295
7699
 
7296
- //const int ne0 = dst->ne[0];
7297
- //const int ne1 = dst->ne[1];
7298
- //const int ne2 = dst->ne[2];
7299
- //const int ne3 = dst->ne[3];
7300
- //const int ne = ne0*ne1*ne2*ne3;
7700
+ //const int64_t ne0 = dst->ne[0];
7701
+ //const int64_t ne1 = dst->ne[1];
7702
+ //const int64_t ne2 = dst->ne[2];
7703
+ //const int64_t ne3 = dst->ne[3];
7704
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7301
7705
 
7302
7706
  const int nb00 = src0->nb[0];
7303
7707
  const int nb01 = src0->nb[1];
@@ -7334,11 +7738,11 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7334
7738
  {
7335
7739
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
7336
7740
 
7337
- for (int i02 = 0; i02 < ne02; i02++) {
7338
- for (int i01 = 0; i01 < ne01; i01++) {
7741
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7742
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7339
7743
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
7340
7744
  ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
7341
- for (int i00 = 0; i00 < ne00; i00++) {
7745
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7342
7746
  dst_data[i00*ew0 + i01] = src[i00];
7343
7747
  }
7344
7748
  }
@@ -7349,10 +7753,10 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7349
7753
  {
7350
7754
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
7351
7755
 
7352
- for (int i11 = 0; i11 < ne11; i11++) {
7756
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7353
7757
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7354
7758
  ggml_fp16_t * dst_data = wdata;
7355
- for (int i10 = 0; i10 < ne10; i10++) {
7759
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7356
7760
  dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
7357
7761
  }
7358
7762
  }
@@ -7377,7 +7781,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7377
7781
 
7378
7782
  for (int i1 = ir0; i1 < ir1; i1++) {
7379
7783
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7380
- for (int i0 = 0; i0 < ne10; i0 += 2) {
7784
+ for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
7381
7785
  dst_data[i0/2] = 0;
7382
7786
  for (int k = -nh; k <= nh; k++) {
7383
7787
  float v = 0.0f;
@@ -7403,21 +7807,21 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7403
7807
  int64_t t0 = ggml_perf_time_us();
7404
7808
  UNUSED(t0);
7405
7809
 
7406
- const int ne00 = src0->ne[0];
7407
- const int ne01 = src0->ne[1];
7408
- const int ne02 = src0->ne[2];
7409
- //const int ne03 = src0->ne[3];
7810
+ const int64_t ne00 = src0->ne[0];
7811
+ const int64_t ne01 = src0->ne[1];
7812
+ const int64_t ne02 = src0->ne[2];
7813
+ //const int64_t ne03 = src0->ne[3];
7410
7814
 
7411
- const int ne10 = src1->ne[0];
7412
- const int ne11 = src1->ne[1];
7413
- //const int ne12 = src1->ne[2];
7414
- //const int ne13 = src1->ne[3];
7815
+ const int64_t ne10 = src1->ne[0];
7816
+ const int64_t ne11 = src1->ne[1];
7817
+ //const int64_t ne12 = src1->ne[2];
7818
+ //const int64_t ne13 = src1->ne[3];
7415
7819
 
7416
- //const int ne0 = dst->ne[0];
7417
- //const int ne1 = dst->ne[1];
7418
- //const int ne2 = dst->ne[2];
7419
- //const int ne3 = dst->ne[3];
7420
- //const int ne = ne0*ne1*ne2*ne3;
7820
+ //const int64_t ne0 = dst->ne[0];
7821
+ //const int64_t ne1 = dst->ne[1];
7822
+ //const int64_t ne2 = dst->ne[2];
7823
+ //const int64_t ne3 = dst->ne[3];
7824
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7421
7825
 
7422
7826
  const int nb00 = src0->nb[0];
7423
7827
  const int nb01 = src0->nb[1];
@@ -7454,11 +7858,11 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7454
7858
  {
7455
7859
  float * const wdata = (float *) params->wdata + 0;
7456
7860
 
7457
- for (int i02 = 0; i02 < ne02; i02++) {
7458
- for (int i01 = 0; i01 < ne01; i01++) {
7861
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7862
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7459
7863
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
7460
7864
  float * dst_data = wdata + i02*ew0*ne00;
7461
- for (int i00 = 0; i00 < ne00; i00++) {
7865
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7462
7866
  dst_data[i00*ew0 + i01] = src[i00];
7463
7867
  }
7464
7868
  }
@@ -7469,10 +7873,10 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7469
7873
  {
7470
7874
  float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
7471
7875
 
7472
- for (int i11 = 0; i11 < ne11; i11++) {
7876
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7473
7877
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7474
7878
  float * dst_data = wdata;
7475
- for (int i10 = 0; i10 < ne10; i10++) {
7879
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7476
7880
  dst_data[(i10 + nh)*ew0 + i11] = src[i10];
7477
7881
  }
7478
7882
  }
@@ -7497,7 +7901,7 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7497
7901
 
7498
7902
  for (int i1 = ir0; i1 < ir1; i1++) {
7499
7903
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7500
- for (int i0 = 0; i0 < ne10; i0 += 2) {
7904
+ for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
7501
7905
  dst_data[i0/2] = 0;
7502
7906
  for (int k = -nh; k <= nh; k++) {
7503
7907
  float v = 0.0f;
@@ -7549,25 +7953,25 @@ static void ggml_compute_forward_flash_attn_f32(
7549
7953
  int64_t t0 = ggml_perf_time_us();
7550
7954
  UNUSED(t0);
7551
7955
 
7552
- const int neq0 = q->ne[0];
7553
- const int neq1 = q->ne[1];
7554
- const int neq2 = q->ne[2];
7555
- const int neq3 = q->ne[3];
7956
+ const int64_t neq0 = q->ne[0];
7957
+ const int64_t neq1 = q->ne[1];
7958
+ const int64_t neq2 = q->ne[2];
7959
+ const int64_t neq3 = q->ne[3];
7556
7960
 
7557
- const int nek0 = k->ne[0];
7558
- const int nek1 = k->ne[1];
7559
- //const int nek2 = k->ne[2];
7560
- //const int nek3 = k->ne[3];
7961
+ const int64_t nek0 = k->ne[0];
7962
+ const int64_t nek1 = k->ne[1];
7963
+ //const int64_t nek2 = k->ne[2];
7964
+ //const int64_t nek3 = k->ne[3];
7561
7965
 
7562
- //const int nev0 = v->ne[0];
7563
- const int nev1 = v->ne[1];
7564
- //const int nev2 = v->ne[2];
7565
- //const int nev3 = v->ne[3];
7966
+ //const int64_t nev0 = v->ne[0];
7967
+ const int64_t nev1 = v->ne[1];
7968
+ //const int64_t nev2 = v->ne[2];
7969
+ //const int64_t nev3 = v->ne[3];
7566
7970
 
7567
- const int ne0 = dst->ne[0];
7568
- const int ne1 = dst->ne[1];
7569
- //const int ne2 = dst->ne[2];
7570
- //const int ne3 = dst->ne[3];
7971
+ const int64_t ne0 = dst->ne[0];
7972
+ const int64_t ne1 = dst->ne[1];
7973
+ //const int64_t ne2 = dst->ne[2];
7974
+ //const int64_t ne3 = dst->ne[3];
7571
7975
 
7572
7976
  const int nbk0 = k->nb[0];
7573
7977
  const int nbk1 = k->nb[1];
@@ -7592,10 +7996,10 @@ static void ggml_compute_forward_flash_attn_f32(
7592
7996
  const int ith = params->ith;
7593
7997
  const int nth = params->nth;
7594
7998
 
7595
- const int D = neq0;
7596
- const int N = neq1;
7597
- const int P = nek1 - N;
7598
- const int M = P + N;
7999
+ const int64_t D = neq0;
8000
+ const int64_t N = neq1;
8001
+ const int64_t P = nek1 - N;
8002
+ const int64_t M = P + N;
7599
8003
 
7600
8004
  const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
7601
8005
 
@@ -7657,7 +8061,7 @@ static void ggml_compute_forward_flash_attn_f32(
7657
8061
  S[i] = -INFINITY;
7658
8062
  }
7659
8063
 
7660
- for (int ic = 0; ic < nek1; ++ic) {
8064
+ for (int64_t ic = 0; ic < nek1; ++ic) {
7661
8065
  // k indices
7662
8066
  const int ik3 = iq3;
7663
8067
  const int ik2 = iq2;
@@ -7676,7 +8080,7 @@ static void ggml_compute_forward_flash_attn_f32(
7676
8080
  ggml_vec_scale_f32(nek1, S, scale);
7677
8081
 
7678
8082
  if (masked) {
7679
- for (int i = P; i < M; i++) {
8083
+ for (int64_t i = P; i < M; i++) {
7680
8084
  if (i > P + iq1) {
7681
8085
  S[i] = -INFINITY;
7682
8086
  }
@@ -7734,7 +8138,7 @@ static void ggml_compute_forward_flash_attn_f32(
7734
8138
  #endif
7735
8139
  }
7736
8140
 
7737
- for (int ic = 0; ic < nev1; ++ic) {
8141
+ for (int64_t ic = 0; ic < nev1; ++ic) {
7738
8142
  // dst indices
7739
8143
  const int i1 = iq1;
7740
8144
  const int i2 = iq2;
@@ -7758,25 +8162,25 @@ static void ggml_compute_forward_flash_attn_f16(
7758
8162
  int64_t t0 = ggml_perf_time_us();
7759
8163
  UNUSED(t0);
7760
8164
 
7761
- const int neq0 = q->ne[0];
7762
- const int neq1 = q->ne[1];
7763
- const int neq2 = q->ne[2];
7764
- const int neq3 = q->ne[3];
8165
+ const int64_t neq0 = q->ne[0];
8166
+ const int64_t neq1 = q->ne[1];
8167
+ const int64_t neq2 = q->ne[2];
8168
+ const int64_t neq3 = q->ne[3];
7765
8169
 
7766
- const int nek0 = k->ne[0];
7767
- const int nek1 = k->ne[1];
7768
- //const int nek2 = k->ne[2];
7769
- //const int nek3 = k->ne[3];
8170
+ const int64_t nek0 = k->ne[0];
8171
+ const int64_t nek1 = k->ne[1];
8172
+ //const int64_t nek2 = k->ne[2];
8173
+ //const int64_t nek3 = k->ne[3];
7770
8174
 
7771
- //const int nev0 = v->ne[0];
7772
- const int nev1 = v->ne[1];
7773
- //const int nev2 = v->ne[2];
7774
- //const int nev3 = v->ne[3];
8175
+ //const int64_t nev0 = v->ne[0];
8176
+ const int64_t nev1 = v->ne[1];
8177
+ //const int64_t nev2 = v->ne[2];
8178
+ //const int64_t nev3 = v->ne[3];
7775
8179
 
7776
- const int ne0 = dst->ne[0];
7777
- const int ne1 = dst->ne[1];
7778
- //const int ne2 = dst->ne[2];
7779
- //const int ne3 = dst->ne[3];
8180
+ const int64_t ne0 = dst->ne[0];
8181
+ const int64_t ne1 = dst->ne[1];
8182
+ //const int64_t ne2 = dst->ne[2];
8183
+ //const int64_t ne3 = dst->ne[3];
7780
8184
 
7781
8185
  const int nbk0 = k->nb[0];
7782
8186
  const int nbk1 = k->nb[1];
@@ -7801,10 +8205,10 @@ static void ggml_compute_forward_flash_attn_f16(
7801
8205
  const int ith = params->ith;
7802
8206
  const int nth = params->nth;
7803
8207
 
7804
- const int D = neq0;
7805
- const int N = neq1;
7806
- const int P = nek1 - N;
7807
- const int M = P + N;
8208
+ const int64_t D = neq0;
8209
+ const int64_t N = neq1;
8210
+ const int64_t P = nek1 - N;
8211
+ const int64_t M = P + N;
7808
8212
 
7809
8213
  const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
7810
8214
 
@@ -7867,7 +8271,7 @@ static void ggml_compute_forward_flash_attn_f16(
7867
8271
  }
7868
8272
 
7869
8273
  if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
7870
- for (int ic = 0; ic < nek1; ++ic) {
8274
+ for (int64_t ic = 0; ic < nek1; ++ic) {
7871
8275
  // k indices
7872
8276
  const int ik3 = iq3;
7873
8277
  const int ik2 = iq2;
@@ -7882,7 +8286,7 @@ static void ggml_compute_forward_flash_attn_f16(
7882
8286
  (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
7883
8287
  }
7884
8288
  } else {
7885
- for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
8289
+ for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
7886
8290
  // k indices
7887
8291
  const int ik3 = iq3;
7888
8292
  const int ik2 = iq2;
@@ -7902,7 +8306,7 @@ static void ggml_compute_forward_flash_attn_f16(
7902
8306
  ggml_vec_scale_f32(nek1, S, scale);
7903
8307
 
7904
8308
  if (masked) {
7905
- for (int i = P; i < M; i++) {
8309
+ for (int64_t i = P; i < M; i++) {
7906
8310
  if (i > P + iq1) {
7907
8311
  S[i] = -INFINITY;
7908
8312
  }
@@ -7962,12 +8366,12 @@ static void ggml_compute_forward_flash_attn_f16(
7962
8366
 
7963
8367
  ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
7964
8368
 
7965
- for (int i = 0; i < M; i++) {
8369
+ for (int64_t i = 0; i < M; i++) {
7966
8370
  S16[i] = GGML_FP32_TO_FP16(S[i]);
7967
8371
  }
7968
8372
 
7969
8373
  if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
7970
- for (int ic = 0; ic < nev1; ++ic) {
8374
+ for (int64_t ic = 0; ic < nev1; ++ic) {
7971
8375
  // dst indices
7972
8376
  const int i1 = iq1;
7973
8377
  const int i2 = iq2;
@@ -7979,7 +8383,7 @@ static void ggml_compute_forward_flash_attn_f16(
7979
8383
  S16);
7980
8384
  }
7981
8385
  } else {
7982
- for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
8386
+ for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
7983
8387
  // dst indices
7984
8388
  const int i1 = iq1;
7985
8389
  const int i2 = iq2;
@@ -8035,35 +8439,35 @@ static void ggml_compute_forward_flash_ff_f16(
8035
8439
  int64_t t0 = ggml_perf_time_us();
8036
8440
  UNUSED(t0);
8037
8441
 
8038
- const int nea0 = a->ne[0];
8039
- const int nea1 = a->ne[1];
8040
- const int nea2 = a->ne[2];
8041
- const int nea3 = a->ne[3];
8442
+ const int64_t nea0 = a->ne[0];
8443
+ const int64_t nea1 = a->ne[1];
8444
+ const int64_t nea2 = a->ne[2];
8445
+ const int64_t nea3 = a->ne[3];
8042
8446
 
8043
- const int neb00 = b0->ne[0];
8044
- const int neb01 = b0->ne[1];
8045
- //const int neb02 = b0->ne[2];
8046
- //const int neb03 = b0->ne[3];
8447
+ const int64_t neb00 = b0->ne[0];
8448
+ const int64_t neb01 = b0->ne[1];
8449
+ //const int64_t neb02 = b0->ne[2];
8450
+ //const int64_t neb03 = b0->ne[3];
8047
8451
 
8048
- const int neb10 = b1->ne[0];
8049
- const int neb11 = b1->ne[1];
8050
- //const int neb12 = b1->ne[2];
8051
- //const int neb13 = b1->ne[3];
8452
+ const int64_t neb10 = b1->ne[0];
8453
+ const int64_t neb11 = b1->ne[1];
8454
+ //const int64_t neb12 = b1->ne[2];
8455
+ //const int64_t neb13 = b1->ne[3];
8052
8456
 
8053
- const int nec00 = c0->ne[0];
8054
- const int nec01 = c0->ne[1];
8055
- //const int nec02 = c0->ne[2];
8056
- //const int nec03 = c0->ne[3];
8457
+ const int64_t nec00 = c0->ne[0];
8458
+ const int64_t nec01 = c0->ne[1];
8459
+ //const int64_t nec02 = c0->ne[2];
8460
+ //const int64_t nec03 = c0->ne[3];
8057
8461
 
8058
- const int nec10 = c1->ne[0];
8059
- const int nec11 = c1->ne[1];
8060
- //const int nec12 = c1->ne[2];
8061
- //const int nec13 = c1->ne[3];
8462
+ const int64_t nec10 = c1->ne[0];
8463
+ const int64_t nec11 = c1->ne[1];
8464
+ //const int64_t nec12 = c1->ne[2];
8465
+ //const int64_t nec13 = c1->ne[3];
8062
8466
 
8063
- const int ne0 = dst->ne[0];
8064
- const int ne1 = dst->ne[1];
8065
- const int ne2 = dst->ne[2];
8066
- //const int ne3 = dst->ne[3];
8467
+ const int64_t ne0 = dst->ne[0];
8468
+ const int64_t ne1 = dst->ne[1];
8469
+ const int64_t ne2 = dst->ne[2];
8470
+ //const int64_t ne3 = dst->ne[3];
8067
8471
 
8068
8472
  const int nba0 = a->nb[0];
8069
8473
  const int nba1 = a->nb[1];
@@ -8098,9 +8502,9 @@ static void ggml_compute_forward_flash_ff_f16(
8098
8502
  const int ith = params->ith;
8099
8503
  const int nth = params->nth;
8100
8504
 
8101
- const int D = nea0;
8102
- //const int N = nea1;
8103
- const int M = neb01;
8505
+ const int64_t D = nea0;
8506
+ //const int64_t N = nea1;
8507
+ const int64_t M = neb01;
8104
8508
 
8105
8509
  GGML_ASSERT(ne0 == nea0);
8106
8510
  GGML_ASSERT(ne1 == nea1);
@@ -8156,7 +8560,7 @@ static void ggml_compute_forward_flash_ff_f16(
8156
8560
 
8157
8561
  float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
8158
8562
 
8159
- for (int ic = 0; ic < neb01; ++ic) {
8563
+ for (int64_t ic = 0; ic < neb01; ++ic) {
8160
8564
  // b0 indices
8161
8565
  const int ib03 = ia3;
8162
8566
  const int ib02 = ia2;
@@ -8176,7 +8580,7 @@ static void ggml_compute_forward_flash_ff_f16(
8176
8580
 
8177
8581
  ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
8178
8582
 
8179
- for (int i = 0; i < M; i++) {
8583
+ for (int64_t i = 0; i < M; i++) {
8180
8584
  S16[i] = GGML_FP32_TO_FP16(S[i]);
8181
8585
  }
8182
8586
 
@@ -8188,7 +8592,7 @@ static void ggml_compute_forward_flash_ff_f16(
8188
8592
  const int i2 = ia2;
8189
8593
  const int i3 = ia3;
8190
8594
 
8191
- for (int ic = 0; ic < nec01; ++ic) {
8595
+ for (int64_t ic = 0; ic < nec01; ++ic) {
8192
8596
 
8193
8597
  ggml_vec_dot_f16(neb01,
8194
8598
  (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
@@ -9053,7 +9457,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9053
9457
  } break;
9054
9458
  case GGML_OP_ROPE:
9055
9459
  {
9056
- node->n_tasks = 1;
9460
+ node->n_tasks = n_threads;
9057
9461
  } break;
9058
9462
  case GGML_OP_CONV_1D_1S:
9059
9463
  case GGML_OP_CONV_1D_2S:
@@ -9091,7 +9495,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9091
9495
 
9092
9496
  size_t cur = 0;
9093
9497
 
9094
- const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
9498
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
9095
9499
 
9096
9500
  if (node->src1->type == GGML_TYPE_F32) {
9097
9501
  cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
@@ -9350,7 +9754,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9350
9754
 
9351
9755
  perf_total_per_op_us[node->op] += node->perf_time_us;
9352
9756
 
9353
- GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
9757
+ GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
9354
9758
  i,
9355
9759
  node->ne[0], node->ne[1], node->ne[2],
9356
9760
  GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -9364,7 +9768,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9364
9768
  for (int i = 0; i < cgraph->n_leafs; i++) {
9365
9769
  struct ggml_tensor * node = cgraph->leafs[i];
9366
9770
 
9367
- GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n",
9771
+ GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
9368
9772
  i,
9369
9773
  node->ne[0], node->ne[1],
9370
9774
  GGML_OP_LABEL[node->op]);
@@ -9435,7 +9839,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
9435
9839
 
9436
9840
  fprintf(fp, " \"%p\" [ \
9437
9841
  style = filled; fillcolor = %s; shape = record; \
9438
- label=\"%d [%d, %d] | <x>%s",
9842
+ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
9439
9843
  (void *) node, color,
9440
9844
  i, node->ne[0], node->ne[1],
9441
9845
  GGML_OP_SYMBOL[node->op]);
@@ -9460,7 +9864,7 @@ label=\"<x>%.1e\"; ]\n",
9460
9864
  } else {
9461
9865
  fprintf(fp, " \"%p\" [ \
9462
9866
  style = filled; fillcolor = %s; shape = record; \
9463
- label=\"<x>CONST %d [%d, %d]\"; ]\n",
9867
+ label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
9464
9868
  (void *) node, color,
9465
9869
  i, node->ne[0], node->ne[1]);
9466
9870
  }
@@ -9524,9 +9928,9 @@ label=\"<x>CONST %d [%d, %d]\"; ]\n",
9524
9928
  static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
9525
9929
  int i = 0;
9526
9930
  for (int p = 0; p < np; ++p) {
9527
- const int ne = ggml_nelements(ps[p]) ;
9931
+ const int64_t ne = ggml_nelements(ps[p]) ;
9528
9932
  // TODO: add function to set tensor from array
9529
- for (int j = 0; j < ne; ++j) {
9933
+ for (int64_t j = 0; j < ne; ++j) {
9530
9934
  ggml_set_f32_1d(ps[p], j, x[i++]);
9531
9935
  }
9532
9936
  }
@@ -9535,9 +9939,9 @@ static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const f
9535
9939
  static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
9536
9940
  int i = 0;
9537
9941
  for (int p = 0; p < np; ++p) {
9538
- const int ne = ggml_nelements(ps[p]) ;
9942
+ const int64_t ne = ggml_nelements(ps[p]) ;
9539
9943
  // TODO: add function to get all elements at once
9540
- for (int j = 0; j < ne; ++j) {
9944
+ for (int64_t j = 0; j < ne; ++j) {
9541
9945
  x[i++] = ggml_get_f32_1d(ps[p], j);
9542
9946
  }
9543
9947
  }
@@ -9546,9 +9950,9 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
9546
9950
  static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
9547
9951
  int i = 0;
9548
9952
  for (int p = 0; p < np; ++p) {
9549
- const int ne = ggml_nelements(ps[p]) ;
9953
+ const int64_t ne = ggml_nelements(ps[p]) ;
9550
9954
  // TODO: add function to get all elements at once
9551
- for (int j = 0; j < ne; ++j) {
9955
+ for (int64_t j = 0; j < ne; ++j) {
9552
9956
  g[i++] = ggml_get_f32_1d(ps[p]->grad, j);
9553
9957
  }
9554
9958
  }
@@ -10146,6 +10550,7 @@ enum ggml_opt_result ggml_opt(
10146
10550
  struct ggml_init_params params_ctx = {
10147
10551
  .mem_size = 16*1024*1024,
10148
10552
  .mem_buffer = NULL,
10553
+ .no_alloc = false,
10149
10554
  };
10150
10555
 
10151
10556
  ctx = ggml_init(params_ctx);