@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -254,14 +254,13 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
254
254
  GGML_ASSERT(ncols % WARP_SIZE == 0);
255
255
  if (ncols < 1024) {
256
256
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
257
- stream->submit([&](sycl::handler& cgh) {
258
- cgh.parallel_for(
259
- sycl::nd_range<3>(global_dims * block_dims, block_dims),
260
- [=](sycl::nd_item<3> item_ct1)
261
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
262
- norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
263
- });
264
- });
257
+ sycl_launch(stream, [&](sycl::handler & cgh) {
258
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
259
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
260
+ norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
261
+ nullptr, WARP_SIZE);
262
+ });
263
+ });
265
264
  }
266
265
  else {
267
266
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -272,16 +271,15 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
272
271
  the limit. To get the device limit, query
273
272
  info::device::max_work_group_size. Adjust the work-group size if needed.
274
273
  */
275
- stream->submit([&](sycl::handler& cgh) {
274
+ sycl_launch(stream, [&](sycl::handler & cgh) {
276
275
  sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
277
276
  sycl::range<1>(work_group_size / WARP_SIZE), cgh);
278
- cgh.parallel_for(
279
- sycl::nd_range<3>(global_dims * block_dims, block_dims),
280
- [=](sycl::nd_item<3> item_ct1)
281
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
282
- norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
283
- });
284
- });
277
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
278
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
279
+ norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
280
+ get_pointer(s_sum_acc_ct1), work_group_size);
281
+ });
282
+ });
285
283
  }
286
284
  }
287
285
 
@@ -290,18 +288,14 @@ static void group_norm_f32_sycl(const float* x, float* dst,
290
288
  const int ne_elements, queue_ptr stream, int device) {
291
289
  if (group_size < 1024) {
292
290
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
293
- stream->submit([&](sycl::handler& cgh) {
291
+ sycl_launch(stream, [&](sycl::handler & cgh) {
294
292
  const float eps_ct4 = eps;
295
- cgh.parallel_for(
296
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
297
- block_dims),
298
- [=](sycl::nd_item<3> item_ct1)
299
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
300
- group_norm_f32(
301
- x, dst, group_size, ne_elements, eps_ct4, item_ct1,
302
- nullptr, WARP_SIZE);
303
- });
304
- });
293
+ sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
294
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
295
+ group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
296
+ WARP_SIZE);
297
+ });
298
+ });
305
299
  }
306
300
  else {
307
301
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -313,22 +307,18 @@ static void group_norm_f32_sycl(const float* x, float* dst,
313
307
  info::device::max_work_group_size. Adjust the work-group size if needed.
314
308
  */
315
309
 
316
- stream->submit([&](sycl::handler& cgh) {
310
+ sycl_launch(stream, [&](sycl::handler & cgh) {
317
311
  sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
318
312
  cgh);
319
313
 
320
314
  const float eps_ct4 = eps;
321
315
 
322
- cgh.parallel_for(
323
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
324
- block_dims),
325
- [=](sycl::nd_item<3> item_ct1)
326
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
327
- group_norm_f32(x, dst, group_size, ne_elements,
328
- eps_ct4, item_ct1,
329
- get_pointer(s_sum_acc_ct1), work_group_size);
330
- });
331
- });
316
+ sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
317
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
318
+ group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
319
+ get_pointer(s_sum_acc_ct1), work_group_size);
320
+ });
321
+ });
332
322
  }
333
323
  }
334
324
 
@@ -340,14 +330,13 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
340
330
  const sycl::range<3> global_dims(nsamples, nchannels, nrows);
341
331
  if (ncols < 1024) {
342
332
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
343
- stream->submit([&](sycl::handler& cgh) {
344
- cgh.parallel_for(
345
- sycl::nd_range<3>(global_dims * block_dims, block_dims),
346
- [=](sycl::nd_item<3> item_ct1)
347
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
348
- rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
349
- });
350
- });
333
+ sycl_launch(stream, [&](sycl::handler & cgh) {
334
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
335
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
336
+ rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
337
+ nullptr, WARP_SIZE);
338
+ });
339
+ });
351
340
  }
352
341
  else {
353
342
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -358,16 +347,15 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
358
347
  the limit. To get the device limit, query
359
348
  info::device::max_work_group_size. Adjust the work-group size if needed.
360
349
  */
361
- stream->submit([&](sycl::handler& cgh) {
350
+ sycl_launch(stream, [&](sycl::handler & cgh) {
362
351
  sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
363
352
  cgh);
364
- cgh.parallel_for(
365
- sycl::nd_range<3>(global_dims * block_dims, block_dims),
366
- [=](sycl::nd_item<3> item_ct1)
367
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
368
- rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
369
- });
370
- });
353
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
354
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
355
+ rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
356
+ get_pointer(s_sum_acc_ct1), work_group_size);
357
+ });
358
+ });
371
359
  }
372
360
  }
373
361
 
@@ -378,16 +366,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
378
366
  // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
379
367
  if (ncols < 1024) {
380
368
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
381
- stream->submit([&](sycl::handler& cgh) {
382
- cgh.parallel_for(
383
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
384
- block_dims),
385
- [=](sycl::nd_item<3> item_ct1)
386
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
387
- l2_norm_f32(x, dst, ncols, eps, item_ct1,
388
- nullptr, WARP_SIZE);
389
- });
390
- });
369
+ sycl_launch(stream, [&](sycl::handler & cgh) {
370
+ sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
371
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
372
+ l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
373
+ });
374
+ });
391
375
  }
392
376
  else {
393
377
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -398,18 +382,15 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
398
382
  the limit. To get the device limit, query
399
383
  info::device::max_work_group_size. Adjust the work-group size if needed.
400
384
  */
401
- stream->submit([&](sycl::handler& cgh) {
385
+ sycl_launch(stream, [&](sycl::handler & cgh) {
402
386
  sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
403
387
  cgh);
404
- cgh.parallel_for(
405
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
406
- block_dims),
407
- [=](sycl::nd_item<3> item_ct1)
408
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
409
- l2_norm_f32(x, dst, ncols, eps, item_ct1,
410
- get_pointer(s_sum_acc_ct1), work_group_size);
411
- });
412
- });
388
+ sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
389
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
390
+ l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
391
+ work_group_size);
392
+ });
393
+ });
413
394
  }
414
395
  }
415
396
 
@@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
47
47
 
48
48
  const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
49
49
 
50
- if (i0 >= n_dims) {
51
- const int i = row * ne0 + i0;
52
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
53
- return;
54
- }
55
-
56
50
  const int row0 = row % ne1;
57
51
  const int channel0 = row / ne1;
58
52
 
59
53
  const int i = row * ne0 + i0;
60
54
  const int i2 = channel0 * s2 + row0 * s1 + i0;
61
55
 
56
+ if (i0 >= n_dims) {
57
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
58
+ return;
59
+ }
60
+
62
61
  const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
63
62
 
64
63
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
88
87
 
89
88
  const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
90
89
 
91
- if (i0 >= n_dims) {
92
- const int i = row * ne0 + i0;
93
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
94
- return;
95
- }
96
-
97
90
  const int row0 = row % ne1;
98
91
  const int channel0 = row / ne1;
99
92
 
100
93
  const int i = row * ne0 + i0 / 2;
101
94
  const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
102
95
 
96
+ if (i0 >= n_dims) {
97
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
98
+ return;
99
+ }
100
+
103
101
  const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
104
102
 
105
103
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
129
127
  }
130
128
  const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
131
129
 
132
- if (i0 >= n_dims) {
133
- const int i = row_dst*ne0 + i0;
134
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
135
- return;
136
- }
137
-
138
130
  const int row_x = row_dst % ne1;
139
131
  const int channel_x = row_dst / ne1;
140
132
  const int idst = (row_dst * ne0) + (i0 / 2);
141
133
  const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
142
134
 
135
+ if (i0 >= n_dims) {
136
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
137
+ return;
138
+ }
139
+
143
140
  const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
144
141
  const int sec_w = sections.v[1] + sections.v[0];
145
142
  const int sector = (i0 / 2) % sect_dims;
@@ -235,20 +232,22 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
235
232
  the limit. To get the device limit, query
236
233
  info::device::max_work_group_size. Adjust the work-group size if needed.
237
234
  */
238
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
239
- rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
240
- theta_scale, freq_factors, item_ct1);
241
- });
235
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
236
+ [=](sycl::nd_item<3> item_ct1) {
237
+ rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
238
+ attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
239
+ });
242
240
  } else {
243
241
  /*
244
242
  DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
245
243
  the limit. To get the device limit, query
246
244
  info::device::max_work_group_size. Adjust the work-group size if needed.
247
245
  */
248
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
249
- rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
250
- theta_scale, freq_factors, item_ct1);
251
- });
246
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
247
+ [=](sycl::nd_item<3> item_ct1) {
248
+ rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
249
+ attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
250
+ });
252
251
  }
253
252
  }
254
253
 
@@ -267,15 +266,17 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
267
266
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
268
267
 
269
268
  if (freq_factors == nullptr) {
270
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
271
- rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
272
- theta_scale, freq_factors, item_ct1);
273
- });
269
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
270
+ [=](sycl::nd_item<3> item_ct1) {
271
+ rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
272
+ attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
273
+ });
274
274
  } else {
275
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
276
- rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
277
- theta_scale, freq_factors, item_ct1);
278
- });
275
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
276
+ [=](sycl::nd_item<3> item_ct1) {
277
+ rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
278
+ attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
279
+ });
279
280
  }
280
281
  }
281
282
 
@@ -298,12 +299,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
298
299
  }
299
300
  // launch kernel
300
301
  if (freq_factors == nullptr) {
301
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
302
+ sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
302
303
  rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
303
304
  corr_dims, theta_scale, freq_factors, sections, item_ct1);
304
305
  });
305
306
  } else {
306
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
307
+ sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
307
308
  rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
308
309
  corr_dims, theta_scale, freq_factors, sections, item_ct1);
309
310
  });
@@ -333,12 +334,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
333
334
  }
334
335
  // launch kernel
335
336
  if (freq_factors == nullptr) {
336
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
337
+ sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
337
338
  rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
338
339
  corr_dims, theta_scale, freq_factors, sections, item_ct1);
339
340
  });
340
341
  } else {
341
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
342
+ sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
342
343
  rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
343
344
  corr_dims, theta_scale, freq_factors, sections, item_ct1);
344
345
  });
@@ -0,0 +1,131 @@
1
+ #include "set_rows.hpp"
2
+
3
+ namespace utils {
4
+ template<typename T>
5
+ static constexpr bool is_arithmetic_v() {
6
+ return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
7
+ }
8
+ }
9
+ template<typename TIn, typename TOut>
10
+ static inline std::enable_if_t<utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void>
11
+ convert (const char* src, char* dst) {
12
+ auto src_val = *reinterpret_cast<const TIn*>(src);
13
+ auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut, sycl::rounding_mode::automatic>()[0];
14
+ *reinterpret_cast<TOut*>(dst) = dst_val;;
15
+ }
16
+
17
+ template<typename TIn, typename TOut>
18
+ static void k_set_rows(
19
+ const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
20
+ const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
21
+ const size_t nb01, const size_t nb02, const size_t nb03,
22
+ const size_t nb10, const size_t nb11, const size_t nb12,
23
+ const size_t nb1, const size_t nb2, const size_t nb3,
24
+ const size_t src_type_size, const size_t dst_type_size,
25
+ const sycl::nd_item<3> & item_ct1) {
26
+
27
+ const int i03 = item_ct1.get_group(0);
28
+ const int i02 = item_ct1.get_group(1);
29
+ const int i01 = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); // Row index
30
+
31
+ if (i01 >= ne01) {
32
+ return;
33
+ }
34
+
35
+ const int i12 = i03 % ne12;
36
+ const int i11 = i02 % ne11;
37
+ const int i10 = i01;
38
+
39
+ const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
40
+
41
+ const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});
42
+ char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
43
+
44
+ for (int col = item_ct1.get_local_id(0); col < ne00; col += item_ct1.get_local_range(0)) {
45
+ const char * src_elem = src0_row + col * src_type_size;
46
+ char * dst_elem = dst_row_ptr + col * dst_type_size;
47
+ convert<TIn, TOut>(src_elem, dst_elem);
48
+ }
49
+ }
50
+
51
+ template<typename TIn, typename TOut>
52
+ static void set_rows_sycl(
53
+ const char * src0_d, const int64_t * src1_d, char * dst_d,
54
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
55
+ const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03,
56
+ const size_t nb10, const size_t nb11, const size_t nb12,
57
+ const size_t nb1, const size_t nb2, const size_t nb3,
58
+ const size_t src_type_size, const size_t dst_type_size,
59
+ queue_ptr stream) {
60
+
61
+ constexpr int max_threads_per_row = 64; // KEEPING 64 for now
62
+ const int threads_per_row = std::min((int)ne00, max_threads_per_row);
63
+
64
+ constexpr int max_threads_per_block = 64;
65
+ const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row);
66
+
67
+ const sycl::range<3> block_size(1, rows_per_block, threads_per_row);
68
+ const sycl::range<3> grid_size(ne03, ne02, (ne01 + rows_per_block - 1) / rows_per_block);
69
+
70
+ sycl_parallel_for(
71
+ stream,
72
+ sycl::nd_range<3>(grid_size * block_size, block_size),
73
+ [=](sycl::nd_item<3> item_ct1) {
74
+ k_set_rows<TIn, TOut>(
75
+ src0_d, src1_d, dst_d,
76
+ ne00, ne01, ne11, ne12,
77
+ nb01, nb02, nb03,
78
+ nb10, nb11, nb12,
79
+ nb1, nb2, nb3,
80
+ src_type_size, dst_type_size,
81
+ item_ct1
82
+ );
83
+ }
84
+ );
85
+ }
86
+
87
+
88
+ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
89
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
90
+ const ggml_tensor * src0 = dst->src[0];
91
+ const ggml_tensor * src1 = dst->src[1];
92
+
93
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
94
+ GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64);
95
+
96
+ GGML_TENSOR_BINARY_OP_LOCALS
97
+
98
+ const int64_t * src1_dd = static_cast<const int64_t *>(src1->data);
99
+
100
+ dpct::queue_ptr stream = ctx.stream();
101
+ switch (dst->type) {
102
+ case GGML_TYPE_F32:
103
+ set_rows_sycl<float, float>(
104
+ (const char *)src0->data, src1_dd, (char *)dst->data,
105
+ ne00, ne01, ne02, ne03,
106
+ ne11, ne12,
107
+ nb01, nb02, nb03,
108
+ nb10, nb11, nb12,
109
+ nb1, nb2, nb3,
110
+ sizeof(float), sizeof(float),
111
+ stream
112
+ );
113
+ break;
114
+ case GGML_TYPE_F16:
115
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
116
+ set_rows_sycl<float, sycl::half>(
117
+ (const char *)src0->data, src1_dd, (char *)dst->data,
118
+ ne00, ne01, ne02, ne03,
119
+ ne11, ne12,
120
+ nb01, nb02, nb03,
121
+ nb10, nb11, nb12,
122
+ nb1, nb2, nb3,
123
+ sizeof(float), sizeof(sycl::half),
124
+ stream
125
+ );
126
+ break;
127
+ default:
128
+ GGML_ABORT("Unsupported tensor type!");
129
+ break;
130
+ }
131
+ }
@@ -0,0 +1,8 @@
1
+ #ifndef GGML_SYCL_SET_ROWS_HPP
2
+ #define GGML_SYCL_SET_ROWS_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7
+
8
+ #endif // GGML_SYCL_SET_ROWS_HPP
@@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
127
127
  const int nrows_y, const float scale, const float max_bias, const float m0,
128
128
  const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129
129
  const size_t n_local_scratch, queue_ptr stream) {
130
- stream->submit([&](sycl::handler &cgh) {
130
+ sycl_launch(stream, [&](sycl::handler & cgh) {
131
131
  sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
132
132
 
133
- cgh.parallel_for(
134
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
133
+ sycl_parallel_for(
134
+ cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
135
135
  [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
136
136
  soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
137
137
  nrows_y, scale, max_bias, m0,
@@ -1,6 +1,7 @@
1
1
  #include "sycl_hw.hpp"
2
2
 
3
-
3
+ // TODO: currently not used
4
+ /*
4
5
  sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
5
6
  sycl_hw_info res;
6
7
  int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
@@ -11,3 +12,4 @@ sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
11
12
 
12
13
  return res;
13
14
  }
15
+ */
@@ -10,6 +10,8 @@
10
10
 
11
11
  namespace syclex = sycl::ext::oneapi::experimental;
12
12
 
13
+ // TODO: currently not used
14
+ /*
13
15
  struct sycl_hw_info {
14
16
  syclex::architecture arch;
15
17
  int32_t device_id;
@@ -18,6 +20,7 @@ struct sycl_hw_info {
18
20
  bool is_in_vector(std::vector<int> &vec, int item);
19
21
 
20
22
  sycl_hw_info get_device_hw_info(sycl::device *device_ptr);
23
+ */
21
24
 
22
25
 
23
26
  #endif // SYCL_HW_HPP
@@ -45,14 +45,9 @@ static void timestep_embedding_f32_sycl(
45
45
  int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
46
46
  sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
47
47
  sycl::range<3> gridDim(1, ne00, num_blocks);
48
- stream->parallel_for(
49
- sycl::nd_range<3>(
50
- gridDim * block_dims, block_dims),
51
- [=](sycl::nd_item<3> item_ct1) {
52
- timestep_embedding_f32(
53
- x, dst, nb1, dim, max_period, item_ct1
54
- );
55
- });
48
+ sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
49
+ timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1);
50
+ });
56
51
  }
57
52
 
58
53
  void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -207,12 +207,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
207
207
 
208
208
  // Submit kernel
209
209
  if (C / H == WKV_BLOCK_SIZE) {
210
- stream->submit([&](sycl::handler& cgh) {
210
+ sycl_launch(stream, [&](sycl::handler & cgh) {
211
211
  sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
212
212
 
213
- cgh.parallel_for(
214
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
215
- [=](sycl::nd_item<3> item_ct1) {
213
+ sycl_parallel_for(
214
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
216
215
  rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
217
216
  B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
218
217
  item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -220,12 +219,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
220
219
  });
221
220
  });
222
221
  } else {
223
- stream->submit([&](sycl::handler& cgh) {
222
+ sycl_launch(stream, [&](sycl::handler & cgh) {
224
223
  sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
225
224
 
226
- cgh.parallel_for(
227
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
228
- [=](sycl::nd_item<3> item_ct1) {
225
+ sycl_parallel_for(
226
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
229
227
  rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
230
228
  B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
231
229
  item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -264,12 +262,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
264
262
 
265
263
  // Submit kernel
266
264
  if (C / H == WKV_BLOCK_SIZE) {
267
- stream->submit([&](sycl::handler& cgh) {
265
+ sycl_launch(stream, [&](sycl::handler & cgh) {
268
266
  sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
269
267
 
270
- cgh.parallel_for(
271
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
272
- [=](sycl::nd_item<3> item_ct1) {
268
+ sycl_parallel_for(
269
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
273
270
  rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
274
271
  B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
275
272
  item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -277,12 +274,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
277
274
  });
278
275
  });
279
276
  } else {
280
- stream->submit([&](sycl::handler& cgh) {
277
+ sycl_launch(stream, [&](sycl::handler & cgh) {
281
278
  sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
282
279
 
283
- cgh.parallel_for(
284
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
285
- [=](sycl::nd_item<3> item_ct1) {
280
+ sycl_parallel_for(
281
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
286
282
  rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
287
283
  B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
288
284
  item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()