numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -139,74 +139,6 @@ nk_angular_bf16_genoa_cycle:
139
139
  *result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
140
140
  }
141
141
 
142
- NK_PUBLIC void nk_sqeuclidean_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
143
- __m512 a_sq_f32x16 = _mm512_setzero_ps();
144
- __m512 b_sq_f32x16 = _mm512_setzero_ps();
145
- __m512 ab_f32x16 = _mm512_setzero_ps();
146
- __m256i a_e4m3x32, b_e4m3x32;
147
-
148
- nk_sqeuclidean_e4m3_genoa_cycle:
149
- if (n < 32) {
150
- __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
151
- a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
152
- b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
153
- n = 0;
154
- }
155
- else {
156
- a_e4m3x32 = _mm256_loadu_epi8(a);
157
- b_e4m3x32 = _mm256_loadu_epi8(b);
158
- a += 32, b += 32, n -= 32;
159
- }
160
- __m512i a_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(a_e4m3x32);
161
- __m512i b_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(b_e4m3x32);
162
- a_sq_f32x16 = _mm512_dpbf16_ps(a_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(a_bf16x32));
163
- b_sq_f32x16 = _mm512_dpbf16_ps(b_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
164
- ab_f32x16 = _mm512_dpbf16_ps(ab_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
165
- if (n) goto nk_sqeuclidean_e4m3_genoa_cycle;
166
-
167
- // (a-b)² = a² + b² - 2ab
168
- __m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
169
- *result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
170
- }
171
-
172
- NK_PUBLIC void nk_euclidean_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
173
- nk_sqeuclidean_e4m3_genoa(a, b, n, result);
174
- *result = nk_f32_sqrt_haswell(*result);
175
- }
176
-
177
- NK_PUBLIC void nk_angular_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
178
- __m512 dot_f32x16 = _mm512_setzero_ps();
179
- __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
180
- __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
181
- __m256i a_e4m3x32, b_e4m3x32;
182
-
183
- nk_angular_e4m3_genoa_cycle:
184
- if (n < 32) {
185
- __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
186
- a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
187
- b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
188
- n = 0;
189
- }
190
- else {
191
- a_e4m3x32 = _mm256_loadu_epi8(a);
192
- b_e4m3x32 = _mm256_loadu_epi8(b);
193
- a += 32, b += 32, n -= 32;
194
- }
195
- __m512i a_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(a_e4m3x32);
196
- __m512i b_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(b_e4m3x32);
197
- dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
198
- a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
199
- nk_m512bh_from_m512i_(a_bf16x32));
200
- b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
201
- nk_m512bh_from_m512i_(b_bf16x32));
202
- if (n) goto nk_angular_e4m3_genoa_cycle;
203
-
204
- nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
205
- nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
206
- nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
207
- *result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
208
- }
209
-
210
142
  NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
211
143
  __m512 a_sq_f32x16 = _mm512_setzero_ps();
212
144
  __m512 b_sq_f32x16 = _mm512_setzero_ps();
@@ -8,14 +8,14 @@
8
8
  *
9
9
  * @section spatial_haswell_instructions Key AVX2 Spatial Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy 0.5/cy p01
13
- * _mm256_mul_ps VMULPS (YMM, YMM, YMM) 5cy 0.5/cy p01
14
- * _mm256_add_ps VADDPS (YMM, YMM, YMM) 3cy 1/cy p01
15
- * _mm256_sub_ps VSUBPS (YMM, YMM, YMM) 3cy 1/cy p01
16
- * _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy 1/cy p0
17
- * _mm_sqrt_ps VSQRTPS (XMM, XMM) 11cy 7cy p0
18
- * _mm256_sqrt_ps VSQRTPS (YMM, YMM) 12cy 14cy p0
11
+ * Intrinsic Instruction Haswell Genoa
12
+ * _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy @ p01 4cy @ p01
13
+ * _mm256_mul_ps VMULPS (YMM, YMM, YMM) 5cy @ p01 3cy @ p01
14
+ * _mm256_add_ps VADDPS (YMM, YMM, YMM) 3cy @ p01 3cy @ p23
15
+ * _mm256_sub_ps VSUBPS (YMM, YMM, YMM) 3cy @ p01 3cy @ p23
16
+ * _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0 4cy @ p01
17
+ * _mm_sqrt_ps VSQRTPS (XMM, XMM) 11cy @ p0 15cy @ p01
18
+ * _mm256_sqrt_ps VSQRTPS (YMM, YMM) 19cy @ p0 15cy @ p01
19
19
  *
20
20
  * For angular distance normalization, `_mm_rsqrt_ps` provides ~12-bit precision (1.5 x 2⁻¹² error).
21
21
  * Newton-Raphson refinement doubles precision to ~22-24 bits, sufficient for f32. For f64 we use
@@ -52,7 +52,7 @@ NK_INTERNAL __m128 nk_rsqrt_f32x4_haswell_(__m128 x) {
52
52
  }
53
53
 
54
54
  /** @brief Safe square root of 4 floats with zero-clamping for numerical stability. */
55
- NK_INTERNAL __m128 nk_safe_sqrt_f32x4_haswell_(__m128 x) { return _mm_sqrt_ps(_mm_max_ps(x, _mm_setzero_ps())); }
55
+ NK_INTERNAL __m128 nk_sqrt_f32x4_haswell_(__m128 x) { return _mm_sqrt_ps(_mm_max_ps(x, _mm_setzero_ps())); }
56
56
 
57
57
  /** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs. */
58
58
  NK_INTERNAL void nk_angular_through_f32_from_dot_haswell_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
@@ -73,7 +73,7 @@ NK_INTERNAL void nk_euclidean_through_f32_from_dot_haswell_(nk_b128_vec_t dots,
73
73
  __m128 query_sumsq_f32x4 = _mm_set1_ps(query_sumsq);
74
74
  __m128 sum_sq_f32x4 = _mm_add_ps(query_sumsq_f32x4, target_sumsqs.xmm_ps);
75
75
  __m128 dist_sq_f32x4 = _mm_fnmadd_ps(_mm_set1_ps(2.0f), dots_f32x4, sum_sq_f32x4);
76
- results->xmm_ps = nk_safe_sqrt_f32x4_haswell_(dist_sq_f32x4);
76
+ results->xmm_ps = nk_sqrt_f32x4_haswell_(dist_sq_f32x4);
77
77
  }
78
78
 
79
79
  /** @brief Angular from_dot for native f64: 1 − dot / √(query_sumsq × target_sumsq) for 4 pairs. */
@@ -117,7 +117,7 @@ NK_INTERNAL void nk_euclidean_through_i32_from_dot_haswell_(nk_b128_vec_t dots,
117
117
  __m128 query_sumsq_f32x4 = _mm_set1_ps((nk_f32_t)query_sumsq);
118
118
  __m128 sum_sq_f32x4 = _mm_add_ps(query_sumsq_f32x4, _mm_cvtepi32_ps(target_sumsqs.xmm));
119
119
  __m128 dist_sq_f32x4 = _mm_fnmadd_ps(_mm_set1_ps(2.0f), dots_f32x4, sum_sq_f32x4);
120
- results->xmm_ps = nk_safe_sqrt_f32x4_haswell_(dist_sq_f32x4);
120
+ results->xmm_ps = nk_sqrt_f32x4_haswell_(dist_sq_f32x4);
121
121
  }
122
122
 
123
123
  /** @brief Angular from_dot for u32 accumulators: cast to f32, rsqrt+NR, clamp. 4 pairs. */
@@ -139,7 +139,7 @@ NK_INTERNAL void nk_euclidean_through_u32_from_dot_haswell_(nk_b128_vec_t dots,
139
139
  __m128 query_sumsq_f32x4 = _mm_set1_ps((nk_f32_t)query_sumsq);
140
140
  __m128 sum_sq_f32x4 = _mm_add_ps(query_sumsq_f32x4, _mm_cvtepi32_ps(target_sumsqs.xmm));
141
141
  __m128 dist_sq_f32x4 = _mm_fnmadd_ps(_mm_set1_ps(2.0f), dots_f32x4, sum_sq_f32x4);
142
- results->xmm_ps = nk_safe_sqrt_f32x4_haswell_(dist_sq_f32x4);
142
+ results->xmm_ps = nk_sqrt_f32x4_haswell_(dist_sq_f32x4);
143
143
  }
144
144
 
145
145
  NK_INTERNAL nk_f64_t nk_angular_normalize_f64_haswell_(nk_f64_t ab, nk_f64_t a2, nk_f64_t b2) {
@@ -173,28 +173,30 @@ NK_INTERNAL nk_f32_t nk_angular_normalize_f32_haswell_(nk_f32_t ab, nk_f32_t a2,
173
173
  else if (ab == 0.0f) return 1.0f;
174
174
 
175
175
  // Load the squares into an __m128 register for single-precision floating-point operations
176
- __m128 squares = _mm_set_ps(a2, b2, a2, b2); // We replicate to make use of full register
176
+ __m128 squares_f32x4 = _mm_set_ps(a2, b2, a2, b2); // We replicate to make use of full register
177
177
 
178
178
  // Compute the reciprocal square root of the squares using `_mm_rsqrt_ps` (single-precision)
179
- __m128 rsqrts = _mm_rsqrt_ps(squares);
179
+ __m128 rsqrts_f32x4 = _mm_rsqrt_ps(squares_f32x4);
180
180
 
181
181
  // Perform one iteration of Newton-Raphson refinement to improve the precision of rsqrt:
182
182
  // Formula: y' = y × (1.5 - 0.5 × x × y × y)
183
- __m128 half = _mm_set1_ps(0.5f);
184
- __m128 three_halves = _mm_set1_ps(1.5f);
185
- rsqrts = _mm_mul_ps(rsqrts,
186
- _mm_sub_ps(three_halves, _mm_mul_ps(half, _mm_mul_ps(squares, _mm_mul_ps(rsqrts, rsqrts)))));
183
+ __m128 half_f32x4 = _mm_set1_ps(0.5f);
184
+ __m128 three_halves_f32x4 = _mm_set1_ps(1.5f);
185
+ rsqrts_f32x4 = _mm_mul_ps(
186
+ rsqrts_f32x4,
187
+ _mm_sub_ps(three_halves_f32x4,
188
+ _mm_mul_ps(half_f32x4, _mm_mul_ps(squares_f32x4, _mm_mul_ps(rsqrts_f32x4, rsqrts_f32x4)))));
187
189
 
188
190
  // Extract the reciprocal square roots of a2 and b2 from the __m128 register
189
- nk_f32_t a2_reciprocal = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts, rsqrts, _MM_SHUFFLE(0, 0, 0, 1)));
190
- nk_f32_t b2_reciprocal = _mm_cvtss_f32(rsqrts);
191
+ nk_f32_t a2_reciprocal = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts_f32x4, rsqrts_f32x4, _MM_SHUFFLE(0, 0, 0, 1)));
192
+ nk_f32_t b2_reciprocal = _mm_cvtss_f32(rsqrts_f32x4);
191
193
 
192
194
  // Calculate the angular distance: 1 - dot_product × a2_reciprocal × b2_reciprocal
193
195
  nk_f32_t result = 1.0f - ab * a2_reciprocal * b2_reciprocal;
194
196
  return result > 0 ? result : 0;
195
197
  }
196
198
 
197
- #pragma region - Smaller Floats
199
+ #pragma region F16 and BF16 Floats
198
200
 
199
201
  NK_PUBLIC void nk_sqeuclidean_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
200
202
  __m256 a_f32x8, b_f32x8;
@@ -257,25 +259,32 @@ nk_angular_f16_haswell_cycle:
257
259
  }
258
260
 
259
261
  NK_PUBLIC void nk_sqeuclidean_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
260
- __m256 a_f32x8, b_f32x8;
262
+ __m256i a_bf16_i16x16, b_bf16_i16x16;
261
263
  __m256 distance_sq_f32x8 = _mm256_setzero_ps();
264
+ __m256i mask_high_u32x8 = _mm256_set1_epi32((int)0xFFFF0000);
262
265
 
263
266
  nk_sqeuclidean_bf16_haswell_cycle:
264
- if (n < 8) {
267
+ if (n < 16) {
265
268
  nk_b256_vec_t a_vec, b_vec;
266
- nk_partial_load_bf16x8_to_f32x8_haswell_(a, &a_vec, n);
267
- nk_partial_load_bf16x8_to_f32x8_haswell_(b, &b_vec, n);
268
- a_f32x8 = a_vec.ymm_ps;
269
- b_f32x8 = b_vec.ymm_ps;
269
+ nk_partial_load_b16x16_serial_(a, &a_vec, n);
270
+ nk_partial_load_b16x16_serial_(b, &b_vec, n);
271
+ a_bf16_i16x16 = a_vec.ymm;
272
+ b_bf16_i16x16 = b_vec.ymm;
270
273
  n = 0;
271
274
  }
272
275
  else {
273
- a_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
274
- b_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
275
- n -= 8, a += 8, b += 8;
276
- }
277
- __m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
278
- distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
276
+ a_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)a);
277
+ b_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)b);
278
+ n -= 16, a += 16, b += 16;
279
+ }
280
+ __m256 a_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(a_bf16_i16x16, 16));
281
+ __m256 b_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(b_bf16_i16x16, 16));
282
+ __m256 diff_even_f32x8 = _mm256_sub_ps(a_even_f32x8, b_even_f32x8);
283
+ distance_sq_f32x8 = _mm256_fmadd_ps(diff_even_f32x8, diff_even_f32x8, distance_sq_f32x8);
284
+ __m256 a_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(a_bf16_i16x16, mask_high_u32x8));
285
+ __m256 b_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(b_bf16_i16x16, mask_high_u32x8));
286
+ __m256 diff_odd_f32x8 = _mm256_sub_ps(a_odd_f32x8, b_odd_f32x8);
287
+ distance_sq_f32x8 = _mm256_fmadd_ps(diff_odd_f32x8, diff_odd_f32x8, distance_sq_f32x8);
279
288
  if (n) goto nk_sqeuclidean_bf16_haswell_cycle;
280
289
 
281
290
  *result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
@@ -287,27 +296,35 @@ NK_PUBLIC void nk_euclidean_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b,
287
296
  }
288
297
 
289
298
  NK_PUBLIC void nk_angular_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
290
- __m256 a_f32x8, b_f32x8;
299
+ __m256i a_bf16_i16x16, b_bf16_i16x16;
291
300
  __m256 dot_product_f32x8 = _mm256_setzero_ps(), a_norm_sq_f32x8 = _mm256_setzero_ps(),
292
301
  b_norm_sq_f32x8 = _mm256_setzero_ps();
302
+ __m256i mask_high_u32x8 = _mm256_set1_epi32((int)0xFFFF0000);
293
303
 
294
304
  nk_angular_bf16_haswell_cycle:
295
- if (n < 8) {
305
+ if (n < 16) {
296
306
  nk_b256_vec_t a_vec, b_vec;
297
- nk_partial_load_bf16x8_to_f32x8_haswell_(a, &a_vec, n);
298
- nk_partial_load_bf16x8_to_f32x8_haswell_(b, &b_vec, n);
299
- a_f32x8 = a_vec.ymm_ps;
300
- b_f32x8 = b_vec.ymm_ps;
307
+ nk_partial_load_b16x16_serial_(a, &a_vec, n);
308
+ nk_partial_load_b16x16_serial_(b, &b_vec, n);
309
+ a_bf16_i16x16 = a_vec.ymm;
310
+ b_bf16_i16x16 = b_vec.ymm;
301
311
  n = 0;
302
312
  }
303
313
  else {
304
- a_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
305
- b_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
306
- n -= 8, a += 8, b += 8;
307
- }
308
- dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
309
- a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
310
- b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
314
+ a_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)a);
315
+ b_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)b);
316
+ n -= 16, a += 16, b += 16;
317
+ }
318
+ __m256 a_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(a_bf16_i16x16, 16));
319
+ __m256 b_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(b_bf16_i16x16, 16));
320
+ dot_product_f32x8 = _mm256_fmadd_ps(a_even_f32x8, b_even_f32x8, dot_product_f32x8);
321
+ a_norm_sq_f32x8 = _mm256_fmadd_ps(a_even_f32x8, a_even_f32x8, a_norm_sq_f32x8);
322
+ b_norm_sq_f32x8 = _mm256_fmadd_ps(b_even_f32x8, b_even_f32x8, b_norm_sq_f32x8);
323
+ __m256 a_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(a_bf16_i16x16, mask_high_u32x8));
324
+ __m256 b_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(b_bf16_i16x16, mask_high_u32x8));
325
+ dot_product_f32x8 = _mm256_fmadd_ps(a_odd_f32x8, b_odd_f32x8, dot_product_f32x8);
326
+ a_norm_sq_f32x8 = _mm256_fmadd_ps(a_odd_f32x8, a_odd_f32x8, a_norm_sq_f32x8);
327
+ b_norm_sq_f32x8 = _mm256_fmadd_ps(b_odd_f32x8, b_odd_f32x8, b_norm_sq_f32x8);
311
328
  if (n) goto nk_angular_bf16_haswell_cycle;
312
329
 
313
330
  nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
@@ -316,8 +333,8 @@ nk_angular_bf16_haswell_cycle:
316
333
  *result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
317
334
  }
318
335
 
319
- #pragma endregion - Smaller Floats
320
- #pragma region - Small Integers
336
+ #pragma endregion F16 and BF16 Floats
337
+ #pragma region I8 and U8 Integers
321
338
 
322
339
  NK_PUBLIC void nk_sqeuclidean_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
323
340
  // Optimized i8 L2-squared using saturating subtract + VPMADDWD
@@ -433,7 +450,8 @@ NK_PUBLIC void nk_angular_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size
433
450
  b_norm_sq_i32 += b_element_i32 * b_element_i32;
434
451
  }
435
452
 
436
- *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
453
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
454
+ (nk_f32_t)b_norm_sq_i32);
437
455
  }
438
456
 
439
457
  NK_PUBLIC void nk_sqeuclidean_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
@@ -539,11 +557,12 @@ NK_PUBLIC void nk_angular_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size
539
557
  b_norm_sq_i32 += b_element_i32 * b_element_i32;
540
558
  }
541
559
 
542
- *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
560
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
561
+ (nk_f32_t)b_norm_sq_i32);
543
562
  }
544
563
 
545
- #pragma endregion - Small Integers
546
- #pragma region - Traditional Floats
564
+ #pragma endregion I8 and U8 Integers
565
+ #pragma region F32 and F64 Floats
547
566
 
548
567
  NK_PUBLIC void nk_sqeuclidean_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
549
568
  // Upcast to f64 for higher precision accumulation
@@ -693,8 +712,8 @@ nk_angular_f64_haswell_cycle:
693
712
  nk_reduce_add_f64x4_haswell_(a_norm_sq_f64x4), nk_reduce_add_f64x4_haswell_(b_norm_sq_f64x4));
694
713
  }
695
714
 
696
- #pragma endregion - Traditional Floats
697
- #pragma region - Smaller Floats
715
+ #pragma endregion F32 and F64 Floats
716
+ #pragma region FP8 Floats
698
717
 
699
718
  NK_PUBLIC void nk_sqeuclidean_e2m3_haswell(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
700
719
  __m256 distance_sq_f32x8 = _mm256_setzero_ps();
@@ -954,7 +973,7 @@ nk_angular_e5m2_haswell_cycle:
954
973
  } // extern "C"
955
974
  #endif
956
975
 
957
- #pragma endregion - Smaller Floats
976
+ #pragma endregion FP8 Floats
958
977
  #endif // NK_TARGET_HASWELL
959
978
  #endif // NK_TARGET_X86_
960
979
  #endif // NK_SPATIAL_HASWELL_H