@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -1,12 +1,19 @@
1
1
  #include "common.hpp"
2
+ #include "ggml-sycl/presets.hpp"
2
3
  #include "ggml.h"
3
4
  #include "element_wise.hpp"
4
5
 
6
+ #define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
7
+ for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
8
+
9
+ #define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
10
+ (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
11
+
12
+
5
13
  static void acc_f32(const float * x, const float * y, float * dst, const int ne,
6
14
  const int ne10, const int ne11, const int ne12,
7
- const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
8
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
9
- item_ct1.get_local_id(2);
15
+ const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
16
+ const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
10
17
  if (i >= ne) {
11
18
  return;
12
19
  }
@@ -21,248 +28,280 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
21
28
  }
22
29
  }
23
30
 
31
+ /* Unary OP funcs */
24
32
  template<typename T>
25
- static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
26
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
27
- dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
28
- }
33
+ static __dpct_inline__ T op_sgn(T x) {
34
+ return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
29
35
  }
30
36
 
31
37
  template<typename T>
32
- static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
33
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
34
- dst[i] = sycl::fabs(x[i]);
35
- }
38
+ static __dpct_inline__ T op_abs(T x) {
39
+ return sycl::fabs(x);
36
40
  }
37
41
 
38
42
  template<typename T>
39
- static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
40
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
41
- dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
42
- }
43
+ static __dpct_inline__ T op_elu(T x) {
44
+ return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
43
45
  }
44
46
 
45
47
  template<typename T>
46
- static void gelu(const T * x, T * dst, const int k,
47
- const sycl::nd_item<3> &item_ct1) {
48
+ static __dpct_inline__ T op_gelu(T x) {
48
49
  const T GELU_COEF_A = static_cast<T>(0.044715f);
49
50
  const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
50
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
51
- item_ct1.get_local_id(2);
51
+ return static_cast<T>(0.5f) * x *
52
+ (static_cast<T>(1.0f) +
53
+ sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
54
+ }
52
55
 
53
- if (i >= k) {
54
- return;
55
- }
56
+ template<typename T>
57
+ static __dpct_inline__ T op_silu(T x) {
58
+ return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
59
+ }
56
60
 
57
- float xi = x[i];
58
- dst[i] = static_cast<T>(0.5f) * xi *
59
- (static_cast<T>(1.0f) +
60
- sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast<T>(1.0f) + GELU_COEF_A * xi * xi)));
61
+ template<typename T>
62
+ static __dpct_inline__ T op_gelu_quick(T x) {
63
+ const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
64
+ return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
61
65
  }
62
66
 
63
67
  template<typename T>
64
- static void silu(const T * x, T * dst, const int k,
65
- const sycl::nd_item<3> &item_ct1) {
66
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
67
- item_ct1.get_local_id(2);
68
+ static __dpct_inline__ T op_gelu_erf(T x) {
69
+ const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
70
+ return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
71
+ }
68
72
 
69
- if (i >= k) {
70
- return;
71
- }
72
- dst[i] = x[i] / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
73
+ template<typename T>
74
+ static __dpct_inline__ T op_tanh(T x) {
75
+ return sycl::tanh(x);
73
76
  }
74
77
 
75
78
  template<typename T>
76
- static void gelu_quick(const T *x, T *dst, int k,
77
- const sycl::nd_item<3> &item_ct1) {
78
- const float GELU_QUICK_COEF = -1.702f;
79
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
80
- item_ct1.get_local_id(2);
81
- if (i >= k) {
82
- return;
83
- }
84
- dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
79
+ static __dpct_inline__ T op_relu(T x) {
80
+ return sycl::fmax(x, static_cast<T>(0));
85
81
  }
86
82
 
87
83
  template<typename T>
88
- static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
89
- const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
90
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
91
- auto x_i = x[i];
92
- dst[i] = static_cast<T>(0.5f) * x_i * (static_cast<T>(1.0f) + sycl::erf(x_i * SQRT_2_INV));
93
- }
84
+ static __dpct_inline__ T op_sigmoid(T x) {
85
+ return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
94
86
  }
95
87
 
96
88
  template<typename T>
97
- static void tanh(const T *x, T *dst, int k,
98
- const sycl::nd_item<3> &item_ct1) {
99
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
100
- item_ct1.get_local_id(2);
101
- if (i >= k) {
102
- return;
103
- }
104
- dst[i] = sycl::tanh((x[i]));
89
+ static __dpct_inline__ T op_sqrt(T x) {
90
+ return sycl::sqrt(x);
91
+ }
92
+
93
+ template<typename T>
94
+ static __dpct_inline__ T op_sin(T x) {
95
+ return sycl::sin(x);
105
96
  }
106
97
 
107
98
  template<typename T>
108
- static void relu(const T * x, T * dst, const int k,
109
- const sycl::nd_item<3> &item_ct1) {
110
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
111
- item_ct1.get_local_id(2);
99
+ static __dpct_inline__ T op_cos(T x) {
100
+ return sycl::cos(x);
101
+ }
112
102
 
113
- if (i >= k) {
114
- return;
115
- }
116
- dst[i] = sycl::fmax((x[i]), static_cast<T>(0));
103
+ template<typename T>
104
+ static __dpct_inline__ T op_hardsigmoid(T x) {
105
+ return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
117
106
  }
118
107
 
119
108
  template<typename T>
120
- static void sigmoid(const T * x, T * dst, const int k,
121
- const sycl::nd_item<3> &item_ct1) {
122
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
123
- item_ct1.get_local_id(2);
109
+ static __dpct_inline__ T op_hardswish(T x) {
110
+ return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
111
+ }
124
112
 
125
- if (i >= k) {
126
- return;
113
+ template<typename T>
114
+ static __dpct_inline__ T op_exp(T x) {
115
+ return sycl::exp(x);
116
+ }
117
+
118
+ template<typename T>
119
+ static __dpct_inline__ T op_log(T x) {
120
+ if (x <= static_cast<T>(0)) {
121
+ return neg_infinity<T>();
127
122
  }
128
- dst[i] = 1.0f / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
123
+ return sycl::log(x);
129
124
  }
130
125
 
131
126
  template<typename T>
132
- static void sqrt(const T * x, T * dst, const int k,
133
- const sycl::nd_item<3> &item_ct1) {
134
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
135
- item_ct1.get_local_id(2);
127
+ static __dpct_inline__ T op_neg(T x) {
128
+ return -x;
129
+ }
136
130
 
137
- if (i >= k) {
138
- return;
131
+ template<typename T>
132
+ static __dpct_inline__ T op_step(T x) {
133
+ return (x > static_cast<T>(0.0f)) ? static_cast<T>(1.0f) : static_cast<T>(0.0f);
134
+ }
135
+
136
+ template<typename T>
137
+ static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
138
+ T neg_slope_T = static_cast<T>(negative_slope);
139
+ return sycl::fmax(x, static_cast<T>(0)) +
140
+ sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
141
+ }
142
+
143
+ template<typename T>
144
+ static __dpct_inline__ T op_sqr(T x) {
145
+ return x * x;
146
+ }
147
+
148
+ template<typename T>
149
+ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
150
+ return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
151
+ }
152
+
153
+ template<typename T>
154
+ static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
155
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
156
+ dst[i] = op_sgn(x[i]);
139
157
  }
140
- dst[i] = sycl::sqrt(x[i]);
141
158
  }
142
159
 
143
160
  template<typename T>
144
- static void sin(const T * x, T * dst, const int k,
145
- const sycl::nd_item<3> &item_ct1) {
146
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
147
- item_ct1.get_local_id(2);
161
+ static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
162
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
163
+ dst[i] = op_abs(x[i]);
164
+ }
165
+ }
148
166
 
149
- if (i >= k) {
150
- return;
167
+ template<typename T>
168
+ static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
169
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
170
+ dst[i] = op_elu(x[i]);
151
171
  }
152
- dst[i] = sycl::sin(x[i]);
153
172
  }
154
173
 
155
174
  template<typename T>
156
- static void cos(const T * x, T * dst, const int k,
157
- const sycl::nd_item<3> &item_ct1) {
158
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
159
- item_ct1.get_local_id(2);
175
+ static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
176
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
177
+ dst[i] = op_gelu(x[i]);
178
+ }
179
+ }
160
180
 
161
- if (i >= k) {
162
- return;
181
+ template<typename T>
182
+ static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
183
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
184
+ dst[i] = op_silu(x[i]);
163
185
  }
164
- dst[i] = sycl::cos(x[i]);
165
186
  }
166
187
 
167
188
  template<typename T>
168
- static void hardsigmoid(const T * x, T * dst, const int k,
169
- const sycl::nd_item<3> &item_ct1) {
170
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
171
- item_ct1.get_local_id(2);
189
+ static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
190
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
191
+ dst[i] = op_gelu_quick(x[i]);
192
+ }
193
+ }
172
194
 
173
- if (i >= k) {
174
- return;
195
+ template<typename T>
196
+ static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
197
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
198
+ dst[i] = op_gelu_erf(x[i]);
199
+ }
200
+ }
201
+
202
+ template<typename T>
203
+ static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
204
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
205
+ dst[i] = op_tanh(x[i]);
175
206
  }
176
- dst[i] = sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
177
207
  }
178
208
 
179
209
  template<typename T>
180
- static void hardswish(const T * x, T * dst, const int k,
181
- const sycl::nd_item<3> &item_ct1) {
182
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
183
- item_ct1.get_local_id(2);
210
+ static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
211
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
212
+ dst[i] = op_relu(x[i]);
213
+ }
214
+ }
184
215
 
185
- if (i >= k) {
186
- return;
216
+ template<typename T>
217
+ static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
218
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
219
+ dst[i] = op_sigmoid(x[i]);
187
220
  }
188
- dst[i] = x[i] * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
189
221
  }
190
222
 
191
223
  template<typename T>
192
- static void exp(const T * x, T * dst, const int k,
193
- const sycl::nd_item<3> &item_ct1) {
194
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
195
- item_ct1.get_local_id(2);
224
+ static void unary_op_sqrt_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
225
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
226
+ dst[i] = op_sqrt(x[i]);
227
+ }
228
+ }
196
229
 
197
- if (i >= k) {
198
- return;
230
+ template<typename T>
231
+ static void unary_op_sin_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
232
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
233
+ dst[i] = op_sin(x[i]);
199
234
  }
200
- dst[i] = sycl::exp(x[i]);
201
235
  }
202
236
 
203
237
  template<typename T>
204
- static void log(const T * x, T * dst, const int k,
205
- const sycl::nd_item<3> &item_ct1) {
206
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
207
- item_ct1.get_local_id(2);
238
+ static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
239
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
240
+ dst[i] = op_cos(x[i]);
241
+ }
242
+ }
208
243
 
209
- if (i >= k) {
210
- return;
244
+ template<typename T>
245
+ static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
246
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
247
+ dst[i] = op_hardsigmoid(x[i]);
211
248
  }
212
- T xi = x[i];
213
- if (xi <= 0) {
214
- dst[i] = neg_infinity<T>();
215
- } else {
216
- dst[i] = sycl::log(xi);
249
+ }
250
+
251
+ template<typename T>
252
+ static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
253
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
254
+ dst[i] = op_hardswish(x[i]);
217
255
  }
218
256
  }
219
257
 
220
258
  template<typename T>
221
- static void neg(const T * x, T * dst, const int k,
222
- const sycl::nd_item<3> &item_ct1) {
223
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
224
- item_ct1.get_local_id(2);
259
+ static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
260
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
261
+ dst[i] = op_exp(x[i]);
262
+ }
263
+ }
225
264
 
226
- if (i >= k) {
227
- return;
265
+ template<typename T>
266
+ static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
267
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
268
+ dst[i] = op_log(x[i]);
228
269
  }
229
- dst[i] = -x[i];
230
270
  }
231
271
 
232
272
  template<typename T>
233
- static void step(const T * x, T * dst, const int k,
234
- const sycl::nd_item<3> &item_ct1) {
235
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
236
- item_ct1.get_local_id(2);
273
+ static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
274
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
275
+ dst[i] = op_neg(x[i]);
276
+ }
277
+ }
237
278
 
238
- if (i >= k) {
239
- return;
279
+ template<typename T>
280
+ static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
281
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
282
+ dst[i] = op_step(x[i]);
240
283
  }
241
- dst[i] = x[i] > static_cast<T>(0.0f);
242
284
  }
243
285
 
244
286
  template<typename T>
245
- static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope,
246
- const sycl::nd_item<3> &item_ct1) {
247
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
248
- item_ct1.get_local_id(2);
249
- if (i >= k) {
250
- return;
287
+ static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
288
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
289
+ dst[i] = op_leaky_relu(x[i], negative_slope);
251
290
  }
252
- dst[i] = sycl::fmax((x[i]), static_cast<T>(0)) +
253
- sycl::fmin((x[i]), static_cast<T>(0.0f)) * negative_slope;
254
291
  }
255
292
 
256
293
  template<typename T>
257
- static void sqr(const T * x, T * dst, const int k,
258
- const sycl::nd_item<3> &item_ct1) {
259
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
260
- item_ct1.get_local_id(2);
294
+ static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
295
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
296
+ dst[i] = op_sqr(x[i]);
297
+ }
298
+ }
261
299
 
262
- if (i >= k) {
263
- return;
300
+ template<typename T>
301
+ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
302
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
303
+ dst[i] = op_clamp(x[i], min_val, max_val);
264
304
  }
265
- dst[i] = x[i] * x[i];
266
305
  }
267
306
 
268
307
  template<typename T>
@@ -281,10 +320,10 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
281
320
  int i12 = (index / (ne10 * ne11)) % ne12;
282
321
  int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
283
322
 
284
- int i00 = i10 / sf0;
285
- int i01 = i11 / sf1;
286
- int i02 = i12 / sf2;
287
- int i03 = i13 / sf3;
323
+ int i00 = static_cast<int>(i10 / sf0);
324
+ int i01 = static_cast<int>(i11 / sf1);
325
+ int i02 = static_cast<int>(i12 / sf2);
326
+ int i03 = static_cast<int>(i13 / sf3);
288
327
 
289
328
  dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
290
329
  }
@@ -292,8 +331,7 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
292
331
  template <typename T>
293
332
  static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
294
333
  const sycl::nd_item<3> &item_ct1) {
295
- int nidx = item_ct1.get_local_id(2) +
296
- item_ct1.get_group(2) * item_ct1.get_local_range(2);
334
+ int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
297
335
  if (nidx >= ne0) {
298
336
  return;
299
337
  }
@@ -310,299 +348,72 @@ static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne
310
348
  }
311
349
  }
312
350
 
313
-
314
351
  template<typename T>
315
352
  static void clamp(const T * x, T * dst, const float min, const float max, const int k,
316
- const sycl::nd_item<3> &item_ct1) {
317
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
318
- item_ct1.get_local_id(2);
319
-
320
- if (i >= k) {
321
- return;
353
+ const sycl::nd_item<1> &item_ct1) {
354
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
355
+ dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
322
356
  }
323
-
324
- dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
325
- }
326
-
327
- static void acc_f32_sycl(const float *x, const float *y, float *dst,
328
- const int n_elements, const int ne10, const int ne11,
329
- const int ne12, const int nb1, const int nb2,
330
- const int offset, queue_ptr stream) {
331
- int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
332
- stream->parallel_for(
333
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
334
- sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
335
- sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
336
- [=](sycl::nd_item<3> item_ct1) {
337
- acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
338
- item_ct1);
339
- });
340
- }
341
-
342
- template<typename T>
343
- static void gelu_sycl(const T *x, T *dst, const int k,
344
- queue_ptr stream) {
345
- const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
346
- stream->parallel_for(
347
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
348
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
349
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
350
- [=](sycl::nd_item<3> item_ct1) {
351
- gelu(x, dst, k, item_ct1);
352
- });
353
- }
354
-
355
- template<typename T>
356
- static void silu_sycl(const T *x, T *dst, const int k,
357
- queue_ptr stream) {
358
- const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
359
- stream->parallel_for(
360
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
361
- sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
362
- sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
363
- [=](sycl::nd_item<3> item_ct1) {
364
- silu(x, dst, k, item_ct1);
365
- });
366
- }
367
-
368
- template<typename T>
369
- static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
370
- // hard code for now
371
- const int num_blocks = ceil_div(k, 256);
372
- stream->parallel_for(
373
- sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
374
- sgn(x, dst, k, item_ct1);
375
- });
376
- }
377
-
378
- template<typename T>
379
- static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
380
- // hard code for now
381
- const int num_blocks = ceil_div(k, 256);
382
- stream->parallel_for(
383
- sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
384
- abs_op(x, dst, k, item_ct1);
385
- });
386
- }
387
-
388
-
389
- template<typename T>
390
- static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
391
- // hard code for now
392
- const int num_blocks = ceil_div(k, 256);
393
- stream->parallel_for(
394
- sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
395
- elu_op(x, dst, k, item_ct1);
396
- });
397
- }
398
-
399
- template<typename T>
400
- static void gelu_quick_sycl(const T *x, T *dst, const int k,
401
- queue_ptr stream) {
402
- const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
403
- stream->parallel_for(
404
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
405
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
406
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
407
- [=](sycl::nd_item<3> item_ct1) {
408
- gelu_quick(x, dst, k, item_ct1);
409
- });
410
- }
411
-
412
-
413
- template<typename T>
414
- static void gelu_erf_sycl(const T *x, T *dst, const int k,
415
- queue_ptr stream) {
416
- const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
417
- stream->parallel_for(
418
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
419
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
420
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
421
- [=](sycl::nd_item<3> item_ct1) {
422
- gelu_erf(x, dst, k, item_ct1);
423
- });
424
- }
425
-
426
- template<typename T>
427
- static void tanh_sycl(const T *x, T *dst, const int k,
428
- queue_ptr stream) {
429
- const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
430
- stream->parallel_for(
431
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
432
- sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
433
- sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
434
- [=](sycl::nd_item<3> item_ct1) {
435
- tanh(x, dst, k, item_ct1);
436
- });
437
- }
438
-
439
- template<typename T>
440
- static void relu_sycl(const T *x, T *dst, const int k,
441
- queue_ptr stream) {
442
- const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
443
- stream->parallel_for(
444
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
445
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
446
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
447
- [=](sycl::nd_item<3> item_ct1) {
448
- relu(x, dst, k, item_ct1);
449
- });
450
- }
451
-
452
- template<typename T>
453
- static void hardsigmoid_sycl(const T *x, T *dst, const int k,
454
- queue_ptr stream) {
455
- const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
456
- stream->parallel_for(
457
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
458
- sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
459
- sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
460
- [=](sycl::nd_item<3> item_ct1) {
461
- hardsigmoid(x, dst, k, item_ct1);
462
- });
463
- }
464
-
465
- template<typename T>
466
- static void hardswish_sycl(const T *x, T *dst, const int k,
467
- queue_ptr stream) {
468
- const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
469
- stream->parallel_for(
470
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
471
- sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
472
- sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
473
- [=](sycl::nd_item<3> item_ct1) {
474
- hardswish(x, dst, k, item_ct1);
475
- });
476
- }
477
-
478
- template<typename T>
479
- static void exp_sycl(const T *x, T *dst, const int k,
480
- queue_ptr stream) {
481
- const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
482
- stream->parallel_for(
483
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
484
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
485
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
486
- [=](sycl::nd_item<3> item_ct1) {
487
- exp(x, dst, k, item_ct1);
488
- });
489
357
  }
490
358
 
491
359
  template<typename T>
492
- static void log_sycl(const T *x, T *dst, const int k,
493
- queue_ptr stream) {
494
- const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
495
- stream->parallel_for(
496
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
497
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
498
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
499
- [=](sycl::nd_item<3> item_ct1) {
500
- log(x, dst, k, item_ct1);
501
- });
502
- }
503
-
504
- template<typename T>
505
- static void neg_sycl(const T *x, T *dst, const int k,
506
- queue_ptr stream) {
507
- const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
508
- stream->parallel_for(
509
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
510
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
511
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
512
- [=](sycl::nd_item<3> item_ct1) {
513
- neg(x, dst, k, item_ct1);
514
- });
515
- }
516
-
517
- template<typename T>
518
- static void step_sycl(const T *x, T *dst, const int k,
519
- queue_ptr stream) {
520
- const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
521
- stream->parallel_for(
522
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
523
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
524
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
525
- [=](sycl::nd_item<3> item_ct1) {
526
- step(x, dst, k, item_ct1);
527
- });
528
- }
529
-
530
- template<typename T>
531
- static void sigmoid_sycl(const T *x, T *dst, const int k,
532
- queue_ptr stream) {
533
- const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
534
- stream->parallel_for(
535
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
536
- sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
537
- sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
538
- [=](sycl::nd_item<3> item_ct1) {
539
- sigmoid(x, dst, k, item_ct1);
540
- });
360
+ static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
361
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
362
+ const int64_t j0 = (i / n) * o0 + (i % n);
363
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
364
+ dst[i] = op_gelu(x[j0]) * g[j1];
365
+ }
541
366
  }
542
367
 
543
368
  template<typename T>
544
- static void sqrt_sycl(const T *x, T *dst, const int k,
545
- queue_ptr stream) {
546
- const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
547
- stream->parallel_for(
548
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
549
- sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
550
- sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
551
- [=](sycl::nd_item<3> item_ct1) {
552
- sqrt(x, dst, k, item_ct1);
553
- });
369
+ static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
370
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
371
+ const int64_t j0 = (i / n) * o0 + (i % n);
372
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
373
+ dst[i] = op_relu(x[j0]) * g[j1];
374
+ }
554
375
  }
555
376
 
556
377
  template<typename T>
557
- static void sin_sycl(const T *x, T *dst, const int k,
558
- queue_ptr stream) {
559
- const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
560
- stream->parallel_for(
561
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
562
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
563
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
564
- [=](sycl::nd_item<3> item_ct1) {
565
- sin(x, dst, k, item_ct1);
566
- });
378
+ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
379
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
380
+ const int64_t j0 = (i / n) * o0 + (i % n);
381
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
382
+ dst[i] = op_silu(x[j0]) * g[j1];
383
+ }
567
384
  }
568
385
 
569
386
  template<typename T>
570
- static void cos_sycl(const T *x, T *dst, const int k,
571
- queue_ptr stream) {
572
- const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
573
- stream->parallel_for(
574
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
575
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
576
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
577
- [=](sycl::nd_item<3> item_ct1) {
578
- cos(x, dst, k, item_ct1);
579
- });
387
+ static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
388
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
389
+ const int64_t j0 = (i / n) * o0 + (i % n);
390
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
391
+ dst[i] = op_gelu_erf(x[j0]) * g[j1];
392
+ }
580
393
  }
581
394
 
582
395
  template<typename T>
583
- static void leaky_relu_sycl(const T *x, T *dst, const int k,
584
- const float negative_slope,
585
- queue_ptr stream) {
586
- const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
587
- stream->parallel_for(
588
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
589
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
590
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
591
- [=](sycl::nd_item<3> item_ct1) {
592
- leaky_relu(x, dst, k, negative_slope, item_ct1);
593
- });
396
+ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
397
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
398
+ const int64_t j0 = (i / n) * o0 + (i % n);
399
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
400
+ dst[i] = op_gelu_quick(x[j0]) * g[j1];
401
+ }
594
402
  }
595
403
 
596
- template<typename T>
597
- static void sqr_sycl(const T *x, T *dst, const int k,
598
- queue_ptr stream) {
599
- const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
600
- stream->parallel_for(
601
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
602
- sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
603
- sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
604
- [=](sycl::nd_item<3> item_ct1) {
605
- sqr(x, dst, k, item_ct1);
404
+ namespace ggml_sycl_detail {
405
+ static void acc_f32_sycl(const float *x, const float *y, float *dst,
406
+ const int n_elements, const int ne10, const int ne11,
407
+ const int ne12, const int nb1, const int nb2,
408
+ const int offset, queue_ptr stream) {
409
+ int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
410
+ sycl_parallel_for(stream,
411
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) *
412
+ sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
413
+ sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
414
+ [=](sycl::nd_item<1> item_ct1) {
415
+ acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
416
+ item_ct1);
606
417
  });
607
418
  }
608
419
 
@@ -612,11 +423,10 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
612
423
  const int ne12, const int ne13, const float sf0, const float sf1,
613
424
  const float sf2, const float sf3, queue_ptr stream) {
614
425
  int dst_size = ne10 * ne11 * ne12 * ne13;
615
- int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
426
+ int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
616
427
  sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
617
- stream->parallel_for(
618
- sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
619
- [=](sycl::nd_item<1> item_ct1) {
428
+ sycl_parallel_for<1>(
429
+ stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
620
430
  upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
621
431
  });
622
432
  }
@@ -625,35 +435,19 @@ template<typename T>
625
435
  static void pad_sycl(const T *x, T *dst, const int ne00,
626
436
  const int ne01, const int ne02, const int ne0,
627
437
  const int ne1, const int ne2, queue_ptr stream) {
628
- int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
438
+ int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
629
439
  sycl::range<3> gridDim(ne2, ne1, num_blocks);
630
- stream->parallel_for(
631
- sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
632
- sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
633
- [=](sycl::nd_item<3> item_ct1) {
634
- pad(x, dst, ne0, ne00, ne01, ne02, item_ct1);
635
- });
440
+ sycl_parallel_for(stream,
441
+ sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
442
+ sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
443
+ [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
636
444
  }
637
445
 
638
- template<typename T>
639
- static void clamp_sycl(const T *x, T *dst, const float min,
640
- const float max, const int k,
641
- queue_ptr stream) {
642
- const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
643
- stream->parallel_for(
644
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
645
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
646
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
647
- [=](sycl::nd_item<3> item_ct1) {
648
- clamp(x, dst, min, max, k, item_ct1);
649
- });
650
- }
651
-
652
- inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
446
+ template<typename KernelInvoker, typename... Args>
447
+ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
653
448
  #if defined (GGML_SYCL_F16)
654
449
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
655
450
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
656
-
657
451
  #else
658
452
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
659
453
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -666,14 +460,14 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
666
460
  case GGML_TYPE_F16:
667
461
  {
668
462
  auto data_pts = cast_data<sycl::half>(dst);
669
- sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
463
+ kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
670
464
  break;
671
465
  }
672
466
  #endif
673
467
  case GGML_TYPE_F32:
674
468
  {
675
469
  auto data_pts = cast_data<float>(dst);
676
- sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
470
+ kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
677
471
  break;
678
472
  }
679
473
  default:
@@ -681,11 +475,11 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
681
475
  }
682
476
  }
683
477
 
684
- inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
478
+ template<typename KernelInvoker, typename... Args>
479
+ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
685
480
  #if defined (GGML_SYCL_F16)
686
481
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
687
482
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
688
-
689
483
  #else
690
484
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
691
485
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -693,19 +487,66 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
693
487
  GGML_ASSERT(dst->src[0]->type == dst->type);
694
488
  dpct::queue_ptr main_stream = ctx.stream();
695
489
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
490
+ const ggml_tensor * src0 = dst->src[0];
491
+ const ggml_tensor * src1 = dst->src[1];
492
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;;
493
+ GGML_ASSERT(dst->ne[0] == nc);
494
+ GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
495
+ GGML_ASSERT(ggml_is_contiguous(dst));
496
+ const int32_t swapped = ((const int32_t *) dst->op_params)[1];
497
+ void * src0_d = src0->data;
498
+ void * src1_d = src1 ? src1->data : src0->data;
499
+ const int64_t src0_o = src0->nb[1];
500
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
501
+ void * dst_d = dst->data;
502
+ if (src1) {
503
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
504
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
505
+ GGML_ASSERT(src1->ne[0] == nc);
506
+ GGML_ASSERT(src0->type == src1->type);
507
+ }
696
508
  switch (dst->type) {
697
509
  #if defined (GGML_SYCL_F16)
698
510
  case GGML_TYPE_F16:
699
511
  {
700
- auto data_pts = cast_data<sycl::half>(dst);
701
- abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
512
+ sycl::half * src0_p = (sycl::half *) src0_d;
513
+ sycl::half * src1_p = (sycl::half *) src1_d;
514
+
515
+ if (!src1) {
516
+ src0_p += swapped ? nc : 0;
517
+ src1_p += swapped ? 0 : nc;
518
+ }
519
+ kernel_invoker(src0_p,
520
+ src1_p,
521
+ (sycl::half *) dst_d,
522
+ ggml_nelements(dst),
523
+ nc,
524
+ src0_o / sizeof(sycl::half),
525
+ src1_o / sizeof(sycl::half),
526
+ main_stream,
527
+ std::forward<Args>(args)...);
702
528
  break;
703
529
  }
704
530
  #endif
705
531
  case GGML_TYPE_F32:
706
532
  {
707
- auto data_pts = cast_data<float>(dst);
708
- abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
533
+ float * src0_p = (float *) src0_d;
534
+ float * src1_p = (float *) src1_d;
535
+
536
+ if (!src1) {
537
+ src0_p += swapped ? nc : 0;
538
+ src1_p += swapped ? 0 : nc;
539
+ }
540
+
541
+ kernel_invoker(src0_p,
542
+ src1_p,
543
+ (float *) dst_d,
544
+ ggml_nelements(dst),
545
+ nc,
546
+ src0_o / sizeof(float),
547
+ src1_o / sizeof(float),
548
+ main_stream,
549
+ std::forward<Args>(args)...);
709
550
  break;
710
551
  }
711
552
  default:
@@ -713,32 +554,41 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
713
554
  }
714
555
  }
715
556
 
716
-
717
- inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
557
+ template<typename KernelInvoker, typename... Args>
558
+ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
718
559
  #if defined (GGML_SYCL_F16)
719
560
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
720
561
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
721
-
722
562
  #else
723
563
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
724
564
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
725
565
  #endif
726
566
  GGML_ASSERT(dst->src[0]->type == dst->type);
567
+
727
568
  dpct::queue_ptr main_stream = ctx.stream();
728
569
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
570
+
571
+ const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
572
+ const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
573
+ const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
574
+ const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
729
575
  switch (dst->type) {
730
576
  #if defined (GGML_SYCL_F16)
731
577
  case GGML_TYPE_F16:
732
578
  {
733
579
  auto data_pts = cast_data<sycl::half>(dst);
734
- elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
580
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
581
+ (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
582
+ main_stream, std::forward<Args>(args)...);
735
583
  break;
736
584
  }
737
585
  #endif
738
586
  case GGML_TYPE_F32:
739
587
  {
740
588
  auto data_pts = cast_data<float>(dst);
741
- elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
589
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
590
+ (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
591
+ main_stream, std::forward<Args>(args)...);
742
592
  break;
743
593
  }
744
594
  default:
@@ -746,7 +596,8 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
746
596
  }
747
597
  }
748
598
 
749
- inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
599
+ template<typename KernelInvoker, typename... Args>
600
+ static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
750
601
  #if defined (GGML_SYCL_F16)
751
602
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
752
603
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -755,6 +606,7 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
755
606
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
756
607
  #endif
757
608
  GGML_ASSERT(dst->src[0]->type == dst->type);
609
+ GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
758
610
  dpct::queue_ptr main_stream = ctx.stream();
759
611
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
760
612
  switch (dst->type) {
@@ -762,14 +614,16 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
762
614
  case GGML_TYPE_F16:
763
615
  {
764
616
  auto data_pts = cast_data<sycl::half>(dst);
765
- silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
617
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
618
+ (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
766
619
  break;
767
620
  }
768
621
  #endif
769
622
  case GGML_TYPE_F32:
770
623
  {
771
624
  auto data_pts = cast_data<float>(dst);
772
- silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
625
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
626
+ (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
773
627
  break;
774
628
  }
775
629
  default:
@@ -777,655 +631,320 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
777
631
  }
778
632
  }
779
633
 
780
- inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
781
- #if defined (GGML_SYCL_F16)
782
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
783
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
784
- #else
785
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
786
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
787
- #endif
788
- GGML_ASSERT(dst->src[0]->type == dst->type);
789
- dpct::queue_ptr main_stream = ctx.stream();
790
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
791
- switch (dst->type) {
792
- #if defined (GGML_SYCL_F16)
793
- case GGML_TYPE_F16:
794
- {
795
- auto data_pts = cast_data<sycl::half>(dst);
796
- gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
797
- break;
798
- }
799
- #endif
800
- case GGML_TYPE_F32:
801
- {
802
- auto data_pts = cast_data<float>(dst);
803
- gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
804
- break;
805
- }
806
- default:
807
- GGML_ABORT("GGML tensor type not supported!\n");
808
- }
634
+ } // namespace ggml_sycl_detail
635
+
636
+
637
+
638
+ static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
639
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
640
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
641
+ const int num_blocks = ceil_div(k_elements, 256);
642
+ sycl_parallel_for(stream,
643
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
644
+ sycl::range<1>(256)),
645
+ [=](sycl::nd_item<1> item_ct1) {
646
+ unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
647
+ });
648
+ });
809
649
  }
810
650
 
811
- inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
812
- #if defined (GGML_SYCL_F16)
813
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
814
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
815
- #else
816
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
817
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
818
- #endif
819
- GGML_ASSERT(dst->src[0]->type == dst->type);
820
- dpct::queue_ptr main_stream = ctx.stream();
821
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
822
- switch (dst->type) {
823
- #if defined (GGML_SYCL_F16)
824
- case GGML_TYPE_F16:
825
- {
826
- auto data_pts = cast_data<sycl::half>(dst);
827
- gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
828
- break;
829
- }
830
- #endif
831
- case GGML_TYPE_F32:
832
- {
833
- auto data_pts = cast_data<float>(dst);
834
- gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
835
- break;
836
- }
837
- default:
838
- GGML_ABORT("GGML tensor type not supported!\n");
839
- }
840
- }
841
-
842
- inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
843
- #if defined (GGML_SYCL_F16)
844
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
845
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
846
- #else
847
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
848
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
849
- #endif
850
- GGML_ASSERT(dst->src[0]->type == dst->type);
851
- dpct::queue_ptr main_stream = ctx.stream();
852
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
853
- switch (dst->type) {
854
- #if defined (GGML_SYCL_F16)
855
- case GGML_TYPE_F16:
856
- {
857
- auto data_pts = cast_data<sycl::half>(dst);
858
- gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
859
- break;
860
- }
861
- #endif
862
- case GGML_TYPE_F32:
863
- {
864
- auto data_pts = cast_data<float>(dst);
865
- gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
866
- break;
867
- }
868
- default:
869
- GGML_ABORT("GGML tensor type not supported!\n");
870
- }
871
- }
872
-
873
-
874
- inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
875
- #if defined (GGML_SYCL_F16)
876
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
877
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
878
- #else
879
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
880
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
881
- #endif
882
- GGML_ASSERT(dst->src[0]->type == dst->type);
883
- dpct::queue_ptr main_stream = ctx.stream();
884
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
885
- switch (dst->type) {
886
- #if defined (GGML_SYCL_F16)
887
- case GGML_TYPE_F16:
888
- {
889
- auto data_pts = cast_data<sycl::half>(dst);
890
- tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
891
- break;
892
- }
893
- #endif
894
- case GGML_TYPE_F32:
895
- {
896
- auto data_pts = cast_data<float>(dst);
897
- tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
898
- break;
899
- }
900
- default:
901
- GGML_ABORT("GGML tensor type not supported!\n");
902
- }
903
- }
904
-
905
- inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
906
- #if defined (GGML_SYCL_F16)
907
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
908
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
909
- #else
910
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
911
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
912
- #endif
913
- GGML_ASSERT(dst->src[0]->type == dst->type);
914
- dpct::queue_ptr main_stream = ctx.stream();
915
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
916
-
917
- switch (dst->type) {
918
- #if defined (GGML_SYCL_F16)
919
- case GGML_TYPE_F16:
920
- {
921
- auto data_pts = cast_data<sycl::half>(dst);
922
- relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
923
- break;
924
- }
925
- #endif
926
- case GGML_TYPE_F32:
927
- {
928
- auto data_pts = cast_data<float>(dst);
929
- relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
930
- break;
931
- }
932
- default:
933
- GGML_ABORT("GGML tensor type not supported!\n");
934
- }
935
- }
936
-
937
- inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
938
- #if defined (GGML_SYCL_F16)
939
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
940
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
941
- #else
942
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
943
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
944
- #endif
945
- GGML_ASSERT(dst->src[0]->type == dst->type);
946
-
947
- dpct::queue_ptr main_stream = ctx.stream();
948
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
949
-
950
- switch (dst->type) {
951
- #if defined (GGML_SYCL_F16)
952
- case GGML_TYPE_F16:
953
- {
954
- auto data_pts = cast_data<sycl::half>(dst);
955
- hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
956
- break;
957
- }
958
- #endif
959
- case GGML_TYPE_F32:
960
- {
961
- auto data_pts = cast_data<float>(dst);
962
- hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
963
- break;
964
- }
965
- default:
966
- GGML_ABORT("GGML tensor type not supported!\n");
967
- }
968
- }
969
-
970
- inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
971
- #if defined (GGML_SYCL_F16)
972
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
973
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
974
- #else
975
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
976
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
977
- #endif
978
- GGML_ASSERT(dst->src[0]->type == dst->type);
979
- dpct::queue_ptr main_stream = ctx.stream();
980
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
981
- switch (dst->type) {
982
- #if defined (GGML_SYCL_F16)
983
- case GGML_TYPE_F16:
984
- {
985
- auto data_pts = cast_data<sycl::half>(dst);
986
- hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
987
- break;
988
- }
989
- #endif
990
- case GGML_TYPE_F32:
991
- {
992
- auto data_pts = cast_data<float>(dst);
993
- hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
994
- break;
995
- }
996
- default:
997
- GGML_ABORT("GGML tensor type not supported!\n");
998
- }
999
- }
1000
-
1001
- inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1002
- #if defined (GGML_SYCL_F16)
1003
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1004
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1005
- #else
1006
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1007
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1008
- #endif
1009
- GGML_ASSERT(dst->src[0]->type == dst->type);
1010
- dpct::queue_ptr main_stream = ctx.stream();
1011
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1012
- switch (dst->type) {
1013
- #if defined (GGML_SYCL_F16)
1014
- case GGML_TYPE_F16:
1015
- {
1016
- auto data_pts = cast_data<sycl::half>(dst);
1017
- exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1018
- break;
1019
- }
1020
- #endif
1021
- case GGML_TYPE_F32:
1022
- {
1023
- auto data_pts = cast_data<float>(dst);
1024
- exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1025
- break;
1026
- }
1027
- default:
1028
- GGML_ABORT("GGML tensor type not supported!\n");
1029
- }
1030
- }
1031
-
1032
- inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1033
- #if defined (GGML_SYCL_F16)
1034
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1035
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1036
- #else
1037
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1038
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1039
- #endif
1040
- GGML_ASSERT(dst->src[0]->type == dst->type);
1041
- dpct::queue_ptr main_stream = ctx.stream();
1042
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1043
- switch (dst->type) {
1044
- #if defined (GGML_SYCL_F16)
1045
- case GGML_TYPE_F16:
1046
- {
1047
- auto data_pts = cast_data<sycl::half>(dst);
1048
- log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1049
- break;
1050
- }
1051
- #endif
1052
- case GGML_TYPE_F32:
1053
- {
1054
- auto data_pts = cast_data<float>(dst);
1055
- log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1056
- break;
1057
- }
1058
- default:
1059
- GGML_ABORT("GGML tensor type not supported!\n");
1060
- }
1061
- }
1062
-
1063
- inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1064
- #if defined (GGML_SYCL_F16)
1065
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1066
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1067
- #else
1068
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1069
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1070
- #endif
1071
- GGML_ASSERT(dst->src[0]->type == dst->type);
1072
- dpct::queue_ptr main_stream = ctx.stream();
1073
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1074
- switch (dst->type) {
1075
- #if defined (GGML_SYCL_F16)
1076
- case GGML_TYPE_F16:
1077
- {
1078
- auto data_pts = cast_data<sycl::half>(dst);
1079
- sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1080
- break;
1081
- }
1082
- #endif
1083
- case GGML_TYPE_F32:
1084
- {
1085
- auto data_pts = cast_data<float>(dst);
1086
- sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1087
- break;
1088
- }
1089
- default:
1090
- GGML_ABORT("GGML tensor type not supported!\n");
1091
- }
1092
- }
1093
-
1094
- inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1095
- #if defined (GGML_SYCL_F16)
1096
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1097
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1098
- #else
1099
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1100
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1101
- #endif
1102
- GGML_ASSERT(dst->src[0]->type == dst->type);
1103
-
1104
- dpct::queue_ptr main_stream = ctx.stream();
1105
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1106
- switch (dst->type) {
1107
- #if defined (GGML_SYCL_F16)
1108
- case GGML_TYPE_F16:
1109
- {
1110
- auto data_pts = cast_data<sycl::half>(dst);
1111
- sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1112
- break;
1113
- }
1114
- #endif
1115
- case GGML_TYPE_F32:
1116
- {
1117
- auto data_pts = cast_data<float>(dst);
1118
- sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1119
- break;
1120
- }
1121
- default:
1122
- GGML_ABORT("GGML tensor type not supported!\n");
1123
- }
651
+ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
652
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
653
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
654
+ const int num_blocks = ceil_div(k_elements, 256);
655
+ sycl_parallel_for(stream,
656
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
657
+ sycl::range<1>(256)),
658
+ [=](sycl::nd_item<1> item_ct1) {
659
+ unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
660
+ });
661
+ });
1124
662
  }
1125
663
 
1126
- inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1127
- #if defined (GGML_SYCL_F16)
1128
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1129
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1130
- #else
1131
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1132
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1133
- #endif
1134
- GGML_ASSERT(dst->src[0]->type == dst->type);
1135
- dpct::queue_ptr main_stream = ctx.stream();
1136
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1137
- switch (dst->type) {
1138
- #if defined (GGML_SYCL_F16)
1139
- case GGML_TYPE_F16:
1140
- {
1141
- auto data_pts = cast_data<sycl::half>(dst);
1142
- sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1143
- break;
1144
- }
1145
- #endif
1146
- case GGML_TYPE_F32:
1147
- {
1148
- auto data_pts = cast_data<float>(dst);
1149
- sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1150
- break;
1151
- }
1152
- default:
1153
- GGML_ABORT("GGML tensor type not supported!\n");
1154
- }
664
+ static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
665
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
666
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
667
+ const int num_blocks = ceil_div(k_elements, 256);
668
+ sycl_parallel_for(stream,
669
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
670
+ sycl::range<1>(256)),
671
+ [=](sycl::nd_item<1> item_ct1) {
672
+ unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
673
+ });
674
+ });
1155
675
  }
1156
676
 
1157
- inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1158
- #if defined (GGML_SYCL_F16)
1159
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1160
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1161
- #else
1162
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1163
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1164
- #endif
1165
- GGML_ASSERT(dst->src[0]->type == dst->type);
1166
- dpct::queue_ptr main_stream = ctx.stream();
1167
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1168
- switch (dst->type) {
1169
- #if defined (GGML_SYCL_F16)
1170
- case GGML_TYPE_F16:
1171
- {
1172
- auto data_pts = cast_data<sycl::half>(dst);
1173
- cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1174
- break;
1175
- }
1176
- #endif
1177
- case GGML_TYPE_F32:
1178
- {
1179
- auto data_pts = cast_data<float>(dst);
1180
- cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1181
- break;
1182
- }
1183
- default:
1184
- GGML_ABORT("GGML tensor type not supported!\n");
1185
- }
677
+ static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
678
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
679
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
680
+ const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
681
+ sycl_parallel_for(stream,
682
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
683
+ sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
684
+ [=](sycl::nd_item<1> item_ct1) {
685
+ unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
686
+ });
687
+ });
1186
688
  }
1187
689
 
1188
- inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1189
- #if defined (GGML_SYCL_F16)
1190
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1191
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1192
- #else
1193
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1194
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1195
- #endif
1196
- GGML_ASSERT(dst->src[0]->type == dst->type);
1197
- dpct::queue_ptr main_stream = ctx.stream();
1198
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1199
- switch (dst->type) {
1200
- #if defined (GGML_SYCL_F16)
1201
- case GGML_TYPE_F16:
1202
- {
1203
- auto data_pts = cast_data<sycl::half>(dst);
1204
- step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1205
- break;
1206
- }
1207
- #endif
1208
- case GGML_TYPE_F32:
1209
- {
1210
- auto data_pts = cast_data<float>(dst);
1211
- step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1212
- break;
1213
- }
1214
- default:
1215
- GGML_ABORT("GGML tensor type not supported!\n");
1216
- }
690
+ static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
691
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
692
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
693
+ const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
694
+ sycl_parallel_for(stream,
695
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
696
+ sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
697
+ [=](sycl::nd_item<1> item_ct1) {
698
+ unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
699
+ });
700
+ });
1217
701
  }
1218
702
 
1219
- inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1220
- #if defined (GGML_SYCL_F16)
1221
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1222
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1223
- #else
1224
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1225
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1226
- #endif
1227
- GGML_ASSERT(dst->src[0]->type == dst->type);
1228
- dpct::queue_ptr main_stream = ctx.stream();
1229
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1230
- switch (dst->type) {
1231
- #if defined (GGML_SYCL_F16)
1232
- case GGML_TYPE_F16:
1233
- {
1234
- auto data_pts = cast_data<sycl::half>(dst);
1235
- neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1236
- break;
1237
- }
1238
- #endif
1239
- case GGML_TYPE_F32:
1240
- {
1241
- auto data_pts = cast_data<float>(dst);
1242
- neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1243
- break;
1244
- }
1245
- default:
1246
- GGML_ABORT("GGML tensor type not supported!\n");
1247
- }
703
+ static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
704
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
705
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
706
+ const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
707
+ sycl_parallel_for(stream,
708
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
709
+ sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
710
+ [=](sycl::nd_item<1> item_ct1) {
711
+ unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
712
+ });
713
+ });
1248
714
  }
1249
715
 
1250
- inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1251
- #if defined (GGML_SYCL_F16)
1252
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1253
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1254
- #else
1255
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1256
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1257
- #endif
716
+ static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
717
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
718
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
719
+ const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
720
+ sycl_parallel_for(stream,
721
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
722
+ sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
723
+ [=](sycl::nd_item<1> item_ct1) {
724
+ unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
725
+ });
726
+ });
727
+ }
1258
728
 
1259
- GGML_ASSERT(dst->src[0]->type == dst->type);
1260
- float negative_slope;
1261
- memcpy(&negative_slope, dst->op_params, sizeof(float));
1262
- dpct::queue_ptr main_stream = ctx.stream();
1263
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1264
- switch (dst->type) {
1265
- #if defined (GGML_SYCL_F16)
1266
- case GGML_TYPE_F16:
1267
- {
1268
- auto data_pts = cast_data<sycl::half>(dst);
1269
- leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
1270
- break;
1271
- }
1272
- #endif
1273
- case GGML_TYPE_F32:
1274
- {
1275
- auto data_pts = cast_data<float>(dst);
1276
- leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
1277
- break;
1278
- }
1279
- default:
1280
- GGML_ABORT("GGML tensor type not supported!\n");
1281
- }
729
+ static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
730
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
731
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
732
+ const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
733
+ sycl_parallel_for(stream,
734
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
735
+ sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
736
+ [=](sycl::nd_item<1> item_ct1) {
737
+ unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
738
+ });
739
+ });
1282
740
  }
1283
741
 
1284
- inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1285
- #if defined (GGML_SYCL_F16)
1286
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1287
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1288
- #else
1289
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1290
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1291
- #endif
1292
- GGML_ASSERT(dst->src[0]->type == dst->type);
1293
- dpct::queue_ptr main_stream = ctx.stream();
1294
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1295
- switch (dst->type) {
1296
- #if defined (GGML_SYCL_F16)
1297
- case GGML_TYPE_F16:
1298
- {
1299
- auto data_pts = cast_data<sycl::half>(dst);
1300
- sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1301
- break;
1302
- }
1303
- #endif
1304
- case GGML_TYPE_F32:
1305
- {
1306
- auto data_pts = cast_data<float>(dst);
1307
- sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1308
- break;
1309
- }
1310
- default:
1311
- GGML_ABORT("GGML tensor type not supported!\n");
1312
- }
742
+ static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
743
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
744
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
745
+ const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
746
+ sycl_parallel_for(stream,
747
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
748
+ sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
749
+ [=](sycl::nd_item<1> item_ct1) {
750
+ unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
751
+ });
752
+ });
1313
753
  }
1314
754
 
1315
- inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1316
- #if defined (GGML_SYCL_F16)
1317
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1318
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1319
- #else
1320
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1321
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1322
- #endif
1323
- GGML_ASSERT(dst->src[0]->type == dst->type);
755
+ static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
756
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
757
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
758
+ const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
759
+ sycl_parallel_for(stream,
760
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
761
+ sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
762
+ [=](sycl::nd_item<1> item_ct1) {
763
+ unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
764
+ });
765
+ });
766
+ }
1324
767
 
1325
- dpct::queue_ptr main_stream = ctx.stream();
1326
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
768
+ static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
769
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
770
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
771
+ const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
772
+ sycl_parallel_for(stream,
773
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
774
+ sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
775
+ [=](sycl::nd_item<1> item_ct1) {
776
+ unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
777
+ });
778
+ });
779
+ }
1327
780
 
1328
- const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
1329
- const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
1330
- const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
1331
- const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
1332
- switch (dst->type) {
1333
- #if defined (GGML_SYCL_F16)
1334
- case GGML_TYPE_F16:
1335
- {
1336
- auto data_pts = cast_data<sycl::half>(dst);
1337
- upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
1338
- dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
1339
- main_stream);
1340
- break;
1341
- }
1342
- #endif
1343
- case GGML_TYPE_F32:
1344
- {
1345
- auto data_pts = cast_data<float>(dst);
1346
- upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
1347
- dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
1348
- main_stream);
1349
- break;
1350
- }
1351
- default:
1352
- GGML_ABORT("GGML tensor type not supported!\n");
1353
- }
781
+ static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
782
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
783
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
784
+ const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
785
+ sycl_parallel_for(stream,
786
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
787
+ sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
788
+ [=](sycl::nd_item<1> item_ct1) {
789
+ unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
790
+ });
791
+ });
1354
792
  }
1355
793
 
1356
- inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1357
- #if defined (GGML_SYCL_F16)
1358
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1359
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1360
- #else
1361
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1362
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1363
- #endif
1364
- GGML_ASSERT(dst->src[0]->type == dst->type);
1365
- GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
1366
- dpct::queue_ptr main_stream = ctx.stream();
1367
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1368
- switch (dst->type) {
1369
- #if defined (GGML_SYCL_F16)
1370
- case GGML_TYPE_F16:
1371
- {
1372
- auto data_pts = cast_data<sycl::half>(dst);
1373
- pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
1374
- dst->ne[1], dst->ne[2], main_stream);
1375
- break;
1376
- }
1377
- #endif
1378
- case GGML_TYPE_F32:
1379
- {
1380
- auto data_pts = cast_data<float>(dst);
1381
- pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
1382
- dst->ne[1], dst->ne[2], main_stream);
1383
- break;
1384
- }
1385
- default:
1386
- GGML_ABORT("GGML tensor type not supported!\n");
1387
- }
794
+ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
795
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
796
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
797
+ const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
798
+ sycl_parallel_for(stream,
799
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
800
+ sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
801
+ [=](sycl::nd_item<1> item_ct1) {
802
+ unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
803
+ });
804
+ });
1388
805
  }
1389
806
 
1390
- inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1391
- #if defined(GGML_SYCL_F16)
1392
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1393
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1394
- #else
807
+ static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
808
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
809
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
810
+ const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
811
+ sycl_parallel_for(stream,
812
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
813
+ sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
814
+ [=](sycl::nd_item<1> item_ct1) {
815
+ unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
816
+ });
817
+ });
818
+ }
1395
819
 
1396
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1397
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1398
- #endif
1399
- GGML_ASSERT(dst->src[0]->type == dst->type);
1400
- dpct::queue_ptr main_stream = ctx.stream();
1401
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1402
- float min;
1403
- float max;
1404
- memcpy(&min, dst->op_params, sizeof(float));
1405
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
820
+ static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
821
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
822
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
823
+ const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
824
+ sycl_parallel_for(stream,
825
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
826
+ sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
827
+ [=](sycl::nd_item<1> item_ct1) {
828
+ unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
829
+ });
830
+ });
831
+ }
1406
832
 
1407
- switch (dst->type) {
1408
- #if defined(GGML_SYCL_F16)
1409
- case GGML_TYPE_F16:
1410
- {
1411
- auto data_pts = cast_data<sycl::half>(dst);
1412
- clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream);
1413
- break;
1414
- }
1415
- #endif
1416
- case GGML_TYPE_F32:
1417
- {
1418
- auto data_pts = cast_data<float>(dst);
1419
- clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream);
1420
- break;
1421
- }
1422
- default:
1423
- GGML_ABORT("GGML tensor type not supported!\n");
1424
- }
833
+ static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
834
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
835
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
836
+ const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
837
+ sycl_parallel_for(stream,
838
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
839
+ sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
840
+ [=](sycl::nd_item<1> item_ct1) {
841
+ unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
842
+ });
843
+ });
844
+ }
845
+
846
+ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
847
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
848
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
849
+ const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
850
+ sycl_parallel_for(stream,
851
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
852
+ sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
853
+ [=](sycl::nd_item<1> item_ct1) {
854
+ unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
855
+ });
856
+ });
857
+ }
858
+
859
+ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
860
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
861
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
862
+ const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
863
+ sycl_parallel_for(stream,
864
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
865
+ sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
866
+ [=](sycl::nd_item<1> item_ct1) {
867
+ unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
868
+ });
869
+ });
1425
870
  }
1426
871
 
1427
- inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
872
+ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
873
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
874
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
875
+ const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
876
+ sycl_parallel_for(stream,
877
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
878
+ sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
879
+ [=](sycl::nd_item<1> item_ct1) {
880
+ unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
881
+ });
882
+ });
883
+ }
884
+
885
+ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
886
+ float negative_slope;
887
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
888
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
889
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
890
+ const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
891
+ sycl_parallel_for(stream,
892
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
893
+ sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
894
+ [=](sycl::nd_item<1> item_ct1) {
895
+ unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
896
+ });
897
+ }, negative_slope);
898
+ }
899
+
900
+ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
901
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
902
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
903
+ const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
904
+ sycl_parallel_for(stream,
905
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
906
+ sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
907
+ [=](sycl::nd_item<1> item_ct1) {
908
+ unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
909
+ });
910
+ });
911
+ }
912
+
913
+ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
914
+ ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,
915
+ [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,
916
+ int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,
917
+ queue_ptr stream) {
918
+ ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);
919
+ });
920
+ }
921
+
922
+ static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
923
+ ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
924
+ [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
925
+ queue_ptr stream) {
926
+ ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
927
+ });
928
+ }
1428
929
 
930
+ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
931
+ float min_val;
932
+ float max_val;
933
+ memcpy(&min_val, dst->op_params, sizeof(float));
934
+ memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float));
935
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
936
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
937
+ const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
938
+ sycl_parallel_for(stream,
939
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
940
+ sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
941
+ [=](sycl::nd_item<1> item_ct1) {
942
+ clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
943
+ });
944
+ }, min_val, max_val);
945
+ }
946
+
947
+ static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
1429
948
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1430
949
  GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
1431
950
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -1441,7 +960,62 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
1441
960
  // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
1442
961
  int offset = dst->op_params[3] / 4; // offset in bytes
1443
962
 
1444
- acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
963
+ ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
964
+ }
965
+
966
+ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
967
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
968
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
969
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
970
+ sycl_parallel_for(main_stream,
971
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
972
+ gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
973
+ });
974
+ });
975
+ }
976
+
977
+ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
978
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
979
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
980
+ const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
981
+ sycl_parallel_for(main_stream,
982
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
983
+ gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
984
+ });
985
+ });
986
+ }
987
+
988
+ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
989
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
990
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
991
+ const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
992
+ sycl_parallel_for(main_stream,
993
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
994
+ gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
995
+ });
996
+ });
997
+ }
998
+
999
+ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1000
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1001
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1002
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1003
+ sycl_parallel_for(main_stream,
1004
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1005
+ gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1006
+ });
1007
+ });
1008
+ }
1009
+
1010
+ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1011
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1012
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1013
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1014
+ sycl_parallel_for(main_stream,
1015
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1016
+ gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1017
+ });
1018
+ });
1445
1019
  }
1446
1020
 
1447
1021
 
@@ -1569,3 +1143,28 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1569
1143
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1570
1144
  ggml_sycl_op_elu(ctx, dst);
1571
1145
  }
1146
+
1147
+ void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1148
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1149
+ ggml_sycl_op_geglu(ctx, dst);
1150
+ }
1151
+
1152
+ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1153
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1154
+ ggml_sycl_op_reglu(ctx, dst);
1155
+ }
1156
+
1157
+ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1158
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1159
+ ggml_sycl_op_swiglu(ctx, dst);
1160
+ }
1161
+
1162
+ void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1163
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1164
+ ggml_sycl_op_geglu_erf(ctx, dst);
1165
+ }
1166
+
1167
+ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1168
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1169
+ ggml_sycl_op_geglu_quick(ctx, dst);
1170
+ }