@fugood/llama.node 1.4.2 → 1.4.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/CMakeLists.txt +1 -1
  2. package/lib/binding.js +3 -0
  3. package/lib/binding.ts +10 -0
  4. package/lib/index.js +9 -0
  5. package/lib/index.ts +10 -0
  6. package/package.json +15 -15
  7. package/scripts/llama.cpp.patch +25 -11
  8. package/src/LlamaContext.cpp +24 -0
  9. package/src/LlamaContext.h +3 -0
  10. package/src/llama.cpp/CMakeLists.txt +21 -6
  11. package/src/llama.cpp/common/CMakeLists.txt +6 -0
  12. package/src/llama.cpp/common/arg.cpp +83 -22
  13. package/src/llama.cpp/common/chat-parser.cpp +40 -0
  14. package/src/llama.cpp/common/chat-peg-parser.cpp +110 -0
  15. package/src/llama.cpp/common/chat-peg-parser.h +105 -0
  16. package/src/llama.cpp/common/chat.cpp +40 -29
  17. package/src/llama.cpp/common/chat.h +10 -1
  18. package/src/llama.cpp/common/common.cpp +70 -7
  19. package/src/llama.cpp/common/common.h +23 -5
  20. package/src/llama.cpp/common/download.cpp +18 -8
  21. package/src/llama.cpp/common/download.h +3 -1
  22. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  23. package/src/llama.cpp/common/log.cpp +18 -27
  24. package/src/llama.cpp/common/log.h +19 -12
  25. package/src/llama.cpp/common/peg-parser.cpp +1712 -0
  26. package/src/llama.cpp/common/peg-parser.h +459 -0
  27. package/src/llama.cpp/common/unicode.cpp +64 -0
  28. package/src/llama.cpp/common/unicode.h +22 -0
  29. package/src/llama.cpp/ggml/CMakeLists.txt +52 -48
  30. package/src/llama.cpp/ggml/include/ggml-rpc.h +1 -2
  31. package/src/llama.cpp/ggml/include/ggml-zendnn.h +22 -0
  32. package/src/llama.cpp/ggml/include/ggml.h +29 -2
  33. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -4
  34. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  35. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +0 -2
  36. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -13
  37. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  38. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +51 -125
  39. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  40. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +98 -12
  41. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  42. package/src/llama.cpp/src/llama-arch.cpp +30 -1
  43. package/src/llama.cpp/src/llama-arch.h +3 -0
  44. package/src/llama.cpp/src/llama-graph.cpp +3 -6
  45. package/src/llama.cpp/src/llama-hparams.h +2 -2
  46. package/src/llama.cpp/src/llama-impl.h +1 -1
  47. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  48. package/src/llama.cpp/src/llama-model.cpp +54 -6
  49. package/src/llama.cpp/src/llama-quant.cpp +0 -29
  50. package/src/llama.cpp/src/llama-vocab.cpp +1 -2
  51. package/src/llama.cpp/src/models/deepseek2.cpp +18 -0
  52. package/src/llama.cpp/src/models/mistral3.cpp +160 -0
  53. package/src/llama.cpp/src/models/models.h +4 -0
  54. package/src/llama.cpp/src/unicode.cpp +2 -2
@@ -175,11 +175,6 @@ option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requi
175
175
  set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
176
176
  set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
177
177
 
178
-
179
- if (MINGW)
180
- set(GGML_WIN_VER "0xA00" CACHE STRING "ggml: Windows version")
181
- endif()
182
-
183
178
  # ggml core
184
179
  set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
185
180
  option(GGML_CPU "ggml: enable CPU backend" ON)
@@ -226,7 +221,7 @@ option(GGML_WEBGPU "ggml: use WebGPU"
226
221
  option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
227
222
  option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)
228
223
  option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
229
-
224
+ option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON)
230
225
  option(GGML_ZDNN "ggml: use zDNN" OFF)
231
226
  option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
232
227
  option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
@@ -258,6 +253,9 @@ option(GGML_HEXAGON "ggml: enable Hexagon backend"
258
253
  # toolchain for vulkan-shaders-gen
259
254
  set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
260
255
 
256
+ option(GGML_ZENDNN "ggml: use ZenDNN" OFF)
257
+ option(ZENDNN_ROOT "ggml: path to ZenDNN installation" "")
258
+
261
259
  # extra artifacts
262
260
  option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
263
261
  option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
@@ -319,6 +317,7 @@ set(GGML_PUBLIC_HEADERS
319
317
  include/ggml-sycl.h
320
318
  include/ggml-vulkan.h
321
319
  include/ggml-webgpu.h
320
+ include/ggml-zendnn.h
322
321
  include/gguf.h)
323
322
 
324
323
  set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
@@ -408,62 +407,67 @@ if (MSVC)
408
407
  /wd4996 # Disable POSIX deprecation warnings
409
408
  /wd4702 # Unreachable code warnings
410
409
  )
411
- function(disable_msvc_warnings target_name)
410
+ set(MSVC_COMPILE_OPTIONS
411
+ "$<$<COMPILE_LANGUAGE:C>:/utf-8>"
412
+ "$<$<COMPILE_LANGUAGE:CXX>:/utf-8>"
413
+ )
414
+ function(configure_msvc_target target_name)
412
415
  if(TARGET ${target_name})
413
416
  target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})
417
+ target_compile_options(${target_name} PRIVATE ${MSVC_COMPILE_OPTIONS})
414
418
  endif()
415
419
  endfunction()
416
420
 
417
- disable_msvc_warnings(ggml-base)
418
- disable_msvc_warnings(ggml)
419
- disable_msvc_warnings(ggml-cpu)
420
- disable_msvc_warnings(ggml-cpu-x64)
421
- disable_msvc_warnings(ggml-cpu-sse42)
422
- disable_msvc_warnings(ggml-cpu-sandybridge)
423
- disable_msvc_warnings(ggml-cpu-haswell)
424
- disable_msvc_warnings(ggml-cpu-skylakex)
425
- disable_msvc_warnings(ggml-cpu-icelake)
426
- disable_msvc_warnings(ggml-cpu-alderlake)
421
+ configure_msvc_target(ggml-base)
422
+ configure_msvc_target(ggml)
423
+ configure_msvc_target(ggml-cpu)
424
+ configure_msvc_target(ggml-cpu-x64)
425
+ configure_msvc_target(ggml-cpu-sse42)
426
+ configure_msvc_target(ggml-cpu-sandybridge)
427
+ configure_msvc_target(ggml-cpu-haswell)
428
+ configure_msvc_target(ggml-cpu-skylakex)
429
+ configure_msvc_target(ggml-cpu-icelake)
430
+ configure_msvc_target(ggml-cpu-alderlake)
427
431
 
428
432
  if (GGML_BUILD_EXAMPLES)
429
- disable_msvc_warnings(common-ggml)
430
- disable_msvc_warnings(common)
433
+ configure_msvc_target(common-ggml)
434
+ configure_msvc_target(common)
431
435
 
432
- disable_msvc_warnings(mnist-common)
433
- disable_msvc_warnings(mnist-eval)
434
- disable_msvc_warnings(mnist-train)
436
+ configure_msvc_target(mnist-common)
437
+ configure_msvc_target(mnist-eval)
438
+ configure_msvc_target(mnist-train)
435
439
 
436
- disable_msvc_warnings(gpt-2-ctx)
437
- disable_msvc_warnings(gpt-2-alloc)
438
- disable_msvc_warnings(gpt-2-backend)
439
- disable_msvc_warnings(gpt-2-sched)
440
- disable_msvc_warnings(gpt-2-quantize)
441
- disable_msvc_warnings(gpt-2-batched)
440
+ configure_msvc_target(gpt-2-ctx)
441
+ configure_msvc_target(gpt-2-alloc)
442
+ configure_msvc_target(gpt-2-backend)
443
+ configure_msvc_target(gpt-2-sched)
444
+ configure_msvc_target(gpt-2-quantize)
445
+ configure_msvc_target(gpt-2-batched)
442
446
 
443
- disable_msvc_warnings(gpt-j)
444
- disable_msvc_warnings(gpt-j-quantize)
447
+ configure_msvc_target(gpt-j)
448
+ configure_msvc_target(gpt-j-quantize)
445
449
 
446
- disable_msvc_warnings(magika)
447
- disable_msvc_warnings(yolov3-tiny)
448
- disable_msvc_warnings(sam)
450
+ configure_msvc_target(magika)
451
+ configure_msvc_target(yolov3-tiny)
452
+ configure_msvc_target(sam)
449
453
 
450
- disable_msvc_warnings(simple-ctx)
451
- disable_msvc_warnings(simple-backend)
454
+ configure_msvc_target(simple-ctx)
455
+ configure_msvc_target(simple-backend)
452
456
  endif()
453
457
 
454
458
  if (GGML_BUILD_TESTS)
455
- disable_msvc_warnings(test-mul-mat)
456
- disable_msvc_warnings(test-arange)
457
- disable_msvc_warnings(test-backend-ops)
458
- disable_msvc_warnings(test-cont)
459
- disable_msvc_warnings(test-conv-transpose)
460
- disable_msvc_warnings(test-conv-transpose-1d)
461
- disable_msvc_warnings(test-conv1d)
462
- disable_msvc_warnings(test-conv2d)
463
- disable_msvc_warnings(test-conv2d-dw)
464
- disable_msvc_warnings(test-customop)
465
- disable_msvc_warnings(test-dup)
466
- disable_msvc_warnings(test-opt)
467
- disable_msvc_warnings(test-pool)
459
+ configure_msvc_target(test-mul-mat)
460
+ configure_msvc_target(test-arange)
461
+ configure_msvc_target(test-backend-ops)
462
+ configure_msvc_target(test-cont)
463
+ configure_msvc_target(test-conv-transpose)
464
+ configure_msvc_target(test-conv-transpose-1d)
465
+ configure_msvc_target(test-conv1d)
466
+ configure_msvc_target(test-conv2d)
467
+ configure_msvc_target(test-conv2d-dw)
468
+ configure_msvc_target(test-customop)
469
+ configure_msvc_target(test-dup)
470
+ configure_msvc_target(test-opt)
471
+ configure_msvc_target(test-pool)
468
472
  endif ()
469
473
  endif()
@@ -1,6 +1,5 @@
1
1
  #pragma once
2
2
 
3
- #include "ggml.h"
4
3
  #include "ggml-backend.h"
5
4
 
6
5
  #ifdef __cplusplus
@@ -8,7 +7,7 @@ extern "C" {
8
7
  #endif
9
8
 
10
9
  #define RPC_PROTO_MAJOR_VERSION 3
11
- #define RPC_PROTO_MINOR_VERSION 5
10
+ #define RPC_PROTO_MINOR_VERSION 6
12
11
  #define RPC_PROTO_PATCH_VERSION 0
13
12
  #define GGML_RPC_MAX_SERVERS 16
14
13
 
@@ -0,0 +1,22 @@
1
+ #pragma once
2
+
3
+ #include "ggml-backend.h"
4
+ #include "ggml.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ // backend API
11
+ GGML_BACKEND_API ggml_backend_t ggml_backend_zendnn_init(void);
12
+
13
+ GGML_BACKEND_API bool ggml_backend_is_zendnn(ggml_backend_t backend);
14
+
15
+ // number of threads used for zendnn operations
16
+ GGML_BACKEND_API void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads);
17
+
18
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zendnn_reg(void);
19
+
20
+ #ifdef __cplusplus
21
+ }
22
+ #endif
@@ -204,6 +204,10 @@
204
204
  # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
205
205
  #endif
206
206
 
207
+ #if defined(_WIN32) && !defined(_WIN32_WINNT)
208
+ # define _WIN32_WINNT 0x0A00
209
+ #endif
210
+
207
211
  #include <stdbool.h>
208
212
  #include <stddef.h>
209
213
  #include <stdint.h>
@@ -2148,7 +2152,8 @@ extern "C" {
2148
2152
  };
2149
2153
 
2150
2154
  enum ggml_scale_flag {
2151
- GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
2155
+ GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8),
2156
+ GGML_SCALE_FLAG_ANTIALIAS = (1 << 9),
2152
2157
  };
2153
2158
 
2154
2159
  // interpolate
@@ -2191,6 +2196,15 @@ extern "C" {
2191
2196
  int p2,
2192
2197
  int p3);
2193
2198
 
2199
+ // pad each dimension with values on the other side of the torus (looping around)
2200
+ GGML_API struct ggml_tensor * ggml_pad_circular(
2201
+ struct ggml_context * ctx,
2202
+ struct ggml_tensor * a,
2203
+ int p0,
2204
+ int p1,
2205
+ int p2,
2206
+ int p3);
2207
+
2194
2208
  GGML_API struct ggml_tensor * ggml_pad_ext(
2195
2209
  struct ggml_context * ctx,
2196
2210
  struct ggml_tensor * a,
@@ -2204,6 +2218,19 @@ extern "C" {
2204
2218
  int rp3
2205
2219
  );
2206
2220
 
2221
+ // pad each dimension with values on the other side of the torus (looping around)
2222
+ GGML_API struct ggml_tensor * ggml_pad_ext_circular(
2223
+ struct ggml_context * ctx,
2224
+ struct ggml_tensor * a,
2225
+ int lp0,
2226
+ int rp0,
2227
+ int lp1,
2228
+ int rp1,
2229
+ int lp2,
2230
+ int rp2,
2231
+ int lp3,
2232
+ int rp3);
2233
+
2207
2234
  // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
2208
2235
  GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
2209
2236
  struct ggml_context * ctx,
@@ -2278,7 +2305,7 @@ extern "C" {
2278
2305
  float stop,
2279
2306
  float step);
2280
2307
 
2281
- #define GGML_KQ_MASK_PAD 64
2308
+ #define GGML_KQ_MASK_PAD 1
2282
2309
 
2283
2310
  // q: [n_embd_k, n_batch, n_head, ne3 ]
2284
2311
  // k: [n_embd_k, n_kv, n_head_kv, ne3 ]
@@ -127,10 +127,6 @@ if (NOT MSVC)
127
127
  endif()
128
128
  endif()
129
129
 
130
- if (MINGW)
131
- add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
132
- endif()
133
-
134
130
  #
135
131
  # POSIX conformance
136
132
  #
@@ -444,6 +440,7 @@ ggml_add_backend(WebGPU)
444
440
  ggml_add_backend(zDNN)
445
441
  ggml_add_backend(OpenCL)
446
442
  ggml_add_backend(Hexagon)
443
+ ggml_add_backend(ZenDNN)
447
444
 
448
445
  foreach (target ggml-base ggml)
449
446
  target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
@@ -8,6 +8,10 @@
8
8
  #include <sys/sysctl.h>
9
9
  #endif
10
10
 
11
+ #if !defined(HWCAP2_SVE2)
12
+ #define HWCAP2_SVE2 (1 << 1)
13
+ #endif
14
+
11
15
  #if !defined(HWCAP2_I8MM)
12
16
  #define HWCAP2_I8MM (1 << 13)
13
17
  #endif
@@ -505,7 +505,6 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
505
505
  constexpr int blocklen = 8;
506
506
 
507
507
  assert(n % qk == 0);
508
- assert(nr % 4 == 0);
509
508
  assert(nc % ncols_interleaved == 0);
510
509
 
511
510
  UNUSED(nb);
@@ -645,7 +644,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
645
644
  constexpr int blocklen = 8;
646
645
 
647
646
  assert(n % qk == 0);
648
- assert(nr % 4 == 0);
649
647
  assert(nc % ncols_interleaved == 0);
650
648
 
651
649
  UNUSED(nb);
@@ -683,22 +683,14 @@ bool ggml_is_numa(void) {
683
683
  }
684
684
 
685
685
  #if defined(__ARM_ARCH)
686
-
687
- #if defined(__linux__) && defined(__aarch64__)
688
- #include <sys/auxv.h>
689
- #endif
690
-
691
- static void ggml_init_arm_arch_features(void) {
692
686
  #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
693
- #if defined(__linux__)
694
- ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
687
+ #include <arm_sve.h>
688
+ static void ggml_init_arm_arch_features(void) {
689
+ ggml_arm_arch_features.sve_cnt = svcntb();
690
+ }
695
691
  #else
696
- // TODO: add support of SVE for non-linux systems
697
- #error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
692
+ static void ggml_init_arm_arch_features(void) {}
698
693
  #endif
699
- #endif
700
- }
701
-
702
694
  #endif // __ARM_ARCH
703
695
 
704
696
  struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
@@ -2706,6 +2698,11 @@ struct ggml_cplan ggml_graph_plan(
2706
2698
  n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS;
2707
2699
  }
2708
2700
 
2701
+ #if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__)
2702
+ // Emscripten without pthreads support can only use a single thread
2703
+ n_threads = 1;
2704
+ #endif
2705
+
2709
2706
  size_t work_size = 0;
2710
2707
 
2711
2708
  struct ggml_cplan cplan;
@@ -0,0 +1,333 @@
1
+ #pragma once
2
+
3
+ typedef vector unsigned char vec_t;
4
+ typedef __vector_quad acc_t;
5
+
6
+ template <typename TA>
7
+ class tinyBLAS_Q0_PPC {
8
+ public:
9
+ tinyBLAS_Q0_PPC(int64_t k,
10
+ const TA *A, int64_t lda,
11
+ const block_q8_0 *B, int64_t ldb,
12
+ float *C, int64_t ldc,
13
+ int ith, int nth);
14
+
15
+ void matmul(int64_t m, int64_t n);
16
+ void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
17
+ vec_t A_pack[mc*kc*2];
18
+ vec_t B_pack[nc*kc*2];
19
+ int comparray[mc*kc];
20
+ constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
21
+ int64_t ytiles = m / mc;
22
+ int64_t xtiles = n / nc;
23
+ int64_t tiles = xtiles * ytiles;
24
+ int64_t duty = (tiles + nth - 1) / nth;
25
+ int64_t start = duty * ith;
26
+ int64_t end = start + duty;
27
+ if (end > tiles) {
28
+ end = tiles;
29
+ }
30
+ for (int64_t job = start; job < end; ++job) {
31
+ int64_t ii = (job / xtiles) * mc;
32
+ int64_t jj = (job % xtiles) * nc;
33
+ for (int64_t kk = 0; kk < k; kk += kc) {
34
+ if constexpr(is_Ablock_q4) {
35
+ packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
36
+ } else {
37
+ packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
38
+ }
39
+ packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
40
+ KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
41
+ }
42
+ }
43
+ }
44
+
45
+ private:
46
+ inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
47
+ for (int I = 0; I < RM; I++) {
48
+ for (int J = 0; J < RN; J++) {
49
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
50
+ }
51
+ }
52
+ }
53
+
54
+ inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
55
+ for (int I = 0; I < RM; I++) {
56
+ for (int J = 0; J < RN; J++) {
57
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
58
+ *c_ptr += *((float*)&fin_res[idx+I]+J);
59
+ }
60
+ }
61
+ }
62
+
63
+ template<typename ArrayType>
64
+ inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
65
+ vector signed int vec_C[4];
66
+ vector float CA[4] = {0};
67
+ vector float res[4] = {0};
68
+ __builtin_mma_disassemble_acc(vec_C, ACC);
69
+ for (int i = 0; i < 4; i++) {
70
+ CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
71
+ res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
72
+ fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
73
+ }
74
+ }
75
+
76
+ inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
77
+ const vector signed char lowMask = vec_splats((signed char)0xF);
78
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
79
+ const vector signed char v8 = vec_splats((signed char)0x8);
80
+ vector signed int vsum = {0};
81
+ vector signed int vsum2 = {0};
82
+ c[0] = vec_and(c[1], lowMask);
83
+ c[1] = vec_sr(c[1], v4);
84
+ c[0] = vec_sub(c[0], v8);
85
+ c[1] = vec_sub(c[1], v8);
86
+ vsum = vec_sum4s(c[0], vsum);
87
+ vsum2 = vec_sum4s(c[1], vsum2);
88
+ vsum = vec_add(vsum, vsum2);
89
+ *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
90
+ }
91
+
92
+ template <typename V1, typename V2>
93
+ inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
94
+ vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
95
+ vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
96
+ vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
97
+ vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
98
+ V2 t1, t2, t3, t4, t5, t6, t7, t8;
99
+ vector unsigned char xor_vector;
100
+ uint8_t flip_vec = 0x80;
101
+ xor_vector = vec_splats(flip_vec);
102
+ t1 = vec_perm(s1, s2, swiz1);
103
+ t2 = vec_perm(s1, s2, swiz2);
104
+ t3 = vec_perm(s3, s4, swiz1);
105
+ t4 = vec_perm(s3, s4, swiz2);
106
+ t5 = vec_perm(t1, t3, swiz3);
107
+ t6 = vec_perm(t1, t3, swiz4);
108
+ t7 = vec_perm(t2, t4, swiz3);
109
+ t8 = vec_perm(t2, t4, swiz4);
110
+ if (flip == true) {
111
+ t5 = vec_xor(t5, xor_vector);
112
+ t6 = vec_xor(t6, xor_vector);
113
+ t7 = vec_xor(t7, xor_vector);
114
+ t8 = vec_xor(t8, xor_vector);
115
+ }
116
+ vec_xst(t5, 0, vecOffset);
117
+ vec_xst(t6, 0, vecOffset+16);
118
+ vec_xst(t7, 0, vecOffset+32);
119
+ vec_xst(t8, 0, vecOffset+48);
120
+ }
121
+
122
+ template<int RM, int RN>
123
+ inline void kernel(int64_t ii, int64_t jj) {
124
+ if constexpr(RM == 4 && RN == 8) {
125
+ KERNEL_4x8(ii,jj);
126
+ } else if constexpr(RM == 8 && RN == 4) {
127
+ KERNEL_8x4(ii,jj);
128
+ } else if constexpr(RM == 8 && RN == 8) {
129
+ KERNEL_8x8(ii,jj);
130
+ } else {
131
+ assert(false && "RN/RM values not supported");
132
+ }
133
+ }
134
+ template<int size>
135
+ void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray);
136
+ template<typename VA, typename VB>
137
+ void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
138
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
139
+ void KERNEL_4x8(int64_t ii, int64_t jj);
140
+ void KERNEL_8x4(int64_t ii, int64_t jj);
141
+ void KERNEL_8x8(int64_t ii, int64_t jj);
142
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
143
+ template <int RM, int RN>
144
+ void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
145
+
146
+ void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
147
+ for (int I = 0; I<8; I++) {
148
+ float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
149
+ for (int J = 0; J<4; J++) {
150
+ *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
151
+ *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
152
+ }
153
+ }
154
+ }
155
+
156
+ inline void process_q8_elements(const int8_t *qs, int *ca) {
157
+ vector signed char c1 = vec_xl(0, qs);
158
+ vector signed char c2 = vec_xl(16, qs);
159
+ vector signed int vsum1 = {0};
160
+ vector signed int vsum2 = {0};
161
+ vsum1 = vec_sum4s(c1, vsum1);
162
+ vsum2 = vec_sum4s(c2, vsum2);
163
+ vector signed int vsum = vec_add(vsum1, vsum2);
164
+ *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
165
+ }
166
+
167
+ template<typename VA, typename VB>
168
+ void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
169
+ int64_t i, j;
170
+ block_q8_0 *aoffset = NULL;
171
+ VA *vecOffset = NULL;
172
+ block_q8_0* aoffsets[8];
173
+ __vector_pair arr[8];
174
+ VB c[8][2] = {0};
175
+ VB c1[8] = {0}; VB c2[8] = {0};
176
+ aoffset = const_cast<block_q8_0*>(a);
177
+ vecOffset = vec;
178
+ j = (rows >> 3);
179
+ int index = 0;
180
+ if (j > 0) {
181
+ do {
182
+ for (int it = 0; it < 8; it++)
183
+ aoffsets[it] = aoffset + it*lda;
184
+ aoffset += 8 * lda;
185
+ for (int blk = 0; blk < kc; blk++) {
186
+ for (int it = 0; it < 8; it++) {
187
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
188
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
189
+ c1[it] = c[it][0];
190
+ c2[it] = c[it][1];
191
+ if (comparray){
192
+ process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
193
+ }
194
+ }
195
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
196
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
197
+ vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
198
+ vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
199
+ vecOffset += 256;
200
+ }
201
+ j--;
202
+ index += 8*kc;
203
+ } while(j > 0);
204
+ }
205
+
206
+ }
207
+
208
+ void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
209
+ int64_t i, j;
210
+ TA *aoffset = NULL;
211
+ int8_t *vecOffset = NULL;
212
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
213
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
214
+ vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
215
+ vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
216
+ aoffset = const_cast<TA*>(a);
217
+ vecOffset = vec;
218
+ int index = 0;
219
+ j = (rows >> 3);
220
+ if (j > 0) {
221
+ do {
222
+ aoffset1 = aoffset;
223
+ aoffset2 = aoffset1 + lda;
224
+ aoffset3 = aoffset2 + lda;
225
+ aoffset4 = aoffset3 + lda;
226
+ aoffset5 = aoffset4 + lda;
227
+ aoffset6 = aoffset5 + lda;
228
+ aoffset7 = aoffset6 + lda;
229
+ aoffset8 = aoffset7 + lda;
230
+ aoffset += 8 * lda;
231
+ for (int blk = 0; blk < kc; blk++) {
232
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
233
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
234
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
235
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
236
+ c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
237
+ c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
238
+ c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
239
+ c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
240
+
241
+ process_q4_elements(c1, &comparray[index + 8*blk+0]);
242
+ process_q4_elements(c2, &comparray[index + 8*blk+1]);
243
+ process_q4_elements(c3, &comparray[index + 8*blk+2]);
244
+ process_q4_elements(c4, &comparray[index + 8*blk+3]);
245
+ process_q4_elements(c5, &comparray[index + 8*blk+4]);
246
+ process_q4_elements(c6, &comparray[index + 8*blk+5]);
247
+ process_q4_elements(c7, &comparray[index + 8*blk+6]);
248
+ process_q4_elements(c8, &comparray[index + 8*blk+7]);
249
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
250
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
251
+ vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
252
+ vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
253
+ vecOffset += 256;
254
+ }
255
+ j--;
256
+ index += 8*kc;
257
+ } while (j > 0);
258
+ }
259
+ }
260
+
261
+ void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
262
+ acc_t acc[8];
263
+ for (int i = 0; i < mc ; i += 8) {
264
+ for (int j = 0; j < nc; j += 8) {
265
+ vector float fin_res[16] = {0};
266
+ vector float vs[16] = {0};
267
+ for (int64_t kk = 0; kk < kc; kk+=2) {
268
+ for (int x = 0; x < 8; x++) {
269
+ __builtin_mma_xxsetaccz(&acc[x]);
270
+ }
271
+ int A_block_idx = (i/8)*(16*kc) + kk*16;
272
+ int B_block_idx = (j/8)*(16*kc)+ kk*16;
273
+ vec_t *A_block = &vec_A[A_block_idx];
274
+ vec_t *B_block = &vec_B[B_block_idx];
275
+ for (int x = 0; x < 8; x++) {
276
+ __builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
277
+ __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
278
+ __builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
279
+ __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
280
+ }
281
+ compute_scale(ii+i, jj+j, l+kk, vs);
282
+ int c_index = (i/8)*(8*kc)+ kk*8;
283
+ int* c_block = &comparray[c_index];
284
+ compute(&acc[0], 0, 0, c_block, vs, fin_res);
285
+ compute(&acc[1], 4, 4, c_block, vs, fin_res);
286
+ compute(&acc[2], 0, 8, c_block, vs, fin_res);
287
+ compute(&acc[3], 4, 12, c_block, vs, fin_res);
288
+
289
+ A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
290
+ B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
291
+ A_block = &vec_A[A_block_idx];
292
+ B_block = &vec_B[B_block_idx];
293
+ for (int x = 0; x < 8; x++) {
294
+ __builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]);
295
+ __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
296
+ __builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]);
297
+ __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]);
298
+ }
299
+ compute_scale(ii+i, jj+j, l+kk+1, vs);
300
+ c_index = (i/8)*(8*kc)+ (kk+1)*8;
301
+ c_block = &comparray[c_index];
302
+ compute(&acc[4], 0, 0, c_block, vs, fin_res);
303
+ compute(&acc[5], 4, 4, c_block, vs, fin_res);
304
+ compute(&acc[6], 0, 8, c_block, vs, fin_res);
305
+ compute(&acc[7], 4, 12, c_block, vs, fin_res);
306
+
307
+ }
308
+ if (l == 0) {
309
+ save_res(ii+i, jj+j, 0, fin_res);
310
+ save_res(ii+i+4, jj+j, 4, fin_res);
311
+ save_res(ii+i, jj+j+4, 8, fin_res);
312
+ save_res(ii+i+4, jj+j+4, 12, fin_res);
313
+ } else {
314
+ add_save_res(ii+i, jj+j, 0, fin_res);
315
+ add_save_res(ii+i+4, jj+j, 4, fin_res);
316
+ add_save_res(ii+i, jj+j+4, 8, fin_res);
317
+ add_save_res(ii+i+4, jj+j+4, 12, fin_res);
318
+ }
319
+ }
320
+ }
321
+ }
322
+
323
+ const TA *const A;
324
+ const block_q8_0 *const B;
325
+ float *C;
326
+ const int64_t k;
327
+ int64_t kc;
328
+ const int64_t lda;
329
+ const int64_t ldb;
330
+ const int64_t ldc;
331
+ const int ith;
332
+ const int nth;
333
+ };