numpy 2.3.5__cp313-cp313-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numpy might be problematic. Click here for more details.

Files changed (897) hide show
  1. numpy/__config__.py +170 -0
  2. numpy/__config__.pyi +102 -0
  3. numpy/__init__.cython-30.pxd +1241 -0
  4. numpy/__init__.pxd +1154 -0
  5. numpy/__init__.py +945 -0
  6. numpy/__init__.pyi +6147 -0
  7. numpy/_array_api_info.py +346 -0
  8. numpy/_array_api_info.pyi +207 -0
  9. numpy/_configtool.py +39 -0
  10. numpy/_configtool.pyi +1 -0
  11. numpy/_core/__init__.py +186 -0
  12. numpy/_core/__init__.pyi +2 -0
  13. numpy/_core/_add_newdocs.py +6967 -0
  14. numpy/_core/_add_newdocs.pyi +3 -0
  15. numpy/_core/_add_newdocs_scalars.py +390 -0
  16. numpy/_core/_add_newdocs_scalars.pyi +16 -0
  17. numpy/_core/_asarray.py +134 -0
  18. numpy/_core/_asarray.pyi +41 -0
  19. numpy/_core/_dtype.py +366 -0
  20. numpy/_core/_dtype.pyi +58 -0
  21. numpy/_core/_dtype_ctypes.py +120 -0
  22. numpy/_core/_dtype_ctypes.pyi +83 -0
  23. numpy/_core/_exceptions.py +162 -0
  24. numpy/_core/_exceptions.pyi +55 -0
  25. numpy/_core/_internal.py +958 -0
  26. numpy/_core/_internal.pyi +72 -0
  27. numpy/_core/_machar.py +355 -0
  28. numpy/_core/_machar.pyi +55 -0
  29. numpy/_core/_methods.py +255 -0
  30. numpy/_core/_methods.pyi +22 -0
  31. numpy/_core/_multiarray_tests.cpython-313-darwin.so +0 -0
  32. numpy/_core/_multiarray_umath.cpython-313-darwin.so +0 -0
  33. numpy/_core/_operand_flag_tests.cpython-313-darwin.so +0 -0
  34. numpy/_core/_rational_tests.cpython-313-darwin.so +0 -0
  35. numpy/_core/_simd.cpython-313-darwin.so +0 -0
  36. numpy/_core/_simd.pyi +25 -0
  37. numpy/_core/_string_helpers.py +100 -0
  38. numpy/_core/_string_helpers.pyi +12 -0
  39. numpy/_core/_struct_ufunc_tests.cpython-313-darwin.so +0 -0
  40. numpy/_core/_type_aliases.py +119 -0
  41. numpy/_core/_type_aliases.pyi +97 -0
  42. numpy/_core/_ufunc_config.py +491 -0
  43. numpy/_core/_ufunc_config.pyi +78 -0
  44. numpy/_core/_umath_tests.cpython-313-darwin.so +0 -0
  45. numpy/_core/arrayprint.py +1775 -0
  46. numpy/_core/arrayprint.pyi +238 -0
  47. numpy/_core/cversions.py +13 -0
  48. numpy/_core/defchararray.py +1427 -0
  49. numpy/_core/defchararray.pyi +1135 -0
  50. numpy/_core/einsumfunc.py +1498 -0
  51. numpy/_core/einsumfunc.pyi +184 -0
  52. numpy/_core/fromnumeric.py +4269 -0
  53. numpy/_core/fromnumeric.pyi +1750 -0
  54. numpy/_core/function_base.py +545 -0
  55. numpy/_core/function_base.pyi +278 -0
  56. numpy/_core/getlimits.py +748 -0
  57. numpy/_core/getlimits.pyi +3 -0
  58. numpy/_core/include/numpy/__multiarray_api.c +376 -0
  59. numpy/_core/include/numpy/__multiarray_api.h +1628 -0
  60. numpy/_core/include/numpy/__ufunc_api.c +54 -0
  61. numpy/_core/include/numpy/__ufunc_api.h +341 -0
  62. numpy/_core/include/numpy/_neighborhood_iterator_imp.h +90 -0
  63. numpy/_core/include/numpy/_numpyconfig.h +33 -0
  64. numpy/_core/include/numpy/_public_dtype_api_table.h +86 -0
  65. numpy/_core/include/numpy/arrayobject.h +7 -0
  66. numpy/_core/include/numpy/arrayscalars.h +196 -0
  67. numpy/_core/include/numpy/dtype_api.h +480 -0
  68. numpy/_core/include/numpy/halffloat.h +70 -0
  69. numpy/_core/include/numpy/ndarrayobject.h +304 -0
  70. numpy/_core/include/numpy/ndarraytypes.h +1950 -0
  71. numpy/_core/include/numpy/npy_2_compat.h +249 -0
  72. numpy/_core/include/numpy/npy_2_complexcompat.h +28 -0
  73. numpy/_core/include/numpy/npy_3kcompat.h +374 -0
  74. numpy/_core/include/numpy/npy_common.h +977 -0
  75. numpy/_core/include/numpy/npy_cpu.h +124 -0
  76. numpy/_core/include/numpy/npy_endian.h +78 -0
  77. numpy/_core/include/numpy/npy_math.h +602 -0
  78. numpy/_core/include/numpy/npy_no_deprecated_api.h +20 -0
  79. numpy/_core/include/numpy/npy_os.h +42 -0
  80. numpy/_core/include/numpy/numpyconfig.h +182 -0
  81. numpy/_core/include/numpy/random/LICENSE.txt +21 -0
  82. numpy/_core/include/numpy/random/bitgen.h +20 -0
  83. numpy/_core/include/numpy/random/distributions.h +209 -0
  84. numpy/_core/include/numpy/random/libdivide.h +2079 -0
  85. numpy/_core/include/numpy/ufuncobject.h +343 -0
  86. numpy/_core/include/numpy/utils.h +37 -0
  87. numpy/_core/lib/libnpymath.a +0 -0
  88. numpy/_core/lib/npy-pkg-config/mlib.ini +12 -0
  89. numpy/_core/lib/npy-pkg-config/npymath.ini +20 -0
  90. numpy/_core/lib/pkgconfig/numpy.pc +7 -0
  91. numpy/_core/memmap.py +363 -0
  92. numpy/_core/memmap.pyi +3 -0
  93. numpy/_core/multiarray.py +1762 -0
  94. numpy/_core/multiarray.pyi +1285 -0
  95. numpy/_core/numeric.py +2760 -0
  96. numpy/_core/numeric.pyi +882 -0
  97. numpy/_core/numerictypes.py +633 -0
  98. numpy/_core/numerictypes.pyi +197 -0
  99. numpy/_core/overrides.py +183 -0
  100. numpy/_core/overrides.pyi +48 -0
  101. numpy/_core/printoptions.py +32 -0
  102. numpy/_core/printoptions.pyi +28 -0
  103. numpy/_core/records.py +1089 -0
  104. numpy/_core/records.pyi +333 -0
  105. numpy/_core/shape_base.py +998 -0
  106. numpy/_core/shape_base.pyi +175 -0
  107. numpy/_core/strings.py +1829 -0
  108. numpy/_core/strings.pyi +511 -0
  109. numpy/_core/tests/_locales.py +72 -0
  110. numpy/_core/tests/_natype.py +205 -0
  111. numpy/_core/tests/data/astype_copy.pkl +0 -0
  112. numpy/_core/tests/data/generate_umath_validation_data.cpp +170 -0
  113. numpy/_core/tests/data/recarray_from_file.fits +0 -0
  114. numpy/_core/tests/data/umath-validation-set-README.txt +15 -0
  115. numpy/_core/tests/data/umath-validation-set-arccos.csv +1429 -0
  116. numpy/_core/tests/data/umath-validation-set-arccosh.csv +1429 -0
  117. numpy/_core/tests/data/umath-validation-set-arcsin.csv +1429 -0
  118. numpy/_core/tests/data/umath-validation-set-arcsinh.csv +1429 -0
  119. numpy/_core/tests/data/umath-validation-set-arctan.csv +1429 -0
  120. numpy/_core/tests/data/umath-validation-set-arctanh.csv +1429 -0
  121. numpy/_core/tests/data/umath-validation-set-cbrt.csv +1429 -0
  122. numpy/_core/tests/data/umath-validation-set-cos.csv +1375 -0
  123. numpy/_core/tests/data/umath-validation-set-cosh.csv +1429 -0
  124. numpy/_core/tests/data/umath-validation-set-exp.csv +412 -0
  125. numpy/_core/tests/data/umath-validation-set-exp2.csv +1429 -0
  126. numpy/_core/tests/data/umath-validation-set-expm1.csv +1429 -0
  127. numpy/_core/tests/data/umath-validation-set-log.csv +271 -0
  128. numpy/_core/tests/data/umath-validation-set-log10.csv +1629 -0
  129. numpy/_core/tests/data/umath-validation-set-log1p.csv +1429 -0
  130. numpy/_core/tests/data/umath-validation-set-log2.csv +1629 -0
  131. numpy/_core/tests/data/umath-validation-set-sin.csv +1370 -0
  132. numpy/_core/tests/data/umath-validation-set-sinh.csv +1429 -0
  133. numpy/_core/tests/data/umath-validation-set-tan.csv +1429 -0
  134. numpy/_core/tests/data/umath-validation-set-tanh.csv +1429 -0
  135. numpy/_core/tests/examples/cython/checks.pyx +373 -0
  136. numpy/_core/tests/examples/cython/meson.build +43 -0
  137. numpy/_core/tests/examples/cython/setup.py +39 -0
  138. numpy/_core/tests/examples/limited_api/limited_api1.c +17 -0
  139. numpy/_core/tests/examples/limited_api/limited_api2.pyx +11 -0
  140. numpy/_core/tests/examples/limited_api/limited_api_latest.c +19 -0
  141. numpy/_core/tests/examples/limited_api/meson.build +59 -0
  142. numpy/_core/tests/examples/limited_api/setup.py +24 -0
  143. numpy/_core/tests/test__exceptions.py +90 -0
  144. numpy/_core/tests/test_abc.py +54 -0
  145. numpy/_core/tests/test_api.py +654 -0
  146. numpy/_core/tests/test_argparse.py +92 -0
  147. numpy/_core/tests/test_array_api_info.py +113 -0
  148. numpy/_core/tests/test_array_coercion.py +911 -0
  149. numpy/_core/tests/test_array_interface.py +222 -0
  150. numpy/_core/tests/test_arraymethod.py +84 -0
  151. numpy/_core/tests/test_arrayobject.py +75 -0
  152. numpy/_core/tests/test_arrayprint.py +1328 -0
  153. numpy/_core/tests/test_casting_floatingpoint_errors.py +154 -0
  154. numpy/_core/tests/test_casting_unittests.py +817 -0
  155. numpy/_core/tests/test_conversion_utils.py +206 -0
  156. numpy/_core/tests/test_cpu_dispatcher.py +49 -0
  157. numpy/_core/tests/test_cpu_features.py +432 -0
  158. numpy/_core/tests/test_custom_dtypes.py +315 -0
  159. numpy/_core/tests/test_cython.py +351 -0
  160. numpy/_core/tests/test_datetime.py +2734 -0
  161. numpy/_core/tests/test_defchararray.py +825 -0
  162. numpy/_core/tests/test_deprecations.py +454 -0
  163. numpy/_core/tests/test_dlpack.py +190 -0
  164. numpy/_core/tests/test_dtype.py +1995 -0
  165. numpy/_core/tests/test_einsum.py +1317 -0
  166. numpy/_core/tests/test_errstate.py +131 -0
  167. numpy/_core/tests/test_extint128.py +217 -0
  168. numpy/_core/tests/test_function_base.py +503 -0
  169. numpy/_core/tests/test_getlimits.py +205 -0
  170. numpy/_core/tests/test_half.py +568 -0
  171. numpy/_core/tests/test_hashtable.py +35 -0
  172. numpy/_core/tests/test_indexerrors.py +125 -0
  173. numpy/_core/tests/test_indexing.py +1455 -0
  174. numpy/_core/tests/test_item_selection.py +167 -0
  175. numpy/_core/tests/test_limited_api.py +102 -0
  176. numpy/_core/tests/test_longdouble.py +369 -0
  177. numpy/_core/tests/test_machar.py +30 -0
  178. numpy/_core/tests/test_mem_overlap.py +930 -0
  179. numpy/_core/tests/test_mem_policy.py +452 -0
  180. numpy/_core/tests/test_memmap.py +246 -0
  181. numpy/_core/tests/test_multiarray.py +10577 -0
  182. numpy/_core/tests/test_multithreading.py +292 -0
  183. numpy/_core/tests/test_nditer.py +3498 -0
  184. numpy/_core/tests/test_nep50_promotions.py +287 -0
  185. numpy/_core/tests/test_numeric.py +4247 -0
  186. numpy/_core/tests/test_numerictypes.py +651 -0
  187. numpy/_core/tests/test_overrides.py +791 -0
  188. numpy/_core/tests/test_print.py +200 -0
  189. numpy/_core/tests/test_protocols.py +46 -0
  190. numpy/_core/tests/test_records.py +544 -0
  191. numpy/_core/tests/test_regression.py +2670 -0
  192. numpy/_core/tests/test_scalar_ctors.py +207 -0
  193. numpy/_core/tests/test_scalar_methods.py +246 -0
  194. numpy/_core/tests/test_scalarbuffer.py +153 -0
  195. numpy/_core/tests/test_scalarinherit.py +105 -0
  196. numpy/_core/tests/test_scalarmath.py +1176 -0
  197. numpy/_core/tests/test_scalarprint.py +403 -0
  198. numpy/_core/tests/test_shape_base.py +891 -0
  199. numpy/_core/tests/test_simd.py +1341 -0
  200. numpy/_core/tests/test_simd_module.py +103 -0
  201. numpy/_core/tests/test_stringdtype.py +1814 -0
  202. numpy/_core/tests/test_strings.py +1499 -0
  203. numpy/_core/tests/test_ufunc.py +3313 -0
  204. numpy/_core/tests/test_umath.py +4928 -0
  205. numpy/_core/tests/test_umath_accuracy.py +124 -0
  206. numpy/_core/tests/test_umath_complex.py +626 -0
  207. numpy/_core/tests/test_unicode.py +368 -0
  208. numpy/_core/umath.py +60 -0
  209. numpy/_core/umath.pyi +197 -0
  210. numpy/_distributor_init.py +15 -0
  211. numpy/_distributor_init.pyi +1 -0
  212. numpy/_expired_attrs_2_0.py +79 -0
  213. numpy/_expired_attrs_2_0.pyi +62 -0
  214. numpy/_globals.py +96 -0
  215. numpy/_globals.pyi +17 -0
  216. numpy/_pyinstaller/__init__.py +0 -0
  217. numpy/_pyinstaller/__init__.pyi +0 -0
  218. numpy/_pyinstaller/hook-numpy.py +36 -0
  219. numpy/_pyinstaller/hook-numpy.pyi +13 -0
  220. numpy/_pyinstaller/tests/__init__.py +16 -0
  221. numpy/_pyinstaller/tests/pyinstaller-smoke.py +32 -0
  222. numpy/_pyinstaller/tests/test_pyinstaller.py +35 -0
  223. numpy/_pytesttester.py +201 -0
  224. numpy/_pytesttester.pyi +18 -0
  225. numpy/_typing/__init__.py +148 -0
  226. numpy/_typing/_add_docstring.py +153 -0
  227. numpy/_typing/_array_like.py +106 -0
  228. numpy/_typing/_char_codes.py +213 -0
  229. numpy/_typing/_dtype_like.py +114 -0
  230. numpy/_typing/_extended_precision.py +15 -0
  231. numpy/_typing/_nbit.py +19 -0
  232. numpy/_typing/_nbit_base.py +94 -0
  233. numpy/_typing/_nbit_base.pyi +40 -0
  234. numpy/_typing/_nested_sequence.py +79 -0
  235. numpy/_typing/_scalars.py +20 -0
  236. numpy/_typing/_shape.py +8 -0
  237. numpy/_typing/_ufunc.py +7 -0
  238. numpy/_typing/_ufunc.pyi +941 -0
  239. numpy/_utils/__init__.py +95 -0
  240. numpy/_utils/__init__.pyi +30 -0
  241. numpy/_utils/_convertions.py +18 -0
  242. numpy/_utils/_convertions.pyi +4 -0
  243. numpy/_utils/_inspect.py +192 -0
  244. numpy/_utils/_inspect.pyi +71 -0
  245. numpy/_utils/_pep440.py +486 -0
  246. numpy/_utils/_pep440.pyi +121 -0
  247. numpy/char/__init__.py +2 -0
  248. numpy/char/__init__.pyi +111 -0
  249. numpy/conftest.py +258 -0
  250. numpy/core/__init__.py +33 -0
  251. numpy/core/__init__.pyi +0 -0
  252. numpy/core/_dtype.py +10 -0
  253. numpy/core/_dtype.pyi +0 -0
  254. numpy/core/_dtype_ctypes.py +10 -0
  255. numpy/core/_dtype_ctypes.pyi +0 -0
  256. numpy/core/_internal.py +27 -0
  257. numpy/core/_multiarray_umath.py +57 -0
  258. numpy/core/_utils.py +21 -0
  259. numpy/core/arrayprint.py +10 -0
  260. numpy/core/defchararray.py +10 -0
  261. numpy/core/einsumfunc.py +10 -0
  262. numpy/core/fromnumeric.py +10 -0
  263. numpy/core/function_base.py +10 -0
  264. numpy/core/getlimits.py +10 -0
  265. numpy/core/multiarray.py +25 -0
  266. numpy/core/numeric.py +12 -0
  267. numpy/core/numerictypes.py +10 -0
  268. numpy/core/overrides.py +10 -0
  269. numpy/core/overrides.pyi +7 -0
  270. numpy/core/records.py +10 -0
  271. numpy/core/shape_base.py +10 -0
  272. numpy/core/umath.py +10 -0
  273. numpy/ctypeslib/__init__.py +13 -0
  274. numpy/ctypeslib/__init__.pyi +33 -0
  275. numpy/ctypeslib/_ctypeslib.py +603 -0
  276. numpy/ctypeslib/_ctypeslib.pyi +245 -0
  277. numpy/doc/ufuncs.py +138 -0
  278. numpy/dtypes.py +41 -0
  279. numpy/dtypes.pyi +631 -0
  280. numpy/exceptions.py +247 -0
  281. numpy/exceptions.pyi +27 -0
  282. numpy/f2py/__init__.py +86 -0
  283. numpy/f2py/__init__.pyi +6 -0
  284. numpy/f2py/__main__.py +5 -0
  285. numpy/f2py/__version__.py +1 -0
  286. numpy/f2py/__version__.pyi +1 -0
  287. numpy/f2py/_backends/__init__.py +9 -0
  288. numpy/f2py/_backends/__init__.pyi +5 -0
  289. numpy/f2py/_backends/_backend.py +44 -0
  290. numpy/f2py/_backends/_backend.pyi +46 -0
  291. numpy/f2py/_backends/_distutils.py +76 -0
  292. numpy/f2py/_backends/_distutils.pyi +13 -0
  293. numpy/f2py/_backends/_meson.py +231 -0
  294. numpy/f2py/_backends/_meson.pyi +63 -0
  295. numpy/f2py/_backends/meson.build.template +55 -0
  296. numpy/f2py/_isocbind.py +62 -0
  297. numpy/f2py/_isocbind.pyi +13 -0
  298. numpy/f2py/_src_pyf.py +247 -0
  299. numpy/f2py/_src_pyf.pyi +29 -0
  300. numpy/f2py/auxfuncs.py +1004 -0
  301. numpy/f2py/auxfuncs.pyi +264 -0
  302. numpy/f2py/capi_maps.py +811 -0
  303. numpy/f2py/capi_maps.pyi +33 -0
  304. numpy/f2py/cb_rules.py +665 -0
  305. numpy/f2py/cb_rules.pyi +17 -0
  306. numpy/f2py/cfuncs.py +1563 -0
  307. numpy/f2py/cfuncs.pyi +31 -0
  308. numpy/f2py/common_rules.py +143 -0
  309. numpy/f2py/common_rules.pyi +9 -0
  310. numpy/f2py/crackfortran.py +3725 -0
  311. numpy/f2py/crackfortran.pyi +258 -0
  312. numpy/f2py/diagnose.py +149 -0
  313. numpy/f2py/diagnose.pyi +1 -0
  314. numpy/f2py/f2py2e.py +786 -0
  315. numpy/f2py/f2py2e.pyi +76 -0
  316. numpy/f2py/f90mod_rules.py +269 -0
  317. numpy/f2py/f90mod_rules.pyi +16 -0
  318. numpy/f2py/func2subr.py +329 -0
  319. numpy/f2py/func2subr.pyi +7 -0
  320. numpy/f2py/rules.py +1629 -0
  321. numpy/f2py/rules.pyi +43 -0
  322. numpy/f2py/setup.cfg +3 -0
  323. numpy/f2py/src/fortranobject.c +1436 -0
  324. numpy/f2py/src/fortranobject.h +173 -0
  325. numpy/f2py/symbolic.py +1516 -0
  326. numpy/f2py/symbolic.pyi +221 -0
  327. numpy/f2py/tests/__init__.py +16 -0
  328. numpy/f2py/tests/src/abstract_interface/foo.f90 +34 -0
  329. numpy/f2py/tests/src/abstract_interface/gh18403_mod.f90 +6 -0
  330. numpy/f2py/tests/src/array_from_pyobj/wrapmodule.c +235 -0
  331. numpy/f2py/tests/src/assumed_shape/.f2py_f2cmap +1 -0
  332. numpy/f2py/tests/src/assumed_shape/foo_free.f90 +34 -0
  333. numpy/f2py/tests/src/assumed_shape/foo_mod.f90 +41 -0
  334. numpy/f2py/tests/src/assumed_shape/foo_use.f90 +19 -0
  335. numpy/f2py/tests/src/assumed_shape/precision.f90 +4 -0
  336. numpy/f2py/tests/src/block_docstring/foo.f +6 -0
  337. numpy/f2py/tests/src/callback/foo.f +62 -0
  338. numpy/f2py/tests/src/callback/gh17797.f90 +7 -0
  339. numpy/f2py/tests/src/callback/gh18335.f90 +17 -0
  340. numpy/f2py/tests/src/callback/gh25211.f +10 -0
  341. numpy/f2py/tests/src/callback/gh25211.pyf +18 -0
  342. numpy/f2py/tests/src/callback/gh26681.f90 +18 -0
  343. numpy/f2py/tests/src/cli/gh_22819.pyf +6 -0
  344. numpy/f2py/tests/src/cli/hi77.f +3 -0
  345. numpy/f2py/tests/src/cli/hiworld.f90 +3 -0
  346. numpy/f2py/tests/src/common/block.f +11 -0
  347. numpy/f2py/tests/src/common/gh19161.f90 +10 -0
  348. numpy/f2py/tests/src/crackfortran/accesstype.f90 +13 -0
  349. numpy/f2py/tests/src/crackfortran/common_with_division.f +17 -0
  350. numpy/f2py/tests/src/crackfortran/data_common.f +8 -0
  351. numpy/f2py/tests/src/crackfortran/data_multiplier.f +5 -0
  352. numpy/f2py/tests/src/crackfortran/data_stmts.f90 +20 -0
  353. numpy/f2py/tests/src/crackfortran/data_with_comments.f +8 -0
  354. numpy/f2py/tests/src/crackfortran/foo_deps.f90 +6 -0
  355. numpy/f2py/tests/src/crackfortran/gh15035.f +16 -0
  356. numpy/f2py/tests/src/crackfortran/gh17859.f +12 -0
  357. numpy/f2py/tests/src/crackfortran/gh22648.pyf +7 -0
  358. numpy/f2py/tests/src/crackfortran/gh23533.f +5 -0
  359. numpy/f2py/tests/src/crackfortran/gh23598.f90 +4 -0
  360. numpy/f2py/tests/src/crackfortran/gh23598Warn.f90 +11 -0
  361. numpy/f2py/tests/src/crackfortran/gh23879.f90 +20 -0
  362. numpy/f2py/tests/src/crackfortran/gh27697.f90 +12 -0
  363. numpy/f2py/tests/src/crackfortran/gh2848.f90 +13 -0
  364. numpy/f2py/tests/src/crackfortran/operators.f90 +49 -0
  365. numpy/f2py/tests/src/crackfortran/privatemod.f90 +11 -0
  366. numpy/f2py/tests/src/crackfortran/publicmod.f90 +10 -0
  367. numpy/f2py/tests/src/crackfortran/pubprivmod.f90 +10 -0
  368. numpy/f2py/tests/src/crackfortran/unicode_comment.f90 +4 -0
  369. numpy/f2py/tests/src/f2cmap/.f2py_f2cmap +1 -0
  370. numpy/f2py/tests/src/f2cmap/isoFortranEnvMap.f90 +9 -0
  371. numpy/f2py/tests/src/isocintrin/isoCtests.f90 +34 -0
  372. numpy/f2py/tests/src/kind/foo.f90 +20 -0
  373. numpy/f2py/tests/src/mixed/foo.f +5 -0
  374. numpy/f2py/tests/src/mixed/foo_fixed.f90 +8 -0
  375. numpy/f2py/tests/src/mixed/foo_free.f90 +8 -0
  376. numpy/f2py/tests/src/modules/gh25337/data.f90 +8 -0
  377. numpy/f2py/tests/src/modules/gh25337/use_data.f90 +6 -0
  378. numpy/f2py/tests/src/modules/gh26920/two_mods_with_no_public_entities.f90 +21 -0
  379. numpy/f2py/tests/src/modules/gh26920/two_mods_with_one_public_routine.f90 +21 -0
  380. numpy/f2py/tests/src/modules/module_data_docstring.f90 +12 -0
  381. numpy/f2py/tests/src/modules/use_modules.f90 +20 -0
  382. numpy/f2py/tests/src/negative_bounds/issue_20853.f90 +7 -0
  383. numpy/f2py/tests/src/parameter/constant_array.f90 +45 -0
  384. numpy/f2py/tests/src/parameter/constant_both.f90 +57 -0
  385. numpy/f2py/tests/src/parameter/constant_compound.f90 +15 -0
  386. numpy/f2py/tests/src/parameter/constant_integer.f90 +22 -0
  387. numpy/f2py/tests/src/parameter/constant_non_compound.f90 +23 -0
  388. numpy/f2py/tests/src/parameter/constant_real.f90 +23 -0
  389. numpy/f2py/tests/src/quoted_character/foo.f +14 -0
  390. numpy/f2py/tests/src/regression/AB.inc +1 -0
  391. numpy/f2py/tests/src/regression/assignOnlyModule.f90 +25 -0
  392. numpy/f2py/tests/src/regression/datonly.f90 +17 -0
  393. numpy/f2py/tests/src/regression/f77comments.f +26 -0
  394. numpy/f2py/tests/src/regression/f77fixedform.f95 +5 -0
  395. numpy/f2py/tests/src/regression/f90continuation.f90 +9 -0
  396. numpy/f2py/tests/src/regression/incfile.f90 +5 -0
  397. numpy/f2py/tests/src/regression/inout.f90 +9 -0
  398. numpy/f2py/tests/src/regression/lower_f2py_fortran.f90 +5 -0
  399. numpy/f2py/tests/src/regression/mod_derived_types.f90 +23 -0
  400. numpy/f2py/tests/src/return_character/foo77.f +45 -0
  401. numpy/f2py/tests/src/return_character/foo90.f90 +48 -0
  402. numpy/f2py/tests/src/return_complex/foo77.f +45 -0
  403. numpy/f2py/tests/src/return_complex/foo90.f90 +48 -0
  404. numpy/f2py/tests/src/return_integer/foo77.f +56 -0
  405. numpy/f2py/tests/src/return_integer/foo90.f90 +59 -0
  406. numpy/f2py/tests/src/return_logical/foo77.f +56 -0
  407. numpy/f2py/tests/src/return_logical/foo90.f90 +59 -0
  408. numpy/f2py/tests/src/return_real/foo77.f +45 -0
  409. numpy/f2py/tests/src/return_real/foo90.f90 +48 -0
  410. numpy/f2py/tests/src/routines/funcfortranname.f +5 -0
  411. numpy/f2py/tests/src/routines/funcfortranname.pyf +11 -0
  412. numpy/f2py/tests/src/routines/subrout.f +4 -0
  413. numpy/f2py/tests/src/routines/subrout.pyf +10 -0
  414. numpy/f2py/tests/src/size/foo.f90 +44 -0
  415. numpy/f2py/tests/src/string/char.f90 +29 -0
  416. numpy/f2py/tests/src/string/fixed_string.f90 +34 -0
  417. numpy/f2py/tests/src/string/gh24008.f +8 -0
  418. numpy/f2py/tests/src/string/gh24662.f90 +7 -0
  419. numpy/f2py/tests/src/string/gh25286.f90 +14 -0
  420. numpy/f2py/tests/src/string/gh25286.pyf +12 -0
  421. numpy/f2py/tests/src/string/gh25286_bc.pyf +12 -0
  422. numpy/f2py/tests/src/string/scalar_string.f90 +9 -0
  423. numpy/f2py/tests/src/string/string.f +12 -0
  424. numpy/f2py/tests/src/value_attrspec/gh21665.f90 +9 -0
  425. numpy/f2py/tests/test_abstract_interface.py +26 -0
  426. numpy/f2py/tests/test_array_from_pyobj.py +678 -0
  427. numpy/f2py/tests/test_assumed_shape.py +50 -0
  428. numpy/f2py/tests/test_block_docstring.py +20 -0
  429. numpy/f2py/tests/test_callback.py +263 -0
  430. numpy/f2py/tests/test_character.py +641 -0
  431. numpy/f2py/tests/test_common.py +23 -0
  432. numpy/f2py/tests/test_crackfortran.py +421 -0
  433. numpy/f2py/tests/test_data.py +71 -0
  434. numpy/f2py/tests/test_docs.py +64 -0
  435. numpy/f2py/tests/test_f2cmap.py +17 -0
  436. numpy/f2py/tests/test_f2py2e.py +964 -0
  437. numpy/f2py/tests/test_isoc.py +56 -0
  438. numpy/f2py/tests/test_kind.py +53 -0
  439. numpy/f2py/tests/test_mixed.py +35 -0
  440. numpy/f2py/tests/test_modules.py +83 -0
  441. numpy/f2py/tests/test_parameter.py +129 -0
  442. numpy/f2py/tests/test_pyf_src.py +43 -0
  443. numpy/f2py/tests/test_quoted_character.py +18 -0
  444. numpy/f2py/tests/test_regression.py +187 -0
  445. numpy/f2py/tests/test_return_character.py +48 -0
  446. numpy/f2py/tests/test_return_complex.py +67 -0
  447. numpy/f2py/tests/test_return_integer.py +55 -0
  448. numpy/f2py/tests/test_return_logical.py +65 -0
  449. numpy/f2py/tests/test_return_real.py +109 -0
  450. numpy/f2py/tests/test_routines.py +29 -0
  451. numpy/f2py/tests/test_semicolon_split.py +75 -0
  452. numpy/f2py/tests/test_size.py +45 -0
  453. numpy/f2py/tests/test_string.py +100 -0
  454. numpy/f2py/tests/test_symbolic.py +495 -0
  455. numpy/f2py/tests/test_value_attrspec.py +15 -0
  456. numpy/f2py/tests/util.py +442 -0
  457. numpy/f2py/use_rules.py +99 -0
  458. numpy/f2py/use_rules.pyi +9 -0
  459. numpy/fft/__init__.py +215 -0
  460. numpy/fft/__init__.pyi +43 -0
  461. numpy/fft/_helper.py +235 -0
  462. numpy/fft/_helper.pyi +45 -0
  463. numpy/fft/_pocketfft.py +1693 -0
  464. numpy/fft/_pocketfft.pyi +138 -0
  465. numpy/fft/_pocketfft_umath.cpython-313-darwin.so +0 -0
  466. numpy/fft/helper.py +17 -0
  467. numpy/fft/helper.pyi +22 -0
  468. numpy/fft/tests/__init__.py +0 -0
  469. numpy/fft/tests/test_helper.py +167 -0
  470. numpy/fft/tests/test_pocketfft.py +589 -0
  471. numpy/lib/__init__.py +97 -0
  472. numpy/lib/__init__.pyi +44 -0
  473. numpy/lib/_array_utils_impl.py +62 -0
  474. numpy/lib/_array_utils_impl.pyi +26 -0
  475. numpy/lib/_arraypad_impl.py +890 -0
  476. numpy/lib/_arraypad_impl.pyi +89 -0
  477. numpy/lib/_arraysetops_impl.py +1260 -0
  478. numpy/lib/_arraysetops_impl.pyi +468 -0
  479. numpy/lib/_arrayterator_impl.py +224 -0
  480. numpy/lib/_arrayterator_impl.pyi +46 -0
  481. numpy/lib/_datasource.py +700 -0
  482. numpy/lib/_datasource.pyi +31 -0
  483. numpy/lib/_format_impl.py +1036 -0
  484. numpy/lib/_format_impl.pyi +26 -0
  485. numpy/lib/_function_base_impl.py +5844 -0
  486. numpy/lib/_function_base_impl.pyi +1164 -0
  487. numpy/lib/_histograms_impl.py +1085 -0
  488. numpy/lib/_histograms_impl.pyi +50 -0
  489. numpy/lib/_index_tricks_impl.py +1067 -0
  490. numpy/lib/_index_tricks_impl.pyi +208 -0
  491. numpy/lib/_iotools.py +900 -0
  492. numpy/lib/_iotools.pyi +114 -0
  493. numpy/lib/_nanfunctions_impl.py +2024 -0
  494. numpy/lib/_nanfunctions_impl.pyi +52 -0
  495. numpy/lib/_npyio_impl.py +2596 -0
  496. numpy/lib/_npyio_impl.pyi +301 -0
  497. numpy/lib/_polynomial_impl.py +1465 -0
  498. numpy/lib/_polynomial_impl.pyi +318 -0
  499. numpy/lib/_scimath_impl.py +642 -0
  500. numpy/lib/_scimath_impl.pyi +93 -0
  501. numpy/lib/_shape_base_impl.py +1301 -0
  502. numpy/lib/_shape_base_impl.pyi +235 -0
  503. numpy/lib/_stride_tricks_impl.py +549 -0
  504. numpy/lib/_stride_tricks_impl.pyi +74 -0
  505. numpy/lib/_twodim_base_impl.py +1201 -0
  506. numpy/lib/_twodim_base_impl.pyi +438 -0
  507. numpy/lib/_type_check_impl.py +699 -0
  508. numpy/lib/_type_check_impl.pyi +350 -0
  509. numpy/lib/_ufunclike_impl.py +207 -0
  510. numpy/lib/_ufunclike_impl.pyi +67 -0
  511. numpy/lib/_user_array_impl.py +299 -0
  512. numpy/lib/_user_array_impl.pyi +225 -0
  513. numpy/lib/_utils_impl.py +784 -0
  514. numpy/lib/_utils_impl.pyi +10 -0
  515. numpy/lib/_version.py +154 -0
  516. numpy/lib/_version.pyi +17 -0
  517. numpy/lib/array_utils.py +7 -0
  518. numpy/lib/array_utils.pyi +12 -0
  519. numpy/lib/format.py +24 -0
  520. numpy/lib/format.pyi +66 -0
  521. numpy/lib/introspect.py +95 -0
  522. numpy/lib/introspect.pyi +3 -0
  523. numpy/lib/mixins.py +180 -0
  524. numpy/lib/mixins.pyi +77 -0
  525. numpy/lib/npyio.py +1 -0
  526. numpy/lib/npyio.pyi +9 -0
  527. numpy/lib/recfunctions.py +1681 -0
  528. numpy/lib/recfunctions.pyi +435 -0
  529. numpy/lib/scimath.py +13 -0
  530. numpy/lib/scimath.pyi +30 -0
  531. numpy/lib/stride_tricks.py +1 -0
  532. numpy/lib/stride_tricks.pyi +6 -0
  533. numpy/lib/tests/__init__.py +0 -0
  534. numpy/lib/tests/data/py2-np0-objarr.npy +0 -0
  535. numpy/lib/tests/data/py2-objarr.npy +0 -0
  536. numpy/lib/tests/data/py2-objarr.npz +0 -0
  537. numpy/lib/tests/data/py3-objarr.npy +0 -0
  538. numpy/lib/tests/data/py3-objarr.npz +0 -0
  539. numpy/lib/tests/data/python3.npy +0 -0
  540. numpy/lib/tests/data/win64python2.npy +0 -0
  541. numpy/lib/tests/test__datasource.py +352 -0
  542. numpy/lib/tests/test__iotools.py +360 -0
  543. numpy/lib/tests/test__version.py +64 -0
  544. numpy/lib/tests/test_array_utils.py +32 -0
  545. numpy/lib/tests/test_arraypad.py +1415 -0
  546. numpy/lib/tests/test_arraysetops.py +1074 -0
  547. numpy/lib/tests/test_arrayterator.py +46 -0
  548. numpy/lib/tests/test_format.py +1054 -0
  549. numpy/lib/tests/test_function_base.py +4573 -0
  550. numpy/lib/tests/test_histograms.py +855 -0
  551. numpy/lib/tests/test_index_tricks.py +573 -0
  552. numpy/lib/tests/test_io.py +2848 -0
  553. numpy/lib/tests/test_loadtxt.py +1101 -0
  554. numpy/lib/tests/test_mixins.py +215 -0
  555. numpy/lib/tests/test_nanfunctions.py +1438 -0
  556. numpy/lib/tests/test_packbits.py +376 -0
  557. numpy/lib/tests/test_polynomial.py +320 -0
  558. numpy/lib/tests/test_recfunctions.py +1052 -0
  559. numpy/lib/tests/test_regression.py +231 -0
  560. numpy/lib/tests/test_shape_base.py +813 -0
  561. numpy/lib/tests/test_stride_tricks.py +656 -0
  562. numpy/lib/tests/test_twodim_base.py +559 -0
  563. numpy/lib/tests/test_type_check.py +473 -0
  564. numpy/lib/tests/test_ufunclike.py +97 -0
  565. numpy/lib/tests/test_utils.py +80 -0
  566. numpy/lib/user_array.py +1 -0
  567. numpy/lib/user_array.pyi +1 -0
  568. numpy/linalg/__init__.py +98 -0
  569. numpy/linalg/__init__.pyi +73 -0
  570. numpy/linalg/_linalg.py +3682 -0
  571. numpy/linalg/_linalg.pyi +475 -0
  572. numpy/linalg/_umath_linalg.cpython-313-darwin.so +0 -0
  573. numpy/linalg/_umath_linalg.pyi +61 -0
  574. numpy/linalg/lapack_lite.cpython-313-darwin.so +0 -0
  575. numpy/linalg/lapack_lite.pyi +141 -0
  576. numpy/linalg/linalg.py +17 -0
  577. numpy/linalg/linalg.pyi +69 -0
  578. numpy/linalg/tests/__init__.py +0 -0
  579. numpy/linalg/tests/test_deprecations.py +20 -0
  580. numpy/linalg/tests/test_linalg.py +2443 -0
  581. numpy/linalg/tests/test_regression.py +181 -0
  582. numpy/ma/API_CHANGES.txt +135 -0
  583. numpy/ma/LICENSE +24 -0
  584. numpy/ma/README.rst +236 -0
  585. numpy/ma/__init__.py +53 -0
  586. numpy/ma/__init__.pyi +458 -0
  587. numpy/ma/core.py +8933 -0
  588. numpy/ma/core.pyi +1462 -0
  589. numpy/ma/extras.py +2344 -0
  590. numpy/ma/extras.pyi +138 -0
  591. numpy/ma/mrecords.py +773 -0
  592. numpy/ma/mrecords.pyi +96 -0
  593. numpy/ma/tests/__init__.py +0 -0
  594. numpy/ma/tests/test_arrayobject.py +40 -0
  595. numpy/ma/tests/test_core.py +5886 -0
  596. numpy/ma/tests/test_deprecations.py +87 -0
  597. numpy/ma/tests/test_extras.py +1998 -0
  598. numpy/ma/tests/test_mrecords.py +497 -0
  599. numpy/ma/tests/test_old_ma.py +942 -0
  600. numpy/ma/tests/test_regression.py +100 -0
  601. numpy/ma/tests/test_subclassing.py +469 -0
  602. numpy/ma/testutils.py +294 -0
  603. numpy/matlib.py +380 -0
  604. numpy/matlib.pyi +582 -0
  605. numpy/matrixlib/__init__.py +12 -0
  606. numpy/matrixlib/__init__.pyi +5 -0
  607. numpy/matrixlib/defmatrix.py +1119 -0
  608. numpy/matrixlib/defmatrix.pyi +17 -0
  609. numpy/matrixlib/tests/__init__.py +0 -0
  610. numpy/matrixlib/tests/test_defmatrix.py +455 -0
  611. numpy/matrixlib/tests/test_interaction.py +360 -0
  612. numpy/matrixlib/tests/test_masked_matrix.py +240 -0
  613. numpy/matrixlib/tests/test_matrix_linalg.py +105 -0
  614. numpy/matrixlib/tests/test_multiarray.py +17 -0
  615. numpy/matrixlib/tests/test_numeric.py +18 -0
  616. numpy/matrixlib/tests/test_regression.py +31 -0
  617. numpy/polynomial/__init__.py +187 -0
  618. numpy/polynomial/__init__.pyi +25 -0
  619. numpy/polynomial/_polybase.py +1191 -0
  620. numpy/polynomial/_polybase.pyi +285 -0
  621. numpy/polynomial/_polytypes.pyi +892 -0
  622. numpy/polynomial/chebyshev.py +2003 -0
  623. numpy/polynomial/chebyshev.pyi +181 -0
  624. numpy/polynomial/hermite.py +1740 -0
  625. numpy/polynomial/hermite.pyi +107 -0
  626. numpy/polynomial/hermite_e.py +1642 -0
  627. numpy/polynomial/hermite_e.pyi +107 -0
  628. numpy/polynomial/laguerre.py +1675 -0
  629. numpy/polynomial/laguerre.pyi +100 -0
  630. numpy/polynomial/legendre.py +1605 -0
  631. numpy/polynomial/legendre.pyi +100 -0
  632. numpy/polynomial/polynomial.py +1616 -0
  633. numpy/polynomial/polynomial.pyi +89 -0
  634. numpy/polynomial/polyutils.py +759 -0
  635. numpy/polynomial/polyutils.pyi +423 -0
  636. numpy/polynomial/tests/__init__.py +0 -0
  637. numpy/polynomial/tests/test_chebyshev.py +623 -0
  638. numpy/polynomial/tests/test_classes.py +618 -0
  639. numpy/polynomial/tests/test_hermite.py +558 -0
  640. numpy/polynomial/tests/test_hermite_e.py +559 -0
  641. numpy/polynomial/tests/test_laguerre.py +540 -0
  642. numpy/polynomial/tests/test_legendre.py +571 -0
  643. numpy/polynomial/tests/test_polynomial.py +669 -0
  644. numpy/polynomial/tests/test_polyutils.py +128 -0
  645. numpy/polynomial/tests/test_printing.py +555 -0
  646. numpy/polynomial/tests/test_symbol.py +217 -0
  647. numpy/py.typed +0 -0
  648. numpy/random/LICENSE.md +71 -0
  649. numpy/random/__init__.pxd +14 -0
  650. numpy/random/__init__.py +213 -0
  651. numpy/random/__init__.pyi +124 -0
  652. numpy/random/_bounded_integers.cpython-313-darwin.so +0 -0
  653. numpy/random/_bounded_integers.pxd +29 -0
  654. numpy/random/_bounded_integers.pyi +1 -0
  655. numpy/random/_common.cpython-313-darwin.so +0 -0
  656. numpy/random/_common.pxd +107 -0
  657. numpy/random/_common.pyi +16 -0
  658. numpy/random/_examples/cffi/extending.py +44 -0
  659. numpy/random/_examples/cffi/parse.py +53 -0
  660. numpy/random/_examples/cython/extending.pyx +77 -0
  661. numpy/random/_examples/cython/extending_distributions.pyx +118 -0
  662. numpy/random/_examples/cython/meson.build +53 -0
  663. numpy/random/_examples/numba/extending.py +86 -0
  664. numpy/random/_examples/numba/extending_distributions.py +67 -0
  665. numpy/random/_generator.cpython-313-darwin.so +0 -0
  666. numpy/random/_generator.pyi +861 -0
  667. numpy/random/_mt19937.cpython-313-darwin.so +0 -0
  668. numpy/random/_mt19937.pyi +25 -0
  669. numpy/random/_pcg64.cpython-313-darwin.so +0 -0
  670. numpy/random/_pcg64.pyi +44 -0
  671. numpy/random/_philox.cpython-313-darwin.so +0 -0
  672. numpy/random/_philox.pyi +39 -0
  673. numpy/random/_pickle.py +88 -0
  674. numpy/random/_pickle.pyi +43 -0
  675. numpy/random/_sfc64.cpython-313-darwin.so +0 -0
  676. numpy/random/_sfc64.pyi +28 -0
  677. numpy/random/bit_generator.cpython-313-darwin.so +0 -0
  678. numpy/random/bit_generator.pxd +35 -0
  679. numpy/random/bit_generator.pyi +124 -0
  680. numpy/random/c_distributions.pxd +119 -0
  681. numpy/random/lib/libnpyrandom.a +0 -0
  682. numpy/random/mtrand.cpython-313-darwin.so +0 -0
  683. numpy/random/mtrand.pyi +703 -0
  684. numpy/random/tests/__init__.py +0 -0
  685. numpy/random/tests/data/__init__.py +0 -0
  686. numpy/random/tests/data/generator_pcg64_np121.pkl.gz +0 -0
  687. numpy/random/tests/data/generator_pcg64_np126.pkl.gz +0 -0
  688. numpy/random/tests/data/mt19937-testset-1.csv +1001 -0
  689. numpy/random/tests/data/mt19937-testset-2.csv +1001 -0
  690. numpy/random/tests/data/pcg64-testset-1.csv +1001 -0
  691. numpy/random/tests/data/pcg64-testset-2.csv +1001 -0
  692. numpy/random/tests/data/pcg64dxsm-testset-1.csv +1001 -0
  693. numpy/random/tests/data/pcg64dxsm-testset-2.csv +1001 -0
  694. numpy/random/tests/data/philox-testset-1.csv +1001 -0
  695. numpy/random/tests/data/philox-testset-2.csv +1001 -0
  696. numpy/random/tests/data/sfc64-testset-1.csv +1001 -0
  697. numpy/random/tests/data/sfc64-testset-2.csv +1001 -0
  698. numpy/random/tests/data/sfc64_np126.pkl.gz +0 -0
  699. numpy/random/tests/test_direct.py +592 -0
  700. numpy/random/tests/test_extending.py +127 -0
  701. numpy/random/tests/test_generator_mt19937.py +2809 -0
  702. numpy/random/tests/test_generator_mt19937_regressions.py +207 -0
  703. numpy/random/tests/test_random.py +1757 -0
  704. numpy/random/tests/test_randomstate.py +2130 -0
  705. numpy/random/tests/test_randomstate_regression.py +217 -0
  706. numpy/random/tests/test_regression.py +152 -0
  707. numpy/random/tests/test_seed_sequence.py +79 -0
  708. numpy/random/tests/test_smoke.py +819 -0
  709. numpy/rec/__init__.py +2 -0
  710. numpy/rec/__init__.pyi +23 -0
  711. numpy/strings/__init__.py +2 -0
  712. numpy/strings/__init__.pyi +97 -0
  713. numpy/testing/__init__.py +22 -0
  714. numpy/testing/__init__.pyi +102 -0
  715. numpy/testing/_private/__init__.py +0 -0
  716. numpy/testing/_private/__init__.pyi +0 -0
  717. numpy/testing/_private/extbuild.py +250 -0
  718. numpy/testing/_private/extbuild.pyi +25 -0
  719. numpy/testing/_private/utils.py +2752 -0
  720. numpy/testing/_private/utils.pyi +499 -0
  721. numpy/testing/overrides.py +84 -0
  722. numpy/testing/overrides.pyi +11 -0
  723. numpy/testing/print_coercion_tables.py +207 -0
  724. numpy/testing/print_coercion_tables.pyi +27 -0
  725. numpy/testing/tests/__init__.py +0 -0
  726. numpy/testing/tests/test_utils.py +1917 -0
  727. numpy/tests/__init__.py +0 -0
  728. numpy/tests/test__all__.py +10 -0
  729. numpy/tests/test_configtool.py +48 -0
  730. numpy/tests/test_ctypeslib.py +377 -0
  731. numpy/tests/test_lazyloading.py +38 -0
  732. numpy/tests/test_matlib.py +59 -0
  733. numpy/tests/test_numpy_config.py +46 -0
  734. numpy/tests/test_numpy_version.py +54 -0
  735. numpy/tests/test_public_api.py +806 -0
  736. numpy/tests/test_reloading.py +74 -0
  737. numpy/tests/test_scripts.py +49 -0
  738. numpy/tests/test_warnings.py +78 -0
  739. numpy/typing/__init__.py +201 -0
  740. numpy/typing/mypy_plugin.py +195 -0
  741. numpy/typing/tests/__init__.py +0 -0
  742. numpy/typing/tests/data/fail/arithmetic.pyi +126 -0
  743. numpy/typing/tests/data/fail/array_constructors.pyi +34 -0
  744. numpy/typing/tests/data/fail/array_like.pyi +15 -0
  745. numpy/typing/tests/data/fail/array_pad.pyi +6 -0
  746. numpy/typing/tests/data/fail/arrayprint.pyi +16 -0
  747. numpy/typing/tests/data/fail/arrayterator.pyi +14 -0
  748. numpy/typing/tests/data/fail/bitwise_ops.pyi +17 -0
  749. numpy/typing/tests/data/fail/char.pyi +65 -0
  750. numpy/typing/tests/data/fail/chararray.pyi +62 -0
  751. numpy/typing/tests/data/fail/comparisons.pyi +27 -0
  752. numpy/typing/tests/data/fail/constants.pyi +3 -0
  753. numpy/typing/tests/data/fail/datasource.pyi +15 -0
  754. numpy/typing/tests/data/fail/dtype.pyi +17 -0
  755. numpy/typing/tests/data/fail/einsumfunc.pyi +12 -0
  756. numpy/typing/tests/data/fail/flatiter.pyi +20 -0
  757. numpy/typing/tests/data/fail/fromnumeric.pyi +148 -0
  758. numpy/typing/tests/data/fail/histograms.pyi +12 -0
  759. numpy/typing/tests/data/fail/index_tricks.pyi +14 -0
  760. numpy/typing/tests/data/fail/lib_function_base.pyi +62 -0
  761. numpy/typing/tests/data/fail/lib_polynomial.pyi +29 -0
  762. numpy/typing/tests/data/fail/lib_utils.pyi +3 -0
  763. numpy/typing/tests/data/fail/lib_version.pyi +6 -0
  764. numpy/typing/tests/data/fail/linalg.pyi +48 -0
  765. numpy/typing/tests/data/fail/ma.pyi +143 -0
  766. numpy/typing/tests/data/fail/memmap.pyi +5 -0
  767. numpy/typing/tests/data/fail/modules.pyi +17 -0
  768. numpy/typing/tests/data/fail/multiarray.pyi +52 -0
  769. numpy/typing/tests/data/fail/ndarray.pyi +11 -0
  770. numpy/typing/tests/data/fail/ndarray_misc.pyi +36 -0
  771. numpy/typing/tests/data/fail/nditer.pyi +8 -0
  772. numpy/typing/tests/data/fail/nested_sequence.pyi +16 -0
  773. numpy/typing/tests/data/fail/npyio.pyi +24 -0
  774. numpy/typing/tests/data/fail/numerictypes.pyi +5 -0
  775. numpy/typing/tests/data/fail/random.pyi +62 -0
  776. numpy/typing/tests/data/fail/rec.pyi +17 -0
  777. numpy/typing/tests/data/fail/scalars.pyi +87 -0
  778. numpy/typing/tests/data/fail/shape.pyi +6 -0
  779. numpy/typing/tests/data/fail/shape_base.pyi +8 -0
  780. numpy/typing/tests/data/fail/stride_tricks.pyi +9 -0
  781. numpy/typing/tests/data/fail/strings.pyi +52 -0
  782. numpy/typing/tests/data/fail/testing.pyi +28 -0
  783. numpy/typing/tests/data/fail/twodim_base.pyi +32 -0
  784. numpy/typing/tests/data/fail/type_check.pyi +13 -0
  785. numpy/typing/tests/data/fail/ufunc_config.pyi +21 -0
  786. numpy/typing/tests/data/fail/ufunclike.pyi +21 -0
  787. numpy/typing/tests/data/fail/ufuncs.pyi +17 -0
  788. numpy/typing/tests/data/fail/warnings_and_errors.pyi +5 -0
  789. numpy/typing/tests/data/misc/extended_precision.pyi +9 -0
  790. numpy/typing/tests/data/mypy.ini +9 -0
  791. numpy/typing/tests/data/pass/arithmetic.py +612 -0
  792. numpy/typing/tests/data/pass/array_constructors.py +137 -0
  793. numpy/typing/tests/data/pass/array_like.py +43 -0
  794. numpy/typing/tests/data/pass/arrayprint.py +37 -0
  795. numpy/typing/tests/data/pass/arrayterator.py +27 -0
  796. numpy/typing/tests/data/pass/bitwise_ops.py +131 -0
  797. numpy/typing/tests/data/pass/comparisons.py +315 -0
  798. numpy/typing/tests/data/pass/dtype.py +57 -0
  799. numpy/typing/tests/data/pass/einsumfunc.py +36 -0
  800. numpy/typing/tests/data/pass/flatiter.py +19 -0
  801. numpy/typing/tests/data/pass/fromnumeric.py +272 -0
  802. numpy/typing/tests/data/pass/index_tricks.py +60 -0
  803. numpy/typing/tests/data/pass/lib_user_array.py +22 -0
  804. numpy/typing/tests/data/pass/lib_utils.py +19 -0
  805. numpy/typing/tests/data/pass/lib_version.py +18 -0
  806. numpy/typing/tests/data/pass/literal.py +51 -0
  807. numpy/typing/tests/data/pass/ma.py +174 -0
  808. numpy/typing/tests/data/pass/mod.py +149 -0
  809. numpy/typing/tests/data/pass/modules.py +45 -0
  810. numpy/typing/tests/data/pass/multiarray.py +76 -0
  811. numpy/typing/tests/data/pass/ndarray_conversion.py +87 -0
  812. numpy/typing/tests/data/pass/ndarray_misc.py +203 -0
  813. numpy/typing/tests/data/pass/ndarray_shape_manipulation.py +47 -0
  814. numpy/typing/tests/data/pass/nditer.py +4 -0
  815. numpy/typing/tests/data/pass/numeric.py +95 -0
  816. numpy/typing/tests/data/pass/numerictypes.py +17 -0
  817. numpy/typing/tests/data/pass/random.py +1497 -0
  818. numpy/typing/tests/data/pass/recfunctions.py +161 -0
  819. numpy/typing/tests/data/pass/scalars.py +248 -0
  820. numpy/typing/tests/data/pass/shape.py +19 -0
  821. numpy/typing/tests/data/pass/simple.py +168 -0
  822. numpy/typing/tests/data/pass/simple_py3.py +6 -0
  823. numpy/typing/tests/data/pass/ufunc_config.py +64 -0
  824. numpy/typing/tests/data/pass/ufunclike.py +47 -0
  825. numpy/typing/tests/data/pass/ufuncs.py +16 -0
  826. numpy/typing/tests/data/pass/warnings_and_errors.py +6 -0
  827. numpy/typing/tests/data/reveal/arithmetic.pyi +720 -0
  828. numpy/typing/tests/data/reveal/array_api_info.pyi +70 -0
  829. numpy/typing/tests/data/reveal/array_constructors.pyi +249 -0
  830. numpy/typing/tests/data/reveal/arraypad.pyi +22 -0
  831. numpy/typing/tests/data/reveal/arrayprint.pyi +25 -0
  832. numpy/typing/tests/data/reveal/arraysetops.pyi +74 -0
  833. numpy/typing/tests/data/reveal/arrayterator.pyi +27 -0
  834. numpy/typing/tests/data/reveal/bitwise_ops.pyi +167 -0
  835. numpy/typing/tests/data/reveal/char.pyi +224 -0
  836. numpy/typing/tests/data/reveal/chararray.pyi +137 -0
  837. numpy/typing/tests/data/reveal/comparisons.pyi +264 -0
  838. numpy/typing/tests/data/reveal/constants.pyi +14 -0
  839. numpy/typing/tests/data/reveal/ctypeslib.pyi +81 -0
  840. numpy/typing/tests/data/reveal/datasource.pyi +23 -0
  841. numpy/typing/tests/data/reveal/dtype.pyi +136 -0
  842. numpy/typing/tests/data/reveal/einsumfunc.pyi +39 -0
  843. numpy/typing/tests/data/reveal/emath.pyi +54 -0
  844. numpy/typing/tests/data/reveal/fft.pyi +37 -0
  845. numpy/typing/tests/data/reveal/flatiter.pyi +47 -0
  846. numpy/typing/tests/data/reveal/fromnumeric.pyi +347 -0
  847. numpy/typing/tests/data/reveal/getlimits.pyi +51 -0
  848. numpy/typing/tests/data/reveal/histograms.pyi +25 -0
  849. numpy/typing/tests/data/reveal/index_tricks.pyi +70 -0
  850. numpy/typing/tests/data/reveal/lib_function_base.pyi +213 -0
  851. numpy/typing/tests/data/reveal/lib_polynomial.pyi +144 -0
  852. numpy/typing/tests/data/reveal/lib_utils.pyi +17 -0
  853. numpy/typing/tests/data/reveal/lib_version.pyi +20 -0
  854. numpy/typing/tests/data/reveal/linalg.pyi +132 -0
  855. numpy/typing/tests/data/reveal/ma.pyi +369 -0
  856. numpy/typing/tests/data/reveal/matrix.pyi +73 -0
  857. numpy/typing/tests/data/reveal/memmap.pyi +19 -0
  858. numpy/typing/tests/data/reveal/mod.pyi +179 -0
  859. numpy/typing/tests/data/reveal/modules.pyi +51 -0
  860. numpy/typing/tests/data/reveal/multiarray.pyi +194 -0
  861. numpy/typing/tests/data/reveal/nbit_base_example.pyi +21 -0
  862. numpy/typing/tests/data/reveal/ndarray_assignability.pyi +77 -0
  863. numpy/typing/tests/data/reveal/ndarray_conversion.pyi +85 -0
  864. numpy/typing/tests/data/reveal/ndarray_misc.pyi +247 -0
  865. numpy/typing/tests/data/reveal/ndarray_shape_manipulation.pyi +39 -0
  866. numpy/typing/tests/data/reveal/nditer.pyi +49 -0
  867. numpy/typing/tests/data/reveal/nested_sequence.pyi +25 -0
  868. numpy/typing/tests/data/reveal/npyio.pyi +83 -0
  869. numpy/typing/tests/data/reveal/numeric.pyi +134 -0
  870. numpy/typing/tests/data/reveal/numerictypes.pyi +16 -0
  871. numpy/typing/tests/data/reveal/polynomial_polybase.pyi +220 -0
  872. numpy/typing/tests/data/reveal/polynomial_polyutils.pyi +219 -0
  873. numpy/typing/tests/data/reveal/polynomial_series.pyi +138 -0
  874. numpy/typing/tests/data/reveal/random.pyi +1546 -0
  875. numpy/typing/tests/data/reveal/rec.pyi +171 -0
  876. numpy/typing/tests/data/reveal/scalars.pyi +191 -0
  877. numpy/typing/tests/data/reveal/shape.pyi +13 -0
  878. numpy/typing/tests/data/reveal/shape_base.pyi +52 -0
  879. numpy/typing/tests/data/reveal/stride_tricks.pyi +27 -0
  880. numpy/typing/tests/data/reveal/strings.pyi +196 -0
  881. numpy/typing/tests/data/reveal/testing.pyi +198 -0
  882. numpy/typing/tests/data/reveal/twodim_base.pyi +145 -0
  883. numpy/typing/tests/data/reveal/type_check.pyi +67 -0
  884. numpy/typing/tests/data/reveal/ufunc_config.pyi +30 -0
  885. numpy/typing/tests/data/reveal/ufunclike.pyi +31 -0
  886. numpy/typing/tests/data/reveal/ufuncs.pyi +123 -0
  887. numpy/typing/tests/data/reveal/warnings_and_errors.pyi +11 -0
  888. numpy/typing/tests/test_isfile.py +32 -0
  889. numpy/typing/tests/test_runtime.py +102 -0
  890. numpy/typing/tests/test_typing.py +205 -0
  891. numpy/version.py +11 -0
  892. numpy/version.pyi +18 -0
  893. numpy-2.3.5.dist-info/LICENSE.txt +971 -0
  894. numpy-2.3.5.dist-info/METADATA +1093 -0
  895. numpy-2.3.5.dist-info/RECORD +897 -0
  896. numpy-2.3.5.dist-info/WHEEL +6 -0
  897. numpy-2.3.5.dist-info/entry_points.txt +13 -0
@@ -0,0 +1,1498 @@
1
+ """
2
+ Implementation of optimized einsum.
3
+
4
+ """
5
+ import itertools
6
+ import operator
7
+
8
+ from numpy._core.multiarray import c_einsum
9
+ from numpy._core.numeric import asanyarray, tensordot
10
+ from numpy._core.overrides import array_function_dispatch
11
+
12
+ __all__ = ['einsum', 'einsum_path']
13
+
14
+ # importing string for string.ascii_letters would be too slow
15
+ # the first import before caching has been measured to take 800 µs (#23777)
16
+ # imports begin with uppercase to mimic ASCII values to avoid sorting issues
17
+ einsum_symbols = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
18
+ einsum_symbols_set = set(einsum_symbols)
19
+
20
+
21
+ def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
22
+ """
23
+ Computes the number of FLOPS in the contraction.
24
+
25
+ Parameters
26
+ ----------
27
+ idx_contraction : iterable
28
+ The indices involved in the contraction
29
+ inner : bool
30
+ Does this contraction require an inner product?
31
+ num_terms : int
32
+ The number of terms in a contraction
33
+ size_dictionary : dict
34
+ The size of each of the indices in idx_contraction
35
+
36
+ Returns
37
+ -------
38
+ flop_count : int
39
+ The total number of FLOPS required for the contraction.
40
+
41
+ Examples
42
+ --------
43
+
44
+ >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
45
+ 30
46
+
47
+ >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
48
+ 60
49
+
50
+ """
51
+
52
+ overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
53
+ op_factor = max(1, num_terms - 1)
54
+ if inner:
55
+ op_factor += 1
56
+
57
+ return overall_size * op_factor
58
+
59
+ def _compute_size_by_dict(indices, idx_dict):
60
+ """
61
+ Computes the product of the elements in indices based on the dictionary
62
+ idx_dict.
63
+
64
+ Parameters
65
+ ----------
66
+ indices : iterable
67
+ Indices to base the product on.
68
+ idx_dict : dictionary
69
+ Dictionary of index sizes
70
+
71
+ Returns
72
+ -------
73
+ ret : int
74
+ The resulting product.
75
+
76
+ Examples
77
+ --------
78
+ >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
79
+ 90
80
+
81
+ """
82
+ ret = 1
83
+ for i in indices:
84
+ ret *= idx_dict[i]
85
+ return ret
86
+
87
+
88
+ def _find_contraction(positions, input_sets, output_set):
89
+ """
90
+ Finds the contraction for a given set of input and output sets.
91
+
92
+ Parameters
93
+ ----------
94
+ positions : iterable
95
+ Integer positions of terms used in the contraction.
96
+ input_sets : list
97
+ List of sets that represent the lhs side of the einsum subscript
98
+ output_set : set
99
+ Set that represents the rhs side of the overall einsum subscript
100
+
101
+ Returns
102
+ -------
103
+ new_result : set
104
+ The indices of the resulting contraction
105
+ remaining : list
106
+ List of sets that have not been contracted, the new set is appended to
107
+ the end of this list
108
+ idx_removed : set
109
+ Indices removed from the entire contraction
110
+ idx_contraction : set
111
+ The indices used in the current contraction
112
+
113
+ Examples
114
+ --------
115
+
116
+ # A simple dot product test case
117
+ >>> pos = (0, 1)
118
+ >>> isets = [set('ab'), set('bc')]
119
+ >>> oset = set('ac')
120
+ >>> _find_contraction(pos, isets, oset)
121
+ ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
122
+
123
+ # A more complex case with additional terms in the contraction
124
+ >>> pos = (0, 2)
125
+ >>> isets = [set('abd'), set('ac'), set('bdc')]
126
+ >>> oset = set('ac')
127
+ >>> _find_contraction(pos, isets, oset)
128
+ ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
129
+ """
130
+
131
+ idx_contract = set()
132
+ idx_remain = output_set.copy()
133
+ remaining = []
134
+ for ind, value in enumerate(input_sets):
135
+ if ind in positions:
136
+ idx_contract |= value
137
+ else:
138
+ remaining.append(value)
139
+ idx_remain |= value
140
+
141
+ new_result = idx_remain & idx_contract
142
+ idx_removed = (idx_contract - new_result)
143
+ remaining.append(new_result)
144
+
145
+ return (new_result, remaining, idx_removed, idx_contract)
146
+
147
+
148
+ def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
149
+ """
150
+ Computes all possible pair contractions, sieves the results based
151
+ on ``memory_limit`` and returns the lowest cost path. This algorithm
152
+ scales factorial with respect to the elements in the list ``input_sets``.
153
+
154
+ Parameters
155
+ ----------
156
+ input_sets : list
157
+ List of sets that represent the lhs side of the einsum subscript
158
+ output_set : set
159
+ Set that represents the rhs side of the overall einsum subscript
160
+ idx_dict : dictionary
161
+ Dictionary of index sizes
162
+ memory_limit : int
163
+ The maximum number of elements in a temporary array
164
+
165
+ Returns
166
+ -------
167
+ path : list
168
+ The optimal contraction order within the memory limit constraint.
169
+
170
+ Examples
171
+ --------
172
+ >>> isets = [set('abd'), set('ac'), set('bdc')]
173
+ >>> oset = set()
174
+ >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
175
+ >>> _optimal_path(isets, oset, idx_sizes, 5000)
176
+ [(0, 2), (0, 1)]
177
+ """
178
+
179
+ full_results = [(0, [], input_sets)]
180
+ for iteration in range(len(input_sets) - 1):
181
+ iter_results = []
182
+
183
+ # Compute all unique pairs
184
+ for curr in full_results:
185
+ cost, positions, remaining = curr
186
+ for con in itertools.combinations(
187
+ range(len(input_sets) - iteration), 2
188
+ ):
189
+
190
+ # Find the contraction
191
+ cont = _find_contraction(con, remaining, output_set)
192
+ new_result, new_input_sets, idx_removed, idx_contract = cont
193
+
194
+ # Sieve the results based on memory_limit
195
+ new_size = _compute_size_by_dict(new_result, idx_dict)
196
+ if new_size > memory_limit:
197
+ continue
198
+
199
+ # Build (total_cost, positions, indices_remaining)
200
+ total_cost = cost + _flop_count(
201
+ idx_contract, idx_removed, len(con), idx_dict
202
+ )
203
+ new_pos = positions + [con]
204
+ iter_results.append((total_cost, new_pos, new_input_sets))
205
+
206
+ # Update combinatorial list, if we did not find anything return best
207
+ # path + remaining contractions
208
+ if iter_results:
209
+ full_results = iter_results
210
+ else:
211
+ path = min(full_results, key=lambda x: x[0])[1]
212
+ path += [tuple(range(len(input_sets) - iteration))]
213
+ return path
214
+
215
+ # If we have not found anything return single einsum contraction
216
+ if len(full_results) == 0:
217
+ return [tuple(range(len(input_sets)))]
218
+
219
+ path = min(full_results, key=lambda x: x[0])[1]
220
+ return path
221
+
222
+ def _parse_possible_contraction(
223
+ positions, input_sets, output_set, idx_dict,
224
+ memory_limit, path_cost, naive_cost
225
+ ):
226
+ """Compute the cost (removed size + flops) and resultant indices for
227
+ performing the contraction specified by ``positions``.
228
+
229
+ Parameters
230
+ ----------
231
+ positions : tuple of int
232
+ The locations of the proposed tensors to contract.
233
+ input_sets : list of sets
234
+ The indices found on each tensors.
235
+ output_set : set
236
+ The output indices of the expression.
237
+ idx_dict : dict
238
+ Mapping of each index to its size.
239
+ memory_limit : int
240
+ The total allowed size for an intermediary tensor.
241
+ path_cost : int
242
+ The contraction cost so far.
243
+ naive_cost : int
244
+ The cost of the unoptimized expression.
245
+
246
+ Returns
247
+ -------
248
+ cost : (int, int)
249
+ A tuple containing the size of any indices removed, and the flop cost.
250
+ positions : tuple of int
251
+ The locations of the proposed tensors to contract.
252
+ new_input_sets : list of sets
253
+ The resulting new list of indices if this proposed contraction
254
+ is performed.
255
+
256
+ """
257
+
258
+ # Find the contraction
259
+ contract = _find_contraction(positions, input_sets, output_set)
260
+ idx_result, new_input_sets, idx_removed, idx_contract = contract
261
+
262
+ # Sieve the results based on memory_limit
263
+ new_size = _compute_size_by_dict(idx_result, idx_dict)
264
+ if new_size > memory_limit:
265
+ return None
266
+
267
+ # Build sort tuple
268
+ old_sizes = (
269
+ _compute_size_by_dict(input_sets[p], idx_dict) for p in positions
270
+ )
271
+ removed_size = sum(old_sizes) - new_size
272
+
273
+ # NB: removed_size used to be just the size of any removed indices i.e.:
274
+ # helpers.compute_size_by_dict(idx_removed, idx_dict)
275
+ cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
276
+ sort = (-removed_size, cost)
277
+
278
+ # Sieve based on total cost as well
279
+ if (path_cost + cost) > naive_cost:
280
+ return None
281
+
282
+ # Add contraction to possible choices
283
+ return [sort, positions, new_input_sets]
284
+
285
+
286
+ def _update_other_results(results, best):
287
+ """Update the positions and provisional input_sets of ``results``
288
+ based on performing the contraction result ``best``. Remove any
289
+ involving the tensors contracted.
290
+
291
+ Parameters
292
+ ----------
293
+ results : list
294
+ List of contraction results produced by
295
+ ``_parse_possible_contraction``.
296
+ best : list
297
+ The best contraction of ``results`` i.e. the one that
298
+ will be performed.
299
+
300
+ Returns
301
+ -------
302
+ mod_results : list
303
+ The list of modified results, updated with outcome of
304
+ ``best`` contraction.
305
+ """
306
+
307
+ best_con = best[1]
308
+ bx, by = best_con
309
+ mod_results = []
310
+
311
+ for cost, (x, y), con_sets in results:
312
+
313
+ # Ignore results involving tensors just contracted
314
+ if x in best_con or y in best_con:
315
+ continue
316
+
317
+ # Update the input_sets
318
+ del con_sets[by - int(by > x) - int(by > y)]
319
+ del con_sets[bx - int(bx > x) - int(bx > y)]
320
+ con_sets.insert(-1, best[2][-1])
321
+
322
+ # Update the position indices
323
+ mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
324
+ mod_results.append((cost, mod_con, con_sets))
325
+
326
+ return mod_results
327
+
328
+ def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
329
+ """
330
+ Finds the path by contracting the best pair until the input list is
331
+ exhausted. The best pair is found by minimizing the tuple
332
+ ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
333
+ matrix multiplication or inner product operations, then Hadamard like
334
+ operations, and finally outer operations. Outer products are limited by
335
+ ``memory_limit``. This algorithm scales cubically with respect to the
336
+ number of elements in the list ``input_sets``.
337
+
338
+ Parameters
339
+ ----------
340
+ input_sets : list
341
+ List of sets that represent the lhs side of the einsum subscript
342
+ output_set : set
343
+ Set that represents the rhs side of the overall einsum subscript
344
+ idx_dict : dictionary
345
+ Dictionary of index sizes
346
+ memory_limit : int
347
+ The maximum number of elements in a temporary array
348
+
349
+ Returns
350
+ -------
351
+ path : list
352
+ The greedy contraction order within the memory limit constraint.
353
+
354
+ Examples
355
+ --------
356
+ >>> isets = [set('abd'), set('ac'), set('bdc')]
357
+ >>> oset = set()
358
+ >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
359
+ >>> _greedy_path(isets, oset, idx_sizes, 5000)
360
+ [(0, 2), (0, 1)]
361
+ """
362
+
363
+ # Handle trivial cases that leaked through
364
+ if len(input_sets) == 1:
365
+ return [(0,)]
366
+ elif len(input_sets) == 2:
367
+ return [(0, 1)]
368
+
369
+ # Build up a naive cost
370
+ contract = _find_contraction(
371
+ range(len(input_sets)), input_sets, output_set
372
+ )
373
+ idx_result, new_input_sets, idx_removed, idx_contract = contract
374
+ naive_cost = _flop_count(
375
+ idx_contract, idx_removed, len(input_sets), idx_dict
376
+ )
377
+
378
+ # Initially iterate over all pairs
379
+ comb_iter = itertools.combinations(range(len(input_sets)), 2)
380
+ known_contractions = []
381
+
382
+ path_cost = 0
383
+ path = []
384
+
385
+ for iteration in range(len(input_sets) - 1):
386
+
387
+ # Iterate over all pairs on the first step, only previously
388
+ # found pairs on subsequent steps
389
+ for positions in comb_iter:
390
+
391
+ # Always initially ignore outer products
392
+ if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
393
+ continue
394
+
395
+ result = _parse_possible_contraction(
396
+ positions, input_sets, output_set, idx_dict,
397
+ memory_limit, path_cost, naive_cost
398
+ )
399
+ if result is not None:
400
+ known_contractions.append(result)
401
+
402
+ # If we do not have a inner contraction, rescan pairs
403
+ # including outer products
404
+ if len(known_contractions) == 0:
405
+
406
+ # Then check the outer products
407
+ for positions in itertools.combinations(
408
+ range(len(input_sets)), 2
409
+ ):
410
+ result = _parse_possible_contraction(
411
+ positions, input_sets, output_set, idx_dict,
412
+ memory_limit, path_cost, naive_cost
413
+ )
414
+ if result is not None:
415
+ known_contractions.append(result)
416
+
417
+ # If we still did not find any remaining contractions,
418
+ # default back to einsum like behavior
419
+ if len(known_contractions) == 0:
420
+ path.append(tuple(range(len(input_sets))))
421
+ break
422
+
423
+ # Sort based on first index
424
+ best = min(known_contractions, key=lambda x: x[0])
425
+
426
+ # Now propagate as many unused contractions as possible
427
+ # to the next iteration
428
+ known_contractions = _update_other_results(known_contractions, best)
429
+
430
+ # Next iteration only compute contractions with the new tensor
431
+ # All other contractions have been accounted for
432
+ input_sets = best[2]
433
+ new_tensor_pos = len(input_sets) - 1
434
+ comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
435
+
436
+ # Update path and total cost
437
+ path.append(best[1])
438
+ path_cost += best[0][1]
439
+
440
+ return path
441
+
442
+
443
+ def _can_dot(inputs, result, idx_removed):
444
+ """
445
+ Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
446
+
447
+ Parameters
448
+ ----------
449
+ inputs : list of str
450
+ Specifies the subscripts for summation.
451
+ result : str
452
+ Resulting summation.
453
+ idx_removed : set
454
+ Indices that are removed in the summation
455
+
456
+
457
+ Returns
458
+ -------
459
+ type : bool
460
+ Returns true if BLAS should and can be used, else False
461
+
462
+ Notes
463
+ -----
464
+ If the operations is BLAS level 1 or 2 and is not already aligned
465
+ we default back to einsum as the memory movement to copy is more
466
+ costly than the operation itself.
467
+
468
+
469
+ Examples
470
+ --------
471
+
472
+ # Standard GEMM operation
473
+ >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
474
+ True
475
+
476
+ # Can use the standard BLAS, but requires odd data movement
477
+ >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
478
+ False
479
+
480
+ # DDOT where the memory is not aligned
481
+ >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
482
+ False
483
+
484
+ """
485
+
486
+ # All `dot` calls remove indices
487
+ if len(idx_removed) == 0:
488
+ return False
489
+
490
+ # BLAS can only handle two operands
491
+ if len(inputs) != 2:
492
+ return False
493
+
494
+ input_left, input_right = inputs
495
+
496
+ for c in set(input_left + input_right):
497
+ # can't deal with repeated indices on same input or more than 2 total
498
+ nl, nr = input_left.count(c), input_right.count(c)
499
+ if (nl > 1) or (nr > 1) or (nl + nr > 2):
500
+ return False
501
+
502
+ # can't do implicit summation or dimension collapse e.g.
503
+ # "ab,bc->c" (implicitly sum over 'a')
504
+ # "ab,ca->ca" (take diagonal of 'a')
505
+ if nl + nr - 1 == int(c in result):
506
+ return False
507
+
508
+ # Build a few temporaries
509
+ set_left = set(input_left)
510
+ set_right = set(input_right)
511
+ keep_left = set_left - idx_removed
512
+ keep_right = set_right - idx_removed
513
+ rs = len(idx_removed)
514
+
515
+ # At this point we are a DOT, GEMV, or GEMM operation
516
+
517
+ # Handle inner products
518
+
519
+ # DDOT with aligned data
520
+ if input_left == input_right:
521
+ return True
522
+
523
+ # DDOT without aligned data (better to use einsum)
524
+ if set_left == set_right:
525
+ return False
526
+
527
+ # Handle the 4 possible (aligned) GEMV or GEMM cases
528
+
529
+ # GEMM or GEMV no transpose
530
+ if input_left[-rs:] == input_right[:rs]:
531
+ return True
532
+
533
+ # GEMM or GEMV transpose both
534
+ if input_left[:rs] == input_right[-rs:]:
535
+ return True
536
+
537
+ # GEMM or GEMV transpose right
538
+ if input_left[-rs:] == input_right[-rs:]:
539
+ return True
540
+
541
+ # GEMM or GEMV transpose left
542
+ if input_left[:rs] == input_right[:rs]:
543
+ return True
544
+
545
+ # Einsum is faster than GEMV if we have to copy data
546
+ if not keep_left or not keep_right:
547
+ return False
548
+
549
+ # We are a matrix-matrix product, but we need to copy data
550
+ return True
551
+
552
+
553
+ def _parse_einsum_input(operands):
554
+ """
555
+ A reproduction of einsum c side einsum parsing in python.
556
+
557
+ Returns
558
+ -------
559
+ input_strings : str
560
+ Parsed input strings
561
+ output_string : str
562
+ Parsed output string
563
+ operands : list of array_like
564
+ The operands to use in the numpy contraction
565
+
566
+ Examples
567
+ --------
568
+ The operand list is simplified to reduce printing:
569
+
570
+ >>> np.random.seed(123)
571
+ >>> a = np.random.rand(4, 4)
572
+ >>> b = np.random.rand(4, 4, 4)
573
+ >>> _parse_einsum_input(('...a,...a->...', a, b))
574
+ ('za,xza', 'xz', [a, b]) # may vary
575
+
576
+ >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
577
+ ('za,xza', 'xz', [a, b]) # may vary
578
+ """
579
+
580
+ if len(operands) == 0:
581
+ raise ValueError("No input operands")
582
+
583
+ if isinstance(operands[0], str):
584
+ subscripts = operands[0].replace(" ", "")
585
+ operands = [asanyarray(v) for v in operands[1:]]
586
+
587
+ # Ensure all characters are valid
588
+ for s in subscripts:
589
+ if s in '.,->':
590
+ continue
591
+ if s not in einsum_symbols:
592
+ raise ValueError(f"Character {s} is not a valid symbol.")
593
+
594
+ else:
595
+ tmp_operands = list(operands)
596
+ operand_list = []
597
+ subscript_list = []
598
+ for p in range(len(operands) // 2):
599
+ operand_list.append(tmp_operands.pop(0))
600
+ subscript_list.append(tmp_operands.pop(0))
601
+
602
+ output_list = tmp_operands[-1] if len(tmp_operands) else None
603
+ operands = [asanyarray(v) for v in operand_list]
604
+ subscripts = ""
605
+ last = len(subscript_list) - 1
606
+ for num, sub in enumerate(subscript_list):
607
+ for s in sub:
608
+ if s is Ellipsis:
609
+ subscripts += "..."
610
+ else:
611
+ try:
612
+ s = operator.index(s)
613
+ except TypeError as e:
614
+ raise TypeError(
615
+ "For this input type lists must contain "
616
+ "either int or Ellipsis"
617
+ ) from e
618
+ subscripts += einsum_symbols[s]
619
+ if num != last:
620
+ subscripts += ","
621
+
622
+ if output_list is not None:
623
+ subscripts += "->"
624
+ for s in output_list:
625
+ if s is Ellipsis:
626
+ subscripts += "..."
627
+ else:
628
+ try:
629
+ s = operator.index(s)
630
+ except TypeError as e:
631
+ raise TypeError(
632
+ "For this input type lists must contain "
633
+ "either int or Ellipsis"
634
+ ) from e
635
+ subscripts += einsum_symbols[s]
636
+ # Check for proper "->"
637
+ if ("-" in subscripts) or (">" in subscripts):
638
+ invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
639
+ if invalid or (subscripts.count("->") != 1):
640
+ raise ValueError("Subscripts can only contain one '->'.")
641
+
642
+ # Parse ellipses
643
+ if "." in subscripts:
644
+ used = subscripts.replace(".", "").replace(",", "").replace("->", "")
645
+ unused = list(einsum_symbols_set - set(used))
646
+ ellipse_inds = "".join(unused)
647
+ longest = 0
648
+
649
+ if "->" in subscripts:
650
+ input_tmp, output_sub = subscripts.split("->")
651
+ split_subscripts = input_tmp.split(",")
652
+ out_sub = True
653
+ else:
654
+ split_subscripts = subscripts.split(',')
655
+ out_sub = False
656
+
657
+ for num, sub in enumerate(split_subscripts):
658
+ if "." in sub:
659
+ if (sub.count(".") != 3) or (sub.count("...") != 1):
660
+ raise ValueError("Invalid Ellipses.")
661
+
662
+ # Take into account numerical values
663
+ if operands[num].shape == ():
664
+ ellipse_count = 0
665
+ else:
666
+ ellipse_count = max(operands[num].ndim, 1)
667
+ ellipse_count -= (len(sub) - 3)
668
+
669
+ if ellipse_count > longest:
670
+ longest = ellipse_count
671
+
672
+ if ellipse_count < 0:
673
+ raise ValueError("Ellipses lengths do not match.")
674
+ elif ellipse_count == 0:
675
+ split_subscripts[num] = sub.replace('...', '')
676
+ else:
677
+ rep_inds = ellipse_inds[-ellipse_count:]
678
+ split_subscripts[num] = sub.replace('...', rep_inds)
679
+
680
+ subscripts = ",".join(split_subscripts)
681
+ if longest == 0:
682
+ out_ellipse = ""
683
+ else:
684
+ out_ellipse = ellipse_inds[-longest:]
685
+
686
+ if out_sub:
687
+ subscripts += "->" + output_sub.replace("...", out_ellipse)
688
+ else:
689
+ # Special care for outputless ellipses
690
+ output_subscript = ""
691
+ tmp_subscripts = subscripts.replace(",", "")
692
+ for s in sorted(set(tmp_subscripts)):
693
+ if s not in (einsum_symbols):
694
+ raise ValueError(f"Character {s} is not a valid symbol.")
695
+ if tmp_subscripts.count(s) == 1:
696
+ output_subscript += s
697
+ normal_inds = ''.join(sorted(set(output_subscript) -
698
+ set(out_ellipse)))
699
+
700
+ subscripts += "->" + out_ellipse + normal_inds
701
+
702
+ # Build output string if does not exist
703
+ if "->" in subscripts:
704
+ input_subscripts, output_subscript = subscripts.split("->")
705
+ else:
706
+ input_subscripts = subscripts
707
+ # Build output subscripts
708
+ tmp_subscripts = subscripts.replace(",", "")
709
+ output_subscript = ""
710
+ for s in sorted(set(tmp_subscripts)):
711
+ if s not in einsum_symbols:
712
+ raise ValueError(f"Character {s} is not a valid symbol.")
713
+ if tmp_subscripts.count(s) == 1:
714
+ output_subscript += s
715
+
716
+ # Make sure output subscripts are in the input
717
+ for char in output_subscript:
718
+ if output_subscript.count(char) != 1:
719
+ raise ValueError("Output character %s appeared more than once in "
720
+ "the output." % char)
721
+ if char not in input_subscripts:
722
+ raise ValueError(f"Output character {char} did not appear in the input")
723
+
724
+ # Make sure number operands is equivalent to the number of terms
725
+ if len(input_subscripts.split(',')) != len(operands):
726
+ raise ValueError("Number of einsum subscripts must be equal to the "
727
+ "number of operands.")
728
+
729
+ return (input_subscripts, output_subscript, operands)
730
+
731
+
732
+ def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
733
+ # NOTE: technically, we should only dispatch on array-like arguments, not
734
+ # subscripts (given as strings). But separating operands into
735
+ # arrays/subscripts is a little tricky/slow (given einsum's two supported
736
+ # signatures), so as a practical shortcut we dispatch on everything.
737
+ # Strings will be ignored for dispatching since they don't define
738
+ # __array_function__.
739
+ return operands
740
+
741
+
742
+ @array_function_dispatch(_einsum_path_dispatcher, module='numpy')
743
+ def einsum_path(*operands, optimize='greedy', einsum_call=False):
744
+ """
745
+ einsum_path(subscripts, *operands, optimize='greedy')
746
+
747
+ Evaluates the lowest cost contraction order for an einsum expression by
748
+ considering the creation of intermediate arrays.
749
+
750
+ Parameters
751
+ ----------
752
+ subscripts : str
753
+ Specifies the subscripts for summation.
754
+ *operands : list of array_like
755
+ These are the arrays for the operation.
756
+ optimize : {bool, list, tuple, 'greedy', 'optimal'}
757
+ Choose the type of path. If a tuple is provided, the second argument is
758
+ assumed to be the maximum intermediate size created. If only a single
759
+ argument is provided the largest input or output array size is used
760
+ as a maximum intermediate size.
761
+
762
+ * if a list is given that starts with ``einsum_path``, uses this as the
763
+ contraction path
764
+ * if False no optimization is taken
765
+ * if True defaults to the 'greedy' algorithm
766
+ * 'optimal' An algorithm that combinatorially explores all possible
767
+ ways of contracting the listed tensors and chooses the least costly
768
+ path. Scales exponentially with the number of terms in the
769
+ contraction.
770
+ * 'greedy' An algorithm that chooses the best pair contraction
771
+ at each step. Effectively, this algorithm searches the largest inner,
772
+ Hadamard, and then outer products at each step. Scales cubically with
773
+ the number of terms in the contraction. Equivalent to the 'optimal'
774
+ path for most contractions.
775
+
776
+ Default is 'greedy'.
777
+
778
+ Returns
779
+ -------
780
+ path : list of tuples
781
+ A list representation of the einsum path.
782
+ string_repr : str
783
+ A printable representation of the einsum path.
784
+
785
+ Notes
786
+ -----
787
+ The resulting path indicates which terms of the input contraction should be
788
+ contracted first, the result of this contraction is then appended to the
789
+ end of the contraction list. This list can then be iterated over until all
790
+ intermediate contractions are complete.
791
+
792
+ See Also
793
+ --------
794
+ einsum, linalg.multi_dot
795
+
796
+ Examples
797
+ --------
798
+
799
+ We can begin with a chain dot example. In this case, it is optimal to
800
+ contract the ``b`` and ``c`` tensors first as represented by the first
801
+ element of the path ``(1, 2)``. The resulting tensor is added to the end
802
+ of the contraction and the remaining contraction ``(0, 1)`` is then
803
+ completed.
804
+
805
+ >>> np.random.seed(123)
806
+ >>> a = np.random.rand(2, 2)
807
+ >>> b = np.random.rand(2, 5)
808
+ >>> c = np.random.rand(5, 2)
809
+ >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
810
+ >>> print(path_info[0])
811
+ ['einsum_path', (1, 2), (0, 1)]
812
+ >>> print(path_info[1])
813
+ Complete contraction: ij,jk,kl->il # may vary
814
+ Naive scaling: 4
815
+ Optimized scaling: 3
816
+ Naive FLOP count: 1.600e+02
817
+ Optimized FLOP count: 5.600e+01
818
+ Theoretical speedup: 2.857
819
+ Largest intermediate: 4.000e+00 elements
820
+ -------------------------------------------------------------------------
821
+ scaling current remaining
822
+ -------------------------------------------------------------------------
823
+ 3 kl,jk->jl ij,jl->il
824
+ 3 jl,ij->il il->il
825
+
826
+
827
+ A more complex index transformation example.
828
+
829
+ >>> I = np.random.rand(10, 10, 10, 10)
830
+ >>> C = np.random.rand(10, 10)
831
+ >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
832
+ ... optimize='greedy')
833
+
834
+ >>> print(path_info[0])
835
+ ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
836
+ >>> print(path_info[1])
837
+ Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
838
+ Naive scaling: 8
839
+ Optimized scaling: 5
840
+ Naive FLOP count: 8.000e+08
841
+ Optimized FLOP count: 8.000e+05
842
+ Theoretical speedup: 1000.000
843
+ Largest intermediate: 1.000e+04 elements
844
+ --------------------------------------------------------------------------
845
+ scaling current remaining
846
+ --------------------------------------------------------------------------
847
+ 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
848
+ 5 bcde,fb->cdef gc,hd,cdef->efgh
849
+ 5 cdef,gc->defg hd,defg->efgh
850
+ 5 defg,hd->efgh efgh->efgh
851
+ """
852
+
853
+ # Figure out what the path really is
854
+ path_type = optimize
855
+ if path_type is True:
856
+ path_type = 'greedy'
857
+ if path_type is None:
858
+ path_type = False
859
+
860
+ explicit_einsum_path = False
861
+ memory_limit = None
862
+
863
+ # No optimization or a named path algorithm
864
+ if (path_type is False) or isinstance(path_type, str):
865
+ pass
866
+
867
+ # Given an explicit path
868
+ elif len(path_type) and (path_type[0] == 'einsum_path'):
869
+ explicit_einsum_path = True
870
+
871
+ # Path tuple with memory limit
872
+ elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
873
+ isinstance(path_type[1], (int, float))):
874
+ memory_limit = int(path_type[1])
875
+ path_type = path_type[0]
876
+
877
+ else:
878
+ raise TypeError(f"Did not understand the path: {str(path_type)}")
879
+
880
+ # Hidden option, only einsum should call this
881
+ einsum_call_arg = einsum_call
882
+
883
+ # Python side parsing
884
+ input_subscripts, output_subscript, operands = (
885
+ _parse_einsum_input(operands)
886
+ )
887
+
888
+ # Build a few useful list and sets
889
+ input_list = input_subscripts.split(',')
890
+ input_sets = [set(x) for x in input_list]
891
+ output_set = set(output_subscript)
892
+ indices = set(input_subscripts.replace(',', ''))
893
+
894
+ # Get length of each unique dimension and ensure all dimensions are correct
895
+ dimension_dict = {}
896
+ broadcast_indices = [[] for x in range(len(input_list))]
897
+ for tnum, term in enumerate(input_list):
898
+ sh = operands[tnum].shape
899
+ if len(sh) != len(term):
900
+ raise ValueError("Einstein sum subscript %s does not contain the "
901
+ "correct number of indices for operand %d."
902
+ % (input_subscripts[tnum], tnum))
903
+ for cnum, char in enumerate(term):
904
+ dim = sh[cnum]
905
+
906
+ # Build out broadcast indices
907
+ if dim == 1:
908
+ broadcast_indices[tnum].append(char)
909
+
910
+ if char in dimension_dict.keys():
911
+ # For broadcasting cases we always want the largest dim size
912
+ if dimension_dict[char] == 1:
913
+ dimension_dict[char] = dim
914
+ elif dim not in (1, dimension_dict[char]):
915
+ raise ValueError("Size of label '%s' for operand %d (%d) "
916
+ "does not match previous terms (%d)."
917
+ % (char, tnum, dimension_dict[char], dim))
918
+ else:
919
+ dimension_dict[char] = dim
920
+
921
+ # Convert broadcast inds to sets
922
+ broadcast_indices = [set(x) for x in broadcast_indices]
923
+
924
+ # Compute size of each input array plus the output array
925
+ size_list = [_compute_size_by_dict(term, dimension_dict)
926
+ for term in input_list + [output_subscript]]
927
+ max_size = max(size_list)
928
+
929
+ if memory_limit is None:
930
+ memory_arg = max_size
931
+ else:
932
+ memory_arg = memory_limit
933
+
934
+ # Compute naive cost
935
+ # This isn't quite right, need to look into exactly how einsum does this
936
+ inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
937
+ naive_cost = _flop_count(
938
+ indices, inner_product, len(input_list), dimension_dict
939
+ )
940
+
941
+ # Compute the path
942
+ if explicit_einsum_path:
943
+ path = path_type[1:]
944
+ elif (
945
+ (path_type is False)
946
+ or (len(input_list) in [1, 2])
947
+ or (indices == output_set)
948
+ ):
949
+ # Nothing to be optimized, leave it to einsum
950
+ path = [tuple(range(len(input_list)))]
951
+ elif path_type == "greedy":
952
+ path = _greedy_path(
953
+ input_sets, output_set, dimension_dict, memory_arg
954
+ )
955
+ elif path_type == "optimal":
956
+ path = _optimal_path(
957
+ input_sets, output_set, dimension_dict, memory_arg
958
+ )
959
+ else:
960
+ raise KeyError("Path name %s not found", path_type)
961
+
962
+ cost_list, scale_list, size_list, contraction_list = [], [], [], []
963
+
964
+ # Build contraction tuple (positions, gemm, einsum_str, remaining)
965
+ for cnum, contract_inds in enumerate(path):
966
+ # Make sure we remove inds from right to left
967
+ contract_inds = tuple(sorted(contract_inds, reverse=True))
968
+
969
+ contract = _find_contraction(contract_inds, input_sets, output_set)
970
+ out_inds, input_sets, idx_removed, idx_contract = contract
971
+
972
+ cost = _flop_count(
973
+ idx_contract, idx_removed, len(contract_inds), dimension_dict
974
+ )
975
+ cost_list.append(cost)
976
+ scale_list.append(len(idx_contract))
977
+ size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
978
+
979
+ bcast = set()
980
+ tmp_inputs = []
981
+ for x in contract_inds:
982
+ tmp_inputs.append(input_list.pop(x))
983
+ bcast |= broadcast_indices.pop(x)
984
+
985
+ new_bcast_inds = bcast - idx_removed
986
+
987
+ # If we're broadcasting, nix blas
988
+ if not len(idx_removed & bcast):
989
+ do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
990
+ else:
991
+ do_blas = False
992
+
993
+ # Last contraction
994
+ if (cnum - len(path)) == -1:
995
+ idx_result = output_subscript
996
+ else:
997
+ sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
998
+ idx_result = "".join([x[1] for x in sorted(sort_result)])
999
+
1000
+ input_list.append(idx_result)
1001
+ broadcast_indices.append(new_bcast_inds)
1002
+ einsum_str = ",".join(tmp_inputs) + "->" + idx_result
1003
+
1004
+ contraction = (
1005
+ contract_inds, idx_removed, einsum_str, input_list[:], do_blas
1006
+ )
1007
+ contraction_list.append(contraction)
1008
+
1009
+ opt_cost = sum(cost_list) + 1
1010
+
1011
+ if len(input_list) != 1:
1012
+ # Explicit "einsum_path" is usually trusted, but we detect this kind of
1013
+ # mistake in order to prevent from returning an intermediate value.
1014
+ raise RuntimeError(
1015
+ f"Invalid einsum_path is specified: {len(input_list) - 1} more "
1016
+ "operands has to be contracted.")
1017
+
1018
+ if einsum_call_arg:
1019
+ return (operands, contraction_list)
1020
+
1021
+ # Return the path along with a nice string representation
1022
+ overall_contraction = input_subscripts + "->" + output_subscript
1023
+ header = ("scaling", "current", "remaining")
1024
+
1025
+ speedup = naive_cost / opt_cost
1026
+ max_i = max(size_list)
1027
+
1028
+ path_print = f" Complete contraction: {overall_contraction}\n"
1029
+ path_print += f" Naive scaling: {len(indices)}\n"
1030
+ path_print += " Optimized scaling: %d\n" % max(scale_list)
1031
+ path_print += f" Naive FLOP count: {naive_cost:.3e}\n"
1032
+ path_print += f" Optimized FLOP count: {opt_cost:.3e}\n"
1033
+ path_print += f" Theoretical speedup: {speedup:3.3f}\n"
1034
+ path_print += f" Largest intermediate: {max_i:.3e} elements\n"
1035
+ path_print += "-" * 74 + "\n"
1036
+ path_print += "%6s %24s %40s\n" % header
1037
+ path_print += "-" * 74
1038
+
1039
+ for n, contraction in enumerate(contraction_list):
1040
+ inds, idx_rm, einsum_str, remaining, blas = contraction
1041
+ remaining_str = ",".join(remaining) + "->" + output_subscript
1042
+ path_run = (scale_list[n], einsum_str, remaining_str)
1043
+ path_print += "\n%4d %24s %40s" % path_run
1044
+
1045
+ path = ['einsum_path'] + path
1046
+ return (path, path_print)
1047
+
1048
+
1049
+ def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
1050
+ # Arguably we dispatch on more arguments than we really should; see note in
1051
+ # _einsum_path_dispatcher for why.
1052
+ yield from operands
1053
+ yield out
1054
+
1055
+
1056
+ # Rewrite einsum to handle different cases
1057
+ @array_function_dispatch(_einsum_dispatcher, module='numpy')
1058
+ def einsum(*operands, out=None, optimize=False, **kwargs):
1059
+ """
1060
+ einsum(subscripts, *operands, out=None, dtype=None, order='K',
1061
+ casting='safe', optimize=False)
1062
+
1063
+ Evaluates the Einstein summation convention on the operands.
1064
+
1065
+ Using the Einstein summation convention, many common multi-dimensional,
1066
+ linear algebraic array operations can be represented in a simple fashion.
1067
+ In *implicit* mode `einsum` computes these values.
1068
+
1069
+ In *explicit* mode, `einsum` provides further flexibility to compute
1070
+ other array operations that might not be considered classical Einstein
1071
+ summation operations, by disabling, or forcing summation over specified
1072
+ subscript labels.
1073
+
1074
+ See the notes and examples for clarification.
1075
+
1076
+ Parameters
1077
+ ----------
1078
+ subscripts : str
1079
+ Specifies the subscripts for summation as comma separated list of
1080
+ subscript labels. An implicit (classical Einstein summation)
1081
+ calculation is performed unless the explicit indicator '->' is
1082
+ included as well as subscript labels of the precise output form.
1083
+ operands : list of array_like
1084
+ These are the arrays for the operation.
1085
+ out : ndarray, optional
1086
+ If provided, the calculation is done into this array.
1087
+ dtype : {data-type, None}, optional
1088
+ If provided, forces the calculation to use the data type specified.
1089
+ Note that you may have to also give a more liberal `casting`
1090
+ parameter to allow the conversions. Default is None.
1091
+ order : {'C', 'F', 'A', 'K'}, optional
1092
+ Controls the memory layout of the output. 'C' means it should
1093
+ be C contiguous. 'F' means it should be Fortran contiguous,
1094
+ 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
1095
+ 'K' means it should be as close to the layout as the inputs as
1096
+ is possible, including arbitrarily permuted axes.
1097
+ Default is 'K'.
1098
+ casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1099
+ Controls what kind of data casting may occur. Setting this to
1100
+ 'unsafe' is not recommended, as it can adversely affect accumulations.
1101
+
1102
+ * 'no' means the data types should not be cast at all.
1103
+ * 'equiv' means only byte-order changes are allowed.
1104
+ * 'safe' means only casts which can preserve values are allowed.
1105
+ * 'same_kind' means only safe casts or casts within a kind,
1106
+ like float64 to float32, are allowed.
1107
+ * 'unsafe' means any data conversions may be done.
1108
+
1109
+ Default is 'safe'.
1110
+ optimize : {False, True, 'greedy', 'optimal'}, optional
1111
+ Controls if intermediate optimization should occur. No optimization
1112
+ will occur if False and True will default to the 'greedy' algorithm.
1113
+ Also accepts an explicit contraction list from the ``np.einsum_path``
1114
+ function. See ``np.einsum_path`` for more details. Defaults to False.
1115
+
1116
+ Returns
1117
+ -------
1118
+ output : ndarray
1119
+ The calculation based on the Einstein summation convention.
1120
+
1121
+ See Also
1122
+ --------
1123
+ einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
1124
+ einsum:
1125
+ Similar verbose interface is provided by the
1126
+ `einops <https://github.com/arogozhnikov/einops>`_ package to cover
1127
+ additional operations: transpose, reshape/flatten, repeat/tile,
1128
+ squeeze/unsqueeze and reductions.
1129
+ The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
1130
+ optimizes contraction order for einsum-like expressions
1131
+ in backend-agnostic manner.
1132
+
1133
+ Notes
1134
+ -----
1135
+ The Einstein summation convention can be used to compute
1136
+ many multi-dimensional, linear algebraic array operations. `einsum`
1137
+ provides a succinct way of representing these.
1138
+
1139
+ A non-exhaustive list of these operations,
1140
+ which can be computed by `einsum`, is shown below along with examples:
1141
+
1142
+ * Trace of an array, :py:func:`numpy.trace`.
1143
+ * Return a diagonal, :py:func:`numpy.diag`.
1144
+ * Array axis summations, :py:func:`numpy.sum`.
1145
+ * Transpositions and permutations, :py:func:`numpy.transpose`.
1146
+ * Matrix multiplication and dot product, :py:func:`numpy.matmul`
1147
+ :py:func:`numpy.dot`.
1148
+ * Vector inner and outer products, :py:func:`numpy.inner`
1149
+ :py:func:`numpy.outer`.
1150
+ * Broadcasting, element-wise and scalar multiplication,
1151
+ :py:func:`numpy.multiply`.
1152
+ * Tensor contractions, :py:func:`numpy.tensordot`.
1153
+ * Chained array operations, in efficient calculation order,
1154
+ :py:func:`numpy.einsum_path`.
1155
+
1156
+ The subscripts string is a comma-separated list of subscript labels,
1157
+ where each label refers to a dimension of the corresponding operand.
1158
+ Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
1159
+ is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
1160
+ appears only once, it is not summed, so ``np.einsum('i', a)``
1161
+ produces a view of ``a`` with no changes. A further example
1162
+ ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication
1163
+ and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`.
1164
+ Repeated subscript labels in one operand take the diagonal.
1165
+ For example, ``np.einsum('ii', a)`` is equivalent to
1166
+ :py:func:`np.trace(a) <numpy.trace>`.
1167
+
1168
+ In *implicit mode*, the chosen subscripts are important
1169
+ since the axes of the output are reordered alphabetically. This
1170
+ means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
1171
+ ``np.einsum('ji', a)`` takes its transpose. Additionally,
1172
+ ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
1173
+ ``np.einsum('ij,jh', a, b)`` returns the transpose of the
1174
+ multiplication since subscript 'h' precedes subscript 'i'.
1175
+
1176
+ In *explicit mode* the output can be directly controlled by
1177
+ specifying output subscript labels. This requires the
1178
+ identifier '->' as well as the list of output subscript labels.
1179
+ This feature increases the flexibility of the function since
1180
+ summing can be disabled or forced when required. The call
1181
+ ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>`
1182
+ if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)``
1183
+ is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array.
1184
+ The difference is that `einsum` does not allow broadcasting by default.
1185
+ Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
1186
+ order of the output subscript labels and therefore returns matrix
1187
+ multiplication, unlike the example above in implicit mode.
1188
+
1189
+ To enable and control broadcasting, use an ellipsis. Default
1190
+ NumPy-style broadcasting is done by adding an ellipsis
1191
+ to the left of each term, like ``np.einsum('...ii->...i', a)``.
1192
+ ``np.einsum('...i->...', a)`` is like
1193
+ :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape.
1194
+ To take the trace along the first and last axes,
1195
+ you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
1196
+ product with the left-most indices instead of rightmost, one can do
1197
+ ``np.einsum('ij...,jk...->ik...', a, b)``.
1198
+
1199
+ When there is only one operand, no axes are summed, and no output
1200
+ parameter is provided, a view into the operand is returned instead
1201
+ of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
1202
+ produces a view (changed in version 1.10.0).
1203
+
1204
+ `einsum` also provides an alternative way to provide the subscripts and
1205
+ operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
1206
+ If the output shape is not provided in this format `einsum` will be
1207
+ calculated in implicit mode, otherwise it will be performed explicitly.
1208
+ The examples below have corresponding `einsum` calls with the two
1209
+ parameter methods.
1210
+
1211
+ Views returned from einsum are now writeable whenever the input array
1212
+ is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
1213
+ have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
1214
+ and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
1215
+ of a 2D array.
1216
+
1217
+ Added the ``optimize`` argument which will optimize the contraction order
1218
+ of an einsum expression. For a contraction with three or more operands
1219
+ this can greatly increase the computational efficiency at the cost of
1220
+ a larger memory footprint during computation.
1221
+
1222
+ Typically a 'greedy' algorithm is applied which empirical tests have shown
1223
+ returns the optimal path in the majority of cases. In some cases 'optimal'
1224
+ will return the superlative path through a more expensive, exhaustive
1225
+ search. For iterative calculations it may be advisable to calculate
1226
+ the optimal path once and reuse that path by supplying it as an argument.
1227
+ An example is given below.
1228
+
1229
+ See :py:func:`numpy.einsum_path` for more details.
1230
+
1231
+ Examples
1232
+ --------
1233
+ >>> a = np.arange(25).reshape(5,5)
1234
+ >>> b = np.arange(5)
1235
+ >>> c = np.arange(6).reshape(2,3)
1236
+
1237
+ Trace of a matrix:
1238
+
1239
+ >>> np.einsum('ii', a)
1240
+ 60
1241
+ >>> np.einsum(a, [0,0])
1242
+ 60
1243
+ >>> np.trace(a)
1244
+ 60
1245
+
1246
+ Extract the diagonal (requires explicit form):
1247
+
1248
+ >>> np.einsum('ii->i', a)
1249
+ array([ 0, 6, 12, 18, 24])
1250
+ >>> np.einsum(a, [0,0], [0])
1251
+ array([ 0, 6, 12, 18, 24])
1252
+ >>> np.diag(a)
1253
+ array([ 0, 6, 12, 18, 24])
1254
+
1255
+ Sum over an axis (requires explicit form):
1256
+
1257
+ >>> np.einsum('ij->i', a)
1258
+ array([ 10, 35, 60, 85, 110])
1259
+ >>> np.einsum(a, [0,1], [0])
1260
+ array([ 10, 35, 60, 85, 110])
1261
+ >>> np.sum(a, axis=1)
1262
+ array([ 10, 35, 60, 85, 110])
1263
+
1264
+ For higher dimensional arrays summing a single axis can be done
1265
+ with ellipsis:
1266
+
1267
+ >>> np.einsum('...j->...', a)
1268
+ array([ 10, 35, 60, 85, 110])
1269
+ >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
1270
+ array([ 10, 35, 60, 85, 110])
1271
+
1272
+ Compute a matrix transpose, or reorder any number of axes:
1273
+
1274
+ >>> np.einsum('ji', c)
1275
+ array([[0, 3],
1276
+ [1, 4],
1277
+ [2, 5]])
1278
+ >>> np.einsum('ij->ji', c)
1279
+ array([[0, 3],
1280
+ [1, 4],
1281
+ [2, 5]])
1282
+ >>> np.einsum(c, [1,0])
1283
+ array([[0, 3],
1284
+ [1, 4],
1285
+ [2, 5]])
1286
+ >>> np.transpose(c)
1287
+ array([[0, 3],
1288
+ [1, 4],
1289
+ [2, 5]])
1290
+
1291
+ Vector inner products:
1292
+
1293
+ >>> np.einsum('i,i', b, b)
1294
+ 30
1295
+ >>> np.einsum(b, [0], b, [0])
1296
+ 30
1297
+ >>> np.inner(b,b)
1298
+ 30
1299
+
1300
+ Matrix vector multiplication:
1301
+
1302
+ >>> np.einsum('ij,j', a, b)
1303
+ array([ 30, 80, 130, 180, 230])
1304
+ >>> np.einsum(a, [0,1], b, [1])
1305
+ array([ 30, 80, 130, 180, 230])
1306
+ >>> np.dot(a, b)
1307
+ array([ 30, 80, 130, 180, 230])
1308
+ >>> np.einsum('...j,j', a, b)
1309
+ array([ 30, 80, 130, 180, 230])
1310
+
1311
+ Broadcasting and scalar multiplication:
1312
+
1313
+ >>> np.einsum('..., ...', 3, c)
1314
+ array([[ 0, 3, 6],
1315
+ [ 9, 12, 15]])
1316
+ >>> np.einsum(',ij', 3, c)
1317
+ array([[ 0, 3, 6],
1318
+ [ 9, 12, 15]])
1319
+ >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
1320
+ array([[ 0, 3, 6],
1321
+ [ 9, 12, 15]])
1322
+ >>> np.multiply(3, c)
1323
+ array([[ 0, 3, 6],
1324
+ [ 9, 12, 15]])
1325
+
1326
+ Vector outer product:
1327
+
1328
+ >>> np.einsum('i,j', np.arange(2)+1, b)
1329
+ array([[0, 1, 2, 3, 4],
1330
+ [0, 2, 4, 6, 8]])
1331
+ >>> np.einsum(np.arange(2)+1, [0], b, [1])
1332
+ array([[0, 1, 2, 3, 4],
1333
+ [0, 2, 4, 6, 8]])
1334
+ >>> np.outer(np.arange(2)+1, b)
1335
+ array([[0, 1, 2, 3, 4],
1336
+ [0, 2, 4, 6, 8]])
1337
+
1338
+ Tensor contraction:
1339
+
1340
+ >>> a = np.arange(60.).reshape(3,4,5)
1341
+ >>> b = np.arange(24.).reshape(4,3,2)
1342
+ >>> np.einsum('ijk,jil->kl', a, b)
1343
+ array([[4400., 4730.],
1344
+ [4532., 4874.],
1345
+ [4664., 5018.],
1346
+ [4796., 5162.],
1347
+ [4928., 5306.]])
1348
+ >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
1349
+ array([[4400., 4730.],
1350
+ [4532., 4874.],
1351
+ [4664., 5018.],
1352
+ [4796., 5162.],
1353
+ [4928., 5306.]])
1354
+ >>> np.tensordot(a,b, axes=([1,0],[0,1]))
1355
+ array([[4400., 4730.],
1356
+ [4532., 4874.],
1357
+ [4664., 5018.],
1358
+ [4796., 5162.],
1359
+ [4928., 5306.]])
1360
+
1361
+ Writeable returned arrays (since version 1.10.0):
1362
+
1363
+ >>> a = np.zeros((3, 3))
1364
+ >>> np.einsum('ii->i', a)[:] = 1
1365
+ >>> a
1366
+ array([[1., 0., 0.],
1367
+ [0., 1., 0.],
1368
+ [0., 0., 1.]])
1369
+
1370
+ Example of ellipsis use:
1371
+
1372
+ >>> a = np.arange(6).reshape((3,2))
1373
+ >>> b = np.arange(12).reshape((4,3))
1374
+ >>> np.einsum('ki,jk->ij', a, b)
1375
+ array([[10, 28, 46, 64],
1376
+ [13, 40, 67, 94]])
1377
+ >>> np.einsum('ki,...k->i...', a, b)
1378
+ array([[10, 28, 46, 64],
1379
+ [13, 40, 67, 94]])
1380
+ >>> np.einsum('k...,jk', a, b)
1381
+ array([[10, 28, 46, 64],
1382
+ [13, 40, 67, 94]])
1383
+
1384
+ Chained array operations. For more complicated contractions, speed ups
1385
+ might be achieved by repeatedly computing a 'greedy' path or pre-computing
1386
+ the 'optimal' path and repeatedly applying it, using an `einsum_path`
1387
+ insertion (since version 1.12.0). Performance improvements can be
1388
+ particularly significant with larger arrays:
1389
+
1390
+ >>> a = np.ones(64).reshape(2,4,8)
1391
+
1392
+ Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
1393
+
1394
+ >>> for iteration in range(500):
1395
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
1396
+
1397
+ Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
1398
+
1399
+ >>> for iteration in range(500):
1400
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1401
+ ... optimize='optimal')
1402
+
1403
+ Greedy `einsum` (faster optimal path approximation): ~160ms
1404
+
1405
+ >>> for iteration in range(500):
1406
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
1407
+
1408
+ Optimal `einsum` (best usage pattern in some use cases): ~110ms
1409
+
1410
+ >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1411
+ ... optimize='optimal')[0]
1412
+ >>> for iteration in range(500):
1413
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
1414
+
1415
+ """
1416
+ # Special handling if out is specified
1417
+ specified_out = out is not None
1418
+
1419
+ # If no optimization, run pure einsum
1420
+ if optimize is False:
1421
+ if specified_out:
1422
+ kwargs['out'] = out
1423
+ return c_einsum(*operands, **kwargs)
1424
+
1425
+ # Check the kwargs to avoid a more cryptic error later, without having to
1426
+ # repeat default values here
1427
+ valid_einsum_kwargs = ['dtype', 'order', 'casting']
1428
+ unknown_kwargs = [k for (k, v) in kwargs.items() if
1429
+ k not in valid_einsum_kwargs]
1430
+ if len(unknown_kwargs):
1431
+ raise TypeError(f"Did not understand the following kwargs: {unknown_kwargs}")
1432
+
1433
+ # Build the contraction list and operand
1434
+ operands, contraction_list = einsum_path(*operands, optimize=optimize,
1435
+ einsum_call=True)
1436
+
1437
+ # Handle order kwarg for output array, c_einsum allows mixed case
1438
+ output_order = kwargs.pop('order', 'K')
1439
+ if output_order.upper() == 'A':
1440
+ if all(arr.flags.f_contiguous for arr in operands):
1441
+ output_order = 'F'
1442
+ else:
1443
+ output_order = 'C'
1444
+
1445
+ # Start contraction loop
1446
+ for num, contraction in enumerate(contraction_list):
1447
+ inds, idx_rm, einsum_str, remaining, blas = contraction
1448
+ tmp_operands = [operands.pop(x) for x in inds]
1449
+
1450
+ # Do we need to deal with the output?
1451
+ handle_out = specified_out and ((num + 1) == len(contraction_list))
1452
+
1453
+ # Call tensordot if still possible
1454
+ if blas:
1455
+ # Checks have already been handled
1456
+ input_str, results_index = einsum_str.split('->')
1457
+ input_left, input_right = input_str.split(',')
1458
+
1459
+ tensor_result = input_left + input_right
1460
+ for s in idx_rm:
1461
+ tensor_result = tensor_result.replace(s, "")
1462
+
1463
+ # Find indices to contract over
1464
+ left_pos, right_pos = [], []
1465
+ for s in sorted(idx_rm):
1466
+ left_pos.append(input_left.find(s))
1467
+ right_pos.append(input_right.find(s))
1468
+
1469
+ # Contract!
1470
+ new_view = tensordot(
1471
+ *tmp_operands, axes=(tuple(left_pos), tuple(right_pos))
1472
+ )
1473
+
1474
+ # Build a new view if needed
1475
+ if (tensor_result != results_index) or handle_out:
1476
+ if handle_out:
1477
+ kwargs["out"] = out
1478
+ new_view = c_einsum(
1479
+ tensor_result + '->' + results_index, new_view, **kwargs
1480
+ )
1481
+
1482
+ # Call einsum
1483
+ else:
1484
+ # If out was specified
1485
+ if handle_out:
1486
+ kwargs["out"] = out
1487
+
1488
+ # Do the contraction
1489
+ new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
1490
+
1491
+ # Append new items and dereference what we can
1492
+ operands.append(new_view)
1493
+ del tmp_operands, new_view
1494
+
1495
+ if specified_out:
1496
+ return out
1497
+ else:
1498
+ return asanyarray(operands[0], order=output_order)