numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,586 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for Ice Lake.
3
+ * @file include/numkong/spatial/icelake.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * @section spatial_icelake_instructions Key AVX-512 VNNI Spatial Instructions
10
+ *
11
+ * Intrinsic Instruction Ice Genoa
12
+ * _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
13
+ * _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 3cy @ p12
14
+ * _mm512_sub_epi16 VPSUBW (ZMM, ZMM, ZMM) 1cy @ p05 1cy @ p0123
15
+ * _mm512_reduce_add_epi32 (pseudo: shuffle chain) ~8cy ~8cy
16
+ *
17
+ * Ice Lake's VNNI enables efficient i8 distance computations via VPDPWSSD for squared differences.
18
+ * After widening i8 to i16, the same instruction computes both multiply and horizontal pair addition.
19
+ * This approach avoids the asymmetric VPDPBUSD issues with signed values like -128.
20
+ */
21
+ #ifndef NK_SPATIAL_ICELAKE_H
22
+ #define NK_SPATIAL_ICELAKE_H
23
+
24
+ #if NK_TARGET_X86_
25
+ #if NK_TARGET_ICELAKE
26
+
27
+ #include "numkong/types.h"
28
+
29
+ #if defined(__cplusplus)
30
+ extern "C" {
31
+ #endif
32
+
33
+ #if defined(__clang__)
34
+ #pragma clang attribute push( \
35
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,f16c,fma,bmi,bmi2"))), \
36
+ apply_to = function)
37
+ #elif defined(__GNUC__)
38
+ #pragma GCC push_options
39
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "f16c", "fma", "bmi", "bmi2")
40
+ #endif
41
+
42
+ NK_PUBLIC void nk_sqeuclidean_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
43
+ // Optimized i8 L2-squared using saturating subtract + DPWSSD
44
+ //
45
+ // Old approach (Haswell/Skylake):
46
+ // - Compute (a-b) as signed i8, then sign-extend i8→i16 using cvtepi8_epi16
47
+ // - Square using vpmaddwd on i16 values (32 elements/iteration)
48
+ // - Bottleneck: cvtepi8_epi16 (3cy latency @ p5) limits throughput
49
+ //
50
+ // New approach (Ice Lake+):
51
+ // - XOR with 0x80 to reinterpret signed i8 as unsigned u8
52
+ // - Compute |a-b| using unsigned saturating subtraction: diff = (a ⊖ b) | (b ⊖ a)
53
+ // - Zero-extend u8→u16 using unpacking (1cy latency @ p5)
54
+ // - Square using vpmaddwd on u16 values (64 elements/iteration)
55
+ // - Eliminates cvtepi8_epi16 bottleneck, doubles throughput
56
+ //
57
+ // Performance gain: 1.6-1.85× speedup
58
+ // - Processes 64 elements/iteration (2× improvement)
59
+ // - Faster zero-extension (unpack 1cy vs cvtepi8_epi16 3cy)
60
+ // - Correctness: |a-b|² = (a-b)², so unsigned absolute differences are valid
61
+ //
62
+ // The XOR bias is needed because subs_epu8 (unsigned) saturates to 0 when
63
+ // the result would be negative, so OR-ing both directions gives the true |a-b|.
64
+ // A naive subs_epi8 (signed) saturates to -128, corrupting the OR trick.
65
+ //
66
+ __m512i distance_sq_low_i32x16 = _mm512_setzero_si512();
67
+ __m512i distance_sq_high_i32x16 = _mm512_setzero_si512();
68
+ __m512i const zeros_i8x64 = _mm512_setzero_si512();
69
+ __m512i const bias_i8x64 = _mm512_set1_epi8((char)0x80);
70
+ __m512i diff_low_i16x32, diff_high_i16x32;
71
+ __m512i a_i8x64, b_i8x64, diff_u8x64;
72
+
73
+ nk_sqeuclidean_i8_icelake_cycle:
74
+ if (n < 64) {
75
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
76
+ a_i8x64 = _mm512_maskz_loadu_epi8(mask, a);
77
+ b_i8x64 = _mm512_maskz_loadu_epi8(mask, b);
78
+ n = 0;
79
+ }
80
+ else {
81
+ a_i8x64 = _mm512_loadu_si512(a);
82
+ b_i8x64 = _mm512_loadu_si512(b);
83
+ a += 64, b += 64, n -= 64;
84
+ }
85
+
86
+ // Reinterpret signed i8 as unsigned u8 by flipping the sign bit
87
+ a_i8x64 = _mm512_xor_si512(a_i8x64, bias_i8x64);
88
+ b_i8x64 = _mm512_xor_si512(b_i8x64, bias_i8x64);
89
+
90
+ // Compute |a-b| using unsigned saturating subtraction
91
+ // subs_epu8 saturates to 0 if result would be negative
92
+ // OR-ing both directions gives absolute difference as unsigned
93
+ diff_u8x64 = _mm512_or_si512(_mm512_subs_epu8(a_i8x64, b_i8x64), _mm512_subs_epu8(b_i8x64, a_i8x64));
94
+
95
+ // Zero-extend to i16 using unpack (1cy @ p5, much faster than cvtepi8_epi16)
96
+ diff_low_i16x32 = _mm512_unpacklo_epi8(diff_u8x64, zeros_i8x64);
97
+ diff_high_i16x32 = _mm512_unpackhi_epi8(diff_u8x64, zeros_i8x64);
98
+
99
+ // Multiply and accumulate at i16 level, accumulate at i32 level
100
+ distance_sq_low_i32x16 = _mm512_dpwssd_epi32(distance_sq_low_i32x16, diff_low_i16x32, diff_low_i16x32);
101
+ distance_sq_high_i32x16 = _mm512_dpwssd_epi32(distance_sq_high_i32x16, diff_high_i16x32, diff_high_i16x32);
102
+ if (n) goto nk_sqeuclidean_i8_icelake_cycle;
103
+
104
+ *result = _mm512_reduce_add_epi32(_mm512_add_epi32(distance_sq_low_i32x16, distance_sq_high_i32x16));
105
+ }
106
+
107
+ NK_PUBLIC void nk_euclidean_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
108
+ nk_u32_t d2;
109
+ nk_sqeuclidean_i8_icelake(a, b, n, &d2);
110
+ *result = nk_f32_sqrt_haswell((nk_f32_t)d2);
111
+ }
112
+
113
+ NK_PUBLIC void nk_angular_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
114
+
115
+ __m512i dot_product_i32x16 = _mm512_setzero_si512();
116
+ __m512i a_norm_sq_i32x16 = _mm512_setzero_si512();
117
+ __m512i b_norm_sq_i32x16 = _mm512_setzero_si512();
118
+ __m512i a_i16x32, b_i16x32;
119
+ nk_angular_i8_icelake_cycle:
120
+ if (n < 32) {
121
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
122
+ a_i16x32 = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, a));
123
+ b_i16x32 = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b));
124
+ n = 0;
125
+ }
126
+ else {
127
+ a_i16x32 = _mm512_cvtepi8_epi16(_mm256_loadu_si256((__m256i const *)a));
128
+ b_i16x32 = _mm512_cvtepi8_epi16(_mm256_loadu_si256((__m256i const *)b));
129
+ a += 32, b += 32, n -= 32;
130
+ }
131
+
132
+ // We can't directly use the `_mm512_dpbusd_epi32` intrinsic everywhere,
133
+ // as it's asymmetric with respect to the sign of the input arguments:
134
+ //
135
+ // Signed(ZeroExtend16(a.byte[4 × j]) × SignExtend16(b.byte[4 × j]))
136
+ //
137
+ // To compute the squares, we could just drop the sign bit of the second argument.
138
+ // But this would lead to big-big problems on values like `-128`!
139
+ // For dot-products we don't have the luxury of optimizing the sign bit away.
140
+ // Assuming this is an approximate kernel (with reciprocal square root approximations)
141
+ // in the end, we can allow clamping the value to [-127, 127] range.
142
+ //
143
+ // VNNI instruction performance (Ice Lake vs Zen4 Genoa):
144
+ //
145
+ // Instruction Ice Genoa
146
+ // VPDPBUSDS (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
147
+ // VPDPWSSDS (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
148
+ // VPMADDWD (ZMM, ZMM, ZMM) 5cy @ p05 3cy @ p01
149
+ //
150
+ // On Ice Lake, VNNI bottlenecks on port 0. On Genoa, dual-issue on p01 is faster.
151
+ //
152
+ // The old solution was complex replied on 1. and 2.:
153
+ //
154
+ // a_i8_abs_vec = _mm512_abs_epi8(a_i8_vec);
155
+ // b_i8_abs_vec = _mm512_abs_epi8(b_i8_vec);
156
+ // a2_i32_vec = _mm512_dpbusds_epi32(a2_i32_vec, a_i8_abs_vec, a_i8_abs_vec);
157
+ // b2_i32_vec = _mm512_dpbusds_epi32(b2_i32_vec, b_i8_abs_vec, b_i8_abs_vec);
158
+ // ab_i32_low_vec = _mm512_dpwssds_epi32( //
159
+ // ab_i32_low_vec, //
160
+ // _mm512_cvtepi8_epi16(_mm512_castsi512_si256(a_i8_vec)), //
161
+ // _mm512_cvtepi8_epi16(_mm512_castsi512_si256(b_i8_vec)));
162
+ // ab_i32_high_vec = _mm512_dpwssds_epi32( //
163
+ // ab_i32_high_vec, //
164
+ // _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(a_i8_vec, 1)), //
165
+ // _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(b_i8_vec, 1)));
166
+ //
167
+ // The new solution is simpler and relies on 3.:
168
+ dot_product_i32x16 = _mm512_add_epi32(dot_product_i32x16, _mm512_madd_epi16(a_i16x32, b_i16x32));
169
+ a_norm_sq_i32x16 = _mm512_add_epi32(a_norm_sq_i32x16, _mm512_madd_epi16(a_i16x32, a_i16x32));
170
+ b_norm_sq_i32x16 = _mm512_add_epi32(b_norm_sq_i32x16, _mm512_madd_epi16(b_i16x32, b_i16x32));
171
+ if (n) goto nk_angular_i8_icelake_cycle;
172
+
173
+ nk_i32_t dot_product_i32 = _mm512_reduce_add_epi32(dot_product_i32x16);
174
+ nk_i32_t a_norm_sq_i32 = _mm512_reduce_add_epi32(a_norm_sq_i32x16);
175
+ nk_i32_t b_norm_sq_i32 = _mm512_reduce_add_epi32(b_norm_sq_i32x16);
176
+ *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
177
+ }
178
+ NK_PUBLIC void nk_sqeuclidean_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
179
+ __m512i distance_sq_low_i32x16 = _mm512_setzero_si512();
180
+ __m512i distance_sq_high_i32x16 = _mm512_setzero_si512();
181
+ __m512i const zeros_i8x64 = _mm512_setzero_si512();
182
+ __m512i diff_low_i16x32, diff_high_i16x32;
183
+ __m512i a_u8x64, b_u8x64, diff_u8x64;
184
+
185
+ nk_sqeuclidean_u8_icelake_cycle:
186
+ if (n < 64) {
187
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
188
+ a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
189
+ b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
190
+ n = 0;
191
+ }
192
+ else {
193
+ a_u8x64 = _mm512_loadu_si512(a);
194
+ b_u8x64 = _mm512_loadu_si512(b);
195
+ a += 64, b += 64, n -= 64;
196
+ }
197
+
198
+ // Substracting unsigned vectors in AVX-512 is done by saturating subtraction:
199
+ diff_u8x64 = _mm512_or_si512(_mm512_subs_epu8(a_u8x64, b_u8x64), _mm512_subs_epu8(b_u8x64, a_u8x64));
200
+ diff_low_i16x32 = _mm512_unpacklo_epi8(diff_u8x64, zeros_i8x64);
201
+ diff_high_i16x32 = _mm512_unpackhi_epi8(diff_u8x64, zeros_i8x64);
202
+
203
+ // Multiply and accumulate at `int16` level, accumulate at `int32` level:
204
+ distance_sq_low_i32x16 = _mm512_dpwssd_epi32(distance_sq_low_i32x16, diff_low_i16x32, diff_low_i16x32);
205
+ distance_sq_high_i32x16 = _mm512_dpwssd_epi32(distance_sq_high_i32x16, diff_high_i16x32, diff_high_i16x32);
206
+ if (n) goto nk_sqeuclidean_u8_icelake_cycle;
207
+
208
+ *result = _mm512_reduce_add_epi32(_mm512_add_epi32(distance_sq_low_i32x16, distance_sq_high_i32x16));
209
+ }
210
+ NK_PUBLIC void nk_euclidean_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
211
+ nk_u32_t d2;
212
+ nk_sqeuclidean_u8_icelake(a, b, n, &d2);
213
+ *result = nk_f32_sqrt_haswell((nk_f32_t)d2);
214
+ }
215
+
216
+ NK_PUBLIC void nk_angular_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
217
+
218
+ __m512i dot_product_low_i32x16 = _mm512_setzero_si512();
219
+ __m512i dot_product_high_i32x16 = _mm512_setzero_si512();
220
+ __m512i a_norm_sq_low_i32x16 = _mm512_setzero_si512();
221
+ __m512i a_norm_sq_high_i32x16 = _mm512_setzero_si512();
222
+ __m512i b_norm_sq_low_i32x16 = _mm512_setzero_si512();
223
+ __m512i b_norm_sq_high_i32x16 = _mm512_setzero_si512();
224
+ __m512i const zeros_i8x64 = _mm512_setzero_si512();
225
+ __m512i a_low_i16x32, a_high_i16x32, b_low_i16x32, b_high_i16x32;
226
+ __m512i a_u8x64, b_u8x64;
227
+
228
+ nk_angular_u8_icelake_cycle:
229
+ if (n < 64) {
230
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
231
+ a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
232
+ b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
233
+ n = 0;
234
+ }
235
+ else {
236
+ a_u8x64 = _mm512_loadu_si512(a);
237
+ b_u8x64 = _mm512_loadu_si512(b);
238
+ a += 64, b += 64, n -= 64;
239
+ }
240
+
241
+ // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking
242
+ // instructions instead of extracts, as they are much faster and more efficient.
243
+ a_low_i16x32 = _mm512_unpacklo_epi8(a_u8x64, zeros_i8x64);
244
+ a_high_i16x32 = _mm512_unpackhi_epi8(a_u8x64, zeros_i8x64);
245
+ b_low_i16x32 = _mm512_unpacklo_epi8(b_u8x64, zeros_i8x64);
246
+ b_high_i16x32 = _mm512_unpackhi_epi8(b_u8x64, zeros_i8x64);
247
+
248
+ // Multiply and accumulate as `int16`, accumulate products as `int32`:
249
+ dot_product_low_i32x16 = _mm512_dpwssds_epi32(dot_product_low_i32x16, a_low_i16x32, b_low_i16x32);
250
+ dot_product_high_i32x16 = _mm512_dpwssds_epi32(dot_product_high_i32x16, a_high_i16x32, b_high_i16x32);
251
+ a_norm_sq_low_i32x16 = _mm512_dpwssds_epi32(a_norm_sq_low_i32x16, a_low_i16x32, a_low_i16x32);
252
+ a_norm_sq_high_i32x16 = _mm512_dpwssds_epi32(a_norm_sq_high_i32x16, a_high_i16x32, a_high_i16x32);
253
+ b_norm_sq_low_i32x16 = _mm512_dpwssds_epi32(b_norm_sq_low_i32x16, b_low_i16x32, b_low_i16x32);
254
+ b_norm_sq_high_i32x16 = _mm512_dpwssds_epi32(b_norm_sq_high_i32x16, b_high_i16x32, b_high_i16x32);
255
+ if (n) goto nk_angular_u8_icelake_cycle;
256
+
257
+ nk_i32_t dot_product_i32 = _mm512_reduce_add_epi32(
258
+ _mm512_add_epi32(dot_product_low_i32x16, dot_product_high_i32x16));
259
+ nk_i32_t a_norm_sq_i32 = _mm512_reduce_add_epi32(_mm512_add_epi32(a_norm_sq_low_i32x16, a_norm_sq_high_i32x16));
260
+ nk_i32_t b_norm_sq_i32 = _mm512_reduce_add_epi32(_mm512_add_epi32(b_norm_sq_low_i32x16, b_norm_sq_high_i32x16));
261
+ *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
262
+ }
263
+
264
+ NK_PUBLIC void nk_sqeuclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result) {
265
+ // i4 values are packed as nibbles: two 4-bit signed values per byte.
266
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
267
+ n = nk_size_round_up_to_multiple_(n, 2);
268
+ nk_size_t n_bytes = n / 2;
269
+
270
+ // While `int8_t` covers the range [-128, 127], `int4_t` covers only [-8, 7].
271
+ // The absolute difference between two 4-bit integers is at most 15 and fits in `uint4_t`.
272
+ // Moreover, its square is at most 225, which fits into `uint8_t`.
273
+ //
274
+ // Instead of using lookup tables for sign extension and squaring, we use arithmetic:
275
+ //
276
+ // 1. XOR trick for sign extension: `signed = (nibble ^ 8) - 8`
277
+ // Maps [0,7] → [0,7] (positive) and [8,15] → [-8,-1] (negative).
278
+ //
279
+ // 2. For L2 squared: |a-b|² = diff * diff, using `_mm512_dpbusd_epi32`.
280
+ // After computing signed difference and taking abs, the result fits ∈ [0,15].
281
+ // We can then use DPBUSD to compute diff² efficiently without lookup tables.
282
+ //
283
+ // This approach avoids 8x VPSHUFB operations per iteration, replacing them with
284
+ // arithmetic operations that distribute better across execution ports.
285
+ __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
286
+ __m512i const eight_i8x64 = _mm512_set1_epi8(8);
287
+
288
+ __m512i a_i4_vec, b_i4_vec;
289
+ __m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
290
+ __m512i a_low_i8x64, a_high_i8x64, b_low_i8x64, b_high_i8x64;
291
+ __m512i diff_low_u8x64, diff_high_u8x64;
292
+ __m512i d2_i32x16 = _mm512_setzero_si512();
293
+
294
+ nk_sqeuclidean_i4_icelake_cycle:
295
+ if (n_bytes < 64) {
296
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
297
+ a_i4_vec = _mm512_maskz_loadu_epi8(mask, a);
298
+ b_i4_vec = _mm512_maskz_loadu_epi8(mask, b);
299
+ n_bytes = 0;
300
+ }
301
+ else {
302
+ a_i4_vec = _mm512_loadu_epi8(a);
303
+ b_i4_vec = _mm512_loadu_epi8(b);
304
+ a += 64, b += 64, n_bytes -= 64;
305
+ }
306
+
307
+ // Extract nibbles as unsigned [0,15]. VPSHUFB ignores high 4 bits of index,
308
+ // so no AND needed for low nibbles when used with lookup, but we need it here.
309
+ a_low_u8x64 = _mm512_and_si512(a_i4_vec, nibble_mask_u8x64);
310
+ a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4_vec, 4), nibble_mask_u8x64);
311
+ b_low_u8x64 = _mm512_and_si512(b_i4_vec, nibble_mask_u8x64);
312
+ b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4_vec, 4), nibble_mask_u8x64);
313
+
314
+ // Sign extend using XOR trick: signed = (nibble ^ 8) - 8
315
+ a_low_i8x64 = _mm512_sub_epi8(_mm512_xor_si512(a_low_u8x64, eight_i8x64), eight_i8x64);
316
+ a_high_i8x64 = _mm512_sub_epi8(_mm512_xor_si512(a_high_u8x64, eight_i8x64), eight_i8x64);
317
+ b_low_i8x64 = _mm512_sub_epi8(_mm512_xor_si512(b_low_u8x64, eight_i8x64), eight_i8x64);
318
+ b_high_i8x64 = _mm512_sub_epi8(_mm512_xor_si512(b_high_u8x64, eight_i8x64), eight_i8x64);
319
+
320
+ // Compute |a - b| for each nibble pair. Result is unsigned ∈ [0, 15].
321
+ diff_low_u8x64 = _mm512_abs_epi8(_mm512_sub_epi8(a_low_i8x64, b_low_i8x64));
322
+ diff_high_u8x64 = _mm512_abs_epi8(_mm512_sub_epi8(a_high_i8x64, b_high_i8x64));
323
+
324
+ // Square and accumulate using DPBUSD: diff² = diff * diff.
325
+ // DPBUSD computes u8*i8 products and sums groups of 4 into i32.
326
+ // Since diff is ∈ [0,15], it's safe for both u8 and i8 interpretation.
327
+ d2_i32x16 = _mm512_dpbusd_epi32(d2_i32x16, diff_low_u8x64, diff_low_u8x64);
328
+ d2_i32x16 = _mm512_dpbusd_epi32(d2_i32x16, diff_high_u8x64, diff_high_u8x64);
329
+ if (n_bytes) goto nk_sqeuclidean_i4_icelake_cycle;
330
+
331
+ *result = (nk_u32_t)_mm512_reduce_add_epi32(d2_i32x16);
332
+ }
333
+ NK_PUBLIC void nk_euclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
334
+ nk_u32_t d2;
335
+ nk_sqeuclidean_i4_icelake(a, b, n, &d2);
336
+ *result = nk_f32_sqrt_haswell((nk_f32_t)d2);
337
+ }
338
+ NK_PUBLIC void nk_angular_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
339
+ // i4 values are packed as nibbles: two 4-bit signed values per byte.
340
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
341
+ n = nk_size_round_up_to_multiple_(n, 2);
342
+ nk_size_t n_bytes = n / 2;
343
+
344
+ // Angular distance for signed 4-bit integers requires computing:
345
+ // 1. Dot product: ∑(aᵢ × bᵢ)
346
+ // 2. Squared norms: ∑(aᵢ²) and ∑(bᵢ²)
347
+ //
348
+ // For signed i4 values in [-8, 7], we use DPBUSD for everything by leveraging
349
+ // an algebraic identity. Define x = a ^ 8 (XOR with 8), which maps:
350
+ // [0,7] → [8,15] and [8,15] → [0,7]
351
+ //
352
+ // The signed value is: a_signed = x - 8
353
+ //
354
+ // For two signed values:
355
+ // a_signed × b_signed = (ax - 8)(bx - 8) = ax × bx - 8 × ax - 8 × bx + 64
356
+ //
357
+ // Therefore:
358
+ // dot(a_signed, b_signed) = DPBUSD(ax, bx) - 8 × (∑(ax) + ∑(bx)) + 64 × n
359
+ //
360
+ // This avoids all i8 → i16 upcasts and uses DPBUSD directly on byte values!
361
+ // For norms, we use |x|² = x², computing abs then squaring with DPBUSD.
362
+ __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
363
+ __m512i const eight_i8x64 = _mm512_set1_epi8(8);
364
+ __m512i const zeros_i8x64 = _mm512_setzero_si512();
365
+
366
+ __m512i a_i4_vec, b_i4_vec;
367
+ __m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
368
+ __m512i ax_low_u8x64, ax_high_u8x64, bx_low_u8x64, bx_high_u8x64;
369
+ __m512i a_low_i8x64, a_high_i8x64, b_low_i8x64, b_high_i8x64;
370
+
371
+ // Accumulators for dot product (using biased values) and correction sums
372
+ __m512i ab_i32x16 = zeros_i8x64;
373
+ __m512i ax_sum_i64x8 = zeros_i8x64;
374
+ __m512i bx_sum_i64x8 = zeros_i8x64;
375
+ // Accumulators for squared norms
376
+ __m512i a2_i32x16 = zeros_i8x64;
377
+ __m512i b2_i32x16 = zeros_i8x64;
378
+
379
+ nk_angular_i4_icelake_cycle:
380
+ if (n_bytes < 64) {
381
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
382
+ a_i4_vec = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, a);
383
+ b_i4_vec = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, b);
384
+ n_bytes = 0;
385
+ }
386
+ else {
387
+ a_i4_vec = _mm512_loadu_epi8(a);
388
+ b_i4_vec = _mm512_loadu_epi8(b);
389
+ a += 64, b += 64, n_bytes -= 64;
390
+ }
391
+
392
+ // Extract nibbles as unsigned [0,15]
393
+ a_low_u8x64 = _mm512_and_si512(a_i4_vec, nibble_mask_u8x64);
394
+ a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4_vec, 4), nibble_mask_u8x64);
395
+ b_low_u8x64 = _mm512_and_si512(b_i4_vec, nibble_mask_u8x64);
396
+ b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4_vec, 4), nibble_mask_u8x64);
397
+
398
+ // Compute biased values: ax = a ^ 8 (still ∈ [0,15], just reordered)
399
+ ax_low_u8x64 = _mm512_xor_si512(a_low_u8x64, eight_i8x64);
400
+ ax_high_u8x64 = _mm512_xor_si512(a_high_u8x64, eight_i8x64);
401
+ bx_low_u8x64 = _mm512_xor_si512(b_low_u8x64, eight_i8x64);
402
+ bx_high_u8x64 = _mm512_xor_si512(b_high_u8x64, eight_i8x64);
403
+
404
+ // Dot product using DPBUSD on biased values (correction applied at end)
405
+ ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, ax_low_u8x64, bx_low_u8x64);
406
+ ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, ax_high_u8x64, bx_high_u8x64);
407
+
408
+ // Track sums for correction using SAD (sum of absolute differences with zero)
409
+ ax_sum_i64x8 = _mm512_add_epi64(ax_sum_i64x8, _mm512_sad_epu8(ax_low_u8x64, zeros_i8x64));
410
+ ax_sum_i64x8 = _mm512_add_epi64(ax_sum_i64x8, _mm512_sad_epu8(ax_high_u8x64, zeros_i8x64));
411
+ bx_sum_i64x8 = _mm512_add_epi64(bx_sum_i64x8, _mm512_sad_epu8(bx_low_u8x64, zeros_i8x64));
412
+ bx_sum_i64x8 = _mm512_add_epi64(bx_sum_i64x8, _mm512_sad_epu8(bx_high_u8x64, zeros_i8x64));
413
+
414
+ // For norms: convert to signed, take abs, then square with DPBUSD
415
+ a_low_i8x64 = _mm512_sub_epi8(ax_low_u8x64, eight_i8x64);
416
+ a_high_i8x64 = _mm512_sub_epi8(ax_high_u8x64, eight_i8x64);
417
+ b_low_i8x64 = _mm512_sub_epi8(bx_low_u8x64, eight_i8x64);
418
+ b_high_i8x64 = _mm512_sub_epi8(bx_high_u8x64, eight_i8x64);
419
+
420
+ __m512i a_low_abs_u8x64 = _mm512_abs_epi8(a_low_i8x64);
421
+ __m512i a_high_abs_u8x64 = _mm512_abs_epi8(a_high_i8x64);
422
+ __m512i b_low_abs_u8x64 = _mm512_abs_epi8(b_low_i8x64);
423
+ __m512i b_high_abs_u8x64 = _mm512_abs_epi8(b_high_i8x64);
424
+
425
+ // Squared norms: ‖x‖² = x², use DPBUSD for efficient squaring
426
+ a2_i32x16 = _mm512_dpbusd_epi32(a2_i32x16, a_low_abs_u8x64, a_low_abs_u8x64);
427
+ a2_i32x16 = _mm512_dpbusd_epi32(a2_i32x16, a_high_abs_u8x64, a_high_abs_u8x64);
428
+ b2_i32x16 = _mm512_dpbusd_epi32(b2_i32x16, b_low_abs_u8x64, b_low_abs_u8x64);
429
+ b2_i32x16 = _mm512_dpbusd_epi32(b2_i32x16, b_high_abs_u8x64, b_high_abs_u8x64);
430
+ if (n_bytes) goto nk_angular_i4_icelake_cycle;
431
+
432
+ // Apply algebraic correction for signed dot product:
433
+ // signed_dot = DPBUSD(ax, bx) - 8 × (∑(ax) + ∑(bx)) + 64 × n
434
+ nk_i64_t ax_sum = _mm512_reduce_add_epi64(ax_sum_i64x8);
435
+ nk_i64_t bx_sum = _mm512_reduce_add_epi64(bx_sum_i64x8);
436
+ nk_i32_t ab_raw = _mm512_reduce_add_epi32(ab_i32x16);
437
+ nk_i32_t ab = ab_raw - 8 * (nk_i32_t)(ax_sum + bx_sum) + 64 * (nk_i32_t)n;
438
+
439
+ nk_size_t n_bytes_total = nk_size_divide_round_up_(n, 2);
440
+ nk_i32_t norm_excess = 128 * (nk_i32_t)(nk_size_round_up_to_multiple_(n_bytes_total, 64) - n_bytes_total);
441
+ nk_i32_t a2 = _mm512_reduce_add_epi32(a2_i32x16) - norm_excess;
442
+ nk_i32_t b2 = _mm512_reduce_add_epi32(b2_i32x16) - norm_excess;
443
+ *result = nk_angular_normalize_f32_haswell_(ab, (nk_f32_t)a2, (nk_f32_t)b2);
444
+ }
445
+
446
+ NK_PUBLIC void nk_sqeuclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
447
+ // u4 values are packed as nibbles: two 4-bit unsigned values per byte.
448
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
449
+ n = nk_size_round_up_to_multiple_(n, 2);
450
+ nk_size_t n_bytes = n / 2;
451
+
452
+ // For unsigned 4-bit integers ∈ [0, 15], the L2 squared distance is straightforward:
453
+ // 1. Extract nibbles as u8 values
454
+ // 2. Compute |a - b| using saturating subtraction: max(a,b) - min(a,b) = (a ⊖ b) | (b ⊖ a)
455
+ // 3. Square with DPBUSD: diff * diff
456
+ //
457
+ // No sign extension needed since values are unsigned.
458
+ __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
459
+
460
+ __m512i a_u4_vec, b_u4_vec;
461
+ __m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
462
+ __m512i diff_low_u8x64, diff_high_u8x64;
463
+ __m512i d2_i32x16 = _mm512_setzero_si512();
464
+
465
+ nk_sqeuclidean_u4_icelake_cycle:
466
+ if (n_bytes < 64) {
467
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
468
+ a_u4_vec = _mm512_maskz_loadu_epi8(mask, a);
469
+ b_u4_vec = _mm512_maskz_loadu_epi8(mask, b);
470
+ n_bytes = 0;
471
+ }
472
+ else {
473
+ a_u4_vec = _mm512_loadu_epi8(a);
474
+ b_u4_vec = _mm512_loadu_epi8(b);
475
+ a += 64, b += 64, n_bytes -= 64;
476
+ }
477
+
478
+ // Extract nibbles as unsigned [0,15]
479
+ a_low_u8x64 = _mm512_and_si512(a_u4_vec, nibble_mask_u8x64);
480
+ a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4_vec, 4), nibble_mask_u8x64);
481
+ b_low_u8x64 = _mm512_and_si512(b_u4_vec, nibble_mask_u8x64);
482
+ b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4_vec, 4), nibble_mask_u8x64);
483
+
484
+ // Absolute difference for unsigned: |a-b| = (a ⊖ b) | (b ⊖ a) where ⊖ is saturating sub
485
+ diff_low_u8x64 = _mm512_or_si512(_mm512_subs_epu8(a_low_u8x64, b_low_u8x64),
486
+ _mm512_subs_epu8(b_low_u8x64, a_low_u8x64));
487
+ diff_high_u8x64 = _mm512_or_si512(_mm512_subs_epu8(a_high_u8x64, b_high_u8x64),
488
+ _mm512_subs_epu8(b_high_u8x64, a_high_u8x64));
489
+
490
+ // Square and accumulate using DPBUSD
491
+ d2_i32x16 = _mm512_dpbusd_epi32(d2_i32x16, diff_low_u8x64, diff_low_u8x64);
492
+ d2_i32x16 = _mm512_dpbusd_epi32(d2_i32x16, diff_high_u8x64, diff_high_u8x64);
493
+ if (n_bytes) goto nk_sqeuclidean_u4_icelake_cycle;
494
+
495
+ *result = (nk_u32_t)_mm512_reduce_add_epi32(d2_i32x16);
496
+ }
497
+ NK_PUBLIC void nk_euclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
498
+ nk_u32_t d2;
499
+ nk_sqeuclidean_u4_icelake(a, b, n, &d2);
500
+ *result = nk_f32_sqrt_haswell((nk_f32_t)d2);
501
+ }
502
+
503
+ NK_PUBLIC void nk_angular_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
504
+ // u4 values are packed as nibbles: two 4-bit unsigned values per byte.
505
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
506
+ n = nk_size_round_up_to_multiple_(n, 2);
507
+ nk_size_t n_bytes = n / 2;
508
+
509
+ // Angular distance for unsigned 4-bit integers ∈ [0, 15].
510
+ // Since values are unsigned and small, we can use DPBUSD directly for both
511
+ // dot product and norms without any sign handling.
512
+ //
513
+ // DPBUSD computes: ZeroExtend(a) * SignExtend(b), but for values ∈ [0, 15],
514
+ // sign extension is identity (no high bit set), so it works correctly.
515
+ __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
516
+ __m512i const zeros_i8x64 = _mm512_setzero_si512();
517
+
518
+ __m512i a_u4_vec, b_u4_vec;
519
+ __m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
520
+
521
+ __m512i ab_i32x16 = zeros_i8x64;
522
+ __m512i a2_i64x8 = zeros_i8x64;
523
+ __m512i b2_i64x8 = zeros_i8x64;
524
+
525
+ nk_angular_u4_icelake_cycle:
526
+ if (n_bytes < 64) {
527
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
528
+ a_u4_vec = _mm512_maskz_loadu_epi8(mask, a);
529
+ b_u4_vec = _mm512_maskz_loadu_epi8(mask, b);
530
+ n_bytes = 0;
531
+ }
532
+ else {
533
+ a_u4_vec = _mm512_loadu_epi8(a);
534
+ b_u4_vec = _mm512_loadu_epi8(b);
535
+ a += 64, b += 64, n_bytes -= 64;
536
+ }
537
+
538
+ // Extract nibbles as unsigned [0,15]
539
+ a_low_u8x64 = _mm512_and_si512(a_u4_vec, nibble_mask_u8x64);
540
+ a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4_vec, 4), nibble_mask_u8x64);
541
+ b_low_u8x64 = _mm512_and_si512(b_u4_vec, nibble_mask_u8x64);
542
+ b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4_vec, 4), nibble_mask_u8x64);
543
+
544
+ // Dot product with DPBUSD (safe for unsigned [0,15])
545
+ ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, a_low_u8x64, b_low_u8x64);
546
+ ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, a_high_u8x64, b_high_u8x64);
547
+
548
+ // Squared norms: compute a² per nibble using lookup table for efficiency
549
+ // Squares lookup: 0 → 0, 1 → 1, 2 → 4, ..., 15 → 225
550
+ __m512i const u4_squares_lookup_u8x64 = _mm512_set_epi8(
551
+ (char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
552
+ (char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
553
+ (char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
554
+ (char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0);
555
+
556
+ __m512i a2_lo_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, a_low_u8x64);
557
+ __m512i a2_hi_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, a_high_u8x64);
558
+ __m512i b2_lo_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, b_low_u8x64);
559
+ __m512i b2_hi_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, b_high_u8x64);
560
+
561
+ // Accumulate low and high squares separately using SAD to avoid u8 overflow
562
+ a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(a2_lo_u8x64, zeros_i8x64));
563
+ a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(a2_hi_u8x64, zeros_i8x64));
564
+ b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(b2_lo_u8x64, zeros_i8x64));
565
+ b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(b2_hi_u8x64, zeros_i8x64));
566
+ if (n_bytes) goto nk_angular_u4_icelake_cycle;
567
+
568
+ nk_i32_t ab = _mm512_reduce_add_epi32(ab_i32x16);
569
+ nk_i64_t a2 = _mm512_reduce_add_epi64(a2_i64x8);
570
+ nk_i64_t b2 = _mm512_reduce_add_epi64(b2_i64x8);
571
+ *result = nk_angular_normalize_f32_haswell_(ab, (nk_f32_t)a2, (nk_f32_t)b2);
572
+ }
573
+
574
+ #if defined(__clang__)
575
+ #pragma clang attribute pop
576
+ #elif defined(__GNUC__)
577
+ #pragma GCC pop_options
578
+ #endif
579
+
580
+ #if defined(__cplusplus)
581
+ } // extern "C"
582
+ #endif
583
+
584
+ #endif // NK_TARGET_ICELAKE
585
+ #endif // NK_TARGET_X86_
586
+ #endif // NK_SPATIAL_ICELAKE_H