llama_cpp 0.14.6 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@
14
14
  #include "ggml-cuda/cpy.cuh"
15
15
  #include "ggml-cuda/diagmask.cuh"
16
16
  #include "ggml-cuda/dmmv.cuh"
17
+ #include "ggml-cuda/fattn.cuh"
17
18
  #include "ggml-cuda/getrows.cuh"
18
19
  #include "ggml-cuda/im2col.cuh"
19
20
  #include "ggml-cuda/mmq.cuh"
@@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
140
141
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
141
142
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
142
143
  info.devices[id].smpb = prop.sharedMemPerBlock;
144
+ info.devices[id].nsm = prop.multiProcessorCount;
143
145
  }
144
146
 
145
147
  for (int id = 0; id < info.device_count; ++id) {
@@ -2290,6 +2292,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2290
2292
  case GGML_OP_ARGSORT:
2291
2293
  ggml_cuda_op_argsort(ctx, dst);
2292
2294
  break;
2295
+ case GGML_OP_FLASH_ATTN_EXT:
2296
+ ggml_cuda_flash_attn_ext(ctx, dst);
2297
+ break;
2293
2298
  default:
2294
2299
  return false;
2295
2300
  }
@@ -2564,6 +2569,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2564
2569
  case GGML_OP_ARANGE:
2565
2570
  case GGML_OP_TIMESTEP_EMBEDDING:
2566
2571
  case GGML_OP_LEAKY_RELU:
2572
+ case GGML_OP_FLASH_ATTN_EXT:
2567
2573
  return true;
2568
2574
  default:
2569
2575
  return false;
@@ -11,6 +11,12 @@
11
11
  #include <string.h> // memcpy
12
12
  #include <math.h> // fabsf
13
13
 
14
+ #undef MIN
15
+ #undef MAX
16
+
17
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
18
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
19
+
14
20
  #ifdef __cplusplus
15
21
  extern "C" {
16
22
  #endif
@@ -45,7 +51,7 @@ extern "C" {
45
51
  // 16-bit float
46
52
  // on Arm, we use __fp16
47
53
  // on x86, we use uint16_t
48
- #if defined(__ARM_NEON) && !defined(_MSC_VER)
54
+ #if defined(__ARM_NEON)
49
55
 
50
56
  // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
51
57
  //
@@ -53,8 +59,262 @@ extern "C" {
53
59
  //
54
60
  #include <arm_neon.h>
55
61
 
62
+ #ifdef _MSC_VER
63
+
64
+ typedef uint16_t ggml_fp16_internal_t;
65
+
66
+ #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
67
+
68
+ #else
69
+
56
70
  typedef __fp16 ggml_fp16_internal_t;
57
71
 
72
+ #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
73
+
74
+ #endif // _MSC_VER
75
+
76
+ #if !defined(__aarch64__)
77
+
78
+ // 32-bit ARM compatibility
79
+
80
+ // vaddvq_s16
81
+ // vpaddq_s16
82
+ // vpaddq_s32
83
+ // vaddvq_s32
84
+ // vaddvq_f32
85
+ // vmaxvq_f32
86
+ // vcvtnq_s32_f32
87
+ // vzip1_u8
88
+ // vzip2_u8
89
+
90
+ inline static int32_t vaddvq_s16(int16x8_t v) {
91
+ return
92
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
93
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
94
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
95
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
96
+ }
97
+
98
+ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
99
+ int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
100
+ int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
101
+ return vcombine_s16(a0, b0);
102
+ }
103
+
104
+ inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
105
+ int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
106
+ int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
107
+ return vcombine_s32(a0, b0);
108
+ }
109
+
110
+ inline static int32_t vaddvq_s32(int32x4_t v) {
111
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
112
+ }
113
+
114
+ inline static float vaddvq_f32(float32x4_t v) {
115
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
116
+ }
117
+
118
+ inline static float vmaxvq_f32(float32x4_t v) {
119
+ return
120
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
121
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
122
+ }
123
+
124
+ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
125
+ int32x4_t res;
126
+
127
+ res[0] = roundf(vgetq_lane_f32(v, 0));
128
+ res[1] = roundf(vgetq_lane_f32(v, 1));
129
+ res[2] = roundf(vgetq_lane_f32(v, 2));
130
+ res[3] = roundf(vgetq_lane_f32(v, 3));
131
+
132
+ return res;
133
+ }
134
+
135
+ inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
136
+ uint8x8_t res;
137
+
138
+ res[0] = a[0]; res[1] = b[0];
139
+ res[2] = a[1]; res[3] = b[1];
140
+ res[4] = a[2]; res[5] = b[2];
141
+ res[6] = a[3]; res[7] = b[3];
142
+
143
+ return res;
144
+ }
145
+
146
+ inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
147
+ uint8x8_t res;
148
+
149
+ res[0] = a[4]; res[1] = b[4];
150
+ res[2] = a[5]; res[3] = b[5];
151
+ res[4] = a[6]; res[5] = b[6];
152
+ res[6] = a[7]; res[7] = b[7];
153
+
154
+ return res;
155
+ }
156
+
157
+ // vld1q_s16_x2
158
+ // vld1q_u8_x2
159
+ // vld1q_u8_x4
160
+ // vld1q_s8_x2
161
+ // vld1q_s8_x4
162
+ // TODO: double-check these work correctly
163
+
164
+ typedef struct ggml_int16x8x2_t {
165
+ int16x8_t val[2];
166
+ } ggml_int16x8x2_t;
167
+
168
+ inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
169
+ ggml_int16x8x2_t res;
170
+
171
+ res.val[0] = vld1q_s16(ptr + 0);
172
+ res.val[1] = vld1q_s16(ptr + 8);
173
+
174
+ return res;
175
+ }
176
+
177
+ typedef struct ggml_uint8x16x2_t {
178
+ uint8x16_t val[2];
179
+ } ggml_uint8x16x2_t;
180
+
181
+ inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
182
+ ggml_uint8x16x2_t res;
183
+
184
+ res.val[0] = vld1q_u8(ptr + 0);
185
+ res.val[1] = vld1q_u8(ptr + 16);
186
+
187
+ return res;
188
+ }
189
+
190
+ typedef struct ggml_uint8x16x4_t {
191
+ uint8x16_t val[4];
192
+ } ggml_uint8x16x4_t;
193
+
194
+ inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
195
+ ggml_uint8x16x4_t res;
196
+
197
+ res.val[0] = vld1q_u8(ptr + 0);
198
+ res.val[1] = vld1q_u8(ptr + 16);
199
+ res.val[2] = vld1q_u8(ptr + 32);
200
+ res.val[3] = vld1q_u8(ptr + 48);
201
+
202
+ return res;
203
+ }
204
+
205
+ typedef struct ggml_int8x16x2_t {
206
+ int8x16_t val[2];
207
+ } ggml_int8x16x2_t;
208
+
209
+ inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
210
+ ggml_int8x16x2_t res;
211
+
212
+ res.val[0] = vld1q_s8(ptr + 0);
213
+ res.val[1] = vld1q_s8(ptr + 16);
214
+
215
+ return res;
216
+ }
217
+
218
+ typedef struct ggml_int8x16x4_t {
219
+ int8x16_t val[4];
220
+ } ggml_int8x16x4_t;
221
+
222
+ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
223
+ ggml_int8x16x4_t res;
224
+
225
+ res.val[0] = vld1q_s8(ptr + 0);
226
+ res.val[1] = vld1q_s8(ptr + 16);
227
+ res.val[2] = vld1q_s8(ptr + 32);
228
+ res.val[3] = vld1q_s8(ptr + 48);
229
+
230
+ return res;
231
+ }
232
+
233
+ // NOTE: not tested
234
+ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
235
+ int8x16_t res;
236
+
237
+ res[ 0] = a[b[ 0]];
238
+ res[ 1] = a[b[ 1]];
239
+ res[ 2] = a[b[ 2]];
240
+ res[ 3] = a[b[ 3]];
241
+ res[ 4] = a[b[ 4]];
242
+ res[ 5] = a[b[ 5]];
243
+ res[ 6] = a[b[ 6]];
244
+ res[ 7] = a[b[ 7]];
245
+ res[ 8] = a[b[ 8]];
246
+ res[ 9] = a[b[ 9]];
247
+ res[10] = a[b[10]];
248
+ res[11] = a[b[11]];
249
+ res[12] = a[b[12]];
250
+ res[13] = a[b[13]];
251
+ res[14] = a[b[14]];
252
+ res[15] = a[b[15]];
253
+
254
+ return res;
255
+ }
256
+
257
+ // NOTE: not tested
258
+ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
259
+ uint8x16_t res;
260
+
261
+ res[ 0] = a[b[ 0]];
262
+ res[ 1] = a[b[ 1]];
263
+ res[ 2] = a[b[ 2]];
264
+ res[ 3] = a[b[ 3]];
265
+ res[ 4] = a[b[ 4]];
266
+ res[ 5] = a[b[ 5]];
267
+ res[ 6] = a[b[ 6]];
268
+ res[ 7] = a[b[ 7]];
269
+ res[ 8] = a[b[ 8]];
270
+ res[ 9] = a[b[ 9]];
271
+ res[10] = a[b[10]];
272
+ res[11] = a[b[11]];
273
+ res[12] = a[b[12]];
274
+ res[13] = a[b[13]];
275
+ res[14] = a[b[14]];
276
+ res[15] = a[b[15]];
277
+
278
+ return res;
279
+ }
280
+
281
+ #else
282
+
283
+ #define ggml_int16x8x2_t int16x8x2_t
284
+ #define ggml_uint8x16x2_t uint8x16x2_t
285
+ #define ggml_uint8x16x4_t uint8x16x4_t
286
+ #define ggml_int8x16x2_t int8x16x2_t
287
+ #define ggml_int8x16x4_t int8x16x4_t
288
+
289
+ #define ggml_vld1q_s16_x2 vld1q_s16_x2
290
+ #define ggml_vld1q_u8_x2 vld1q_u8_x2
291
+ #define ggml_vld1q_u8_x4 vld1q_u8_x4
292
+ #define ggml_vld1q_s8_x2 vld1q_s8_x2
293
+ #define ggml_vld1q_s8_x4 vld1q_s8_x4
294
+ #define ggml_vqtbl1q_s8 vqtbl1q_s8
295
+ #define ggml_vqtbl1q_u8 vqtbl1q_u8
296
+
297
+ #endif // !defined(__aarch64__)
298
+
299
+ #if !defined(__ARM_FEATURE_DOTPROD)
300
+
301
+ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
302
+ const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
303
+ const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
304
+
305
+ return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
306
+ }
307
+
308
+ #else
309
+
310
+ #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
311
+
312
+ #endif // !defined(__ARM_FEATURE_DOTPROD)
313
+
314
+ #endif // defined(__ARM_NEON)
315
+
316
+ #if defined(__ARM_NEON) && !defined(_MSC_VER)
317
+
58
318
  #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
59
319
  #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
60
320
 
@@ -75,8 +335,6 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
75
335
 
76
336
  #else
77
337
 
78
- typedef uint16_t ggml_fp16_internal_t;
79
-
80
338
  #ifdef __wasm_simd128__
81
339
  #include <wasm_simd128.h>
82
340
  #else
@@ -221,7 +479,7 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
221
479
 
222
480
  #endif // __F16C__
223
481
 
224
- #endif // __ARM_NEON
482
+ #endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
225
483
 
226
484
  // precomputed f32 table for f16 (256 KB)
227
485
  // defined in ggml.c, initialized in ggml_init()
@@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1427
1427
  for (int i = node_start; i < node_end; ++i) {
1428
1428
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1429
1429
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
1430
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
1430
1431
  struct ggml_tensor * dst = gf->nodes[i];
1431
1432
  GGML_ASSERT(dst->data != nullptr);
1432
1433
 
@@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1559
1560
  {
1560
1561
  float scale;
1561
1562
  memcpy(&scale, dst->op_params, sizeof(float));
1563
+
1564
+ #pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
1565
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1566
+ GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1567
+ GGML_ASSERT(src2 == nullptr);
1568
+
1562
1569
  ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1563
1570
  } break;
1564
1571
  case GGML_OP_DIAG_MASK_INF: