numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -1,157 +0,0 @@
1
- /**
2
- * @brief NEON FP16 implementations for the redesigned reduction API (moments + minmax).
3
- * @file include/numkong/reduce/neonhalf.h
4
- * @author Ash Vardanian
5
- * @date February 13, 2026
6
- *
7
- * @sa include/numkong/reduce.h
8
- *
9
- * @section reduce_neonhalf_new_design Design Notes
10
- *
11
- * Moments (sum + sum-of-squares) accumulate in f32 via vcvt_f32_f16 widening, giving
12
- * full f32 precision. The contiguous path processes 8 f16 elements per iteration, widening
13
- * to two f32x4 halves and using vfmaq_f32 for fused multiply-accumulate of squares.
14
- *
15
- * Minmax tracks min/max values as native f16x8 with u16x8 iteration counters (same width
16
- * as f16). The u16 counters wrap at 65536, so the dispatcher splits arrays larger than
17
- * 65536 * 8 = 524288 elements via recursive halving.
18
- */
19
- #ifndef NK_REDUCE_NEONHALF_H
20
- #define NK_REDUCE_NEONHALF_H
21
-
22
- #if NK_TARGET_ARM_
23
- #if NK_TARGET_NEONHALF
24
-
25
- #include "numkong/types.h"
26
- #include "numkong/cast/neon.h"
27
- #include "numkong/cast/serial.h"
28
- #include "numkong/reduce/serial.h"
29
-
30
- #if defined(__cplusplus)
31
- extern "C" {
32
- #endif
33
-
34
- #if defined(__clang__)
35
- #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
36
- #elif defined(__GNUC__)
37
- #pragma GCC push_options
38
- #pragma GCC target("arch=armv8.2-a+simd+fp16")
39
- #endif
40
-
41
- NK_INTERNAL void nk_reduce_moments_f16_neonhalf_contiguous_( //
42
- nk_f16_t const *data_ptr, nk_size_t count, //
43
- nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
44
- float32x4_t sum_f32x4 = vdupq_n_f32(0);
45
- float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
46
- nk_size_t idx = 0;
47
-
48
- for (; idx + 8 <= count; idx += 8) {
49
- float16x8_t data_f16x8 = vld1q_f16((nk_f16_for_arm_simd_t const *)(data_ptr + idx));
50
- float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
51
- float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
52
- sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
53
- sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
54
- sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
55
- sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
56
- }
57
-
58
- // Scalar tail
59
- nk_f32_t sum = vaddvq_f32(sum_f32x4);
60
- nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
61
- for (; idx < count; ++idx) {
62
- nk_f32_t value_f32;
63
- nk_f16_to_f32_serial(data_ptr + idx, &value_f32);
64
- sum += value_f32, sumsq += value_f32 * value_f32;
65
- }
66
- *sum_ptr = sum, *sumsq_ptr = sumsq;
67
- }
68
-
69
- NK_INTERNAL void nk_reduce_moments_f16_neonhalf_strided_( //
70
- nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
71
- nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
72
- float32x4_t sum_f32x4 = vdupq_n_f32(0);
73
- float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
74
- nk_size_t idx = 0;
75
-
76
- if (stride_elements == 2) {
77
- for (; idx + 8 <= count; idx += 8) {
78
- uint16x8x2_t loaded_u16x8x2 = vld2q_u16((uint16_t const *)(data_ptr + idx * 2));
79
- float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x2.val[0]);
80
- float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
81
- float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
82
- sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
83
- sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
84
- sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
85
- sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
86
- }
87
- }
88
- else if (stride_elements == 3) {
89
- for (; idx + 8 <= count; idx += 8) {
90
- uint16x8x3_t loaded_u16x8x3 = vld3q_u16((uint16_t const *)(data_ptr + idx * 3));
91
- float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x3.val[0]);
92
- float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
93
- float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
94
- sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
95
- sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
96
- sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
97
- sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
98
- }
99
- }
100
- else if (stride_elements == 4) {
101
- for (; idx + 8 <= count; idx += 8) {
102
- uint16x8x4_t loaded_u16x8x4 = vld4q_u16((uint16_t const *)(data_ptr + idx * 4));
103
- float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x4.val[0]);
104
- float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
105
- float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
106
- sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
107
- sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
108
- sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
109
- sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
110
- }
111
- }
112
-
113
- // Scalar tail for remaining elements
114
- nk_f32_t sum = vaddvq_f32(sum_f32x4);
115
- nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
116
- for (; idx < count; ++idx) {
117
- nk_f32_t value_f32;
118
- nk_f16_to_f32_serial((nk_f16_t const *)(data_ptr + idx * stride_elements), &value_f32);
119
- sum += value_f32, sumsq += value_f32 * value_f32;
120
- }
121
- *sum_ptr = sum, *sumsq_ptr = sumsq;
122
- }
123
-
124
- NK_PUBLIC void nk_reduce_moments_f16_neonhalf( //
125
- nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
126
- nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
127
- nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
128
- int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
129
- if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
130
- else if (!aligned) nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
131
- else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
132
- nk_size_t left_count = count / 2;
133
- nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
134
- nk_reduce_moments_f16_neonhalf(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
135
- nk_reduce_moments_f16_neonhalf(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
136
- &right_sum_value, &right_sumsq_value);
137
- *sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
138
- }
139
- else if (stride_elements == 1) nk_reduce_moments_f16_neonhalf_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
140
- else if (stride_elements <= 4)
141
- nk_reduce_moments_f16_neonhalf_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
142
- else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
143
- }
144
-
145
- #if defined(__clang__)
146
- #pragma clang attribute pop
147
- #elif defined(__GNUC__)
148
- #pragma GCC pop_options
149
- #endif
150
-
151
- #if defined(__cplusplus)
152
- } // extern "C"
153
- #endif
154
-
155
- #endif // NK_TARGET_NEONHALF
156
- #endif // NK_TARGET_ARM_
157
- #endif // NK_REDUCE_NEONHALF_H
@@ -1,118 +0,0 @@
1
- /**
2
- * @brief SIMD-accelerated Spatial Similarity Measures for NEON FP16.
3
- * @file include/numkong/spatial/neonhalf.h
4
- * @author Ash Vardanian
5
- * @date December 27, 2025
6
- *
7
- * @sa include/numkong/spatial.h
8
- *
9
- * @section spatial_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
10
- *
11
- * Intrinsic Instruction Latency Throughput
12
- * A76 M4+/V1+/Oryon
13
- * vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
14
- * vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
15
- * vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
16
- * vsubq_f16 FSUB (V.8H, V.8H, V.8H) 2cy 2/cy 4/cy
17
- * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
18
- *
19
- * The ARMv8.2-FP16 extension enables native half-precision arithmetic, doubling the element count
20
- * per vector register (8x F16 vs 4x F32). For spatial distance computations like L2 and angular
21
- * distance, this halves memory bandwidth requirements.
22
- *
23
- * Inputs are widened from F16 to F32 for accumulation via FCVTL to preserve numerical precision
24
- * during the squared difference summation. The subtraction and FMA operations use F32 precision
25
- * in the accumulator to avoid catastrophic cancellation in distance computations.
26
- */
27
- #ifndef NK_SPATIAL_NEONHALF_H
28
- #define NK_SPATIAL_NEONHALF_H
29
-
30
- #if NK_TARGET_ARM_
31
- #if NK_TARGET_NEONHALF
32
-
33
- #include "numkong/types.h"
34
- #include "numkong/cast/serial.h" // `nk_partial_load_b16x4_serial_`
35
- #include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
36
-
37
- #if defined(__cplusplus)
38
- extern "C" {
39
- #endif
40
-
41
- #if defined(__clang__)
42
- #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
43
- #elif defined(__GNUC__)
44
- #pragma GCC push_options
45
- #pragma GCC target("arch=armv8.2-a+simd+fp16")
46
- #endif
47
-
48
- NK_PUBLIC void nk_sqeuclidean_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
49
- float32x4_t a_f32x4, b_f32x4;
50
- float32x4_t distance_sq_f32x4 = vdupq_n_f32(0);
51
-
52
- nk_sqeuclidean_f16_neonhalf_cycle:
53
- if (n < 4) {
54
- nk_b64_vec_t a_vec, b_vec;
55
- nk_partial_load_b16x4_serial_(a, &a_vec, n);
56
- nk_partial_load_b16x4_serial_(b, &b_vec, n);
57
- a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
58
- b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
59
- n = 0;
60
- }
61
- else {
62
- a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
63
- b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
64
- n -= 4, a += 4, b += 4;
65
- }
66
- float32x4_t diff_f32x4 = vsubq_f32(a_f32x4, b_f32x4);
67
- distance_sq_f32x4 = vfmaq_f32(distance_sq_f32x4, diff_f32x4, diff_f32x4);
68
- if (n) goto nk_sqeuclidean_f16_neonhalf_cycle;
69
-
70
- *result = vaddvq_f32(distance_sq_f32x4);
71
- }
72
- NK_PUBLIC void nk_euclidean_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
73
- nk_sqeuclidean_f16_neonhalf(a, b, n, result);
74
- *result = nk_f32_sqrt_neon(*result);
75
- }
76
-
77
- NK_PUBLIC void nk_angular_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
78
- float32x4_t dot_product_f32x4 = vdupq_n_f32(0), a_norm_sq_f32x4 = vdupq_n_f32(0), b_norm_sq_f32x4 = vdupq_n_f32(0);
79
- float32x4_t a_f32x4, b_f32x4;
80
-
81
- nk_angular_f16_neonhalf_cycle:
82
- if (n < 4) {
83
- nk_b64_vec_t a_vec, b_vec;
84
- nk_partial_load_b16x4_serial_(a, &a_vec, n);
85
- nk_partial_load_b16x4_serial_(b, &b_vec, n);
86
- a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
87
- b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
88
- n = 0;
89
- }
90
- else {
91
- a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
92
- b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
93
- n -= 4, a += 4, b += 4;
94
- }
95
- dot_product_f32x4 = vfmaq_f32(dot_product_f32x4, a_f32x4, b_f32x4);
96
- a_norm_sq_f32x4 = vfmaq_f32(a_norm_sq_f32x4, a_f32x4, a_f32x4);
97
- b_norm_sq_f32x4 = vfmaq_f32(b_norm_sq_f32x4, b_f32x4, b_f32x4);
98
- if (n) goto nk_angular_f16_neonhalf_cycle;
99
-
100
- nk_f32_t dot_product_f32 = vaddvq_f32(dot_product_f32x4);
101
- nk_f32_t a_norm_sq_f32 = vaddvq_f32(a_norm_sq_f32x4);
102
- nk_f32_t b_norm_sq_f32 = vaddvq_f32(b_norm_sq_f32x4);
103
- *result = nk_angular_normalize_f32_neon_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
104
- }
105
-
106
- #if defined(__clang__)
107
- #pragma clang attribute pop
108
- #elif defined(__GNUC__)
109
- #pragma GCC pop_options
110
- #endif
111
-
112
- #if defined(__cplusplus)
113
- } // extern "C"
114
- #endif
115
-
116
- #endif // NK_TARGET_NEONHALF
117
- #endif // NK_TARGET_ARM_
118
- #endif // NK_SPATIAL_NEONHALF_H
@@ -1,343 +0,0 @@
1
- /**
2
- * @brief SIMD-accelerated Spatial Similarity Measures for Sapphire Rapids.
3
- * @file include/numkong/spatial/sapphire.h
4
- * @author Ash Vardanian
5
- * @date December 27, 2025
6
- *
7
- * @sa include/numkong/spatial.h
8
- *
9
- * Sapphire Rapids adds native FP16 support via AVX-512 FP16 extension.
10
- * For e4m3 L2 distance, we can leverage F16 for the subtraction step:
11
- * - e4m3 differences fit in F16 (max |a−b| = 896 < 65504)
12
- * - But squared differences overflow F16 (896² = 802816 > 65504)
13
- * - So: subtract in F16, convert to F32, then square and accumulate
14
- *
15
- * For e2m3/e3m2 L2 distance, squared differences fit in FP16:
16
- * - E2M3: max |a−b| = 15, max (a−b)² = 225 < 65504, flush cadence = 4 (conservative for uniformity)
17
- * - E3M2: max |a−b| = 56, max (a−b)² = 3136 < 65504, flush cadence = 4
18
- * So the entire sub+square+accumulate stays in FP16 with periodic F32 flush.
19
- *
20
- * @section spatial_sapphire_instructions Relevant Instructions
21
- *
22
- * Intrinsic Instruction Sapphire Genoa
23
- * _mm256_sub_ph VSUBPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
24
- * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 5cy @ p01
25
- * _mm512_fmadd_ps VFMADD (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
26
- * _mm512_reduce_add_ps (pseudo: VHADDPS chain) ~8cy ~8cy
27
- * _mm_maskz_loadu_epi8 VMOVDQU8 (XMM {K}, M128) 7cy @ p23 7cy @ p23
28
- */
29
- #ifndef NK_SPATIAL_SAPPHIRE_H
30
- #define NK_SPATIAL_SAPPHIRE_H
31
-
32
- #if NK_TARGET_X86_
33
- #if NK_TARGET_SAPPHIRE
34
-
35
- #include "numkong/types.h"
36
- #include "numkong/cast/sapphire.h" // `nk_e4m3x16_to_f16x16_sapphire_`
37
- #include "numkong/dot/sapphire.h" // `nk_e2m3x32_to_f16x32_sapphire_`, `nk_flush_f16_to_f32_sapphire_`
38
- #include "numkong/spatial/haswell.h" // `nk_angular_normalize_f32_haswell_`, `nk_f32_sqrt_haswell`
39
-
40
- #if defined(__cplusplus)
41
- extern "C" {
42
- #endif
43
-
44
- #if defined(__clang__)
45
- #pragma clang attribute push( \
46
- __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,f16c,fma,bmi,bmi2"))), \
47
- apply_to = function)
48
- #elif defined(__GNUC__)
49
- #pragma GCC push_options
50
- #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
51
- #endif
52
-
53
- NK_PUBLIC void nk_sqeuclidean_e4m3_sapphire(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars,
54
- nk_size_t count_scalars, nk_f32_t *result) {
55
- __m512 sum_f32x16 = _mm512_setzero_ps();
56
-
57
- while (count_scalars > 0) {
58
- nk_size_t const n = count_scalars < 16 ? count_scalars : 16;
59
- __mmask16 const mask = (__mmask16)_bzhi_u32(0xFFFF, n);
60
- __m128i a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
61
- __m128i b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
62
-
63
- // Convert e4m3 → f16
64
- __m256h a_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(a_e4m3x16);
65
- __m256h b_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(b_e4m3x16);
66
-
67
- // Subtract in F16 − differences fit (max 896 < 65504)
68
- __m256h diff_f16x16 = _mm256_sub_ph(a_f16x16, b_f16x16);
69
-
70
- // Convert to F32 before squaring (896² = 802816 overflows F16!)
71
- __m512 diff_f32x16 = _mm512_cvtph_ps(_mm256_castph_si256(diff_f16x16));
72
-
73
- // Square and accumulate in F32
74
- sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
75
- a_scalars += n, b_scalars += n, count_scalars -= n;
76
- }
77
-
78
- *result = _mm512_reduce_add_ps(sum_f32x16);
79
- }
80
-
81
- NK_PUBLIC void nk_euclidean_e4m3_sapphire(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars,
82
- nk_size_t count_scalars, nk_f32_t *result) {
83
- nk_sqeuclidean_e4m3_sapphire(a_scalars, b_scalars, count_scalars, result);
84
- *result = nk_f32_sqrt_haswell(*result);
85
- }
86
-
87
- NK_PUBLIC void nk_sqeuclidean_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
88
- nk_size_t count_scalars, nk_f32_t *result) {
89
- __m512 sum_f32x16 = _mm512_setzero_ps();
90
-
91
- // Main loop: 4-way unrolled, 128 elements per flush
92
- while (count_scalars >= 128) {
93
- __m512h acc_f16x32 = _mm512_setzero_ph();
94
- __m512h a_f16x32, b_f16x32, diff_f16x32;
95
- // Iteration 1
96
- a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
97
- b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
98
- diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
99
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
100
- // Iteration 2
101
- a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
102
- b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
103
- diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
104
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
105
- // Iteration 3
106
- a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
107
- b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
108
- diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
109
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
110
- // Iteration 4
111
- a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
112
- b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
113
- diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
114
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
115
- // Flush to F32
116
- sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
117
- a_scalars += 128, b_scalars += 128, count_scalars -= 128;
118
- }
119
-
120
- // Tail: remaining 0–127 elements, 32 at a time via masked loads
121
- __m512h acc_f16x32 = _mm512_setzero_ph();
122
- while (count_scalars > 0) {
123
- nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
124
- __mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
125
- __m512h a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
126
- __m512h b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
127
- __m512h diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
128
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
129
- a_scalars += n, b_scalars += n, count_scalars -= n;
130
- }
131
- sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
132
-
133
- *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
134
- }
135
-
136
- NK_PUBLIC void nk_sqeuclidean_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
137
- nk_size_t count_scalars, nk_f32_t *result) {
138
- __m512 sum_f32x16 = _mm512_setzero_ps();
139
-
140
- // Main loop: 4-way unrolled, 128 elements per flush
141
- while (count_scalars >= 128) {
142
- __m512h acc_f16x32 = _mm512_setzero_ph();
143
- __m512h a_f16x32, b_f16x32, diff_f16x32;
144
- // Iteration 1
145
- a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
146
- b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
147
- diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
148
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
149
- // Iteration 2
150
- a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
151
- b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
152
- diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
153
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
154
- // Iteration 3
155
- a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
156
- b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
157
- diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
158
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
159
- // Iteration 4
160
- a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
161
- b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
162
- diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
163
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
164
- // Flush to F32
165
- sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
166
- a_scalars += 128, b_scalars += 128, count_scalars -= 128;
167
- }
168
-
169
- // Tail: remaining 0–127 elements, 32 at a time via masked loads
170
- __m512h acc_f16x32 = _mm512_setzero_ph();
171
- while (count_scalars > 0) {
172
- nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
173
- __mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
174
- __m512h a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
175
- __m512h b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
176
- __m512h diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
177
- acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
178
- a_scalars += n, b_scalars += n, count_scalars -= n;
179
- }
180
- sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
181
-
182
- *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
183
- }
184
-
185
- NK_PUBLIC void nk_euclidean_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
186
- nk_size_t count_scalars, nk_f32_t *result) {
187
- nk_sqeuclidean_e2m3_sapphire(a_scalars, b_scalars, count_scalars, result);
188
- *result = nk_f32_sqrt_haswell(*result);
189
- }
190
-
191
- NK_PUBLIC void nk_euclidean_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
192
- nk_size_t count_scalars, nk_f32_t *result) {
193
- nk_sqeuclidean_e3m2_sapphire(a_scalars, b_scalars, count_scalars, result);
194
- *result = nk_f32_sqrt_haswell(*result);
195
- }
196
-
197
- NK_PUBLIC void nk_angular_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
198
- nk_f32_t *result) {
199
- __m512 sum_dot_f32x16 = _mm512_setzero_ps();
200
- __m512 sum_a_f32x16 = _mm512_setzero_ps();
201
- __m512 sum_b_f32x16 = _mm512_setzero_ps();
202
-
203
- // Main loop: 4-way unrolled, 128 elements per flush
204
- while (count_scalars >= 128) {
205
- __m512h dot_acc = _mm512_setzero_ph();
206
- __m512h a_norm_acc = _mm512_setzero_ph();
207
- __m512h b_norm_acc = _mm512_setzero_ph();
208
- __m512h a_f16x32, b_f16x32;
209
- // Iteration 1
210
- a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
211
- b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
212
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
213
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
214
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
215
- // Iteration 2
216
- a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
217
- b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
218
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
219
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
220
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
221
- // Iteration 3
222
- a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
223
- b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
224
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
225
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
226
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
227
- // Iteration 4
228
- a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
229
- b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
230
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
231
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
232
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
233
- // Flush to F32
234
- sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
235
- sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
236
- sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
237
- a_scalars += 128, b_scalars += 128, count_scalars -= 128;
238
- }
239
-
240
- // Tail: remaining 0–127 elements, 32 at a time via masked loads
241
- __m512h dot_acc = _mm512_setzero_ph();
242
- __m512h a_norm_acc = _mm512_setzero_ph();
243
- __m512h b_norm_acc = _mm512_setzero_ph();
244
- while (count_scalars > 0) {
245
- nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
246
- __mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
247
- __m512h a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
248
- __m512h b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
249
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
250
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
251
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
252
- a_scalars += n, b_scalars += n, count_scalars -= n;
253
- }
254
- sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
255
- sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
256
- sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
257
-
258
- nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(sum_dot_f32x16);
259
- nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_a_f32x16);
260
- nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_b_f32x16);
261
- *result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
262
- }
263
-
264
- NK_PUBLIC void nk_angular_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
265
- nk_f32_t *result) {
266
- __m512 sum_dot_f32x16 = _mm512_setzero_ps();
267
- __m512 sum_a_f32x16 = _mm512_setzero_ps();
268
- __m512 sum_b_f32x16 = _mm512_setzero_ps();
269
-
270
- // Main loop: 4-way unrolled, 128 elements per flush
271
- while (count_scalars >= 128) {
272
- __m512h dot_acc = _mm512_setzero_ph();
273
- __m512h a_norm_acc = _mm512_setzero_ph();
274
- __m512h b_norm_acc = _mm512_setzero_ph();
275
- __m512h a_f16x32, b_f16x32;
276
- // Iteration 1
277
- a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
278
- b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
279
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
280
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
281
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
282
- // Iteration 2
283
- a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
284
- b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
285
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
286
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
287
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
288
- // Iteration 3
289
- a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
290
- b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
291
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
292
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
293
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
294
- // Iteration 4
295
- a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
296
- b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
297
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
298
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
299
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
300
- // Flush to F32
301
- sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
302
- sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
303
- sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
304
- a_scalars += 128, b_scalars += 128, count_scalars -= 128;
305
- }
306
-
307
- // Tail: remaining 0–127 elements, 32 at a time via masked loads
308
- __m512h dot_acc = _mm512_setzero_ph();
309
- __m512h a_norm_acc = _mm512_setzero_ph();
310
- __m512h b_norm_acc = _mm512_setzero_ph();
311
- while (count_scalars > 0) {
312
- nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
313
- __mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
314
- __m512h a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
315
- __m512h b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
316
- dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
317
- a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
318
- b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
319
- a_scalars += n, b_scalars += n, count_scalars -= n;
320
- }
321
- sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
322
- sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
323
- sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
324
-
325
- nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(sum_dot_f32x16);
326
- nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_a_f32x16);
327
- nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_b_f32x16);
328
- *result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
329
- }
330
-
331
- #if defined(__clang__)
332
- #pragma clang attribute pop
333
- #elif defined(__GNUC__)
334
- #pragma GCC pop_options
335
- #endif
336
-
337
- #if defined(__cplusplus)
338
- } // extern "C"
339
- #endif
340
-
341
- #endif // NK_TARGET_SAPPHIRE
342
- #endif // NK_TARGET_X86_
343
- #endif // NK_SPATIAL_SAPPHIRE_H
@@ -1,58 +0,0 @@
1
- /**
2
- * @brief Batched Spatial Distances for NEON FP16 (Half-Precision).
3
- * @file include/numkong/spatials/neonhalf.h
4
- * @author Ash Vardanian
5
- * @date February 23, 2026
6
- *
7
- * @sa include/numkong/spatials.h
8
- */
9
- #ifndef NK_SPATIALS_NEONHALF_H
10
- #define NK_SPATIALS_NEONHALF_H
11
-
12
- #if NK_TARGET_ARM_
13
- #if NK_TARGET_NEONHALF
14
-
15
- #include "numkong/spatial/neon.h"
16
- #include "numkong/dots/neonhalf.h"
17
-
18
- #if defined(__cplusplus)
19
- extern "C" {
20
- #endif
21
-
22
- #if defined(__clang__)
23
- #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
24
- #elif defined(__GNUC__)
25
- #pragma GCC push_options
26
- #pragma GCC target("arch=armv8.2-a+simd+fp16")
27
- #endif
28
-
29
- nk_define_cross_normalized_packed_(angular, f16, neonhalf, f16, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
30
- nk_dots_packed_f16_neonhalf, nk_angular_through_f32_from_dot_neon_,
31
- nk_dots_reduce_sumsq_f16_, nk_load_b128_neon_, nk_partial_load_b32x4_serial_,
32
- nk_store_b128_neon_, nk_partial_store_b32x4_serial_, 1)
33
- nk_define_cross_normalized_packed_(euclidean, f16, neonhalf, f16, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
34
- nk_dots_packed_f16_neonhalf, nk_euclidean_through_f32_from_dot_neon_,
35
- nk_dots_reduce_sumsq_f16_, nk_load_b128_neon_, nk_partial_load_b32x4_serial_,
36
- nk_store_b128_neon_, nk_partial_store_b32x4_serial_, 1)
37
- nk_define_cross_normalized_symmetric_(angular, f16, neonhalf, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
38
- nk_dots_symmetric_f16_neonhalf, nk_angular_through_f32_from_dot_neon_,
39
- nk_dots_reduce_sumsq_f16_, nk_load_b128_neon_, nk_partial_load_b32x4_serial_,
40
- nk_store_b128_neon_, nk_partial_store_b32x4_serial_, 1)
41
- nk_define_cross_normalized_symmetric_(euclidean, f16, neonhalf, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
42
- nk_dots_symmetric_f16_neonhalf, nk_euclidean_through_f32_from_dot_neon_,
43
- nk_dots_reduce_sumsq_f16_, nk_load_b128_neon_, nk_partial_load_b32x4_serial_,
44
- nk_store_b128_neon_, nk_partial_store_b32x4_serial_, 1)
45
-
46
- #if defined(__clang__)
47
- #pragma clang attribute pop
48
- #elif defined(__GNUC__)
49
- #pragma GCC pop_options
50
- #endif
51
-
52
- #if defined(__cplusplus)
53
- } // extern "C"
54
- #endif
55
-
56
- #endif // NK_TARGET_NEONHALF
57
- #endif // NK_TARGET_ARM_
58
- #endif // NK_SPATIALS_NEONHALF_H