llama_cpp 0.15.3 → 0.16.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +27 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +15 -1
  7. data/vendor/tmp/llama.cpp/Makefile +66 -36
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +35 -16
  131. data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
  132. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -7
  133. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  134. data/vendor/tmp/llama.cpp/ggml-metal.m +99 -35
  135. data/vendor/tmp/llama.cpp/ggml-metal.metal +146 -80
  136. data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
  137. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
  138. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +345 -227
  139. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  140. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +458 -329
  141. data/vendor/tmp/llama.cpp/ggml.c +301 -409
  142. data/vendor/tmp/llama.cpp/ggml.h +19 -23
  143. data/vendor/tmp/llama.cpp/llama.cpp +855 -651
  144. data/vendor/tmp/llama.cpp/llama.h +28 -48
  145. metadata +121 -6
  146. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  147. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  148. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  149. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
35
35
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
36
  GGML_METAL_KERNEL_TYPE_DIV,
37
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
38
42
  GGML_METAL_KERNEL_TYPE_SCALE,
39
43
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
44
  GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -168,8 +172,10 @@ enum ggml_metal_kernel_type {
168
172
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
169
173
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
170
174
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
171
- GGML_METAL_KERNEL_TYPE_ROPE_F32,
172
- GGML_METAL_KERNEL_TYPE_ROPE_F16,
175
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
173
179
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174
180
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175
181
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -184,9 +190,9 @@ enum ggml_metal_kernel_type {
184
190
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185
191
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186
192
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
193
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
188
194
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
195
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
190
196
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191
197
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192
198
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -485,6 +491,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
485
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
486
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
487
493
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
496
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
488
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
489
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
490
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -618,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
618
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
619
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
620
630
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
621
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
622
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
631
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
623
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
624
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
625
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -634,9 +646,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
634
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
635
647
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
636
648
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
649
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638
650
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
651
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
640
652
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
641
653
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
642
654
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -746,6 +758,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746
758
  case GGML_OP_ACC:
747
759
  case GGML_OP_MUL:
748
760
  case GGML_OP_DIV:
761
+ case GGML_OP_REPEAT:
749
762
  case GGML_OP_SCALE:
750
763
  case GGML_OP_CLAMP:
751
764
  case GGML_OP_SQR:
@@ -770,6 +783,15 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
770
783
  case GGML_OP_LEAKY_RELU:
771
784
  return true;
772
785
  case GGML_OP_FLASH_ATTN_EXT:
786
+ if (op->src[1]->type != GGML_TYPE_F16) {
787
+ return false;
788
+ }
789
+ if (op->src[2]->type != GGML_TYPE_F16) {
790
+ return false;
791
+ }
792
+ if (op->src[0]->ne[0] == 256) {
793
+ return false;
794
+ }
773
795
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
774
796
  case GGML_OP_MUL_MAT:
775
797
  case GGML_OP_MUL_MAT_ID:
@@ -976,10 +998,10 @@ static enum ggml_status ggml_metal_graph_compute(
976
998
  switch (dst->op) {
977
999
  case GGML_OP_CONCAT:
978
1000
  {
979
- const int64_t nb = ne00;
980
-
981
1001
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
982
1002
 
1003
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
1004
+
983
1005
  [encoder setComputePipelineState:pipeline];
984
1006
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
985
1007
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1008,7 +1030,7 @@ static enum ggml_status ggml_metal_graph_compute(
1008
1030
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1009
1031
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1010
1032
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1011
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1033
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1012
1034
 
1013
1035
  const int nth = MIN(1024, ne0);
1014
1036
 
@@ -1018,11 +1040,14 @@ static enum ggml_status ggml_metal_graph_compute(
1018
1040
  case GGML_OP_MUL:
1019
1041
  case GGML_OP_DIV:
1020
1042
  {
1043
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1044
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1045
+
1021
1046
  const size_t offs = 0;
1022
1047
 
1023
1048
  bool bcast_row = false;
1024
1049
 
1025
- int64_t nb = ne00;
1050
+ int64_t nb = ne00; // used by the "row" kernels
1026
1051
 
1027
1052
  id<MTLComputePipelineState> pipeline = nil;
1028
1053
 
@@ -1091,6 +1116,42 @@ static enum ggml_status ggml_metal_graph_compute(
1091
1116
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1092
1117
  }
1093
1118
  } break;
1119
+ case GGML_OP_REPEAT:
1120
+ {
1121
+ id<MTLComputePipelineState> pipeline;
1122
+
1123
+ switch (src0t) {
1124
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1125
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1126
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1127
+ case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1128
+ default: GGML_ASSERT(false);
1129
+ }
1130
+
1131
+ [encoder setComputePipelineState:pipeline];
1132
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1133
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1134
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1135
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1136
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1137
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1138
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1139
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1140
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1141
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1142
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1143
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1144
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1145
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1146
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1147
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1148
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1149
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1150
+
1151
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1152
+
1153
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1154
+ } break;
1094
1155
  case GGML_OP_ACC:
1095
1156
  {
1096
1157
  GGML_ASSERT(src0t == GGML_TYPE_F32);
@@ -1468,7 +1529,6 @@ static enum ggml_status ggml_metal_graph_compute(
1468
1529
  {
1469
1530
  GGML_ASSERT(ne00 == ne10);
1470
1531
 
1471
- // TODO: assert that dim2 and dim3 are contiguous
1472
1532
  GGML_ASSERT(ne12 % ne02 == 0);
1473
1533
  GGML_ASSERT(ne13 % ne03 == 0);
1474
1534
 
@@ -2136,6 +2196,7 @@ static enum ggml_status ggml_metal_graph_compute(
2136
2196
  case GGML_OP_RMS_NORM:
2137
2197
  {
2138
2198
  GGML_ASSERT(ne00 % 4 == 0);
2199
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2139
2200
 
2140
2201
  float eps;
2141
2202
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -2163,6 +2224,7 @@ static enum ggml_status ggml_metal_graph_compute(
2163
2224
  case GGML_OP_GROUP_NORM:
2164
2225
  {
2165
2226
  GGML_ASSERT(ne00 % 4 == 0);
2227
+ GGML_ASSERT(ggml_is_contiguous(src0));
2166
2228
 
2167
2229
  //float eps;
2168
2230
  //memcpy(&eps, dst->op_params, sizeof(float));
@@ -2196,6 +2258,8 @@ static enum ggml_status ggml_metal_graph_compute(
2196
2258
  } break;
2197
2259
  case GGML_OP_NORM:
2198
2260
  {
2261
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2262
+
2199
2263
  float eps;
2200
2264
  memcpy(&eps, dst->op_params, sizeof(float));
2201
2265
 
@@ -2225,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
2225
2289
  const int n_dims = ((int32_t *) dst->op_params)[1];
2226
2290
  const int mode = ((int32_t *) dst->op_params)[2];
2227
2291
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2228
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2292
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
2229
2293
 
2230
2294
  float freq_base;
2231
2295
  float freq_scale;
@@ -2242,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
2242
2306
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2243
2307
 
2244
2308
  const bool is_neox = mode & 2;
2245
- const bool is_glm = mode & 4;
2246
2309
 
2247
- GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2310
+ id<MTLComputePipelineState> pipeline = nil;
2248
2311
 
2249
2312
  if (!is_neox) {
2250
- GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2313
+ switch (src0->type) {
2314
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2315
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2316
+ default: GGML_ASSERT(false);
2317
+ };
2318
+ } else {
2319
+ switch (src0->type) {
2320
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2321
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2322
+ default: GGML_ASSERT(false);
2323
+ };
2251
2324
  }
2252
2325
 
2253
- id<MTLComputePipelineState> pipeline = nil;
2254
-
2255
- switch (src0->type) {
2256
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
2257
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
2258
- default: GGML_ASSERT(false);
2259
- };
2260
-
2261
2326
  [encoder setComputePipelineState:pipeline];
2262
2327
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2263
2328
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2285,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
2285
2350
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2286
2351
  [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2287
2352
  [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2288
- [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2289
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2290
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2291
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2292
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2293
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2294
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2295
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2353
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2354
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2355
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2356
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2357
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2358
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2359
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2296
2360
 
2297
2361
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2298
2362
  } break;
@@ -2573,7 +2637,7 @@ static enum ggml_status ggml_metal_graph_compute(
2573
2637
  case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2574
2638
  case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2575
2639
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2576
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2640
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2577
2641
  default:
2578
2642
  {
2579
2643
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2586,7 +2650,7 @@ static enum ggml_status ggml_metal_graph_compute(
2586
2650
 
2587
2651
  switch (ne00) {
2588
2652
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2589
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2653
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2590
2654
  default:
2591
2655
  {
2592
2656
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -168,6 +168,53 @@ kernel void kernel_div(
168
168
  }
169
169
  }
170
170
 
171
+ template<typename T>
172
+ kernel void kernel_repeat(
173
+ device const char * src0,
174
+ device char * dst,
175
+ constant int64_t & ne00,
176
+ constant int64_t & ne01,
177
+ constant int64_t & ne02,
178
+ constant int64_t & ne03,
179
+ constant uint64_t & nb00,
180
+ constant uint64_t & nb01,
181
+ constant uint64_t & nb02,
182
+ constant uint64_t & nb03,
183
+ constant int64_t & ne0,
184
+ constant int64_t & ne1,
185
+ constant int64_t & ne2,
186
+ constant int64_t & ne3,
187
+ constant uint64_t & nb0,
188
+ constant uint64_t & nb1,
189
+ constant uint64_t & nb2,
190
+ constant uint64_t & nb3,
191
+ uint3 tgpig[[threadgroup_position_in_grid]],
192
+ uint3 tpitg[[thread_position_in_threadgroup]],
193
+ uint3 ntg[[threads_per_threadgroup]]) {
194
+ const int64_t i3 = tgpig.z;
195
+ const int64_t i2 = tgpig.y;
196
+ const int64_t i1 = tgpig.x;
197
+
198
+ const int64_t i03 = i3 % ne03;
199
+ const int64_t i02 = i2 % ne02;
200
+ const int64_t i01 = i1 % ne01;
201
+
202
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
203
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
204
+
205
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
206
+ const int i00 = i0 % ne00;
207
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
208
+ }
209
+ }
210
+
211
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
212
+
213
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
214
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
215
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
216
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
217
+
171
218
  // assumption: src1 is a row
172
219
  // broadcast src1 into src0
173
220
  kernel void kernel_add_row(
@@ -1607,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1607
1654
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1608
1655
  static void rope_yarn(
1609
1656
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1610
- thread float * cos_theta, thread float * sin_theta
1611
- ) {
1657
+ thread float * cos_theta, thread float * sin_theta) {
1612
1658
  // Get n-d rotational scaling corrected for extrapolation
1613
1659
  float theta_interp = freq_scale * theta_extrap;
1614
1660
  float theta = theta_interp;
@@ -1625,19 +1671,20 @@ static void rope_yarn(
1625
1671
 
1626
1672
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1627
1673
  // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1628
- static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1629
- return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1674
+ static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
1675
+ return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1630
1676
  }
1631
1677
 
1632
1678
  static void rope_yarn_corr_dims(
1633
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1679
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
1634
1680
  ) {
1635
1681
  // start and end correction dims
1636
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1637
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1682
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
1683
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
1638
1684
  }
1639
1685
 
1640
- typedef void (rope_t)(
1686
+ template<typename T>
1687
+ kernel void kernel_rope_norm(
1641
1688
  device const void * src0,
1642
1689
  device const int32_t * src1,
1643
1690
  device const float * src2,
@@ -1660,8 +1707,7 @@ typedef void (rope_t)(
1660
1707
  constant uint64_t & nb3,
1661
1708
  constant int & n_past,
1662
1709
  constant int & n_dims,
1663
- constant int & mode,
1664
- constant int & n_orig_ctx,
1710
+ constant int & n_ctx_orig,
1665
1711
  constant float & freq_base,
1666
1712
  constant float & freq_scale,
1667
1713
  constant float & ext_factor,
@@ -1670,10 +1716,52 @@ typedef void (rope_t)(
1670
1716
  constant float & beta_slow,
1671
1717
  uint tiitg[[thread_index_in_threadgroup]],
1672
1718
  uint3 tptg[[threads_per_threadgroup]],
1673
- uint3 tgpig[[threadgroup_position_in_grid]]);
1719
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
1720
+ const int64_t i3 = tgpig[2];
1721
+ const int64_t i2 = tgpig[1];
1722
+ const int64_t i1 = tgpig[0];
1723
+
1724
+ float corr_dims[2];
1725
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1726
+
1727
+ device const int32_t * pos = src1;
1728
+
1729
+ const float theta_base = (float) pos[i2];
1730
+ const float inv_ndims = -1.f/n_dims;
1731
+
1732
+ float cos_theta;
1733
+ float sin_theta;
1734
+
1735
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1736
+ if (i0 < n_dims) {
1737
+ const int64_t ic = i0/2;
1738
+
1739
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1740
+
1741
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1742
+
1743
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1744
+
1745
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1746
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1747
+
1748
+ const float x0 = src[0];
1749
+ const float x1 = src[1];
1750
+
1751
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1752
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
1753
+ } else {
1754
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1756
+
1757
+ dst_data[0] = src[0];
1758
+ dst_data[1] = src[1];
1759
+ }
1760
+ }
1761
+ }
1674
1762
 
1675
1763
  template<typename T>
1676
- kernel void kernel_rope(
1764
+ kernel void kernel_rope_neox(
1677
1765
  device const void * src0,
1678
1766
  device const int32_t * src1,
1679
1767
  device const float * src2,
@@ -1696,8 +1784,7 @@ kernel void kernel_rope(
1696
1784
  constant uint64_t & nb3,
1697
1785
  constant int & n_past,
1698
1786
  constant int & n_dims,
1699
- constant int & mode,
1700
- constant int & n_orig_ctx,
1787
+ constant int & n_ctx_orig,
1701
1788
  constant float & freq_base,
1702
1789
  constant float & freq_scale,
1703
1790
  constant float & ext_factor,
@@ -1711,73 +1798,53 @@ kernel void kernel_rope(
1711
1798
  const int64_t i2 = tgpig[1];
1712
1799
  const int64_t i1 = tgpig[0];
1713
1800
 
1714
- const bool is_neox = mode & 2;
1715
-
1716
1801
  float corr_dims[2];
1717
- rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1802
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1718
1803
 
1719
1804
  device const int32_t * pos = src1;
1720
1805
 
1721
- const int64_t p = pos[i2];
1722
-
1723
- const float theta_0 = (float)p;
1806
+ const float theta_base = (float) pos[i2];
1724
1807
  const float inv_ndims = -1.f/n_dims;
1725
1808
 
1726
- if (!is_neox) {
1727
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1728
-
1729
- const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1730
- float cos_theta, sin_theta;
1731
- rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1732
-
1733
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1734
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1735
-
1736
- const T x0 = src[0];
1737
- const T x1 = src[1];
1809
+ float cos_theta;
1810
+ float sin_theta;
1738
1811
 
1739
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1740
- dst_data[1] = x0*sin_theta + x1*cos_theta;
1741
- }
1742
- } else {
1743
- for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1744
- if (ic < n_dims) {
1745
- const int64_t ib = 0;
1746
-
1747
- // simplified from `(ib * n_dims + ic) * inv_ndims`
1748
- const float cur_rot = inv_ndims*ic - ib;
1749
- const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1750
-
1751
- const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1812
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1813
+ if (i0 < n_dims) {
1814
+ const int64_t ic = i0/2;
1752
1815
 
1753
- float cos_theta, sin_theta;
1754
- rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1816
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1755
1817
 
1756
- const int64_t i0 = ib*n_dims + ic/2;
1818
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1757
1819
 
1758
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1759
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1760
1821
 
1761
- const float x0 = src[0];
1762
- const float x1 = src[n_dims/2];
1822
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1823
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1763
1824
 
1764
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1765
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1766
- } else {
1767
- const int64_t i0 = ic;
1825
+ const float x0 = src[0];
1826
+ const float x1 = src[n_dims/2];
1768
1827
 
1769
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1770
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1828
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1829
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1830
+ } else {
1831
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1832
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1771
1833
 
1772
- dst_data[0] = src[0];
1773
- dst_data[1] = src[1];
1774
- }
1834
+ dst_data[0] = src[0];
1835
+ dst_data[1] = src[1];
1775
1836
  }
1776
1837
  }
1777
1838
  }
1778
1839
 
1779
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1780
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1840
+ typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
1841
+ typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
1842
+
1843
+ template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
1844
+ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
1845
+
1846
+ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
1847
+ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
1781
1848
 
1782
1849
  typedef void (im2col_t)(
1783
1850
  device const float * x,
@@ -2418,7 +2485,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
2418
2485
  template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2419
2486
  template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2420
2487
  template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2421
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2488
+ //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2422
2489
 
2423
2490
  template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2424
2491
  kernel void kernel_flash_attn_ext_vec_f16(
@@ -2696,7 +2763,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2696
2763
  }
2697
2764
 
2698
2765
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2699
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2766
+ //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2700
2767
 
2701
2768
  kernel void kernel_cpy_f16_f16(
2702
2769
  device const half * src0,
@@ -3319,31 +3386,30 @@ kernel void kernel_concat(
3319
3386
  constant uint64_t & nb1,
3320
3387
  constant uint64_t & nb2,
3321
3388
  constant uint64_t & nb3,
3389
+ constant int32_t & dim,
3322
3390
  uint3 tgpig[[threadgroup_position_in_grid]],
3323
3391
  uint3 tpitg[[thread_position_in_threadgroup]],
3324
3392
  uint3 ntg[[threads_per_threadgroup]]) {
3325
3393
 
3326
- const int64_t i03 = tgpig.z;
3327
- const int64_t i02 = tgpig.y;
3328
- const int64_t i01 = tgpig.x;
3394
+ const int64_t i3 = tgpig.z;
3395
+ const int64_t i2 = tgpig.y;
3396
+ const int64_t i1 = tgpig.x;
3329
3397
 
3330
- const int64_t i13 = i03 % ne13;
3331
- const int64_t i12 = i02 % ne12;
3332
- const int64_t i11 = i01 % ne11;
3398
+ int64_t o[4] = {0, 0, 0, 0};
3399
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
3333
3400
 
3334
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
3335
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
3336
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
3401
+ device const float * x;
3337
3402
 
3338
3403
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3339
- if (i02 < ne02) {
3340
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
3341
- src0_ptr += ntg.x*nb00;
3404
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3405
+ x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
3342
3406
  } else {
3343
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
3344
- src1_ptr += ntg.x*nb10;
3407
+ x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
3345
3408
  }
3346
- dst_ptr += ntg.x*nb0;
3409
+
3410
+ device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3411
+
3412
+ *y = *x;
3347
3413
  }
3348
3414
  }
3349
3415