saiunit 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. {saiunit-0.2.0 → saiunit-0.2.2}/PKG-INFO +72 -5
  2. {saiunit-0.2.0 → saiunit-0.2.2}/README.md +57 -3
  3. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/__init__.py +33 -1
  4. saiunit-0.2.2/brainunit/brainunit/_backend.py +17 -0
  5. saiunit-0.2.2/brainunit/brainunit/_exceptions.py +17 -0
  6. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/linalg/__init__.py +6 -2
  7. {saiunit-0.2.0 → saiunit-0.2.2}/pyproject.toml +11 -1
  8. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/__init__.py +75 -12
  9. saiunit-0.2.2/saiunit/_backend.py +390 -0
  10. saiunit-0.2.2/saiunit/_backend_parametrize_test.py +126 -0
  11. saiunit-0.2.2/saiunit/_backend_test.py +503 -0
  12. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_decorators.py +16 -16
  13. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_dimension.py +28 -13
  14. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_getters.py +74 -11
  15. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_quantity.py +651 -171
  16. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_quantity_test.py +241 -15
  17. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_unit.py +303 -88
  18. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_unit_test.py +156 -46
  19. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_celsius.py +70 -1
  20. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_celsius_test.py +40 -0
  21. saiunit-0.2.2/saiunit/_compatible_import.py +129 -0
  22. saiunit-0.2.2/saiunit/_dask_laziness_test.py +139 -0
  23. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_display_test.py +30 -0
  24. saiunit-0.2.2/saiunit/_exceptions.py +24 -0
  25. saiunit-0.2.2/saiunit/_exceptions_test.py +26 -0
  26. saiunit-0.2.2/saiunit/_jax_compat.py +279 -0
  27. saiunit-0.2.2/saiunit/_jax_guard.py +90 -0
  28. saiunit-0.2.2/saiunit/_jax_guard_test.py +104 -0
  29. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_matplotlib_compat.py +26 -1
  30. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_matplotlib_compat_test.py +6 -3
  31. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_misc.py +10 -2
  32. saiunit-0.2.2/saiunit/_ndonnx_test.py +59 -0
  33. saiunit-0.2.2/saiunit/_no_jax_test.py +113 -0
  34. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_sparse_base.py +84 -13
  35. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_unit_common.py +40 -0
  36. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_unit_constants.py +2 -11
  37. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_unit_constants_test.py +0 -13
  38. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_version.py +1 -1
  39. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/_jacobian.py +46 -24
  40. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/constants.py +12 -17
  41. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/constants_test.py +0 -6
  42. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/custom_array.py +33 -58
  43. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/fft/_fft_change_unit.py +64 -44
  44. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/fft/_fft_change_unit_test.py +11 -0
  45. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/fft/_fft_keep_unit.py +14 -6
  46. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_accept_unitless.py +59 -31
  47. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_array_creation.py +12 -9
  48. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_change_unit.py +54 -40
  49. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_keep_unit.py +162 -90
  50. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_keep_unit_test.py +45 -0
  51. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_misc.py +10 -7
  52. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/linalg/_linalg_change_unit.py +26 -20
  53. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/linalg/_linalg_keep_unit.py +44 -26
  54. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/linalg/_linalg_keep_unit_test.py +11 -0
  55. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/linalg/_linalg_remove_unit.py +11 -9
  56. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_activation.py +8 -2
  57. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_einops.py +124 -125
  58. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_exprel.py +32 -20
  59. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_accept_unitless.py +110 -60
  60. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_accept_unitless_test.py +11 -0
  61. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_array_creation.py +104 -81
  62. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_array_creation_test.py +34 -0
  63. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_change_unit.py +183 -97
  64. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_change_unit_test.py +115 -0
  65. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_keep_unit.py +399 -202
  66. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_keep_unit_test.py +27 -0
  67. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_remove_unit.py +120 -64
  68. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_fun_remove_unit_test.py +43 -0
  69. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_misc.py +53 -33
  70. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/sparse/_coo.py +13 -10
  71. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/sparse/_csr.py +35 -25
  72. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/sparse/_csr_test.py +20 -0
  73. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/typing.py +3 -2
  74. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit.egg-info/PKG-INFO +72 -5
  75. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit.egg-info/SOURCES.txt +13 -7
  76. saiunit-0.2.2/saiunit.egg-info/requires.txt +38 -0
  77. saiunit-0.2.0/brainunit/brainunit/sparse/_block_csr.py +0 -20
  78. saiunit-0.2.0/brainunit/brainunit/sparse/_block_ell.py +0 -20
  79. saiunit-0.2.0/saiunit/_compatible_import.py +0 -89
  80. saiunit-0.2.0/saiunit/sparse/_block_csr.py +0 -249
  81. saiunit-0.2.0/saiunit/sparse/_block_csr_benchmark.py +0 -99
  82. saiunit-0.2.0/saiunit/sparse/_block_ell.py +0 -294
  83. saiunit-0.2.0/saiunit/sparse/_block_ell_benchmark.py +0 -99
  84. saiunit-0.2.0/saiunit/sparse/_block_sparse_test.py +0 -55
  85. saiunit-0.2.0/saiunit.egg-info/requires.txt +0 -19
  86. {saiunit-0.2.0 → saiunit-0.2.2}/LICENSE +0 -0
  87. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_base_decorators.py +0 -0
  88. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_base_dimension.py +0 -0
  89. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_base_getters.py +0 -0
  90. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_base_quantity.py +0 -0
  91. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_base_unit.py +0 -0
  92. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_celsius.py +0 -0
  93. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_misc.py +0 -0
  94. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_unit_common.py +0 -0
  95. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_unit_constants.py +0 -0
  96. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/_unit_shortcuts.py +0 -0
  97. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/autograd/__init__.py +0 -0
  98. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/autograd/_hessian.py +0 -0
  99. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/autograd/_jacobian.py +0 -0
  100. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/autograd/_value_and_grad.py +0 -0
  101. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/autograd/_vector_grad.py +0 -0
  102. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/constants.py +0 -0
  103. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/custom_array.py +0 -0
  104. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/fft/__init__.py +0 -0
  105. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/fft/_fft_change_unit.py +0 -0
  106. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/fft/_fft_keep_unit.py +0 -0
  107. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/lax/__init__.py +0 -0
  108. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/lax/_lax_accept_unitless.py +0 -0
  109. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/lax/_lax_array_creation.py +0 -0
  110. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/lax/_lax_change_unit.py +0 -0
  111. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/lax/_lax_keep_unit.py +0 -0
  112. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/lax/_lax_linalg.py +0 -0
  113. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/lax/_lax_remove_unit.py +0 -0
  114. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/lax/_misc.py +0 -0
  115. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/linalg/_linalg_change_unit.py +0 -0
  116. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/linalg/_linalg_keep_unit.py +0 -0
  117. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/linalg/_linalg_remove_unit.py +0 -0
  118. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/__init__.py +0 -0
  119. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/_activation.py +0 -0
  120. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/_alias.py +0 -0
  121. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/_einops.py +0 -0
  122. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/_fun_accept_unitless.py +0 -0
  123. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/_fun_array_creation.py +0 -0
  124. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/_fun_change_unit.py +0 -0
  125. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/_fun_keep_unit.py +0 -0
  126. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/_fun_remove_unit.py +0 -0
  127. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/_misc.py +0 -0
  128. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/fft.py +0 -0
  129. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/math/linalg.py +0 -0
  130. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/sparse/__init__.py +0 -0
  131. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/sparse/_coo.py +0 -0
  132. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/sparse/_csr.py +0 -0
  133. {saiunit-0.2.0 → saiunit-0.2.2}/brainunit/brainunit/typing.py +0 -0
  134. {saiunit-0.2.0 → saiunit-0.2.2}/examples/matplotlib_plain_basics.py +0 -0
  135. {saiunit-0.2.0 → saiunit-0.2.2}/examples/matplotlib_quantity_basics.py +0 -0
  136. {saiunit-0.2.0 → saiunit-0.2.2}/examples/matplotlib_quantity_interactive.py +0 -0
  137. {saiunit-0.2.0 → saiunit-0.2.2}/examples/matplotlib_quantity_vs_plain.py +0 -0
  138. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_decorators_test.py +0 -0
  139. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_dimension_test.py +0 -0
  140. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_base_getters_test.py +0 -0
  141. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_misc_test.py +0 -0
  142. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/_unit_shortcuts.py +0 -0
  143. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/__init__.py +0 -0
  144. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/_hessian.py +0 -0
  145. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/_hessian_test.py +0 -0
  146. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/_jacobian_test.py +0 -0
  147. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/_misc.py +0 -0
  148. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/_value_and_grad.py +0 -0
  149. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/_value_and_grad_test.py +0 -0
  150. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/_vector_grad.py +0 -0
  151. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/autograd/_vector_grad_test.py +0 -0
  152. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/custom_array_test.py +0 -0
  153. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/fft/__init__.py +0 -0
  154. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/fft/_fft_keep_unit_test.py +0 -0
  155. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/__init__.py +0 -0
  156. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_accept_unitless_test.py +0 -0
  157. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_array_creation_test.py +0 -0
  158. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_change_unit_test.py +0 -0
  159. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_linalg.py +0 -0
  160. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_linalg_test.py +0 -0
  161. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_remove_unit.py +0 -0
  162. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_lax_remove_unit_test.py +0 -0
  163. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/lax/_misc_test.py +0 -0
  164. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/linalg/__init__.py +0 -0
  165. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/linalg/_linalg_change_unit_test.py +0 -0
  166. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/linalg/_linalg_docstring_test.py +0 -0
  167. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/linalg/_linalg_remove_unit_test.py +0 -0
  168. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/__init__.py +0 -0
  169. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_activation_test.py +0 -0
  170. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_alias.py +0 -0
  171. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_einops_parsing.py +0 -0
  172. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_einops_parsing_test.py +0 -0
  173. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_einops_test.py +0 -0
  174. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_exprel_test.py +0 -0
  175. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/_misc_test.py +0 -0
  176. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/fft.py +0 -0
  177. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/math/linalg.py +0 -0
  178. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/sparse/__init__.py +0 -0
  179. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/sparse/_coo_test.py +0 -0
  180. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit/typing_test.py +0 -0
  181. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit.egg-info/dependency_links.txt +0 -0
  182. {saiunit-0.2.0 → saiunit-0.2.2}/saiunit.egg-info/top_level.txt +0 -0
  183. {saiunit-0.2.0 → saiunit-0.2.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: saiunit
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Enabling Unit-aware Computations for AI-driven Scientific Computing.
5
5
  Author-email: SAIUnit Developers <chao.brain@qq.com>
6
6
  License-Expression: Apache-2.0
@@ -28,12 +28,15 @@ Classifier: Topic :: Software Development :: Libraries
28
28
  Requires-Python: >=3.10
29
29
  Description-Content-Type: text/markdown
30
30
  License-File: LICENSE
31
- Requires-Dist: jax
32
31
  Requires-Dist: numpy
33
32
  Requires-Dist: typing_extensions
33
+ Requires-Dist: array_api_compat>=1.9
34
+ Requires-Dist: opt_einsum
34
35
  Provides-Extra: testing
35
36
  Requires-Dist: pytest; extra == "testing"
36
37
  Requires-Dist: brainstate; extra == "testing"
38
+ Provides-Extra: jax
39
+ Requires-Dist: jax; extra == "jax"
37
40
  Provides-Extra: cpu
38
41
  Requires-Dist: jax[cpu]; extra == "cpu"
39
42
  Provides-Extra: cuda12
@@ -42,6 +45,16 @@ Provides-Extra: cuda13
42
45
  Requires-Dist: jax[cuda13]; extra == "cuda13"
43
46
  Provides-Extra: tpu
44
47
  Requires-Dist: jax[tpu]; extra == "tpu"
48
+ Provides-Extra: cupy
49
+ Requires-Dist: cupy-cuda12x>=13.0; extra == "cupy"
50
+ Provides-Extra: torch
51
+ Requires-Dist: torch>=2.0; extra == "torch"
52
+ Provides-Extra: dask
53
+ Requires-Dist: dask[array]>=2024.1; extra == "dask"
54
+ Provides-Extra: ndonnx
55
+ Requires-Dist: ndonnx>=0.9; extra == "ndonnx"
56
+ Provides-Extra: all
57
+ Requires-Dist: saiunit[cupy,dask,jax,ndonnx,torch]; extra == "all"
45
58
  Dynamic: license-file
46
59
 
47
60
  <p align="center">
@@ -124,17 +137,71 @@ jax.vmap(f)(u.math.arange(0. * u.mV, 10. * u.mV, 1. * u.mV))
124
137
 
125
138
 
126
139
 
140
+ ## Multiple-backend support
141
+
142
+ `saiunit` is **backend-agnostic**: a `Quantity` pairs a unit with an array
143
+ mantissa, and that mantissa can live on any of the supported array libraries.
144
+ Every unit-aware operation dispatches to the matching backend, so you can stay
145
+ in one library end-to-end or convert with a single method call.
146
+
147
+ | Backend | Mantissa | Install | When to use |
148
+ |----------|---------------------|--------------------------|---------------------------------------------|
149
+ | `numpy` | `numpy.ndarray` | core (always installed) | eager CPU, scipy/pandas/sklearn interop |
150
+ | `jax` | `jax.Array` | `saiunit[jax]` (or `[cpu]`/`[cuda12]`/`[cuda13]`/`[tpu]`) | autograd, JIT, vmap, accelerators |
151
+ | `cupy` | `cupy.ndarray` | `saiunit[cupy]` | NVIDIA GPU arrays |
152
+ | `torch` | `torch.Tensor` | `saiunit[torch]` | PyTorch models, torch autograd |
153
+ | `dask` | `dask.array.Array` | `saiunit[dask]` | out-of-core / parallel, lazy compute |
154
+ | `ndonnx` | `ndonnx.Array` | `saiunit[ndonnx]` | symbolic graph for ONNX export |
155
+
156
+ Select or override the backend explicitly with `u.using_backend(...)` or
157
+ `u.set_default_backend(...)`; convert with `q.to_jax()` / `q.to_numpy()` /
158
+ `q.to_cupy()` / `q.to_torch()` / `q.to_dask()` / `q.to_ndonnx()`. Requesting an
159
+ uninstalled backend raises `saiunit.BackendError` with the install command —
160
+ not a bare `ImportError`. See the
161
+ [Backends documentation](https://saiunit.readthedocs.io/en/latest/backends/overview.html)
162
+ for the full story.
163
+
164
+
127
165
  ## Installation
128
166
 
129
- ``saiunit`` has been well tested on ``python>=3.9`` + ``jax>=0.4.30`` environments, and can be installed on Windows, Linux, and MacOS.
167
+ ``saiunit`` has been well tested on ``python>=3.10`` and can be installed on
168
+ Windows, Linux, and MacOS. The core package depends only on NumPy. JAX is
169
+ optional — install it to enable the ``saiunit.autograd``, ``saiunit.lax``,
170
+ and ``saiunit.sparse`` submodules, the custom ``exprel`` primitive, and the
171
+ ``"jax"`` backend.
130
172
 
131
- You can install ``saiunit`` via pip:
173
+ Install the NumPy-only core:
132
174
 
133
175
  ```bash
134
176
  pip install saiunit --upgrade
135
177
  ```
136
178
 
137
- which should install in about 1 minute. If you want to install the latest version from the source, you can clone the repository and install it:
179
+ Or pull in JAX with the accelerator build that matches your hardware:
180
+
181
+ ```bash
182
+ pip install -U saiunit[jax] # plain JAX
183
+ pip install -U saiunit[cpu] # pinned JAX CPU wheels
184
+ pip install -U saiunit[cuda12] # JAX on CUDA 12
185
+ pip install -U saiunit[cuda13] # JAX on CUDA 13
186
+ pip install -U saiunit[tpu] # JAX on TPU
187
+ ```
188
+
189
+ Opt into additional array backends with the matching extra:
190
+
191
+ ```bash
192
+ pip install -U saiunit[cupy] # CuPy (NVIDIA GPU)
193
+ pip install -U saiunit[torch] # PyTorch
194
+ pip install -U saiunit[dask] # Dask
195
+ pip install -U saiunit[ndonnx] # ndonnx
196
+ pip install -U saiunit[all] # jax + cupy + torch + dask + ndonnx
197
+ ```
198
+
199
+ Without JAX, the NumPy backend is auto-selected and any access to a
200
+ JAX-only submodule (``saiunit.autograd``, ``saiunit.lax``, ``saiunit.sparse``)
201
+ raises ``saiunit.BackendError`` with an install hint. The optional extras are
202
+ independent and can be combined freely.
203
+
204
+ To install the latest version from source:
138
205
 
139
206
  ```bash
140
207
  git clone https://github.com/chaobrain/saiunit.git
@@ -78,17 +78,71 @@ jax.vmap(f)(u.math.arange(0. * u.mV, 10. * u.mV, 1. * u.mV))
78
78
 
79
79
 
80
80
 
81
+ ## Multiple-backend support
82
+
83
+ `saiunit` is **backend-agnostic**: a `Quantity` pairs a unit with an array
84
+ mantissa, and that mantissa can live on any of the supported array libraries.
85
+ Every unit-aware operation dispatches to the matching backend, so you can stay
86
+ in one library end-to-end or convert with a single method call.
87
+
88
+ | Backend | Mantissa | Install | When to use |
89
+ |----------|---------------------|--------------------------|---------------------------------------------|
90
+ | `numpy` | `numpy.ndarray` | core (always installed) | eager CPU, scipy/pandas/sklearn interop |
91
+ | `jax` | `jax.Array` | `saiunit[jax]` (or `[cpu]`/`[cuda12]`/`[cuda13]`/`[tpu]`) | autograd, JIT, vmap, accelerators |
92
+ | `cupy` | `cupy.ndarray` | `saiunit[cupy]` | NVIDIA GPU arrays |
93
+ | `torch` | `torch.Tensor` | `saiunit[torch]` | PyTorch models, torch autograd |
94
+ | `dask` | `dask.array.Array` | `saiunit[dask]` | out-of-core / parallel, lazy compute |
95
+ | `ndonnx` | `ndonnx.Array` | `saiunit[ndonnx]` | symbolic graph for ONNX export |
96
+
97
+ Select or override the backend explicitly with `u.using_backend(...)` or
98
+ `u.set_default_backend(...)`; convert with `q.to_jax()` / `q.to_numpy()` /
99
+ `q.to_cupy()` / `q.to_torch()` / `q.to_dask()` / `q.to_ndonnx()`. Requesting an
100
+ uninstalled backend raises `saiunit.BackendError` with the install command —
101
+ not a bare `ImportError`. See the
102
+ [Backends documentation](https://saiunit.readthedocs.io/en/latest/backends/overview.html)
103
+ for the full story.
104
+
105
+
81
106
  ## Installation
82
107
 
83
- ``saiunit`` has been well tested on ``python>=3.9`` + ``jax>=0.4.30`` environments, and can be installed on Windows, Linux, and MacOS.
108
+ ``saiunit`` has been well tested on ``python>=3.10`` and can be installed on
109
+ Windows, Linux, and MacOS. The core package depends only on NumPy. JAX is
110
+ optional — install it to enable the ``saiunit.autograd``, ``saiunit.lax``,
111
+ and ``saiunit.sparse`` submodules, the custom ``exprel`` primitive, and the
112
+ ``"jax"`` backend.
84
113
 
85
- You can install ``saiunit`` via pip:
114
+ Install the NumPy-only core:
86
115
 
87
116
  ```bash
88
117
  pip install saiunit --upgrade
89
118
  ```
90
119
 
91
- which should install in about 1 minute. If you want to install the latest version from the source, you can clone the repository and install it:
120
+ Or pull in JAX with the accelerator build that matches your hardware:
121
+
122
+ ```bash
123
+ pip install -U saiunit[jax] # plain JAX
124
+ pip install -U saiunit[cpu] # pinned JAX CPU wheels
125
+ pip install -U saiunit[cuda12] # JAX on CUDA 12
126
+ pip install -U saiunit[cuda13] # JAX on CUDA 13
127
+ pip install -U saiunit[tpu] # JAX on TPU
128
+ ```
129
+
130
+ Opt into additional array backends with the matching extra:
131
+
132
+ ```bash
133
+ pip install -U saiunit[cupy] # CuPy (NVIDIA GPU)
134
+ pip install -U saiunit[torch] # PyTorch
135
+ pip install -U saiunit[dask] # Dask
136
+ pip install -U saiunit[ndonnx] # ndonnx
137
+ pip install -U saiunit[all] # jax + cupy + torch + dask + ndonnx
138
+ ```
139
+
140
+ Without JAX, the NumPy backend is auto-selected and any access to a
141
+ JAX-only submodule (``saiunit.autograd``, ``saiunit.lax``, ``saiunit.sparse``)
142
+ raises ``saiunit.BackendError`` with an install hint. The optional extras are
143
+ independent and can be combined freely.
144
+
145
+ To install the latest version from source:
92
146
 
93
147
  ```bash
94
148
  git clone https://github.com/chaobrain/saiunit.git
@@ -18,6 +18,7 @@ __version__ = saiunit.__version__
18
18
  __version_info__ = saiunit.__version_info__
19
19
 
20
20
  from . import autograd
21
+ from saiunit._matplotlib_compat import enable_matplotlib_support
21
22
  from . import constants
22
23
  from . import fft
23
24
  from . import lax
@@ -34,6 +35,18 @@ from ._base_dimension import (
34
35
  get_dim_for_display,
35
36
  get_or_create_dimension,
36
37
  )
38
+ from ._backend import (
39
+ get_default_backend,
40
+ is_cupy_array,
41
+ is_dask_array,
42
+ is_jax_array,
43
+ is_ndonnx_array,
44
+ is_numpy_array,
45
+ is_torch_array,
46
+ set_default_backend,
47
+ using_backend,
48
+ )
49
+ from ._exceptions import BackendError
37
50
  from ._base_getters import (
38
51
  array_with_unit,
39
52
  assert_quantity,
@@ -48,6 +61,7 @@ from ._base_getters import (
48
61
  have_same_dim,
49
62
  is_dimensionless,
50
63
  is_scalar_type,
64
+ is_unit_equal_math,
51
65
  is_unitless,
52
66
  maybe_decimal,
53
67
  split_mantissa_unit,
@@ -55,7 +69,7 @@ from ._base_getters import (
55
69
  )
56
70
  from ._base_quantity import Quantity, compatible_with_equinox
57
71
  from ._base_unit import UNITLESS, Unit, add_standard_unit, parse_unit
58
- from ._celsius import celsius2kelvin, kelvin2celsius
72
+ from ._celsius import celsius2kelvin, kelvin2celsius, fahrenheit2kelvin, kelvin2fahrenheit
59
73
  from ._misc import maybe_custom_array, maybe_custom_array_tree
60
74
  from ._unit_common import *
61
75
  from ._unit_common import __all__ as _common_all
@@ -100,6 +114,18 @@ __all__ = [
100
114
  'DIMENSIONLESS',
101
115
  'DimensionMismatchError',
102
116
  'UnitMismatchError',
117
+ 'BackendError',
118
+
119
+ # _backend
120
+ 'get_default_backend',
121
+ 'set_default_backend',
122
+ 'using_backend',
123
+ 'is_jax_array',
124
+ 'is_numpy_array',
125
+ 'is_cupy_array',
126
+ 'is_torch_array',
127
+ 'is_dask_array',
128
+ 'is_ndonnx_array',
103
129
  'get_or_create_dimension',
104
130
  'get_dim_for_display',
105
131
 
@@ -125,6 +151,7 @@ __all__ = [
125
151
  'assert_quantity',
126
152
  'have_same_dim',
127
153
  'has_same_unit',
154
+ 'is_unit_equal_math',
128
155
  'unit_scale_align_to_first',
129
156
  'array_with_unit',
130
157
 
@@ -140,6 +167,11 @@ __all__ = [
140
167
  # _celsius
141
168
  'celsius2kelvin',
142
169
  'kelvin2celsius',
170
+ 'fahrenheit2kelvin',
171
+ 'kelvin2fahrenheit',
172
+
173
+ # _matplotlib_compat
174
+ 'enable_matplotlib_support',
143
175
 
144
176
  # old version compatibility
145
177
  'avogadro_constant',
@@ -0,0 +1,17 @@
1
+ # Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from saiunit._backend import *
17
+ from saiunit._backend import __all__
@@ -0,0 +1,17 @@
1
+ # Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from saiunit._exceptions import *
17
+ from saiunit._exceptions import __all__
@@ -17,9 +17,13 @@ from ._linalg_change_unit import *
17
17
  from ._linalg_change_unit import __all__ as _linalg_change_unit_all
18
18
  from ._linalg_keep_unit import *
19
19
  from ._linalg_keep_unit import __all__ as _linalg_keep_unit_all
20
+ from ._linalg_remove_unit import *
21
+ from ._linalg_remove_unit import __all__ as _linalg_remove_unit_all
20
22
 
21
23
  __all__ = (_linalg_change_unit_all +
22
- _linalg_keep_unit_all)
24
+ _linalg_keep_unit_all +
25
+ _linalg_remove_unit_all)
23
26
 
24
27
  del (_linalg_change_unit_all,
25
- _linalg_keep_unit_all)
28
+ _linalg_keep_unit_all,
29
+ _linalg_remove_unit_all)
@@ -47,25 +47,35 @@ classifiers = [
47
47
  keywords = ['physical unit', 'physical quantity', 'scientific computing', 'AI for science', ]
48
48
 
49
49
  dependencies = [
50
- 'jax',
51
50
  'numpy',
52
51
  'typing_extensions',
52
+ 'array_api_compat>=1.9',
53
+ 'opt_einsum',
53
54
  ]
54
55
 
55
56
  dynamic = ['version']
56
57
 
58
+
57
59
  [tool.flit.module]
58
60
  name = "saiunit"
59
61
 
62
+
60
63
  [project.urls]
61
64
  homepage = 'https://github.com/chaobrain/saiunit'
62
65
  repository = 'https://github.com/chaobrain/saiunit'
63
66
  "Bug Tracker" = "https://github.com/chaobrain/saiunit/issues"
64
67
  "Documentation" = "https://saiunit.readthedocs.io/"
65
68
 
69
+
66
70
  [project.optional-dependencies]
67
71
  testing = ['pytest', 'brainstate']
72
+ jax = ["jax"]
68
73
  cpu = ["jax[cpu]"]
69
74
  cuda12 = ["jax[cuda12]"]
70
75
  cuda13 = ["jax[cuda13]"]
71
76
  tpu = ["jax[tpu]"]
77
+ cupy = ["cupy-cuda12x>=13.0"]
78
+ torch = ["torch>=2.0"]
79
+ dask = ["dask[array]>=2024.1"]
80
+ ndonnx = ["ndonnx>=0.9"]
81
+ all = ["saiunit[jax,cupy,torch,dask,ndonnx]"]
@@ -14,13 +14,15 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """
17
- saiunit -- Physical units for JAX arrays.
17
+ saiunit -- Physical units for JAX and NumPy arrays.
18
18
 
19
- ``saiunit`` provides a :class:`Quantity` type that pairs a JAX array with a
20
- physical :class:`Unit`, ensuring dimensional correctness at every arithmetic
21
- operation. It also supplies the standard SI base and derived units (e.g.
22
- ``meter``, ``second``, ``volt``), physical constants, and unit-aware wrappers
23
- for NumPy/JAX math functions.
19
+ ``saiunit`` provides a :class:`Quantity` type that pairs a JAX array or
20
+ NumPy array with a physical :class:`Unit`, ensuring dimensional correctness
21
+ at every arithmetic operation. The backend is detected from the mantissa
22
+ type; users can force a default with :func:`set_default_backend` or the
23
+ :func:`using_backend` context manager. ``saiunit`` also supplies the standard
24
+ SI base and derived units (e.g. ``meter``, ``second``, ``volt``), physical
25
+ constants, and unit-aware wrappers for NumPy/JAX math functions.
24
26
 
25
27
  Subpackages
26
28
  -----------
@@ -52,15 +54,13 @@ Examples
52
54
  True
53
55
  """
54
56
 
55
- from . import _matplotlib_compat
56
- from . import autograd
57
+ from ._matplotlib_compat import enable_matplotlib_support
57
58
  from . import constants
58
59
  from . import fft
59
- from . import lax
60
60
  from . import linalg
61
61
  from . import math
62
- from . import sparse
63
62
  from . import typing
63
+ from ._jax_compat import HAS_JAX as _HAS_JAX
64
64
  from ._base_decorators import assign_units, check_dims, check_units
65
65
  from ._base_dimension import (
66
66
  DIMENSIONLESS,
@@ -70,6 +70,18 @@ from ._base_dimension import (
70
70
  get_dim_for_display,
71
71
  get_or_create_dimension,
72
72
  )
73
+ from ._backend import (
74
+ get_default_backend,
75
+ is_cupy_array,
76
+ is_dask_array,
77
+ is_jax_array,
78
+ is_ndonnx_array,
79
+ is_numpy_array,
80
+ is_torch_array,
81
+ set_default_backend,
82
+ using_backend,
83
+ )
84
+ from ._exceptions import BackendError
73
85
  from ._base_getters import (
74
86
  array_with_unit,
75
87
  assert_quantity,
@@ -84,6 +96,7 @@ from ._base_getters import (
84
96
  have_same_dim,
85
97
  is_dimensionless,
86
98
  is_scalar_type,
99
+ is_unit_equal_math,
87
100
  is_unitless,
88
101
  maybe_decimal,
89
102
  split_mantissa_unit,
@@ -91,7 +104,7 @@ from ._base_getters import (
91
104
  )
92
105
  from ._base_quantity import Quantity, compatible_with_equinox
93
106
  from ._base_unit import UNITLESS, Unit, add_standard_unit, parse_unit
94
- from ._celsius import celsius2kelvin, kelvin2celsius
107
+ from ._celsius import celsius2kelvin, kelvin2celsius, fahrenheit2kelvin, kelvin2fahrenheit
95
108
  from ._misc import maybe_custom_array, maybe_custom_array_tree
96
109
  from ._unit_common import *
97
110
  from ._unit_common import __all__ as _common_all
@@ -137,6 +150,18 @@ __all__ = [
137
150
  'DIMENSIONLESS',
138
151
  'DimensionMismatchError',
139
152
  'UnitMismatchError',
153
+ 'BackendError',
154
+
155
+ # _backend
156
+ 'get_default_backend',
157
+ 'set_default_backend',
158
+ 'using_backend',
159
+ 'is_jax_array',
160
+ 'is_numpy_array',
161
+ 'is_cupy_array',
162
+ 'is_torch_array',
163
+ 'is_dask_array',
164
+ 'is_ndonnx_array',
140
165
  'get_or_create_dimension',
141
166
  'get_dim_for_display',
142
167
 
@@ -162,6 +187,7 @@ __all__ = [
162
187
  'assert_quantity',
163
188
  'have_same_dim',
164
189
  'has_same_unit',
190
+ 'is_unit_equal_math',
165
191
  'unit_scale_align_to_first',
166
192
  'array_with_unit',
167
193
 
@@ -177,6 +203,11 @@ __all__ = [
177
203
  # _celsius
178
204
  'celsius2kelvin',
179
205
  'kelvin2celsius',
206
+ 'fahrenheit2kelvin',
207
+ 'kelvin2fahrenheit',
208
+
209
+ # _matplotlib_compat
210
+ 'enable_matplotlib_support',
180
211
 
181
212
  # old version compatibility
182
213
  'avogadro_constant',
@@ -189,4 +220,36 @@ __all__ = [
189
220
  'magnetic_constant',
190
221
  'molar_mass_constant',
191
222
  ] + _common_all + _std_units_all + _constants_all
192
- del _common_all, _std_units_all, _matplotlib_compat, _constants_all
223
+ del _common_all, _std_units_all, _constants_all
224
+
225
+ # ---------------------------------------------------------------------------
226
+ # Lazy submodule loading for JAX-only features.
227
+ #
228
+ # ``autograd``, ``lax`` and ``sparse`` use JAX primitives that have no NumPy
229
+ # equivalent. Importing them eagerly would force a hard JAX dependency on
230
+ # every ``import saiunit`` call. Instead, expose them through a module-level
231
+ # ``__getattr__``: the first attribute access triggers the real import and,
232
+ # if JAX is missing, raises :class:`~saiunit._exceptions.BackendError` with
233
+ # an actionable install hint.
234
+ # ---------------------------------------------------------------------------
235
+
236
+ _JAX_ONLY_SUBMODULES = ("autograd", "lax", "sparse")
237
+ __all__ = __all__ + ["autograd", "lax", "sparse"]
238
+
239
+
240
+ def __getattr__(name):
241
+ if name in _JAX_ONLY_SUBMODULES:
242
+ if not _HAS_JAX:
243
+ from ._exceptions import BackendError
244
+ raise BackendError(
245
+ f"saiunit.{name} requires JAX. Install with: pip install saiunit[jax]"
246
+ )
247
+ import importlib
248
+ mod = importlib.import_module(f"saiunit.{name}")
249
+ globals()[name] = mod
250
+ return mod
251
+ raise AttributeError(f"module 'saiunit' has no attribute {name!r}")
252
+
253
+
254
+ def __dir__():
255
+ return sorted(set(__all__) | set(globals()))