llama_cpp 0.15.3 → 0.16.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (149) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +27 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +15 -1
  7. data/vendor/tmp/llama.cpp/Makefile +66 -36
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +35 -16
  131. data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
  132. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -7
  133. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  134. data/vendor/tmp/llama.cpp/ggml-metal.m +99 -35
  135. data/vendor/tmp/llama.cpp/ggml-metal.metal +146 -80
  136. data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
  137. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
  138. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +345 -227
  139. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  140. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +458 -329
  141. data/vendor/tmp/llama.cpp/ggml.c +301 -409
  142. data/vendor/tmp/llama.cpp/ggml.h +19 -23
  143. data/vendor/tmp/llama.cpp/llama.cpp +855 -651
  144. data/vendor/tmp/llama.cpp/llama.h +28 -48
  145. metadata +121 -6
  146. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  147. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  148. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  149. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
35
35
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
36
  GGML_METAL_KERNEL_TYPE_DIV,
37
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
38
42
  GGML_METAL_KERNEL_TYPE_SCALE,
39
43
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
44
  GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -168,8 +172,10 @@ enum ggml_metal_kernel_type {
168
172
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
169
173
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
170
174
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
171
- GGML_METAL_KERNEL_TYPE_ROPE_F32,
172
- GGML_METAL_KERNEL_TYPE_ROPE_F16,
175
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
173
179
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174
180
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175
181
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -184,9 +190,9 @@ enum ggml_metal_kernel_type {
184
190
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185
191
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186
192
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
193
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
188
194
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
195
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
190
196
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191
197
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192
198
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -485,6 +491,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
485
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
486
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
487
493
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
496
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
488
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
489
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
490
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -618,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
618
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
619
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
620
630
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
621
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
622
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
631
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
623
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
624
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
625
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -634,9 +646,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
634
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
635
647
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
636
648
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
649
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638
650
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
651
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
640
652
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
641
653
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
642
654
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -746,6 +758,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746
758
  case GGML_OP_ACC:
747
759
  case GGML_OP_MUL:
748
760
  case GGML_OP_DIV:
761
+ case GGML_OP_REPEAT:
749
762
  case GGML_OP_SCALE:
750
763
  case GGML_OP_CLAMP:
751
764
  case GGML_OP_SQR:
@@ -770,6 +783,15 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
770
783
  case GGML_OP_LEAKY_RELU:
771
784
  return true;
772
785
  case GGML_OP_FLASH_ATTN_EXT:
786
+ if (op->src[1]->type != GGML_TYPE_F16) {
787
+ return false;
788
+ }
789
+ if (op->src[2]->type != GGML_TYPE_F16) {
790
+ return false;
791
+ }
792
+ if (op->src[0]->ne[0] == 256) {
793
+ return false;
794
+ }
773
795
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
774
796
  case GGML_OP_MUL_MAT:
775
797
  case GGML_OP_MUL_MAT_ID:
@@ -976,10 +998,10 @@ static enum ggml_status ggml_metal_graph_compute(
976
998
  switch (dst->op) {
977
999
  case GGML_OP_CONCAT:
978
1000
  {
979
- const int64_t nb = ne00;
980
-
981
1001
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
982
1002
 
1003
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
1004
+
983
1005
  [encoder setComputePipelineState:pipeline];
984
1006
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
985
1007
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1008,7 +1030,7 @@ static enum ggml_status ggml_metal_graph_compute(
1008
1030
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1009
1031
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1010
1032
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1011
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1033
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1012
1034
 
1013
1035
  const int nth = MIN(1024, ne0);
1014
1036
 
@@ -1018,11 +1040,14 @@ static enum ggml_status ggml_metal_graph_compute(
1018
1040
  case GGML_OP_MUL:
1019
1041
  case GGML_OP_DIV:
1020
1042
  {
1043
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1044
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1045
+
1021
1046
  const size_t offs = 0;
1022
1047
 
1023
1048
  bool bcast_row = false;
1024
1049
 
1025
- int64_t nb = ne00;
1050
+ int64_t nb = ne00; // used by the "row" kernels
1026
1051
 
1027
1052
  id<MTLComputePipelineState> pipeline = nil;
1028
1053
 
@@ -1091,6 +1116,42 @@ static enum ggml_status ggml_metal_graph_compute(
1091
1116
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1092
1117
  }
1093
1118
  } break;
1119
+ case GGML_OP_REPEAT:
1120
+ {
1121
+ id<MTLComputePipelineState> pipeline;
1122
+
1123
+ switch (src0t) {
1124
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1125
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1126
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1127
+ case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1128
+ default: GGML_ASSERT(false);
1129
+ }
1130
+
1131
+ [encoder setComputePipelineState:pipeline];
1132
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1133
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1134
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1135
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1136
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1137
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1138
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1139
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1140
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1141
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1142
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1143
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1144
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1145
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1146
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1147
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1148
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1149
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1150
+
1151
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1152
+
1153
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1154
+ } break;
1094
1155
  case GGML_OP_ACC:
1095
1156
  {
1096
1157
  GGML_ASSERT(src0t == GGML_TYPE_F32);
@@ -1468,7 +1529,6 @@ static enum ggml_status ggml_metal_graph_compute(
1468
1529
  {
1469
1530
  GGML_ASSERT(ne00 == ne10);
1470
1531
 
1471
- // TODO: assert that dim2 and dim3 are contiguous
1472
1532
  GGML_ASSERT(ne12 % ne02 == 0);
1473
1533
  GGML_ASSERT(ne13 % ne03 == 0);
1474
1534
 
@@ -2136,6 +2196,7 @@ static enum ggml_status ggml_metal_graph_compute(
2136
2196
  case GGML_OP_RMS_NORM:
2137
2197
  {
2138
2198
  GGML_ASSERT(ne00 % 4 == 0);
2199
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2139
2200
 
2140
2201
  float eps;
2141
2202
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -2163,6 +2224,7 @@ static enum ggml_status ggml_metal_graph_compute(
2163
2224
  case GGML_OP_GROUP_NORM:
2164
2225
  {
2165
2226
  GGML_ASSERT(ne00 % 4 == 0);
2227
+ GGML_ASSERT(ggml_is_contiguous(src0));
2166
2228
 
2167
2229
  //float eps;
2168
2230
  //memcpy(&eps, dst->op_params, sizeof(float));
@@ -2196,6 +2258,8 @@ static enum ggml_status ggml_metal_graph_compute(
2196
2258
  } break;
2197
2259
  case GGML_OP_NORM:
2198
2260
  {
2261
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2262
+
2199
2263
  float eps;
2200
2264
  memcpy(&eps, dst->op_params, sizeof(float));
2201
2265
 
@@ -2225,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
2225
2289
  const int n_dims = ((int32_t *) dst->op_params)[1];
2226
2290
  const int mode = ((int32_t *) dst->op_params)[2];
2227
2291
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2228
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2292
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
2229
2293
 
2230
2294
  float freq_base;
2231
2295
  float freq_scale;
@@ -2242,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
2242
2306
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2243
2307
 
2244
2308
  const bool is_neox = mode & 2;
2245
- const bool is_glm = mode & 4;
2246
2309
 
2247
- GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2310
+ id<MTLComputePipelineState> pipeline = nil;
2248
2311
 
2249
2312
  if (!is_neox) {
2250
- GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2313
+ switch (src0->type) {
2314
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2315
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2316
+ default: GGML_ASSERT(false);
2317
+ };
2318
+ } else {
2319
+ switch (src0->type) {
2320
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2321
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2322
+ default: GGML_ASSERT(false);
2323
+ };
2251
2324
  }
2252
2325
 
2253
- id<MTLComputePipelineState> pipeline = nil;
2254
-
2255
- switch (src0->type) {
2256
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
2257
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
2258
- default: GGML_ASSERT(false);
2259
- };
2260
-
2261
2326
  [encoder setComputePipelineState:pipeline];
2262
2327
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2263
2328
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2285,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
2285
2350
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2286
2351
  [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2287
2352
  [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2288
- [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2289
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2290
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2291
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2292
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2293
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2294
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2295
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2353
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2354
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2355
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2356
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2357
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2358
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2359
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2296
2360
 
2297
2361
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2298
2362
  } break;
@@ -2573,7 +2637,7 @@ static enum ggml_status ggml_metal_graph_compute(
2573
2637
  case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2574
2638
  case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2575
2639
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2576
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2640
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2577
2641
  default:
2578
2642
  {
2579
2643
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2586,7 +2650,7 @@ static enum ggml_status ggml_metal_graph_compute(
2586
2650
 
2587
2651
  switch (ne00) {
2588
2652
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2589
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2653
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2590
2654
  default:
2591
2655
  {
2592
2656
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -168,6 +168,53 @@ kernel void kernel_div(
168
168
  }
169
169
  }
170
170
 
171
+ template<typename T>
172
+ kernel void kernel_repeat(
173
+ device const char * src0,
174
+ device char * dst,
175
+ constant int64_t & ne00,
176
+ constant int64_t & ne01,
177
+ constant int64_t & ne02,
178
+ constant int64_t & ne03,
179
+ constant uint64_t & nb00,
180
+ constant uint64_t & nb01,
181
+ constant uint64_t & nb02,
182
+ constant uint64_t & nb03,
183
+ constant int64_t & ne0,
184
+ constant int64_t & ne1,
185
+ constant int64_t & ne2,
186
+ constant int64_t & ne3,
187
+ constant uint64_t & nb0,
188
+ constant uint64_t & nb1,
189
+ constant uint64_t & nb2,
190
+ constant uint64_t & nb3,
191
+ uint3 tgpig[[threadgroup_position_in_grid]],
192
+ uint3 tpitg[[thread_position_in_threadgroup]],
193
+ uint3 ntg[[threads_per_threadgroup]]) {
194
+ const int64_t i3 = tgpig.z;
195
+ const int64_t i2 = tgpig.y;
196
+ const int64_t i1 = tgpig.x;
197
+
198
+ const int64_t i03 = i3 % ne03;
199
+ const int64_t i02 = i2 % ne02;
200
+ const int64_t i01 = i1 % ne01;
201
+
202
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
203
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
204
+
205
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
206
+ const int i00 = i0 % ne00;
207
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
208
+ }
209
+ }
210
+
211
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
212
+
213
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
214
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
215
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
216
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
217
+
171
218
  // assumption: src1 is a row
172
219
  // broadcast src1 into src0
173
220
  kernel void kernel_add_row(
@@ -1607,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1607
1654
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1608
1655
  static void rope_yarn(
1609
1656
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1610
- thread float * cos_theta, thread float * sin_theta
1611
- ) {
1657
+ thread float * cos_theta, thread float * sin_theta) {
1612
1658
  // Get n-d rotational scaling corrected for extrapolation
1613
1659
  float theta_interp = freq_scale * theta_extrap;
1614
1660
  float theta = theta_interp;
@@ -1625,19 +1671,20 @@ static void rope_yarn(
1625
1671
 
1626
1672
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1627
1673
  // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1628
- static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1629
- return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1674
+ static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
1675
+ return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1630
1676
  }
1631
1677
 
1632
1678
  static void rope_yarn_corr_dims(
1633
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1679
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
1634
1680
  ) {
1635
1681
  // start and end correction dims
1636
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1637
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1682
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
1683
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
1638
1684
  }
1639
1685
 
1640
- typedef void (rope_t)(
1686
+ template<typename T>
1687
+ kernel void kernel_rope_norm(
1641
1688
  device const void * src0,
1642
1689
  device const int32_t * src1,
1643
1690
  device const float * src2,
@@ -1660,8 +1707,7 @@ typedef void (rope_t)(
1660
1707
  constant uint64_t & nb3,
1661
1708
  constant int & n_past,
1662
1709
  constant int & n_dims,
1663
- constant int & mode,
1664
- constant int & n_orig_ctx,
1710
+ constant int & n_ctx_orig,
1665
1711
  constant float & freq_base,
1666
1712
  constant float & freq_scale,
1667
1713
  constant float & ext_factor,
@@ -1670,10 +1716,52 @@ typedef void (rope_t)(
1670
1716
  constant float & beta_slow,
1671
1717
  uint tiitg[[thread_index_in_threadgroup]],
1672
1718
  uint3 tptg[[threads_per_threadgroup]],
1673
- uint3 tgpig[[threadgroup_position_in_grid]]);
1719
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
1720
+ const int64_t i3 = tgpig[2];
1721
+ const int64_t i2 = tgpig[1];
1722
+ const int64_t i1 = tgpig[0];
1723
+
1724
+ float corr_dims[2];
1725
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1726
+
1727
+ device const int32_t * pos = src1;
1728
+
1729
+ const float theta_base = (float) pos[i2];
1730
+ const float inv_ndims = -1.f/n_dims;
1731
+
1732
+ float cos_theta;
1733
+ float sin_theta;
1734
+
1735
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1736
+ if (i0 < n_dims) {
1737
+ const int64_t ic = i0/2;
1738
+
1739
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1740
+
1741
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1742
+
1743
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1744
+
1745
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1746
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1747
+
1748
+ const float x0 = src[0];
1749
+ const float x1 = src[1];
1750
+
1751
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1752
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
1753
+ } else {
1754
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1756
+
1757
+ dst_data[0] = src[0];
1758
+ dst_data[1] = src[1];
1759
+ }
1760
+ }
1761
+ }
1674
1762
 
1675
1763
  template<typename T>
1676
- kernel void kernel_rope(
1764
+ kernel void kernel_rope_neox(
1677
1765
  device const void * src0,
1678
1766
  device const int32_t * src1,
1679
1767
  device const float * src2,
@@ -1696,8 +1784,7 @@ kernel void kernel_rope(
1696
1784
  constant uint64_t & nb3,
1697
1785
  constant int & n_past,
1698
1786
  constant int & n_dims,
1699
- constant int & mode,
1700
- constant int & n_orig_ctx,
1787
+ constant int & n_ctx_orig,
1701
1788
  constant float & freq_base,
1702
1789
  constant float & freq_scale,
1703
1790
  constant float & ext_factor,
@@ -1711,73 +1798,53 @@ kernel void kernel_rope(
1711
1798
  const int64_t i2 = tgpig[1];
1712
1799
  const int64_t i1 = tgpig[0];
1713
1800
 
1714
- const bool is_neox = mode & 2;
1715
-
1716
1801
  float corr_dims[2];
1717
- rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1802
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1718
1803
 
1719
1804
  device const int32_t * pos = src1;
1720
1805
 
1721
- const int64_t p = pos[i2];
1722
-
1723
- const float theta_0 = (float)p;
1806
+ const float theta_base = (float) pos[i2];
1724
1807
  const float inv_ndims = -1.f/n_dims;
1725
1808
 
1726
- if (!is_neox) {
1727
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1728
-
1729
- const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1730
- float cos_theta, sin_theta;
1731
- rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1732
-
1733
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1734
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1735
-
1736
- const T x0 = src[0];
1737
- const T x1 = src[1];
1809
+ float cos_theta;
1810
+ float sin_theta;
1738
1811
 
1739
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1740
- dst_data[1] = x0*sin_theta + x1*cos_theta;
1741
- }
1742
- } else {
1743
- for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1744
- if (ic < n_dims) {
1745
- const int64_t ib = 0;
1746
-
1747
- // simplified from `(ib * n_dims + ic) * inv_ndims`
1748
- const float cur_rot = inv_ndims*ic - ib;
1749
- const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1750
-
1751
- const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1812
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1813
+ if (i0 < n_dims) {
1814
+ const int64_t ic = i0/2;
1752
1815
 
1753
- float cos_theta, sin_theta;
1754
- rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1816
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1755
1817
 
1756
- const int64_t i0 = ib*n_dims + ic/2;
1818
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1757
1819
 
1758
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1759
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1760
1821
 
1761
- const float x0 = src[0];
1762
- const float x1 = src[n_dims/2];
1822
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1823
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1763
1824
 
1764
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1765
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1766
- } else {
1767
- const int64_t i0 = ic;
1825
+ const float x0 = src[0];
1826
+ const float x1 = src[n_dims/2];
1768
1827
 
1769
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1770
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1828
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1829
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1830
+ } else {
1831
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1832
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1771
1833
 
1772
- dst_data[0] = src[0];
1773
- dst_data[1] = src[1];
1774
- }
1834
+ dst_data[0] = src[0];
1835
+ dst_data[1] = src[1];
1775
1836
  }
1776
1837
  }
1777
1838
  }
1778
1839
 
1779
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1780
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1840
+ typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
1841
+ typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
1842
+
1843
+ template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
1844
+ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
1845
+
1846
+ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
1847
+ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
1781
1848
 
1782
1849
  typedef void (im2col_t)(
1783
1850
  device const float * x,
@@ -2418,7 +2485,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
2418
2485
  template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2419
2486
  template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2420
2487
  template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2421
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2488
+ //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2422
2489
 
2423
2490
  template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2424
2491
  kernel void kernel_flash_attn_ext_vec_f16(
@@ -2696,7 +2763,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2696
2763
  }
2697
2764
 
2698
2765
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2699
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2766
+ //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2700
2767
 
2701
2768
  kernel void kernel_cpy_f16_f16(
2702
2769
  device const half * src0,
@@ -3319,31 +3386,30 @@ kernel void kernel_concat(
3319
3386
  constant uint64_t & nb1,
3320
3387
  constant uint64_t & nb2,
3321
3388
  constant uint64_t & nb3,
3389
+ constant int32_t & dim,
3322
3390
  uint3 tgpig[[threadgroup_position_in_grid]],
3323
3391
  uint3 tpitg[[thread_position_in_threadgroup]],
3324
3392
  uint3 ntg[[threads_per_threadgroup]]) {
3325
3393
 
3326
- const int64_t i03 = tgpig.z;
3327
- const int64_t i02 = tgpig.y;
3328
- const int64_t i01 = tgpig.x;
3394
+ const int64_t i3 = tgpig.z;
3395
+ const int64_t i2 = tgpig.y;
3396
+ const int64_t i1 = tgpig.x;
3329
3397
 
3330
- const int64_t i13 = i03 % ne13;
3331
- const int64_t i12 = i02 % ne12;
3332
- const int64_t i11 = i01 % ne11;
3398
+ int64_t o[4] = {0, 0, 0, 0};
3399
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
3333
3400
 
3334
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
3335
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
3336
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
3401
+ device const float * x;
3337
3402
 
3338
3403
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3339
- if (i02 < ne02) {
3340
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
3341
- src0_ptr += ntg.x*nb00;
3404
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3405
+ x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
3342
3406
  } else {
3343
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
3344
- src1_ptr += ntg.x*nb10;
3407
+ x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
3345
3408
  }
3346
- dst_ptr += ntg.x*nb0;
3409
+
3410
+ device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3411
+
3412
+ *y = *x;
3347
3413
  }
3348
3414
  }
3349
3415