pytensor 2.34.0__tar.gz → 2.35.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytensor-2.34.0/pytensor.egg-info → pytensor-2.35.0}/PKG-INFO +3 -4
- pytensor-2.35.0/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb +436 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/environment.yml +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pyproject.toml +4 -5
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/_version.py +3 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/bin/pytensor_cache.py +3 -6
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/builders.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/function/__init__.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/function/pfunc.py +2 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/function/types.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/mode.py +18 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/monitormode.py +1 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/profiling.py +26 -37
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/configdefaults.py +1 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/gradient.py +4 -5
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/basic.py +8 -10
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/features.py +1 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/fg.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/rewriting/basic.py +5 -7
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/rewriting/db.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/basic.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/basic.py +11 -13
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/cmodule.py +3 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/op.py +2 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/params_type.py +11 -6
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/type.py +9 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/random.py +4 -4
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/scan.py +4 -4
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/tensor_basic.py +10 -8
- pytensor-2.35.0/pytensor/link/mlx/__init__.py +1 -0
- pytensor-2.35.0/pytensor/link/mlx/dispatch/__init__.py +13 -0
- pytensor-2.35.0/pytensor/link/mlx/dispatch/basic.py +101 -0
- pytensor-2.35.0/pytensor/link/mlx/dispatch/blockwise.py +35 -0
- pytensor-2.35.0/pytensor/link/mlx/dispatch/core.py +321 -0
- pytensor-2.35.0/pytensor/link/mlx/dispatch/elemwise.py +446 -0
- pytensor-2.35.0/pytensor/link/mlx/dispatch/math.py +72 -0
- pytensor-2.35.0/pytensor/link/mlx/dispatch/shape.py +42 -0
- pytensor-2.35.0/pytensor/link/mlx/dispatch/signal/conv.py +33 -0
- pytensor-2.35.0/pytensor/link/mlx/dispatch/subtensor.py +105 -0
- pytensor-2.35.0/pytensor/link/mlx/linker.py +69 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/elemwise.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/_LAPACK.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/decomposition/lu.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/scalar.py +2 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/subtensor.py +3 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/basic.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/subtensor.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/linker.py +4 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/vm.py +6 -6
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/misc/ordered_set.py +1 -1
- pytensor-2.35.0/pytensor/npy_2_compat.py +22 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/printing.py +11 -5
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scalar/basic.py +12 -13
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scalar/loop.py +2 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scalar/math.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scan/op.py +4 -5
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scan/rewriting.py +1 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/sparse/basic.py +8 -10
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/sparse/rewriting.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/_linalg/solve/rewriting.py +4 -4
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/basic.py +61 -61
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/blas.py +7 -10
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/conv/abstract_conv.py +13 -13
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/einsum.py +8 -9
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/elemwise.py +2 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/elemwise_cgen.py +4 -4
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/extra_ops.py +24 -32
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/math.py +123 -131
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/nlinalg.py +7 -7
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/optimize.py +5 -5
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/pad.py +5 -5
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/rewriting/numba.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/basic.py +2 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/elemwise.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/extra_ops.py +3 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/linalg.py +4 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/math.py +3 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/shape.py +1 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/subtensor_lift.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/shape.py +3 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/signal/conv.py +5 -5
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/slinalg.py +103 -36
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/special.py +9 -22
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/subtensor.py +5 -225
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/type_other.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/utils.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/utils.py +16 -5
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/rewriting/shape.py +3 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/shape.py +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/type.py +2 -2
- {pytensor-2.34.0 → pytensor-2.35.0/pytensor.egg-info}/PKG-INFO +3 -4
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor.egg-info/SOURCES.txt +13 -3
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor.egg-info/requires.txt +1 -1
- {pytensor-2.34.0 → pytensor-2.35.0}/scripts/slowest_tests/update-slowest-times-issue.sh +2 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/test_config.py +7 -5
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/test_gradient.py +8 -8
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/test_ifelse.py +2 -2
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/test_printing.py +1 -1
- pytensor-2.34.0/pytensor/npy_2_compat.py +0 -308
- pytensor-2.34.0/pytensor/sparse/sandbox/sp.py +0 -447
- pytensor-2.34.0/pytensor/sparse/sandbox/sp2.py +0 -222
- {pytensor-2.34.0 → pytensor-2.35.0}/LICENSE.txt +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/MANIFEST.in +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/README.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/.templates/PLACEHOLDER +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/.templates/layout.html +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/.templates/nb-badges.html +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/.templates/rendered_citation.html +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/LICENSE.txt +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/README.md +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/_thumbnails/autodiff/vector_jacobian_product.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/acknowledgement.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/bcast.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/bcast.svg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/blog.md +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/conf.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/core_development_guide.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/css.inc +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/dev_start_guide.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/apply.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/apply.svg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/apply2.svg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/creating_a_c_op.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/creating_a_numba_jax_op.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/creating_an_op.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/ctype.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/extending_faq.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/extending_pytensor_solution_1.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/graph_rewriting.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/graphstructures.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/inplace.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/op.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/other_ops.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/pics/symbolic_graph_opt.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/pics/symbolic_graph_unopt.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/pipeline.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/scan.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/tips.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/type.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/unittest.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/extending/using_params.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/faq.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/gallery/applications/normalizing_flows_in_pytensor.ipynb +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/gallery/autodiff/vector_jacobian_product.ipynb +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/gallery/introduction/pytensor_intro.ipynb +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/gallery/optimize/root.ipynb +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/gallery/page_footer.md +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/gallery/rewrites/graph_rewrites.ipynb +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/gallery/scan/scan_tutorial.ipynb +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/generate_dtype_tensor_table.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/glossary.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/Elman_srnn.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/PyTensor.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/PyTensor_RGB.svg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/PyTensor_logo.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/binder.svg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/blocksparse.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/colab.svg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/github.svg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/lstm.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/lstm_memorycell.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/talk2010.gif +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/images/talk2010.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/install.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/internal/how_to_release.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/internal/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/internal/metadocumentation.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/introduction.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/debugmode.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/function.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/io.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/mode.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/nanguardmode.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/opfromgraph.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/ops.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/profilemode.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/compile/shared.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/config.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/d3viz/css/d3-context-menu.css +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/d3viz/css/d3viz.css +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/d3viz/js/d3-context-menu.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/d3viz/js/d3.v3.min.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/d3viz/js/d3viz.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/d3viz/js/dagre-d3.min.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/d3viz/js/graphlib-dot.min.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/mlp.html +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/mlp.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/mlp2.html +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/mlp2.pdf +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/mlp2.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/ofg.html +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/examples/ofg2.html +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/index.ipynb +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/index_files/index_10_0.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/index_files/index_11_0.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/index_files/index_24_0.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/d3viz/index_files/index_25_0.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/graph/features.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/graph/fgraph.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/graph/graph.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/graph/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/graph/op.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/graph/replace.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/graph/type.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/graph/utils.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/misc/pkl_utils.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/printing.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/scalar/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/scan.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/sparse/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/sparse/sandbox.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/basic.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/basic_opt.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/bcast.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/bcast.svg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/conv.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/elemwise.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/extra_ops.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/fft.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/functional.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/io.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/math_opt.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/nlinalg.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/optimize.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/plot_fft.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/random/distributions.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/random/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/slinalg.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/tensor/utils.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/typed_list.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/xtensor/index.md +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/xtensor/linalg.md +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/xtensor/math.md +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/xtensor/module_functions.md +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/xtensor/random.md +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/library/xtensor/type.md +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/links.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/optimizations.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/pylintrc +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/robots.txt +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/troubleshooting.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/adding.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/adding_solution_1.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/aliasing.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/apply.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/apply.svg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/bcast.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/broadcasting.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/conditions.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/debug_faq.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/dlogistic.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/examples.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/faq_tutorial.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/gradients.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/index.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/loading_and_saving.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/logistic.gp +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/logistic.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/loop.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/loop_solution_1.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/modes.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/modes_solution_1.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/multi_cores.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/nan_tutorial.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/pics/d3viz.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/pics/logreg_pydotprint_predict.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/pics/logreg_pydotprint_prediction.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/pics/logreg_pydotprint_train.png +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/printing_drawing.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/prng.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/profiling.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/profiling_example.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/profiling_example_out.prof +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/shape_info.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/sparse.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/tutorial/symbolic_graphs.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/doc/user_guide.rst +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/bin/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/breakpoint.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/compiledir.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/compilelock.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/debugmode.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/io.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/nanguardmode.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/ops.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/compile/sharedvalue.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/configparser.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/css/d3-context-menu.css +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/css/d3viz.css +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/d3viz.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/formatting.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/html/template.html +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/js/d3-context-menu.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/js/d3.v3.min.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/js/d3viz.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/js/dagre-d3.min.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/d3viz/js/graphlib-dot.min.js +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/destroyhandler.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/null_type.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/op.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/replace.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/rewriting/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/rewriting/kanren.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/rewriting/unify.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/rewriting/utils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/traversal.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/type.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/graph/utils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/ifelse.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/ipython.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/c_code/lazylinker_c.c +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/c_code/pytensor_mod_helper.h +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/cutils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/cvm.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/exceptions.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/interface.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/c/lazylinker_c.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/basic.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/blas.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/blockwise.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/einsum.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/elemwise.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/extra_ops.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/math.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/nlinalg.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/pad.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/scalar.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/shape.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/signal/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/signal/conv.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/slinalg.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/sort.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/sparse.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/dispatch/subtensor.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/linker.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/jax/ops.py +0 -0
- {pytensor-2.34.0/pytensor/link/numba/dispatch/linalg → pytensor-2.35.0/pytensor/link/mlx/dispatch/signal}/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/basic.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/blockwise.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/cython_support.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/extra_ops.py +0 -0
- {pytensor-2.34.0/pytensor/link/numba/dispatch/linalg/decomposition → pytensor-2.35.0/pytensor/link/numba/dispatch/linalg}/__init__.py +0 -0
- {pytensor-2.34.0/pytensor/link/numba/dispatch/linalg/solve → pytensor-2.35.0/pytensor/link/numba/dispatch/linalg/decomposition}/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/decomposition/qr.py +0 -0
- {pytensor-2.34.0/pytensor/misc → pytensor-2.35.0/pytensor/link/numba/dispatch/linalg/solve}/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/solve/cholesky.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/solve/general.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/solve/lu_solve.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/solve/norm.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/solve/posdef.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/solve/symmetric.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/solve/triangular.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/solve/utils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/linalg/utils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/nlinalg.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/random.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/scan.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/signal/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/signal/conv.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/slinalg.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/sparse.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/tensor_basic.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/dispatch/vectorize_codegen.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/numba/linker.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/blas.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/blockwise.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/elemwise.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/extra_ops.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/math.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/nlinalg.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/scalar.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/shape.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/slinalg.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/pytorch/dispatch/sort.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/link/utils.py +0 -0
- {pytensor-2.34.0/pytensor/sparse/sandbox → pytensor-2.35.0/pytensor/misc}/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/misc/check_blas.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/misc/check_blas_many.sh +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/misc/check_duplicate_key.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/misc/elemwise_openmp_speedup.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/misc/elemwise_time_test.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/misc/frozendict.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/misc/may_share_memory.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/misc/pkl_utils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/py.typed +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/raise_op.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scalar/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scalar/c_code/Faddeeva.cc +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scalar/c_code/Faddeeva.hh +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scalar/c_code/gamma.c +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scalar/c_code/incbet.c +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scalar/sharedvar.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scan/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scan/basic.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scan/checkpoints.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scan/scan_perform.pyx +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scan/scan_perform_ext.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scan/utils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/scan/views.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/sparse/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/sparse/sharedvar.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/sparse/type.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/sparse/utils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/_linalg/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/_linalg/solve/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/_linalg/solve/tridiagonal.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/blas_c.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/blas_headers.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/blockwise.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/c_code/alt_blas_common.h +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/c_code/alt_blas_template.c +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/c_code/dimshuffle.c +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/conv/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/exceptions.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/fft.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/fourier.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/functional.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/inplace.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/interpolate.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/io.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/linalg.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/basic.py +28 -28
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/op.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/rewriting/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/rewriting/basic.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/rewriting/jax.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/type.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/utils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/random/var.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/blas.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/blas_c.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/blockwise.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/einsum.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/jax.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/numba.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/ofg.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/special.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/subtensor.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/rewriting/uncanonicalize.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/sharedvar.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/signal/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/sort.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/type.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/var.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/variable.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/tensor/xlogx.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/typed_list/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/typed_list/basic.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/typed_list/rewriting.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/typed_list/type.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/updates.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/basic.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/indexing.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/linalg.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/math.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/random.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/reduction.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/rewriting/__init__.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/rewriting/basic.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/rewriting/indexing.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/rewriting/math.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/rewriting/reduction.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/rewriting/utils.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/rewriting/vectorization.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor/xtensor/vectorization.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor.egg-info/dependency_links.txt +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor.egg-info/entry_points.txt +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/pytensor.egg-info/top_level.txt +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/scripts/mypy-failing.txt +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/setup.cfg +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/setup.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/link/c/c_code/test_cenum.h +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/link/c/c_code/test_quadratic_function.c +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/tensor/conv/c_code/corr3d_gemm.c +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/tensor/conv/c_code/corr_gemm.c +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/test_breakpoint.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/test_raise_op.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/test_rop.py +0 -0
- {pytensor-2.34.0 → pytensor-2.35.0}/tests/test_updates.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pytensor
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.35.0
|
|
4
4
|
Summary: Optimizing compiler for evaluating mathematical expressions on CPUs and GPUs.
|
|
5
5
|
Author-email: pymc-devs <pymc.devs@gmail.com>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -21,16 +21,15 @@ Classifier: Operating System :: POSIX
|
|
|
21
21
|
Classifier: Operating System :: Unix
|
|
22
22
|
Classifier: Operating System :: MacOS
|
|
23
23
|
Classifier: Programming Language :: Python :: 3
|
|
24
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
25
24
|
Classifier: Programming Language :: Python :: 3.11
|
|
26
25
|
Classifier: Programming Language :: Python :: 3.12
|
|
27
26
|
Classifier: Programming Language :: Python :: 3.13
|
|
28
|
-
Requires-Python: <3.14,>=3.
|
|
27
|
+
Requires-Python: <3.14,>=3.11
|
|
29
28
|
Description-Content-Type: text/x-rst
|
|
30
29
|
License-File: LICENSE.txt
|
|
31
30
|
Requires-Dist: setuptools>=59.0.0
|
|
32
31
|
Requires-Dist: scipy<2,>=1
|
|
33
|
-
Requires-Dist: numpy>=
|
|
32
|
+
Requires-Dist: numpy>=2.0
|
|
34
33
|
Requires-Dist: filelock>=3.15
|
|
35
34
|
Requires-Dist: etuples
|
|
36
35
|
Requires-Dist: logical-unification
|
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
{
|
|
2
|
+
"cells": [
|
|
3
|
+
{
|
|
4
|
+
"cell_type": "code",
|
|
5
|
+
"execution_count": 1,
|
|
6
|
+
"metadata": {},
|
|
7
|
+
"outputs": [
|
|
8
|
+
{
|
|
9
|
+
"name": "stdout",
|
|
10
|
+
"output_type": "stream",
|
|
11
|
+
"text": [
|
|
12
|
+
"Obtaining file:///Users/carlostrujillo/Documents/GitHub/pytensor\n",
|
|
13
|
+
" Installing build dependencies ... \u001b[?25ldone\n",
|
|
14
|
+
"\u001b[?25h Checking if build backend supports build_editable ... \u001b[?25ldone\n",
|
|
15
|
+
"\u001b[?25h Getting requirements to build editable ... \u001b[?25ldone\n",
|
|
16
|
+
"\u001b[?25h Preparing editable metadata (pyproject.toml) ... \u001b[?25ldone\n",
|
|
17
|
+
"\u001b[?25hBuilding wheels for collected packages: pytensor\n",
|
|
18
|
+
" Building editable for pytensor (pyproject.toml) ... \u001b[?25ldone\n",
|
|
19
|
+
"\u001b[?25h Created wheel for pytensor: filename=pytensor-2.31.7+80.g06ccf91ba.dirty-0.editable-cp312-cp312-macosx_11_0_arm64.whl size=7323 sha256=c09587a5f3141d49000666d2817c5a01436f13ff5a19aa3deda20f647660afee\n",
|
|
20
|
+
" Stored in directory: /private/var/folders/f0/rbz8xs8s17n3k3f_ccp31bvh0000gn/T/pip-ephem-wheel-cache-i00nb67k/wheels/52/f6/4c/e6784e2203d5405c94db1d544248730e598e4397674416af05\n",
|
|
21
|
+
"Successfully built pytensor\n",
|
|
22
|
+
"Installing collected packages: pytensor\n",
|
|
23
|
+
" Attempting uninstall: pytensor\n",
|
|
24
|
+
" Found existing installation: pytensor 2.31.7+80.g06ccf91ba.dirty\n",
|
|
25
|
+
" Uninstalling pytensor-2.31.7+80.g06ccf91ba.dirty:\n",
|
|
26
|
+
" Successfully uninstalled pytensor-2.31.7+80.g06ccf91ba.dirty\n",
|
|
27
|
+
"Successfully installed pytensor-2.31.7+80.g06ccf91ba.dirty\n",
|
|
28
|
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
|
29
|
+
]
|
|
30
|
+
}
|
|
31
|
+
],
|
|
32
|
+
"source": [
|
|
33
|
+
"%pip install -e ../.. --no-deps"
|
|
34
|
+
]
|
|
35
|
+
},
|
|
36
|
+
{
|
|
37
|
+
"cell_type": "code",
|
|
38
|
+
"execution_count": 1,
|
|
39
|
+
"metadata": {},
|
|
40
|
+
"outputs": [],
|
|
41
|
+
"source": [
|
|
42
|
+
"import time\n",
|
|
43
|
+
"import numpy as np\n",
|
|
44
|
+
"import jax\n",
|
|
45
|
+
"import jax.numpy as jnp\n",
|
|
46
|
+
"\n",
|
|
47
|
+
"import pytensor\n",
|
|
48
|
+
"import pytensor.tensor as pt\n",
|
|
49
|
+
"from pytensor.compile.function import function\n",
|
|
50
|
+
"from pytensor.compile.mode import Mode\n",
|
|
51
|
+
"from pytensor.graph import RewriteDatabaseQuery\n",
|
|
52
|
+
"from pytensor.link.jax import JAXLinker\n"
|
|
53
|
+
]
|
|
54
|
+
},
|
|
55
|
+
{
|
|
56
|
+
"cell_type": "code",
|
|
57
|
+
"execution_count": 2,
|
|
58
|
+
"metadata": {},
|
|
59
|
+
"outputs": [],
|
|
60
|
+
"source": [
|
|
61
|
+
"# Configure JAX to use float32 for consistency with MLX\n",
|
|
62
|
+
"jax.config.update(\"jax_enable_x64\", False)\n",
|
|
63
|
+
"\n",
|
|
64
|
+
"# Set up PyTensor JAX mode\n",
|
|
65
|
+
"jax_optimizer = RewriteDatabaseQuery(include=[\"jax\"], exclude=[])\n",
|
|
66
|
+
"pytensor_jax_mode = \"JAX\"\n",
|
|
67
|
+
"\n",
|
|
68
|
+
"# Try to set up MLX mode\n",
|
|
69
|
+
"try:\n",
|
|
70
|
+
" from pytensor.link.mlx import MLXLinker\n",
|
|
71
|
+
" import mlx.core as mx\n",
|
|
72
|
+
" mlx_optimizer = RewriteDatabaseQuery(include=[\"mlx\"], exclude=[])\n",
|
|
73
|
+
" pytensor_mlx_mode = \"MLX\"\n",
|
|
74
|
+
" MLX_AVAILABLE = True\n",
|
|
75
|
+
"except ImportError:\n",
|
|
76
|
+
" MLX_AVAILABLE = False\n",
|
|
77
|
+
"\n",
|
|
78
|
+
"def timer_jax(func, N=1000):\n",
|
|
79
|
+
" \"\"\"Time function execution with proper JAX synchronization, repeated N times\"\"\"\n",
|
|
80
|
+
" def wrapper(*args, **kwargs):\n",
|
|
81
|
+
" times = []\n",
|
|
82
|
+
" for _ in range(N):\n",
|
|
83
|
+
" start = time.perf_counter()\n",
|
|
84
|
+
" result = func(*args, **kwargs)\n",
|
|
85
|
+
" if hasattr(result, 'block_until_ready'):\n",
|
|
86
|
+
" result.block_until_ready()\n",
|
|
87
|
+
" elif isinstance(result, (list, tuple)):\n",
|
|
88
|
+
" for r in result:\n",
|
|
89
|
+
" if hasattr(r, 'block_until_ready'):\n",
|
|
90
|
+
" r.block_until_ready()\n",
|
|
91
|
+
" end = time.perf_counter()\n",
|
|
92
|
+
" times.append(end - start)\n",
|
|
93
|
+
" \n",
|
|
94
|
+
" mean_time = np.mean(times)\n",
|
|
95
|
+
" std_time = np.std(times)\n",
|
|
96
|
+
" return result, mean_time, std_time\n",
|
|
97
|
+
" return wrapper\n",
|
|
98
|
+
"\n",
|
|
99
|
+
"def timer_mlx(func, N=1000):\n",
|
|
100
|
+
" \"\"\"Time function execution with proper MLX synchronization, repeated N times\"\"\"\n",
|
|
101
|
+
" def wrapper(*args, **kwargs):\n",
|
|
102
|
+
" times = []\n",
|
|
103
|
+
" for _ in range(N):\n",
|
|
104
|
+
" start = time.perf_counter()\n",
|
|
105
|
+
" result = func(*args, **kwargs)\n",
|
|
106
|
+
" # For MLX, we need to use mx.eval() to force computation\n",
|
|
107
|
+
" if MLX_AVAILABLE:\n",
|
|
108
|
+
" if isinstance(result, (list, tuple)):\n",
|
|
109
|
+
" mx.eval(*result)\n",
|
|
110
|
+
" else:\n",
|
|
111
|
+
" mx.eval(result)\n",
|
|
112
|
+
" end = time.perf_counter()\n",
|
|
113
|
+
" times.append(end - start)\n",
|
|
114
|
+
" \n",
|
|
115
|
+
" mean_time = np.mean(times)\n",
|
|
116
|
+
" std_time = np.std(times)\n",
|
|
117
|
+
" return result, mean_time, std_time\n",
|
|
118
|
+
" return wrapper\n",
|
|
119
|
+
"\n",
|
|
120
|
+
"def run_benchmark(N=1000):\n",
|
|
121
|
+
" \"\"\"Run comprehensive benchmark comparing PyTensor JAX vs MLX backends\"\"\"\n",
|
|
122
|
+
" import pandas as pd\n",
|
|
123
|
+
" \n",
|
|
124
|
+
" sizes = [2, 4, 1080, 2080, 3080]\n",
|
|
125
|
+
" results = []\n",
|
|
126
|
+
" \n",
|
|
127
|
+
" print(f\"Running benchmarks with N={N} repetitions per test...\")\n",
|
|
128
|
+
" \n",
|
|
129
|
+
" for size in sizes:\n",
|
|
130
|
+
" print(f\"Testing {size}x{size} matrices...\")\n",
|
|
131
|
+
" \n",
|
|
132
|
+
" # Generate test matrices with fixed seed for reproducibility\n",
|
|
133
|
+
" np.random.seed(42)\n",
|
|
134
|
+
" A = np.random.randn(size, size).astype(np.float32)\n",
|
|
135
|
+
" B = np.random.randn(size, size).astype(np.float32)\n",
|
|
136
|
+
" C = np.random.randn(size, size).astype(np.float32)\n",
|
|
137
|
+
"\n",
|
|
138
|
+
" pt_A = pt.matrix('A', dtype='float32')\n",
|
|
139
|
+
" pt_B = pt.matrix('B', dtype='float32') \n",
|
|
140
|
+
" pt_C = pt.matrix('C', dtype='float32')\n",
|
|
141
|
+
" result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n",
|
|
142
|
+
"\n",
|
|
143
|
+
"\n",
|
|
144
|
+
" f_jax = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode, trust_input=True)\n",
|
|
145
|
+
" f_mlx = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode, trust_input=True)\n",
|
|
146
|
+
" f_jax(A, B, C)\n",
|
|
147
|
+
" f_mlx(A, B, C)\n",
|
|
148
|
+
" \n",
|
|
149
|
+
" # === TEST 1: Matrix Multiplication Chain ===\n",
|
|
150
|
+
" # PyTensor + JAX backend\n",
|
|
151
|
+
" @timer_jax\n",
|
|
152
|
+
" def pytensor_jax_matmul():\n",
|
|
153
|
+
" return f_jax(A, B, C)\n",
|
|
154
|
+
" \n",
|
|
155
|
+
" # PyTensor + MLX backend\n",
|
|
156
|
+
" @timer_mlx\n",
|
|
157
|
+
" def pytensor_mlx_matmul():\n",
|
|
158
|
+
" if not MLX_AVAILABLE:\n",
|
|
159
|
+
" return None, float('inf'), 0\n",
|
|
160
|
+
" return f_mlx(A, B, C)\n",
|
|
161
|
+
" \n",
|
|
162
|
+
" # Run matrix multiplication test\n",
|
|
163
|
+
" _, jax_mean, jax_std = pytensor_jax_matmul()\n",
|
|
164
|
+
" try:\n",
|
|
165
|
+
" _, mlx_mean, mlx_std = pytensor_mlx_matmul()\n",
|
|
166
|
+
" except Exception as e:\n",
|
|
167
|
+
" print(f\"MLX matmul error: {e}\")\n",
|
|
168
|
+
" mlx_mean, mlx_std = float('inf'), 0\n",
|
|
169
|
+
" \n",
|
|
170
|
+
" # Calculate percentage improvement (positive = MLX is faster, negative = MLX is slower)\n",
|
|
171
|
+
" if mlx_mean != float('inf') and mlx_mean > 0:\n",
|
|
172
|
+
" speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n",
|
|
173
|
+
" speedup_str = f'{speedup_percentage:+.1f}%'\n",
|
|
174
|
+
" else:\n",
|
|
175
|
+
" speedup_str = 'N/A'\n",
|
|
176
|
+
" \n",
|
|
177
|
+
" results.append({\n",
|
|
178
|
+
" 'Size': f'{size}x{size}',\n",
|
|
179
|
+
" 'Operation': 'Matrix Chain (A @ B @ C)',\n",
|
|
180
|
+
" 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n",
|
|
181
|
+
" 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n",
|
|
182
|
+
" 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n",
|
|
183
|
+
" 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n",
|
|
184
|
+
" 'MLX Performance': speedup_str\n",
|
|
185
|
+
" })\n",
|
|
186
|
+
" \n",
|
|
187
|
+
" # === TEST 2: Element-wise Operations ===\n",
|
|
188
|
+
" # PyTensor + JAX\n",
|
|
189
|
+
" result = pt.sin(pt_A) + pt.cos(pt_B)\n",
|
|
190
|
+
" f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n",
|
|
191
|
+
" f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n",
|
|
192
|
+
" f_jax(A, B)\n",
|
|
193
|
+
" f_mlx(A, B)\n",
|
|
194
|
+
"\n",
|
|
195
|
+
" @timer_jax\n",
|
|
196
|
+
" def pytensor_jax_elemwise():\n",
|
|
197
|
+
" return f_jax(A, B)\n",
|
|
198
|
+
" \n",
|
|
199
|
+
" # PyTensor + MLX\n",
|
|
200
|
+
" @timer_mlx\n",
|
|
201
|
+
" def pytensor_mlx_elemwise():\n",
|
|
202
|
+
" if not MLX_AVAILABLE:\n",
|
|
203
|
+
" return None, float('inf'), 0\n",
|
|
204
|
+
" return f_mlx(A, B)\n",
|
|
205
|
+
" \n",
|
|
206
|
+
" # Run element-wise test\n",
|
|
207
|
+
" _, jax_mean, jax_std = pytensor_jax_elemwise()\n",
|
|
208
|
+
" try:\n",
|
|
209
|
+
" _, mlx_mean, mlx_std = pytensor_mlx_elemwise()\n",
|
|
210
|
+
" except Exception as e:\n",
|
|
211
|
+
" print(f\"MLX elemwise error: {e}\")\n",
|
|
212
|
+
" mlx_mean, mlx_std = float('inf'), 0\n",
|
|
213
|
+
" \n",
|
|
214
|
+
" # Calculate percentage improvement\n",
|
|
215
|
+
" if mlx_mean != float('inf') and mlx_mean > 0:\n",
|
|
216
|
+
" speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n",
|
|
217
|
+
" speedup_str = f'{speedup_percentage:+.1f}%'\n",
|
|
218
|
+
" else:\n",
|
|
219
|
+
" speedup_str = 'N/A'\n",
|
|
220
|
+
" \n",
|
|
221
|
+
" results.append({\n",
|
|
222
|
+
" 'Size': f'{size}x{size}',\n",
|
|
223
|
+
" 'Operation': 'Element-wise (sin(A) + cos(B))',\n",
|
|
224
|
+
" 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n",
|
|
225
|
+
" 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n",
|
|
226
|
+
" 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n",
|
|
227
|
+
" 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n",
|
|
228
|
+
" 'MLX Performance': speedup_str\n",
|
|
229
|
+
" })\n",
|
|
230
|
+
" \n",
|
|
231
|
+
" # === TEST 3: Matrix Addition with Broadcasting ===\n",
|
|
232
|
+
" # PyTensor + JAX\n",
|
|
233
|
+
" result = pt_A + pt_B.T\n",
|
|
234
|
+
" f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n",
|
|
235
|
+
" f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n",
|
|
236
|
+
" f_jax(A, B)\n",
|
|
237
|
+
" f_mlx(A, B)\n",
|
|
238
|
+
" @timer_jax\n",
|
|
239
|
+
" def pytensor_jax_broadcast():\n",
|
|
240
|
+
" return f_jax(A, B)\n",
|
|
241
|
+
" \n",
|
|
242
|
+
" # PyTensor + MLX\n",
|
|
243
|
+
" @timer_mlx\n",
|
|
244
|
+
" def pytensor_mlx_broadcast():\n",
|
|
245
|
+
" if not MLX_AVAILABLE:\n",
|
|
246
|
+
" return None, float('inf'), 0\n",
|
|
247
|
+
" return f_mlx(A, B)\n",
|
|
248
|
+
" \n",
|
|
249
|
+
" # Run broadcasting test\n",
|
|
250
|
+
" _, jax_mean, jax_std = pytensor_jax_broadcast()\n",
|
|
251
|
+
" try:\n",
|
|
252
|
+
" _, mlx_mean, mlx_std = pytensor_mlx_broadcast()\n",
|
|
253
|
+
" except Exception as e:\n",
|
|
254
|
+
" print(f\"MLX broadcast error: {e}\")\n",
|
|
255
|
+
" mlx_mean, mlx_std = float('inf'), 0\n",
|
|
256
|
+
" \n",
|
|
257
|
+
" # Calculate percentage improvement\n",
|
|
258
|
+
" if mlx_mean != float('inf') and mlx_mean > 0:\n",
|
|
259
|
+
" speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n",
|
|
260
|
+
" speedup_str = f'{speedup_percentage:+.1f}%'\n",
|
|
261
|
+
" else:\n",
|
|
262
|
+
" speedup_str = 'N/A'\n",
|
|
263
|
+
" \n",
|
|
264
|
+
" results.append({\n",
|
|
265
|
+
" 'Size': f'{size}x{size}',\n",
|
|
266
|
+
" 'Operation': 'Broadcasting (A + B.T)',\n",
|
|
267
|
+
" 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n",
|
|
268
|
+
" 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n",
|
|
269
|
+
" 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n",
|
|
270
|
+
" 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n",
|
|
271
|
+
" 'MLX Performance': speedup_str\n",
|
|
272
|
+
" })\n",
|
|
273
|
+
" \n",
|
|
274
|
+
" # Create and display results table\n",
|
|
275
|
+
" df = pd.DataFrame(results)\n",
|
|
276
|
+
" return df\n",
|
|
277
|
+
"\n",
|
|
278
|
+
"def main(N=1000):\n",
|
|
279
|
+
" \"\"\"Main benchmark execution\"\"\"\n",
|
|
280
|
+
" # Display system info\n",
|
|
281
|
+
" system_info = {\n",
|
|
282
|
+
" 'JAX version': jax.__version__,\n",
|
|
283
|
+
" 'PyTensor version': pytensor.__version__,\n",
|
|
284
|
+
" 'MLX Available': 'Yes' if MLX_AVAILABLE else 'No',\n",
|
|
285
|
+
" 'Platform': 'Apple Silicon' if MLX_AVAILABLE else 'Generic',\n",
|
|
286
|
+
" 'Repetitions (N)': N\n",
|
|
287
|
+
" }\n",
|
|
288
|
+
" \n",
|
|
289
|
+
" if MLX_AVAILABLE:\n",
|
|
290
|
+
" system_info['MLX version'] = mx.__version__\n",
|
|
291
|
+
" \n",
|
|
292
|
+
" import pandas as pd\n",
|
|
293
|
+
" info_df = pd.DataFrame([system_info])\n",
|
|
294
|
+
" \n",
|
|
295
|
+
" # Then run benchmarks\n",
|
|
296
|
+
" results_df = run_benchmark(N=N)\n",
|
|
297
|
+
" \n",
|
|
298
|
+
" return info_df, results_df\n"
|
|
299
|
+
]
|
|
300
|
+
},
|
|
301
|
+
{
|
|
302
|
+
"cell_type": "code",
|
|
303
|
+
"execution_count": 3,
|
|
304
|
+
"metadata": {},
|
|
305
|
+
"outputs": [
|
|
306
|
+
{
|
|
307
|
+
"name": "stdout",
|
|
308
|
+
"output_type": "stream",
|
|
309
|
+
"text": [
|
|
310
|
+
"Running benchmarks with N=150 repetitions per test...\n",
|
|
311
|
+
"Testing 2x2 matrices...\n",
|
|
312
|
+
"Testing 4x4 matrices...\n",
|
|
313
|
+
"Testing 1080x1080 matrices...\n",
|
|
314
|
+
"Testing 2080x2080 matrices...\n",
|
|
315
|
+
"Testing 3080x3080 matrices...\n"
|
|
316
|
+
]
|
|
317
|
+
}
|
|
318
|
+
],
|
|
319
|
+
"source": [
|
|
320
|
+
"iteration=150\n",
|
|
321
|
+
"_, results = main(N=iteration)"
|
|
322
|
+
]
|
|
323
|
+
},
|
|
324
|
+
{
|
|
325
|
+
"cell_type": "code",
|
|
326
|
+
"execution_count": 4,
|
|
327
|
+
"metadata": {},
|
|
328
|
+
"outputs": [
|
|
329
|
+
{
|
|
330
|
+
"name": "stdout",
|
|
331
|
+
"output_type": "stream",
|
|
332
|
+
"text": [
|
|
333
|
+
"\n",
|
|
334
|
+
"Benchmark Results over 150 repetitions:\n",
|
|
335
|
+
" Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Performance\n",
|
|
336
|
+
" 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000305 0.000299 -3213.5%\n",
|
|
337
|
+
" 2x2 Element-wise (sin(A) + cos(B)) 0.000007 0.000002 0.000352 0.003757 -5078.0%\n",
|
|
338
|
+
" 2x2 Broadcasting (A + B.T) 0.000007 0.000001 0.000188 0.000153 -2721.1%\n",
|
|
339
|
+
" 4x4 Matrix Chain (A @ B @ C) 0.000009 0.000001 0.000209 0.000063 -2126.2%\n",
|
|
340
|
+
" 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000180 0.000066 -2449.5%\n",
|
|
341
|
+
" 4x4 Broadcasting (A + B.T) 0.000007 0.000003 0.000181 0.000065 -2564.1%\n",
|
|
342
|
+
"1080x1080 Matrix Chain (A @ B @ C) 0.005951 0.000356 0.001355 0.000392 +77.2%\n",
|
|
343
|
+
"1080x1080 Element-wise (sin(A) + cos(B)) 0.002820 0.000107 0.000432 0.000207 +84.7%\n",
|
|
344
|
+
"1080x1080 Broadcasting (A + B.T) 0.000212 0.000035 0.000428 0.000206 -102.0%\n",
|
|
345
|
+
"2080x2080 Matrix Chain (A @ B @ C) 0.027609 0.001255 0.004550 0.002528 +83.5%\n",
|
|
346
|
+
"2080x2080 Element-wise (sin(A) + cos(B)) 0.010086 0.000417 0.001175 0.000350 +88.3%\n",
|
|
347
|
+
"2080x2080 Broadcasting (A + B.T) 0.000856 0.000068 0.001124 0.000241 -31.2%\n",
|
|
348
|
+
"3080x3080 Matrix Chain (A @ B @ C) 0.093115 0.003823 0.013649 0.000513 +85.3%\n",
|
|
349
|
+
"3080x3080 Element-wise (sin(A) + cos(B)) 0.022586 0.000756 0.001930 0.000287 +91.5%\n",
|
|
350
|
+
"3080x3080 Broadcasting (A + B.T) 0.002580 0.000161 0.001937 0.000257 +24.9%\n"
|
|
351
|
+
]
|
|
352
|
+
}
|
|
353
|
+
],
|
|
354
|
+
"source": [
|
|
355
|
+
"print(f\"\\nBenchmark Results over {iteration} repetitions:\")\n",
|
|
356
|
+
"print(results.to_string(index=False))"
|
|
357
|
+
]
|
|
358
|
+
},
|
|
359
|
+
{
|
|
360
|
+
"cell_type": "code",
|
|
361
|
+
"execution_count": null,
|
|
362
|
+
"metadata": {},
|
|
363
|
+
"outputs": [],
|
|
364
|
+
"source": [
|
|
365
|
+
"# # Additional timing analysis - separate compilation vs execution time\n",
|
|
366
|
+
"# if MLX_AVAILABLE:\n",
|
|
367
|
+
"# print(\"\\n=== Detailed MLX Timing Analysis ===\")\n",
|
|
368
|
+
" \n",
|
|
369
|
+
"# # Test with medium-sized matrix\n",
|
|
370
|
+
"# np.random.seed(42)\n",
|
|
371
|
+
"# A = np.random.randn(512, 512).astype(np.float32)\n",
|
|
372
|
+
"# B = np.random.randn(512, 512).astype(np.float32)\n",
|
|
373
|
+
"# C = np.random.randn(512, 512).astype(np.float32)\n",
|
|
374
|
+
" \n",
|
|
375
|
+
"# # Create PyTensor function (compilation time)\n",
|
|
376
|
+
"# start = time.perf_counter()\n",
|
|
377
|
+
"# pt_A = pt.matrix('A', dtype='float32')\n",
|
|
378
|
+
"# pt_B = pt.matrix('B', dtype='float32')\n",
|
|
379
|
+
"# pt_C = pt.matrix('C', dtype='float32')\n",
|
|
380
|
+
"# result_expr = pt_A @ pt_B @ pt_C\n",
|
|
381
|
+
"# f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n",
|
|
382
|
+
"# compilation_time = time.perf_counter() - start\n",
|
|
383
|
+
" \n",
|
|
384
|
+
"# # First execution (may include additional compilation/optimization)\n",
|
|
385
|
+
"# start = time.perf_counter()\n",
|
|
386
|
+
"# result = f_mlx(A, B, C)\n",
|
|
387
|
+
"# mx.eval(result) # Force evaluation\n",
|
|
388
|
+
"# first_exec_time = time.perf_counter() - start\n",
|
|
389
|
+
" \n",
|
|
390
|
+
"# # Subsequent executions (should be faster)\n",
|
|
391
|
+
"# exec_times = []\n",
|
|
392
|
+
"# for _ in range(1000):\n",
|
|
393
|
+
"# start = time.perf_counter()\n",
|
|
394
|
+
"# result = f_mlx(A, B, C)\n",
|
|
395
|
+
"# mx.eval(result)\n",
|
|
396
|
+
"# exec_times.append(time.perf_counter() - start)\n",
|
|
397
|
+
" \n",
|
|
398
|
+
"# avg_exec_time = np.mean(exec_times)\n",
|
|
399
|
+
"# std_exec_time = np.std(exec_times)\n",
|
|
400
|
+
" \n",
|
|
401
|
+
"# print(f\"Compilation time: {compilation_time:.4f}s\")\n",
|
|
402
|
+
"# print(f\"First execution: {first_exec_time:.4f}s\")\n",
|
|
403
|
+
"# print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n",
|
|
404
|
+
"# print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n"
|
|
405
|
+
]
|
|
406
|
+
},
|
|
407
|
+
{
|
|
408
|
+
"cell_type": "code",
|
|
409
|
+
"execution_count": null,
|
|
410
|
+
"metadata": {},
|
|
411
|
+
"outputs": [],
|
|
412
|
+
"source": []
|
|
413
|
+
}
|
|
414
|
+
],
|
|
415
|
+
"metadata": {
|
|
416
|
+
"kernelspec": {
|
|
417
|
+
"display_name": "mlx_env",
|
|
418
|
+
"language": "python",
|
|
419
|
+
"name": "python3"
|
|
420
|
+
},
|
|
421
|
+
"language_info": {
|
|
422
|
+
"codemirror_mode": {
|
|
423
|
+
"name": "ipython",
|
|
424
|
+
"version": 3
|
|
425
|
+
},
|
|
426
|
+
"file_extension": ".py",
|
|
427
|
+
"mimetype": "text/x-python",
|
|
428
|
+
"name": "python",
|
|
429
|
+
"nbconvert_exporter": "python",
|
|
430
|
+
"pygments_lexer": "ipython3",
|
|
431
|
+
"version": "3.12.2"
|
|
432
|
+
}
|
|
433
|
+
},
|
|
434
|
+
"nbformat": 4,
|
|
435
|
+
"nbformat_minor": 2
|
|
436
|
+
}
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
requires = [
|
|
3
3
|
"setuptools>=59.0.0",
|
|
4
4
|
"cython",
|
|
5
|
-
"numpy>=
|
|
5
|
+
"numpy>=2.0",
|
|
6
6
|
"versioneer[toml]==0.29",
|
|
7
7
|
]
|
|
8
8
|
build-backend = "setuptools.build_meta"
|
|
@@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"
|
|
|
10
10
|
[project]
|
|
11
11
|
name = "pytensor"
|
|
12
12
|
dynamic = ['version']
|
|
13
|
-
requires-python = ">=3.
|
|
13
|
+
requires-python = ">=3.11,<3.14"
|
|
14
14
|
authors = [{ name = "pymc-devs", email = "pymc.devs@gmail.com" }]
|
|
15
15
|
description = "Optimizing compiler for evaluating mathematical expressions on CPUs and GPUs."
|
|
16
16
|
readme = "README.rst"
|
|
@@ -30,7 +30,6 @@ classifiers = [
|
|
|
30
30
|
"Operating System :: Unix",
|
|
31
31
|
"Operating System :: MacOS",
|
|
32
32
|
"Programming Language :: Python :: 3",
|
|
33
|
-
"Programming Language :: Python :: 3.10",
|
|
34
33
|
"Programming Language :: Python :: 3.11",
|
|
35
34
|
"Programming Language :: Python :: 3.12",
|
|
36
35
|
"Programming Language :: Python :: 3.13",
|
|
@@ -49,7 +48,7 @@ keywords = [
|
|
|
49
48
|
dependencies = [
|
|
50
49
|
"setuptools>=59.0.0",
|
|
51
50
|
"scipy>=1,<2",
|
|
52
|
-
"numpy>=
|
|
51
|
+
"numpy>=2.0",
|
|
53
52
|
"filelock>=3.15",
|
|
54
53
|
"etuples",
|
|
55
54
|
"logical-unification",
|
|
@@ -169,7 +168,7 @@ lines-after-imports = 2
|
|
|
169
168
|
|
|
170
169
|
|
|
171
170
|
[tool.mypy]
|
|
172
|
-
python_version = "3.
|
|
171
|
+
python_version = "3.11"
|
|
173
172
|
ignore_missing_imports = true
|
|
174
173
|
strict_equality = true
|
|
175
174
|
warn_redundant_casts = true
|
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2025-10-
|
|
11
|
+
"date": "2025-10-14T09:22:05+0200",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "2.
|
|
14
|
+
"full-revisionid": "f772066a9691c64430e100581ab1c2398e43451e",
|
|
15
|
+
"version": "2.35.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
|
+
from pathlib import Path
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
if sys.platform == "win32":
|
|
@@ -24,7 +25,7 @@ _logger = logging.getLogger("pytensor.bin.pytensor-cache")
|
|
|
24
25
|
|
|
25
26
|
def print_help(exit_status):
|
|
26
27
|
if exit_status:
|
|
27
|
-
print(f
|
|
28
|
+
print(f'command "{" ".join(sys.argv)}" not recognized')
|
|
28
29
|
print('Type "pytensor-cache" to print the cache location')
|
|
29
30
|
print('Type "pytensor-cache help" to print this help')
|
|
30
31
|
print('Type "pytensor-cache clear" to erase the cache')
|
|
@@ -65,11 +66,7 @@ def main():
|
|
|
65
66
|
# Print a warning if some cached modules were not removed, so that the
|
|
66
67
|
# user knows he should manually delete them, or call
|
|
67
68
|
# pytensor-cache purge, # to properly clear the cache.
|
|
68
|
-
items =
|
|
69
|
-
item
|
|
70
|
-
for item in sorted(os.listdir(cache.dirname))
|
|
71
|
-
if item.startswith("tmp")
|
|
72
|
-
]
|
|
69
|
+
items = list(Path(cache.dirname).glob("tmp*"))
|
|
73
70
|
if items:
|
|
74
71
|
_logger.warning(
|
|
75
72
|
"There remain elements in the cache dir that you may "
|
|
@@ -122,7 +122,7 @@ def construct_nominal_fgraph(
|
|
|
122
122
|
(
|
|
123
123
|
local_inputs,
|
|
124
124
|
local_outputs,
|
|
125
|
-
(
|
|
125
|
+
(_clone_d, update_d, update_expr, new_shared_inputs),
|
|
126
126
|
) = new
|
|
127
127
|
|
|
128
128
|
assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs)
|
|
@@ -12,7 +12,7 @@ from pytensor.compile.profiling import ProfileStats
|
|
|
12
12
|
from pytensor.graph import Variable
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
__all__ = ["
|
|
15
|
+
__all__ = ["pfunc", "types"]
|
|
16
16
|
|
|
17
17
|
__docformat__ = "restructuredtext en"
|
|
18
18
|
_logger = logging.getLogger("pytensor.compile.function")
|
|
@@ -328,8 +328,7 @@ def rebuild_collect_shared(
|
|
|
328
328
|
cloned_outputs = [] # TODO: get Function.__call__ to return None
|
|
329
329
|
else:
|
|
330
330
|
raise TypeError(
|
|
331
|
-
"output must be an PyTensor Variable or Out "
|
|
332
|
-
"instance (or list of them)",
|
|
331
|
+
"output must be an PyTensor Variable or Out instance (or list of them)",
|
|
333
332
|
outputs,
|
|
334
333
|
)
|
|
335
334
|
|
|
@@ -592,7 +591,7 @@ def construct_pfunc_ins_and_outs(
|
|
|
592
591
|
clone_inner_graphs=True,
|
|
593
592
|
)
|
|
594
593
|
input_variables, cloned_extended_outputs, other_stuff = output_vars
|
|
595
|
-
clone_d, update_d,
|
|
594
|
+
clone_d, update_d, _update_expr, shared_inputs = other_stuff
|
|
596
595
|
|
|
597
596
|
# Recover only the clones of the original outputs
|
|
598
597
|
if outputs is None:
|
|
@@ -215,7 +215,7 @@ def add_supervisor_to_fgraph(
|
|
|
215
215
|
input
|
|
216
216
|
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
|
|
217
217
|
if not (
|
|
218
|
-
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input])
|
|
218
|
+
spec.mutable or (has_destroy_handler and fgraph.has_destroyers([input]))
|
|
219
219
|
)
|
|
220
220
|
)
|
|
221
221
|
)
|
|
@@ -27,6 +27,7 @@ from pytensor.graph.rewriting.db import (
|
|
|
27
27
|
from pytensor.link.basic import Linker, PerformLinker
|
|
28
28
|
from pytensor.link.c.basic import CLinker, OpWiseCLinker
|
|
29
29
|
from pytensor.link.jax.linker import JAXLinker
|
|
30
|
+
from pytensor.link.mlx.linker import MLXLinker
|
|
30
31
|
from pytensor.link.numba.linker import NumbaLinker
|
|
31
32
|
from pytensor.link.pytorch.linker import PytorchLinker
|
|
32
33
|
from pytensor.link.vm import VMLinker
|
|
@@ -50,6 +51,7 @@ predefined_linkers = {
|
|
|
50
51
|
"jax": JAXLinker(),
|
|
51
52
|
"pytorch": PytorchLinker(),
|
|
52
53
|
"numba": NumbaLinker(),
|
|
54
|
+
"mlx": MLXLinker(),
|
|
53
55
|
}
|
|
54
56
|
|
|
55
57
|
|
|
@@ -407,7 +409,7 @@ class Mode:
|
|
|
407
409
|
optimizations.
|
|
408
410
|
"""
|
|
409
411
|
|
|
410
|
-
|
|
412
|
+
_link, opt = self.get_linker_optimizer(
|
|
411
413
|
self.provided_linker, self.provided_optimizer
|
|
412
414
|
)
|
|
413
415
|
return self.clone(optimizer=opt.register(*optimizations))
|
|
@@ -504,6 +506,20 @@ PYTORCH = Mode(
|
|
|
504
506
|
),
|
|
505
507
|
)
|
|
506
508
|
|
|
509
|
+
MLX = Mode(
|
|
510
|
+
MLXLinker(),
|
|
511
|
+
RewriteDatabaseQuery(
|
|
512
|
+
include=["fast_run"],
|
|
513
|
+
exclude=[
|
|
514
|
+
"cxx_only",
|
|
515
|
+
"BlasOpt",
|
|
516
|
+
"fusion",
|
|
517
|
+
"inplace",
|
|
518
|
+
"scan_save_mem_prealloc",
|
|
519
|
+
],
|
|
520
|
+
),
|
|
521
|
+
)
|
|
522
|
+
|
|
507
523
|
|
|
508
524
|
predefined_modes = {
|
|
509
525
|
"FAST_COMPILE": FAST_COMPILE,
|
|
@@ -511,6 +527,7 @@ predefined_modes = {
|
|
|
511
527
|
"JAX": JAX,
|
|
512
528
|
"NUMBA": NUMBA,
|
|
513
529
|
"PYTORCH": PYTORCH,
|
|
530
|
+
"MLX": MLX,
|
|
514
531
|
}
|
|
515
532
|
|
|
516
533
|
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
|