llama_cpp 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- // Defines CLOCK_MONOTONIC and asprintf on Linux
1
+ // Defines CLOCK_MONOTONIC on Linux
2
2
  #define _GNU_SOURCE
3
3
 
4
4
  #include "ggml.h"
@@ -26,14 +26,9 @@
26
26
  #define static_assert(cond, msg) struct global_scope_noop_trick
27
27
  #endif
28
28
 
29
- #if defined _MSC_VER || defined(__MINGW32__)
29
+ #if defined(_WIN32)
30
30
 
31
- #if !defined(__MINGW32__)
32
- #include <Windows.h>
33
- #else
34
- // ref: https://github.com/ggerganov/whisper.cpp/issues/168
35
31
  #include <windows.h>
36
- #endif
37
32
 
38
33
  typedef volatile LONG atomic_int;
39
34
  typedef atomic_int atomic_bool;
@@ -55,6 +50,7 @@ typedef HANDLE pthread_t;
55
50
 
56
51
  typedef DWORD thread_ret_t;
57
52
  static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
53
+ (void) unused;
58
54
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
59
55
  if (handle == NULL)
60
56
  {
@@ -66,6 +62,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
66
62
  }
67
63
 
68
64
  static int pthread_join(pthread_t thread, void* unused) {
65
+ (void) unused;
69
66
  return (int) WaitForSingleObject(thread, INFINITE);
70
67
  }
71
68
 
@@ -97,17 +94,6 @@ typedef void* thread_ret_t;
97
94
  #define static_assert(cond, msg) _Static_assert(cond, msg)
98
95
  #endif
99
96
 
100
- #define GGML_MLOCK_SUPPORT 0
101
-
102
- #ifdef __has_include
103
- #if __has_include(<sys/mman.h>)
104
- #undef GGML_MLOCK_SUPPORT
105
- #define GGML_MLOCK_SUPPORT 1
106
- #include <sys/mman.h>
107
- #endif
108
- #endif
109
-
110
-
111
97
  /*#define GGML_PERF*/
112
98
  #define GGML_DEBUG 0
113
99
  #define GGML_GELU_FP16
@@ -128,6 +114,14 @@ typedef void* thread_ret_t;
128
114
  #define GGML_MEM_ALIGN 16
129
115
  #endif
130
116
 
117
+ #if defined(_MSC_VER) || defined(__MINGW32__)
118
+ #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
+ #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
+ #else
121
+ #define GGML_ALIGNED_MALLOC(size) aligned_alloc(GGML_MEM_ALIGN, size)
122
+ #define GGML_ALIGNED_FREE(ptr) free(ptr)
123
+ #endif
124
+
131
125
  #define UNUSED(x) (void)(x)
132
126
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
133
127
 
@@ -242,12 +236,12 @@ static inline float fp32_from_bits(uint32_t w) {
242
236
  }
243
237
 
244
238
  static inline uint32_t fp32_to_bits(float f) {
245
- union {
246
- float as_value;
247
- uint32_t as_bits;
248
- } fp32;
249
- fp32.as_value = f;
250
- return fp32.as_bits;
239
+ union {
240
+ float as_value;
241
+ uint32_t as_bits;
242
+ } fp32;
243
+ fp32.as_value = f;
244
+ return fp32.as_bits;
251
245
  }
252
246
 
253
247
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
@@ -497,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
497
491
  }
498
492
  #endif
499
493
 
494
+ #if __ARM_NEON
495
+
496
+ #if !defined(__aarch64__)
497
+
498
+ inline static uint16_t vaddvq_u8(uint8x16_t v) {
499
+ return
500
+ (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
501
+ (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
502
+ (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
503
+ (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
504
+ (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
505
+ (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
506
+ (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
507
+ (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
508
+ }
509
+
510
+ inline static int32_t vaddvq_s16(int16x8_t v) {
511
+ return
512
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
513
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
514
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
515
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
516
+ }
517
+
518
+ inline static uint32_t vaddvq_u16(uint16x8_t v) {
519
+ return
520
+ (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
521
+ (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
522
+ (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
523
+ (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
524
+ }
525
+
526
+ inline static int32_t vaddvq_s32(int32x4_t v) {
527
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
528
+ }
529
+
530
+ inline static float vaddvq_f32(float32x4_t v) {
531
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
532
+ }
533
+
534
+ inline float vminvq_f32(float32x4_t v) {
535
+ return
536
+ MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
537
+ MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
538
+ }
539
+
540
+ inline float vmaxvq_f32(float32x4_t v) {
541
+ return
542
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
543
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
544
+ }
545
+
546
+ inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
547
+ return vget_low_s8(vcombine_s8(a, b));
548
+ }
549
+
550
+ inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
551
+ return vget_high_s8(vcombine_s8(a, b));
552
+ }
553
+
554
+ inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
555
+ return vget_low_u8(vcombine_u8(a, b));
556
+ }
557
+
558
+ inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
559
+ return vget_high_u8(vcombine_u8(a, b));
560
+ }
561
+
562
+ #endif
563
+ #endif
564
+
500
565
  // method 5
501
566
  // blocks of QK elements
502
567
  // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -610,10 +675,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
610
675
  for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
611
676
  for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
612
677
 
613
- // absolute max
614
- const float amax = MAX(
615
- MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
616
- MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
678
+ const float amax = vmaxvq_f32(amaxv[0]);
617
679
 
618
680
  const float d = amax / ((1 << 3) - 1);
619
681
  const float id = d ? 1.0f/d : 0.0f;
@@ -935,7 +997,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
935
997
  float32x4_t minv[8];
936
998
  float32x4_t maxv[8];
937
999
 
938
- for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1000
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
939
1001
 
940
1002
  for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
941
1003
  for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -958,7 +1020,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
958
1020
 
959
1021
  for (int l = 0; l < 8; l++) {
960
1022
  const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
961
- const int32x4_t vi = vcvtq_s32_f32(v);
1023
+ const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
1024
+ const int32x4_t vi = vcvtq_s32_f32(vf);
962
1025
 
963
1026
  y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
964
1027
  y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
@@ -1226,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1226
1289
  #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
1227
1290
  #define GGML_F32x4_ADD vaddq_f32
1228
1291
  #define GGML_F32x4_MUL vmulq_f32
1229
- #if defined(__ARM_FEATURE_QRDMX)
1230
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1231
- #else
1232
- #define GGML_F32x4_REDUCE_ONE(x) \
1233
- (vgetq_lane_f32(x, 0) + \
1234
- vgetq_lane_f32(x, 1) + \
1235
- vgetq_lane_f32(x, 2) + \
1236
- vgetq_lane_f32(x, 3))
1237
- #endif
1292
+ #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1238
1293
  #define GGML_F32x4_REDUCE(res, x) \
1239
1294
  { \
1240
1295
  for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1857,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1857
1912
  // 4-bit -> 8-bit
1858
1913
  const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
1859
1914
  const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
1860
-
1861
1915
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
1862
1916
  const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
1863
1917
 
1864
1918
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
1865
1919
  const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
1866
-
1867
1920
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
1868
1921
  const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
1869
1922
 
1870
1923
  // sub 8
1871
1924
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
1872
1925
  const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1873
-
1874
1926
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
1875
1927
  const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
1876
1928
 
1877
1929
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
1878
1930
  const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1879
-
1880
1931
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
1881
1932
  const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
1882
1933
 
1883
1934
  #if defined(__ARM_FEATURE_DOTPROD)
1884
- // dot product into int16x8_t
1935
+ // dot product into int32x4_t
1885
1936
  int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1886
1937
  int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
1887
1938
 
1888
1939
  p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1889
1940
  p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
1890
1941
 
1891
- // scalar
1892
- #if defined(__ARM_FEATURE_QRDMX)
1893
- sum0 += x0->d * y0->d * vaddvq_s32(p_0);
1894
- sum1 += x1->d * y1->d * vaddvq_s32(p_1);
1942
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
1943
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
1895
1944
  #else
1896
- sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
1897
- sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
1898
- #endif
1899
- #else
1900
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1945
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1901
1946
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1902
-
1903
1947
  const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
1904
1948
  const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
1905
1949
 
1906
1950
  const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
1907
1951
  const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1908
-
1909
1952
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1910
1953
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1911
1954
 
@@ -1918,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1918
1961
  const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1919
1962
  const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1920
1963
 
1921
- // scalar
1922
- #if defined(__ARM_FEATURE_QRDMX)
1923
- sum0 += x0->d * y0->d * vaddvq_s16(p_0);
1924
- sum1 += x1->d * y1->d * vaddvq_s16(p_1);
1925
- #else
1926
- sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
1927
- sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
1928
- #endif
1964
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
1965
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
1929
1966
  #endif
1930
1967
  }
1931
1968
 
@@ -1962,7 +1999,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1962
1999
  // Initialize accumulator with zeros
1963
2000
  __m256 acc = _mm256_setzero_ps();
1964
2001
 
1965
- /* Prepare the constants we will need during execution */
2002
+ /* Prepare the constants we will need during execution */
1966
2003
  const __m256i lowMask = _mm256_set1_epi8( 0xF );
1967
2004
  const __m256i offset_8 = _mm256_set1_epi16( 8 );
1968
2005
 
@@ -1972,61 +2009,59 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1972
2009
 
1973
2010
  // Main loop
1974
2011
  for (int i = 0; i < nb; i+=UNROLL_COUNT) {
1975
-
1976
- // This loop will be unrolled by the compiler
2012
+ // This loop will be unrolled by the compiler
1977
2013
  for (int u=0;u<UNROLL_COUNT;u++) {
1978
- /* Compute combined scale for the block */
1979
- const __m256 scale = _mm256_mul_ps(
1980
- _mm256_broadcast_ss( &x[i+u].d ),
1981
- _mm256_broadcast_ss( &y[i+u].d ) );
1982
-
1983
- /* get input from x
1984
- Input: 32 Nibbles (16 bytes) at *x[i+u]
1985
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1986
-
1987
- /* Load 16 bytes from memory */
1988
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
1989
- /* Expand bytes into uint16_t values */
1990
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2014
+ /* Compute combined scale for the block */
2015
+ const __m256 scale = _mm256_mul_ps(
2016
+ _mm256_broadcast_ss( &x[i+u].d ),
2017
+ _mm256_broadcast_ss( &y[i+u].d ) );
2018
+
2019
+ /* get input from x
2020
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
2021
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2022
+
2023
+ /* Load 16 bytes from memory */
2024
+ const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2025
+ /* Expand bytes into uint16_t values */
2026
+ const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
1991
2027
  /* Unpack values into individual bytes */
1992
2028
  __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
1993
2029
  const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
1994
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2030
+ __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
1995
2031
  /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1996
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
1997
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2032
+ x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2033
+ x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
1998
2034
 
1999
- /* get input from y
2000
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2001
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2035
+ /* get input from y
2036
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
2037
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2002
2038
 
2003
- /* Load 16 bytes from memory */
2004
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2005
- /* Expand bytes into uint16_t values */
2006
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2039
+ /* Load 16 bytes from memory */
2040
+ const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2041
+ /* Expand bytes into uint16_t values */
2042
+ const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2007
2043
  /* Unpack values into individual bytes */
2008
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2009
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2010
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2044
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2045
+ __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2046
+ __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2011
2047
  /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2012
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2013
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2048
+ y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2049
+ y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2014
2050
 
2015
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2016
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2017
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2051
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
2052
+ __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2053
+ __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2018
2054
 
2019
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2020
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2055
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2056
+ __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2021
2057
 
2022
- /* Convert to vectore of 8 int32_t to 8 floats */
2023
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2058
+ /* Convert to vectore of 8 int32_t to 8 floats */
2059
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2024
2060
 
2025
- /* Multiply q with scale and accumulate */
2026
- acc = _mm256_fmadd_ps( scale, q, acc );
2061
+ /* Multiply q with scale and accumulate */
2062
+ acc = _mm256_fmadd_ps( scale, q, acc );
2027
2063
  }
2028
-
2029
- }
2064
+ }
2030
2065
 
2031
2066
  // Return horizontal sum of the acc vector
2032
2067
  __m128 res = _mm256_extractf128_ps( acc, 1 );
@@ -2087,18 +2122,18 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2087
2122
  float sum1 = 0.0f;
2088
2123
 
2089
2124
  for (int i = 0; i < nb; i += 2) {
2090
- const block_q4_0 * restrict x0 = &px[i + 0];
2091
- const block_q4_0 * restrict y0 = &py[i + 0];
2092
- const block_q4_0 * restrict x1 = &px[i + 1];
2093
- const block_q4_0 * restrict y1 = &py[i + 1];
2125
+ const block_q4_0 * restrict x0 = &x[i + 0];
2126
+ const block_q4_0 * restrict y0 = &y[i + 0];
2127
+ const block_q4_0 * restrict x1 = &x[i + 1];
2128
+ const block_q4_0 * restrict y1 = &y[i + 1];
2094
2129
 
2095
2130
  const v128_t m4b = wasm_u8x16_splat(0xf);
2096
2131
  const v128_t s8b = wasm_i8x16_splat(0x8);
2097
2132
 
2098
- const v128_t v0_0 = wasm_v128_load(x0.qs);
2099
- const v128_t v0_1 = wasm_v128_load(y0.qs);
2100
- const v128_t v1_0 = wasm_v128_load(x1.qs);
2101
- const v128_t v1_1 = wasm_v128_load(y1.qs);
2133
+ const v128_t v0_0 = wasm_v128_load(x0->qs);
2134
+ const v128_t v0_1 = wasm_v128_load(y0->qs);
2135
+ const v128_t v1_0 = wasm_v128_load(x1->qs);
2136
+ const v128_t v1_1 = wasm_v128_load(y1->qs);
2102
2137
 
2103
2138
  // 4-bit -> 8-bit
2104
2139
  const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@@ -2170,18 +2205,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2170
2205
  const uint8_t * restrict p0 = x[i].qs;
2171
2206
  const uint8_t * restrict p1 = y[i].qs;
2172
2207
 
2208
+ int sumi = 0;
2173
2209
  for (int j = 0; j < QK/2; j++) {
2174
2210
  const uint8_t v0 = p0[j];
2175
2211
  const uint8_t v1 = p1[j];
2176
2212
 
2177
- const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
2178
- const float f1 = d0*((int8_t) (v0 >> 4) - 8);
2213
+ const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
2214
+ const int8_t i1 = (int8_t) (v0 >> 4) - 8;
2179
2215
 
2180
- const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
2181
- const float f3 = d1*((int8_t) (v1 >> 4) - 8);
2216
+ const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
2217
+ const int8_t i3 = (int8_t) (v1 >> 4) - 8;
2182
2218
 
2183
- sumf += f0*f2 + f1*f3;
2219
+ sumi += i0*i2 + i1*i3;
2184
2220
  }
2221
+ sumf += d0 * d1 * sumi;
2185
2222
  }
2186
2223
  #endif
2187
2224
 
@@ -2273,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2273
2310
  float sum10 = 0.0f;
2274
2311
  float sum11 = 0.0f;
2275
2312
 
2276
- for (int i = 0; i < nb; ++i) {
2313
+ for (int i = 0; i < nb; i += 2) {
2277
2314
  const block_q4_1 * restrict x0 = &x[i + 0];
2278
2315
  const block_q4_1 * restrict y0 = &y[i + 0];
2316
+ const block_q4_1 * restrict x1 = &x[i + 1];
2317
+ const block_q4_1 * restrict y1 = &y[i + 1];
2279
2318
 
2280
2319
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2281
2320
 
2282
2321
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2283
2322
  const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2323
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2324
+ const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2284
2325
 
2285
- // and with 0xf
2326
+ // 4-bit -> 8-bit
2286
2327
  const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2287
2328
  const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2288
-
2289
2329
  const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2290
2330
  const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2291
2331
 
2292
- // dot product into uint16x8_t
2332
+ const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2333
+ const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2334
+ const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2335
+ const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2336
+
2337
+ sum00 += x0->m*y0->m;
2338
+ sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2339
+ sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2340
+
2341
+ sum00 += x1->m*y1->m;
2342
+ sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2343
+ sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2344
+
2345
+ #if defined(__ARM_FEATURE_DOTPROD)
2346
+ // dot product into int32x4_t
2347
+ uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2348
+ uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2349
+
2350
+ p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2351
+ p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2352
+
2353
+ sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2354
+ sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2355
+ #else
2293
2356
  const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2294
2357
  const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2295
-
2296
2358
  const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2297
2359
  const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2298
2360
 
2299
- const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2300
- const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
2361
+ const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2362
+ const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2363
+ const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2364
+ const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2301
2365
 
2302
- sum00 += x0->m*y0->m;
2303
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2304
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2305
- sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
2366
+ const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2367
+ const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2368
+
2369
+ const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2370
+ const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2371
+
2372
+ const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2373
+ const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2374
+
2375
+ sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2376
+ sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2377
+ #endif
2306
2378
  }
2307
2379
 
2308
2380
  sumf = QK*sum00 + sum01 + sum10 + sum11;
@@ -2578,29 +2650,38 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2578
2650
  //
2579
2651
 
2580
2652
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2581
- QK,
2582
- QK,
2583
- 1,
2584
- 1,
2585
- 1,
2586
- 1,
2587
- 1,
2653
+ [GGML_TYPE_F32] = 1,
2654
+ [GGML_TYPE_F16] = 1,
2655
+ [GGML_TYPE_Q4_0] = QK,
2656
+ [GGML_TYPE_Q4_1] = QK,
2657
+ [GGML_TYPE_I8] = 1,
2658
+ [GGML_TYPE_I16] = 1,
2659
+ [GGML_TYPE_I32] = 1,
2588
2660
  };
2589
-
2590
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2661
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
2591
2662
 
2592
2663
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2593
- sizeof(block_q4_0),
2594
- sizeof(block_q4_1),
2595
- sizeof(int8_t ),
2596
- sizeof(int16_t),
2597
- sizeof(int32_t),
2598
- sizeof(ggml_fp16_t),
2599
- sizeof(float ),
2664
+ [GGML_TYPE_F32] = sizeof(float),
2665
+ [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
2666
+ [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
2667
+ [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
2668
+ [GGML_TYPE_I8] = sizeof(int8_t),
2669
+ [GGML_TYPE_I16] = sizeof(int16_t),
2670
+ [GGML_TYPE_I32] = sizeof(int32_t),
2600
2671
  };
2672
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
2673
+
2601
2674
 
2602
- // don't forget to update the array above when adding new types
2603
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2675
+ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
2676
+ [GGML_TYPE_F32] = "f32",
2677
+ [GGML_TYPE_F16] = "f16",
2678
+ [GGML_TYPE_Q4_0] = "q4_0",
2679
+ [GGML_TYPE_Q4_1] = "q4_1",
2680
+ [GGML_TYPE_I8] = "i8",
2681
+ [GGML_TYPE_I16] = "i16",
2682
+ [GGML_TYPE_I32] = "i32",
2683
+ };
2684
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated");
2604
2685
 
2605
2686
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2606
2687
  "NONE",
@@ -2629,6 +2710,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2629
2710
 
2630
2711
  "SCALE",
2631
2712
  "CPY",
2713
+ "CONT",
2632
2714
  "RESHAPE",
2633
2715
  "VIEW",
2634
2716
  "PERMUTE",
@@ -2642,9 +2724,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2642
2724
 
2643
2725
  "FLASH_ATTN",
2644
2726
  "FLASH_FF",
2727
+
2728
+ "MAP_UNARY",
2729
+ "MAP_BINARY",
2645
2730
  };
2646
2731
 
2647
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2732
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2648
2733
 
2649
2734
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2650
2735
  "none",
@@ -2673,6 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2673
2758
 
2674
2759
  "x*v",
2675
2760
  "x-\\>y",
2761
+ "cont(x)",
2676
2762
  "reshape(x)",
2677
2763
  "view(x)",
2678
2764
  "permute(x)",
@@ -2686,24 +2772,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2686
2772
 
2687
2773
  "flash_attn(x)",
2688
2774
  "flash_ff(x)",
2689
- };
2690
-
2691
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2692
-
2693
- //
2694
- // ggml object
2695
- //
2696
-
2697
- struct ggml_object {
2698
- size_t offs;
2699
- size_t size;
2700
2775
 
2701
- struct ggml_object * next;
2702
-
2703
- char padding[8];
2776
+ "f(x)",
2777
+ "f(x,y)",
2704
2778
  };
2705
2779
 
2706
- static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
2780
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2707
2781
 
2708
2782
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
2709
2783
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -2716,7 +2790,6 @@ struct ggml_context {
2716
2790
  size_t mem_size;
2717
2791
  void * mem_buffer;
2718
2792
  bool mem_buffer_owned;
2719
- bool mem_buffer_mlocked;
2720
2793
  bool no_alloc;
2721
2794
 
2722
2795
  int n_objects;
@@ -2834,6 +2907,11 @@ float ggml_type_sizef(enum ggml_type type) {
2834
2907
  return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
2835
2908
  }
2836
2909
 
2910
+ const char * ggml_type_name(enum ggml_type type) {
2911
+ return GGML_TYPE_NAME[type];
2912
+ }
2913
+
2914
+
2837
2915
  size_t ggml_element_size(const struct ggml_tensor * tensor) {
2838
2916
  return GGML_TYPE_SIZE[tensor->type];
2839
2917
  }
@@ -2999,11 +3077,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2999
3077
  return NULL;
3000
3078
  }
3001
3079
 
3080
+ const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
3081
+
3002
3082
  *ctx = (struct ggml_context) {
3003
- /*.mem_size =*/ params.mem_size,
3004
- /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
3083
+ /*.mem_size =*/ mem_size,
3084
+ /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
3005
3085
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
3006
- /*.mem_buffer_mlocked =*/ false,
3007
3086
  /*.no_alloc =*/ params.no_alloc,
3008
3087
  /*.n_objects =*/ 0,
3009
3088
  /*.objects_begin =*/ NULL,
@@ -3012,7 +3091,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3012
3091
  /*.scratch_save =*/ { 0, 0, NULL, },
3013
3092
  };
3014
3093
 
3015
- GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
3094
+ GGML_ASSERT(ctx->mem_buffer != NULL);
3016
3095
 
3017
3096
  ggml_assert_aligned(ctx->mem_buffer);
3018
3097
 
@@ -3036,16 +3115,8 @@ void ggml_free(struct ggml_context * ctx) {
3036
3115
  GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
3037
3116
  __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
3038
3117
 
3039
- #if GGML_MLOCK_SUPPORT
3040
- if (ctx->mem_buffer_mlocked) {
3041
- if (munlock(ctx->mem_buffer, ctx->mem_size)) {
3042
- fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
3043
- }
3044
- }
3045
- #endif
3046
-
3047
3118
  if (ctx->mem_buffer_owned) {
3048
- free(ctx->mem_buffer);
3119
+ GGML_ALIGNED_FREE(ctx->mem_buffer);
3049
3120
  }
3050
3121
 
3051
3122
  found = true;
@@ -3072,48 +3143,6 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
3072
3143
  return result;
3073
3144
  }
3074
3145
 
3075
- #ifdef __APPLE__
3076
- #define MLOCK_SUGGESTION \
3077
- "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
3078
- "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
3079
- #else
3080
- #define MLOCK_SUGGESTION \
3081
- "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
3082
- #endif
3083
-
3084
- bool ggml_mlock_supported(void) {
3085
- return GGML_MLOCK_SUPPORT;
3086
- }
3087
-
3088
- bool ggml_mlock(
3089
- struct ggml_context * ctx,
3090
- const void *opt_extra_addr,
3091
- size_t opt_extra_len,
3092
- char **err_p) {
3093
- // TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
3094
- #if GGML_MLOCK_SUPPORT
3095
- if (ctx->mem_buffer_mlocked) {
3096
- return true;
3097
- }
3098
- if (mlock(ctx->mem_buffer, ctx->mem_size) ||
3099
- (opt_extra_len &&
3100
- mlock(opt_extra_addr, opt_extra_len))) {
3101
- if ((*err_p = malloc(1024))) {
3102
- snprintf(*err_p, 1024,
3103
- "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
3104
- ctx->mem_size + opt_extra_len,
3105
- strerror(errno));
3106
- }
3107
- return false;
3108
- }
3109
- ctx->mem_buffer_mlocked = true;
3110
- return true;
3111
- #else // GGML_MLOCK_SUPPORT
3112
- *err_p = strdup("can't mlock because it's not supported on this system");
3113
- return false;
3114
- #endif // GGML_MLOCK_SUPPORT
3115
- }
3116
-
3117
3146
  ////////////////////////////////////////////////////////////////////////////////
3118
3147
 
3119
3148
  struct ggml_tensor * ggml_new_tensor_impl(
@@ -4388,6 +4417,41 @@ struct ggml_tensor * ggml_cpy_inplace(
4388
4417
  return ggml_cpy_impl(ctx, a, b, true);
4389
4418
  }
4390
4419
 
4420
+ // ggml_cont
4421
+
4422
+ struct ggml_tensor * ggml_cont_impl(
4423
+ struct ggml_context * ctx,
4424
+ struct ggml_tensor * a,
4425
+ bool inplace) {
4426
+ bool is_node = false;
4427
+
4428
+ if (!inplace && a->grad) {
4429
+ GGML_ASSERT(false); // TODO: implement backward
4430
+ is_node = true;
4431
+ }
4432
+
4433
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4434
+
4435
+ result->op = GGML_OP_CONT;
4436
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4437
+ result->src0 = a;
4438
+ result->src1 = NULL;
4439
+
4440
+ return result;
4441
+ }
4442
+
4443
+ struct ggml_tensor * ggml_cont(
4444
+ struct ggml_context * ctx,
4445
+ struct ggml_tensor * a) {
4446
+ return ggml_cont_impl(ctx, a, false);
4447
+ }
4448
+
4449
+ struct ggml_tensor * ggml_cont_inplace(
4450
+ struct ggml_context * ctx,
4451
+ struct ggml_tensor * a) {
4452
+ return ggml_cont_impl(ctx, a, true);
4453
+ }
4454
+
4391
4455
  // ggml_reshape
4392
4456
 
4393
4457
  struct ggml_tensor * ggml_reshape(
@@ -4866,6 +4930,90 @@ struct ggml_tensor * ggml_flash_ff(
4866
4930
  return result;
4867
4931
  }
4868
4932
 
4933
+ // ggml_map_unary
4934
+
4935
+ struct ggml_tensor * ggml_map_unary_impl_f32(
4936
+ struct ggml_context * ctx,
4937
+ struct ggml_tensor * a,
4938
+ const ggml_unary_op_f32_t fun,
4939
+ bool inplace) {
4940
+ bool is_node = false;
4941
+
4942
+ if (!inplace && a->grad) {
4943
+ is_node = true;
4944
+ }
4945
+
4946
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
4947
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
4948
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4949
+
4950
+ result->op = GGML_OP_MAP_UNARY;
4951
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4952
+ result->src0 = a;
4953
+ result->opt[0] = addr_tensor;
4954
+
4955
+ return result;
4956
+ }
4957
+
4958
+ struct ggml_tensor * ggml_map_unary_f32(
4959
+ struct ggml_context * ctx,
4960
+ struct ggml_tensor * a,
4961
+ const ggml_unary_op_f32_t fun) {
4962
+ return ggml_map_unary_impl_f32(ctx, a, fun, false);
4963
+ }
4964
+
4965
+ struct ggml_tensor * ggml_map_unary_inplace_f32(
4966
+ struct ggml_context * ctx,
4967
+ struct ggml_tensor * a,
4968
+ const ggml_unary_op_f32_t fun) {
4969
+ return ggml_map_unary_impl_f32(ctx, a, fun, true);
4970
+ }
4971
+
4972
+ // ggml_map_binary
4973
+
4974
+ struct ggml_tensor * ggml_map_binary_impl_f32(
4975
+ struct ggml_context * ctx,
4976
+ struct ggml_tensor * a,
4977
+ struct ggml_tensor * b,
4978
+ const ggml_binary_op_f32_t fun,
4979
+ bool inplace) {
4980
+ GGML_ASSERT(ggml_are_same_shape(a, b));
4981
+
4982
+ bool is_node = false;
4983
+
4984
+ if (!inplace && (a->grad || b->grad)) {
4985
+ is_node = true;
4986
+ }
4987
+
4988
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
4989
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
4990
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4991
+
4992
+ result->op = GGML_OP_MAP_BINARY;
4993
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4994
+ result->src0 = a;
4995
+ result->src1 = b;
4996
+ result->opt[0] = addr_tensor;
4997
+
4998
+ return result;
4999
+ }
5000
+
5001
+ struct ggml_tensor * ggml_map_binary_f32(
5002
+ struct ggml_context * ctx,
5003
+ struct ggml_tensor * a,
5004
+ struct ggml_tensor * b,
5005
+ const ggml_binary_op_f32_t fun) {
5006
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
5007
+ }
5008
+
5009
+ struct ggml_tensor * ggml_map_binary_inplace_f32(
5010
+ struct ggml_context * ctx,
5011
+ struct ggml_tensor * a,
5012
+ struct ggml_tensor * b,
5013
+ const ggml_binary_op_f32_t fun) {
5014
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
5015
+ }
5016
+
4869
5017
  ////////////////////////////////////////////////////////////////////////////////
4870
5018
 
4871
5019
  void ggml_set_param(
@@ -4930,6 +5078,85 @@ static void ggml_compute_forward_dup_f16(
4930
5078
 
4931
5079
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
4932
5080
 
5081
+ if (ggml_is_contiguous(dst)) {
5082
+ if (src0->nb[0] == sizeof(ggml_fp16_t)) {
5083
+ if (dst->type == GGML_TYPE_F16) {
5084
+ size_t id = 0;
5085
+ const size_t rs = ne00*nb00;
5086
+
5087
+ for (int i03 = 0; i03 < ne03; i03++) {
5088
+ for (int i02 = 0; i02 < ne02; i02++) {
5089
+ for (int i01 = 0; i01 < ne01; i01++) {
5090
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5091
+ char * dst_ptr = (char *) dst->data + id*rs;
5092
+
5093
+ memcpy(dst_ptr, src0_ptr, rs);
5094
+
5095
+ id++;
5096
+ }
5097
+ }
5098
+ }
5099
+ } else if (dst->type == GGML_TYPE_F32) {
5100
+ size_t id = 0;
5101
+ float * dst_ptr = (float *) dst->data;
5102
+
5103
+ for (int i03 = 0; i03 < ne03; i03++) {
5104
+ for (int i02 = 0; i02 < ne02; i02++) {
5105
+ for (int i01 = 0; i01 < ne01; i01++) {
5106
+ for (int i00 = 0; i00 < ne00; i00++) {
5107
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5108
+
5109
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5110
+ id++;
5111
+ }
5112
+ }
5113
+ }
5114
+ }
5115
+ } else {
5116
+ GGML_ASSERT(false); // TODO: implement
5117
+ }
5118
+ } else {
5119
+ //printf("%s: this is not optimal - fix me\n", __func__);
5120
+
5121
+ if (dst->type == GGML_TYPE_F32) {
5122
+ size_t id = 0;
5123
+ float * dst_ptr = (float *) dst->data;
5124
+
5125
+ for (int i03 = 0; i03 < ne03; i03++) {
5126
+ for (int i02 = 0; i02 < ne02; i02++) {
5127
+ for (int i01 = 0; i01 < ne01; i01++) {
5128
+ for (int i00 = 0; i00 < ne00; i00++) {
5129
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5130
+
5131
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5132
+ id++;
5133
+ }
5134
+ }
5135
+ }
5136
+ }
5137
+ } else if (dst->type == GGML_TYPE_F16) {
5138
+ size_t id = 0;
5139
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5140
+
5141
+ for (int i03 = 0; i03 < ne03; i03++) {
5142
+ for (int i02 = 0; i02 < ne02; i02++) {
5143
+ for (int i01 = 0; i01 < ne01; i01++) {
5144
+ for (int i00 = 0; i00 < ne00; i00++) {
5145
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5146
+
5147
+ dst_ptr[id] = *src0_ptr;
5148
+ id++;
5149
+ }
5150
+ }
5151
+ }
5152
+ }
5153
+ } else {
5154
+ GGML_ASSERT(false); // TODO: implement
5155
+ }
5156
+ }
5157
+ return;
5158
+ }
5159
+
4933
5160
  // dst counters
4934
5161
  int64_t i10 = 0;
4935
5162
  int64_t i11 = 0;
@@ -5024,6 +5251,105 @@ static void ggml_compute_forward_dup_f32(
5024
5251
  return;
5025
5252
  }
5026
5253
 
5254
+ if (src0->type == dst->type &&
5255
+ src0->ne[0] == dst->ne[0] &&
5256
+ src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5257
+ // copy by rows
5258
+ const size_t rs = ne00*nb00;
5259
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5260
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5261
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5262
+ memcpy(
5263
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5264
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
5265
+ rs);
5266
+ }
5267
+ }
5268
+ }
5269
+ return;
5270
+ }
5271
+
5272
+ if (ggml_is_contiguous(dst)) {
5273
+ // TODO: simplify
5274
+ if (src0->nb[0] == sizeof(float)) {
5275
+ if (dst->type == GGML_TYPE_F32) {
5276
+ size_t id = 0;
5277
+ const size_t rs = ne00*nb00;
5278
+
5279
+ for (int i03 = 0; i03 < ne03; i03++) {
5280
+ for (int i02 = 0; i02 < ne02; i02++) {
5281
+ for (int i01 = 0; i01 < ne01; i01++) {
5282
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5283
+ char * dst_ptr = (char *) dst->data + id*rs;
5284
+
5285
+ memcpy(dst_ptr, src0_ptr, rs);
5286
+
5287
+ id++;
5288
+ }
5289
+ }
5290
+ }
5291
+ } else if (dst->type == GGML_TYPE_F16) {
5292
+ size_t id = 0;
5293
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5294
+
5295
+ for (int i03 = 0; i03 < ne03; i03++) {
5296
+ for (int i02 = 0; i02 < ne02; i02++) {
5297
+ for (int i01 = 0; i01 < ne01; i01++) {
5298
+ for (int i00 = 0; i00 < ne00; i00++) {
5299
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5300
+
5301
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5302
+ id++;
5303
+ }
5304
+ }
5305
+ }
5306
+ }
5307
+ } else {
5308
+ GGML_ASSERT(false); // TODO: implement
5309
+ }
5310
+ } else {
5311
+ //printf("%s: this is not optimal - fix me\n", __func__);
5312
+
5313
+ if (dst->type == GGML_TYPE_F32) {
5314
+ size_t id = 0;
5315
+ float * dst_ptr = (float *) dst->data;
5316
+
5317
+ for (int i03 = 0; i03 < ne03; i03++) {
5318
+ for (int i02 = 0; i02 < ne02; i02++) {
5319
+ for (int i01 = 0; i01 < ne01; i01++) {
5320
+ for (int i00 = 0; i00 < ne00; i00++) {
5321
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5322
+
5323
+ dst_ptr[id] = *src0_ptr;
5324
+ id++;
5325
+ }
5326
+ }
5327
+ }
5328
+ }
5329
+ } else if (dst->type == GGML_TYPE_F16) {
5330
+ size_t id = 0;
5331
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5332
+
5333
+ for (int i03 = 0; i03 < ne03; i03++) {
5334
+ for (int i02 = 0; i02 < ne02; i02++) {
5335
+ for (int i01 = 0; i01 < ne01; i01++) {
5336
+ for (int i00 = 0; i00 < ne00; i00++) {
5337
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5338
+
5339
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5340
+ id++;
5341
+ }
5342
+ }
5343
+ }
5344
+ }
5345
+ } else {
5346
+ GGML_ASSERT(false); // TODO: implement
5347
+ }
5348
+ }
5349
+
5350
+ return;
5351
+ }
5352
+
5027
5353
  // dst counters
5028
5354
  int64_t i10 = 0;
5029
5355
  int64_t i11 = 0;
@@ -5144,14 +5470,18 @@ static void ggml_compute_forward_add_f32(
5144
5470
  GGML_ASSERT(nb00 == sizeof(float));
5145
5471
 
5146
5472
  if (nb10 == sizeof(float)) {
5147
- const int j0 = (n/nth)*ith;
5148
- const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
5149
-
5150
- for (int j = j0; j < j1; j++) {
5473
+ for (int j = ith; j < n; j += nth) {
5474
+ #ifdef GGML_USE_ACCELERATE
5475
+ vDSP_vadd(
5476
+ (float *) ((char *) src0->data + j*nb01), 1,
5477
+ (float *) ((char *) src1->data + j*nb11), 1,
5478
+ (float *) ((char *) dst->data + j*nb1), 1, nc);
5479
+ #else
5151
5480
  ggml_vec_add_f32(nc,
5152
5481
  (float *) ((char *) dst->data + j*nb1),
5153
5482
  (float *) ((char *) src0->data + j*nb01),
5154
5483
  (float *) ((char *) src1->data + j*nb11));
5484
+ #endif
5155
5485
  }
5156
5486
  } else {
5157
5487
  // src1 is not contiguous
@@ -6304,7 +6634,7 @@ static void ggml_compute_forward_mul_mat_f32(
6304
6634
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6305
6635
  ne11, ne01, ne10,
6306
6636
  1.0f, y, ne10,
6307
- x, ne10,
6637
+ x, ne00,
6308
6638
  0.0f, d, ne01);
6309
6639
  }
6310
6640
  }
@@ -6476,7 +6806,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6476
6806
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6477
6807
  ne11, ne01, ne10,
6478
6808
  1.0f, y, ne10,
6479
- x, ne10,
6809
+ x, ne00,
6480
6810
  0.0f, d, ne01);
6481
6811
  }
6482
6812
  }
@@ -6564,29 +6894,27 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6564
6894
  //}
6565
6895
  }
6566
6896
 
6567
- typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k);
6568
- typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k);
6569
- typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y);
6570
-
6571
- typedef struct {
6572
- dequantize_row_q_t dequantize_row_q;
6573
- quantize_row_q_t quantize_row_q;
6574
- vec_dot_q_t vec_dot_q;
6575
- } quantize_fns_t;
6576
-
6577
6897
  static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6578
6898
  [GGML_TYPE_Q4_0] = {
6579
- .dequantize_row_q = dequantize_row_q4_0,
6580
- .quantize_row_q = quantize_row_q4_0,
6581
- .vec_dot_q = ggml_vec_dot_q4_0,
6899
+ .dequantize_row_q = dequantize_row_q4_0,
6900
+ .quantize_row_q = quantize_row_q4_0,
6901
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6902
+ .vec_dot_q = ggml_vec_dot_q4_0,
6582
6903
  },
6583
6904
  [GGML_TYPE_Q4_1] = {
6584
- .dequantize_row_q = dequantize_row_q4_1,
6585
- .quantize_row_q = quantize_row_q4_1,
6586
- .vec_dot_q = ggml_vec_dot_q4_1,
6905
+ .dequantize_row_q = dequantize_row_q4_1,
6906
+ .quantize_row_q = quantize_row_q4_1,
6907
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6908
+ .vec_dot_q = ggml_vec_dot_q4_1,
6587
6909
  },
6588
6910
  };
6589
6911
 
6912
+ // For internal test use
6913
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6914
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
6915
+ return quantize_fns[i];
6916
+ }
6917
+
6590
6918
  static void ggml_compute_forward_mul_mat_q_f32(
6591
6919
  const struct ggml_compute_params * params,
6592
6920
  const struct ggml_tensor * src0,
@@ -6691,7 +7019,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6691
7019
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6692
7020
  ne11, ne01, ne10,
6693
7021
  1.0f, y, ne10,
6694
- x, ne10,
7022
+ x, ne00,
6695
7023
  0.0f, d, ne01);
6696
7024
  }
6697
7025
  }
@@ -6901,6 +7229,15 @@ static void ggml_compute_forward_cpy(
6901
7229
  ggml_compute_forward_dup(params, src0, dst);
6902
7230
  }
6903
7231
 
7232
+ // ggml_compute_forward_cont
7233
+
7234
+ static void ggml_compute_forward_cont(
7235
+ const struct ggml_compute_params * params,
7236
+ const struct ggml_tensor * src0,
7237
+ struct ggml_tensor * dst) {
7238
+ ggml_compute_forward_dup(params, src0, dst);
7239
+ }
7240
+
6904
7241
  // ggml_compute_forward_reshape
6905
7242
 
6906
7243
  static void ggml_compute_forward_reshape(
@@ -7279,6 +7616,8 @@ static void ggml_compute_forward_rope_f32(
7279
7616
  // row index used to determine which thread to use
7280
7617
  int ir = 0;
7281
7618
 
7619
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
7620
+
7282
7621
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7283
7622
  for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7284
7623
  const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7286,11 +7625,13 @@ static void ggml_compute_forward_rope_f32(
7286
7625
  if (ir++ < ir0) continue;
7287
7626
  if (ir > ir1) break;
7288
7627
 
7628
+ float theta = (float)p;
7629
+
7289
7630
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7290
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
7631
+ const float cos_theta = cosf(theta);
7632
+ const float sin_theta = sinf(theta);
7291
7633
 
7292
- const float cos_theta = cosf(p*theta);
7293
- const float sin_theta = sinf(p*theta);
7634
+ theta *= theta_scale;
7294
7635
 
7295
7636
  const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7296
7637
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -7352,6 +7693,8 @@ static void ggml_compute_forward_rope_f16(
7352
7693
  // row index used to determine which thread to use
7353
7694
  int ir = 0;
7354
7695
 
7696
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
7697
+
7355
7698
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7356
7699
  for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7357
7700
  const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7359,11 +7702,13 @@ static void ggml_compute_forward_rope_f16(
7359
7702
  if (ir++ < ir0) continue;
7360
7703
  if (ir > ir1) break;
7361
7704
 
7705
+ float theta = (float)p;
7706
+
7362
7707
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7363
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
7708
+ const float cos_theta = cosf(theta);
7709
+ const float sin_theta = sinf(theta);
7364
7710
 
7365
- const float cos_theta = cosf(p*theta);
7366
- const float sin_theta = sinf(p*theta);
7711
+ theta *= theta_scale;
7367
7712
 
7368
7713
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7369
7714
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -8637,6 +8982,111 @@ static void ggml_compute_forward_flash_ff(
8637
8982
  }
8638
8983
  }
8639
8984
 
8985
+ // ggml_compute_forward_map_unary
8986
+
8987
+ static void ggml_compute_forward_map_unary_f32(
8988
+ const struct ggml_compute_params * params,
8989
+ const struct ggml_tensor * src0,
8990
+ struct ggml_tensor * dst,
8991
+ const ggml_unary_op_f32_t fun) {
8992
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
8993
+
8994
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8995
+ return;
8996
+ }
8997
+
8998
+ const int n = ggml_nrows(src0);
8999
+ const int nc = src0->ne[0];
9000
+
9001
+ assert( dst->nb[0] == sizeof(float));
9002
+ assert(src0->nb[0] == sizeof(float));
9003
+
9004
+ for (int i = 0; i < n; i++) {
9005
+ fun(nc,
9006
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9007
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
9008
+ }
9009
+ }
9010
+
9011
+
9012
+ static void ggml_compute_forward_map_unary(
9013
+ const struct ggml_compute_params * params,
9014
+ const struct ggml_tensor * src0,
9015
+ struct ggml_tensor * dst,
9016
+ const ggml_unary_op_f32_t fun) {
9017
+ switch (src0->type) {
9018
+ case GGML_TYPE_F32:
9019
+ {
9020
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9021
+ } break;
9022
+ case GGML_TYPE_Q4_0:
9023
+ case GGML_TYPE_Q4_1:
9024
+ case GGML_TYPE_I8:
9025
+ case GGML_TYPE_I16:
9026
+ case GGML_TYPE_I32:
9027
+ case GGML_TYPE_F16:
9028
+ case GGML_TYPE_COUNT:
9029
+ {
9030
+ GGML_ASSERT(false);
9031
+ } break;
9032
+ }
9033
+ }
9034
+
9035
+ // ggml_compute_forward_map_binary
9036
+
9037
+ static void ggml_compute_forward_map_binary_f32(
9038
+ const struct ggml_compute_params * params,
9039
+ const struct ggml_tensor * src0,
9040
+ const struct ggml_tensor * src1,
9041
+ struct ggml_tensor * dst,
9042
+ const ggml_binary_op_f32_t fun) {
9043
+ assert(params->ith == 0);
9044
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9045
+
9046
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9047
+ return;
9048
+ }
9049
+
9050
+ const int n = ggml_nrows(src0);
9051
+ const int nc = src0->ne[0];
9052
+
9053
+ assert( dst->nb[0] == sizeof(float));
9054
+ assert(src0->nb[0] == sizeof(float));
9055
+ assert(src1->nb[0] == sizeof(float));
9056
+
9057
+ for (int i = 0; i < n; i++) {
9058
+ fun(nc,
9059
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9060
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
9061
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
9062
+ }
9063
+ }
9064
+
9065
+
9066
+ static void ggml_compute_forward_map_binary(
9067
+ const struct ggml_compute_params * params,
9068
+ const struct ggml_tensor * src0,
9069
+ const struct ggml_tensor * src1,
9070
+ struct ggml_tensor * dst,
9071
+ const ggml_binary_op_f32_t fun) {
9072
+ switch (src0->type) {
9073
+ case GGML_TYPE_F32:
9074
+ {
9075
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9076
+ } break;
9077
+ case GGML_TYPE_Q4_0:
9078
+ case GGML_TYPE_Q4_1:
9079
+ case GGML_TYPE_I8:
9080
+ case GGML_TYPE_I16:
9081
+ case GGML_TYPE_I32:
9082
+ case GGML_TYPE_F16:
9083
+ case GGML_TYPE_COUNT:
9084
+ {
9085
+ GGML_ASSERT(false);
9086
+ } break;
9087
+ }
9088
+ }
9089
+
8640
9090
  /////////////////////////////////
8641
9091
 
8642
9092
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -8731,6 +9181,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8731
9181
  {
8732
9182
  ggml_compute_forward_cpy(params, tensor->src0, tensor);
8733
9183
  } break;
9184
+ case GGML_OP_CONT:
9185
+ {
9186
+ ggml_compute_forward_cont(params, tensor->src0, tensor);
9187
+ } break;
8734
9188
  case GGML_OP_RESHAPE:
8735
9189
  {
8736
9190
  ggml_compute_forward_reshape(params, tensor->src0, tensor);
@@ -8782,6 +9236,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8782
9236
  {
8783
9237
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
8784
9238
  } break;
9239
+ case GGML_OP_MAP_UNARY:
9240
+ {
9241
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
9242
+ ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
9243
+ }
9244
+ break;
9245
+ case GGML_OP_MAP_BINARY:
9246
+ {
9247
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
9248
+ ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
9249
+ }
9250
+ break;
8785
9251
  case GGML_OP_NONE:
8786
9252
  {
8787
9253
  // nop
@@ -8975,8 +9441,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8975
9441
  src1->grad =
8976
9442
  ggml_add_impl(ctx,
8977
9443
  src1->grad,
8978
- // TODO: fix transpose, the node will break the graph connections
8979
- ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
9444
+ ggml_mul_mat(ctx,
9445
+ ggml_cont(ctx, ggml_transpose(ctx, src0)),
9446
+ tensor->grad),
8980
9447
  inplace);
8981
9448
  }
8982
9449
  } break;
@@ -8988,6 +9455,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8988
9455
  {
8989
9456
  GGML_ASSERT(false); // TODO: not implemented
8990
9457
  } break;
9458
+ case GGML_OP_CONT:
9459
+ {
9460
+ GGML_ASSERT(false); // TODO: not implemented
9461
+ } break;
8991
9462
  case GGML_OP_RESHAPE:
8992
9463
  {
8993
9464
  GGML_ASSERT(false); // TODO: not implemented
@@ -9036,6 +9507,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
9036
9507
  {
9037
9508
  GGML_ASSERT(false); // not supported
9038
9509
  } break;
9510
+ case GGML_OP_MAP_UNARY:
9511
+ case GGML_OP_MAP_BINARY:
9512
+ {
9513
+ GGML_ASSERT(false); // not supported
9514
+ } break;
9039
9515
  case GGML_OP_NONE:
9040
9516
  {
9041
9517
  // nop
@@ -9126,7 +9602,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
9126
9602
  struct ggml_cgraph result = {
9127
9603
  /*.n_nodes =*/ 0,
9128
9604
  /*.n_leafs =*/ 0,
9129
- /*.n_threads =*/ 0,
9605
+ /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
9130
9606
  /*.work_size =*/ 0,
9131
9607
  /*.work =*/ NULL,
9132
9608
  /*.nodes =*/ { NULL },
@@ -9442,6 +9918,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9442
9918
  node->n_tasks = n_threads;
9443
9919
  } break;
9444
9920
  case GGML_OP_CPY:
9921
+ case GGML_OP_CONT:
9445
9922
  case GGML_OP_RESHAPE:
9446
9923
  case GGML_OP_VIEW:
9447
9924
  case GGML_OP_PERMUTE:
@@ -9527,6 +10004,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9527
10004
 
9528
10005
  work_size = MAX(work_size, cur);
9529
10006
  } break;
10007
+ case GGML_OP_MAP_UNARY:
10008
+ case GGML_OP_MAP_BINARY:
10009
+ {
10010
+ node->n_tasks = 1;
10011
+ } break;
9530
10012
  case GGML_OP_NONE:
9531
10013
  {
9532
10014
  node->n_tasks = 1;
@@ -9745,8 +10227,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9745
10227
 
9746
10228
  GGML_PRINT("=== GRAPH ===\n");
9747
10229
 
9748
- GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
9749
- GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
10230
+ GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
10231
+ GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
9750
10232
 
9751
10233
  GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
9752
10234
  for (int i = 0; i < cgraph->n_nodes; i++) {