warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -130,14 +130,14 @@ class XLA_FFI_DataType(enum.IntEnum):
130
130
  # int64_t* dims; // length == rank
131
131
  # };
132
132
  class XLA_FFI_Buffer(ctypes.Structure):
133
- _fields_ = [
133
+ _fields_ = (
134
134
  ("struct_size", ctypes.c_size_t),
135
135
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
136
136
  ("dtype", ctypes.c_int), # XLA_FFI_DataType
137
137
  ("data", ctypes.c_void_p),
138
138
  ("rank", ctypes.c_int64),
139
139
  ("dims", ctypes.POINTER(ctypes.c_int64)),
140
- ]
140
+ )
141
141
 
142
142
 
143
143
  # typedef enum {
@@ -162,13 +162,13 @@ class XLA_FFI_RetType(enum.IntEnum):
162
162
  # void** args; // length == size
163
163
  # };
164
164
  class XLA_FFI_Args(ctypes.Structure):
165
- _fields_ = [
165
+ _fields_ = (
166
166
  ("struct_size", ctypes.c_size_t),
167
167
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
168
168
  ("size", ctypes.c_int64),
169
169
  ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_ArgType*
170
170
  ("args", ctypes.POINTER(ctypes.c_void_p)),
171
- ]
171
+ )
172
172
 
173
173
 
174
174
  # struct XLA_FFI_Rets {
@@ -179,13 +179,13 @@ class XLA_FFI_Args(ctypes.Structure):
179
179
  # void** rets; // length == size
180
180
  # };
181
181
  class XLA_FFI_Rets(ctypes.Structure):
182
- _fields_ = [
182
+ _fields_ = (
183
183
  ("struct_size", ctypes.c_size_t),
184
184
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
185
185
  ("size", ctypes.c_int64),
186
186
  ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_RetType*
187
187
  ("rets", ctypes.POINTER(ctypes.c_void_p)),
188
- ]
188
+ )
189
189
 
190
190
 
191
191
  # typedef struct XLA_FFI_ByteSpan {
@@ -193,7 +193,10 @@ class XLA_FFI_Rets(ctypes.Structure):
193
193
  # size_t len;
194
194
  # } XLA_FFI_ByteSpan;
195
195
  class XLA_FFI_ByteSpan(ctypes.Structure):
196
- _fields_ = [("ptr", ctypes.POINTER(ctypes.c_char)), ("len", ctypes.c_size_t)]
196
+ _fields_ = (
197
+ ("ptr", ctypes.POINTER(ctypes.c_char)),
198
+ ("len", ctypes.c_size_t),
199
+ )
197
200
 
198
201
 
199
202
  # typedef struct XLA_FFI_Scalar {
@@ -201,7 +204,10 @@ class XLA_FFI_ByteSpan(ctypes.Structure):
201
204
  # void* value;
202
205
  # } XLA_FFI_Scalar;
203
206
  class XLA_FFI_Scalar(ctypes.Structure):
204
- _fields_ = [("dtype", ctypes.c_int), ("value", ctypes.c_void_p)]
207
+ _fields_ = (
208
+ ("dtype", ctypes.c_int),
209
+ ("value", ctypes.c_void_p),
210
+ )
205
211
 
206
212
 
207
213
  # typedef struct XLA_FFI_Array {
@@ -210,7 +216,11 @@ class XLA_FFI_Scalar(ctypes.Structure):
210
216
  # void* data;
211
217
  # } XLA_FFI_Array;
212
218
  class XLA_FFI_Array(ctypes.Structure):
213
- _fields_ = [("dtype", ctypes.c_int), ("size", ctypes.c_size_t), ("data", ctypes.c_void_p)]
219
+ _fields_ = (
220
+ ("dtype", ctypes.c_int),
221
+ ("size", ctypes.c_size_t),
222
+ ("data", ctypes.c_void_p),
223
+ )
214
224
 
215
225
 
216
226
  # typedef enum {
@@ -235,14 +245,14 @@ class XLA_FFI_AttrType(enum.IntEnum):
235
245
  # void** attrs; // length == size
236
246
  # };
237
247
  class XLA_FFI_Attrs(ctypes.Structure):
238
- _fields_ = [
248
+ _fields_ = (
239
249
  ("struct_size", ctypes.c_size_t),
240
250
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
241
251
  ("size", ctypes.c_int64),
242
252
  ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_AttrType*
243
253
  ("names", ctypes.POINTER(ctypes.POINTER(XLA_FFI_ByteSpan))),
244
254
  ("attrs", ctypes.POINTER(ctypes.c_void_p)),
245
- ]
255
+ )
246
256
 
247
257
 
248
258
  # struct XLA_FFI_Api_Version {
@@ -252,12 +262,12 @@ class XLA_FFI_Attrs(ctypes.Structure):
252
262
  # int minor_version; // out
253
263
  # };
254
264
  class XLA_FFI_Api_Version(ctypes.Structure):
255
- _fields_ = [
265
+ _fields_ = (
256
266
  ("struct_size", ctypes.c_size_t),
257
267
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
258
268
  ("major_version", ctypes.c_int),
259
269
  ("minor_version", ctypes.c_int),
260
- ]
270
+ )
261
271
 
262
272
 
263
273
  # enum XLA_FFI_Handler_TraitsBits {
@@ -276,11 +286,11 @@ class XLA_FFI_Handler_TraitsBits(enum.IntEnum):
276
286
  # XLA_FFI_Handler_Traits traits;
277
287
  # };
278
288
  class XLA_FFI_Metadata(ctypes.Structure):
279
- _fields_ = [
289
+ _fields_ = (
280
290
  ("struct_size", ctypes.c_size_t),
281
291
  ("api_version", XLA_FFI_Api_Version), # XLA_FFI_Extension_Type
282
292
  ("traits", ctypes.c_uint32), # XLA_FFI_Handler_Traits
283
- ]
293
+ )
284
294
 
285
295
 
286
296
  # struct XLA_FFI_Metadata_Extension {
@@ -288,7 +298,10 @@ class XLA_FFI_Metadata(ctypes.Structure):
288
298
  # XLA_FFI_Metadata* metadata;
289
299
  # };
290
300
  class XLA_FFI_Metadata_Extension(ctypes.Structure):
291
- _fields_ = [("extension_base", XLA_FFI_Extension_Base), ("metadata", ctypes.POINTER(XLA_FFI_Metadata))]
301
+ _fields_ = (
302
+ ("extension_base", XLA_FFI_Extension_Base),
303
+ ("metadata", ctypes.POINTER(XLA_FFI_Metadata)),
304
+ )
292
305
 
293
306
 
294
307
  # typedef enum {
@@ -337,12 +350,12 @@ class XLA_FFI_Error_Code(enum.IntEnum):
337
350
  # XLA_FFI_Error_Code errc;
338
351
  # };
339
352
  class XLA_FFI_Error_Create_Args(ctypes.Structure):
340
- _fields_ = [
353
+ _fields_ = (
341
354
  ("struct_size", ctypes.c_size_t),
342
355
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
343
356
  ("message", ctypes.c_char_p),
344
357
  ("errc", ctypes.c_int),
345
- ] # XLA_FFI_Error_Code
358
+ ) # XLA_FFI_Error_Code
346
359
 
347
360
 
348
361
  XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Error_Create_Args))
@@ -355,12 +368,12 @@ XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_
355
368
  # void* stream; // out
356
369
  # };
357
370
  class XLA_FFI_Stream_Get_Args(ctypes.Structure):
358
- _fields_ = [
371
+ _fields_ = (
359
372
  ("struct_size", ctypes.c_size_t),
360
373
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
361
374
  ("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
362
375
  ("stream", ctypes.c_void_p),
363
- ] # // out
376
+ ) # // out
364
377
 
365
378
 
366
379
  XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))
@@ -391,7 +404,7 @@ XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_St
391
404
  # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
392
405
  # };
393
406
  class XLA_FFI_Api(ctypes.Structure):
394
- _fields_ = [
407
+ _fields_ = (
395
408
  ("struct_size", ctypes.c_size_t),
396
409
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
397
410
  ("api_version", XLA_FFI_Api_Version),
@@ -412,7 +425,7 @@ class XLA_FFI_Api(ctypes.Structure):
412
425
  ("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
413
426
  ("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
414
427
  ("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
415
- ]
428
+ )
416
429
 
417
430
 
418
431
  # struct XLA_FFI_CallFrame {
@@ -431,7 +444,7 @@ class XLA_FFI_Api(ctypes.Structure):
431
444
  # XLA_FFI_Future* future; // out
432
445
  # };
433
446
  class XLA_FFI_CallFrame(ctypes.Structure):
434
- _fields_ = [
447
+ _fields_ = (
435
448
  ("struct_size", ctypes.c_size_t),
436
449
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
437
450
  ("api", ctypes.POINTER(XLA_FFI_Api)),
@@ -441,7 +454,7 @@ class XLA_FFI_CallFrame(ctypes.Structure):
441
454
  ("rets", XLA_FFI_Rets),
442
455
  ("attrs", XLA_FFI_Attrs),
443
456
  ("future", ctypes.c_void_p), # XLA_FFI_Future* // out
444
- ]
457
+ )
445
458
 
446
459
 
447
460
  _xla_data_type_to_constructor = {
warp/math.py CHANGED
@@ -22,11 +22,13 @@ Vector norm functions
22
22
  """
23
23
 
24
24
  __all__ = [
25
+ "norm_huber",
25
26
  "norm_l1",
26
27
  "norm_l2",
27
- "norm_huber",
28
28
  "norm_pseudo_huber",
29
29
  "smooth_normalize",
30
+ "transform_compose",
31
+ "transform_decompose",
30
32
  "transform_from_matrix",
31
33
  "transform_to_matrix",
32
34
  ]
@@ -142,6 +144,19 @@ def create_transform_from_matrix_func(dtype):
142
144
  """
143
145
  Construct a transformation from a 4x4 matrix.
144
146
 
147
+ .. math::
148
+ M = \\begin{bmatrix}
149
+ R_{00} & R_{01} & R_{02} & p_x \\\\
150
+ R_{10} & R_{11} & R_{12} & p_y \\\\
151
+ R_{20} & R_{21} & R_{22} & p_z \\\\
152
+ 0 & 0 & 0 & 1
153
+ \\end{bmatrix}
154
+
155
+ Where:
156
+
157
+ * :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
158
+ * :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
159
+
145
160
  Args:
146
161
  mat (Matrix[4, 4, Float]): Matrix to convert.
147
162
 
@@ -177,6 +192,19 @@ def create_transform_to_matrix_func(dtype):
177
192
  """
178
193
  Convert a transformation to a 4x4 matrix.
179
194
 
195
+ .. math::
196
+ M = \\begin{bmatrix}
197
+ R_{00} & R_{01} & R_{02} & p_x \\\\
198
+ R_{10} & R_{11} & R_{12} & p_y \\\\
199
+ R_{20} & R_{21} & R_{22} & p_z \\\\
200
+ 0 & 0 & 0 & 1
201
+ \\end{bmatrix}
202
+
203
+ Where:
204
+
205
+ * :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
206
+ * :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
207
+
180
208
  Args:
181
209
  xform (Transformation[Float]): Transformation to convert.
182
210
 
@@ -212,6 +240,140 @@ wp.func(
212
240
  )
213
241
 
214
242
 
243
+ def create_transform_compose_func(dtype):
244
+ mat44 = wp.types.matrix((4, 4), dtype)
245
+ quat = wp.types.quaternion(dtype)
246
+ vec3 = wp.types.vector(3, dtype)
247
+
248
+ def transform_compose(position: vec3, rotation: quat, scale: vec3):
249
+ """
250
+ Compose a 4x4 transformation matrix from a 3D position, quaternion orientation, and 3D scale.
251
+
252
+ .. math::
253
+ M = \\begin{bmatrix}
254
+ s_x R_{00} & s_y R_{01} & s_z R_{02} & p_x \\\\
255
+ s_x R_{10} & s_y R_{11} & s_z R_{12} & p_y \\\\
256
+ s_x R_{20} & s_y R_{21} & s_z R_{22} & p_z \\\\
257
+ 0 & 0 & 0 & 1
258
+ \\end{bmatrix}
259
+
260
+ Where:
261
+
262
+ * :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
263
+ * :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
264
+ * :math:`s` is the 3D scale vector :math:`[s_x, s_y, s_z]` of the input transform.
265
+
266
+ Args:
267
+ position (Vector[3, Float]): The 3D position vector.
268
+ rotation (Quaternion[Float]): The quaternion orientation.
269
+ scale (Vector[3, Float]): The 3D scale vector.
270
+
271
+ Returns:
272
+ Matrix[4, 4, Float]: The transformation matrix.
273
+ """
274
+ R = wp.quat_to_matrix(rotation)
275
+ # fmt: off
276
+ return mat44(
277
+ scale[0] * R[0,0], scale[1] * R[0,1], scale[2] * R[0,2], position[0],
278
+ scale[0] * R[1,0], scale[1] * R[1,1], scale[2] * R[1,2], position[1],
279
+ scale[0] * R[2,0], scale[1] * R[2,1], scale[2] * R[2,2], position[2],
280
+ dtype(0.0), dtype(0.0), dtype(0.0), dtype(1.0),
281
+ )
282
+ # fmt: on
283
+
284
+ return transform_compose
285
+
286
+
287
+ transform_compose = wp.func(
288
+ create_transform_compose_func(wp.float32),
289
+ name="transform_compose",
290
+ )
291
+ wp.func(
292
+ create_transform_compose_func(wp.float16),
293
+ name="transform_compose",
294
+ )
295
+ wp.func(
296
+ create_transform_compose_func(wp.float64),
297
+ name="transform_compose",
298
+ )
299
+
300
+
301
+ def create_transform_decompose_func(dtype):
302
+ mat44 = wp.types.matrix((4, 4), dtype)
303
+ vec3 = wp.types.vector(3, dtype)
304
+ mat33 = wp.types.matrix((3, 3), dtype)
305
+ zero = dtype(0.0)
306
+
307
+ def transform_decompose(m: mat44):
308
+ """
309
+ Decompose a 4x4 transformation matrix into 3D position, quaternion orientation, and 3D scale.
310
+
311
+ .. math::
312
+ M = \\begin{bmatrix}
313
+ s_x R_{00} & s_y R_{01} & s_z R_{02} & p_x \\\\
314
+ s_x R_{10} & s_y R_{11} & s_z R_{12} & p_y \\\\
315
+ s_x R_{20} & s_y R_{21} & s_z R_{22} & p_z \\\\
316
+ 0 & 0 & 0 & 1
317
+ \\end{bmatrix}
318
+
319
+ Where:
320
+
321
+ * :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
322
+ * :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
323
+ * :math:`s` is the 3D scale vector :math:`[s_x, s_y, s_z]` of the input transform.
324
+
325
+ Args:
326
+ m (Matrix[4, 4, Float]): The matrix to decompose.
327
+
328
+ Returns:
329
+ Tuple[Vector[3, Float], Quaternion[Float], Vector[3, Float]]: A tuple containing the position vector, quaternion orientation, and scale vector.
330
+ """
331
+ # extract position
332
+ position = vec3(m[0, 3], m[1, 3], m[2, 3])
333
+ # extract rotation matrix components
334
+ r00, r01, r02 = m[0, 0], m[0, 1], m[0, 2]
335
+ r10, r11, r12 = m[1, 0], m[1, 1], m[1, 2]
336
+ r20, r21, r22 = m[2, 0], m[2, 1], m[2, 2]
337
+ # get scale magnitudes
338
+ sx = wp.sqrt(r00 * r00 + r10 * r10 + r20 * r20)
339
+ sy = wp.sqrt(r01 * r01 + r11 * r11 + r21 * r21)
340
+ sz = wp.sqrt(r02 * r02 + r12 * r12 + r22 * r22)
341
+ # normalize rotation matrix components
342
+ if sx != zero:
343
+ r00 /= sx
344
+ r10 /= sx
345
+ r20 /= sx
346
+ if sy != zero:
347
+ r01 /= sy
348
+ r11 /= sy
349
+ r21 /= sy
350
+ if sz != zero:
351
+ r02 /= sz
352
+ r12 /= sz
353
+ r22 /= sz
354
+ # extract rotation (quaternion)
355
+ rotation = wp.quat_from_matrix(mat33(r00, r01, r02, r10, r11, r12, r20, r21, r22))
356
+ # extract scale
357
+ scale = vec3(sx, sy, sz)
358
+ return position, rotation, scale
359
+
360
+ return transform_decompose
361
+
362
+
363
+ transform_decompose = wp.func(
364
+ create_transform_decompose_func(wp.float32),
365
+ name="transform_decompose",
366
+ )
367
+ wp.func(
368
+ create_transform_decompose_func(wp.float16),
369
+ name="transform_decompose",
370
+ )
371
+ wp.func(
372
+ create_transform_decompose_func(wp.float64),
373
+ name="transform_decompose",
374
+ )
375
+
376
+
215
377
  # register API functions so they appear in the documentation
216
378
 
217
379
  wp.context.register_api_function(
@@ -242,3 +404,11 @@ wp.context.register_api_function(
242
404
  transform_to_matrix,
243
405
  group="Transformations",
244
406
  )
407
+ wp.context.register_api_function(
408
+ transform_compose,
409
+ group="Transformations",
410
+ )
411
+ wp.context.register_api_function(
412
+ transform_decompose,
413
+ group="Transformations",
414
+ )
warp/native/array.h CHANGED
@@ -161,7 +161,7 @@ inline CUDA_CALLABLE void print(shape_t s)
161
161
  // should probably store ndim with shape
162
162
  printf("(%d, %d, %d, %d)\n", s.dims[0], s.dims[1], s.dims[2], s.dims[3]);
163
163
  }
164
- inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& shape_t) {}
164
+ inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& adj_s) {}
165
165
 
166
166
 
167
167
  template <typename T>
@@ -665,11 +665,11 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
665
665
  }
666
666
 
667
667
  template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
668
- inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T> adj_ret) {}
668
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T>& adj_ret) {}
669
669
  template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
670
- inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T> adj_ret) {}
670
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T>& adj_ret) {}
671
671
  template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
672
- inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T> adj_ret) {}
672
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T>& adj_ret) {}
673
673
 
674
674
  // TODO: lower_bound() for indexed arrays?
675
675
 
@@ -743,6 +743,24 @@ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, T value)
743
743
  template<template<typename> class A, typename T>
744
744
  inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
745
745
 
746
+ template<template<typename> class A, typename T>
747
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, T old_value, T new_value) { return atomic_cas(&index(buf, i), old_value, new_value); }
748
+ template<template<typename> class A, typename T>
749
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, T old_value, T new_value) { return atomic_cas(&index(buf, i, j), old_value, new_value); }
750
+ template<template<typename> class A, typename T>
751
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, int k, T old_value, T new_value) { return atomic_cas(&index(buf, i, j, k), old_value, new_value); }
752
+ template<template<typename> class A, typename T>
753
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, int k, int l, T old_value, T new_value) { return atomic_cas(&index(buf, i, j, k, l), old_value, new_value); }
754
+
755
+ template<template<typename> class A, typename T>
756
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, T value) { return atomic_exch(&index(buf, i), value); }
757
+ template<template<typename> class A, typename T>
758
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, T value) { return atomic_exch(&index(buf, i, j), value); }
759
+ template<template<typename> class A, typename T>
760
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, T value) { return atomic_exch(&index(buf, i, j, k), value); }
761
+ template<template<typename> class A, typename T>
762
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_exch(&index(buf, i, j, k, l), value); }
763
+
746
764
  template<template<typename> class A, typename T>
747
765
  inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
748
766
  template<template<typename> class A, typename T>
@@ -1128,6 +1146,87 @@ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k,
1128
1146
  FP_VERIFY_ADJ_4(value, adj_value)
1129
1147
  }
1130
1148
 
1149
+ template<template<typename> class A1, template<typename> class A2, typename T>
1150
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, T compare, T value, const A2<T>& adj_buf, int adj_i, T& adj_compare, T& adj_value, const T& adj_ret) {
1151
+ if (adj_buf.data)
1152
+ adj_atomic_cas(&index(buf, i), compare, value, &index(adj_buf, i), adj_compare, adj_value, adj_ret);
1153
+ else if (buf.grad)
1154
+ adj_atomic_cas(&index(buf, i), compare, value, &index_grad(buf, i), adj_compare, adj_value, adj_ret);
1155
+
1156
+ FP_VERIFY_ADJ_1(value, adj_value)
1157
+ }
1158
+
1159
+ template<template<typename> class A1, template<typename> class A2, typename T>
1160
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_compare, T& adj_value, const T& adj_ret) {
1161
+ if (adj_buf.data)
1162
+ adj_atomic_cas(&index(buf, i, j), compare, value, &index(adj_buf, i, j), adj_compare, adj_value, adj_ret);
1163
+ else if (buf.grad)
1164
+ adj_atomic_cas(&index(buf, i, j), compare, value, &index_grad(buf, i, j), adj_compare, adj_value, adj_ret);
1165
+
1166
+ FP_VERIFY_ADJ_2(value, adj_value)
1167
+ }
1168
+
1169
+ template<template<typename> class A1, template<typename> class A2, typename T>
1170
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, int k, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_compare, T& adj_value, const T& adj_ret) {
1171
+ if (adj_buf.data)
1172
+ adj_atomic_cas(&index(buf, i, j, k), compare, value, &index(adj_buf, i, j, k), adj_compare, adj_value, adj_ret);
1173
+ else if (buf.grad)
1174
+ adj_atomic_cas(&index(buf, i, j, k), compare, value, &index_grad(buf, i, j, k), adj_compare, adj_value, adj_ret);
1175
+
1176
+ FP_VERIFY_ADJ_3(value, adj_value)
1177
+ }
1178
+
1179
+ template<template<typename> class A1, template<typename> class A2, typename T>
1180
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, int k, int l, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_compare, T& adj_value, const T& adj_ret) {
1181
+ if (adj_buf.data)
1182
+ adj_atomic_cas(&index(buf, i, j, k, l), compare, value, &index(adj_buf, i, j, k, l), adj_compare, adj_value, adj_ret);
1183
+ else if (buf.grad)
1184
+ adj_atomic_cas(&index(buf, i, j, k, l), compare, value, &index_grad(buf, i, j, k, l), adj_compare, adj_value, adj_ret);
1185
+
1186
+ FP_VERIFY_ADJ_4(value, adj_value)
1187
+ }
1188
+
1189
+ template<template<typename> class A1, template<typename> class A2, typename T>
1190
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {
1191
+ if (adj_buf.data)
1192
+ adj_atomic_exch(&index(buf, i), value, &index(adj_buf, i), adj_value, adj_ret);
1193
+ else if (buf.grad)
1194
+ adj_atomic_exch(&index(buf, i), value, &index_grad(buf, i), adj_value, adj_ret);
1195
+
1196
+ FP_VERIFY_ADJ_1(value, adj_value)
1197
+ }
1198
+
1199
+ template<template<typename> class A1, template<typename> class A2, typename T>
1200
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {
1201
+ if (adj_buf.data)
1202
+ adj_atomic_exch(&index(buf, i, j), value, &index(adj_buf, i, j), adj_value, adj_ret);
1203
+ else if (buf.grad)
1204
+ adj_atomic_exch(&index(buf, i, j), value, &index_grad(buf, i, j), adj_value, adj_ret);
1205
+
1206
+ FP_VERIFY_ADJ_2(value, adj_value)
1207
+ }
1208
+
1209
+ template<template<typename> class A1, template<typename> class A2, typename T>
1210
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {
1211
+ if (adj_buf.data)
1212
+ adj_atomic_exch(&index(buf, i, j, k), value, &index(adj_buf, i, j, k), adj_value, adj_ret);
1213
+ else if (buf.grad)
1214
+ adj_atomic_exch(&index(buf, i, j, k), value, &index_grad(buf, i, j, k), adj_value, adj_ret);
1215
+
1216
+ FP_VERIFY_ADJ_3(value, adj_value)
1217
+ }
1218
+
1219
+ template<template<typename> class A1, template<typename> class A2, typename T>
1220
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {
1221
+ if (adj_buf.data)
1222
+ adj_atomic_exch(&index(buf, i, j, k, l), value, &index(adj_buf, i, j, k, l), adj_value, adj_ret);
1223
+ else if (buf.grad)
1224
+ adj_atomic_exch(&index(buf, i, j, k, l), value, &index_grad(buf, i, j, k, l), adj_value, adj_ret);
1225
+
1226
+ FP_VERIFY_ADJ_4(value, adj_value)
1227
+ }
1228
+
1229
+
1131
1230
  template<template<typename> class A, typename T>
1132
1231
  CUDA_CALLABLE inline int len(const A<T>& a)
1133
1232
  {