@fugood/llama.node 0.0.1-alpha.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. package/CMakeLists.txt +42 -7
  2. package/README.md +10 -0
  3. package/bin/darwin/arm64/default.metallib +0 -0
  4. package/bin/darwin/arm64/llama-node.node +0 -0
  5. package/bin/darwin/x64/default.metallib +0 -0
  6. package/bin/darwin/x64/llama-node.node +0 -0
  7. package/bin/linux/arm64/llama-node.node +0 -0
  8. package/bin/linux/x64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  10. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  11. package/lib/binding.js +1 -1
  12. package/lib/binding.ts +16 -2
  13. package/lib/index.ts +2 -2
  14. package/package.json +15 -3
  15. package/src/DetokenizeWorker.cpp +22 -0
  16. package/src/DetokenizeWorker.h +19 -0
  17. package/src/EmbeddingWorker.cpp +46 -0
  18. package/src/EmbeddingWorker.h +23 -0
  19. package/src/LlamaCompletionWorker.cpp +5 -1
  20. package/src/LlamaCompletionWorker.h +4 -0
  21. package/src/LlamaContext.cpp +80 -1
  22. package/src/LlamaContext.h +3 -0
  23. package/src/TokenizeWorker.cpp +26 -0
  24. package/src/TokenizeWorker.h +23 -0
  25. package/src/common.hpp +12 -7
  26. package/src/llama.cpp/CMakeLists.txt +13 -7
  27. package/src/llama.cpp/common/common.cpp +221 -173
  28. package/src/llama.cpp/common/common.h +19 -8
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/log.h +2 -2
  31. package/src/llama.cpp/common/sampling.cpp +17 -1
  32. package/src/llama.cpp/common/sampling.h +28 -20
  33. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +17 -11
  34. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +5 -5
  35. package/src/llama.cpp/examples/finetune/finetune.cpp +1 -1
  36. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -4
  37. package/src/llama.cpp/examples/imatrix/imatrix.cpp +72 -39
  38. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -3
  39. package/src/llama.cpp/examples/llava/clip.cpp +74 -23
  40. package/src/llama.cpp/examples/llava/llava-cli.cpp +37 -28
  41. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +0 -1
  42. package/src/llama.cpp/examples/lookup/lookup.cpp +0 -1
  43. package/src/llama.cpp/examples/main/main.cpp +10 -8
  44. package/src/llama.cpp/examples/perplexity/perplexity.cpp +175 -55
  45. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/quantize/quantize.cpp +74 -47
  47. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +1 -1
  48. package/src/llama.cpp/examples/server/server.cpp +97 -86
  49. package/src/llama.cpp/examples/server/utils.hpp +17 -15
  50. package/src/llama.cpp/ggml-backend.c +7 -5
  51. package/src/llama.cpp/ggml-impl.h +339 -4
  52. package/src/llama.cpp/ggml-kompute.cpp +7 -0
  53. package/src/llama.cpp/ggml-opencl.cpp +1 -0
  54. package/src/llama.cpp/ggml-quants.c +302 -293
  55. package/src/llama.cpp/ggml-sycl.cpp +28 -16
  56. package/src/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
  57. package/src/llama.cpp/ggml-vulkan.cpp +951 -263
  58. package/src/llama.cpp/ggml.c +1469 -116
  59. package/src/llama.cpp/ggml.h +37 -7
  60. package/src/llama.cpp/llama.cpp +969 -432
  61. package/src/llama.cpp/llama.h +46 -14
  62. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf-update.txt +2 -0
  63. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +0 -1
  64. package/src/llama.cpp/requirements/requirements-convert.txt +2 -2
  65. package/src/llama.cpp/requirements.txt +1 -0
  66. package/src/llama.cpp/sgemm.cpp +134 -103
  67. package/src/llama.cpp/sgemm.h +4 -2
  68. package/src/llama.cpp/tests/CMakeLists.txt +96 -36
  69. package/src/llama.cpp/tests/test-backend-ops.cpp +56 -6
  70. package/src/llama.cpp/tests/test-chat-template.cpp +4 -0
  71. package/src/llama.cpp/tests/test-grammar-integration.cpp +225 -136
  72. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +1 -0
  73. package/src/llama.cpp/tests/test-tokenizer-0.cpp +292 -0
  74. package/src/llama.cpp/tests/{test-tokenizer-1-llama.cpp → test-tokenizer-1-spm.cpp} +1 -1
  75. package/src/llama.cpp/unicode-data.cpp +1188 -656
  76. package/src/llama.cpp/unicode-data.h +4 -3
  77. package/src/llama.cpp/unicode.cpp +590 -49
  78. package/src/llama.cpp/unicode.h +6 -3
  79. package/bin/win32/arm64/llama-node.node +0 -0
  80. package/bin/win32/arm64/node.lib +0 -0
  81. package/bin/win32/x64/llama-node.node +0 -0
  82. package/bin/win32/x64/node.lib +0 -0
  83. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +0 -187
  84. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +0 -190
@@ -14,47 +14,6 @@
14
14
  #include <stdlib.h> // for qsort
15
15
  #include <stdio.h> // for GGML_ASSERT
16
16
 
17
- #ifdef __ARM_NEON
18
-
19
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
20
- //
21
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
22
- //
23
- #include <arm_neon.h>
24
-
25
- #else
26
-
27
- #ifdef __wasm_simd128__
28
- #include <wasm_simd128.h>
29
- #else
30
- #if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
31
- #include <altivec.h>
32
- #undef bool
33
- #define bool _Bool
34
- #else
35
- #if defined(_MSC_VER) || defined(__MINGW32__)
36
- #include <intrin.h>
37
- #else
38
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
39
- #if !defined(__riscv)
40
- #include <immintrin.h>
41
- #endif
42
- #endif
43
- #endif
44
- #endif
45
- #endif
46
- #endif
47
-
48
- #ifdef __riscv_v_intrinsic
49
- #include <riscv_vector.h>
50
- #endif
51
-
52
- #undef MIN
53
- #undef MAX
54
-
55
- #define MIN(a, b) ((a) < (b) ? (a) : (b))
56
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
57
-
58
17
  #define UNUSED GGML_UNUSED
59
18
 
60
19
  // some compilers don't provide _mm256_set_m128i, e.g. gcc 7
@@ -276,258 +235,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
276
235
  #endif // __AVX__ || __AVX2__ || __AVX512F__
277
236
  #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
278
237
 
279
- #if defined(__ARM_NEON)
280
-
281
- #ifdef _MSC_VER
282
-
283
- #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
284
-
285
- #else
286
-
287
- #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
288
-
289
- #endif
290
-
291
- #if !defined(__aarch64__)
292
-
293
- // 64-bit compatibility
294
-
295
- // vaddvq_s16
296
- // vpaddq_s16
297
- // vpaddq_s32
298
- // vaddvq_s32
299
- // vaddvq_f32
300
- // vmaxvq_f32
301
- // vcvtnq_s32_f32
302
- // vzip1_u8
303
- // vzip2_u8
304
-
305
- inline static int32_t vaddvq_s16(int16x8_t v) {
306
- return
307
- (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
308
- (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
309
- (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
310
- (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
311
- }
312
-
313
- inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
314
- int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
315
- int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
316
- return vcombine_s16(a0, b0);
317
- }
318
-
319
- inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
320
- int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
321
- int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
322
- return vcombine_s32(a0, b0);
323
- }
324
-
325
- inline static int32_t vaddvq_s32(int32x4_t v) {
326
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
327
- }
328
-
329
- inline static float vaddvq_f32(float32x4_t v) {
330
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
331
- }
332
-
333
- inline static float vmaxvq_f32(float32x4_t v) {
334
- return
335
- MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
336
- MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
337
- }
338
-
339
- inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
340
- int32x4_t res;
341
-
342
- res[0] = roundf(vgetq_lane_f32(v, 0));
343
- res[1] = roundf(vgetq_lane_f32(v, 1));
344
- res[2] = roundf(vgetq_lane_f32(v, 2));
345
- res[3] = roundf(vgetq_lane_f32(v, 3));
346
-
347
- return res;
348
- }
349
-
350
- inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
351
- uint8x8_t res;
352
-
353
- res[0] = a[0]; res[1] = b[0];
354
- res[2] = a[1]; res[3] = b[1];
355
- res[4] = a[2]; res[5] = b[2];
356
- res[6] = a[3]; res[7] = b[3];
357
-
358
- return res;
359
- }
360
-
361
- inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
362
- uint8x8_t res;
363
-
364
- res[0] = a[4]; res[1] = b[4];
365
- res[2] = a[5]; res[3] = b[5];
366
- res[4] = a[6]; res[5] = b[6];
367
- res[6] = a[7]; res[7] = b[7];
368
-
369
- return res;
370
- }
371
-
372
- // vld1q_s16_x2
373
- // vld1q_u8_x2
374
- // vld1q_u8_x4
375
- // vld1q_s8_x2
376
- // vld1q_s8_x4
377
- // TODO: double-check these work correctly
378
-
379
- typedef struct ggml_int16x8x2_t {
380
- int16x8_t val[2];
381
- } ggml_int16x8x2_t;
382
-
383
- inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
384
- ggml_int16x8x2_t res;
385
-
386
- res.val[0] = vld1q_s16(ptr + 0);
387
- res.val[1] = vld1q_s16(ptr + 8);
388
-
389
- return res;
390
- }
391
-
392
- typedef struct ggml_uint8x16x2_t {
393
- uint8x16_t val[2];
394
- } ggml_uint8x16x2_t;
395
-
396
- inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
397
- ggml_uint8x16x2_t res;
398
-
399
- res.val[0] = vld1q_u8(ptr + 0);
400
- res.val[1] = vld1q_u8(ptr + 16);
401
-
402
- return res;
403
- }
404
-
405
- typedef struct ggml_uint8x16x4_t {
406
- uint8x16_t val[4];
407
- } ggml_uint8x16x4_t;
408
-
409
- inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
410
- ggml_uint8x16x4_t res;
411
-
412
- res.val[0] = vld1q_u8(ptr + 0);
413
- res.val[1] = vld1q_u8(ptr + 16);
414
- res.val[2] = vld1q_u8(ptr + 32);
415
- res.val[3] = vld1q_u8(ptr + 48);
416
-
417
- return res;
418
- }
419
-
420
- typedef struct ggml_int8x16x2_t {
421
- int8x16_t val[2];
422
- } ggml_int8x16x2_t;
423
-
424
- inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
425
- ggml_int8x16x2_t res;
426
-
427
- res.val[0] = vld1q_s8(ptr + 0);
428
- res.val[1] = vld1q_s8(ptr + 16);
429
-
430
- return res;
431
- }
432
-
433
- typedef struct ggml_int8x16x4_t {
434
- int8x16_t val[4];
435
- } ggml_int8x16x4_t;
436
-
437
- inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
438
- ggml_int8x16x4_t res;
439
-
440
- res.val[0] = vld1q_s8(ptr + 0);
441
- res.val[1] = vld1q_s8(ptr + 16);
442
- res.val[2] = vld1q_s8(ptr + 32);
443
- res.val[3] = vld1q_s8(ptr + 48);
444
-
445
- return res;
446
- }
447
-
448
- // NOTE: not tested
449
- inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
450
- int8x16_t res;
451
-
452
- res[ 0] = a[b[ 0]];
453
- res[ 1] = a[b[ 1]];
454
- res[ 2] = a[b[ 2]];
455
- res[ 3] = a[b[ 3]];
456
- res[ 4] = a[b[ 4]];
457
- res[ 5] = a[b[ 5]];
458
- res[ 6] = a[b[ 6]];
459
- res[ 7] = a[b[ 7]];
460
- res[ 8] = a[b[ 8]];
461
- res[ 9] = a[b[ 9]];
462
- res[10] = a[b[10]];
463
- res[11] = a[b[11]];
464
- res[12] = a[b[12]];
465
- res[13] = a[b[13]];
466
- res[14] = a[b[14]];
467
- res[15] = a[b[15]];
468
-
469
- return res;
470
- }
471
-
472
- // NOTE: not tested
473
- inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
474
- uint8x16_t res;
475
-
476
- res[ 0] = a[b[ 0]];
477
- res[ 1] = a[b[ 1]];
478
- res[ 2] = a[b[ 2]];
479
- res[ 3] = a[b[ 3]];
480
- res[ 4] = a[b[ 4]];
481
- res[ 5] = a[b[ 5]];
482
- res[ 6] = a[b[ 6]];
483
- res[ 7] = a[b[ 7]];
484
- res[ 8] = a[b[ 8]];
485
- res[ 9] = a[b[ 9]];
486
- res[10] = a[b[10]];
487
- res[11] = a[b[11]];
488
- res[12] = a[b[12]];
489
- res[13] = a[b[13]];
490
- res[14] = a[b[14]];
491
- res[15] = a[b[15]];
492
-
493
- return res;
494
- }
495
-
496
- #else
497
-
498
- #define ggml_int16x8x2_t int16x8x2_t
499
- #define ggml_uint8x16x2_t uint8x16x2_t
500
- #define ggml_uint8x16x4_t uint8x16x4_t
501
- #define ggml_int8x16x2_t int8x16x2_t
502
- #define ggml_int8x16x4_t int8x16x4_t
503
-
504
- #define ggml_vld1q_s16_x2 vld1q_s16_x2
505
- #define ggml_vld1q_u8_x2 vld1q_u8_x2
506
- #define ggml_vld1q_u8_x4 vld1q_u8_x4
507
- #define ggml_vld1q_s8_x2 vld1q_s8_x2
508
- #define ggml_vld1q_s8_x4 vld1q_s8_x4
509
- #define ggml_vqtbl1q_s8 vqtbl1q_s8
510
- #define ggml_vqtbl1q_u8 vqtbl1q_u8
511
-
512
- #endif
513
-
514
- #if !defined(__ARM_FEATURE_DOTPROD)
515
-
516
- inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
517
- const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
518
- const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
519
-
520
- return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
521
- }
522
-
523
- #else
524
-
525
- #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
526
-
527
- #endif
528
-
529
- #endif
530
-
531
238
  #if defined(__ARM_NEON) || defined(__wasm_simd128__)
532
239
  #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
533
240
  #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
@@ -12676,3 +12383,305 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k)
12676
12383
  block_iq2_s * restrict y = vy;
12677
12384
  quantize_row_iq2_s_reference(x, y, k);
12678
12385
  }
12386
+
12387
+ static bool validate_float(float f, size_t i) {
12388
+ if (isinf(f)) {
12389
+ fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
12390
+ return false;
12391
+ }
12392
+
12393
+ if (isnan(f)) {
12394
+ fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
12395
+ return false;
12396
+ }
12397
+
12398
+ return true;
12399
+ }
12400
+
12401
+ static bool isinf_fp16(ggml_fp16_t f) {
12402
+ return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;
12403
+ }
12404
+
12405
+ static bool isnan_fp16(ggml_fp16_t f) {
12406
+ return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
12407
+ }
12408
+
12409
+ static bool validate_fp16(ggml_fp16_t f, size_t i) {
12410
+ if (isinf_fp16(f)) {
12411
+ fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
12412
+ return false;
12413
+ }
12414
+
12415
+ if (isnan_fp16(f)) {
12416
+ fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
12417
+ return false;
12418
+ }
12419
+
12420
+ return true;
12421
+ }
12422
+
12423
+ #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
12424
+ const type * q = (const type *) (data); \
12425
+ for (size_t i = 0; i < (nb); ++i) { \
12426
+ if (!validate_fp16(q[i].d, i)) { \
12427
+ return false; \
12428
+ } \
12429
+ }
12430
+
12431
+ #define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
12432
+ const type * q = (const type *) (data); \
12433
+ for (size_t i = 0; i < (nb); ++i) { \
12434
+ if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \
12435
+ return false; \
12436
+ } \
12437
+ }
12438
+
12439
+ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
12440
+ if (type < 0 || type >= GGML_TYPE_COUNT) {
12441
+ fprintf(stderr, "%s: invalid type %d\n", __func__, type);
12442
+ return false;
12443
+ }
12444
+
12445
+ if (nbytes % ggml_type_size(type) != 0) {
12446
+ fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
12447
+ return false;
12448
+ }
12449
+
12450
+ const size_t nb = nbytes/ggml_type_size(type);
12451
+
12452
+ switch (type) {
12453
+ case GGML_TYPE_BF16:
12454
+ {
12455
+ int nans = 0;
12456
+ int infs = 0;
12457
+ const unsigned short * f = (const unsigned short *) data;
12458
+ for (size_t i = 0; i < nb; ++i) {
12459
+ nans += (f[i] & 0x7fff) > 0x7f80;
12460
+ infs += (f[i] & 0x7fff) == 0x7f80;
12461
+ }
12462
+ if (nans) {
12463
+ fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
12464
+ return false;
12465
+ }
12466
+ if (infs) {
12467
+ fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
12468
+ return false;
12469
+ }
12470
+ } break;
12471
+ case GGML_TYPE_F16:
12472
+ {
12473
+ const ggml_fp16_t * f = (const ggml_fp16_t *) data;
12474
+ size_t i = 0;
12475
+ #if defined(__AVX2__)
12476
+ for (; i + 15 < nb; i += 16) {
12477
+ __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
12478
+ __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
12479
+ __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));
12480
+ int mask = _mm256_movemask_epi8(cmp);
12481
+ if (mask) {
12482
+ for (size_t j = 0; j < 16; ++j) {
12483
+ if (!validate_fp16(f[i + j], i + j)) {
12484
+ return false;
12485
+ }
12486
+ }
12487
+ GGML_UNREACHABLE();
12488
+ }
12489
+ }
12490
+ #elif defined(__ARM_NEON)
12491
+ for (; i + 7 < nb; i += 8) {
12492
+ uint16x8_t v = vld1q_u16(f + i);
12493
+ uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
12494
+ uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
12495
+ uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
12496
+ if (mask) {
12497
+ for (size_t j = 0; j < 8; ++j) {
12498
+ if (!validate_fp16(f[i + j], i + j)) {
12499
+ return false;
12500
+ }
12501
+ }
12502
+ GGML_UNREACHABLE();
12503
+ }
12504
+ }
12505
+ #endif
12506
+ for (; i < nb; ++i) {
12507
+ if (!validate_fp16(f[i], i)) {
12508
+ return false;
12509
+ }
12510
+ }
12511
+ } break;
12512
+ case GGML_TYPE_F32:
12513
+ {
12514
+ const float * f = (const float *) data;
12515
+ size_t i = 0;
12516
+ #if defined(__AVX2__)
12517
+ for (; i + 7 < nb; i += 8) {
12518
+ __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
12519
+ __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
12520
+ __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));
12521
+ int mask = _mm256_movemask_epi8(cmp);
12522
+ if (mask) {
12523
+ for (size_t j = 0; j < 8; ++j) {
12524
+ if (!validate_float(f[i + j], i + j)) {
12525
+ return false;
12526
+ }
12527
+ }
12528
+ GGML_UNREACHABLE();
12529
+ }
12530
+ }
12531
+ #elif defined(__ARM_NEON)
12532
+ for (; i + 3 < nb; i += 4) {
12533
+ uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
12534
+ uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
12535
+ uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
12536
+ uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0);
12537
+ if (mask) {
12538
+ for (size_t j = 0; j < 4; ++j) {
12539
+ if (!validate_float(f[i + j], i + j)) {
12540
+ return false;
12541
+ }
12542
+ }
12543
+ GGML_UNREACHABLE();
12544
+ }
12545
+ }
12546
+ #endif
12547
+ for (; i < nb; ++i) {
12548
+ if (!validate_float(f[i], i)) {
12549
+ return false;
12550
+ }
12551
+ }
12552
+ } break;
12553
+ case GGML_TYPE_F64:
12554
+ {
12555
+ const double * f = (const double *) data;
12556
+ for (size_t i = 0; i < nb; ++i) {
12557
+ if (!validate_float(f[i], i)) {
12558
+ return false;
12559
+ }
12560
+ }
12561
+ } break;
12562
+ case GGML_TYPE_Q4_0:
12563
+ {
12564
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
12565
+ } break;
12566
+ case GGML_TYPE_Q4_1:
12567
+ {
12568
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);
12569
+ } break;
12570
+ case GGML_TYPE_Q5_0:
12571
+ {
12572
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);
12573
+ } break;
12574
+ case GGML_TYPE_Q5_1:
12575
+ {
12576
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);
12577
+ } break;
12578
+ case GGML_TYPE_Q8_0:
12579
+ {
12580
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
12581
+ } break;
12582
+ case GGML_TYPE_Q2_K:
12583
+ {
12584
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
12585
+ } break;
12586
+ case GGML_TYPE_Q3_K:
12587
+ {
12588
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);
12589
+ } break;
12590
+ case GGML_TYPE_Q4_K:
12591
+ {
12592
+ #ifdef GGML_QKK_64
12593
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]);
12594
+ #else
12595
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
12596
+ #endif
12597
+ } break;
12598
+ case GGML_TYPE_Q5_K:
12599
+ {
12600
+ #ifdef GGML_QKK_64
12601
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb);
12602
+ #else
12603
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
12604
+ #endif
12605
+ } break;
12606
+ case GGML_TYPE_Q6_K:
12607
+ {
12608
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);
12609
+ } break;
12610
+ case GGML_TYPE_Q8_K:
12611
+ {
12612
+ const block_q8_K * q = (const block_q8_K *) data;
12613
+ for (size_t i = 0; i < nb; ++i) {
12614
+ if (!validate_float(q[i].d, i)) {
12615
+ return false;
12616
+ }
12617
+ }
12618
+ } break;
12619
+ case GGML_TYPE_IQ1_S:
12620
+ {
12621
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
12622
+ } break;
12623
+ case GGML_TYPE_IQ1_M:
12624
+ {
12625
+ const block_iq1_m * q = (const block_iq1_m *) data;
12626
+ for (size_t i = 0; i < nb; ++i) {
12627
+ #if QK_K == 64
12628
+ if (!validate_fp16(q[i].d, i)) {
12629
+ return false;
12630
+ }
12631
+ #else
12632
+ iq1m_scale_t scale;
12633
+ const uint16_t * sc = (const uint16_t *)q[i].scales;
12634
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
12635
+ if (!validate_fp16(scale.f16, i)) {
12636
+ return false;
12637
+ }
12638
+ #endif
12639
+ }
12640
+ } break;
12641
+ case GGML_TYPE_IQ2_XXS:
12642
+ {
12643
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);
12644
+ } break;
12645
+ case GGML_TYPE_IQ2_XS:
12646
+ {
12647
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);
12648
+ } break;
12649
+ case GGML_TYPE_IQ2_S:
12650
+ {
12651
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);
12652
+ } break;
12653
+ case GGML_TYPE_IQ3_XXS:
12654
+ {
12655
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);
12656
+ } break;
12657
+
12658
+ case GGML_TYPE_IQ3_S:
12659
+ {
12660
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
12661
+ } break;
12662
+ case GGML_TYPE_IQ4_XS:
12663
+ #if QK_K != 64
12664
+ {
12665
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
12666
+ } break;
12667
+ #endif
12668
+ // with QK_K == 64, iq4_xs is iq4_nl
12669
+ case GGML_TYPE_IQ4_NL:
12670
+ {
12671
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
12672
+ } break;
12673
+ case GGML_TYPE_I8:
12674
+ case GGML_TYPE_I16:
12675
+ case GGML_TYPE_I32:
12676
+ case GGML_TYPE_I64:
12677
+ // nothing to validate
12678
+ break;
12679
+ default:
12680
+ {
12681
+ fprintf(stderr, "%s: invalid type %d\n", __func__, type);
12682
+ return false;
12683
+ }
12684
+ }
12685
+
12686
+ return true;
12687
+ }
@@ -8330,24 +8330,26 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
8330
8330
  const int blocks_per_row = ncols / qk;
8331
8331
  const int blocks_per_warp = vdr * WARP_SIZE / qi;
8332
8332
 
8333
- // partial sum for each thread
8333
+ const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block
8334
+
8335
+ // partial sum for each thread
8334
8336
  float tmp = 0.0f;
8335
8337
 
8336
8338
  const block_q_t * x = (const block_q_t *) vx;
8337
8339
  const block_q8_1 * y = (const block_q8_1 *) vy;
8338
8340
 
8339
- for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8341
+ for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row;
8340
8342
  i += blocks_per_warp) {
8341
- const int ibx = row*blocks_per_row + i; // x block index
8343
+ const int ibx = row * blocks_per_row + i; // x block index
8342
8344
 
8343
- const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8345
+ const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
8344
8346
 
8345
- const int iqs =
8346
- vdr *
8347
- (item_ct1.get_local_id(2) %
8348
- (qi / vdr)); // x block quant index when casting the quants to int
8347
+ const int iqs =
8348
+ vdr *
8349
+ (item_ct1.get_local_id(2) -
8350
+ i * qi_vdr); // x block quant index when casting the quants to int
8349
8351
 
8350
- tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
8352
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
8351
8353
  }
8352
8354
 
8353
8355
  // sum up partial sums and write back result
@@ -13416,11 +13418,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
13416
13418
  version += std::to_string(prop.get_minor_version());
13417
13419
 
13418
13420
  device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
13421
+ std::string name = std::string(prop.get_name());
13422
+ name = std::regex_replace(name, std::regex("\\(R\\)"), "");
13423
+ name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
13419
13424
 
13420
- fprintf(stderr, "|%2d|%18s|%45s|%10s|%11d|%8d|%7d|%15lu|\n", id, device_type.c_str(),
13421
- prop.get_name(), version.c_str(), prop.get_max_compute_units(),
13425
+ auto global_mem_size = prop.get_global_mem_size()/1000000;
13426
+
13427
+ fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
13428
+ name.c_str(), version.c_str(), prop.get_max_compute_units(),
13422
13429
  prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
13423
- prop.get_global_mem_size());
13430
+ global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
13424
13431
  }
13425
13432
 
13426
13433
  void ggml_backend_sycl_print_sycl_devices() {
@@ -13428,9 +13435,10 @@ void ggml_backend_sycl_print_sycl_devices() {
13428
13435
  int device_count = dpct::dev_mgr::instance().device_count();
13429
13436
  std::map<std::string, size_t> DeviceNums;
13430
13437
  fprintf(stderr, "found %d SYCL devices:\n", device_count);
13431
- fprintf(stderr, "| | | |Compute |Max compute|Max work|Max sub| |\n");
13432
- fprintf(stderr, "|ID| Device Type| Name|capability|units |group |group |Global mem size|\n");
13433
- fprintf(stderr, "|--|------------------|---------------------------------------------|----------|-----------|--------|-------|---------------|\n");
13438
+ fprintf(stderr, "| | | | |Max | |Max |Global | |\n");
13439
+ fprintf(stderr, "| | | | |compute|Max work|sub |mem | |\n");
13440
+ fprintf(stderr, "|ID| Device Type| Name|Version|units |group |group|size | Driver version|\n");
13441
+ fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n");
13434
13442
  for (int id = 0; id < device_count; ++id) {
13435
13443
  sycl::device device = dpct::dev_mgr::instance().get_device(id);
13436
13444
  sycl::backend backend = device.get_backend();
@@ -14738,7 +14746,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14738
14746
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
14739
14747
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
14740
14748
 
14749
+ const ggml_tensor * src2 = dst->src[2];
14750
+
14751
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
14752
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
14741
14753
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
14754
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
14742
14755
 
14743
14756
  const int64_t ne00 = src0->ne[0];
14744
14757
  const int64_t nrows_x = ggml_nrows(src0);
@@ -14754,7 +14767,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14754
14767
  float * src2_dd = nullptr;
14755
14768
  sycl_pool_alloc<float> src2_f;
14756
14769
 
14757
- ggml_tensor * src2 = dst->src[2];
14758
14770
  const bool use_src2 = src2 != nullptr;
14759
14771
 
14760
14772
  if (use_src2) {