@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -13,7 +13,9 @@
13
13
  #include "types.comp"
14
14
  #include "flash_attn_base.comp"
15
15
 
16
- const uint32_t D_per_thread = D / D_split;
16
+ const uint32_t HSK_per_thread = HSK / D_split;
17
+ const uint32_t HSV_per_thread = HSV / D_split;
18
+
17
19
  const uint32_t row_split = 4;
18
20
  const uint32_t rows_per_thread = Br / row_split;
19
21
  const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
@@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
32
34
  // Rows index by Q's dimension 2, and the first N rows are valid.
33
35
  D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
34
36
  {
35
- uint32_t offset = (iq2 + r) * D + c;
37
+ uint32_t offset = (iq2 + r) * HSV + c;
36
38
  data_o[o_offset + offset] = D_TYPE(elem);
37
39
  return elem;
38
40
  }
@@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
44
46
  shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
45
47
  shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
46
48
 
47
- const uint32_t qstride = D / 4 + 2; // in units of f16vec4
49
+ const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
48
50
  shared f16vec4 Qf[Br * qstride];
49
51
 
50
- // Avoid padding for D==256 to make it fit in 48KB shmem.
51
- const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
52
+ // Avoid padding for hsk==256 to make it fit in 48KB shmem.
53
+ const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
52
54
  shared ACC_TYPE sfsh[Bc * sfshstride];
53
55
 
54
- const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
56
+ const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
55
57
  shared f16vec4 ksh[Bc * kshstride];
56
58
 
57
59
  shared float slope[Br];
@@ -74,18 +76,18 @@ void main() {
74
76
 
75
77
  uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
76
78
 
77
- [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
78
- uint32_t d = (idx + tid) % (D / 4);
79
- uint32_t r = (idx + tid) / (D / 4);
80
- if (r < Br && d < D / 4 &&
79
+ [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
80
+ uint32_t d = (idx + tid) % (HSK / 4);
81
+ uint32_t r = (idx + tid) / (HSK / 4);
82
+ if (r < Br && d < HSK / 4 &&
81
83
  i * Br + r < N) {
82
84
  Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
83
85
  }
84
86
  }
85
87
  barrier();
86
88
 
87
- ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
88
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
89
+ ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
90
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
89
91
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
90
92
  Of[r][d] = ACC_TYPEV4(0.0);
91
93
  }
@@ -123,14 +125,18 @@ void main() {
123
125
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
124
126
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
125
127
  #endif
128
+ uint32_t m_offset = 0;
129
+ if (p.nem2 != 1 || p.nem3 != 1) {
130
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
131
+ }
126
132
 
127
133
  [[dont_unroll]]
128
134
  for (uint32_t j = start_j; j < end_j; ++j) {
129
135
 
130
- [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
131
- uint32_t d = (idx + tid) % (D / 4);
132
- uint32_t c = (idx + tid) / (D / 4);
133
- if (c < Bc && d < D / 4) {
136
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
137
+ uint32_t d = (idx + tid) % (HSK / 4);
138
+ uint32_t c = (idx + tid) / (HSK / 4);
139
+ if (c < Bc && d < HSK / 4) {
134
140
  #if BLOCK_SIZE > 1
135
141
  uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
136
142
  uint ib = coord / BLOCK_SIZE;
@@ -145,14 +151,14 @@ void main() {
145
151
  }
146
152
  barrier();
147
153
 
148
- // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
149
- // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
154
+ // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
155
+ // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
150
156
  // This is written transposed in order to allow for N being 8 if implementations need it
151
157
  coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
152
158
  coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
153
159
  coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
154
160
 
155
- for (uint32_t d = 0; d < D / 16; ++d) {
161
+ for (uint32_t d = 0; d < HSK / 16; ++d) {
156
162
  coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
157
163
 
158
164
  uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
@@ -176,12 +182,12 @@ void main() {
176
182
  barrier();
177
183
  }
178
184
 
179
- if (p.mask != 0) {
185
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
180
186
  [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
181
187
  uint32_t c = (idx + tid) % Bc;
182
188
  uint32_t r = (idx + tid) / Bc;
183
189
  if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
184
- sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
190
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
185
191
  }
186
192
  }
187
193
  barrier();
@@ -202,7 +208,7 @@ void main() {
202
208
  eMf[r] = exp(Moldf - Mf[r]);
203
209
  }
204
210
 
205
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
211
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
206
212
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
207
213
  Of[r][d] = float16_t(eMf[r]) * Of[r][d];
208
214
  }
@@ -217,7 +223,7 @@ void main() {
217
223
  Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
218
224
  Lf[r] += Pf[r];
219
225
  }
220
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
226
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
221
227
  #if BLOCK_SIZE > 1
222
228
  uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
223
229
  uint ib = coord / BLOCK_SIZE;
@@ -280,7 +286,7 @@ void main() {
280
286
  }
281
287
 
282
288
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
283
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
289
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
284
290
 
285
291
  Of[r][d] = float16_t(eMf[r]) * Of[r][d];
286
292
  tmpshv4[tid] = Of[r][d];
@@ -300,11 +306,11 @@ void main() {
300
306
  // If there is split_k, then the split_k resolve shader does the final
301
307
  // division by L. Store the intermediate O value and per-row m and L values.
302
308
  if (p.k_num > 1) {
303
- uint32_t o_offset = D * p.ne1 * split_k_index;
309
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
304
310
 
305
311
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
306
312
  if (tile_row(r) < N) {
307
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
313
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
308
314
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
309
315
  perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
310
316
  }
@@ -312,7 +318,7 @@ void main() {
312
318
  }
313
319
  }
314
320
 
315
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
321
+ o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
316
322
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
317
323
  if (tile_row(r) < N) {
318
324
  perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -328,18 +334,18 @@ void main() {
328
334
  Lfrcp[r] = 1.0 / Lf[r];
329
335
  }
330
336
 
331
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
337
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
332
338
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
333
339
  Of[r][d] *= float16_t(Lfrcp[r]);
334
340
  }
335
341
  }
336
342
 
337
- uint32_t o_offset = iq3*p.ne2*p.ne1;
343
+ uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
338
344
 
339
345
  if (p.gqa_ratio > 1) {
340
346
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
341
347
  if (tile_row(r) < N) {
342
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
348
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
343
349
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
344
350
  perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
345
351
  }
@@ -349,9 +355,9 @@ void main() {
349
355
  } else {
350
356
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
351
357
  if (i * Br + tile_row(r) < N) {
352
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
358
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
353
359
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
354
- data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
360
+ data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
355
361
  }
356
362
  }
357
363
  }
@@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
61
61
  // Rows index by Q's dimension 2, and the first N rows are valid.
62
62
  D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
63
63
  {
64
- if (r < N && c < D) {
65
- uint32_t offset = (iq2 + r) * D + c;
64
+ if (r < N && c < HSV) {
65
+ uint32_t offset = (iq2 + r) * HSV + c;
66
66
  data_o[o_offset + offset] = D_TYPE(elem);
67
67
  }
68
68
  return elem;
@@ -86,9 +86,9 @@ void main() {
86
86
  tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
87
87
  #endif
88
88
 
89
- tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
90
- tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
91
- tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
89
+ tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
90
+ tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
91
+ tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
92
92
 
93
93
  // hint to the compiler that strides are aligned for the aligned variant of the shader
94
94
  if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
@@ -104,16 +104,16 @@ void main() {
104
104
  tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
105
105
  tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
106
106
 
107
- coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
108
- coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
107
+ coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
108
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
109
109
 
110
110
  uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
111
- coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
111
+ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
112
112
 
113
- Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
113
+ Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
114
114
  Qf16 *= float16_t(p.scale);
115
115
 
116
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
116
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
117
117
 
118
118
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
119
119
 
@@ -130,15 +130,20 @@ void main() {
130
130
  coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
131
131
  }
132
132
 
133
+ uint32_t m_offset = 0;
134
+ if (p.nem2 != 1 || p.nem3 != 1) {
135
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
136
+ }
137
+
133
138
  [[dont_unroll]]
134
139
  for (uint32_t j = start_j; j < end_j; ++j) {
135
140
 
136
141
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
137
142
 
138
- coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
143
+ coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
139
144
 
140
145
  uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
141
- coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
146
+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
142
147
  S = coopMatMulAdd(Qf16, K_T, S);
143
148
 
144
149
  if (p.logit_softcap != 0.0f) {
@@ -148,14 +153,14 @@ void main() {
148
153
  }
149
154
  }
150
155
 
151
- if (p.mask != 0) {
156
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
152
157
  tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
153
158
  tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
154
159
  tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
155
160
 
156
161
  coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
157
162
 
158
- coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
163
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
159
164
 
160
165
  S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
161
166
  }
@@ -203,42 +208,42 @@ void main() {
203
208
  rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
204
209
  rowsum = coopMatMulAdd(P_A, One, rowsum);
205
210
 
206
- coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
211
+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
207
212
  uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
208
- coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
213
+ coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
209
214
 
210
215
  L = eM*L + rowsum;
211
216
 
212
217
  // This is the "diagonal" matrix in the paper, but since we do componentwise
213
218
  // multiply rather than matrix multiply it has the diagonal element smeared
214
219
  // across the row
215
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
220
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
216
221
 
217
222
  // resize eM by using smear/reduce
218
223
  coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
219
224
 
220
225
  // multiply with fp16 accumulation, then add to O.
221
- coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
226
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
222
227
  PV = coopMatMulAdd(P_A, V, PV);
223
228
 
224
- O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
229
+ O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
225
230
  }
226
231
 
227
232
  // If there is split_k, then the split_k resolve shader does the final
228
233
  // division by L. Store the intermediate O value and per-row m and L values.
229
234
  if (p.k_num > 1) {
230
- coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
235
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
231
236
 
232
- uint32_t o_offset = D * p.ne1 * split_k_index;
237
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
233
238
  coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
234
239
 
235
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
240
+ o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
236
241
  coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
237
242
  coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
238
243
  return;
239
244
  }
240
245
 
241
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
246
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
242
247
 
243
248
  // resize L by using smear/reduce
244
249
  coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
@@ -250,18 +255,18 @@ void main() {
250
255
 
251
256
  O = Ldiag*O;
252
257
 
253
- uint32_t o_offset = iq3*p.ne2*p.ne1;
258
+ uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
254
259
 
255
- coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
260
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
256
261
  if (p.gqa_ratio > 1) {
257
262
  coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
258
263
  } else {
259
264
  tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
260
- tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
265
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
261
266
 
262
267
  // permute dimensions
263
268
  tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
264
269
 
265
- coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
270
+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
266
271
  }
267
272
  }
@@ -2,9 +2,9 @@
2
2
 
3
3
  #extension GL_EXT_control_flow_attributes : enable
4
4
 
5
- #define BLOCK_SIZE 32
5
+ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
6
6
 
7
- layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
7
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
8
 
9
9
  layout (binding = 0) readonly buffer A {float data_a[];};
10
10
  layout (binding = 1) writeonly buffer D {float data_d[];};
@@ -12,48 +12,80 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
12
12
  layout (push_constant) uniform parameter {
13
13
  uint D;
14
14
  uint N;
15
+ uint ne3;
15
16
  uint k_num;
16
17
  } p;
17
18
 
19
+ shared float tmpsh[BLOCK_SIZE];
20
+
18
21
  void main() {
19
22
  // Each workgroup handles a row
20
23
  const uint n = gl_WorkGroupID.x;
21
24
  const uint tid = gl_LocalInvocationID.x;
25
+ const uint iq3 = gl_WorkGroupID.z;
22
26
 
23
27
  uint D = p.D;
24
28
  uint N = p.N;
25
29
  uint k_num = p.k_num;
26
30
 
27
- uint l_offset = D * N * k_num + n;
28
- uint m_offset = D * N * k_num + N + n;
31
+ uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
32
+ uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
29
33
  uint lm_stride = N * 2;
30
34
 
31
35
  // Compute the max m value for the row
32
36
  float m_max = -1.0/0.0;
33
- [[unroll]] for (uint k = 0; k < k_num; ++k) {
34
- float m = data_a[m_offset + k * lm_stride];
37
+ for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
38
+ float m = data_a[m_offset + (k + tid) * lm_stride];
35
39
  m_max = max(m_max, m);
36
40
  }
37
41
 
42
+ // reduce across the workgroup
43
+ tmpsh[tid] = m_max;
44
+ barrier();
45
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
46
+ if (tid < s) {
47
+ m_max = max(m_max, tmpsh[tid + s]);
48
+ tmpsh[tid] = m_max;
49
+ }
50
+ barrier();
51
+ }
52
+ m_max = tmpsh[0];
53
+
54
+ barrier();
55
+
38
56
  // Compute L based on m_max
39
57
  float L = 0;
40
- [[unroll]] for (uint k = 0; k < k_num; ++k) {
41
- float l = data_a[l_offset + k * lm_stride];
42
- float m = data_a[m_offset + k * lm_stride];
58
+ for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
59
+ float l = data_a[l_offset + (k + tid) * lm_stride];
60
+ float m = data_a[m_offset + (k + tid) * lm_stride];
43
61
  L += exp(m - m_max) * l;
44
62
  }
45
63
 
64
+ // reduce across the workgroup
65
+ tmpsh[tid] = L;
66
+ barrier();
67
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
68
+ if (tid < s) {
69
+ L += tmpsh[tid + s];
70
+ tmpsh[tid] = L;
71
+ }
72
+ barrier();
73
+ }
74
+ L = tmpsh[0];
75
+
46
76
  L = 1.0 / L;
47
77
 
78
+ // D dimension is split across workgroups in the y dimension
79
+ uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
48
80
  // Scale and sum the O contributions based on m_max and store the result to memory
49
- for (uint d = tid; d < D; d += BLOCK_SIZE) {
81
+ if (d < D) {
50
82
  float O = 0.0;
51
83
  [[unroll]] for (uint k = 0; k < k_num; ++k) {
52
- uint o_offset = D * N * k + D * n + d;
84
+ uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
53
85
  float m = data_a[m_offset + k * lm_stride];
54
86
  O += exp(m - m_max) * data_a[o_offset];
55
87
  }
56
88
  O *= L;
57
- data_d[D * n + d] = O;
89
+ data_d[iq3 * D * N + D * n + d] = O;
58
90
  }
59
91
  }
@@ -0,0 +1,13 @@
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ const float GELU_COEF_A = 0.044715f;
6
+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
7
+
8
+ float op(float a, float b) {
9
+ const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
10
+ return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
11
+ }
12
+
13
+ #include "glu_main.comp"
@@ -0,0 +1,27 @@
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
6
+ // ref: https://www.johndcook.com/blog/python_erf/
7
+ const float p_erf = 0.3275911f;
8
+ const float a1_erf = 0.254829592f;
9
+ const float a2_erf = -0.284496736f;
10
+ const float a3_erf = 1.421413741f;
11
+ const float a4_erf = -1.453152027f;
12
+ const float a5_erf = 1.061405429f;
13
+
14
+ const float SQRT_2_INV = 0.70710678118654752440084436210484f;
15
+
16
+ float op(float a, float b) {
17
+ const float a_div_sqr2 = a * SQRT_2_INV;
18
+ const float sign_x = sign(a_div_sqr2);
19
+ const float x = abs(a_div_sqr2);
20
+ const float t = 1.0f / (1.0f + p_erf * x);
21
+ const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
22
+ const float erf_approx = sign_x * y;
23
+
24
+ return 0.5f * a * (1.0f + erf_approx) * b;
25
+ }
26
+
27
+ #include "glu_main.comp"
@@ -0,0 +1,11 @@
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ const float GELU_QUICK_COEF = -1.702f;
6
+
7
+ float op(float a, float b) {
8
+ return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
9
+ }
10
+
11
+ #include "glu_main.comp"
@@ -0,0 +1,39 @@
1
+ #version 450
2
+
3
+ #include "generic_head.comp"
4
+ #include "types.comp"
5
+
6
+ #extension GL_EXT_control_flow_attributes : enable
7
+
8
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9
+
10
+ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11
+ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12
+
13
+ void main() {
14
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
15
+ // ref: https://www.johndcook.com/blog/python_erf/
16
+ const float p_erf = 0.3275911f;
17
+ const float a1_erf = 0.254829592f;
18
+ const float a2_erf = -0.284496736f;
19
+ const float a3_erf = 1.421413741f;
20
+ const float a4_erf = -1.453152027f;
21
+ const float a5_erf = 1.061405429f;
22
+
23
+ const float SQRT_2_INV = 0.70710678118654752440084436210484f;
24
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
25
+
26
+ if (i >= p.KX) {
27
+ return;
28
+ }
29
+
30
+ const float a = float(data_a[i]);
31
+ const float a_div_sqr2 = a * SQRT_2_INV;
32
+ const float sign_x = sign(a_div_sqr2);
33
+ const float x = abs(a_div_sqr2);
34
+ const float t = 1.0f / (1.0f + p_erf * x);
35
+ const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
36
+ const float erf_approx = sign_x * y;
37
+
38
+ data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
39
+ }
@@ -1,6 +1,8 @@
1
1
  #extension GL_EXT_shader_16bit_storage : require
2
2
  #extension GL_EXT_control_flow_attributes : require
3
3
 
4
+ #include "rte.comp"
5
+
4
6
  layout (push_constant) uniform parameter
5
7
  {
6
8
  uint ne;
@@ -0,0 +1,17 @@
1
+ #extension GL_EXT_shader_16bit_storage : require
2
+
3
+ #include "rte.comp"
4
+
5
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
6
+
7
+ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
8
+ layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
9
+ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
10
+
11
+ layout (push_constant) uniform parameter
12
+ {
13
+ uint N;
14
+ uint ne00;
15
+ uint ne20;
16
+ uint mode;
17
+ } p;
@@ -0,0 +1,29 @@
1
+ void main() {
2
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
3
+
4
+ if (i >= p.N) {
5
+ return;
6
+ }
7
+
8
+ const uint row = i / p.ne20;
9
+ const uint col = i - row * p.ne20;
10
+
11
+ if (p.mode == 0) {
12
+ // Default
13
+ const uint offset = p.ne00 / 2;
14
+ const uint idx = row * p.ne00 + col;
15
+
16
+ data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
17
+ } else if (p.mode == 1) {
18
+ // Swapped
19
+ const uint offset = p.ne00 / 2;
20
+ const uint idx = row * p.ne00 + col;
21
+
22
+ data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
23
+ } else {
24
+ // Split
25
+ const uint idx = row * p.ne00 + col;
26
+
27
+ data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
28
+ }
29
+ }
@@ -1,12 +1,9 @@
1
1
  #version 450
2
2
 
3
3
  #extension GL_EXT_shader_16bit_storage : require
4
- #extension GL_EXT_spirv_intrinsics: enable
5
4
  #extension GL_EXT_control_flow_attributes : require
6
5
 
7
- #if RTE16
8
- spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
9
- #endif
6
+ #include "rte.comp"
10
7
 
11
8
  layout (push_constant) uniform parameter
12
9
  {
@@ -43,12 +40,10 @@ void main() {
43
40
  const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
44
41
  const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
45
42
  const int oh_s1 = int(oh) * p.s1;
46
- const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
43
+ const uint ksize = p.OW * p.KH;
47
44
 
48
45
  const uint base_linear_idx = gidx * NUM_ITER;
49
46
 
50
- const uint max_ky = ksize / p.OW;
51
-
52
47
  uint current_kx = base_linear_idx / ksize;
53
48
  const uint rem = base_linear_idx - (current_kx * ksize);
54
49
  uint current_ky = rem / p.OW;
@@ -79,7 +74,7 @@ void main() {
79
74
 
80
75
  if (++current_ix == p.OW) {
81
76
  current_ix = 0;
82
- if (++current_ky == max_ky) {
77
+ if (++current_ky == p.KH) {
83
78
  current_ky = 0;
84
79
  current_kx++;
85
80
  }