warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,339 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import ctypes
9
+ import warp as wp
10
+ from warp.types import array_t, launch_bounds_t, strides_from_shape
11
+ from warp.context import type_str
12
+ import jax
13
+ import jax.numpy as jp
14
+
15
+ _jax_warp_p = None
16
+
17
+ # Holder for the custom callback to keep it alive.
18
+ _cc_callback = None
19
+ _registered_kernels = [None]
20
+ _registered_kernel_to_id = {}
21
+
22
+
23
+ def jax_kernel(wp_kernel):
24
+ """Create a Jax primitive from a Warp kernel.
25
+
26
+ NOTE: This is an experimental feature under development.
27
+
28
+ Current limitations:
29
+ - All kernel arguments must be arrays.
30
+ - Kernel launch dimensions are inferred from the shape of the first argument.
31
+ - Input arguments are followed by output arguments in the Warp kernel definition.
32
+ - There must be at least one input argument and at least one output argument.
33
+ - Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
34
+ - All arrays must be contiguous.
35
+ - Only the CUDA backend is supported.
36
+ """
37
+
38
+ if _jax_warp_p == None:
39
+ # Create and register the primitive
40
+ _create_jax_warp_primitive()
41
+ if not wp_kernel in _registered_kernel_to_id:
42
+ id = len(_registered_kernels)
43
+ _registered_kernels.append(wp_kernel)
44
+ _registered_kernel_to_id[wp_kernel] = id
45
+ else:
46
+ id = _registered_kernel_to_id[wp_kernel]
47
+
48
+ def bind(*args):
49
+ return _jax_warp_p.bind(*args, kernel=id)
50
+
51
+ return bind
52
+
53
+
54
+ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
55
+ # The descriptor is the form
56
+ # <kernel-id>|<launch-dims>|<arg-dims-list>
57
+ # Example: 42|16,32|16,32;100;16,32
58
+ kernel_id_str, dim_str, args_str = opaque.decode().split("|")
59
+
60
+ # Get the kernel from the registry.
61
+ kernel_id = int(kernel_id_str)
62
+ kernel = _registered_kernels[kernel_id]
63
+
64
+ # Parse launch dimensions.
65
+ dims = [int(d) for d in dim_str.split(",")]
66
+ bounds = launch_bounds_t(dims)
67
+
68
+ # Parse arguments.
69
+ arg_strings = args_str.split(";")
70
+ num_args = len(arg_strings)
71
+ assert num_args == len(kernel.adj.args), "Incorrect number of arguments"
72
+
73
+ # First param is the launch bounds.
74
+ kernel_params = (ctypes.c_void_p * (1 + num_args))()
75
+ kernel_params[0] = ctypes.addressof(bounds)
76
+
77
+ # Parse array descriptors.
78
+ args = []
79
+ for i in range(num_args):
80
+ dtype = kernel.adj.args[i].type.dtype
81
+ shape = [int(d) for d in arg_strings[i].split(",")]
82
+ strides = strides_from_shape(shape, dtype)
83
+
84
+ arr = array_t(buffers[i], 0, len(shape), shape, strides)
85
+ args.append(arr) # keep a reference
86
+ arg_ptr = ctypes.addressof(arr)
87
+
88
+ kernel_params[i + 1] = arg_ptr
89
+
90
+ # Get current device.
91
+ device = wp.device_from_jax(_get_jax_device())
92
+
93
+ # Get kernel hooks.
94
+ # Note: module was loaded during jit lowering.
95
+ hooks = kernel.module.get_kernel_hooks(kernel, device)
96
+ assert hooks.forward, "Failed to find kernel entry point"
97
+
98
+ # Launch the kernel.
99
+ wp.context.runtime.core.cuda_launch_kernel(
100
+ device.context, hooks.forward, bounds.size, 0, kernel_params, stream
101
+ )
102
+
103
+
104
+ # TODO: is there a simpler way of getting the Jax "current" device?
105
+ def _get_jax_device():
106
+ # check if jax.default_device() context manager is active
107
+ device = jax.config.jax_default_device
108
+ # if default device is not set, use first device
109
+ if device is None:
110
+ device = jax.devices()[0]
111
+ return device
112
+
113
+
114
+ def _create_jax_warp_primitive():
115
+ from functools import reduce
116
+ import jax
117
+ from jax._src.interpreters import batching
118
+ from jax.interpreters import mlir
119
+ from jax.interpreters.mlir import ir
120
+ from jaxlib.hlo_helpers import custom_call
121
+
122
+ global _jax_warp_p
123
+ global _cc_callback
124
+
125
+ # Create and register the primitive.
126
+ # TODO add default implementation that calls the kernel via warp.
127
+ _jax_warp_p = jax.core.Primitive("jax_warp")
128
+ _jax_warp_p.multiple_results = True
129
+
130
+ # TODO Just launch the kernel directly, but make sure the argument
131
+ # shapes are massaged the same way as below so that vmap works.
132
+ def impl(*args):
133
+ raise Exception("Not implemented")
134
+
135
+ _jax_warp_p.def_impl(impl)
136
+
137
+ # Auto-batching. Make sure all the arguments are fully broadcasted
138
+ # so that Warp is not confused about dimensions.
139
+ def vectorized_multi_batcher(args, dims, **params):
140
+ # Figure out the number of outputs.
141
+ wp_kernel = _registered_kernels[params["kernel"]]
142
+ output_count = len(wp_kernel.adj.args) - len(args)
143
+ shape, dim = next((a.shape, d) for a, d in zip(args, dims) if d is not None)
144
+ size = shape[dim]
145
+ args = [batching.bdim_at_front(a, d, size) if len(a.shape) else a for a, d in zip(args, dims)]
146
+ # Create the batched primitive.
147
+ return _jax_warp_p.bind(*args, **params), [dims[0]] * output_count
148
+
149
+ batching.primitive_batchers[_jax_warp_p] = vectorized_multi_batcher
150
+
151
+ def get_vecmat_shape(warp_type):
152
+ if hasattr(warp_type.dtype, "_shape_"):
153
+ return warp_type.dtype._shape_
154
+ return []
155
+
156
+ def strip_vecmat_dimensions(warp_arg, actual_shape):
157
+ shape = get_vecmat_shape(warp_arg.type)
158
+ for i, s in enumerate(reversed(shape)):
159
+ item = actual_shape[-i - 1]
160
+ if s != item:
161
+ raise Exception(f"The vector/matrix shape for argument {warp_arg.label} does not match")
162
+ return actual_shape[: len(actual_shape) - len(shape)]
163
+
164
+ def collapse_into_leading_dimension(warp_arg, actual_shape):
165
+ if len(actual_shape) < warp_arg.type.ndim:
166
+ raise Exception(f"Argument {warp_arg.label} has too few non-matrix/vector dimensions")
167
+ index_rest = len(actual_shape) - warp_arg.type.ndim + 1
168
+ leading_size = reduce(lambda x, y: x * y, actual_shape[:index_rest])
169
+ return [leading_size] + actual_shape[index_rest:]
170
+
171
+ # Infer array dimensions from input type.
172
+ def infer_dimensions(warp_arg, actual_shape):
173
+ actual_shape = strip_vecmat_dimensions(warp_arg, actual_shape)
174
+ return collapse_into_leading_dimension(warp_arg, actual_shape)
175
+
176
+ def base_type_to_jax(warp_dtype):
177
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
178
+ return wp.jax.dtype_to_jax(warp_dtype._wp_scalar_type_)
179
+ return wp.jax.dtype_to_jax(warp_dtype)
180
+
181
+ def base_type_to_jax_ir(warp_dtype):
182
+ warp_to_jax_dict = {
183
+ wp.float16: ir.F16Type.get(),
184
+ wp.float32: ir.F32Type.get(),
185
+ wp.float64: ir.F64Type.get(),
186
+ wp.int8: ir.IntegerType.get_signless(8),
187
+ wp.int16: ir.IntegerType.get_signless(16),
188
+ wp.int32: ir.IntegerType.get_signless(32),
189
+ wp.int64: ir.IntegerType.get_signless(64),
190
+ wp.uint8: ir.IntegerType.get_unsigned(8),
191
+ wp.uint16: ir.IntegerType.get_unsigned(16),
192
+ wp.uint32: ir.IntegerType.get_unsigned(32),
193
+ wp.uint64: ir.IntegerType.get_unsigned(64),
194
+ }
195
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
196
+ warp_dtype = warp_dtype._wp_scalar_type_
197
+ jax_dtype = warp_to_jax_dict.get(warp_dtype)
198
+ if jax_dtype is None:
199
+ raise TypeError(f"Invalid or unsupported data type: {warp_dtype}")
200
+ return jax_dtype
201
+
202
+ def base_type_is_compatible(warp_type, jax_ir_type):
203
+ jax_ir_to_warp = {
204
+ "f16": wp.float16,
205
+ "f32": wp.float32,
206
+ "f64": wp.float64,
207
+ "i8": wp.int8,
208
+ "i16": wp.int16,
209
+ "i32": wp.int32,
210
+ "i64": wp.int64,
211
+ "ui8": wp.uint8,
212
+ "ui16": wp.uint16,
213
+ "ui32": wp.uint32,
214
+ "ui64": wp.uint64,
215
+ }
216
+ expected_warp_type = jax_ir_to_warp.get(str(jax_ir_type))
217
+ if expected_warp_type is not None:
218
+ if hasattr(warp_type, "_wp_scalar_type_"):
219
+ return warp_type._wp_scalar_type_ == expected_warp_type
220
+ else:
221
+ return warp_type == expected_warp_type
222
+ else:
223
+ raise TypeError(f"Invalid or unsupported data type: {jax_ir_type}")
224
+
225
+ # Abstract evaluation.
226
+ def jax_warp_abstract(*args, kernel=None):
227
+ wp_kernel = _registered_kernels[kernel]
228
+ # All the extra arguments to the warp kernel are outputs.
229
+ warp_outputs = [o.type for o in wp_kernel.adj.args[len(args) :]]
230
+ # TODO. Let's just use the first input dimension to infer the output's dimensions.
231
+ dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape))
232
+ jax_outputs = []
233
+ for o in warp_outputs:
234
+ shape = list(dims) + list(get_vecmat_shape(o))
235
+ dtype = base_type_to_jax(o.dtype)
236
+ jax_outputs.append(jax.core.ShapedArray(shape, dtype))
237
+ return jax_outputs
238
+
239
+ _jax_warp_p.def_abstract_eval(jax_warp_abstract)
240
+
241
+ # Lowering to MLIR.
242
+
243
+ # Create python-land custom call target.
244
+ CCALLFUNC = ctypes.CFUNCTYPE(
245
+ ctypes.c_voidp, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p), ctypes.c_char_p, ctypes.c_size_t
246
+ )
247
+ _cc_callback = CCALLFUNC(_warp_custom_callback)
248
+ ccall_address = ctypes.cast(_cc_callback, ctypes.c_void_p)
249
+
250
+ # Put the custom call into a capsule, as required by XLA.
251
+ PyCapsule_Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object)
252
+ PyCapsule_New = ctypes.pythonapi.PyCapsule_New
253
+ PyCapsule_New.restype = ctypes.py_object
254
+ PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, PyCapsule_Destructor)
255
+ capsule = PyCapsule_New(ccall_address.value, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0))
256
+
257
+ # Register the callback in XLA.
258
+ jax.lib.xla_client.register_custom_call_target("warp_call", capsule, platform="gpu")
259
+
260
+ def default_layout(shape):
261
+ return range(len(shape) - 1, -1, -1)
262
+
263
+ def warp_call_lowering(ctx, *args, kernel=None):
264
+ if not kernel:
265
+ raise Exception("Unknown kernel id " + str(kernel))
266
+ wp_kernel = _registered_kernels[kernel]
267
+
268
+ # TODO This may not be necessary, but it is perhaps better not to be
269
+ # mucking with kernel loading while already running the workload.
270
+ module = wp_kernel.module
271
+ device = wp.device_from_jax(_get_jax_device())
272
+ if not module.load(device):
273
+ raise Exception("Could not load kernel on device")
274
+
275
+ # Infer dimensions from the first input.
276
+ warp_arg0 = wp_kernel.adj.args[0]
277
+ actual_shape0 = ir.RankedTensorType(args[0].type).shape
278
+ dims = strip_vecmat_dimensions(warp_arg0, actual_shape0)
279
+ warp_dims = collapse_into_leading_dimension(warp_arg0, dims)
280
+
281
+ # Figure out the types and shapes of the input arrays.
282
+ arg_strings = []
283
+ operand_layouts = []
284
+ for actual, warg in zip(args, wp_kernel.adj.args):
285
+ wtype = warg.type
286
+ rtt = ir.RankedTensorType(actual.type)
287
+
288
+ if not isinstance(wtype, wp.array):
289
+ raise Exception("Only contiguous arrays are supported for Jax kernel arguments")
290
+
291
+ if not base_type_is_compatible(wtype.dtype, rtt.element_type):
292
+ raise TypeError(f"Incompatible data type for argument '{warg.label}', expected {type_str(wtype.dtype)}, got {rtt.element_type}")
293
+
294
+ # Infer array dimension (by removing the vector/matrix dimensions and
295
+ # collapsing the initial dimensions).
296
+ shape = infer_dimensions(warg, rtt.shape)
297
+
298
+ if len(shape) != wtype.ndim:
299
+ raise TypeError(f"Incompatible array dimensionality for argument '{warg.label}'")
300
+
301
+ arg_strings.append(",".join([str(d) for d in shape]))
302
+ operand_layouts.append(default_layout(rtt.shape))
303
+
304
+ # Figure out the types and shapes of the output arrays.
305
+ result_types = []
306
+ result_layouts = []
307
+ for warg in wp_kernel.adj.args[len(args) :]:
308
+ wtype = warg.type
309
+
310
+ if not isinstance(wtype, wp.array):
311
+ raise Exception("Only contiguous arrays are supported for Jax kernel arguments")
312
+
313
+ # Infer dimensions from the first input.
314
+ arg_strings.append(",".join([str(d) for d in warp_dims]))
315
+
316
+ result_shape = list(dims) + list(get_vecmat_shape(wtype))
317
+ result_types.append(ir.RankedTensorType.get(result_shape, base_type_to_jax_ir(wtype.dtype)))
318
+ result_layouts.append(default_layout(result_shape))
319
+
320
+ # Build opaque descriptor for callback.
321
+ shape_str = ",".join([str(d) for d in warp_dims])
322
+ args_str = ";".join(arg_strings)
323
+ descriptor = f"{kernel}|{shape_str}|{args_str}"
324
+
325
+ out = custom_call(
326
+ b"warp_call",
327
+ result_types=result_types,
328
+ operands=args,
329
+ backend_config=descriptor.encode("utf-8"),
330
+ operand_layouts=operand_layouts,
331
+ result_layouts=result_layouts,
332
+ ).results
333
+ return out
334
+
335
+ mlir.register_lowering(
336
+ _jax_warp_p,
337
+ warp_call_lowering,
338
+ platform="gpu",
339
+ )
warp/native/builtin.h CHANGED
@@ -354,6 +354,12 @@ inline CUDA_CALLABLE uint32 sign(uint32 x) { return 1; }
354
354
  inline CUDA_CALLABLE uint64 sign(uint64 x) { return 1; }
355
355
 
356
356
 
357
+ // Catch-all for non-float types
358
+ template<typename T>
359
+ inline bool CUDA_CALLABLE isfinite(const T&)
360
+ {
361
+ return true;
362
+ }
357
363
 
358
364
  inline bool CUDA_CALLABLE isfinite(half x)
359
365
  {
@@ -368,6 +374,12 @@ inline bool CUDA_CALLABLE isfinite(double x)
368
374
  return ::isfinite(x);
369
375
  }
370
376
 
377
+ template<typename T>
378
+ inline CUDA_CALLABLE void print(const T&)
379
+ {
380
+ printf("<type without print implementation>\n");
381
+ }
382
+
371
383
  inline CUDA_CALLABLE void print(float16 f)
372
384
  {
373
385
  printf("%g\n", half_to_float(f));
warp/native/bvh.cu CHANGED
@@ -373,16 +373,16 @@ LinearBVHBuilderGPU::LinearBVHBuilderGPU()
373
373
  , total_upper(NULL)
374
374
  , total_inv_edges(NULL)
375
375
  {
376
- total_lower = (vec3*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(vec3));
377
- total_upper = (vec3*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(vec3));
378
- total_inv_edges = (vec3*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(vec3));
376
+ total_lower = (vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(vec3));
377
+ total_upper = (vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(vec3));
378
+ total_inv_edges = (vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(vec3));
379
379
  }
380
380
 
381
381
  LinearBVHBuilderGPU::~LinearBVHBuilderGPU()
382
382
  {
383
- free_temp_device(WP_CURRENT_CONTEXT, total_lower);
384
- free_temp_device(WP_CURRENT_CONTEXT, total_upper);
385
- free_temp_device(WP_CURRENT_CONTEXT, total_inv_edges);
383
+ free_device(WP_CURRENT_CONTEXT, total_lower);
384
+ free_device(WP_CURRENT_CONTEXT, total_upper);
385
+ free_device(WP_CURRENT_CONTEXT, total_inv_edges);
386
386
  }
387
387
 
388
388
 
@@ -390,12 +390,12 @@ LinearBVHBuilderGPU::~LinearBVHBuilderGPU()
390
390
  void LinearBVHBuilderGPU::build(BVH& bvh, const vec3* item_lowers, const vec3* item_uppers, int num_items, bounds3* total_bounds)
391
391
  {
392
392
  // allocate temporary memory used during building
393
- indices = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items*2); // *2 for radix sort
394
- keys = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items*2); // *2 for radix sort
395
- deltas = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items); // highest differenting bit between keys for item i and i+1
396
- range_lefts = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
397
- range_rights = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
398
- num_children = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
393
+ indices = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items*2); // *2 for radix sort
394
+ keys = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items*2); // *2 for radix sort
395
+ deltas = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items); // highest differenting bit between keys for item i and i+1
396
+ range_lefts = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
397
+ range_rights = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
398
+ num_children = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
399
399
 
400
400
  // if total bounds supplied by the host then we just
401
401
  // compute our edge length and upload it to the GPU directly
@@ -445,13 +445,13 @@ void LinearBVHBuilderGPU::build(BVH& bvh, const vec3* item_lowers, const vec3* i
445
445
  wp_launch_device(WP_CURRENT_CONTEXT, build_hierarchy, num_items, (num_items, bvh.root, deltas, num_children, range_lefts, range_rights, bvh.node_parents, bvh.node_lowers, bvh.node_uppers));
446
446
 
447
447
  // free temporary memory
448
- free_temp_device(WP_CURRENT_CONTEXT, indices);
449
- free_temp_device(WP_CURRENT_CONTEXT, keys);
450
- free_temp_device(WP_CURRENT_CONTEXT, deltas);
448
+ free_device(WP_CURRENT_CONTEXT, indices);
449
+ free_device(WP_CURRENT_CONTEXT, keys);
450
+ free_device(WP_CURRENT_CONTEXT, deltas);
451
451
 
452
- free_temp_device(WP_CURRENT_CONTEXT, range_lefts);
453
- free_temp_device(WP_CURRENT_CONTEXT, range_rights);
454
- free_temp_device(WP_CURRENT_CONTEXT, num_children);
452
+ free_device(WP_CURRENT_CONTEXT, range_lefts);
453
+ free_device(WP_CURRENT_CONTEXT, range_rights);
454
+ free_device(WP_CURRENT_CONTEXT, num_children);
455
455
 
456
456
  }
457
457
 
@@ -81,7 +81,7 @@ static void initialize_llvm()
81
81
  llvm::InitializeAllAsmPrinters();
82
82
  }
83
83
 
84
- static std::unique_ptr<llvm::Module> cpp_to_llvm(const std::string& input_file, const char* cpp_src, const char* include_dir, bool debug, llvm::LLVMContext& context)
84
+ static std::unique_ptr<llvm::Module> cpp_to_llvm(const std::string& input_file, const char* cpp_src, const char* include_dir, bool debug, bool verify_fp, llvm::LLVMContext& context)
85
85
  {
86
86
  // Compilation arguments
87
87
  std::vector<const char*> args;
@@ -126,6 +126,11 @@ static std::unique_ptr<llvm::Module> cpp_to_llvm(const std::string& input_file,
126
126
  compiler_instance.getPreprocessorOpts().addMacroDef("NDEBUG");
127
127
  }
128
128
 
129
+ if(verify_fp)
130
+ {
131
+ compiler_instance.getPreprocessorOpts().addMacroDef("WP_VERIFY_FP");
132
+ }
133
+
129
134
  compiler_instance.getLangOpts().MicrosoftExt = 1; // __forceinline / __int64
130
135
  compiler_instance.getLangOpts().DeclSpecKeyword = 1; // __declspec
131
136
 
@@ -201,12 +206,12 @@ static std::unique_ptr<llvm::Module> cuda_to_llvm(const std::string& input_file,
201
206
 
202
207
  extern "C" {
203
208
 
204
- WP_API int compile_cpp(const char* cpp_src, const char *input_file, const char* include_dir, const char* output_file, bool debug)
209
+ WP_API int compile_cpp(const char* cpp_src, const char *input_file, const char* include_dir, const char* output_file, bool debug, bool verify_fp)
205
210
  {
206
211
  initialize_llvm();
207
212
 
208
213
  llvm::LLVMContext context;
209
- std::unique_ptr<llvm::Module> module = cpp_to_llvm(input_file, cpp_src, include_dir, debug, context);
214
+ std::unique_ptr<llvm::Module> module = cpp_to_llvm(input_file, cpp_src, include_dir, debug, verify_fp, context);
210
215
 
211
216
  if(!module)
212
217
  {
warp/native/cuda_util.cpp CHANGED
@@ -9,6 +9,7 @@
9
9
  #if WP_ENABLE_CUDA
10
10
 
11
11
  #include "cuda_util.h"
12
+ #include "error.h"
12
13
 
13
14
  #if defined(_WIN32)
14
15
  #define WIN32_LEAN_AND_MEAN
@@ -19,6 +20,9 @@
19
20
  #include <dlfcn.h>
20
21
  #endif
21
22
 
23
+ #include <set>
24
+ #include <stack>
25
+
22
26
  // the minimum CUDA version required from the driver
23
27
  #define WP_CUDA_DRIVER_VERSION 11030
24
28
 
@@ -63,6 +67,7 @@ static PFN_cuDeviceGetUuid_v11040 pfn_cuDeviceGetUuid;
63
67
  static PFN_cuDevicePrimaryCtxRetain_v7000 pfn_cuDevicePrimaryCtxRetain;
64
68
  static PFN_cuDevicePrimaryCtxRelease_v11000 pfn_cuDevicePrimaryCtxRelease;
65
69
  static PFN_cuDeviceCanAccessPeer_v4000 pfn_cuDeviceCanAccessPeer;
70
+ static PFN_cuMemGetInfo_v3020 pfn_cuMemGetInfo;
66
71
  static PFN_cuCtxGetCurrent_v4000 pfn_cuCtxGetCurrent;
67
72
  static PFN_cuCtxSetCurrent_v4000 pfn_cuCtxSetCurrent;
68
73
  static PFN_cuCtxPushCurrent_v4000 pfn_cuCtxPushCurrent;
@@ -72,18 +77,23 @@ static PFN_cuCtxGetDevice_v2000 pfn_cuCtxGetDevice;
72
77
  static PFN_cuCtxCreate_v3020 pfn_cuCtxCreate;
73
78
  static PFN_cuCtxDestroy_v4000 pfn_cuCtxDestroy;
74
79
  static PFN_cuCtxEnablePeerAccess_v4000 pfn_cuCtxEnablePeerAccess;
80
+ static PFN_cuCtxDisablePeerAccess_v4000 pfn_cuCtxDisablePeerAccess;
75
81
  static PFN_cuStreamCreate_v2000 pfn_cuStreamCreate;
76
82
  static PFN_cuStreamDestroy_v4000 pfn_cuStreamDestroy;
77
83
  static PFN_cuStreamSynchronize_v2000 pfn_cuStreamSynchronize;
78
84
  static PFN_cuStreamWaitEvent_v3020 pfn_cuStreamWaitEvent;
85
+ static PFN_cuStreamGetCaptureInfo_v11030 pfn_cuStreamGetCaptureInfo;
86
+ static PFN_cuStreamUpdateCaptureDependencies_v11030 pfn_cuStreamUpdateCaptureDependencies;
79
87
  static PFN_cuEventCreate_v2000 pfn_cuEventCreate;
80
88
  static PFN_cuEventDestroy_v4000 pfn_cuEventDestroy;
81
89
  static PFN_cuEventRecord_v2000 pfn_cuEventRecord;
90
+ static PFN_cuEventRecordWithFlags_v11010 pfn_cuEventRecordWithFlags;
82
91
  static PFN_cuModuleLoadDataEx_v2010 pfn_cuModuleLoadDataEx;
83
92
  static PFN_cuModuleUnload_v2000 pfn_cuModuleUnload;
84
93
  static PFN_cuModuleGetFunction_v2000 pfn_cuModuleGetFunction;
85
94
  static PFN_cuLaunchKernel_v4000 pfn_cuLaunchKernel;
86
95
  static PFN_cuMemcpyPeerAsync_v4000 pfn_cuMemcpyPeerAsync;
96
+ static PFN_cuPointerGetAttribute_v4000 pfn_cuPointerGetAttribute;
87
97
  static PFN_cuGraphicsMapResources_v3000 pfn_cuGraphicsMapResources;
88
98
  static PFN_cuGraphicsUnmapResources_v3000 pfn_cuGraphicsUnmapResources;
89
99
  static PFN_cuGraphicsResourceGetMappedPointer_v3020 pfn_cuGraphicsResourceGetMappedPointer;
@@ -171,6 +181,7 @@ bool init_cuda_driver()
171
181
  get_driver_entry_point("cuDevicePrimaryCtxRetain", &(void*&)pfn_cuDevicePrimaryCtxRetain);
172
182
  get_driver_entry_point("cuDevicePrimaryCtxRelease", &(void*&)pfn_cuDevicePrimaryCtxRelease);
173
183
  get_driver_entry_point("cuDeviceCanAccessPeer", &(void*&)pfn_cuDeviceCanAccessPeer);
184
+ get_driver_entry_point("cuMemGetInfo", &(void*&)pfn_cuMemGetInfo);
174
185
  get_driver_entry_point("cuCtxSetCurrent", &(void*&)pfn_cuCtxSetCurrent);
175
186
  get_driver_entry_point("cuCtxGetCurrent", &(void*&)pfn_cuCtxGetCurrent);
176
187
  get_driver_entry_point("cuCtxPushCurrent", &(void*&)pfn_cuCtxPushCurrent);
@@ -180,18 +191,23 @@ bool init_cuda_driver()
180
191
  get_driver_entry_point("cuCtxCreate", &(void*&)pfn_cuCtxCreate);
181
192
  get_driver_entry_point("cuCtxDestroy", &(void*&)pfn_cuCtxDestroy);
182
193
  get_driver_entry_point("cuCtxEnablePeerAccess", &(void*&)pfn_cuCtxEnablePeerAccess);
194
+ get_driver_entry_point("cuCtxDisablePeerAccess", &(void*&)pfn_cuCtxDisablePeerAccess);
183
195
  get_driver_entry_point("cuStreamCreate", &(void*&)pfn_cuStreamCreate);
184
196
  get_driver_entry_point("cuStreamDestroy", &(void*&)pfn_cuStreamDestroy);
185
197
  get_driver_entry_point("cuStreamSynchronize", &(void*&)pfn_cuStreamSynchronize);
186
198
  get_driver_entry_point("cuStreamWaitEvent", &(void*&)pfn_cuStreamWaitEvent);
199
+ get_driver_entry_point("cuStreamGetCaptureInfo", &(void*&)pfn_cuStreamGetCaptureInfo);
200
+ get_driver_entry_point("cuStreamUpdateCaptureDependencies", &(void*&)pfn_cuStreamUpdateCaptureDependencies);
187
201
  get_driver_entry_point("cuEventCreate", &(void*&)pfn_cuEventCreate);
188
202
  get_driver_entry_point("cuEventDestroy", &(void*&)pfn_cuEventDestroy);
189
203
  get_driver_entry_point("cuEventRecord", &(void*&)pfn_cuEventRecord);
204
+ get_driver_entry_point("cuEventRecordWithFlags", &(void*&)pfn_cuEventRecordWithFlags);
190
205
  get_driver_entry_point("cuModuleLoadDataEx", &(void*&)pfn_cuModuleLoadDataEx);
191
206
  get_driver_entry_point("cuModuleUnload", &(void*&)pfn_cuModuleUnload);
192
207
  get_driver_entry_point("cuModuleGetFunction", &(void*&)pfn_cuModuleGetFunction);
193
208
  get_driver_entry_point("cuLaunchKernel", &(void*&)pfn_cuLaunchKernel);
194
209
  get_driver_entry_point("cuMemcpyPeerAsync", &(void*&)pfn_cuMemcpyPeerAsync);
210
+ get_driver_entry_point("cuPointerGetAttribute", &(void*&)pfn_cuPointerGetAttribute);
195
211
  get_driver_entry_point("cuGraphicsMapResources", &(void*&)pfn_cuGraphicsMapResources);
196
212
  get_driver_entry_point("cuGraphicsUnmapResources", &(void*&)pfn_cuGraphicsUnmapResources);
197
213
  get_driver_entry_point("cuGraphicsResourceGetMappedPointer", &(void*&)pfn_cuGraphicsResourceGetMappedPointer);
@@ -209,16 +225,16 @@ bool is_cuda_driver_initialized()
209
225
  return cuda_driver_initialized;
210
226
  }
211
227
 
212
- bool check_cuda_result(cudaError_t code, const char* file, int line)
228
+ bool check_cuda_result(cudaError_t code, const char* func, const char* file, int line)
213
229
  {
214
230
  if (code == cudaSuccess)
215
231
  return true;
216
232
 
217
- fprintf(stderr, "Warp CUDA error %u: %s (%s:%d)\n", unsigned(code), cudaGetErrorString(code), file, line);
233
+ wp::set_error_string("Warp CUDA error %u: %s (in function %s, %s:%d)", unsigned(code), cudaGetErrorString(code), func, file, line);
218
234
  return false;
219
235
  }
220
236
 
221
- bool check_cu_result(CUresult result, const char* file, int line)
237
+ bool check_cu_result(CUresult result, const char* func, const char* file, int line)
222
238
  {
223
239
  if (result == CUDA_SUCCESS)
224
240
  return true;
@@ -228,13 +244,56 @@ bool check_cu_result(CUresult result, const char* file, int line)
228
244
  pfn_cuGetErrorString(result, &errString);
229
245
 
230
246
  if (errString)
231
- fprintf(stderr, "Warp CUDA error %u: %s (%s:%d)\n", unsigned(result), errString, file, line);
247
+ wp::set_error_string("Warp CUDA error %u: %s (in function %s, %s:%d)", unsigned(result), errString, func, file, line);
232
248
  else
233
- fprintf(stderr, "Warp CUDA error %u (%s:%d)\n", unsigned(result), file, line);
249
+ wp::set_error_string("Warp CUDA error %u (in function %s, %s:%d)", unsigned(result), func, file, line);
234
250
 
235
251
  return false;
236
252
  }
237
253
 
254
+ bool get_capture_dependencies(CUstream stream, std::vector<CUgraphNode>& dependencies_ret)
255
+ {
256
+ CUstreamCaptureStatus status;
257
+ size_t num_dependencies = 0;
258
+ const CUgraphNode* dependencies = NULL;
259
+ dependencies_ret.clear();
260
+ if (check_cu(cuStreamGetCaptureInfo_f(stream, &status, NULL, NULL, &dependencies, &num_dependencies)))
261
+ {
262
+ if (dependencies && num_dependencies > 0)
263
+ dependencies_ret.insert(dependencies_ret.begin(), dependencies, dependencies + num_dependencies);
264
+ return true;
265
+ }
266
+ return false;
267
+ }
268
+
269
+ bool get_graph_leaf_nodes(cudaGraph_t graph, std::vector<cudaGraphNode_t>& leaf_nodes_ret)
270
+ {
271
+ if (!graph)
272
+ return false;
273
+
274
+ size_t node_count = 0;
275
+ if (!check_cuda(cudaGraphGetNodes(graph, NULL, &node_count)))
276
+ return false;
277
+
278
+ std::vector<cudaGraphNode_t> nodes(node_count);
279
+ if (!check_cuda(cudaGraphGetNodes(graph, nodes.data(), &node_count)))
280
+ return false;
281
+
282
+ leaf_nodes_ret.clear();
283
+
284
+ for (cudaGraphNode_t node : nodes)
285
+ {
286
+ size_t dependent_count;
287
+ if (!check_cuda(cudaGraphNodeGetDependentNodes(node, NULL, &dependent_count)))
288
+ return false;
289
+
290
+ if (dependent_count == 0)
291
+ leaf_nodes_ret.push_back(node);
292
+ }
293
+
294
+ return true;
295
+ }
296
+
238
297
 
239
298
  #define DRIVER_ENTRY_POINT_ERROR driver_entry_point_error(__FUNCTION__)
240
299
 
@@ -311,6 +370,11 @@ CUresult cuDeviceCanAccessPeer_f(int* can_access, CUdevice dev, CUdevice peer_de
311
370
  return pfn_cuDeviceCanAccessPeer ? pfn_cuDeviceCanAccessPeer(can_access, dev, peer_dev) : DRIVER_ENTRY_POINT_ERROR;
312
371
  }
313
372
 
373
+ CUresult cuMemGetInfo_f(size_t* free, size_t* total)
374
+ {
375
+ return pfn_cuMemGetInfo ? pfn_cuMemGetInfo(free, total) : DRIVER_ENTRY_POINT_ERROR;
376
+ }
377
+
314
378
  CUresult cuCtxGetCurrent_f(CUcontext* ctx)
315
379
  {
316
380
  return pfn_cuCtxGetCurrent ? pfn_cuCtxGetCurrent(ctx) : DRIVER_ENTRY_POINT_ERROR;
@@ -356,6 +420,11 @@ CUresult cuCtxEnablePeerAccess_f(CUcontext peer_ctx, unsigned int flags)
356
420
  return pfn_cuCtxEnablePeerAccess ? pfn_cuCtxEnablePeerAccess(peer_ctx, flags) : DRIVER_ENTRY_POINT_ERROR;
357
421
  }
358
422
 
423
+ CUresult cuCtxDisablePeerAccess_f(CUcontext peer_ctx)
424
+ {
425
+ return pfn_cuCtxDisablePeerAccess ? pfn_cuCtxDisablePeerAccess(peer_ctx) : DRIVER_ENTRY_POINT_ERROR;
426
+ }
427
+
359
428
  CUresult cuStreamCreate_f(CUstream* stream, unsigned int flags)
360
429
  {
361
430
  return pfn_cuStreamCreate ? pfn_cuStreamCreate(stream, flags) : DRIVER_ENTRY_POINT_ERROR;
@@ -376,6 +445,16 @@ CUresult cuStreamWaitEvent_f(CUstream stream, CUevent event, unsigned int flags)
376
445
  return pfn_cuStreamWaitEvent ? pfn_cuStreamWaitEvent(stream, event, flags) : DRIVER_ENTRY_POINT_ERROR;
377
446
  }
378
447
 
448
+ CUresult cuStreamGetCaptureInfo_f(CUstream stream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out)
449
+ {
450
+ return pfn_cuStreamGetCaptureInfo ? pfn_cuStreamGetCaptureInfo(stream, captureStatus_out, id_out, graph_out, dependencies_out, numDependencies_out) : DRIVER_ENTRY_POINT_ERROR;
451
+ }
452
+
453
+ CUresult cuStreamUpdateCaptureDependencies_f(CUstream stream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags)
454
+ {
455
+ return pfn_cuStreamUpdateCaptureDependencies ? pfn_cuStreamUpdateCaptureDependencies(stream, dependencies, numDependencies, flags) : DRIVER_ENTRY_POINT_ERROR;
456
+ }
457
+
379
458
  CUresult cuEventCreate_f(CUevent* event, unsigned int flags)
380
459
  {
381
460
  return pfn_cuEventCreate ? pfn_cuEventCreate(event, flags) : DRIVER_ENTRY_POINT_ERROR;
@@ -391,6 +470,11 @@ CUresult cuEventRecord_f(CUevent event, CUstream stream)
391
470
  return pfn_cuEventRecord ? pfn_cuEventRecord(event, stream) : DRIVER_ENTRY_POINT_ERROR;
392
471
  }
393
472
 
473
+ CUresult cuEventRecordWithFlags_f(CUevent event, CUstream stream, unsigned int flags)
474
+ {
475
+ return pfn_cuEventRecordWithFlags ? pfn_cuEventRecordWithFlags(event, stream, flags) : DRIVER_ENTRY_POINT_ERROR;
476
+ }
477
+
394
478
  CUresult cuModuleLoadDataEx_f(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues)
395
479
  {
396
480
  return pfn_cuModuleLoadDataEx ? pfn_cuModuleLoadDataEx(module, image, numOptions, options, optionValues) : DRIVER_ENTRY_POINT_ERROR;
@@ -416,6 +500,11 @@ CUresult cuMemcpyPeerAsync_f(CUdeviceptr dst_ptr, CUcontext dst_ctx, CUdeviceptr
416
500
  return pfn_cuMemcpyPeerAsync ? pfn_cuMemcpyPeerAsync(dst_ptr, dst_ctx, src_ptr, src_ctx, n, stream) : DRIVER_ENTRY_POINT_ERROR;
417
501
  }
418
502
 
503
+ CUresult cuPointerGetAttribute_f(void* data, CUpointer_attribute attribute, CUdeviceptr ptr)
504
+ {
505
+ return pfn_cuPointerGetAttribute ? pfn_cuPointerGetAttribute(data, attribute, ptr) : DRIVER_ENTRY_POINT_ERROR;
506
+ }
507
+
419
508
  CUresult cuGraphicsMapResources_f(unsigned int count, CUgraphicsResource* resources, CUstream stream)
420
509
  {
421
510
  return pfn_cuGraphicsMapResources ? pfn_cuGraphicsMapResources(count, resources, stream) : DRIVER_ENTRY_POINT_ERROR;