warp-lang 1.4.2__py3-none-macosx_10_13_universal2.whl → 1.5.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (165) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +21 -2
  5. warp/build_dll.py +23 -6
  6. warp/builtins.py +1819 -7
  7. warp/codegen.py +197 -61
  8. warp/config.py +2 -2
  9. warp/context.py +379 -107
  10. warp/examples/assets/pixel.jpg +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  14. warp/examples/benchmarks/benchmark_tile.py +179 -0
  15. warp/examples/fem/example_adaptive_grid.py +37 -10
  16. warp/examples/fem/example_apic_fluid.py +3 -2
  17. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  18. warp/examples/fem/example_deformed_geometry.py +1 -1
  19. warp/examples/fem/example_diffusion_3d.py +47 -4
  20. warp/examples/fem/example_distortion_energy.py +220 -0
  21. warp/examples/fem/example_magnetostatics.py +127 -85
  22. warp/examples/fem/example_nonconforming_contact.py +5 -5
  23. warp/examples/fem/example_stokes.py +3 -1
  24. warp/examples/fem/example_streamlines.py +12 -19
  25. warp/examples/fem/utils.py +38 -15
  26. warp/examples/sim/example_cloth.py +4 -25
  27. warp/examples/sim/example_quadruped.py +2 -1
  28. warp/examples/tile/example_tile_convolution.py +58 -0
  29. warp/examples/tile/example_tile_fft.py +47 -0
  30. warp/examples/tile/example_tile_filtering.py +105 -0
  31. warp/examples/tile/example_tile_matmul.py +79 -0
  32. warp/examples/tile/example_tile_mlp.py +375 -0
  33. warp/fem/__init__.py +8 -0
  34. warp/fem/cache.py +16 -12
  35. warp/fem/dirichlet.py +1 -1
  36. warp/fem/domain.py +44 -1
  37. warp/fem/field/__init__.py +1 -2
  38. warp/fem/field/field.py +31 -19
  39. warp/fem/field/nodal_field.py +101 -49
  40. warp/fem/field/virtual.py +794 -0
  41. warp/fem/geometry/__init__.py +2 -2
  42. warp/fem/geometry/deformed_geometry.py +3 -105
  43. warp/fem/geometry/element.py +13 -0
  44. warp/fem/geometry/geometry.py +165 -7
  45. warp/fem/geometry/grid_2d.py +3 -6
  46. warp/fem/geometry/grid_3d.py +31 -28
  47. warp/fem/geometry/hexmesh.py +3 -46
  48. warp/fem/geometry/nanogrid.py +3 -2
  49. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  50. warp/fem/geometry/tetmesh.py +2 -43
  51. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  52. warp/fem/integrate.py +683 -261
  53. warp/fem/linalg.py +404 -0
  54. warp/fem/operator.py +101 -18
  55. warp/fem/polynomial.py +5 -5
  56. warp/fem/quadrature/quadrature.py +45 -21
  57. warp/fem/space/__init__.py +45 -11
  58. warp/fem/space/basis_function_space.py +451 -0
  59. warp/fem/space/basis_space.py +58 -11
  60. warp/fem/space/function_space.py +146 -5
  61. warp/fem/space/grid_2d_function_space.py +80 -66
  62. warp/fem/space/grid_3d_function_space.py +113 -68
  63. warp/fem/space/hexmesh_function_space.py +96 -108
  64. warp/fem/space/nanogrid_function_space.py +62 -110
  65. warp/fem/space/quadmesh_function_space.py +208 -0
  66. warp/fem/space/shape/__init__.py +45 -7
  67. warp/fem/space/shape/cube_shape_function.py +328 -54
  68. warp/fem/space/shape/shape_function.py +10 -1
  69. warp/fem/space/shape/square_shape_function.py +328 -60
  70. warp/fem/space/shape/tet_shape_function.py +269 -19
  71. warp/fem/space/shape/triangle_shape_function.py +238 -19
  72. warp/fem/space/tetmesh_function_space.py +69 -37
  73. warp/fem/space/topology.py +38 -0
  74. warp/fem/space/trimesh_function_space.py +179 -0
  75. warp/fem/utils.py +6 -331
  76. warp/jax_experimental.py +3 -1
  77. warp/native/array.h +15 -0
  78. warp/native/builtin.h +66 -26
  79. warp/native/bvh.h +4 -0
  80. warp/native/coloring.cpp +604 -0
  81. warp/native/cuda_util.cpp +68 -51
  82. warp/native/cuda_util.h +2 -1
  83. warp/native/fabric.h +8 -0
  84. warp/native/hashgrid.h +4 -0
  85. warp/native/marching.cu +8 -0
  86. warp/native/mat.h +14 -3
  87. warp/native/mathdx.cpp +59 -0
  88. warp/native/mesh.h +4 -0
  89. warp/native/range.h +13 -1
  90. warp/native/reduce.cpp +9 -1
  91. warp/native/reduce.cu +7 -0
  92. warp/native/runlength_encode.cpp +9 -1
  93. warp/native/runlength_encode.cu +7 -1
  94. warp/native/scan.cpp +8 -0
  95. warp/native/scan.cu +8 -0
  96. warp/native/scan.h +8 -1
  97. warp/native/sparse.cpp +8 -0
  98. warp/native/sparse.cu +8 -0
  99. warp/native/temp_buffer.h +7 -0
  100. warp/native/tile.h +1854 -0
  101. warp/native/tile_gemm.h +341 -0
  102. warp/native/tile_reduce.h +210 -0
  103. warp/native/volume_builder.cu +8 -0
  104. warp/native/volume_builder.h +8 -0
  105. warp/native/warp.cpp +10 -2
  106. warp/native/warp.cu +369 -15
  107. warp/native/warp.h +12 -2
  108. warp/optim/adam.py +39 -4
  109. warp/paddle.py +29 -12
  110. warp/render/render_opengl.py +140 -67
  111. warp/sim/graph_coloring.py +292 -0
  112. warp/sim/import_urdf.py +8 -8
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +109 -32
  117. warp/sparse.py +1 -1
  118. warp/stubs.py +569 -4
  119. warp/tape.py +12 -7
  120. warp/tests/assets/pixel.npy +0 -0
  121. warp/tests/aux_test_instancing_gc.py +18 -0
  122. warp/tests/test_array.py +39 -0
  123. warp/tests/test_codegen.py +81 -1
  124. warp/tests/test_codegen_instancing.py +30 -0
  125. warp/tests/test_collision.py +110 -0
  126. warp/tests/test_coloring.py +251 -0
  127. warp/tests/test_context.py +34 -0
  128. warp/tests/test_examples.py +21 -5
  129. warp/tests/test_fem.py +453 -113
  130. warp/tests/test_func.py +34 -4
  131. warp/tests/test_generics.py +52 -0
  132. warp/tests/test_iter.py +68 -0
  133. warp/tests/test_lerp.py +13 -87
  134. warp/tests/test_mat_scalar_ops.py +1 -1
  135. warp/tests/test_matmul.py +6 -9
  136. warp/tests/test_matmul_lite.py +6 -11
  137. warp/tests/test_mesh_query_point.py +1 -1
  138. warp/tests/test_module_hashing.py +23 -0
  139. warp/tests/test_overwrite.py +45 -0
  140. warp/tests/test_paddle.py +27 -87
  141. warp/tests/test_print.py +56 -1
  142. warp/tests/test_smoothstep.py +17 -83
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_static.py +3 -3
  145. warp/tests/test_tile.py +744 -0
  146. warp/tests/test_tile_mathdx.py +144 -0
  147. warp/tests/test_tile_mlp.py +383 -0
  148. warp/tests/test_tile_reduce.py +374 -0
  149. warp/tests/test_tile_shared_memory.py +190 -0
  150. warp/tests/test_vbd.py +12 -20
  151. warp/tests/test_volume.py +43 -0
  152. warp/tests/unittest_suites.py +19 -2
  153. warp/tests/unittest_utils.py +4 -2
  154. warp/types.py +340 -74
  155. warp/utils.py +23 -3
  156. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +160 -133
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  159. warp/fem/field/test.py +0 -180
  160. warp/fem/field/trial.py +0 -183
  161. warp/fem/space/collocated_function_space.py +0 -102
  162. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  163. warp/fem/space/trimesh_2d_function_space.py +0 -153
  164. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
@@ -200,4 +200,6 @@ if __name__ == "__main__":
200
200
  example.render()
201
201
 
202
202
  if not args.headless:
203
- example.renderer.plot(options={"velocity": {"streamlines": {}}, "pressure": {"contours": {}}})
203
+ example.renderer.plot(
204
+ options={"velocity": {"streamlines": {}}, "pressure": {"contours": {}}}, backend="matplotlib"
205
+ )
@@ -82,7 +82,7 @@ def mass_form(
82
82
  u: fem.Field,
83
83
  v: fem.Field,
84
84
  ):
85
- return u(s) * v(s)
85
+ return wp.dot(u(s), v(s))
86
86
 
87
87
 
88
88
  @fem.integrand
@@ -199,7 +199,7 @@ class Example:
199
199
  domain=self._inflow, order=self._degree, family=fem.Polynomial.GAUSS_LEGENDRE
200
200
  )
201
201
  n_streamlines = streamline_spawn.total_point_count()
202
- spawn_points = wp.array(dtype=wp.vec3, shape=n_streamlines)
202
+ spawn_points = wp.empty(dtype=wp.vec3, shape=n_streamlines)
203
203
 
204
204
  jitter_amount = self._streamline_dx / self._degree
205
205
  fem.interpolate(
@@ -212,8 +212,8 @@ class Example:
212
212
  # to populate the per-point data
213
213
 
214
214
  point_count = self._streamline_point_count
215
- points = wp.array(dtype=wp.vec3, shape=(n_streamlines, point_count))
216
- speed = wp.array(dtype=float, shape=(n_streamlines, point_count))
215
+ points = wp.empty(dtype=wp.vec3, shape=(n_streamlines, point_count))
216
+ speed = wp.empty(dtype=float, shape=(n_streamlines, point_count))
217
217
 
218
218
  fem.interpolate(
219
219
  gen_streamlines,
@@ -235,7 +235,7 @@ class Example:
235
235
  def render(self):
236
236
  # self.renderer.add_field("solution", self.pressure_field)
237
237
  self.plot.add_field("pressure", self.pressure_field)
238
- self.plot.add_field("velocity", self.velocity_field)
238
+ # self.plot.add_field("velocity", self.velocity_field)
239
239
 
240
240
  if self.renderer is not None:
241
241
  streamline_count = self._points.shape[0]
@@ -259,10 +259,11 @@ class Example:
259
259
  self.renderer.end_frame()
260
260
 
261
261
  def _generate_incompressible_flow(self):
262
- # Function spaces for velocity, scalars and pressure (Pk / Pk / Pk-1)
263
- u_space = fem.make_polynomial_space(geo=self._geo, degree=self._degree, dtype=wp.vec3)
264
- s_space = fem.make_polynomial_space(geo=self._geo, degree=self._degree, dtype=float)
265
- p_space = fem.make_polynomial_space(geo=self._geo, degree=self._degree - 1, dtype=float)
262
+ # Function spaces for velocity and pressure (RT1 / P0)
263
+ u_space = fem.make_polynomial_space(
264
+ geo=self._geo, element_basis=fem.ElementBasis.RAVIART_THOMAS, degree=1, dtype=wp.vec3
265
+ )
266
+ p_space = fem.make_polynomial_space(geo=self._geo, degree=0, dtype=float)
266
267
 
267
268
  self.pressure_field = p_space.make_field()
268
269
  self.velocity_field = u_space.make_field()
@@ -288,8 +289,8 @@ class Example:
288
289
  fem.interpolate(inflow_velocity, dest=fem.make_restriction(self.velocity_field, domain=self._inflow))
289
290
 
290
291
  # (Diagonal) mass matrix
291
- rho_test = fem.make_test(s_space)
292
- rho_trial = fem.make_trial(s_space)
292
+ rho_test = fem.make_test(u_space)
293
+ rho_trial = fem.make_trial(u_space)
293
294
  inv_mass_matrix = fem.integrate(
294
295
  mass_form, fields={"u": rho_trial, "v": rho_test}, nodal=True, output_dtype=float
295
296
  )
@@ -341,11 +342,3 @@ if __name__ == "__main__":
341
342
 
342
343
  example.step()
343
344
  example.render()
344
-
345
- if not args.headless:
346
- example.plot.plot(
347
- {
348
- "velocity": {"streamlines": {"density": 2}},
349
- "pressure": {"contours": {}},
350
- }
351
- )
@@ -31,7 +31,7 @@ def gen_trimesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp
31
31
  Args:
32
32
  res: Resolution of the grid along each dimension
33
33
  bounds_lo: Position of the lower bound of the axis-aligned grid
34
- bounds_up: Position of the upper bound of the axis-aligned grid
34
+ bounds_hi: Position of the upper bound of the axis-aligned grid
35
35
 
36
36
  Returns:
37
37
  Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
@@ -62,7 +62,7 @@ def gen_tetmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp
62
62
  Args:
63
63
  res: Resolution of the grid along each dimension
64
64
  bounds_lo: Position of the lower bound of the axis-aligned grid
65
- bounds_up: Position of the upper bound of the axis-aligned grid
65
+ bounds_hi: Position of the upper bound of the axis-aligned grid
66
66
 
67
67
  Returns:
68
68
  Tuple of ndarrays: (Vertex positions, Tetrahedron vertex indices)
@@ -95,7 +95,7 @@ def gen_quadmesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[w
95
95
  Args:
96
96
  res: Resolution of the grid along each dimension
97
97
  bounds_lo: Position of the lower bound of the axis-aligned grid
98
- bounds_up: Position of the upper bound of the axis-aligned grid
98
+ bounds_hi: Position of the upper bound of the axis-aligned grid
99
99
 
100
100
  Returns:
101
101
  Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
@@ -125,7 +125,7 @@ def gen_hexmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp
125
125
  Args:
126
126
  res: Resolution of the grid along each dimension
127
127
  bounds_lo: Position of the lower bound of the axis-aligned grid
128
- bounds_up: Position of the upper bound of the axis-aligned grid
128
+ bounds_hi: Position of the upper bound of the axis-aligned grid
129
129
 
130
130
  Returns:
131
131
  Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
@@ -158,7 +158,7 @@ def gen_volume(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.
158
158
  Args:
159
159
  res: Resolution of the grid along each dimension
160
160
  bounds_lo: Position of the lower bound of the axis-aligned grid
161
- bounds_up: Position of the upper bound of the axis-aligned grid
161
+ bounds_hi: Position of the upper bound of the axis-aligned grid
162
162
  device: Cuda device on which to allocate the grid
163
163
  """
164
164
 
@@ -575,6 +575,7 @@ class Plot:
575
575
 
576
576
  def _plot_pyvista(self, options: Dict[str, Any]):
577
577
  import pyvista
578
+ import pyvista.themes
578
579
 
579
580
  grids = {}
580
581
  scales = {}
@@ -702,7 +703,7 @@ class Plot:
702
703
  subplot_rows = options.get("rows", 1)
703
704
  subplot_shape = (subplot_rows, (len(grids) + subplot_rows - 1) // subplot_rows)
704
705
 
705
- plotter = pyvista.Plotter(shape=subplot_shape)
706
+ plotter = pyvista.Plotter(shape=subplot_shape, theme=pyvista.themes.DocumentProTheme())
706
707
  plotter.link_views()
707
708
  plotter.add_camera_orientation_widget()
708
709
  for index, (name, grid) in enumerate(grids.items()):
@@ -717,7 +718,7 @@ class Plot:
717
718
  plotter.view_xy()
718
719
  else:
719
720
  plotter.add_mesh(marker)
720
- elif field.space.dimension == 3:
721
+ elif field.space.geometry.cell_dimension == 3:
721
722
  plotter.add_mesh_clip_plane(grid, show_edges=True, clim=value_range, assign_to_axis="z")
722
723
  else:
723
724
  plotter.add_mesh(grid, show_edges=True, clim=value_range)
@@ -809,6 +810,8 @@ class Plot:
809
810
  if "arrows" in args or "streamlines" in args:
810
811
  plot_opts["glyph_scale"] = args.get("arrows", {}).get("glyph_scale", 1.0)
811
812
  plot_fn = _plot_quivers_3d
813
+ elif field.space.geometry.cell_dimension == 2:
814
+ plot_fn = _plot_surface
812
815
  else:
813
816
  plot_fn = _plot_3d_scatter
814
817
  plot_3d = True
@@ -856,23 +859,43 @@ def _field_triangulation(field):
856
859
 
857
860
 
858
861
  def _plot_surface(field, axes, **kwargs):
859
- Z = _value_or_magnitude(field.dof_values.numpy())
862
+ from matplotlib.cm import get_cmap
863
+ from matplotlib.colors import Normalize
860
864
 
861
- if "clim" in kwargs:
862
- axes.set_zlim(*kwargs["clim"])
865
+ C = _value_or_magnitude(field.dof_values.numpy())
866
+
867
+ positions = field.space.node_positions().numpy().T
868
+ if field.space.dimension == 3:
869
+ X, Y, Z = positions
870
+ else:
871
+ X, Y = positions
872
+ Z = C
873
+ axes.set_zlim(kwargs["clim"])
863
874
 
864
875
  if hasattr(field.space, "node_grid"):
865
876
  X, Y = field.space.node_grid()
866
- Z = Z.reshape(X.shape)
867
- return axes.plot_surface(X, Y, Z, linewidth=0.1, antialiased=False, **kwargs)
877
+ C = C.reshape(X.shape)
878
+ return axes.plot_surface(X, Y, C, linewidth=0.1, antialiased=False, **kwargs)
868
879
 
869
880
  if hasattr(field.space, "node_triangulation"):
870
881
  triangulation = _field_triangulation(field)
871
- return axes.plot_trisurf(triangulation, Z, linewidth=0.1, antialiased=False, **kwargs)
882
+
883
+ if field.space.dimension == 3:
884
+ plot = axes.plot_trisurf(triangulation, Z, linewidth=0.1, antialiased=False)
885
+ # change colors -- recompute color map manually
886
+ vmin, vmax = kwargs["clim"]
887
+ norm = Normalize(vmin=vmin, vmax=vmax)
888
+ values = np.mean(C[triangulation.triangles], axis=1)
889
+ colors = get_cmap(kwargs["cmap"])(norm(values))
890
+ plot.set_norm(norm)
891
+ plot.set_fc(colors)
892
+ else:
893
+ plot = axes.plot_trisurf(triangulation, C, linewidth=0.1, antialiased=False, **kwargs)
894
+
895
+ return plot
872
896
 
873
897
  # scatter
874
- X, Y = field.space.node_positions().numpy().T
875
- return axes.scatter(X, Y, Z, c=Z, **kwargs)
898
+ return axes.scatter(X, Y, Z, c=C, **kwargs)
876
899
 
877
900
 
878
901
  def _plot_displaced_tri_mesh(field, axes, **kwargs):
@@ -26,29 +26,6 @@ import warp.sim
26
26
  import warp.sim.render
27
27
 
28
28
 
29
- def color_lattice_grid(num_x, num_y):
30
- colors = []
31
- for _i in range(4):
32
- colors.append([])
33
-
34
- for xi in range(num_x + 1):
35
- for yi in range(num_y + 1):
36
- vId = xi * (num_y + 1) + yi
37
-
38
- a = 1 if xi % 2 else 0
39
- b = 1 if yi % 2 else 0
40
-
41
- c = a * 2 + b
42
-
43
- colors[c].append(vId)
44
-
45
- colors_wp = []
46
- for i_color in range(len(colors)):
47
- colors_wp.append(wp.array(colors[i_color], dtype=wp.int32))
48
-
49
- return colors_wp
50
-
51
-
52
29
  class IntegratorType(Enum):
53
30
  EULER = "euler"
54
31
  XPBD = "xpbd"
@@ -122,6 +99,7 @@ class Example:
122
99
  tri_ke=1e4,
123
100
  tri_ka=1e4,
124
101
  tri_kd=1e-5,
102
+ edge_ke=100,
125
103
  )
126
104
 
127
105
  usd_stage = Usd.Stage.Open(os.path.join(warp.examples.get_asset_directory(), "bunny.usd"))
@@ -143,6 +121,9 @@ class Example:
143
121
  kf=1.0e1,
144
122
  )
145
123
 
124
+ if self.integrator_type == IntegratorType.VBD:
125
+ builder.color()
126
+
146
127
  self.model = builder.finalize()
147
128
  self.model.ground = True
148
129
  self.model.soft_contact_ke = 1.0e4
@@ -154,8 +135,6 @@ class Example:
154
135
  self.integrator = wp.sim.XPBDIntegrator(iterations=1)
155
136
  else:
156
137
  self.integrator = wp.sim.VBDIntegrator(self.model, iterations=1)
157
- # we need to give VBD coloring information
158
- self.model.particle_coloring = color_lattice_grid(width, height)
159
138
 
160
139
  self.state_0 = self.model.state()
161
140
  self.state_1 = self.model.state()
@@ -115,10 +115,11 @@ class Example:
115
115
 
116
116
  self.model.joint_attach_ke = 16000.0
117
117
  self.model.joint_attach_kd = 200.0
118
+ self.use_tile_gemm = False
118
119
 
119
120
  # self.integrator = wp.sim.XPBDIntegrator()
120
121
  # self.integrator = wp.sim.SemiImplicitIntegrator()
121
- self.integrator = wp.sim.FeatherstoneIntegrator(self.model)
122
+ self.integrator = wp.sim.FeatherstoneIntegrator(self.model, use_tile_gemm=self.use_tile_gemm)
122
123
 
123
124
  if stage_path:
124
125
  self.renderer = wp.sim.render.SimRenderer(self.model, stage_path)
@@ -0,0 +1,58 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Tile Convolution
10
+ #
11
+ # Shows how to write a simple convolution kernel using Warp FFT tile
12
+ # primitives.
13
+ #
14
+ ###########################################################################
15
+
16
+ import numpy as np
17
+
18
+ import warp as wp
19
+
20
+ wp.set_module_options({"enable_backward": False})
21
+
22
+ BLOCK_DIM = 64
23
+ TILE_M = 1
24
+ TILE_N = 128
25
+
26
+ scale = wp.vec2d(wp.float64(1 / TILE_N), wp.float64(1 / TILE_N))
27
+
28
+
29
+ @wp.func
30
+ def filter(x: wp.vec2d):
31
+ return wp.cw_mul(x, scale)
32
+
33
+
34
+ @wp.kernel
35
+ def conv_tiled(x: wp.array2d(dtype=wp.vec2d), y: wp.array2d(dtype=wp.vec2d)):
36
+ i, j, _ = wp.tid()
37
+ a = wp.tile_load(x, i, j, m=TILE_M, n=TILE_N)
38
+ wp.tile_fft(a)
39
+ b = wp.tile_map(filter, a)
40
+ wp.tile_ifft(b)
41
+ wp.tile_store(y, i, j, b)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ wp.set_device("cuda:0")
46
+
47
+ rng = np.random.default_rng(42)
48
+
49
+ x_h = rng.standard_normal((TILE_M, TILE_N, 2), dtype=np.float64)
50
+ y_h = np.zeros_like(x_h)
51
+
52
+ x_wp = wp.array2d(x_h, dtype=wp.vec2d)
53
+ y_wp = wp.array2d(y_h, dtype=wp.vec2d)
54
+
55
+ wp.launch_tiled(conv_tiled, dim=[1, 1], inputs=[x_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
56
+
57
+ # Since filter is 1/N, conv_tiled is a ~no-op
58
+ assert np.allclose(x_h, y_wp.numpy())
@@ -0,0 +1,47 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Tile FFT
10
+ #
11
+ # Shows how to write a simple FFT kernel using Warp tile primitives.
12
+ #
13
+ ###########################################################################
14
+
15
+ import numpy as np
16
+
17
+ import warp as wp
18
+
19
+ wp.set_module_options({"enable_backward": False})
20
+
21
+ BLOCK_DIM = 8
22
+ TILE_M = 1
23
+ TILE_N = 32
24
+
25
+
26
+ @wp.kernel
27
+ def fft_tiled(x: wp.array2d(dtype=wp.vec2d), y: wp.array2d(dtype=wp.vec2d)):
28
+ i, j, _ = wp.tid()
29
+ a = wp.tile_load(x, i, j, m=TILE_M, n=TILE_N)
30
+ wp.tile_fft(a)
31
+ wp.tile_ifft(a)
32
+ wp.tile_store(y, i, j, a)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ wp.set_device("cuda:0")
37
+
38
+ x_h = np.ones((TILE_M, TILE_N, 2), dtype=np.float64)
39
+ x_h[:, :, 1] = 0
40
+ y_h = 3 * np.ones((TILE_M, TILE_N, 2), dtype=np.float64)
41
+ x_wp = wp.array2d(x_h, dtype=wp.vec2d)
42
+ y_wp = wp.array2d(y_h, dtype=wp.vec2d)
43
+
44
+ wp.launch_tiled(fft_tiled, dim=[1, 1], inputs=[x_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
45
+
46
+ print("Inputs:\n", x_wp) # [1+0i, 1+0i, 1+0i, ...]
47
+ print("Output:\n", y_wp) # [32+0i, 0, 0, ...]
@@ -0,0 +1,105 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Tile Filtering
10
+ #
11
+ # Shows how to write a simple filtering kernel using Warp FFT tile
12
+ # primitives.
13
+ #
14
+ ###########################################################################
15
+
16
+ import numpy as np
17
+
18
+ import warp as wp
19
+
20
+ wp.set_module_options({"enable_backward": False})
21
+
22
+ BLOCK_DIM = 128
23
+ TILE_M = 1
24
+ TILE_N = 512
25
+
26
+ scale = wp.vec2d(wp.float64(1 / TILE_N), wp.float64(1 / TILE_N))
27
+
28
+
29
+ def cplx(array):
30
+ return array[..., 0] + 1j * array[..., 1]
31
+
32
+
33
+ @wp.func
34
+ def cplx_prod(x: wp.vec2d, y: wp.vec2d):
35
+ return wp.cw_mul(wp.vec2d(x[0] * y[0] - x[1] * y[1], x[0] * y[1] + x[1] * y[0]), scale)
36
+
37
+
38
+ @wp.kernel
39
+ def conv_tiled(x: wp.array2d(dtype=wp.vec2d), y: wp.array2d(dtype=wp.vec2d), z: wp.array2d(dtype=wp.vec2d)):
40
+ i, j, _ = wp.tid()
41
+ a = wp.tile_load(x, i, j, m=TILE_M, n=TILE_N)
42
+ b = wp.tile_load(y, i, j, m=TILE_M, n=TILE_N)
43
+ wp.tile_fft(a)
44
+ c = wp.tile_map(cplx_prod, a, b)
45
+ wp.tile_ifft(c)
46
+ wp.tile_store(z, i, j, c)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ rng = np.random.default_rng(42)
51
+
52
+ # Create noisy input signal
53
+ t = np.linspace(0, 2 * np.pi, TILE_N, dtype=np.float64)
54
+ x = np.sin(t) + 0.5 * rng.random(TILE_N, dtype=np.float64)
55
+
56
+ # Create filter. This filter keeps only ~10% of the frequencies at the center
57
+ # of the spectrum.
58
+ f = np.ones_like(x)
59
+ freq = np.fft.fftfreq(TILE_N)
60
+ f[np.abs(freq) > 0.05] = 0.0
61
+ f[np.abs(freq) <= 0.05] = 1.0
62
+
63
+ # Create Warp input data
64
+ # We use vec2d to hold complex numbers
65
+ x_h = np.zeros((TILE_M, TILE_N, 2), dtype=np.float64)
66
+ f_h = np.zeros_like(x_h)
67
+ y_h = np.zeros_like(f_h)
68
+
69
+ x_h[:, :, 0] = x
70
+ f_h[:, :, 0] = f
71
+
72
+ x_wp = wp.array2d(x_h, dtype=wp.vec2d)
73
+ f_wp = wp.array2d(f_h, dtype=wp.vec2d)
74
+ y_wp = wp.array2d(y_h, dtype=wp.vec2d)
75
+
76
+ wp.launch_tiled(conv_tiled, dim=[1, 1], inputs=[x_wp, f_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
77
+
78
+ # Extract output and compare with numpy
79
+ x_np = cplx(x_h)
80
+ f_np = cplx(f_h)
81
+ y_test = cplx(y_wp.numpy())
82
+ y_ref = np.fft.ifft(f_np * np.fft.fft(x_np))
83
+ assert np.allclose(y_ref, y_test)
84
+
85
+ try:
86
+ import matplotlib.pyplot as plt
87
+
88
+ fig, ax = plt.subplots(figsize=(10, 5))
89
+
90
+ ax.plot(
91
+ x,
92
+ color="#DDDDDD",
93
+ linewidth=2,
94
+ label="Original",
95
+ )
96
+ ax.plot(y_test[0, :].real, color="#76B900", linewidth=3, label="Smoothed")
97
+
98
+ ax.legend()
99
+ ax.grid(True)
100
+
101
+ plt.tight_layout()
102
+ plt.show()
103
+
104
+ except ModuleNotFoundError:
105
+ print("Matplotlib not available; skipping figure")
@@ -0,0 +1,79 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Tile MatMul
10
+ #
11
+ # Shows how to write a simple GEMM kernel using Warp tile primitives.
12
+ #
13
+ ###########################################################################
14
+
15
+ import numpy as np
16
+
17
+ import warp as wp
18
+
19
+ # tile size
20
+ TILE_M = wp.constant(8)
21
+ TILE_N = wp.constant(4)
22
+ TILE_K = wp.constant(8)
23
+
24
+ # num threads per-tile
25
+ TILE_THREADS = 64
26
+
27
+
28
+ @wp.kernel
29
+ def tile_gemm(A: wp.array2d(dtype=wp.float32), B: wp.array2d(dtype=wp.float16), C: wp.array2d(dtype=wp.float64)):
30
+ # output tile index
31
+ i, j = wp.tid()
32
+
33
+ sum = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float64)
34
+
35
+ _M = A.shape[0]
36
+ _N = B.shape[1]
37
+ K = A.shape[1]
38
+
39
+ count = int(K / TILE_K)
40
+
41
+ for k in range(0, count):
42
+ a = wp.tile_load(A, i, k, m=TILE_M, n=TILE_K)
43
+ b = wp.tile_load(B, k, j, m=TILE_K, n=TILE_N)
44
+
45
+ # sum += a*b
46
+ wp.tile_matmul(a, b, sum)
47
+
48
+ wp.tile_store(C, i, j, sum)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ wp.set_device("cuda:0")
53
+
54
+ # generate some tile aligned matrix dimensions
55
+ M = TILE_M * 7
56
+ K = TILE_K * 6
57
+ N = TILE_N * 5
58
+
59
+ rng = np.random.default_rng(42)
60
+ A = rng.random((M, K), dtype=np.float32)
61
+ B = rng.random((K, N), dtype=np.float32).astype(np.float16)
62
+ C = np.zeros((M, N), dtype=np.float64)
63
+
64
+ A_wp = wp.array(A, requires_grad=True)
65
+ B_wp = wp.array(B, requires_grad=True)
66
+ C_wp = wp.array(C, requires_grad=True)
67
+
68
+ with wp.Tape() as tape:
69
+ wp.launch_tiled(
70
+ tile_gemm,
71
+ dim=(int(M / TILE_M), int(N / TILE_N)),
72
+ inputs=[A_wp, B_wp],
73
+ outputs=[C_wp],
74
+ block_dim=TILE_THREADS,
75
+ )
76
+
77
+ assert np.allclose(C_wp.numpy(), A @ B)
78
+
79
+ print("Example matrix multiplication passed")