warp-lang 0.15.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (80) hide show
  1. warp/__init__.py +1 -0
  2. warp/codegen.py +7 -3
  3. warp/config.py +2 -1
  4. warp/constants.py +3 -0
  5. warp/context.py +44 -21
  6. warp/examples/assets/bunny.usd +0 -0
  7. warp/examples/assets/cartpole.urdf +110 -0
  8. warp/examples/assets/crazyflie.usd +0 -0
  9. warp/examples/assets/cube.usda +42 -0
  10. warp/examples/assets/nv_ant.xml +92 -0
  11. warp/examples/assets/nv_humanoid.xml +183 -0
  12. warp/examples/assets/quadruped.urdf +268 -0
  13. warp/examples/assets/rocks.nvdb +0 -0
  14. warp/examples/assets/rocks.usd +0 -0
  15. warp/examples/assets/sphere.usda +56 -0
  16. warp/examples/assets/torus.usda +105 -0
  17. warp/examples/core/example_dem.py +6 -6
  18. warp/examples/core/example_fluid.py +3 -3
  19. warp/examples/core/example_graph_capture.py +3 -6
  20. warp/examples/optim/example_bounce.py +9 -8
  21. warp/examples/optim/example_cloth_throw.py +12 -8
  22. warp/examples/optim/example_diffray.py +10 -12
  23. warp/examples/optim/example_drone.py +31 -14
  24. warp/examples/optim/example_spring_cage.py +10 -15
  25. warp/examples/optim/example_trajectory.py +7 -24
  26. warp/examples/sim/example_cartpole.py +3 -9
  27. warp/examples/sim/example_cloth.py +10 -10
  28. warp/examples/sim/example_granular.py +3 -3
  29. warp/examples/sim/example_granular_collision_sdf.py +9 -4
  30. warp/examples/sim/example_jacobian_ik.py +0 -10
  31. warp/examples/sim/example_particle_chain.py +4 -4
  32. warp/examples/sim/example_quadruped.py +15 -11
  33. warp/examples/sim/example_rigid_chain.py +13 -8
  34. warp/examples/sim/example_rigid_contact.py +4 -4
  35. warp/examples/sim/example_rigid_force.py +7 -7
  36. warp/examples/sim/example_rigid_soft_contact.py +4 -4
  37. warp/examples/sim/example_soft_body.py +3 -3
  38. warp/jax.py +45 -0
  39. warp/jax_experimental.py +339 -0
  40. warp/render/render_opengl.py +188 -95
  41. warp/render/render_usd.py +34 -10
  42. warp/sim/__init__.py +13 -4
  43. warp/sim/articulation.py +4 -5
  44. warp/sim/collide.py +320 -175
  45. warp/sim/import_mjcf.py +25 -30
  46. warp/sim/import_urdf.py +94 -63
  47. warp/sim/import_usd.py +51 -36
  48. warp/sim/inertia.py +3 -2
  49. warp/sim/integrator.py +233 -0
  50. warp/sim/integrator_euler.py +447 -469
  51. warp/sim/integrator_featherstone.py +1991 -0
  52. warp/sim/integrator_xpbd.py +1420 -640
  53. warp/sim/model.py +741 -487
  54. warp/sim/particles.py +2 -1
  55. warp/sim/render.py +18 -2
  56. warp/sim/utils.py +222 -11
  57. warp/stubs.py +1 -0
  58. warp/tape.py +6 -9
  59. warp/tests/test_examples.py +87 -20
  60. warp/tests/test_grad_customs.py +122 -0
  61. warp/tests/test_jax.py +254 -0
  62. warp/tests/test_options.py +13 -53
  63. warp/tests/test_quat.py +23 -0
  64. warp/tests/test_snippet.py +2 -0
  65. warp/tests/test_utils.py +31 -26
  66. warp/tests/test_verify_fp.py +65 -0
  67. warp/tests/unittest_suites.py +4 -0
  68. warp/utils.py +50 -1
  69. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/METADATA +1 -1
  70. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +73 -64
  71. warp/examples/env/__init__.py +0 -0
  72. warp/examples/env/env_ant.py +0 -61
  73. warp/examples/env/env_cartpole.py +0 -63
  74. warp/examples/env/env_humanoid.py +0 -65
  75. warp/examples/env/env_usd.py +0 -97
  76. warp/examples/env/environment.py +0 -526
  77. warp/sim/optimizer.py +0 -138
  78. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/LICENSE.md +0 -0
  79. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  80. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
@@ -79,12 +79,12 @@ class Example:
79
79
  floating=True,
80
80
  density=1000,
81
81
  armature=0.01,
82
- stiffness=120,
82
+ stiffness=200,
83
83
  damping=1,
84
- shape_ke=1.0e4,
85
- shape_kd=1.0e2,
86
- shape_kf=1.0e2,
87
- shape_mu=0.0,
84
+ contact_ke=1.0e4,
85
+ contact_kd=1.0e2,
86
+ contact_kf=1.0e2,
87
+ contact_mu=1.0,
88
88
  limit_ke=1.0e4,
89
89
  limit_kd=1.0e1,
90
90
  )
@@ -106,17 +106,21 @@ class Example:
106
106
 
107
107
  builder.joint_q[-12:] = [0.2, 0.4, -0.6, -0.2, -0.4, 0.6, -0.2, 0.4, -0.6, 0.2, -0.4, 0.6]
108
108
 
109
- builder.joint_target[-12:] = [0.2, 0.4, -0.6, -0.2, -0.4, 0.6, -0.2, 0.4, -0.6, 0.2, -0.4, 0.6]
109
+ builder.joint_axis_mode = [wp.sim.JOINT_MODE_TARGET_POSITION] * len(builder.joint_axis_mode)
110
+ builder.joint_act[-12:] = [0.2, 0.4, -0.6, -0.2, -0.4, 0.6, -0.2, 0.4, -0.6, 0.2, -0.4, 0.6]
110
111
 
111
112
  np.set_printoptions(suppress=True)
112
113
  # finalize model
113
114
  self.model = builder.finalize()
114
115
  self.model.ground = True
116
+ # self.model.gravity = 0.0
115
117
 
116
118
  self.model.joint_attach_ke = 16000.0
117
119
  self.model.joint_attach_kd = 200.0
118
120
 
119
- self.integrator = wp.sim.XPBDIntegrator()
121
+ # self.integrator = wp.sim.XPBDIntegrator()
122
+ # self.integrator = wp.sim.SemiImplicitIntegrator()
123
+ self.integrator = wp.sim.FeatherstoneIntegrator(self.model)
120
124
 
121
125
  self.renderer = None
122
126
  if stage:
@@ -130,10 +134,11 @@ class Example:
130
134
  wp.sim.eval_fk(self.model, self.model.joint_q, self.model.joint_qd, None, self.state_0)
131
135
 
132
136
  self.use_graph = wp.get_device().is_cuda
137
+ self.graph = None
133
138
  if self.use_graph:
134
- wp.capture_begin()
135
- self.simulate()
136
- self.graph = wp.capture_end()
139
+ with wp.ScopedCapture() as capture:
140
+ self.simulate()
141
+ self.graph = capture.graph
137
142
 
138
143
  def simulate(self):
139
144
  for _ in range(self.sim_substeps):
@@ -172,4 +177,3 @@ if __name__ == "__main__":
172
177
 
173
178
  if example.renderer:
174
179
  example.renderer.save()
175
-
@@ -45,7 +45,7 @@ class Example:
45
45
  episode_duration = 5.0 # seconds
46
46
  self.episode_frames = int(episode_duration / self.frame_dt)
47
47
 
48
- self.sim_substeps = 32 # 5
48
+ self.sim_substeps = 10
49
49
  self.sim_dt = self.frame_dt / self.sim_substeps
50
50
 
51
51
  for c, t in enumerate(self.chain_types):
@@ -88,8 +88,8 @@ class Example:
88
88
  limit_upper=joint_limit_upper,
89
89
  target_ke=0.0,
90
90
  target_kd=0.0,
91
- limit_ke=30.0,
92
- limit_kd=30.0,
91
+ limit_ke=1e5,
92
+ limit_kd=1.0,
93
93
  )
94
94
 
95
95
  elif joint_type == wp.sim.JOINT_UNIVERSAL:
@@ -128,11 +128,11 @@ class Example:
128
128
  parent_xform=parent_joint_xform,
129
129
  child_xform=wp.transform_identity(),
130
130
  )
131
- # finalize model
131
+
132
132
  self.model = builder.finalize()
133
133
  self.model.ground = False
134
134
 
135
- self.integrator = wp.sim.XPBDIntegrator(iterations=5)
135
+ self.integrator = wp.sim.FeatherstoneIntegrator(self.model)
136
136
 
137
137
  self.renderer = None
138
138
  if stage:
@@ -145,9 +145,9 @@ class Example:
145
145
 
146
146
  self.use_graph = wp.get_device().is_cuda
147
147
  if self.use_graph:
148
- wp.capture_begin()
149
- self.simulate()
150
- self.graph = wp.capture_end()
148
+ with wp.ScopedCapture() as capture:
149
+ self.simulate()
150
+ self.graph = capture.graph
151
151
 
152
152
  def simulate(self):
153
153
  for _ in range(self.sim_substeps):
@@ -167,6 +167,9 @@ class Example:
167
167
  if self.renderer is None:
168
168
  return
169
169
 
170
+ if self.renderer is None:
171
+ return
172
+
170
173
  with wp.ScopedTimer("render", active=True):
171
174
  self.renderer.begin_frame(self.sim_time)
172
175
  self.renderer.render(self.state_0)
@@ -184,3 +187,5 @@ if __name__ == "__main__":
184
187
 
185
188
  if example.renderer:
186
189
  example.renderer.save()
190
+ if example.renderer:
191
+ example.renderer.save()
@@ -125,9 +125,9 @@ class Example:
125
125
 
126
126
  self.use_graph = wp.get_device().is_cuda
127
127
  if self.use_graph:
128
- wp.capture_begin()
129
- self.simulate()
130
- self.graph = wp.capture_end()
128
+ with wp.ScopedCapture() as capture:
129
+ self.simulate()
130
+ self.graph = capture.graph
131
131
 
132
132
  def load_mesh(self, filename, path):
133
133
  asset_stage = Usd.Stage.Open(filename)
@@ -156,7 +156,7 @@ class Example:
156
156
  def render(self):
157
157
  if self.renderer is None:
158
158
  return
159
-
159
+
160
160
  with wp.ScopedTimer("render", active=True):
161
161
  self.renderer.begin_frame(self.sim_time)
162
162
  self.renderer.render(self.state_0)
@@ -25,7 +25,7 @@ wp.init()
25
25
 
26
26
  class Example:
27
27
  parser = argparse.ArgumentParser()
28
- parser.add_argument("--opengl", action='store_true')
28
+ parser.add_argument("--opengl", action="store_true")
29
29
 
30
30
  def __init__(self, stage=None, args=None, **kwargs):
31
31
  if args is None:
@@ -33,7 +33,7 @@ class Example:
33
33
  args = argparse.Namespace(**kwargs)
34
34
  args = Example.parser.parse_args(args=[], namespace=args)
35
35
  self._args = args
36
-
36
+
37
37
  self.sim_fps = 60.0
38
38
  self.sim_substeps = 5
39
39
  self.sim_duration = 5.0
@@ -64,9 +64,9 @@ class Example:
64
64
 
65
65
  self.use_graph = wp.get_device().is_cuda
66
66
  if self.use_graph:
67
- wp.capture_begin()
68
- self.simulate()
69
- self.graph = wp.capture_end()
67
+ with wp.ScopedCapture() as capture:
68
+ self.simulate()
69
+ self.graph = capture.graph
70
70
 
71
71
  def simulate(self):
72
72
  for _ in range(self.sim_substeps):
@@ -77,7 +77,7 @@ class Example:
77
77
 
78
78
  self.state_0.body_f.assign(
79
79
  [
80
- [0.0, 0.0, -10000.0, 0.0, 0.0, 0.0],
80
+ [0.0, 0.0, -7000.0, 0.0, 0.0, 0.0],
81
81
  ]
82
82
  )
83
83
 
@@ -123,4 +123,4 @@ if __name__ == "__main__":
123
123
  if example.renderer:
124
124
  example.renderer.save()
125
125
 
126
- example.renderer = None
126
+ example.renderer = None
@@ -77,9 +77,9 @@ class Example:
77
77
 
78
78
  self.use_graph = wp.get_device().is_cuda
79
79
  if self.use_graph:
80
- wp.capture_begin()
81
- self.simulate()
82
- self.graph = wp.capture_end()
80
+ with wp.ScopedCapture() as capture:
81
+ self.simulate()
82
+ self.graph = capture.graph
83
83
 
84
84
  def simulate(self):
85
85
  for s in range(self.sim_substeps):
@@ -98,7 +98,7 @@ class Example:
98
98
  if self.use_graph:
99
99
  wp.capture_launch(self.graph)
100
100
  else:
101
- self.simulate()
101
+ self.simulate()
102
102
  self.sim_time += self.frame_dt
103
103
 
104
104
  def render(self):
@@ -118,9 +118,9 @@ class Example:
118
118
 
119
119
  self.use_graph = wp.get_device().is_cuda
120
120
  if self.use_graph:
121
- wp.capture_begin()
122
- self.simulate()
123
- self.graph = wp.capture_end()
121
+ with wp.ScopedCapture() as capture:
122
+ self.simulate()
123
+ self.graph = capture.graph
124
124
 
125
125
  def simulate(self):
126
126
  for _ in range(self.sim_substeps):
warp/jax.py CHANGED
@@ -6,6 +6,7 @@
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
8
  import warp
9
+ from warp.context import type_str
9
10
 
10
11
 
11
12
  def device_to_jax(wp_device):
@@ -34,6 +35,50 @@ def device_from_jax(jax_device):
34
35
  raise RuntimeError(f"Unknown or unsupported Jax device platform '{jax_device.platform}'")
35
36
 
36
37
 
38
+ def dtype_to_jax(wp_dtype):
39
+ import jax.numpy as jp
40
+
41
+ warp_to_jax_dict = {
42
+ warp.float16: jp.float16,
43
+ warp.float32: jp.float32,
44
+ warp.float64: jp.float64,
45
+ warp.int8: jp.int8,
46
+ warp.int16: jp.int16,
47
+ warp.int32: jp.int32,
48
+ warp.int64: jp.int64,
49
+ warp.uint8: jp.uint8,
50
+ warp.uint16: jp.uint16,
51
+ warp.uint32: jp.uint32,
52
+ warp.uint64: jp.uint64,
53
+ }
54
+ jax_dtype = warp_to_jax_dict.get(wp_dtype)
55
+ if jax_dtype is None:
56
+ raise TypeError(f"Invalid or unsupported data type: {type_str(wp_dtype)}")
57
+ return jax_dtype
58
+
59
+
60
+ def dtype_from_jax(jax_dtype):
61
+ import jax.numpy as jp
62
+
63
+ jax_to_warp_dict = {
64
+ jp.float16: warp.float16,
65
+ jp.float32: warp.float32,
66
+ jp.float64: warp.float64,
67
+ jp.int8: warp.int8,
68
+ jp.int16: warp.int16,
69
+ jp.int32: warp.int32,
70
+ jp.int64: warp.int64,
71
+ jp.uint8: warp.uint8,
72
+ jp.uint16: warp.uint16,
73
+ jp.uint32: warp.uint32,
74
+ jp.uint64: warp.uint64,
75
+ }
76
+ wp_dtype = jax_to_warp_dict.get(jax_dtype)
77
+ if wp_dtype is None:
78
+ raise TypeError(f"Invalid or unsupported data type: {jax_dtype}")
79
+ return wp_dtype
80
+
81
+
37
82
  def to_jax(wp_array):
38
83
  import jax.dlpack
39
84
 
@@ -0,0 +1,339 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import ctypes
9
+ import warp as wp
10
+ from warp.types import array_t, launch_bounds_t, strides_from_shape
11
+ from warp.context import type_str
12
+ import jax
13
+ import jax.numpy as jp
14
+
15
+ _jax_warp_p = None
16
+
17
+ # Holder for the custom callback to keep it alive.
18
+ _cc_callback = None
19
+ _registered_kernels = [None]
20
+ _registered_kernel_to_id = {}
21
+
22
+
23
+ def jax_kernel(wp_kernel):
24
+ """Create a Jax primitive from a Warp kernel.
25
+
26
+ NOTE: This is an experimental feature under development.
27
+
28
+ Current limitations:
29
+ - All kernel arguments must be arrays.
30
+ - Kernel launch dimensions are inferred from the shape of the first argument.
31
+ - Input arguments are followed by output arguments in the Warp kernel definition.
32
+ - There must be at least one input argument and at least one output argument.
33
+ - Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
34
+ - All arrays must be contiguous.
35
+ - Only the CUDA backend is supported.
36
+ """
37
+
38
+ if _jax_warp_p == None:
39
+ # Create and register the primitive
40
+ _create_jax_warp_primitive()
41
+ if not wp_kernel in _registered_kernel_to_id:
42
+ id = len(_registered_kernels)
43
+ _registered_kernels.append(wp_kernel)
44
+ _registered_kernel_to_id[wp_kernel] = id
45
+ else:
46
+ id = _registered_kernel_to_id[wp_kernel]
47
+
48
+ def bind(*args):
49
+ return _jax_warp_p.bind(*args, kernel=id)
50
+
51
+ return bind
52
+
53
+
54
+ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
55
+ # The descriptor is the form
56
+ # <kernel-id>|<launch-dims>|<arg-dims-list>
57
+ # Example: 42|16,32|16,32;100;16,32
58
+ kernel_id_str, dim_str, args_str = opaque.decode().split("|")
59
+
60
+ # Get the kernel from the registry.
61
+ kernel_id = int(kernel_id_str)
62
+ kernel = _registered_kernels[kernel_id]
63
+
64
+ # Parse launch dimensions.
65
+ dims = [int(d) for d in dim_str.split(",")]
66
+ bounds = launch_bounds_t(dims)
67
+
68
+ # Parse arguments.
69
+ arg_strings = args_str.split(";")
70
+ num_args = len(arg_strings)
71
+ assert num_args == len(kernel.adj.args), "Incorrect number of arguments"
72
+
73
+ # First param is the launch bounds.
74
+ kernel_params = (ctypes.c_void_p * (1 + num_args))()
75
+ kernel_params[0] = ctypes.addressof(bounds)
76
+
77
+ # Parse array descriptors.
78
+ args = []
79
+ for i in range(num_args):
80
+ dtype = kernel.adj.args[i].type.dtype
81
+ shape = [int(d) for d in arg_strings[i].split(",")]
82
+ strides = strides_from_shape(shape, dtype)
83
+
84
+ arr = array_t(buffers[i], 0, len(shape), shape, strides)
85
+ args.append(arr) # keep a reference
86
+ arg_ptr = ctypes.addressof(arr)
87
+
88
+ kernel_params[i + 1] = arg_ptr
89
+
90
+ # Get current device.
91
+ device = wp.device_from_jax(_get_jax_device())
92
+
93
+ # Get kernel hooks.
94
+ # Note: module was loaded during jit lowering.
95
+ hooks = kernel.module.get_kernel_hooks(kernel, device)
96
+ assert hooks.forward, "Failed to find kernel entry point"
97
+
98
+ # Launch the kernel.
99
+ wp.context.runtime.core.cuda_launch_kernel(
100
+ device.context, hooks.forward, bounds.size, 0, kernel_params, stream
101
+ )
102
+
103
+
104
+ # TODO: is there a simpler way of getting the Jax "current" device?
105
+ def _get_jax_device():
106
+ # check if jax.default_device() context manager is active
107
+ device = jax.config.jax_default_device
108
+ # if default device is not set, use first device
109
+ if device is None:
110
+ device = jax.devices()[0]
111
+ return device
112
+
113
+
114
+ def _create_jax_warp_primitive():
115
+ from functools import reduce
116
+ import jax
117
+ from jax._src.interpreters import batching
118
+ from jax.interpreters import mlir
119
+ from jax.interpreters.mlir import ir
120
+ from jaxlib.hlo_helpers import custom_call
121
+
122
+ global _jax_warp_p
123
+ global _cc_callback
124
+
125
+ # Create and register the primitive.
126
+ # TODO add default implementation that calls the kernel via warp.
127
+ _jax_warp_p = jax.core.Primitive("jax_warp")
128
+ _jax_warp_p.multiple_results = True
129
+
130
+ # TODO Just launch the kernel directly, but make sure the argument
131
+ # shapes are massaged the same way as below so that vmap works.
132
+ def impl(*args):
133
+ raise Exception("Not implemented")
134
+
135
+ _jax_warp_p.def_impl(impl)
136
+
137
+ # Auto-batching. Make sure all the arguments are fully broadcasted
138
+ # so that Warp is not confused about dimensions.
139
+ def vectorized_multi_batcher(args, dims, **params):
140
+ # Figure out the number of outputs.
141
+ wp_kernel = _registered_kernels[params["kernel"]]
142
+ output_count = len(wp_kernel.adj.args) - len(args)
143
+ shape, dim = next((a.shape, d) for a, d in zip(args, dims) if d is not None)
144
+ size = shape[dim]
145
+ args = [batching.bdim_at_front(a, d, size) if len(a.shape) else a for a, d in zip(args, dims)]
146
+ # Create the batched primitive.
147
+ return _jax_warp_p.bind(*args, **params), [dims[0]] * output_count
148
+
149
+ batching.primitive_batchers[_jax_warp_p] = vectorized_multi_batcher
150
+
151
+ def get_vecmat_shape(warp_type):
152
+ if hasattr(warp_type.dtype, "_shape_"):
153
+ return warp_type.dtype._shape_
154
+ return []
155
+
156
+ def strip_vecmat_dimensions(warp_arg, actual_shape):
157
+ shape = get_vecmat_shape(warp_arg.type)
158
+ for i, s in enumerate(reversed(shape)):
159
+ item = actual_shape[-i - 1]
160
+ if s != item:
161
+ raise Exception(f"The vector/matrix shape for argument {warp_arg.label} does not match")
162
+ return actual_shape[: len(actual_shape) - len(shape)]
163
+
164
+ def collapse_into_leading_dimension(warp_arg, actual_shape):
165
+ if len(actual_shape) < warp_arg.type.ndim:
166
+ raise Exception(f"Argument {warp_arg.label} has too few non-matrix/vector dimensions")
167
+ index_rest = len(actual_shape) - warp_arg.type.ndim + 1
168
+ leading_size = reduce(lambda x, y: x * y, actual_shape[:index_rest])
169
+ return [leading_size] + actual_shape[index_rest:]
170
+
171
+ # Infer array dimensions from input type.
172
+ def infer_dimensions(warp_arg, actual_shape):
173
+ actual_shape = strip_vecmat_dimensions(warp_arg, actual_shape)
174
+ return collapse_into_leading_dimension(warp_arg, actual_shape)
175
+
176
+ def base_type_to_jax(warp_dtype):
177
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
178
+ return wp.jax.dtype_to_jax(warp_dtype._wp_scalar_type_)
179
+ return wp.jax.dtype_to_jax(warp_dtype)
180
+
181
+ def base_type_to_jax_ir(warp_dtype):
182
+ warp_to_jax_dict = {
183
+ wp.float16: ir.F16Type.get(),
184
+ wp.float32: ir.F32Type.get(),
185
+ wp.float64: ir.F64Type.get(),
186
+ wp.int8: ir.IntegerType.get_signless(8),
187
+ wp.int16: ir.IntegerType.get_signless(16),
188
+ wp.int32: ir.IntegerType.get_signless(32),
189
+ wp.int64: ir.IntegerType.get_signless(64),
190
+ wp.uint8: ir.IntegerType.get_unsigned(8),
191
+ wp.uint16: ir.IntegerType.get_unsigned(16),
192
+ wp.uint32: ir.IntegerType.get_unsigned(32),
193
+ wp.uint64: ir.IntegerType.get_unsigned(64),
194
+ }
195
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
196
+ warp_dtype = warp_dtype._wp_scalar_type_
197
+ jax_dtype = warp_to_jax_dict.get(warp_dtype)
198
+ if jax_dtype is None:
199
+ raise TypeError(f"Invalid or unsupported data type: {warp_dtype}")
200
+ return jax_dtype
201
+
202
+ def base_type_is_compatible(warp_type, jax_ir_type):
203
+ jax_ir_to_warp = {
204
+ "f16": wp.float16,
205
+ "f32": wp.float32,
206
+ "f64": wp.float64,
207
+ "i8": wp.int8,
208
+ "i16": wp.int16,
209
+ "i32": wp.int32,
210
+ "i64": wp.int64,
211
+ "ui8": wp.uint8,
212
+ "ui16": wp.uint16,
213
+ "ui32": wp.uint32,
214
+ "ui64": wp.uint64,
215
+ }
216
+ expected_warp_type = jax_ir_to_warp.get(str(jax_ir_type))
217
+ if expected_warp_type is not None:
218
+ if hasattr(warp_type, "_wp_scalar_type_"):
219
+ return warp_type._wp_scalar_type_ == expected_warp_type
220
+ else:
221
+ return warp_type == expected_warp_type
222
+ else:
223
+ raise TypeError(f"Invalid or unsupported data type: {jax_ir_type}")
224
+
225
+ # Abstract evaluation.
226
+ def jax_warp_abstract(*args, kernel=None):
227
+ wp_kernel = _registered_kernels[kernel]
228
+ # All the extra arguments to the warp kernel are outputs.
229
+ warp_outputs = [o.type for o in wp_kernel.adj.args[len(args) :]]
230
+ # TODO. Let's just use the first input dimension to infer the output's dimensions.
231
+ dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape))
232
+ jax_outputs = []
233
+ for o in warp_outputs:
234
+ shape = list(dims) + list(get_vecmat_shape(o))
235
+ dtype = base_type_to_jax(o.dtype)
236
+ jax_outputs.append(jax.core.ShapedArray(shape, dtype))
237
+ return jax_outputs
238
+
239
+ _jax_warp_p.def_abstract_eval(jax_warp_abstract)
240
+
241
+ # Lowering to MLIR.
242
+
243
+ # Create python-land custom call target.
244
+ CCALLFUNC = ctypes.CFUNCTYPE(
245
+ ctypes.c_voidp, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p), ctypes.c_char_p, ctypes.c_size_t
246
+ )
247
+ _cc_callback = CCALLFUNC(_warp_custom_callback)
248
+ ccall_address = ctypes.cast(_cc_callback, ctypes.c_void_p)
249
+
250
+ # Put the custom call into a capsule, as required by XLA.
251
+ PyCapsule_Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object)
252
+ PyCapsule_New = ctypes.pythonapi.PyCapsule_New
253
+ PyCapsule_New.restype = ctypes.py_object
254
+ PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, PyCapsule_Destructor)
255
+ capsule = PyCapsule_New(ccall_address.value, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0))
256
+
257
+ # Register the callback in XLA.
258
+ jax.lib.xla_client.register_custom_call_target("warp_call", capsule, platform="gpu")
259
+
260
+ def default_layout(shape):
261
+ return range(len(shape) - 1, -1, -1)
262
+
263
+ def warp_call_lowering(ctx, *args, kernel=None):
264
+ if not kernel:
265
+ raise Exception("Unknown kernel id " + str(kernel))
266
+ wp_kernel = _registered_kernels[kernel]
267
+
268
+ # TODO This may not be necessary, but it is perhaps better not to be
269
+ # mucking with kernel loading while already running the workload.
270
+ module = wp_kernel.module
271
+ device = wp.device_from_jax(_get_jax_device())
272
+ if not module.load(device):
273
+ raise Exception("Could not load kernel on device")
274
+
275
+ # Infer dimensions from the first input.
276
+ warp_arg0 = wp_kernel.adj.args[0]
277
+ actual_shape0 = ir.RankedTensorType(args[0].type).shape
278
+ dims = strip_vecmat_dimensions(warp_arg0, actual_shape0)
279
+ warp_dims = collapse_into_leading_dimension(warp_arg0, dims)
280
+
281
+ # Figure out the types and shapes of the input arrays.
282
+ arg_strings = []
283
+ operand_layouts = []
284
+ for actual, warg in zip(args, wp_kernel.adj.args):
285
+ wtype = warg.type
286
+ rtt = ir.RankedTensorType(actual.type)
287
+
288
+ if not isinstance(wtype, wp.array):
289
+ raise Exception("Only contiguous arrays are supported for Jax kernel arguments")
290
+
291
+ if not base_type_is_compatible(wtype.dtype, rtt.element_type):
292
+ raise TypeError(f"Incompatible data type for argument '{warg.label}', expected {type_str(wtype.dtype)}, got {rtt.element_type}")
293
+
294
+ # Infer array dimension (by removing the vector/matrix dimensions and
295
+ # collapsing the initial dimensions).
296
+ shape = infer_dimensions(warg, rtt.shape)
297
+
298
+ if len(shape) != wtype.ndim:
299
+ raise TypeError(f"Incompatible array dimensionality for argument '{warg.label}'")
300
+
301
+ arg_strings.append(",".join([str(d) for d in shape]))
302
+ operand_layouts.append(default_layout(rtt.shape))
303
+
304
+ # Figure out the types and shapes of the output arrays.
305
+ result_types = []
306
+ result_layouts = []
307
+ for warg in wp_kernel.adj.args[len(args) :]:
308
+ wtype = warg.type
309
+
310
+ if not isinstance(wtype, wp.array):
311
+ raise Exception("Only contiguous arrays are supported for Jax kernel arguments")
312
+
313
+ # Infer dimensions from the first input.
314
+ arg_strings.append(",".join([str(d) for d in warp_dims]))
315
+
316
+ result_shape = list(dims) + list(get_vecmat_shape(wtype))
317
+ result_types.append(ir.RankedTensorType.get(result_shape, base_type_to_jax_ir(wtype.dtype)))
318
+ result_layouts.append(default_layout(result_shape))
319
+
320
+ # Build opaque descriptor for callback.
321
+ shape_str = ",".join([str(d) for d in warp_dims])
322
+ args_str = ";".join(arg_strings)
323
+ descriptor = f"{kernel}|{shape_str}|{args_str}"
324
+
325
+ out = custom_call(
326
+ b"warp_call",
327
+ result_types=result_types,
328
+ operands=args,
329
+ backend_config=descriptor.encode("utf-8"),
330
+ operand_layouts=operand_layouts,
331
+ result_layouts=result_layouts,
332
+ ).results
333
+ return out
334
+
335
+ mlir.register_lowering(
336
+ _jax_warp_p,
337
+ warp_call_lowering,
338
+ platform="gpu",
339
+ )