@fugood/llama.node 0.3.13 → 0.3.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +89 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/CMakeLists.txt +9 -1
  25. package/src/llama.cpp/cmake/common.cmake +2 -0
  26. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  27. package/src/llama.cpp/common/arg.cpp +132 -13
  28. package/src/llama.cpp/common/chat.cpp +960 -266
  29. package/src/llama.cpp/common/chat.h +135 -0
  30. package/src/llama.cpp/common/common.cpp +33 -174
  31. package/src/llama.cpp/common/common.h +27 -67
  32. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  33. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  34. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  35. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  36. package/src/llama.cpp/common/sampling.cpp +45 -7
  37. package/src/llama.cpp/common/speculative.cpp +10 -9
  38. package/src/llama.cpp/common/speculative.h +1 -1
  39. package/src/llama.cpp/docs/build.md +45 -7
  40. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  41. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +4 -2
  42. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -1
  43. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  44. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  45. package/src/llama.cpp/examples/imatrix/imatrix.cpp +3 -4
  46. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  48. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +5 -5
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  50. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  51. package/src/llama.cpp/examples/llava/clip.h +19 -3
  52. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  53. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  54. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  55. package/src/llama.cpp/examples/lookahead/lookahead.cpp +7 -6
  56. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  57. package/src/llama.cpp/examples/main/main.cpp +79 -34
  58. package/src/llama.cpp/examples/parallel/parallel.cpp +6 -5
  59. package/src/llama.cpp/examples/passkey/passkey.cpp +15 -14
  60. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  61. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  62. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  63. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +196 -108
  67. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  68. package/src/llama.cpp/examples/server/server.cpp +113 -101
  69. package/src/llama.cpp/examples/server/utils.hpp +94 -105
  70. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  74. package/src/llama.cpp/examples/tts/tts.cpp +263 -151
  75. package/src/llama.cpp/ggml/CMakeLists.txt +14 -1
  76. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  77. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  79. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  80. package/src/llama.cpp/ggml/include/ggml.h +29 -1
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -34
  82. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  83. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  84. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  85. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  87. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -7
  88. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  89. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +139 -16
  90. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  91. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1546 -387
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1645 -113
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  102. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  103. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  104. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  105. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  106. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  107. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +242 -0
  108. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -6
  109. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  110. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +315 -138
  111. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  112. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  113. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +5 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +117 -36
  117. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  118. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +147 -16
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  124. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +307 -0
  125. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +262 -746
  127. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  128. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -78
  129. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  130. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +4 -1
  132. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  134. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  135. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +498 -188
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +16 -3
  139. package/src/llama.cpp/ggml/src/ggml.c +93 -5
  140. package/src/llama.cpp/include/llama.h +105 -27
  141. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  142. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  143. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  144. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  145. package/src/llama.cpp/requirements.txt +1 -0
  146. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  147. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  148. package/src/llama.cpp/src/llama-adapter.h +11 -9
  149. package/src/llama.cpp/src/llama-arch.cpp +123 -16
  150. package/src/llama.cpp/src/llama-arch.h +19 -0
  151. package/src/llama.cpp/src/llama-batch.h +2 -2
  152. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  153. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  154. package/src/llama.cpp/src/llama-context.h +214 -77
  155. package/src/llama.cpp/src/llama-cparams.h +1 -0
  156. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  157. package/src/llama.cpp/src/llama-grammar.h +12 -3
  158. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  159. package/src/llama.cpp/src/llama-graph.h +574 -0
  160. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  161. package/src/llama.cpp/src/llama-hparams.h +9 -0
  162. package/src/llama.cpp/src/llama-io.cpp +15 -0
  163. package/src/llama.cpp/src/llama-io.h +35 -0
  164. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  165. package/src/llama.cpp/src/llama-kv-cache.h +178 -109
  166. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  167. package/src/llama.cpp/src/llama-memory.h +21 -0
  168. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  169. package/src/llama.cpp/src/llama-model.cpp +8230 -122
  170. package/src/llama.cpp/src/llama-model.h +34 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  172. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  173. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  174. package/src/llama.cpp/src/llama.cpp +51 -9837
  175. package/src/llama.cpp/tests/test-backend-ops.cpp +247 -112
  176. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  177. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  178. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  179. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  180. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  181. package/src/llama.cpp/common/chat.hpp +0 -55
  182. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  183. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
  184. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -0,0 +1,17 @@
1
+ // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
2
+ // SPDX-License-Identifier: MIT
3
+ //
4
+
5
+ #pragma once
6
+
7
+ #include "ggml-alloc.h"
8
+
9
+ #ifdef __cplusplus
10
+ extern "C" {
11
+ #endif
12
+
13
+ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void);
14
+
15
+ #ifdef __cplusplus
16
+ }
17
+ #endif
@@ -7,7 +7,7 @@ if (CUDAToolkit_FOUND)
7
7
 
8
8
  if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
9
9
  # native == GPUs available at build time
10
- # 52 == Maxwell, lowest CUDA 12 standard
10
+ # 50 == Maxwell, lowest CUDA 12 standard
11
11
  # 60 == P100, FP16 CUDA intrinsics
12
12
  # 61 == Pascal, __dp4a instruction (per-byte integer dot product)
13
13
  # 70 == V100, FP16 tensor cores
@@ -17,7 +17,7 @@ if (CUDAToolkit_FOUND)
17
17
  elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
18
18
  set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
19
19
  else()
20
- set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;80")
20
+ set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80")
21
21
  endif()
22
22
  endif()
23
23
  message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
69
69
  add_compile_definitions(GGML_CUDA_NO_VMM)
70
70
  endif()
71
71
 
72
+ if (NOT GGML_CUDA_FA)
73
+ add_compile_definitions(GGML_CUDA_NO_FA)
74
+ endif()
75
+
72
76
  if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
73
77
  add_compile_definitions(GGML_CUDA_F16)
74
78
  endif()
@@ -98,6 +102,15 @@ if (CUDAToolkit_FOUND)
98
102
 
99
103
  set(CUDA_FLAGS -use_fast_math)
100
104
 
105
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
106
+ # Options are:
107
+ # - none (not recommended)
108
+ # - speed (nvcc's default)
109
+ # - balance
110
+ # - size
111
+ list(APPEND CUDA_FLAGS -compress-mode=${GGML_CUDA_COMPRESSION_MODE})
112
+ endif()
113
+
101
114
  if (GGML_FATAL_WARNINGS)
102
115
  list(APPEND CUDA_FLAGS -Werror all-warnings)
103
116
  endif()
@@ -112,7 +112,7 @@
112
112
  #define cudaGraphExecDestroy hipGraphExecDestroy
113
113
  #define cudaGraphLaunch hipGraphLaunch
114
114
  #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
115
- #define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
115
+ #define cudaGraphExecUpdateResult hipGraphExecUpdateResult
116
116
  #define cudaGraphNodeType hipGraphNodeType
117
117
  #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
118
118
  #define cudaGraphInstantiate hipGraphInstantiate
@@ -129,6 +129,7 @@
129
129
  #define cudaGraph_t hipGraph_t
130
130
  #define cudaStream_t hipStream_t
131
131
  #define cudaSuccess hipSuccess
132
+ #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
132
133
  #define __trap() do { abort(); __builtin_unreachable(); } while(0)
133
134
  #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
134
135
  #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
@@ -119,7 +119,7 @@
119
119
  #define cudaGraphExecDestroy musaGraphExecDestroy
120
120
  #define cudaGraphExec_t musaGraphExec_t
121
121
  #define cudaGraphExecUpdate musaGraphExecUpdate
122
- #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
122
+ #define cudaGraphExecUpdateResult musaGraphExecUpdateResult
123
123
  #define cudaGraphGetNodes musaGraphGetNodes
124
124
  #define cudaGraphInstantiate musaGraphInstantiate
125
125
  #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@@ -132,6 +132,8 @@
132
132
  #define cudaGraph_t musaGraph_t
133
133
  #define cudaKernelNodeParams musaKernelNodeParams
134
134
  #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135
+ #define cudaStreamBeginCapture musaStreamBeginCapture
135
136
  #define cudaStreamEndCapture musaStreamEndCapture
137
+ #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
136
138
 
137
139
  typedef mt_bfloat16 nv_bfloat16;
@@ -39,6 +39,12 @@ endif()
39
39
  find_package(hip REQUIRED)
40
40
  find_package(hipblas REQUIRED)
41
41
  find_package(rocblas REQUIRED)
42
+ if (GGML_HIP_ROCWMMA_FATTN)
43
+ CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
44
+ if (NOT ${FOUND_ROCWMMA})
45
+ message(FATAL_ERROR "rocwmma has not been found")
46
+ endif()
47
+ endif()
42
48
 
43
49
  if (${hip_VERSION} VERSION_LESS 5.5)
44
50
  message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
@@ -107,6 +113,14 @@ if (GGML_HIP_NO_VMM)
107
113
  add_compile_definitions(GGML_HIP_NO_VMM)
108
114
  endif()
109
115
 
116
+ if (GGML_HIP_ROCWMMA_FATTN)
117
+ add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
118
+ endif()
119
+
120
+ if (NOT GGML_CUDA_FA)
121
+ add_compile_definitions(GGML_CUDA_NO_FA)
122
+ endif()
123
+
110
124
  if (CXX_IS_HIPCC)
111
125
  set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
112
126
  target_link_libraries(ggml-hip PRIVATE hip::device)
@@ -16,7 +16,7 @@
16
16
  #include <arm_sve.h>
17
17
  #endif // __ARM_FEATURE_SVE
18
18
 
19
- #if defined(__ARM_NEON) && !defined(__CUDACC__)
19
+ #if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
20
20
  // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
21
21
  //
22
22
  // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
@@ -27,12 +27,12 @@ configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
27
27
  configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
28
28
  configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
29
29
 
30
+ set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
30
31
  if (GGML_METAL_EMBED_LIBRARY)
31
32
  enable_language(ASM)
32
33
 
33
34
  add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
34
35
 
35
- set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
36
36
  set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
37
37
  set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
38
38
 
@@ -88,12 +88,11 @@ else()
88
88
 
89
89
  add_custom_command(
90
90
  OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
91
- COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
92
- COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
93
- COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
91
+ COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
92
+ xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
94
93
  COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
95
94
  COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
96
- DEPENDS ggml-metal.metal ggml-common.h
95
+ DEPENDS ggml-metal.metal ${METALLIB_COMMON}
97
96
  COMMENT "Compiling Metal kernels"
98
97
  )
99
98
 
@@ -285,4 +285,246 @@ typedef struct {
285
285
  float eps;
286
286
  } ggml_metal_kargs_rms_norm;
287
287
 
288
+ typedef struct {
289
+ int32_t ne00;
290
+ int32_t ne00_4;
291
+ uint64_t nb01;
292
+ float eps;
293
+ } ggml_metal_kargs_l2_norm;
294
+
295
+ typedef struct {
296
+ int64_t ne00;
297
+ int64_t ne01;
298
+ int64_t ne02;
299
+ uint64_t nb00;
300
+ uint64_t nb01;
301
+ uint64_t nb02;
302
+ int32_t n_groups;
303
+ float eps;
304
+ } ggml_metal_kargs_group_norm;
305
+
306
+ typedef struct {
307
+ int32_t IC;
308
+ int32_t IL;
309
+ int32_t K;
310
+ int32_t s0;
311
+ uint64_t nb0;
312
+ uint64_t nb1;
313
+ } ggml_metal_kargs_conv_transpose_1d;
314
+
315
+ typedef struct {
316
+ uint64_t ofs0;
317
+ uint64_t ofs1;
318
+ int32_t IW;
319
+ int32_t IH;
320
+ int32_t CHW;
321
+ int32_t s0;
322
+ int32_t s1;
323
+ int32_t p0;
324
+ int32_t p1;
325
+ int32_t d0;
326
+ int32_t d1;
327
+ int32_t N;
328
+ int32_t KH;
329
+ int32_t KW;
330
+ int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
331
+ } ggml_metal_kargs_im2col;
332
+
333
+ typedef struct {
334
+ int64_t ne00;
335
+ int64_t ne01;
336
+ int64_t ne02;
337
+ int64_t ne03;
338
+ uint64_t nb00;
339
+ uint64_t nb01;
340
+ uint64_t nb02;
341
+ uint64_t nb03;
342
+ int64_t ne10;
343
+ int64_t ne11;
344
+ int64_t ne12;
345
+ int64_t ne13;
346
+ uint64_t nb10;
347
+ uint64_t nb11;
348
+ uint64_t nb12;
349
+ uint64_t nb13;
350
+ int64_t ne0;
351
+ int64_t ne1;
352
+ int64_t ne2;
353
+ int64_t ne3;
354
+ uint64_t nb0;
355
+ uint64_t nb1;
356
+ uint64_t nb2;
357
+ uint64_t nb3;
358
+ } ggml_metal_kargs_sum_rows;
359
+
360
+ typedef struct {
361
+ int64_t ne00;
362
+ int64_t ne01;
363
+ int64_t ne02;
364
+ float scale;
365
+ float max_bias;
366
+ float m0;
367
+ float m1;
368
+ uint32_t n_head_log2;
369
+ } ggml_metal_kargs_soft_max;
370
+
371
+ typedef struct {
372
+ int64_t ne00;
373
+ int64_t ne01;
374
+ int n_past;
375
+ } ggml_metal_kargs_diag_mask_inf;
376
+
377
+ typedef struct {
378
+ int64_t ne00;
379
+ int64_t ne01;
380
+ int64_t ne02;
381
+ uint64_t nb00;
382
+ uint64_t nb01;
383
+ uint64_t nb02;
384
+ int64_t ne10;
385
+ int64_t ne11;
386
+ uint64_t nb10;
387
+ uint64_t nb11;
388
+ int64_t ne0;
389
+ int64_t ne1;
390
+ int64_t ne2;
391
+ uint64_t nb0;
392
+ uint64_t nb1;
393
+ uint64_t nb2;
394
+ } ggml_metal_kargs_ssm_conv;
395
+
396
+ typedef struct {
397
+ int64_t d_state;
398
+ int64_t d_inner;
399
+ int64_t n_seq_tokens;
400
+ int64_t n_seqs;
401
+ uint64_t nb00;
402
+ uint64_t nb01;
403
+ uint64_t nb02;
404
+ uint64_t nb10;
405
+ uint64_t nb11;
406
+ uint64_t nb12;
407
+ uint64_t nb13;
408
+ uint64_t nb20;
409
+ uint64_t nb21;
410
+ uint64_t nb22;
411
+ uint64_t nb30;
412
+ uint64_t nb31;
413
+ uint64_t nb40;
414
+ uint64_t nb41;
415
+ uint64_t nb42;
416
+ uint64_t nb50;
417
+ uint64_t nb51;
418
+ uint64_t nb52;
419
+ } ggml_metal_kargs_ssm_scan;
420
+
421
+ typedef struct {
422
+ int64_t ne00;
423
+ uint64_t nb01;
424
+ uint64_t nb02;
425
+ int64_t ne10;
426
+ uint64_t nb10;
427
+ uint64_t nb11;
428
+ uint64_t nb1;
429
+ uint64_t nb2;
430
+ } ggml_metal_kargs_get_rows;
431
+
432
+ typedef struct {
433
+ int64_t ne00;
434
+ int64_t ne01;
435
+ int64_t ne02;
436
+ int64_t ne03;
437
+ uint64_t nb00;
438
+ uint64_t nb01;
439
+ uint64_t nb02;
440
+ uint64_t nb03;
441
+ int64_t ne0;
442
+ int64_t ne1;
443
+ int64_t ne2;
444
+ int64_t ne3;
445
+ uint64_t nb0;
446
+ uint64_t nb1;
447
+ uint64_t nb2;
448
+ uint64_t nb3;
449
+ float sf0;
450
+ float sf1;
451
+ float sf2;
452
+ float sf3;
453
+ } ggml_metal_kargs_upscale;
454
+
455
+ typedef struct {
456
+ int64_t ne00;
457
+ int64_t ne01;
458
+ int64_t ne02;
459
+ int64_t ne03;
460
+ uint64_t nb00;
461
+ uint64_t nb01;
462
+ uint64_t nb02;
463
+ uint64_t nb03;
464
+ int64_t ne0;
465
+ int64_t ne1;
466
+ int64_t ne2;
467
+ int64_t ne3;
468
+ uint64_t nb0;
469
+ uint64_t nb1;
470
+ uint64_t nb2;
471
+ uint64_t nb3;
472
+ } ggml_metal_kargs_pad;
473
+
474
+ typedef struct {
475
+ int64_t ne00;
476
+ int64_t ne01;
477
+ int64_t ne02;
478
+ int64_t ne03;
479
+ uint64_t nb00;
480
+ uint64_t nb01;
481
+ uint64_t nb02;
482
+ uint64_t nb03;
483
+ int64_t ne0;
484
+ int64_t ne1;
485
+ int64_t ne2;
486
+ int64_t ne3;
487
+ uint64_t nb0;
488
+ uint64_t nb1;
489
+ uint64_t nb2;
490
+ uint64_t nb3;
491
+ int32_t p0;
492
+ int32_t p1;
493
+ } ggml_metal_kargs_pad_reflect_1d;
494
+
495
+ typedef struct {
496
+ uint64_t nb1;
497
+ int dim;
498
+ int max_period;
499
+ } ggml_metal_kargs_timestep_embedding;
500
+
501
+ typedef struct {
502
+ float slope;
503
+ } ggml_metal_kargs_leaky_relu;
504
+
505
+ typedef struct {
506
+ int64_t ncols;
507
+ int64_t ncols_pad;
508
+ } ggml_metal_kargs_argsort;
509
+
510
+ typedef struct {
511
+ int64_t ne0;
512
+ float start;
513
+ float step;
514
+ } ggml_metal_kargs_arange;
515
+
516
+ typedef struct {
517
+ int32_t k0;
518
+ int32_t k1;
519
+ int32_t s0;
520
+ int32_t s1;
521
+ int32_t p0;
522
+ int32_t p1;
523
+ int64_t IH;
524
+ int64_t IW;
525
+ int64_t OH;
526
+ int64_t OW;
527
+ int64_t parallel_elements;
528
+ } ggml_metal_kargs_pool_2d;
529
+
288
530
  #endif // GGML_METAL_IMPL
@@ -21,7 +21,7 @@ if (MUSAToolkit_FOUND)
21
21
  message(STATUS "MUSA Toolkit found")
22
22
 
23
23
  if (NOT DEFINED MUSA_ARCHITECTURES)
24
- set(MUSA_ARCHITECTURES "21;22")
24
+ set(MUSA_ARCHITECTURES "21;22;31")
25
25
  endif()
26
26
  message(STATUS "Using MUSA architectures: ${MUSA_ARCHITECTURES}")
27
27
 
@@ -49,7 +49,7 @@ if (MUSAToolkit_FOUND)
49
49
 
50
50
  set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
51
51
  foreach(SOURCE ${GGML_SOURCES_MUSA})
52
- set(COMPILE_FLAGS "-x musa -mtgpu")
52
+ set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu")
53
53
  foreach(ARCH ${MUSA_ARCHITECTURES})
54
54
  set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
55
55
  endforeach()
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
67
67
  add_compile_definitions(GGML_USE_MUSA)
68
68
  add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
69
69
 
70
- if (GGML_CUDA_GRAPHS)
71
- add_compile_definitions(GGML_CUDA_USE_GRAPHS)
72
- endif()
73
-
74
70
  if (GGML_CUDA_FORCE_MMQ)
75
71
  add_compile_definitions(GGML_CUDA_FORCE_MMQ)
76
72
  endif()
@@ -83,6 +79,10 @@ if (MUSAToolkit_FOUND)
83
79
  add_compile_definitions(GGML_CUDA_NO_VMM)
84
80
  endif()
85
81
 
82
+ if (NOT GGML_CUDA_FA)
83
+ add_compile_definitions(GGML_CUDA_NO_FA)
84
+ endif()
85
+
86
86
  if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
87
87
  add_compile_definitions(GGML_CUDA_F16)
88
88
  endif()
@@ -15,6 +15,7 @@ if (GGML_OPENCL_PROFILING)
15
15
  endif ()
16
16
 
17
17
  add_compile_definitions(GGML_OPENCL_SOA_Q)
18
+ add_compile_definitions(GGML_OPENCL_TARGET_VERSION=${GGML_OPENCL_TARGET_VERSION})
18
19
 
19
20
  if (GGML_OPENCL_USE_ADRENO_KERNELS)
20
21
  message(STATUS "OpenCL will use matmul kernels optimized for Adreno")