@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -25,15 +25,3 @@ add_executable(${TARGET} vulkan-shaders-gen.cpp)
25
25
  install(TARGETS ${TARGET} RUNTIME)
26
26
  target_compile_features(${TARGET} PRIVATE cxx_std_17)
27
27
  target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
28
-
29
- # Configure output directories for MSVC builds
30
- if(MSVC)
31
- # Get the main project's runtime output directory if possible
32
- if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY)
33
- foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
34
- string(TOUPPER ${CONFIG} CONFIG)
35
- set_target_properties(${TARGET} PROPERTIES
36
- RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
37
- endforeach()
38
- endif()
39
- endif()
@@ -0,0 +1,98 @@
1
+ #version 450
2
+
3
+ #include "types.comp"
4
+
5
+ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
6
+ layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
7
+ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
8
+
9
+ layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
10
+
11
+ layout (push_constant) uniform parameter {
12
+ uint32_t Cout;
13
+ uint32_t Cin;
14
+ uint32_t K;
15
+ uint32_t L;
16
+ uint32_t KL;
17
+
18
+ uint32_t nb01;
19
+ uint32_t nb02;
20
+ uint32_t nb11;
21
+ uint32_t nb1;
22
+
23
+ int32_t s0;
24
+ } p;
25
+
26
+
27
+ uint32_t Cout_idx = gl_WorkGroupID.x;
28
+ const uint32_t bs = gl_WorkGroupSize.x;
29
+ uint32_t tid = gl_LocalInvocationID.x;
30
+ // Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
31
+ uint32_t tmp_len = bs*p.s0+p.K;
32
+ shared D_TYPE tmp[4096];
33
+
34
+ uint splitWork(uint workSize){
35
+ return (bs + workSize -1) / bs;
36
+ }
37
+
38
+ void main(){
39
+ for(uint32_t i = 0; i < splitWork(tmp_len); i++){
40
+ uint32_t idx = i*bs+tid;
41
+ if(idx < tmp_len){
42
+ tmp[idx] = 0.0;
43
+ }
44
+ }
45
+
46
+ uint32_t L_blocks = splitWork(p.L);
47
+ for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
48
+ if(L_block_id > 0){
49
+ barrier();
50
+ // Shift values in tmp to the current processing window
51
+ for(int i = 0; i < splitWork(tmp_len); i++){
52
+ uint32_t idx = i*bs+tid;
53
+ if(idx >= bs*p.s0 && idx < tmp_len){
54
+ tmp[idx-bs*p.s0] = tmp[idx];
55
+ tmp[idx] = 0.0;
56
+ }else if(idx >= p.K && idx < bs*p.s0){
57
+ tmp[idx] = 0.0;
58
+ }
59
+ }
60
+ }
61
+ barrier();
62
+
63
+ // Save contributions of the block to tmp
64
+ uint32_t L_idx = L_block_id*bs + tid;
65
+ for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
66
+ D_TYPE dp = 0.0;
67
+ for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
68
+ A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
69
+ if(L_idx < p.L){
70
+ B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
71
+ dp = fma(elemKrn, elemInp, dp);
72
+ }
73
+ }
74
+ tmp[tid*p.s0 + K_idx] += dp;
75
+ barrier();
76
+ }
77
+
78
+ // Save the computed values except the last block that can have different size
79
+ uint32_t KLb_idx = L_block_id*bs*p.s0;
80
+ if(L_block_id < L_blocks-1){
81
+ for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
82
+ uint32_t sh_idx = p.s0*tid+s0_idx;
83
+ uint32_t KL_idx = KLb_idx+sh_idx;
84
+ if(KL_idx < p.KL){
85
+ data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
86
+ }
87
+ }
88
+ }
89
+ }
90
+
91
+ for(uint32_t i = 0; i < splitWork(tmp_len); i++){
92
+ uint32_t idx = i*bs+tid;
93
+ uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
94
+ if(KL_idx < p.KL){
95
+ data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
96
+ }
97
+ }
98
+ }
@@ -622,6 +622,8 @@ void process_shaders() {
622
622
 
623
623
  string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
624
624
 
625
+ string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
626
+
625
627
  string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
626
628
 
627
629
  string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
@@ -888,12 +888,6 @@ struct ggml_context {
888
888
  struct ggml_object * objects_end;
889
889
  };
890
890
 
891
- struct ggml_context_container {
892
- bool used;
893
-
894
- struct ggml_context context;
895
- };
896
-
897
891
  //
898
892
  // data types
899
893
  //
@@ -291,6 +291,7 @@ class MODEL_ARCH(IntEnum):
291
291
  BERT = auto()
292
292
  NOMIC_BERT = auto()
293
293
  NOMIC_BERT_MOE = auto()
294
+ NEO_BERT = auto()
294
295
  JINA_BERT_V2 = auto()
295
296
  BLOOM = auto()
296
297
  STABLELM = auto()
@@ -343,6 +344,8 @@ class MODEL_ARCH(IntEnum):
343
344
  WAVTOKENIZER_DEC = auto()
344
345
  PLM = auto()
345
346
  BAILINGMOE = auto()
347
+ DOTS1 = auto()
348
+ ARCEE = auto()
346
349
 
347
350
 
348
351
  class VISION_PROJECTOR_TYPE(IntEnum):
@@ -571,6 +574,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
571
574
  MODEL_ARCH.BERT: "bert",
572
575
  MODEL_ARCH.NOMIC_BERT: "nomic-bert",
573
576
  MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
577
+ MODEL_ARCH.NEO_BERT: "neo-bert",
574
578
  MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
575
579
  MODEL_ARCH.BLOOM: "bloom",
576
580
  MODEL_ARCH.STABLELM: "stablelm",
@@ -623,6 +627,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
623
627
  MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
624
628
  MODEL_ARCH.PLM: "plm",
625
629
  MODEL_ARCH.BAILINGMOE: "bailingmoe",
630
+ MODEL_ARCH.DOTS1: "dots1",
631
+ MODEL_ARCH.ARCEE: "arcee",
626
632
  }
627
633
 
628
634
  VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -1077,6 +1083,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
1077
1083
  MODEL_TENSOR.FFN_UP_EXP,
1078
1084
  MODEL_TENSOR.LAYER_OUT_NORM,
1079
1085
  ],
1086
+ MODEL_ARCH.NEO_BERT: [
1087
+ MODEL_TENSOR.TOKEN_EMBD,
1088
+ MODEL_TENSOR.ATTN_NORM,
1089
+ MODEL_TENSOR.ATTN_QKV,
1090
+ MODEL_TENSOR.ATTN_OUT,
1091
+ MODEL_TENSOR.FFN_NORM,
1092
+ MODEL_TENSOR.FFN_DOWN,
1093
+ MODEL_TENSOR.FFN_UP,
1094
+ MODEL_TENSOR.ENC_OUTPUT_NORM,
1095
+ MODEL_TENSOR.CLS,
1096
+ MODEL_TENSOR.CLS_OUT,
1097
+ ],
1080
1098
  MODEL_ARCH.JINA_BERT_V2: [
1081
1099
  MODEL_TENSOR.TOKEN_EMBD,
1082
1100
  MODEL_TENSOR.TOKEN_EMBD_NORM,
@@ -2044,6 +2062,45 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
2044
2062
  MODEL_TENSOR.FFN_DOWN_SHEXP,
2045
2063
  MODEL_TENSOR.FFN_UP_SHEXP,
2046
2064
  ],
2065
+ MODEL_ARCH.DOTS1: [
2066
+ MODEL_TENSOR.TOKEN_EMBD,
2067
+ MODEL_TENSOR.OUTPUT_NORM,
2068
+ MODEL_TENSOR.OUTPUT,
2069
+ MODEL_TENSOR.ATTN_NORM,
2070
+ MODEL_TENSOR.ATTN_Q,
2071
+ MODEL_TENSOR.ATTN_Q_NORM,
2072
+ MODEL_TENSOR.ATTN_K,
2073
+ MODEL_TENSOR.ATTN_K_NORM,
2074
+ MODEL_TENSOR.ATTN_V,
2075
+ MODEL_TENSOR.ATTN_OUT,
2076
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
2077
+ MODEL_TENSOR.FFN_NORM,
2078
+ MODEL_TENSOR.FFN_GATE,
2079
+ MODEL_TENSOR.FFN_GATE_EXP,
2080
+ MODEL_TENSOR.FFN_GATE_INP,
2081
+ MODEL_TENSOR.FFN_GATE_SHEXP,
2082
+ MODEL_TENSOR.FFN_DOWN,
2083
+ MODEL_TENSOR.FFN_DOWN_EXP,
2084
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
2085
+ MODEL_TENSOR.FFN_UP,
2086
+ MODEL_TENSOR.FFN_UP_EXP,
2087
+ MODEL_TENSOR.FFN_UP_SHEXP,
2088
+ ],
2089
+ MODEL_ARCH.ARCEE: [
2090
+ MODEL_TENSOR.TOKEN_EMBD,
2091
+ MODEL_TENSOR.OUTPUT_NORM,
2092
+ MODEL_TENSOR.OUTPUT,
2093
+ MODEL_TENSOR.ROPE_FREQS,
2094
+ MODEL_TENSOR.ATTN_NORM,
2095
+ MODEL_TENSOR.ATTN_Q,
2096
+ MODEL_TENSOR.ATTN_K,
2097
+ MODEL_TENSOR.ATTN_V,
2098
+ MODEL_TENSOR.ATTN_OUT,
2099
+ MODEL_TENSOR.ATTN_ROT_EMBD,
2100
+ MODEL_TENSOR.FFN_NORM,
2101
+ MODEL_TENSOR.FFN_DOWN,
2102
+ MODEL_TENSOR.FFN_UP,
2103
+ ],
2047
2104
  # TODO
2048
2105
  }
2049
2106
 
@@ -271,7 +271,7 @@ class GGUFWriter:
271
271
 
272
272
  def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
273
273
  if any(key in kv_data for kv_data in self.kv_data):
274
- raise ValueError(f'Duplicated key name {key!r}')
274
+ logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
275
275
 
276
276
  self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
277
277
 
@@ -935,6 +935,9 @@ class GGUFWriter:
935
935
  def add_eom_token_id(self, id: int) -> None:
936
936
  self.add_uint32(Keys.Tokenizer.EOM_ID, id)
937
937
 
938
+ def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
939
+ self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
940
+
938
941
  # for vision models
939
942
 
940
943
  def add_clip_has_vision_encoder(self, value: bool) -> None:
@@ -31,6 +31,7 @@ class TensorNameMap:
31
31
  "model.embeddings", # rwkv7
32
32
  "model.word_embeddings", # bailingmoe
33
33
  "language_model.model.embed_tokens", # llama4
34
+ "encoder", # neobert
34
35
  ),
35
36
 
36
37
  # Token type embeddings
@@ -134,6 +135,7 @@ class TensorNameMap:
134
135
  "rwkv.blocks.{bid}.ln1", # rwkv6
135
136
  "model.layers.{bid}.ln1", # rwkv7
136
137
  "model.layers.{bid}.input_layernorm", # llama4
138
+ "transformer_encoder.{bid}.attention_norm", # neobert
137
139
  ),
138
140
 
139
141
  # Attention norm 2
@@ -161,6 +163,7 @@ class TensorNameMap:
161
163
  "model.layers.{bid}.self_attn.qkv_proj", # phi3
162
164
  "encoder.layers.{bid}.self_attention.query_key_value", # chatglm
163
165
  "transformer.layers.{bid}.attn.qkv_proj", # openelm
166
+ "transformer_encoder.{bid}.qkv", # neobert
164
167
  ),
165
168
 
166
169
  # Attention query
@@ -236,6 +239,7 @@ class TensorNameMap:
236
239
  "transformer.layers.{bid}.attn.out_proj", # openelm
237
240
  "transformer.h.{bid}.attn.attention.out_proj", # exaone
238
241
  "model.layers.{bid}.self_attn.o_proj", # llama4
242
+ "transformer_encoder.{bid}.wo", # neobert
239
243
  ),
240
244
 
241
245
  # Attention output norm
@@ -276,6 +280,7 @@ class TensorNameMap:
276
280
  "encoder.layers.{bid}.post_attention_layernorm", # chatglm
277
281
  "transformer.layers.{bid}.ffn_norm", # openelm
278
282
  "model.layers.{bid}.post_attention_layernorm", # llama4
283
+ "transformer_encoder.{bid}.ffn_norm", # neobert
279
284
  ),
280
285
 
281
286
  # Post feed-forward norm
@@ -305,7 +310,7 @@ class TensorNameMap:
305
310
  ),
306
311
 
307
312
  MODEL_TENSOR.FFN_EXP_PROBS_B: (
308
- "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
313
+ "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
309
314
  ),
310
315
 
311
316
  # Feed-forward up
@@ -333,11 +338,14 @@ class TensorNameMap:
333
338
  "encoder.layers.{bid}.mlp.fc11", # nomic-bert
334
339
  "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe
335
340
  "model.layers.{bid}.mlp.c_fc", # starcoder2
336
- "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
341
+ "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 (split up/gate, no longer used)
342
+ "encoder.layer.{bid}.mlp.gated_layers", # jina-bert-v2 (GEGLU)
343
+ "encoder.layer.{bid}.mlp.up_gated_layer", # jina-v2-code (GEGLU)
337
344
  "model.layers.{bid}.residual_mlp.w3", # arctic
338
345
  "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
339
346
  "transformer.h.{bid}.mlp.c_fc_1", # exaone
340
347
  "model.layers.{bid}.feed_forward.up_proj", # llama4
348
+ "transformer_encoder.{bid}.ffn.w12", # neobert
341
349
  ),
342
350
 
343
351
  MODEL_TENSOR.FFN_UP_EXP: (
@@ -370,7 +378,7 @@ class TensorNameMap:
370
378
  "model.layers.layers.{bid}.mlp.gate_proj", # plamo
371
379
  "model.layers.{bid}.feed_forward.w1", # internlm2
372
380
  "encoder.layers.{bid}.mlp.fc12", # nomic-bert
373
- "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
381
+ "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used)
374
382
  "transformer.h.{bid}.mlp.linear_1", # refact
375
383
  "model.layers.{bid}.residual_mlp.w1", # arctic
376
384
  "transformer.h.{bid}.mlp.c_fc_0", # exaone
@@ -420,6 +428,7 @@ class TensorNameMap:
420
428
  "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
421
429
  "model.layers.h.{bid}.mlp.c_proj", # exaone
422
430
  "model.layers.{bid}.feed_forward.down_proj", # llama4
431
+ "transformer_encoder.{bid}.ffn.w3", # neobert
423
432
  ),
424
433
 
425
434
  MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -830,12 +839,14 @@ class TensorNameMap:
830
839
  # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
831
840
  MODEL_TENSOR.ENC_OUTPUT_NORM: (
832
841
  "encoder.final_layer_norm", # t5
842
+ "layer_norm", # neobert
833
843
  ),
834
844
 
835
845
  MODEL_TENSOR.CLS: (
836
846
  "classifier", # jina
837
847
  "classifier.dense", # roberta
838
848
  "pre_classifier", # distillbert
849
+ "dense", # neobert
839
850
  ),
840
851
 
841
852
  MODEL_TENSOR.CLS_OUT: (