llama_cpp 0.0.2 → 0.0.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,4 +1,4 @@
1
- // Defines CLOCK_MONOTONIC and asprintf on Linux
1
+ // Defines CLOCK_MONOTONIC on Linux
2
2
  #define _GNU_SOURCE
3
3
 
4
4
  #include "ggml.h"
@@ -16,6 +16,7 @@
16
16
  #include <stdlib.h>
17
17
  #include <string.h>
18
18
  #include <stdint.h>
19
+ #include <inttypes.h>
19
20
  #include <stdio.h>
20
21
  #include <float.h>
21
22
 
@@ -25,14 +26,9 @@
25
26
  #define static_assert(cond, msg) struct global_scope_noop_trick
26
27
  #endif
27
28
 
28
- #if defined _MSC_VER || defined(__MINGW32__)
29
+ #if defined(_WIN32)
29
30
 
30
- #if !defined(__MINGW32__)
31
- #include <Windows.h>
32
- #else
33
- // ref: https://github.com/ggerganov/whisper.cpp/issues/168
34
31
  #include <windows.h>
35
- #endif
36
32
 
37
33
  typedef volatile LONG atomic_int;
38
34
  typedef atomic_int atomic_bool;
@@ -54,6 +50,7 @@ typedef HANDLE pthread_t;
54
50
 
55
51
  typedef DWORD thread_ret_t;
56
52
  static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
53
+ (void) unused;
57
54
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
58
55
  if (handle == NULL)
59
56
  {
@@ -65,6 +62,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
65
62
  }
66
63
 
67
64
  static int pthread_join(pthread_t thread, void* unused) {
65
+ (void) unused;
68
66
  return (int) WaitForSingleObject(thread, INFINITE);
69
67
  }
70
68
 
@@ -96,17 +94,6 @@ typedef void* thread_ret_t;
96
94
  #define static_assert(cond, msg) _Static_assert(cond, msg)
97
95
  #endif
98
96
 
99
- #define GGML_MLOCK_SUPPORT 0
100
-
101
- #ifdef __has_include
102
- #if __has_include(<sys/mman.h>)
103
- #undef GGML_MLOCK_SUPPORT
104
- #define GGML_MLOCK_SUPPORT 1
105
- #include <sys/mman.h>
106
- #endif
107
- #endif
108
-
109
-
110
97
  /*#define GGML_PERF*/
111
98
  #define GGML_DEBUG 0
112
99
  #define GGML_GELU_FP16
@@ -127,6 +114,14 @@ typedef void* thread_ret_t;
127
114
  #define GGML_MEM_ALIGN 16
128
115
  #endif
129
116
 
117
+ #if defined(_MSC_VER) || defined(__MINGW32__)
118
+ #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
+ #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
+ #else
121
+ #define GGML_ALIGNED_MALLOC(size) aligned_alloc(GGML_MEM_ALIGN, size)
122
+ #define GGML_ALIGNED_FREE(ptr) free(ptr)
123
+ #endif
124
+
130
125
  #define UNUSED(x) (void)(x)
131
126
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
132
127
 
@@ -241,12 +236,12 @@ static inline float fp32_from_bits(uint32_t w) {
241
236
  }
242
237
 
243
238
  static inline uint32_t fp32_to_bits(float f) {
244
- union {
245
- float as_value;
246
- uint32_t as_bits;
247
- } fp32;
248
- fp32.as_value = f;
249
- return fp32.as_bits;
239
+ union {
240
+ float as_value;
241
+ uint32_t as_bits;
242
+ } fp32;
243
+ fp32.as_value = f;
244
+ return fp32.as_bits;
250
245
  }
251
246
 
252
247
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
@@ -496,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
496
491
  }
497
492
  #endif
498
493
 
494
+ #if __ARM_NEON
495
+
496
+ #if !defined(__aarch64__)
497
+
498
+ inline static uint16_t vaddvq_u8(uint8x16_t v) {
499
+ return
500
+ (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
501
+ (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
502
+ (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
503
+ (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
504
+ (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
505
+ (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
506
+ (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
507
+ (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
508
+ }
509
+
510
+ inline static int32_t vaddvq_s16(int16x8_t v) {
511
+ return
512
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
513
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
514
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
515
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
516
+ }
517
+
518
+ inline static uint32_t vaddvq_u16(uint16x8_t v) {
519
+ return
520
+ (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
521
+ (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
522
+ (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
523
+ (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
524
+ }
525
+
526
+ inline static int32_t vaddvq_s32(int32x4_t v) {
527
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
528
+ }
529
+
530
+ inline static float vaddvq_f32(float32x4_t v) {
531
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
532
+ }
533
+
534
+ inline float vminvq_f32(float32x4_t v) {
535
+ return
536
+ MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
537
+ MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
538
+ }
539
+
540
+ inline float vmaxvq_f32(float32x4_t v) {
541
+ return
542
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
543
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
544
+ }
545
+
546
+ inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
547
+ return vget_low_s8(vcombine_s8(a, b));
548
+ }
549
+
550
+ inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
551
+ return vget_high_s8(vcombine_s8(a, b));
552
+ }
553
+
554
+ inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
555
+ return vget_low_u8(vcombine_u8(a, b));
556
+ }
557
+
558
+ inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
559
+ return vget_high_u8(vcombine_u8(a, b));
560
+ }
561
+
562
+ #endif
563
+ #endif
564
+
499
565
  // method 5
500
566
  // blocks of QK elements
501
567
  // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -609,10 +675,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
609
675
  for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
610
676
  for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
611
677
 
612
- // absolute max
613
- const float amax = MAX(
614
- MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
615
- MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
678
+ const float amax = vmaxvq_f32(amaxv[0]);
616
679
 
617
680
  const float d = amax / ((1 << 3) - 1);
618
681
  const float id = d ? 1.0f/d : 0.0f;
@@ -934,7 +997,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
934
997
  float32x4_t minv[8];
935
998
  float32x4_t maxv[8];
936
999
 
937
- for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1000
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
938
1001
 
939
1002
  for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
940
1003
  for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -957,7 +1020,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
957
1020
 
958
1021
  for (int l = 0; l < 8; l++) {
959
1022
  const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
960
- const int32x4_t vi = vcvtq_s32_f32(v);
1023
+ const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
1024
+ const int32x4_t vi = vcvtq_s32_f32(vf);
961
1025
 
962
1026
  y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
963
1027
  y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
@@ -1225,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1225
1289
  #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
1226
1290
  #define GGML_F32x4_ADD vaddq_f32
1227
1291
  #define GGML_F32x4_MUL vmulq_f32
1228
- #if defined(__ARM_FEATURE_QRDMX)
1229
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1230
- #else
1231
- #define GGML_F32x4_REDUCE_ONE(x) \
1232
- (vgetq_lane_f32(x, 0) + \
1233
- vgetq_lane_f32(x, 1) + \
1234
- vgetq_lane_f32(x, 2) + \
1235
- vgetq_lane_f32(x, 3))
1236
- #endif
1292
+ #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1237
1293
  #define GGML_F32x4_REDUCE(res, x) \
1238
1294
  { \
1239
1295
  for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1856,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1856
1912
  // 4-bit -> 8-bit
1857
1913
  const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
1858
1914
  const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
1859
-
1860
1915
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
1861
1916
  const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
1862
1917
 
1863
1918
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
1864
1919
  const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
1865
-
1866
1920
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
1867
1921
  const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
1868
1922
 
1869
1923
  // sub 8
1870
1924
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
1871
1925
  const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1872
-
1873
1926
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
1874
1927
  const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
1875
1928
 
1876
1929
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
1877
1930
  const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1878
-
1879
1931
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
1880
1932
  const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
1881
1933
 
1882
1934
  #if defined(__ARM_FEATURE_DOTPROD)
1883
- // dot product into int16x8_t
1935
+ // dot product into int32x4_t
1884
1936
  int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1885
1937
  int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
1886
1938
 
1887
1939
  p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1888
1940
  p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
1889
1941
 
1890
- // scalar
1891
- #if defined(__ARM_FEATURE_QRDMX)
1892
- sum0 += x0->d * y0->d * vaddvq_s32(p_0);
1893
- sum1 += x1->d * y1->d * vaddvq_s32(p_1);
1894
- #else
1895
- sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
1896
- sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
1897
- #endif
1942
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
1943
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
1898
1944
  #else
1899
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1945
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1900
1946
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1901
-
1902
1947
  const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
1903
1948
  const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
1904
1949
 
1905
1950
  const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
1906
1951
  const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1907
-
1908
1952
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1909
1953
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1910
1954
 
@@ -1917,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1917
1961
  const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1918
1962
  const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1919
1963
 
1920
- // scalar
1921
- #if defined(__ARM_FEATURE_QRDMX)
1922
- sum0 += x0->d * y0->d * vaddvq_s16(p_0);
1923
- sum1 += x1->d * y1->d * vaddvq_s16(p_1);
1924
- #else
1925
- sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
1926
- sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
1927
- #endif
1964
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
1965
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
1928
1966
  #endif
1929
1967
  }
1930
1968
 
@@ -1961,41 +1999,68 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1961
1999
  // Initialize accumulator with zeros
1962
2000
  __m256 acc = _mm256_setzero_ps();
1963
2001
 
1964
- // Main loop
1965
- // TODO: figure a way to do this in a portable way
1966
- #ifdef __GNUC__
1967
- #pragma GCC unroll 16
1968
- #endif
1969
- for (int i = 0; i < nb; ++i) {
1970
- // Compute combined scale for the block
1971
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
1972
-
1973
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1974
- __m256i bx = bytesFromNibbles( x[i].qs );
1975
- __m256i by = bytesFromNibbles( y[i].qs );
1976
-
1977
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1978
- const __m256i off = _mm256_set1_epi8( 8 );
1979
- bx = _mm256_sub_epi8( bx, off );
1980
- by = _mm256_sub_epi8( by, off );
1981
-
1982
- // Get absolute values of x vectors
1983
- const __m256i ax = _mm256_sign_epi8(bx, bx);
1984
-
1985
- // Sign the values of the y vectors
1986
- const __m256i sy = _mm256_sign_epi8(by, bx);
1987
-
1988
- // Perform multiplication and create 16-bit values
1989
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
1990
-
1991
- const __m256i ones = _mm256_set1_epi16(1);
1992
- const __m256i i32 = _mm256_madd_epi16(ones, dot);
2002
+ /* Prepare the constants we will need during execution */
2003
+ const __m256i lowMask = _mm256_set1_epi8( 0xF );
2004
+ const __m256i offset_8 = _mm256_set1_epi16( 8 );
1993
2005
 
1994
- // Convert int32_t to float
1995
- const __m256 p = _mm256_cvtepi32_ps( i32 );
2006
+ #define UNROLL_COUNT 8
2007
+ // make sure we only unroll multiples of the block count
2008
+ assert(nb % UNROLL_COUNT == 0);
1996
2009
 
1997
- // Apply the scale, and accumulate
1998
- acc = _mm256_fmadd_ps( d, p, acc );
2010
+ // Main loop
2011
+ for (int i = 0; i < nb; i+=UNROLL_COUNT) {
2012
+ // This loop will be unrolled by the compiler
2013
+ for (int u=0;u<UNROLL_COUNT;u++) {
2014
+ /* Compute combined scale for the block */
2015
+ const __m256 scale = _mm256_mul_ps(
2016
+ _mm256_broadcast_ss( &x[i+u].d ),
2017
+ _mm256_broadcast_ss( &y[i+u].d ) );
2018
+
2019
+ /* get input from x
2020
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
2021
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2022
+
2023
+ /* Load 16 bytes from memory */
2024
+ const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2025
+ /* Expand bytes into uint16_t values */
2026
+ const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2027
+ /* Unpack values into individual bytes */
2028
+ __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
2029
+ const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
2030
+ __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2031
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2032
+ x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2033
+ x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2034
+
2035
+ /* get input from y
2036
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
2037
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2038
+
2039
+ /* Load 16 bytes from memory */
2040
+ const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2041
+ /* Expand bytes into uint16_t values */
2042
+ const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2043
+ /* Unpack values into individual bytes */
2044
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2045
+ __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2046
+ __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2047
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2048
+ y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2049
+ y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2050
+
2051
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
2052
+ __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2053
+ __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2054
+
2055
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2056
+ __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2057
+
2058
+ /* Convert to vectore of 8 int32_t to 8 floats */
2059
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2060
+
2061
+ /* Multiply q with scale and accumulate */
2062
+ acc = _mm256_fmadd_ps( scale, q, acc );
2063
+ }
1999
2064
  }
2000
2065
 
2001
2066
  // Return horizontal sum of the acc vector
@@ -2025,7 +2090,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2025
2090
  bx = _mm_sub_epi8( bx, off );
2026
2091
  by = _mm_sub_epi8( by, off );
2027
2092
 
2028
- // Get absolute values of x vectors
2093
+ // Get absolute values of x vectors
2029
2094
  const __m128i ax = _mm_sign_epi8(bx, bx);
2030
2095
 
2031
2096
  // Sign the values of the y vectors
@@ -2057,18 +2122,18 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2057
2122
  float sum1 = 0.0f;
2058
2123
 
2059
2124
  for (int i = 0; i < nb; i += 2) {
2060
- const block_q4_0 * restrict x0 = &px[i + 0];
2061
- const block_q4_0 * restrict y0 = &py[i + 0];
2062
- const block_q4_0 * restrict x1 = &px[i + 1];
2063
- const block_q4_0 * restrict y1 = &py[i + 1];
2125
+ const block_q4_0 * restrict x0 = &x[i + 0];
2126
+ const block_q4_0 * restrict y0 = &y[i + 0];
2127
+ const block_q4_0 * restrict x1 = &x[i + 1];
2128
+ const block_q4_0 * restrict y1 = &y[i + 1];
2064
2129
 
2065
2130
  const v128_t m4b = wasm_u8x16_splat(0xf);
2066
2131
  const v128_t s8b = wasm_i8x16_splat(0x8);
2067
2132
 
2068
- const v128_t v0_0 = wasm_v128_load(x0.qs);
2069
- const v128_t v0_1 = wasm_v128_load(y0.qs);
2070
- const v128_t v1_0 = wasm_v128_load(x1.qs);
2071
- const v128_t v1_1 = wasm_v128_load(y1.qs);
2133
+ const v128_t v0_0 = wasm_v128_load(x0->qs);
2134
+ const v128_t v0_1 = wasm_v128_load(y0->qs);
2135
+ const v128_t v1_0 = wasm_v128_load(x1->qs);
2136
+ const v128_t v1_1 = wasm_v128_load(y1->qs);
2072
2137
 
2073
2138
  // 4-bit -> 8-bit
2074
2139
  const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@@ -2140,18 +2205,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2140
2205
  const uint8_t * restrict p0 = x[i].qs;
2141
2206
  const uint8_t * restrict p1 = y[i].qs;
2142
2207
 
2208
+ int sumi = 0;
2143
2209
  for (int j = 0; j < QK/2; j++) {
2144
2210
  const uint8_t v0 = p0[j];
2145
2211
  const uint8_t v1 = p1[j];
2146
2212
 
2147
- const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
2148
- const float f1 = d0*((int8_t) (v0 >> 4) - 8);
2213
+ const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
2214
+ const int8_t i1 = (int8_t) (v0 >> 4) - 8;
2149
2215
 
2150
- const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
2151
- const float f3 = d1*((int8_t) (v1 >> 4) - 8);
2216
+ const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
2217
+ const int8_t i3 = (int8_t) (v1 >> 4) - 8;
2152
2218
 
2153
- sumf += f0*f2 + f1*f3;
2219
+ sumi += i0*i2 + i1*i3;
2154
2220
  }
2221
+ sumf += d0 * d1 * sumi;
2155
2222
  }
2156
2223
  #endif
2157
2224
 
@@ -2243,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2243
2310
  float sum10 = 0.0f;
2244
2311
  float sum11 = 0.0f;
2245
2312
 
2246
- for (int i = 0; i < nb; ++i) {
2313
+ for (int i = 0; i < nb; i += 2) {
2247
2314
  const block_q4_1 * restrict x0 = &x[i + 0];
2248
2315
  const block_q4_1 * restrict y0 = &y[i + 0];
2316
+ const block_q4_1 * restrict x1 = &x[i + 1];
2317
+ const block_q4_1 * restrict y1 = &y[i + 1];
2249
2318
 
2250
2319
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2251
2320
 
2252
2321
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2253
2322
  const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2323
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2324
+ const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2254
2325
 
2255
- // and with 0xf
2326
+ // 4-bit -> 8-bit
2256
2327
  const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2257
2328
  const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2258
-
2259
2329
  const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2260
2330
  const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2261
2331
 
2262
- // dot product into uint16x8_t
2332
+ const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2333
+ const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2334
+ const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2335
+ const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2336
+
2337
+ sum00 += x0->m*y0->m;
2338
+ sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2339
+ sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2340
+
2341
+ sum00 += x1->m*y1->m;
2342
+ sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2343
+ sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2344
+
2345
+ #if defined(__ARM_FEATURE_DOTPROD)
2346
+ // dot product into int32x4_t
2347
+ uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2348
+ uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2349
+
2350
+ p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2351
+ p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2352
+
2353
+ sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2354
+ sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2355
+ #else
2263
2356
  const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2264
2357
  const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2265
-
2266
2358
  const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2267
2359
  const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2268
2360
 
2269
- const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2270
- const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
2361
+ const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2362
+ const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2363
+ const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2364
+ const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2271
2365
 
2272
- sum00 += x0->m*y0->m;
2273
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2274
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2275
- sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
2366
+ const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2367
+ const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2368
+
2369
+ const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2370
+ const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2371
+
2372
+ const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2373
+ const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2374
+
2375
+ sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2376
+ sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2377
+ #endif
2276
2378
  }
2277
2379
 
2278
2380
  sumf = QK*sum00 + sum01 + sum10 + sum11;
@@ -2548,29 +2650,38 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2548
2650
  //
2549
2651
 
2550
2652
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2551
- QK,
2552
- QK,
2553
- 1,
2554
- 1,
2555
- 1,
2556
- 1,
2557
- 1,
2653
+ [GGML_TYPE_F32] = 1,
2654
+ [GGML_TYPE_F16] = 1,
2655
+ [GGML_TYPE_Q4_0] = QK,
2656
+ [GGML_TYPE_Q4_1] = QK,
2657
+ [GGML_TYPE_I8] = 1,
2658
+ [GGML_TYPE_I16] = 1,
2659
+ [GGML_TYPE_I32] = 1,
2558
2660
  };
2559
-
2560
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2661
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
2561
2662
 
2562
2663
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2563
- sizeof(block_q4_0),
2564
- sizeof(block_q4_1),
2565
- sizeof(int8_t ),
2566
- sizeof(int16_t),
2567
- sizeof(int32_t),
2568
- sizeof(ggml_fp16_t),
2569
- sizeof(float ),
2664
+ [GGML_TYPE_F32] = sizeof(float),
2665
+ [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
2666
+ [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
2667
+ [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
2668
+ [GGML_TYPE_I8] = sizeof(int8_t),
2669
+ [GGML_TYPE_I16] = sizeof(int16_t),
2670
+ [GGML_TYPE_I32] = sizeof(int32_t),
2570
2671
  };
2672
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
2571
2673
 
2572
- // don't forget to update the array above when adding new types
2573
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2674
+
2675
+ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
2676
+ [GGML_TYPE_F32] = "f32",
2677
+ [GGML_TYPE_F16] = "f16",
2678
+ [GGML_TYPE_Q4_0] = "q4_0",
2679
+ [GGML_TYPE_Q4_1] = "q4_1",
2680
+ [GGML_TYPE_I8] = "i8",
2681
+ [GGML_TYPE_I16] = "i16",
2682
+ [GGML_TYPE_I32] = "i32",
2683
+ };
2684
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated");
2574
2685
 
2575
2686
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2576
2687
  "NONE",
@@ -2599,6 +2710,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2599
2710
 
2600
2711
  "SCALE",
2601
2712
  "CPY",
2713
+ "CONT",
2602
2714
  "RESHAPE",
2603
2715
  "VIEW",
2604
2716
  "PERMUTE",
@@ -2612,9 +2724,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2612
2724
 
2613
2725
  "FLASH_ATTN",
2614
2726
  "FLASH_FF",
2727
+
2728
+ "MAP_UNARY",
2729
+ "MAP_BINARY",
2615
2730
  };
2616
2731
 
2617
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2732
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2618
2733
 
2619
2734
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2620
2735
  "none",
@@ -2643,6 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2643
2758
 
2644
2759
  "x*v",
2645
2760
  "x-\\>y",
2761
+ "cont(x)",
2646
2762
  "reshape(x)",
2647
2763
  "view(x)",
2648
2764
  "permute(x)",
@@ -2656,24 +2772,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2656
2772
 
2657
2773
  "flash_attn(x)",
2658
2774
  "flash_ff(x)",
2659
- };
2660
2775
 
2661
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2662
-
2663
- //
2664
- // ggml object
2665
- //
2666
-
2667
- struct ggml_object {
2668
- size_t offs;
2669
- size_t size;
2670
-
2671
- struct ggml_object * next;
2672
-
2673
- char padding[8];
2776
+ "f(x)",
2777
+ "f(x,y)",
2674
2778
  };
2675
2779
 
2676
- static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
2780
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2677
2781
 
2678
2782
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
2679
2783
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -2686,7 +2790,6 @@ struct ggml_context {
2686
2790
  size_t mem_size;
2687
2791
  void * mem_buffer;
2688
2792
  bool mem_buffer_owned;
2689
- bool mem_buffer_mlocked;
2690
2793
  bool no_alloc;
2691
2794
 
2692
2795
  int n_objects;
@@ -2774,7 +2877,7 @@ void ggml_print_objects(const struct ggml_context * ctx) {
2774
2877
  GGML_PRINT("%s: --- end ---\n", __func__);
2775
2878
  }
2776
2879
 
2777
- int ggml_nelements(const struct ggml_tensor * tensor) {
2880
+ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
2778
2881
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
2779
2882
 
2780
2883
  return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@@ -2804,6 +2907,11 @@ float ggml_type_sizef(enum ggml_type type) {
2804
2907
  return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
2805
2908
  }
2806
2909
 
2910
+ const char * ggml_type_name(enum ggml_type type) {
2911
+ return GGML_TYPE_NAME[type];
2912
+ }
2913
+
2914
+
2807
2915
  size_t ggml_element_size(const struct ggml_tensor * tensor) {
2808
2916
  return GGML_TYPE_SIZE[tensor->type];
2809
2917
  }
@@ -2969,11 +3077,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2969
3077
  return NULL;
2970
3078
  }
2971
3079
 
3080
+ const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
3081
+
2972
3082
  *ctx = (struct ggml_context) {
2973
- /*.mem_size =*/ params.mem_size,
2974
- /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
3083
+ /*.mem_size =*/ mem_size,
3084
+ /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
2975
3085
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
2976
- /*.mem_buffer_mlocked =*/ false,
2977
3086
  /*.no_alloc =*/ params.no_alloc,
2978
3087
  /*.n_objects =*/ 0,
2979
3088
  /*.objects_begin =*/ NULL,
@@ -2982,7 +3091,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2982
3091
  /*.scratch_save =*/ { 0, 0, NULL, },
2983
3092
  };
2984
3093
 
2985
- GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
3094
+ GGML_ASSERT(ctx->mem_buffer != NULL);
2986
3095
 
2987
3096
  ggml_assert_aligned(ctx->mem_buffer);
2988
3097
 
@@ -3006,16 +3115,8 @@ void ggml_free(struct ggml_context * ctx) {
3006
3115
  GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
3007
3116
  __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
3008
3117
 
3009
- #if GGML_MLOCK_SUPPORT
3010
- if (ctx->mem_buffer_mlocked) {
3011
- if (munlock(ctx->mem_buffer, ctx->mem_size)) {
3012
- fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
3013
- }
3014
- }
3015
- #endif
3016
-
3017
3118
  if (ctx->mem_buffer_owned) {
3018
- free(ctx->mem_buffer);
3119
+ GGML_ALIGNED_FREE(ctx->mem_buffer);
3019
3120
  }
3020
3121
 
3021
3122
  found = true;
@@ -3042,55 +3143,13 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
3042
3143
  return result;
3043
3144
  }
3044
3145
 
3045
- #ifdef __APPLE__
3046
- #define MLOCK_SUGGESTION \
3047
- "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
3048
- "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
3049
- #else
3050
- #define MLOCK_SUGGESTION \
3051
- "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
3052
- #endif
3053
-
3054
- bool ggml_mlock_supported(void) {
3055
- return GGML_MLOCK_SUPPORT;
3056
- }
3057
-
3058
- bool ggml_mlock(
3059
- struct ggml_context * ctx,
3060
- const void *opt_extra_addr,
3061
- size_t opt_extra_len,
3062
- char **err_p) {
3063
- // TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
3064
- #if GGML_MLOCK_SUPPORT
3065
- if (ctx->mem_buffer_mlocked) {
3066
- return true;
3067
- }
3068
- if (mlock(ctx->mem_buffer, ctx->mem_size) ||
3069
- (opt_extra_len &&
3070
- mlock(opt_extra_addr, opt_extra_len))) {
3071
- if ((*err_p = malloc(1024))) {
3072
- snprintf(*err_p, 1024,
3073
- "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
3074
- ctx->mem_size + opt_extra_len,
3075
- strerror(errno));
3076
- }
3077
- return false;
3078
- }
3079
- ctx->mem_buffer_mlocked = true;
3080
- return true;
3081
- #else // GGML_MLOCK_SUPPORT
3082
- *err_p = strdup("can't mlock because it's not supported on this system");
3083
- return false;
3084
- #endif // GGML_MLOCK_SUPPORT
3085
- }
3086
-
3087
3146
  ////////////////////////////////////////////////////////////////////////////////
3088
3147
 
3089
3148
  struct ggml_tensor * ggml_new_tensor_impl(
3090
3149
  struct ggml_context * ctx,
3091
3150
  enum ggml_type type,
3092
3151
  int n_dims,
3093
- const int* ne,
3152
+ const int64_t* ne,
3094
3153
  void* data) {
3095
3154
  // always insert objects at the end of the context's memory pool
3096
3155
  struct ggml_object * obj_cur = ctx->objects_end;
@@ -3189,7 +3248,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
3189
3248
  /*.pad =*/ { 0 },
3190
3249
  };
3191
3250
 
3192
- ggml_assert_aligned(result->data);
3251
+ // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
3252
+ //ggml_assert_aligned(result->data);
3193
3253
 
3194
3254
  for (int i = 0; i < n_dims; i++) {
3195
3255
  result->ne[i] = ne[i];
@@ -3210,44 +3270,44 @@ struct ggml_tensor * ggml_new_tensor(
3210
3270
  struct ggml_context * ctx,
3211
3271
  enum ggml_type type,
3212
3272
  int n_dims,
3213
- const int * ne) {
3273
+ const int64_t * ne) {
3214
3274
  return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
3215
3275
  }
3216
3276
 
3217
3277
  struct ggml_tensor * ggml_new_tensor_1d(
3218
3278
  struct ggml_context * ctx,
3219
3279
  enum ggml_type type,
3220
- int ne0) {
3280
+ int64_t ne0) {
3221
3281
  return ggml_new_tensor(ctx, type, 1, &ne0);
3222
3282
  }
3223
3283
 
3224
3284
  struct ggml_tensor * ggml_new_tensor_2d(
3225
3285
  struct ggml_context * ctx,
3226
3286
  enum ggml_type type,
3227
- int ne0,
3228
- int ne1) {
3229
- const int ne[2] = { ne0, ne1 };
3287
+ int64_t ne0,
3288
+ int64_t ne1) {
3289
+ const int64_t ne[2] = { ne0, ne1 };
3230
3290
  return ggml_new_tensor(ctx, type, 2, ne);
3231
3291
  }
3232
3292
 
3233
3293
  struct ggml_tensor * ggml_new_tensor_3d(
3234
3294
  struct ggml_context * ctx,
3235
3295
  enum ggml_type type,
3236
- int ne0,
3237
- int ne1,
3238
- int ne2) {
3239
- const int ne[3] = { ne0, ne1, ne2 };
3296
+ int64_t ne0,
3297
+ int64_t ne1,
3298
+ int64_t ne2) {
3299
+ const int64_t ne[3] = { ne0, ne1, ne2 };
3240
3300
  return ggml_new_tensor(ctx, type, 3, ne);
3241
3301
  }
3242
3302
 
3243
3303
  struct ggml_tensor * ggml_new_tensor_4d(
3244
3304
  struct ggml_context * ctx,
3245
3305
  enum ggml_type type,
3246
- int ne0,
3247
- int ne1,
3248
- int ne2,
3249
- int ne3) {
3250
- const int ne[4] = { ne0, ne1, ne2, ne3 };
3306
+ int64_t ne0,
3307
+ int64_t ne1,
3308
+ int64_t ne2,
3309
+ int64_t ne3) {
3310
+ const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
3251
3311
  return ggml_new_tensor(ctx, type, 4, ne);
3252
3312
  }
3253
3313
 
@@ -3590,7 +3650,14 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
3590
3650
  struct ggml_tensor * ggml_view_tensor(
3591
3651
  struct ggml_context * ctx,
3592
3652
  const struct ggml_tensor * src) {
3593
- return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
3653
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
3654
+
3655
+ result->nb[0] = src->nb[0];
3656
+ result->nb[1] = src->nb[1];
3657
+ result->nb[2] = src->nb[2];
3658
+ result->nb[3] = src->nb[3];
3659
+
3660
+ return result;
3594
3661
  }
3595
3662
 
3596
3663
  ////////////////////////////////////////////////////////////////////////////////
@@ -3894,7 +3961,7 @@ struct ggml_tensor * ggml_mean(
3894
3961
  is_node = true;
3895
3962
  }
3896
3963
 
3897
- int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
3964
+ int64_t ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
3898
3965
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne);
3899
3966
 
3900
3967
  result->op = GGML_OP_MEAN;
@@ -4255,7 +4322,7 @@ struct ggml_tensor * ggml_mul_mat(
4255
4322
  is_node = true;
4256
4323
  }
4257
4324
 
4258
- const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
4325
+ const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
4259
4326
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
4260
4327
 
4261
4328
  result->op = GGML_OP_MUL_MAT;
@@ -4350,6 +4417,41 @@ struct ggml_tensor * ggml_cpy_inplace(
4350
4417
  return ggml_cpy_impl(ctx, a, b, true);
4351
4418
  }
4352
4419
 
4420
+ // ggml_cont
4421
+
4422
+ struct ggml_tensor * ggml_cont_impl(
4423
+ struct ggml_context * ctx,
4424
+ struct ggml_tensor * a,
4425
+ bool inplace) {
4426
+ bool is_node = false;
4427
+
4428
+ if (!inplace && a->grad) {
4429
+ GGML_ASSERT(false); // TODO: implement backward
4430
+ is_node = true;
4431
+ }
4432
+
4433
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4434
+
4435
+ result->op = GGML_OP_CONT;
4436
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4437
+ result->src0 = a;
4438
+ result->src1 = NULL;
4439
+
4440
+ return result;
4441
+ }
4442
+
4443
+ struct ggml_tensor * ggml_cont(
4444
+ struct ggml_context * ctx,
4445
+ struct ggml_tensor * a) {
4446
+ return ggml_cont_impl(ctx, a, false);
4447
+ }
4448
+
4449
+ struct ggml_tensor * ggml_cont_inplace(
4450
+ struct ggml_context * ctx,
4451
+ struct ggml_tensor * a) {
4452
+ return ggml_cont_impl(ctx, a, true);
4453
+ }
4454
+
4353
4455
  // ggml_reshape
4354
4456
 
4355
4457
  struct ggml_tensor * ggml_reshape(
@@ -4380,8 +4482,8 @@ struct ggml_tensor * ggml_reshape(
4380
4482
  struct ggml_tensor * ggml_reshape_2d(
4381
4483
  struct ggml_context * ctx,
4382
4484
  struct ggml_tensor * a,
4383
- int ne0,
4384
- int ne1) {
4485
+ int64_t ne0,
4486
+ int64_t ne1) {
4385
4487
  GGML_ASSERT(ggml_is_contiguous(a));
4386
4488
  GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
4387
4489
 
@@ -4392,7 +4494,7 @@ struct ggml_tensor * ggml_reshape_2d(
4392
4494
  is_node = true;
4393
4495
  }
4394
4496
 
4395
- const int ne[2] = { ne0, ne1 };
4497
+ const int64_t ne[2] = { ne0, ne1 };
4396
4498
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
4397
4499
 
4398
4500
  result->op = GGML_OP_RESHAPE;
@@ -4406,9 +4508,9 @@ struct ggml_tensor * ggml_reshape_2d(
4406
4508
  struct ggml_tensor * ggml_reshape_3d(
4407
4509
  struct ggml_context * ctx,
4408
4510
  struct ggml_tensor * a,
4409
- int ne0,
4410
- int ne1,
4411
- int ne2) {
4511
+ int64_t ne0,
4512
+ int64_t ne1,
4513
+ int64_t ne2) {
4412
4514
  GGML_ASSERT(ggml_is_contiguous(a));
4413
4515
  GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
4414
4516
 
@@ -4419,7 +4521,7 @@ struct ggml_tensor * ggml_reshape_3d(
4419
4521
  is_node = true;
4420
4522
  }
4421
4523
 
4422
- const int ne[3] = { ne0, ne1, ne2 };
4524
+ const int64_t ne[3] = { ne0, ne1, ne2 };
4423
4525
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
4424
4526
 
4425
4527
  result->op = GGML_OP_RESHAPE;
@@ -4435,7 +4537,7 @@ struct ggml_tensor * ggml_reshape_3d(
4435
4537
  struct ggml_tensor * ggml_view_1d(
4436
4538
  struct ggml_context * ctx,
4437
4539
  struct ggml_tensor * a,
4438
- int ne0,
4540
+ int64_t ne0,
4439
4541
  size_t offset) {
4440
4542
  if (a->grad) {
4441
4543
  GGML_ASSERT(false); // gradient propagation is not supported
@@ -4456,15 +4558,15 @@ struct ggml_tensor * ggml_view_1d(
4456
4558
  struct ggml_tensor * ggml_view_2d(
4457
4559
  struct ggml_context * ctx,
4458
4560
  struct ggml_tensor * a,
4459
- int ne0,
4460
- int ne1,
4561
+ int64_t ne0,
4562
+ int64_t ne1,
4461
4563
  size_t nb1,
4462
4564
  size_t offset) {
4463
4565
  if (a->grad) {
4464
4566
  GGML_ASSERT(false); // gradient propagation is not supported
4465
4567
  }
4466
4568
 
4467
- const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
4569
+ const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
4468
4570
 
4469
4571
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
4470
4572
 
@@ -4480,6 +4582,37 @@ struct ggml_tensor * ggml_view_2d(
4480
4582
  return result;
4481
4583
  }
4482
4584
 
4585
+ // ggml_view_3d
4586
+
4587
+ struct ggml_tensor * ggml_view_3d(
4588
+ struct ggml_context * ctx,
4589
+ struct ggml_tensor * a,
4590
+ int64_t ne0,
4591
+ int64_t ne1,
4592
+ int64_t ne2,
4593
+ size_t nb1,
4594
+ size_t nb2,
4595
+ size_t offset) {
4596
+ if (a->grad) {
4597
+ GGML_ASSERT(false); // gradient propagation is not supported
4598
+ }
4599
+
4600
+ const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
4601
+
4602
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
4603
+
4604
+ result->nb[1] = nb1;
4605
+ result->nb[2] = nb2;
4606
+ result->nb[3] = result->nb[2]*ne2;
4607
+
4608
+ result->op = GGML_OP_VIEW;
4609
+ result->grad = NULL;
4610
+ result->src0 = a;
4611
+ result->src1 = NULL; // TODO: maybe store the offset here?
4612
+
4613
+ return result;
4614
+ }
4615
+
4483
4616
  // ggml_permute
4484
4617
 
4485
4618
  struct ggml_tensor * ggml_permute(
@@ -4695,7 +4828,7 @@ struct ggml_tensor * ggml_conv_1d_1s(
4695
4828
  is_node = true;
4696
4829
  }
4697
4830
 
4698
- const int ne[4] = { b->ne[0], a->ne[2], 1, 1, };
4831
+ const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
4699
4832
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
4700
4833
 
4701
4834
  result->op = GGML_OP_CONV_1D_1S;
@@ -4722,7 +4855,7 @@ struct ggml_tensor * ggml_conv_1d_2s(
4722
4855
  is_node = true;
4723
4856
  }
4724
4857
 
4725
- const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
4858
+ const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
4726
4859
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
4727
4860
 
4728
4861
  result->op = GGML_OP_CONV_1D_2S;
@@ -4797,6 +4930,90 @@ struct ggml_tensor * ggml_flash_ff(
4797
4930
  return result;
4798
4931
  }
4799
4932
 
4933
+ // ggml_map_unary
4934
+
4935
+ struct ggml_tensor * ggml_map_unary_impl_f32(
4936
+ struct ggml_context * ctx,
4937
+ struct ggml_tensor * a,
4938
+ const ggml_unary_op_f32_t fun,
4939
+ bool inplace) {
4940
+ bool is_node = false;
4941
+
4942
+ if (!inplace && a->grad) {
4943
+ is_node = true;
4944
+ }
4945
+
4946
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
4947
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
4948
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4949
+
4950
+ result->op = GGML_OP_MAP_UNARY;
4951
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4952
+ result->src0 = a;
4953
+ result->opt[0] = addr_tensor;
4954
+
4955
+ return result;
4956
+ }
4957
+
4958
+ struct ggml_tensor * ggml_map_unary_f32(
4959
+ struct ggml_context * ctx,
4960
+ struct ggml_tensor * a,
4961
+ const ggml_unary_op_f32_t fun) {
4962
+ return ggml_map_unary_impl_f32(ctx, a, fun, false);
4963
+ }
4964
+
4965
+ struct ggml_tensor * ggml_map_unary_inplace_f32(
4966
+ struct ggml_context * ctx,
4967
+ struct ggml_tensor * a,
4968
+ const ggml_unary_op_f32_t fun) {
4969
+ return ggml_map_unary_impl_f32(ctx, a, fun, true);
4970
+ }
4971
+
4972
+ // ggml_map_binary
4973
+
4974
+ struct ggml_tensor * ggml_map_binary_impl_f32(
4975
+ struct ggml_context * ctx,
4976
+ struct ggml_tensor * a,
4977
+ struct ggml_tensor * b,
4978
+ const ggml_binary_op_f32_t fun,
4979
+ bool inplace) {
4980
+ GGML_ASSERT(ggml_are_same_shape(a, b));
4981
+
4982
+ bool is_node = false;
4983
+
4984
+ if (!inplace && (a->grad || b->grad)) {
4985
+ is_node = true;
4986
+ }
4987
+
4988
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
4989
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
4990
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4991
+
4992
+ result->op = GGML_OP_MAP_BINARY;
4993
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4994
+ result->src0 = a;
4995
+ result->src1 = b;
4996
+ result->opt[0] = addr_tensor;
4997
+
4998
+ return result;
4999
+ }
5000
+
5001
+ struct ggml_tensor * ggml_map_binary_f32(
5002
+ struct ggml_context * ctx,
5003
+ struct ggml_tensor * a,
5004
+ struct ggml_tensor * b,
5005
+ const ggml_binary_op_f32_t fun) {
5006
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
5007
+ }
5008
+
5009
+ struct ggml_tensor * ggml_map_binary_inplace_f32(
5010
+ struct ggml_context * ctx,
5011
+ struct ggml_tensor * a,
5012
+ struct ggml_tensor * b,
5013
+ const ggml_binary_op_f32_t fun) {
5014
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
5015
+ }
5016
+
4800
5017
  ////////////////////////////////////////////////////////////////////////////////
4801
5018
 
4802
5019
  void ggml_set_param(
@@ -4815,102 +5032,191 @@ static void ggml_compute_forward_dup_f16(
4815
5032
  const struct ggml_tensor * src0,
4816
5033
  struct ggml_tensor * dst) {
4817
5034
  GGML_ASSERT(params->ith == 0);
4818
- GGML_ASSERT(ggml_is_contiguous(dst));
4819
5035
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
4820
5036
 
4821
5037
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
4822
5038
  return;
4823
5039
  }
4824
5040
 
4825
- const int ne00 = src0->ne[0];
4826
- const int ne01 = src0->ne[1];
4827
- const int ne02 = src0->ne[2];
4828
- const int ne03 = src0->ne[3];
5041
+ const int64_t ne00 = src0->ne[0];
5042
+ const int64_t ne01 = src0->ne[1];
5043
+ const int64_t ne02 = src0->ne[2];
5044
+ const int64_t ne03 = src0->ne[3];
4829
5045
 
4830
5046
  const size_t nb00 = src0->nb[0];
4831
5047
  const size_t nb01 = src0->nb[1];
4832
5048
  const size_t nb02 = src0->nb[2];
4833
5049
  const size_t nb03 = src0->nb[3];
4834
5050
 
4835
- if (ggml_is_contiguous(src0) && src0->type == dst->type) {
5051
+ const size_t nb0 = dst->nb[0];
5052
+ const size_t nb1 = dst->nb[1];
5053
+ const size_t nb2 = dst->nb[2];
5054
+ const size_t nb3 = dst->nb[3];
5055
+
5056
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
4836
5057
  memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
4837
5058
  return;
4838
5059
  }
4839
5060
 
4840
- if (src0->nb[0] == sizeof(ggml_fp16_t)) {
4841
- if (dst->type == GGML_TYPE_F16) {
4842
- size_t id = 0;
4843
- const size_t rs = ne00*nb00;
5061
+ if (src0->type == dst->type &&
5062
+ src0->ne[0] == dst->ne[0] &&
5063
+ src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5064
+ // copy by rows
5065
+ const size_t rs = ne00*nb00;
5066
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5067
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5068
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5069
+ memcpy(
5070
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5071
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
5072
+ rs);
5073
+ }
5074
+ }
5075
+ }
5076
+ return;
5077
+ }
4844
5078
 
4845
- for (int i03 = 0; i03 < ne03; i03++) {
4846
- for (int i02 = 0; i02 < ne02; i02++) {
4847
- for (int i01 = 0; i01 < ne01; i01++) {
4848
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
4849
- char * dst_ptr = (char *) dst->data + id*rs;
5079
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
4850
5080
 
4851
- memcpy(dst_ptr, src0_ptr, rs);
5081
+ if (ggml_is_contiguous(dst)) {
5082
+ if (src0->nb[0] == sizeof(ggml_fp16_t)) {
5083
+ if (dst->type == GGML_TYPE_F16) {
5084
+ size_t id = 0;
5085
+ const size_t rs = ne00*nb00;
4852
5086
 
4853
- id++;
4854
- }
4855
- }
4856
- }
4857
- } else if (dst->type == GGML_TYPE_F32) {
4858
- size_t id = 0;
4859
- float * dst_ptr = (float *) dst->data;
5087
+ for (int i03 = 0; i03 < ne03; i03++) {
5088
+ for (int i02 = 0; i02 < ne02; i02++) {
5089
+ for (int i01 = 0; i01 < ne01; i01++) {
5090
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5091
+ char * dst_ptr = (char *) dst->data + id*rs;
4860
5092
 
4861
- for (int i03 = 0; i03 < ne03; i03++) {
4862
- for (int i02 = 0; i02 < ne02; i02++) {
4863
- for (int i01 = 0; i01 < ne01; i01++) {
4864
- for (int i00 = 0; i00 < ne00; i00++) {
4865
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5093
+ memcpy(dst_ptr, src0_ptr, rs);
4866
5094
 
4867
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4868
5095
  id++;
4869
5096
  }
4870
5097
  }
4871
5098
  }
5099
+ } else if (dst->type == GGML_TYPE_F32) {
5100
+ size_t id = 0;
5101
+ float * dst_ptr = (float *) dst->data;
5102
+
5103
+ for (int i03 = 0; i03 < ne03; i03++) {
5104
+ for (int i02 = 0; i02 < ne02; i02++) {
5105
+ for (int i01 = 0; i01 < ne01; i01++) {
5106
+ for (int i00 = 0; i00 < ne00; i00++) {
5107
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5108
+
5109
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5110
+ id++;
5111
+ }
5112
+ }
5113
+ }
5114
+ }
5115
+ } else {
5116
+ GGML_ASSERT(false); // TODO: implement
4872
5117
  }
4873
5118
  } else {
4874
- GGML_ASSERT(false); // TODO: implement
4875
- }
4876
- } else {
4877
- //printf("%s: this is not optimal - fix me\n", __func__);
5119
+ //printf("%s: this is not optimal - fix me\n", __func__);
4878
5120
 
4879
- if (dst->type == GGML_TYPE_F32) {
4880
- size_t id = 0;
4881
- float * dst_ptr = (float *) dst->data;
5121
+ if (dst->type == GGML_TYPE_F32) {
5122
+ size_t id = 0;
5123
+ float * dst_ptr = (float *) dst->data;
4882
5124
 
4883
- for (int i03 = 0; i03 < ne03; i03++) {
4884
- for (int i02 = 0; i02 < ne02; i02++) {
4885
- for (int i01 = 0; i01 < ne01; i01++) {
4886
- for (int i00 = 0; i00 < ne00; i00++) {
4887
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5125
+ for (int i03 = 0; i03 < ne03; i03++) {
5126
+ for (int i02 = 0; i02 < ne02; i02++) {
5127
+ for (int i01 = 0; i01 < ne01; i01++) {
5128
+ for (int i00 = 0; i00 < ne00; i00++) {
5129
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4888
5130
 
4889
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4890
- id++;
5131
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5132
+ id++;
5133
+ }
5134
+ }
5135
+ }
5136
+ }
5137
+ } else if (dst->type == GGML_TYPE_F16) {
5138
+ size_t id = 0;
5139
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5140
+
5141
+ for (int i03 = 0; i03 < ne03; i03++) {
5142
+ for (int i02 = 0; i02 < ne02; i02++) {
5143
+ for (int i01 = 0; i01 < ne01; i01++) {
5144
+ for (int i00 = 0; i00 < ne00; i00++) {
5145
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5146
+
5147
+ dst_ptr[id] = *src0_ptr;
5148
+ id++;
5149
+ }
4891
5150
  }
4892
5151
  }
4893
5152
  }
5153
+ } else {
5154
+ GGML_ASSERT(false); // TODO: implement
4894
5155
  }
4895
- } else if (dst->type == GGML_TYPE_F16) {
4896
- size_t id = 0;
4897
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4898
-
4899
- for (int i03 = 0; i03 < ne03; i03++) {
4900
- for (int i02 = 0; i02 < ne02; i02++) {
4901
- for (int i01 = 0; i01 < ne01; i01++) {
4902
- for (int i00 = 0; i00 < ne00; i00++) {
4903
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5156
+ }
5157
+ return;
5158
+ }
4904
5159
 
4905
- dst_ptr[id] = *src0_ptr;
4906
- id++;
5160
+ // dst counters
5161
+ int64_t i10 = 0;
5162
+ int64_t i11 = 0;
5163
+ int64_t i12 = 0;
5164
+ int64_t i13 = 0;
5165
+
5166
+ if (dst->type == GGML_TYPE_F16) {
5167
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5168
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5169
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5170
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5171
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5172
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5173
+
5174
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
5175
+
5176
+ if (++i10 == ne00) {
5177
+ i10 = 0;
5178
+ if (++i11 == ne01) {
5179
+ i11 = 0;
5180
+ if (++i12 == ne02) {
5181
+ i12 = 0;
5182
+ if (++i13 == ne03) {
5183
+ i13 = 0;
5184
+ }
5185
+ }
5186
+ }
5187
+ }
5188
+ }
5189
+ }
5190
+ }
5191
+ }
5192
+ } else if (dst->type == GGML_TYPE_F32) {
5193
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5194
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5195
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5196
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5197
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5198
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5199
+
5200
+ *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
5201
+
5202
+ if (++i10 == ne00) {
5203
+ i10 = 0;
5204
+ if (++i11 == ne01) {
5205
+ i11 = 0;
5206
+ if (++i12 == ne02) {
5207
+ i12 = 0;
5208
+ if (++i13 == ne03) {
5209
+ i13 = 0;
5210
+ }
5211
+ }
5212
+ }
4907
5213
  }
4908
5214
  }
4909
5215
  }
4910
5216
  }
4911
- } else {
4912
- GGML_ASSERT(false); // TODO: implement
4913
5217
  }
5218
+ } else {
5219
+ GGML_ASSERT(false); // TODO: implement
4914
5220
  }
4915
5221
  }
4916
5222
 
@@ -4919,102 +5225,191 @@ static void ggml_compute_forward_dup_f32(
4919
5225
  const struct ggml_tensor * src0,
4920
5226
  struct ggml_tensor * dst) {
4921
5227
  GGML_ASSERT(params->ith == 0);
4922
- GGML_ASSERT(ggml_is_contiguous(dst));
4923
5228
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
4924
5229
 
4925
5230
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
4926
5231
  return;
4927
5232
  }
4928
5233
 
4929
- const int ne00 = src0->ne[0];
4930
- const int ne01 = src0->ne[1];
4931
- const int ne02 = src0->ne[2];
4932
- const int ne03 = src0->ne[3];
5234
+ const int64_t ne00 = src0->ne[0];
5235
+ const int64_t ne01 = src0->ne[1];
5236
+ const int64_t ne02 = src0->ne[2];
5237
+ const int64_t ne03 = src0->ne[3];
4933
5238
 
4934
5239
  const size_t nb00 = src0->nb[0];
4935
5240
  const size_t nb01 = src0->nb[1];
4936
5241
  const size_t nb02 = src0->nb[2];
4937
5242
  const size_t nb03 = src0->nb[3];
4938
5243
 
4939
- if (ggml_is_contiguous(src0) && src0->type == dst->type) {
5244
+ const size_t nb0 = dst->nb[0];
5245
+ const size_t nb1 = dst->nb[1];
5246
+ const size_t nb2 = dst->nb[2];
5247
+ const size_t nb3 = dst->nb[3];
5248
+
5249
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
4940
5250
  memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
4941
5251
  return;
4942
5252
  }
4943
5253
 
4944
- if (src0->nb[0] == sizeof(float)) {
4945
- if (dst->type == GGML_TYPE_F32) {
4946
- size_t id = 0;
4947
- const size_t rs = ne00*nb00;
4948
-
4949
- for (int i03 = 0; i03 < ne03; i03++) {
4950
- for (int i02 = 0; i02 < ne02; i02++) {
4951
- for (int i01 = 0; i01 < ne01; i01++) {
4952
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
4953
- char * dst_ptr = (char *) dst->data + id*rs;
4954
-
4955
- memcpy(dst_ptr, src0_ptr, rs);
4956
-
4957
- id++;
4958
- }
5254
+ if (src0->type == dst->type &&
5255
+ src0->ne[0] == dst->ne[0] &&
5256
+ src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5257
+ // copy by rows
5258
+ const size_t rs = ne00*nb00;
5259
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5260
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5261
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5262
+ memcpy(
5263
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5264
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
5265
+ rs);
4959
5266
  }
4960
5267
  }
4961
- } else if (dst->type == GGML_TYPE_F16) {
4962
- size_t id = 0;
4963
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5268
+ }
5269
+ return;
5270
+ }
5271
+
5272
+ if (ggml_is_contiguous(dst)) {
5273
+ // TODO: simplify
5274
+ if (src0->nb[0] == sizeof(float)) {
5275
+ if (dst->type == GGML_TYPE_F32) {
5276
+ size_t id = 0;
5277
+ const size_t rs = ne00*nb00;
5278
+
5279
+ for (int i03 = 0; i03 < ne03; i03++) {
5280
+ for (int i02 = 0; i02 < ne02; i02++) {
5281
+ for (int i01 = 0; i01 < ne01; i01++) {
5282
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5283
+ char * dst_ptr = (char *) dst->data + id*rs;
4964
5284
 
4965
- for (int i03 = 0; i03 < ne03; i03++) {
4966
- for (int i02 = 0; i02 < ne02; i02++) {
4967
- for (int i01 = 0; i01 < ne01; i01++) {
4968
- for (int i00 = 0; i00 < ne00; i00++) {
4969
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5285
+ memcpy(dst_ptr, src0_ptr, rs);
4970
5286
 
4971
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
4972
5287
  id++;
4973
5288
  }
4974
5289
  }
4975
5290
  }
5291
+ } else if (dst->type == GGML_TYPE_F16) {
5292
+ size_t id = 0;
5293
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5294
+
5295
+ for (int i03 = 0; i03 < ne03; i03++) {
5296
+ for (int i02 = 0; i02 < ne02; i02++) {
5297
+ for (int i01 = 0; i01 < ne01; i01++) {
5298
+ for (int i00 = 0; i00 < ne00; i00++) {
5299
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5300
+
5301
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5302
+ id++;
5303
+ }
5304
+ }
5305
+ }
5306
+ }
5307
+ } else {
5308
+ GGML_ASSERT(false); // TODO: implement
4976
5309
  }
4977
5310
  } else {
4978
- GGML_ASSERT(false); // TODO: implement
4979
- }
4980
- } else {
4981
- //printf("%s: this is not optimal - fix me\n", __func__);
5311
+ //printf("%s: this is not optimal - fix me\n", __func__);
4982
5312
 
4983
- if (dst->type == GGML_TYPE_F32) {
4984
- size_t id = 0;
4985
- float * dst_ptr = (float *) dst->data;
5313
+ if (dst->type == GGML_TYPE_F32) {
5314
+ size_t id = 0;
5315
+ float * dst_ptr = (float *) dst->data;
4986
5316
 
4987
- for (int i03 = 0; i03 < ne03; i03++) {
4988
- for (int i02 = 0; i02 < ne02; i02++) {
4989
- for (int i01 = 0; i01 < ne01; i01++) {
4990
- for (int i00 = 0; i00 < ne00; i00++) {
4991
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5317
+ for (int i03 = 0; i03 < ne03; i03++) {
5318
+ for (int i02 = 0; i02 < ne02; i02++) {
5319
+ for (int i01 = 0; i01 < ne01; i01++) {
5320
+ for (int i00 = 0; i00 < ne00; i00++) {
5321
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4992
5322
 
4993
- dst_ptr[id] = *src0_ptr;
4994
- id++;
5323
+ dst_ptr[id] = *src0_ptr;
5324
+ id++;
5325
+ }
4995
5326
  }
4996
5327
  }
4997
5328
  }
5329
+ } else if (dst->type == GGML_TYPE_F16) {
5330
+ size_t id = 0;
5331
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5332
+
5333
+ for (int i03 = 0; i03 < ne03; i03++) {
5334
+ for (int i02 = 0; i02 < ne02; i02++) {
5335
+ for (int i01 = 0; i01 < ne01; i01++) {
5336
+ for (int i00 = 0; i00 < ne00; i00++) {
5337
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5338
+
5339
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5340
+ id++;
5341
+ }
5342
+ }
5343
+ }
5344
+ }
5345
+ } else {
5346
+ GGML_ASSERT(false); // TODO: implement
4998
5347
  }
4999
- } else if (dst->type == GGML_TYPE_F16) {
5000
- size_t id = 0;
5001
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5348
+ }
5002
5349
 
5003
- for (int i03 = 0; i03 < ne03; i03++) {
5004
- for (int i02 = 0; i02 < ne02; i02++) {
5005
- for (int i01 = 0; i01 < ne01; i01++) {
5006
- for (int i00 = 0; i00 < ne00; i00++) {
5007
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5350
+ return;
5351
+ }
5008
5352
 
5009
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5010
- id++;
5353
+ // dst counters
5354
+ int64_t i10 = 0;
5355
+ int64_t i11 = 0;
5356
+ int64_t i12 = 0;
5357
+ int64_t i13 = 0;
5358
+
5359
+ if (dst->type == GGML_TYPE_F32) {
5360
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5361
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5362
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5363
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5364
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5365
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5366
+
5367
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
5368
+
5369
+ if (++i10 == dst->ne[0]) {
5370
+ i10 = 0;
5371
+ if (++i11 == dst->ne[1]) {
5372
+ i11 = 0;
5373
+ if (++i12 == dst->ne[2]) {
5374
+ i12 = 0;
5375
+ if (++i13 == dst->ne[3]) {
5376
+ i13 = 0;
5377
+ }
5378
+ }
5379
+ }
5380
+ }
5381
+ }
5382
+ }
5383
+ }
5384
+ }
5385
+ } else if (dst->type == GGML_TYPE_F16) {
5386
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5387
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5388
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5389
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5390
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5391
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5392
+
5393
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5394
+
5395
+ if (++i10 == dst->ne[0]) {
5396
+ i10 = 0;
5397
+ if (++i11 == dst->ne[1]) {
5398
+ i11 = 0;
5399
+ if (++i12 == dst->ne[2]) {
5400
+ i12 = 0;
5401
+ if (++i13 == dst->ne[3]) {
5402
+ i13 = 0;
5403
+ }
5404
+ }
5405
+ }
5011
5406
  }
5012
5407
  }
5013
5408
  }
5014
5409
  }
5015
- } else {
5016
- GGML_ASSERT(false); // TODO: implement
5017
5410
  }
5411
+ } else {
5412
+ GGML_ASSERT(false); // TODO: implement
5018
5413
  }
5019
5414
  }
5020
5415
 
@@ -5075,14 +5470,18 @@ static void ggml_compute_forward_add_f32(
5075
5470
  GGML_ASSERT(nb00 == sizeof(float));
5076
5471
 
5077
5472
  if (nb10 == sizeof(float)) {
5078
- const int j0 = (n/nth)*ith;
5079
- const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
5080
-
5081
- for (int j = j0; j < j1; j++) {
5473
+ for (int j = ith; j < n; j += nth) {
5474
+ #ifdef GGML_USE_ACCELERATE
5475
+ vDSP_vadd(
5476
+ (float *) ((char *) src0->data + j*nb01), 1,
5477
+ (float *) ((char *) src1->data + j*nb11), 1,
5478
+ (float *) ((char *) dst->data + j*nb1), 1, nc);
5479
+ #else
5082
5480
  ggml_vec_add_f32(nc,
5083
5481
  (float *) ((char *) dst->data + j*nb1),
5084
5482
  (float *) ((char *) src0->data + j*nb01),
5085
5483
  (float *) ((char *) src1->data + j*nb11));
5484
+ #endif
5086
5485
  }
5087
5486
  } else {
5088
5487
  // src1 is not contiguous
@@ -5389,18 +5788,18 @@ static void ggml_compute_forward_sum_f32(
5389
5788
  assert(ggml_is_scalar(dst));
5390
5789
  assert(src0->nb[0] == sizeof(float));
5391
5790
 
5392
- const int ne00 = src0->ne[0];
5393
- const int ne01 = src0->ne[1];
5394
- const int ne02 = src0->ne[2];
5395
- const int ne03 = src0->ne[3];
5791
+ const int64_t ne00 = src0->ne[0];
5792
+ const int64_t ne01 = src0->ne[1];
5793
+ const int64_t ne02 = src0->ne[2];
5794
+ const int64_t ne03 = src0->ne[3];
5396
5795
 
5397
5796
  const size_t nb01 = src0->nb[1];
5398
5797
  const size_t nb02 = src0->nb[2];
5399
5798
  const size_t nb03 = src0->nb[3];
5400
5799
 
5401
- for (int i03 = 0; i03 < ne03; i03++) {
5402
- for (int i02 = 0; i02 < ne02; i02++) {
5403
- for (int i01 = 0; i01 < ne01; i01++) {
5800
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5801
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5802
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5404
5803
  ggml_vec_sum_f32(ne00,
5405
5804
  (float *) (dst->data),
5406
5805
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@@ -5445,19 +5844,19 @@ static void ggml_compute_forward_mean_f32(
5445
5844
 
5446
5845
  assert(src0->nb[0] == sizeof(float));
5447
5846
 
5448
- const int ne00 = src0->ne[0];
5449
- const int ne01 = src0->ne[1];
5450
- const int ne02 = src0->ne[2];
5451
- const int ne03 = src0->ne[3];
5847
+ const int64_t ne00 = src0->ne[0];
5848
+ const int64_t ne01 = src0->ne[1];
5849
+ const int64_t ne02 = src0->ne[2];
5850
+ const int64_t ne03 = src0->ne[3];
5452
5851
 
5453
5852
  const size_t nb01 = src0->nb[1];
5454
5853
  const size_t nb02 = src0->nb[2];
5455
5854
  const size_t nb03 = src0->nb[3];
5456
5855
 
5457
- const int ne0 = dst->ne[0];
5458
- const int ne1 = dst->ne[1];
5459
- const int ne2 = dst->ne[2];
5460
- const int ne3 = dst->ne[3];
5856
+ const int64_t ne0 = dst->ne[0];
5857
+ const int64_t ne1 = dst->ne[1];
5858
+ const int64_t ne2 = dst->ne[2];
5859
+ const int64_t ne3 = dst->ne[3];
5461
5860
 
5462
5861
  assert(ne0 == 1);
5463
5862
  assert(ne1 == ne01);
@@ -5473,9 +5872,9 @@ static void ggml_compute_forward_mean_f32(
5473
5872
  const size_t nb2 = dst->nb[2];
5474
5873
  const size_t nb3 = dst->nb[3];
5475
5874
 
5476
- for (int i03 = 0; i03 < ne03; i03++) {
5477
- for (int i02 = 0; i02 < ne02; i02++) {
5478
- for (int i01 = 0; i01 < ne01; i01++) {
5875
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5876
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5877
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5479
5878
  ggml_vec_sum_f32(ne00,
5480
5879
  (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5481
5880
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@@ -5962,10 +6361,10 @@ static void ggml_compute_forward_norm_f32(
5962
6361
  const int ith = params->ith;
5963
6362
  const int nth = params->nth;
5964
6363
 
5965
- const int ne00 = src0->ne[0];
5966
- const int ne01 = src0->ne[1];
5967
- const int ne02 = src0->ne[2];
5968
- const int ne03 = src0->ne[3];
6364
+ const int64_t ne00 = src0->ne[0];
6365
+ const int64_t ne01 = src0->ne[1];
6366
+ const int64_t ne02 = src0->ne[2];
6367
+ const int64_t ne03 = src0->ne[3];
5969
6368
 
5970
6369
  const size_t nb01 = src0->nb[1];
5971
6370
  const size_t nb02 = src0->nb[2];
@@ -5978,13 +6377,13 @@ static void ggml_compute_forward_norm_f32(
5978
6377
  const float eps = 1e-5f; // TODO: make this a parameter
5979
6378
 
5980
6379
  // TODO: optimize
5981
- for (int i03 = 0; i03 < ne03; i03++) {
5982
- for (int i02 = 0; i02 < ne02; i02++) {
5983
- for (int i01 = ith; i01 < ne01; i01 += nth) {
6380
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6381
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6382
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5984
6383
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5985
6384
 
5986
6385
  ggml_float sum = 0.0;
5987
- for (int i00 = 0; i00 < ne00; i00++) {
6386
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5988
6387
  sum += (ggml_float)x[i00];
5989
6388
  }
5990
6389
 
@@ -5993,7 +6392,7 @@ static void ggml_compute_forward_norm_f32(
5993
6392
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5994
6393
 
5995
6394
  ggml_float sum2 = 0.0;
5996
- for (int i00 = 0; i00 < ne00; i00++) {
6395
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
5997
6396
  float v = x[i00] - mean;
5998
6397
  y[i00] = v;
5999
6398
  sum2 += (ggml_float)(v*v);
@@ -6045,10 +6444,10 @@ static void ggml_compute_forward_rms_norm_f32(
6045
6444
  const int ith = params->ith;
6046
6445
  const int nth = params->nth;
6047
6446
 
6048
- const int ne00 = src0->ne[0];
6049
- const int ne01 = src0->ne[1];
6050
- const int ne02 = src0->ne[2];
6051
- const int ne03 = src0->ne[3];
6447
+ const int64_t ne00 = src0->ne[0];
6448
+ const int64_t ne01 = src0->ne[1];
6449
+ const int64_t ne02 = src0->ne[2];
6450
+ const int64_t ne03 = src0->ne[3];
6052
6451
 
6053
6452
  const size_t nb01 = src0->nb[1];
6054
6453
  const size_t nb02 = src0->nb[2];
@@ -6061,13 +6460,13 @@ static void ggml_compute_forward_rms_norm_f32(
6061
6460
  const float eps = 1e-6f; // TODO: make this a parameter
6062
6461
 
6063
6462
  // TODO: optimize
6064
- for (int i03 = 0; i03 < ne03; i03++) {
6065
- for (int i02 = 0; i02 < ne02; i02++) {
6066
- for (int i01 = ith; i01 < ne01; i01 += nth) {
6463
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6464
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6465
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
6067
6466
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
6068
6467
 
6069
6468
  ggml_float sum = 0.0;
6070
- for (int i00 = 0; i00 < ne00; i00++) {
6469
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
6071
6470
  sum += (ggml_float)(x[i00] * x[i00]);
6072
6471
  }
6073
6472
 
@@ -6120,13 +6519,13 @@ static bool ggml_compute_forward_mul_mat_use_blas(
6120
6519
  const struct ggml_tensor * src0,
6121
6520
  const struct ggml_tensor * src1,
6122
6521
  struct ggml_tensor * dst) {
6123
- //const int ne00 = src0->ne[0];
6124
- //const int ne01 = src0->ne[1];
6522
+ //const int64_t ne00 = src0->ne[0];
6523
+ //const int64_t ne01 = src0->ne[1];
6125
6524
 
6126
- const int ne10 = src1->ne[0];
6525
+ const int64_t ne10 = src1->ne[0];
6127
6526
 
6128
- const int ne0 = dst->ne[0];
6129
- const int ne1 = dst->ne[1];
6527
+ const int64_t ne0 = dst->ne[0];
6528
+ const int64_t ne1 = dst->ne[1];
6130
6529
 
6131
6530
  // TODO: find the optimal values for these
6132
6531
  if (ggml_is_contiguous(src0) &&
@@ -6148,23 +6547,23 @@ static void ggml_compute_forward_mul_mat_f32(
6148
6547
  int64_t t0 = ggml_perf_time_us();
6149
6548
  UNUSED(t0);
6150
6549
 
6151
- const int ne00 = src0->ne[0];
6152
- const int ne01 = src0->ne[1];
6153
- const int ne02 = src0->ne[2];
6154
- const int ne03 = src0->ne[3];
6550
+ const int64_t ne00 = src0->ne[0];
6551
+ const int64_t ne01 = src0->ne[1];
6552
+ const int64_t ne02 = src0->ne[2];
6553
+ const int64_t ne03 = src0->ne[3];
6155
6554
 
6156
6555
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
6157
- const int ne10 = src1->ne[0];
6556
+ const int64_t ne10 = src1->ne[0];
6158
6557
  #endif
6159
- const int ne11 = src1->ne[1];
6558
+ const int64_t ne11 = src1->ne[1];
6160
6559
  #ifndef NDEBUG
6161
- const int ne12 = src1->ne[2];
6162
- const int ne13 = src1->ne[3];
6560
+ const int64_t ne12 = src1->ne[2];
6561
+ const int64_t ne13 = src1->ne[3];
6163
6562
 
6164
- const int ne0 = dst->ne[0];
6165
- const int ne1 = dst->ne[1];
6166
- const int ne2 = dst->ne[2];
6167
- const int ne3 = dst->ne[3];
6563
+ const int64_t ne0 = dst->ne[0];
6564
+ const int64_t ne1 = dst->ne[1];
6565
+ const int64_t ne2 = dst->ne[2];
6566
+ const int64_t ne3 = dst->ne[3];
6168
6567
 
6169
6568
  const int nb00 = src0->nb[0];
6170
6569
  #endif
@@ -6224,8 +6623,8 @@ static void ggml_compute_forward_mul_mat_f32(
6224
6623
  return;
6225
6624
  }
6226
6625
 
6227
- for (int i03 = 0; i03 < ne03; i03++) {
6228
- for (int i02 = 0; i02 < ne02; i02++) {
6626
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6627
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6229
6628
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
6230
6629
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
6231
6630
 
@@ -6235,7 +6634,7 @@ static void ggml_compute_forward_mul_mat_f32(
6235
6634
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6236
6635
  ne11, ne01, ne10,
6237
6636
  1.0f, y, ne10,
6238
- x, ne10,
6637
+ x, ne00,
6239
6638
  0.0f, d, ne01);
6240
6639
  }
6241
6640
  }
@@ -6272,7 +6671,7 @@ static void ggml_compute_forward_mul_mat_f32(
6272
6671
  const int i02 = (ir - i03*ne02*ne01)/ne01;
6273
6672
  const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
6274
6673
 
6275
- for (int ic = 0; ic < ne11; ++ic) {
6674
+ for (int64_t ic = 0; ic < ne11; ++ic) {
6276
6675
  // src1 indices
6277
6676
  const int i13 = i03;
6278
6677
  const int i12 = i02;
@@ -6313,21 +6712,21 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6313
6712
  int64_t t0 = ggml_perf_time_us();
6314
6713
  UNUSED(t0);
6315
6714
 
6316
- const int ne00 = src0->ne[0];
6317
- const int ne01 = src0->ne[1];
6318
- const int ne02 = src0->ne[2];
6319
- const int ne03 = src0->ne[3];
6715
+ const int64_t ne00 = src0->ne[0];
6716
+ const int64_t ne01 = src0->ne[1];
6717
+ const int64_t ne02 = src0->ne[2];
6718
+ const int64_t ne03 = src0->ne[3];
6320
6719
 
6321
- const int ne10 = src1->ne[0];
6322
- const int ne11 = src1->ne[1];
6323
- const int ne12 = src1->ne[2];
6324
- const int ne13 = src1->ne[3];
6720
+ const int64_t ne10 = src1->ne[0];
6721
+ const int64_t ne11 = src1->ne[1];
6722
+ const int64_t ne12 = src1->ne[2];
6723
+ const int64_t ne13 = src1->ne[3];
6325
6724
 
6326
- const int ne0 = dst->ne[0];
6327
- const int ne1 = dst->ne[1];
6328
- const int ne2 = dst->ne[2];
6329
- const int ne3 = dst->ne[3];
6330
- //const int ne = ne0*ne1*ne2*ne3;
6725
+ const int64_t ne0 = dst->ne[0];
6726
+ const int64_t ne1 = dst->ne[1];
6727
+ const int64_t ne2 = dst->ne[2];
6728
+ const int64_t ne3 = dst->ne[3];
6729
+ //const int64_t ne = ne0*ne1*ne2*ne3;
6331
6730
 
6332
6731
  const int nb00 = src0->nb[0];
6333
6732
  const int nb01 = src0->nb[1];
@@ -6387,12 +6786,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6387
6786
 
6388
6787
  float * const wdata = params->wdata;
6389
6788
 
6390
- for (int i03 = 0; i03 < ne03; i03++) {
6391
- for (int i02 = 0; i02 < ne02; i02++) {
6789
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
6790
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6392
6791
  {
6393
6792
  size_t id = 0;
6394
- for (int i01 = 0; i01 < ne01; ++i01) {
6395
- for (int i00 = 0; i00 < ne00; ++i00) {
6793
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
6794
+ for (int64_t i00 = 0; i00 < ne00; ++i00) {
6396
6795
  wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
6397
6796
  }
6398
6797
  }
@@ -6407,7 +6806,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6407
6806
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6408
6807
  ne11, ne01, ne10,
6409
6808
  1.0f, y, ne10,
6410
- x, ne10,
6809
+ x, ne00,
6411
6810
  0.0f, d, ne01);
6412
6811
  }
6413
6812
  }
@@ -6422,10 +6821,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6422
6821
  ggml_fp16_t * const wdata = params->wdata;
6423
6822
 
6424
6823
  size_t id = 0;
6425
- for (int i13 = 0; i13 < ne13; ++i13) {
6426
- for (int i12 = 0; i12 < ne12; ++i12) {
6427
- for (int i11 = 0; i11 < ne11; ++i11) {
6428
- for (int i10 = 0; i10 < ne10; ++i10) {
6824
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
6825
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
6826
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
6827
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
6429
6828
  wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
6430
6829
  }
6431
6830
  }
@@ -6477,7 +6876,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6477
6876
 
6478
6877
  float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
6479
6878
 
6480
- for (int ic = 0; ic < ne11; ++ic) {
6879
+ for (int64_t ic = 0; ic < ne11; ++ic) {
6481
6880
  ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
6482
6881
  }
6483
6882
  }
@@ -6495,29 +6894,27 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6495
6894
  //}
6496
6895
  }
6497
6896
 
6498
- typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k);
6499
- typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k);
6500
- typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y);
6501
-
6502
- typedef struct {
6503
- dequantize_row_q_t dequantize_row_q;
6504
- quantize_row_q_t quantize_row_q;
6505
- vec_dot_q_t vec_dot_q;
6506
- } quantize_fns_t;
6507
-
6508
6897
  static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6509
6898
  [GGML_TYPE_Q4_0] = {
6510
- .dequantize_row_q = dequantize_row_q4_0,
6511
- .quantize_row_q = quantize_row_q4_0,
6512
- .vec_dot_q = ggml_vec_dot_q4_0,
6899
+ .dequantize_row_q = dequantize_row_q4_0,
6900
+ .quantize_row_q = quantize_row_q4_0,
6901
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6902
+ .vec_dot_q = ggml_vec_dot_q4_0,
6513
6903
  },
6514
6904
  [GGML_TYPE_Q4_1] = {
6515
- .dequantize_row_q = dequantize_row_q4_1,
6516
- .quantize_row_q = quantize_row_q4_1,
6517
- .vec_dot_q = ggml_vec_dot_q4_1,
6905
+ .dequantize_row_q = dequantize_row_q4_1,
6906
+ .quantize_row_q = quantize_row_q4_1,
6907
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6908
+ .vec_dot_q = ggml_vec_dot_q4_1,
6518
6909
  },
6519
6910
  };
6520
6911
 
6912
+ // For internal test use
6913
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6914
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
6915
+ return quantize_fns[i];
6916
+ }
6917
+
6521
6918
  static void ggml_compute_forward_mul_mat_q_f32(
6522
6919
  const struct ggml_compute_params * params,
6523
6920
  const struct ggml_tensor * src0,
@@ -6526,20 +6923,20 @@ static void ggml_compute_forward_mul_mat_q_f32(
6526
6923
  int64_t t0 = ggml_perf_time_us();
6527
6924
  UNUSED(t0);
6528
6925
 
6529
- const int ne00 = src0->ne[0];
6530
- const int ne01 = src0->ne[1];
6531
- const int ne02 = src0->ne[2];
6532
- const int ne03 = src0->ne[3];
6926
+ const int64_t ne00 = src0->ne[0];
6927
+ const int64_t ne01 = src0->ne[1];
6928
+ const int64_t ne02 = src0->ne[2];
6929
+ const int64_t ne03 = src0->ne[3];
6533
6930
 
6534
- const int ne10 = src1->ne[0];
6535
- const int ne11 = src1->ne[1];
6536
- const int ne12 = src1->ne[2];
6537
- const int ne13 = src1->ne[3];
6931
+ const int64_t ne10 = src1->ne[0];
6932
+ const int64_t ne11 = src1->ne[1];
6933
+ const int64_t ne12 = src1->ne[2];
6934
+ const int64_t ne13 = src1->ne[3];
6538
6935
 
6539
- const int ne0 = dst->ne[0];
6540
- const int ne1 = dst->ne[1];
6541
- const int ne2 = dst->ne[2];
6542
- const int ne3 = dst->ne[3];
6936
+ const int64_t ne0 = dst->ne[0];
6937
+ const int64_t ne1 = dst->ne[1];
6938
+ const int64_t ne2 = dst->ne[2];
6939
+ const int64_t ne3 = dst->ne[3];
6543
6940
 
6544
6941
  const int nb00 = src0->nb[0];
6545
6942
  const int nb01 = src0->nb[1];
@@ -6603,11 +7000,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
6603
7000
  float * const wdata = params->wdata;
6604
7001
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
6605
7002
 
6606
- for (int i03 = 0; i03 < ne03; i03++) {
6607
- for (int i02 = 0; i02 < ne02; i02++) {
7003
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7004
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
6608
7005
  {
6609
7006
  size_t id = 0;
6610
- for (int i01 = 0; i01 < ne01; ++i01) {
7007
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
6611
7008
  dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
6612
7009
  id += ne00;
6613
7010
  }
@@ -6622,7 +7019,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6622
7019
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6623
7020
  ne11, ne01, ne10,
6624
7021
  1.0f, y, ne10,
6625
- x, ne10,
7022
+ x, ne00,
6626
7023
  0.0f, d, ne01);
6627
7024
  }
6628
7025
  }
@@ -6637,9 +7034,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
6637
7034
  char * wdata = params->wdata;
6638
7035
  const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
6639
7036
 
6640
- for (int i13 = 0; i13 < ne13; ++i13) {
6641
- for (int i12 = 0; i12 < ne12; ++i12) {
6642
- for (int i11 = 0; i11 < ne11; ++i11) {
7037
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
7038
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
7039
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
6643
7040
  quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
6644
7041
  wdata += row_size;
6645
7042
  }
@@ -6688,7 +7085,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6688
7085
 
6689
7086
  assert(ne00 % 32 == 0);
6690
7087
 
6691
- for (int ic = 0; ic < ne11; ++ic) {
7088
+ for (int64_t ic = 0; ic < ne11; ++ic) {
6692
7089
  vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
6693
7090
  }
6694
7091
  }
@@ -6832,6 +7229,15 @@ static void ggml_compute_forward_cpy(
6832
7229
  ggml_compute_forward_dup(params, src0, dst);
6833
7230
  }
6834
7231
 
7232
+ // ggml_compute_forward_cont
7233
+
7234
+ static void ggml_compute_forward_cont(
7235
+ const struct ggml_compute_params * params,
7236
+ const struct ggml_tensor * src0,
7237
+ struct ggml_tensor * dst) {
7238
+ ggml_compute_forward_dup(params, src0, dst);
7239
+ }
7240
+
6835
7241
  // ggml_compute_forward_reshape
6836
7242
 
6837
7243
  static void ggml_compute_forward_reshape(
@@ -7169,7 +7575,6 @@ static void ggml_compute_forward_rope_f32(
7169
7575
  const struct ggml_tensor * src0,
7170
7576
  const struct ggml_tensor * src1,
7171
7577
  struct ggml_tensor * dst) {
7172
- assert(params->ith == 0);
7173
7578
  assert(src1->type == GGML_TYPE_I32);
7174
7579
  assert(ggml_nelements(src1) == 3);
7175
7580
 
@@ -7181,10 +7586,10 @@ static void ggml_compute_forward_rope_f32(
7181
7586
  const int n_dims = ((int32_t *) src1->data)[1];
7182
7587
  const int mode = ((int32_t *) src1->data)[2];
7183
7588
 
7184
- //const int ne0 = src0->ne[0];
7185
- const int ne1 = src0->ne[1];
7186
- const int ne2 = src0->ne[2];
7187
- const int ne3 = src0->ne[3];
7589
+ //const int64_t ne0 = src0->ne[0];
7590
+ const int64_t ne1 = src0->ne[1];
7591
+ const int64_t ne2 = src0->ne[2];
7592
+ const int64_t ne3 = src0->ne[3];
7188
7593
 
7189
7594
  const int nb0 = src0->nb[0];
7190
7595
  const int nb1 = src0->nb[1];
@@ -7196,16 +7601,37 @@ static void ggml_compute_forward_rope_f32(
7196
7601
 
7197
7602
  assert(nb0 == sizeof(float));
7198
7603
 
7199
- // TODO: optimize
7200
- for (int i3 = 0; i3 < ne3; i3++) {
7201
- for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7604
+ const int ith = params->ith;
7605
+ const int nth = params->nth;
7606
+
7607
+ const int nr = ggml_nrows(src0);
7608
+
7609
+ // rows per thread
7610
+ const int dr = (nr + nth - 1)/nth;
7611
+
7612
+ // row range for this thread
7613
+ const int ir0 = dr*ith;
7614
+ const int ir1 = MIN(ir0 + dr, nr);
7615
+
7616
+ // row index used to determine which thread to use
7617
+ int ir = 0;
7618
+
7619
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
7620
+
7621
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7622
+ for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7202
7623
  const int p = (mode == 0 ? n_past + i2 : i2);
7203
- for (int i1 = 0; i1 < ne1; i1++) {
7624
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7625
+ if (ir++ < ir0) continue;
7626
+ if (ir > ir1) break;
7627
+
7628
+ float theta = (float)p;
7629
+
7204
7630
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7205
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
7631
+ const float cos_theta = cosf(theta);
7632
+ const float sin_theta = sinf(theta);
7206
7633
 
7207
- const float cos_theta = cosf(p*theta);
7208
- const float sin_theta = sinf(p*theta);
7634
+ theta *= theta_scale;
7209
7635
 
7210
7636
  const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7211
7637
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -7226,7 +7652,6 @@ static void ggml_compute_forward_rope_f16(
7226
7652
  const struct ggml_tensor * src0,
7227
7653
  const struct ggml_tensor * src1,
7228
7654
  struct ggml_tensor * dst) {
7229
- assert(params->ith == 0);
7230
7655
  assert(src1->type == GGML_TYPE_I32);
7231
7656
  assert(ggml_nelements(src1) == 3);
7232
7657
 
@@ -7238,10 +7663,10 @@ static void ggml_compute_forward_rope_f16(
7238
7663
  const int n_dims = ((int32_t *) src1->data)[1];
7239
7664
  const int mode = ((int32_t *) src1->data)[2];
7240
7665
 
7241
- //const int ne0 = src0->ne[0];
7242
- const int ne1 = src0->ne[1];
7243
- const int ne2 = src0->ne[2];
7244
- const int ne3 = src0->ne[3];
7666
+ //const int64_t ne0 = src0->ne[0];
7667
+ const int64_t ne1 = src0->ne[1];
7668
+ const int64_t ne2 = src0->ne[2];
7669
+ const int64_t ne3 = src0->ne[3];
7245
7670
 
7246
7671
  const int nb0 = src0->nb[0];
7247
7672
  const int nb1 = src0->nb[1];
@@ -7253,15 +7678,37 @@ static void ggml_compute_forward_rope_f16(
7253
7678
 
7254
7679
  assert(nb0 == sizeof(ggml_fp16_t));
7255
7680
 
7256
- for (int i3 = 0; i3 < ne3; i3++) {
7257
- for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7681
+ const int ith = params->ith;
7682
+ const int nth = params->nth;
7683
+
7684
+ const int nr = ggml_nrows(src0);
7685
+
7686
+ // rows per thread
7687
+ const int dr = (nr + nth - 1)/nth;
7688
+
7689
+ // row range for this thread
7690
+ const int ir0 = dr*ith;
7691
+ const int ir1 = MIN(ir0 + dr, nr);
7692
+
7693
+ // row index used to determine which thread to use
7694
+ int ir = 0;
7695
+
7696
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
7697
+
7698
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7699
+ for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7258
7700
  const int p = (mode == 0 ? n_past + i2 : i2);
7259
- for (int i1 = 0; i1 < ne1; i1++) {
7701
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7702
+ if (ir++ < ir0) continue;
7703
+ if (ir > ir1) break;
7704
+
7705
+ float theta = (float)p;
7706
+
7260
7707
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7261
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
7708
+ const float cos_theta = cosf(theta);
7709
+ const float sin_theta = sinf(theta);
7262
7710
 
7263
- const float cos_theta = cosf(p*theta);
7264
- const float sin_theta = sinf(p*theta);
7711
+ theta *= theta_scale;
7265
7712
 
7266
7713
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7267
7714
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -7317,21 +7764,21 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7317
7764
  int64_t t0 = ggml_perf_time_us();
7318
7765
  UNUSED(t0);
7319
7766
 
7320
- const int ne00 = src0->ne[0];
7321
- const int ne01 = src0->ne[1];
7322
- const int ne02 = src0->ne[2];
7323
- //const int ne03 = src0->ne[3];
7767
+ const int64_t ne00 = src0->ne[0];
7768
+ const int64_t ne01 = src0->ne[1];
7769
+ const int64_t ne02 = src0->ne[2];
7770
+ //const int64_t ne03 = src0->ne[3];
7324
7771
 
7325
- const int ne10 = src1->ne[0];
7326
- const int ne11 = src1->ne[1];
7327
- //const int ne12 = src1->ne[2];
7328
- //const int ne13 = src1->ne[3];
7772
+ const int64_t ne10 = src1->ne[0];
7773
+ const int64_t ne11 = src1->ne[1];
7774
+ //const int64_t ne12 = src1->ne[2];
7775
+ //const int64_t ne13 = src1->ne[3];
7329
7776
 
7330
- //const int ne0 = dst->ne[0];
7331
- //const int ne1 = dst->ne[1];
7332
- //const int ne2 = dst->ne[2];
7333
- //const int ne3 = dst->ne[3];
7334
- //const int ne = ne0*ne1*ne2*ne3;
7777
+ //const int64_t ne0 = dst->ne[0];
7778
+ //const int64_t ne1 = dst->ne[1];
7779
+ //const int64_t ne2 = dst->ne[2];
7780
+ //const int64_t ne3 = dst->ne[3];
7781
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7335
7782
 
7336
7783
  const int nb00 = src0->nb[0];
7337
7784
  const int nb01 = src0->nb[1];
@@ -7368,11 +7815,11 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7368
7815
  {
7369
7816
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
7370
7817
 
7371
- for (int i02 = 0; i02 < ne02; i02++) {
7372
- for (int i01 = 0; i01 < ne01; i01++) {
7818
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7819
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7373
7820
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
7374
7821
  ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
7375
- for (int i00 = 0; i00 < ne00; i00++) {
7822
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7376
7823
  dst_data[i00*ew0 + i01] = src[i00];
7377
7824
  }
7378
7825
  }
@@ -7383,10 +7830,10 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7383
7830
  {
7384
7831
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
7385
7832
 
7386
- for (int i11 = 0; i11 < ne11; i11++) {
7833
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7387
7834
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7388
7835
  ggml_fp16_t * dst_data = wdata;
7389
- for (int i10 = 0; i10 < ne10; i10++) {
7836
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7390
7837
  dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
7391
7838
  }
7392
7839
  }
@@ -7411,7 +7858,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
7411
7858
 
7412
7859
  for (int i1 = ir0; i1 < ir1; i1++) {
7413
7860
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7414
- for (int i0 = 0; i0 < ne10; ++i0) {
7861
+ for (int64_t i0 = 0; i0 < ne10; ++i0) {
7415
7862
  dst_data[i0] = 0;
7416
7863
  for (int k = -nh; k <= nh; k++) {
7417
7864
  float v = 0.0f;
@@ -7437,21 +7884,21 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7437
7884
  int64_t t0 = ggml_perf_time_us();
7438
7885
  UNUSED(t0);
7439
7886
 
7440
- const int ne00 = src0->ne[0];
7441
- const int ne01 = src0->ne[1];
7442
- const int ne02 = src0->ne[2];
7443
- //const int ne03 = src0->ne[3];
7887
+ const int64_t ne00 = src0->ne[0];
7888
+ const int64_t ne01 = src0->ne[1];
7889
+ const int64_t ne02 = src0->ne[2];
7890
+ //const int64_t ne03 = src0->ne[3];
7444
7891
 
7445
- const int ne10 = src1->ne[0];
7446
- const int ne11 = src1->ne[1];
7447
- //const int ne12 = src1->ne[2];
7448
- //const int ne13 = src1->ne[3];
7892
+ const int64_t ne10 = src1->ne[0];
7893
+ const int64_t ne11 = src1->ne[1];
7894
+ //const int64_t ne12 = src1->ne[2];
7895
+ //const int64_t ne13 = src1->ne[3];
7449
7896
 
7450
- //const int ne0 = dst->ne[0];
7451
- //const int ne1 = dst->ne[1];
7452
- //const int ne2 = dst->ne[2];
7453
- //const int ne3 = dst->ne[3];
7454
- //const int ne = ne0*ne1*ne2*ne3;
7897
+ //const int64_t ne0 = dst->ne[0];
7898
+ //const int64_t ne1 = dst->ne[1];
7899
+ //const int64_t ne2 = dst->ne[2];
7900
+ //const int64_t ne3 = dst->ne[3];
7901
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7455
7902
 
7456
7903
  const int nb00 = src0->nb[0];
7457
7904
  const int nb01 = src0->nb[1];
@@ -7488,11 +7935,11 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7488
7935
  {
7489
7936
  float * const wdata = (float *) params->wdata + 0;
7490
7937
 
7491
- for (int i02 = 0; i02 < ne02; i02++) {
7492
- for (int i01 = 0; i01 < ne01; i01++) {
7938
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7939
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7493
7940
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
7494
7941
  float * dst_data = wdata + i02*ew0*ne00;
7495
- for (int i00 = 0; i00 < ne00; i00++) {
7942
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7496
7943
  dst_data[i00*ew0 + i01] = src[i00];
7497
7944
  }
7498
7945
  }
@@ -7503,10 +7950,10 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7503
7950
  {
7504
7951
  float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
7505
7952
 
7506
- for (int i11 = 0; i11 < ne11; i11++) {
7953
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7507
7954
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7508
7955
  float * dst_data = wdata;
7509
- for (int i10 = 0; i10 < ne10; i10++) {
7956
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7510
7957
  dst_data[(i10 + nh)*ew0 + i11] = src[i10];
7511
7958
  }
7512
7959
  }
@@ -7531,7 +7978,7 @@ static void ggml_compute_forward_conv_1d_1s_f32(
7531
7978
 
7532
7979
  for (int i1 = ir0; i1 < ir1; i1++) {
7533
7980
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7534
- for (int i0 = 0; i0 < ne10; ++i0) {
7981
+ for (int64_t i0 = 0; i0 < ne10; ++i0) {
7535
7982
  dst_data[i0] = 0;
7536
7983
  for (int k = -nh; k <= nh; k++) {
7537
7984
  float v = 0.0f;
@@ -7585,21 +8032,21 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7585
8032
  int64_t t0 = ggml_perf_time_us();
7586
8033
  UNUSED(t0);
7587
8034
 
7588
- const int ne00 = src0->ne[0];
7589
- const int ne01 = src0->ne[1];
7590
- const int ne02 = src0->ne[2];
7591
- //const int ne03 = src0->ne[3];
8035
+ const int64_t ne00 = src0->ne[0];
8036
+ const int64_t ne01 = src0->ne[1];
8037
+ const int64_t ne02 = src0->ne[2];
8038
+ //const int64_t ne03 = src0->ne[3];
7592
8039
 
7593
- const int ne10 = src1->ne[0];
7594
- const int ne11 = src1->ne[1];
7595
- //const int ne12 = src1->ne[2];
7596
- //const int ne13 = src1->ne[3];
8040
+ const int64_t ne10 = src1->ne[0];
8041
+ const int64_t ne11 = src1->ne[1];
8042
+ //const int64_t ne12 = src1->ne[2];
8043
+ //const int64_t ne13 = src1->ne[3];
7597
8044
 
7598
- //const int ne0 = dst->ne[0];
7599
- //const int ne1 = dst->ne[1];
7600
- //const int ne2 = dst->ne[2];
7601
- //const int ne3 = dst->ne[3];
7602
- //const int ne = ne0*ne1*ne2*ne3;
8045
+ //const int64_t ne0 = dst->ne[0];
8046
+ //const int64_t ne1 = dst->ne[1];
8047
+ //const int64_t ne2 = dst->ne[2];
8048
+ //const int64_t ne3 = dst->ne[3];
8049
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7603
8050
 
7604
8051
  const int nb00 = src0->nb[0];
7605
8052
  const int nb01 = src0->nb[1];
@@ -7636,11 +8083,11 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7636
8083
  {
7637
8084
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
7638
8085
 
7639
- for (int i02 = 0; i02 < ne02; i02++) {
7640
- for (int i01 = 0; i01 < ne01; i01++) {
8086
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8087
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7641
8088
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
7642
8089
  ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
7643
- for (int i00 = 0; i00 < ne00; i00++) {
8090
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7644
8091
  dst_data[i00*ew0 + i01] = src[i00];
7645
8092
  }
7646
8093
  }
@@ -7651,10 +8098,10 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7651
8098
  {
7652
8099
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
7653
8100
 
7654
- for (int i11 = 0; i11 < ne11; i11++) {
8101
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7655
8102
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7656
8103
  ggml_fp16_t * dst_data = wdata;
7657
- for (int i10 = 0; i10 < ne10; i10++) {
8104
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7658
8105
  dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
7659
8106
  }
7660
8107
  }
@@ -7679,7 +8126,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
7679
8126
 
7680
8127
  for (int i1 = ir0; i1 < ir1; i1++) {
7681
8128
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7682
- for (int i0 = 0; i0 < ne10; i0 += 2) {
8129
+ for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
7683
8130
  dst_data[i0/2] = 0;
7684
8131
  for (int k = -nh; k <= nh; k++) {
7685
8132
  float v = 0.0f;
@@ -7705,21 +8152,21 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7705
8152
  int64_t t0 = ggml_perf_time_us();
7706
8153
  UNUSED(t0);
7707
8154
 
7708
- const int ne00 = src0->ne[0];
7709
- const int ne01 = src0->ne[1];
7710
- const int ne02 = src0->ne[2];
7711
- //const int ne03 = src0->ne[3];
8155
+ const int64_t ne00 = src0->ne[0];
8156
+ const int64_t ne01 = src0->ne[1];
8157
+ const int64_t ne02 = src0->ne[2];
8158
+ //const int64_t ne03 = src0->ne[3];
7712
8159
 
7713
- const int ne10 = src1->ne[0];
7714
- const int ne11 = src1->ne[1];
7715
- //const int ne12 = src1->ne[2];
7716
- //const int ne13 = src1->ne[3];
8160
+ const int64_t ne10 = src1->ne[0];
8161
+ const int64_t ne11 = src1->ne[1];
8162
+ //const int64_t ne12 = src1->ne[2];
8163
+ //const int64_t ne13 = src1->ne[3];
7717
8164
 
7718
- //const int ne0 = dst->ne[0];
7719
- //const int ne1 = dst->ne[1];
7720
- //const int ne2 = dst->ne[2];
7721
- //const int ne3 = dst->ne[3];
7722
- //const int ne = ne0*ne1*ne2*ne3;
8165
+ //const int64_t ne0 = dst->ne[0];
8166
+ //const int64_t ne1 = dst->ne[1];
8167
+ //const int64_t ne2 = dst->ne[2];
8168
+ //const int64_t ne3 = dst->ne[3];
8169
+ //const int64_t ne = ne0*ne1*ne2*ne3;
7723
8170
 
7724
8171
  const int nb00 = src0->nb[0];
7725
8172
  const int nb01 = src0->nb[1];
@@ -7756,11 +8203,11 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7756
8203
  {
7757
8204
  float * const wdata = (float *) params->wdata + 0;
7758
8205
 
7759
- for (int i02 = 0; i02 < ne02; i02++) {
7760
- for (int i01 = 0; i01 < ne01; i01++) {
8206
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8207
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
7761
8208
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
7762
8209
  float * dst_data = wdata + i02*ew0*ne00;
7763
- for (int i00 = 0; i00 < ne00; i00++) {
8210
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7764
8211
  dst_data[i00*ew0 + i01] = src[i00];
7765
8212
  }
7766
8213
  }
@@ -7771,10 +8218,10 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7771
8218
  {
7772
8219
  float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
7773
8220
 
7774
- for (int i11 = 0; i11 < ne11; i11++) {
8221
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
7775
8222
  const float * const src = (float *)((char *) src1->data + i11*nb11);
7776
8223
  float * dst_data = wdata;
7777
- for (int i10 = 0; i10 < ne10; i10++) {
8224
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
7778
8225
  dst_data[(i10 + nh)*ew0 + i11] = src[i10];
7779
8226
  }
7780
8227
  }
@@ -7799,7 +8246,7 @@ static void ggml_compute_forward_conv_1d_2s_f32(
7799
8246
 
7800
8247
  for (int i1 = ir0; i1 < ir1; i1++) {
7801
8248
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
7802
- for (int i0 = 0; i0 < ne10; i0 += 2) {
8249
+ for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
7803
8250
  dst_data[i0/2] = 0;
7804
8251
  for (int k = -nh; k <= nh; k++) {
7805
8252
  float v = 0.0f;
@@ -7851,25 +8298,25 @@ static void ggml_compute_forward_flash_attn_f32(
7851
8298
  int64_t t0 = ggml_perf_time_us();
7852
8299
  UNUSED(t0);
7853
8300
 
7854
- const int neq0 = q->ne[0];
7855
- const int neq1 = q->ne[1];
7856
- const int neq2 = q->ne[2];
7857
- const int neq3 = q->ne[3];
8301
+ const int64_t neq0 = q->ne[0];
8302
+ const int64_t neq1 = q->ne[1];
8303
+ const int64_t neq2 = q->ne[2];
8304
+ const int64_t neq3 = q->ne[3];
7858
8305
 
7859
- const int nek0 = k->ne[0];
7860
- const int nek1 = k->ne[1];
7861
- //const int nek2 = k->ne[2];
7862
- //const int nek3 = k->ne[3];
8306
+ const int64_t nek0 = k->ne[0];
8307
+ const int64_t nek1 = k->ne[1];
8308
+ //const int64_t nek2 = k->ne[2];
8309
+ //const int64_t nek3 = k->ne[3];
7863
8310
 
7864
- //const int nev0 = v->ne[0];
7865
- const int nev1 = v->ne[1];
7866
- //const int nev2 = v->ne[2];
7867
- //const int nev3 = v->ne[3];
8311
+ //const int64_t nev0 = v->ne[0];
8312
+ const int64_t nev1 = v->ne[1];
8313
+ //const int64_t nev2 = v->ne[2];
8314
+ //const int64_t nev3 = v->ne[3];
7868
8315
 
7869
- const int ne0 = dst->ne[0];
7870
- const int ne1 = dst->ne[1];
7871
- //const int ne2 = dst->ne[2];
7872
- //const int ne3 = dst->ne[3];
8316
+ const int64_t ne0 = dst->ne[0];
8317
+ const int64_t ne1 = dst->ne[1];
8318
+ //const int64_t ne2 = dst->ne[2];
8319
+ //const int64_t ne3 = dst->ne[3];
7873
8320
 
7874
8321
  const int nbk0 = k->nb[0];
7875
8322
  const int nbk1 = k->nb[1];
@@ -7894,10 +8341,10 @@ static void ggml_compute_forward_flash_attn_f32(
7894
8341
  const int ith = params->ith;
7895
8342
  const int nth = params->nth;
7896
8343
 
7897
- const int D = neq0;
7898
- const int N = neq1;
7899
- const int P = nek1 - N;
7900
- const int M = P + N;
8344
+ const int64_t D = neq0;
8345
+ const int64_t N = neq1;
8346
+ const int64_t P = nek1 - N;
8347
+ const int64_t M = P + N;
7901
8348
 
7902
8349
  const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
7903
8350
 
@@ -7959,7 +8406,7 @@ static void ggml_compute_forward_flash_attn_f32(
7959
8406
  S[i] = -INFINITY;
7960
8407
  }
7961
8408
 
7962
- for (int ic = 0; ic < nek1; ++ic) {
8409
+ for (int64_t ic = 0; ic < nek1; ++ic) {
7963
8410
  // k indices
7964
8411
  const int ik3 = iq3;
7965
8412
  const int ik2 = iq2;
@@ -7978,7 +8425,7 @@ static void ggml_compute_forward_flash_attn_f32(
7978
8425
  ggml_vec_scale_f32(nek1, S, scale);
7979
8426
 
7980
8427
  if (masked) {
7981
- for (int i = P; i < M; i++) {
8428
+ for (int64_t i = P; i < M; i++) {
7982
8429
  if (i > P + iq1) {
7983
8430
  S[i] = -INFINITY;
7984
8431
  }
@@ -8036,7 +8483,7 @@ static void ggml_compute_forward_flash_attn_f32(
8036
8483
  #endif
8037
8484
  }
8038
8485
 
8039
- for (int ic = 0; ic < nev1; ++ic) {
8486
+ for (int64_t ic = 0; ic < nev1; ++ic) {
8040
8487
  // dst indices
8041
8488
  const int i1 = iq1;
8042
8489
  const int i2 = iq2;
@@ -8060,25 +8507,25 @@ static void ggml_compute_forward_flash_attn_f16(
8060
8507
  int64_t t0 = ggml_perf_time_us();
8061
8508
  UNUSED(t0);
8062
8509
 
8063
- const int neq0 = q->ne[0];
8064
- const int neq1 = q->ne[1];
8065
- const int neq2 = q->ne[2];
8066
- const int neq3 = q->ne[3];
8510
+ const int64_t neq0 = q->ne[0];
8511
+ const int64_t neq1 = q->ne[1];
8512
+ const int64_t neq2 = q->ne[2];
8513
+ const int64_t neq3 = q->ne[3];
8067
8514
 
8068
- const int nek0 = k->ne[0];
8069
- const int nek1 = k->ne[1];
8070
- //const int nek2 = k->ne[2];
8071
- //const int nek3 = k->ne[3];
8515
+ const int64_t nek0 = k->ne[0];
8516
+ const int64_t nek1 = k->ne[1];
8517
+ //const int64_t nek2 = k->ne[2];
8518
+ //const int64_t nek3 = k->ne[3];
8072
8519
 
8073
- //const int nev0 = v->ne[0];
8074
- const int nev1 = v->ne[1];
8075
- //const int nev2 = v->ne[2];
8076
- //const int nev3 = v->ne[3];
8520
+ //const int64_t nev0 = v->ne[0];
8521
+ const int64_t nev1 = v->ne[1];
8522
+ //const int64_t nev2 = v->ne[2];
8523
+ //const int64_t nev3 = v->ne[3];
8077
8524
 
8078
- const int ne0 = dst->ne[0];
8079
- const int ne1 = dst->ne[1];
8080
- //const int ne2 = dst->ne[2];
8081
- //const int ne3 = dst->ne[3];
8525
+ const int64_t ne0 = dst->ne[0];
8526
+ const int64_t ne1 = dst->ne[1];
8527
+ //const int64_t ne2 = dst->ne[2];
8528
+ //const int64_t ne3 = dst->ne[3];
8082
8529
 
8083
8530
  const int nbk0 = k->nb[0];
8084
8531
  const int nbk1 = k->nb[1];
@@ -8103,10 +8550,10 @@ static void ggml_compute_forward_flash_attn_f16(
8103
8550
  const int ith = params->ith;
8104
8551
  const int nth = params->nth;
8105
8552
 
8106
- const int D = neq0;
8107
- const int N = neq1;
8108
- const int P = nek1 - N;
8109
- const int M = P + N;
8553
+ const int64_t D = neq0;
8554
+ const int64_t N = neq1;
8555
+ const int64_t P = nek1 - N;
8556
+ const int64_t M = P + N;
8110
8557
 
8111
8558
  const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
8112
8559
 
@@ -8169,7 +8616,7 @@ static void ggml_compute_forward_flash_attn_f16(
8169
8616
  }
8170
8617
 
8171
8618
  if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
8172
- for (int ic = 0; ic < nek1; ++ic) {
8619
+ for (int64_t ic = 0; ic < nek1; ++ic) {
8173
8620
  // k indices
8174
8621
  const int ik3 = iq3;
8175
8622
  const int ik2 = iq2;
@@ -8184,7 +8631,7 @@ static void ggml_compute_forward_flash_attn_f16(
8184
8631
  (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
8185
8632
  }
8186
8633
  } else {
8187
- for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
8634
+ for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
8188
8635
  // k indices
8189
8636
  const int ik3 = iq3;
8190
8637
  const int ik2 = iq2;
@@ -8204,7 +8651,7 @@ static void ggml_compute_forward_flash_attn_f16(
8204
8651
  ggml_vec_scale_f32(nek1, S, scale);
8205
8652
 
8206
8653
  if (masked) {
8207
- for (int i = P; i < M; i++) {
8654
+ for (int64_t i = P; i < M; i++) {
8208
8655
  if (i > P + iq1) {
8209
8656
  S[i] = -INFINITY;
8210
8657
  }
@@ -8264,12 +8711,12 @@ static void ggml_compute_forward_flash_attn_f16(
8264
8711
 
8265
8712
  ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
8266
8713
 
8267
- for (int i = 0; i < M; i++) {
8714
+ for (int64_t i = 0; i < M; i++) {
8268
8715
  S16[i] = GGML_FP32_TO_FP16(S[i]);
8269
8716
  }
8270
8717
 
8271
8718
  if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
8272
- for (int ic = 0; ic < nev1; ++ic) {
8719
+ for (int64_t ic = 0; ic < nev1; ++ic) {
8273
8720
  // dst indices
8274
8721
  const int i1 = iq1;
8275
8722
  const int i2 = iq2;
@@ -8281,7 +8728,7 @@ static void ggml_compute_forward_flash_attn_f16(
8281
8728
  S16);
8282
8729
  }
8283
8730
  } else {
8284
- for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
8731
+ for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
8285
8732
  // dst indices
8286
8733
  const int i1 = iq1;
8287
8734
  const int i2 = iq2;
@@ -8337,35 +8784,35 @@ static void ggml_compute_forward_flash_ff_f16(
8337
8784
  int64_t t0 = ggml_perf_time_us();
8338
8785
  UNUSED(t0);
8339
8786
 
8340
- const int nea0 = a->ne[0];
8341
- const int nea1 = a->ne[1];
8342
- const int nea2 = a->ne[2];
8343
- const int nea3 = a->ne[3];
8787
+ const int64_t nea0 = a->ne[0];
8788
+ const int64_t nea1 = a->ne[1];
8789
+ const int64_t nea2 = a->ne[2];
8790
+ const int64_t nea3 = a->ne[3];
8344
8791
 
8345
- const int neb00 = b0->ne[0];
8346
- const int neb01 = b0->ne[1];
8347
- //const int neb02 = b0->ne[2];
8348
- //const int neb03 = b0->ne[3];
8792
+ const int64_t neb00 = b0->ne[0];
8793
+ const int64_t neb01 = b0->ne[1];
8794
+ //const int64_t neb02 = b0->ne[2];
8795
+ //const int64_t neb03 = b0->ne[3];
8349
8796
 
8350
- const int neb10 = b1->ne[0];
8351
- const int neb11 = b1->ne[1];
8352
- //const int neb12 = b1->ne[2];
8353
- //const int neb13 = b1->ne[3];
8797
+ const int64_t neb10 = b1->ne[0];
8798
+ const int64_t neb11 = b1->ne[1];
8799
+ //const int64_t neb12 = b1->ne[2];
8800
+ //const int64_t neb13 = b1->ne[3];
8354
8801
 
8355
- const int nec00 = c0->ne[0];
8356
- const int nec01 = c0->ne[1];
8357
- //const int nec02 = c0->ne[2];
8358
- //const int nec03 = c0->ne[3];
8802
+ const int64_t nec00 = c0->ne[0];
8803
+ const int64_t nec01 = c0->ne[1];
8804
+ //const int64_t nec02 = c0->ne[2];
8805
+ //const int64_t nec03 = c0->ne[3];
8359
8806
 
8360
- const int nec10 = c1->ne[0];
8361
- const int nec11 = c1->ne[1];
8362
- //const int nec12 = c1->ne[2];
8363
- //const int nec13 = c1->ne[3];
8807
+ const int64_t nec10 = c1->ne[0];
8808
+ const int64_t nec11 = c1->ne[1];
8809
+ //const int64_t nec12 = c1->ne[2];
8810
+ //const int64_t nec13 = c1->ne[3];
8364
8811
 
8365
- const int ne0 = dst->ne[0];
8366
- const int ne1 = dst->ne[1];
8367
- const int ne2 = dst->ne[2];
8368
- //const int ne3 = dst->ne[3];
8812
+ const int64_t ne0 = dst->ne[0];
8813
+ const int64_t ne1 = dst->ne[1];
8814
+ const int64_t ne2 = dst->ne[2];
8815
+ //const int64_t ne3 = dst->ne[3];
8369
8816
 
8370
8817
  const int nba0 = a->nb[0];
8371
8818
  const int nba1 = a->nb[1];
@@ -8400,9 +8847,9 @@ static void ggml_compute_forward_flash_ff_f16(
8400
8847
  const int ith = params->ith;
8401
8848
  const int nth = params->nth;
8402
8849
 
8403
- const int D = nea0;
8404
- //const int N = nea1;
8405
- const int M = neb01;
8850
+ const int64_t D = nea0;
8851
+ //const int64_t N = nea1;
8852
+ const int64_t M = neb01;
8406
8853
 
8407
8854
  GGML_ASSERT(ne0 == nea0);
8408
8855
  GGML_ASSERT(ne1 == nea1);
@@ -8458,7 +8905,7 @@ static void ggml_compute_forward_flash_ff_f16(
8458
8905
 
8459
8906
  float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
8460
8907
 
8461
- for (int ic = 0; ic < neb01; ++ic) {
8908
+ for (int64_t ic = 0; ic < neb01; ++ic) {
8462
8909
  // b0 indices
8463
8910
  const int ib03 = ia3;
8464
8911
  const int ib02 = ia2;
@@ -8478,7 +8925,7 @@ static void ggml_compute_forward_flash_ff_f16(
8478
8925
 
8479
8926
  ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
8480
8927
 
8481
- for (int i = 0; i < M; i++) {
8928
+ for (int64_t i = 0; i < M; i++) {
8482
8929
  S16[i] = GGML_FP32_TO_FP16(S[i]);
8483
8930
  }
8484
8931
 
@@ -8490,7 +8937,7 @@ static void ggml_compute_forward_flash_ff_f16(
8490
8937
  const int i2 = ia2;
8491
8938
  const int i3 = ia3;
8492
8939
 
8493
- for (int ic = 0; ic < nec01; ++ic) {
8940
+ for (int64_t ic = 0; ic < nec01; ++ic) {
8494
8941
 
8495
8942
  ggml_vec_dot_f16(neb01,
8496
8943
  (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
@@ -8535,6 +8982,111 @@ static void ggml_compute_forward_flash_ff(
8535
8982
  }
8536
8983
  }
8537
8984
 
8985
+ // ggml_compute_forward_map_unary
8986
+
8987
+ static void ggml_compute_forward_map_unary_f32(
8988
+ const struct ggml_compute_params * params,
8989
+ const struct ggml_tensor * src0,
8990
+ struct ggml_tensor * dst,
8991
+ const ggml_unary_op_f32_t fun) {
8992
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
8993
+
8994
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8995
+ return;
8996
+ }
8997
+
8998
+ const int n = ggml_nrows(src0);
8999
+ const int nc = src0->ne[0];
9000
+
9001
+ assert( dst->nb[0] == sizeof(float));
9002
+ assert(src0->nb[0] == sizeof(float));
9003
+
9004
+ for (int i = 0; i < n; i++) {
9005
+ fun(nc,
9006
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9007
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
9008
+ }
9009
+ }
9010
+
9011
+
9012
+ static void ggml_compute_forward_map_unary(
9013
+ const struct ggml_compute_params * params,
9014
+ const struct ggml_tensor * src0,
9015
+ struct ggml_tensor * dst,
9016
+ const ggml_unary_op_f32_t fun) {
9017
+ switch (src0->type) {
9018
+ case GGML_TYPE_F32:
9019
+ {
9020
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9021
+ } break;
9022
+ case GGML_TYPE_Q4_0:
9023
+ case GGML_TYPE_Q4_1:
9024
+ case GGML_TYPE_I8:
9025
+ case GGML_TYPE_I16:
9026
+ case GGML_TYPE_I32:
9027
+ case GGML_TYPE_F16:
9028
+ case GGML_TYPE_COUNT:
9029
+ {
9030
+ GGML_ASSERT(false);
9031
+ } break;
9032
+ }
9033
+ }
9034
+
9035
+ // ggml_compute_forward_map_binary
9036
+
9037
+ static void ggml_compute_forward_map_binary_f32(
9038
+ const struct ggml_compute_params * params,
9039
+ const struct ggml_tensor * src0,
9040
+ const struct ggml_tensor * src1,
9041
+ struct ggml_tensor * dst,
9042
+ const ggml_binary_op_f32_t fun) {
9043
+ assert(params->ith == 0);
9044
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9045
+
9046
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9047
+ return;
9048
+ }
9049
+
9050
+ const int n = ggml_nrows(src0);
9051
+ const int nc = src0->ne[0];
9052
+
9053
+ assert( dst->nb[0] == sizeof(float));
9054
+ assert(src0->nb[0] == sizeof(float));
9055
+ assert(src1->nb[0] == sizeof(float));
9056
+
9057
+ for (int i = 0; i < n; i++) {
9058
+ fun(nc,
9059
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9060
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
9061
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
9062
+ }
9063
+ }
9064
+
9065
+
9066
+ static void ggml_compute_forward_map_binary(
9067
+ const struct ggml_compute_params * params,
9068
+ const struct ggml_tensor * src0,
9069
+ const struct ggml_tensor * src1,
9070
+ struct ggml_tensor * dst,
9071
+ const ggml_binary_op_f32_t fun) {
9072
+ switch (src0->type) {
9073
+ case GGML_TYPE_F32:
9074
+ {
9075
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9076
+ } break;
9077
+ case GGML_TYPE_Q4_0:
9078
+ case GGML_TYPE_Q4_1:
9079
+ case GGML_TYPE_I8:
9080
+ case GGML_TYPE_I16:
9081
+ case GGML_TYPE_I32:
9082
+ case GGML_TYPE_F16:
9083
+ case GGML_TYPE_COUNT:
9084
+ {
9085
+ GGML_ASSERT(false);
9086
+ } break;
9087
+ }
9088
+ }
9089
+
8538
9090
  /////////////////////////////////
8539
9091
 
8540
9092
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -8629,6 +9181,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8629
9181
  {
8630
9182
  ggml_compute_forward_cpy(params, tensor->src0, tensor);
8631
9183
  } break;
9184
+ case GGML_OP_CONT:
9185
+ {
9186
+ ggml_compute_forward_cont(params, tensor->src0, tensor);
9187
+ } break;
8632
9188
  case GGML_OP_RESHAPE:
8633
9189
  {
8634
9190
  ggml_compute_forward_reshape(params, tensor->src0, tensor);
@@ -8680,6 +9236,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8680
9236
  {
8681
9237
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
8682
9238
  } break;
9239
+ case GGML_OP_MAP_UNARY:
9240
+ {
9241
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
9242
+ ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
9243
+ }
9244
+ break;
9245
+ case GGML_OP_MAP_BINARY:
9246
+ {
9247
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
9248
+ ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
9249
+ }
9250
+ break;
8683
9251
  case GGML_OP_NONE:
8684
9252
  {
8685
9253
  // nop
@@ -8873,8 +9441,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8873
9441
  src1->grad =
8874
9442
  ggml_add_impl(ctx,
8875
9443
  src1->grad,
8876
- // TODO: fix transpose, the node will break the graph connections
8877
- ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
9444
+ ggml_mul_mat(ctx,
9445
+ ggml_cont(ctx, ggml_transpose(ctx, src0)),
9446
+ tensor->grad),
8878
9447
  inplace);
8879
9448
  }
8880
9449
  } break;
@@ -8886,6 +9455,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8886
9455
  {
8887
9456
  GGML_ASSERT(false); // TODO: not implemented
8888
9457
  } break;
9458
+ case GGML_OP_CONT:
9459
+ {
9460
+ GGML_ASSERT(false); // TODO: not implemented
9461
+ } break;
8889
9462
  case GGML_OP_RESHAPE:
8890
9463
  {
8891
9464
  GGML_ASSERT(false); // TODO: not implemented
@@ -8934,6 +9507,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8934
9507
  {
8935
9508
  GGML_ASSERT(false); // not supported
8936
9509
  } break;
9510
+ case GGML_OP_MAP_UNARY:
9511
+ case GGML_OP_MAP_BINARY:
9512
+ {
9513
+ GGML_ASSERT(false); // not supported
9514
+ } break;
8937
9515
  case GGML_OP_NONE:
8938
9516
  {
8939
9517
  // nop
@@ -9024,7 +9602,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
9024
9602
  struct ggml_cgraph result = {
9025
9603
  /*.n_nodes =*/ 0,
9026
9604
  /*.n_leafs =*/ 0,
9027
- /*.n_threads =*/ 0,
9605
+ /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
9028
9606
  /*.work_size =*/ 0,
9029
9607
  /*.work =*/ NULL,
9030
9608
  /*.nodes =*/ { NULL },
@@ -9340,6 +9918,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9340
9918
  node->n_tasks = n_threads;
9341
9919
  } break;
9342
9920
  case GGML_OP_CPY:
9921
+ case GGML_OP_CONT:
9343
9922
  case GGML_OP_RESHAPE:
9344
9923
  case GGML_OP_VIEW:
9345
9924
  case GGML_OP_PERMUTE:
@@ -9355,7 +9934,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9355
9934
  } break;
9356
9935
  case GGML_OP_ROPE:
9357
9936
  {
9358
- node->n_tasks = 1;
9937
+ node->n_tasks = n_threads;
9359
9938
  } break;
9360
9939
  case GGML_OP_CONV_1D_1S:
9361
9940
  case GGML_OP_CONV_1D_2S:
@@ -9393,7 +9972,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9393
9972
 
9394
9973
  size_t cur = 0;
9395
9974
 
9396
- const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
9975
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
9397
9976
 
9398
9977
  if (node->src1->type == GGML_TYPE_F32) {
9399
9978
  cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
@@ -9425,6 +10004,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9425
10004
 
9426
10005
  work_size = MAX(work_size, cur);
9427
10006
  } break;
10007
+ case GGML_OP_MAP_UNARY:
10008
+ case GGML_OP_MAP_BINARY:
10009
+ {
10010
+ node->n_tasks = 1;
10011
+ } break;
9428
10012
  case GGML_OP_NONE:
9429
10013
  {
9430
10014
  node->n_tasks = 1;
@@ -9643,8 +10227,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9643
10227
 
9644
10228
  GGML_PRINT("=== GRAPH ===\n");
9645
10229
 
9646
- GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
9647
- GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
10230
+ GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
10231
+ GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
9648
10232
 
9649
10233
  GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
9650
10234
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -9652,7 +10236,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9652
10236
 
9653
10237
  perf_total_per_op_us[node->op] += node->perf_time_us;
9654
10238
 
9655
- GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
10239
+ GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
9656
10240
  i,
9657
10241
  node->ne[0], node->ne[1], node->ne[2],
9658
10242
  GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -9666,7 +10250,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9666
10250
  for (int i = 0; i < cgraph->n_leafs; i++) {
9667
10251
  struct ggml_tensor * node = cgraph->leafs[i];
9668
10252
 
9669
- GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n",
10253
+ GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
9670
10254
  i,
9671
10255
  node->ne[0], node->ne[1],
9672
10256
  GGML_OP_LABEL[node->op]);
@@ -9737,7 +10321,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
9737
10321
 
9738
10322
  fprintf(fp, " \"%p\" [ \
9739
10323
  style = filled; fillcolor = %s; shape = record; \
9740
- label=\"%d [%d, %d] | <x>%s",
10324
+ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
9741
10325
  (void *) node, color,
9742
10326
  i, node->ne[0], node->ne[1],
9743
10327
  GGML_OP_SYMBOL[node->op]);
@@ -9762,7 +10346,7 @@ label=\"<x>%.1e\"; ]\n",
9762
10346
  } else {
9763
10347
  fprintf(fp, " \"%p\" [ \
9764
10348
  style = filled; fillcolor = %s; shape = record; \
9765
- label=\"<x>CONST %d [%d, %d]\"; ]\n",
10349
+ label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
9766
10350
  (void *) node, color,
9767
10351
  i, node->ne[0], node->ne[1]);
9768
10352
  }
@@ -9826,9 +10410,9 @@ label=\"<x>CONST %d [%d, %d]\"; ]\n",
9826
10410
  static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
9827
10411
  int i = 0;
9828
10412
  for (int p = 0; p < np; ++p) {
9829
- const int ne = ggml_nelements(ps[p]) ;
10413
+ const int64_t ne = ggml_nelements(ps[p]) ;
9830
10414
  // TODO: add function to set tensor from array
9831
- for (int j = 0; j < ne; ++j) {
10415
+ for (int64_t j = 0; j < ne; ++j) {
9832
10416
  ggml_set_f32_1d(ps[p], j, x[i++]);
9833
10417
  }
9834
10418
  }
@@ -9837,9 +10421,9 @@ static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const f
9837
10421
  static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
9838
10422
  int i = 0;
9839
10423
  for (int p = 0; p < np; ++p) {
9840
- const int ne = ggml_nelements(ps[p]) ;
10424
+ const int64_t ne = ggml_nelements(ps[p]) ;
9841
10425
  // TODO: add function to get all elements at once
9842
- for (int j = 0; j < ne; ++j) {
10426
+ for (int64_t j = 0; j < ne; ++j) {
9843
10427
  x[i++] = ggml_get_f32_1d(ps[p], j);
9844
10428
  }
9845
10429
  }
@@ -9848,9 +10432,9 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
9848
10432
  static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
9849
10433
  int i = 0;
9850
10434
  for (int p = 0; p < np; ++p) {
9851
- const int ne = ggml_nelements(ps[p]) ;
10435
+ const int64_t ne = ggml_nelements(ps[p]) ;
9852
10436
  // TODO: add function to get all elements at once
9853
- for (int j = 0; j < ne; ++j) {
10437
+ for (int64_t j = 0; j < ne; ++j) {
9854
10438
  g[i++] = ggml_get_f32_1d(ps[p]->grad, j);
9855
10439
  }
9856
10440
  }