numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,12 +8,11 @@
8
8
  *
9
9
  * @section spatial_skylake_instructions Key AVX-512 Spatial Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p05
13
- * _mm512_sub_ps VSUBPS (ZMM, ZMM, ZMM) 4cy 0.5/cy p05
14
- * _mm512_rsqrt14_ps VRSQRT14PS (ZMM, ZMM) 4cy 1/cy p0
15
- * _mm512_sqrt_ps VSQRTPS (ZMM, ZMM) 12cy 3cy p0
16
- * _mm512_reduce_add_ps (sequence) ~8-10cy - -
11
+ * Intrinsic Instruction Skylake-X Genoa
12
+ * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
13
+ * _mm512_sub_ps VSUBPS (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p23
14
+ * _mm512_rsqrt14_ps VRSQRT14PS (ZMM, ZMM) 7cy @ p0+p0+p05 5cy @ p01
15
+ * _mm512_sqrt_ps VSQRTPS (ZMM, ZMM) 20cy @ p0+p0+p05 15cy @ p01
17
16
  *
18
17
  * Distance computations benefit from Skylake-X's dual FMA units achieving 0.5cy throughput for
19
18
  * fused multiply-add operations. VRSQRT14PS provides ~14-bit precision reciprocal square root;
@@ -43,21 +42,21 @@ extern "C" {
43
42
 
44
43
  /** @brief Reciprocal square root of 16 floats with Newton-Raphson refinement (~28-bit precision). */
45
44
  NK_INTERNAL __m512 nk_rsqrt_f32x16_skylake_(__m512 x) {
46
- __m512 rsqrt = _mm512_rsqrt14_ps(x);
47
- __m512 nr = _mm512_mul_ps(_mm512_mul_ps(x, rsqrt), rsqrt);
48
- nr = _mm512_sub_ps(_mm512_set1_ps(3.0f), nr);
49
- return _mm512_mul_ps(_mm512_mul_ps(_mm512_set1_ps(0.5f), rsqrt), nr);
45
+ __m512 rsqrt_f32x16 = _mm512_rsqrt14_ps(x);
46
+ __m512 nr_f32x16 = _mm512_mul_ps(_mm512_mul_ps(x, rsqrt_f32x16), rsqrt_f32x16);
47
+ nr_f32x16 = _mm512_sub_ps(_mm512_set1_ps(3.0f), nr_f32x16);
48
+ return _mm512_mul_ps(_mm512_mul_ps(_mm512_set1_ps(0.5f), rsqrt_f32x16), nr_f32x16);
50
49
  }
51
50
 
52
51
  /** @brief Reciprocal square root of 8 doubles with Newton-Raphson refinement (~28-bit precision). */
53
52
  NK_INTERNAL __m512d nk_rsqrt_f64x8_skylake_(__m512d x) {
54
- __m512d rsqrt = _mm512_rsqrt14_pd(x);
55
- __m512d nr = _mm512_mul_pd(_mm512_mul_pd(x, rsqrt), rsqrt);
56
- nr = _mm512_sub_pd(_mm512_set1_pd(3.0), nr);
57
- return _mm512_mul_pd(_mm512_mul_pd(_mm512_set1_pd(0.5), rsqrt), nr);
53
+ __m512d rsqrt_f64x8 = _mm512_rsqrt14_pd(x);
54
+ __m512d nr_f64x8 = _mm512_mul_pd(_mm512_mul_pd(x, rsqrt_f64x8), rsqrt_f64x8);
55
+ nr_f64x8 = _mm512_sub_pd(_mm512_set1_pd(3.0), nr_f64x8);
56
+ return _mm512_mul_pd(_mm512_mul_pd(_mm512_set1_pd(0.5), rsqrt_f64x8), nr_f64x8);
58
57
  }
59
58
 
60
- #pragma region - Traditional Floats
59
+ #pragma region F32 and F64 Floats
61
60
 
62
61
  NK_PUBLIC void nk_sqeuclidean_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
63
62
  // Upcast to f64 for higher precision accumulation
@@ -282,8 +281,8 @@ NK_INTERNAL void nk_euclidean_through_f64_from_dot_skylake_(nk_b128_vec_t dots,
282
281
  results->xmm_ps = _mm256_cvtpd_ps(dist_f64x4);
283
282
  }
284
283
 
285
- #pragma endregion - Traditional Floats
286
- #pragma region - Smaller Floats
284
+ #pragma endregion F32 and F64 Floats
285
+ #pragma region F16 and BF16 Floats
287
286
 
288
287
  NK_PUBLIC void nk_sqeuclidean_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
289
288
  __m512 sum_f32x16 = _mm512_setzero_ps();
@@ -348,22 +347,22 @@ nk_angular_f16_skylake_cycle:
348
347
 
349
348
  NK_PUBLIC void nk_sqeuclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
350
349
  __m512 sum_f32x16 = _mm512_setzero_ps();
351
- __m128i a_e4m3x16, b_e4m3x16;
350
+ __m128i a_e4m3_u8x16, b_e4m3_u8x16;
352
351
 
353
352
  nk_sqeuclidean_e4m3_skylake_cycle:
354
353
  if (n < 16) {
355
354
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
356
- a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a);
357
- b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b);
355
+ a_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, a);
356
+ b_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, b);
358
357
  n = 0;
359
358
  }
360
359
  else {
361
- a_e4m3x16 = _mm_loadu_si128((__m128i const *)a);
362
- b_e4m3x16 = _mm_loadu_si128((__m128i const *)b);
360
+ a_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)a);
361
+ b_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)b);
363
362
  a += 16, b += 16, n -= 16;
364
363
  }
365
- __m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
366
- __m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3x16);
364
+ __m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3_u8x16);
365
+ __m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3_u8x16);
367
366
  __m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
368
367
  sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
369
368
  if (n) goto nk_sqeuclidean_e4m3_skylake_cycle;
@@ -380,22 +379,22 @@ NK_PUBLIC void nk_angular_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, n
380
379
  __m512 dot_f32x16 = _mm512_setzero_ps();
381
380
  __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
382
381
  __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
383
- __m128i a_e4m3x16, b_e4m3x16;
382
+ __m128i a_e4m3_u8x16, b_e4m3_u8x16;
384
383
 
385
384
  nk_angular_e4m3_skylake_cycle:
386
385
  if (n < 16) {
387
386
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
388
- a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a);
389
- b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b);
387
+ a_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, a);
388
+ b_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, b);
390
389
  n = 0;
391
390
  }
392
391
  else {
393
- a_e4m3x16 = _mm_loadu_si128((__m128i const *)a);
394
- b_e4m3x16 = _mm_loadu_si128((__m128i const *)b);
392
+ a_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)a);
393
+ b_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)b);
395
394
  a += 16, b += 16, n -= 16;
396
395
  }
397
- __m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
398
- __m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3x16);
396
+ __m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3_u8x16);
397
+ __m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3_u8x16);
399
398
  dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
400
399
  a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
401
400
  b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
@@ -409,22 +408,22 @@ nk_angular_e4m3_skylake_cycle:
409
408
 
410
409
  NK_PUBLIC void nk_sqeuclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
411
410
  __m512 sum_f32x16 = _mm512_setzero_ps();
412
- __m128i a_e5m2x16, b_e5m2x16;
411
+ __m128i a_e5m2_u8x16, b_e5m2_u8x16;
413
412
 
414
413
  nk_sqeuclidean_e5m2_skylake_cycle:
415
414
  if (n < 16) {
416
415
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
417
- a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a);
418
- b_e5m2x16 = _mm_maskz_loadu_epi8(mask, b);
416
+ a_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, a);
417
+ b_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, b);
419
418
  n = 0;
420
419
  }
421
420
  else {
422
- a_e5m2x16 = _mm_loadu_si128((__m128i const *)a);
423
- b_e5m2x16 = _mm_loadu_si128((__m128i const *)b);
421
+ a_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)a);
422
+ b_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)b);
424
423
  a += 16, b += 16, n -= 16;
425
424
  }
426
- __m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
427
- __m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2x16);
425
+ __m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2_u8x16);
426
+ __m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2_u8x16);
428
427
  __m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
429
428
  sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
430
429
  if (n) goto nk_sqeuclidean_e5m2_skylake_cycle;
@@ -441,22 +440,22 @@ NK_PUBLIC void nk_angular_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, n
441
440
  __m512 dot_f32x16 = _mm512_setzero_ps();
442
441
  __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
443
442
  __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
444
- __m128i a_e5m2x16, b_e5m2x16;
443
+ __m128i a_e5m2_u8x16, b_e5m2_u8x16;
445
444
 
446
445
  nk_angular_e5m2_skylake_cycle:
447
446
  if (n < 16) {
448
447
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
449
- a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a);
450
- b_e5m2x16 = _mm_maskz_loadu_epi8(mask, b);
448
+ a_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, a);
449
+ b_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, b);
451
450
  n = 0;
452
451
  }
453
452
  else {
454
- a_e5m2x16 = _mm_loadu_si128((__m128i const *)a);
455
- b_e5m2x16 = _mm_loadu_si128((__m128i const *)b);
453
+ a_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)a);
454
+ b_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)b);
456
455
  a += 16, b += 16, n -= 16;
457
456
  }
458
- __m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
459
- __m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2x16);
457
+ __m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2_u8x16);
458
+ __m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2_u8x16);
460
459
  dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
461
460
  a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
462
461
  b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
@@ -470,22 +469,22 @@ nk_angular_e5m2_skylake_cycle:
470
469
 
471
470
  NK_PUBLIC void nk_sqeuclidean_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
472
471
  __m512 sum_f32x16 = _mm512_setzero_ps();
473
- __m128i a_e2m3x16, b_e2m3x16;
472
+ __m128i a_e2m3_u8x16, b_e2m3_u8x16;
474
473
 
475
474
  nk_sqeuclidean_e2m3_skylake_cycle:
476
475
  if (n < 16) {
477
476
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
478
- a_e2m3x16 = _mm_maskz_loadu_epi8(mask, a);
479
- b_e2m3x16 = _mm_maskz_loadu_epi8(mask, b);
477
+ a_e2m3_u8x16 = _mm_maskz_loadu_epi8(mask, a);
478
+ b_e2m3_u8x16 = _mm_maskz_loadu_epi8(mask, b);
480
479
  n = 0;
481
480
  }
482
481
  else {
483
- a_e2m3x16 = _mm_loadu_si128((__m128i const *)a);
484
- b_e2m3x16 = _mm_loadu_si128((__m128i const *)b);
482
+ a_e2m3_u8x16 = _mm_loadu_si128((__m128i const *)a);
483
+ b_e2m3_u8x16 = _mm_loadu_si128((__m128i const *)b);
485
484
  a += 16, b += 16, n -= 16;
486
485
  }
487
- __m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3x16);
488
- __m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3x16);
486
+ __m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3_u8x16);
487
+ __m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3_u8x16);
489
488
  __m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
490
489
  sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
491
490
  if (n) goto nk_sqeuclidean_e2m3_skylake_cycle;
@@ -502,22 +501,22 @@ NK_PUBLIC void nk_angular_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, n
502
501
  __m512 dot_f32x16 = _mm512_setzero_ps();
503
502
  __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
504
503
  __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
505
- __m128i a_e2m3x16, b_e2m3x16;
504
+ __m128i a_e2m3_u8x16, b_e2m3_u8x16;
506
505
 
507
506
  nk_angular_e2m3_skylake_cycle:
508
507
  if (n < 16) {
509
508
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
510
- a_e2m3x16 = _mm_maskz_loadu_epi8(mask, a);
511
- b_e2m3x16 = _mm_maskz_loadu_epi8(mask, b);
509
+ a_e2m3_u8x16 = _mm_maskz_loadu_epi8(mask, a);
510
+ b_e2m3_u8x16 = _mm_maskz_loadu_epi8(mask, b);
512
511
  n = 0;
513
512
  }
514
513
  else {
515
- a_e2m3x16 = _mm_loadu_si128((__m128i const *)a);
516
- b_e2m3x16 = _mm_loadu_si128((__m128i const *)b);
514
+ a_e2m3_u8x16 = _mm_loadu_si128((__m128i const *)a);
515
+ b_e2m3_u8x16 = _mm_loadu_si128((__m128i const *)b);
517
516
  a += 16, b += 16, n -= 16;
518
517
  }
519
- __m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3x16);
520
- __m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3x16);
518
+ __m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3_u8x16);
519
+ __m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3_u8x16);
521
520
  dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
522
521
  a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
523
522
  b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
@@ -531,22 +530,22 @@ nk_angular_e2m3_skylake_cycle:
531
530
 
532
531
  NK_PUBLIC void nk_sqeuclidean_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
533
532
  __m512 sum_f32x16 = _mm512_setzero_ps();
534
- __m128i a_e3m2x16, b_e3m2x16;
533
+ __m128i a_e3m2_u8x16, b_e3m2_u8x16;
535
534
 
536
535
  nk_sqeuclidean_e3m2_skylake_cycle:
537
536
  if (n < 16) {
538
537
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
539
- a_e3m2x16 = _mm_maskz_loadu_epi8(mask, a);
540
- b_e3m2x16 = _mm_maskz_loadu_epi8(mask, b);
538
+ a_e3m2_u8x16 = _mm_maskz_loadu_epi8(mask, a);
539
+ b_e3m2_u8x16 = _mm_maskz_loadu_epi8(mask, b);
541
540
  n = 0;
542
541
  }
543
542
  else {
544
- a_e3m2x16 = _mm_loadu_si128((__m128i const *)a);
545
- b_e3m2x16 = _mm_loadu_si128((__m128i const *)b);
543
+ a_e3m2_u8x16 = _mm_loadu_si128((__m128i const *)a);
544
+ b_e3m2_u8x16 = _mm_loadu_si128((__m128i const *)b);
546
545
  a += 16, b += 16, n -= 16;
547
546
  }
548
- __m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2x16);
549
- __m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2x16);
547
+ __m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2_u8x16);
548
+ __m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2_u8x16);
550
549
  __m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
551
550
  sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
552
551
  if (n) goto nk_sqeuclidean_e3m2_skylake_cycle;
@@ -563,22 +562,22 @@ NK_PUBLIC void nk_angular_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, n
563
562
  __m512 dot_f32x16 = _mm512_setzero_ps();
564
563
  __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
565
564
  __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
566
- __m128i a_e3m2x16, b_e3m2x16;
565
+ __m128i a_e3m2_u8x16, b_e3m2_u8x16;
567
566
 
568
567
  nk_angular_e3m2_skylake_cycle:
569
568
  if (n < 16) {
570
569
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
571
- a_e3m2x16 = _mm_maskz_loadu_epi8(mask, a);
572
- b_e3m2x16 = _mm_maskz_loadu_epi8(mask, b);
570
+ a_e3m2_u8x16 = _mm_maskz_loadu_epi8(mask, a);
571
+ b_e3m2_u8x16 = _mm_maskz_loadu_epi8(mask, b);
573
572
  n = 0;
574
573
  }
575
574
  else {
576
- a_e3m2x16 = _mm_loadu_si128((__m128i const *)a);
577
- b_e3m2x16 = _mm_loadu_si128((__m128i const *)b);
575
+ a_e3m2_u8x16 = _mm_loadu_si128((__m128i const *)a);
576
+ b_e3m2_u8x16 = _mm_loadu_si128((__m128i const *)b);
578
577
  a += 16, b += 16, n -= 16;
579
578
  }
580
- __m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2x16);
581
- __m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2x16);
579
+ __m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2_u8x16);
580
+ __m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2_u8x16);
582
581
  dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
583
582
  a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
584
583
  b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
@@ -600,7 +599,7 @@ nk_angular_e3m2_skylake_cycle:
600
599
  } // extern "C"
601
600
  #endif
602
601
 
603
- #pragma endregion - Smaller Floats
602
+ #pragma endregion F16 and BF16 Floats
604
603
  #endif // NK_TARGET_SKYLAKE
605
604
  #endif // NK_TARGET_X86_
606
605
  #endif // NK_SPATIAL_SKYLAKE_H
@@ -8,19 +8,19 @@
8
8
  *
9
9
  * @section spatial_sve_instructions ARM SVE Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * svld1_f32 LD1W (Z.S, P/Z, [Xn]) 4-6cy 2/cy
13
- * svsub_f32_x FSUB (Z.S, P/M, Z.S, Z.S) 3cy 2/cy
14
- * svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy 2/cy
15
- * svaddv_f32 FADDV (S, P, Z.S) 6cy 1/cy
16
- * svdupq_n_f32 DUP (Z.S, #imm) 1cy 2/cy
17
- * svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy 1/cy
18
- * svptrue_b32 PTRUE (P.S, pattern) 1cy 2/cy
19
- * svcntw CNTW (Xd) 1cy 2/cy
20
- * svld1_f64 LD1D (Z.D, P/Z, [Xn]) 4-6cy 2/cy
21
- * svsub_f64_x FSUB (Z.D, P/M, Z.D, Z.D) 3cy 2/cy
22
- * svmla_f64_x FMLA (Z.D, P/M, Z.D, Z.D) 4cy 2/cy
23
- * svaddv_f64 FADDV (D, P, Z.D) 6cy 1/cy
11
+ * Intrinsic Instruction V1
12
+ * svld1_f32 LD1W (Z.S, P/Z, [Xn]) 4-6cy @ 2p
13
+ * svsub_f32_x FSUB (Z.S, P/M, Z.S, Z.S) 3cy @ 2p
14
+ * svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy @ 2p
15
+ * svaddv_f32 FADDV (S, P, Z.S) 6cy @ 1p
16
+ * svdupq_n_f32 DUP (Z.S, #imm) 1cy @ 2p
17
+ * svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy @ 1p
18
+ * svptrue_b32 PTRUE (P.S, pattern) 1cy @ 2p
19
+ * svcntw CNTW (Xd) 1cy @ 2p
20
+ * svld1_f64 LD1D (Z.D, P/Z, [Xn]) 4-6cy @ 2p
21
+ * svsub_f64_x FSUB (Z.D, P/M, Z.D, Z.D) 3cy @ 2p
22
+ * svmla_f64_x FMLA (Z.D, P/M, Z.D, Z.D) 4cy @ 2p
23
+ * svaddv_f64 FADDV (D, P, Z.D) 6cy @ 1p
24
24
  *
25
25
  * SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
26
26
  * and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
@@ -63,10 +63,10 @@ extern "C" {
63
63
  * @param x Input vector (must be positive for meaningful results)
64
64
  * @return Approximate 1/sqrt(x) with ~23-bit mantissa accuracy
65
65
  */
66
- NK_INTERNAL svfloat32_t nk_rsqrt_f32x_sve_(svbool_t predicate, svfloat32_t x) NK_STREAMING_COMPATIBLE_ {
66
+ NK_INTERNAL svfloat32_t nk_rsqrt_f32x_sve_(svbool_t predicate_b32x, svfloat32_t x) NK_STREAMING_COMPATIBLE_ {
67
67
  svfloat32_t r = svrsqrte_f32(x);
68
- r = svmul_f32_x(predicate, r, svrsqrts_f32(svmul_f32_x(predicate, x, r), r));
69
- r = svmul_f32_x(predicate, r, svrsqrts_f32(svmul_f32_x(predicate, x, r), r));
68
+ r = svmul_f32_x(predicate_b32x, r, svrsqrts_f32(svmul_f32_x(predicate_b32x, x, r), r));
69
+ r = svmul_f32_x(predicate_b32x, r, svrsqrts_f32(svmul_f32_x(predicate_b32x, x, r), r));
70
70
  return r;
71
71
  }
72
72
 
@@ -79,29 +79,39 @@ NK_INTERNAL svfloat32_t nk_rsqrt_f32x_sve_(svbool_t predicate, svfloat32_t x) NK
79
79
  * Marked `__arm_streaming_compatible` so the helper is callable from both streaming
80
80
  * (SME) and non-streaming (SVE) contexts without mode transitions.
81
81
  *
82
- * @param predicate Active-lane mask
82
+ * @param predicate_b32x Active-lane mask
83
83
  * @param x Input vector (must be positive for meaningful results)
84
84
  * @return Approximate 1/sqrt(x) with ~52-bit mantissa accuracy
85
85
  */
86
- NK_INTERNAL svfloat64_t nk_rsqrt_f64x_sve_(svbool_t predicate, svfloat64_t x) NK_STREAMING_COMPATIBLE_ {
86
+ NK_INTERNAL svfloat64_t nk_rsqrt_f64x_sve_(svbool_t predicate_b64x, svfloat64_t x) NK_STREAMING_COMPATIBLE_ {
87
87
  svfloat64_t r = svrsqrte_f64(x);
88
- r = svmul_f64_x(predicate, r, svrsqrts_f64(svmul_f64_x(predicate, x, r), r));
89
- r = svmul_f64_x(predicate, r, svrsqrts_f64(svmul_f64_x(predicate, x, r), r));
90
- r = svmul_f64_x(predicate, r, svrsqrts_f64(svmul_f64_x(predicate, x, r), r));
88
+ r = svmul_f64_x(predicate_b64x, r, svrsqrts_f64(svmul_f64_x(predicate_b64x, x, r), r));
89
+ r = svmul_f64_x(predicate_b64x, r, svrsqrts_f64(svmul_f64_x(predicate_b64x, x, r), r));
90
+ r = svmul_f64_x(predicate_b64x, r, svrsqrts_f64(svmul_f64_x(predicate_b64x, x, r), r));
91
91
  return r;
92
92
  }
93
93
 
94
94
  NK_PUBLIC void nk_sqeuclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
95
95
  nk_size_t i = 0;
96
- nk_size_t const vector_length = svcntd();
97
96
  svfloat64_t dist_sq_f64x = svdupq_n_f64(0.0, 0.0);
98
- for (; i < n; i += vector_length) {
99
- svbool_t predicate_f32x = svwhilelt_b32_u64(i, n);
100
- svbool_t predicate_f64x = svwhilelt_b64_u64(i, n);
101
- svfloat64_t a_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(predicate_f32x, a + i));
102
- svfloat64_t b_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(predicate_f32x, b + i));
103
- svfloat64_t diff_f64x = svsub_f64_x(predicate_f64x, a_f64x, b_f64x);
104
- dist_sq_f64x = svmla_f64_x(predicate_f64x, dist_sq_f64x, diff_f64x, diff_f64x);
97
+ for (; i < n; i += svcntw()) {
98
+ svbool_t predicate_b32x = svwhilelt_b32_u64(i, n);
99
+ svfloat32_t a_f32x = svld1_f32(predicate_b32x, a + i);
100
+ svfloat32_t b_f32x = svld1_f32(predicate_b32x, b + i);
101
+ nk_size_t remaining = n - i < svcntw() ? n - i : svcntw();
102
+
103
+ // svcvt_f64_f32_x widens only even-indexed f32 elements; svext by 1 shifts odd into even.
104
+ svbool_t pred_even_b64x = svwhilelt_b64_u64(0u, (remaining + 1) / 2);
105
+ svfloat64_t a_even_f64x = svcvt_f64_f32_x(pred_even_b64x, a_f32x);
106
+ svfloat64_t b_even_f64x = svcvt_f64_f32_x(pred_even_b64x, b_f32x);
107
+ svfloat64_t diff_even_f64x = svsub_f64_x(pred_even_b64x, a_even_f64x, b_even_f64x);
108
+ dist_sq_f64x = svmla_f64_m(pred_even_b64x, dist_sq_f64x, diff_even_f64x, diff_even_f64x);
109
+
110
+ svbool_t pred_odd_b64x = svwhilelt_b64_u64(0u, remaining / 2);
111
+ svfloat64_t a_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(a_f32x, a_f32x, 1));
112
+ svfloat64_t b_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(b_f32x, b_f32x, 1));
113
+ svfloat64_t diff_odd_f64x = svsub_f64_x(pred_odd_b64x, a_odd_f64x, b_odd_f64x);
114
+ dist_sq_f64x = svmla_f64_m(pred_odd_b64x, dist_sq_f64x, diff_odd_f64x, diff_odd_f64x);
105
115
  }
106
116
  nk_f64_t dist_sq_f64 = svaddv_f64(svptrue_b64(), dist_sq_f64x);
107
117
  *result = dist_sq_f64;
@@ -114,18 +124,29 @@ NK_PUBLIC void nk_euclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_siz
114
124
 
115
125
  NK_PUBLIC void nk_angular_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
116
126
  nk_size_t i = 0;
117
- nk_size_t const vector_length = svcntd();
118
127
  svfloat64_t ab_f64x = svdupq_n_f64(0.0, 0.0);
119
128
  svfloat64_t a2_f64x = svdupq_n_f64(0.0, 0.0);
120
129
  svfloat64_t b2_f64x = svdupq_n_f64(0.0, 0.0);
121
- for (; i < n; i += vector_length) {
122
- svbool_t predicate_f32x = svwhilelt_b32_u64(i, n);
123
- svbool_t predicate_f64x = svwhilelt_b64_u64(i, n);
124
- svfloat64_t a_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(predicate_f32x, a + i));
125
- svfloat64_t b_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(predicate_f32x, b + i));
126
- ab_f64x = svmla_f64_x(predicate_f64x, ab_f64x, a_f64x, b_f64x);
127
- a2_f64x = svmla_f64_x(predicate_f64x, a2_f64x, a_f64x, a_f64x);
128
- b2_f64x = svmla_f64_x(predicate_f64x, b2_f64x, b_f64x, b_f64x);
130
+ for (; i < n; i += svcntw()) {
131
+ svbool_t predicate_b32x = svwhilelt_b32_u64(i, n);
132
+ svfloat32_t a_f32x = svld1_f32(predicate_b32x, a + i);
133
+ svfloat32_t b_f32x = svld1_f32(predicate_b32x, b + i);
134
+ nk_size_t remaining = n - i < svcntw() ? n - i : svcntw();
135
+
136
+ // svcvt_f64_f32_x widens only even-indexed f32 elements; svext by 1 shifts odd into even.
137
+ svbool_t pred_even_b64x = svwhilelt_b64_u64(0u, (remaining + 1) / 2);
138
+ svfloat64_t a_even_f64x = svcvt_f64_f32_x(pred_even_b64x, a_f32x);
139
+ svfloat64_t b_even_f64x = svcvt_f64_f32_x(pred_even_b64x, b_f32x);
140
+ ab_f64x = svmla_f64_m(pred_even_b64x, ab_f64x, a_even_f64x, b_even_f64x);
141
+ a2_f64x = svmla_f64_m(pred_even_b64x, a2_f64x, a_even_f64x, a_even_f64x);
142
+ b2_f64x = svmla_f64_m(pred_even_b64x, b2_f64x, b_even_f64x, b_even_f64x);
143
+
144
+ svbool_t pred_odd_b64x = svwhilelt_b64_u64(0u, remaining / 2);
145
+ svfloat64_t a_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(a_f32x, a_f32x, 1));
146
+ svfloat64_t b_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(b_f32x, b_f32x, 1));
147
+ ab_f64x = svmla_f64_m(pred_odd_b64x, ab_f64x, a_odd_f64x, b_odd_f64x);
148
+ a2_f64x = svmla_f64_m(pred_odd_b64x, a2_f64x, a_odd_f64x, a_odd_f64x);
149
+ b2_f64x = svmla_f64_m(pred_odd_b64x, b2_f64x, b_odd_f64x, b_odd_f64x);
129
150
  }
130
151
 
131
152
  nk_f64_t ab_f64 = svaddv_f64(svptrue_b64(), ab_f64x);
@@ -139,29 +160,29 @@ NK_PUBLIC void nk_sqeuclidean_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_s
139
160
  nk_size_t i = 0;
140
161
  svfloat64_t sum_f64x = svdupq_n_f64(0.0, 0.0);
141
162
  svfloat64_t compensation_f64x = svdupq_n_f64(0.0, 0.0);
142
- svbool_t predicate_all_f64x = svptrue_b64();
163
+ svbool_t predicate_all_b64x = svptrue_b64();
143
164
  do {
144
- svbool_t predicate_f64x = svwhilelt_b64_u64(i, n);
145
- svfloat64_t a_f64x = svld1_f64(predicate_f64x, a + i);
146
- svfloat64_t b_f64x = svld1_f64(predicate_f64x, b + i);
147
- svfloat64_t diff_f64x = svsub_f64_x(predicate_f64x, a_f64x, b_f64x);
148
- svfloat64_t diff_sq_f64x = svmul_f64_x(predicate_f64x, diff_f64x, diff_f64x);
165
+ svbool_t predicate_b64x = svwhilelt_b64_u64(i, n);
166
+ svfloat64_t a_f64x = svld1_f64(predicate_b64x, a + i);
167
+ svfloat64_t b_f64x = svld1_f64(predicate_b64x, b + i);
168
+ svfloat64_t diff_f64x = svsub_f64_x(predicate_b64x, a_f64x, b_f64x);
169
+ svfloat64_t diff_sq_f64x = svmul_f64_x(predicate_b64x, diff_f64x, diff_f64x);
149
170
  // Neumaier: t = sum + x
150
- svfloat64_t t_f64x = svadd_f64_x(predicate_f64x, sum_f64x, diff_sq_f64x);
151
- svfloat64_t abs_sum_f64x = svabs_f64_x(predicate_f64x, sum_f64x);
171
+ svfloat64_t t_f64x = svadd_f64_m(predicate_b64x, sum_f64x, diff_sq_f64x);
172
+ svfloat64_t abs_sum_f64x = svabs_f64_x(predicate_b64x, sum_f64x);
152
173
  // diff_sq is already non-negative (it's a square), so svabs is unnecessary
153
- svbool_t sum_ge_x_f64x = svcmpge_f64(predicate_f64x, abs_sum_f64x, diff_sq_f64x);
174
+ svbool_t sum_ge_x_b64x = svcmpge_f64(predicate_b64x, abs_sum_f64x, diff_sq_f64x);
154
175
  // When |sum| >= |x|: comp += (sum - t) + x; when |x| > |sum|: comp += (x - t) + sum
155
- svfloat64_t comp_sum_large_f64x = svadd_f64_x(predicate_f64x, svsub_f64_x(predicate_f64x, sum_f64x, t_f64x),
176
+ svfloat64_t comp_sum_large_f64x = svadd_f64_x(predicate_b64x, svsub_f64_x(predicate_b64x, sum_f64x, t_f64x),
156
177
  diff_sq_f64x);
157
- svfloat64_t comp_x_large_f64x = svadd_f64_x(predicate_f64x, svsub_f64_x(predicate_f64x, diff_sq_f64x, t_f64x),
178
+ svfloat64_t comp_x_large_f64x = svadd_f64_x(predicate_b64x, svsub_f64_x(predicate_b64x, diff_sq_f64x, t_f64x),
158
179
  sum_f64x);
159
- svfloat64_t comp_update_f64x = svsel_f64(sum_ge_x_f64x, comp_sum_large_f64x, comp_x_large_f64x);
160
- compensation_f64x = svadd_f64_x(predicate_f64x, compensation_f64x, comp_update_f64x);
180
+ svfloat64_t comp_update_f64x = svsel_f64(sum_ge_x_b64x, comp_sum_large_f64x, comp_x_large_f64x);
181
+ compensation_f64x = svadd_f64_m(predicate_b64x, compensation_f64x, comp_update_f64x);
161
182
  sum_f64x = t_f64x;
162
183
  i += svcntd();
163
184
  } while (i < n);
164
- *result = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_f64x, compensation_f64x);
185
+ *result = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, sum_f64x, compensation_f64x);
165
186
  }
166
187
 
167
188
  NK_PUBLIC void nk_euclidean_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
@@ -177,35 +198,35 @@ NK_PUBLIC void nk_angular_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_
177
198
  svfloat64_t ab_compensation_f64x = svdupq_n_f64(0.0, 0.0);
178
199
  svfloat64_t a2_f64x = svdupq_n_f64(0.0, 0.0);
179
200
  svfloat64_t b2_f64x = svdupq_n_f64(0.0, 0.0);
180
- svbool_t predicate_all_f64x = svptrue_b64();
201
+ svbool_t predicate_all_b64x = svptrue_b64();
181
202
  do {
182
- svbool_t predicate_f64x = svwhilelt_b64_u64(i, n);
183
- svfloat64_t a_f64x = svld1_f64(predicate_f64x, a + i);
184
- svfloat64_t b_f64x = svld1_f64(predicate_f64x, b + i);
203
+ svbool_t predicate_b64x = svwhilelt_b64_u64(i, n);
204
+ svfloat64_t a_f64x = svld1_f64(predicate_b64x, a + i);
205
+ svfloat64_t b_f64x = svld1_f64(predicate_b64x, b + i);
185
206
  // TwoProd for ab: product = a*b, error = fma(a,b,-product) = -(product - a*b)
186
- svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_f64x, b_f64x);
187
- svfloat64_t product_error_f64x = svneg_f64_x(predicate_f64x,
188
- svnmls_f64_x(predicate_f64x, product_f64x, a_f64x, b_f64x));
207
+ svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_f64x, b_f64x);
208
+ svfloat64_t product_error_f64x = svneg_f64_x(predicate_b64x,
209
+ svnmls_f64_x(predicate_b64x, product_f64x, a_f64x, b_f64x));
189
210
  // TwoSum: (tentative_sum, sum_error) = TwoSum(sum, product)
190
- svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, ab_sum_f64x, product_f64x);
191
- svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, ab_sum_f64x);
211
+ svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, ab_sum_f64x, product_f64x);
212
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, ab_sum_f64x);
192
213
  svfloat64_t sum_error_f64x = svadd_f64_x(
193
- predicate_f64x,
194
- svsub_f64_x(predicate_f64x, ab_sum_f64x,
195
- svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
196
- svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
214
+ predicate_b64x,
215
+ svsub_f64_x(predicate_b64x, ab_sum_f64x,
216
+ svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
217
+ svsub_f64_x(predicate_b64x, product_f64x, virtual_addend_f64x));
197
218
  ab_sum_f64x = tentative_sum_f64x;
198
- ab_compensation_f64x = svadd_f64_x(predicate_f64x, ab_compensation_f64x,
199
- svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
219
+ ab_compensation_f64x = svadd_f64_m(predicate_b64x, ab_compensation_f64x,
220
+ svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
200
221
  // Simple FMA for self-products (no cancellation)
201
- a2_f64x = svmla_f64_x(predicate_f64x, a2_f64x, a_f64x, a_f64x);
202
- b2_f64x = svmla_f64_x(predicate_f64x, b2_f64x, b_f64x, b_f64x);
222
+ a2_f64x = svmla_f64_m(predicate_b64x, a2_f64x, a_f64x, a_f64x);
223
+ b2_f64x = svmla_f64_m(predicate_b64x, b2_f64x, b_f64x, b_f64x);
203
224
  i += svcntd();
204
225
  } while (i < n);
205
226
 
206
- nk_f64_t ab_f64 = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, ab_sum_f64x, ab_compensation_f64x);
207
- nk_f64_t a2_f64 = svaddv_f64(predicate_all_f64x, a2_f64x);
208
- nk_f64_t b2_f64 = svaddv_f64(predicate_all_f64x, b2_f64x);
227
+ nk_f64_t ab_f64 = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, ab_sum_f64x, ab_compensation_f64x);
228
+ nk_f64_t a2_f64 = svaddv_f64(predicate_all_b64x, a2_f64x);
229
+ nk_f64_t b2_f64 = svaddv_f64(predicate_all_b64x, b2_f64x);
209
230
  *result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
210
231
  }
211
232
 
@@ -8,19 +8,19 @@
8
8
  *
9
9
  * @section spatial_svebfdot_instructions ARM SVE+BF16 Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * svld1_bf16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
13
- * svld1_u16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
14
- * svbfdot_f32 BFDOT (Z.S, Z.H, Z.H) 4cy 2/cy
15
- * svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy 2/cy
16
- * svsub_f32_x FSUB (Z.S, P/M, Z.S, Z.S) 3cy 2/cy
17
- * svaddv_f32 FADDV (S, P, Z.S) 6cy 1/cy
18
- * svunpklo_u32 UUNPKLO (Z.S, Z.H) 2cy 2/cy
19
- * svunpkhi_u32 UUNPKHI (Z.S, Z.H) 2cy 2/cy
20
- * svlsl_n_u32_x LSL (Z.S, P/M, Z.S, #imm) 2cy 2/cy
21
- * svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy 1/cy
22
- * svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy 1/cy
23
- * svcnth CNTH (Xd) 1cy 2/cy
11
+ * Intrinsic Instruction V1
12
+ * svld1_bf16 LD1H (Z.H, P/Z, [Xn]) 4-6cy @ 2p
13
+ * svld1_u16 LD1H (Z.H, P/Z, [Xn]) 4-6cy @ 2p
14
+ * svbfdot_f32 BFDOT (Z.S, Z.H, Z.H) 4cy @ 2p
15
+ * svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy @ 2p
16
+ * svsub_f32_x FSUB (Z.S, P/M, Z.S, Z.S) 3cy @ 2p
17
+ * svaddv_f32 FADDV (S, P, Z.S) 6cy @ 1p
18
+ * svunpklo_u32 UUNPKLO (Z.S, Z.H) 2cy @ 2p
19
+ * svunpkhi_u32 UUNPKHI (Z.S, Z.H) 2cy @ 2p
20
+ * svlsl_n_u32_x LSL (Z.S, P/M, Z.S, #imm) 2cy @ 2p
21
+ * svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy @ 1p
22
+ * svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy @ 1p
23
+ * svcnth CNTH (Xd) 1cy @ 2p
24
24
  *
25
25
  * SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
26
26
  * and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
@@ -57,22 +57,22 @@ NK_PUBLIC void nk_sqeuclidean_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t c
57
57
  nk_u16_t const *a = (nk_u16_t const *)(a_enum);
58
58
  nk_u16_t const *b = (nk_u16_t const *)(b_enum);
59
59
  do {
60
- svbool_t predicate_bf16x = svwhilelt_b16_u64(i, n);
61
- svuint16_t a_u16x = svld1_u16(predicate_bf16x, a + i);
62
- svuint16_t b_u16x = svld1_u16(predicate_bf16x, b + i);
60
+ svbool_t predicate_b16x = svwhilelt_b16_u64(i, n);
61
+ svuint16_t a_u16x = svld1_u16(predicate_b16x, a + i);
62
+ svuint16_t b_u16x = svld1_u16(predicate_b16x, b + i);
63
63
 
64
64
  // There is no `bf16` subtraction in SVE, so we need to convert to `u32` and shift.
65
- svbool_t predicate_low_f32x = svwhilelt_b32_u64(i, n);
66
- svbool_t predicate_high_f32x = svwhilelt_b32_u64(i + svcnth() / 2, n);
67
- svfloat32_t a_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_low_f32x, svunpklo_u32(a_u16x), 16));
68
- svfloat32_t a_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_high_f32x, svunpkhi_u32(a_u16x), 16));
69
- svfloat32_t b_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_low_f32x, svunpklo_u32(b_u16x), 16));
70
- svfloat32_t b_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_high_f32x, svunpkhi_u32(b_u16x), 16));
65
+ svbool_t predicate_low_b32x = svwhilelt_b32_u64(i, n);
66
+ svbool_t predicate_high_b32x = svwhilelt_b32_u64(i + svcnth() / 2, n);
67
+ svfloat32_t a_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_low_b32x, svunpklo_u32(a_u16x), 16));
68
+ svfloat32_t a_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_high_b32x, svunpkhi_u32(a_u16x), 16));
69
+ svfloat32_t b_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_low_b32x, svunpklo_u32(b_u16x), 16));
70
+ svfloat32_t b_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_high_b32x, svunpkhi_u32(b_u16x), 16));
71
71
 
72
- svfloat32_t a_minus_b_low_f32x = svsub_f32_x(predicate_low_f32x, a_low_f32x, b_low_f32x);
73
- svfloat32_t a_minus_b_high_f32x = svsub_f32_x(predicate_high_f32x, a_high_f32x, b_high_f32x);
74
- d2_low_f32x = svmla_f32_x(predicate_bf16x, d2_low_f32x, a_minus_b_low_f32x, a_minus_b_low_f32x);
75
- d2_high_f32x = svmla_f32_x(predicate_bf16x, d2_high_f32x, a_minus_b_high_f32x, a_minus_b_high_f32x);
72
+ svfloat32_t a_minus_b_low_f32x = svsub_f32_x(predicate_low_b32x, a_low_f32x, b_low_f32x);
73
+ svfloat32_t a_minus_b_high_f32x = svsub_f32_x(predicate_high_b32x, a_high_f32x, b_high_f32x);
74
+ d2_low_f32x = svmla_f32_m(predicate_low_b32x, d2_low_f32x, a_minus_b_low_f32x, a_minus_b_low_f32x);
75
+ d2_high_f32x = svmla_f32_m(predicate_high_b32x, d2_high_f32x, a_minus_b_high_f32x, a_minus_b_high_f32x);
76
76
  i += svcnth();
77
77
  } while (i < n);
78
78
  nk_f32_t d2 = svaddv_f32(svptrue_b32(), d2_low_f32x) + svaddv_f32(svptrue_b32(), d2_high_f32x);
@@ -92,9 +92,9 @@ NK_PUBLIC void nk_angular_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t const
92
92
  nk_bf16_for_arm_simd_t const *a = (nk_bf16_for_arm_simd_t const *)(a_enum);
93
93
  nk_bf16_for_arm_simd_t const *b = (nk_bf16_for_arm_simd_t const *)(b_enum);
94
94
  do {
95
- svbool_t predicate_bf16x = svwhilelt_b16_u64(i, n);
96
- svbfloat16_t a_bf16x = svld1_bf16(predicate_bf16x, a + i);
97
- svbfloat16_t b_bf16x = svld1_bf16(predicate_bf16x, b + i);
95
+ svbool_t predicate_b16x = svwhilelt_b16_u64(i, n);
96
+ svbfloat16_t a_bf16x = svld1_bf16(predicate_b16x, a + i);
97
+ svbfloat16_t b_bf16x = svld1_bf16(predicate_b16x, b + i);
98
98
  ab_f32x = svbfdot_f32(ab_f32x, a_bf16x, b_bf16x);
99
99
  a2_f32x = svbfdot_f32(a2_f32x, a_bf16x, a_bf16x);
100
100
  b2_f32x = svbfdot_f32(b2_f32x, b_bf16x, b_bf16x);