pytensor 3.0.0__tar.gz → 3.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytensor-3.0.0/pytensor.egg-info → pytensor-3.0.2}/PKG-INFO +1 -1
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/__init__.py +18 -1
- pytensor-3.0.2/pytensor/_sparse_lazy.py +31 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/_version.py +3 -3
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/basic.py +4 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/builders.py +65 -1
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/sharedvalue.py +4 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/features.py +30 -26
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/elemwise.py +1 -12
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/linalg/solvers.py +15 -1
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/elemwise.py +1 -12
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/compile_ops.py +4 -2
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/cython_support.py +5 -2
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/elemwise.py +0 -119
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/elemwise.py +29 -28
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/printing.py +28 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scalar/math.py +5 -1
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/basic.py +34 -51
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/blas.py +4 -8
- pytensor-3.0.2/pytensor/tensor/linalg/_lazy.py +12 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/constructors.py +1 -2
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/decomposition/cholesky.py +1 -1
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/decomposition/eigen.py +1 -1
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/decomposition/lu.py +2 -3
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/decomposition/qr.py +4 -4
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/decomposition/schur.py +2 -3
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/products.py +28 -34
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/solvers/general.py +1 -1
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/solvers/linear_control.py +2 -2
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/solvers/psd.py +2 -2
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/solvers/triangular.py +2 -2
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/solvers/tridiagonal.py +5 -5
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/basic.py +3 -6
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/basic.py +225 -190
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/elemwise.py +4 -10
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/linalg/solvers.py +100 -1
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/math.py +195 -237
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/ofg.py +2 -1
- pytensor-3.0.2/pytensor/tensor/rewriting/special.py +99 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/subtensor_lift.py +3 -3
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/signal/conv.py +6 -2
- pytensor-3.0.2/pytensor/tensor/special.py +238 -0
- pytensor-3.0.2/pytensor/tensor/symbolic.py +10 -0
- pytensor-3.0.2/pytensor/tensor/xlogx.py +37 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/utils.py +75 -0
- {pytensor-3.0.0 → pytensor-3.0.2/pytensor.egg-info}/PKG-INFO +1 -1
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor.egg-info/SOURCES.txt +3 -0
- pytensor-3.0.2/tests/test_basic.py +65 -0
- pytensor-3.0.0/pytensor/tensor/rewriting/special.py +0 -175
- pytensor-3.0.0/pytensor/tensor/special.py +0 -819
- pytensor-3.0.0/pytensor/tensor/xlogx.py +0 -66
- pytensor-3.0.0/tests/test_basic.py +0 -34
- {pytensor-3.0.0 → pytensor-3.0.2}/LICENSE.txt +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/MANIFEST.in +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/README.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/.templates/PLACEHOLDER +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/.templates/layout.html +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/.templates/nb-badges.html +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/.templates/rendered_citation.html +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/LICENSE.txt +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/README.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/_thumbnails/autodiff/vector_jacobian_product.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/acknowledgement.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/bcast.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/bcast.svg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/blog.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/conf.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/core_development_guide.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/css.inc +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/dev_start_guide.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/environment.yml +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/apply.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/apply.svg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/apply2.svg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/creating_a_c_op.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/creating_a_numba_jax_op.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/creating_an_op.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/ctype.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/extending_faq.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/extending_pytensor_solution_1.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/graph_rewriting.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/graphstructures.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/inplace.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/op.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/other_ops.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/pics/symbolic_graph_opt.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/pics/symbolic_graph_unopt.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/pipeline.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/scan.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/tips.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/type.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/unittest.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/extending/using_params.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/faq.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/gallery/applications/normalizing_flows_in_pytensor.ipynb +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/gallery/autodiff/vector_jacobian_product.ipynb +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/gallery/introduction/pytensor_intro.ipynb +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/gallery/optimize/root.ipynb +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/gallery/page_footer.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/gallery/rewrites/graph_rewrites.ipynb +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/gallery/scan/scan_tutorial.ipynb +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/generate_dtype_tensor_table.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/glossary.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/Elman_srnn.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/PyTensor.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/PyTensor_RGB.svg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/PyTensor_logo.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/binder.svg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/blocksparse.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/colab.svg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/github.svg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/lstm.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/lstm_memorycell.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/talk2010.gif +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/images/talk2010.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/install.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/internal/how_to_release.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/internal/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/internal/metadocumentation.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/introduction.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/debugmode.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/function.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/io.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/mode.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/nanguardmode.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/opfromgraph.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/ops.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/profilemode.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/compile/shared.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/config.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/d3viz/css/d3-context-menu.css +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/d3viz/css/d3viz.css +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/d3viz/js/d3-context-menu.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/d3viz/js/d3.v3.min.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/d3viz/js/d3viz.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/d3viz/js/dagre-d3.min.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/d3viz/js/graphlib-dot.min.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/mlp.html +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/mlp.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/mlp2.html +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/mlp2.pdf +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/mlp2.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/ofg.html +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/examples/ofg2.html +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/index.ipynb +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/index_files/index_10_0.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/index_files/index_11_0.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/index_files/index_24_0.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/d3viz/index_files/index_25_0.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/graph/features.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/graph/fgraph.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/graph/graph.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/graph/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/graph/op.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/graph/replace.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/graph/type.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/graph/utils.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/misc/pkl_utils.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/printing.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/scalar/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/scan.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/sparse/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/sparse/sandbox.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/basic.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/basic_opt.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/bcast.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/bcast.svg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/elemwise.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/extra_ops.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/fft.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/functional.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/io.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/linalg.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/math_opt.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/nlinalg.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/optimize.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/plot_fft.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/random.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/slinalg.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/tensor/utils.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/typed_list.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/xtensor/index.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/xtensor/linalg.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/xtensor/math.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/xtensor/module_functions.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/xtensor/random.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/xtensor/signal.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/library/xtensor/type.md +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/links.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/optimizations.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/pylintrc +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/robots.txt +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/troubleshooting.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/adding.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/adding_solution_1.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/aliasing.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/apply.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/apply.svg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/bcast.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/broadcasting.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/conditions.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/debug_faq.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/dlogistic.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/examples.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/faq_tutorial.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/gradients.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/index.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/loading_and_saving.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/logistic.gp +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/logistic.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/loop.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/loop_solution_1.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/modes.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/modes_solution_1.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/multi_cores.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/nan_tutorial.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/pics/d3viz.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/pics/logreg_pydotprint_predict.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/pics/logreg_pydotprint_prediction.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/pics/logreg_pydotprint_train.png +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/printing_drawing.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/prng.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/profiling.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/profiling_example.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/profiling_example_out.prof +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/shape_info.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/sparse.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/tutorial/symbolic_graphs.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/doc/user_guide.rst +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pyproject.toml +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/bin/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/bin/pytensor_cache.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/breakpoint.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/aliasing.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/compiledir.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/compilelock.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/debug/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/debug/debugmode.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/debug/dump.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/debug/monitormode.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/debug/nanguardmode.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/debug/profiling.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/executor.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/io.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/maker.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/mode.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/ops.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/compile/rebuild.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/configdefaults.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/configparser.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/css/d3-context-menu.css +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/css/d3viz.css +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/d3viz.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/formatting.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/html/template.html +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/js/d3-context-menu.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/js/d3.v3.min.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/js/d3viz.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/js/dagre-d3.min.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/d3viz/js/graphlib-dot.min.js +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/gradient.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/destroyhandler.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/fg.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/null_type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/op.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/replace.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/rewriting/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/rewriting/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/rewriting/db.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/rewriting/kanren.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/rewriting/unify.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/rewriting/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/traversal.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/graph/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/ifelse.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/ipython.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/c_code/lazylinker_c.c +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/c_code/pytensor_mod_helper.h +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/cmodule.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/cutils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/cvm.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/exceptions.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/interface.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/lazylinker_c.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/op.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/params_type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/c/type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/blas.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/blockwise.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/einsum.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/extra_ops.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/linalg/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/linalg/constructors.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/linalg/decomposition.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/linalg/inverse.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/linalg/products.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/linalg/summary.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/math.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/pad.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/random.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/scalar.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/scan.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/shape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/signal/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/signal/conv.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/sort.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/sparse.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/subtensor.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/dispatch/tensor_basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/linker.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/jax/ops.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/blas.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/blockwise.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/einsum.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/extra_ops.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/linalg/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/linalg/decomposition.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/linalg/inverse.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/linalg/products.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/linalg/solvers.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/linalg/summary.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/math.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/pad.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/scalar.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/shape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/signal/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/signal/conv.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/sort.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/subtensor.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/dispatch/tensor_basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/mlx/linker.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/cache.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/blockwise.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/extra_ops.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/_LAPACK.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/constructors.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/decomposition/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/decomposition/dispatch.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/decomposition/eigen.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/decomposition/lu.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/decomposition/qr.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/decomposition/qz.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/decomposition/schur.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/inverse.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/cholesky.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/dispatch.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/general.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/hermitian.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/linear_control.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/lu_solve.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/posdef.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/symmetric.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/triangular.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/tridiagonal.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/solvers/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/summary.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/linalg/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/random.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/scalar.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/scan.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/shape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/signal/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/signal/conv.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/sort.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/sparse/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/sparse/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/sparse/math.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/sparse/variable.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/string_codegen.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/subtensor.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/tensor_basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/typed_list.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/dispatch/vectorize_codegen.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/numba/linker.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/blas.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/blockwise.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/extra_ops.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/linalg/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/linalg/decomposition.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/linalg/inverse.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/linalg/products.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/linalg/summary.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/math.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/scalar.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/shape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/sort.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/dispatch/subtensor.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/pytorch/linker.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/link/vm.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/check_blas.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/check_blas_many.sh +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/check_duplicate_key.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/elemwise_openmp_speedup.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/elemwise_time_test.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/frozendict.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/may_share_memory.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/ordered_set.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/misc/pkl_utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/npy_2_compat.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/py.typed +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/raise_op.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scalar/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scalar/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scalar/c_code/Faddeeva.cc +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scalar/c_code/Faddeeva.hh +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scalar/c_code/gamma.c +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scalar/c_code/incbet.c +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scalar/loop.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scalar/sharedvar.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scan/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scan/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scan/checkpoints.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scan/op.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scan/rewriting.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scan/scan_perform.pyx +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scan/scan_perform_ext.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scan/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/scan/views.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/sparse/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/sparse/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/sparse/linalg.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/sparse/math.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/sparse/rewriting.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/sparse/sharedvar.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/sparse/type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/sparse/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/sparse/variable.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/blas_c.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/blas_headers.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/blockwise.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/c_code/alt_blas_common.h +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/c_code/alt_blas_template.c +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/c_code/dimshuffle.c +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/einsum.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/elemwise.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/elemwise_cgen.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/exceptions.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/extra_ops.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/fft.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/fourier.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/functional.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/interpolate.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/decomposition/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/decomposition/svd.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/dtype_utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/inverse.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/solvers/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/solvers/core.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/solvers/lstsq.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/linalg/summary.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/math.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/nlinalg.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/optimize.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/pad.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/op.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/rewriting/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/rewriting/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/rewriting/jax.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/rewriting/numba.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/random/variable.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/reshape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/blas.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/blas_c.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/blockwise.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/einsum.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/extra_ops.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/jax.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/linalg/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/linalg/decomposition.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/linalg/inverse.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/linalg/products.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/linalg/summary.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/linalg/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/numba.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/optimize.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/reshape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/shape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/subtensor.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/rewriting/uncanonicalize.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/shape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/sharedvar.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/signal/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/slinalg.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/sort.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/subtensor.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/type_other.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/var.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/tensor/variable.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/typed_list/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/typed_list/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/typed_list/rewriting.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/typed_list/type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/indexing.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/linalg.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/math.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/random/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/random/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/random/type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/random/variable.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/reduction.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/rewriting/__init__.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/rewriting/basic.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/rewriting/indexing.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/rewriting/math.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/rewriting/reduction.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/rewriting/shape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/rewriting/utils.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/rewriting/vectorization.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/shape.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/signal.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/type.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor/xtensor/vectorization.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor.egg-info/dependency_links.txt +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor.egg-info/entry_points.txt +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor.egg-info/requires.txt +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/pytensor.egg-info/top_level.txt +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/scripts/mypy-failing.txt +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/scripts/slowest_tests/update-slowest-times-issue.sh +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/setup.cfg +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/setup.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/tests/link/c/c_code/test_cenum.h +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/tests/link/c/c_code/test_quadratic_function.c +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/tests/test_breakpoint.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/tests/test_config.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/tests/test_gradient.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/tests/test_ifelse.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/tests/test_printing.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/tests/test_raise_op.py +0 -0
- {pytensor-3.0.0 → pytensor-3.0.2}/tests/test_rop.py +0 -0
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
|
|
1
4
|
__docformat__ = "restructuredtext en"
|
|
2
5
|
|
|
3
6
|
|
|
@@ -14,7 +17,6 @@ from pytensor.configdefaults import config
|
|
|
14
17
|
|
|
15
18
|
# isort: off
|
|
16
19
|
from pytensor import tensor
|
|
17
|
-
from pytensor import sparse
|
|
18
20
|
from pytensor.compile import (
|
|
19
21
|
In,
|
|
20
22
|
Mode,
|
|
@@ -32,9 +34,24 @@ from pytensor.scan.basic import scan
|
|
|
32
34
|
from pytensor.scan.views import map
|
|
33
35
|
from pytensor.compile.builders import OpFromGraph
|
|
34
36
|
from pytensor.link.jax.ops import wrap_jax
|
|
37
|
+
from pytensor import _sparse_lazy
|
|
35
38
|
# isort: on
|
|
36
39
|
|
|
37
40
|
|
|
41
|
+
def __getattr__(name):
|
|
42
|
+
if name == "sparse":
|
|
43
|
+
# During pytensor.sparse's own import, submodules may do
|
|
44
|
+
# `import pytensor.sparse.X as Y` which probes pytensor.sparse via
|
|
45
|
+
# getattr before the parent attribute has been set. Return the
|
|
46
|
+
# partially-loaded module from sys.modules to avoid re-entry.
|
|
47
|
+
if "pytensor.sparse" in sys.modules:
|
|
48
|
+
return sys.modules["pytensor.sparse"]
|
|
49
|
+
import pytensor.sparse as sparse
|
|
50
|
+
|
|
51
|
+
return sparse
|
|
52
|
+
raise AttributeError(f"module 'pytensor' has no attribute {name!r}")
|
|
53
|
+
|
|
54
|
+
|
|
38
55
|
# Some config variables are registered by submodules. Only after all those
|
|
39
56
|
# imports were executed, we can warn about remaining flags provided by the user
|
|
40
57
|
# through PYTENSOR_FLAGS.
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Lazy registration of scipy.sparse handlers on pytensor's dispatchers.
|
|
2
|
+
|
|
3
|
+
Imported by `pytensor/__init__.py` so `pytensor.sparse` doesn't have to be
|
|
4
|
+
loaded eagerly at startup. The fallbacks match `type(x).__module__` against
|
|
5
|
+
`scipy.sparse` as a string, so this module itself doesn't import scipy.sparse;
|
|
6
|
+
the real handlers (and their scipy.sparse dependency) are pulled in only when
|
|
7
|
+
an actual scipy.sparse value is passed to `as_symbolic` / `shared`.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from pytensor.basic import _as_symbolic
|
|
11
|
+
from pytensor.compile.sharedvalue import shared_constructor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _lazy_as_symbolic_sparse(x):
|
|
15
|
+
if type(x).__module__.startswith("scipy.sparse"):
|
|
16
|
+
from pytensor.sparse.basic import as_symbolic_sparse
|
|
17
|
+
|
|
18
|
+
return as_symbolic_sparse
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _lazy_shared_sparse(x):
|
|
23
|
+
if type(x).__module__.startswith("scipy.sparse"):
|
|
24
|
+
from pytensor.sparse.sharedvar import sparse_constructor
|
|
25
|
+
|
|
26
|
+
return sparse_constructor
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_as_symbolic.register_lazy(_lazy_as_symbolic_sparse) # type: ignore[attr-defined]
|
|
31
|
+
shared_constructor.register_lazy(_lazy_shared_sparse) # type: ignore[attr-defined]
|
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2026-05-
|
|
11
|
+
"date": "2026-05-12T11:06:13+0200",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "3.0.
|
|
14
|
+
"full-revisionid": "a3e786e2a595d6b686b43686cc6587794b0c286d",
|
|
15
|
+
"version": "3.0.2"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -2,6 +2,7 @@ from functools import singledispatch
|
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
4
|
from pytensor.graph import Variable
|
|
5
|
+
from pytensor.utils import add_lazy_dispatcher
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable:
|
|
@@ -40,3 +41,6 @@ def _as_symbolic(x: Any, **kwargs) -> Variable:
|
|
|
40
41
|
from pytensor.tensor import as_tensor_variable
|
|
41
42
|
|
|
42
43
|
return as_tensor_variable(x, **kwargs)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
add_lazy_dispatcher(_as_symbolic)
|
|
@@ -11,6 +11,7 @@ from itertools import chain
|
|
|
11
11
|
from typing import cast
|
|
12
12
|
|
|
13
13
|
from pytensor.compile.maker import function
|
|
14
|
+
from pytensor.compile.mode import get_mode
|
|
14
15
|
from pytensor.compile.rebuild import rebuild_collect_shared
|
|
15
16
|
from pytensor.compile.sharedvalue import SharedVariable
|
|
16
17
|
from pytensor.gradient import DisconnectedType, disconnected_type, grad, pushforward
|
|
@@ -917,7 +918,9 @@ class OpFromGraph(Op, HasInnerGraph):
|
|
|
917
918
|
if getattr(self, "_fn", None) is not None:
|
|
918
919
|
return self._fn
|
|
919
920
|
|
|
920
|
-
|
|
921
|
+
kwargs = self.kwargs.copy()
|
|
922
|
+
mode = get_mode(kwargs.pop("mode", None)).excluding("symbolic_op_recognition")
|
|
923
|
+
self._fn = function(self.inner_inputs, self.inner_outputs, mode=mode, **kwargs)
|
|
921
924
|
self._fn.trust_input = True
|
|
922
925
|
|
|
923
926
|
return self._fn
|
|
@@ -940,3 +943,64 @@ class OpFromGraph(Op, HasInnerGraph):
|
|
|
940
943
|
# zip strict not specified because we are in a hot loop
|
|
941
944
|
for output, variable in zip(outputs, variables):
|
|
942
945
|
output[0] = variable
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
class SymbolicOp(OpFromGraph):
|
|
949
|
+
r"""OpFromGraph subclass that builds the inner graph from input types.
|
|
950
|
+
|
|
951
|
+
Subclasses define the forward graph via :meth:`build_inner_graph` and
|
|
952
|
+
optionally override :meth:`pullback` / :meth:`pushforward`.
|
|
953
|
+
|
|
954
|
+
Override :meth:`filter_inputs` to coerce raw arguments (e.g. Python
|
|
955
|
+
scalars) into typed Variables at call sites.
|
|
956
|
+
|
|
957
|
+
Set the class attribute ``inline`` to control whether the inner graph is
|
|
958
|
+
inlined during compilation (default ``False``).
|
|
959
|
+
"""
|
|
960
|
+
|
|
961
|
+
inline: bool = False
|
|
962
|
+
|
|
963
|
+
def __init_subclass__(cls, **kwargs):
|
|
964
|
+
super().__init_subclass__(**kwargs)
|
|
965
|
+
if "__props__" in cls.__dict__:
|
|
966
|
+
# MetaType installs props-only __hash__ and __eq__ which ignores the inner graph
|
|
967
|
+
# override with fgraph-aware version
|
|
968
|
+
cls.__hash__ = OpFromGraph.__hash__
|
|
969
|
+
cls.__eq__ = OpFromGraph.__eq__
|
|
970
|
+
|
|
971
|
+
@staticmethod
|
|
972
|
+
def filter_inputs(*inputs):
|
|
973
|
+
return inputs
|
|
974
|
+
|
|
975
|
+
def build_inner_graph(self, *inputs) -> list[Variable]:
|
|
976
|
+
raise NotImplementedError
|
|
977
|
+
|
|
978
|
+
def __init__(self, input_types=None, **kwargs):
|
|
979
|
+
"""Construct op for the given input Types.
|
|
980
|
+
|
|
981
|
+
When input_types is None, construction is deferred until the first
|
|
982
|
+
__call__, which inspects the actual input types and builds the graph.
|
|
983
|
+
"""
|
|
984
|
+
for prop in getattr(type(self), "__props__", ()):
|
|
985
|
+
if prop in kwargs:
|
|
986
|
+
setattr(self, prop, kwargs.pop(prop))
|
|
987
|
+
self._init_kwargs = kwargs
|
|
988
|
+
if input_types is not None:
|
|
989
|
+
kwargs.setdefault("inline", type(self).inline)
|
|
990
|
+
kwargs.setdefault("strict", True)
|
|
991
|
+
dummy_inputs = [t() for t in input_types]
|
|
992
|
+
outputs = self.build_inner_graph(*dummy_inputs)
|
|
993
|
+
super().__init__(dummy_inputs, outputs, **kwargs)
|
|
994
|
+
|
|
995
|
+
def __call__(self, *inputs, **kwargs):
|
|
996
|
+
inputs = self.filter_inputs(*inputs)
|
|
997
|
+
input_types = tuple(inp.type for inp in inputs)
|
|
998
|
+
|
|
999
|
+
if hasattr(self, "fgraph") and input_types == tuple(self.input_types):
|
|
1000
|
+
return super().__call__(*inputs, **kwargs)
|
|
1001
|
+
|
|
1002
|
+
init_kwargs = dict(self._init_kwargs)
|
|
1003
|
+
for prop in getattr(type(self), "__props__", ()):
|
|
1004
|
+
init_kwargs[prop] = getattr(self, prop)
|
|
1005
|
+
op = type(self)(input_types=list(input_types), **init_kwargs)
|
|
1006
|
+
return super(SymbolicOp, op).__call__(*inputs, **kwargs)
|
|
@@ -10,6 +10,7 @@ from pytensor.graph.basic import Variable
|
|
|
10
10
|
from pytensor.graph.utils import add_tag_trace
|
|
11
11
|
from pytensor.link.basic import Container
|
|
12
12
|
from pytensor.link.c.type import generic
|
|
13
|
+
from pytensor.utils import add_lazy_dispatcher
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
@@ -223,3 +224,6 @@ def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kw
|
|
|
223
224
|
allow_downcast=allow_downcast,
|
|
224
225
|
name=name,
|
|
225
226
|
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
add_lazy_dispatcher(shared_constructor)
|
|
@@ -510,8 +510,8 @@ class FullHistory(Feature):
|
|
|
510
510
|
from pytensor.graph.features import FullHistory
|
|
511
511
|
from pytensor.graph.rewriting.utils import rewrite_graph
|
|
512
512
|
|
|
513
|
-
x = pt.
|
|
514
|
-
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
|
|
513
|
+
x = pt.vector("x")
|
|
514
|
+
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x), keepdims=True))
|
|
515
515
|
|
|
516
516
|
fg = FunctionGraph(outputs=[out])
|
|
517
517
|
history = FullHistory()
|
|
@@ -528,22 +528,24 @@ class FullHistory(Feature):
|
|
|
528
528
|
pytensor.dprint(history.next())
|
|
529
529
|
|
|
530
530
|
.. testoutput::
|
|
531
|
-
Log [id A]
|
|
532
|
-
└─ True_div [id B]
|
|
533
|
-
├─ Exp [id C]
|
|
531
|
+
Log [id A] 5
|
|
532
|
+
└─ True_div [id B] 4
|
|
533
|
+
├─ Exp [id C] 3
|
|
534
534
|
│ └─ x [id D]
|
|
535
|
-
└─
|
|
536
|
-
└─
|
|
537
|
-
└─
|
|
535
|
+
└─ ExpandDims{axis=0} [id E] 2
|
|
536
|
+
└─ Sum{axes=None} [id F] 1
|
|
537
|
+
└─ Exp [id G] 0
|
|
538
|
+
└─ x [id D]
|
|
538
539
|
>> MergeOptimizer
|
|
539
|
-
Log [id A]
|
|
540
|
-
└─ True_div [id B]
|
|
540
|
+
Log [id A] 4
|
|
541
|
+
└─ True_div [id B] 3
|
|
541
542
|
├─ Exp [id C] 0
|
|
542
543
|
│ └─ x [id D]
|
|
543
|
-
└─
|
|
544
|
-
└─
|
|
545
|
-
└─
|
|
546
|
-
|
|
544
|
+
└─ ExpandDims{axis=0} [id E] 2
|
|
545
|
+
└─ Sum{axes=None} [id F] 1
|
|
546
|
+
└─ Exp [id C] 0
|
|
547
|
+
└─ ···
|
|
548
|
+
>> local_softmax_stabilize
|
|
547
549
|
Log [id A] 1
|
|
548
550
|
└─ Softmax{axis=None} [id B] 0
|
|
549
551
|
└─ x [id C]
|
|
@@ -564,22 +566,24 @@ class FullHistory(Feature):
|
|
|
564
566
|
Log [id A] 1
|
|
565
567
|
└─ Softmax{axis=None} [id B] 0
|
|
566
568
|
└─ x [id C]
|
|
567
|
-
>>
|
|
568
|
-
Log [id A]
|
|
569
|
-
└─ True_div [id B]
|
|
569
|
+
>> local_softmax_stabilize
|
|
570
|
+
Log [id A] 4
|
|
571
|
+
└─ True_div [id B] 3
|
|
570
572
|
├─ Exp [id C] 0
|
|
571
573
|
│ └─ x [id D]
|
|
572
|
-
└─
|
|
573
|
-
└─
|
|
574
|
-
└─
|
|
574
|
+
└─ ExpandDims{axis=0} [id E] 2
|
|
575
|
+
└─ Sum{axes=None} [id F] 1
|
|
576
|
+
└─ Exp [id C] 0
|
|
577
|
+
└─ ···
|
|
575
578
|
>> MergeOptimizer
|
|
576
|
-
Log [id A]
|
|
577
|
-
└─ True_div [id B]
|
|
578
|
-
├─ Exp [id C]
|
|
579
|
+
Log [id A] 5
|
|
580
|
+
└─ True_div [id B] 4
|
|
581
|
+
├─ Exp [id C] 3
|
|
579
582
|
│ └─ x [id D]
|
|
580
|
-
└─
|
|
581
|
-
└─
|
|
582
|
-
└─
|
|
583
|
+
└─ ExpandDims{axis=0} [id E] 2
|
|
584
|
+
└─ Sum{axes=None} [id F] 1
|
|
585
|
+
└─ Exp [id G] 0
|
|
586
|
+
└─ x [id D]
|
|
583
587
|
|
|
584
588
|
|
|
585
589
|
.. testcode::
|
|
@@ -3,7 +3,7 @@ import jax.numpy as jnp
|
|
|
3
3
|
|
|
4
4
|
from pytensor.link.jax.dispatch.basic import jax_funcify
|
|
5
5
|
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
|
6
|
-
from pytensor.tensor.special import LogSoftmax, Softmax
|
|
6
|
+
from pytensor.tensor.special import LogSoftmax, Softmax
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
@jax_funcify.register(Elemwise)
|
|
@@ -94,17 +94,6 @@ def jax_funcify_Softmax(op, **kwargs):
|
|
|
94
94
|
return softmax
|
|
95
95
|
|
|
96
96
|
|
|
97
|
-
@jax_funcify.register(SoftmaxGrad)
|
|
98
|
-
def jax_funcify_SoftmaxGrad(op, **kwargs):
|
|
99
|
-
axis = op.axis
|
|
100
|
-
|
|
101
|
-
def softmax_grad(dy, sm):
|
|
102
|
-
dy_times_sm = dy * sm
|
|
103
|
-
return dy_times_sm - jnp.sum(dy_times_sm, axis=axis, keepdims=True) * sm
|
|
104
|
-
|
|
105
|
-
return softmax_grad
|
|
106
|
-
|
|
107
|
-
|
|
108
97
|
@jax_funcify.register(LogSoftmax)
|
|
109
98
|
def jax_funcify_LogSoftmax(op, **kwargs):
|
|
110
99
|
axis = op.axis
|
|
@@ -82,8 +82,22 @@ def jax_funcify_ChoSolve(op, **kwargs):
|
|
|
82
82
|
|
|
83
83
|
|
|
84
84
|
@jax_funcify.register(SolveSylvester)
|
|
85
|
-
def
|
|
85
|
+
def jax_funcify_SolveSylvester(op, **kwargs):
|
|
86
|
+
@jax.custom_vjp
|
|
86
87
|
def solve_sylvester(a, b, c):
|
|
87
88
|
return jax.scipy.linalg.solve_sylvester(a, b, c)
|
|
88
89
|
|
|
90
|
+
def _fwd(a, b, c):
|
|
91
|
+
x = jax.scipy.linalg.solve_sylvester(a, b, c)
|
|
92
|
+
return x, (a, b, x)
|
|
93
|
+
|
|
94
|
+
def _bwd(res, dx):
|
|
95
|
+
a, b, x = res
|
|
96
|
+
dc = jax.scipy.linalg.solve_sylvester(a.conj().T, b.conj().T, dx)
|
|
97
|
+
da = -dc @ x.conj().T
|
|
98
|
+
db = -x.conj().T @ dc
|
|
99
|
+
return da, db, dc
|
|
100
|
+
|
|
101
|
+
solve_sylvester.defvjp(_fwd, _bwd)
|
|
102
|
+
|
|
89
103
|
return solve_sylvester
|
|
@@ -14,7 +14,7 @@ from pytensor.scalar.basic import (
|
|
|
14
14
|
Mul,
|
|
15
15
|
)
|
|
16
16
|
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
|
17
|
-
from pytensor.tensor.special import LogSoftmax, Softmax
|
|
17
|
+
from pytensor.tensor.special import LogSoftmax, Softmax
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@mlx_funcify.register(DimShuffle)
|
|
@@ -105,17 +105,6 @@ def mlx_funcify_Softmax(op, **kwargs):
|
|
|
105
105
|
return softmax
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
@mlx_funcify.register(SoftmaxGrad)
|
|
109
|
-
def mlx_funcify_SoftmaxGrad(op, **kwargs):
|
|
110
|
-
axis = op.axis
|
|
111
|
-
|
|
112
|
-
def softmax_grad(dy, sm):
|
|
113
|
-
dy_times_sm = dy * sm
|
|
114
|
-
return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm
|
|
115
|
-
|
|
116
|
-
return softmax_grad
|
|
117
|
-
|
|
118
|
-
|
|
119
108
|
@mlx_funcify.register(LogSoftmax)
|
|
120
109
|
def mlx_funcify_LogSoftmax(op, **kwargs):
|
|
121
110
|
axis = op.axis
|
|
@@ -50,7 +50,9 @@ def numba_deepcopy_tensor(x):
|
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
@register_funcify_and_cache_key(OpFromGraph)
|
|
53
|
-
def numba_funcify_OpFromGraph(
|
|
53
|
+
def numba_funcify_OpFromGraph(
|
|
54
|
+
op, node=None, mode=NUMBA.excluding("symbolic_op_recognition"), **kwargs
|
|
55
|
+
):
|
|
54
56
|
_ = kwargs.pop("storage_map", None)
|
|
55
57
|
|
|
56
58
|
# Apply inner rewrites
|
|
@@ -64,7 +66,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
|
|
|
64
66
|
input_specs=input_specs,
|
|
65
67
|
accept_inplace=True,
|
|
66
68
|
)
|
|
67
|
-
|
|
69
|
+
mode.optimizer(fgraph)
|
|
68
70
|
output_specs = [Out(o, borrow=False) for o in fgraph.outputs]
|
|
69
71
|
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
|
|
70
72
|
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
|
|
@@ -27,6 +27,7 @@ _C_TO_NUMPY: dict[str, DTypeLike] = {
|
|
|
27
27
|
"long double": np.longdouble,
|
|
28
28
|
"float complex": np.csingle,
|
|
29
29
|
"double complex": np.cdouble,
|
|
30
|
+
"Py_ssize_t": np.intp,
|
|
30
31
|
}
|
|
31
32
|
|
|
32
33
|
|
|
@@ -83,8 +84,10 @@ class Signature:
|
|
|
83
84
|
raw_args = groups["args"]
|
|
84
85
|
|
|
85
86
|
decl_expr = re.compile(
|
|
86
|
-
rb"\s*(?P<type>
|
|
87
|
-
rb"((
|
|
87
|
+
rb"\s*(?P<type>"
|
|
88
|
+
rb"((long )|(unsigned )|(signed )|(double )|)"
|
|
89
|
+
rb"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex))"
|
|
90
|
+
rb"|Py_ssize_t)"
|
|
88
91
|
rb"(\s(?P<name>[\w_]*))?\s*"
|
|
89
92
|
)
|
|
90
93
|
|
|
@@ -36,13 +36,10 @@ from pytensor.scalar.basic import (
|
|
|
36
36
|
Sub,
|
|
37
37
|
TrueDiv,
|
|
38
38
|
get_scalar_type,
|
|
39
|
-
maximum,
|
|
40
39
|
)
|
|
41
|
-
from pytensor.scalar.basic import add as add_as
|
|
42
40
|
from pytensor.tensor.blas import BatchedDot
|
|
43
41
|
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
|
44
42
|
from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum
|
|
45
|
-
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
|
|
46
43
|
|
|
47
44
|
|
|
48
45
|
@singledispatch
|
|
@@ -505,122 +502,6 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
|
|
|
505
502
|
return dimshuffle, cache_version
|
|
506
503
|
|
|
507
504
|
|
|
508
|
-
@register_funcify_default_op_cache_key(Softmax)
|
|
509
|
-
def numba_funcify_Softmax(op, node, **kwargs):
|
|
510
|
-
ndim = node.inputs[0].type.ndim
|
|
511
|
-
inp_dtype = node.inputs[0].type.numpy_dtype
|
|
512
|
-
axis = op.axis
|
|
513
|
-
|
|
514
|
-
if ndim > 1 and axis is not None:
|
|
515
|
-
reduce_max_py = create_multiaxis_reducer(
|
|
516
|
-
maximum,
|
|
517
|
-
identity=-np.inf,
|
|
518
|
-
axes=(axis,),
|
|
519
|
-
ndim=ndim,
|
|
520
|
-
out_dtype=inp_dtype,
|
|
521
|
-
keepdims=True,
|
|
522
|
-
)
|
|
523
|
-
reduce_sum_py = create_multiaxis_reducer(
|
|
524
|
-
add_as,
|
|
525
|
-
identity=0.0,
|
|
526
|
-
axes=(axis,),
|
|
527
|
-
ndim=ndim,
|
|
528
|
-
out_dtype=inp_dtype,
|
|
529
|
-
keepdims=True,
|
|
530
|
-
)
|
|
531
|
-
|
|
532
|
-
jit_fn = numba_basic.numba_njit(boundscheck=False)
|
|
533
|
-
reduce_max = jit_fn(reduce_max_py)
|
|
534
|
-
reduce_sum = jit_fn(reduce_sum_py)
|
|
535
|
-
else:
|
|
536
|
-
reduce_max = np.max
|
|
537
|
-
reduce_sum = np.sum
|
|
538
|
-
|
|
539
|
-
@numba_basic.numba_njit(boundscheck=False)
|
|
540
|
-
def softmax(x):
|
|
541
|
-
z = reduce_max(x)
|
|
542
|
-
e_x = np.exp(x - z)
|
|
543
|
-
w = reduce_sum(e_x)
|
|
544
|
-
sm = e_x / w
|
|
545
|
-
return sm
|
|
546
|
-
|
|
547
|
-
cache_version = 1
|
|
548
|
-
return softmax, cache_version
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
@register_funcify_default_op_cache_key(SoftmaxGrad)
|
|
552
|
-
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
|
|
553
|
-
ndim = node.inputs[0].type.ndim
|
|
554
|
-
inp_dtype = node.inputs[0].type.numpy_dtype
|
|
555
|
-
|
|
556
|
-
axis = op.axis
|
|
557
|
-
if ndim > 1 and axis is not None:
|
|
558
|
-
reduce_sum_py = create_multiaxis_reducer(
|
|
559
|
-
add_as,
|
|
560
|
-
identity=0.0,
|
|
561
|
-
axes=(axis,),
|
|
562
|
-
ndim=ndim,
|
|
563
|
-
out_dtype=inp_dtype,
|
|
564
|
-
keepdims=True,
|
|
565
|
-
)
|
|
566
|
-
|
|
567
|
-
jit_fn = numba_basic.numba_njit(boundscheck=False)
|
|
568
|
-
reduce_sum = jit_fn(reduce_sum_py)
|
|
569
|
-
else:
|
|
570
|
-
reduce_sum = np.sum
|
|
571
|
-
|
|
572
|
-
@numba_basic.numba_njit(boundscheck=False)
|
|
573
|
-
def softmax_grad(dy, sm):
|
|
574
|
-
dy_times_sm = dy * sm
|
|
575
|
-
sum_dy_times_sm = reduce_sum(dy_times_sm)
|
|
576
|
-
dx = dy_times_sm - sum_dy_times_sm * sm
|
|
577
|
-
return dx
|
|
578
|
-
|
|
579
|
-
cache_version = 1
|
|
580
|
-
return softmax_grad, cache_version
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
@register_funcify_default_op_cache_key(LogSoftmax)
|
|
584
|
-
def numba_funcify_LogSoftmax(op, node, **kwargs):
|
|
585
|
-
ndim = node.inputs[0].type.ndim
|
|
586
|
-
inp_dtype = node.inputs[0].type.numpy_dtype
|
|
587
|
-
axis = op.axis
|
|
588
|
-
|
|
589
|
-
if ndim > 1 and axis is not None:
|
|
590
|
-
reduce_max_py = create_multiaxis_reducer(
|
|
591
|
-
maximum,
|
|
592
|
-
identity=-np.inf,
|
|
593
|
-
axes=(axis,),
|
|
594
|
-
ndim=ndim,
|
|
595
|
-
out_dtype=inp_dtype,
|
|
596
|
-
keepdims=True,
|
|
597
|
-
)
|
|
598
|
-
reduce_sum_py = create_multiaxis_reducer(
|
|
599
|
-
add_as,
|
|
600
|
-
identity=0.0,
|
|
601
|
-
axes=(axis,),
|
|
602
|
-
ndim=ndim,
|
|
603
|
-
out_dtype=inp_dtype,
|
|
604
|
-
keepdims=True,
|
|
605
|
-
)
|
|
606
|
-
|
|
607
|
-
jit_fn = numba_basic.numba_njit(boundscheck=False)
|
|
608
|
-
reduce_max = jit_fn(reduce_max_py)
|
|
609
|
-
reduce_sum = jit_fn(reduce_sum_py)
|
|
610
|
-
else:
|
|
611
|
-
reduce_max = np.max
|
|
612
|
-
reduce_sum = np.sum
|
|
613
|
-
|
|
614
|
-
@numba_basic.numba_njit(boundscheck=False)
|
|
615
|
-
def log_softmax(x):
|
|
616
|
-
xdev = x - reduce_max(x)
|
|
617
|
-
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
|
|
618
|
-
return lsm
|
|
619
|
-
|
|
620
|
-
cache_version = 1
|
|
621
|
-
return log_softmax, cache_version
|
|
622
|
-
|
|
623
|
-
|
|
624
505
|
@register_funcify_default_op_cache_key(Argmax)
|
|
625
506
|
def numba_funcify_Argmax(op, node, **kwargs):
|
|
626
507
|
axis = op.axis
|
|
@@ -6,7 +6,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
|
|
|
6
6
|
from pytensor.scalar import ScalarLoop
|
|
7
7
|
from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
|
8
8
|
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
|
|
9
|
-
from pytensor.tensor.special import LogSoftmax, Softmax
|
|
9
|
+
from pytensor.tensor.special import LogSoftmax, Softmax
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@pytorch_funcify.register(Elemwise)
|
|
@@ -129,9 +129,34 @@ def pytorch_funcify_min(op, **kwargs):
|
|
|
129
129
|
return torch_min
|
|
130
130
|
|
|
131
131
|
|
|
132
|
+
def _pytorch_softmax_dispatch(torch_fn, axis):
|
|
133
|
+
if axis is None:
|
|
134
|
+
|
|
135
|
+
def fn(x):
|
|
136
|
+
return torch_fn(x.ravel(), dim=0).reshape(x.shape)
|
|
137
|
+
|
|
138
|
+
elif len(axis) == 1:
|
|
139
|
+
|
|
140
|
+
def fn(x):
|
|
141
|
+
return torch_fn(x, dim=axis[0])
|
|
142
|
+
|
|
143
|
+
else:
|
|
144
|
+
|
|
145
|
+
def fn(x):
|
|
146
|
+
orig_shape = x.shape
|
|
147
|
+
x = torch.movedim(x, axis, tuple(range(-len(axis), 0)))
|
|
148
|
+
unflatten_shape = x.shape[: -len(axis)]
|
|
149
|
+
x = x.reshape(*unflatten_shape, -1)
|
|
150
|
+
x = torch_fn(x, dim=-1)
|
|
151
|
+
x = x.reshape(*unflatten_shape, *[orig_shape[a] for a in axis])
|
|
152
|
+
x = torch.movedim(x, tuple(range(-len(axis), 0)), axis)
|
|
153
|
+
return x
|
|
154
|
+
|
|
155
|
+
return fn
|
|
156
|
+
|
|
157
|
+
|
|
132
158
|
@pytorch_funcify.register(Softmax)
|
|
133
159
|
def pytorch_funcify_Softmax(op, **kwargs):
|
|
134
|
-
axis = op.axis
|
|
135
160
|
dtype = kwargs["node"].inputs[0].dtype
|
|
136
161
|
|
|
137
162
|
if not dtype.startswith("float"):
|
|
@@ -139,18 +164,11 @@ def pytorch_funcify_Softmax(op, **kwargs):
|
|
|
139
164
|
"Pytorch Softmax is not currently implemented for non-float types."
|
|
140
165
|
)
|
|
141
166
|
|
|
142
|
-
|
|
143
|
-
if axis is not None:
|
|
144
|
-
return torch.softmax(x, dim=axis)
|
|
145
|
-
else:
|
|
146
|
-
return torch.softmax(x.ravel(), dim=0).reshape(x.shape)
|
|
147
|
-
|
|
148
|
-
return softmax
|
|
167
|
+
return _pytorch_softmax_dispatch(torch.softmax, op.axis)
|
|
149
168
|
|
|
150
169
|
|
|
151
170
|
@pytorch_funcify.register(LogSoftmax)
|
|
152
171
|
def pytorch_funcify_LogSoftmax(op, **kwargs):
|
|
153
|
-
axis = op.axis
|
|
154
172
|
dtype = kwargs["node"].inputs[0].dtype
|
|
155
173
|
|
|
156
174
|
if not dtype.startswith("float"):
|
|
@@ -158,24 +176,7 @@ def pytorch_funcify_LogSoftmax(op, **kwargs):
|
|
|
158
176
|
"Pytorch LogSoftmax is not currently implemented for non-float types."
|
|
159
177
|
)
|
|
160
178
|
|
|
161
|
-
|
|
162
|
-
if axis is not None:
|
|
163
|
-
return torch.log_softmax(x, dim=axis)
|
|
164
|
-
else:
|
|
165
|
-
return torch.log_softmax(x.ravel(), dim=0).reshape(x.shape)
|
|
166
|
-
|
|
167
|
-
return log_softmax
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
@pytorch_funcify.register(SoftmaxGrad)
|
|
171
|
-
def jax_funcify_SoftmaxGrad(op, **kwargs):
|
|
172
|
-
axis = op.axis
|
|
173
|
-
|
|
174
|
-
def softmax_grad(dy, sm):
|
|
175
|
-
dy_times_sm = dy * sm
|
|
176
|
-
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
|
|
177
|
-
|
|
178
|
-
return softmax_grad
|
|
179
|
+
return _pytorch_softmax_dispatch(torch.log_softmax, op.axis)
|
|
179
180
|
|
|
180
181
|
|
|
181
182
|
def elemwise_ravel_fn(base_fn, op, node, **kwargs):
|