@novastera-oss/llamarn 0.2.7 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +56 -22
  11. package/cpp/build-info.cpp +2 -2
  12. package/cpp/llama.cpp/CMakeLists.txt +1 -1
  13. package/cpp/llama.cpp/common/arg.cpp +7 -0
  14. package/cpp/llama.cpp/common/common.cpp +3 -0
  15. package/cpp/llama.cpp/common/common.h +1 -0
  16. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  17. package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
  18. package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
  19. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  20. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  21. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
  22. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  23. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
  24. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  25. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  26. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  27. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  28. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  30. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  32. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  33. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  34. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  35. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  62. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
  64. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
  65. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  66. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
  67. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  68. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  69. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  70. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  71. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  72. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  73. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  74. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  76. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  77. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
  78. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  79. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  80. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  81. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  82. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  83. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  89. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  90. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  92. package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
  93. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  94. package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
  95. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
  96. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
  97. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  98. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  99. package/cpp/llama.cpp/include/llama.h +8 -3
  100. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  101. package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
  102. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  103. package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
  104. package/cpp/llama.cpp/src/llama-batch.h +98 -70
  105. package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
  106. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  107. package/cpp/llama.cpp/src/llama-context.h +13 -13
  108. package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
  109. package/cpp/llama.cpp/src/llama-graph.h +44 -32
  110. package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
  111. package/cpp/llama.cpp/src/llama-hparams.h +8 -0
  112. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
  113. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
  114. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
  115. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
  116. package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
  117. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
  118. package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
  119. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
  120. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  121. package/cpp/llama.cpp/src/llama-memory.h +18 -22
  122. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  123. package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
  124. package/cpp/llama.cpp/src/llama-model.h +22 -0
  125. package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
  126. package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
  127. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  128. package/cpp/rn-utils.h +3 -0
  129. package/ios/include/common.h +1 -0
  130. package/ios/include/llama.h +8 -3
  131. package/ios/libs/llama.xcframework/Info.plist +19 -19
  132. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  133. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  134. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  135. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  136. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
  137. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  138. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  139. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  140. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  141. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  142. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  143. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  144. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  145. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  146. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  147. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
  148. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  149. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  150. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
  151. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  152. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  153. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
  154. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  155. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  160. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  161. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  162. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  163. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  164. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
  165. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  168. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  173. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
  175. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
  178. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  183. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  184. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  186. package/package.json +1 -1
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
48
48
  int mtl_device_ref_count;
49
49
  id<MTLLibrary> mtl_library;
50
50
 
51
+ NSLock * mtl_lock;
52
+
51
53
  bool has_simdgroup_reduction;
52
54
  bool has_simdgroup_mm;
53
55
  bool has_residency_sets;
54
56
  bool has_bfloat;
55
57
  bool use_bfloat;
56
58
 
59
+ size_t max_size;
60
+
57
61
  char name[128];
58
62
  } g_ggml_ctx_dev_main = {
59
63
  /*.mtl_device =*/ nil,
60
64
  /*.mtl_device_ref_count =*/ 0,
61
65
  /*.mtl_library =*/ nil,
66
+ /*.mtl_lock =*/ nil,
62
67
  /*.has_simdgroup_reduction =*/ false,
63
68
  /*.has_simdgroup_mm =*/ false,
64
69
  /*.has_residency_sets =*/ false,
65
70
  /*.has_bfloat =*/ false,
66
71
  /*.use_bfloat =*/ false,
72
+ /*.max_size =*/ 0,
67
73
  /*.name =*/ "",
68
74
  };
69
75
 
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
71
77
  static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
72
78
  assert(ctx != NULL);
73
79
 
80
+ if (ctx->mtl_lock == nil) {
81
+ ctx->mtl_lock = [[NSLock alloc] init];
82
+ }
83
+
74
84
  if (ctx->mtl_device == nil) {
75
85
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
76
86
  }
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
94
104
  ctx->use_bfloat = false;
95
105
  #endif
96
106
 
107
+ ctx->max_size = ctx->mtl_device.maxBufferLength;
108
+
97
109
  strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
98
110
  }
99
111
 
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
110
122
  ctx->mtl_device_ref_count--;
111
123
 
112
124
  if (ctx->mtl_device_ref_count == 0) {
125
+ if (ctx->mtl_lock) {
126
+ [ctx->mtl_lock release];
127
+ ctx->mtl_lock = nil;
128
+ }
129
+
113
130
  if (ctx->mtl_library) {
114
131
  [ctx->mtl_library release];
115
132
  ctx->mtl_library = nil;
@@ -185,6 +202,15 @@ enum ggml_metal_kernel_type {
185
202
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
186
203
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
187
204
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
188
214
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
189
215
  GGML_METAL_KERNEL_TYPE_L2_NORM,
190
216
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -194,11 +220,14 @@ enum ggml_metal_kernel_type {
194
220
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
195
221
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
196
222
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
223
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
197
224
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
225
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
198
226
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
199
227
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
200
228
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
201
229
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
230
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
202
231
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
203
232
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
204
233
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@@ -977,7 +1006,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
977
1006
  struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
978
1007
  struct ggml_backend_metal_device_context * ctx_dev = dev->context;
979
1008
 
980
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
1009
+ id<MTLDevice> device = ctx_dev->mtl_device;
981
1010
 
982
1011
  GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
983
1012
 
@@ -991,9 +1020,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
991
1020
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
992
1021
 
993
1022
  // load library
994
- if (ctx_dev->mtl_library == nil) {
995
- ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1023
+ {
1024
+ [ctx_dev->mtl_lock lock];
1025
+
1026
+ if (ctx_dev->mtl_library == nil) {
1027
+ ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1028
+ }
1029
+
1030
+ [ctx_dev->mtl_lock unlock];
996
1031
  }
1032
+
997
1033
  id<MTLLibrary> metal_library = ctx_dev->mtl_library;
998
1034
  if (metal_library == nil) {
999
1035
  GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@@ -1142,6 +1178,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1142
1178
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
1143
1179
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
1144
1180
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1181
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1182
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1183
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1184
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
1185
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
1186
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
1187
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1189
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1145
1190
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1146
1191
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1147
1192
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1151,11 +1196,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1151
1196
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1152
1197
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1153
1198
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
1199
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
1154
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
1201
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
1155
1202
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
1156
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
1157
1204
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
1158
1205
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
1206
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
1159
1207
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
1160
1208
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
1161
1209
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
@@ -1605,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1605
1653
  const bool use_bfloat = ctx_dev->use_bfloat;
1606
1654
 
1607
1655
  if (!use_bfloat) {
1656
+ if (op->type == GGML_TYPE_BF16) {
1657
+ return false;
1658
+ }
1659
+
1608
1660
  for (size_t i = 0, n = 3; i < n; ++i) {
1609
1661
  if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1610
1662
  return false;
@@ -1774,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1774
1826
  {
1775
1827
  return op->ne[3] == 1;
1776
1828
  }
1829
+ case GGML_OP_SET_ROWS:
1830
+ {
1831
+ if (op->src[0]->type != GGML_TYPE_F32) {
1832
+ return false;
1833
+ }
1834
+
1835
+ switch (op->type) {
1836
+ case GGML_TYPE_F32:
1837
+ case GGML_TYPE_F16:
1838
+ case GGML_TYPE_BF16:
1839
+ case GGML_TYPE_Q8_0:
1840
+ case GGML_TYPE_Q4_0:
1841
+ case GGML_TYPE_Q4_1:
1842
+ case GGML_TYPE_Q5_0:
1843
+ case GGML_TYPE_Q5_1:
1844
+ case GGML_TYPE_IQ4_NL:
1845
+ return true;
1846
+ default:
1847
+ return false;
1848
+ };
1849
+ }
1777
1850
  default:
1778
1851
  return false;
1779
1852
  }
@@ -2426,6 +2499,7 @@ static bool ggml_metal_encode_node(
2426
2499
  nth *= 2;
2427
2500
  }
2428
2501
 
2502
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
2429
2503
  nth = MIN(nth, ne00);
2430
2504
 
2431
2505
  ggml_metal_kargs_sum_rows args = {
@@ -3086,14 +3160,23 @@ static bool ggml_metal_encode_node(
3086
3160
  nsg = 1;
3087
3161
  nr0 = 1;
3088
3162
  nr1 = 4;
3089
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3163
+ if (ne00 == 4) {
3164
+ nr0 = 32;
3165
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
3166
+ } else {
3167
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3168
+ }
3090
3169
  } break;
3091
3170
  case GGML_TYPE_F16:
3092
3171
  {
3093
3172
  nsg = 1;
3094
3173
  nr0 = 1;
3095
3174
  if (src1t == GGML_TYPE_F32) {
3096
- if (ne11 * ne12 < 4) {
3175
+ if (ne00 == 4) {
3176
+ nr0 = 32;
3177
+ nr1 = 4;
3178
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
3179
+ } else if (ne11 * ne12 < 4) {
3097
3180
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
3098
3181
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3099
3182
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@@ -3112,7 +3195,11 @@ static bool ggml_metal_encode_node(
3112
3195
  nsg = 1;
3113
3196
  nr0 = 1;
3114
3197
  if (src1t == GGML_TYPE_F32) {
3115
- if (ne11 * ne12 < 4) {
3198
+ if (ne00 == 4) {
3199
+ nr0 = 32;
3200
+ nr1 = 4;
3201
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
3202
+ } else if (ne11 * ne12 < 4) {
3116
3203
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
3117
3204
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3118
3205
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
@@ -3733,13 +3820,74 @@ static bool ggml_metal_encode_node(
3733
3820
  };
3734
3821
 
3735
3822
  [encoder setComputePipelineState:pipeline];
3736
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3737
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3738
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3739
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
3823
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3824
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3825
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3826
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3740
3827
 
3741
3828
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
3742
3829
  } break;
3830
+ case GGML_OP_SET_ROWS:
3831
+ {
3832
+ id<MTLComputePipelineState> pipeline = nil;
3833
+
3834
+ switch (dst->type) {
3835
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3836
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3837
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3838
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
3839
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
3840
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
3841
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
3842
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
3843
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
3844
+ default: GGML_ABORT("not implemented");
3845
+ }
3846
+
3847
+ const int32_t nk0 = ne0/ggml_blck_size(dst->type);
3848
+
3849
+ int nth = 32; // SIMD width
3850
+
3851
+ while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3852
+ nth *= 2;
3853
+ }
3854
+
3855
+ int nrptg = 1;
3856
+ if (nth > nk0) {
3857
+ nrptg = (nth + nk0 - 1)/nk0;
3858
+ nth = nk0;
3859
+
3860
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
3861
+ nrptg--;
3862
+ }
3863
+ }
3864
+
3865
+ nth = MIN(nth, nk0);
3866
+
3867
+ ggml_metal_kargs_set_rows args = {
3868
+ /*.nk0 =*/ nk0,
3869
+ /*.ne01 =*/ ne01,
3870
+ /*.nb01 =*/ nb01,
3871
+ /*.nb02 =*/ nb02,
3872
+ /*.nb03 =*/ nb03,
3873
+ /*.ne11 =*/ ne11,
3874
+ /*.ne12 =*/ ne12,
3875
+ /*.nb10 =*/ nb10,
3876
+ /*.nb11 =*/ nb11,
3877
+ /*.nb12 =*/ nb12,
3878
+ /*.nb1 =*/ nb1,
3879
+ /*.nb2 =*/ nb2,
3880
+ /*.nb3 =*/ nb3,
3881
+ };
3882
+
3883
+ [encoder setComputePipelineState:pipeline];
3884
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3885
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3886
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3887
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3888
+
3889
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
3890
+ } break;
3743
3891
  case GGML_OP_RMS_NORM:
3744
3892
  {
3745
3893
  GGML_ASSERT(ne00 % 4 == 0);
@@ -3756,6 +3904,7 @@ static bool ggml_metal_encode_node(
3756
3904
  nth *= 2;
3757
3905
  }
3758
3906
 
3907
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3759
3908
  nth = MIN(nth, ne00/4);
3760
3909
 
3761
3910
  ggml_metal_kargs_rms_norm args = {
@@ -3792,6 +3941,7 @@ static bool ggml_metal_encode_node(
3792
3941
  nth *= 2;
3793
3942
  }
3794
3943
 
3944
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3795
3945
  nth = MIN(nth, ne00/4);
3796
3946
 
3797
3947
  ggml_metal_kargs_l2_norm args = {
@@ -3864,6 +4014,7 @@ static bool ggml_metal_encode_node(
3864
4014
  nth *= 2;
3865
4015
  }
3866
4016
 
4017
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3867
4018
  nth = MIN(nth, ne00/4);
3868
4019
 
3869
4020
  ggml_metal_kargs_norm args = {
@@ -4950,8 +5101,39 @@ static bool ggml_metal_encode_node(
4950
5101
  default: GGML_ABORT("not implemented");
4951
5102
  }
4952
5103
 
5104
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
5105
+
5106
+ // TODO: support
5107
+ //const int32_t nk00 = ne00/ggml_blck_size(dst->type);
5108
+ const int32_t nk00 = ne00;
5109
+
5110
+ int nth = 32; // SIMD width
5111
+
5112
+ while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
5113
+ nth *= 2;
5114
+ }
5115
+
5116
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
5117
+
5118
+ // when rows are small, we can batch them together in a single threadgroup
5119
+ int nrptg = 1;
5120
+
5121
+ // TODO: relax this constraint in the future
5122
+ if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
5123
+ if (nth > nk00) {
5124
+ nrptg = (nth + nk00 - 1)/nk00;
5125
+ nth = nk00;
5126
+
5127
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
5128
+ nrptg--;
5129
+ }
5130
+ }
5131
+ }
5132
+
5133
+ nth = MIN(nth, nk00);
5134
+
4953
5135
  ggml_metal_kargs_cpy args = {
4954
- /*.ne00 =*/ ne00,
5136
+ /*.ne00 =*/ nk00,
4955
5137
  /*.ne01 =*/ ne01,
4956
5138
  /*.ne02 =*/ ne02,
4957
5139
  /*.ne03 =*/ ne03,
@@ -4974,11 +5156,7 @@ static bool ggml_metal_encode_node(
4974
5156
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4975
5157
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4976
5158
 
4977
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4978
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4979
-
4980
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4981
-
5159
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
4982
5160
  } break;
4983
5161
  case GGML_OP_SET:
4984
5162
  {
@@ -5284,7 +5462,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
5284
5462
  }
5285
5463
 
5286
5464
  ggml_backend_metal_buffer_rset_free(ctx);
5287
- ggml_backend_metal_device_rel(buffer->buft->device->context);
5288
5465
 
5289
5466
  if (ctx->owned) {
5290
5467
  #if TARGET_OS_OSX
@@ -5393,7 +5570,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5393
5570
  }
5394
5571
 
5395
5572
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
5396
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5573
+
5574
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5575
+
5576
+ id<MTLDevice> device = ctx_dev->mtl_device;
5397
5577
 
5398
5578
  ctx->all_data = ggml_metal_host_malloc(size_aligned);
5399
5579
  ctx->all_size = size_aligned;
@@ -5416,14 +5596,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5416
5596
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
5417
5597
  GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
5418
5598
  free(ctx);
5419
- ggml_backend_metal_device_rel(ctx_dev);
5420
5599
  return NULL;
5421
5600
  }
5422
5601
 
5423
5602
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5424
5603
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5425
5604
  free(ctx);
5426
- ggml_backend_metal_device_rel(ctx_dev);
5427
5605
  return NULL;
5428
5606
  }
5429
5607
 
@@ -5434,17 +5612,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5434
5612
 
5435
5613
  static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
5436
5614
  return 32;
5615
+
5437
5616
  GGML_UNUSED(buft);
5438
5617
  }
5439
5618
 
5440
5619
  static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
5441
- id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
5442
- const size_t max_size = device.maxBufferLength;
5443
- ggml_backend_metal_device_rel(buft->device->context);
5620
+ const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
5444
5621
 
5445
5622
  return max_size;
5446
-
5447
- GGML_UNUSED(buft);
5448
5623
  }
5449
5624
 
5450
5625
  static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
@@ -5517,7 +5692,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
5517
5692
  }
5518
5693
 
5519
5694
  struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
5520
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5695
+
5696
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5697
+
5698
+ id<MTLDevice> device = ctx_dev->mtl_device;
5521
5699
 
5522
5700
  // the buffer fits into the max buffer size allowed by the device
5523
5701
  if (size_aligned <= device.maxBufferLength) {
@@ -5573,7 +5751,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
5573
5751
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5574
5752
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5575
5753
  free(ctx);
5576
- ggml_backend_metal_device_rel(ctx_dev);
5577
5754
  return NULL;
5578
5755
  }
5579
5756
 
@@ -5589,10 +5766,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
5589
5766
  }
5590
5767
 
5591
5768
  static void ggml_backend_metal_free(ggml_backend_t backend) {
5592
- struct ggml_backend_metal_context * ctx = backend->context;
5593
- struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5769
+ struct ggml_backend_metal_context * ctx = backend->context;
5594
5770
 
5595
- ggml_backend_metal_device_rel(ctx_dev);
5596
5771
  ggml_metal_free(ctx);
5597
5772
 
5598
5773
  free(backend);
@@ -5732,6 +5907,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
5732
5907
 
5733
5908
  struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5734
5909
 
5910
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5911
+
5735
5912
  return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
5736
5913
  }
5737
5914
 
@@ -5751,10 +5928,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
5751
5928
  }
5752
5929
 
5753
5930
  static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
5754
- // acq/rel just to populate ctx->name in case it hasn't been done yet
5755
5931
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5756
- ggml_backend_metal_device_acq(ctx_dev);
5757
- ggml_backend_metal_device_rel(ctx_dev);
5758
5932
 
5759
5933
  return ctx_dev->name;
5760
5934
  }
@@ -5762,12 +5936,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
5762
5936
  static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
5763
5937
  if (@available(macOS 10.12, iOS 16.0, *)) {
5764
5938
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5765
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5939
+ id<MTLDevice> device = ctx_dev->mtl_device;
5766
5940
 
5767
5941
  *total = device.recommendedMaxWorkingSetSize;
5768
5942
  *free = *total - device.currentAllocatedSize;
5769
-
5770
- ggml_backend_metal_device_rel(ctx_dev);
5771
5943
  } else {
5772
5944
  *free = 1;
5773
5945
  *total = 1;
@@ -5845,7 +6017,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
5845
6017
  }
5846
6018
 
5847
6019
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5848
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
6020
+
6021
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
6022
+
6023
+ id<MTLDevice> device = ctx_dev->mtl_device;
5849
6024
 
5850
6025
  // the buffer fits into the max buffer size allowed by the device
5851
6026
  if (size_aligned <= device.maxBufferLength) {
@@ -5901,7 +6076,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
5901
6076
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5902
6077
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5903
6078
  free(ctx);
5904
- ggml_backend_metal_device_rel(ctx_dev);
5905
6079
  return NULL;
5906
6080
  }
5907
6081
 
@@ -5915,8 +6089,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
5915
6089
  }
5916
6090
 
5917
6091
  static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
5918
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
5919
- buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
6092
+ return
6093
+ buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
6094
+ buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
5920
6095
 
5921
6096
  GGML_UNUSED(dev);
5922
6097
  }
@@ -6001,8 +6176,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
6001
6176
  /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
6002
6177
  };
6003
6178
 
6179
+ // called upon program exit
6180
+ static void ggml_metal_cleanup(void) {
6181
+ ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
6182
+ }
6183
+
6184
+ // TODO: make thread-safe
6004
6185
  ggml_backend_reg_t ggml_backend_metal_reg(void) {
6005
- // TODO: make this thread-safe somehow?
6186
+ ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
6187
+
6188
+ // register cleanup callback
6189
+ // TODO: not ideal, but not sure if there is a better way to do this in Objective-C
6190
+ atexit(ggml_metal_cleanup);
6191
+
6006
6192
  {
6007
6193
  g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
6008
6194
  /* .api_version = */ GGML_BACKEND_API_VERSION,