warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,124 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Sim Rigid FEM
10
+ #
11
+ # Shows how to set up a rigid sphere colliding with an FEM beam
12
+ # using wp.sim.ModelBuilder().
13
+ #
14
+ ###########################################################################
15
+
16
+ import os
17
+
18
+ import warp as wp
19
+ import warp.sim
20
+ import warp.sim.render
21
+
22
+ wp.init()
23
+
24
+
25
+ class Example:
26
+ def __init__(self, stage):
27
+ self.sim_width = 8
28
+ self.sim_height = 8
29
+
30
+ self.sim_fps = 60.0
31
+ self.frame_dt = 1.0 / self.sim_fps
32
+ self.sim_substeps = 32
33
+ self.sim_duration = 5.0
34
+ self.sim_frames = int(self.sim_duration * self.sim_fps)
35
+ self.sim_dt = (1.0 / self.sim_fps) / self.sim_substeps
36
+ self.sim_time = 0.0
37
+ self.sim_iterations = 1
38
+ self.sim_relaxation = 1.0
39
+ self.profiler = {}
40
+
41
+ builder = wp.sim.ModelBuilder()
42
+ builder.default_particle_radius = 0.01
43
+
44
+ builder.add_soft_grid(
45
+ pos=wp.vec3(0.0, 0.0, 0.0),
46
+ rot=wp.quat_identity(),
47
+ vel=wp.vec3(0.0, 0.0, 0.0),
48
+ dim_x=20,
49
+ dim_y=10,
50
+ dim_z=10,
51
+ cell_x=0.1,
52
+ cell_y=0.1,
53
+ cell_z=0.1,
54
+ density=100.0,
55
+ k_mu=50000.0,
56
+ k_lambda=20000.0,
57
+ k_damp=0.0,
58
+ )
59
+
60
+ b = builder.add_body(origin=wp.transform((0.5, 2.5, 0.5), wp.quat_identity()))
61
+ builder.add_shape_sphere(body=b, radius=0.75, density=100.0)
62
+
63
+ self.model = builder.finalize()
64
+ self.model.ground = True
65
+ self.model.soft_contact_ke = 1.0e3
66
+ self.model.soft_contact_kd = 0.0
67
+ self.model.soft_contact_kf = 1.0e3
68
+
69
+ self.integrator = wp.sim.SemiImplicitIntegrator()
70
+
71
+ self.state_0 = self.model.state()
72
+ self.state_1 = self.model.state()
73
+
74
+ self.renderer = None
75
+ if stage:
76
+ self.renderer = wp.sim.render.SimRenderer(self.model, stage, scaling=1.0)
77
+
78
+ self.use_graph = wp.get_device().is_cuda
79
+ if self.use_graph:
80
+ with wp.ScopedCapture() as capture:
81
+ self.simulate()
82
+ self.graph = capture.graph
83
+
84
+ def simulate(self):
85
+ for s in range(self.sim_substeps):
86
+ wp.sim.collide(self.model, self.state_0)
87
+
88
+ self.state_0.clear_forces()
89
+ self.state_1.clear_forces()
90
+
91
+ self.integrator.simulate(self.model, self.state_0, self.state_1, self.sim_dt)
92
+
93
+ # swap states
94
+ (self.state_0, self.state_1) = (self.state_1, self.state_0)
95
+
96
+ def step(self):
97
+ with wp.ScopedTimer("step", dict=self.profiler):
98
+ if self.use_graph:
99
+ wp.capture_launch(self.graph)
100
+ else:
101
+ self.simulate()
102
+ self.sim_time += self.frame_dt
103
+
104
+ def render(self):
105
+ if self.renderer is None:
106
+ return
107
+
108
+ with wp.ScopedTimer("render", active=True):
109
+ self.renderer.begin_frame(self.sim_time)
110
+ self.renderer.render(self.state_0)
111
+ self.renderer.end_frame()
112
+
113
+
114
+ if __name__ == "__main__":
115
+ stage_path = os.path.join(os.path.dirname(__file__), "example_rigid_soft_contact.usd")
116
+
117
+ example = Example(stage_path)
118
+
119
+ for i in range(example.sim_frames):
120
+ example.step()
121
+ example.render()
122
+
123
+ if example.renderer:
124
+ example.renderer.save()
@@ -0,0 +1,178 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Sim Neo-Hookean
10
+ #
11
+ # Shows a simulation of an Neo-Hookean FEM beam being twisted through a
12
+ # 180 degree rotation.
13
+ #
14
+ ###########################################################################
15
+
16
+ import math
17
+ import os
18
+
19
+ import warp as wp
20
+ import warp.sim
21
+ import warp.sim.render
22
+
23
+ wp.init()
24
+
25
+
26
+ @wp.kernel
27
+ def twist_points(
28
+ rest: wp.array(dtype=wp.vec3), points: wp.array(dtype=wp.vec3), mass: wp.array(dtype=float), xform: wp.transform
29
+ ):
30
+ tid = wp.tid()
31
+
32
+ r = rest[tid]
33
+ p = points[tid]
34
+ m = mass[tid]
35
+
36
+ # twist the top layer of particles in the beam
37
+ if m == 0 and p[1] != 0.0:
38
+ points[tid] = wp.transform_point(xform, r)
39
+
40
+
41
+ @wp.kernel
42
+ def compute_volume(points: wp.array(dtype=wp.vec3), indices: wp.array2d(dtype=int), volume: wp.array(dtype=float)):
43
+ tid = wp.tid()
44
+
45
+ i = indices[tid, 0]
46
+ j = indices[tid, 1]
47
+ k = indices[tid, 2]
48
+ l = indices[tid, 3]
49
+
50
+ x0 = points[i]
51
+ x1 = points[j]
52
+ x2 = points[k]
53
+ x3 = points[l]
54
+
55
+ x10 = x1 - x0
56
+ x20 = x2 - x0
57
+ x30 = x3 - x0
58
+
59
+ v = wp.dot(x10, wp.cross(x20, x30)) / 6.0
60
+
61
+ wp.atomic_add(volume, 0, v)
62
+
63
+
64
+ class Example:
65
+ def __init__(self, stage):
66
+ sim_fps = 60.0
67
+ self.sim_substeps = 64
68
+ sim_duration = 5.0
69
+ self.sim_frames = int(sim_duration * sim_fps)
70
+ self.frame_dt = 1.0 / sim_fps
71
+ self.sim_dt = (1.0 / sim_fps) / self.sim_substeps
72
+ self.sim_time = 0.0
73
+ self.lift_speed = 2.5 / sim_duration * 2.0 # from Smith et al.
74
+ self.rot_speed = math.pi / sim_duration
75
+
76
+ builder = wp.sim.ModelBuilder()
77
+
78
+ cell_dim = 15
79
+ cell_size = 2.0 / cell_dim
80
+
81
+ center = cell_size * cell_dim * 0.5
82
+
83
+ builder.add_soft_grid(
84
+ pos=wp.vec3(-center, 0.0, -center),
85
+ rot=wp.quat_identity(),
86
+ vel=wp.vec3(0.0, 0.0, 0.0),
87
+ dim_x=cell_dim,
88
+ dim_y=cell_dim,
89
+ dim_z=cell_dim,
90
+ cell_x=cell_size,
91
+ cell_y=cell_size,
92
+ cell_z=cell_size,
93
+ density=100.0,
94
+ fix_bottom=True,
95
+ fix_top=True,
96
+ k_mu=1000.0,
97
+ k_lambda=5000.0,
98
+ k_damp=0.0,
99
+ )
100
+
101
+ self.model = builder.finalize()
102
+ self.model.ground = False
103
+ self.model.gravity[1] = 0.0
104
+
105
+ self.integrator = wp.sim.SemiImplicitIntegrator()
106
+
107
+ self.rest = self.model.state()
108
+ self.rest_vol = (cell_size * cell_dim) ** 3
109
+
110
+ self.state_0 = self.model.state()
111
+ self.state_1 = self.model.state()
112
+
113
+ self.volume = wp.zeros(1, dtype=wp.float32)
114
+
115
+ self.renderer = None
116
+ if stage:
117
+ self.renderer = wp.sim.render.SimRenderer(self.model, stage, scaling=20.0)
118
+
119
+ self.use_graph = wp.get_device().is_cuda
120
+ if self.use_graph:
121
+ with wp.ScopedCapture() as capture:
122
+ self.simulate()
123
+ self.graph = capture.graph
124
+
125
+ def simulate(self):
126
+ for _ in range(self.sim_substeps):
127
+ self.state_0.clear_forces()
128
+ self.state_1.clear_forces()
129
+
130
+ self.integrator.simulate(self.model, self.state_0, self.state_1, self.sim_dt)
131
+
132
+ # swap states
133
+ (self.state_0, self.state_1) = (self.state_1, self.state_0)
134
+
135
+ def step(self):
136
+ with wp.ScopedTimer("step"):
137
+ xform = wp.transform(
138
+ (0.0, self.lift_speed * self.sim_time, 0.0),
139
+ wp.quat_from_axis_angle(wp.vec3(0.0, 1.0, 0.0), self.rot_speed * self.sim_time),
140
+ )
141
+ wp.launch(
142
+ kernel=twist_points,
143
+ dim=len(self.state_0.particle_q),
144
+ inputs=[self.rest.particle_q, self.state_0.particle_q, self.model.particle_mass, xform],
145
+ )
146
+ if self.use_graph:
147
+ wp.capture_launch(self.graph)
148
+ else:
149
+ self.simulate()
150
+ self.volume.zero_()
151
+ wp.launch(
152
+ kernel=compute_volume,
153
+ dim=self.model.tet_count,
154
+ inputs=[self.state_0.particle_q, self.model.tet_indices, self.volume],
155
+ )
156
+ self.sim_time += self.frame_dt
157
+
158
+ def render(self):
159
+ if self.renderer is None:
160
+ return
161
+
162
+ with wp.ScopedTimer("render"):
163
+ self.renderer.begin_frame(self.sim_time)
164
+ self.renderer.render(self.state_0)
165
+ self.renderer.end_frame()
166
+
167
+
168
+ if __name__ == "__main__":
169
+ stage_path = os.path.join(os.path.dirname(__file__), "example_soft_body.usd")
170
+
171
+ example = Example(stage_path)
172
+
173
+ for i in range(example.sim_frames):
174
+ example.step()
175
+ example.render()
176
+
177
+ if example.renderer:
178
+ example.renderer.save()
warp/fabric.py CHANGED
@@ -108,6 +108,8 @@ class fabricarray(noncontiguous_array_base[T]):
108
108
  def __init__(self, data=None, attrib=None, dtype=Any, ndim=None):
109
109
  super().__init__(ARRAY_TYPE_FABRIC)
110
110
 
111
+ self.deleter = None
112
+
111
113
  if data is not None:
112
114
  from .context import runtime
113
115
 
@@ -174,24 +176,29 @@ class fabricarray(noncontiguous_array_base[T]):
174
176
 
175
177
  num_buckets = len(pointers)
176
178
  size = 0
177
-
178
179
  buckets = (fabricbucket_t * num_buckets)()
179
- for i in range(num_buckets):
180
- buckets[i].index_start = size
181
- buckets[i].index_end = size + counts[i]
182
- buckets[i].ptr = pointers[i]
183
- if array_lengths:
184
- buckets[i].lengths = array_lengths[i]
185
- size += counts[i]
186
-
187
- if self.device.is_cuda:
188
- # copy bucket info to device
189
- with warp.ScopedStream(self.device.null_stream):
180
+
181
+ if num_buckets > 0:
182
+ for i in range(num_buckets):
183
+ buckets[i].index_start = size
184
+ buckets[i].index_end = size + counts[i]
185
+ buckets[i].ptr = pointers[i]
186
+ if array_lengths:
187
+ buckets[i].lengths = array_lengths[i]
188
+ size += counts[i]
189
+
190
+ if self.device.is_cuda:
191
+ # copy bucket info to device
190
192
  buckets_size = ctypes.sizeof(buckets)
191
- buckets_ptr = self.device.allocator.alloc(buckets_size)
192
- runtime.core.memcpy_h2d(self.device.context, buckets_ptr, ctypes.addressof(buckets), buckets_size)
193
+ allocator = self.device.get_allocator()
194
+ buckets_ptr = allocator.alloc(buckets_size)
195
+ cuda_stream = self.device.stream.cuda_stream
196
+ runtime.core.memcpy_h2d(self.device.context, buckets_ptr, ctypes.addressof(buckets), buckets_size, cuda_stream)
197
+ self.deleter = allocator.deleter
198
+ else:
199
+ buckets_ptr = ctypes.addressof(buckets)
193
200
  else:
194
- buckets_ptr = ctypes.addressof(buckets)
201
+ buckets_ptr = None
195
202
 
196
203
  self.buckets = buckets
197
204
  self.size = size
@@ -211,11 +218,13 @@ class fabricarray(noncontiguous_array_base[T]):
211
218
  self.ctype = fabricarray_t()
212
219
 
213
220
  def __del__(self):
214
- # release the GPU copy of bucket info
215
- if self.buckets is not None and self.device.is_cuda:
216
- buckets_size = ctypes.sizeof(self.buckets)
217
- with self.device.context_guard:
218
- self.device.allocator.free(self.ctype.buckets, buckets_size)
221
+ # release the bucket info if needed
222
+ if self.deleter is None:
223
+ return
224
+
225
+ buckets_size = ctypes.sizeof(self.buckets)
226
+ with self.device.context_guard:
227
+ self.deleter(self.ctype.buckets, buckets_size)
219
228
 
220
229
  def __ctype__(self):
221
230
  return self.ctype
warp/fem/cache.py CHANGED
@@ -223,7 +223,6 @@ class Temporary:
223
223
  pinned=array.pinned,
224
224
  capacity=array.capacity,
225
225
  copy=False,
226
- owner=False,
227
226
  grad=None if array.grad is None else _view_reshaped_truncated(array.grad),
228
227
  )
229
228
 
warp/fem/dirichlet.py CHANGED
@@ -23,7 +23,6 @@ def normalize_dirichlet_projector(projector_matrix: BsrMatrix, fixed_value: Opti
23
23
  data=None,
24
24
  ptr=projector_values.ptr,
25
25
  capacity=projector_values.capacity,
26
- owner=False,
27
26
  device=projector_values.device,
28
27
  dtype=wp.mat(shape=projector_matrix.block_shape, dtype=projector_matrix.scalar_type),
29
28
  shape=projector_values.shape[0],
@@ -47,7 +46,6 @@ def normalize_dirichlet_projector(projector_matrix: BsrMatrix, fixed_value: Opti
47
46
  data=None,
48
47
  ptr=fixed_value.ptr,
49
48
  capacity=fixed_value.capacity,
50
- owner=False,
51
49
  device=fixed_value.device,
52
50
  dtype=wp.vec(length=projector_matrix.block_shape[0], dtype=projector_matrix.scalar_type),
53
51
  shape=fixed_value.shape[0],
warp/fem/integrate.py CHANGED
@@ -969,7 +969,6 @@ def _launch_integrate_kernel(
969
969
  data=None,
970
970
  ptr=array.ptr,
971
971
  capacity=array.capacity,
972
- owner=False,
973
972
  device=array.device,
974
973
  shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
975
974
  dtype=wp.types.type_scalar_type(output_dtype),
warp/jax.py CHANGED
@@ -6,6 +6,7 @@
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
8
  import warp
9
+ from warp.context import type_str
9
10
 
10
11
 
11
12
  def device_to_jax(wp_device):
@@ -34,6 +35,50 @@ def device_from_jax(jax_device):
34
35
  raise RuntimeError(f"Unknown or unsupported Jax device platform '{jax_device.platform}'")
35
36
 
36
37
 
38
+ def dtype_to_jax(wp_dtype):
39
+ import jax.numpy as jp
40
+
41
+ warp_to_jax_dict = {
42
+ warp.float16: jp.float16,
43
+ warp.float32: jp.float32,
44
+ warp.float64: jp.float64,
45
+ warp.int8: jp.int8,
46
+ warp.int16: jp.int16,
47
+ warp.int32: jp.int32,
48
+ warp.int64: jp.int64,
49
+ warp.uint8: jp.uint8,
50
+ warp.uint16: jp.uint16,
51
+ warp.uint32: jp.uint32,
52
+ warp.uint64: jp.uint64,
53
+ }
54
+ jax_dtype = warp_to_jax_dict.get(wp_dtype)
55
+ if jax_dtype is None:
56
+ raise TypeError(f"Invalid or unsupported data type: {type_str(wp_dtype)}")
57
+ return jax_dtype
58
+
59
+
60
+ def dtype_from_jax(jax_dtype):
61
+ import jax.numpy as jp
62
+
63
+ jax_to_warp_dict = {
64
+ jp.float16: warp.float16,
65
+ jp.float32: warp.float32,
66
+ jp.float64: warp.float64,
67
+ jp.int8: warp.int8,
68
+ jp.int16: warp.int16,
69
+ jp.int32: warp.int32,
70
+ jp.int64: warp.int64,
71
+ jp.uint8: warp.uint8,
72
+ jp.uint16: warp.uint16,
73
+ jp.uint32: warp.uint32,
74
+ jp.uint64: warp.uint64,
75
+ }
76
+ wp_dtype = jax_to_warp_dict.get(jax_dtype)
77
+ if wp_dtype is None:
78
+ raise TypeError(f"Invalid or unsupported data type: {jax_dtype}")
79
+ return wp_dtype
80
+
81
+
37
82
  def to_jax(wp_array):
38
83
  import jax.dlpack
39
84