numpy 2.4.2__cp313-cp313t-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (929) hide show
  1. numpy/__config__.py +170 -0
  2. numpy/__config__.pyi +108 -0
  3. numpy/__init__.cython-30.pxd +1242 -0
  4. numpy/__init__.pxd +1155 -0
  5. numpy/__init__.py +942 -0
  6. numpy/__init__.pyi +6202 -0
  7. numpy/_array_api_info.py +346 -0
  8. numpy/_array_api_info.pyi +206 -0
  9. numpy/_configtool.py +39 -0
  10. numpy/_configtool.pyi +1 -0
  11. numpy/_core/__init__.py +203 -0
  12. numpy/_core/__init__.pyi +666 -0
  13. numpy/_core/_add_newdocs.py +7151 -0
  14. numpy/_core/_add_newdocs.pyi +2 -0
  15. numpy/_core/_add_newdocs_scalars.py +381 -0
  16. numpy/_core/_add_newdocs_scalars.pyi +16 -0
  17. numpy/_core/_asarray.py +130 -0
  18. numpy/_core/_asarray.pyi +43 -0
  19. numpy/_core/_dtype.py +366 -0
  20. numpy/_core/_dtype.pyi +56 -0
  21. numpy/_core/_dtype_ctypes.py +120 -0
  22. numpy/_core/_dtype_ctypes.pyi +83 -0
  23. numpy/_core/_exceptions.py +162 -0
  24. numpy/_core/_exceptions.pyi +54 -0
  25. numpy/_core/_internal.py +968 -0
  26. numpy/_core/_internal.pyi +61 -0
  27. numpy/_core/_methods.py +252 -0
  28. numpy/_core/_methods.pyi +22 -0
  29. numpy/_core/_multiarray_tests.cp313t-win32.lib +0 -0
  30. numpy/_core/_multiarray_tests.cp313t-win32.pyd +0 -0
  31. numpy/_core/_multiarray_umath.cp313t-win32.lib +0 -0
  32. numpy/_core/_multiarray_umath.cp313t-win32.pyd +0 -0
  33. numpy/_core/_operand_flag_tests.cp313t-win32.lib +0 -0
  34. numpy/_core/_operand_flag_tests.cp313t-win32.pyd +0 -0
  35. numpy/_core/_rational_tests.cp313t-win32.lib +0 -0
  36. numpy/_core/_rational_tests.cp313t-win32.pyd +0 -0
  37. numpy/_core/_simd.cp313t-win32.lib +0 -0
  38. numpy/_core/_simd.cp313t-win32.pyd +0 -0
  39. numpy/_core/_simd.pyi +35 -0
  40. numpy/_core/_string_helpers.py +100 -0
  41. numpy/_core/_string_helpers.pyi +12 -0
  42. numpy/_core/_struct_ufunc_tests.cp313t-win32.lib +0 -0
  43. numpy/_core/_struct_ufunc_tests.cp313t-win32.pyd +0 -0
  44. numpy/_core/_type_aliases.py +131 -0
  45. numpy/_core/_type_aliases.pyi +86 -0
  46. numpy/_core/_ufunc_config.py +515 -0
  47. numpy/_core/_ufunc_config.pyi +69 -0
  48. numpy/_core/_umath_tests.cp313t-win32.lib +0 -0
  49. numpy/_core/_umath_tests.cp313t-win32.pyd +0 -0
  50. numpy/_core/_umath_tests.pyi +47 -0
  51. numpy/_core/arrayprint.py +1779 -0
  52. numpy/_core/arrayprint.pyi +158 -0
  53. numpy/_core/cversions.py +13 -0
  54. numpy/_core/defchararray.py +1414 -0
  55. numpy/_core/defchararray.pyi +1150 -0
  56. numpy/_core/einsumfunc.py +1650 -0
  57. numpy/_core/einsumfunc.pyi +184 -0
  58. numpy/_core/fromnumeric.py +4233 -0
  59. numpy/_core/fromnumeric.pyi +1735 -0
  60. numpy/_core/function_base.py +547 -0
  61. numpy/_core/function_base.pyi +276 -0
  62. numpy/_core/getlimits.py +462 -0
  63. numpy/_core/getlimits.pyi +124 -0
  64. numpy/_core/include/numpy/__multiarray_api.c +376 -0
  65. numpy/_core/include/numpy/__multiarray_api.h +1628 -0
  66. numpy/_core/include/numpy/__ufunc_api.c +55 -0
  67. numpy/_core/include/numpy/__ufunc_api.h +349 -0
  68. numpy/_core/include/numpy/_neighborhood_iterator_imp.h +90 -0
  69. numpy/_core/include/numpy/_numpyconfig.h +33 -0
  70. numpy/_core/include/numpy/_public_dtype_api_table.h +86 -0
  71. numpy/_core/include/numpy/arrayobject.h +7 -0
  72. numpy/_core/include/numpy/arrayscalars.h +198 -0
  73. numpy/_core/include/numpy/dtype_api.h +547 -0
  74. numpy/_core/include/numpy/halffloat.h +70 -0
  75. numpy/_core/include/numpy/ndarrayobject.h +304 -0
  76. numpy/_core/include/numpy/ndarraytypes.h +1982 -0
  77. numpy/_core/include/numpy/npy_2_compat.h +249 -0
  78. numpy/_core/include/numpy/npy_2_complexcompat.h +28 -0
  79. numpy/_core/include/numpy/npy_3kcompat.h +374 -0
  80. numpy/_core/include/numpy/npy_common.h +989 -0
  81. numpy/_core/include/numpy/npy_cpu.h +126 -0
  82. numpy/_core/include/numpy/npy_endian.h +79 -0
  83. numpy/_core/include/numpy/npy_math.h +602 -0
  84. numpy/_core/include/numpy/npy_no_deprecated_api.h +20 -0
  85. numpy/_core/include/numpy/npy_os.h +42 -0
  86. numpy/_core/include/numpy/numpyconfig.h +185 -0
  87. numpy/_core/include/numpy/random/LICENSE.txt +21 -0
  88. numpy/_core/include/numpy/random/bitgen.h +20 -0
  89. numpy/_core/include/numpy/random/distributions.h +209 -0
  90. numpy/_core/include/numpy/random/libdivide.h +2079 -0
  91. numpy/_core/include/numpy/ufuncobject.h +343 -0
  92. numpy/_core/include/numpy/utils.h +37 -0
  93. numpy/_core/lib/npy-pkg-config/mlib.ini +12 -0
  94. numpy/_core/lib/npy-pkg-config/npymath.ini +20 -0
  95. numpy/_core/lib/npymath.lib +0 -0
  96. numpy/_core/lib/pkgconfig/numpy.pc +7 -0
  97. numpy/_core/memmap.py +363 -0
  98. numpy/_core/memmap.pyi +3 -0
  99. numpy/_core/multiarray.py +1740 -0
  100. numpy/_core/multiarray.pyi +1328 -0
  101. numpy/_core/numeric.py +2771 -0
  102. numpy/_core/numeric.pyi +1276 -0
  103. numpy/_core/numerictypes.py +633 -0
  104. numpy/_core/numerictypes.pyi +196 -0
  105. numpy/_core/overrides.py +188 -0
  106. numpy/_core/overrides.pyi +47 -0
  107. numpy/_core/printoptions.py +32 -0
  108. numpy/_core/printoptions.pyi +28 -0
  109. numpy/_core/records.py +1088 -0
  110. numpy/_core/records.pyi +340 -0
  111. numpy/_core/shape_base.py +996 -0
  112. numpy/_core/shape_base.pyi +182 -0
  113. numpy/_core/strings.py +1813 -0
  114. numpy/_core/strings.pyi +536 -0
  115. numpy/_core/tests/_locales.py +72 -0
  116. numpy/_core/tests/_natype.py +144 -0
  117. numpy/_core/tests/data/astype_copy.pkl +0 -0
  118. numpy/_core/tests/data/generate_umath_validation_data.cpp +170 -0
  119. numpy/_core/tests/data/recarray_from_file.fits +0 -0
  120. numpy/_core/tests/data/umath-validation-set-README.txt +15 -0
  121. numpy/_core/tests/data/umath-validation-set-arccos.csv +1429 -0
  122. numpy/_core/tests/data/umath-validation-set-arccosh.csv +1429 -0
  123. numpy/_core/tests/data/umath-validation-set-arcsin.csv +1429 -0
  124. numpy/_core/tests/data/umath-validation-set-arcsinh.csv +1429 -0
  125. numpy/_core/tests/data/umath-validation-set-arctan.csv +1429 -0
  126. numpy/_core/tests/data/umath-validation-set-arctanh.csv +1429 -0
  127. numpy/_core/tests/data/umath-validation-set-cbrt.csv +1429 -0
  128. numpy/_core/tests/data/umath-validation-set-cos.csv +1375 -0
  129. numpy/_core/tests/data/umath-validation-set-cosh.csv +1429 -0
  130. numpy/_core/tests/data/umath-validation-set-exp.csv +412 -0
  131. numpy/_core/tests/data/umath-validation-set-exp2.csv +1429 -0
  132. numpy/_core/tests/data/umath-validation-set-expm1.csv +1429 -0
  133. numpy/_core/tests/data/umath-validation-set-log.csv +271 -0
  134. numpy/_core/tests/data/umath-validation-set-log10.csv +1629 -0
  135. numpy/_core/tests/data/umath-validation-set-log1p.csv +1429 -0
  136. numpy/_core/tests/data/umath-validation-set-log2.csv +1629 -0
  137. numpy/_core/tests/data/umath-validation-set-sin.csv +1370 -0
  138. numpy/_core/tests/data/umath-validation-set-sinh.csv +1429 -0
  139. numpy/_core/tests/data/umath-validation-set-tan.csv +1429 -0
  140. numpy/_core/tests/data/umath-validation-set-tanh.csv +1429 -0
  141. numpy/_core/tests/examples/cython/checks.pyx +374 -0
  142. numpy/_core/tests/examples/cython/meson.build +43 -0
  143. numpy/_core/tests/examples/cython/setup.py +39 -0
  144. numpy/_core/tests/examples/limited_api/limited_api1.c +15 -0
  145. numpy/_core/tests/examples/limited_api/limited_api2.pyx +11 -0
  146. numpy/_core/tests/examples/limited_api/limited_api_latest.c +19 -0
  147. numpy/_core/tests/examples/limited_api/meson.build +63 -0
  148. numpy/_core/tests/examples/limited_api/setup.py +24 -0
  149. numpy/_core/tests/test__exceptions.py +90 -0
  150. numpy/_core/tests/test_abc.py +54 -0
  151. numpy/_core/tests/test_api.py +655 -0
  152. numpy/_core/tests/test_argparse.py +90 -0
  153. numpy/_core/tests/test_array_api_info.py +113 -0
  154. numpy/_core/tests/test_array_coercion.py +928 -0
  155. numpy/_core/tests/test_array_interface.py +222 -0
  156. numpy/_core/tests/test_arraymethod.py +84 -0
  157. numpy/_core/tests/test_arrayobject.py +95 -0
  158. numpy/_core/tests/test_arrayprint.py +1324 -0
  159. numpy/_core/tests/test_casting_floatingpoint_errors.py +154 -0
  160. numpy/_core/tests/test_casting_unittests.py +955 -0
  161. numpy/_core/tests/test_conversion_utils.py +209 -0
  162. numpy/_core/tests/test_cpu_dispatcher.py +48 -0
  163. numpy/_core/tests/test_cpu_features.py +450 -0
  164. numpy/_core/tests/test_custom_dtypes.py +393 -0
  165. numpy/_core/tests/test_cython.py +352 -0
  166. numpy/_core/tests/test_datetime.py +2792 -0
  167. numpy/_core/tests/test_defchararray.py +858 -0
  168. numpy/_core/tests/test_deprecations.py +460 -0
  169. numpy/_core/tests/test_dlpack.py +190 -0
  170. numpy/_core/tests/test_dtype.py +2110 -0
  171. numpy/_core/tests/test_einsum.py +1351 -0
  172. numpy/_core/tests/test_errstate.py +131 -0
  173. numpy/_core/tests/test_extint128.py +217 -0
  174. numpy/_core/tests/test_finfo.py +86 -0
  175. numpy/_core/tests/test_function_base.py +504 -0
  176. numpy/_core/tests/test_getlimits.py +171 -0
  177. numpy/_core/tests/test_half.py +593 -0
  178. numpy/_core/tests/test_hashtable.py +36 -0
  179. numpy/_core/tests/test_indexerrors.py +122 -0
  180. numpy/_core/tests/test_indexing.py +1692 -0
  181. numpy/_core/tests/test_item_selection.py +167 -0
  182. numpy/_core/tests/test_limited_api.py +102 -0
  183. numpy/_core/tests/test_longdouble.py +370 -0
  184. numpy/_core/tests/test_mem_overlap.py +933 -0
  185. numpy/_core/tests/test_mem_policy.py +453 -0
  186. numpy/_core/tests/test_memmap.py +248 -0
  187. numpy/_core/tests/test_multiarray.py +11008 -0
  188. numpy/_core/tests/test_multiprocessing.py +55 -0
  189. numpy/_core/tests/test_multithreading.py +406 -0
  190. numpy/_core/tests/test_nditer.py +3533 -0
  191. numpy/_core/tests/test_nep50_promotions.py +287 -0
  192. numpy/_core/tests/test_numeric.py +4301 -0
  193. numpy/_core/tests/test_numerictypes.py +650 -0
  194. numpy/_core/tests/test_overrides.py +800 -0
  195. numpy/_core/tests/test_print.py +202 -0
  196. numpy/_core/tests/test_protocols.py +46 -0
  197. numpy/_core/tests/test_records.py +544 -0
  198. numpy/_core/tests/test_regression.py +2677 -0
  199. numpy/_core/tests/test_scalar_ctors.py +203 -0
  200. numpy/_core/tests/test_scalar_methods.py +328 -0
  201. numpy/_core/tests/test_scalarbuffer.py +153 -0
  202. numpy/_core/tests/test_scalarinherit.py +105 -0
  203. numpy/_core/tests/test_scalarmath.py +1168 -0
  204. numpy/_core/tests/test_scalarprint.py +403 -0
  205. numpy/_core/tests/test_shape_base.py +904 -0
  206. numpy/_core/tests/test_simd.py +1345 -0
  207. numpy/_core/tests/test_simd_module.py +105 -0
  208. numpy/_core/tests/test_stringdtype.py +1855 -0
  209. numpy/_core/tests/test_strings.py +1523 -0
  210. numpy/_core/tests/test_ufunc.py +3405 -0
  211. numpy/_core/tests/test_umath.py +4962 -0
  212. numpy/_core/tests/test_umath_accuracy.py +132 -0
  213. numpy/_core/tests/test_umath_complex.py +631 -0
  214. numpy/_core/tests/test_unicode.py +369 -0
  215. numpy/_core/umath.py +60 -0
  216. numpy/_core/umath.pyi +232 -0
  217. numpy/_distributor_init.py +15 -0
  218. numpy/_distributor_init.pyi +1 -0
  219. numpy/_expired_attrs_2_0.py +78 -0
  220. numpy/_expired_attrs_2_0.pyi +61 -0
  221. numpy/_globals.py +121 -0
  222. numpy/_globals.pyi +17 -0
  223. numpy/_pyinstaller/__init__.py +0 -0
  224. numpy/_pyinstaller/__init__.pyi +0 -0
  225. numpy/_pyinstaller/hook-numpy.py +36 -0
  226. numpy/_pyinstaller/hook-numpy.pyi +6 -0
  227. numpy/_pyinstaller/tests/__init__.py +16 -0
  228. numpy/_pyinstaller/tests/pyinstaller-smoke.py +32 -0
  229. numpy/_pyinstaller/tests/test_pyinstaller.py +35 -0
  230. numpy/_pytesttester.py +201 -0
  231. numpy/_pytesttester.pyi +18 -0
  232. numpy/_typing/__init__.py +173 -0
  233. numpy/_typing/_add_docstring.py +153 -0
  234. numpy/_typing/_array_like.py +106 -0
  235. numpy/_typing/_char_codes.py +213 -0
  236. numpy/_typing/_dtype_like.py +114 -0
  237. numpy/_typing/_extended_precision.py +15 -0
  238. numpy/_typing/_nbit.py +19 -0
  239. numpy/_typing/_nbit_base.py +94 -0
  240. numpy/_typing/_nbit_base.pyi +39 -0
  241. numpy/_typing/_nested_sequence.py +79 -0
  242. numpy/_typing/_scalars.py +20 -0
  243. numpy/_typing/_shape.py +8 -0
  244. numpy/_typing/_ufunc.py +7 -0
  245. numpy/_typing/_ufunc.pyi +975 -0
  246. numpy/_utils/__init__.py +95 -0
  247. numpy/_utils/__init__.pyi +28 -0
  248. numpy/_utils/_convertions.py +18 -0
  249. numpy/_utils/_convertions.pyi +4 -0
  250. numpy/_utils/_inspect.py +192 -0
  251. numpy/_utils/_inspect.pyi +70 -0
  252. numpy/_utils/_pep440.py +486 -0
  253. numpy/_utils/_pep440.pyi +118 -0
  254. numpy/char/__init__.py +2 -0
  255. numpy/char/__init__.pyi +111 -0
  256. numpy/conftest.py +248 -0
  257. numpy/core/__init__.py +33 -0
  258. numpy/core/__init__.pyi +0 -0
  259. numpy/core/_dtype.py +10 -0
  260. numpy/core/_dtype.pyi +0 -0
  261. numpy/core/_dtype_ctypes.py +10 -0
  262. numpy/core/_dtype_ctypes.pyi +0 -0
  263. numpy/core/_internal.py +27 -0
  264. numpy/core/_multiarray_umath.py +57 -0
  265. numpy/core/_utils.py +21 -0
  266. numpy/core/arrayprint.py +10 -0
  267. numpy/core/defchararray.py +10 -0
  268. numpy/core/einsumfunc.py +10 -0
  269. numpy/core/fromnumeric.py +10 -0
  270. numpy/core/function_base.py +10 -0
  271. numpy/core/getlimits.py +10 -0
  272. numpy/core/multiarray.py +25 -0
  273. numpy/core/numeric.py +12 -0
  274. numpy/core/numerictypes.py +10 -0
  275. numpy/core/overrides.py +10 -0
  276. numpy/core/overrides.pyi +7 -0
  277. numpy/core/records.py +10 -0
  278. numpy/core/shape_base.py +10 -0
  279. numpy/core/umath.py +10 -0
  280. numpy/ctypeslib/__init__.py +13 -0
  281. numpy/ctypeslib/__init__.pyi +15 -0
  282. numpy/ctypeslib/_ctypeslib.py +603 -0
  283. numpy/ctypeslib/_ctypeslib.pyi +236 -0
  284. numpy/doc/ufuncs.py +138 -0
  285. numpy/dtypes.py +41 -0
  286. numpy/dtypes.pyi +630 -0
  287. numpy/exceptions.py +246 -0
  288. numpy/exceptions.pyi +27 -0
  289. numpy/f2py/__init__.py +86 -0
  290. numpy/f2py/__init__.pyi +5 -0
  291. numpy/f2py/__main__.py +5 -0
  292. numpy/f2py/__version__.py +1 -0
  293. numpy/f2py/__version__.pyi +1 -0
  294. numpy/f2py/_backends/__init__.py +9 -0
  295. numpy/f2py/_backends/__init__.pyi +5 -0
  296. numpy/f2py/_backends/_backend.py +44 -0
  297. numpy/f2py/_backends/_backend.pyi +46 -0
  298. numpy/f2py/_backends/_distutils.py +76 -0
  299. numpy/f2py/_backends/_distutils.pyi +13 -0
  300. numpy/f2py/_backends/_meson.py +244 -0
  301. numpy/f2py/_backends/_meson.pyi +62 -0
  302. numpy/f2py/_backends/meson.build.template +58 -0
  303. numpy/f2py/_isocbind.py +62 -0
  304. numpy/f2py/_isocbind.pyi +13 -0
  305. numpy/f2py/_src_pyf.py +247 -0
  306. numpy/f2py/_src_pyf.pyi +28 -0
  307. numpy/f2py/auxfuncs.py +1004 -0
  308. numpy/f2py/auxfuncs.pyi +262 -0
  309. numpy/f2py/capi_maps.py +811 -0
  310. numpy/f2py/capi_maps.pyi +33 -0
  311. numpy/f2py/cb_rules.py +665 -0
  312. numpy/f2py/cb_rules.pyi +17 -0
  313. numpy/f2py/cfuncs.py +1563 -0
  314. numpy/f2py/cfuncs.pyi +31 -0
  315. numpy/f2py/common_rules.py +143 -0
  316. numpy/f2py/common_rules.pyi +9 -0
  317. numpy/f2py/crackfortran.py +3725 -0
  318. numpy/f2py/crackfortran.pyi +266 -0
  319. numpy/f2py/diagnose.py +149 -0
  320. numpy/f2py/diagnose.pyi +1 -0
  321. numpy/f2py/f2py2e.py +788 -0
  322. numpy/f2py/f2py2e.pyi +74 -0
  323. numpy/f2py/f90mod_rules.py +269 -0
  324. numpy/f2py/f90mod_rules.pyi +16 -0
  325. numpy/f2py/func2subr.py +329 -0
  326. numpy/f2py/func2subr.pyi +7 -0
  327. numpy/f2py/rules.py +1629 -0
  328. numpy/f2py/rules.pyi +41 -0
  329. numpy/f2py/setup.cfg +3 -0
  330. numpy/f2py/src/fortranobject.c +1436 -0
  331. numpy/f2py/src/fortranobject.h +173 -0
  332. numpy/f2py/symbolic.py +1518 -0
  333. numpy/f2py/symbolic.pyi +219 -0
  334. numpy/f2py/tests/__init__.py +16 -0
  335. numpy/f2py/tests/src/abstract_interface/foo.f90 +34 -0
  336. numpy/f2py/tests/src/abstract_interface/gh18403_mod.f90 +6 -0
  337. numpy/f2py/tests/src/array_from_pyobj/wrapmodule.c +235 -0
  338. numpy/f2py/tests/src/assumed_shape/.f2py_f2cmap +1 -0
  339. numpy/f2py/tests/src/assumed_shape/foo_free.f90 +34 -0
  340. numpy/f2py/tests/src/assumed_shape/foo_mod.f90 +41 -0
  341. numpy/f2py/tests/src/assumed_shape/foo_use.f90 +19 -0
  342. numpy/f2py/tests/src/assumed_shape/precision.f90 +4 -0
  343. numpy/f2py/tests/src/block_docstring/foo.f +6 -0
  344. numpy/f2py/tests/src/callback/foo.f +62 -0
  345. numpy/f2py/tests/src/callback/gh17797.f90 +7 -0
  346. numpy/f2py/tests/src/callback/gh18335.f90 +17 -0
  347. numpy/f2py/tests/src/callback/gh25211.f +10 -0
  348. numpy/f2py/tests/src/callback/gh25211.pyf +18 -0
  349. numpy/f2py/tests/src/callback/gh26681.f90 +18 -0
  350. numpy/f2py/tests/src/cli/gh_22819.pyf +6 -0
  351. numpy/f2py/tests/src/cli/hi77.f +3 -0
  352. numpy/f2py/tests/src/cli/hiworld.f90 +3 -0
  353. numpy/f2py/tests/src/common/block.f +11 -0
  354. numpy/f2py/tests/src/common/gh19161.f90 +10 -0
  355. numpy/f2py/tests/src/crackfortran/accesstype.f90 +13 -0
  356. numpy/f2py/tests/src/crackfortran/common_with_division.f +17 -0
  357. numpy/f2py/tests/src/crackfortran/data_common.f +8 -0
  358. numpy/f2py/tests/src/crackfortran/data_multiplier.f +5 -0
  359. numpy/f2py/tests/src/crackfortran/data_stmts.f90 +20 -0
  360. numpy/f2py/tests/src/crackfortran/data_with_comments.f +8 -0
  361. numpy/f2py/tests/src/crackfortran/foo_deps.f90 +6 -0
  362. numpy/f2py/tests/src/crackfortran/gh15035.f +16 -0
  363. numpy/f2py/tests/src/crackfortran/gh17859.f +12 -0
  364. numpy/f2py/tests/src/crackfortran/gh22648.pyf +7 -0
  365. numpy/f2py/tests/src/crackfortran/gh23533.f +5 -0
  366. numpy/f2py/tests/src/crackfortran/gh23598.f90 +4 -0
  367. numpy/f2py/tests/src/crackfortran/gh23598Warn.f90 +11 -0
  368. numpy/f2py/tests/src/crackfortran/gh23879.f90 +20 -0
  369. numpy/f2py/tests/src/crackfortran/gh27697.f90 +12 -0
  370. numpy/f2py/tests/src/crackfortran/gh2848.f90 +13 -0
  371. numpy/f2py/tests/src/crackfortran/operators.f90 +49 -0
  372. numpy/f2py/tests/src/crackfortran/privatemod.f90 +11 -0
  373. numpy/f2py/tests/src/crackfortran/publicmod.f90 +10 -0
  374. numpy/f2py/tests/src/crackfortran/pubprivmod.f90 +10 -0
  375. numpy/f2py/tests/src/crackfortran/unicode_comment.f90 +4 -0
  376. numpy/f2py/tests/src/f2cmap/.f2py_f2cmap +1 -0
  377. numpy/f2py/tests/src/f2cmap/isoFortranEnvMap.f90 +9 -0
  378. numpy/f2py/tests/src/isocintrin/isoCtests.f90 +34 -0
  379. numpy/f2py/tests/src/kind/foo.f90 +20 -0
  380. numpy/f2py/tests/src/mixed/foo.f +5 -0
  381. numpy/f2py/tests/src/mixed/foo_fixed.f90 +8 -0
  382. numpy/f2py/tests/src/mixed/foo_free.f90 +8 -0
  383. numpy/f2py/tests/src/modules/gh25337/data.f90 +8 -0
  384. numpy/f2py/tests/src/modules/gh25337/use_data.f90 +6 -0
  385. numpy/f2py/tests/src/modules/gh26920/two_mods_with_no_public_entities.f90 +21 -0
  386. numpy/f2py/tests/src/modules/gh26920/two_mods_with_one_public_routine.f90 +21 -0
  387. numpy/f2py/tests/src/modules/module_data_docstring.f90 +12 -0
  388. numpy/f2py/tests/src/modules/use_modules.f90 +20 -0
  389. numpy/f2py/tests/src/negative_bounds/issue_20853.f90 +7 -0
  390. numpy/f2py/tests/src/parameter/constant_array.f90 +45 -0
  391. numpy/f2py/tests/src/parameter/constant_both.f90 +57 -0
  392. numpy/f2py/tests/src/parameter/constant_compound.f90 +15 -0
  393. numpy/f2py/tests/src/parameter/constant_integer.f90 +22 -0
  394. numpy/f2py/tests/src/parameter/constant_non_compound.f90 +23 -0
  395. numpy/f2py/tests/src/parameter/constant_real.f90 +23 -0
  396. numpy/f2py/tests/src/quoted_character/foo.f +14 -0
  397. numpy/f2py/tests/src/regression/AB.inc +1 -0
  398. numpy/f2py/tests/src/regression/assignOnlyModule.f90 +25 -0
  399. numpy/f2py/tests/src/regression/datonly.f90 +17 -0
  400. numpy/f2py/tests/src/regression/f77comments.f +26 -0
  401. numpy/f2py/tests/src/regression/f77fixedform.f95 +5 -0
  402. numpy/f2py/tests/src/regression/f90continuation.f90 +9 -0
  403. numpy/f2py/tests/src/regression/incfile.f90 +5 -0
  404. numpy/f2py/tests/src/regression/inout.f90 +9 -0
  405. numpy/f2py/tests/src/regression/lower_f2py_fortran.f90 +5 -0
  406. numpy/f2py/tests/src/regression/mod_derived_types.f90 +23 -0
  407. numpy/f2py/tests/src/return_character/foo77.f +45 -0
  408. numpy/f2py/tests/src/return_character/foo90.f90 +48 -0
  409. numpy/f2py/tests/src/return_complex/foo77.f +45 -0
  410. numpy/f2py/tests/src/return_complex/foo90.f90 +48 -0
  411. numpy/f2py/tests/src/return_integer/foo77.f +56 -0
  412. numpy/f2py/tests/src/return_integer/foo90.f90 +59 -0
  413. numpy/f2py/tests/src/return_logical/foo77.f +56 -0
  414. numpy/f2py/tests/src/return_logical/foo90.f90 +59 -0
  415. numpy/f2py/tests/src/return_real/foo77.f +45 -0
  416. numpy/f2py/tests/src/return_real/foo90.f90 +48 -0
  417. numpy/f2py/tests/src/routines/funcfortranname.f +5 -0
  418. numpy/f2py/tests/src/routines/funcfortranname.pyf +11 -0
  419. numpy/f2py/tests/src/routines/subrout.f +4 -0
  420. numpy/f2py/tests/src/routines/subrout.pyf +10 -0
  421. numpy/f2py/tests/src/size/foo.f90 +44 -0
  422. numpy/f2py/tests/src/string/char.f90 +29 -0
  423. numpy/f2py/tests/src/string/fixed_string.f90 +34 -0
  424. numpy/f2py/tests/src/string/gh24008.f +8 -0
  425. numpy/f2py/tests/src/string/gh24662.f90 +7 -0
  426. numpy/f2py/tests/src/string/gh25286.f90 +14 -0
  427. numpy/f2py/tests/src/string/gh25286.pyf +12 -0
  428. numpy/f2py/tests/src/string/gh25286_bc.pyf +12 -0
  429. numpy/f2py/tests/src/string/scalar_string.f90 +9 -0
  430. numpy/f2py/tests/src/string/string.f +12 -0
  431. numpy/f2py/tests/src/value_attrspec/gh21665.f90 +9 -0
  432. numpy/f2py/tests/test_abstract_interface.py +26 -0
  433. numpy/f2py/tests/test_array_from_pyobj.py +678 -0
  434. numpy/f2py/tests/test_assumed_shape.py +50 -0
  435. numpy/f2py/tests/test_block_docstring.py +20 -0
  436. numpy/f2py/tests/test_callback.py +263 -0
  437. numpy/f2py/tests/test_character.py +641 -0
  438. numpy/f2py/tests/test_common.py +23 -0
  439. numpy/f2py/tests/test_crackfortran.py +421 -0
  440. numpy/f2py/tests/test_data.py +71 -0
  441. numpy/f2py/tests/test_docs.py +66 -0
  442. numpy/f2py/tests/test_f2cmap.py +17 -0
  443. numpy/f2py/tests/test_f2py2e.py +983 -0
  444. numpy/f2py/tests/test_isoc.py +56 -0
  445. numpy/f2py/tests/test_kind.py +52 -0
  446. numpy/f2py/tests/test_mixed.py +35 -0
  447. numpy/f2py/tests/test_modules.py +83 -0
  448. numpy/f2py/tests/test_parameter.py +129 -0
  449. numpy/f2py/tests/test_pyf_src.py +43 -0
  450. numpy/f2py/tests/test_quoted_character.py +18 -0
  451. numpy/f2py/tests/test_regression.py +187 -0
  452. numpy/f2py/tests/test_return_character.py +48 -0
  453. numpy/f2py/tests/test_return_complex.py +67 -0
  454. numpy/f2py/tests/test_return_integer.py +55 -0
  455. numpy/f2py/tests/test_return_logical.py +65 -0
  456. numpy/f2py/tests/test_return_real.py +109 -0
  457. numpy/f2py/tests/test_routines.py +29 -0
  458. numpy/f2py/tests/test_semicolon_split.py +75 -0
  459. numpy/f2py/tests/test_size.py +45 -0
  460. numpy/f2py/tests/test_string.py +100 -0
  461. numpy/f2py/tests/test_symbolic.py +500 -0
  462. numpy/f2py/tests/test_value_attrspec.py +15 -0
  463. numpy/f2py/tests/util.py +442 -0
  464. numpy/f2py/use_rules.py +99 -0
  465. numpy/f2py/use_rules.pyi +9 -0
  466. numpy/fft/__init__.py +213 -0
  467. numpy/fft/__init__.pyi +38 -0
  468. numpy/fft/_helper.py +235 -0
  469. numpy/fft/_helper.pyi +44 -0
  470. numpy/fft/_pocketfft.py +1693 -0
  471. numpy/fft/_pocketfft.pyi +137 -0
  472. numpy/fft/_pocketfft_umath.cp313t-win32.lib +0 -0
  473. numpy/fft/_pocketfft_umath.cp313t-win32.pyd +0 -0
  474. numpy/fft/tests/__init__.py +0 -0
  475. numpy/fft/tests/test_helper.py +167 -0
  476. numpy/fft/tests/test_pocketfft.py +589 -0
  477. numpy/lib/__init__.py +97 -0
  478. numpy/lib/__init__.pyi +52 -0
  479. numpy/lib/_array_utils_impl.py +62 -0
  480. numpy/lib/_array_utils_impl.pyi +10 -0
  481. numpy/lib/_arraypad_impl.py +926 -0
  482. numpy/lib/_arraypad_impl.pyi +88 -0
  483. numpy/lib/_arraysetops_impl.py +1158 -0
  484. numpy/lib/_arraysetops_impl.pyi +462 -0
  485. numpy/lib/_arrayterator_impl.py +224 -0
  486. numpy/lib/_arrayterator_impl.pyi +45 -0
  487. numpy/lib/_datasource.py +700 -0
  488. numpy/lib/_datasource.pyi +30 -0
  489. numpy/lib/_format_impl.py +1036 -0
  490. numpy/lib/_format_impl.pyi +56 -0
  491. numpy/lib/_function_base_impl.py +5760 -0
  492. numpy/lib/_function_base_impl.pyi +2324 -0
  493. numpy/lib/_histograms_impl.py +1085 -0
  494. numpy/lib/_histograms_impl.pyi +40 -0
  495. numpy/lib/_index_tricks_impl.py +1048 -0
  496. numpy/lib/_index_tricks_impl.pyi +267 -0
  497. numpy/lib/_iotools.py +900 -0
  498. numpy/lib/_iotools.pyi +116 -0
  499. numpy/lib/_nanfunctions_impl.py +2006 -0
  500. numpy/lib/_nanfunctions_impl.pyi +48 -0
  501. numpy/lib/_npyio_impl.py +2583 -0
  502. numpy/lib/_npyio_impl.pyi +299 -0
  503. numpy/lib/_polynomial_impl.py +1465 -0
  504. numpy/lib/_polynomial_impl.pyi +338 -0
  505. numpy/lib/_scimath_impl.py +642 -0
  506. numpy/lib/_scimath_impl.pyi +93 -0
  507. numpy/lib/_shape_base_impl.py +1289 -0
  508. numpy/lib/_shape_base_impl.pyi +236 -0
  509. numpy/lib/_stride_tricks_impl.py +582 -0
  510. numpy/lib/_stride_tricks_impl.pyi +73 -0
  511. numpy/lib/_twodim_base_impl.py +1201 -0
  512. numpy/lib/_twodim_base_impl.pyi +408 -0
  513. numpy/lib/_type_check_impl.py +710 -0
  514. numpy/lib/_type_check_impl.pyi +348 -0
  515. numpy/lib/_ufunclike_impl.py +199 -0
  516. numpy/lib/_ufunclike_impl.pyi +60 -0
  517. numpy/lib/_user_array_impl.py +310 -0
  518. numpy/lib/_user_array_impl.pyi +226 -0
  519. numpy/lib/_utils_impl.py +784 -0
  520. numpy/lib/_utils_impl.pyi +22 -0
  521. numpy/lib/_version.py +153 -0
  522. numpy/lib/_version.pyi +17 -0
  523. numpy/lib/array_utils.py +7 -0
  524. numpy/lib/array_utils.pyi +6 -0
  525. numpy/lib/format.py +24 -0
  526. numpy/lib/format.pyi +24 -0
  527. numpy/lib/introspect.py +94 -0
  528. numpy/lib/introspect.pyi +3 -0
  529. numpy/lib/mixins.py +180 -0
  530. numpy/lib/mixins.pyi +78 -0
  531. numpy/lib/npyio.py +1 -0
  532. numpy/lib/npyio.pyi +5 -0
  533. numpy/lib/recfunctions.py +1681 -0
  534. numpy/lib/recfunctions.pyi +444 -0
  535. numpy/lib/scimath.py +13 -0
  536. numpy/lib/scimath.pyi +12 -0
  537. numpy/lib/stride_tricks.py +1 -0
  538. numpy/lib/stride_tricks.pyi +4 -0
  539. numpy/lib/tests/__init__.py +0 -0
  540. numpy/lib/tests/data/py2-np0-objarr.npy +0 -0
  541. numpy/lib/tests/data/py2-objarr.npy +0 -0
  542. numpy/lib/tests/data/py2-objarr.npz +0 -0
  543. numpy/lib/tests/data/py3-objarr.npy +0 -0
  544. numpy/lib/tests/data/py3-objarr.npz +0 -0
  545. numpy/lib/tests/data/python3.npy +0 -0
  546. numpy/lib/tests/data/win64python2.npy +0 -0
  547. numpy/lib/tests/test__datasource.py +328 -0
  548. numpy/lib/tests/test__iotools.py +358 -0
  549. numpy/lib/tests/test__version.py +64 -0
  550. numpy/lib/tests/test_array_utils.py +32 -0
  551. numpy/lib/tests/test_arraypad.py +1427 -0
  552. numpy/lib/tests/test_arraysetops.py +1302 -0
  553. numpy/lib/tests/test_arrayterator.py +45 -0
  554. numpy/lib/tests/test_format.py +1054 -0
  555. numpy/lib/tests/test_function_base.py +4756 -0
  556. numpy/lib/tests/test_histograms.py +855 -0
  557. numpy/lib/tests/test_index_tricks.py +693 -0
  558. numpy/lib/tests/test_io.py +2857 -0
  559. numpy/lib/tests/test_loadtxt.py +1099 -0
  560. numpy/lib/tests/test_mixins.py +215 -0
  561. numpy/lib/tests/test_nanfunctions.py +1438 -0
  562. numpy/lib/tests/test_packbits.py +376 -0
  563. numpy/lib/tests/test_polynomial.py +325 -0
  564. numpy/lib/tests/test_recfunctions.py +1042 -0
  565. numpy/lib/tests/test_regression.py +231 -0
  566. numpy/lib/tests/test_shape_base.py +813 -0
  567. numpy/lib/tests/test_stride_tricks.py +655 -0
  568. numpy/lib/tests/test_twodim_base.py +559 -0
  569. numpy/lib/tests/test_type_check.py +473 -0
  570. numpy/lib/tests/test_ufunclike.py +97 -0
  571. numpy/lib/tests/test_utils.py +80 -0
  572. numpy/lib/user_array.py +1 -0
  573. numpy/lib/user_array.pyi +1 -0
  574. numpy/linalg/__init__.py +95 -0
  575. numpy/linalg/__init__.pyi +71 -0
  576. numpy/linalg/_linalg.py +3657 -0
  577. numpy/linalg/_linalg.pyi +548 -0
  578. numpy/linalg/_umath_linalg.cp313t-win32.lib +0 -0
  579. numpy/linalg/_umath_linalg.cp313t-win32.pyd +0 -0
  580. numpy/linalg/_umath_linalg.pyi +60 -0
  581. numpy/linalg/lapack_lite.cp313t-win32.lib +0 -0
  582. numpy/linalg/lapack_lite.cp313t-win32.pyd +0 -0
  583. numpy/linalg/lapack_lite.pyi +143 -0
  584. numpy/linalg/tests/__init__.py +0 -0
  585. numpy/linalg/tests/test_deprecations.py +21 -0
  586. numpy/linalg/tests/test_linalg.py +2442 -0
  587. numpy/linalg/tests/test_regression.py +182 -0
  588. numpy/ma/API_CHANGES.txt +135 -0
  589. numpy/ma/LICENSE +24 -0
  590. numpy/ma/README.rst +236 -0
  591. numpy/ma/__init__.py +53 -0
  592. numpy/ma/__init__.pyi +458 -0
  593. numpy/ma/core.py +8929 -0
  594. numpy/ma/core.pyi +3733 -0
  595. numpy/ma/extras.py +2266 -0
  596. numpy/ma/extras.pyi +297 -0
  597. numpy/ma/mrecords.py +762 -0
  598. numpy/ma/mrecords.pyi +96 -0
  599. numpy/ma/tests/__init__.py +0 -0
  600. numpy/ma/tests/test_arrayobject.py +40 -0
  601. numpy/ma/tests/test_core.py +6008 -0
  602. numpy/ma/tests/test_deprecations.py +65 -0
  603. numpy/ma/tests/test_extras.py +1945 -0
  604. numpy/ma/tests/test_mrecords.py +495 -0
  605. numpy/ma/tests/test_old_ma.py +939 -0
  606. numpy/ma/tests/test_regression.py +83 -0
  607. numpy/ma/tests/test_subclassing.py +469 -0
  608. numpy/ma/testutils.py +294 -0
  609. numpy/ma/testutils.pyi +69 -0
  610. numpy/matlib.py +380 -0
  611. numpy/matlib.pyi +580 -0
  612. numpy/matrixlib/__init__.py +12 -0
  613. numpy/matrixlib/__init__.pyi +3 -0
  614. numpy/matrixlib/defmatrix.py +1119 -0
  615. numpy/matrixlib/defmatrix.pyi +218 -0
  616. numpy/matrixlib/tests/__init__.py +0 -0
  617. numpy/matrixlib/tests/test_defmatrix.py +455 -0
  618. numpy/matrixlib/tests/test_interaction.py +360 -0
  619. numpy/matrixlib/tests/test_masked_matrix.py +240 -0
  620. numpy/matrixlib/tests/test_matrix_linalg.py +110 -0
  621. numpy/matrixlib/tests/test_multiarray.py +17 -0
  622. numpy/matrixlib/tests/test_numeric.py +18 -0
  623. numpy/matrixlib/tests/test_regression.py +31 -0
  624. numpy/polynomial/__init__.py +187 -0
  625. numpy/polynomial/__init__.pyi +31 -0
  626. numpy/polynomial/_polybase.py +1191 -0
  627. numpy/polynomial/_polybase.pyi +262 -0
  628. numpy/polynomial/_polytypes.pyi +501 -0
  629. numpy/polynomial/chebyshev.py +2001 -0
  630. numpy/polynomial/chebyshev.pyi +180 -0
  631. numpy/polynomial/hermite.py +1738 -0
  632. numpy/polynomial/hermite.pyi +106 -0
  633. numpy/polynomial/hermite_e.py +1640 -0
  634. numpy/polynomial/hermite_e.pyi +106 -0
  635. numpy/polynomial/laguerre.py +1673 -0
  636. numpy/polynomial/laguerre.pyi +100 -0
  637. numpy/polynomial/legendre.py +1603 -0
  638. numpy/polynomial/legendre.pyi +100 -0
  639. numpy/polynomial/polynomial.py +1625 -0
  640. numpy/polynomial/polynomial.pyi +109 -0
  641. numpy/polynomial/polyutils.py +759 -0
  642. numpy/polynomial/polyutils.pyi +307 -0
  643. numpy/polynomial/tests/__init__.py +0 -0
  644. numpy/polynomial/tests/test_chebyshev.py +618 -0
  645. numpy/polynomial/tests/test_classes.py +613 -0
  646. numpy/polynomial/tests/test_hermite.py +553 -0
  647. numpy/polynomial/tests/test_hermite_e.py +554 -0
  648. numpy/polynomial/tests/test_laguerre.py +535 -0
  649. numpy/polynomial/tests/test_legendre.py +566 -0
  650. numpy/polynomial/tests/test_polynomial.py +691 -0
  651. numpy/polynomial/tests/test_polyutils.py +123 -0
  652. numpy/polynomial/tests/test_printing.py +557 -0
  653. numpy/polynomial/tests/test_symbol.py +217 -0
  654. numpy/py.typed +0 -0
  655. numpy/random/LICENSE.md +71 -0
  656. numpy/random/__init__.pxd +14 -0
  657. numpy/random/__init__.py +213 -0
  658. numpy/random/__init__.pyi +124 -0
  659. numpy/random/_bounded_integers.cp313t-win32.lib +0 -0
  660. numpy/random/_bounded_integers.cp313t-win32.pyd +0 -0
  661. numpy/random/_bounded_integers.pxd +38 -0
  662. numpy/random/_bounded_integers.pyi +1 -0
  663. numpy/random/_common.cp313t-win32.lib +0 -0
  664. numpy/random/_common.cp313t-win32.pyd +0 -0
  665. numpy/random/_common.pxd +110 -0
  666. numpy/random/_common.pyi +16 -0
  667. numpy/random/_examples/cffi/extending.py +44 -0
  668. numpy/random/_examples/cffi/parse.py +53 -0
  669. numpy/random/_examples/cython/extending.pyx +77 -0
  670. numpy/random/_examples/cython/extending_distributions.pyx +117 -0
  671. numpy/random/_examples/cython/meson.build +53 -0
  672. numpy/random/_examples/numba/extending.py +86 -0
  673. numpy/random/_examples/numba/extending_distributions.py +67 -0
  674. numpy/random/_generator.cp313t-win32.lib +0 -0
  675. numpy/random/_generator.cp313t-win32.pyd +0 -0
  676. numpy/random/_generator.pyi +862 -0
  677. numpy/random/_mt19937.cp313t-win32.lib +0 -0
  678. numpy/random/_mt19937.cp313t-win32.pyd +0 -0
  679. numpy/random/_mt19937.pyi +27 -0
  680. numpy/random/_pcg64.cp313t-win32.lib +0 -0
  681. numpy/random/_pcg64.cp313t-win32.pyd +0 -0
  682. numpy/random/_pcg64.pyi +41 -0
  683. numpy/random/_philox.cp313t-win32.lib +0 -0
  684. numpy/random/_philox.cp313t-win32.pyd +0 -0
  685. numpy/random/_philox.pyi +36 -0
  686. numpy/random/_pickle.py +88 -0
  687. numpy/random/_pickle.pyi +43 -0
  688. numpy/random/_sfc64.cp313t-win32.lib +0 -0
  689. numpy/random/_sfc64.cp313t-win32.pyd +0 -0
  690. numpy/random/_sfc64.pyi +25 -0
  691. numpy/random/bit_generator.cp313t-win32.lib +0 -0
  692. numpy/random/bit_generator.cp313t-win32.pyd +0 -0
  693. numpy/random/bit_generator.pxd +40 -0
  694. numpy/random/bit_generator.pyi +123 -0
  695. numpy/random/c_distributions.pxd +119 -0
  696. numpy/random/lib/npyrandom.lib +0 -0
  697. numpy/random/mtrand.cp313t-win32.lib +0 -0
  698. numpy/random/mtrand.cp313t-win32.pyd +0 -0
  699. numpy/random/mtrand.pyi +759 -0
  700. numpy/random/tests/__init__.py +0 -0
  701. numpy/random/tests/data/__init__.py +0 -0
  702. numpy/random/tests/data/generator_pcg64_np121.pkl.gz +0 -0
  703. numpy/random/tests/data/generator_pcg64_np126.pkl.gz +0 -0
  704. numpy/random/tests/data/mt19937-testset-1.csv +1001 -0
  705. numpy/random/tests/data/mt19937-testset-2.csv +1001 -0
  706. numpy/random/tests/data/pcg64-testset-1.csv +1001 -0
  707. numpy/random/tests/data/pcg64-testset-2.csv +1001 -0
  708. numpy/random/tests/data/pcg64dxsm-testset-1.csv +1001 -0
  709. numpy/random/tests/data/pcg64dxsm-testset-2.csv +1001 -0
  710. numpy/random/tests/data/philox-testset-1.csv +1001 -0
  711. numpy/random/tests/data/philox-testset-2.csv +1001 -0
  712. numpy/random/tests/data/sfc64-testset-1.csv +1001 -0
  713. numpy/random/tests/data/sfc64-testset-2.csv +1001 -0
  714. numpy/random/tests/data/sfc64_np126.pkl.gz +0 -0
  715. numpy/random/tests/test_direct.py +595 -0
  716. numpy/random/tests/test_extending.py +131 -0
  717. numpy/random/tests/test_generator_mt19937.py +2825 -0
  718. numpy/random/tests/test_generator_mt19937_regressions.py +221 -0
  719. numpy/random/tests/test_random.py +1724 -0
  720. numpy/random/tests/test_randomstate.py +2099 -0
  721. numpy/random/tests/test_randomstate_regression.py +213 -0
  722. numpy/random/tests/test_regression.py +175 -0
  723. numpy/random/tests/test_seed_sequence.py +79 -0
  724. numpy/random/tests/test_smoke.py +882 -0
  725. numpy/rec/__init__.py +2 -0
  726. numpy/rec/__init__.pyi +23 -0
  727. numpy/strings/__init__.py +2 -0
  728. numpy/strings/__init__.pyi +97 -0
  729. numpy/testing/__init__.py +22 -0
  730. numpy/testing/__init__.pyi +107 -0
  731. numpy/testing/_private/__init__.py +0 -0
  732. numpy/testing/_private/__init__.pyi +0 -0
  733. numpy/testing/_private/extbuild.py +250 -0
  734. numpy/testing/_private/extbuild.pyi +25 -0
  735. numpy/testing/_private/utils.py +2830 -0
  736. numpy/testing/_private/utils.pyi +505 -0
  737. numpy/testing/overrides.py +84 -0
  738. numpy/testing/overrides.pyi +10 -0
  739. numpy/testing/print_coercion_tables.py +207 -0
  740. numpy/testing/print_coercion_tables.pyi +26 -0
  741. numpy/testing/tests/__init__.py +0 -0
  742. numpy/testing/tests/test_utils.py +2123 -0
  743. numpy/tests/__init__.py +0 -0
  744. numpy/tests/test__all__.py +10 -0
  745. numpy/tests/test_configtool.py +51 -0
  746. numpy/tests/test_ctypeslib.py +383 -0
  747. numpy/tests/test_lazyloading.py +42 -0
  748. numpy/tests/test_matlib.py +59 -0
  749. numpy/tests/test_numpy_config.py +47 -0
  750. numpy/tests/test_numpy_version.py +54 -0
  751. numpy/tests/test_public_api.py +807 -0
  752. numpy/tests/test_reloading.py +76 -0
  753. numpy/tests/test_scripts.py +48 -0
  754. numpy/tests/test_warnings.py +79 -0
  755. numpy/typing/__init__.py +233 -0
  756. numpy/typing/__init__.pyi +3 -0
  757. numpy/typing/mypy_plugin.py +200 -0
  758. numpy/typing/tests/__init__.py +0 -0
  759. numpy/typing/tests/data/fail/arithmetic.pyi +126 -0
  760. numpy/typing/tests/data/fail/array_constructors.pyi +34 -0
  761. numpy/typing/tests/data/fail/array_like.pyi +15 -0
  762. numpy/typing/tests/data/fail/array_pad.pyi +6 -0
  763. numpy/typing/tests/data/fail/arrayprint.pyi +15 -0
  764. numpy/typing/tests/data/fail/arrayterator.pyi +14 -0
  765. numpy/typing/tests/data/fail/bitwise_ops.pyi +17 -0
  766. numpy/typing/tests/data/fail/char.pyi +63 -0
  767. numpy/typing/tests/data/fail/chararray.pyi +61 -0
  768. numpy/typing/tests/data/fail/comparisons.pyi +27 -0
  769. numpy/typing/tests/data/fail/constants.pyi +3 -0
  770. numpy/typing/tests/data/fail/datasource.pyi +16 -0
  771. numpy/typing/tests/data/fail/dtype.pyi +17 -0
  772. numpy/typing/tests/data/fail/einsumfunc.pyi +12 -0
  773. numpy/typing/tests/data/fail/flatiter.pyi +38 -0
  774. numpy/typing/tests/data/fail/fromnumeric.pyi +148 -0
  775. numpy/typing/tests/data/fail/histograms.pyi +12 -0
  776. numpy/typing/tests/data/fail/index_tricks.pyi +14 -0
  777. numpy/typing/tests/data/fail/lib_function_base.pyi +60 -0
  778. numpy/typing/tests/data/fail/lib_polynomial.pyi +29 -0
  779. numpy/typing/tests/data/fail/lib_utils.pyi +3 -0
  780. numpy/typing/tests/data/fail/lib_version.pyi +6 -0
  781. numpy/typing/tests/data/fail/linalg.pyi +52 -0
  782. numpy/typing/tests/data/fail/ma.pyi +155 -0
  783. numpy/typing/tests/data/fail/memmap.pyi +5 -0
  784. numpy/typing/tests/data/fail/modules.pyi +17 -0
  785. numpy/typing/tests/data/fail/multiarray.pyi +52 -0
  786. numpy/typing/tests/data/fail/ndarray.pyi +11 -0
  787. numpy/typing/tests/data/fail/ndarray_misc.pyi +49 -0
  788. numpy/typing/tests/data/fail/nditer.pyi +8 -0
  789. numpy/typing/tests/data/fail/nested_sequence.pyi +17 -0
  790. numpy/typing/tests/data/fail/npyio.pyi +24 -0
  791. numpy/typing/tests/data/fail/numerictypes.pyi +5 -0
  792. numpy/typing/tests/data/fail/random.pyi +62 -0
  793. numpy/typing/tests/data/fail/rec.pyi +17 -0
  794. numpy/typing/tests/data/fail/scalars.pyi +86 -0
  795. numpy/typing/tests/data/fail/shape.pyi +7 -0
  796. numpy/typing/tests/data/fail/shape_base.pyi +8 -0
  797. numpy/typing/tests/data/fail/stride_tricks.pyi +9 -0
  798. numpy/typing/tests/data/fail/strings.pyi +52 -0
  799. numpy/typing/tests/data/fail/testing.pyi +28 -0
  800. numpy/typing/tests/data/fail/twodim_base.pyi +39 -0
  801. numpy/typing/tests/data/fail/type_check.pyi +12 -0
  802. numpy/typing/tests/data/fail/ufunc_config.pyi +21 -0
  803. numpy/typing/tests/data/fail/ufunclike.pyi +21 -0
  804. numpy/typing/tests/data/fail/ufuncs.pyi +17 -0
  805. numpy/typing/tests/data/fail/warnings_and_errors.pyi +5 -0
  806. numpy/typing/tests/data/misc/extended_precision.pyi +9 -0
  807. numpy/typing/tests/data/mypy.ini +8 -0
  808. numpy/typing/tests/data/pass/arithmetic.py +614 -0
  809. numpy/typing/tests/data/pass/array_constructors.py +138 -0
  810. numpy/typing/tests/data/pass/array_like.py +43 -0
  811. numpy/typing/tests/data/pass/arrayprint.py +37 -0
  812. numpy/typing/tests/data/pass/arrayterator.py +28 -0
  813. numpy/typing/tests/data/pass/bitwise_ops.py +131 -0
  814. numpy/typing/tests/data/pass/comparisons.py +316 -0
  815. numpy/typing/tests/data/pass/dtype.py +57 -0
  816. numpy/typing/tests/data/pass/einsumfunc.py +36 -0
  817. numpy/typing/tests/data/pass/flatiter.py +26 -0
  818. numpy/typing/tests/data/pass/fromnumeric.py +272 -0
  819. numpy/typing/tests/data/pass/index_tricks.py +62 -0
  820. numpy/typing/tests/data/pass/lib_user_array.py +22 -0
  821. numpy/typing/tests/data/pass/lib_utils.py +19 -0
  822. numpy/typing/tests/data/pass/lib_version.py +18 -0
  823. numpy/typing/tests/data/pass/literal.py +52 -0
  824. numpy/typing/tests/data/pass/ma.py +199 -0
  825. numpy/typing/tests/data/pass/mod.py +149 -0
  826. numpy/typing/tests/data/pass/modules.py +45 -0
  827. numpy/typing/tests/data/pass/multiarray.py +77 -0
  828. numpy/typing/tests/data/pass/ndarray_conversion.py +81 -0
  829. numpy/typing/tests/data/pass/ndarray_misc.py +199 -0
  830. numpy/typing/tests/data/pass/ndarray_shape_manipulation.py +47 -0
  831. numpy/typing/tests/data/pass/nditer.py +4 -0
  832. numpy/typing/tests/data/pass/numeric.py +90 -0
  833. numpy/typing/tests/data/pass/numerictypes.py +17 -0
  834. numpy/typing/tests/data/pass/random.py +1498 -0
  835. numpy/typing/tests/data/pass/recfunctions.py +164 -0
  836. numpy/typing/tests/data/pass/scalars.py +249 -0
  837. numpy/typing/tests/data/pass/shape.py +19 -0
  838. numpy/typing/tests/data/pass/simple.py +170 -0
  839. numpy/typing/tests/data/pass/ufunc_config.py +64 -0
  840. numpy/typing/tests/data/pass/ufunclike.py +52 -0
  841. numpy/typing/tests/data/pass/ufuncs.py +16 -0
  842. numpy/typing/tests/data/pass/warnings_and_errors.py +6 -0
  843. numpy/typing/tests/data/reveal/arithmetic.pyi +719 -0
  844. numpy/typing/tests/data/reveal/array_api_info.pyi +70 -0
  845. numpy/typing/tests/data/reveal/array_constructors.pyi +279 -0
  846. numpy/typing/tests/data/reveal/arraypad.pyi +27 -0
  847. numpy/typing/tests/data/reveal/arrayprint.pyi +25 -0
  848. numpy/typing/tests/data/reveal/arraysetops.pyi +74 -0
  849. numpy/typing/tests/data/reveal/arrayterator.pyi +27 -0
  850. numpy/typing/tests/data/reveal/bitwise_ops.pyi +166 -0
  851. numpy/typing/tests/data/reveal/char.pyi +225 -0
  852. numpy/typing/tests/data/reveal/chararray.pyi +138 -0
  853. numpy/typing/tests/data/reveal/comparisons.pyi +264 -0
  854. numpy/typing/tests/data/reveal/constants.pyi +14 -0
  855. numpy/typing/tests/data/reveal/ctypeslib.pyi +81 -0
  856. numpy/typing/tests/data/reveal/datasource.pyi +23 -0
  857. numpy/typing/tests/data/reveal/dtype.pyi +132 -0
  858. numpy/typing/tests/data/reveal/einsumfunc.pyi +39 -0
  859. numpy/typing/tests/data/reveal/emath.pyi +54 -0
  860. numpy/typing/tests/data/reveal/fft.pyi +37 -0
  861. numpy/typing/tests/data/reveal/flatiter.pyi +86 -0
  862. numpy/typing/tests/data/reveal/fromnumeric.pyi +347 -0
  863. numpy/typing/tests/data/reveal/getlimits.pyi +53 -0
  864. numpy/typing/tests/data/reveal/histograms.pyi +25 -0
  865. numpy/typing/tests/data/reveal/index_tricks.pyi +70 -0
  866. numpy/typing/tests/data/reveal/lib_function_base.pyi +409 -0
  867. numpy/typing/tests/data/reveal/lib_polynomial.pyi +147 -0
  868. numpy/typing/tests/data/reveal/lib_utils.pyi +17 -0
  869. numpy/typing/tests/data/reveal/lib_version.pyi +20 -0
  870. numpy/typing/tests/data/reveal/linalg.pyi +154 -0
  871. numpy/typing/tests/data/reveal/ma.pyi +1098 -0
  872. numpy/typing/tests/data/reveal/matrix.pyi +73 -0
  873. numpy/typing/tests/data/reveal/memmap.pyi +19 -0
  874. numpy/typing/tests/data/reveal/mod.pyi +178 -0
  875. numpy/typing/tests/data/reveal/modules.pyi +51 -0
  876. numpy/typing/tests/data/reveal/multiarray.pyi +197 -0
  877. numpy/typing/tests/data/reveal/nbit_base_example.pyi +20 -0
  878. numpy/typing/tests/data/reveal/ndarray_assignability.pyi +82 -0
  879. numpy/typing/tests/data/reveal/ndarray_conversion.pyi +83 -0
  880. numpy/typing/tests/data/reveal/ndarray_misc.pyi +246 -0
  881. numpy/typing/tests/data/reveal/ndarray_shape_manipulation.pyi +47 -0
  882. numpy/typing/tests/data/reveal/nditer.pyi +49 -0
  883. numpy/typing/tests/data/reveal/nested_sequence.pyi +25 -0
  884. numpy/typing/tests/data/reveal/npyio.pyi +83 -0
  885. numpy/typing/tests/data/reveal/numeric.pyi +170 -0
  886. numpy/typing/tests/data/reveal/numerictypes.pyi +16 -0
  887. numpy/typing/tests/data/reveal/polynomial_polybase.pyi +217 -0
  888. numpy/typing/tests/data/reveal/polynomial_polyutils.pyi +218 -0
  889. numpy/typing/tests/data/reveal/polynomial_series.pyi +138 -0
  890. numpy/typing/tests/data/reveal/random.pyi +1546 -0
  891. numpy/typing/tests/data/reveal/rec.pyi +171 -0
  892. numpy/typing/tests/data/reveal/scalars.pyi +191 -0
  893. numpy/typing/tests/data/reveal/shape.pyi +13 -0
  894. numpy/typing/tests/data/reveal/shape_base.pyi +52 -0
  895. numpy/typing/tests/data/reveal/stride_tricks.pyi +27 -0
  896. numpy/typing/tests/data/reveal/strings.pyi +196 -0
  897. numpy/typing/tests/data/reveal/testing.pyi +198 -0
  898. numpy/typing/tests/data/reveal/twodim_base.pyi +225 -0
  899. numpy/typing/tests/data/reveal/type_check.pyi +67 -0
  900. numpy/typing/tests/data/reveal/ufunc_config.pyi +29 -0
  901. numpy/typing/tests/data/reveal/ufunclike.pyi +31 -0
  902. numpy/typing/tests/data/reveal/ufuncs.pyi +142 -0
  903. numpy/typing/tests/data/reveal/warnings_and_errors.pyi +11 -0
  904. numpy/typing/tests/test_isfile.py +38 -0
  905. numpy/typing/tests/test_runtime.py +110 -0
  906. numpy/typing/tests/test_typing.py +205 -0
  907. numpy/version.py +11 -0
  908. numpy/version.pyi +9 -0
  909. numpy-2.4.2.dist-info/METADATA +139 -0
  910. numpy-2.4.2.dist-info/RECORD +929 -0
  911. numpy-2.4.2.dist-info/WHEEL +4 -0
  912. numpy-2.4.2.dist-info/entry_points.txt +13 -0
  913. numpy-2.4.2.dist-info/licenses/LICENSE.txt +914 -0
  914. numpy-2.4.2.dist-info/licenses/numpy/_core/include/numpy/libdivide/LICENSE.txt +21 -0
  915. numpy-2.4.2.dist-info/licenses/numpy/_core/src/common/pythoncapi-compat/COPYING +14 -0
  916. numpy-2.4.2.dist-info/licenses/numpy/_core/src/highway/LICENSE +371 -0
  917. numpy-2.4.2.dist-info/licenses/numpy/_core/src/multiarray/dragon4_LICENSE.txt +27 -0
  918. numpy-2.4.2.dist-info/licenses/numpy/_core/src/npysort/x86-simd-sort/LICENSE.md +28 -0
  919. numpy-2.4.2.dist-info/licenses/numpy/_core/src/umath/svml/LICENSE +30 -0
  920. numpy-2.4.2.dist-info/licenses/numpy/fft/pocketfft/LICENSE.md +25 -0
  921. numpy-2.4.2.dist-info/licenses/numpy/linalg/lapack_lite/LICENSE.txt +48 -0
  922. numpy-2.4.2.dist-info/licenses/numpy/ma/LICENSE +24 -0
  923. numpy-2.4.2.dist-info/licenses/numpy/random/LICENSE.md +71 -0
  924. numpy-2.4.2.dist-info/licenses/numpy/random/src/distributions/LICENSE.md +61 -0
  925. numpy-2.4.2.dist-info/licenses/numpy/random/src/mt19937/LICENSE.md +61 -0
  926. numpy-2.4.2.dist-info/licenses/numpy/random/src/pcg64/LICENSE.md +22 -0
  927. numpy-2.4.2.dist-info/licenses/numpy/random/src/philox/LICENSE.md +31 -0
  928. numpy-2.4.2.dist-info/licenses/numpy/random/src/sfc64/LICENSE.md +27 -0
  929. numpy-2.4.2.dist-info/licenses/numpy/random/src/splitmix64/LICENSE.md +9 -0
@@ -0,0 +1,1650 @@
1
+ """
2
+ Implementation of optimized einsum.
3
+
4
+ """
5
+ import functools
6
+ import itertools
7
+ import operator
8
+
9
+ from numpy._core.multiarray import c_einsum, matmul
10
+ from numpy._core.numeric import asanyarray, reshape
11
+ from numpy._core.overrides import array_function_dispatch
12
+ from numpy._core.umath import multiply
13
+
14
+ __all__ = ['einsum', 'einsum_path']
15
+
16
+ # importing string for string.ascii_letters would be too slow
17
+ # the first import before caching has been measured to take 800 µs (#23777)
18
+ # imports begin with uppercase to mimic ASCII values to avoid sorting issues
19
+ einsum_symbols = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20
+ einsum_symbols_set = set(einsum_symbols)
21
+
22
+
23
+ def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
24
+ """
25
+ Computes the number of FLOPS in the contraction.
26
+
27
+ Parameters
28
+ ----------
29
+ idx_contraction : iterable
30
+ The indices involved in the contraction
31
+ inner : bool
32
+ Does this contraction require an inner product?
33
+ num_terms : int
34
+ The number of terms in a contraction
35
+ size_dictionary : dict
36
+ The size of each of the indices in idx_contraction
37
+
38
+ Returns
39
+ -------
40
+ flop_count : int
41
+ The total number of FLOPS required for the contraction.
42
+
43
+ Examples
44
+ --------
45
+
46
+ >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
47
+ 30
48
+
49
+ >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
50
+ 60
51
+
52
+ """
53
+
54
+ overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
55
+ op_factor = max(1, num_terms - 1)
56
+ if inner:
57
+ op_factor += 1
58
+
59
+ return overall_size * op_factor
60
+
61
+ def _compute_size_by_dict(indices, idx_dict):
62
+ """
63
+ Computes the product of the elements in indices based on the dictionary
64
+ idx_dict.
65
+
66
+ Parameters
67
+ ----------
68
+ indices : iterable
69
+ Indices to base the product on.
70
+ idx_dict : dictionary
71
+ Dictionary of index sizes
72
+
73
+ Returns
74
+ -------
75
+ ret : int
76
+ The resulting product.
77
+
78
+ Examples
79
+ --------
80
+ >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
81
+ 90
82
+
83
+ """
84
+ ret = 1
85
+ for i in indices:
86
+ ret *= idx_dict[i]
87
+ return ret
88
+
89
+
90
+ def _find_contraction(positions, input_sets, output_set):
91
+ """
92
+ Finds the contraction for a given set of input and output sets.
93
+
94
+ Parameters
95
+ ----------
96
+ positions : iterable
97
+ Integer positions of terms used in the contraction.
98
+ input_sets : list
99
+ List of sets that represent the lhs side of the einsum subscript
100
+ output_set : set
101
+ Set that represents the rhs side of the overall einsum subscript
102
+
103
+ Returns
104
+ -------
105
+ new_result : set
106
+ The indices of the resulting contraction
107
+ remaining : list
108
+ List of sets that have not been contracted, the new set is appended to
109
+ the end of this list
110
+ idx_removed : set
111
+ Indices removed from the entire contraction
112
+ idx_contraction : set
113
+ The indices used in the current contraction
114
+
115
+ Examples
116
+ --------
117
+
118
+ # A simple dot product test case
119
+ >>> pos = (0, 1)
120
+ >>> isets = [set('ab'), set('bc')]
121
+ >>> oset = set('ac')
122
+ >>> _find_contraction(pos, isets, oset)
123
+ ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
124
+
125
+ # A more complex case with additional terms in the contraction
126
+ >>> pos = (0, 2)
127
+ >>> isets = [set('abd'), set('ac'), set('bdc')]
128
+ >>> oset = set('ac')
129
+ >>> _find_contraction(pos, isets, oset)
130
+ ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
131
+ """
132
+
133
+ idx_contract = set()
134
+ idx_remain = output_set.copy()
135
+ remaining = []
136
+ for ind, value in enumerate(input_sets):
137
+ if ind in positions:
138
+ idx_contract |= value
139
+ else:
140
+ remaining.append(value)
141
+ idx_remain |= value
142
+
143
+ new_result = idx_remain & idx_contract
144
+ idx_removed = (idx_contract - new_result)
145
+ remaining.append(new_result)
146
+
147
+ return (new_result, remaining, idx_removed, idx_contract)
148
+
149
+
150
+ def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
151
+ """
152
+ Computes all possible pair contractions, sieves the results based
153
+ on ``memory_limit`` and returns the lowest cost path. This algorithm
154
+ scales factorial with respect to the elements in the list ``input_sets``.
155
+
156
+ Parameters
157
+ ----------
158
+ input_sets : list
159
+ List of sets that represent the lhs side of the einsum subscript
160
+ output_set : set
161
+ Set that represents the rhs side of the overall einsum subscript
162
+ idx_dict : dictionary
163
+ Dictionary of index sizes
164
+ memory_limit : int
165
+ The maximum number of elements in a temporary array
166
+
167
+ Returns
168
+ -------
169
+ path : list
170
+ The optimal contraction order within the memory limit constraint.
171
+
172
+ Examples
173
+ --------
174
+ >>> isets = [set('abd'), set('ac'), set('bdc')]
175
+ >>> oset = set()
176
+ >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
177
+ >>> _optimal_path(isets, oset, idx_sizes, 5000)
178
+ [(0, 2), (0, 1)]
179
+ """
180
+
181
+ full_results = [(0, [], input_sets)]
182
+ for iteration in range(len(input_sets) - 1):
183
+ iter_results = []
184
+
185
+ # Compute all unique pairs
186
+ for curr in full_results:
187
+ cost, positions, remaining = curr
188
+ for con in itertools.combinations(
189
+ range(len(input_sets) - iteration), 2
190
+ ):
191
+
192
+ # Find the contraction
193
+ cont = _find_contraction(con, remaining, output_set)
194
+ new_result, new_input_sets, idx_removed, idx_contract = cont
195
+
196
+ # Sieve the results based on memory_limit
197
+ new_size = _compute_size_by_dict(new_result, idx_dict)
198
+ if new_size > memory_limit:
199
+ continue
200
+
201
+ # Build (total_cost, positions, indices_remaining)
202
+ total_cost = cost + _flop_count(
203
+ idx_contract, idx_removed, len(con), idx_dict
204
+ )
205
+ new_pos = positions + [con]
206
+ iter_results.append((total_cost, new_pos, new_input_sets))
207
+
208
+ # Update combinatorial list, if we did not find anything return best
209
+ # path + remaining contractions
210
+ if iter_results:
211
+ full_results = iter_results
212
+ else:
213
+ path = min(full_results, key=lambda x: x[0])[1]
214
+ path += [tuple(range(len(input_sets) - iteration))]
215
+ return path
216
+
217
+ # If we have not found anything return single einsum contraction
218
+ if len(full_results) == 0:
219
+ return [tuple(range(len(input_sets)))]
220
+
221
+ path = min(full_results, key=lambda x: x[0])[1]
222
+ return path
223
+
224
+ def _parse_possible_contraction(
225
+ positions, input_sets, output_set, idx_dict,
226
+ memory_limit, path_cost, naive_cost
227
+ ):
228
+ """Compute the cost (removed size + flops) and resultant indices for
229
+ performing the contraction specified by ``positions``.
230
+
231
+ Parameters
232
+ ----------
233
+ positions : tuple of int
234
+ The locations of the proposed tensors to contract.
235
+ input_sets : list of sets
236
+ The indices found on each tensors.
237
+ output_set : set
238
+ The output indices of the expression.
239
+ idx_dict : dict
240
+ Mapping of each index to its size.
241
+ memory_limit : int
242
+ The total allowed size for an intermediary tensor.
243
+ path_cost : int
244
+ The contraction cost so far.
245
+ naive_cost : int
246
+ The cost of the unoptimized expression.
247
+
248
+ Returns
249
+ -------
250
+ cost : (int, int)
251
+ A tuple containing the size of any indices removed, and the flop cost.
252
+ positions : tuple of int
253
+ The locations of the proposed tensors to contract.
254
+ new_input_sets : list of sets
255
+ The resulting new list of indices if this proposed contraction
256
+ is performed.
257
+
258
+ """
259
+
260
+ # Find the contraction
261
+ contract = _find_contraction(positions, input_sets, output_set)
262
+ idx_result, new_input_sets, idx_removed, idx_contract = contract
263
+
264
+ # Sieve the results based on memory_limit
265
+ new_size = _compute_size_by_dict(idx_result, idx_dict)
266
+ if new_size > memory_limit:
267
+ return None
268
+
269
+ # Build sort tuple
270
+ old_sizes = (
271
+ _compute_size_by_dict(input_sets[p], idx_dict) for p in positions
272
+ )
273
+ removed_size = sum(old_sizes) - new_size
274
+
275
+ # NB: removed_size used to be just the size of any removed indices i.e.:
276
+ # helpers.compute_size_by_dict(idx_removed, idx_dict)
277
+ cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
278
+ sort = (-removed_size, cost)
279
+
280
+ # Sieve based on total cost as well
281
+ if (path_cost + cost) > naive_cost:
282
+ return None
283
+
284
+ # Add contraction to possible choices
285
+ return [sort, positions, new_input_sets]
286
+
287
+
288
+ def _update_other_results(results, best):
289
+ """Update the positions and provisional input_sets of ``results``
290
+ based on performing the contraction result ``best``. Remove any
291
+ involving the tensors contracted.
292
+
293
+ Parameters
294
+ ----------
295
+ results : list
296
+ List of contraction results produced by
297
+ ``_parse_possible_contraction``.
298
+ best : list
299
+ The best contraction of ``results`` i.e. the one that
300
+ will be performed.
301
+
302
+ Returns
303
+ -------
304
+ mod_results : list
305
+ The list of modified results, updated with outcome of
306
+ ``best`` contraction.
307
+ """
308
+
309
+ best_con = best[1]
310
+ bx, by = best_con
311
+ mod_results = []
312
+
313
+ for cost, (x, y), con_sets in results:
314
+
315
+ # Ignore results involving tensors just contracted
316
+ if x in best_con or y in best_con:
317
+ continue
318
+
319
+ # Update the input_sets
320
+ del con_sets[by - int(by > x) - int(by > y)]
321
+ del con_sets[bx - int(bx > x) - int(bx > y)]
322
+ con_sets.insert(-1, best[2][-1])
323
+
324
+ # Update the position indices
325
+ mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
326
+ mod_results.append((cost, mod_con, con_sets))
327
+
328
+ return mod_results
329
+
330
+ def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
331
+ """
332
+ Finds the path by contracting the best pair until the input list is
333
+ exhausted. The best pair is found by minimizing the tuple
334
+ ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
335
+ matrix multiplication or inner product operations, then Hadamard like
336
+ operations, and finally outer operations. Outer products are limited by
337
+ ``memory_limit``. This algorithm scales cubically with respect to the
338
+ number of elements in the list ``input_sets``.
339
+
340
+ Parameters
341
+ ----------
342
+ input_sets : list
343
+ List of sets that represent the lhs side of the einsum subscript
344
+ output_set : set
345
+ Set that represents the rhs side of the overall einsum subscript
346
+ idx_dict : dictionary
347
+ Dictionary of index sizes
348
+ memory_limit : int
349
+ The maximum number of elements in a temporary array
350
+
351
+ Returns
352
+ -------
353
+ path : list
354
+ The greedy contraction order within the memory limit constraint.
355
+
356
+ Examples
357
+ --------
358
+ >>> isets = [set('abd'), set('ac'), set('bdc')]
359
+ >>> oset = set()
360
+ >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
361
+ >>> _greedy_path(isets, oset, idx_sizes, 5000)
362
+ [(0, 2), (0, 1)]
363
+ """
364
+
365
+ # Handle trivial cases that leaked through
366
+ if len(input_sets) == 1:
367
+ return [(0,)]
368
+ elif len(input_sets) == 2:
369
+ return [(0, 1)]
370
+
371
+ # Build up a naive cost
372
+ contract = _find_contraction(
373
+ range(len(input_sets)), input_sets, output_set
374
+ )
375
+ idx_result, new_input_sets, idx_removed, idx_contract = contract
376
+ naive_cost = _flop_count(
377
+ idx_contract, idx_removed, len(input_sets), idx_dict
378
+ )
379
+
380
+ # Initially iterate over all pairs
381
+ comb_iter = itertools.combinations(range(len(input_sets)), 2)
382
+ known_contractions = []
383
+
384
+ path_cost = 0
385
+ path = []
386
+
387
+ for iteration in range(len(input_sets) - 1):
388
+
389
+ # Iterate over all pairs on the first step, only previously
390
+ # found pairs on subsequent steps
391
+ for positions in comb_iter:
392
+
393
+ # Always initially ignore outer products
394
+ if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
395
+ continue
396
+
397
+ result = _parse_possible_contraction(
398
+ positions, input_sets, output_set, idx_dict,
399
+ memory_limit, path_cost, naive_cost
400
+ )
401
+ if result is not None:
402
+ known_contractions.append(result)
403
+
404
+ # If we do not have a inner contraction, rescan pairs
405
+ # including outer products
406
+ if len(known_contractions) == 0:
407
+
408
+ # Then check the outer products
409
+ for positions in itertools.combinations(
410
+ range(len(input_sets)), 2
411
+ ):
412
+ result = _parse_possible_contraction(
413
+ positions, input_sets, output_set, idx_dict,
414
+ memory_limit, path_cost, naive_cost
415
+ )
416
+ if result is not None:
417
+ known_contractions.append(result)
418
+
419
+ # If we still did not find any remaining contractions,
420
+ # default back to einsum like behavior
421
+ if len(known_contractions) == 0:
422
+ path.append(tuple(range(len(input_sets))))
423
+ break
424
+
425
+ # Sort based on first index
426
+ best = min(known_contractions, key=lambda x: x[0])
427
+
428
+ # Now propagate as many unused contractions as possible
429
+ # to the next iteration
430
+ known_contractions = _update_other_results(known_contractions, best)
431
+
432
+ # Next iteration only compute contractions with the new tensor
433
+ # All other contractions have been accounted for
434
+ input_sets = best[2]
435
+ new_tensor_pos = len(input_sets) - 1
436
+ comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
437
+
438
+ # Update path and total cost
439
+ path.append(best[1])
440
+ path_cost += best[0][1]
441
+
442
+ return path
443
+
444
+
445
+ def _parse_einsum_input(operands):
446
+ """
447
+ A reproduction of einsum c side einsum parsing in python.
448
+
449
+ Returns
450
+ -------
451
+ input_strings : str
452
+ Parsed input strings
453
+ output_string : str
454
+ Parsed output string
455
+ operands : list of array_like
456
+ The operands to use in the numpy contraction
457
+
458
+ Examples
459
+ --------
460
+ The operand list is simplified to reduce printing:
461
+
462
+ >>> np.random.seed(123)
463
+ >>> a = np.random.rand(4, 4)
464
+ >>> b = np.random.rand(4, 4, 4)
465
+ >>> _parse_einsum_input(('...a,...a->...', a, b))
466
+ ('za,xza', 'xz', [a, b]) # may vary
467
+
468
+ >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
469
+ ('za,xza', 'xz', [a, b]) # may vary
470
+ """
471
+
472
+ if len(operands) == 0:
473
+ raise ValueError("No input operands")
474
+
475
+ if isinstance(operands[0], str):
476
+ subscripts = operands[0].replace(" ", "")
477
+ operands = [asanyarray(v) for v in operands[1:]]
478
+
479
+ # Ensure all characters are valid
480
+ for s in subscripts:
481
+ if s in '.,->':
482
+ continue
483
+ if s not in einsum_symbols:
484
+ raise ValueError(f"Character {s} is not a valid symbol.")
485
+
486
+ else:
487
+ tmp_operands = list(operands)
488
+ operand_list = []
489
+ subscript_list = []
490
+ for p in range(len(operands) // 2):
491
+ operand_list.append(tmp_operands.pop(0))
492
+ subscript_list.append(tmp_operands.pop(0))
493
+
494
+ output_list = tmp_operands[-1] if len(tmp_operands) else None
495
+ operands = [asanyarray(v) for v in operand_list]
496
+ subscripts = ""
497
+ last = len(subscript_list) - 1
498
+ for num, sub in enumerate(subscript_list):
499
+ for s in sub:
500
+ if s is Ellipsis:
501
+ subscripts += "..."
502
+ else:
503
+ try:
504
+ s = operator.index(s)
505
+ except TypeError as e:
506
+ raise TypeError(
507
+ "For this input type lists must contain "
508
+ "either int or Ellipsis"
509
+ ) from e
510
+ subscripts += einsum_symbols[s]
511
+ if num != last:
512
+ subscripts += ","
513
+
514
+ if output_list is not None:
515
+ subscripts += "->"
516
+ for s in output_list:
517
+ if s is Ellipsis:
518
+ subscripts += "..."
519
+ else:
520
+ try:
521
+ s = operator.index(s)
522
+ except TypeError as e:
523
+ raise TypeError(
524
+ "For this input type lists must contain "
525
+ "either int or Ellipsis"
526
+ ) from e
527
+ subscripts += einsum_symbols[s]
528
+ # Check for proper "->"
529
+ if ("-" in subscripts) or (">" in subscripts):
530
+ invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
531
+ if invalid or (subscripts.count("->") != 1):
532
+ raise ValueError("Subscripts can only contain one '->'.")
533
+
534
+ # Parse ellipses
535
+ if "." in subscripts:
536
+ used = subscripts.replace(".", "").replace(",", "").replace("->", "")
537
+ unused = list(einsum_symbols_set - set(used))
538
+ ellipse_inds = "".join(unused)
539
+ longest = 0
540
+
541
+ if "->" in subscripts:
542
+ input_tmp, output_sub = subscripts.split("->")
543
+ split_subscripts = input_tmp.split(",")
544
+ out_sub = True
545
+ else:
546
+ split_subscripts = subscripts.split(',')
547
+ out_sub = False
548
+
549
+ for num, sub in enumerate(split_subscripts):
550
+ if "." in sub:
551
+ if (sub.count(".") != 3) or (sub.count("...") != 1):
552
+ raise ValueError("Invalid Ellipses.")
553
+
554
+ # Take into account numerical values
555
+ if operands[num].shape == ():
556
+ ellipse_count = 0
557
+ else:
558
+ ellipse_count = max(operands[num].ndim, 1)
559
+ ellipse_count -= (len(sub) - 3)
560
+
561
+ if ellipse_count > longest:
562
+ longest = ellipse_count
563
+
564
+ if ellipse_count < 0:
565
+ raise ValueError("Ellipses lengths do not match.")
566
+ elif ellipse_count == 0:
567
+ split_subscripts[num] = sub.replace('...', '')
568
+ else:
569
+ rep_inds = ellipse_inds[-ellipse_count:]
570
+ split_subscripts[num] = sub.replace('...', rep_inds)
571
+
572
+ subscripts = ",".join(split_subscripts)
573
+ if longest == 0:
574
+ out_ellipse = ""
575
+ else:
576
+ out_ellipse = ellipse_inds[-longest:]
577
+
578
+ if out_sub:
579
+ subscripts += "->" + output_sub.replace("...", out_ellipse)
580
+ else:
581
+ # Special care for outputless ellipses
582
+ output_subscript = ""
583
+ tmp_subscripts = subscripts.replace(",", "")
584
+ for s in sorted(set(tmp_subscripts)):
585
+ if s not in (einsum_symbols):
586
+ raise ValueError(f"Character {s} is not a valid symbol.")
587
+ if tmp_subscripts.count(s) == 1:
588
+ output_subscript += s
589
+ normal_inds = ''.join(sorted(set(output_subscript) -
590
+ set(out_ellipse)))
591
+
592
+ subscripts += "->" + out_ellipse + normal_inds
593
+
594
+ # Build output string if does not exist
595
+ if "->" in subscripts:
596
+ input_subscripts, output_subscript = subscripts.split("->")
597
+ else:
598
+ input_subscripts = subscripts
599
+ # Build output subscripts
600
+ tmp_subscripts = subscripts.replace(",", "")
601
+ output_subscript = ""
602
+ for s in sorted(set(tmp_subscripts)):
603
+ if s not in einsum_symbols:
604
+ raise ValueError(f"Character {s} is not a valid symbol.")
605
+ if tmp_subscripts.count(s) == 1:
606
+ output_subscript += s
607
+
608
+ # Make sure output subscripts are in the input
609
+ for char in output_subscript:
610
+ if output_subscript.count(char) != 1:
611
+ raise ValueError("Output character %s appeared more than once in "
612
+ "the output." % char)
613
+ if char not in input_subscripts:
614
+ raise ValueError(f"Output character {char} did not appear in the input")
615
+
616
+ # Make sure number operands is equivalent to the number of terms
617
+ if len(input_subscripts.split(',')) != len(operands):
618
+ raise ValueError("Number of einsum subscripts must be equal to the "
619
+ "number of operands.")
620
+
621
+ return (input_subscripts, output_subscript, operands)
622
+
623
+
624
+ def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
625
+ # NOTE: technically, we should only dispatch on array-like arguments, not
626
+ # subscripts (given as strings). But separating operands into
627
+ # arrays/subscripts is a little tricky/slow (given einsum's two supported
628
+ # signatures), so as a practical shortcut we dispatch on everything.
629
+ # Strings will be ignored for dispatching since they don't define
630
+ # __array_function__.
631
+ return operands
632
+
633
+
634
+ @array_function_dispatch(_einsum_path_dispatcher, module='numpy')
635
+ def einsum_path(*operands, optimize='greedy', einsum_call=False):
636
+ """
637
+ einsum_path(subscripts, *operands, optimize='greedy')
638
+
639
+ Evaluates the lowest cost contraction order for an einsum expression by
640
+ considering the creation of intermediate arrays.
641
+
642
+ Parameters
643
+ ----------
644
+ subscripts : str
645
+ Specifies the subscripts for summation.
646
+ *operands : list of array_like
647
+ These are the arrays for the operation.
648
+ optimize : {bool, list, tuple, 'greedy', 'optimal'}
649
+ Choose the type of path. If a tuple is provided, the second argument is
650
+ assumed to be the maximum intermediate size created. If only a single
651
+ argument is provided the largest input or output array size is used
652
+ as a maximum intermediate size.
653
+
654
+ * if a list is given that starts with ``einsum_path``, uses this as the
655
+ contraction path
656
+ * if False no optimization is taken
657
+ * if True defaults to the 'greedy' algorithm
658
+ * 'optimal' An algorithm that combinatorially explores all possible
659
+ ways of contracting the listed tensors and chooses the least costly
660
+ path. Scales exponentially with the number of terms in the
661
+ contraction.
662
+ * 'greedy' An algorithm that chooses the best pair contraction
663
+ at each step. Effectively, this algorithm searches the largest inner,
664
+ Hadamard, and then outer products at each step. Scales cubically with
665
+ the number of terms in the contraction. Equivalent to the 'optimal'
666
+ path for most contractions.
667
+
668
+ Default is 'greedy'.
669
+
670
+ Returns
671
+ -------
672
+ path : list of tuples
673
+ A list representation of the einsum path.
674
+ string_repr : str
675
+ A printable representation of the einsum path.
676
+
677
+ Notes
678
+ -----
679
+ The resulting path indicates which terms of the input contraction should be
680
+ contracted first, the result of this contraction is then appended to the
681
+ end of the contraction list. This list can then be iterated over until all
682
+ intermediate contractions are complete.
683
+
684
+ See Also
685
+ --------
686
+ einsum, linalg.multi_dot
687
+
688
+ Examples
689
+ --------
690
+
691
+ We can begin with a chain dot example. In this case, it is optimal to
692
+ contract the ``b`` and ``c`` tensors first as represented by the first
693
+ element of the path ``(1, 2)``. The resulting tensor is added to the end
694
+ of the contraction and the remaining contraction ``(0, 1)`` is then
695
+ completed.
696
+
697
+ >>> np.random.seed(123)
698
+ >>> a = np.random.rand(2, 2)
699
+ >>> b = np.random.rand(2, 5)
700
+ >>> c = np.random.rand(5, 2)
701
+ >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
702
+ >>> print(path_info[0])
703
+ ['einsum_path', (1, 2), (0, 1)]
704
+ >>> print(path_info[1])
705
+ Complete contraction: ij,jk,kl->il # may vary
706
+ Naive scaling: 4
707
+ Optimized scaling: 3
708
+ Naive FLOP count: 1.600e+02
709
+ Optimized FLOP count: 5.600e+01
710
+ Theoretical speedup: 2.857
711
+ Largest intermediate: 4.000e+00 elements
712
+ -------------------------------------------------------------------------
713
+ scaling current remaining
714
+ -------------------------------------------------------------------------
715
+ 3 kl,jk->jl ij,jl->il
716
+ 3 jl,ij->il il->il
717
+
718
+
719
+ A more complex index transformation example.
720
+
721
+ >>> I = np.random.rand(10, 10, 10, 10)
722
+ >>> C = np.random.rand(10, 10)
723
+ >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
724
+ ... optimize='greedy')
725
+
726
+ >>> print(path_info[0])
727
+ ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
728
+ >>> print(path_info[1])
729
+ Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
730
+ Naive scaling: 8
731
+ Optimized scaling: 5
732
+ Naive FLOP count: 8.000e+08
733
+ Optimized FLOP count: 8.000e+05
734
+ Theoretical speedup: 1000.000
735
+ Largest intermediate: 1.000e+04 elements
736
+ --------------------------------------------------------------------------
737
+ scaling current remaining
738
+ --------------------------------------------------------------------------
739
+ 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
740
+ 5 bcde,fb->cdef gc,hd,cdef->efgh
741
+ 5 cdef,gc->defg hd,defg->efgh
742
+ 5 defg,hd->efgh efgh->efgh
743
+ """
744
+
745
+ # Figure out what the path really is
746
+ path_type = optimize
747
+ if path_type is True:
748
+ path_type = 'greedy'
749
+ if path_type is None:
750
+ path_type = False
751
+
752
+ explicit_einsum_path = False
753
+ memory_limit = None
754
+
755
+ # No optimization or a named path algorithm
756
+ if (path_type is False) or isinstance(path_type, str):
757
+ pass
758
+
759
+ # Given an explicit path
760
+ elif len(path_type) and (path_type[0] == 'einsum_path'):
761
+ explicit_einsum_path = True
762
+
763
+ # Path tuple with memory limit
764
+ elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
765
+ isinstance(path_type[1], (int, float))):
766
+ memory_limit = int(path_type[1])
767
+ path_type = path_type[0]
768
+
769
+ else:
770
+ raise TypeError(f"Did not understand the path: {str(path_type)}")
771
+
772
+ # Hidden option, only einsum should call this
773
+ einsum_call_arg = einsum_call
774
+
775
+ # Python side parsing
776
+ input_subscripts, output_subscript, operands = (
777
+ _parse_einsum_input(operands)
778
+ )
779
+
780
+ # Build a few useful list and sets
781
+ input_list = input_subscripts.split(',')
782
+ num_inputs = len(input_list)
783
+ input_sets = [set(x) for x in input_list]
784
+ output_set = set(output_subscript)
785
+ indices = set(input_subscripts.replace(',', ''))
786
+ num_indices = len(indices)
787
+
788
+ # Get length of each unique dimension and ensure all dimensions are correct
789
+ dimension_dict = {}
790
+ for tnum, term in enumerate(input_list):
791
+ sh = operands[tnum].shape
792
+ if len(sh) != len(term):
793
+ raise ValueError("Einstein sum subscript %s does not contain the "
794
+ "correct number of indices for operand %d."
795
+ % (input_subscripts[tnum], tnum))
796
+ for cnum, char in enumerate(term):
797
+ dim = sh[cnum]
798
+
799
+ if char in dimension_dict.keys():
800
+ # For broadcasting cases we always want the largest dim size
801
+ if dimension_dict[char] == 1:
802
+ dimension_dict[char] = dim
803
+ elif dim not in (1, dimension_dict[char]):
804
+ raise ValueError("Size of label '%s' for operand %d (%d) "
805
+ "does not match previous terms (%d)."
806
+ % (char, tnum, dimension_dict[char], dim))
807
+ else:
808
+ dimension_dict[char] = dim
809
+
810
+ # Compute size of each input array plus the output array
811
+ size_list = [_compute_size_by_dict(term, dimension_dict)
812
+ for term in input_list + [output_subscript]]
813
+ max_size = max(size_list)
814
+
815
+ if memory_limit is None:
816
+ memory_arg = max_size
817
+ else:
818
+ memory_arg = memory_limit
819
+
820
+ # Compute the path
821
+ if explicit_einsum_path:
822
+ path = path_type[1:]
823
+ elif (
824
+ (path_type is False)
825
+ or (num_inputs in [1, 2])
826
+ or (indices == output_set)
827
+ ):
828
+ # Nothing to be optimized, leave it to einsum
829
+ path = [tuple(range(num_inputs))]
830
+ elif path_type == "greedy":
831
+ path = _greedy_path(
832
+ input_sets, output_set, dimension_dict, memory_arg
833
+ )
834
+ elif path_type == "optimal":
835
+ path = _optimal_path(
836
+ input_sets, output_set, dimension_dict, memory_arg
837
+ )
838
+ else:
839
+ raise KeyError("Path name %s not found", path_type)
840
+
841
+ cost_list, scale_list, size_list, contraction_list = [], [], [], []
842
+
843
+ # Build contraction tuple (positions, gemm, einsum_str, remaining)
844
+ for cnum, contract_inds in enumerate(path):
845
+ # Make sure we remove inds from right to left
846
+ contract_inds = tuple(sorted(contract_inds, reverse=True))
847
+
848
+ contract = _find_contraction(contract_inds, input_sets, output_set)
849
+ out_inds, input_sets, idx_removed, idx_contract = contract
850
+
851
+ if not einsum_call_arg:
852
+ # these are only needed for printing info
853
+ cost = _flop_count(
854
+ idx_contract, idx_removed, len(contract_inds), dimension_dict
855
+ )
856
+ cost_list.append(cost)
857
+ scale_list.append(len(idx_contract))
858
+ size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
859
+
860
+ tmp_inputs = []
861
+ for x in contract_inds:
862
+ tmp_inputs.append(input_list.pop(x))
863
+
864
+ # Last contraction
865
+ if (cnum - len(path)) == -1:
866
+ idx_result = output_subscript
867
+ else:
868
+ sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
869
+ idx_result = "".join([x[1] for x in sorted(sort_result)])
870
+
871
+ input_list.append(idx_result)
872
+ einsum_str = ",".join(tmp_inputs) + "->" + idx_result
873
+
874
+ contraction = (contract_inds, einsum_str, input_list[:])
875
+ contraction_list.append(contraction)
876
+
877
+ if len(input_list) != 1:
878
+ # Explicit "einsum_path" is usually trusted, but we detect this kind of
879
+ # mistake in order to prevent from returning an intermediate value.
880
+ raise RuntimeError(
881
+ f"Invalid einsum_path is specified: {len(input_list) - 1} more "
882
+ "operands has to be contracted.")
883
+
884
+ if einsum_call_arg:
885
+ return (operands, contraction_list)
886
+
887
+ # Return the path along with a nice string representation
888
+ overall_contraction = input_subscripts + "->" + output_subscript
889
+ header = ("scaling", "current", "remaining")
890
+
891
+ # Compute naive cost
892
+ # This isn't quite right, need to look into exactly how einsum does this
893
+ inner_product = (
894
+ sum(len(set(x)) for x in input_subscripts.split(',')) - num_indices
895
+ ) > 0
896
+ naive_cost = _flop_count(
897
+ indices, inner_product, num_inputs, dimension_dict
898
+ )
899
+
900
+ opt_cost = sum(cost_list) + 1
901
+ speedup = naive_cost / opt_cost
902
+ max_i = max(size_list)
903
+
904
+ path_print = f" Complete contraction: {overall_contraction}\n"
905
+ path_print += f" Naive scaling: {num_indices}\n"
906
+ path_print += " Optimized scaling: %d\n" % max(scale_list)
907
+ path_print += f" Naive FLOP count: {naive_cost:.3e}\n"
908
+ path_print += f" Optimized FLOP count: {opt_cost:.3e}\n"
909
+ path_print += f" Theoretical speedup: {speedup:3.3f}\n"
910
+ path_print += f" Largest intermediate: {max_i:.3e} elements\n"
911
+ path_print += "-" * 74 + "\n"
912
+ path_print += "%6s %24s %40s\n" % header
913
+ path_print += "-" * 74
914
+
915
+ for n, contraction in enumerate(contraction_list):
916
+ _, einsum_str, remaining = contraction
917
+ remaining_str = ",".join(remaining) + "->" + output_subscript
918
+ path_run = (scale_list[n], einsum_str, remaining_str)
919
+ path_print += "\n%4d %24s %40s" % path_run
920
+
921
+ path = ['einsum_path'] + path
922
+ return (path, path_print)
923
+
924
+
925
+ def _parse_eq_to_pure_multiplication(a_term, shape_a, b_term, shape_b, out):
926
+ """If there are no contracted indices, then we can directly transpose and
927
+ insert singleton dimensions into ``a`` and ``b`` such that (broadcast)
928
+ elementwise multiplication performs the einsum.
929
+
930
+ No need to cache this as it is within the cached
931
+ ``_parse_eq_to_batch_matmul``.
932
+
933
+ """
934
+ desired_a = ""
935
+ desired_b = ""
936
+ new_shape_a = []
937
+ new_shape_b = []
938
+ for ix in out:
939
+ if ix in a_term:
940
+ desired_a += ix
941
+ new_shape_a.append(shape_a[a_term.index(ix)])
942
+ else:
943
+ new_shape_a.append(1)
944
+ if ix in b_term:
945
+ desired_b += ix
946
+ new_shape_b.append(shape_b[b_term.index(ix)])
947
+ else:
948
+ new_shape_b.append(1)
949
+
950
+ if desired_a != a_term:
951
+ eq_a = f"{a_term}->{desired_a}"
952
+ else:
953
+ eq_a = None
954
+ if desired_b != b_term:
955
+ eq_b = f"{b_term}->{desired_b}"
956
+ else:
957
+ eq_b = None
958
+
959
+ return (
960
+ eq_a,
961
+ eq_b,
962
+ new_shape_a,
963
+ new_shape_b,
964
+ None, # new_shape_ab, not needed since not fusing
965
+ None, # perm_ab, not needed as we transpose a and b first
966
+ True, # pure_multiplication=True
967
+ )
968
+
969
+
970
+ @functools.lru_cache(2**12)
971
+ def _parse_eq_to_batch_matmul(eq, shape_a, shape_b):
972
+ """Cached parsing of a two term einsum equation into the necessary
973
+ sequence of arguments for contracttion via batched matrix multiplication.
974
+ The steps we need to specify are:
975
+
976
+ 1. Remove repeated and trivial indices from the left and right terms,
977
+ and transpose them, done as a single einsum.
978
+ 2. Fuse the remaining indices so we have two 3D tensors.
979
+ 3. Perform the batched matrix multiplication.
980
+ 4. Unfuse the output to get the desired final index order.
981
+
982
+ """
983
+ lhs, out = eq.split("->")
984
+ a_term, b_term = lhs.split(",")
985
+
986
+ if len(a_term) != len(shape_a):
987
+ raise ValueError(f"Term '{a_term}' does not match shape {shape_a}.")
988
+ if len(b_term) != len(shape_b):
989
+ raise ValueError(f"Term '{b_term}' does not match shape {shape_b}.")
990
+
991
+ sizes = {}
992
+ singletons = set()
993
+
994
+ # parse left term to unique indices with size > 1
995
+ left = {}
996
+ for ix, d in zip(a_term, shape_a):
997
+ if d == 1:
998
+ # everything (including broadcasting) works nicely if simply ignore
999
+ # such dimensions, but we do need to track if they appear in output
1000
+ # and thus should be reintroduced later
1001
+ singletons.add(ix)
1002
+ continue
1003
+ if sizes.setdefault(ix, d) != d:
1004
+ # set and check size
1005
+ raise ValueError(
1006
+ f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
1007
+ )
1008
+ left[ix] = True
1009
+
1010
+ # parse right term to unique indices with size > 1
1011
+ right = {}
1012
+ for ix, d in zip(b_term, shape_b):
1013
+ # broadcast indices (size 1 on one input and size != 1
1014
+ # on the other) should not be treated as singletons
1015
+ if d == 1:
1016
+ if ix not in left:
1017
+ singletons.add(ix)
1018
+ continue
1019
+ singletons.discard(ix)
1020
+
1021
+ if sizes.setdefault(ix, d) != d:
1022
+ # set and check size
1023
+ raise ValueError(
1024
+ f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
1025
+ )
1026
+ right[ix] = True
1027
+
1028
+ # now we classify the unique size > 1 indices only
1029
+ bat_inds = [] # appears on A, B, O
1030
+ con_inds = [] # appears on A, B, .
1031
+ a_keep = [] # appears on A, ., O
1032
+ b_keep = [] # appears on ., B, O
1033
+ # other indices (appearing on A or B only) will
1034
+ # be summed or traced out prior to the matmul
1035
+ for ix in left:
1036
+ if right.pop(ix, False):
1037
+ if ix in out:
1038
+ bat_inds.append(ix)
1039
+ else:
1040
+ con_inds.append(ix)
1041
+ elif ix in out:
1042
+ a_keep.append(ix)
1043
+ # now only indices unique to right remain
1044
+ for ix in right:
1045
+ if ix in out:
1046
+ b_keep.append(ix)
1047
+
1048
+ if not con_inds:
1049
+ # contraction is pure multiplication, prepare inputs differently
1050
+ return _parse_eq_to_pure_multiplication(
1051
+ a_term, shape_a, b_term, shape_b, out
1052
+ )
1053
+
1054
+ # only need the size one indices that appear in the output
1055
+ singletons = [ix for ix in out if ix in singletons]
1056
+
1057
+ # take diagonal, remove any trivial axes and transpose left
1058
+ desired_a = "".join((*bat_inds, *a_keep, *con_inds))
1059
+ if a_term != desired_a:
1060
+ eq_a = f"{a_term}->{desired_a}"
1061
+ else:
1062
+ eq_a = None
1063
+
1064
+ # take diagonal, remove any trivial axes and transpose right
1065
+ desired_b = "".join((*bat_inds, *con_inds, *b_keep))
1066
+ if b_term != desired_b:
1067
+ eq_b = f"{b_term}->{desired_b}"
1068
+ else:
1069
+ eq_b = None
1070
+
1071
+ # then we want to reshape
1072
+ if bat_inds:
1073
+ lgroups = (bat_inds, a_keep, con_inds)
1074
+ rgroups = (bat_inds, con_inds, b_keep)
1075
+ ogroups = (bat_inds, a_keep, b_keep)
1076
+ else:
1077
+ # avoid size 1 batch dimension if no batch indices
1078
+ lgroups = (a_keep, con_inds)
1079
+ rgroups = (con_inds, b_keep)
1080
+ ogroups = (a_keep, b_keep)
1081
+
1082
+ if any(len(group) != 1 for group in lgroups):
1083
+ # need to fuse 'kept' and contracted indices
1084
+ # (though could allow batch indices to be broadcast)
1085
+ new_shape_a = tuple(
1086
+ functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
1087
+ for ix_group in lgroups
1088
+ )
1089
+ else:
1090
+ new_shape_a = None
1091
+
1092
+ if any(len(group) != 1 for group in rgroups):
1093
+ # need to fuse 'kept' and contracted indices
1094
+ # (though could allow batch indices to be broadcast)
1095
+ new_shape_b = tuple(
1096
+ functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
1097
+ for ix_group in rgroups
1098
+ )
1099
+ else:
1100
+ new_shape_b = None
1101
+
1102
+ if any(len(group) != 1 for group in ogroups) or singletons:
1103
+ new_shape_ab = (1,) * len(singletons) + tuple(
1104
+ sizes[ix] for ix_group in ogroups for ix in ix_group
1105
+ )
1106
+ else:
1107
+ new_shape_ab = None
1108
+
1109
+ # then we might need to permute the matmul produced output:
1110
+ out_produced = "".join((*singletons, *bat_inds, *a_keep, *b_keep))
1111
+ if out_produced != out:
1112
+ perm_ab = tuple(out_produced.index(ix) for ix in out)
1113
+ else:
1114
+ perm_ab = None
1115
+
1116
+ return (
1117
+ eq_a,
1118
+ eq_b,
1119
+ new_shape_a,
1120
+ new_shape_b,
1121
+ new_shape_ab,
1122
+ perm_ab,
1123
+ False, # pure_multiplication=False
1124
+ )
1125
+
1126
+
1127
+ @functools.lru_cache(maxsize=64)
1128
+ def _parse_output_order(order, a_is_fcontig, b_is_fcontig):
1129
+ order = order.upper()
1130
+ if order == "K":
1131
+ return None
1132
+ elif order in "CF":
1133
+ return order
1134
+ elif order == "A":
1135
+ if a_is_fcontig and b_is_fcontig:
1136
+ return "F"
1137
+ else:
1138
+ return "C"
1139
+ else:
1140
+ raise ValueError(
1141
+ "ValueError: order must be one of "
1142
+ f"'C', 'F', 'A', or 'K' (got '{order}')"
1143
+ )
1144
+
1145
+
1146
+ def bmm_einsum(eq, a, b, out=None, **kwargs):
1147
+ """Perform arbitrary pairwise einsums using only ``matmul``, or
1148
+ ``multiply`` if no contracted indices are involved (plus maybe single term
1149
+ ``einsum`` to prepare the terms individually). The logic for each is cached
1150
+ based on the equation and array shape, and each step is only performed if
1151
+ necessary.
1152
+
1153
+ Parameters
1154
+ ----------
1155
+ eq : str
1156
+ The einsum equation.
1157
+ a : array_like
1158
+ The first array to contract.
1159
+ b : array_like
1160
+ The second array to contract.
1161
+
1162
+ Returns
1163
+ -------
1164
+ array_like
1165
+
1166
+ Notes
1167
+ -----
1168
+ A fuller description of this algorithm, and original source for this
1169
+ implementation, can be found at https://github.com/jcmgray/einsum_bmm.
1170
+ """
1171
+ (
1172
+ eq_a,
1173
+ eq_b,
1174
+ new_shape_a,
1175
+ new_shape_b,
1176
+ new_shape_ab,
1177
+ perm_ab,
1178
+ pure_multiplication,
1179
+ ) = _parse_eq_to_batch_matmul(eq, a.shape, b.shape)
1180
+
1181
+ # n.b. one could special case various cases to call c_einsum directly here
1182
+
1183
+ # need to handle `order` a little manually, since we do transpose
1184
+ # operations before and potentially after the ufunc calls
1185
+ output_order = _parse_output_order(
1186
+ kwargs.pop("order", "K"), a.flags.f_contiguous, b.flags.f_contiguous
1187
+ )
1188
+
1189
+ # prepare left
1190
+ if eq_a is not None:
1191
+ # diagonals, sums, and tranpose
1192
+ a = c_einsum(eq_a, a)
1193
+ if new_shape_a is not None:
1194
+ a = reshape(a, new_shape_a)
1195
+
1196
+ # prepare right
1197
+ if eq_b is not None:
1198
+ # diagonals, sums, and tranpose
1199
+ b = c_einsum(eq_b, b)
1200
+ if new_shape_b is not None:
1201
+ b = reshape(b, new_shape_b)
1202
+
1203
+ if pure_multiplication:
1204
+ # no contracted indices
1205
+ if output_order is not None:
1206
+ kwargs["order"] = output_order
1207
+
1208
+ # do the 'contraction' via multiplication!
1209
+ return multiply(a, b, out=out, **kwargs)
1210
+
1211
+ # can only supply out here if no other reshaping / transposing
1212
+ matmul_out_compatible = (new_shape_ab is None) and (perm_ab is None)
1213
+ if matmul_out_compatible:
1214
+ kwargs["out"] = out
1215
+
1216
+ # do the contraction!
1217
+ ab = matmul(a, b, **kwargs)
1218
+
1219
+ # prepare the output
1220
+ if new_shape_ab is not None:
1221
+ ab = reshape(ab, new_shape_ab)
1222
+ if perm_ab is not None:
1223
+ ab = ab.transpose(perm_ab)
1224
+
1225
+ if (out is not None) and (not matmul_out_compatible):
1226
+ # handle case where out is specified, but we also needed
1227
+ # to reshape / transpose ``ab`` after the matmul
1228
+ out[:] = ab
1229
+ ab = out
1230
+ elif output_order is not None:
1231
+ ab = asanyarray(ab, order=output_order)
1232
+
1233
+ return ab
1234
+
1235
+
1236
+ def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
1237
+ # Arguably we dispatch on more arguments than we really should; see note in
1238
+ # _einsum_path_dispatcher for why.
1239
+ yield from operands
1240
+ yield out
1241
+
1242
+
1243
+ # Rewrite einsum to handle different cases
1244
+ @array_function_dispatch(_einsum_dispatcher, module='numpy')
1245
+ def einsum(*operands, out=None, optimize=False, **kwargs):
1246
+ """
1247
+ einsum(subscripts, *operands, out=None, dtype=None, order='K',
1248
+ casting='safe', optimize=False)
1249
+
1250
+ Evaluates the Einstein summation convention on the operands.
1251
+
1252
+ Using the Einstein summation convention, many common multi-dimensional,
1253
+ linear algebraic array operations can be represented in a simple fashion.
1254
+ In *implicit* mode `einsum` computes these values.
1255
+
1256
+ In *explicit* mode, `einsum` provides further flexibility to compute
1257
+ other array operations that might not be considered classical Einstein
1258
+ summation operations, by disabling, or forcing summation over specified
1259
+ subscript labels.
1260
+
1261
+ See the notes and examples for clarification.
1262
+
1263
+ Parameters
1264
+ ----------
1265
+ subscripts : str
1266
+ Specifies the subscripts for summation as comma separated list of
1267
+ subscript labels. An implicit (classical Einstein summation)
1268
+ calculation is performed unless the explicit indicator '->' is
1269
+ included as well as subscript labels of the precise output form.
1270
+ operands : list of array_like
1271
+ These are the arrays for the operation.
1272
+ out : ndarray, optional
1273
+ If provided, the calculation is done into this array.
1274
+ dtype : {data-type, None}, optional
1275
+ If provided, forces the calculation to use the data type specified.
1276
+ Note that you may have to also give a more liberal `casting`
1277
+ parameter to allow the conversions. Default is None.
1278
+ order : {'C', 'F', 'A', 'K'}, optional
1279
+ Controls the memory layout of the output. 'C' means it should
1280
+ be C contiguous. 'F' means it should be Fortran contiguous,
1281
+ 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
1282
+ 'K' means it should be as close to the layout as the inputs as
1283
+ is possible, including arbitrarily permuted axes.
1284
+ Default is 'K'.
1285
+ casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1286
+ Controls what kind of data casting may occur. Setting this to
1287
+ 'unsafe' is not recommended, as it can adversely affect accumulations.
1288
+
1289
+ * 'no' means the data types should not be cast at all.
1290
+ * 'equiv' means only byte-order changes are allowed.
1291
+ * 'safe' means only casts which can preserve values are allowed.
1292
+ * 'same_kind' means only safe casts or casts within a kind,
1293
+ like float64 to float32, are allowed.
1294
+ * 'unsafe' means any data conversions may be done.
1295
+
1296
+ Default is 'safe'.
1297
+ optimize : {False, True, 'greedy', 'optimal'}, optional
1298
+ Controls if intermediate optimization should occur. No optimization
1299
+ will occur if False and True will default to the 'greedy' algorithm.
1300
+ Also accepts an explicit contraction list from the ``np.einsum_path``
1301
+ function. See ``np.einsum_path`` for more details. Defaults to False.
1302
+
1303
+ Returns
1304
+ -------
1305
+ output : ndarray
1306
+ The calculation based on the Einstein summation convention.
1307
+
1308
+ See Also
1309
+ --------
1310
+ einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
1311
+ einsum:
1312
+ Similar verbose interface is provided by the
1313
+ `einops <https://github.com/arogozhnikov/einops>`_ package to cover
1314
+ additional operations: transpose, reshape/flatten, repeat/tile,
1315
+ squeeze/unsqueeze and reductions.
1316
+ The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
1317
+ optimizes contraction order for einsum-like expressions
1318
+ in backend-agnostic manner.
1319
+
1320
+ Notes
1321
+ -----
1322
+ The Einstein summation convention can be used to compute
1323
+ many multi-dimensional, linear algebraic array operations. `einsum`
1324
+ provides a succinct way of representing these.
1325
+
1326
+ A non-exhaustive list of these operations,
1327
+ which can be computed by `einsum`, is shown below along with examples:
1328
+
1329
+ * Trace of an array, :py:func:`numpy.trace`.
1330
+ * Return a diagonal, :py:func:`numpy.diag`.
1331
+ * Array axis summations, :py:func:`numpy.sum`.
1332
+ * Transpositions and permutations, :py:func:`numpy.transpose`.
1333
+ * Matrix multiplication and dot product, :py:func:`numpy.matmul`
1334
+ :py:func:`numpy.dot`.
1335
+ * Vector inner and outer products, :py:func:`numpy.inner`
1336
+ :py:func:`numpy.outer`.
1337
+ * Broadcasting, element-wise and scalar multiplication,
1338
+ :py:func:`numpy.multiply`.
1339
+ * Tensor contractions, :py:func:`numpy.tensordot`.
1340
+ * Chained array operations, in efficient calculation order,
1341
+ :py:func:`numpy.einsum_path`.
1342
+
1343
+ The subscripts string is a comma-separated list of subscript labels,
1344
+ where each label refers to a dimension of the corresponding operand.
1345
+ Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
1346
+ is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
1347
+ appears only once, it is not summed, so ``np.einsum('i', a)``
1348
+ produces a view of ``a`` with no changes. A further example
1349
+ ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication
1350
+ and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`.
1351
+ Repeated subscript labels in one operand take the diagonal.
1352
+ For example, ``np.einsum('ii', a)`` is equivalent to
1353
+ :py:func:`np.trace(a) <numpy.trace>`.
1354
+
1355
+ In *implicit mode*, the chosen subscripts are important
1356
+ since the axes of the output are reordered alphabetically. This
1357
+ means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
1358
+ ``np.einsum('ji', a)`` takes its transpose. Additionally,
1359
+ ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
1360
+ ``np.einsum('ij,jh', a, b)`` returns the transpose of the
1361
+ multiplication since subscript 'h' precedes subscript 'i'.
1362
+
1363
+ In *explicit mode* the output can be directly controlled by
1364
+ specifying output subscript labels. This requires the
1365
+ identifier '->' as well as the list of output subscript labels.
1366
+ This feature increases the flexibility of the function since
1367
+ summing can be disabled or forced when required. The call
1368
+ ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>`
1369
+ if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)``
1370
+ is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array.
1371
+ The difference is that `einsum` does not allow broadcasting by default.
1372
+ Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
1373
+ order of the output subscript labels and therefore returns matrix
1374
+ multiplication, unlike the example above in implicit mode.
1375
+
1376
+ To enable and control broadcasting, use an ellipsis. Default
1377
+ NumPy-style broadcasting is done by adding an ellipsis
1378
+ to the left of each term, like ``np.einsum('...ii->...i', a)``.
1379
+ ``np.einsum('...i->...', a)`` is like
1380
+ :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape.
1381
+ To take the trace along the first and last axes,
1382
+ you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
1383
+ product with the left-most indices instead of rightmost, one can do
1384
+ ``np.einsum('ij...,jk...->ik...', a, b)``.
1385
+
1386
+ When there is only one operand, no axes are summed, and no output
1387
+ parameter is provided, a view into the operand is returned instead
1388
+ of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
1389
+ produces a view (changed in version 1.10.0).
1390
+
1391
+ `einsum` also provides an alternative way to provide the subscripts and
1392
+ operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
1393
+ If the output shape is not provided in this format `einsum` will be
1394
+ calculated in implicit mode, otherwise it will be performed explicitly.
1395
+ The examples below have corresponding `einsum` calls with the two
1396
+ parameter methods.
1397
+
1398
+ Views returned from einsum are now writeable whenever the input array
1399
+ is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
1400
+ have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
1401
+ and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
1402
+ of a 2D array.
1403
+
1404
+ Added the ``optimize`` argument which will optimize the contraction order
1405
+ of an einsum expression. For a contraction with three or more operands
1406
+ this can greatly increase the computational efficiency at the cost of
1407
+ a larger memory footprint during computation.
1408
+
1409
+ Typically a 'greedy' algorithm is applied which empirical tests have shown
1410
+ returns the optimal path in the majority of cases. In some cases 'optimal'
1411
+ will return the superlative path through a more expensive, exhaustive
1412
+ search. For iterative calculations it may be advisable to calculate
1413
+ the optimal path once and reuse that path by supplying it as an argument.
1414
+ An example is given below.
1415
+
1416
+ See :py:func:`numpy.einsum_path` for more details.
1417
+
1418
+ Examples
1419
+ --------
1420
+ >>> a = np.arange(25).reshape(5,5)
1421
+ >>> b = np.arange(5)
1422
+ >>> c = np.arange(6).reshape(2,3)
1423
+
1424
+ Trace of a matrix:
1425
+
1426
+ >>> np.einsum('ii', a)
1427
+ 60
1428
+ >>> np.einsum(a, [0,0])
1429
+ 60
1430
+ >>> np.trace(a)
1431
+ 60
1432
+
1433
+ Extract the diagonal (requires explicit form):
1434
+
1435
+ >>> np.einsum('ii->i', a)
1436
+ array([ 0, 6, 12, 18, 24])
1437
+ >>> np.einsum(a, [0,0], [0])
1438
+ array([ 0, 6, 12, 18, 24])
1439
+ >>> np.diag(a)
1440
+ array([ 0, 6, 12, 18, 24])
1441
+
1442
+ Sum over an axis (requires explicit form):
1443
+
1444
+ >>> np.einsum('ij->i', a)
1445
+ array([ 10, 35, 60, 85, 110])
1446
+ >>> np.einsum(a, [0,1], [0])
1447
+ array([ 10, 35, 60, 85, 110])
1448
+ >>> np.sum(a, axis=1)
1449
+ array([ 10, 35, 60, 85, 110])
1450
+
1451
+ For higher dimensional arrays summing a single axis can be done
1452
+ with ellipsis:
1453
+
1454
+ >>> np.einsum('...j->...', a)
1455
+ array([ 10, 35, 60, 85, 110])
1456
+ >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
1457
+ array([ 10, 35, 60, 85, 110])
1458
+
1459
+ Compute a matrix transpose, or reorder any number of axes:
1460
+
1461
+ >>> np.einsum('ji', c)
1462
+ array([[0, 3],
1463
+ [1, 4],
1464
+ [2, 5]])
1465
+ >>> np.einsum('ij->ji', c)
1466
+ array([[0, 3],
1467
+ [1, 4],
1468
+ [2, 5]])
1469
+ >>> np.einsum(c, [1,0])
1470
+ array([[0, 3],
1471
+ [1, 4],
1472
+ [2, 5]])
1473
+ >>> np.transpose(c)
1474
+ array([[0, 3],
1475
+ [1, 4],
1476
+ [2, 5]])
1477
+
1478
+ Vector inner products:
1479
+
1480
+ >>> np.einsum('i,i', b, b)
1481
+ 30
1482
+ >>> np.einsum(b, [0], b, [0])
1483
+ 30
1484
+ >>> np.inner(b,b)
1485
+ 30
1486
+
1487
+ Matrix vector multiplication:
1488
+
1489
+ >>> np.einsum('ij,j', a, b)
1490
+ array([ 30, 80, 130, 180, 230])
1491
+ >>> np.einsum(a, [0,1], b, [1])
1492
+ array([ 30, 80, 130, 180, 230])
1493
+ >>> np.dot(a, b)
1494
+ array([ 30, 80, 130, 180, 230])
1495
+ >>> np.einsum('...j,j', a, b)
1496
+ array([ 30, 80, 130, 180, 230])
1497
+
1498
+ Broadcasting and scalar multiplication:
1499
+
1500
+ >>> np.einsum('..., ...', 3, c)
1501
+ array([[ 0, 3, 6],
1502
+ [ 9, 12, 15]])
1503
+ >>> np.einsum(',ij', 3, c)
1504
+ array([[ 0, 3, 6],
1505
+ [ 9, 12, 15]])
1506
+ >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
1507
+ array([[ 0, 3, 6],
1508
+ [ 9, 12, 15]])
1509
+ >>> np.multiply(3, c)
1510
+ array([[ 0, 3, 6],
1511
+ [ 9, 12, 15]])
1512
+
1513
+ Vector outer product:
1514
+
1515
+ >>> np.einsum('i,j', np.arange(2)+1, b)
1516
+ array([[0, 1, 2, 3, 4],
1517
+ [0, 2, 4, 6, 8]])
1518
+ >>> np.einsum(np.arange(2)+1, [0], b, [1])
1519
+ array([[0, 1, 2, 3, 4],
1520
+ [0, 2, 4, 6, 8]])
1521
+ >>> np.outer(np.arange(2)+1, b)
1522
+ array([[0, 1, 2, 3, 4],
1523
+ [0, 2, 4, 6, 8]])
1524
+
1525
+ Tensor contraction:
1526
+
1527
+ >>> a = np.arange(60.).reshape(3,4,5)
1528
+ >>> b = np.arange(24.).reshape(4,3,2)
1529
+ >>> np.einsum('ijk,jil->kl', a, b)
1530
+ array([[4400., 4730.],
1531
+ [4532., 4874.],
1532
+ [4664., 5018.],
1533
+ [4796., 5162.],
1534
+ [4928., 5306.]])
1535
+ >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
1536
+ array([[4400., 4730.],
1537
+ [4532., 4874.],
1538
+ [4664., 5018.],
1539
+ [4796., 5162.],
1540
+ [4928., 5306.]])
1541
+ >>> np.tensordot(a,b, axes=([1,0],[0,1]))
1542
+ array([[4400., 4730.],
1543
+ [4532., 4874.],
1544
+ [4664., 5018.],
1545
+ [4796., 5162.],
1546
+ [4928., 5306.]])
1547
+
1548
+ Writeable returned arrays (since version 1.10.0):
1549
+
1550
+ >>> a = np.zeros((3, 3))
1551
+ >>> np.einsum('ii->i', a)[:] = 1
1552
+ >>> a
1553
+ array([[1., 0., 0.],
1554
+ [0., 1., 0.],
1555
+ [0., 0., 1.]])
1556
+
1557
+ Example of ellipsis use:
1558
+
1559
+ >>> a = np.arange(6).reshape((3,2))
1560
+ >>> b = np.arange(12).reshape((4,3))
1561
+ >>> np.einsum('ki,jk->ij', a, b)
1562
+ array([[10, 28, 46, 64],
1563
+ [13, 40, 67, 94]])
1564
+ >>> np.einsum('ki,...k->i...', a, b)
1565
+ array([[10, 28, 46, 64],
1566
+ [13, 40, 67, 94]])
1567
+ >>> np.einsum('k...,jk', a, b)
1568
+ array([[10, 28, 46, 64],
1569
+ [13, 40, 67, 94]])
1570
+
1571
+ Chained array operations. For more complicated contractions, speed ups
1572
+ might be achieved by repeatedly computing a 'greedy' path or pre-computing
1573
+ the 'optimal' path and repeatedly applying it, using an `einsum_path`
1574
+ insertion (since version 1.12.0). Performance improvements can be
1575
+ particularly significant with larger arrays:
1576
+
1577
+ >>> a = np.ones(64).reshape(2,4,8)
1578
+
1579
+ Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
1580
+
1581
+ >>> for iteration in range(500):
1582
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
1583
+
1584
+ Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
1585
+
1586
+ >>> for iteration in range(500):
1587
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1588
+ ... optimize='optimal')
1589
+
1590
+ Greedy `einsum` (faster optimal path approximation): ~160ms
1591
+
1592
+ >>> for iteration in range(500):
1593
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
1594
+
1595
+ Optimal `einsum` (best usage pattern in some use cases): ~110ms
1596
+
1597
+ >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1598
+ ... optimize='optimal')[0]
1599
+ >>> for iteration in range(500):
1600
+ ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
1601
+
1602
+ """
1603
+ # Special handling if out is specified
1604
+ specified_out = out is not None
1605
+
1606
+ # If no optimization, run pure einsum
1607
+ if optimize is False:
1608
+ if specified_out:
1609
+ kwargs['out'] = out
1610
+ return c_einsum(*operands, **kwargs)
1611
+
1612
+ # Check the kwargs to avoid a more cryptic error later, without having to
1613
+ # repeat default values here
1614
+ valid_einsum_kwargs = ['dtype', 'order', 'casting']
1615
+ unknown_kwargs = [k for (k, v) in kwargs.items() if
1616
+ k not in valid_einsum_kwargs]
1617
+ if len(unknown_kwargs):
1618
+ raise TypeError(f"Did not understand the following kwargs: {unknown_kwargs}")
1619
+
1620
+ # Build the contraction list and operand
1621
+ operands, contraction_list = einsum_path(*operands, optimize=optimize,
1622
+ einsum_call=True)
1623
+
1624
+ # Start contraction loop
1625
+ for num, contraction in enumerate(contraction_list):
1626
+ inds, einsum_str, _ = contraction
1627
+ tmp_operands = [operands.pop(x) for x in inds]
1628
+
1629
+ # Do we need to deal with the output?
1630
+ handle_out = specified_out and ((num + 1) == len(contraction_list))
1631
+
1632
+ # If out was specified
1633
+ if handle_out:
1634
+ kwargs["out"] = out
1635
+
1636
+ if len(tmp_operands) == 2:
1637
+ # Call (batched) matrix multiplication if possible
1638
+ new_view = bmm_einsum(einsum_str, *tmp_operands, **kwargs)
1639
+ else:
1640
+ # Call einsum
1641
+ new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
1642
+
1643
+ # Append new items and dereference what we can
1644
+ operands.append(new_view)
1645
+ del tmp_operands, new_view
1646
+
1647
+ if specified_out:
1648
+ return out
1649
+ else:
1650
+ return operands[0]