netgen-mesher 6.2.2506.post35.dev0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (340) hide show
  1. netgen/NgOCC.py +7 -0
  2. netgen/__init__.py +114 -0
  3. netgen/__init__.pyi +22 -0
  4. netgen/__main__.py +53 -0
  5. netgen/cmake/NetgenConfig.cmake +79 -0
  6. netgen/cmake/netgen-targets-release.cmake +69 -0
  7. netgen/cmake/netgen-targets.cmake +146 -0
  8. netgen/config/__init__.py +1 -0
  9. netgen/config/__init__.pyi +52 -0
  10. netgen/config/__main__.py +4 -0
  11. netgen/config/config.py +68 -0
  12. netgen/config/config.pyi +54 -0
  13. netgen/csg.py +25 -0
  14. netgen/geom2d.py +178 -0
  15. netgen/gui.py +82 -0
  16. netgen/include/core/archive.hpp +1256 -0
  17. netgen/include/core/array.hpp +1760 -0
  18. netgen/include/core/autodiff.hpp +1131 -0
  19. netgen/include/core/autodiffdiff.hpp +733 -0
  20. netgen/include/core/bitarray.hpp +240 -0
  21. netgen/include/core/concurrentqueue.h +3619 -0
  22. netgen/include/core/exception.hpp +145 -0
  23. netgen/include/core/flags.hpp +199 -0
  24. netgen/include/core/hashtable.hpp +1281 -0
  25. netgen/include/core/localheap.hpp +318 -0
  26. netgen/include/core/logging.hpp +117 -0
  27. netgen/include/core/memtracer.hpp +221 -0
  28. netgen/include/core/mpi4py_pycapi.h +245 -0
  29. netgen/include/core/mpi_wrapper.hpp +643 -0
  30. netgen/include/core/ng_mpi.hpp +94 -0
  31. netgen/include/core/ng_mpi_generated_declarations.hpp +155 -0
  32. netgen/include/core/ng_mpi_native.hpp +25 -0
  33. netgen/include/core/ngcore.hpp +32 -0
  34. netgen/include/core/ngcore_api.hpp +152 -0
  35. netgen/include/core/ngstream.hpp +115 -0
  36. netgen/include/core/paje_trace.hpp +279 -0
  37. netgen/include/core/profiler.hpp +382 -0
  38. netgen/include/core/python_ngcore.hpp +457 -0
  39. netgen/include/core/ranges.hpp +109 -0
  40. netgen/include/core/register_archive.hpp +100 -0
  41. netgen/include/core/signal.hpp +82 -0
  42. netgen/include/core/simd.hpp +160 -0
  43. netgen/include/core/simd_arm64.hpp +407 -0
  44. netgen/include/core/simd_avx.hpp +394 -0
  45. netgen/include/core/simd_avx512.hpp +285 -0
  46. netgen/include/core/simd_generic.hpp +1053 -0
  47. netgen/include/core/simd_math.hpp +178 -0
  48. netgen/include/core/simd_sse.hpp +289 -0
  49. netgen/include/core/statushandler.hpp +37 -0
  50. netgen/include/core/symboltable.hpp +153 -0
  51. netgen/include/core/table.hpp +810 -0
  52. netgen/include/core/taskmanager.hpp +1161 -0
  53. netgen/include/core/type_traits.hpp +65 -0
  54. netgen/include/core/utils.hpp +385 -0
  55. netgen/include/core/version.hpp +102 -0
  56. netgen/include/core/xbool.hpp +47 -0
  57. netgen/include/csg/algprim.hpp +563 -0
  58. netgen/include/csg/brick.hpp +150 -0
  59. netgen/include/csg/csg.hpp +43 -0
  60. netgen/include/csg/csgeom.hpp +389 -0
  61. netgen/include/csg/csgparser.hpp +101 -0
  62. netgen/include/csg/curve2d.hpp +67 -0
  63. netgen/include/csg/edgeflw.hpp +112 -0
  64. netgen/include/csg/explicitcurve2d.hpp +113 -0
  65. netgen/include/csg/extrusion.hpp +185 -0
  66. netgen/include/csg/gencyl.hpp +70 -0
  67. netgen/include/csg/geoml.hpp +16 -0
  68. netgen/include/csg/identify.hpp +213 -0
  69. netgen/include/csg/manifold.hpp +29 -0
  70. netgen/include/csg/meshsurf.hpp +46 -0
  71. netgen/include/csg/polyhedra.hpp +121 -0
  72. netgen/include/csg/revolution.hpp +180 -0
  73. netgen/include/csg/singularref.hpp +84 -0
  74. netgen/include/csg/solid.hpp +295 -0
  75. netgen/include/csg/specpoin.hpp +194 -0
  76. netgen/include/csg/spline3d.hpp +99 -0
  77. netgen/include/csg/splinesurface.hpp +85 -0
  78. netgen/include/csg/surface.hpp +394 -0
  79. netgen/include/csg/triapprox.hpp +63 -0
  80. netgen/include/csg/vscsg.hpp +34 -0
  81. netgen/include/general/autodiff.hpp +356 -0
  82. netgen/include/general/autoptr.hpp +39 -0
  83. netgen/include/general/gzstream.h +121 -0
  84. netgen/include/general/hashtabl.hpp +1692 -0
  85. netgen/include/general/myadt.hpp +48 -0
  86. netgen/include/general/mystring.hpp +226 -0
  87. netgen/include/general/netgenout.hpp +205 -0
  88. netgen/include/general/ngarray.hpp +797 -0
  89. netgen/include/general/ngbitarray.hpp +149 -0
  90. netgen/include/general/ngpython.hpp +74 -0
  91. netgen/include/general/optmem.hpp +44 -0
  92. netgen/include/general/parthreads.hpp +138 -0
  93. netgen/include/general/seti.hpp +50 -0
  94. netgen/include/general/sort.hpp +47 -0
  95. netgen/include/general/spbita2d.hpp +59 -0
  96. netgen/include/general/stack.hpp +114 -0
  97. netgen/include/general/table.hpp +280 -0
  98. netgen/include/general/template.hpp +509 -0
  99. netgen/include/geom2d/csg2d.hpp +750 -0
  100. netgen/include/geom2d/geometry2d.hpp +280 -0
  101. netgen/include/geom2d/spline2d.hpp +234 -0
  102. netgen/include/geom2d/vsgeom2d.hpp +28 -0
  103. netgen/include/gprim/adtree.hpp +1392 -0
  104. netgen/include/gprim/geom2d.hpp +858 -0
  105. netgen/include/gprim/geom3d.hpp +749 -0
  106. netgen/include/gprim/geomfuncs.hpp +212 -0
  107. netgen/include/gprim/geomobjects.hpp +544 -0
  108. netgen/include/gprim/geomops.hpp +404 -0
  109. netgen/include/gprim/geomtest3d.hpp +101 -0
  110. netgen/include/gprim/gprim.hpp +33 -0
  111. netgen/include/gprim/spline.hpp +778 -0
  112. netgen/include/gprim/splinegeometry.hpp +73 -0
  113. netgen/include/gprim/transform3d.hpp +216 -0
  114. netgen/include/include/acisgeom.hpp +3 -0
  115. netgen/include/include/csg.hpp +1 -0
  116. netgen/include/include/geometry2d.hpp +1 -0
  117. netgen/include/include/gprim.hpp +1 -0
  118. netgen/include/include/incopengl.hpp +62 -0
  119. netgen/include/include/inctcl.hpp +13 -0
  120. netgen/include/include/incvis.hpp +6 -0
  121. netgen/include/include/linalg.hpp +1 -0
  122. netgen/include/include/meshing.hpp +1 -0
  123. netgen/include/include/myadt.hpp +1 -0
  124. netgen/include/include/mydefs.hpp +70 -0
  125. netgen/include/include/mystdlib.h +59 -0
  126. netgen/include/include/netgen_config.hpp +27 -0
  127. netgen/include/include/netgen_version.hpp +9 -0
  128. netgen/include/include/nginterface_v2_impl.hpp +395 -0
  129. netgen/include/include/ngsimd.hpp +1 -0
  130. netgen/include/include/occgeom.hpp +1 -0
  131. netgen/include/include/opti.hpp +1 -0
  132. netgen/include/include/parallel.hpp +1 -0
  133. netgen/include/include/stlgeom.hpp +1 -0
  134. netgen/include/include/visual.hpp +1 -0
  135. netgen/include/interface/rw_medit.hpp +11 -0
  136. netgen/include/interface/writeuser.hpp +80 -0
  137. netgen/include/linalg/densemat.hpp +414 -0
  138. netgen/include/linalg/linalg.hpp +29 -0
  139. netgen/include/linalg/opti.hpp +142 -0
  140. netgen/include/linalg/polynomial.hpp +47 -0
  141. netgen/include/linalg/vector.hpp +217 -0
  142. netgen/include/meshing/adfront2.hpp +274 -0
  143. netgen/include/meshing/adfront3.hpp +332 -0
  144. netgen/include/meshing/basegeom.hpp +370 -0
  145. netgen/include/meshing/bcfunctions.hpp +53 -0
  146. netgen/include/meshing/bisect.hpp +72 -0
  147. netgen/include/meshing/boundarylayer.hpp +113 -0
  148. netgen/include/meshing/classifyhpel.hpp +1984 -0
  149. netgen/include/meshing/clusters.hpp +46 -0
  150. netgen/include/meshing/curvedelems.hpp +274 -0
  151. netgen/include/meshing/delaunay2d.hpp +73 -0
  152. netgen/include/meshing/fieldlines.hpp +103 -0
  153. netgen/include/meshing/findip.hpp +198 -0
  154. netgen/include/meshing/findip2.hpp +103 -0
  155. netgen/include/meshing/geomsearch.hpp +69 -0
  156. netgen/include/meshing/global.hpp +54 -0
  157. netgen/include/meshing/hpref_hex.hpp +330 -0
  158. netgen/include/meshing/hpref_prism.hpp +3405 -0
  159. netgen/include/meshing/hpref_pyramid.hpp +154 -0
  160. netgen/include/meshing/hpref_quad.hpp +2082 -0
  161. netgen/include/meshing/hpref_segm.hpp +122 -0
  162. netgen/include/meshing/hpref_tet.hpp +4230 -0
  163. netgen/include/meshing/hpref_trig.hpp +848 -0
  164. netgen/include/meshing/hprefinement.hpp +366 -0
  165. netgen/include/meshing/improve2.hpp +178 -0
  166. netgen/include/meshing/improve3.hpp +151 -0
  167. netgen/include/meshing/localh.hpp +223 -0
  168. netgen/include/meshing/meshclass.hpp +1076 -0
  169. netgen/include/meshing/meshfunc.hpp +47 -0
  170. netgen/include/meshing/meshing.hpp +63 -0
  171. netgen/include/meshing/meshing2.hpp +163 -0
  172. netgen/include/meshing/meshing3.hpp +123 -0
  173. netgen/include/meshing/meshtool.hpp +90 -0
  174. netgen/include/meshing/meshtype.hpp +1930 -0
  175. netgen/include/meshing/msghandler.hpp +62 -0
  176. netgen/include/meshing/paralleltop.hpp +172 -0
  177. netgen/include/meshing/python_mesh.hpp +206 -0
  178. netgen/include/meshing/ruler2.hpp +172 -0
  179. netgen/include/meshing/ruler3.hpp +211 -0
  180. netgen/include/meshing/soldata.hpp +141 -0
  181. netgen/include/meshing/specials.hpp +17 -0
  182. netgen/include/meshing/surfacegeom.hpp +73 -0
  183. netgen/include/meshing/topology.hpp +1003 -0
  184. netgen/include/meshing/validate.hpp +21 -0
  185. netgen/include/meshing/visual_interface.hpp +71 -0
  186. netgen/include/mydefs.hpp +70 -0
  187. netgen/include/nginterface.h +474 -0
  188. netgen/include/nginterface_v2.hpp +406 -0
  189. netgen/include/nglib.h +697 -0
  190. netgen/include/nglib_occ.h +50 -0
  191. netgen/include/occ/occ_edge.hpp +47 -0
  192. netgen/include/occ/occ_face.hpp +52 -0
  193. netgen/include/occ/occ_solid.hpp +23 -0
  194. netgen/include/occ/occ_utils.hpp +376 -0
  195. netgen/include/occ/occ_vertex.hpp +30 -0
  196. netgen/include/occ/occgeom.hpp +659 -0
  197. netgen/include/occ/occmeshsurf.hpp +168 -0
  198. netgen/include/occ/vsocc.hpp +33 -0
  199. netgen/include/pybind11/LICENSE +29 -0
  200. netgen/include/pybind11/attr.h +722 -0
  201. netgen/include/pybind11/buffer_info.h +208 -0
  202. netgen/include/pybind11/cast.h +2361 -0
  203. netgen/include/pybind11/chrono.h +228 -0
  204. netgen/include/pybind11/common.h +2 -0
  205. netgen/include/pybind11/complex.h +74 -0
  206. netgen/include/pybind11/conduit/README.txt +15 -0
  207. netgen/include/pybind11/conduit/pybind11_conduit_v1.h +116 -0
  208. netgen/include/pybind11/conduit/pybind11_platform_abi_id.h +87 -0
  209. netgen/include/pybind11/conduit/wrap_include_python_h.h +72 -0
  210. netgen/include/pybind11/critical_section.h +56 -0
  211. netgen/include/pybind11/detail/class.h +823 -0
  212. netgen/include/pybind11/detail/common.h +1348 -0
  213. netgen/include/pybind11/detail/cpp_conduit.h +75 -0
  214. netgen/include/pybind11/detail/descr.h +226 -0
  215. netgen/include/pybind11/detail/dynamic_raw_ptr_cast_if_possible.h +39 -0
  216. netgen/include/pybind11/detail/exception_translation.h +71 -0
  217. netgen/include/pybind11/detail/function_record_pyobject.h +191 -0
  218. netgen/include/pybind11/detail/init.h +538 -0
  219. netgen/include/pybind11/detail/internals.h +799 -0
  220. netgen/include/pybind11/detail/native_enum_data.h +209 -0
  221. netgen/include/pybind11/detail/pybind11_namespace_macros.h +82 -0
  222. netgen/include/pybind11/detail/struct_smart_holder.h +378 -0
  223. netgen/include/pybind11/detail/type_caster_base.h +1591 -0
  224. netgen/include/pybind11/detail/typeid.h +65 -0
  225. netgen/include/pybind11/detail/using_smart_holder.h +22 -0
  226. netgen/include/pybind11/detail/value_and_holder.h +90 -0
  227. netgen/include/pybind11/eigen/common.h +9 -0
  228. netgen/include/pybind11/eigen/matrix.h +723 -0
  229. netgen/include/pybind11/eigen/tensor.h +521 -0
  230. netgen/include/pybind11/eigen.h +12 -0
  231. netgen/include/pybind11/embed.h +320 -0
  232. netgen/include/pybind11/eval.h +161 -0
  233. netgen/include/pybind11/functional.h +147 -0
  234. netgen/include/pybind11/gil.h +199 -0
  235. netgen/include/pybind11/gil_safe_call_once.h +102 -0
  236. netgen/include/pybind11/gil_simple.h +37 -0
  237. netgen/include/pybind11/iostream.h +265 -0
  238. netgen/include/pybind11/native_enum.h +67 -0
  239. netgen/include/pybind11/numpy.h +2312 -0
  240. netgen/include/pybind11/operators.h +202 -0
  241. netgen/include/pybind11/options.h +92 -0
  242. netgen/include/pybind11/pybind11.h +3645 -0
  243. netgen/include/pybind11/pytypes.h +2680 -0
  244. netgen/include/pybind11/stl/filesystem.h +114 -0
  245. netgen/include/pybind11/stl.h +666 -0
  246. netgen/include/pybind11/stl_bind.h +858 -0
  247. netgen/include/pybind11/subinterpreter.h +299 -0
  248. netgen/include/pybind11/trampoline_self_life_support.h +65 -0
  249. netgen/include/pybind11/type_caster_pyobject_ptr.h +61 -0
  250. netgen/include/pybind11/typing.h +298 -0
  251. netgen/include/pybind11/warnings.h +75 -0
  252. netgen/include/stlgeom/meshstlsurface.hpp +67 -0
  253. netgen/include/stlgeom/stlgeom.hpp +491 -0
  254. netgen/include/stlgeom/stlline.hpp +193 -0
  255. netgen/include/stlgeom/stltool.hpp +331 -0
  256. netgen/include/stlgeom/stltopology.hpp +419 -0
  257. netgen/include/stlgeom/vsstl.hpp +58 -0
  258. netgen/include/visualization/meshdoc.hpp +42 -0
  259. netgen/include/visualization/mvdraw.hpp +325 -0
  260. netgen/include/visualization/vispar.hpp +128 -0
  261. netgen/include/visualization/visual.hpp +28 -0
  262. netgen/include/visualization/visual_api.hpp +10 -0
  263. netgen/include/visualization/vssolution.hpp +399 -0
  264. netgen/lib/libnggui.lib +0 -0
  265. netgen/lib/ngcore.lib +0 -0
  266. netgen/lib/nglib.lib +0 -0
  267. netgen/lib/togl.lib +0 -0
  268. netgen/libnggui.dll +0 -0
  269. netgen/libngguipy.lib +0 -0
  270. netgen/libngguipy.pyd +0 -0
  271. netgen/libngpy/_NgOCC.pyi +1545 -0
  272. netgen/libngpy/__init__.pyi +7 -0
  273. netgen/libngpy/_csg.pyi +259 -0
  274. netgen/libngpy/_geom2d.pyi +323 -0
  275. netgen/libngpy/_meshing.pyi +1111 -0
  276. netgen/libngpy/_stl.pyi +131 -0
  277. netgen/libngpy.lib +0 -0
  278. netgen/libngpy.pyd +0 -0
  279. netgen/meshing.py +65 -0
  280. netgen/ngcore.dll +0 -0
  281. netgen/nglib.dll +0 -0
  282. netgen/occ.py +52 -0
  283. netgen/read_gmsh.py +259 -0
  284. netgen/read_meshio.py +22 -0
  285. netgen/stl.py +2 -0
  286. netgen/togl.dll +0 -0
  287. netgen/version.py +2 -0
  288. netgen/webgui.py +529 -0
  289. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/boundarycondition.geo +16 -0
  290. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/boxcyl.geo +32 -0
  291. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/circle_on_cube.geo +27 -0
  292. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/cone.geo +13 -0
  293. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/cube.geo +16 -0
  294. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/cubeandring.geo +55 -0
  295. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/cubeandspheres.geo +21 -0
  296. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/cubemcyl.geo +18 -0
  297. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/cubemsphere.geo +19 -0
  298. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/cylinder.geo +12 -0
  299. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/cylsphere.geo +12 -0
  300. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/doc/ng4.pdf +0 -0
  301. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/ellipsoid.geo +8 -0
  302. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/ellipticcyl.geo +10 -0
  303. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/extrusion.geo +99 -0
  304. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/fichera.geo +24 -0
  305. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/frame.step +11683 -0
  306. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/hinge.stl +8486 -0
  307. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/lshape3d.geo +26 -0
  308. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/manyholes.geo +26 -0
  309. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/manyholes2.geo +26 -0
  310. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/matrix.geo +27 -0
  311. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/ortho.geo +11 -0
  312. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/part1.stl +2662 -0
  313. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/period.geo +33 -0
  314. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/py_tutorials/exportNeutral.py +26 -0
  315. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/py_tutorials/mesh.py +19 -0
  316. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/py_tutorials/shaft.geo +65 -0
  317. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/revolution.geo +18 -0
  318. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/screw.step +1694 -0
  319. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/sculpture.geo +13 -0
  320. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/shaft.geo +65 -0
  321. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/shell.geo +10 -0
  322. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/sphere.geo +8 -0
  323. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/sphereincube.geo +17 -0
  324. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/square.in2d +35 -0
  325. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/squarecircle.in2d +48 -0
  326. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/squarehole.in2d +47 -0
  327. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/torus.geo +8 -0
  328. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/trafo.geo +57 -0
  329. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/twobricks.geo +15 -0
  330. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/twocubes.geo +18 -0
  331. netgen_mesher-6.2.2506.post35.dev0.data/data/share/netgen/twocyl.geo +16 -0
  332. netgen_mesher-6.2.2506.post35.dev0.dist-info/METADATA +15 -0
  333. netgen_mesher-6.2.2506.post35.dev0.dist-info/RECORD +340 -0
  334. netgen_mesher-6.2.2506.post35.dev0.dist-info/WHEEL +5 -0
  335. netgen_mesher-6.2.2506.post35.dev0.dist-info/entry_points.txt +2 -0
  336. netgen_mesher-6.2.2506.post35.dev0.dist-info/licenses/AUTHORS +1 -0
  337. netgen_mesher-6.2.2506.post35.dev0.dist-info/licenses/LICENSE +504 -0
  338. netgen_mesher-6.2.2506.post35.dev0.dist-info/top_level.txt +2 -0
  339. pyngcore/__init__.py +1 -0
  340. pyngcore/pyngcore.cp314-win_amd64.pyd +0 -0
@@ -0,0 +1,2312 @@
1
+ /*
2
+ pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
3
+
4
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
5
+
6
+ All rights reserved. Use of this source code is governed by a
7
+ BSD-style license that can be found in the LICENSE file.
8
+ */
9
+
10
+ #pragma once
11
+
12
+ #include "pybind11.h"
13
+ #include "detail/common.h"
14
+ #include "complex.h"
15
+ #include "gil_safe_call_once.h"
16
+ #include "pytypes.h"
17
+
18
+ #include <algorithm>
19
+ #include <array>
20
+ #include <cstdint>
21
+ #include <cstdlib>
22
+ #include <cstring>
23
+ #include <functional>
24
+ #include <numeric>
25
+ #include <sstream>
26
+ #include <string>
27
+ #include <type_traits>
28
+ #include <typeindex>
29
+ #include <utility>
30
+ #include <vector>
31
+
32
+ #if defined(PYBIND11_NUMPY_1_ONLY)
33
+ # error "PYBIND11_NUMPY_1_ONLY is no longer supported (see PR #5595)."
34
+ #endif
35
+
36
+ /* This will be true on all flat address space platforms and allows us to reduce the
37
+ whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
38
+ and dimension types (e.g. shape, strides, indexing), instead of inflicting this
39
+ upon the library user.
40
+ Note that NumPy 2 now uses ssize_t for `npy_intp` to simplify this. */
41
+ static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
42
+ static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
43
+ // We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares)
44
+
45
+ PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
46
+
47
+ PYBIND11_WARNING_DISABLE_MSVC(4127)
48
+
49
+ class dtype; // Forward declaration
50
+ class array; // Forward declaration
51
+
52
+ template <typename>
53
+ struct numpy_scalar; // Forward declaration
54
+
55
+ PYBIND11_NAMESPACE_BEGIN(detail)
56
+
57
+ template <>
58
+ struct handle_type_name<dtype> {
59
+ static constexpr auto name = const_name("numpy.dtype");
60
+ };
61
+
62
+ template <>
63
+ struct handle_type_name<array> {
64
+ static constexpr auto name = const_name("numpy.ndarray");
65
+ };
66
+
67
+ template <typename type, typename SFINAE = void>
68
+ struct npy_format_descriptor;
69
+
70
+ /* NumPy 1 proxy (always includes legacy fields) */
71
+ struct PyArrayDescr1_Proxy {
72
+ PyObject_HEAD
73
+ PyObject *typeobj;
74
+ char kind;
75
+ char type;
76
+ char byteorder;
77
+ char flags;
78
+ int type_num;
79
+ int elsize;
80
+ int alignment;
81
+ char *subarray;
82
+ PyObject *fields;
83
+ PyObject *names;
84
+ };
85
+
86
+ struct PyArrayDescr_Proxy {
87
+ PyObject_HEAD
88
+ PyObject *typeobj;
89
+ char kind;
90
+ char type;
91
+ char byteorder;
92
+ char _former_flags;
93
+ int type_num;
94
+ /* Additional fields are NumPy version specific. */
95
+ };
96
+
97
+ /* NumPy 2 proxy, including legacy fields */
98
+ struct PyArrayDescr2_Proxy {
99
+ PyObject_HEAD
100
+ PyObject *typeobj;
101
+ char kind;
102
+ char type;
103
+ char byteorder;
104
+ char _former_flags;
105
+ int type_num;
106
+ std::uint64_t flags;
107
+ ssize_t elsize;
108
+ ssize_t alignment;
109
+ PyObject *metadata;
110
+ Py_hash_t hash;
111
+ void *reserved_null[2];
112
+ /* The following fields only exist if 0 <= type_num < 2056 */
113
+ char *subarray;
114
+ PyObject *fields;
115
+ PyObject *names;
116
+ };
117
+
118
+ struct PyArray_Proxy {
119
+ PyObject_HEAD
120
+ char *data;
121
+ int nd;
122
+ ssize_t *dimensions;
123
+ ssize_t *strides;
124
+ PyObject *base;
125
+ PyObject *descr;
126
+ int flags;
127
+ };
128
+
129
+ struct PyVoidScalarObject_Proxy {
130
+ PyObject_VAR_HEAD char *obval;
131
+ PyArrayDescr_Proxy *descr;
132
+ int flags;
133
+ PyObject *base;
134
+ };
135
+
136
+ struct numpy_type_info {
137
+ PyObject *dtype_ptr;
138
+ std::string format_str;
139
+ };
140
+
141
+ struct numpy_internals {
142
+ std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
143
+
144
+ numpy_type_info *get_type_info(const std::type_info &tinfo, bool throw_if_missing = true) {
145
+ auto it = registered_dtypes.find(std::type_index(tinfo));
146
+ if (it != registered_dtypes.end()) {
147
+ return &(it->second);
148
+ }
149
+ if (throw_if_missing) {
150
+ pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
151
+ }
152
+ return nullptr;
153
+ }
154
+
155
+ template <typename T>
156
+ numpy_type_info *get_type_info(bool throw_if_missing = true) {
157
+ return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
158
+ }
159
+ };
160
+
161
+ PYBIND11_NOINLINE void load_numpy_internals(numpy_internals *&ptr) {
162
+ ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
163
+ }
164
+
165
+ inline numpy_internals &get_numpy_internals() {
166
+ static numpy_internals *ptr = nullptr;
167
+ if (!ptr) {
168
+ load_numpy_internals(ptr);
169
+ }
170
+ return *ptr;
171
+ }
172
+
173
+ PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name) {
174
+ module_ numpy = module_::import("numpy");
175
+ str version_string = numpy.attr("__version__");
176
+ module_ numpy_lib = module_::import("numpy.lib");
177
+ object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
178
+ int major_version = numpy_version.attr("major").cast<int>();
179
+
180
+ /* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
181
+ became a private module. */
182
+ std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
183
+ return module_::import((numpy_core_path + "." + submodule_name).c_str());
184
+ }
185
+
186
+ template <typename T>
187
+ struct same_size {
188
+ template <typename U>
189
+ using as = bool_constant<sizeof(T) == sizeof(U)>;
190
+ };
191
+
192
+ template <typename Concrete>
193
+ constexpr int platform_lookup() {
194
+ return -1;
195
+ }
196
+
197
+ // Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
198
+ template <typename Concrete, typename T, typename... Ts, typename... Ints>
199
+ constexpr int platform_lookup(int I, Ints... Is) {
200
+ return sizeof(Concrete) == sizeof(T) ? I : platform_lookup<Concrete, Ts...>(Is...);
201
+ }
202
+
203
+ struct npy_api {
204
+ // If you change this code, please review `normalized_dtype_num` below.
205
+ enum constants {
206
+ NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
207
+ NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
208
+ NPY_ARRAY_OWNDATA_ = 0x0004,
209
+ NPY_ARRAY_FORCECAST_ = 0x0010,
210
+ NPY_ARRAY_ENSUREARRAY_ = 0x0040,
211
+ NPY_ARRAY_ALIGNED_ = 0x0100,
212
+ NPY_ARRAY_WRITEABLE_ = 0x0400,
213
+ NPY_BOOL_ = 0,
214
+ NPY_BYTE_,
215
+ NPY_UBYTE_,
216
+ NPY_SHORT_,
217
+ NPY_USHORT_,
218
+ NPY_INT_,
219
+ NPY_UINT_,
220
+ NPY_LONG_,
221
+ NPY_ULONG_,
222
+ NPY_LONGLONG_,
223
+ NPY_ULONGLONG_,
224
+ NPY_FLOAT_,
225
+ NPY_DOUBLE_,
226
+ NPY_LONGDOUBLE_,
227
+ NPY_CFLOAT_,
228
+ NPY_CDOUBLE_,
229
+ NPY_CLONGDOUBLE_,
230
+ NPY_OBJECT_ = 17,
231
+ NPY_STRING_,
232
+ NPY_UNICODE_,
233
+ NPY_VOID_,
234
+ // Platform-dependent normalization
235
+ NPY_INT8_ = NPY_BYTE_,
236
+ NPY_UINT8_ = NPY_UBYTE_,
237
+ NPY_INT16_ = NPY_SHORT_,
238
+ NPY_UINT16_ = NPY_USHORT_,
239
+ // `npy_common.h` defines the integer aliases. In order, it checks:
240
+ // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
241
+ // and assigns the alias to the first matching size, so we should check in this order.
242
+ NPY_INT32_
243
+ = platform_lookup<std::int32_t, long, int, short>(NPY_LONG_, NPY_INT_, NPY_SHORT_),
244
+ NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
245
+ NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
246
+ NPY_INT64_
247
+ = platform_lookup<std::int64_t, long, long long, int>(NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
248
+ NPY_UINT64_
249
+ = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
250
+ NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
251
+ NPY_FLOAT32_ = platform_lookup<float, double, float, long double>(
252
+ NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
253
+ NPY_FLOAT64_ = platform_lookup<double, double, float, long double>(
254
+ NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
255
+ NPY_COMPLEX64_
256
+ = platform_lookup<std::complex<float>,
257
+ std::complex<double>,
258
+ std::complex<float>,
259
+ std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
260
+ NPY_COMPLEX128_
261
+ = platform_lookup<std::complex<double>,
262
+ std::complex<double>,
263
+ std::complex<float>,
264
+ std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
265
+ NPY_CHAR_ = std::is_signed<char>::value ? NPY_BYTE_ : NPY_UBYTE_,
266
+ };
267
+
268
+ unsigned int PyArray_RUNTIME_VERSION_;
269
+
270
+ struct PyArray_Dims {
271
+ Py_intptr_t *ptr;
272
+ int len;
273
+ };
274
+
275
+ static npy_api &get() {
276
+ PYBIND11_CONSTINIT static gil_safe_call_once_and_store<npy_api> storage;
277
+ return storage.call_once_and_store_result(lookup).get_stored();
278
+ }
279
+
280
+ bool PyArray_Check_(PyObject *obj) const {
281
+ return PyObject_TypeCheck(obj, PyArray_Type_) != 0;
282
+ }
283
+ bool PyArrayDescr_Check_(PyObject *obj) const {
284
+ return PyObject_TypeCheck(obj, PyArrayDescr_Type_) != 0;
285
+ }
286
+
287
+ unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
288
+ PyObject *(*PyArray_DescrFromType_)(int);
289
+ PyObject *(*PyArray_TypeObjectFromType_)(int);
290
+ PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
291
+ PyObject *,
292
+ int,
293
+ Py_intptr_t const *,
294
+ Py_intptr_t const *,
295
+ void *,
296
+ int,
297
+ PyObject *);
298
+ // Unused. Not removed because that affects ABI of the class.
299
+ PyObject *(*PyArray_DescrNewFromType_)(int);
300
+ int (*PyArray_CopyInto_)(PyObject *, PyObject *);
301
+ PyObject *(*PyArray_NewCopy_)(PyObject *, int);
302
+ PyTypeObject *PyArray_Type_;
303
+ PyTypeObject *PyVoidArrType_Type_;
304
+ PyTypeObject *PyArrayDescr_Type_;
305
+ PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
306
+ PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *);
307
+ void (*PyArray_ScalarAsCtype_)(PyObject *, void *);
308
+ PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
309
+ int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
310
+ bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
311
+ PyObject *(*PyArray_Squeeze_)(PyObject *);
312
+ // Unused. Not removed because that affects ABI of the class.
313
+ int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
314
+ PyObject *(*PyArray_Resize_)(PyObject *, PyArray_Dims *, int, int);
315
+ PyObject *(*PyArray_Newshape_)(PyObject *, PyArray_Dims *, int);
316
+ PyObject *(*PyArray_View_)(PyObject *, PyObject *, PyObject *);
317
+
318
+ private:
319
+ enum functions {
320
+ API_PyArray_GetNDArrayCFeatureVersion = 211,
321
+ API_PyArray_Type = 2,
322
+ API_PyArrayDescr_Type = 3,
323
+ API_PyVoidArrType_Type = 39,
324
+ API_PyArray_DescrFromType = 45,
325
+ API_PyArray_TypeObjectFromType = 46,
326
+ API_PyArray_DescrFromScalar = 57,
327
+ API_PyArray_Scalar = 60,
328
+ API_PyArray_ScalarAsCtype = 62,
329
+ API_PyArray_FromAny = 69,
330
+ API_PyArray_Resize = 80,
331
+ // CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
332
+ API_PyArray_CopyInto = 50,
333
+ API_PyArray_NewCopy = 85,
334
+ API_PyArray_NewFromDescr = 94,
335
+ API_PyArray_DescrNewFromType = 96,
336
+ API_PyArray_Newshape = 135,
337
+ API_PyArray_Squeeze = 136,
338
+ API_PyArray_View = 137,
339
+ API_PyArray_DescrConverter = 174,
340
+ API_PyArray_EquivTypes = 182,
341
+ API_PyArray_SetBaseObject = 282
342
+ };
343
+
344
+ static npy_api lookup() {
345
+ module_ m = detail::import_numpy_core_submodule("multiarray");
346
+ auto c = m.attr("_ARRAY_API");
347
+ void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), nullptr);
348
+ if (api_ptr == nullptr) {
349
+ raise_from(PyExc_SystemError, "FAILURE obtaining numpy _ARRAY_API pointer.");
350
+ throw error_already_set();
351
+ }
352
+ npy_api api;
353
+ #define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
354
+ DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
355
+ api.PyArray_RUNTIME_VERSION_ = api.PyArray_GetNDArrayCFeatureVersion_();
356
+ if (api.PyArray_RUNTIME_VERSION_ < 0x7) {
357
+ pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
358
+ }
359
+ DECL_NPY_API(PyArray_Type);
360
+ DECL_NPY_API(PyVoidArrType_Type);
361
+ DECL_NPY_API(PyArrayDescr_Type);
362
+ DECL_NPY_API(PyArray_DescrFromType);
363
+ DECL_NPY_API(PyArray_TypeObjectFromType);
364
+ DECL_NPY_API(PyArray_DescrFromScalar);
365
+ DECL_NPY_API(PyArray_Scalar);
366
+ DECL_NPY_API(PyArray_ScalarAsCtype);
367
+ DECL_NPY_API(PyArray_FromAny);
368
+ DECL_NPY_API(PyArray_Resize);
369
+ DECL_NPY_API(PyArray_CopyInto);
370
+ DECL_NPY_API(PyArray_NewCopy);
371
+ DECL_NPY_API(PyArray_NewFromDescr);
372
+ DECL_NPY_API(PyArray_DescrNewFromType);
373
+ DECL_NPY_API(PyArray_Newshape);
374
+ DECL_NPY_API(PyArray_Squeeze);
375
+ DECL_NPY_API(PyArray_View);
376
+ DECL_NPY_API(PyArray_DescrConverter);
377
+ DECL_NPY_API(PyArray_EquivTypes);
378
+ DECL_NPY_API(PyArray_SetBaseObject);
379
+
380
+ #undef DECL_NPY_API
381
+ return api;
382
+ }
383
+ };
384
+
385
+ template <typename T>
386
+ struct is_complex : std::false_type {};
387
+ template <typename T>
388
+ struct is_complex<std::complex<T>> : std::true_type {};
389
+
390
+ template <typename T, typename = void>
391
+ struct npy_format_descriptor_name;
392
+
393
+ template <typename T>
394
+ struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
395
+ static constexpr auto name = const_name<std::is_same<T, bool>::value>(
396
+ const_name("numpy.bool"),
397
+ const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
398
+ + const_name<sizeof(T) * 8>());
399
+ };
400
+
401
+ template <typename T>
402
+ struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
403
+ static constexpr auto name = const_name < std::is_same<T, float>::value
404
+ || std::is_same<T, const float>::value
405
+ || std::is_same<T, double>::value
406
+ || std::is_same<T, const double>::value
407
+ > (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
408
+ const_name("numpy.longdouble"));
409
+ };
410
+
411
+ template <typename T>
412
+ struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
413
+ static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
414
+ || std::is_same<typename T::value_type, const float>::value
415
+ || std::is_same<typename T::value_type, double>::value
416
+ || std::is_same<typename T::value_type, const double>::value
417
+ > (const_name("numpy.complex")
418
+ + const_name<sizeof(typename T::value_type) * 16>(),
419
+ const_name("numpy.longcomplex"));
420
+ };
421
+
422
+ template <typename T>
423
+ struct numpy_scalar_info {};
424
+
425
+ #define PYBIND11_NUMPY_SCALAR_IMPL(ctype_, typenum_) \
426
+ template <> \
427
+ struct numpy_scalar_info<ctype_> { \
428
+ static constexpr auto name = npy_format_descriptor_name<ctype_>::name; \
429
+ static constexpr int typenum = npy_api::typenum_##_; \
430
+ }
431
+
432
+ // boolean type
433
+ PYBIND11_NUMPY_SCALAR_IMPL(bool, NPY_BOOL);
434
+
435
+ // character types
436
+ PYBIND11_NUMPY_SCALAR_IMPL(char, NPY_CHAR);
437
+ PYBIND11_NUMPY_SCALAR_IMPL(signed char, NPY_BYTE);
438
+ PYBIND11_NUMPY_SCALAR_IMPL(unsigned char, NPY_UBYTE);
439
+
440
+ // signed integer types
441
+ PYBIND11_NUMPY_SCALAR_IMPL(std::int16_t, NPY_INT16);
442
+ PYBIND11_NUMPY_SCALAR_IMPL(std::int32_t, NPY_INT32);
443
+ PYBIND11_NUMPY_SCALAR_IMPL(std::int64_t, NPY_INT64);
444
+
445
+ // unsigned integer types
446
+ PYBIND11_NUMPY_SCALAR_IMPL(std::uint16_t, NPY_UINT16);
447
+ PYBIND11_NUMPY_SCALAR_IMPL(std::uint32_t, NPY_UINT32);
448
+ PYBIND11_NUMPY_SCALAR_IMPL(std::uint64_t, NPY_UINT64);
449
+
450
+ // floating point types
451
+ PYBIND11_NUMPY_SCALAR_IMPL(float, NPY_FLOAT);
452
+ PYBIND11_NUMPY_SCALAR_IMPL(double, NPY_DOUBLE);
453
+ PYBIND11_NUMPY_SCALAR_IMPL(long double, NPY_LONGDOUBLE);
454
+
455
+ // complex types
456
+ PYBIND11_NUMPY_SCALAR_IMPL(std::complex<float>, NPY_CFLOAT);
457
+ PYBIND11_NUMPY_SCALAR_IMPL(std::complex<double>, NPY_CDOUBLE);
458
+ PYBIND11_NUMPY_SCALAR_IMPL(std::complex<long double>, NPY_CLONGDOUBLE);
459
+
460
+ #undef PYBIND11_NUMPY_SCALAR_IMPL
461
+
462
+ // This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ...
463
+ // This is needed to correctly handle situations where multiple typenums map to the same type,
464
+ // e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different
465
+ // typenum. The normalized typenum should always match the values used in npy_format_descriptor.
466
+ // If you change this code, please review `enum constants` above.
467
+ static constexpr int normalized_dtype_num[npy_api::NPY_VOID_ + 1] = {
468
+ // NPY_BOOL_ =>
469
+ npy_api::NPY_BOOL_,
470
+ // NPY_BYTE_ =>
471
+ npy_api::NPY_BYTE_,
472
+ // NPY_UBYTE_ =>
473
+ npy_api::NPY_UBYTE_,
474
+ // NPY_SHORT_ =>
475
+ npy_api::NPY_INT16_,
476
+ // NPY_USHORT_ =>
477
+ npy_api::NPY_UINT16_,
478
+ // NPY_INT_ =>
479
+ sizeof(int) == sizeof(std::int16_t) ? npy_api::NPY_INT16_
480
+ : sizeof(int) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
481
+ : sizeof(int) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
482
+ : npy_api::NPY_INT_,
483
+ // NPY_UINT_ =>
484
+ sizeof(unsigned int) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_
485
+ : sizeof(unsigned int) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
486
+ : sizeof(unsigned int) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
487
+ : npy_api::NPY_UINT_,
488
+ // NPY_LONG_ =>
489
+ sizeof(long) == sizeof(std::int16_t) ? npy_api::NPY_INT16_
490
+ : sizeof(long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
491
+ : sizeof(long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
492
+ : npy_api::NPY_LONG_,
493
+ // NPY_ULONG_ =>
494
+ sizeof(unsigned long) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_
495
+ : sizeof(unsigned long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
496
+ : sizeof(unsigned long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
497
+ : npy_api::NPY_ULONG_,
498
+ // NPY_LONGLONG_ =>
499
+ sizeof(long long) == sizeof(std::int16_t) ? npy_api::NPY_INT16_
500
+ : sizeof(long long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
501
+ : sizeof(long long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
502
+ : npy_api::NPY_LONGLONG_,
503
+ // NPY_ULONGLONG_ =>
504
+ sizeof(unsigned long long) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_
505
+ : sizeof(unsigned long long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
506
+ : sizeof(unsigned long long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
507
+ : npy_api::NPY_ULONGLONG_,
508
+ // NPY_FLOAT_ =>
509
+ npy_api::NPY_FLOAT_,
510
+ // NPY_DOUBLE_ =>
511
+ npy_api::NPY_DOUBLE_,
512
+ // NPY_LONGDOUBLE_ =>
513
+ npy_api::NPY_LONGDOUBLE_,
514
+ // NPY_CFLOAT_ =>
515
+ npy_api::NPY_CFLOAT_,
516
+ // NPY_CDOUBLE_ =>
517
+ npy_api::NPY_CDOUBLE_,
518
+ // NPY_CLONGDOUBLE_ =>
519
+ npy_api::NPY_CLONGDOUBLE_,
520
+ // NPY_OBJECT_ =>
521
+ npy_api::NPY_OBJECT_,
522
+ // NPY_STRING_ =>
523
+ npy_api::NPY_STRING_,
524
+ // NPY_UNICODE_ =>
525
+ npy_api::NPY_UNICODE_,
526
+ // NPY_VOID_ =>
527
+ npy_api::NPY_VOID_,
528
+ };
529
+
530
+ inline PyArray_Proxy *array_proxy(void *ptr) { return reinterpret_cast<PyArray_Proxy *>(ptr); }
531
+
532
+ inline const PyArray_Proxy *array_proxy(const void *ptr) {
533
+ return reinterpret_cast<const PyArray_Proxy *>(ptr);
534
+ }
535
+
536
+ inline PyArrayDescr_Proxy *array_descriptor_proxy(PyObject *ptr) {
537
+ return reinterpret_cast<PyArrayDescr_Proxy *>(ptr);
538
+ }
539
+
540
+ inline const PyArrayDescr_Proxy *array_descriptor_proxy(const PyObject *ptr) {
541
+ return reinterpret_cast<const PyArrayDescr_Proxy *>(ptr);
542
+ }
543
+
544
+ inline const PyArrayDescr1_Proxy *array_descriptor1_proxy(const PyObject *ptr) {
545
+ return reinterpret_cast<const PyArrayDescr1_Proxy *>(ptr);
546
+ }
547
+
548
+ inline const PyArrayDescr2_Proxy *array_descriptor2_proxy(const PyObject *ptr) {
549
+ return reinterpret_cast<const PyArrayDescr2_Proxy *>(ptr);
550
+ }
551
+
552
+ inline bool check_flags(const void *ptr, int flag) {
553
+ return (flag == (array_proxy(ptr)->flags & flag));
554
+ }
555
+
556
+ template <typename T>
557
+ struct is_std_array : std::false_type {};
558
+ template <typename T, size_t N>
559
+ struct is_std_array<std::array<T, N>> : std::true_type {};
560
+
561
+ template <typename T>
562
+ struct array_info_scalar {
563
+ using type = T;
564
+ static constexpr bool is_array = false;
565
+ static constexpr bool is_empty = false;
566
+ static constexpr auto extents = const_name("");
567
+ static void append_extents(list & /* shape */) {}
568
+ };
569
+ // Computes underlying type and a comma-separated list of extents for array
570
+ // types (any mix of std::array and built-in arrays). An array of char is
571
+ // treated as scalar because it gets special handling.
572
+ template <typename T>
573
+ struct array_info : array_info_scalar<T> {};
574
+ template <typename T, size_t N>
575
+ struct array_info<std::array<T, N>> {
576
+ using type = typename array_info<T>::type;
577
+ static constexpr bool is_array = true;
578
+ static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
579
+ static constexpr size_t extent = N;
580
+
581
+ // appends the extents to shape
582
+ static void append_extents(list &shape) {
583
+ shape.append(N);
584
+ array_info<T>::append_extents(shape);
585
+ }
586
+
587
+ static constexpr auto extents = const_name<array_info<T>::is_array>(
588
+ ::pybind11::detail::concat(const_name<N>(), array_info<T>::extents), const_name<N>());
589
+ };
590
+ // For numpy we have special handling for arrays of characters, so we don't include
591
+ // the size in the array extents.
592
+ template <size_t N>
593
+ struct array_info<char[N]> : array_info_scalar<char[N]> {};
594
+ template <size_t N>
595
+ struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> {};
596
+ template <typename T, size_t N>
597
+ struct array_info<T[N]> : array_info<std::array<T, N>> {};
598
+ template <typename T>
599
+ using remove_all_extents_t = typename array_info<T>::type;
600
+
601
+ template <typename T>
602
+ using is_pod_struct
603
+ = all_of<std::is_standard_layout<T>, // since we're accessing directly in memory
604
+ // we need a standard layout type
605
+ #if defined(__GLIBCXX__) \
606
+ && (__GLIBCXX__ < 20150422 || __GLIBCXX__ == 20150426 || __GLIBCXX__ == 20150623 \
607
+ || __GLIBCXX__ == 20150626 || __GLIBCXX__ == 20160803)
608
+ // libstdc++ < 5 (including versions 4.8.5, 4.9.3 and 4.9.4 which were released after
609
+ // 5) don't implement is_trivially_copyable, so approximate it
610
+ std::is_trivially_destructible<T>,
611
+ satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
612
+ #else
613
+ std::is_trivially_copyable<T>,
614
+ #endif
615
+ satisfies_none_of<T,
616
+ std::is_reference,
617
+ std::is_array,
618
+ is_std_array,
619
+ std::is_arithmetic,
620
+ is_complex,
621
+ std::is_enum>>;
622
+
623
+ // Replacement for std::is_pod (deprecated in C++20)
624
+ template <typename T>
625
+ using is_pod = all_of<std::is_standard_layout<T>, std::is_trivial<T>>;
626
+
627
+ template <ssize_t Dim = 0, typename Strides>
628
+ ssize_t byte_offset_unsafe(const Strides &) {
629
+ return 0;
630
+ }
631
+ template <ssize_t Dim = 0, typename Strides, typename... Ix>
632
+ ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
633
+ return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
634
+ }
635
+
636
+ /**
637
+ * Proxy class providing unsafe, unchecked const access to array data. This is constructed through
638
+ * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
639
+ * will be -1 for dimensions determined at runtime.
640
+ */
641
+ template <typename T, ssize_t Dims>
642
+ class unchecked_reference {
643
+ protected:
644
+ static constexpr bool Dynamic = Dims < 0;
645
+ const unsigned char *data_;
646
+ // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
647
+ // make large performance gains on big, nested loops, but requires compile-time dimensions
648
+ conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>> shape_, strides_;
649
+ const ssize_t dims_;
650
+
651
+ friend class pybind11::array;
652
+ // Constructor for compile-time dimensions:
653
+ template <bool Dyn = Dynamic>
654
+ unchecked_reference(const void *data,
655
+ const ssize_t *shape,
656
+ const ssize_t *strides,
657
+ enable_if_t<!Dyn, ssize_t>)
658
+ : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
659
+ for (size_t i = 0; i < (size_t) dims_; i++) {
660
+ shape_[i] = shape[i];
661
+ strides_[i] = strides[i];
662
+ }
663
+ }
664
+ // Constructor for runtime dimensions:
665
+ template <bool Dyn = Dynamic>
666
+ unchecked_reference(const void *data,
667
+ const ssize_t *shape,
668
+ const ssize_t *strides,
669
+ enable_if_t<Dyn, ssize_t> dims)
670
+ : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides},
671
+ dims_{dims} {}
672
+
673
+ public:
674
+ /**
675
+ * Unchecked const reference access to data at the given indices. For a compile-time known
676
+ * number of dimensions, this requires the correct number of arguments; for run-time
677
+ * dimensionality, this is not checked (and so is up to the caller to use safely).
678
+ */
679
+ template <typename... Ix>
680
+ const T &operator()(Ix... index) const {
681
+ static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
682
+ "Invalid number of indices for unchecked array reference");
683
+ return *reinterpret_cast<const T *>(data_
684
+ + byte_offset_unsafe(strides_, ssize_t(index)...));
685
+ }
686
+ /**
687
+ * Unchecked const reference access to data; this operator only participates if the reference
688
+ * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
689
+ */
690
+ template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
691
+ const T &operator[](ssize_t index) const {
692
+ return operator()(index);
693
+ }
694
+
695
+ /// Pointer access to the data at the given indices.
696
+ template <typename... Ix>
697
+ const T *data(Ix... ix) const {
698
+ return &operator()(ssize_t(ix)...);
699
+ }
700
+
701
+ /// Returns the item size, i.e. sizeof(T)
702
+ constexpr static ssize_t itemsize() { return sizeof(T); }
703
+
704
+ /// Returns the shape (i.e. size) of dimension `dim`
705
+ ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
706
+
707
+ /// Returns the number of dimensions of the array
708
+ ssize_t ndim() const { return dims_; }
709
+
710
+ /// Returns the total number of elements in the referenced array, i.e. the product of the
711
+ /// shapes
712
+ template <bool Dyn = Dynamic>
713
+ enable_if_t<!Dyn, ssize_t> size() const {
714
+ return std::accumulate(
715
+ shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
716
+ }
717
+ template <bool Dyn = Dynamic>
718
+ enable_if_t<Dyn, ssize_t> size() const {
719
+ return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
720
+ }
721
+
722
+ /// Returns the total number of bytes used by the referenced data. Note that the actual span
723
+ /// in memory may be larger if the referenced array has non-contiguous strides (e.g. for a
724
+ /// slice).
725
+ ssize_t nbytes() const { return size() * itemsize(); }
726
+ };
727
+
728
+ template <typename T, ssize_t Dims>
729
+ class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
730
+ friend class pybind11::array;
731
+ using ConstBase = unchecked_reference<T, Dims>;
732
+ using ConstBase::ConstBase;
733
+ using ConstBase::Dynamic;
734
+
735
+ public:
736
+ // Bring in const-qualified versions from base class
737
+ using ConstBase::operator();
738
+ using ConstBase::operator[];
739
+
740
+ /// Mutable, unchecked access to data at the given indices.
741
+ template <typename... Ix>
742
+ T &operator()(Ix... index) {
743
+ static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
744
+ "Invalid number of indices for unchecked array reference");
745
+ return const_cast<T &>(ConstBase::operator()(index...));
746
+ }
747
+ /**
748
+ * Mutable, unchecked access data at the given index; this operator only participates if the
749
+ * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
750
+ * exactly equivalent to `obj(index)`.
751
+ */
752
+ template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
753
+ T &operator[](ssize_t index) {
754
+ return operator()(index);
755
+ }
756
+
757
+ /// Mutable pointer access to the data at the given indices.
758
+ template <typename... Ix>
759
+ T *mutable_data(Ix... ix) {
760
+ return &operator()(ssize_t(ix)...);
761
+ }
762
+ };
763
+
764
+ template <typename T, ssize_t Dim>
765
+ struct type_caster<unchecked_reference<T, Dim>> {
766
+ static_assert(Dim == 0 && Dim > 0 /* always fail */,
767
+ "unchecked array proxy object is not castable");
768
+ };
769
+ template <typename T, ssize_t Dim>
770
+ struct type_caster<unchecked_mutable_reference<T, Dim>>
771
+ : type_caster<unchecked_reference<T, Dim>> {};
772
+
773
+ template <typename T>
774
+ struct type_caster<numpy_scalar<T>> {
775
+ using value_type = T;
776
+ using type_info = numpy_scalar_info<T>;
777
+
778
+ PYBIND11_TYPE_CASTER(numpy_scalar<T>, type_info::name);
779
+
780
+ static handle &target_type() {
781
+ static handle tp = npy_api::get().PyArray_TypeObjectFromType_(type_info::typenum);
782
+ return tp;
783
+ }
784
+
785
+ static handle &target_dtype() {
786
+ static handle tp = npy_api::get().PyArray_DescrFromType_(type_info::typenum);
787
+ return tp;
788
+ }
789
+
790
+ bool load(handle src, bool) {
791
+ if (isinstance(src, target_type())) {
792
+ npy_api::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value);
793
+ return true;
794
+ }
795
+ return false;
796
+ }
797
+
798
+ static handle cast(numpy_scalar<T> src, return_value_policy, handle) {
799
+ return npy_api::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr);
800
+ }
801
+ };
802
+
803
+ PYBIND11_NAMESPACE_END(detail)
804
+
805
+ template <typename T>
806
+ struct numpy_scalar {
807
+ using value_type = T;
808
+
809
+ value_type value;
810
+
811
+ numpy_scalar() = default;
812
+ explicit numpy_scalar(value_type value) : value(value) {}
813
+
814
+ explicit operator value_type() const { return value; }
815
+ numpy_scalar &operator=(value_type value) {
816
+ this->value = value;
817
+ return *this;
818
+ }
819
+
820
+ friend bool operator==(const numpy_scalar &a, const numpy_scalar &b) {
821
+ return a.value == b.value;
822
+ }
823
+
824
+ friend bool operator!=(const numpy_scalar &a, const numpy_scalar &b) { return !(a == b); }
825
+ };
826
+
827
+ template <typename T>
828
+ numpy_scalar<T> make_scalar(T value) {
829
+ return numpy_scalar<T>(value);
830
+ }
831
+
832
+ class dtype : public object {
833
+ public:
834
+ PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_)
835
+
836
+ explicit dtype(const buffer_info &info) {
837
+ dtype descr(_dtype_from_pep3118()(pybind11::str(info.format)));
838
+ // If info.itemsize == 0, use the value calculated from the format string
839
+ m_ptr = descr.strip_padding(info.itemsize != 0 ? info.itemsize : descr.itemsize())
840
+ .release()
841
+ .ptr();
842
+ }
843
+
844
+ explicit dtype(const pybind11::str &format) : dtype(from_args(format)) {}
845
+
846
+ explicit dtype(const std::string &format) : dtype(pybind11::str(format)) {}
847
+
848
+ explicit dtype(const char *format) : dtype(pybind11::str(format)) {}
849
+
850
+ dtype(list names, list formats, list offsets, ssize_t itemsize) {
851
+ dict args;
852
+ args["names"] = std::move(names);
853
+ args["formats"] = std::move(formats);
854
+ args["offsets"] = std::move(offsets);
855
+ args["itemsize"] = pybind11::int_(itemsize);
856
+ m_ptr = from_args(args).release().ptr();
857
+ }
858
+
859
+ /// Return dtype for the given typenum (one of the NPY_TYPES).
860
+ /// https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_DescrFromType
861
+ explicit dtype(int typenum)
862
+ : object(detail::npy_api::get().PyArray_DescrFromType_(typenum), stolen_t{}) {
863
+ if (m_ptr == nullptr) {
864
+ throw error_already_set();
865
+ }
866
+ }
867
+
868
+ /// This is essentially the same as calling numpy.dtype(args) in Python.
869
+ static dtype from_args(const object &args) {
870
+ PyObject *ptr = nullptr;
871
+ if ((detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) == 0) || !ptr) {
872
+ throw error_already_set();
873
+ }
874
+ return reinterpret_steal<dtype>(ptr);
875
+ }
876
+
877
+ /// Return dtype associated with a C++ type.
878
+ template <typename T>
879
+ static dtype of() {
880
+ return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
881
+ }
882
+
883
+ /// Return the type number associated with a C++ type.
884
+ /// This is the constexpr equivalent of `dtype::of<T>().num()`.
885
+ template <typename T>
886
+ static constexpr int num_of() {
887
+ return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::value;
888
+ }
889
+
890
+ /// Size of the data type in bytes.
891
+ ssize_t itemsize() const {
892
+ if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
893
+ return detail::array_descriptor1_proxy(m_ptr)->elsize;
894
+ }
895
+ return detail::array_descriptor2_proxy(m_ptr)->elsize;
896
+ }
897
+
898
+ /// Returns true for structured data types.
899
+ bool has_fields() const {
900
+ if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
901
+ return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
902
+ }
903
+ const auto *proxy = detail::array_descriptor2_proxy(m_ptr);
904
+ if (proxy->type_num < 0 || proxy->type_num >= 2056) {
905
+ return false;
906
+ }
907
+ return proxy->names != nullptr;
908
+ }
909
+
910
+ /// Single-character code for dtype's kind.
911
+ /// For example, floating point types are 'f' and integral types are 'i'.
912
+ char kind() const { return detail::array_descriptor_proxy(m_ptr)->kind; }
913
+
914
+ /// Single-character for dtype's type.
915
+ /// For example, ``float`` is 'f', ``double`` 'd', ``int`` 'i', and ``long`` 'l'.
916
+ char char_() const {
917
+ // Note: The signature, `dtype::char_` follows the naming of NumPy's
918
+ // public Python API (i.e., ``dtype.char``), rather than its internal
919
+ // C API (``PyArray_Descr::type``).
920
+ return detail::array_descriptor_proxy(m_ptr)->type;
921
+ }
922
+
923
+ /// Type number of dtype. Note that different values may be returned for equivalent types,
924
+ /// e.g. even though ``long`` may be equivalent to ``int`` or ``long long``, they still have
925
+ /// different type numbers. Consider using `normalized_num` to avoid this.
926
+ int num() const {
927
+ // Note: The signature, `dtype::num` follows the naming of NumPy's public
928
+ // Python API (i.e., ``dtype.num``), rather than its internal
929
+ // C API (``PyArray_Descr::type_num``).
930
+ return detail::array_descriptor_proxy(m_ptr)->type_num;
931
+ }
932
+
933
+ /// Type number of dtype, normalized to match the return value of `num_of` for equivalent
934
+ /// types. This function can be used to write switch statements that correctly handle
935
+ /// equivalent types with different type numbers.
936
+ int normalized_num() const {
937
+ int value = num();
938
+ if (value >= 0 && value <= detail::npy_api::NPY_VOID_) {
939
+ return detail::normalized_dtype_num[value];
940
+ }
941
+ return value;
942
+ }
943
+
944
+ /// Single character for byteorder
945
+ char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
946
+
947
+ /// Alignment of the data type
948
+ ssize_t alignment() const {
949
+ if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
950
+ return detail::array_descriptor1_proxy(m_ptr)->alignment;
951
+ }
952
+ return detail::array_descriptor2_proxy(m_ptr)->alignment;
953
+ }
954
+
955
+ /// Flags for the array descriptor
956
+ std::uint64_t flags() const {
957
+ if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
958
+ return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
959
+ }
960
+ return detail::array_descriptor2_proxy(m_ptr)->flags;
961
+ }
962
+
963
+ private:
964
+ static object &_dtype_from_pep3118() {
965
+ PYBIND11_CONSTINIT static gil_safe_call_once_and_store<object> storage;
966
+ return storage
967
+ .call_once_and_store_result([]() {
968
+ return detail::import_numpy_core_submodule("_internal")
969
+ .attr("_dtype_from_pep3118");
970
+ })
971
+ .get_stored();
972
+ }
973
+
974
+ dtype strip_padding(ssize_t itemsize) {
975
+ // Recursively strip all void fields with empty names that are generated for
976
+ // padding fields (as of NumPy v1.11).
977
+ if (!has_fields()) {
978
+ return *this;
979
+ }
980
+
981
+ struct field_descr {
982
+ pybind11::str name;
983
+ object format;
984
+ pybind11::int_ offset;
985
+ field_descr(pybind11::str &&name, object &&format, pybind11::int_ &&offset)
986
+ : name{std::move(name)}, format{std::move(format)}, offset{std::move(offset)} {};
987
+ };
988
+ auto field_dict = attr("fields").cast<dict>();
989
+ std::vector<field_descr> field_descriptors;
990
+ field_descriptors.reserve(field_dict.size());
991
+
992
+ for (auto field : field_dict.attr("items")()) {
993
+ auto spec = field.cast<tuple>();
994
+ auto name = spec[0].cast<pybind11::str>();
995
+ auto spec_fo = spec[1].cast<tuple>();
996
+ auto format = spec_fo[0].cast<dtype>();
997
+ auto offset = spec_fo[1].cast<pybind11::int_>();
998
+ if ((len(name) == 0u) && format.kind() == 'V') {
999
+ continue;
1000
+ }
1001
+ field_descriptors.emplace_back(
1002
+ std::move(name), format.strip_padding(format.itemsize()), std::move(offset));
1003
+ }
1004
+
1005
+ std::sort(field_descriptors.begin(),
1006
+ field_descriptors.end(),
1007
+ [](const field_descr &a, const field_descr &b) {
1008
+ return a.offset.cast<int>() < b.offset.cast<int>();
1009
+ });
1010
+
1011
+ list names, formats, offsets;
1012
+ for (auto &descr : field_descriptors) {
1013
+ names.append(std::move(descr.name));
1014
+ formats.append(std::move(descr.format));
1015
+ offsets.append(std::move(descr.offset));
1016
+ }
1017
+ return dtype(std::move(names), std::move(formats), std::move(offsets), itemsize);
1018
+ }
1019
+ };
1020
+
1021
+ class array : public buffer {
1022
+ public:
1023
+ PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
1024
+
1025
+ enum {
1026
+ c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
1027
+ f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
1028
+ forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
1029
+ };
1030
+
1031
+ array() : array(0, static_cast<const double *>(nullptr)) {}
1032
+
1033
+ using ShapeContainer = detail::any_container<ssize_t>;
1034
+ using StridesContainer = detail::any_container<ssize_t>;
1035
+
1036
+ // Constructs an array taking shape/strides from arbitrary container types
1037
+ array(const pybind11::dtype &dt,
1038
+ ShapeContainer shape,
1039
+ StridesContainer strides,
1040
+ const void *ptr = nullptr,
1041
+ handle base = handle()) {
1042
+
1043
+ if (strides->empty()) {
1044
+ *strides = detail::c_strides(*shape, dt.itemsize());
1045
+ }
1046
+
1047
+ auto ndim = shape->size();
1048
+ if (ndim != strides->size()) {
1049
+ pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
1050
+ }
1051
+ auto descr = dt;
1052
+
1053
+ int flags = 0;
1054
+ if (base && ptr) {
1055
+ if (isinstance<array>(base)) {
1056
+ /* Copy flags from base (except ownership bit) */
1057
+ flags = reinterpret_borrow<array>(base).flags()
1058
+ & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
1059
+ } else {
1060
+ /* Writable by default, easy to downgrade later on if needed */
1061
+ flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
1062
+ }
1063
+ }
1064
+
1065
+ auto &api = detail::npy_api::get();
1066
+ auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
1067
+ api.PyArray_Type_,
1068
+ descr.release().ptr(),
1069
+ (int) ndim,
1070
+ // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
1071
+ reinterpret_cast<Py_intptr_t *>(shape->data()),
1072
+ reinterpret_cast<Py_intptr_t *>(strides->data()),
1073
+ const_cast<void *>(ptr),
1074
+ flags,
1075
+ nullptr));
1076
+ if (!tmp) {
1077
+ throw error_already_set();
1078
+ }
1079
+ if (ptr) {
1080
+ if (base) {
1081
+ api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
1082
+ } else {
1083
+ tmp = reinterpret_steal<object>(
1084
+ api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
1085
+ }
1086
+ }
1087
+ m_ptr = tmp.release().ptr();
1088
+ }
1089
+
1090
+ array(const pybind11::dtype &dt,
1091
+ ShapeContainer shape,
1092
+ const void *ptr = nullptr,
1093
+ handle base = handle())
1094
+ : array(dt, std::move(shape), {}, ptr, base) {}
1095
+
1096
+ template <typename T,
1097
+ typename
1098
+ = detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
1099
+ array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
1100
+ : array(dt, {{count}}, ptr, base) {}
1101
+
1102
+ template <typename T>
1103
+ array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
1104
+ : array(pybind11::dtype::of<T>(),
1105
+ std::move(shape),
1106
+ std::move(strides),
1107
+ reinterpret_cast<const void *>(ptr),
1108
+ base) {}
1109
+
1110
+ template <typename T>
1111
+ array(ShapeContainer shape, const T *ptr, handle base = handle())
1112
+ : array(std::move(shape), {}, ptr, base) {}
1113
+
1114
+ template <typename T>
1115
+ explicit array(ssize_t count, const T *ptr, handle base = handle())
1116
+ : array({count}, {}, ptr, base) {}
1117
+
1118
+ explicit array(const buffer_info &info, handle base = handle())
1119
+ : array(pybind11::dtype(info), info.shape, info.strides, info.ptr, base) {}
1120
+
1121
+ /// Array descriptor (dtype)
1122
+ pybind11::dtype dtype() const {
1123
+ return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
1124
+ }
1125
+
1126
+ /// Total number of elements
1127
+ ssize_t size() const {
1128
+ return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
1129
+ }
1130
+
1131
+ /// Byte size of a single element
1132
+ ssize_t itemsize() const { return dtype().itemsize(); }
1133
+
1134
+ /// Total number of bytes
1135
+ ssize_t nbytes() const { return size() * itemsize(); }
1136
+
1137
+ /// Number of dimensions
1138
+ ssize_t ndim() const { return detail::array_proxy(m_ptr)->nd; }
1139
+
1140
+ /// Base object
1141
+ object base() const { return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base); }
1142
+
1143
+ /// Dimensions of the array
1144
+ const ssize_t *shape() const { return detail::array_proxy(m_ptr)->dimensions; }
1145
+
1146
+ /// Dimension along a given axis
1147
+ ssize_t shape(ssize_t dim) const {
1148
+ if (dim >= ndim()) {
1149
+ fail_dim_check(dim, "invalid axis");
1150
+ }
1151
+ return shape()[dim];
1152
+ }
1153
+
1154
+ /// Strides of the array
1155
+ const ssize_t *strides() const { return detail::array_proxy(m_ptr)->strides; }
1156
+
1157
+ /// Stride along a given axis
1158
+ ssize_t strides(ssize_t dim) const {
1159
+ if (dim >= ndim()) {
1160
+ fail_dim_check(dim, "invalid axis");
1161
+ }
1162
+ return strides()[dim];
1163
+ }
1164
+
1165
+ /// Return the NumPy array flags
1166
+ int flags() const { return detail::array_proxy(m_ptr)->flags; }
1167
+
1168
+ /// If set, the array is writeable (otherwise the buffer is read-only)
1169
+ bool writeable() const {
1170
+ return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
1171
+ }
1172
+
1173
+ /// If set, the array owns the data (will be freed when the array is deleted)
1174
+ bool owndata() const {
1175
+ return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
1176
+ }
1177
+
1178
+ /// Pointer to the contained data. If index is not provided, points to the
1179
+ /// beginning of the buffer. May throw if the index would lead to out of bounds access.
1180
+ template <typename... Ix>
1181
+ const void *data(Ix... index) const {
1182
+ return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
1183
+ }
1184
+
1185
+ /// Mutable pointer to the contained data. If index is not provided, points to the
1186
+ /// beginning of the buffer. May throw if the index would lead to out of bounds access.
1187
+ /// May throw if the array is not writeable.
1188
+ template <typename... Ix>
1189
+ void *mutable_data(Ix... index) {
1190
+ check_writeable();
1191
+ return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
1192
+ }
1193
+
1194
+ /// Byte offset from beginning of the array to a given index (full or partial).
1195
+ /// May throw if the index would lead to out of bounds access.
1196
+ template <typename... Ix>
1197
+ ssize_t offset_at(Ix... index) const {
1198
+ if ((ssize_t) sizeof...(index) > ndim()) {
1199
+ fail_dim_check(sizeof...(index), "too many indices for an array");
1200
+ }
1201
+ return byte_offset(ssize_t(index)...);
1202
+ }
1203
+
1204
+ ssize_t offset_at() const { return 0; }
1205
+
1206
+ /// Item count from beginning of the array to a given index (full or partial).
1207
+ /// May throw if the index would lead to out of bounds access.
1208
+ template <typename... Ix>
1209
+ ssize_t index_at(Ix... index) const {
1210
+ return offset_at(index...) / itemsize();
1211
+ }
1212
+
1213
+ /**
1214
+ * Returns a proxy object that provides access to the array's data without bounds or
1215
+ * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
1216
+ * care: the array must not be destroyed or reshaped for the duration of the returned object,
1217
+ * and the caller must take care not to access invalid dimensions or dimension indices.
1218
+ */
1219
+ template <typename T, ssize_t Dims = -1>
1220
+ detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
1221
+ if (Dims >= 0 && ndim() != Dims) {
1222
+ throw std::domain_error("array has incorrect number of dimensions: "
1223
+ + std::to_string(ndim()) + "; expected "
1224
+ + std::to_string(Dims));
1225
+ }
1226
+ return detail::unchecked_mutable_reference<T, Dims>(
1227
+ mutable_data(), shape(), strides(), ndim());
1228
+ }
1229
+
1230
+ /**
1231
+ * Returns a proxy object that provides const access to the array's data without bounds or
1232
+ * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
1233
+ * underlying array have the `writable` flag. Use with care: the array must not be destroyed
1234
+ * or reshaped for the duration of the returned object, and the caller must take care not to
1235
+ * access invalid dimensions or dimension indices.
1236
+ */
1237
+ template <typename T, ssize_t Dims = -1>
1238
+ detail::unchecked_reference<T, Dims> unchecked() const & {
1239
+ if (Dims >= 0 && ndim() != Dims) {
1240
+ throw std::domain_error("array has incorrect number of dimensions: "
1241
+ + std::to_string(ndim()) + "; expected "
1242
+ + std::to_string(Dims));
1243
+ }
1244
+ return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
1245
+ }
1246
+
1247
+ /// Return a new view with all of the dimensions of length 1 removed
1248
+ array squeeze() {
1249
+ auto &api = detail::npy_api::get();
1250
+ return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
1251
+ }
1252
+
1253
+ /// Resize array to given shape
1254
+ /// If refcheck is true and more that one reference exist to this array
1255
+ /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
1256
+ void resize(ShapeContainer new_shape, bool refcheck = true) {
1257
+ detail::npy_api::PyArray_Dims d
1258
+ = {// Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
1259
+ reinterpret_cast<Py_intptr_t *>(new_shape->data()),
1260
+ int(new_shape->size())};
1261
+ // try to resize, set ordering param to -1 cause it's not used anyway
1262
+ auto new_array = reinterpret_steal<object>(
1263
+ detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1));
1264
+ if (!new_array) {
1265
+ throw error_already_set();
1266
+ }
1267
+ if (isinstance<array>(new_array)) {
1268
+ *this = std::move(new_array);
1269
+ }
1270
+ }
1271
+
1272
+ /// Optional `order` parameter omitted, to be added as needed.
1273
+ array reshape(ShapeContainer new_shape) {
1274
+ detail::npy_api::PyArray_Dims d
1275
+ = {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
1276
+ auto new_array
1277
+ = reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
1278
+ if (!new_array) {
1279
+ throw error_already_set();
1280
+ }
1281
+ return new_array;
1282
+ }
1283
+
1284
+ /// Create a view of an array in a different data type.
1285
+ /// This function may fundamentally reinterpret the data in the array.
1286
+ /// It is the responsibility of the caller to ensure that this is safe.
1287
+ /// Only supports the `dtype` argument, the `type` argument is omitted,
1288
+ /// to be added as needed.
1289
+ array view(const std::string &dtype) {
1290
+ auto &api = detail::npy_api::get();
1291
+ auto new_view = reinterpret_steal<array>(api.PyArray_View_(
1292
+ m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr));
1293
+ if (!new_view) {
1294
+ throw error_already_set();
1295
+ }
1296
+ return new_view;
1297
+ }
1298
+
1299
+ /// Ensure that the argument is a NumPy array
1300
+ /// In case of an error, nullptr is returned and the Python error is cleared.
1301
+ static array ensure(handle h, int ExtraFlags = 0) {
1302
+ auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
1303
+ if (!result) {
1304
+ PyErr_Clear();
1305
+ }
1306
+ return result;
1307
+ }
1308
+
1309
+ protected:
1310
+ template <typename, typename>
1311
+ friend struct detail::npy_format_descriptor;
1312
+
1313
+ void fail_dim_check(ssize_t dim, const std::string &msg) const {
1314
+ throw index_error(msg + ": " + std::to_string(dim) + " (ndim = " + std::to_string(ndim())
1315
+ + ')');
1316
+ }
1317
+
1318
+ template <typename... Ix>
1319
+ ssize_t byte_offset(Ix... index) const {
1320
+ check_dimensions(index...);
1321
+ return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
1322
+ }
1323
+
1324
+ void check_writeable() const {
1325
+ if (!writeable()) {
1326
+ throw std::domain_error("array is not writeable");
1327
+ }
1328
+ }
1329
+
1330
+ template <typename... Ix>
1331
+ void check_dimensions(Ix... index) const {
1332
+ check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
1333
+ }
1334
+
1335
+ void check_dimensions_impl(ssize_t, const ssize_t *) const {}
1336
+
1337
+ template <typename... Ix>
1338
+ void check_dimensions_impl(ssize_t axis, const ssize_t *shape, ssize_t i, Ix... index) const {
1339
+ if (i >= *shape) {
1340
+ throw index_error(std::string("index ") + std::to_string(i)
1341
+ + " is out of bounds for axis " + std::to_string(axis)
1342
+ + " with size " + std::to_string(*shape));
1343
+ }
1344
+ check_dimensions_impl(axis + 1, shape + 1, index...);
1345
+ }
1346
+
1347
+ /// Create array from any object -- always returns a new reference
1348
+ static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
1349
+ if (ptr == nullptr) {
1350
+ set_error(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
1351
+ return nullptr;
1352
+ }
1353
+ return detail::npy_api::get().PyArray_FromAny_(
1354
+ ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
1355
+ }
1356
+ };
1357
+
1358
+ template <typename T, int ExtraFlags = array::forcecast>
1359
+ class array_t : public array {
1360
+ private:
1361
+ struct private_ctor {};
1362
+ // Delegating constructor needed when both moving and accessing in the same constructor
1363
+ array_t(private_ctor,
1364
+ ShapeContainer &&shape,
1365
+ StridesContainer &&strides,
1366
+ const T *ptr,
1367
+ handle base)
1368
+ : array(std::move(shape), std::move(strides), ptr, base) {}
1369
+
1370
+ public:
1371
+ static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
1372
+
1373
+ using value_type = T;
1374
+
1375
+ array_t() : array(0, static_cast<const T *>(nullptr)) {}
1376
+ array_t(handle h, borrowed_t) : array(h, borrowed_t{}) {}
1377
+ array_t(handle h, stolen_t) : array(h, stolen_t{}) {}
1378
+
1379
+ PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
1380
+ array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
1381
+ if (!m_ptr) {
1382
+ PyErr_Clear();
1383
+ }
1384
+ if (!is_borrowed) {
1385
+ Py_XDECREF(h.ptr());
1386
+ }
1387
+ }
1388
+
1389
+ // NOLINTNEXTLINE(google-explicit-constructor)
1390
+ array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
1391
+ if (!m_ptr) {
1392
+ throw error_already_set();
1393
+ }
1394
+ }
1395
+
1396
+ explicit array_t(const buffer_info &info, handle base = handle()) : array(info, base) {}
1397
+
1398
+ array_t(ShapeContainer shape,
1399
+ StridesContainer strides,
1400
+ const T *ptr = nullptr,
1401
+ handle base = handle())
1402
+ : array(std::move(shape), std::move(strides), ptr, base) {}
1403
+
1404
+ explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
1405
+ : array_t(private_ctor{},
1406
+ std::move(shape),
1407
+ (ExtraFlags & f_style) != 0 ? detail::f_strides(*shape, itemsize())
1408
+ : detail::c_strides(*shape, itemsize()),
1409
+ ptr,
1410
+ base) {}
1411
+
1412
+ explicit array_t(ssize_t count, const T *ptr = nullptr, handle base = handle())
1413
+ : array({count}, {}, ptr, base) {}
1414
+
1415
+ constexpr ssize_t itemsize() const { return sizeof(T); }
1416
+
1417
+ template <typename... Ix>
1418
+ ssize_t index_at(Ix... index) const {
1419
+ return offset_at(index...) / itemsize();
1420
+ }
1421
+
1422
+ template <typename... Ix>
1423
+ const T *data(Ix... index) const {
1424
+ return static_cast<const T *>(array::data(index...));
1425
+ }
1426
+
1427
+ template <typename... Ix>
1428
+ T *mutable_data(Ix... index) {
1429
+ return static_cast<T *>(array::mutable_data(index...));
1430
+ }
1431
+
1432
+ // Reference to element at a given index
1433
+ template <typename... Ix>
1434
+ const T &at(Ix... index) const {
1435
+ if ((ssize_t) sizeof...(index) != ndim()) {
1436
+ fail_dim_check(sizeof...(index), "index dimension mismatch");
1437
+ }
1438
+ return *(static_cast<const T *>(array::data())
1439
+ + byte_offset(ssize_t(index)...) / itemsize());
1440
+ }
1441
+
1442
+ // Mutable reference to element at a given index
1443
+ template <typename... Ix>
1444
+ T &mutable_at(Ix... index) {
1445
+ if ((ssize_t) sizeof...(index) != ndim()) {
1446
+ fail_dim_check(sizeof...(index), "index dimension mismatch");
1447
+ }
1448
+ return *(static_cast<T *>(array::mutable_data())
1449
+ + byte_offset(ssize_t(index)...) / itemsize());
1450
+ }
1451
+
1452
+ /**
1453
+ * Returns a proxy object that provides access to the array's data without bounds or
1454
+ * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
1455
+ * care: the array must not be destroyed or reshaped for the duration of the returned object,
1456
+ * and the caller must take care not to access invalid dimensions or dimension indices.
1457
+ */
1458
+ template <ssize_t Dims = -1>
1459
+ detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
1460
+ return array::mutable_unchecked<T, Dims>();
1461
+ }
1462
+
1463
+ /**
1464
+ * Returns a proxy object that provides const access to the array's data without bounds or
1465
+ * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
1466
+ * underlying array have the `writable` flag. Use with care: the array must not be destroyed
1467
+ * or reshaped for the duration of the returned object, and the caller must take care not to
1468
+ * access invalid dimensions or dimension indices.
1469
+ */
1470
+ template <ssize_t Dims = -1>
1471
+ detail::unchecked_reference<T, Dims> unchecked() const & {
1472
+ return array::unchecked<T, Dims>();
1473
+ }
1474
+
1475
+ /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
1476
+ /// it). In case of an error, nullptr is returned and the Python error is cleared.
1477
+ static array_t ensure(handle h) {
1478
+ auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
1479
+ if (!result) {
1480
+ PyErr_Clear();
1481
+ }
1482
+ return result;
1483
+ }
1484
+
1485
+ static bool check_(handle h) {
1486
+ const auto &api = detail::npy_api::get();
1487
+ return api.PyArray_Check_(h.ptr())
1488
+ && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr,
1489
+ dtype::of<T>().ptr())
1490
+ && detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
1491
+ }
1492
+
1493
+ protected:
1494
+ /// Create array from any object -- always returns a new reference
1495
+ static PyObject *raw_array_t(PyObject *ptr) {
1496
+ if (ptr == nullptr) {
1497
+ set_error(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
1498
+ return nullptr;
1499
+ }
1500
+ return detail::npy_api::get().PyArray_FromAny_(ptr,
1501
+ dtype::of<T>().release().ptr(),
1502
+ 0,
1503
+ 0,
1504
+ detail::npy_api::NPY_ARRAY_ENSUREARRAY_
1505
+ | ExtraFlags,
1506
+ nullptr);
1507
+ }
1508
+ };
1509
+
1510
+ template <typename T>
1511
+ struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
1512
+ static std::string format() {
1513
+ return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
1514
+ }
1515
+ };
1516
+
1517
+ template <size_t N>
1518
+ struct format_descriptor<char[N]> {
1519
+ static std::string format() { return std::to_string(N) + 's'; }
1520
+ };
1521
+ template <size_t N>
1522
+ struct format_descriptor<std::array<char, N>> {
1523
+ static std::string format() { return std::to_string(N) + 's'; }
1524
+ };
1525
+
1526
+ template <typename T>
1527
+ struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
1528
+ static std::string format() {
1529
+ return format_descriptor<
1530
+ typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
1531
+ }
1532
+ };
1533
+
1534
+ template <typename T>
1535
+ struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
1536
+ static std::string format() {
1537
+ using namespace detail;
1538
+ static constexpr auto extents = const_name("(") + array_info<T>::extents + const_name(")");
1539
+ return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
1540
+ }
1541
+ };
1542
+
1543
+ PYBIND11_NAMESPACE_BEGIN(detail)
1544
+ template <typename T, int ExtraFlags>
1545
+ struct pyobject_caster<array_t<T, ExtraFlags>> {
1546
+ using type = array_t<T, ExtraFlags>;
1547
+
1548
+ bool load(handle src, bool convert) {
1549
+ if (!convert && !type::check_(src)) {
1550
+ return false;
1551
+ }
1552
+ value = type::ensure(src);
1553
+ return static_cast<bool>(value);
1554
+ }
1555
+
1556
+ static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
1557
+ return src.inc_ref();
1558
+ }
1559
+ PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
1560
+ };
1561
+
1562
+ template <typename T>
1563
+ struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
1564
+ static bool compare(const buffer_info &b) {
1565
+ return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
1566
+ }
1567
+ };
1568
+
1569
+ template <typename T>
1570
+ struct npy_format_descriptor<
1571
+ T,
1572
+ enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
1573
+ : npy_format_descriptor_name<T> {
1574
+ private:
1575
+ // NB: the order here must match the one in common.h
1576
+ constexpr static const int values[15] = {npy_api::NPY_BOOL_,
1577
+ npy_api::NPY_BYTE_,
1578
+ npy_api::NPY_UBYTE_,
1579
+ npy_api::NPY_INT16_,
1580
+ npy_api::NPY_UINT16_,
1581
+ npy_api::NPY_INT32_,
1582
+ npy_api::NPY_UINT32_,
1583
+ npy_api::NPY_INT64_,
1584
+ npy_api::NPY_UINT64_,
1585
+ npy_api::NPY_FLOAT_,
1586
+ npy_api::NPY_DOUBLE_,
1587
+ npy_api::NPY_LONGDOUBLE_,
1588
+ npy_api::NPY_CFLOAT_,
1589
+ npy_api::NPY_CDOUBLE_,
1590
+ npy_api::NPY_CLONGDOUBLE_};
1591
+
1592
+ public:
1593
+ static constexpr int value = values[detail::is_fmt_numeric<T>::index];
1594
+
1595
+ static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
1596
+ };
1597
+
1598
+ template <typename T>
1599
+ struct npy_format_descriptor<
1600
+ T,
1601
+ enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value
1602
+ || ((std::is_same<T, handle>::value || std::is_same<T, object>::value)
1603
+ && sizeof(T) == sizeof(PyObject *))>> {
1604
+ static constexpr auto name = const_name("numpy.object_");
1605
+
1606
+ static constexpr int value = npy_api::NPY_OBJECT_;
1607
+
1608
+ static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
1609
+ };
1610
+
1611
+ #define PYBIND11_DECL_CHAR_FMT \
1612
+ static constexpr auto name = const_name("S") + const_name<N>(); \
1613
+ static pybind11::dtype dtype() { \
1614
+ return pybind11::dtype(std::string("S") + std::to_string(N)); \
1615
+ }
1616
+ template <size_t N>
1617
+ struct npy_format_descriptor<char[N]> {
1618
+ PYBIND11_DECL_CHAR_FMT
1619
+ };
1620
+ template <size_t N>
1621
+ struct npy_format_descriptor<std::array<char, N>> {
1622
+ PYBIND11_DECL_CHAR_FMT
1623
+ };
1624
+ #undef PYBIND11_DECL_CHAR_FMT
1625
+
1626
+ template <typename T>
1627
+ struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
1628
+ private:
1629
+ using base_descr = npy_format_descriptor<typename array_info<T>::type>;
1630
+
1631
+ public:
1632
+ static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
1633
+
1634
+ static constexpr auto name
1635
+ = const_name("(") + array_info<T>::extents + const_name(")") + base_descr::name;
1636
+ static pybind11::dtype dtype() {
1637
+ list shape;
1638
+ array_info<T>::append_extents(shape);
1639
+ return pybind11::dtype::from_args(
1640
+ pybind11::make_tuple(base_descr::dtype(), std::move(shape)));
1641
+ }
1642
+ };
1643
+
1644
+ template <typename T>
1645
+ struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
1646
+ private:
1647
+ using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
1648
+
1649
+ public:
1650
+ static constexpr auto name = base_descr::name;
1651
+ static pybind11::dtype dtype() { return base_descr::dtype(); }
1652
+ };
1653
+
1654
+ struct field_descriptor {
1655
+ const char *name;
1656
+ ssize_t offset;
1657
+ ssize_t size;
1658
+ std::string format;
1659
+ dtype descr;
1660
+ };
1661
+
1662
+ PYBIND11_NOINLINE void register_structured_dtype(any_container<field_descriptor> fields,
1663
+ const std::type_info &tinfo,
1664
+ ssize_t itemsize,
1665
+ bool (*direct_converter)(PyObject *, void *&)) {
1666
+
1667
+ auto &numpy_internals = get_numpy_internals();
1668
+ if (numpy_internals.get_type_info(tinfo, false)) {
1669
+ pybind11_fail("NumPy: dtype is already registered");
1670
+ }
1671
+
1672
+ // Use ordered fields because order matters as of NumPy 1.14:
1673
+ // https://docs.scipy.org/doc/numpy/release.html#multiple-field-indexing-assignment-of-structured-arrays
1674
+ std::vector<field_descriptor> ordered_fields(std::move(fields));
1675
+ std::sort(
1676
+ ordered_fields.begin(),
1677
+ ordered_fields.end(),
1678
+ [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
1679
+
1680
+ list names, formats, offsets;
1681
+ for (auto &field : ordered_fields) {
1682
+ if (!field.descr) {
1683
+ pybind11_fail(std::string("NumPy: unsupported field dtype: `") + field.name + "` @ "
1684
+ + tinfo.name());
1685
+ }
1686
+ names.append(pybind11::str(field.name));
1687
+ formats.append(field.descr);
1688
+ offsets.append(pybind11::int_(field.offset));
1689
+ }
1690
+ auto *dtype_ptr
1691
+ = pybind11::dtype(std::move(names), std::move(formats), std::move(offsets), itemsize)
1692
+ .release()
1693
+ .ptr();
1694
+
1695
+ // There is an existing bug in NumPy (as of v1.11): trailing bytes are
1696
+ // not encoded explicitly into the format string. This will supposedly
1697
+ // get fixed in v1.12; for further details, see these:
1698
+ // - https://github.com/numpy/numpy/issues/7797
1699
+ // - https://github.com/numpy/numpy/pull/7798
1700
+ // Because of this, we won't use numpy's logic to generate buffer format
1701
+ // strings and will just do it ourselves.
1702
+ ssize_t offset = 0;
1703
+ std::ostringstream oss;
1704
+ // mark the structure as unaligned with '^', because numpy and C++ don't
1705
+ // always agree about alignment (particularly for complex), and we're
1706
+ // explicitly listing all our padding. This depends on none of the fields
1707
+ // overriding the endianness. Putting the ^ in front of individual fields
1708
+ // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
1709
+ oss << "^T{";
1710
+ for (auto &field : ordered_fields) {
1711
+ if (field.offset > offset) {
1712
+ oss << (field.offset - offset) << 'x';
1713
+ }
1714
+ oss << field.format << ':' << field.name << ':';
1715
+ offset = field.offset + field.size;
1716
+ }
1717
+ if (itemsize > offset) {
1718
+ oss << (itemsize - offset) << 'x';
1719
+ }
1720
+ oss << '}';
1721
+ auto format_str = oss.str();
1722
+
1723
+ // Smoke test: verify that NumPy properly parses our buffer format string
1724
+ auto &api = npy_api::get();
1725
+ auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
1726
+ if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) {
1727
+ pybind11_fail("NumPy: invalid buffer descriptor!");
1728
+ }
1729
+
1730
+ auto tindex = std::type_index(tinfo);
1731
+ numpy_internals.registered_dtypes[tindex] = {dtype_ptr, std::move(format_str)};
1732
+ with_internals([tindex, &direct_converter](internals &internals) {
1733
+ internals.direct_conversions[tindex].push_back(direct_converter);
1734
+ });
1735
+ }
1736
+
1737
+ template <typename T, typename SFINAE>
1738
+ struct npy_format_descriptor {
1739
+ static_assert(is_pod_struct<T>::value,
1740
+ "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
1741
+
1742
+ static constexpr auto name = make_caster<T>::name;
1743
+
1744
+ static pybind11::dtype dtype() { return reinterpret_borrow<pybind11::dtype>(dtype_ptr()); }
1745
+
1746
+ static std::string format() {
1747
+ static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
1748
+ return format_str;
1749
+ }
1750
+
1751
+ static void register_dtype(any_container<field_descriptor> fields) {
1752
+ register_structured_dtype(std::move(fields),
1753
+ typeid(typename std::remove_cv<T>::type),
1754
+ sizeof(T),
1755
+ &direct_converter);
1756
+ }
1757
+
1758
+ private:
1759
+ static PyObject *dtype_ptr() {
1760
+ static PyObject *ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
1761
+ return ptr;
1762
+ }
1763
+
1764
+ static bool direct_converter(PyObject *obj, void *&value) {
1765
+ auto &api = npy_api::get();
1766
+ if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) {
1767
+ return false;
1768
+ }
1769
+ if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
1770
+ if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
1771
+ value = ((PyVoidScalarObject_Proxy *) obj)->obval;
1772
+ return true;
1773
+ }
1774
+ }
1775
+ return false;
1776
+ }
1777
+ };
1778
+
1779
+ #ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
1780
+ # define PYBIND11_NUMPY_DTYPE(Type, ...) ((void) 0)
1781
+ # define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void) 0)
1782
+ #else
1783
+
1784
+ # define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
1785
+ ::pybind11::detail::field_descriptor { \
1786
+ Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
1787
+ ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
1788
+ ::pybind11::detail::npy_format_descriptor< \
1789
+ decltype(std::declval<T>().Field)>::dtype() \
1790
+ }
1791
+
1792
+ // Extract name, offset and format descriptor for a struct field
1793
+ # define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
1794
+
1795
+ // The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
1796
+ // (C) William Swanson, Paul Fultz
1797
+ # define PYBIND11_EVAL0(...) __VA_ARGS__
1798
+ # define PYBIND11_EVAL1(...) PYBIND11_EVAL0(PYBIND11_EVAL0(PYBIND11_EVAL0(__VA_ARGS__)))
1799
+ # define PYBIND11_EVAL2(...) PYBIND11_EVAL1(PYBIND11_EVAL1(PYBIND11_EVAL1(__VA_ARGS__)))
1800
+ # define PYBIND11_EVAL3(...) PYBIND11_EVAL2(PYBIND11_EVAL2(PYBIND11_EVAL2(__VA_ARGS__)))
1801
+ # define PYBIND11_EVAL4(...) PYBIND11_EVAL3(PYBIND11_EVAL3(PYBIND11_EVAL3(__VA_ARGS__)))
1802
+ # define PYBIND11_EVAL(...) PYBIND11_EVAL4(PYBIND11_EVAL4(PYBIND11_EVAL4(__VA_ARGS__)))
1803
+ # define PYBIND11_MAP_END(...)
1804
+ # define PYBIND11_MAP_OUT
1805
+ # define PYBIND11_MAP_COMMA ,
1806
+ # define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
1807
+ # define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
1808
+ # define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0(test, next, 0)
1809
+ # define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1(PYBIND11_MAP_GET_END test, next)
1810
+ # if defined(_MSC_VER) \
1811
+ && !defined(__clang__) // MSVC is not as eager to expand macros, hence this workaround
1812
+ # define PYBIND11_MAP_LIST_NEXT1(test, next) \
1813
+ PYBIND11_EVAL0(PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0))
1814
+ # else
1815
+ # define PYBIND11_MAP_LIST_NEXT1(test, next) \
1816
+ PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0)
1817
+ # endif
1818
+ # define PYBIND11_MAP_LIST_NEXT(test, next) \
1819
+ PYBIND11_MAP_LIST_NEXT1(PYBIND11_MAP_GET_END test, next)
1820
+ # define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
1821
+ f(t, x) PYBIND11_MAP_LIST_NEXT(peek, PYBIND11_MAP_LIST1)(f, t, peek, __VA_ARGS__)
1822
+ # define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
1823
+ f(t, x) PYBIND11_MAP_LIST_NEXT(peek, PYBIND11_MAP_LIST0)(f, t, peek, __VA_ARGS__)
1824
+ // PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
1825
+ # define PYBIND11_MAP_LIST(f, t, ...) \
1826
+ PYBIND11_EVAL(PYBIND11_MAP_LIST1(f, t, __VA_ARGS__, (), 0))
1827
+
1828
+ # define PYBIND11_NUMPY_DTYPE(Type, ...) \
1829
+ ::pybind11::detail::npy_format_descriptor<Type>::register_dtype( \
1830
+ ::std::vector<::pybind11::detail::field_descriptor>{ \
1831
+ PYBIND11_MAP_LIST(PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
1832
+
1833
+ # if defined(_MSC_VER) && !defined(__clang__)
1834
+ # define PYBIND11_MAP2_LIST_NEXT1(test, next) \
1835
+ PYBIND11_EVAL0(PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0))
1836
+ # else
1837
+ # define PYBIND11_MAP2_LIST_NEXT1(test, next) \
1838
+ PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0)
1839
+ # endif
1840
+ # define PYBIND11_MAP2_LIST_NEXT(test, next) \
1841
+ PYBIND11_MAP2_LIST_NEXT1(PYBIND11_MAP_GET_END test, next)
1842
+ # define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
1843
+ f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT(peek, PYBIND11_MAP2_LIST1)(f, t, peek, __VA_ARGS__)
1844
+ # define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
1845
+ f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT(peek, PYBIND11_MAP2_LIST0)(f, t, peek, __VA_ARGS__)
1846
+ // PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
1847
+ # define PYBIND11_MAP2_LIST(f, t, ...) \
1848
+ PYBIND11_EVAL(PYBIND11_MAP2_LIST1(f, t, __VA_ARGS__, (), 0))
1849
+
1850
+ # define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
1851
+ ::pybind11::detail::npy_format_descriptor<Type>::register_dtype( \
1852
+ ::std::vector<::pybind11::detail::field_descriptor>{ \
1853
+ PYBIND11_MAP2_LIST(PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
1854
+
1855
+ #endif // __CLION_IDE__
1856
+
1857
+ class common_iterator {
1858
+ public:
1859
+ using container_type = std::vector<ssize_t>;
1860
+ using value_type = container_type::value_type;
1861
+ using size_type = container_type::size_type;
1862
+
1863
+ common_iterator() : m_strides() {}
1864
+
1865
+ common_iterator(void *ptr, const container_type &strides, const container_type &shape)
1866
+ : p_ptr(reinterpret_cast<char *>(ptr)), m_strides(strides.size()) {
1867
+ m_strides.back() = static_cast<value_type>(strides.back());
1868
+ for (size_type i = m_strides.size() - 1; i != 0; --i) {
1869
+ size_type j = i - 1;
1870
+ auto s = static_cast<value_type>(shape[i]);
1871
+ m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
1872
+ }
1873
+ }
1874
+
1875
+ void increment(size_type dim) { p_ptr += m_strides[dim]; }
1876
+
1877
+ void *data() const { return p_ptr; }
1878
+
1879
+ private:
1880
+ char *p_ptr{nullptr};
1881
+ container_type m_strides;
1882
+ };
1883
+
1884
+ template <size_t N>
1885
+ class multi_array_iterator {
1886
+ public:
1887
+ using container_type = std::vector<ssize_t>;
1888
+
1889
+ multi_array_iterator(const std::array<buffer_info, N> &buffers, const container_type &shape)
1890
+ : m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator() {
1891
+
1892
+ // Manual copy to avoid conversion warning if using std::copy
1893
+ for (size_t i = 0; i < shape.size(); ++i) {
1894
+ m_shape[i] = shape[i];
1895
+ }
1896
+
1897
+ container_type strides(shape.size());
1898
+ for (size_t i = 0; i < N; ++i) {
1899
+ init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
1900
+ }
1901
+ }
1902
+
1903
+ multi_array_iterator &operator++() {
1904
+ for (size_t j = m_index.size(); j != 0; --j) {
1905
+ size_t i = j - 1;
1906
+ if (++m_index[i] != m_shape[i]) {
1907
+ increment_common_iterator(i);
1908
+ break;
1909
+ }
1910
+ m_index[i] = 0;
1911
+ }
1912
+ return *this;
1913
+ }
1914
+
1915
+ template <size_t K, class T = void>
1916
+ T *data() const {
1917
+ return reinterpret_cast<T *>(m_common_iterator[K].data());
1918
+ }
1919
+
1920
+ private:
1921
+ using common_iter = common_iterator;
1922
+
1923
+ void init_common_iterator(const buffer_info &buffer,
1924
+ const container_type &shape,
1925
+ common_iter &iterator,
1926
+ container_type &strides) {
1927
+ auto buffer_shape_iter = buffer.shape.rbegin();
1928
+ auto buffer_strides_iter = buffer.strides.rbegin();
1929
+ auto shape_iter = shape.rbegin();
1930
+ auto strides_iter = strides.rbegin();
1931
+
1932
+ while (buffer_shape_iter != buffer.shape.rend()) {
1933
+ if (*shape_iter == *buffer_shape_iter) {
1934
+ *strides_iter = *buffer_strides_iter;
1935
+ } else {
1936
+ *strides_iter = 0;
1937
+ }
1938
+
1939
+ ++buffer_shape_iter;
1940
+ ++buffer_strides_iter;
1941
+ ++shape_iter;
1942
+ ++strides_iter;
1943
+ }
1944
+
1945
+ std::fill(strides_iter, strides.rend(), 0);
1946
+ iterator = common_iter(buffer.ptr, strides, shape);
1947
+ }
1948
+
1949
+ void increment_common_iterator(size_t dim) {
1950
+ for (auto &iter : m_common_iterator) {
1951
+ iter.increment(dim);
1952
+ }
1953
+ }
1954
+
1955
+ container_type m_shape;
1956
+ container_type m_index;
1957
+ std::array<common_iter, N> m_common_iterator;
1958
+ };
1959
+
1960
+ enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };
1961
+
1962
+ // Populates the shape and number of dimensions for the set of buffers. Returns a
1963
+ // broadcast_trivial enum value indicating whether the broadcast is "trivial"--that is, has each
1964
+ // buffer being either a singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous
1965
+ // (`f_trivial`) storage buffer; returns `non_trivial` otherwise.
1966
+ template <size_t N>
1967
+ broadcast_trivial
1968
+ broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
1969
+ ndim = std::accumulate(
1970
+ buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
1971
+ return std::max(res, buf.ndim);
1972
+ });
1973
+
1974
+ shape.clear();
1975
+ shape.resize((size_t) ndim, 1);
1976
+
1977
+ // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1
1978
+ // or the full size).
1979
+ for (size_t i = 0; i < N; ++i) {
1980
+ auto res_iter = shape.rbegin();
1981
+ auto end = buffers[i].shape.rend();
1982
+ for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end;
1983
+ ++shape_iter, ++res_iter) {
1984
+ const auto &dim_size_in = *shape_iter;
1985
+ auto &dim_size_out = *res_iter;
1986
+
1987
+ // Each input dimension can either be 1 or `n`, but `n` values must match across
1988
+ // buffers
1989
+ if (dim_size_out == 1) {
1990
+ dim_size_out = dim_size_in;
1991
+ } else if (dim_size_in != 1 && dim_size_in != dim_size_out) {
1992
+ pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1993
+ }
1994
+ }
1995
+ }
1996
+
1997
+ bool trivial_broadcast_c = true;
1998
+ bool trivial_broadcast_f = true;
1999
+ for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
2000
+ if (buffers[i].size == 1) {
2001
+ continue;
2002
+ }
2003
+
2004
+ // Require the same number of dimensions:
2005
+ if (buffers[i].ndim != ndim) {
2006
+ return broadcast_trivial::non_trivial;
2007
+ }
2008
+
2009
+ // Require all dimensions be full-size:
2010
+ if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) {
2011
+ return broadcast_trivial::non_trivial;
2012
+ }
2013
+
2014
+ // Check for C contiguity (but only if previous inputs were also C contiguous)
2015
+ if (trivial_broadcast_c) {
2016
+ ssize_t expect_stride = buffers[i].itemsize;
2017
+ auto end = buffers[i].shape.crend();
2018
+ for (auto shape_iter = buffers[i].shape.crbegin(),
2019
+ stride_iter = buffers[i].strides.crbegin();
2020
+ trivial_broadcast_c && shape_iter != end;
2021
+ ++shape_iter, ++stride_iter) {
2022
+ if (expect_stride == *stride_iter) {
2023
+ expect_stride *= *shape_iter;
2024
+ } else {
2025
+ trivial_broadcast_c = false;
2026
+ }
2027
+ }
2028
+ }
2029
+
2030
+ // Check for Fortran contiguity (if previous inputs were also F contiguous)
2031
+ if (trivial_broadcast_f) {
2032
+ ssize_t expect_stride = buffers[i].itemsize;
2033
+ auto end = buffers[i].shape.cend();
2034
+ for (auto shape_iter = buffers[i].shape.cbegin(),
2035
+ stride_iter = buffers[i].strides.cbegin();
2036
+ trivial_broadcast_f && shape_iter != end;
2037
+ ++shape_iter, ++stride_iter) {
2038
+ if (expect_stride == *stride_iter) {
2039
+ expect_stride *= *shape_iter;
2040
+ } else {
2041
+ trivial_broadcast_f = false;
2042
+ }
2043
+ }
2044
+ }
2045
+ }
2046
+
2047
+ return trivial_broadcast_c ? broadcast_trivial::c_trivial
2048
+ : trivial_broadcast_f ? broadcast_trivial::f_trivial
2049
+ : broadcast_trivial::non_trivial;
2050
+ }
2051
+
2052
+ template <typename T>
2053
+ struct vectorize_arg {
2054
+ static_assert(!std::is_rvalue_reference<T>::value,
2055
+ "Functions with rvalue reference arguments cannot be vectorized");
2056
+ // The wrapped function gets called with this type:
2057
+ using call_type = remove_reference_t<T>;
2058
+ // Is this a vectorized argument?
2059
+ static constexpr bool vectorize
2060
+ = satisfies_any_of<call_type, std::is_arithmetic, is_complex, is_pod>::value
2061
+ && satisfies_none_of<call_type,
2062
+ std::is_pointer,
2063
+ std::is_array,
2064
+ is_std_array,
2065
+ std::is_enum>::value
2066
+ && (!std::is_reference<T>::value
2067
+ || (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
2068
+ // Accept this type: an array for vectorized types, otherwise the type as-is:
2069
+ using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
2070
+ };
2071
+
2072
+ // py::vectorize when a return type is present
2073
+ template <typename Func, typename Return, typename... Args>
2074
+ struct vectorize_returned_array {
2075
+ using Type = array_t<Return>;
2076
+
2077
+ static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
2078
+ if (trivial == broadcast_trivial::f_trivial) {
2079
+ return array_t<Return, array::f_style>(shape);
2080
+ }
2081
+ return array_t<Return>(shape);
2082
+ }
2083
+
2084
+ static Return *mutable_data(Type &array) { return array.mutable_data(); }
2085
+
2086
+ static Return call(Func &f, Args &...args) { return f(args...); }
2087
+
2088
+ static void call(Return *out, size_t i, Func &f, Args &...args) { out[i] = f(args...); }
2089
+ };
2090
+
2091
+ // py::vectorize when a return type is not present
2092
+ template <typename Func, typename... Args>
2093
+ struct vectorize_returned_array<Func, void, Args...> {
2094
+ using Type = none;
2095
+
2096
+ static Type create(broadcast_trivial, const std::vector<ssize_t> &) { return none(); }
2097
+
2098
+ static void *mutable_data(Type &) { return nullptr; }
2099
+
2100
+ static detail::void_type call(Func &f, Args &...args) {
2101
+ f(args...);
2102
+ return {};
2103
+ }
2104
+
2105
+ static void call(void *, size_t, Func &f, Args &...args) { f(args...); }
2106
+ };
2107
+
2108
+ template <typename Func, typename Return, typename... Args>
2109
+ struct vectorize_helper {
2110
+
2111
+ // NVCC for some reason breaks if NVectorized is private
2112
+ #ifdef __CUDACC__
2113
+ public:
2114
+ #else
2115
+ private:
2116
+ #endif
2117
+
2118
+ static constexpr size_t N = sizeof...(Args);
2119
+ static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
2120
+ static_assert(
2121
+ NVectorized >= 1,
2122
+ "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
2123
+
2124
+ public:
2125
+ template <typename T,
2126
+ // SFINAE to prevent shadowing the copy constructor.
2127
+ typename = detail::enable_if_t<
2128
+ !std::is_same<vectorize_helper, typename std::decay<T>::type>::value>>
2129
+ explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) {}
2130
+
2131
+ object operator()(typename vectorize_arg<Args>::type... args) {
2132
+ return run(args...,
2133
+ make_index_sequence<N>(),
2134
+ select_indices<vectorize_arg<Args>::vectorize...>(),
2135
+ make_index_sequence<NVectorized>());
2136
+ }
2137
+
2138
+ private:
2139
+ remove_reference_t<Func> f;
2140
+
2141
+ // Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling
2142
+ // with "/permissive-" flag when arg_call_types is manually inlined.
2143
+ using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
2144
+ template <size_t Index>
2145
+ using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
2146
+
2147
+ using returned_array = vectorize_returned_array<Func, Return, Args...>;
2148
+
2149
+ // Runs a vectorized function given arguments tuple and three index sequences:
2150
+ // - Index is the full set of 0 ... (N-1) argument indices;
2151
+ // - VIndex is the subset of argument indices with vectorized parameters, letting us access
2152
+ // vectorized arguments (anything not in this sequence is passed through)
2153
+ // - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that
2154
+ // we can store vectorized buffer_infos in an array (argument VIndex has its buffer at
2155
+ // index BIndex in the array).
2156
+ template <size_t... Index, size_t... VIndex, size_t... BIndex>
2157
+ object run(typename vectorize_arg<Args>::type &...args,
2158
+ index_sequence<Index...> i_seq,
2159
+ index_sequence<VIndex...> vi_seq,
2160
+ index_sequence<BIndex...> bi_seq) {
2161
+
2162
+ // Pointers to values the function was called with; the vectorized ones set here will start
2163
+ // out as array_t<T> pointers, but they will be changed them to T pointers before we make
2164
+ // call the wrapped function. Non-vectorized pointers are left as-is.
2165
+ std::array<void *, N> params{{reinterpret_cast<void *>(&args)...}};
2166
+
2167
+ // The array of `buffer_info`s of vectorized arguments:
2168
+ std::array<buffer_info, NVectorized> buffers{
2169
+ {reinterpret_cast<array *>(params[VIndex])->request()...}};
2170
+
2171
+ /* Determine dimensions parameters of output array */
2172
+ ssize_t nd = 0;
2173
+ std::vector<ssize_t> shape(0);
2174
+ auto trivial = broadcast(buffers, nd, shape);
2175
+ auto ndim = (size_t) nd;
2176
+
2177
+ size_t size
2178
+ = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());
2179
+
2180
+ // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e.
2181
+ // not wrapped in an array).
2182
+ if (size == 1 && ndim == 0) {
2183
+ PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
2184
+ return cast(
2185
+ returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
2186
+ }
2187
+
2188
+ auto result = returned_array::create(trivial, shape);
2189
+
2190
+ PYBIND11_WARNING_PUSH
2191
+ #ifdef PYBIND11_DETECTED_CLANG_WITH_MISLEADING_CALL_STD_MOVE_EXPLICITLY_WARNING
2192
+ PYBIND11_WARNING_DISABLE_CLANG("-Wreturn-std-move")
2193
+ #endif
2194
+
2195
+ if (size == 0) {
2196
+ return result;
2197
+ }
2198
+
2199
+ /* Call the function */
2200
+ auto *mutable_data = returned_array::mutable_data(result);
2201
+ if (trivial == broadcast_trivial::non_trivial) {
2202
+ apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
2203
+ } else {
2204
+ apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
2205
+ }
2206
+
2207
+ return result;
2208
+ PYBIND11_WARNING_POP
2209
+ }
2210
+
2211
+ template <size_t... Index, size_t... VIndex, size_t... BIndex>
2212
+ void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
2213
+ std::array<void *, N> &params,
2214
+ Return *out,
2215
+ size_t size,
2216
+ index_sequence<Index...>,
2217
+ index_sequence<VIndex...>,
2218
+ index_sequence<BIndex...>) {
2219
+
2220
+ // Initialize an array of mutable byte references and sizes with references set to the
2221
+ // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size
2222
+ // (except for singletons, which get an increment of 0).
2223
+ std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{
2224
+ {std::pair<unsigned char *&, const size_t>(
2225
+ reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
2226
+ buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>))...}};
2227
+
2228
+ for (size_t i = 0; i < size; ++i) {
2229
+ returned_array::call(
2230
+ out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
2231
+ for (auto &x : vecparams) {
2232
+ x.first += x.second;
2233
+ }
2234
+ }
2235
+ }
2236
+
2237
+ template <size_t... Index, size_t... VIndex, size_t... BIndex>
2238
+ void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
2239
+ std::array<void *, N> &params,
2240
+ Return *out,
2241
+ size_t size,
2242
+ const std::vector<ssize_t> &output_shape,
2243
+ index_sequence<Index...>,
2244
+ index_sequence<VIndex...>,
2245
+ index_sequence<BIndex...>) {
2246
+
2247
+ multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
2248
+
2249
+ for (size_t i = 0; i < size; ++i, ++input_iter) {
2250
+ PYBIND11_EXPAND_SIDE_EFFECTS((params[VIndex] = input_iter.template data<BIndex>()));
2251
+ returned_array::call(
2252
+ out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
2253
+ }
2254
+ }
2255
+ };
2256
+
2257
+ template <typename Func, typename Return, typename... Args>
2258
+ vectorize_helper<Func, Return, Args...> vectorize_extractor(const Func &f, Return (*)(Args...)) {
2259
+ return detail::vectorize_helper<Func, Return, Args...>(f);
2260
+ }
2261
+
2262
+ template <typename T, int Flags>
2263
+ struct handle_type_name<array_t<T, Flags>> {
2264
+ static constexpr auto name
2265
+ = io_name("typing.Annotated[numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
2266
+ + npy_format_descriptor<T>::name + const_name("]");
2267
+ };
2268
+
2269
+ PYBIND11_NAMESPACE_END(detail)
2270
+
2271
+ // Vanilla pointer vectorizer:
2272
+ template <typename Return, typename... Args>
2273
+ detail::vectorize_helper<Return (*)(Args...), Return, Args...> vectorize(Return (*f)(Args...)) {
2274
+ return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
2275
+ }
2276
+
2277
+ // lambda vectorizer:
2278
+ template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
2279
+ auto vectorize(Func &&f)
2280
+ -> decltype(detail::vectorize_extractor(std::forward<Func>(f),
2281
+ (detail::function_signature_t<Func> *) nullptr)) {
2282
+ return detail::vectorize_extractor(std::forward<Func>(f),
2283
+ (detail::function_signature_t<Func> *) nullptr);
2284
+ }
2285
+
2286
+ // Vectorize a class method (non-const):
2287
+ template <typename Return,
2288
+ typename Class,
2289
+ typename... Args,
2290
+ typename Helper = detail::vectorize_helper<
2291
+ decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())),
2292
+ Return,
2293
+ Class *,
2294
+ Args...>>
2295
+ Helper vectorize(Return (Class::*f)(Args...)) {
2296
+ return Helper(std::mem_fn(f));
2297
+ }
2298
+
2299
+ // Vectorize a class method (const):
2300
+ template <typename Return,
2301
+ typename Class,
2302
+ typename... Args,
2303
+ typename Helper = detail::vectorize_helper<
2304
+ decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())),
2305
+ Return,
2306
+ const Class *,
2307
+ Args...>>
2308
+ Helper vectorize(Return (Class::*f)(Args...) const) {
2309
+ return Helper(std::mem_fn(f));
2310
+ }
2311
+
2312
+ PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)