llama_cpp 0.0.3 → 0.0.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,4 +1,4 @@
1
- // Defines CLOCK_MONOTONIC and asprintf on Linux
1
+ // Defines CLOCK_MONOTONIC on Linux
2
2
  #define _GNU_SOURCE
3
3
 
4
4
  #include "ggml.h"
@@ -26,14 +26,9 @@
26
26
  #define static_assert(cond, msg) struct global_scope_noop_trick
27
27
  #endif
28
28
 
29
- #if defined _MSC_VER || defined(__MINGW32__)
29
+ #if defined(_WIN32)
30
30
 
31
- #if !defined(__MINGW32__)
32
- #include <Windows.h>
33
- #else
34
- // ref: https://github.com/ggerganov/whisper.cpp/issues/168
35
31
  #include <windows.h>
36
- #endif
37
32
 
38
33
  typedef volatile LONG atomic_int;
39
34
  typedef atomic_int atomic_bool;
@@ -55,6 +50,7 @@ typedef HANDLE pthread_t;
55
50
 
56
51
  typedef DWORD thread_ret_t;
57
52
  static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
53
+ (void) unused;
58
54
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
59
55
  if (handle == NULL)
60
56
  {
@@ -66,6 +62,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
66
62
  }
67
63
 
68
64
  static int pthread_join(pthread_t thread, void* unused) {
65
+ (void) unused;
69
66
  return (int) WaitForSingleObject(thread, INFINITE);
70
67
  }
71
68
 
@@ -97,17 +94,6 @@ typedef void* thread_ret_t;
97
94
  #define static_assert(cond, msg) _Static_assert(cond, msg)
98
95
  #endif
99
96
 
100
- #define GGML_MLOCK_SUPPORT 0
101
-
102
- #ifdef __has_include
103
- #if __has_include(<sys/mman.h>)
104
- #undef GGML_MLOCK_SUPPORT
105
- #define GGML_MLOCK_SUPPORT 1
106
- #include <sys/mman.h>
107
- #endif
108
- #endif
109
-
110
-
111
97
  /*#define GGML_PERF*/
112
98
  #define GGML_DEBUG 0
113
99
  #define GGML_GELU_FP16
@@ -128,6 +114,14 @@ typedef void* thread_ret_t;
128
114
  #define GGML_MEM_ALIGN 16
129
115
  #endif
130
116
 
117
+ #if defined(_MSC_VER) || defined(__MINGW32__)
118
+ #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
+ #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
+ #else
121
+ #define GGML_ALIGNED_MALLOC(size) aligned_alloc(GGML_MEM_ALIGN, size)
122
+ #define GGML_ALIGNED_FREE(ptr) free(ptr)
123
+ #endif
124
+
131
125
  #define UNUSED(x) (void)(x)
132
126
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
133
127
 
@@ -242,12 +236,12 @@ static inline float fp32_from_bits(uint32_t w) {
242
236
  }
243
237
 
244
238
  static inline uint32_t fp32_to_bits(float f) {
245
- union {
246
- float as_value;
247
- uint32_t as_bits;
248
- } fp32;
249
- fp32.as_value = f;
250
- return fp32.as_bits;
239
+ union {
240
+ float as_value;
241
+ uint32_t as_bits;
242
+ } fp32;
243
+ fp32.as_value = f;
244
+ return fp32.as_bits;
251
245
  }
252
246
 
253
247
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
@@ -497,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
497
491
  }
498
492
  #endif
499
493
 
494
+ #if __ARM_NEON
495
+
496
+ #if !defined(__aarch64__)
497
+
498
+ inline static uint16_t vaddvq_u8(uint8x16_t v) {
499
+ return
500
+ (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
501
+ (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
502
+ (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
503
+ (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
504
+ (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
505
+ (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
506
+ (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
507
+ (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
508
+ }
509
+
510
+ inline static int32_t vaddvq_s16(int16x8_t v) {
511
+ return
512
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
513
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
514
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
515
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
516
+ }
517
+
518
+ inline static uint32_t vaddvq_u16(uint16x8_t v) {
519
+ return
520
+ (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
521
+ (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
522
+ (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
523
+ (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
524
+ }
525
+
526
+ inline static int32_t vaddvq_s32(int32x4_t v) {
527
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
528
+ }
529
+
530
+ inline static float vaddvq_f32(float32x4_t v) {
531
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
532
+ }
533
+
534
+ inline float vminvq_f32(float32x4_t v) {
535
+ return
536
+ MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
537
+ MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
538
+ }
539
+
540
+ inline float vmaxvq_f32(float32x4_t v) {
541
+ return
542
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
543
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
544
+ }
545
+
546
+ inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
547
+ return vget_low_s8(vcombine_s8(a, b));
548
+ }
549
+
550
+ inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
551
+ return vget_high_s8(vcombine_s8(a, b));
552
+ }
553
+
554
+ inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
555
+ return vget_low_u8(vcombine_u8(a, b));
556
+ }
557
+
558
+ inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
559
+ return vget_high_u8(vcombine_u8(a, b));
560
+ }
561
+
562
+ #endif
563
+ #endif
564
+
500
565
  // method 5
501
566
  // blocks of QK elements
502
567
  // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -610,10 +675,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
610
675
  for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
611
676
  for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
612
677
 
613
- // absolute max
614
- const float amax = MAX(
615
- MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
616
- MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
678
+ const float amax = vmaxvq_f32(amaxv[0]);
617
679
 
618
680
  const float d = amax / ((1 << 3) - 1);
619
681
  const float id = d ? 1.0f/d : 0.0f;
@@ -935,7 +997,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
935
997
  float32x4_t minv[8];
936
998
  float32x4_t maxv[8];
937
999
 
938
- for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1000
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
939
1001
 
940
1002
  for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
941
1003
  for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -958,7 +1020,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
958
1020
 
959
1021
  for (int l = 0; l < 8; l++) {
960
1022
  const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
961
- const int32x4_t vi = vcvtq_s32_f32(v);
1023
+ const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
1024
+ const int32x4_t vi = vcvtq_s32_f32(vf);
962
1025
 
963
1026
  y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
964
1027
  y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
@@ -1226,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1226
1289
  #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
1227
1290
  #define GGML_F32x4_ADD vaddq_f32
1228
1291
  #define GGML_F32x4_MUL vmulq_f32
1229
- #if defined(__ARM_FEATURE_QRDMX)
1230
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1231
- #else
1232
- #define GGML_F32x4_REDUCE_ONE(x) \
1233
- (vgetq_lane_f32(x, 0) + \
1234
- vgetq_lane_f32(x, 1) + \
1235
- vgetq_lane_f32(x, 2) + \
1236
- vgetq_lane_f32(x, 3))
1237
- #endif
1292
+ #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1238
1293
  #define GGML_F32x4_REDUCE(res, x) \
1239
1294
  { \
1240
1295
  for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1857,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1857
1912
  // 4-bit -> 8-bit
1858
1913
  const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
1859
1914
  const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
1860
-
1861
1915
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
1862
1916
  const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
1863
1917
 
1864
1918
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
1865
1919
  const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
1866
-
1867
1920
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
1868
1921
  const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
1869
1922
 
1870
1923
  // sub 8
1871
1924
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
1872
1925
  const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1873
-
1874
1926
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
1875
1927
  const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
1876
1928
 
1877
1929
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
1878
1930
  const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1879
-
1880
1931
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
1881
1932
  const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
1882
1933
 
1883
1934
  #if defined(__ARM_FEATURE_DOTPROD)
1884
- // dot product into int16x8_t
1935
+ // dot product into int32x4_t
1885
1936
  int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1886
1937
  int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
1887
1938
 
1888
1939
  p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1889
1940
  p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
1890
1941
 
1891
- // scalar
1892
- #if defined(__ARM_FEATURE_QRDMX)
1893
- sum0 += x0->d * y0->d * vaddvq_s32(p_0);
1894
- sum1 += x1->d * y1->d * vaddvq_s32(p_1);
1942
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
1943
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
1895
1944
  #else
1896
- sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
1897
- sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
1898
- #endif
1899
- #else
1900
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1945
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1901
1946
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1902
-
1903
1947
  const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
1904
1948
  const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
1905
1949
 
1906
1950
  const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
1907
1951
  const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1908
-
1909
1952
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1910
1953
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1911
1954
 
@@ -1918,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1918
1961
  const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1919
1962
  const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1920
1963
 
1921
- // scalar
1922
- #if defined(__ARM_FEATURE_QRDMX)
1923
- sum0 += x0->d * y0->d * vaddvq_s16(p_0);
1924
- sum1 += x1->d * y1->d * vaddvq_s16(p_1);
1925
- #else
1926
- sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
1927
- sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
1928
- #endif
1964
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
1965
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
1929
1966
  #endif
1930
1967
  }
1931
1968
 
@@ -1962,7 +1999,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1962
1999
  // Initialize accumulator with zeros
1963
2000
  __m256 acc = _mm256_setzero_ps();
1964
2001
 
1965
- /* Prepare the constants we will need during execution */
2002
+ /* Prepare the constants we will need during execution */
1966
2003
  const __m256i lowMask = _mm256_set1_epi8( 0xF );
1967
2004
  const __m256i offset_8 = _mm256_set1_epi16( 8 );
1968
2005
 
@@ -1972,61 +2009,59 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1972
2009
 
1973
2010
  // Main loop
1974
2011
  for (int i = 0; i < nb; i+=UNROLL_COUNT) {
1975
-
1976
- // This loop will be unrolled by the compiler
2012
+ // This loop will be unrolled by the compiler
1977
2013
  for (int u=0;u<UNROLL_COUNT;u++) {
1978
- /* Compute combined scale for the block */
1979
- const __m256 scale = _mm256_mul_ps(
1980
- _mm256_broadcast_ss( &x[i+u].d ),
1981
- _mm256_broadcast_ss( &y[i+u].d ) );
1982
-
1983
- /* get input from x
1984
- Input: 32 Nibbles (16 bytes) at *x[i+u]
1985
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1986
-
1987
- /* Load 16 bytes from memory */
1988
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
1989
- /* Expand bytes into uint16_t values */
1990
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2014
+ /* Compute combined scale for the block */
2015
+ const __m256 scale = _mm256_mul_ps(
2016
+ _mm256_broadcast_ss( &x[i+u].d ),
2017
+ _mm256_broadcast_ss( &y[i+u].d ) );
2018
+
2019
+ /* get input from x
2020
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
2021
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2022
+
2023
+ /* Load 16 bytes from memory */
2024
+ const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2025
+ /* Expand bytes into uint16_t values */
2026
+ const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
1991
2027
  /* Unpack values into individual bytes */
1992
2028
  __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
1993
2029
  const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
1994
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2030
+ __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
1995
2031
  /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1996
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
1997
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2032
+ x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2033
+ x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
1998
2034
 
1999
- /* get input from y
2000
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2001
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2035
+ /* get input from y
2036
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
2037
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2002
2038
 
2003
- /* Load 16 bytes from memory */
2004
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2005
- /* Expand bytes into uint16_t values */
2006
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2039
+ /* Load 16 bytes from memory */
2040
+ const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2041
+ /* Expand bytes into uint16_t values */
2042
+ const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2007
2043
  /* Unpack values into individual bytes */
2008
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2009
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2010
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2044
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2045
+ __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2046
+ __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2011
2047
  /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2012
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2013
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2048
+ y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2049
+ y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2014
2050
 
2015
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2016
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2017
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2051
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
2052
+ __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2053
+ __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2018
2054
 
2019
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2020
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2055
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2056
+ __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2021
2057
 
2022
- /* Convert to vectore of 8 int32_t to 8 floats */
2023
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2058
+ /* Convert to vectore of 8 int32_t to 8 floats */
2059
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2024
2060
 
2025
- /* Multiply q with scale and accumulate */
2026
- acc = _mm256_fmadd_ps( scale, q, acc );
2061
+ /* Multiply q with scale and accumulate */
2062
+ acc = _mm256_fmadd_ps( scale, q, acc );
2027
2063
  }
2028
-
2029
- }
2064
+ }
2030
2065
 
2031
2066
  // Return horizontal sum of the acc vector
2032
2067
  __m128 res = _mm256_extractf128_ps( acc, 1 );
@@ -2087,18 +2122,18 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2087
2122
  float sum1 = 0.0f;
2088
2123
 
2089
2124
  for (int i = 0; i < nb; i += 2) {
2090
- const block_q4_0 * restrict x0 = &px[i + 0];
2091
- const block_q4_0 * restrict y0 = &py[i + 0];
2092
- const block_q4_0 * restrict x1 = &px[i + 1];
2093
- const block_q4_0 * restrict y1 = &py[i + 1];
2125
+ const block_q4_0 * restrict x0 = &x[i + 0];
2126
+ const block_q4_0 * restrict y0 = &y[i + 0];
2127
+ const block_q4_0 * restrict x1 = &x[i + 1];
2128
+ const block_q4_0 * restrict y1 = &y[i + 1];
2094
2129
 
2095
2130
  const v128_t m4b = wasm_u8x16_splat(0xf);
2096
2131
  const v128_t s8b = wasm_i8x16_splat(0x8);
2097
2132
 
2098
- const v128_t v0_0 = wasm_v128_load(x0.qs);
2099
- const v128_t v0_1 = wasm_v128_load(y0.qs);
2100
- const v128_t v1_0 = wasm_v128_load(x1.qs);
2101
- const v128_t v1_1 = wasm_v128_load(y1.qs);
2133
+ const v128_t v0_0 = wasm_v128_load(x0->qs);
2134
+ const v128_t v0_1 = wasm_v128_load(y0->qs);
2135
+ const v128_t v1_0 = wasm_v128_load(x1->qs);
2136
+ const v128_t v1_1 = wasm_v128_load(y1->qs);
2102
2137
 
2103
2138
  // 4-bit -> 8-bit
2104
2139
  const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@@ -2170,18 +2205,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2170
2205
  const uint8_t * restrict p0 = x[i].qs;
2171
2206
  const uint8_t * restrict p1 = y[i].qs;
2172
2207
 
2208
+ int sumi = 0;
2173
2209
  for (int j = 0; j < QK/2; j++) {
2174
2210
  const uint8_t v0 = p0[j];
2175
2211
  const uint8_t v1 = p1[j];
2176
2212
 
2177
- const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
2178
- const float f1 = d0*((int8_t) (v0 >> 4) - 8);
2213
+ const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
2214
+ const int8_t i1 = (int8_t) (v0 >> 4) - 8;
2179
2215
 
2180
- const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
2181
- const float f3 = d1*((int8_t) (v1 >> 4) - 8);
2216
+ const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
2217
+ const int8_t i3 = (int8_t) (v1 >> 4) - 8;
2182
2218
 
2183
- sumf += f0*f2 + f1*f3;
2219
+ sumi += i0*i2 + i1*i3;
2184
2220
  }
2221
+ sumf += d0 * d1 * sumi;
2185
2222
  }
2186
2223
  #endif
2187
2224
 
@@ -2273,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2273
2310
  float sum10 = 0.0f;
2274
2311
  float sum11 = 0.0f;
2275
2312
 
2276
- for (int i = 0; i < nb; ++i) {
2313
+ for (int i = 0; i < nb; i += 2) {
2277
2314
  const block_q4_1 * restrict x0 = &x[i + 0];
2278
2315
  const block_q4_1 * restrict y0 = &y[i + 0];
2316
+ const block_q4_1 * restrict x1 = &x[i + 1];
2317
+ const block_q4_1 * restrict y1 = &y[i + 1];
2279
2318
 
2280
2319
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2281
2320
 
2282
2321
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2283
2322
  const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2323
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2324
+ const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2284
2325
 
2285
- // and with 0xf
2326
+ // 4-bit -> 8-bit
2286
2327
  const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2287
2328
  const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2288
-
2289
2329
  const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2290
2330
  const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2291
2331
 
2292
- // dot product into uint16x8_t
2332
+ const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2333
+ const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2334
+ const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2335
+ const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2336
+
2337
+ sum00 += x0->m*y0->m;
2338
+ sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2339
+ sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2340
+
2341
+ sum00 += x1->m*y1->m;
2342
+ sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2343
+ sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2344
+
2345
+ #if defined(__ARM_FEATURE_DOTPROD)
2346
+ // dot product into int32x4_t
2347
+ uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2348
+ uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2349
+
2350
+ p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2351
+ p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2352
+
2353
+ sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2354
+ sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2355
+ #else
2293
2356
  const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2294
2357
  const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2295
-
2296
2358
  const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2297
2359
  const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2298
2360
 
2299
- const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2300
- const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
2361
+ const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2362
+ const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2363
+ const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2364
+ const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2301
2365
 
2302
- sum00 += x0->m*y0->m;
2303
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2304
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2305
- sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
2366
+ const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2367
+ const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2368
+
2369
+ const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2370
+ const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2371
+
2372
+ const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2373
+ const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2374
+
2375
+ sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2376
+ sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2377
+ #endif
2306
2378
  }
2307
2379
 
2308
2380
  sumf = QK*sum00 + sum01 + sum10 + sum11;
@@ -2578,29 +2650,38 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2578
2650
  //
2579
2651
 
2580
2652
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2581
- QK,
2582
- QK,
2583
- 1,
2584
- 1,
2585
- 1,
2586
- 1,
2587
- 1,
2653
+ [GGML_TYPE_F32] = 1,
2654
+ [GGML_TYPE_F16] = 1,
2655
+ [GGML_TYPE_Q4_0] = QK,
2656
+ [GGML_TYPE_Q4_1] = QK,
2657
+ [GGML_TYPE_I8] = 1,
2658
+ [GGML_TYPE_I16] = 1,
2659
+ [GGML_TYPE_I32] = 1,
2588
2660
  };
2589
-
2590
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2661
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
2591
2662
 
2592
2663
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2593
- sizeof(block_q4_0),
2594
- sizeof(block_q4_1),
2595
- sizeof(int8_t ),
2596
- sizeof(int16_t),
2597
- sizeof(int32_t),
2598
- sizeof(ggml_fp16_t),
2599
- sizeof(float ),
2664
+ [GGML_TYPE_F32] = sizeof(float),
2665
+ [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
2666
+ [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
2667
+ [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
2668
+ [GGML_TYPE_I8] = sizeof(int8_t),
2669
+ [GGML_TYPE_I16] = sizeof(int16_t),
2670
+ [GGML_TYPE_I32] = sizeof(int32_t),
2600
2671
  };
2672
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
2673
+
2601
2674
 
2602
- // don't forget to update the array above when adding new types
2603
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2675
+ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
2676
+ [GGML_TYPE_F32] = "f32",
2677
+ [GGML_TYPE_F16] = "f16",
2678
+ [GGML_TYPE_Q4_0] = "q4_0",
2679
+ [GGML_TYPE_Q4_1] = "q4_1",
2680
+ [GGML_TYPE_I8] = "i8",
2681
+ [GGML_TYPE_I16] = "i16",
2682
+ [GGML_TYPE_I32] = "i32",
2683
+ };
2684
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated");
2604
2685
 
2605
2686
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2606
2687
  "NONE",
@@ -2629,6 +2710,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2629
2710
 
2630
2711
  "SCALE",
2631
2712
  "CPY",
2713
+ "CONT",
2632
2714
  "RESHAPE",
2633
2715
  "VIEW",
2634
2716
  "PERMUTE",
@@ -2642,9 +2724,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2642
2724
 
2643
2725
  "FLASH_ATTN",
2644
2726
  "FLASH_FF",
2727
+
2728
+ "MAP_UNARY",
2729
+ "MAP_BINARY",
2645
2730
  };
2646
2731
 
2647
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2732
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2648
2733
 
2649
2734
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2650
2735
  "none",
@@ -2673,6 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2673
2758
 
2674
2759
  "x*v",
2675
2760
  "x-\\>y",
2761
+ "cont(x)",
2676
2762
  "reshape(x)",
2677
2763
  "view(x)",
2678
2764
  "permute(x)",
@@ -2686,24 +2772,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2686
2772
 
2687
2773
  "flash_attn(x)",
2688
2774
  "flash_ff(x)",
2689
- };
2690
-
2691
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2692
-
2693
- //
2694
- // ggml object
2695
- //
2696
-
2697
- struct ggml_object {
2698
- size_t offs;
2699
- size_t size;
2700
2775
 
2701
- struct ggml_object * next;
2702
-
2703
- char padding[8];
2776
+ "f(x)",
2777
+ "f(x,y)",
2704
2778
  };
2705
2779
 
2706
- static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
2780
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2707
2781
 
2708
2782
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
2709
2783
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -2716,7 +2790,6 @@ struct ggml_context {
2716
2790
  size_t mem_size;
2717
2791
  void * mem_buffer;
2718
2792
  bool mem_buffer_owned;
2719
- bool mem_buffer_mlocked;
2720
2793
  bool no_alloc;
2721
2794
 
2722
2795
  int n_objects;
@@ -2834,6 +2907,11 @@ float ggml_type_sizef(enum ggml_type type) {
2834
2907
  return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
2835
2908
  }
2836
2909
 
2910
+ const char * ggml_type_name(enum ggml_type type) {
2911
+ return GGML_TYPE_NAME[type];
2912
+ }
2913
+
2914
+
2837
2915
  size_t ggml_element_size(const struct ggml_tensor * tensor) {
2838
2916
  return GGML_TYPE_SIZE[tensor->type];
2839
2917
  }
@@ -2999,11 +3077,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2999
3077
  return NULL;
3000
3078
  }
3001
3079
 
3080
+ const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
3081
+
3002
3082
  *ctx = (struct ggml_context) {
3003
- /*.mem_size =*/ params.mem_size,
3004
- /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
3083
+ /*.mem_size =*/ mem_size,
3084
+ /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
3005
3085
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
3006
- /*.mem_buffer_mlocked =*/ false,
3007
3086
  /*.no_alloc =*/ params.no_alloc,
3008
3087
  /*.n_objects =*/ 0,
3009
3088
  /*.objects_begin =*/ NULL,
@@ -3012,7 +3091,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3012
3091
  /*.scratch_save =*/ { 0, 0, NULL, },
3013
3092
  };
3014
3093
 
3015
- GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
3094
+ GGML_ASSERT(ctx->mem_buffer != NULL);
3016
3095
 
3017
3096
  ggml_assert_aligned(ctx->mem_buffer);
3018
3097
 
@@ -3036,16 +3115,8 @@ void ggml_free(struct ggml_context * ctx) {
3036
3115
  GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
3037
3116
  __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
3038
3117
 
3039
- #if GGML_MLOCK_SUPPORT
3040
- if (ctx->mem_buffer_mlocked) {
3041
- if (munlock(ctx->mem_buffer, ctx->mem_size)) {
3042
- fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
3043
- }
3044
- }
3045
- #endif
3046
-
3047
3118
  if (ctx->mem_buffer_owned) {
3048
- free(ctx->mem_buffer);
3119
+ GGML_ALIGNED_FREE(ctx->mem_buffer);
3049
3120
  }
3050
3121
 
3051
3122
  found = true;
@@ -3072,48 +3143,6 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
3072
3143
  return result;
3073
3144
  }
3074
3145
 
3075
- #ifdef __APPLE__
3076
- #define MLOCK_SUGGESTION \
3077
- "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
3078
- "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
3079
- #else
3080
- #define MLOCK_SUGGESTION \
3081
- "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
3082
- #endif
3083
-
3084
- bool ggml_mlock_supported(void) {
3085
- return GGML_MLOCK_SUPPORT;
3086
- }
3087
-
3088
- bool ggml_mlock(
3089
- struct ggml_context * ctx,
3090
- const void *opt_extra_addr,
3091
- size_t opt_extra_len,
3092
- char **err_p) {
3093
- // TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
3094
- #if GGML_MLOCK_SUPPORT
3095
- if (ctx->mem_buffer_mlocked) {
3096
- return true;
3097
- }
3098
- if (mlock(ctx->mem_buffer, ctx->mem_size) ||
3099
- (opt_extra_len &&
3100
- mlock(opt_extra_addr, opt_extra_len))) {
3101
- if ((*err_p = malloc(1024))) {
3102
- snprintf(*err_p, 1024,
3103
- "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
3104
- ctx->mem_size + opt_extra_len,
3105
- strerror(errno));
3106
- }
3107
- return false;
3108
- }
3109
- ctx->mem_buffer_mlocked = true;
3110
- return true;
3111
- #else // GGML_MLOCK_SUPPORT
3112
- *err_p = strdup("can't mlock because it's not supported on this system");
3113
- return false;
3114
- #endif // GGML_MLOCK_SUPPORT
3115
- }
3116
-
3117
3146
  ////////////////////////////////////////////////////////////////////////////////
3118
3147
 
3119
3148
  struct ggml_tensor * ggml_new_tensor_impl(
@@ -4388,6 +4417,41 @@ struct ggml_tensor * ggml_cpy_inplace(
4388
4417
  return ggml_cpy_impl(ctx, a, b, true);
4389
4418
  }
4390
4419
 
4420
+ // ggml_cont
4421
+
4422
+ struct ggml_tensor * ggml_cont_impl(
4423
+ struct ggml_context * ctx,
4424
+ struct ggml_tensor * a,
4425
+ bool inplace) {
4426
+ bool is_node = false;
4427
+
4428
+ if (!inplace && a->grad) {
4429
+ GGML_ASSERT(false); // TODO: implement backward
4430
+ is_node = true;
4431
+ }
4432
+
4433
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4434
+
4435
+ result->op = GGML_OP_CONT;
4436
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4437
+ result->src0 = a;
4438
+ result->src1 = NULL;
4439
+
4440
+ return result;
4441
+ }
4442
+
4443
+ struct ggml_tensor * ggml_cont(
4444
+ struct ggml_context * ctx,
4445
+ struct ggml_tensor * a) {
4446
+ return ggml_cont_impl(ctx, a, false);
4447
+ }
4448
+
4449
+ struct ggml_tensor * ggml_cont_inplace(
4450
+ struct ggml_context * ctx,
4451
+ struct ggml_tensor * a) {
4452
+ return ggml_cont_impl(ctx, a, true);
4453
+ }
4454
+
4391
4455
  // ggml_reshape
4392
4456
 
4393
4457
  struct ggml_tensor * ggml_reshape(
@@ -4866,6 +4930,90 @@ struct ggml_tensor * ggml_flash_ff(
4866
4930
  return result;
4867
4931
  }
4868
4932
 
4933
+ // ggml_map_unary
4934
+
4935
+ struct ggml_tensor * ggml_map_unary_impl_f32(
4936
+ struct ggml_context * ctx,
4937
+ struct ggml_tensor * a,
4938
+ const ggml_unary_op_f32_t fun,
4939
+ bool inplace) {
4940
+ bool is_node = false;
4941
+
4942
+ if (!inplace && a->grad) {
4943
+ is_node = true;
4944
+ }
4945
+
4946
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
4947
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
4948
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4949
+
4950
+ result->op = GGML_OP_MAP_UNARY;
4951
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4952
+ result->src0 = a;
4953
+ result->opt[0] = addr_tensor;
4954
+
4955
+ return result;
4956
+ }
4957
+
4958
+ struct ggml_tensor * ggml_map_unary_f32(
4959
+ struct ggml_context * ctx,
4960
+ struct ggml_tensor * a,
4961
+ const ggml_unary_op_f32_t fun) {
4962
+ return ggml_map_unary_impl_f32(ctx, a, fun, false);
4963
+ }
4964
+
4965
+ struct ggml_tensor * ggml_map_unary_inplace_f32(
4966
+ struct ggml_context * ctx,
4967
+ struct ggml_tensor * a,
4968
+ const ggml_unary_op_f32_t fun) {
4969
+ return ggml_map_unary_impl_f32(ctx, a, fun, true);
4970
+ }
4971
+
4972
+ // ggml_map_binary
4973
+
4974
+ struct ggml_tensor * ggml_map_binary_impl_f32(
4975
+ struct ggml_context * ctx,
4976
+ struct ggml_tensor * a,
4977
+ struct ggml_tensor * b,
4978
+ const ggml_binary_op_f32_t fun,
4979
+ bool inplace) {
4980
+ GGML_ASSERT(ggml_are_same_shape(a, b));
4981
+
4982
+ bool is_node = false;
4983
+
4984
+ if (!inplace && (a->grad || b->grad)) {
4985
+ is_node = true;
4986
+ }
4987
+
4988
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
4989
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
4990
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4991
+
4992
+ result->op = GGML_OP_MAP_BINARY;
4993
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4994
+ result->src0 = a;
4995
+ result->src1 = b;
4996
+ result->opt[0] = addr_tensor;
4997
+
4998
+ return result;
4999
+ }
5000
+
5001
+ struct ggml_tensor * ggml_map_binary_f32(
5002
+ struct ggml_context * ctx,
5003
+ struct ggml_tensor * a,
5004
+ struct ggml_tensor * b,
5005
+ const ggml_binary_op_f32_t fun) {
5006
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
5007
+ }
5008
+
5009
+ struct ggml_tensor * ggml_map_binary_inplace_f32(
5010
+ struct ggml_context * ctx,
5011
+ struct ggml_tensor * a,
5012
+ struct ggml_tensor * b,
5013
+ const ggml_binary_op_f32_t fun) {
5014
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
5015
+ }
5016
+
4869
5017
  ////////////////////////////////////////////////////////////////////////////////
4870
5018
 
4871
5019
  void ggml_set_param(
@@ -4930,6 +5078,85 @@ static void ggml_compute_forward_dup_f16(
4930
5078
 
4931
5079
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
4932
5080
 
5081
+ if (ggml_is_contiguous(dst)) {
5082
+ if (src0->nb[0] == sizeof(ggml_fp16_t)) {
5083
+ if (dst->type == GGML_TYPE_F16) {
5084
+ size_t id = 0;
5085
+ const size_t rs = ne00*nb00;
5086
+
5087
+ for (int i03 = 0; i03 < ne03; i03++) {
5088
+ for (int i02 = 0; i02 < ne02; i02++) {
5089
+ for (int i01 = 0; i01 < ne01; i01++) {
5090
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5091
+ char * dst_ptr = (char *) dst->data + id*rs;
5092
+
5093
+ memcpy(dst_ptr, src0_ptr, rs);
5094
+
5095
+ id++;
5096
+ }
5097
+ }
5098
+ }
5099
+ } else if (dst->type == GGML_TYPE_F32) {
5100
+ size_t id = 0;
5101
+ float * dst_ptr = (float *) dst->data;
5102
+
5103
+ for (int i03 = 0; i03 < ne03; i03++) {
5104
+ for (int i02 = 0; i02 < ne02; i02++) {
5105
+ for (int i01 = 0; i01 < ne01; i01++) {
5106
+ for (int i00 = 0; i00 < ne00; i00++) {
5107
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5108
+
5109
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5110
+ id++;
5111
+ }
5112
+ }
5113
+ }
5114
+ }
5115
+ } else {
5116
+ GGML_ASSERT(false); // TODO: implement
5117
+ }
5118
+ } else {
5119
+ //printf("%s: this is not optimal - fix me\n", __func__);
5120
+
5121
+ if (dst->type == GGML_TYPE_F32) {
5122
+ size_t id = 0;
5123
+ float * dst_ptr = (float *) dst->data;
5124
+
5125
+ for (int i03 = 0; i03 < ne03; i03++) {
5126
+ for (int i02 = 0; i02 < ne02; i02++) {
5127
+ for (int i01 = 0; i01 < ne01; i01++) {
5128
+ for (int i00 = 0; i00 < ne00; i00++) {
5129
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5130
+
5131
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5132
+ id++;
5133
+ }
5134
+ }
5135
+ }
5136
+ }
5137
+ } else if (dst->type == GGML_TYPE_F16) {
5138
+ size_t id = 0;
5139
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5140
+
5141
+ for (int i03 = 0; i03 < ne03; i03++) {
5142
+ for (int i02 = 0; i02 < ne02; i02++) {
5143
+ for (int i01 = 0; i01 < ne01; i01++) {
5144
+ for (int i00 = 0; i00 < ne00; i00++) {
5145
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5146
+
5147
+ dst_ptr[id] = *src0_ptr;
5148
+ id++;
5149
+ }
5150
+ }
5151
+ }
5152
+ }
5153
+ } else {
5154
+ GGML_ASSERT(false); // TODO: implement
5155
+ }
5156
+ }
5157
+ return;
5158
+ }
5159
+
4933
5160
  // dst counters
4934
5161
  int64_t i10 = 0;
4935
5162
  int64_t i11 = 0;
@@ -5024,6 +5251,105 @@ static void ggml_compute_forward_dup_f32(
5024
5251
  return;
5025
5252
  }
5026
5253
 
5254
+ if (src0->type == dst->type &&
5255
+ src0->ne[0] == dst->ne[0] &&
5256
+ src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5257
+ // copy by rows
5258
+ const size_t rs = ne00*nb00;
5259
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5260
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5261
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5262
+ memcpy(
5263
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5264
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
5265
+ rs);
5266
+ }
5267
+ }
5268
+ }
5269
+ return;
5270
+ }
5271
+
5272
+ if (ggml_is_contiguous(dst)) {
5273
+ // TODO: simplify
5274
+ if (src0->nb[0] == sizeof(float)) {
5275
+ if (dst->type == GGML_TYPE_F32) {
5276
+ size_t id = 0;
5277
+ const size_t rs = ne00*nb00;
5278
+
5279
+ for (int i03 = 0; i03 < ne03; i03++) {
5280
+ for (int i02 = 0; i02 < ne02; i02++) {
5281
+ for (int i01 = 0; i01 < ne01; i01++) {
5282
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5283
+ char * dst_ptr = (char *) dst->data + id*rs;
5284
+
5285
+ memcpy(dst_ptr, src0_ptr, rs);
5286
+
5287
+ id++;
5288
+ }
5289
+ }
5290
+ }
5291
+ } else if (dst->type == GGML_TYPE_F16) {
5292
+ size_t id = 0;
5293
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5294
+
5295
+ for (int i03 = 0; i03 < ne03; i03++) {
5296
+ for (int i02 = 0; i02 < ne02; i02++) {
5297
+ for (int i01 = 0; i01 < ne01; i01++) {
5298
+ for (int i00 = 0; i00 < ne00; i00++) {
5299
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5300
+
5301
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5302
+ id++;
5303
+ }
5304
+ }
5305
+ }
5306
+ }
5307
+ } else {
5308
+ GGML_ASSERT(false); // TODO: implement
5309
+ }
5310
+ } else {
5311
+ //printf("%s: this is not optimal - fix me\n", __func__);
5312
+
5313
+ if (dst->type == GGML_TYPE_F32) {
5314
+ size_t id = 0;
5315
+ float * dst_ptr = (float *) dst->data;
5316
+
5317
+ for (int i03 = 0; i03 < ne03; i03++) {
5318
+ for (int i02 = 0; i02 < ne02; i02++) {
5319
+ for (int i01 = 0; i01 < ne01; i01++) {
5320
+ for (int i00 = 0; i00 < ne00; i00++) {
5321
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5322
+
5323
+ dst_ptr[id] = *src0_ptr;
5324
+ id++;
5325
+ }
5326
+ }
5327
+ }
5328
+ }
5329
+ } else if (dst->type == GGML_TYPE_F16) {
5330
+ size_t id = 0;
5331
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5332
+
5333
+ for (int i03 = 0; i03 < ne03; i03++) {
5334
+ for (int i02 = 0; i02 < ne02; i02++) {
5335
+ for (int i01 = 0; i01 < ne01; i01++) {
5336
+ for (int i00 = 0; i00 < ne00; i00++) {
5337
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5338
+
5339
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5340
+ id++;
5341
+ }
5342
+ }
5343
+ }
5344
+ }
5345
+ } else {
5346
+ GGML_ASSERT(false); // TODO: implement
5347
+ }
5348
+ }
5349
+
5350
+ return;
5351
+ }
5352
+
5027
5353
  // dst counters
5028
5354
  int64_t i10 = 0;
5029
5355
  int64_t i11 = 0;
@@ -5144,14 +5470,18 @@ static void ggml_compute_forward_add_f32(
5144
5470
  GGML_ASSERT(nb00 == sizeof(float));
5145
5471
 
5146
5472
  if (nb10 == sizeof(float)) {
5147
- const int j0 = (n/nth)*ith;
5148
- const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
5149
-
5150
- for (int j = j0; j < j1; j++) {
5473
+ for (int j = ith; j < n; j += nth) {
5474
+ #ifdef GGML_USE_ACCELERATE
5475
+ vDSP_vadd(
5476
+ (float *) ((char *) src0->data + j*nb01), 1,
5477
+ (float *) ((char *) src1->data + j*nb11), 1,
5478
+ (float *) ((char *) dst->data + j*nb1), 1, nc);
5479
+ #else
5151
5480
  ggml_vec_add_f32(nc,
5152
5481
  (float *) ((char *) dst->data + j*nb1),
5153
5482
  (float *) ((char *) src0->data + j*nb01),
5154
5483
  (float *) ((char *) src1->data + j*nb11));
5484
+ #endif
5155
5485
  }
5156
5486
  } else {
5157
5487
  // src1 is not contiguous
@@ -6304,7 +6634,7 @@ static void ggml_compute_forward_mul_mat_f32(
6304
6634
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6305
6635
  ne11, ne01, ne10,
6306
6636
  1.0f, y, ne10,
6307
- x, ne10,
6637
+ x, ne00,
6308
6638
  0.0f, d, ne01);
6309
6639
  }
6310
6640
  }
@@ -6476,7 +6806,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6476
6806
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6477
6807
  ne11, ne01, ne10,
6478
6808
  1.0f, y, ne10,
6479
- x, ne10,
6809
+ x, ne00,
6480
6810
  0.0f, d, ne01);
6481
6811
  }
6482
6812
  }
@@ -6564,29 +6894,27 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6564
6894
  //}
6565
6895
  }
6566
6896
 
6567
- typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k);
6568
- typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k);
6569
- typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y);
6570
-
6571
- typedef struct {
6572
- dequantize_row_q_t dequantize_row_q;
6573
- quantize_row_q_t quantize_row_q;
6574
- vec_dot_q_t vec_dot_q;
6575
- } quantize_fns_t;
6576
-
6577
6897
  static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6578
6898
  [GGML_TYPE_Q4_0] = {
6579
- .dequantize_row_q = dequantize_row_q4_0,
6580
- .quantize_row_q = quantize_row_q4_0,
6581
- .vec_dot_q = ggml_vec_dot_q4_0,
6899
+ .dequantize_row_q = dequantize_row_q4_0,
6900
+ .quantize_row_q = quantize_row_q4_0,
6901
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6902
+ .vec_dot_q = ggml_vec_dot_q4_0,
6582
6903
  },
6583
6904
  [GGML_TYPE_Q4_1] = {
6584
- .dequantize_row_q = dequantize_row_q4_1,
6585
- .quantize_row_q = quantize_row_q4_1,
6586
- .vec_dot_q = ggml_vec_dot_q4_1,
6905
+ .dequantize_row_q = dequantize_row_q4_1,
6906
+ .quantize_row_q = quantize_row_q4_1,
6907
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6908
+ .vec_dot_q = ggml_vec_dot_q4_1,
6587
6909
  },
6588
6910
  };
6589
6911
 
6912
+ // For internal test use
6913
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6914
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
6915
+ return quantize_fns[i];
6916
+ }
6917
+
6590
6918
  static void ggml_compute_forward_mul_mat_q_f32(
6591
6919
  const struct ggml_compute_params * params,
6592
6920
  const struct ggml_tensor * src0,
@@ -6691,7 +7019,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6691
7019
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6692
7020
  ne11, ne01, ne10,
6693
7021
  1.0f, y, ne10,
6694
- x, ne10,
7022
+ x, ne00,
6695
7023
  0.0f, d, ne01);
6696
7024
  }
6697
7025
  }
@@ -6901,6 +7229,15 @@ static void ggml_compute_forward_cpy(
6901
7229
  ggml_compute_forward_dup(params, src0, dst);
6902
7230
  }
6903
7231
 
7232
+ // ggml_compute_forward_cont
7233
+
7234
+ static void ggml_compute_forward_cont(
7235
+ const struct ggml_compute_params * params,
7236
+ const struct ggml_tensor * src0,
7237
+ struct ggml_tensor * dst) {
7238
+ ggml_compute_forward_dup(params, src0, dst);
7239
+ }
7240
+
6904
7241
  // ggml_compute_forward_reshape
6905
7242
 
6906
7243
  static void ggml_compute_forward_reshape(
@@ -7279,6 +7616,8 @@ static void ggml_compute_forward_rope_f32(
7279
7616
  // row index used to determine which thread to use
7280
7617
  int ir = 0;
7281
7618
 
7619
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
7620
+
7282
7621
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7283
7622
  for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7284
7623
  const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7286,11 +7625,13 @@ static void ggml_compute_forward_rope_f32(
7286
7625
  if (ir++ < ir0) continue;
7287
7626
  if (ir > ir1) break;
7288
7627
 
7628
+ float theta = (float)p;
7629
+
7289
7630
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7290
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
7631
+ const float cos_theta = cosf(theta);
7632
+ const float sin_theta = sinf(theta);
7291
7633
 
7292
- const float cos_theta = cosf(p*theta);
7293
- const float sin_theta = sinf(p*theta);
7634
+ theta *= theta_scale;
7294
7635
 
7295
7636
  const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7296
7637
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -7352,6 +7693,8 @@ static void ggml_compute_forward_rope_f16(
7352
7693
  // row index used to determine which thread to use
7353
7694
  int ir = 0;
7354
7695
 
7696
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
7697
+
7355
7698
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7356
7699
  for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7357
7700
  const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7359,11 +7702,13 @@ static void ggml_compute_forward_rope_f16(
7359
7702
  if (ir++ < ir0) continue;
7360
7703
  if (ir > ir1) break;
7361
7704
 
7705
+ float theta = (float)p;
7706
+
7362
7707
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7363
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
7708
+ const float cos_theta = cosf(theta);
7709
+ const float sin_theta = sinf(theta);
7364
7710
 
7365
- const float cos_theta = cosf(p*theta);
7366
- const float sin_theta = sinf(p*theta);
7711
+ theta *= theta_scale;
7367
7712
 
7368
7713
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7369
7714
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -8637,6 +8982,111 @@ static void ggml_compute_forward_flash_ff(
8637
8982
  }
8638
8983
  }
8639
8984
 
8985
+ // ggml_compute_forward_map_unary
8986
+
8987
+ static void ggml_compute_forward_map_unary_f32(
8988
+ const struct ggml_compute_params * params,
8989
+ const struct ggml_tensor * src0,
8990
+ struct ggml_tensor * dst,
8991
+ const ggml_unary_op_f32_t fun) {
8992
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
8993
+
8994
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8995
+ return;
8996
+ }
8997
+
8998
+ const int n = ggml_nrows(src0);
8999
+ const int nc = src0->ne[0];
9000
+
9001
+ assert( dst->nb[0] == sizeof(float));
9002
+ assert(src0->nb[0] == sizeof(float));
9003
+
9004
+ for (int i = 0; i < n; i++) {
9005
+ fun(nc,
9006
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9007
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
9008
+ }
9009
+ }
9010
+
9011
+
9012
+ static void ggml_compute_forward_map_unary(
9013
+ const struct ggml_compute_params * params,
9014
+ const struct ggml_tensor * src0,
9015
+ struct ggml_tensor * dst,
9016
+ const ggml_unary_op_f32_t fun) {
9017
+ switch (src0->type) {
9018
+ case GGML_TYPE_F32:
9019
+ {
9020
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9021
+ } break;
9022
+ case GGML_TYPE_Q4_0:
9023
+ case GGML_TYPE_Q4_1:
9024
+ case GGML_TYPE_I8:
9025
+ case GGML_TYPE_I16:
9026
+ case GGML_TYPE_I32:
9027
+ case GGML_TYPE_F16:
9028
+ case GGML_TYPE_COUNT:
9029
+ {
9030
+ GGML_ASSERT(false);
9031
+ } break;
9032
+ }
9033
+ }
9034
+
9035
+ // ggml_compute_forward_map_binary
9036
+
9037
+ static void ggml_compute_forward_map_binary_f32(
9038
+ const struct ggml_compute_params * params,
9039
+ const struct ggml_tensor * src0,
9040
+ const struct ggml_tensor * src1,
9041
+ struct ggml_tensor * dst,
9042
+ const ggml_binary_op_f32_t fun) {
9043
+ assert(params->ith == 0);
9044
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9045
+
9046
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9047
+ return;
9048
+ }
9049
+
9050
+ const int n = ggml_nrows(src0);
9051
+ const int nc = src0->ne[0];
9052
+
9053
+ assert( dst->nb[0] == sizeof(float));
9054
+ assert(src0->nb[0] == sizeof(float));
9055
+ assert(src1->nb[0] == sizeof(float));
9056
+
9057
+ for (int i = 0; i < n; i++) {
9058
+ fun(nc,
9059
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9060
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
9061
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
9062
+ }
9063
+ }
9064
+
9065
+
9066
+ static void ggml_compute_forward_map_binary(
9067
+ const struct ggml_compute_params * params,
9068
+ const struct ggml_tensor * src0,
9069
+ const struct ggml_tensor * src1,
9070
+ struct ggml_tensor * dst,
9071
+ const ggml_binary_op_f32_t fun) {
9072
+ switch (src0->type) {
9073
+ case GGML_TYPE_F32:
9074
+ {
9075
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9076
+ } break;
9077
+ case GGML_TYPE_Q4_0:
9078
+ case GGML_TYPE_Q4_1:
9079
+ case GGML_TYPE_I8:
9080
+ case GGML_TYPE_I16:
9081
+ case GGML_TYPE_I32:
9082
+ case GGML_TYPE_F16:
9083
+ case GGML_TYPE_COUNT:
9084
+ {
9085
+ GGML_ASSERT(false);
9086
+ } break;
9087
+ }
9088
+ }
9089
+
8640
9090
  /////////////////////////////////
8641
9091
 
8642
9092
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -8731,6 +9181,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8731
9181
  {
8732
9182
  ggml_compute_forward_cpy(params, tensor->src0, tensor);
8733
9183
  } break;
9184
+ case GGML_OP_CONT:
9185
+ {
9186
+ ggml_compute_forward_cont(params, tensor->src0, tensor);
9187
+ } break;
8734
9188
  case GGML_OP_RESHAPE:
8735
9189
  {
8736
9190
  ggml_compute_forward_reshape(params, tensor->src0, tensor);
@@ -8782,6 +9236,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8782
9236
  {
8783
9237
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
8784
9238
  } break;
9239
+ case GGML_OP_MAP_UNARY:
9240
+ {
9241
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
9242
+ ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
9243
+ }
9244
+ break;
9245
+ case GGML_OP_MAP_BINARY:
9246
+ {
9247
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
9248
+ ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
9249
+ }
9250
+ break;
8785
9251
  case GGML_OP_NONE:
8786
9252
  {
8787
9253
  // nop
@@ -8975,8 +9441,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8975
9441
  src1->grad =
8976
9442
  ggml_add_impl(ctx,
8977
9443
  src1->grad,
8978
- // TODO: fix transpose, the node will break the graph connections
8979
- ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
9444
+ ggml_mul_mat(ctx,
9445
+ ggml_cont(ctx, ggml_transpose(ctx, src0)),
9446
+ tensor->grad),
8980
9447
  inplace);
8981
9448
  }
8982
9449
  } break;
@@ -8988,6 +9455,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8988
9455
  {
8989
9456
  GGML_ASSERT(false); // TODO: not implemented
8990
9457
  } break;
9458
+ case GGML_OP_CONT:
9459
+ {
9460
+ GGML_ASSERT(false); // TODO: not implemented
9461
+ } break;
8991
9462
  case GGML_OP_RESHAPE:
8992
9463
  {
8993
9464
  GGML_ASSERT(false); // TODO: not implemented
@@ -9036,6 +9507,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
9036
9507
  {
9037
9508
  GGML_ASSERT(false); // not supported
9038
9509
  } break;
9510
+ case GGML_OP_MAP_UNARY:
9511
+ case GGML_OP_MAP_BINARY:
9512
+ {
9513
+ GGML_ASSERT(false); // not supported
9514
+ } break;
9039
9515
  case GGML_OP_NONE:
9040
9516
  {
9041
9517
  // nop
@@ -9126,7 +9602,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
9126
9602
  struct ggml_cgraph result = {
9127
9603
  /*.n_nodes =*/ 0,
9128
9604
  /*.n_leafs =*/ 0,
9129
- /*.n_threads =*/ 0,
9605
+ /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
9130
9606
  /*.work_size =*/ 0,
9131
9607
  /*.work =*/ NULL,
9132
9608
  /*.nodes =*/ { NULL },
@@ -9442,6 +9918,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9442
9918
  node->n_tasks = n_threads;
9443
9919
  } break;
9444
9920
  case GGML_OP_CPY:
9921
+ case GGML_OP_CONT:
9445
9922
  case GGML_OP_RESHAPE:
9446
9923
  case GGML_OP_VIEW:
9447
9924
  case GGML_OP_PERMUTE:
@@ -9527,6 +10004,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9527
10004
 
9528
10005
  work_size = MAX(work_size, cur);
9529
10006
  } break;
10007
+ case GGML_OP_MAP_UNARY:
10008
+ case GGML_OP_MAP_BINARY:
10009
+ {
10010
+ node->n_tasks = 1;
10011
+ } break;
9530
10012
  case GGML_OP_NONE:
9531
10013
  {
9532
10014
  node->n_tasks = 1;
@@ -9745,8 +10227,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9745
10227
 
9746
10228
  GGML_PRINT("=== GRAPH ===\n");
9747
10229
 
9748
- GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
9749
- GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
10230
+ GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
10231
+ GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
9750
10232
 
9751
10233
  GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
9752
10234
  for (int i = 0; i < cgraph->n_nodes; i++) {