warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/warp-clang.dll +0 -0
  3. warp/bin/warp.dll +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,317 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- ###########################################################################
17
- # Example Walker
18
- #
19
- # Trains a tetrahedral mesh quadruped to run. Feeds 8 time-varying input
20
- # phases as inputs into a single layer fully connected network with a tanh
21
- # activation function. Interprets the output of the network as tet
22
- # activations, which are fed into the wp.sim soft mesh model. This is
23
- # simulated forward in time and then evaluated based on the center of mass
24
- # momentum of the mesh.
25
- #
26
- # This example uses the deprecated wp.matmul() for matrix multiplication,
27
- # which will be removed in a future version. See the updated version of
28
- # this example, example_tile_walker.py, in examples/tile for the new
29
- # approach to GEMMs using Warp's tile API.
30
- #
31
- ###########################################################################
32
-
33
- import math
34
- import os
35
-
36
- import numpy as np
37
- from pxr import Gf, Usd, UsdGeom
38
-
39
- import warp as wp
40
- import warp.examples
41
- import warp.optim
42
- import warp.sim
43
- import warp.sim.render
44
-
45
-
46
- @wp.kernel
47
- def loss_kernel(com: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)):
48
- tid = wp.tid()
49
- vx = com[tid][0]
50
- vy = com[tid][1]
51
- vz = com[tid][2]
52
- delta = wp.sqrt(vx * vx) + wp.sqrt(vy * vy) - vz
53
-
54
- wp.atomic_add(loss, 0, delta)
55
-
56
-
57
- @wp.kernel
58
- def com_kernel(velocities: wp.array(dtype=wp.vec3), n: int, com: wp.array(dtype=wp.vec3)):
59
- tid = wp.tid()
60
- v = velocities[tid]
61
- a = v / wp.float32(n)
62
- wp.atomic_add(com, 0, a)
63
-
64
-
65
- @wp.kernel
66
- def compute_phases(phases: wp.array(dtype=float), sim_time: float):
67
- tid = wp.tid()
68
- phases[tid] = wp.sin(phase_freq * sim_time + wp.float32(tid) * phase_step)
69
-
70
-
71
- @wp.kernel
72
- def activation_function(tet_activations: wp.array(dtype=float), activation_inputs: wp.array(dtype=float)):
73
- tid = wp.tid()
74
- activation = wp.tanh(activation_inputs[tid])
75
- tet_activations[tid] = activation_strength * activation
76
-
77
-
78
- phase_count = 8
79
- phase_step = wp.constant((2.0 * math.pi) / phase_count)
80
- phase_freq = wp.constant(5.0)
81
- activation_strength = wp.constant(0.3)
82
-
83
-
84
- class Example:
85
- def __init__(self, stage_path="example_walker.usd", verbose=False, num_frames=300):
86
- self.verbose = verbose
87
-
88
- fps = 60
89
- self.frame_dt = 1.0 / fps
90
- self.num_frames = num_frames
91
-
92
- self.sim_substeps = 80
93
- self.sim_dt = self.frame_dt / self.sim_substeps
94
- self.sim_time = 0.0
95
-
96
- self.iter = 0
97
- self.train_rate = 0.025
98
-
99
- self.phase_count = phase_count
100
-
101
- self.render_time = 0.0
102
-
103
- # bear
104
- asset_stage = Usd.Stage.Open(os.path.join(warp.examples.get_asset_directory(), "bear.usd"))
105
-
106
- geom = UsdGeom.Mesh(asset_stage.GetPrimAtPath("/root/bear"))
107
- points = geom.GetPointsAttr().Get()
108
-
109
- xform = Gf.Matrix4f(geom.ComputeLocalToWorldTransform(0.0))
110
- for i in range(len(points)):
111
- points[i] = xform.Transform(points[i])
112
-
113
- self.points = [wp.vec3(point) for point in points]
114
- self.tet_indices = geom.GetPrim().GetAttribute("tetraIndices").Get()
115
-
116
- # sim model
117
- builder = wp.sim.ModelBuilder()
118
- builder.add_soft_mesh(
119
- pos=wp.vec3(0.0, 0.5, 0.0),
120
- rot=wp.quat_identity(),
121
- scale=1.0,
122
- vel=wp.vec3(0.0, 0.0, 0.0),
123
- vertices=self.points,
124
- indices=self.tet_indices,
125
- density=1.0,
126
- k_mu=2000.0,
127
- k_lambda=2000.0,
128
- k_damp=2.0,
129
- tri_ke=0.0,
130
- tri_ka=1e-8,
131
- tri_kd=0.0,
132
- tri_drag=0.0,
133
- tri_lift=0.0,
134
- )
135
-
136
- # finalize model
137
- self.model = builder.finalize(requires_grad=True)
138
- self.control = self.model.control()
139
-
140
- self.model.soft_contact_ke = 2.0e3
141
- self.model.soft_contact_kd = 0.1
142
- self.model.soft_contact_kf = 10.0
143
- self.model.soft_contact_mu = 0.7
144
-
145
- radii = wp.zeros(self.model.particle_count, dtype=float)
146
- radii.fill_(0.05)
147
- self.model.particle_radius = radii
148
- self.model.ground = True
149
-
150
- # allocate sim states
151
- self.states = []
152
- for _i in range(self.num_frames * self.sim_substeps + 1):
153
- self.states.append(self.model.state(requires_grad=True))
154
-
155
- # initialize the integrator.
156
- self.integrator = wp.sim.SemiImplicitIntegrator()
157
-
158
- # model input
159
- self.phases = []
160
- for _i in range(self.num_frames):
161
- self.phases.append(wp.zeros(self.phase_count, dtype=float, requires_grad=True))
162
-
163
- # single layer linear network
164
- rng = np.random.default_rng(42)
165
- k = 1.0 / self.phase_count
166
- weights = rng.uniform(-np.sqrt(k), np.sqrt(k), (self.model.tet_count, self.phase_count))
167
- self.weights = wp.array(weights, dtype=float, requires_grad=True)
168
- self.bias = wp.zeros(self.model.tet_count, dtype=float, requires_grad=True)
169
-
170
- # tanh activation layer
171
- self.activation_inputs = []
172
- self.tet_activations = []
173
- for _i in range(self.num_frames):
174
- self.activation_inputs.append(wp.zeros(self.model.tet_count, dtype=float, requires_grad=True))
175
- self.tet_activations.append(wp.zeros(self.model.tet_count, dtype=float, requires_grad=True))
176
-
177
- # optimization
178
- self.loss = wp.zeros(1, dtype=float, requires_grad=True)
179
- self.coms = []
180
- for _i in range(self.num_frames):
181
- self.coms.append(wp.zeros(1, dtype=wp.vec3, requires_grad=True))
182
- self.optimizer = warp.optim.Adam([self.weights.flatten()], lr=self.train_rate)
183
-
184
- # rendering
185
- if stage_path:
186
- self.renderer = wp.sim.render.SimRenderer(self.model, stage_path)
187
- else:
188
- self.renderer = None
189
-
190
- # capture forward/backward passes
191
- self.use_cuda_graph = wp.get_device().is_cuda
192
- if self.use_cuda_graph:
193
- with wp.ScopedCapture() as capture:
194
- self.tape = wp.Tape()
195
- with self.tape:
196
- for i in range(self.num_frames):
197
- self.forward(i)
198
- self.tape.backward(self.loss)
199
- self.graph = capture.graph
200
-
201
- def forward(self, frame):
202
- with wp.ScopedTimer("network", active=self.verbose):
203
- # build sinusoidal input phases
204
- wp.launch(kernel=compute_phases, dim=self.phase_count, inputs=[self.phases[frame], self.sim_time])
205
- # fully connected, linear transformation layer
206
- wp.matmul(
207
- self.weights,
208
- self.phases[frame].reshape((self.phase_count, 1)),
209
- self.bias.reshape((self.model.tet_count, 1)),
210
- self.activation_inputs[frame].reshape((self.model.tet_count, 1)),
211
- )
212
- # tanh activation function
213
- wp.launch(
214
- kernel=activation_function,
215
- dim=self.model.tet_count,
216
- inputs=[self.tet_activations[frame], self.activation_inputs[frame]],
217
- )
218
- self.control.tet_activations = self.tet_activations[frame]
219
-
220
- with wp.ScopedTimer("simulate", active=self.verbose):
221
- # run simulation loop
222
- for i in range(self.sim_substeps):
223
- self.states[frame * self.sim_substeps + i].clear_forces()
224
- self.integrator.simulate(
225
- self.model,
226
- self.states[frame * self.sim_substeps + i],
227
- self.states[frame * self.sim_substeps + i + 1],
228
- self.sim_dt,
229
- self.control,
230
- )
231
- self.sim_time += self.sim_dt
232
-
233
- with wp.ScopedTimer("loss", active=self.verbose):
234
- # compute center of mass velocity
235
- wp.launch(
236
- com_kernel,
237
- dim=self.model.particle_count,
238
- inputs=[
239
- self.states[(frame + 1) * self.sim_substeps].particle_qd,
240
- self.model.particle_count,
241
- self.coms[frame],
242
- ],
243
- outputs=[],
244
- )
245
- # compute loss
246
- wp.launch(loss_kernel, dim=1, inputs=[self.coms[frame], self.loss], outputs=[])
247
-
248
- def step(self):
249
- with wp.ScopedTimer("step"):
250
- if self.use_cuda_graph:
251
- wp.capture_launch(self.graph)
252
- else:
253
- self.tape = wp.Tape()
254
- with self.tape:
255
- for i in range(self.num_frames):
256
- self.forward(i)
257
- self.tape.backward(self.loss)
258
-
259
- # optimization
260
- x = self.weights.grad.flatten()
261
- self.optimizer.step([x])
262
-
263
- loss = self.loss.numpy()
264
- if self.verbose:
265
- print(f"Iteration {self.iter}: {loss}")
266
-
267
- # reset sim
268
- self.sim_time = 0.0
269
- self.states[0] = self.model.state(requires_grad=True)
270
-
271
- # clear grads and zero arrays for next iteration
272
- self.tape.zero()
273
- self.loss.zero_()
274
- for i in range(self.num_frames):
275
- self.coms[i].zero_()
276
-
277
- self.iter += 1
278
-
279
- def render(self):
280
- if self.renderer is None:
281
- return
282
-
283
- with wp.ScopedTimer("render"):
284
- for i in range(self.num_frames + 1):
285
- self.renderer.begin_frame(self.render_time)
286
- self.renderer.render(self.states[i * self.sim_substeps])
287
- self.renderer.end_frame()
288
-
289
- self.render_time += self.frame_dt
290
-
291
-
292
- if __name__ == "__main__":
293
- import argparse
294
-
295
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
296
- parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
297
- parser.add_argument(
298
- "--stage_path",
299
- type=lambda x: None if x == "None" else str(x),
300
- default="example_walker.usd",
301
- help="Path to the output USD file.",
302
- )
303
- parser.add_argument("--num_frames", type=int, default=300, help="Total number of frames per training iteration.")
304
- parser.add_argument("--train_iters", type=int, default=30, help="Total number of training iterations.")
305
- parser.add_argument("--verbose", action="store_true", help="Print out additional status messages during execution.")
306
-
307
- args = parser.parse_known_args()[0]
308
-
309
- with wp.ScopedDevice(args.device):
310
- example = Example(stage_path=args.stage_path, verbose=args.verbose, num_frames=args.num_frames)
311
-
312
- for _ in range(args.train_iters):
313
- example.step()
314
- example.render()
315
-
316
- if example.renderer:
317
- example.renderer.save()
@@ -1,43 +0,0 @@
1
- /*
2
- * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: Apache-2.0
4
- *
5
- * Licensed under the Apache License, Version 2.0 (the "License");
6
- * you may not use this file except in compliance with the License.
7
- * You may obtain a copy of the License at
8
- *
9
- * http://www.apache.org/licenses/LICENSE-2.0
10
- *
11
- * Unless required by applicable law or agreed to in writing, software
12
- * distributed under the License is distributed on an "AS IS" BASIS,
13
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- * See the License for the specific language governing permissions and
15
- * limitations under the License.
16
- */
17
-
18
- #include "builtin.h"
19
-
20
- // stubs for platforms where there is no CUDA
21
- #if !WP_ENABLE_CUDA || !WP_ENABLE_CUTLASS
22
-
23
- extern "C"
24
- {
25
-
26
- WP_API
27
- bool cutlass_gemm(
28
- void* context, int compute_capability,
29
- int m, int n, int k,
30
- const char* datatype_str,
31
- const void* a, const void* b, const void* c, void* d,
32
- float alpha, float beta,
33
- bool row_major_a, bool row_major_b,
34
- bool allow_tf32x3_arith,
35
- int batch_count)
36
- {
37
- printf("CUDA is disabled and/or CUTLASS is disabled.\n");
38
- return false;
39
- }
40
-
41
- } // extern "C"
42
-
43
- #endif // !WP_ENABLE_CUDA || !WP_ENABLE_CUTLASS