warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/utils.py CHANGED
@@ -6,11 +6,10 @@
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
8
  import cProfile
9
- import math
10
9
  import sys
11
10
  import timeit
12
11
  import warnings
13
- from typing import Any, Tuple, Union
12
+ from typing import Any
14
13
 
15
14
  import numpy as np
16
15
 
@@ -31,157 +30,9 @@ def warn(message, category=None, stacklevel=1):
31
30
  warnings.warn(message, category, stacklevel + 1) # Increment stacklevel by 1 since we are in a wrapper
32
31
 
33
32
 
34
- def length(a):
35
- return np.linalg.norm(a)
36
-
37
-
38
- def length_sq(a):
39
- return np.dot(a, a)
40
-
41
-
42
- def cross(a, b):
43
- return np.array((a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]), dtype=np.float32)
44
-
45
-
46
- # NumPy has no normalize() method..
47
- def normalize(v):
48
- norm = np.linalg.norm(v)
49
- if norm == 0.0:
50
- return v
51
- return v / norm
52
-
53
-
54
- def skew(v):
55
- return np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
56
-
57
-
58
- # math utils
59
- # def quat(i, j, k, w):
60
- # return np.array([i, j, k, w])
61
-
62
-
63
- def quat_identity():
64
- return np.array((0.0, 0.0, 0.0, 1.0))
65
-
66
-
67
- def quat_inverse(q):
68
- return np.array((-q[0], -q[1], -q[2], q[3]))
69
-
70
-
71
- def quat_from_axis_angle(axis, angle):
72
- v = normalize(np.array(axis))
73
-
74
- half = angle * 0.5
75
- w = math.cos(half)
76
-
77
- sin_theta_over_two = math.sin(half)
78
- v *= sin_theta_over_two
79
-
80
- return np.array((v[0], v[1], v[2], w))
81
-
82
-
83
- def quat_to_axis_angle(quat):
84
- w2 = quat[3] * quat[3]
85
- if w2 > 1 - 1e-7:
86
- return np.zeros(3), 0.0
87
-
88
- angle = 2 * np.arccos(quat[3])
89
- xyz = quat[:3] / np.sqrt(1 - w2)
90
- return xyz, angle
91
-
92
-
93
- # quat_rotate a vector
94
- def quat_rotate(q, x):
95
- x = np.array(x)
96
- axis = np.array((q[0], q[1], q[2]))
97
- return x * (2.0 * q[3] * q[3] - 1.0) + np.cross(axis, x) * q[3] * 2.0 + axis * np.dot(axis, x) * 2.0
98
-
99
-
100
- # multiply two quats
101
- def quat_multiply(a, b):
102
- return np.array(
103
- (
104
- a[3] * b[0] + b[3] * a[0] + a[1] * b[2] - b[1] * a[2],
105
- a[3] * b[1] + b[3] * a[1] + a[2] * b[0] - b[2] * a[0],
106
- a[3] * b[2] + b[3] * a[2] + a[0] * b[1] - b[0] * a[1],
107
- a[3] * b[3] - a[0] * b[0] - a[1] * b[1] - a[2] * b[2],
108
- )
109
- )
110
-
111
-
112
- # convert to mat33
113
- def quat_to_matrix(q):
114
- c1 = quat_rotate(q, np.array((1.0, 0.0, 0.0)))
115
- c2 = quat_rotate(q, np.array((0.0, 1.0, 0.0)))
116
- c3 = quat_rotate(q, np.array((0.0, 0.0, 1.0)))
117
-
118
- return np.array([c1, c2, c3]).T
119
-
120
-
121
- def quat_rpy(roll, pitch, yaw):
122
- cy = math.cos(yaw * 0.5)
123
- sy = math.sin(yaw * 0.5)
124
- cr = math.cos(roll * 0.5)
125
- sr = math.sin(roll * 0.5)
126
- cp = math.cos(pitch * 0.5)
127
- sp = math.sin(pitch * 0.5)
128
-
129
- w = cy * cr * cp + sy * sr * sp
130
- x = cy * sr * cp - sy * cr * sp
131
- y = cy * cr * sp + sy * sr * cp
132
- z = sy * cr * cp - cy * sr * sp
133
-
134
- return (x, y, z, w)
135
-
136
-
137
- def quat_from_matrix(m):
138
- tr = m[0, 0] + m[1, 1] + m[2, 2]
139
- h = 0.0
140
-
141
- if tr >= 0.0:
142
- h = math.sqrt(tr + 1.0)
143
- w = 0.5 * h
144
- h = 0.5 / h
145
-
146
- x = (m[2, 1] - m[1, 2]) * h
147
- y = (m[0, 2] - m[2, 0]) * h
148
- z = (m[1, 0] - m[0, 1]) * h
149
-
150
- else:
151
- i = 0
152
- if m[1, 1] > m[0, 0]:
153
- i = 1
154
- if m[2, 2] > m[i, i]:
155
- i = 2
156
-
157
- if i == 0:
158
- h = math.sqrt((m[0, 0] - (m[1, 1] + m[2, 2])) + 1.0)
159
- x = 0.5 * h
160
- h = 0.5 / h
161
-
162
- y = (m[0, 1] + m[1, 0]) * h
163
- z = (m[2, 0] + m[0, 2]) * h
164
- w = (m[2, 1] - m[1, 2]) * h
165
-
166
- elif i == 1:
167
- h = math.sqrt((m[1, 1] - (m[2, 2] + m[0, 0])) + 1.0)
168
- y = 0.5 * h
169
- h = 0.5 / h
170
-
171
- z = (m[1, 2] + m[2, 1]) * h
172
- x = (m[0, 1] + m[1, 0]) * h
173
- w = (m[0, 2] - m[2, 0]) * h
174
-
175
- elif i == 2:
176
- h = math.sqrt((m[2, 2] - (m[0, 0] + m[1, 1])) + 1.0)
177
- z = 0.5 * h
178
- h = 0.5 / h
179
-
180
- x = (m[2, 0] + m[0, 2]) * h
181
- y = (m[1, 2] + m[2, 1]) * h
182
- w = (m[1, 0] - m[0, 1]) * h
183
-
184
- return normalize(np.array([x, y, z, w]))
33
+ # expand a 7-vec to a tuple of arrays
34
+ def transform_expand(t):
35
+ return wp.transform(np.array(t[0:3]), np.array(t[3:7]))
185
36
 
186
37
 
187
38
  @wp.func
@@ -197,210 +48,6 @@ def quat_between_vectors(a: wp.vec3, b: wp.vec3) -> wp.quat:
197
48
  return wp.normalize(q)
198
49
 
199
50
 
200
- # rigid body transform
201
-
202
-
203
- # def transform(x, r):
204
- # return (np.array(x), np.array(r))
205
-
206
-
207
- def transform_identity():
208
- return wp.transform(np.array((0.0, 0.0, 0.0)), quat_identity())
209
-
210
-
211
- # se(3) -> SE(3), Park & Lynch pg. 105, screw in [w, v] normalized form
212
- def transform_exp(s, angle):
213
- w = np.array(s[0:3])
214
- v = np.array(s[3:6])
215
-
216
- if length(w) < 1.0:
217
- r = quat_identity()
218
- else:
219
- r = quat_from_axis_angle(w, angle)
220
-
221
- t = v * angle + (1.0 - math.cos(angle)) * np.cross(w, v) + (angle - math.sin(angle)) * np.cross(w, np.cross(w, v))
222
-
223
- return (t, r)
224
-
225
-
226
- def transform_inverse(t):
227
- q_inv = quat_inverse(t.q)
228
- return wp.transform(-quat_rotate(q_inv, t.p), q_inv)
229
-
230
-
231
- def transform_vector(t, v):
232
- return quat_rotate(t.q, v)
233
-
234
-
235
- def transform_point(t, p):
236
- return np.array(t.p) + quat_rotate(t.q, p)
237
-
238
-
239
- def transform_multiply(t, u):
240
- return wp.transform(quat_rotate(t.q, u.p) + t.p, quat_multiply(t.q, u.q))
241
-
242
-
243
- # flatten an array of transforms (p,q) format to a 7-vector
244
- def transform_flatten(t):
245
- return np.array([*t.p, *t.q])
246
-
247
-
248
- # expand a 7-vec to a tuple of arrays
249
- def transform_expand(t):
250
- return wp.transform(np.array(t[0:3]), np.array(t[3:7]))
251
-
252
-
253
- # convert array of transforms to a array of 7-vecs
254
- def transform_flatten_list(xforms):
255
- exp = lambda t: transform_flatten(t)
256
- return list(map(exp, xforms))
257
-
258
-
259
- def transform_expand_list(xforms):
260
- exp = lambda t: transform_expand(t)
261
- return list(map(exp, xforms))
262
-
263
-
264
- def transform_inertia(m, I, p, q):
265
- """
266
- Transforms the inertia tensor described by the given mass and 3x3 inertia
267
- matrix to a new frame described by the given position and orientation.
268
- """
269
- R = quat_to_matrix(q)
270
-
271
- # Steiner's theorem
272
- return R @ I @ R.T + m * (np.dot(p, p) * np.eye(3) - np.outer(p, p))
273
-
274
-
275
- # spatial operators
276
-
277
-
278
- # AdT
279
- def spatial_adjoint(t):
280
- R = quat_to_matrix(t.q)
281
- w = skew(t.p)
282
-
283
- A = np.zeros((6, 6))
284
- A[0:3, 0:3] = R
285
- A[3:6, 0:3] = np.dot(w, R)
286
- A[3:6, 3:6] = R
287
-
288
- return A
289
-
290
-
291
- # (AdT)^-T
292
- def spatial_adjoint_dual(t):
293
- R = quat_to_matrix(t.q)
294
- w = skew(t.p)
295
-
296
- A = np.zeros((6, 6))
297
- A[0:3, 0:3] = R
298
- A[0:3, 3:6] = np.dot(w, R)
299
- A[3:6, 3:6] = R
300
-
301
- return A
302
-
303
-
304
- # AdT*s
305
- def transform_twist(t_ab, s_b):
306
- return np.dot(spatial_adjoint(t_ab), s_b)
307
-
308
-
309
- # AdT^{-T}*s
310
- def transform_wrench(t_ab, f_b):
311
- return np.dot(spatial_adjoint_dual(t_ab), f_b)
312
-
313
-
314
- # transform spatial inertia (6x6) in b frame to a frame
315
- def transform_spatial_inertia(t_ab, I_b):
316
- t_ba = transform_inverse(t_ab)
317
-
318
- # todo: write specialized method
319
- I_a = np.dot(np.dot(spatial_adjoint(t_ba).T, I_b), spatial_adjoint(t_ba))
320
- return I_a
321
-
322
-
323
- def translate_twist(p_ab, s_b):
324
- w = s_b[0:3]
325
- v = np.cross(p_ab, s_b[0:3]) + s_b[3:6]
326
-
327
- return np.array((*w, *v))
328
-
329
-
330
- def translate_wrench(p_ab, s_b):
331
- w = s_b[0:3] + np.cross(p_ab, s_b[3:6])
332
- v = s_b[3:6]
333
-
334
- return np.array((*w, *v))
335
-
336
-
337
- # def spatial_vector(v=(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)):
338
- # return np.array(v)
339
-
340
-
341
- # ad_V pg. 289 L&P, pg. 25 Featherstone
342
- def spatial_cross(a, b):
343
- w = np.cross(a[0:3], b[0:3])
344
- v = np.cross(a[3:6], b[0:3]) + np.cross(a[0:3], b[3:6])
345
-
346
- return np.array((*w, *v))
347
-
348
-
349
- # ad_V^T pg. 290 L&P, pg. 25 Featurestone, note this does not includes the sign flip in the definition
350
- def spatial_cross_dual(a, b):
351
- w = np.cross(a[0:3], b[0:3]) + np.cross(a[3:6], b[3:6])
352
- v = np.cross(a[0:3], b[3:6])
353
-
354
- return np.array((*w, *v))
355
-
356
-
357
- def spatial_dot(a, b):
358
- return np.dot(a, b)
359
-
360
-
361
- def spatial_outer(a, b):
362
- return np.outer(a, b)
363
-
364
-
365
- # def spatial_matrix():
366
- # return np.zeros((6, 6))
367
-
368
-
369
- def spatial_matrix_from_inertia(I, m):
370
- G = spatial_matrix()
371
-
372
- G[0:3, 0:3] = I
373
- G[3, 3] = m
374
- G[4, 4] = m
375
- G[5, 5] = m
376
-
377
- return G
378
-
379
-
380
- # solves x = I^(-1)b
381
- def spatial_solve(I, b):
382
- return np.dot(np.linalg.inv(I), b)
383
-
384
-
385
- # helper to retrive body angular velocity from a twist v_s in se(3)
386
- def get_body_angular_velocity(v_s):
387
- return v_s[0:3]
388
-
389
-
390
- # helper to compute velocity of a point p on a body given it's spatial twist v_s
391
- def get_body_linear_velocity(v_s, p):
392
- dpdt = v_s[3:6] + np.cross(v_s[0:3], p)
393
- return dpdt
394
-
395
-
396
- # helper to build a body twist given the angular and linear velocity of
397
- # the center of mass specified in the world frame, returns the body
398
- # twist with respect to the origin (v_s)
399
- def get_body_twist(w_m, v_m, p_m):
400
- lin = v_m + np.cross(p_m, w_m)
401
- return (*w_m, *lin)
402
-
403
-
404
51
  def array_scan(in_array, out_array, inclusive=True):
405
52
  if in_array.device != out_array.device:
406
53
  raise RuntimeError("Array storage devices do not match")
@@ -411,6 +58,9 @@ def array_scan(in_array, out_array, inclusive=True):
411
58
  if in_array.dtype != out_array.dtype:
412
59
  raise RuntimeError("Array data types do not match")
413
60
 
61
+ if in_array.size == 0:
62
+ return
63
+
414
64
  from warp.context import runtime
415
65
 
416
66
  if in_array.device.is_cpu:
@@ -433,6 +83,9 @@ def radix_sort_pairs(keys, values, count: int):
433
83
  if keys.device != values.device:
434
84
  raise RuntimeError("Array storage devices do not match")
435
85
 
86
+ if count == 0:
87
+ return
88
+
436
89
  if keys.size < 2 * count or values.size < 2 * count:
437
90
  raise RuntimeError("Array storage must be large enough to contain 2*count elements")
438
91
 
@@ -469,14 +122,19 @@ def runlength_encode(values, run_values, run_lengths, run_count=None, value_coun
469
122
  # User can provide a device output array for storing the number of runs
470
123
  # For convenience, if no such array is provided, number of runs is returned on host
471
124
  if run_count is None:
472
- host_return = True
125
+ if value_count == 0:
126
+ return 0
473
127
  run_count = wp.empty(shape=(1,), dtype=int, device=values.device)
128
+ host_return = True
474
129
  else:
475
- host_return = False
476
130
  if run_count.device != values.device:
477
- raise RuntimeError("run_count storage devices does not match other arrays")
131
+ raise RuntimeError("run_count storage device does not match other arrays")
478
132
  if run_count.dtype != wp.int32:
479
133
  raise RuntimeError("run_count array must be of type int32")
134
+ if value_count == 0:
135
+ run_count.zero_()
136
+ return 0
137
+ host_return = False
480
138
 
481
139
  from warp.context import runtime
482
140
 
@@ -532,6 +190,12 @@ def array_sum(values, out=None, value_count=None, axis=None):
532
190
  if out.shape != output_shape:
533
191
  raise RuntimeError(f"out array should have shape {output_shape}")
534
192
 
193
+ if value_count == 0:
194
+ out.zero_()
195
+ if axis is None and host_return:
196
+ return out.numpy()[0]
197
+ return out
198
+
535
199
  from warp.context import runtime
536
200
 
537
201
  if values.device.is_cpu:
@@ -578,7 +242,7 @@ def array_inner(a, b, out=None, count=None, axis=None):
578
242
  raise RuntimeError("Array storage sizes do not match")
579
243
 
580
244
  if a.device != b.device:
581
- raise RuntimeError("Array storage sizes do not match")
245
+ raise RuntimeError("Array storage devices do not match")
582
246
 
583
247
  if a.dtype != b.dtype:
584
248
  raise RuntimeError("Array data types do not match")
@@ -615,6 +279,12 @@ def array_inner(a, b, out=None, count=None, axis=None):
615
279
  if out.shape != output_shape:
616
280
  raise RuntimeError(f"out array should have shape {output_shape}")
617
281
 
282
+ if count == 0:
283
+ if axis is None and host_return:
284
+ return 0.0
285
+ out.zero_()
286
+ return out
287
+
618
288
  from warp.context import runtime
619
289
 
620
290
  if a.device.is_cpu:
@@ -662,29 +332,16 @@ def array_inner(a, b, out=None, count=None, axis=None):
662
332
  return out
663
333
 
664
334
 
665
- _copy_kernel_cache = dict()
335
+ @wp.kernel
336
+ def _array_cast_kernel(
337
+ dest: Any,
338
+ src: Any,
339
+ ):
340
+ i = wp.tid()
341
+ dest[i] = dest.dtype(src[i])
666
342
 
667
343
 
668
344
  def array_cast(in_array, out_array, count=None):
669
- def make_copy_kernel(dest_dtype, src_dtype):
670
- import re
671
-
672
- import warp.context
673
-
674
- def copy_kernel(
675
- dest: Any,
676
- src: Any,
677
- ):
678
- dest[wp.tid()] = dest_dtype(src[wp.tid()])
679
-
680
- module = wp.get_module(copy_kernel.__module__)
681
- key = f"{copy_kernel.__name__}_{warp.context.type_str(src_dtype)}_{warp.context.type_str(dest_dtype)}"
682
- key = re.sub("[^0-9a-zA-Z_]+", "", key)
683
-
684
- if key not in _copy_kernel_cache:
685
- _copy_kernel_cache[key] = wp.Kernel(func=copy_kernel, key=key, module=module)
686
- return _copy_kernel_cache[key]
687
-
688
345
  if in_array.device != out_array.device:
689
346
  raise RuntimeError("Array storage devices do not match")
690
347
 
@@ -739,8 +396,7 @@ def array_cast(in_array, out_array, count=None):
739
396
  # Same data type, can simply copy
740
397
  wp.copy(dest=out_array, src=in_array, count=count)
741
398
  else:
742
- copy_kernel = make_copy_kernel(src_dtype=in_array.dtype, dest_dtype=out_array.dtype)
743
- wp.launch(kernel=copy_kernel, dim=dim, inputs=[out_array, in_array], device=out_array.device)
399
+ wp.launch(kernel=_array_cast_kernel, dim=dim, inputs=[out_array, in_array], device=out_array.device)
744
400
 
745
401
 
746
402
  # code snippet for invoking cProfile
@@ -816,11 +472,8 @@ class MeshAdjacency:
816
472
 
817
473
  self.edges[key] = edge
818
474
 
819
- def opposite_vertex(self, edge):
820
- pass
821
-
822
475
 
823
- def mem_report():
476
+ def mem_report(): #pragma: no cover
824
477
  def _mem_report(tensors, mem_type):
825
478
  """Print the selected tensors of type
826
479
  There are two major storage types in our major concern:
@@ -872,12 +525,6 @@ def mem_report():
872
525
  print("=" * LEN)
873
526
 
874
527
 
875
- def lame_parameters(E, nu):
876
- l = (E * nu) / ((1.0 + nu) * (1.0 - 2.0 * nu))
877
- mu = E / (2.0 * (1.0 + nu))
878
-
879
- return (l, mu)
880
-
881
528
 
882
529
  class ScopedDevice:
883
530
  def __init__(self, device):
@@ -1019,3 +666,17 @@ class ScopedTimer:
1019
666
  print("{}{} took {:.2f} ms".format(indent, self.name, self.elapsed))
1020
667
 
1021
668
  ScopedTimer.indent -= 1
669
+
670
+
671
+ # helper kernels for adj_matmul
672
+ @wp.kernel
673
+ def add_kernel_2d(x: wp.array2d(dtype=Any), acc: wp.array2d(dtype=Any), beta: Any):
674
+ i, j = wp.tid()
675
+
676
+ x[i,j] = x[i,j] + beta * acc[i,j]
677
+
678
+ @wp.kernel
679
+ def add_kernel_3d(x: wp.array3d(dtype=Any), acc: wp.array3d(dtype=Any), beta: Any):
680
+ i, j, k = wp.tid()
681
+
682
+ x[i,j,k] = x[i,j,k] + beta * acc[i,j,k]