numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -70,7 +70,7 @@ extern "C" {
70
70
  "avx512fp16", "f16c", "fma", "bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8")
71
71
  #endif
72
72
 
73
- #pragma region i8 Header (for f32/f16 coarse+refine)
73
+ #pragma region I8 Header
74
74
 
75
75
  /**
76
76
  * i8 packed buffer header for AMX coarse+refine MaxSim (64 bytes).
@@ -92,9 +92,9 @@ typedef struct {
92
92
 
93
93
  NK_STATIC_ASSERT(sizeof(nk_maxsim_sapphireamx_i8_header_t) == 64, nk_maxsim_sapphireamx_i8_header_must_be_64_bytes);
94
94
 
95
- #pragma endregion
95
+ #pragma endregion I8 Header
96
96
 
97
- #pragma region Single Precision Floats
97
+ #pragma region F32 Floats
98
98
 
99
99
  NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
100
100
  nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
@@ -108,7 +108,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sapphireamx(nk_size_t vector_count
108
108
  }
109
109
 
110
110
  NK_PUBLIC void nk_maxsim_pack_f32_sapphireamx( //
111
- nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
111
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
112
112
 
113
113
  nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
114
114
  nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
@@ -147,7 +147,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_sapphireamx( //
147
147
 
148
148
  // Quantize vectors and scatter into A-side tiles, copy originals, compute inverse norms
149
149
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
150
- nk_f32_t const *source_vector = (nk_f32_t const *)((char const *)vectors + vector_index * stride);
150
+ nk_f32_t const *source_vector = (nk_f32_t const *)((char const *)vectors + vector_index * stride_in_bytes);
151
151
 
152
152
  // Pass 1: find absmax and norm_squared
153
153
  nk_f32_t absmax_f32 = 0.0f;
@@ -347,9 +347,9 @@ NK_PUBLIC void nk_maxsim_packed_f32_sapphireamx( //
347
347
  *result = total_angular_distance_f64;
348
348
  }
349
349
 
350
- #pragma endregion
350
+ #pragma endregion F32 Floats
351
351
 
352
- #pragma region Half Precision Floats
352
+ #pragma region F16 Floats
353
353
 
354
354
  NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
355
355
  nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
@@ -363,7 +363,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sapphireamx(nk_size_t vector_count
363
363
  }
364
364
 
365
365
  NK_PUBLIC void nk_maxsim_pack_f16_sapphireamx( //
366
- nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
366
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
367
367
 
368
368
  nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
369
369
  nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
@@ -401,7 +401,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_sapphireamx( //
401
401
  }
402
402
 
403
403
  // Quantize vectors and scatter into A-side tiles, copy originals, compute inverse norms
404
- nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
404
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
405
405
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
406
406
  nk_f16_t const *source_vector = vectors + vector_index * stride_elements;
407
407
 
@@ -602,9 +602,9 @@ NK_PUBLIC void nk_maxsim_packed_f16_sapphireamx( //
602
602
  *result = (nk_f32_t)total_angular_distance_f64;
603
603
  }
604
604
 
605
- #pragma endregion
605
+ #pragma endregion F16 Floats
606
606
 
607
- #pragma region Brain Floats (Fused AMX)
607
+ #pragma region BF16 Floats
608
608
 
609
609
  /**
610
610
  * BF16 packed buffer header for AMX fused MaxSim (64 bytes).
@@ -635,10 +635,10 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sapphireamx(nk_size_t vector_coun
635
635
  }
636
636
 
637
637
  NK_PUBLIC void nk_maxsim_pack_bf16_sapphireamx( //
638
- nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
638
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
639
639
 
640
640
  nk_size_t const tile_bytes = 1024;
641
- nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
641
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
642
642
  nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
643
643
  nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 32);
644
644
 
@@ -860,7 +860,7 @@ NK_PUBLIC void nk_maxsim_packed_bf16_sapphireamx( //
860
860
  *result = (nk_f32_t)total_angular_distance_f64;
861
861
  }
862
862
 
863
- #pragma endregion
863
+ #pragma endregion BF16 Floats
864
864
 
865
865
  #if defined(__clang__)
866
866
  #pragma clang attribute pop
@@ -234,7 +234,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_serial(nk_size_t vector_count, nk_
234
234
  }
235
235
 
236
236
  NK_PUBLIC void nk_maxsim_pack_bf16_serial( //
237
- nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
237
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
238
238
 
239
239
  nk_size_t const element_bytes = sizeof(nk_bf16_t);
240
240
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 1, element_bytes);
@@ -246,7 +246,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_serial( //
246
246
  nk_size_t const original_stride = header->original_stride_bytes;
247
247
 
248
248
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
249
- char const *source_row = (char const *)vectors + vector_index * stride;
249
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
250
250
  nk_f32_t norm_sq;
251
251
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
252
252
  (nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
@@ -260,7 +260,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_serial( //
260
260
  }
261
261
 
262
262
  NK_PUBLIC void nk_maxsim_pack_f32_serial( //
263
- nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
263
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
264
264
 
265
265
  nk_size_t const element_bytes = sizeof(nk_f32_t);
266
266
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 1, element_bytes);
@@ -272,7 +272,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_serial( //
272
272
  nk_size_t const original_stride = header->original_stride_bytes;
273
273
 
274
274
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
275
- char const *source_row = (char const *)vectors + vector_index * stride;
275
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
276
276
  nk_f32_t norm_sq;
277
277
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
278
278
  &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
@@ -289,7 +289,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_serial(nk_size_t vector_count, nk_
289
289
  }
290
290
 
291
291
  NK_PUBLIC void nk_maxsim_pack_f16_serial( //
292
- nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
292
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
293
293
 
294
294
  nk_size_t const element_bytes = sizeof(nk_f16_t);
295
295
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 1, element_bytes);
@@ -301,7 +301,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_serial( //
301
301
  nk_size_t const original_stride = header->original_stride_bytes;
302
302
 
303
303
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
304
- char const *source_row = (char const *)vectors + vector_index * stride;
304
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
305
305
  nk_f32_t norm_sq;
306
306
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
307
307
  (nk_maxsim_to_f32_t)nk_f16_to_f32_serial,