openscvx 0.4.1.dev88__tar.gz → 0.4.1.dev90__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {openscvx-0.4.1.dev88/openscvx.egg-info → openscvx-0.4.1.dev90}/PKG-INFO +1 -1
  2. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/_version.py +3 -3
  3. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/discretization/linearize_discretize.py +27 -40
  4. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/expert/lowering.py +39 -17
  5. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/propagation/propagation.py +8 -12
  6. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/augmentation.py +1 -1
  7. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/builder.py +13 -1
  8. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/expr.py +2 -35
  9. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lower.py +1 -16
  10. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/utils/caching.py +0 -3
  11. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90/openscvx.egg-info}/PKG-INFO +1 -1
  12. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/test_augmentation.py +6 -0
  13. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_discretization.py +23 -27
  14. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_propagation.py +14 -20
  15. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/assets/logo.svg +0 -0
  16. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/release-drafter.yml +0 -0
  17. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/workflows/_docs.yml +0 -0
  18. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/workflows/branch-name.yml +0 -0
  19. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/workflows/docs.yml +0 -0
  20. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/workflows/lint.yml +0 -0
  21. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/workflows/nightly.yml +0 -0
  22. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/workflows/release-drafter.yml +0 -0
  23. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/workflows/release.yml +0 -0
  24. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/workflows/tests-integration.yml +0 -0
  25. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.github/workflows/tests-unit.yml +0 -0
  26. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/.gitignore +0 -0
  27. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/CONTRIBUTING.md +0 -0
  28. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/LICENSE +0 -0
  29. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/README.md +0 -0
  30. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/Foundations/constraint_reformulation.md +0 -0
  31. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/Foundations/control_parameterization.md +0 -0
  32. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/Foundations/discretization.md +0 -0
  33. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/Foundations/ocp.md +0 -0
  34. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/Foundations/scvx.md +0 -0
  35. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/Foundations/time_dilation.md +0 -0
  36. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UnderTheHood/lowering_architecture.md +0 -0
  37. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UnderTheHood/vectorization_and_vmapping.md +0 -0
  38. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UsersGuide/00_introduction.md +0 -0
  39. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UsersGuide/01_hello_world_brachistochrone.md +0 -0
  40. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UsersGuide/02_drone_racing_constraints.md +0 -0
  41. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UsersGuide/03_obstacle_avoidance_vmap.md +0 -0
  42. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UsersGuide/04_viewpoint_constraints.md +0 -0
  43. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UsersGuide/05_visualization.md +0 -0
  44. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UsersGuide/06_logic.md +0 -0
  45. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/UsersGuide/07_lie.md +0 -0
  46. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/assets/favicon.png +0 -0
  47. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/assets/images/ct-scvx_dark.png +0 -0
  48. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/assets/images/ct-scvx_light.png +0 -0
  49. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/assets/images/ctcs_dark.png +0 -0
  50. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/assets/images/ctcs_light.png +0 -0
  51. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/assets/images/problem_class_dark.png +0 -0
  52. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/assets/images/problem_class_light.png +0 -0
  53. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/assets/logo.svg +0 -0
  54. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/citation.md +0 -0
  55. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/examples.md +0 -0
  56. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/getting-started.md +0 -0
  57. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/index.md +0 -0
  58. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/docs/javascripts/mathjax.js +0 -0
  59. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/abstract/brachistochrone.py +0 -0
  60. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/arm/three_link_arm.py +0 -0
  61. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/car/dubins_car.py +0 -0
  62. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/car/dubins_car_conditional.py +0 -0
  63. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/car/dubins_car_disjoint.py +0 -0
  64. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/car/dubins_car_stljax.py +0 -0
  65. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/cinema_vp.py +0 -0
  66. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/dr_double_integrator.py +0 -0
  67. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/dr_vp.py +0 -0
  68. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/dr_vp_nodal.py +0 -0
  69. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/dr_vp_polytope.py +0 -0
  70. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/drone_racing.py +0 -0
  71. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/logo.py +0 -0
  72. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/logo_utils/acl_logo.svg +0 -0
  73. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/logo_utils/svg_path_utils.py +0 -0
  74. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/obstacle_avoidance.py +0 -0
  75. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/obstacle_avoidance_nodal.py +0 -0
  76. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/drone/obstacle_avoidance_vmap.py +0 -0
  77. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/plotting.py +0 -0
  78. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/plotting_viser.py +0 -0
  79. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/realtime/base_problems/cinema_vp_realtime_base.py +0 -0
  80. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/realtime/base_problems/drone_racing_realtime_base.py +0 -0
  81. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/realtime/base_problems/obstacle_avoidance_realtime_base.py +0 -0
  82. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/realtime/cinema_vp_realtime.py +0 -0
  83. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/realtime/drone_racing_realtime.py +0 -0
  84. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/realtime/dubins_car_realtime.py +0 -0
  85. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/realtime/obstacle_avoidance_realtime.py +0 -0
  86. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/rocket/3DoF_pdg.py +0 -0
  87. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/examples/spacecraft/proxops_cw.py +0 -0
  88. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/figures/ctlos_cine.gif +0 -0
  89. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/figures/ctlos_dr.gif +0 -0
  90. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/figures/dtlos_cine.gif +0 -0
  91. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/figures/dtlos_dr.gif +0 -0
  92. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/figures/openscvx_logo.svg +0 -0
  93. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/figures/openscvx_logo_square.png +0 -0
  94. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/figures/oscvx_structure_full_dark.svg +0 -0
  95. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/figures/video_preview.png +0 -0
  96. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/__init__.py +0 -0
  97. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/1-background.avif +0 -0
  98. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/1-background@1x.avif +0 -0
  99. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/1-background@2x.avif +0 -0
  100. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/1-background@3x.avif +0 -0
  101. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/1-background@4x.avif +0 -0
  102. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/2-mars.avif +0 -0
  103. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/2-mars@1x.avif +0 -0
  104. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/2-mars@2x.avif +0 -0
  105. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/2-mars@3x.avif +0 -0
  106. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/2-mars@4x.avif +0 -0
  107. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/3-moon.avif +0 -0
  108. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/3-moon@1x.avif +0 -0
  109. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/3-moon@2x.avif +0 -0
  110. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/3-moon@3x.avif +0 -0
  111. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/3-moon@4x.avif +0 -0
  112. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/4-sat1.avif +0 -0
  113. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/4-sat1@1x.avif +0 -0
  114. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/4-sat1@2x.avif +0 -0
  115. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/4-sat1@3x.avif +0 -0
  116. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/4-sat1@4x.avif +0 -0
  117. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/5-space.avif +0 -0
  118. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/5-space@1x.avif +0 -0
  119. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/5-space@2x.avif +0 -0
  120. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/5-space@3x.avif +0 -0
  121. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/5-space@4x.avif +0 -0
  122. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/6-earth.avif +0 -0
  123. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/6-earth@1x.avif +0 -0
  124. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/6-earth@2x.avif +0 -0
  125. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/6-earth@3x.avif +0 -0
  126. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/images/layers/6-earth@4x.avif +0 -0
  127. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/javascripts/parallax.js +0 -0
  128. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/logo.svg +0 -0
  129. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/stylesheets/custom.css +0 -0
  130. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/assets/stylesheets/parallax.css +0 -0
  131. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/home.html +0 -0
  132. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/main.html +0 -0
  133. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/partials/parallax/hero.html +0 -0
  134. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/material/overrides/partials/parallax.html +0 -0
  135. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/mkdocs.yml +0 -0
  136. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/__init__.py +0 -0
  137. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/__main__.py +0 -0
  138. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/algorithms/AugmentedLagrangian.py +0 -0
  139. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/algorithms/ConstantProximalWeight.py +0 -0
  140. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/algorithms/RampProximalWeight.py +0 -0
  141. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/algorithms/__init__.py +0 -0
  142. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/algorithms/base.py +0 -0
  143. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/algorithms/optimization_results.py +0 -0
  144. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/algorithms/penalized_trust_region.py +0 -0
  145. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/config.py +0 -0
  146. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/discretization/__init__.py +0 -0
  147. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/discretization/base.py +0 -0
  148. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/expert/__init__.py +0 -0
  149. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/expert/byof.py +0 -0
  150. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/expert/validation.py +0 -0
  151. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/init/__init__.py +0 -0
  152. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/init/interpolation.py +0 -0
  153. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/integrators/__init__.py +0 -0
  154. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/integrators/runge_kutta.py +0 -0
  155. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/loader.py +0 -0
  156. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/lowered/__init__.py +0 -0
  157. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/lowered/cvxpy_constraints.py +0 -0
  158. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/lowered/cvxpy_variables.py +0 -0
  159. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/lowered/dynamics.py +0 -0
  160. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/lowered/jax_constraints.py +0 -0
  161. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/lowered/parameters.py +0 -0
  162. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/lowered/problem.py +0 -0
  163. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/lowered/unified.py +0 -0
  164. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/plotting/__init__.py +0 -0
  165. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/plotting/plotting.py +0 -0
  166. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/plotting/scp_iteration.py +0 -0
  167. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/plotting/viser/__init__.py +0 -0
  168. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/plotting/viser/animated.py +0 -0
  169. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/plotting/viser/plotly_integration.py +0 -0
  170. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/plotting/viser/primitives.py +0 -0
  171. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/plotting/viser/scp.py +0 -0
  172. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/plotting/viser/server.py +0 -0
  173. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/problem.py +0 -0
  174. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/propagation/__init__.py +0 -0
  175. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/propagation/post_processing.py +0 -0
  176. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/solvers/__init__.py +0 -0
  177. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/solvers/base.py +0 -0
  178. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/solvers/ptr_solver.py +0 -0
  179. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/__init__.py +0 -0
  180. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/constraint_set.py +0 -0
  181. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/__init__.py +0 -0
  182. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/arithmetic.py +0 -0
  183. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/array.py +0 -0
  184. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/constraint.py +0 -0
  185. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/control.py +0 -0
  186. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/lie/__init__.py +0 -0
  187. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/lie/adjoint.py +0 -0
  188. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/lie/se3.py +0 -0
  189. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/lie/so3.py +0 -0
  190. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/linalg.py +0 -0
  191. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/logic.py +0 -0
  192. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/math.py +0 -0
  193. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/spatial.py +0 -0
  194. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/state.py +0 -0
  195. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/stl.py +0 -0
  196. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/variable.py +0 -0
  197. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/expr/vmap.py +0 -0
  198. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/hashing.py +0 -0
  199. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/__init__.py +0 -0
  200. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/__init__.py +0 -0
  201. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/_lowerer.py +0 -0
  202. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/_registry.py +0 -0
  203. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/arithmetic.py +0 -0
  204. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/array.py +0 -0
  205. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/constraint.py +0 -0
  206. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/control.py +0 -0
  207. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/expr.py +0 -0
  208. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/linalg.py +0 -0
  209. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/logic.py +0 -0
  210. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/math.py +0 -0
  211. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/cvxpy/state.py +0 -0
  212. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/__init__.py +0 -0
  213. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/_lowerer.py +0 -0
  214. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/_registry.py +0 -0
  215. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/arithmetic.py +0 -0
  216. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/array.py +0 -0
  217. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/constraint.py +0 -0
  218. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/control.py +0 -0
  219. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/expr.py +0 -0
  220. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/lie.py +0 -0
  221. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/linalg.py +0 -0
  222. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/logic.py +0 -0
  223. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/math.py +0 -0
  224. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/spatial.py +0 -0
  225. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/state.py +0 -0
  226. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/stl.py +0 -0
  227. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/lowerers/jax/vmap.py +0 -0
  228. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/__init__.py +0 -0
  229. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/_registry.py +0 -0
  230. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/array.py +0 -0
  231. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/constraint.py +0 -0
  232. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/lie.py +0 -0
  233. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/linalg.py +0 -0
  234. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/logic.py +0 -0
  235. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/math.py +0 -0
  236. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/parser.py +0 -0
  237. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/spatial.py +0 -0
  238. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/stl.py +0 -0
  239. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/parser/tokenizer.py +0 -0
  240. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/preprocessing.py +0 -0
  241. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/problem.py +0 -0
  242. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/sparsity.py +0 -0
  243. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/time.py +0 -0
  244. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/symbolic/unified.py +0 -0
  245. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/utils/__init__.py +0 -0
  246. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/utils/cache.py +0 -0
  247. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/utils/printing.py +0 -0
  248. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/utils/profiling.py +0 -0
  249. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx/utils/utils.py +0 -0
  250. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx.egg-info/SOURCES.txt +0 -0
  251. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx.egg-info/dependency_links.txt +0 -0
  252. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx.egg-info/entry_points.txt +0 -0
  253. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx.egg-info/requires.txt +0 -0
  254. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/openscvx.egg-info/top_level.txt +0 -0
  255. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/pyproject.toml +0 -0
  256. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/scripts/gen_example_pages.py +0 -0
  257. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/scripts/gen_ref_pages.py +0 -0
  258. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/setup.cfg +0 -0
  259. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/__init__.py +0 -0
  260. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/brachistochrone_analytical.py +0 -0
  261. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/fixtures/brachistochrone.json +0 -0
  262. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/fixtures/brachistochrone.yaml +0 -0
  263. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/__init__.py +0 -0
  264. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/__init__.py +0 -0
  265. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_arithmetic.py +0 -0
  266. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_array.py +0 -0
  267. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_constraint.py +0 -0
  268. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_expr.py +0 -0
  269. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_lie.py +0 -0
  270. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_linalg.py +0 -0
  271. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_logic.py +0 -0
  272. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_math.py +0 -0
  273. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_node_reference.py +0 -0
  274. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_parameters.py +0 -0
  275. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_scaling.py +0 -0
  276. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_spatial.py +0 -0
  277. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_variable.py +0 -0
  278. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/expr/test_vmap.py +0 -0
  279. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/__init__.py +0 -0
  280. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_array.py +0 -0
  281. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_constraint.py +0 -0
  282. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_lie.py +0 -0
  283. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_linalg.py +0 -0
  284. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_load.py +0 -0
  285. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_logic.py +0 -0
  286. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_math.py +0 -0
  287. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_parser.py +0 -0
  288. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_spatial.py +0 -0
  289. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_stl.py +0 -0
  290. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_tokenizer.py +0 -0
  291. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/parser/test_vmap.py +0 -0
  292. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/test_hashing.py +0 -0
  293. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/test_lower_cvxpy.py +0 -0
  294. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/test_lower_jax.py +0 -0
  295. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/test_preprocessing.py +0 -0
  296. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/test_sparsity.py +0 -0
  297. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/symbolic/test_unified.py +0 -0
  298. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_autotuning.py +0 -0
  299. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_brachistochrone.py +0 -0
  300. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_cvxpygen_optional.py +0 -0
  301. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_examples.py +0 -0
  302. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_expert.py +0 -0
  303. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_init.py +0 -0
  304. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_integrators.py +0 -0
  305. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_loader.py +0 -0
  306. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_optimization_results.py +0 -0
  307. {openscvx-0.4.1.dev88 → openscvx-0.4.1.dev90}/tests/test_plotting.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: openscvx
3
- Version: 0.4.1.dev88
3
+ Version: 0.4.1.dev90
4
4
  Summary: A general Python-based successive convexification implementation which uses a JAX backend.
5
5
  Author-email: Chris Hayner and Griffin Norris <haynec@uw.edu>
6
6
  License: Apache Software License
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.4.1.dev88'
32
- __version_tuple__ = version_tuple = (0, 4, 1, 'dev88')
31
+ __version__ = version = '0.4.1.dev90'
32
+ __version_tuple__ = version_tuple = (0, 4, 1, 'dev90')
33
33
 
34
- __commit_id__ = commit_id = 'g6d7b113ab'
34
+ __commit_id__ = commit_id = 'g7eea8c25b'
@@ -110,25 +110,26 @@ def _dVdt(
110
110
 
111
111
  Their derivatives follow from the variational equations:
112
112
 
113
- - ``dx/dτ = s · f(x, u)``
114
- - ``dΦ/dτ = s · A(x, u) · Φ``
115
- - ``dB_d/dτ = s · A(x, u) · B_d + α · s · B(x, u)``
116
- - ``dC_d/dτ = s · A(x, u) · C_d + β · s · B(x, u)``
113
+ - ``dx/dτ = F(x, u)`` where F = s · f(x, u) is the time-dilated dynamics
114
+ - ``dΦ/dτ = A(x, u) · Φ``
115
+ - ``dB_d/dτ = A(x, u) · B_d + α · B(x, u)``
116
+ - ``dC_d/dτ = A(x, u) · C_d + β · B(x, u)``
117
117
 
118
- where ``s`` is the time-dilation factor (last column of ``u``),
119
- ``A = ∂f/∂x``, ``B = ∂f/∂u``, and ``α, β`` are interpolation weights
120
- determined by the hold type (ZOH: α=1, β=0; FOH: linear blend).
118
+ where ``A = ∂F/∂x`` and ``B = ∂F/∂u`` are Jacobians of the
119
+ time-dilated dynamics (which include the time-dilation factor ``s``
120
+ symbolically), and ``α, β`` are interpolation weights determined by the
121
+ hold type (ZOH: α=1, β=0; FOH: linear blend).
121
122
 
122
123
  Args:
123
124
  tau: Normalized time in [0, 1] within the current segment.
124
125
  V: Flattened augmented state vector, shape ``((N-1) * aug_dim,)``.
125
- u_cur: Control at current node, shape ``(N-1, n_u+1)``.
126
- u_next: Control at next node, shape ``(N-1, n_u+1)``.
127
- state_dot: Vmapped dynamics ``f(x, u, node, params) -> x_dot``.
128
- A: Vmapped state Jacobian ``∂f/∂x(x, u, node, params)``.
129
- B: Vmapped control Jacobian ``∂f/∂u(x, u, node, params)``.
126
+ u_cur: Control at current node, shape ``(N-1, n_u)``.
127
+ u_next: Control at next node, shape ``(N-1, n_u)``.
128
+ state_dot: Vmapped time-dilated dynamics ``F(x, u, node, params) -> x_dot``.
129
+ A: Vmapped state Jacobian ``∂F/∂x(x, u, node, params)``.
130
+ B: Vmapped control Jacobian ``∂F/∂u(x, u, node, params)``.
130
131
  n_x: Number of states.
131
- n_u: Number of controls (excluding time-dilation slack).
132
+ n_u: Number of controls (including time-dilation).
132
133
  N: Number of trajectory nodes.
133
134
  dis_type: ``"ZOH"`` (zero-order hold) or ``"FOH"`` (first-order hold).
134
135
  S_x: State scaling matrix (unused, reserved for future scaling).
@@ -165,46 +166,32 @@ def _dVdt(
165
166
  beta = (tau) * N
166
167
  alpha = 1 - beta
167
168
 
168
- # TODO: (norrisg) integrate the multiplication with `s` into the symbolic layer
169
- # This currently requires a hack to get the sparsity pattern to include the relation with s
170
-
171
169
  # Interpolate the control input
172
170
  u = u_cur + beta * (u_next - u_cur)
173
- s = u[:, -1]
174
-
175
- # Initialize the augmented Jacobians
176
- dfdx = jnp.zeros((V.shape[0], n_x, n_x))
177
- dfdu = jnp.zeros((V.shape[0], n_x, n_u))
178
171
 
179
172
  # Ensure x_seq and u have the same batch size
180
173
  x = V[:, :n_x]
181
174
  u = u[: x.shape[0]]
182
175
 
183
- # Compute the nonlinear propagation term
184
- f = state_dot(x, u[:, :-1], nodes, params)
185
- F = s[:, None] * f
186
-
187
- # Evaluate the State Jacobian
188
- dfdx = A(x, u[:, :-1], nodes, params)
189
- sdfdx = s[:, None, None] * dfdx
176
+ # Compute the time-dilated dynamics (s * f already included symbolically)
177
+ F = state_dot(x, u, nodes, params)
190
178
 
191
- # Evaluate the Control Jacobian
192
- dfdu_veh = B(x, u[:, :-1], nodes, params)
193
- dfdu = dfdu.at[:, :, :-1].set(s[:, None, None] * dfdu_veh)
194
- dfdu = dfdu.at[:, :, -1].set(f)
179
+ # Evaluate the Jacobians (already include time-dilation derivatives via autodiff)
180
+ dfdx = A(x, u, nodes, params)
181
+ dfdu = B(x, u, nodes, params)
195
182
 
196
183
  # Stack up the results into the augmented state vector
197
184
  # fmt: off
198
185
  dVdt = jnp.zeros_like(V)
199
186
  dVdt = dVdt.at[:, i0:i1].set(F)
200
187
  dVdt = dVdt.at[:, i1:i2].set(
201
- jnp.matmul(sdfdx, V[:, i1:i2].reshape(-1, n_x, n_x)).reshape(-1, n_x * n_x)
188
+ jnp.matmul(dfdx, V[:, i1:i2].reshape(-1, n_x, n_x)).reshape(-1, n_x * n_x)
202
189
  )
203
190
  dVdt = dVdt.at[:, i2:i3].set(
204
- (jnp.matmul(sdfdx, V[:, i2:i3].reshape(-1, n_x, n_u)) + dfdu * alpha).reshape(-1, n_x * n_u)
191
+ (jnp.matmul(dfdx, V[:, i2:i3].reshape(-1, n_x, n_u)) + dfdu * alpha).reshape(-1, n_x * n_u)
205
192
  )
206
193
  dVdt = dVdt.at[:, i3:i4].set(
207
- (jnp.matmul(sdfdx, V[:, i3:i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u)
194
+ (jnp.matmul(dfdx, V[:, i3:i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u)
208
195
  )
209
196
  # fmt: on
210
197
 
@@ -231,11 +218,11 @@ def _calculate_discretization(
231
218
 
232
219
  Args:
233
220
  x: Reference state trajectory, shape ``(N, n_x)``.
234
- u: Reference control trajectory, shape ``(N, n_u+1)`` (includes
235
- time-dilation slack as last column).
236
- state_dot: Vmapped dynamics ``f(x, u, node, params) -> x_dot``.
237
- A: Vmapped state Jacobian ``∂f/∂x``.
238
- B: Vmapped control Jacobian ``∂f/∂u``.
221
+ u: Reference control trajectory, shape ``(N, n_u)`` (includes
222
+ time-dilation as part of the unified control vector).
223
+ state_dot: Vmapped time-dilated dynamics ``F(x, u, node, params) -> x_dot``.
224
+ A: Vmapped state Jacobian ``∂F/∂x``.
225
+ B: Vmapped control Jacobian ``∂F/∂u``.
239
226
  settings: Configuration (integrator choice, tolerances, hold type, etc.).
240
227
  params: Parameters forwarded to ``state_dot``, ``A``, and ``B``.
241
228
 
@@ -78,17 +78,23 @@ def apply_byof(
78
78
  state_slices = {state.name: state._slice for state in states}
79
79
  state_slices_prop = {state.name: state._slice for state in states_prop}
80
80
 
81
- def _make_composite_dynamics(orig_f, byof_fns, slices_map):
81
+ # Time-dilation slice for multiplying byof outputs by s
82
+ td_slice = u_unified.time_dilation_slice
83
+
84
+ def _make_composite_dynamics(orig_f, byof_fns, slices_map, td_sl):
82
85
  """Create composite dynamics combining symbolic and byof state derivatives.
83
86
 
84
87
  This factory splices user-provided byof dynamics into the unified dynamics
85
88
  function at the appropriate slice indices, replacing the symbolic dynamics
86
- for specific states while preserving the rest.
89
+ for specific states while preserving the rest. The byof outputs are
90
+ multiplied by the time-dilation factor s to match the symbolic dynamics
91
+ which already include s * f(x, u) via the Mul node.
87
92
 
88
93
  Args:
89
94
  orig_f: Original unified dynamics (x, u, node, params) -> xdot
90
95
  byof_fns: Dict mapping state names to byof dynamics functions
91
96
  slices_map: Dict mapping state names to slice objects for indexing
97
+ td_sl: Slice for the time-dilation control in the unified u vector
92
98
 
93
99
  Returns:
94
100
  Composite dynamics function with byof derivatives spliced in
@@ -98,11 +104,14 @@ def apply_byof(
98
104
  # Start with symbolic/default dynamics for all states
99
105
  xdot = orig_f(x, u, node, params)
100
106
 
101
- # Splice in byof dynamics for specific states
107
+ # Time-dilation factor (symbolic dynamics already include s *)
108
+ s = u[td_sl]
109
+
110
+ # Splice in byof dynamics for specific states, multiplied by s
102
111
  for state_name, byof_fn in byof_fns.items():
103
112
  sl = slices_map[state_name]
104
- # Replace the derivative for this state with the byof result
105
- xdot = xdot.at[sl].set(byof_fn(x, u, node, params))
113
+ # Replace the derivative for this state with s * byof result
114
+ xdot = xdot.at[sl].set(s * byof_fn(x, u, node, params))
106
115
 
107
116
  return xdot
108
117
 
@@ -110,12 +119,12 @@ def apply_byof(
110
119
 
111
120
  # Create composite optimization dynamics
112
121
  # Jacobians are computed by the discretizer, not here.
113
- composite_f = _make_composite_dynamics(dynamics.f, byof_dynamics, state_slices)
122
+ composite_f = _make_composite_dynamics(dynamics.f, byof_dynamics, state_slices, td_slice)
114
123
  dynamics = Dynamics(f=composite_f)
115
124
 
116
125
  # Create composite propagation dynamics
117
126
  composite_f_prop = _make_composite_dynamics(
118
- dynamics_prop.f, byof_dynamics, state_slices_prop
127
+ dynamics_prop.f, byof_dynamics, state_slices_prop, td_slice
119
128
  )
120
129
  dynamics_prop = Dynamics(f=composite_f_prop)
121
130
 
@@ -273,17 +282,24 @@ def apply_byof(
273
282
 
274
283
  penalty_fns.append(_make_penalty_fn(constraint_fn, penalty_func, over_interval))
275
284
 
285
+ # Time-dilation slice for multiplying byof CTCS penalties by s
286
+ td_slice = u_unified.time_dilation_slice
287
+
276
288
  if idx in idx_to_aug_slice:
277
289
  # This idx already exists from symbolic CTCS - add penalties to existing state
278
290
  aug_slice = idx_to_aug_slice[idx]
279
291
 
280
- def _make_ctcs_addition(orig_f, pen_fns, aug_sl):
292
+ def _make_ctcs_addition(orig_f, pen_fns, aug_sl, td_sl):
281
293
  """Create dynamics that adds penalties to existing augmented state.
282
294
 
295
+ The penalty is multiplied by the time-dilation factor s to match
296
+ the symbolic dynamics which already include s * f(x, u).
297
+
283
298
  Args:
284
299
  orig_f: Original dynamics function
285
300
  pen_fns: List of penalty functions to add
286
301
  aug_sl: Slice of the augmented state to modify
302
+ td_sl: Slice for the time-dilation control
287
303
 
288
304
  Returns:
289
305
  Modified dynamics function
@@ -292,8 +308,9 @@ def apply_byof(
292
308
  def modified_f(x, u, node, params):
293
309
  xdot = orig_f(x, u, node, params)
294
310
 
295
- # Sum all penalties for this idx
296
- total_penalty = sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)
311
+ # Sum all penalties for this idx, scaled by time-dilation
312
+ s = u[td_sl]
313
+ total_penalty = s * sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)
297
314
 
298
315
  # Add to existing augmented state derivative
299
316
  current_deriv = xdot[aug_sl]
@@ -305,8 +322,8 @@ def apply_byof(
305
322
 
306
323
  # Modify both optimization and propagation dynamics
307
324
  # Jacobians are computed by the discretizer, not here.
308
- dynamics.f = _make_ctcs_addition(dynamics.f, penalty_fns, aug_slice)
309
- dynamics_prop.f = _make_ctcs_addition(dynamics_prop.f, penalty_fns, aug_slice)
325
+ dynamics.f = _make_ctcs_addition(dynamics.f, penalty_fns, aug_slice, td_slice)
326
+ dynamics_prop.f = _make_ctcs_addition(dynamics_prop.f, penalty_fns, aug_slice, td_slice)
310
327
 
311
328
  else:
312
329
  # New idx - create new augmented state
@@ -315,12 +332,16 @@ def apply_byof(
315
332
  bounds = first_spec.get("bounds", (0.0, 1e-4))
316
333
  initial = first_spec.get("initial", bounds[0])
317
334
 
318
- def _make_ctcs_new_state(orig_f, pen_fns):
335
+ def _make_ctcs_new_state(orig_f, pen_fns, td_sl):
319
336
  """Create dynamics augmented with new CTCS state.
320
337
 
338
+ The penalty is multiplied by the time-dilation factor s to match
339
+ the symbolic dynamics which already include s * f(x, u).
340
+
321
341
  Args:
322
342
  orig_f: Original dynamics function
323
343
  pen_fns: List of penalty functions to sum
344
+ td_sl: Slice for the time-dilation control
324
345
 
325
346
  Returns:
326
347
  Augmented dynamics function
@@ -329,8 +350,9 @@ def apply_byof(
329
350
  def augmented_f(x, u, node, params):
330
351
  xdot = orig_f(x, u, node, params)
331
352
 
332
- # Sum all penalties for this new idx
333
- total_penalty = sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)
353
+ # Sum all penalties for this new idx, scaled by time-dilation
354
+ s = u[td_sl]
355
+ total_penalty = s * sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)
334
356
 
335
357
  # Append as new augmented state derivative
336
358
  return jnp.concatenate([xdot, jnp.atleast_1d(total_penalty)])
@@ -339,11 +361,11 @@ def apply_byof(
339
361
 
340
362
  # Augment optimization dynamics
341
363
  # Jacobians are computed by the discretizer, not here.
342
- aug_f = _make_ctcs_new_state(dynamics.f, penalty_fns)
364
+ aug_f = _make_ctcs_new_state(dynamics.f, penalty_fns, td_slice)
343
365
  dynamics = Dynamics(f=aug_f)
344
366
 
345
367
  # Augment propagation dynamics
346
- aug_f_prop = _make_ctcs_new_state(dynamics_prop.f, penalty_fns)
368
+ aug_f_prop = _make_ctcs_new_state(dynamics_prop.f, penalty_fns, td_slice)
347
369
  dynamics_prop = Dynamics(f=aug_f_prop)
348
370
 
349
371
  # Create State objects for the new augmented states
@@ -12,7 +12,6 @@ def prop_aug_dy(
12
12
  u_next: np.ndarray,
13
13
  tau_init: float,
14
14
  node: int,
15
- idx_s: int,
16
15
  state_dot: callable,
17
16
  dis_type: str,
18
17
  N: int,
@@ -20,8 +19,10 @@ def prop_aug_dy(
20
19
  ) -> np.ndarray:
21
20
  """Compute the augmented dynamics for propagation.
22
21
 
23
- This function computes the time-scaled dynamics for propagating the system state,
24
- taking into account the discretization type (ZOH or FOH) and time dilation.
22
+ This function computes the time-dilated dynamics for propagating the system
23
+ state, taking into account the discretization type (ZOH or FOH). The
24
+ time-dilation multiplication is already included in ``state_dot``
25
+ symbolically.
25
26
 
26
27
  Args:
27
28
  tau (float): Current normalized time in [0,1].
@@ -30,14 +31,13 @@ def prop_aug_dy(
30
31
  u_next (np.ndarray): Control input at next node.
31
32
  tau_init (float): Initial normalized time.
32
33
  node (int): Current node index.
33
- idx_s (int): Index of time dilation variable in control vector.
34
- state_dot (callable): Function computing state derivatives.
34
+ state_dot (callable): Function computing time-dilated state derivatives.
35
35
  dis_type (str): Discretization type ("ZOH" or "FOH").
36
36
  N (int): Number of nodes in trajectory.
37
37
  params: Dictionary of additional parameters passed to state_dot.
38
38
 
39
39
  Returns:
40
- np.ndarray: Time-scaled state derivatives.
40
+ np.ndarray: Time-dilated state derivatives.
41
41
  """
42
42
  x = x[None, :]
43
43
 
@@ -47,7 +47,7 @@ def prop_aug_dy(
47
47
  beta = (tau - tau_init) * N
48
48
  u = u_current + beta * (u_next - u_current)
49
49
 
50
- return u[:, idx_s] * state_dot(x, u[:, :-1], node, params).squeeze()
50
+ return state_dot(x, u, node, params).squeeze()
51
51
 
52
52
 
53
53
  def get_propagation_solver(state_dot: Dynamics, settings: Config) -> callable:
@@ -64,9 +64,7 @@ def get_propagation_solver(state_dot: Dynamics, settings: Config) -> callable:
64
64
  callable: A function that solves the propagation problem.
65
65
  """
66
66
 
67
- def propagation_solver(
68
- V0, tau_grid, u_cur, u_next, tau_init, node, idx_s, save_time, mask, params
69
- ):
67
+ def propagation_solver(V0, tau_grid, u_cur, u_next, tau_init, node, save_time, mask, params):
70
68
  param_map_update = params
71
69
  return solve_ivp_diffrax_prop(
72
70
  f=prop_aug_dy,
@@ -77,7 +75,6 @@ def get_propagation_solver(state_dot: Dynamics, settings: Config) -> callable:
77
75
  u_next, # shape (1, n_controls)
78
76
  tau_init, # shape (1, 1)
79
77
  node, # shape (1, 1)
80
- idx_s, # int
81
78
  state_dot, # function or array
82
79
  settings.dis.dis_type,
83
80
  settings.scp.n,
@@ -242,7 +239,6 @@ def simulate_nonlinear_time(
242
239
  controls_next,
243
240
  np.array([[tau[k]]]),
244
241
  np.array([[k]]),
245
- settings.sim.time_dilation_slice.stop,
246
242
  tau_cur_padded,
247
243
  mask_padded,
248
244
  params,
@@ -457,7 +457,7 @@ def augment_dynamics_with_ctcs(
457
457
  The augmented dynamics become:
458
458
  x_dot = f(x, u)
459
459
  aug_dot = penalty(g(x, u)) # For each constraint group
460
- time_dot = time_dilation
460
+ time_dot = 1.0
461
461
 
462
462
  Args:
463
463
  xdot: Original dynamics expression for states
@@ -341,7 +341,19 @@ def preprocess_symbolic_problem(
341
341
  parameters=parameters,
342
342
  )
343
343
 
344
- # ==================== PHASE 6: Process Algebraic Outputs ====================
344
+ # ==================== PHASE 6: Apply Time-Dilation Multiplication ===========
345
+ #
346
+ # Multiply all dynamics by the time-dilation control symbolically.
347
+ # This transforms f(x,u) -> s * f(x,u) so that JAX autodiff automatically
348
+ # computes the correct Jacobians including df/ds = f(x,u).
349
+ #
350
+ # This must happen AFTER propagation dynamics are assembled so that extra
351
+ # propagation states (e.g. distance) also get the s * multiplication.
352
+ time_dilation = next(c for c in controls_aug if c.name == "_time_dilation")
353
+ dynamics_aug = time_dilation * dynamics_aug
354
+ dynamics_prop = time_dilation * dynamics_prop
355
+
356
+ # ==================== PHASE 7: Process Algebraic Outputs ====================
345
357
 
346
358
  # Validate and canonicalize algebraic_prop expressions
347
359
  algebraic_prop_processed = None
@@ -572,7 +572,7 @@ def traverse(expr: Expr, visit: Callable[[Expr], None]):
572
572
  traverse(child, visit)
573
573
 
574
574
 
575
- class Constant(Expr):
575
+ class Constant(Leaf):
576
576
  """Constant value expression.
577
577
 
578
578
  Represents a constant numeric value in the expression tree. Constants are
@@ -601,40 +601,7 @@ class Constant(Expr):
601
601
  if not isinstance(value, np.ndarray):
602
602
  value = np.array(value, dtype=float)
603
603
  self.value = np.squeeze(value)
604
-
605
- def canonicalize(self) -> "Expr":
606
- """Constants are already in canonical form.
607
-
608
- Returns:
609
- Expr: Returns self since constants are already canonical
610
- """
611
- return self
612
-
613
- def check_shape(self) -> Tuple[int, ...]:
614
- """Return the shape of this constant's value.
615
-
616
- Returns:
617
- tuple: The shape of the constant's numpy array value
618
- """
619
- # Verify the invariant: constants should already be squeezed during construction
620
- original_shape = self.value.shape
621
- squeezed_shape = np.squeeze(self.value).shape
622
- if original_shape != squeezed_shape:
623
- raise ValueError(
624
- f"Constant not properly normalized: has shape {original_shape} "
625
- "but should have shape {squeezed_shape}. "
626
- "Constants should be squeezed during construction."
627
- )
628
- return self.value.shape
629
-
630
- def sparsity(self, n_x: int, n_u: int) -> Tuple[np.ndarray, np.ndarray]:
631
- """Constants have no decision-variable dependence."""
632
- shape = self.value.shape
633
- n_out = int(np.prod(shape)) if shape else 1
634
- return (
635
- np.zeros((n_out, n_x), dtype=bool),
636
- np.zeros((n_out, n_u), dtype=bool),
637
- )
604
+ super().__init__(name="__constant__", shape=self.value.shape)
638
605
 
639
606
  def _hash_into(self, hasher: "hashlib._Hash") -> None:
640
607
  """Hash constant by its value.
@@ -717,27 +717,12 @@ def lower_symbolic_problem(
717
717
  # patterns (the superset) since dis_type isn't known yet.
718
718
  dynamics_sparsity = None
719
719
  if byof is None and problem.dynamics is not None:
720
- from openscvx.symbolic.sparsity import _element_liveness, discrete_sparsity
720
+ from openscvx.symbolic.sparsity import discrete_sparsity
721
721
 
722
722
  n_x = sum(s.shape[0] for s in problem.states)
723
723
  n_u = sum(c.shape[0] for c in problem.controls)
724
724
  A_c, B_c = problem.dynamics.sparsity(n_x, n_u)
725
725
 
726
- # TODO: (norrisg) tidy up how time-dilation is multiplied onto the
727
- # dynamics so that it is automatically included into the sparsity.
728
- # Requires fixing stuff in the discretizer.
729
-
730
- # The discretization multiplies all dynamics by the time-dilation
731
- # factor s and adds a column df/d(sigma) = f(x, u) to the control
732
- # Jacobian (see linearize_discretize.py). The symbolic dynamics
733
- # don't reference sigma, so we patch B_c here: every row whose
734
- # dynamics can be nonzero depends on sigma.
735
- for ctrl in problem.controls:
736
- if ctrl.name == "_time_dilation" and ctrl._slice is not None:
737
- live = _element_liveness(problem.dynamics)
738
- B_c[live, ctrl._slice] = True
739
- break
740
-
741
726
  dynamics_sparsity = discrete_sparsity(A_c, B_c, dis_type="FOH")
742
727
 
743
728
  # Compute per-constraint Jacobian sparsity from the symbolic AST.
@@ -192,7 +192,6 @@ def load_or_compile_propagation_solver(
192
192
  np.ones((1, n_controls)), # controls_next
193
193
  np.ones((1, 1)), # tau_0
194
194
  np.ones((1, 1)).astype("int"), # segment index
195
- 0, # idx_s_stop
196
195
  np.ones((max_tau_len,)), # save_time (tau_cur_padded)
197
196
  np.ones((max_tau_len,), dtype=bool), # mask_padded (boolean mask)
198
197
  params, # additional parameters as dict
@@ -223,7 +222,6 @@ def prime_propagation_solver(
223
222
  controls_next = np.ones((1, settings.sim.u.shape[0]), dtype=settings.sim.u.guess.dtype)
224
223
  tau_init = np.array([[0.0]], dtype=np.float64)
225
224
  node = np.array([[0]], dtype=np.int64)
226
- idx_s_stop = settings.sim.time_dilation_slice.stop
227
225
  save_time = np.ones((settings.prp.max_tau_len,), dtype=np.float64)
228
226
  mask_padded = np.ones((settings.prp.max_tau_len,), dtype=bool)
229
227
  # Create dummy params dict with same structure
@@ -238,7 +236,6 @@ def prime_propagation_solver(
238
236
  controls_next,
239
237
  tau_init,
240
238
  node,
241
- idx_s_stop,
242
239
  save_time,
243
240
  mask_padded,
244
241
  dummy_params,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: openscvx
3
- Version: 0.4.1.dev88
3
+ Version: 0.4.1.dev90
4
4
  Summary: A general Python-based successive convexification implementation which uses a JAX backend.
5
5
  Author-email: Chris Hayner and Griffin Norris <haynec@uw.edu>
6
6
  License: Apache Software License
@@ -387,6 +387,9 @@ def test_augment_single_penalty_no_add():
387
387
  N,
388
388
  )
389
389
 
390
+ # Augmented dynamics should be Concat
391
+ assert isinstance(xdot_aug, Concat)
392
+
390
393
  # Single penalty should not be wrapped in Add
391
394
  # But it should be wrapped in CTCS
392
395
  ctcs_expr = xdot_aug.exprs[1]
@@ -418,6 +421,9 @@ def test_augment_multiple_penalties_create_add():
418
421
  N,
419
422
  )
420
423
 
424
+ # Augmented dynamics should be Concat
425
+ assert isinstance(xdot_aug, Concat)
426
+
421
427
  # Multiple penalties should be wrapped in Add
422
428
  penalty_expr = xdot_aug.exprs[1]
423
429
  assert isinstance(penalty_expr, Add)
@@ -19,7 +19,7 @@ def settings():
19
19
  p = Dummy()
20
20
  p.sim = Dummy()
21
21
  p.sim.n_states = 2
22
- p.sim.n_controls = 1
22
+ p.sim.n_controls = 2 # 1 vehicle control + 1 time-dilation (unified)
23
23
  p.sim.S_x = jnp.eye(p.sim.n_states)
24
24
  p.sim.c_x = jnp.zeros(p.sim.n_states)
25
25
  p.sim.S_u = jnp.eye(p.sim.n_controls)
@@ -41,20 +41,12 @@ def settings():
41
41
 
42
42
 
43
43
  def state_dot(x, u, node, params):
44
- # simple linear: x' = A_true x + B_true u
45
- return x + u
46
-
47
-
48
- def A(x, u, node, params):
49
- batch = x.shape[0]
50
- eye = jnp.eye(2)
51
- return jnp.broadcast_to(eye, (batch, 2, 2))
52
-
53
-
54
- def B(x, u, node, params):
55
- batch = x.shape[0]
56
- ones = jnp.ones((2, 1))
57
- return jnp.broadcast_to(ones, (batch, 2, 1))
44
+ # simple time-dilated dynamics: x' = s * (x + u_vehicle)
45
+ # u = [u_vehicle, s] includes both vehicle control and time-dilation
46
+ # This is the un-vmapped version (single sample, not batched)
47
+ s = u[1]
48
+ u_v = u[0]
49
+ return s * (x + u_v)
58
50
 
59
51
 
60
52
  @pytest.fixture
@@ -72,9 +64,9 @@ def test_discretization_shapes(settings, dynamics):
72
64
  discretizer = LinearizeDiscretize()
73
65
  solver = discretizer.get_solver(dynamics, settings)
74
66
 
75
- # dummy x,u
67
+ # dummy x,u (n_controls already includes time-dilation)
76
68
  x = jnp.ones((settings.scp.n, settings.sim.n_states))
77
- u = jnp.ones((settings.scp.n, settings.sim.n_controls + 1)) # +1 slack
69
+ u = jnp.ones((settings.scp.n, settings.sim.n_controls))
78
70
 
79
71
  A_bar, B_bar, C_bar, x_prop, Vmulti = solver(x, u, {})
80
72
 
@@ -85,19 +77,23 @@ def test_discretization_shapes(settings, dynamics):
85
77
  assert B_bar.shape == ((N - 1), n_x, n_u)
86
78
  assert C_bar.shape == ((N - 1), n_x, n_u)
87
79
  assert x_prop.shape == ((N - 1), n_x)
88
- # assert Vmulti.shape == (N, (n_x + n_x*n_x + 2*n_x*n_u + n_x))
89
80
 
90
81
 
91
82
  def test_jit_dVdt_compiles(settings):
92
- # prepare trivial inputs
83
+ # prepare trivial inputs (n_u already includes time-dilation)
93
84
  n_x, n_u = settings.sim.n_states, settings.sim.n_controls
94
85
  N = settings.scp.n
95
86
  aug_dim = n_x + n_x * n_x + 2 * n_x * n_u
96
87
 
97
88
  tau = jnp.array(0.3)
98
89
  V_flat = jnp.ones((N - 1) * aug_dim)
99
- u_cur = jnp.ones((N - 1, n_u + 1))
100
- u_next = jnp.ones((N - 1, n_u + 1))
90
+ u_cur = jnp.ones((N - 1, n_u))
91
+ u_next = jnp.ones((N - 1, n_u))
92
+
93
+ # Create vmapped versions of dynamics and Jacobians (as _dVdt expects)
94
+ f_vmapped = jax.vmap(state_dot, in_axes=(0, 0, 0, None))
95
+ A_vmapped = jax.vmap(jax.jacfwd(state_dot, argnums=0), in_axes=(0, 0, 0, None))
96
+ B_vmapped = jax.vmap(jax.jacfwd(state_dot, argnums=1), in_axes=(0, 0, 0, None))
101
97
 
102
98
  # bind out the Python callables & settings
103
99
  def wrapped(tau_, V_):
@@ -106,20 +102,20 @@ def test_jit_dVdt_compiles(settings):
106
102
  V_,
107
103
  u_cur,
108
104
  u_next,
109
- state_dot,
110
- A,
111
- B,
105
+ f_vmapped,
106
+ A_vmapped,
107
+ B_vmapped,
112
108
  n_x,
113
109
  n_u,
114
110
  N,
115
111
  settings.dis.dis_type,
116
- {},
117
112
  settings.sim.S_x,
118
113
  settings.sim.c_x,
119
114
  settings.sim.S_u,
120
115
  settings.sim.c_u,
121
116
  settings.sim.inv_S_x,
122
117
  settings.sim.inv_S_u,
118
+ {},
123
119
  )
124
120
 
125
121
  # now JIT only over (tau_, V_)
@@ -141,9 +137,9 @@ def test_jit_discretization_solver_compiles(settings, dynamics, integrator):
141
137
  discretizer = LinearizeDiscretize()
142
138
  solver = discretizer.get_solver(dynamics, settings)
143
139
 
144
- # dummy x,u (including slack column)
140
+ # dummy x,u (n_controls already includes time-dilation)
145
141
  x = jnp.ones((settings.scp.n, settings.sim.n_states))
146
- u = jnp.ones((settings.scp.n, settings.sim.n_controls + 1))
142
+ u = jnp.ones((settings.scp.n, settings.sim.n_controls))
147
143
 
148
144
  # jit & lower & compile
149
145
  jitted = jax.jit(solver)