cui-llama.rn 1.6.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +16 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
  4. package/android/src/main/jni.cpp +20 -4
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/cpp/LICENSE +21 -0
  14. package/cpp/chat.cpp +1 -1
  15. package/cpp/common.cpp +17 -2
  16. package/cpp/common.h +7 -3
  17. package/cpp/ggml-alloc.c +4 -1
  18. package/cpp/ggml-cpp.h +1 -1
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  25. package/cpp/ggml-cpu/common.h +72 -0
  26. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  27. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  28. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  29. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  31. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  32. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  33. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  34. package/cpp/ggml-cpu.h +5 -0
  35. package/cpp/ggml-impl.h +16 -9
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal.m +492 -47
  39. package/cpp/ggml.c +134 -244
  40. package/cpp/ggml.h +61 -94
  41. package/cpp/json-schema-to-grammar.cpp +3 -0
  42. package/cpp/llama-arch.cpp +46 -17
  43. package/cpp/llama-arch.h +9 -0
  44. package/cpp/llama-batch.cpp +5 -1
  45. package/cpp/llama-batch.h +2 -1
  46. package/cpp/llama-chat.cpp +31 -10
  47. package/cpp/llama-chat.h +3 -2
  48. package/cpp/llama-context.cpp +104 -489
  49. package/cpp/llama-context.h +14 -30
  50. package/cpp/llama-graph.cpp +69 -62
  51. package/cpp/llama-graph.h +21 -18
  52. package/cpp/llama-hparams.h +5 -0
  53. package/cpp/llama-kv-cache.cpp +1497 -391
  54. package/cpp/llama-kv-cache.h +272 -80
  55. package/cpp/llama-memory.h +11 -1
  56. package/cpp/llama-model.cpp +502 -176
  57. package/cpp/llama-model.h +13 -3
  58. package/cpp/llama-sampling.cpp +2 -1
  59. package/cpp/llama-vocab.cpp +8 -1
  60. package/cpp/llama.h +14 -11
  61. package/cpp/rn-llama.cpp +20 -172
  62. package/cpp/rn-llama.h +1 -5
  63. package/ios/CMakeLists.txt +13 -10
  64. package/ios/RNLlama.h +6 -0
  65. package/ios/RNLlama.mm +5 -0
  66. package/ios/RNLlamaContext.mm +26 -28
  67. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
  68. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  69. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  70. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  71. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
  72. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  73. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  74. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  75. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  76. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  77. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  78. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  79. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  80. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  81. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  85. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  86. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  87. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  88. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  89. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  90. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  91. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  92. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  93. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  94. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  95. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  96. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  97. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  98. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  99. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  100. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  101. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  102. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  103. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
  104. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  105. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  106. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  107. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
  108. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  109. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  110. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  111. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  112. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  113. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  114. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  115. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  116. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  117. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
  118. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  119. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  120. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  121. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  122. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  125. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  126. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  127. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  128. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  129. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  130. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  131. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  132. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  133. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  134. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  135. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  136. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  137. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  138. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  139. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  140. package/lib/module/NativeRNLlama.js.map +1 -1
  141. package/lib/typescript/NativeRNLlama.d.ts +4 -0
  142. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  143. package/package.json +1 -1
  144. package/src/NativeRNLlama.ts +5 -0
  145. package/cpp/binary-ops.h +0 -16
  146. package/cpp/ops.h +0 -128
  147. package/cpp/simd-mappings.h +0 -888
  148. package/cpp/unary-ops.h +0 -28
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  176. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  177. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  178. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  179. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  180. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  181. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  182. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  183. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  184. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  185. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  186. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  187. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  188. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  189. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  190. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  191. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  192. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  193. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  194. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  195. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
@@ -215,7 +215,7 @@ static const struct lm_ggml_type_traits_cpu type_traits_cpu[LM_GGML_TYPE_COUNT]
215
215
  .nrows = 1,
216
216
  },
217
217
  [LM_GGML_TYPE_F16] = {
218
- .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_fp16_row,
218
+ .from_float = (lm_ggml_from_float_t) lm_ggml_cpu_fp32_to_fp16,
219
219
  .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_f16,
220
220
  .vec_dot_type = LM_GGML_TYPE_F16,
221
221
  .nrows = 1,
@@ -356,7 +356,7 @@ static const struct lm_ggml_type_traits_cpu type_traits_cpu[LM_GGML_TYPE_COUNT]
356
356
  .from_float = quantize_row_q8_K,
357
357
  },
358
358
  [LM_GGML_TYPE_BF16] = {
359
- .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row,
359
+ .from_float = (lm_ggml_from_float_t) lm_ggml_cpu_fp32_to_bf16,
360
360
  .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_bf16,
361
361
  .vec_dot_type = LM_GGML_TYPE_BF16,
362
362
  .nrows = 1,
@@ -1932,6 +1932,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
1932
1932
  {
1933
1933
  lm_ggml_compute_forward_im2col_back_f32(params, tensor);
1934
1934
  } break;
1935
+ case LM_GGML_OP_CONV_2D_DW:
1936
+ {
1937
+ lm_ggml_compute_forward_conv_2d_dw(params, tensor);
1938
+ } break;
1935
1939
  case LM_GGML_OP_CONV_TRANSPOSE_2D:
1936
1940
  {
1937
1941
  lm_ggml_compute_forward_conv_transpose_2d(params, tensor);
@@ -2027,41 +2031,6 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
2027
2031
  {
2028
2032
  lm_ggml_compute_forward_rwkv_wkv7(params, tensor);
2029
2033
  } break;
2030
- case LM_GGML_OP_MAP_UNARY:
2031
- {
2032
- lm_ggml_unary_op_f32_t fun;
2033
- memcpy(&fun, tensor->op_params, sizeof(fun));
2034
- lm_ggml_compute_forward_map_unary(params, tensor, fun);
2035
- }
2036
- break;
2037
- case LM_GGML_OP_MAP_BINARY:
2038
- {
2039
- lm_ggml_binary_op_f32_t fun;
2040
- memcpy(&fun, tensor->op_params, sizeof(fun));
2041
- lm_ggml_compute_forward_map_binary(params, tensor, fun);
2042
- }
2043
- break;
2044
- case LM_GGML_OP_MAP_CUSTOM1_F32:
2045
- {
2046
- lm_ggml_custom1_op_f32_t fun;
2047
- memcpy(&fun, tensor->op_params, sizeof(fun));
2048
- lm_ggml_compute_forward_map_custom1_f32(params, tensor, fun);
2049
- }
2050
- break;
2051
- case LM_GGML_OP_MAP_CUSTOM2_F32:
2052
- {
2053
- lm_ggml_custom2_op_f32_t fun;
2054
- memcpy(&fun, tensor->op_params, sizeof(fun));
2055
- lm_ggml_compute_forward_map_custom2_f32(params, tensor, fun);
2056
- }
2057
- break;
2058
- case LM_GGML_OP_MAP_CUSTOM3_F32:
2059
- {
2060
- lm_ggml_custom3_op_f32_t fun;
2061
- memcpy(&fun, tensor->op_params, sizeof(fun));
2062
- lm_ggml_compute_forward_map_custom3_f32(params, tensor, fun);
2063
- }
2064
- break;
2065
2034
  case LM_GGML_OP_MAP_CUSTOM1:
2066
2035
  {
2067
2036
  lm_ggml_compute_forward_map_custom1(params, tensor);
@@ -2077,6 +2046,11 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
2077
2046
  lm_ggml_compute_forward_map_custom3(params, tensor);
2078
2047
  }
2079
2048
  break;
2049
+ case LM_GGML_OP_CUSTOM:
2050
+ {
2051
+ lm_ggml_compute_forward_custom(params, tensor);
2052
+ }
2053
+ break;
2080
2054
  case LM_GGML_OP_CROSS_ENTROPY_LOSS:
2081
2055
  {
2082
2056
  lm_ggml_compute_forward_cross_entropy_loss(params, tensor);
@@ -2298,6 +2272,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
2298
2272
  } break;
2299
2273
  case LM_GGML_OP_IM2COL:
2300
2274
  case LM_GGML_OP_IM2COL_BACK:
2275
+ case LM_GGML_OP_CONV_2D_DW:
2301
2276
  case LM_GGML_OP_CONV_TRANSPOSE_1D:
2302
2277
  case LM_GGML_OP_CONV_TRANSPOSE_2D:
2303
2278
  {
@@ -2328,11 +2303,6 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
2328
2303
  case LM_GGML_OP_WIN_PART:
2329
2304
  case LM_GGML_OP_WIN_UNPART:
2330
2305
  case LM_GGML_OP_GET_REL_POS:
2331
- case LM_GGML_OP_MAP_UNARY:
2332
- case LM_GGML_OP_MAP_BINARY:
2333
- case LM_GGML_OP_MAP_CUSTOM1_F32:
2334
- case LM_GGML_OP_MAP_CUSTOM2_F32:
2335
- case LM_GGML_OP_MAP_CUSTOM3_F32:
2336
2306
  {
2337
2307
  n_tasks = 1;
2338
2308
  } break;
@@ -2366,6 +2336,16 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
2366
2336
  n_tasks = MIN(p.n_tasks, n_threads);
2367
2337
  }
2368
2338
  } break;
2339
+ case LM_GGML_OP_CUSTOM:
2340
+ {
2341
+ struct lm_ggml_custom_op_params p;
2342
+ memcpy(&p, node->op_params, sizeof(p));
2343
+ if (p.n_tasks == LM_GGML_N_TASKS_MAX) {
2344
+ n_tasks = n_threads;
2345
+ } else {
2346
+ n_tasks = MIN(p.n_tasks, n_threads);
2347
+ }
2348
+ } break;
2369
2349
  case LM_GGML_OP_CROSS_ENTROPY_LOSS:
2370
2350
  case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2371
2351
  case LM_GGML_OP_OPT_STEP_ADAMW:
@@ -3186,6 +3166,93 @@ enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx,
3186
3166
  return lm_ggml_graph_compute(cgraph, &cplan);
3187
3167
  }
3188
3168
 
3169
+ void lm_ggml_cpu_fp32_to_fp16(const float * x, lm_ggml_fp16_t * y, int64_t n) {
3170
+ int64_t i = 0;
3171
+ #if defined(__F16C__)
3172
+ #if defined(__AVX512F__)
3173
+ for (; i + 15 < n; i += 16) {
3174
+ __m512 x_vec = _mm512_loadu_ps(x + i);
3175
+ __m256i y_vec = _mm512_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
3176
+ _mm256_storeu_si256((__m256i *)(y + i), y_vec);
3177
+ }
3178
+ #endif
3179
+ for (; i + 7 < n; i += 8) {
3180
+ __m256 x_vec = _mm256_loadu_ps(x + i);
3181
+ __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
3182
+ _mm_storeu_si128((__m128i *)(y + i), y_vec);
3183
+ }
3184
+ for (; i + 3 < n; i += 4) {
3185
+ __m128 x_vec = _mm_loadu_ps(x + i);
3186
+ __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
3187
+ _mm_storel_epi64((__m128i *)(y + i), y_vec);
3188
+ }
3189
+ #endif
3190
+ for (; i < n; ++i) {
3191
+ y[i] = LM_GGML_FP32_TO_FP16(x[i]);
3192
+ }
3193
+ }
3194
+
3195
+ void lm_ggml_cpu_fp16_to_fp32(const lm_ggml_fp16_t * x, float * y, int64_t n) {
3196
+ int64_t i = 0;
3197
+ #if defined(__F16C__)
3198
+ #if defined(__AVX512F__)
3199
+ for (; i + 15 < n; i += 16) {
3200
+ __m256i x_vec = _mm256_loadu_si256((const __m256i *)(x + i));
3201
+ __m512 y_vec = _mm512_cvtph_ps(x_vec);
3202
+ _mm512_storeu_ps(y + i, y_vec);
3203
+ }
3204
+ #endif
3205
+ for (; i + 7 < n; i += 8) {
3206
+ __m128i x_vec = _mm_loadu_si128((const __m128i *)(x + i));
3207
+ __m256 y_vec = _mm256_cvtph_ps(x_vec);
3208
+ _mm256_storeu_ps(y + i, y_vec);
3209
+ }
3210
+ for (; i + 3 < n; i += 4) {
3211
+ __m128i x_vec = _mm_loadl_epi64((const __m128i *)(x + i));
3212
+ __m128 y_vec = _mm_cvtph_ps(x_vec);
3213
+ _mm_storeu_ps(y + i, y_vec);
3214
+ }
3215
+ #endif
3216
+ for (; i < n; ++i) {
3217
+ y[i] = LM_GGML_FP16_TO_FP32(x[i]);
3218
+ }
3219
+ }
3220
+
3221
+ void lm_ggml_cpu_fp32_to_bf16(const float * x, lm_ggml_bf16_t * y, int64_t n) {
3222
+ int64_t i = 0;
3223
+ for (; i < n; ++i) {
3224
+ y[i] = LM_GGML_FP32_TO_BF16(x[i]);
3225
+ }
3226
+ }
3227
+
3228
+ void lm_ggml_cpu_bf16_to_fp32(const lm_ggml_bf16_t * x, float * y, int64_t n) {
3229
+ int64_t i = 0;
3230
+ #if defined(__AVX2__)
3231
+ #if defined(__AVX512F__)
3232
+ for (; i + 15 < n; i += 16) {
3233
+ _mm512_storeu_ps(y + i,
3234
+ _mm512_castsi512_ps(
3235
+ _mm512_slli_epi32(
3236
+ _mm512_cvtepu16_epi32(
3237
+ _mm256_loadu_si256(
3238
+ (const __m256i *)(x + i))),
3239
+ 16)));
3240
+ }
3241
+ #endif
3242
+ for (; i + 7 < n; i += 8) {
3243
+ _mm256_storeu_ps(y + i,
3244
+ _mm256_castsi256_ps(
3245
+ _mm256_slli_epi32(
3246
+ _mm256_cvtepu16_epi32(
3247
+ _mm_loadu_si128(
3248
+ (const __m128i *)(x + i))),
3249
+ 16)));
3250
+ }
3251
+ #endif
3252
+ for (; i < n; i++) {
3253
+ y[i] = LM_GGML_BF16_TO_FP32(x[i]);
3254
+ }
3255
+ }
3189
3256
 
3190
3257
  int lm_ggml_cpu_has_avx(void) {
3191
3258
  #if defined(__AVX__)
@@ -4,6 +4,7 @@
4
4
  #include "ggml-cpu-aarch64.h"
5
5
  #include "ggml-cpu-traits.h"
6
6
  #include "ggml-impl.h"
7
+ #include "amx/amx.h"
7
8
 
8
9
  #include <cctype>
9
10
  #include <string>
@@ -424,6 +425,8 @@ static bool lm_ggml_backend_cpu_device_supports_op(lm_ggml_backend_dev_t dev, co
424
425
  }
425
426
  case LM_GGML_OP_IM2COL_BACK:
426
427
  return src0->type == LM_GGML_TYPE_F32 && src1->type == LM_GGML_TYPE_F32;
428
+ case LM_GGML_OP_GET_ROWS_BACK:
429
+ return src0->type == LM_GGML_TYPE_F32 || src0->type == LM_GGML_TYPE_F16;
427
430
  case LM_GGML_OP_OUT_PROD:
428
431
  return (src0->type == LM_GGML_TYPE_F32 || (lm_ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
429
432
  src1->type == LM_GGML_TYPE_F32 && op->type == LM_GGML_TYPE_F32;
@@ -4222,7 +4222,7 @@ static void lm_ggml_compute_forward_get_rows_f16(
4222
4222
 
4223
4223
  LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
4224
4224
 
4225
- lm_ggml_fp16_to_fp32_row(
4225
+ lm_ggml_cpu_fp16_to_fp32(
4226
4226
  (const lm_ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4227
4227
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
4228
4228
  }
@@ -4263,7 +4263,7 @@ static void lm_ggml_compute_forward_get_rows_bf16(
4263
4263
 
4264
4264
  LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
4265
4265
 
4266
- lm_ggml_bf16_to_fp32_row(
4266
+ lm_ggml_cpu_bf16_to_fp32(
4267
4267
  (const lm_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4268
4268
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
4269
4269
  }
@@ -6064,6 +6064,178 @@ void lm_ggml_compute_forward_conv_transpose_2d(
6064
6064
  }
6065
6065
  }
6066
6066
 
6067
+ // lm_ggml_compute_forward_conv_2d_dw
6068
+
6069
+ struct lm_ggml_conv_2d_dw_params {
6070
+ int64_t channels;
6071
+ int64_t batch;
6072
+ int64_t src_w;
6073
+ int64_t src_h;
6074
+ int64_t dst_w;
6075
+ int64_t dst_h;
6076
+ int64_t knl_w;
6077
+ int64_t knl_h;
6078
+ int stride_x;
6079
+ int stride_y;
6080
+ int pad_x;
6081
+ int pad_y;
6082
+ int dilation_x;
6083
+ int dilation_y;
6084
+ };
6085
+
6086
+ static void lm_ggml_compute_forward_conv_2d_dw_cwhn(
6087
+ const lm_ggml_compute_params * params,
6088
+ const lm_ggml_tensor * src,
6089
+ const lm_ggml_tensor * kernel,
6090
+ lm_ggml_tensor * dst,
6091
+ const lm_ggml_conv_2d_dw_params & p) {
6092
+
6093
+ const int64_t c = p.channels;
6094
+ const float * knl_data = (const float *)kernel->data;
6095
+
6096
+ const int64_t rows_total = p.dst_h * p.batch;
6097
+ const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
6098
+ const int64_t row_start = params->ith * rows_per_thread;
6099
+ const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
6100
+
6101
+ #ifdef LM_GGML_SIMD
6102
+ const int64_t pkg_size = LM_GGML_F32_EPR;
6103
+ const int64_t pkg_count = c / pkg_size;
6104
+ const int64_t c_pkg_end = pkg_count * pkg_size;
6105
+ #else
6106
+ const int64_t c_pkg_end = 0;
6107
+ #endif
6108
+
6109
+ for (int64_t row = row_start; row < row_end; ++row) {
6110
+ const int64_t dst_y = row % p.dst_h;
6111
+ const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
6112
+ for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
6113
+ float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
6114
+ const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
6115
+ const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
6116
+
6117
+ #ifdef LM_GGML_SIMD
6118
+ // Vectorized loop
6119
+ for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
6120
+ LM_GGML_F32_VEC sum = LM_GGML_F32_VEC_ZERO;
6121
+ for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
6122
+ const int64_t src_y = src_y_base + knl_y * p.dilation_y;
6123
+ if (src_y < 0 || src_y >= p.src_h) {
6124
+ continue;
6125
+ }
6126
+ for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
6127
+ const int64_t src_x = src_x_base + knl_x * p.dilation_x;
6128
+ if (src_x < 0 || src_x >= p.src_w) {
6129
+ continue;
6130
+ }
6131
+ LM_GGML_F32_VEC k = LM_GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
6132
+ LM_GGML_F32_VEC s = LM_GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
6133
+ sum = LM_GGML_F32_VEC_FMA(sum, k, s);
6134
+ }
6135
+ }
6136
+ LM_GGML_F32_VEC_STORE(dst_data + c_i, sum);
6137
+ }
6138
+ #endif
6139
+ // Scalar loop
6140
+ for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
6141
+ float sum = 0.0f;
6142
+ for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
6143
+ const int64_t src_y = src_y_base + knl_y * p.dilation_y;
6144
+ if (src_y < 0 || src_y >= p.src_h) {
6145
+ continue;
6146
+ }
6147
+ for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
6148
+ const int64_t src_x = src_x_base + knl_x * p.dilation_x;
6149
+ if (src_x < 0 || src_x >= p.src_w) {
6150
+ continue;
6151
+ }
6152
+ sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
6153
+ * src_data[(src_y * p.src_w + src_x) * c + c_i];
6154
+ }
6155
+ }
6156
+ dst_data[c_i] = sum;
6157
+ }
6158
+ }
6159
+ }
6160
+ }
6161
+
6162
+ static void lm_ggml_compute_forward_conv_2d_dw_whcn(
6163
+ const lm_ggml_compute_params * params,
6164
+ const lm_ggml_tensor * src,
6165
+ const lm_ggml_tensor * kernel,
6166
+ lm_ggml_tensor * dst,
6167
+ const lm_ggml_conv_2d_dw_params & p) {
6168
+
6169
+ const int64_t n = p.channels * p.batch;
6170
+ const int64_t per_thread = (n + params->nth - 1) / params->nth;
6171
+ const int64_t start = params->ith * per_thread;
6172
+ const int64_t end = MIN(start + per_thread, n);
6173
+
6174
+ for (int64_t i = start; i < end; ++i) {
6175
+ const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
6176
+ const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
6177
+ float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
6178
+
6179
+ for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
6180
+ for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
6181
+
6182
+ float sum = 0.0f;
6183
+ for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
6184
+ const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
6185
+ if (src_y < 0 || src_y >= p.src_h) {
6186
+ continue;
6187
+ }
6188
+ for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
6189
+ const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
6190
+ if (src_x < 0 || src_x >= p.src_w) {
6191
+ continue;
6192
+ }
6193
+ sum += knl_data[knl_y * p.knl_w + knl_x]
6194
+ * src_data[src_y * p.src_w + src_x];
6195
+ }
6196
+ }
6197
+ dst_data[dst_y * p.dst_w + dst_x] = sum;
6198
+ }
6199
+ }
6200
+ }
6201
+ }
6202
+
6203
+ void lm_ggml_compute_forward_conv_2d_dw(
6204
+ const lm_ggml_compute_params * params,
6205
+ lm_ggml_tensor * dst) {
6206
+
6207
+ const lm_ggml_tensor * kernel = dst->src[0];
6208
+ const lm_ggml_tensor * src = dst->src[1];
6209
+ lm_ggml_conv_2d_dw_params p;
6210
+ p.channels = src->ne[2];
6211
+ p.batch = src->ne[3];
6212
+ p.src_w = src->ne[0];
6213
+ p.src_h = src->ne[1];
6214
+ p.dst_w = dst->ne[0];
6215
+ p.dst_h = dst->ne[1];
6216
+ p.knl_w = kernel->ne[0];
6217
+ p.knl_h = kernel->ne[1];
6218
+ p.stride_x = dst->op_params[0];
6219
+ p.stride_y = dst->op_params[1];
6220
+ p.pad_x = dst->op_params[2];
6221
+ p.pad_y = dst->op_params[3];
6222
+ p.dilation_x = dst->op_params[4];
6223
+ p.dilation_y = dst->op_params[5];
6224
+
6225
+ LM_GGML_ASSERT(kernel->ne[3] == p.channels);
6226
+ LM_GGML_ASSERT(dst->ne[3] == p.batch);
6227
+
6228
+ if (lm_ggml_is_contiguous(src)) {
6229
+ lm_ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
6230
+ } else if (lm_ggml_is_contiguous_channels(src)) {
6231
+ // kernel should also have channels most contiguous in memory
6232
+ LM_GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
6233
+ lm_ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
6234
+ } else {
6235
+ LM_GGML_ABORT("non-contiguous memory layout not supported");
6236
+ }
6237
+ }
6238
+
6067
6239
  // lm_ggml_compute_forward_pool_1d_sk_p0
6068
6240
 
6069
6241
  static void lm_ggml_compute_forward_pool_1d_sk_p0(
@@ -6351,24 +6523,72 @@ static void lm_ggml_compute_forward_upscale_f32(
6351
6523
  const float sf2 = (float)ne2/src0->ne[2];
6352
6524
  const float sf3 = (float)ne3/src0->ne[3];
6353
6525
 
6354
- // TODO: optimize
6526
+ const lm_ggml_scale_mode mode = (lm_ggml_scale_mode) lm_ggml_get_op_params_i32(dst, 0);
6355
6527
 
6356
- for (int64_t i3 = 0; i3 < ne3; i3++) {
6357
- const int64_t i03 = i3 / sf3;
6358
- for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
6359
- const int64_t i02 = i2 / sf2;
6360
- for (int64_t i1 = 0; i1 < ne1; i1++) {
6361
- const int64_t i01 = i1 / sf1;
6362
- for (int64_t i0 = 0; i0 < ne0; i0++) {
6363
- const int64_t i00 = i0 / sf0;
6528
+ if (mode == LM_GGML_SCALE_MODE_NEAREST) {
6529
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
6530
+ const int64_t i03 = i3 / sf3;
6531
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
6532
+ const int64_t i02 = i2 / sf2;
6533
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
6534
+ const int64_t i01 = i1 / sf1;
6535
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
6536
+ const int64_t i00 = i0 / sf0;
6364
6537
 
6365
- const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
6366
- float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6538
+ const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
6539
+ float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6367
6540
 
6368
- *y = *x;
6541
+ *y = *x;
6542
+ }
6543
+ }
6544
+ }
6545
+ }
6546
+ } else if (mode == LM_GGML_SCALE_MODE_BILINEAR) {
6547
+ // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
6548
+ const float pixel_offset = 0.5f;
6549
+
6550
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
6551
+ const int64_t i03 = i3 / sf3;
6552
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
6553
+ const int64_t i02 = i2 / sf2;
6554
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
6555
+ const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
6556
+ int64_t y0 = (int64_t)floorf(y);
6557
+ int64_t y1 = y0 + 1;
6558
+
6559
+ y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
6560
+ y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
6561
+
6562
+ float dy = y - (float)y0;
6563
+ dy = std::max(0.0f, std::min(dy, 1.0f));
6564
+
6565
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
6566
+ const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
6567
+ int64_t x0 = (int64_t)floorf(x);
6568
+ int64_t x1 = x0 + 1;
6569
+
6570
+ x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
6571
+ x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
6572
+
6573
+ float dx = x - (float)x0;
6574
+ dx = std::max(0.0f, std::min(dx, 1.0f));
6575
+
6576
+ // fetch the four surrounding pixel values and interpolate
6577
+ const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
6578
+ const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
6579
+ const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
6580
+ const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
6581
+
6582
+ const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
6583
+
6584
+ float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6585
+ *y_dst = val;
6586
+ }
6369
6587
  }
6370
6588
  }
6371
6589
  }
6590
+ } else {
6591
+ LM_GGML_ABORT("unsupported upscale mode");
6372
6592
  }
6373
6593
  }
6374
6594
 
@@ -8268,152 +8488,6 @@ void lm_ggml_compute_forward_rwkv_wkv7(
8268
8488
  }
8269
8489
  }
8270
8490
 
8271
- // lm_ggml_compute_forward_map_unary
8272
-
8273
- static void lm_ggml_compute_forward_map_unary_f32(
8274
- const lm_ggml_compute_params * params,
8275
- lm_ggml_tensor * dst,
8276
- const lm_ggml_unary_op_f32_t fun) {
8277
-
8278
- const lm_ggml_tensor * src0 = dst->src[0];
8279
-
8280
- if (params->ith != 0) {
8281
- return;
8282
- }
8283
-
8284
- assert(lm_ggml_is_contiguous_1(src0));
8285
- assert(lm_ggml_is_contiguous_1(dst));
8286
- assert(lm_ggml_are_same_shape(src0, dst));
8287
-
8288
- const int n = lm_ggml_nrows(src0);
8289
- const int nc = src0->ne[0];
8290
-
8291
- for (int i = 0; i < n; i++) {
8292
- fun(nc,
8293
- (float *) ((char *) dst->data + i*( dst->nb[1])),
8294
- (float *) ((char *) src0->data + i*(src0->nb[1])));
8295
- }
8296
- }
8297
-
8298
- void lm_ggml_compute_forward_map_unary(
8299
- const lm_ggml_compute_params * params,
8300
- lm_ggml_tensor * dst,
8301
- const lm_ggml_unary_op_f32_t fun) {
8302
-
8303
- const lm_ggml_tensor * src0 = dst->src[0];
8304
-
8305
- switch (src0->type) {
8306
- case LM_GGML_TYPE_F32:
8307
- {
8308
- lm_ggml_compute_forward_map_unary_f32(params, dst, fun);
8309
- } break;
8310
- default:
8311
- {
8312
- LM_GGML_ABORT("fatal error");
8313
- }
8314
- }
8315
- }
8316
-
8317
- // lm_ggml_compute_forward_map_binary
8318
-
8319
- static void lm_ggml_compute_forward_map_binary_f32(
8320
- const lm_ggml_compute_params * params,
8321
- lm_ggml_tensor * dst,
8322
- const lm_ggml_binary_op_f32_t fun) {
8323
-
8324
- const lm_ggml_tensor * src0 = dst->src[0];
8325
- const lm_ggml_tensor * src1 = dst->src[1];
8326
-
8327
- if (params->ith != 0) {
8328
- return;
8329
- }
8330
-
8331
- assert(lm_ggml_is_contiguous_1(src0));
8332
- assert(lm_ggml_is_contiguous_1(src1));
8333
- assert(lm_ggml_is_contiguous_1(dst));
8334
- assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst));
8335
-
8336
- const int n = lm_ggml_nrows(src0);
8337
- const int nc = src0->ne[0];
8338
-
8339
- for (int i = 0; i < n; i++) {
8340
- fun(nc,
8341
- (float *) ((char *) dst->data + i*( dst->nb[1])),
8342
- (float *) ((char *) src0->data + i*(src0->nb[1])),
8343
- (float *) ((char *) src1->data + i*(src1->nb[1])));
8344
- }
8345
- }
8346
-
8347
- void lm_ggml_compute_forward_map_binary(
8348
- const lm_ggml_compute_params * params,
8349
- lm_ggml_tensor * dst,
8350
- const lm_ggml_binary_op_f32_t fun) {
8351
-
8352
- const lm_ggml_tensor * src0 = dst->src[0];
8353
-
8354
- switch (src0->type) {
8355
- case LM_GGML_TYPE_F32:
8356
- {
8357
- lm_ggml_compute_forward_map_binary_f32(params, dst, fun);
8358
- } break;
8359
- default:
8360
- {
8361
- LM_GGML_ABORT("fatal error");
8362
- }
8363
- }
8364
- }
8365
-
8366
- // lm_ggml_compute_forward_map_custom1
8367
-
8368
- void lm_ggml_compute_forward_map_custom1_f32(
8369
- const lm_ggml_compute_params * params,
8370
- lm_ggml_tensor * dst,
8371
- const lm_ggml_custom1_op_f32_t fun) {
8372
-
8373
- const lm_ggml_tensor * a = dst->src[0];
8374
-
8375
- if (params->ith != 0) {
8376
- return;
8377
- }
8378
-
8379
- fun(dst, a);
8380
- }
8381
-
8382
- // lm_ggml_compute_forward_map_custom2
8383
-
8384
- void lm_ggml_compute_forward_map_custom2_f32(
8385
- const lm_ggml_compute_params * params,
8386
- lm_ggml_tensor * dst,
8387
- const lm_ggml_custom2_op_f32_t fun) {
8388
-
8389
- const lm_ggml_tensor * a = dst->src[0];
8390
- const lm_ggml_tensor * b = dst->src[1];
8391
-
8392
- if (params->ith != 0) {
8393
- return;
8394
- }
8395
-
8396
- fun(dst, a, b);
8397
- }
8398
-
8399
- // lm_ggml_compute_forward_map_custom3
8400
-
8401
- void lm_ggml_compute_forward_map_custom3_f32(
8402
- const lm_ggml_compute_params * params,
8403
- lm_ggml_tensor * dst,
8404
- const lm_ggml_custom3_op_f32_t fun) {
8405
-
8406
- const lm_ggml_tensor * a = dst->src[0];
8407
- const lm_ggml_tensor * b = dst->src[1];
8408
- const lm_ggml_tensor * c = dst->src[1];
8409
-
8410
- if (params->ith != 0) {
8411
- return;
8412
- }
8413
-
8414
- fun(dst, a, b, c);
8415
- }
8416
-
8417
8491
  // lm_ggml_compute_forward_map_custom1
8418
8492
 
8419
8493
  void lm_ggml_compute_forward_map_custom1(
@@ -8459,6 +8533,18 @@ void lm_ggml_compute_forward_map_custom3(
8459
8533
  p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
8460
8534
  }
8461
8535
 
8536
+ // lm_ggml_compute_forward_custom
8537
+
8538
+ void lm_ggml_compute_forward_custom(
8539
+ const struct lm_ggml_compute_params * params,
8540
+ struct lm_ggml_tensor * dst) {
8541
+
8542
+ struct lm_ggml_custom_op_params p;
8543
+ memcpy(&p, dst->op_params, sizeof(p));
8544
+
8545
+ p.fun(dst, params->ith, params->nth, p.userdata);
8546
+ }
8547
+
8462
8548
  // lm_ggml_compute_forward_cross_entropy_loss
8463
8549
 
8464
8550
  static void lm_ggml_compute_forward_cross_entropy_loss_f32(