@fugood/llama.node 0.3.17 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -1054,6 +1054,493 @@ class tinyBLAS_Q0_AVX {
1054
1054
  } \
1055
1055
  } \
1056
1056
 
1057
+ template <typename TA, typename TB, typename TC>
1058
+ class tinyBLAS_BF16_PPC {
1059
+ public:
1060
+ tinyBLAS_BF16_PPC(int64_t k,
1061
+ const TA *A, int64_t lda,
1062
+ const TB *B, int64_t ldb,
1063
+ TC *C, int64_t ldc,
1064
+ int ith, int nth)
1065
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1066
+ }
1067
+
1068
+ void matmul(int64_t m, int64_t n) {
1069
+ mnpack(0, m, 0, n);
1070
+ }
1071
+
1072
+ private:
1073
+ void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
1074
+ vec_t t[8], s[8];
1075
+ vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
1076
+ vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
1077
+ vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1078
+ vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1079
+
1080
+ if (numVec == 2) {
1081
+ t[0] = vec_perm(c[0], c[1], swiz1);
1082
+ t[1] = vec_perm(c[2], c[3], swiz1);
1083
+ s[0] = vec_perm(t[0], t[1], swiz3);
1084
+ s[1] = vec_perm(t[0], t[1], swiz4);
1085
+ vec_xst(s[0], 0, (vec_t*)vecOffset);
1086
+ vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
1087
+ } else if (numVec == 4) {
1088
+ t[0] = vec_perm(c[0], c[1], swiz1);
1089
+ t[1] = vec_perm(c[0], c[1], swiz2);
1090
+ t[2] = vec_perm(c[2], c[3], swiz1);
1091
+ t[3] = vec_perm(c[2], c[3], swiz2);
1092
+ s[0] = vec_perm(t[0], t[2], swiz3);
1093
+ s[1] = vec_perm(t[0], t[2], swiz4);
1094
+ s[2] = vec_perm(t[1], t[3], swiz3);
1095
+ s[3] = vec_perm(t[1], t[3], swiz4);
1096
+ for (int i = 0; i < 4; ++i)
1097
+ vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1098
+ } else if (numVec == 8) {
1099
+ for (int i = 0; i < 4; i += 2) {
1100
+ t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1101
+ t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1102
+ }
1103
+ for (int i = 4; i < 8; i += 2) {
1104
+ t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1105
+ t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1106
+ }
1107
+ s[0] = vec_perm(t[0], t[2], swiz3);
1108
+ s[1] = vec_perm(t[0], t[2], swiz4);
1109
+ s[2] = vec_perm(t[1], t[3], swiz3);
1110
+ s[3] = vec_perm(t[1], t[3], swiz4);
1111
+ s[4] = vec_perm(t[4], t[6], swiz3);
1112
+ s[5] = vec_perm(t[4], t[6], swiz4);
1113
+ s[6] = vec_perm(t[5], t[7], swiz3);
1114
+ s[7] = vec_perm(t[5], t[7], swiz4);
1115
+ for (int i = 0; i < 8; ++i)
1116
+ vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1117
+ }
1118
+ }
1119
+
1120
+ void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
1121
+ int64_t i, j;
1122
+ TA *aoffset = NULL;
1123
+ unsigned char *vecOffset = NULL;
1124
+ TA * aoffsets[8];
1125
+ vector unsigned char c_arr[8];
1126
+ aoffset = const_cast<TA*>(a);
1127
+ vecOffset = vec;
1128
+ j = (rows >> 3);
1129
+ if (j > 0) {
1130
+ do {
1131
+ if (cols == 4) {
1132
+ aoffsets[0] = aoffset;
1133
+ for (int it = 1; it < 4; ++it)
1134
+ aoffsets[it] = aoffsets[it-1] + lda;
1135
+ aoffset += 4 * lda;
1136
+ for (int i = 0; i < 4; ++i)
1137
+ c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
1138
+ vector_permute_store(c_arr, 4, vecOffset);
1139
+ for (int i = 0; i<4; i++)
1140
+ aoffsets[i] = aoffsets[i]+lda;
1141
+ vecOffset +=64;
1142
+ }
1143
+ i = (cols >> 3);
1144
+ if (i > 0) {
1145
+ aoffsets[0] = aoffset;
1146
+ for (int it = 1; it < 8; ++it) {
1147
+ aoffsets[it] = aoffsets[it-1] + lda;
1148
+ }
1149
+ aoffset += 8 * lda;
1150
+ do {
1151
+ for (int it = 0; it < 8; ++it)
1152
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1153
+ vector_permute_store(c_arr, 8, vecOffset);
1154
+ for (int it = 0; it < 8; ++it)
1155
+ aoffsets[it] = aoffsets[it] + 8*lda;
1156
+ vecOffset += 128;
1157
+ i--;
1158
+ } while(i > 0);
1159
+ }
1160
+ j--;
1161
+ } while(j > 0);
1162
+ }
1163
+ if (rows & 4) {
1164
+ aoffsets[0] = aoffset;
1165
+ for (int it = 1; it < 4; ++it)
1166
+ aoffsets[it] = aoffsets[it-1] + lda;
1167
+ aoffset += 4 * lda;
1168
+ if (cols == 4) {
1169
+ for (int it = 0; it < 4; ++it)
1170
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1171
+ vector_permute_store(c_arr, 2, vecOffset);
1172
+ for (int it = 0; it< 4; it++)
1173
+ aoffsets[it] = aoffsets[it] + lda;
1174
+ vecOffset += 32;
1175
+ }
1176
+ i = (cols >> 3);
1177
+ if (i > 0) {
1178
+ do {
1179
+ for (int it = 0; it < 4; ++it)
1180
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1181
+ vector_permute_store(c_arr, 4, vecOffset);
1182
+ for (int it = 0; it< 4; it++)
1183
+ aoffsets[it] = aoffsets[it] + 8*lda;
1184
+ vecOffset += 64;
1185
+ i--;
1186
+ } while(i > 0);
1187
+ }
1188
+ }
1189
+ if (rows & 3) {
1190
+ aoffsets[0] = aoffset;
1191
+ for (int it = 1; it < 4; ++it)
1192
+ aoffsets[it] = aoffsets[it-1] + lda;
1193
+ if (cols == 4) {
1194
+ switch(rows) {
1195
+ case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1196
+ case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1197
+ case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1198
+ break;
1199
+ }
1200
+ vector_permute_store(c_arr, 2, vecOffset);
1201
+ for (int it = 0; it< 4; it++)
1202
+ aoffsets[it] = aoffsets[it] + lda;
1203
+ vecOffset += 32;
1204
+ }
1205
+ i = (cols >> 3);
1206
+ if (i > 0) {
1207
+ do {
1208
+ switch(rows) {
1209
+ case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1210
+ case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1211
+ case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1212
+ break;
1213
+ }
1214
+ vector_permute_store(c_arr, 4, vecOffset);
1215
+ for (int it = 0; it <4; it++)
1216
+ aoffsets[it] = aoffsets[it] + 8* lda;
1217
+ vecOffset += 64;
1218
+ i--;
1219
+ } while(i > 0);
1220
+ }
1221
+ }
1222
+ }
1223
+
1224
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1225
+ int64_t mc, nc, mp, np;
1226
+ int m_rem = MIN(m - m0, 8);
1227
+ int n_rem = MIN(n - n0, 8);
1228
+
1229
+ if (m_rem >= 8 && n_rem >= 8) {
1230
+ mc = 8;
1231
+ nc = 8;
1232
+ gemm<8,8>(m0, m, n0, n);
1233
+ } else if (m_rem >= 4 && n_rem >= 8) {
1234
+ mc = 4;
1235
+ nc = 8;
1236
+ gemm<4,8>(m0, m, n0, n);
1237
+ } else if (m_rem >=8 && n_rem >=4){
1238
+ mc = 8;
1239
+ nc = 4;
1240
+ gemm<8,4>(m0, m, n0, n);
1241
+ } else if ((m_rem < 4) && (n_rem >= 8)) {
1242
+ nc = 8;
1243
+ switch(m_rem) {
1244
+ case 1:
1245
+ mc = 1;
1246
+ gemm_Mx8<1>(m0, m, n0, n);
1247
+ break;
1248
+ case 2:
1249
+ mc = 2;
1250
+ gemm_Mx8<2>(m0, m, n0, n);
1251
+ break;
1252
+ case 3:
1253
+ mc = 3;
1254
+ gemm_Mx8<3>(m0, m, n0, n);
1255
+ break;
1256
+ default:
1257
+ return;
1258
+ }
1259
+ } else if (m_rem >= 4 && n_rem >= 4) {
1260
+ mc = 4;
1261
+ nc = 4;
1262
+ gemm_small<4, 4>(m0, m, n0, n);
1263
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1264
+ mc = 4;
1265
+ switch(n_rem) {
1266
+ case 1:
1267
+ nc = 1;
1268
+ gemm_small<4, 1>(m0, m, n0, n);
1269
+ break;
1270
+ case 2:
1271
+ nc = 2;
1272
+ gemm_small<4, 2>(m0, m, n0, n);
1273
+ break;
1274
+ case 3:
1275
+ nc = 3;
1276
+ gemm_small<4, 3>(m0, m, n0, n);
1277
+ break;
1278
+
1279
+ default:
1280
+ return;
1281
+ }
1282
+ } else {
1283
+ switch((m_rem << 4) | n_rem) {
1284
+ case 0x43:
1285
+ mc = 4;
1286
+ nc = 3;
1287
+ gemm_small<4, 3>(m0, m, n0, n);
1288
+ break;
1289
+ case 0x42:
1290
+ mc = 4;
1291
+ nc = 2;
1292
+ gemm_small<4, 2>(m0, m, n0, n);
1293
+ break;
1294
+ case 0x41:
1295
+ mc = 4;
1296
+ nc = 1;
1297
+ gemm_small<4, 1>(m0, m, n0, n);
1298
+ break;
1299
+ case 0x34:
1300
+ mc = 3;
1301
+ nc = 4;
1302
+ gemm_small<3, 4>(m0, m, n0, n);
1303
+ break;
1304
+ case 0x33:
1305
+ mc = 3;
1306
+ nc = 3;
1307
+ gemm_small<3, 3>(m0, m, n0, n);
1308
+ break;
1309
+ case 0x32:
1310
+ mc = 3;
1311
+ nc = 2;
1312
+ gemm_small<3, 2>(m0, m, n0, n);
1313
+ break;
1314
+ case 0x31:
1315
+ mc = 3;
1316
+ nc = 1;
1317
+ gemm_small<3, 1>(m0, m, n0, n);
1318
+ break;
1319
+ case 0x24:
1320
+ mc = 2;
1321
+ nc = 4;
1322
+ gemm_small<2,4>(m0, m, n0, n);
1323
+ break;
1324
+ case 0x23:
1325
+ mc = 2;
1326
+ nc = 3;
1327
+ gemm_small<2, 3>(m0, m, n0, n);
1328
+ break;
1329
+ case 0x22:
1330
+ mc = 2;
1331
+ nc = 2;
1332
+ gemm_small<2, 2>(m0, m, n0, n);
1333
+ break;
1334
+ case 0x21:
1335
+ mc = 2;
1336
+ nc = 1;
1337
+ gemm_small<2, 1>(m0, m, n0, n);
1338
+ break;
1339
+ case 0x14:
1340
+ mc = 1;
1341
+ nc = 4;
1342
+ gemm_small<1, 4>(m0, m, n0, n);
1343
+ break;
1344
+ case 0x13:
1345
+ mc = 1;
1346
+ nc = 3;
1347
+ gemm_small<1, 3>(m0, m, n0, n);
1348
+ break;
1349
+ case 0x12:
1350
+ mc = 1;
1351
+ nc = 2;
1352
+ gemm_small<1, 2>(m0, m, n0, n);
1353
+ break;
1354
+ case 0x11:
1355
+ mc = 1;
1356
+ nc = 1;
1357
+ gemm_small<1, 1>(m0, m, n0, n);
1358
+ break;
1359
+ default:
1360
+ return;
1361
+ }
1362
+ }
1363
+ mp = m0 + (m - m0) / mc * mc;
1364
+ np = n0 + (n - n0) / nc * nc;
1365
+ mnpack(mp, m, n0, np);
1366
+ mnpack(m0, m, np, n);
1367
+ }
1368
+
1369
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1370
+ vec_t vec_A[4], vec_B[8] , vec_C[4];
1371
+ acc_t acc_0, acc_1;
1372
+ __builtin_mma_xxsetaccz(&acc_0);
1373
+ __builtin_mma_xxsetaccz(&acc_1);
1374
+ for (int l = 0; l < k; l+=8) {
1375
+ packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
1376
+ packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
1377
+ for (int x = 0; x < 4; x++) {
1378
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1379
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1380
+ }
1381
+ }
1382
+ SAVE_ACC(&acc_0, ii, jj);
1383
+ SAVE_ACC(&acc_1, ii, jj+4);
1384
+ }
1385
+
1386
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1387
+ vec_t vec_A[8], vec_B[4] , vec_C[4];
1388
+ acc_t acc_0, acc_1;
1389
+ __builtin_mma_xxsetaccz(&acc_0);
1390
+ __builtin_mma_xxsetaccz(&acc_1);
1391
+ for (int l = 0; l < k; l+=8) {
1392
+ packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
1393
+ packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
1394
+ for (int x = 0; x < 4; x++) {
1395
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1396
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
1397
+ }
1398
+ }
1399
+ SAVE_ACC(&acc_0, ii, jj);
1400
+ SAVE_ACC(&acc_1, ii+4, jj);
1401
+ }
1402
+
1403
+
1404
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1405
+ vec_t vec_A[8], vec_B[8], vec_C[4];
1406
+ acc_t acc_0, acc_1, acc_2, acc_3;
1407
+ __builtin_mma_xxsetaccz(&acc_0);
1408
+ __builtin_mma_xxsetaccz(&acc_1);
1409
+ __builtin_mma_xxsetaccz(&acc_2);
1410
+ __builtin_mma_xxsetaccz(&acc_3);
1411
+ for (int l = 0; l < k; l+=8) {
1412
+ packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
1413
+ packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
1414
+ for (int x = 0; x < 4; x++) {
1415
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1416
+ __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
1417
+ __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
1418
+ __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
1419
+ }
1420
+ }
1421
+
1422
+ SAVE_ACC(&acc_0, ii, jj);
1423
+ SAVE_ACC(&acc_1, ii, jj+4);
1424
+ SAVE_ACC(&acc_2, ii+4, jj);
1425
+ SAVE_ACC(&acc_3, ii+4, jj+4);
1426
+ }
1427
+
1428
+ template<int RM, int RN>
1429
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1430
+ int64_t ytiles = (m - m0) / RM;
1431
+ int64_t xtiles = (n - n0) / RN;
1432
+ int64_t tiles = xtiles * ytiles;
1433
+ int64_t duty = (tiles + nth - 1) / nth;
1434
+ int64_t start = duty * ith;
1435
+ int64_t end = start + duty;
1436
+ if (end > tiles)
1437
+ end = tiles;
1438
+ for (int64_t job = start; job < end; ++job) {
1439
+ int64_t ii = m0 + job / xtiles * RM;
1440
+ int64_t jj = n0 + job % xtiles * RN;
1441
+ vec_t vec_C[4];
1442
+ acc_t acc_0;
1443
+ __builtin_mma_xxsetaccz(&acc_0);
1444
+ vec_t vec_A[2], vec_B[2];
1445
+ for (int l=0; l<k; l+=4) {
1446
+ packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
1447
+ packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
1448
+ for (int x = 0; x<2; x++) {
1449
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1450
+ }
1451
+ }
1452
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1453
+ for (int I = 0; I < RM; I++) {
1454
+ for (int J = 0; J < RN; J++) {
1455
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1456
+ }
1457
+ }
1458
+ }
1459
+ }
1460
+
1461
+ template<int RM>
1462
+ void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1463
+ int RN = 8;
1464
+ int64_t ytiles = (m - m0) / RM;
1465
+ int64_t xtiles = (n - n0) / RN;
1466
+ int64_t tiles = xtiles * ytiles;
1467
+ int64_t duty = (tiles + nth - 1) / nth;
1468
+ int64_t start = duty * ith;
1469
+ int64_t end = start + duty;
1470
+ if (end > tiles)
1471
+ end = tiles;
1472
+ for (int64_t job = start; job < end; ++job) {
1473
+ int64_t ii = m0 + job / xtiles * RM;
1474
+ int64_t jj = n0 + job % xtiles * RN;
1475
+ vec_t vec_C[4];
1476
+ acc_t acc_0, acc_1;
1477
+ __builtin_mma_xxsetaccz(&acc_0);
1478
+ __builtin_mma_xxsetaccz(&acc_1);
1479
+ vec_t vec_A[4], vec_B[8];
1480
+ for (int l=0; l<k; l+=8) {
1481
+ packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
1482
+ packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
1483
+ for (int x = 0; x<4; x++) {
1484
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1485
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1486
+ }
1487
+ }
1488
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1489
+ for (int I = 0; I < RM; I++) {
1490
+ for (int J = 0; J < 4; J++) {
1491
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1492
+ }
1493
+ }
1494
+ __builtin_mma_disassemble_acc(vec_C, &acc_1);
1495
+ for (int I = 0; I < RM; I++) {
1496
+ for (int J = 0; J < 4; J++) {
1497
+ *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1498
+ }
1499
+ }
1500
+ }
1501
+ }
1502
+
1503
+ template<int RM, int RN>
1504
+ inline void kernel(int64_t ii, int64_t jj) {
1505
+ if constexpr(RM == 4 && RN == 8) {
1506
+ KERNEL_4x8(ii,jj);
1507
+ } else if constexpr(RM == 8 && RN == 8) {
1508
+ KERNEL_8x8(ii,jj);
1509
+ } else if constexpr(RM == 8 && RN == 4) {
1510
+ KERNEL_8x4(ii,jj);
1511
+ } else {
1512
+ static_assert(false, "RN/RM values not supported");
1513
+ }
1514
+ }
1515
+
1516
+ template <int RM, int RN>
1517
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1518
+ int64_t ytiles = (m - m0) / RM;
1519
+ int64_t xtiles = (n - n0) / RN;
1520
+ int64_t tiles = xtiles * ytiles;
1521
+ int64_t duty = (tiles + nth - 1) / nth;
1522
+ int64_t start = duty * ith;
1523
+ int64_t end = start + duty;
1524
+ if (end > tiles)
1525
+ end = tiles;
1526
+ for (int64_t job = start; job < end; ++job) {
1527
+ int64_t ii = m0 + job / xtiles * RM;
1528
+ int64_t jj = n0 + job % xtiles * RN;
1529
+ kernel<RM, RN>(ii, jj);
1530
+ }
1531
+ }
1532
+
1533
+ const TA *const A;
1534
+ const TB *const B;
1535
+ TC *C;
1536
+ const int64_t k;
1537
+ const int64_t lda;
1538
+ const int64_t ldb;
1539
+ const int64_t ldc;
1540
+ const int ith;
1541
+ const int nth;
1542
+ };
1543
+
1057
1544
  template <typename TA, typename TB, typename TC>
1058
1545
  class tinyBLAS_Q0_PPC {
1059
1546
  public:
@@ -2202,6 +2689,7 @@ class tinyBLAS_PPC {
2202
2689
  boffset = vec;
2203
2690
  j = (rows >> 3);
2204
2691
  if (j > 0) {
2692
+
2205
2693
  do {
2206
2694
  aoffset1 = aoffset;
2207
2695
  aoffset2 = aoffset1 + lda;
@@ -2875,9 +3363,22 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2875
3363
  (float *)C, ldc};
2876
3364
  return tb.matmul(m, n);
2877
3365
  }
3366
+ #elif defined(__MMA__)
3367
+ if ((k % 8))
3368
+ return false;
3369
+ if(Btype == GGML_TYPE_BF16) {
3370
+ tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3371
+ (const ggml_bf16_t *)A, lda,
3372
+ (const ggml_bf16_t *)B, ldb,
3373
+ (float *)C, ldc,
3374
+ params->ith, params->nth};
3375
+ tb.matmul(m, n);
3376
+ return true;
3377
+ }
2878
3378
  #endif
2879
3379
  return false;
2880
3380
  }
3381
+
2881
3382
  case GGML_TYPE_F16: {
2882
3383
  #if defined(__AVX512F__)
2883
3384
  if (Btype == GGML_TYPE_F16) {
@@ -8,19 +8,6 @@
8
8
 
9
9
  #include <float.h>
10
10
 
11
- #if defined(_MSC_VER)
12
- // disable "possible loss of data" to avoid hundreds of casts
13
- // we should just be careful :)
14
- #pragma warning(disable: 4244 4267)
15
-
16
- // disable POSIX deprecation warnings
17
- // these functions are never going away, anyway
18
- #pragma warning(disable: 4996)
19
-
20
- // unreachable code because of multiple instances of code after GGML_ABORT
21
- #pragma warning(disable: 4702)
22
- #endif
23
-
24
11
  // ggml_compute_forward_dup
25
12
 
26
13
  static void ggml_compute_forward_dup_same_cont(
@@ -2,12 +2,6 @@
2
2
 
3
3
  #include <cassert>
4
4
 
5
- #if defined(_MSC_VER)
6
- // disable "possible loss of data" to avoid hundreds of casts
7
- // we should just be careful :)
8
- #pragma warning(disable: 4244 4267)
9
- #endif
10
-
11
5
  // precomputed gelu table for f16 (128 KB)
12
6
  ggml_fp16_t ggml_table_gelu_f16[1 << 16];
13
7
 
@@ -12,12 +12,30 @@ if (CUDAToolkit_FOUND)
12
12
  # 61 == Pascal, __dp4a instruction (per-byte integer dot product)
13
13
  # 70 == V100, FP16 tensor cores
14
14
  # 75 == Turing, int8 tensor cores
15
+ # 80 == Ampere, asynchronous data loading, faster tensor core instructions
16
+ # 86 == RTX 3000, needs CUDA v11.1
17
+ # 89 == RTX 4000, needs CUDA v11.8
18
+ #
19
+ # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
20
+ # XX-real == compile CUDA code as device code for this specific architecture
21
+ # no suffix == compile as both PTX and device code
22
+ #
23
+ # The default behavior for a non-native is to build virtual architectures as needed to cover all features needed
24
+ # for best performance and to also build real architectures for the most commonly used GPUs.
15
25
  if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
16
26
  set(CMAKE_CUDA_ARCHITECTURES "native")
17
27
  elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
18
- set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
28
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
29
+ set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
30
+ else()
31
+ set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
32
+ endif()
19
33
  else()
20
- set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80")
34
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
35
+ set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
36
+ else()
37
+ set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
38
+ endif()
21
39
  endif()
22
40
  endif()
23
41
  message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
@@ -100,7 +118,7 @@ if (CUDAToolkit_FOUND)
100
118
 
101
119
  set(CUDA_CXX_FLAGS "")
102
120
 
103
- set(CUDA_FLAGS -use_fast_math)
121
+ set(CUDA_FLAGS -use_fast_math -extended-lambda)
104
122
 
105
123
  if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
106
124
  # Options are:
@@ -133,6 +151,7 @@ if (CUDAToolkit_FOUND)
133
151
  COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
134
152
  OUTPUT_VARIABLE CUDA_CCVER
135
153
  ERROR_QUIET
154
+ OUTPUT_STRIP_TRAILING_WHITESPACE
136
155
  )
137
156
  else()
138
157
  if (CUDA_CCFULLVER MATCHES Apple)
@@ -143,7 +162,7 @@ if (CUDAToolkit_FOUND)
143
162
  string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
144
163
  endif()
145
164
 
146
- message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
165
+ message(STATUS "CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
147
166
 
148
167
  ggml_get_flags(${CUDA_CCID} ${CUDA_CCVER})
149
168
  list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
@@ -207,6 +207,10 @@ typedef struct {
207
207
  float attn_factor;
208
208
  float beta_fast;
209
209
  float beta_slow;
210
+ int32_t sect_0;
211
+ int32_t sect_1;
212
+ int32_t sect_2;
213
+ int32_t sect_3;
210
214
  } ggml_metal_kargs_rope;
211
215
 
212
216
  typedef struct {
@@ -299,21 +303,42 @@ typedef struct {
299
303
  } ggml_metal_kargs_mul_mv_ext;
300
304
 
301
305
  typedef struct {
302
- int32_t nei0;
303
- int32_t nei1;
304
- uint64_t nbi1;
306
+ int32_t ne10;
307
+ int32_t ne11; // n_expert_used (bcast)
308
+ uint64_t nb11;
309
+ uint64_t nb12;
310
+ int32_t neh11; // n_tokens
311
+ uint64_t nbh11;
312
+ int32_t ne20; // n_expert_used
313
+ uint64_t nb21;
314
+ } ggml_metal_kargs_mul_mm_id_map0;
315
+
316
+ typedef struct {
317
+ int32_t ne20; // n_expert_used
318
+ int32_t neh0;
319
+ int32_t neh1;
320
+ uint64_t nbh1;
321
+ uint64_t nbh2;
322
+ int32_t ne0;
323
+ uint64_t nb1;
324
+ uint64_t nb2;
325
+ } ggml_metal_kargs_mul_mm_id_map1;
326
+
327
+ typedef struct {
305
328
  int32_t ne00;
306
329
  int32_t ne02;
307
330
  uint64_t nb01;
308
331
  uint64_t nb02;
309
- int32_t ne11;
310
- int32_t ne12;
311
- int32_t ne13;
312
- uint64_t nb10;
313
- uint64_t nb11;
314
- uint64_t nb12;
315
- int32_t ne0;
316
- int32_t ne1;
332
+ uint64_t nb03;
333
+ int32_t neh12;
334
+ uint64_t nbh10;
335
+ uint64_t nbh11;
336
+ uint64_t nbh12;
337
+ uint64_t nbh13;
338
+ int32_t neh0;
339
+ int32_t neh1;
340
+ int16_t r2;
341
+ int16_t r3;
317
342
  } ggml_metal_kargs_mul_mm_id;
318
343
 
319
344
  typedef struct {
@@ -4855,8 +4855,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
4855
4855
  if (!any_on_device) {
4856
4856
  return false;
4857
4857
  }
4858
- GGML_ASSERT(ggml_is_contiguous(src0));
4859
- GGML_ASSERT(ggml_is_contiguous(src1));
4860
4858
  func = ggml_cl_add;
4861
4859
  break;
4862
4860
  case GGML_OP_MUL: