warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/jax_experimental/xla_ffi.py
CHANGED
|
@@ -130,14 +130,14 @@ class XLA_FFI_DataType(enum.IntEnum):
|
|
|
130
130
|
# int64_t* dims; // length == rank
|
|
131
131
|
# };
|
|
132
132
|
class XLA_FFI_Buffer(ctypes.Structure):
|
|
133
|
-
_fields_ =
|
|
133
|
+
_fields_ = (
|
|
134
134
|
("struct_size", ctypes.c_size_t),
|
|
135
135
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
136
136
|
("dtype", ctypes.c_int), # XLA_FFI_DataType
|
|
137
137
|
("data", ctypes.c_void_p),
|
|
138
138
|
("rank", ctypes.c_int64),
|
|
139
139
|
("dims", ctypes.POINTER(ctypes.c_int64)),
|
|
140
|
-
|
|
140
|
+
)
|
|
141
141
|
|
|
142
142
|
|
|
143
143
|
# typedef enum {
|
|
@@ -162,13 +162,13 @@ class XLA_FFI_RetType(enum.IntEnum):
|
|
|
162
162
|
# void** args; // length == size
|
|
163
163
|
# };
|
|
164
164
|
class XLA_FFI_Args(ctypes.Structure):
|
|
165
|
-
_fields_ =
|
|
165
|
+
_fields_ = (
|
|
166
166
|
("struct_size", ctypes.c_size_t),
|
|
167
167
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
168
168
|
("size", ctypes.c_int64),
|
|
169
169
|
("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_ArgType*
|
|
170
170
|
("args", ctypes.POINTER(ctypes.c_void_p)),
|
|
171
|
-
|
|
171
|
+
)
|
|
172
172
|
|
|
173
173
|
|
|
174
174
|
# struct XLA_FFI_Rets {
|
|
@@ -179,13 +179,13 @@ class XLA_FFI_Args(ctypes.Structure):
|
|
|
179
179
|
# void** rets; // length == size
|
|
180
180
|
# };
|
|
181
181
|
class XLA_FFI_Rets(ctypes.Structure):
|
|
182
|
-
_fields_ =
|
|
182
|
+
_fields_ = (
|
|
183
183
|
("struct_size", ctypes.c_size_t),
|
|
184
184
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
185
185
|
("size", ctypes.c_int64),
|
|
186
186
|
("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_RetType*
|
|
187
187
|
("rets", ctypes.POINTER(ctypes.c_void_p)),
|
|
188
|
-
|
|
188
|
+
)
|
|
189
189
|
|
|
190
190
|
|
|
191
191
|
# typedef struct XLA_FFI_ByteSpan {
|
|
@@ -193,7 +193,10 @@ class XLA_FFI_Rets(ctypes.Structure):
|
|
|
193
193
|
# size_t len;
|
|
194
194
|
# } XLA_FFI_ByteSpan;
|
|
195
195
|
class XLA_FFI_ByteSpan(ctypes.Structure):
|
|
196
|
-
_fields_ =
|
|
196
|
+
_fields_ = (
|
|
197
|
+
("ptr", ctypes.POINTER(ctypes.c_char)),
|
|
198
|
+
("len", ctypes.c_size_t),
|
|
199
|
+
)
|
|
197
200
|
|
|
198
201
|
|
|
199
202
|
# typedef struct XLA_FFI_Scalar {
|
|
@@ -201,7 +204,10 @@ class XLA_FFI_ByteSpan(ctypes.Structure):
|
|
|
201
204
|
# void* value;
|
|
202
205
|
# } XLA_FFI_Scalar;
|
|
203
206
|
class XLA_FFI_Scalar(ctypes.Structure):
|
|
204
|
-
_fields_ =
|
|
207
|
+
_fields_ = (
|
|
208
|
+
("dtype", ctypes.c_int),
|
|
209
|
+
("value", ctypes.c_void_p),
|
|
210
|
+
)
|
|
205
211
|
|
|
206
212
|
|
|
207
213
|
# typedef struct XLA_FFI_Array {
|
|
@@ -210,7 +216,11 @@ class XLA_FFI_Scalar(ctypes.Structure):
|
|
|
210
216
|
# void* data;
|
|
211
217
|
# } XLA_FFI_Array;
|
|
212
218
|
class XLA_FFI_Array(ctypes.Structure):
|
|
213
|
-
_fields_ =
|
|
219
|
+
_fields_ = (
|
|
220
|
+
("dtype", ctypes.c_int),
|
|
221
|
+
("size", ctypes.c_size_t),
|
|
222
|
+
("data", ctypes.c_void_p),
|
|
223
|
+
)
|
|
214
224
|
|
|
215
225
|
|
|
216
226
|
# typedef enum {
|
|
@@ -235,14 +245,14 @@ class XLA_FFI_AttrType(enum.IntEnum):
|
|
|
235
245
|
# void** attrs; // length == size
|
|
236
246
|
# };
|
|
237
247
|
class XLA_FFI_Attrs(ctypes.Structure):
|
|
238
|
-
_fields_ =
|
|
248
|
+
_fields_ = (
|
|
239
249
|
("struct_size", ctypes.c_size_t),
|
|
240
250
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
241
251
|
("size", ctypes.c_int64),
|
|
242
252
|
("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_AttrType*
|
|
243
253
|
("names", ctypes.POINTER(ctypes.POINTER(XLA_FFI_ByteSpan))),
|
|
244
254
|
("attrs", ctypes.POINTER(ctypes.c_void_p)),
|
|
245
|
-
|
|
255
|
+
)
|
|
246
256
|
|
|
247
257
|
|
|
248
258
|
# struct XLA_FFI_Api_Version {
|
|
@@ -252,12 +262,12 @@ class XLA_FFI_Attrs(ctypes.Structure):
|
|
|
252
262
|
# int minor_version; // out
|
|
253
263
|
# };
|
|
254
264
|
class XLA_FFI_Api_Version(ctypes.Structure):
|
|
255
|
-
_fields_ =
|
|
265
|
+
_fields_ = (
|
|
256
266
|
("struct_size", ctypes.c_size_t),
|
|
257
267
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
258
268
|
("major_version", ctypes.c_int),
|
|
259
269
|
("minor_version", ctypes.c_int),
|
|
260
|
-
|
|
270
|
+
)
|
|
261
271
|
|
|
262
272
|
|
|
263
273
|
# enum XLA_FFI_Handler_TraitsBits {
|
|
@@ -276,11 +286,11 @@ class XLA_FFI_Handler_TraitsBits(enum.IntEnum):
|
|
|
276
286
|
# XLA_FFI_Handler_Traits traits;
|
|
277
287
|
# };
|
|
278
288
|
class XLA_FFI_Metadata(ctypes.Structure):
|
|
279
|
-
_fields_ =
|
|
289
|
+
_fields_ = (
|
|
280
290
|
("struct_size", ctypes.c_size_t),
|
|
281
291
|
("api_version", XLA_FFI_Api_Version), # XLA_FFI_Extension_Type
|
|
282
292
|
("traits", ctypes.c_uint32), # XLA_FFI_Handler_Traits
|
|
283
|
-
|
|
293
|
+
)
|
|
284
294
|
|
|
285
295
|
|
|
286
296
|
# struct XLA_FFI_Metadata_Extension {
|
|
@@ -288,7 +298,10 @@ class XLA_FFI_Metadata(ctypes.Structure):
|
|
|
288
298
|
# XLA_FFI_Metadata* metadata;
|
|
289
299
|
# };
|
|
290
300
|
class XLA_FFI_Metadata_Extension(ctypes.Structure):
|
|
291
|
-
_fields_ =
|
|
301
|
+
_fields_ = (
|
|
302
|
+
("extension_base", XLA_FFI_Extension_Base),
|
|
303
|
+
("metadata", ctypes.POINTER(XLA_FFI_Metadata)),
|
|
304
|
+
)
|
|
292
305
|
|
|
293
306
|
|
|
294
307
|
# typedef enum {
|
|
@@ -337,12 +350,12 @@ class XLA_FFI_Error_Code(enum.IntEnum):
|
|
|
337
350
|
# XLA_FFI_Error_Code errc;
|
|
338
351
|
# };
|
|
339
352
|
class XLA_FFI_Error_Create_Args(ctypes.Structure):
|
|
340
|
-
_fields_ =
|
|
353
|
+
_fields_ = (
|
|
341
354
|
("struct_size", ctypes.c_size_t),
|
|
342
355
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
343
356
|
("message", ctypes.c_char_p),
|
|
344
357
|
("errc", ctypes.c_int),
|
|
345
|
-
|
|
358
|
+
) # XLA_FFI_Error_Code
|
|
346
359
|
|
|
347
360
|
|
|
348
361
|
XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Error_Create_Args))
|
|
@@ -355,12 +368,12 @@ XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_
|
|
|
355
368
|
# void* stream; // out
|
|
356
369
|
# };
|
|
357
370
|
class XLA_FFI_Stream_Get_Args(ctypes.Structure):
|
|
358
|
-
_fields_ =
|
|
371
|
+
_fields_ = (
|
|
359
372
|
("struct_size", ctypes.c_size_t),
|
|
360
373
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
361
374
|
("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
|
|
362
375
|
("stream", ctypes.c_void_p),
|
|
363
|
-
|
|
376
|
+
) # // out
|
|
364
377
|
|
|
365
378
|
|
|
366
379
|
XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))
|
|
@@ -391,7 +404,7 @@ XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_St
|
|
|
391
404
|
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
|
|
392
405
|
# };
|
|
393
406
|
class XLA_FFI_Api(ctypes.Structure):
|
|
394
|
-
_fields_ =
|
|
407
|
+
_fields_ = (
|
|
395
408
|
("struct_size", ctypes.c_size_t),
|
|
396
409
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
397
410
|
("api_version", XLA_FFI_Api_Version),
|
|
@@ -412,7 +425,7 @@ class XLA_FFI_Api(ctypes.Structure):
|
|
|
412
425
|
("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
|
|
413
426
|
("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
|
|
414
427
|
("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
|
|
415
|
-
|
|
428
|
+
)
|
|
416
429
|
|
|
417
430
|
|
|
418
431
|
# struct XLA_FFI_CallFrame {
|
|
@@ -431,7 +444,7 @@ class XLA_FFI_Api(ctypes.Structure):
|
|
|
431
444
|
# XLA_FFI_Future* future; // out
|
|
432
445
|
# };
|
|
433
446
|
class XLA_FFI_CallFrame(ctypes.Structure):
|
|
434
|
-
_fields_ =
|
|
447
|
+
_fields_ = (
|
|
435
448
|
("struct_size", ctypes.c_size_t),
|
|
436
449
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
437
450
|
("api", ctypes.POINTER(XLA_FFI_Api)),
|
|
@@ -441,7 +454,7 @@ class XLA_FFI_CallFrame(ctypes.Structure):
|
|
|
441
454
|
("rets", XLA_FFI_Rets),
|
|
442
455
|
("attrs", XLA_FFI_Attrs),
|
|
443
456
|
("future", ctypes.c_void_p), # XLA_FFI_Future* // out
|
|
444
|
-
|
|
457
|
+
)
|
|
445
458
|
|
|
446
459
|
|
|
447
460
|
_xla_data_type_to_constructor = {
|
warp/math.py
CHANGED
|
@@ -22,11 +22,13 @@ Vector norm functions
|
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
24
|
__all__ = [
|
|
25
|
+
"norm_huber",
|
|
25
26
|
"norm_l1",
|
|
26
27
|
"norm_l2",
|
|
27
|
-
"norm_huber",
|
|
28
28
|
"norm_pseudo_huber",
|
|
29
29
|
"smooth_normalize",
|
|
30
|
+
"transform_compose",
|
|
31
|
+
"transform_decompose",
|
|
30
32
|
"transform_from_matrix",
|
|
31
33
|
"transform_to_matrix",
|
|
32
34
|
]
|
|
@@ -142,6 +144,19 @@ def create_transform_from_matrix_func(dtype):
|
|
|
142
144
|
"""
|
|
143
145
|
Construct a transformation from a 4x4 matrix.
|
|
144
146
|
|
|
147
|
+
.. math::
|
|
148
|
+
M = \\begin{bmatrix}
|
|
149
|
+
R_{00} & R_{01} & R_{02} & p_x \\\\
|
|
150
|
+
R_{10} & R_{11} & R_{12} & p_y \\\\
|
|
151
|
+
R_{20} & R_{21} & R_{22} & p_z \\\\
|
|
152
|
+
0 & 0 & 0 & 1
|
|
153
|
+
\\end{bmatrix}
|
|
154
|
+
|
|
155
|
+
Where:
|
|
156
|
+
|
|
157
|
+
* :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
|
|
158
|
+
* :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
|
|
159
|
+
|
|
145
160
|
Args:
|
|
146
161
|
mat (Matrix[4, 4, Float]): Matrix to convert.
|
|
147
162
|
|
|
@@ -177,6 +192,19 @@ def create_transform_to_matrix_func(dtype):
|
|
|
177
192
|
"""
|
|
178
193
|
Convert a transformation to a 4x4 matrix.
|
|
179
194
|
|
|
195
|
+
.. math::
|
|
196
|
+
M = \\begin{bmatrix}
|
|
197
|
+
R_{00} & R_{01} & R_{02} & p_x \\\\
|
|
198
|
+
R_{10} & R_{11} & R_{12} & p_y \\\\
|
|
199
|
+
R_{20} & R_{21} & R_{22} & p_z \\\\
|
|
200
|
+
0 & 0 & 0 & 1
|
|
201
|
+
\\end{bmatrix}
|
|
202
|
+
|
|
203
|
+
Where:
|
|
204
|
+
|
|
205
|
+
* :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
|
|
206
|
+
* :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
|
|
207
|
+
|
|
180
208
|
Args:
|
|
181
209
|
xform (Transformation[Float]): Transformation to convert.
|
|
182
210
|
|
|
@@ -212,6 +240,140 @@ wp.func(
|
|
|
212
240
|
)
|
|
213
241
|
|
|
214
242
|
|
|
243
|
+
def create_transform_compose_func(dtype):
|
|
244
|
+
mat44 = wp.types.matrix((4, 4), dtype)
|
|
245
|
+
quat = wp.types.quaternion(dtype)
|
|
246
|
+
vec3 = wp.types.vector(3, dtype)
|
|
247
|
+
|
|
248
|
+
def transform_compose(position: vec3, rotation: quat, scale: vec3):
|
|
249
|
+
"""
|
|
250
|
+
Compose a 4x4 transformation matrix from a 3D position, quaternion orientation, and 3D scale.
|
|
251
|
+
|
|
252
|
+
.. math::
|
|
253
|
+
M = \\begin{bmatrix}
|
|
254
|
+
s_x R_{00} & s_y R_{01} & s_z R_{02} & p_x \\\\
|
|
255
|
+
s_x R_{10} & s_y R_{11} & s_z R_{12} & p_y \\\\
|
|
256
|
+
s_x R_{20} & s_y R_{21} & s_z R_{22} & p_z \\\\
|
|
257
|
+
0 & 0 & 0 & 1
|
|
258
|
+
\\end{bmatrix}
|
|
259
|
+
|
|
260
|
+
Where:
|
|
261
|
+
|
|
262
|
+
* :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
|
|
263
|
+
* :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
|
|
264
|
+
* :math:`s` is the 3D scale vector :math:`[s_x, s_y, s_z]` of the input transform.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
position (Vector[3, Float]): The 3D position vector.
|
|
268
|
+
rotation (Quaternion[Float]): The quaternion orientation.
|
|
269
|
+
scale (Vector[3, Float]): The 3D scale vector.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Matrix[4, 4, Float]: The transformation matrix.
|
|
273
|
+
"""
|
|
274
|
+
R = wp.quat_to_matrix(rotation)
|
|
275
|
+
# fmt: off
|
|
276
|
+
return mat44(
|
|
277
|
+
scale[0] * R[0,0], scale[1] * R[0,1], scale[2] * R[0,2], position[0],
|
|
278
|
+
scale[0] * R[1,0], scale[1] * R[1,1], scale[2] * R[1,2], position[1],
|
|
279
|
+
scale[0] * R[2,0], scale[1] * R[2,1], scale[2] * R[2,2], position[2],
|
|
280
|
+
dtype(0.0), dtype(0.0), dtype(0.0), dtype(1.0),
|
|
281
|
+
)
|
|
282
|
+
# fmt: on
|
|
283
|
+
|
|
284
|
+
return transform_compose
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
transform_compose = wp.func(
|
|
288
|
+
create_transform_compose_func(wp.float32),
|
|
289
|
+
name="transform_compose",
|
|
290
|
+
)
|
|
291
|
+
wp.func(
|
|
292
|
+
create_transform_compose_func(wp.float16),
|
|
293
|
+
name="transform_compose",
|
|
294
|
+
)
|
|
295
|
+
wp.func(
|
|
296
|
+
create_transform_compose_func(wp.float64),
|
|
297
|
+
name="transform_compose",
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def create_transform_decompose_func(dtype):
|
|
302
|
+
mat44 = wp.types.matrix((4, 4), dtype)
|
|
303
|
+
vec3 = wp.types.vector(3, dtype)
|
|
304
|
+
mat33 = wp.types.matrix((3, 3), dtype)
|
|
305
|
+
zero = dtype(0.0)
|
|
306
|
+
|
|
307
|
+
def transform_decompose(m: mat44):
|
|
308
|
+
"""
|
|
309
|
+
Decompose a 4x4 transformation matrix into 3D position, quaternion orientation, and 3D scale.
|
|
310
|
+
|
|
311
|
+
.. math::
|
|
312
|
+
M = \\begin{bmatrix}
|
|
313
|
+
s_x R_{00} & s_y R_{01} & s_z R_{02} & p_x \\\\
|
|
314
|
+
s_x R_{10} & s_y R_{11} & s_z R_{12} & p_y \\\\
|
|
315
|
+
s_x R_{20} & s_y R_{21} & s_z R_{22} & p_z \\\\
|
|
316
|
+
0 & 0 & 0 & 1
|
|
317
|
+
\\end{bmatrix}
|
|
318
|
+
|
|
319
|
+
Where:
|
|
320
|
+
|
|
321
|
+
* :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
|
|
322
|
+
* :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
|
|
323
|
+
* :math:`s` is the 3D scale vector :math:`[s_x, s_y, s_z]` of the input transform.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
m (Matrix[4, 4, Float]): The matrix to decompose.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
Tuple[Vector[3, Float], Quaternion[Float], Vector[3, Float]]: A tuple containing the position vector, quaternion orientation, and scale vector.
|
|
330
|
+
"""
|
|
331
|
+
# extract position
|
|
332
|
+
position = vec3(m[0, 3], m[1, 3], m[2, 3])
|
|
333
|
+
# extract rotation matrix components
|
|
334
|
+
r00, r01, r02 = m[0, 0], m[0, 1], m[0, 2]
|
|
335
|
+
r10, r11, r12 = m[1, 0], m[1, 1], m[1, 2]
|
|
336
|
+
r20, r21, r22 = m[2, 0], m[2, 1], m[2, 2]
|
|
337
|
+
# get scale magnitudes
|
|
338
|
+
sx = wp.sqrt(r00 * r00 + r10 * r10 + r20 * r20)
|
|
339
|
+
sy = wp.sqrt(r01 * r01 + r11 * r11 + r21 * r21)
|
|
340
|
+
sz = wp.sqrt(r02 * r02 + r12 * r12 + r22 * r22)
|
|
341
|
+
# normalize rotation matrix components
|
|
342
|
+
if sx != zero:
|
|
343
|
+
r00 /= sx
|
|
344
|
+
r10 /= sx
|
|
345
|
+
r20 /= sx
|
|
346
|
+
if sy != zero:
|
|
347
|
+
r01 /= sy
|
|
348
|
+
r11 /= sy
|
|
349
|
+
r21 /= sy
|
|
350
|
+
if sz != zero:
|
|
351
|
+
r02 /= sz
|
|
352
|
+
r12 /= sz
|
|
353
|
+
r22 /= sz
|
|
354
|
+
# extract rotation (quaternion)
|
|
355
|
+
rotation = wp.quat_from_matrix(mat33(r00, r01, r02, r10, r11, r12, r20, r21, r22))
|
|
356
|
+
# extract scale
|
|
357
|
+
scale = vec3(sx, sy, sz)
|
|
358
|
+
return position, rotation, scale
|
|
359
|
+
|
|
360
|
+
return transform_decompose
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
transform_decompose = wp.func(
|
|
364
|
+
create_transform_decompose_func(wp.float32),
|
|
365
|
+
name="transform_decompose",
|
|
366
|
+
)
|
|
367
|
+
wp.func(
|
|
368
|
+
create_transform_decompose_func(wp.float16),
|
|
369
|
+
name="transform_decompose",
|
|
370
|
+
)
|
|
371
|
+
wp.func(
|
|
372
|
+
create_transform_decompose_func(wp.float64),
|
|
373
|
+
name="transform_decompose",
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
|
|
215
377
|
# register API functions so they appear in the documentation
|
|
216
378
|
|
|
217
379
|
wp.context.register_api_function(
|
|
@@ -242,3 +404,11 @@ wp.context.register_api_function(
|
|
|
242
404
|
transform_to_matrix,
|
|
243
405
|
group="Transformations",
|
|
244
406
|
)
|
|
407
|
+
wp.context.register_api_function(
|
|
408
|
+
transform_compose,
|
|
409
|
+
group="Transformations",
|
|
410
|
+
)
|
|
411
|
+
wp.context.register_api_function(
|
|
412
|
+
transform_decompose,
|
|
413
|
+
group="Transformations",
|
|
414
|
+
)
|
warp/native/array.h
CHANGED
|
@@ -161,7 +161,7 @@ inline CUDA_CALLABLE void print(shape_t s)
|
|
|
161
161
|
// should probably store ndim with shape
|
|
162
162
|
printf("(%d, %d, %d, %d)\n", s.dims[0], s.dims[1], s.dims[2], s.dims[3]);
|
|
163
163
|
}
|
|
164
|
-
inline CUDA_CALLABLE void adj_print(shape_t s, shape_t&
|
|
164
|
+
inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& adj_s) {}
|
|
165
165
|
|
|
166
166
|
|
|
167
167
|
template <typename T>
|
|
@@ -665,11 +665,11 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
|
|
|
665
665
|
}
|
|
666
666
|
|
|
667
667
|
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
668
|
-
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T
|
|
668
|
+
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T>& adj_ret) {}
|
|
669
669
|
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
670
|
-
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T
|
|
670
|
+
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T>& adj_ret) {}
|
|
671
671
|
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
672
|
-
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T
|
|
672
|
+
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T>& adj_ret) {}
|
|
673
673
|
|
|
674
674
|
// TODO: lower_bound() for indexed arrays?
|
|
675
675
|
|
|
@@ -743,6 +743,24 @@ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, T value)
|
|
|
743
743
|
template<template<typename> class A, typename T>
|
|
744
744
|
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
|
|
745
745
|
|
|
746
|
+
template<template<typename> class A, typename T>
|
|
747
|
+
inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, T old_value, T new_value) { return atomic_cas(&index(buf, i), old_value, new_value); }
|
|
748
|
+
template<template<typename> class A, typename T>
|
|
749
|
+
inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, T old_value, T new_value) { return atomic_cas(&index(buf, i, j), old_value, new_value); }
|
|
750
|
+
template<template<typename> class A, typename T>
|
|
751
|
+
inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, int k, T old_value, T new_value) { return atomic_cas(&index(buf, i, j, k), old_value, new_value); }
|
|
752
|
+
template<template<typename> class A, typename T>
|
|
753
|
+
inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, int k, int l, T old_value, T new_value) { return atomic_cas(&index(buf, i, j, k, l), old_value, new_value); }
|
|
754
|
+
|
|
755
|
+
template<template<typename> class A, typename T>
|
|
756
|
+
inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, T value) { return atomic_exch(&index(buf, i), value); }
|
|
757
|
+
template<template<typename> class A, typename T>
|
|
758
|
+
inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, T value) { return atomic_exch(&index(buf, i, j), value); }
|
|
759
|
+
template<template<typename> class A, typename T>
|
|
760
|
+
inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, T value) { return atomic_exch(&index(buf, i, j, k), value); }
|
|
761
|
+
template<template<typename> class A, typename T>
|
|
762
|
+
inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_exch(&index(buf, i, j, k, l), value); }
|
|
763
|
+
|
|
746
764
|
template<template<typename> class A, typename T>
|
|
747
765
|
inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
|
|
748
766
|
template<template<typename> class A, typename T>
|
|
@@ -1128,6 +1146,87 @@ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k,
|
|
|
1128
1146
|
FP_VERIFY_ADJ_4(value, adj_value)
|
|
1129
1147
|
}
|
|
1130
1148
|
|
|
1149
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1150
|
+
inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, T compare, T value, const A2<T>& adj_buf, int adj_i, T& adj_compare, T& adj_value, const T& adj_ret) {
|
|
1151
|
+
if (adj_buf.data)
|
|
1152
|
+
adj_atomic_cas(&index(buf, i), compare, value, &index(adj_buf, i), adj_compare, adj_value, adj_ret);
|
|
1153
|
+
else if (buf.grad)
|
|
1154
|
+
adj_atomic_cas(&index(buf, i), compare, value, &index_grad(buf, i), adj_compare, adj_value, adj_ret);
|
|
1155
|
+
|
|
1156
|
+
FP_VERIFY_ADJ_1(value, adj_value)
|
|
1157
|
+
}
|
|
1158
|
+
|
|
1159
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1160
|
+
inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_compare, T& adj_value, const T& adj_ret) {
|
|
1161
|
+
if (adj_buf.data)
|
|
1162
|
+
adj_atomic_cas(&index(buf, i, j), compare, value, &index(adj_buf, i, j), adj_compare, adj_value, adj_ret);
|
|
1163
|
+
else if (buf.grad)
|
|
1164
|
+
adj_atomic_cas(&index(buf, i, j), compare, value, &index_grad(buf, i, j), adj_compare, adj_value, adj_ret);
|
|
1165
|
+
|
|
1166
|
+
FP_VERIFY_ADJ_2(value, adj_value)
|
|
1167
|
+
}
|
|
1168
|
+
|
|
1169
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1170
|
+
inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, int k, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_compare, T& adj_value, const T& adj_ret) {
|
|
1171
|
+
if (adj_buf.data)
|
|
1172
|
+
adj_atomic_cas(&index(buf, i, j, k), compare, value, &index(adj_buf, i, j, k), adj_compare, adj_value, adj_ret);
|
|
1173
|
+
else if (buf.grad)
|
|
1174
|
+
adj_atomic_cas(&index(buf, i, j, k), compare, value, &index_grad(buf, i, j, k), adj_compare, adj_value, adj_ret);
|
|
1175
|
+
|
|
1176
|
+
FP_VERIFY_ADJ_3(value, adj_value)
|
|
1177
|
+
}
|
|
1178
|
+
|
|
1179
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1180
|
+
inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, int k, int l, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_compare, T& adj_value, const T& adj_ret) {
|
|
1181
|
+
if (adj_buf.data)
|
|
1182
|
+
adj_atomic_cas(&index(buf, i, j, k, l), compare, value, &index(adj_buf, i, j, k, l), adj_compare, adj_value, adj_ret);
|
|
1183
|
+
else if (buf.grad)
|
|
1184
|
+
adj_atomic_cas(&index(buf, i, j, k, l), compare, value, &index_grad(buf, i, j, k, l), adj_compare, adj_value, adj_ret);
|
|
1185
|
+
|
|
1186
|
+
FP_VERIFY_ADJ_4(value, adj_value)
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1190
|
+
inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {
|
|
1191
|
+
if (adj_buf.data)
|
|
1192
|
+
adj_atomic_exch(&index(buf, i), value, &index(adj_buf, i), adj_value, adj_ret);
|
|
1193
|
+
else if (buf.grad)
|
|
1194
|
+
adj_atomic_exch(&index(buf, i), value, &index_grad(buf, i), adj_value, adj_ret);
|
|
1195
|
+
|
|
1196
|
+
FP_VERIFY_ADJ_1(value, adj_value)
|
|
1197
|
+
}
|
|
1198
|
+
|
|
1199
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1200
|
+
inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {
|
|
1201
|
+
if (adj_buf.data)
|
|
1202
|
+
adj_atomic_exch(&index(buf, i, j), value, &index(adj_buf, i, j), adj_value, adj_ret);
|
|
1203
|
+
else if (buf.grad)
|
|
1204
|
+
adj_atomic_exch(&index(buf, i, j), value, &index_grad(buf, i, j), adj_value, adj_ret);
|
|
1205
|
+
|
|
1206
|
+
FP_VERIFY_ADJ_2(value, adj_value)
|
|
1207
|
+
}
|
|
1208
|
+
|
|
1209
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1210
|
+
inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {
|
|
1211
|
+
if (adj_buf.data)
|
|
1212
|
+
adj_atomic_exch(&index(buf, i, j, k), value, &index(adj_buf, i, j, k), adj_value, adj_ret);
|
|
1213
|
+
else if (buf.grad)
|
|
1214
|
+
adj_atomic_exch(&index(buf, i, j, k), value, &index_grad(buf, i, j, k), adj_value, adj_ret);
|
|
1215
|
+
|
|
1216
|
+
FP_VERIFY_ADJ_3(value, adj_value)
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1220
|
+
inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {
|
|
1221
|
+
if (adj_buf.data)
|
|
1222
|
+
adj_atomic_exch(&index(buf, i, j, k, l), value, &index(adj_buf, i, j, k, l), adj_value, adj_ret);
|
|
1223
|
+
else if (buf.grad)
|
|
1224
|
+
adj_atomic_exch(&index(buf, i, j, k, l), value, &index_grad(buf, i, j, k, l), adj_value, adj_ret);
|
|
1225
|
+
|
|
1226
|
+
FP_VERIFY_ADJ_4(value, adj_value)
|
|
1227
|
+
}
|
|
1228
|
+
|
|
1229
|
+
|
|
1131
1230
|
template<template<typename> class A, typename T>
|
|
1132
1231
|
CUDA_CALLABLE inline int len(const A<T>& a)
|
|
1133
1232
|
{
|