warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Diffusion
10
+ #
11
+ # This example solves a 2d diffusion problem:
12
+ #
13
+ # nu Div u = 1
14
+ #
15
+ # with Dirichlet boundary conditions on vertical edges and
16
+ # homogeneous Neumann on horizontal edges.
17
+ ###########################################################################
18
+
19
+ import argparse
20
+
21
+ import warp as wp
22
+ import warp.fem as fem
23
+
24
+ from warp.sparse import bsr_axpy
25
+ from warp.fem.utils import array_axpy
26
+
27
+
28
+ # Import example utilities
29
+ # Make sure that works both when imported as module and run as standalone file
30
+ try:
31
+ from .bsr_utils import bsr_cg
32
+ from .mesh_utils import gen_trimesh, gen_quadmesh
33
+ from .plot_utils import Plot
34
+ except ImportError:
35
+ from bsr_utils import bsr_cg
36
+ from mesh_utils import gen_trimesh, gen_quadmesh
37
+ from plot_utils import Plot
38
+
39
+ wp.init()
40
+
41
+
42
+ @fem.integrand
43
+ def linear_form(
44
+ s: fem.Sample,
45
+ v: fem.Field,
46
+ ):
47
+ """Linear form with constant slope 1 -- forcing term of our problem"""
48
+ return v(s)
49
+
50
+
51
+ @fem.integrand
52
+ def diffusion_form(s: fem.Sample, u: fem.Field, v: fem.Field, nu: float):
53
+ """Diffusion bilinear form with constant coefficient ``nu``"""
54
+ return nu * wp.dot(
55
+ fem.grad(u, s),
56
+ fem.grad(v, s),
57
+ )
58
+
59
+
60
+ @fem.integrand
61
+ def y_boundary_value_form(s: fem.Sample, domain: fem.Domain, v: fem.Field, val: float):
62
+ """Linear form with coefficient val on vertical edges, zero elsewhere"""
63
+ nor = fem.normal(domain, s)
64
+ return val * v(s) * wp.abs(nor[0])
65
+
66
+
67
+ @fem.integrand
68
+ def y_boundary_projector_form(
69
+ s: fem.Sample,
70
+ domain: fem.Domain,
71
+ u: fem.Field,
72
+ v: fem.Field,
73
+ ):
74
+ """
75
+ Bilinear boundary condition projector form, non-zero on vertical edges only.
76
+ """
77
+ # Reuse the above linear form implementation by evaluating one of the participating field and passing it as a normal scalar argument.
78
+ return y_boundary_value_form(s, domain, v, u(s))
79
+
80
+
81
+ class Example:
82
+ parser = argparse.ArgumentParser()
83
+ parser.add_argument("--resolution", type=int, default=50)
84
+ parser.add_argument("--degree", type=int, default=2)
85
+ parser.add_argument("--serendipity", action="store_true", default=False)
86
+ parser.add_argument("--viscosity", type=float, default=2.0)
87
+ parser.add_argument("--boundary_value", type=float, default=5.0)
88
+ parser.add_argument("--boundary_compliance", type=float, default=0, help="Dirichlet boundary condition compliance")
89
+ parser.add_argument("--mesh", choices=("grid", "tri", "quad"), default="grid", help="Mesh type")
90
+
91
+ def __init__(self, stage=None, quiet=False, args=None, **kwargs):
92
+ if args is None:
93
+ # Read args from kwargs, add default arg values from parser
94
+ args = argparse.Namespace(**kwargs)
95
+ args = Example.parser.parse_args(args=[], namespace=args)
96
+ self._args = args
97
+ self._quiet = quiet
98
+
99
+ # Grid or triangle mesh geometry
100
+ if args.mesh == "tri":
101
+ positions, tri_vidx = gen_trimesh(res=wp.vec2i(args.resolution))
102
+ self._geo = fem.Trimesh2D(tri_vertex_indices=tri_vidx, positions=positions)
103
+ elif args.mesh == "quad":
104
+ positions, quad_vidx = gen_quadmesh(res=wp.vec2i(args.resolution))
105
+ self._geo = fem.Quadmesh2D(quad_vertex_indices=quad_vidx, positions=positions)
106
+ else:
107
+ self._geo = fem.Grid2D(res=wp.vec2i(args.resolution))
108
+
109
+ # Scalar function space
110
+ element_basis = fem.ElementBasis.SERENDIPITY if args.serendipity else None
111
+ self._scalar_space = fem.make_polynomial_space(self._geo, degree=args.degree, element_basis=element_basis)
112
+
113
+ # Scalar field over our function space
114
+ self._scalar_field = self._scalar_space.make_field()
115
+
116
+ self.renderer = Plot(stage)
117
+
118
+ def step(self):
119
+ args = self._args
120
+ geo = self._geo
121
+
122
+ domain = fem.Cells(geometry=geo)
123
+
124
+ # Right-hand-side (forcing term)
125
+ test = fem.make_test(space=self._scalar_space, domain=domain)
126
+ rhs = fem.integrate(linear_form, fields={"v": test})
127
+
128
+ # Diffusion form
129
+ trial = fem.make_trial(space=self._scalar_space, domain=domain)
130
+ matrix = fem.integrate(diffusion_form, fields={"u": trial, "v": test}, values={"nu": args.viscosity})
131
+
132
+ # Boundary conditions on Y sides
133
+ # Use nodal integration so that boundary conditions are specified on each node independently
134
+ boundary = fem.BoundarySides(geo)
135
+ bd_test = fem.make_test(space=self._scalar_space, domain=boundary)
136
+ bd_trial = fem.make_trial(space=self._scalar_space, domain=boundary)
137
+
138
+ bd_matrix = fem.integrate(y_boundary_projector_form, fields={"u": bd_trial, "v": bd_test}, nodal=True)
139
+ bd_rhs = fem.integrate(
140
+ y_boundary_value_form, fields={"v": bd_test}, values={"val": args.boundary_value}, nodal=True
141
+ )
142
+
143
+ # Assemble linear system
144
+ if args.boundary_compliance == 0.0:
145
+ # Hard BC: project linear system
146
+ fem.project_linear_system(matrix, rhs, bd_matrix, bd_rhs)
147
+ else:
148
+ # Weak BC: add toegether diffusion and boundary condition matrices
149
+ boundary_strength = 1.0 / args.boundary_compliance
150
+ bsr_axpy(x=bd_matrix, y=matrix, alpha=boundary_strength, beta=1)
151
+ array_axpy(x=bd_rhs, y=rhs, alpha=boundary_strength, beta=1)
152
+
153
+ # Solve linear system using Conjugate Gradient
154
+ x = wp.zeros_like(rhs)
155
+ bsr_cg(matrix, b=rhs, x=x, quiet=self._quiet)
156
+
157
+ # Assign system result to our discrete field
158
+ self._scalar_field.dof_values = x
159
+
160
+ def render(self):
161
+ self.renderer.add_surface("solution", self._scalar_field)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ wp.set_module_options({"enable_backward": False})
166
+
167
+ args = Example.parser.parse_args()
168
+
169
+ example = Example(args=args)
170
+ example.step()
171
+ example.render()
172
+
173
+ example.renderer.plot()
@@ -0,0 +1,152 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Diffusion 3D
10
+ #
11
+ # This example solves a 3d diffusion problem:
12
+ #
13
+ # nu Div u = 1
14
+ #
15
+ # with homogeneous Neumann conditions on horizontal sides
16
+ # and homogeneous Dirichlet boundary conditions other sides.
17
+ ###########################################################################
18
+
19
+ import argparse
20
+
21
+ import warp as wp
22
+ import warp.fem as fem
23
+
24
+ from warp.sparse import bsr_axpy
25
+
26
+ # Import example utilities
27
+ # Make sure that works both when imported as module and run as standalone file
28
+ try:
29
+ from .example_diffusion import diffusion_form, linear_form
30
+ from .bsr_utils import bsr_cg
31
+ from .mesh_utils import gen_tetmesh
32
+ from .plot_utils import Plot
33
+ except ImportError:
34
+ from example_diffusion import diffusion_form, linear_form
35
+ from bsr_utils import bsr_cg
36
+ from mesh_utils import gen_tetmesh, gen_hexmesh
37
+ from plot_utils import Plot
38
+
39
+ wp.init()
40
+
41
+
42
+ @fem.integrand
43
+ def vert_boundary_projector_form(
44
+ s: fem.Sample,
45
+ domain: fem.Domain,
46
+ u: fem.Field,
47
+ v: fem.Field,
48
+ ):
49
+ # Non-zero mass on vertical sides only
50
+ w = 1.0 - wp.abs(fem.normal(domain, s)[1])
51
+ return w * u(s) * v(s)
52
+
53
+
54
+ class Example:
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument("--resolution", type=int, default=10)
57
+ parser.add_argument("--degree", type=int, default=2)
58
+ parser.add_argument("--serendipity", action="store_true", default=False)
59
+ parser.add_argument("--viscosity", type=float, default=2.0)
60
+ parser.add_argument("--boundary_compliance", type=float, default=0, help="Dirichlet boundary condition compliance")
61
+ parser.add_argument("--mesh", choices=("grid", "tet", "hex"), default="grid", help="Mesh type")
62
+
63
+ def __init__(self, stage=None, quiet=False, args=None, **kwargs):
64
+ if args is None:
65
+ # Read args from kwargs, add default arg values from parser
66
+ args = argparse.Namespace(**kwargs)
67
+ args = Example.parser.parse_args(args=[], namespace=args)
68
+ self._args = args
69
+ self._quiet = quiet
70
+
71
+ res = wp.vec3i(args.resolution, args.resolution // 2, args.resolution * 2)
72
+
73
+ if args.mesh == "tet":
74
+ pos, tet_vtx_indices = gen_tetmesh(
75
+ res=res,
76
+ bounds_lo=wp.vec3(0.0, 0.0, 0.0),
77
+ bounds_hi=wp.vec3(1.0, 0.5, 2.0),
78
+ )
79
+ self._geo = fem.Tetmesh(tet_vtx_indices, pos)
80
+ elif args.mesh == "hex":
81
+ pos, hex_vtx_indices = gen_hexmesh(
82
+ res=res,
83
+ bounds_lo=wp.vec3(0.0, 0.0, 0.0),
84
+ bounds_hi=wp.vec3(1.0, 0.5, 2.0),
85
+ )
86
+ self._geo = fem.Hexmesh(hex_vtx_indices, pos)
87
+ else:
88
+ self._geo = fem.Grid3D(
89
+ res=res,
90
+ bounds_lo=wp.vec3(0.0, 0.0, 0.0),
91
+ bounds_hi=wp.vec3(1.0, 0.5, 2.0),
92
+ )
93
+
94
+ # Domain and function spaces
95
+ element_basis = fem.ElementBasis.SERENDIPITY if args.serendipity else None
96
+ self._scalar_space = fem.make_polynomial_space(self._geo, degree=args.degree, element_basis=element_basis)
97
+
98
+ # Scalar field over our function space
99
+ self._scalar_field: fem.DiscreteField = self._scalar_space.make_field()
100
+
101
+ self.renderer = Plot(stage)
102
+
103
+ def step(self):
104
+ args = self._args
105
+ geo = self._geo
106
+
107
+ domain = fem.Cells(geometry=geo)
108
+
109
+ # Right-hand-side
110
+ test = fem.make_test(space=self._scalar_space, domain=domain)
111
+ rhs = fem.integrate(linear_form, fields={"v": test})
112
+
113
+ # Weakly-imposed boundary conditions on Y sides
114
+ with wp.ScopedTimer("Integrate"):
115
+ boundary = fem.BoundarySides(geo)
116
+
117
+ bd_test = fem.make_test(space=self._scalar_space, domain=boundary)
118
+ bd_trial = fem.make_trial(space=self._scalar_space, domain=boundary)
119
+ bd_matrix = fem.integrate(vert_boundary_projector_form, fields={"u": bd_trial, "v": bd_test}, nodal=True)
120
+
121
+ # Diffusion form
122
+ trial = fem.make_trial(space=self._scalar_space, domain=domain)
123
+ matrix = fem.integrate(diffusion_form, fields={"u": trial, "v": test}, values={"nu": args.viscosity})
124
+
125
+ if args.boundary_compliance == 0.0:
126
+ # Hard BC: project linear system
127
+ bd_rhs = wp.zeros_like(rhs)
128
+ fem.project_linear_system(matrix, rhs, bd_matrix, bd_rhs)
129
+ else:
130
+ # Weak BC: add toegether diffusion and boundary condition matrices
131
+ boundary_strength = 1.0 / args.boundary_compliance
132
+ bsr_axpy(x=bd_matrix, y=matrix, alpha=boundary_strength, beta=1)
133
+
134
+ with wp.ScopedTimer("CG solve"):
135
+ x = wp.zeros_like(rhs)
136
+ bsr_cg(matrix, b=rhs, x=x, quiet=self._quiet)
137
+ self._scalar_field.dof_values = x
138
+
139
+ def render(self):
140
+ self.renderer.add_volume("solution", self._scalar_field)
141
+
142
+
143
+ if __name__ == "__main__":
144
+ wp.set_module_options({"enable_backward": False})
145
+
146
+ args = Example.parser.parse_args()
147
+
148
+ example = Example(args=args)
149
+ example.step()
150
+ example.render()
151
+
152
+ example.renderer.plot()
@@ -0,0 +1,214 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Diffusion MGPU
10
+ #
11
+ # This example illustrates using domain decomposition to
12
+ # solve a diffusion PDE over multiple devices
13
+ ###########################################################################
14
+
15
+ from typing import Tuple
16
+
17
+ import warp as wp
18
+ import warp.fem as fem
19
+ from warp.sparse import bsr_axpy, bsr_mv
20
+ from warp.utils import array_cast
21
+
22
+ # Import example utilities
23
+ # Make sure that works both when imported as module and run as standalone file
24
+ try:
25
+ from .bsr_utils import bsr_cg
26
+ from .example_diffusion import diffusion_form, linear_form
27
+ from .plot_utils import Plot
28
+ except ImportError:
29
+ from bsr_utils import bsr_cg
30
+ from example_diffusion import diffusion_form, linear_form
31
+ from plot_utils import Plot
32
+
33
+ wp.init()
34
+
35
+
36
+ @fem.integrand
37
+ def mass_form(
38
+ s: fem.Sample,
39
+ u: fem.Field,
40
+ v: fem.Field,
41
+ ):
42
+ return u(s) * v(s)
43
+
44
+
45
+ @wp.kernel
46
+ def scal_kernel(a: wp.array(dtype=wp.float64), res: wp.array(dtype=wp.float64), alpha: wp.float64):
47
+ res[wp.tid()] = a[wp.tid()] * alpha
48
+
49
+
50
+ @wp.kernel
51
+ def sum_kernel(a: wp.indexedarray(dtype=wp.float64), b: wp.array(dtype=wp.float64)):
52
+ a[wp.tid()] = a[wp.tid()] + b[wp.tid()]
53
+
54
+
55
+ def sum_vecs(vecs, indices, sum: wp.array, tmp: wp.array):
56
+ for v, idx in zip(vecs, indices):
57
+ wp.copy(dest=tmp, src=v)
58
+ idx_sum = wp.indexedarray(sum, idx)
59
+ wp.launch(kernel=sum_kernel, dim=idx.shape, device=sum.device, inputs=[idx_sum, tmp])
60
+
61
+ return sum
62
+
63
+
64
+ class DistributedSystem:
65
+ device = None
66
+ scalar_type: type
67
+ tmp_buf: wp.array
68
+
69
+ nrow: int
70
+ shape = Tuple[int, int]
71
+ rank_data = None
72
+
73
+ def mv_routine(self, x: wp.array, y: wp.array, z: wp.array, alpha=1.0, beta=0.0):
74
+ """Distributed matrix-vector multiplication routine, for example purposes"""
75
+
76
+ tmp = self.tmp_buf
77
+
78
+ wp.launch(kernel=scal_kernel, dim=y.shape, device=y.device, inputs=[y, z, wp.float64(beta)])
79
+
80
+ stream = wp.get_stream()
81
+
82
+ for mat_i, x_i, y_i, idx in zip(*self.rank_data):
83
+ # WAR copy with indexed array requiring matching shape
84
+ tmp_i = wp.array(
85
+ ptr=tmp.ptr, device=tmp.device, capacity=tmp.capacity, dtype=tmp.dtype, shape=idx.shape
86
+ )
87
+
88
+ # Compress rhs on rank 0
89
+ x_idx = wp.indexedarray(x, idx)
90
+ wp.copy(dest=tmp_i, src=x_idx, count=idx.size, stream=stream)
91
+
92
+ # Send to rank i
93
+ wp.copy(dest=x_i, src=tmp_i, count=idx.size, stream=stream)
94
+
95
+ with wp.ScopedDevice(x_i.device):
96
+ wp.wait_stream(stream)
97
+ bsr_mv(A=mat_i, x=x_i, y=y_i, alpha=alpha, beta=0.0)
98
+
99
+ wp.wait_stream(wp.get_stream(x_i.device))
100
+
101
+ # Back to rank 0 for sum
102
+ wp.copy(dest=tmp_i, src=y_i, count=idx.size, stream=stream)
103
+ z_idx = wp.indexedarray(z, idx)
104
+ wp.launch(kernel=sum_kernel, dim=idx.shape, device=z_idx.device, inputs=[z_idx, tmp_i], stream=stream)
105
+
106
+ wp.wait_stream(stream)
107
+
108
+
109
+ class Example:
110
+ def __init__(self, stage=None, quiet=False):
111
+ self._bd_weight = 100.0
112
+ self._quiet = quiet
113
+
114
+ self._geo = fem.Grid2D(res=wp.vec2i(25))
115
+
116
+ self._main_device = wp.get_device("cuda")
117
+
118
+ with wp.ScopedDevice(self._main_device):
119
+ self._scalar_space = fem.make_polynomial_space(self._geo, degree=3)
120
+ self._scalar_field = self._scalar_space.make_field()
121
+
122
+ self.renderer = Plot(stage)
123
+
124
+ def step(self):
125
+ devices = wp.get_cuda_devices()
126
+ main_device = self._main_device
127
+
128
+ rhs_vecs = []
129
+ res_vecs = []
130
+ matrices = []
131
+ indices = []
132
+
133
+ # Build local system for each device
134
+ for k, device in enumerate(devices):
135
+ with wp.ScopedDevice(device):
136
+ # Construct the partition corresponding to the k'th device
137
+ geo_partition = fem.LinearGeometryPartition(self._geo, k, len(devices))
138
+ matrix, rhs, partition_node_indices = self._assemble_local_system(geo_partition)
139
+
140
+ rhs_vecs.append(rhs)
141
+ res_vecs.append(wp.empty_like(rhs))
142
+ matrices.append(matrix)
143
+ indices.append(partition_node_indices.to(main_device))
144
+
145
+ # Global rhs as sum of all local rhs
146
+ glob_rhs = wp.zeros(n=self._scalar_space.node_count(), dtype=wp.float64, device=main_device)
147
+
148
+ # This temporary buffer will be used for peer-to-peer copying during graph capture,
149
+ # so we allocate it using the default CUDA allocator. This ensures that the copying
150
+ # will succeed without enabling mempool access between devices, which is not supported
151
+ # on all systems.
152
+ with wp.ScopedMempool(main_device, False):
153
+ tmp = wp.empty_like(glob_rhs)
154
+
155
+ sum_vecs(rhs_vecs, indices, glob_rhs, tmp)
156
+
157
+ # Distributed CG
158
+ global_res = wp.zeros_like(glob_rhs)
159
+ A = DistributedSystem()
160
+ A.device = main_device
161
+ A.dtype = glob_rhs.dtype
162
+ A.nrow = self._scalar_space.node_count()
163
+ A.shape = (A.nrow, A.nrow)
164
+ A.tmp_buf = tmp
165
+ A.rank_data = (matrices, rhs_vecs, res_vecs, indices)
166
+
167
+ with wp.ScopedDevice(main_device):
168
+ bsr_cg(
169
+ A,
170
+ x=global_res,
171
+ b=glob_rhs,
172
+ use_diag_precond=False,
173
+ quiet=self._quiet,
174
+ mv_routine=A.mv_routine
175
+ )
176
+
177
+ array_cast(in_array=global_res, out_array=self._scalar_field.dof_values)
178
+
179
+ def render(self):
180
+ self.renderer.add_surface("solution", self._scalar_field)
181
+
182
+ def _assemble_local_system(self, geo_partition: fem.GeometryPartition):
183
+ scalar_space = self._scalar_space
184
+ space_partition = fem.make_space_partition(scalar_space, geo_partition)
185
+
186
+ domain = fem.Cells(geometry=geo_partition)
187
+
188
+ # Right-hand-side
189
+ test = fem.make_test(space=scalar_space, space_partition=space_partition, domain=domain)
190
+ rhs = fem.integrate(linear_form, fields={"v": test})
191
+
192
+ # Weakly-imposed boundary conditions on all sides
193
+ boundary = fem.BoundarySides(geometry=geo_partition)
194
+ bd_test = fem.make_test(space=scalar_space, space_partition=space_partition, domain=boundary)
195
+ bd_trial = fem.make_trial(space=scalar_space, space_partition=space_partition, domain=boundary)
196
+ bd_matrix = fem.integrate(mass_form, fields={"u": bd_trial, "v": bd_test})
197
+
198
+ # Diffusion form
199
+ trial = fem.make_trial(space=scalar_space, space_partition=space_partition, domain=domain)
200
+ matrix = fem.integrate(diffusion_form, fields={"u": trial, "v": test}, values={"nu": 1.0})
201
+
202
+ bsr_axpy(y=matrix, x=bd_matrix, alpha=self._bd_weight)
203
+
204
+ return matrix, rhs, space_partition.space_node_indices()
205
+
206
+
207
+ if __name__ == "__main__":
208
+ wp.set_module_options({"enable_backward": False})
209
+
210
+ example = Example()
211
+ example.step()
212
+ example.render()
213
+
214
+ example.renderer.plot()