numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,290 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for Genoa.
3
+ * @file include/numkong/spatial/genoa.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ */
9
+ #ifndef NK_SPATIAL_GENOA_H
10
+ #define NK_SPATIAL_GENOA_H
11
+
12
+ #if NK_TARGET_X86_
13
+ #if NK_TARGET_GENOA
14
+
15
+ #include "numkong/types.h"
16
+ #include "numkong/spatial/haswell.h" // `nk_angular_normalize_f32_haswell_`, `nk_f32_sqrt_haswell`
17
+ #include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
18
+ #include "numkong/cast/icelake.h" // `nk_e4m3x32_to_bf16x32_icelake_`
19
+
20
+ #if defined(__cplusplus)
21
+ extern "C" {
22
+ #endif
23
+
24
+ #if defined(__clang__)
25
+ #pragma clang attribute push( \
26
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512bf16,f16c,fma,bmi,bmi2"))), \
27
+ apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512bf16", "f16c", "fma", "bmi", "bmi2")
31
+ #endif
32
+
33
+ NK_INTERNAL __m512i nk_substract_bf16x32_genoa_(__m512i a_i16, __m512i b_i16) {
34
+
35
+ nk_b512_vec_t d, a_f32_even, b_f32_even, d_f32_even, a_f32_odd, b_f32_odd, d_f32_odd;
36
+
37
+ // There are several approaches to perform subtraction in `bf16`. The first one is:
38
+ //
39
+ // Perform a couple of casts - each is a bitshift. To convert `bf16` to `f32`,
40
+ // expand it to 32-bit integers, then shift the bits by 16 to the left.
41
+ // Then subtract as floats, and shift back. During expansion, we will double the space,
42
+ // and should use separate registers for top and bottom halves.
43
+ // Some compilers don't have `_mm512_extracti32x8_epi32`, so we use `_mm512_extracti64x4_epi64`:
44
+ //
45
+ // a_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32(
46
+ // _mm512_cvtepu16_epi32(_mm512_castsi512_si256(a_i16)), 16));
47
+ // b_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32(
48
+ // _mm512_cvtepu16_epi32(_mm512_castsi512_si256(b_i16)), 16));
49
+ // a_f32_top.fvec =_mm512_castsi512_ps(
50
+ // _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(a_i16, 1)), 16));
51
+ // b_f32_top.fvec =_mm512_castsi512_ps(
52
+ // _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(b_i16, 1)), 16));
53
+ // d_f32_top.fvec = _mm512_sub_ps(a_f32_top.fvec, b_f32_top.fvec);
54
+ // d_f32_bot.fvec = _mm512_sub_ps(a_f32_bot.fvec, b_f32_bot.fvec);
55
+ // d.ivec = _mm512_castsi256_si512(_mm512_cvtepi32_epi16(
56
+ // _mm512_srli_epi32(_mm512_castps_si512(d_f32_bot.fvec), 16)));
57
+ // d.ivec = _mm512_inserti64x4(d.ivec, _mm512_cvtepi32_epi16(
58
+ // _mm512_srli_epi32(_mm512_castps_si512(d_f32_top.fvec), 16)), 1);
59
+ //
60
+ // Instead of using multple shifts and an insertion, we can achieve similar result with fewer expensive
61
+ // calls to `_mm512_permutex2var_epi16`, or a cheap `_mm512_mask_shuffle_epi8` and blend:
62
+ //
63
+ a_f32_odd.zmm = _mm512_and_si512(a_i16, _mm512_set1_epi32(0xFFFF0000));
64
+ a_f32_even.zmm = _mm512_slli_epi32(a_i16, 16);
65
+ b_f32_odd.zmm = _mm512_and_si512(b_i16, _mm512_set1_epi32(0xFFFF0000));
66
+ b_f32_even.zmm = _mm512_slli_epi32(b_i16, 16);
67
+
68
+ d_f32_odd.zmm_ps = _mm512_sub_ps(a_f32_odd.zmm_ps, b_f32_odd.zmm_ps);
69
+ d_f32_even.zmm_ps = _mm512_sub_ps(a_f32_even.zmm_ps, b_f32_even.zmm_ps);
70
+
71
+ d_f32_even.zmm = _mm512_srli_epi32(d_f32_even.zmm, 16);
72
+ d.zmm = _mm512_mask_blend_epi16(0x55555555, d_f32_odd.zmm, d_f32_even.zmm);
73
+
74
+ return d.zmm;
75
+ }
76
+
77
+ NK_PUBLIC void nk_sqeuclidean_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
78
+ __m512 a_sq_f32x16 = _mm512_setzero_ps();
79
+ __m512 b_sq_f32x16 = _mm512_setzero_ps();
80
+ __m512 ab_f32x16 = _mm512_setzero_ps();
81
+ __m512i a_bf16x32, b_bf16x32;
82
+
83
+ nk_sqeuclidean_bf16_genoa_cycle:
84
+ if (n < 32) {
85
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
86
+ a_bf16x32 = _mm512_maskz_loadu_epi16(mask, a);
87
+ b_bf16x32 = _mm512_maskz_loadu_epi16(mask, b);
88
+ n = 0;
89
+ }
90
+ else {
91
+ a_bf16x32 = _mm512_loadu_epi16(a);
92
+ b_bf16x32 = _mm512_loadu_epi16(b);
93
+ a += 32, b += 32, n -= 32;
94
+ }
95
+ a_sq_f32x16 = _mm512_dpbf16_ps(a_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(a_bf16x32));
96
+ b_sq_f32x16 = _mm512_dpbf16_ps(b_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
97
+ ab_f32x16 = _mm512_dpbf16_ps(ab_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
98
+ if (n) goto nk_sqeuclidean_bf16_genoa_cycle;
99
+
100
+ // (a-b)² = a² + b² - 2ab
101
+ __m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
102
+ *result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
103
+ }
104
+
105
+ NK_PUBLIC void nk_euclidean_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
106
+ nk_sqeuclidean_bf16_genoa(a, b, n, result);
107
+ *result = nk_f32_sqrt_haswell(*result);
108
+ }
109
+
110
+ NK_PUBLIC void nk_angular_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
111
+ __m512 dot_product_f32x16 = _mm512_setzero_ps();
112
+ __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
113
+ __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
114
+ __m512i a_bf16x32, b_bf16x32;
115
+
116
+ nk_angular_bf16_genoa_cycle:
117
+ if (n < 32) {
118
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
119
+ a_bf16x32 = _mm512_maskz_loadu_epi16(mask, a);
120
+ b_bf16x32 = _mm512_maskz_loadu_epi16(mask, b);
121
+ n = 0;
122
+ }
123
+ else {
124
+ a_bf16x32 = _mm512_loadu_epi16(a);
125
+ b_bf16x32 = _mm512_loadu_epi16(b);
126
+ a += 32, b += 32, n -= 32;
127
+ }
128
+ dot_product_f32x16 = _mm512_dpbf16_ps(dot_product_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
129
+ nk_m512bh_from_m512i_(b_bf16x32));
130
+ a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
131
+ nk_m512bh_from_m512i_(a_bf16x32));
132
+ b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
133
+ nk_m512bh_from_m512i_(b_bf16x32));
134
+ if (n) goto nk_angular_bf16_genoa_cycle;
135
+
136
+ nk_f32_t dot_product_f32 = nk_reduce_add_f32x16_skylake_(dot_product_f32x16);
137
+ nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
138
+ nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
139
+ *result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
140
+ }
141
+
142
+ NK_PUBLIC void nk_sqeuclidean_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
143
+ __m512 a_sq_f32x16 = _mm512_setzero_ps();
144
+ __m512 b_sq_f32x16 = _mm512_setzero_ps();
145
+ __m512 ab_f32x16 = _mm512_setzero_ps();
146
+ __m256i a_e4m3x32, b_e4m3x32;
147
+
148
+ nk_sqeuclidean_e4m3_genoa_cycle:
149
+ if (n < 32) {
150
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
151
+ a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
152
+ b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
153
+ n = 0;
154
+ }
155
+ else {
156
+ a_e4m3x32 = _mm256_loadu_epi8(a);
157
+ b_e4m3x32 = _mm256_loadu_epi8(b);
158
+ a += 32, b += 32, n -= 32;
159
+ }
160
+ __m512i a_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(a_e4m3x32);
161
+ __m512i b_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(b_e4m3x32);
162
+ a_sq_f32x16 = _mm512_dpbf16_ps(a_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(a_bf16x32));
163
+ b_sq_f32x16 = _mm512_dpbf16_ps(b_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
164
+ ab_f32x16 = _mm512_dpbf16_ps(ab_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
165
+ if (n) goto nk_sqeuclidean_e4m3_genoa_cycle;
166
+
167
+ // (a-b)² = a² + b² - 2ab
168
+ __m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
169
+ *result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
170
+ }
171
+
172
+ NK_PUBLIC void nk_euclidean_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
173
+ nk_sqeuclidean_e4m3_genoa(a, b, n, result);
174
+ *result = nk_f32_sqrt_haswell(*result);
175
+ }
176
+
177
+ NK_PUBLIC void nk_angular_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
178
+ __m512 dot_f32x16 = _mm512_setzero_ps();
179
+ __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
180
+ __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
181
+ __m256i a_e4m3x32, b_e4m3x32;
182
+
183
+ nk_angular_e4m3_genoa_cycle:
184
+ if (n < 32) {
185
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
186
+ a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
187
+ b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
188
+ n = 0;
189
+ }
190
+ else {
191
+ a_e4m3x32 = _mm256_loadu_epi8(a);
192
+ b_e4m3x32 = _mm256_loadu_epi8(b);
193
+ a += 32, b += 32, n -= 32;
194
+ }
195
+ __m512i a_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(a_e4m3x32);
196
+ __m512i b_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(b_e4m3x32);
197
+ dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
198
+ a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
199
+ nk_m512bh_from_m512i_(a_bf16x32));
200
+ b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
201
+ nk_m512bh_from_m512i_(b_bf16x32));
202
+ if (n) goto nk_angular_e4m3_genoa_cycle;
203
+
204
+ nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
205
+ nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
206
+ nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
207
+ *result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
208
+ }
209
+
210
+ NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
211
+ __m512 a_sq_f32x16 = _mm512_setzero_ps();
212
+ __m512 b_sq_f32x16 = _mm512_setzero_ps();
213
+ __m512 ab_f32x16 = _mm512_setzero_ps();
214
+ __m256i a_e5m2x32, b_e5m2x32;
215
+
216
+ nk_sqeuclidean_e5m2_genoa_cycle:
217
+ if (n < 32) {
218
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
219
+ a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
220
+ b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
221
+ n = 0;
222
+ }
223
+ else {
224
+ a_e5m2x32 = _mm256_loadu_epi8(a);
225
+ b_e5m2x32 = _mm256_loadu_epi8(b);
226
+ a += 32, b += 32, n -= 32;
227
+ }
228
+ __m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
229
+ __m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
230
+ a_sq_f32x16 = _mm512_dpbf16_ps(a_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(a_bf16x32));
231
+ b_sq_f32x16 = _mm512_dpbf16_ps(b_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
232
+ ab_f32x16 = _mm512_dpbf16_ps(ab_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
233
+ if (n) goto nk_sqeuclidean_e5m2_genoa_cycle;
234
+
235
+ // (a-b)² = a² + b² - 2ab
236
+ __m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
237
+ *result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
238
+ }
239
+
240
+ NK_PUBLIC void nk_euclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
241
+ nk_sqeuclidean_e5m2_genoa(a, b, n, result);
242
+ *result = nk_f32_sqrt_haswell(*result);
243
+ }
244
+
245
+ NK_PUBLIC void nk_angular_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
246
+ __m512 dot_f32x16 = _mm512_setzero_ps();
247
+ __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
248
+ __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
249
+ __m256i a_e5m2x32, b_e5m2x32;
250
+
251
+ nk_angular_e5m2_genoa_cycle:
252
+ if (n < 32) {
253
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
254
+ a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
255
+ b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
256
+ n = 0;
257
+ }
258
+ else {
259
+ a_e5m2x32 = _mm256_loadu_epi8(a);
260
+ b_e5m2x32 = _mm256_loadu_epi8(b);
261
+ a += 32, b += 32, n -= 32;
262
+ }
263
+ __m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
264
+ __m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
265
+ dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
266
+ a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
267
+ nk_m512bh_from_m512i_(a_bf16x32));
268
+ b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
269
+ nk_m512bh_from_m512i_(b_bf16x32));
270
+ if (n) goto nk_angular_e5m2_genoa_cycle;
271
+
272
+ nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
273
+ nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
274
+ nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
275
+ *result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
276
+ }
277
+
278
+ #if defined(__clang__)
279
+ #pragma clang attribute pop
280
+ #elif defined(__GNUC__)
281
+ #pragma GCC pop_options
282
+ #endif
283
+
284
+ #if defined(__cplusplus)
285
+ } // extern "C"
286
+ #endif
287
+
288
+ #endif // NK_TARGET_GENOA
289
+ #endif // NK_TARGET_X86_
290
+ #endif // NK_SPATIAL_GENOA_H