warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/native/cuda_util.h CHANGED
@@ -17,8 +17,10 @@
17
17
 
18
18
  #include <stdio.h>
19
19
 
20
- #define check_cuda(code) (check_cuda_result(code, __FILE__, __LINE__))
21
- #define check_cu(code) (check_cu_result(code, __FILE__, __LINE__))
20
+ #include <vector>
21
+
22
+ #define check_cuda(code) (check_cuda_result(code, __FUNCTION__, __FILE__, __LINE__))
23
+ #define check_cu(code) (check_cu_result(code, __FUNCTION__, __FILE__, __LINE__))
22
24
 
23
25
 
24
26
  #if defined(__CUDACC__)
@@ -55,6 +57,7 @@ CUresult cuDeviceGetUuid_f(CUuuid* uuid, CUdevice dev);
55
57
  CUresult cuDevicePrimaryCtxRetain_f(CUcontext* ctx, CUdevice dev);
56
58
  CUresult cuDevicePrimaryCtxRelease_f(CUdevice dev);
57
59
  CUresult cuDeviceCanAccessPeer_f(int* can_access, CUdevice dev, CUdevice peer_dev);
60
+ CUresult cuMemGetInfo_f(size_t* free, size_t* total);
58
61
  CUresult cuCtxGetCurrent_f(CUcontext* ctx);
59
62
  CUresult cuCtxSetCurrent_f(CUcontext ctx);
60
63
  CUresult cuCtxPushCurrent_f(CUcontext ctx);
@@ -64,18 +67,23 @@ CUresult cuCtxGetDevice_f(CUdevice* dev);
64
67
  CUresult cuCtxCreate_f(CUcontext* ctx, unsigned int flags, CUdevice dev);
65
68
  CUresult cuCtxDestroy_f(CUcontext ctx);
66
69
  CUresult cuCtxEnablePeerAccess_f(CUcontext peer_ctx, unsigned int flags);
70
+ CUresult cuCtxDisablePeerAccess_f(CUcontext peer_ctx);
67
71
  CUresult cuStreamCreate_f(CUstream* stream, unsigned int flags);
68
72
  CUresult cuStreamDestroy_f(CUstream stream);
69
73
  CUresult cuStreamSynchronize_f(CUstream stream);
70
74
  CUresult cuStreamWaitEvent_f(CUstream stream, CUevent event, unsigned int flags);
75
+ CUresult cuStreamGetCaptureInfo_f(CUstream stream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out);
76
+ CUresult cuStreamUpdateCaptureDependencies_f(CUstream stream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags);
71
77
  CUresult cuEventCreate_f(CUevent* event, unsigned int flags);
72
78
  CUresult cuEventDestroy_f(CUevent event);
73
79
  CUresult cuEventRecord_f(CUevent event, CUstream stream);
80
+ CUresult cuEventRecordWithFlags_f(CUevent event, CUstream stream, unsigned int flags);
74
81
  CUresult cuModuleUnload_f(CUmodule hmod);
75
82
  CUresult cuModuleLoadDataEx_f(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
76
83
  CUresult cuModuleGetFunction_f(CUfunction *hfunc, CUmodule hmod, const char *name);
77
84
  CUresult cuLaunchKernel_f(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
78
85
  CUresult cuMemcpyPeerAsync_f(CUdeviceptr dst_ptr, CUcontext dst_ctx, CUdeviceptr src_ptr, CUcontext src_ctx, size_t n, CUstream stream);
86
+ CUresult cuPointerGetAttribute_f(void* data, CUpointer_attribute attribute, CUdeviceptr ptr);
79
87
  CUresult cuGraphicsMapResources_f(unsigned int count, CUgraphicsResource* resources, CUstream stream);
80
88
  CUresult cuGraphicsUnmapResources_f(unsigned int count, CUgraphicsResource* resources, CUstream hStream);
81
89
  CUresult cuGraphicsResourceGetMappedPointer_f(CUdeviceptr* pDevPtr, size_t* pSize, CUgraphicsResource resource);
@@ -86,13 +94,34 @@ CUresult cuGraphicsUnregisterResource_f(CUgraphicsResource resource);
86
94
  bool init_cuda_driver();
87
95
  bool is_cuda_driver_initialized();
88
96
 
89
- bool check_cuda_result(cudaError_t code, const char* file, int line);
90
- inline bool check_cuda_result(uint64_t code, const char* file, int line)
97
+ bool check_cuda_result(cudaError_t code, const char* func, const char* file, int line);
98
+
99
+ inline bool check_cuda_result(uint64_t code, const char* func, const char* file, int line)
100
+ {
101
+ return check_cuda_result(static_cast<cudaError_t>(code), func, file, line);
102
+ }
103
+
104
+ bool check_cu_result(CUresult result, const char* func, const char* file, int line);
105
+
106
+ inline uint64_t get_capture_id(CUstream stream)
91
107
  {
92
- return check_cuda_result(static_cast<cudaError_t>(code), file, line);
108
+ CUstreamCaptureStatus status;
109
+ uint64_t id = 0;
110
+ check_cu(cuStreamGetCaptureInfo_f(stream, &status, &id, NULL, NULL, NULL));
111
+ return id;
93
112
  }
94
113
 
95
- bool check_cu_result(CUresult result, const char* file, int line);
114
+ inline CUgraph get_capture_graph(CUstream stream)
115
+ {
116
+ CUstreamCaptureStatus status;
117
+ CUgraph graph = NULL;
118
+ check_cu(cuStreamGetCaptureInfo_f(stream, &status, NULL, &graph, NULL, NULL));
119
+ return graph;
120
+ }
121
+
122
+ bool get_capture_dependencies(CUstream stream, std::vector<CUgraphNode>& dependencies_ret);
123
+
124
+ bool get_graph_leaf_nodes(cudaGraph_t graph, std::vector<cudaGraphNode_t>& leaf_nodes_ret);
96
125
 
97
126
 
98
127
  //
@@ -16,7 +16,7 @@ extern "C"
16
16
 
17
17
  WP_API
18
18
  bool cutlass_gemm(
19
- int compute_capability,
19
+ void* context, int compute_capability,
20
20
  int m, int n, int k,
21
21
  const char* datatype_str,
22
22
  const void* a, const void* b, const void* c, void* d,
@@ -8,6 +8,7 @@
8
8
 
9
9
  #include "builtin.h"
10
10
  #include "temp_buffer.h"
11
+ #include "cuda_util.h"
11
12
 
12
13
  #include "cutlass/cutlass.h"
13
14
  #include "cutlass/gemm/device/gemm_universal.h"
@@ -226,7 +227,7 @@ extern "C" {
226
227
 
227
228
  WP_API
228
229
  bool cutlass_gemm(
229
- int compute_capability,
230
+ void* context, int compute_capability,
230
231
  int m, int n, int k,
231
232
  const char* datatype_str,
232
233
  const void* a, const void* b, const void* c, void* d,
@@ -237,6 +238,8 @@ bool cutlass_gemm(
237
238
 
238
239
  std::string datatype(datatype_str);
239
240
 
241
+ ContextGuard guard(context);
242
+
240
243
  // Specializations for using Tensor Cores and A/B RowMajor/ColumnMajor designations
241
244
  if (compute_capability == 80) {
242
245
  if (datatype == F64_STR) {
warp/native/error.cpp ADDED
@@ -0,0 +1,66 @@
1
+ /** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ * and proprietary rights in and to this software, related documentation
4
+ * and any modifications thereto. Any use, reproduction, disclosure or
5
+ * distribution of this software and related documentation without an express
6
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+ */
8
+
9
+ #include <stdarg.h>
10
+ #include <stdio.h>
11
+ #include <string.h>
12
+
13
+ namespace wp
14
+ {
15
+ static char g_error_buffer[4096] = "";
16
+ static bool g_error_output_enabled = true;
17
+ static FILE* g_error_stream = stderr;
18
+
19
+ const char* get_error_string()
20
+ {
21
+ return g_error_buffer;
22
+ }
23
+
24
+ void set_error_string(const char* fmt, ...)
25
+ {
26
+ va_list args;
27
+ va_start(args, fmt);
28
+ vsnprintf(g_error_buffer, sizeof(g_error_buffer), fmt, args);
29
+ if (g_error_output_enabled)
30
+ {
31
+ vfprintf(g_error_stream, fmt, args);
32
+ fputc('\n', g_error_stream);
33
+ fflush(g_error_stream);
34
+ }
35
+ va_end(args);
36
+ }
37
+
38
+ void append_error_string(const char* fmt, ...)
39
+ {
40
+ size_t offset = strlen(g_error_buffer);
41
+ if (offset + 2 > sizeof(g_error_buffer))
42
+ return;
43
+ g_error_buffer[offset++] = '\n';
44
+ va_list args;
45
+ va_start(args, fmt);
46
+ vsnprintf(g_error_buffer + offset, sizeof(g_error_buffer) - offset, fmt, args);
47
+ if (g_error_output_enabled)
48
+ {
49
+ vfprintf(g_error_stream, fmt, args);
50
+ fputc('\n', g_error_stream);
51
+ fflush(g_error_stream);
52
+ }
53
+ va_end(args);
54
+ }
55
+
56
+ void set_error_output_enabled(bool enable)
57
+ {
58
+ g_error_output_enabled = enable;
59
+ }
60
+
61
+ bool is_error_output_enabled()
62
+ {
63
+ return g_error_output_enabled;
64
+ }
65
+
66
+ } // end of namespace wp
warp/native/error.h ADDED
@@ -0,0 +1,27 @@
1
+ /** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ * and proprietary rights in and to this software, related documentation
4
+ * and any modifications thereto. Any use, reproduction, disclosure or
5
+ * distribution of this software and related documentation without an express
6
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ namespace wp
12
+ {
13
+ // functions related to error reporting
14
+
15
+ // get error string from Python
16
+ const char* get_error_string();
17
+
18
+ // set error message for Python
19
+ // these functions also print the error message if error output is enabled
20
+ void set_error_string(const char* fmt, ...);
21
+ void append_error_string(const char* fmt, ...);
22
+
23
+ // allow disabling printing errors, which is handy during tests that expect failure
24
+ void set_error_output_enabled(bool enable);
25
+ bool is_error_output_enabled();
26
+
27
+ }
warp/native/mesh.cu CHANGED
@@ -203,8 +203,8 @@ uint64_t mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::arr
203
203
  // bvh_destroy_host(bvh_host);
204
204
 
205
205
  // create lower upper arrays expected by GPU BVH builder
206
- mesh.lowers = (wp::vec3*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(wp::vec3)*num_tris);
207
- mesh.uppers = (wp::vec3*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(wp::vec3)*num_tris);
206
+ mesh.lowers = (wp::vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(wp::vec3)*num_tris);
207
+ mesh.uppers = (wp::vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(wp::vec3)*num_tris);
208
208
 
209
209
  wp_launch_device(WP_CURRENT_CONTEXT, wp::compute_triangle_bounds, num_tris, (num_tris, points.data, indices.data, mesh.lowers, mesh.uppers));
210
210
 
warp/native/reduce.cu CHANGED
@@ -110,7 +110,7 @@ template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int coun
110
110
 
111
111
  size_t buff_size = 0;
112
112
  check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, ptr_strided, ptr_out, count, stream));
113
- void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
113
+ void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
114
114
 
115
115
  for (int k = 0; k < type_length; ++k)
116
116
  {
@@ -118,7 +118,7 @@ template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int coun
118
118
  check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, ptr_strided, ptr_out + k, count, stream));
119
119
  }
120
120
 
121
- free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
121
+ free_device(WP_CURRENT_CONTEXT, temp_buffer);
122
122
  }
123
123
 
124
124
  template <typename T>
@@ -271,11 +271,11 @@ void array_inner_device(const ElemT *ptr_a, const ElemT *ptr_b, ScalarT *ptr_out
271
271
 
272
272
  size_t buff_size = 0;
273
273
  check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, inner_iterator, ptr_out, count, stream));
274
- void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
274
+ void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
275
275
 
276
276
  check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, inner_iterator, ptr_out, count, stream));
277
277
 
278
- free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
278
+ free_device(WP_CURRENT_CONTEXT, temp_buffer);
279
279
  }
280
280
 
281
281
  template <typename T>
@@ -21,13 +21,13 @@ void runlength_encode_device(int n,
21
21
  nullptr, buff_size, values, run_values, run_lengths, run_count,
22
22
  n, stream));
23
23
 
24
- void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
24
+ void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
25
25
 
26
26
  check_cuda(cub::DeviceRunLengthEncode::Encode(
27
27
  temp_buffer, buff_size, values, run_values, run_lengths, run_count,
28
28
  n, stream));
29
29
 
30
- free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
30
+ free_device(WP_CURRENT_CONTEXT, temp_buffer);
31
31
  }
32
32
 
33
33
  void runlength_encode_int_device(
warp/native/scan.cu CHANGED
@@ -20,7 +20,7 @@ void scan_device(const T* values_in, T* values_out, int n, bool inclusive)
20
20
  check_cuda(cub::DeviceScan::ExclusiveSum(NULL, scan_temp_size, values_in, values_out, n));
21
21
  }
22
22
 
23
- void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, scan_temp_size);
23
+ void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, scan_temp_size);
24
24
 
25
25
  // scan
26
26
  if (inclusive) {
@@ -29,7 +29,7 @@ void scan_device(const T* values_in, T* values_out, int n, bool inclusive)
29
29
  check_cuda(cub::DeviceScan::ExclusiveSum(temp_buffer, scan_temp_size, values_in, values_out, n, stream));
30
30
  }
31
31
 
32
- free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
32
+ free_device(WP_CURRENT_CONTEXT, temp_buffer);
33
33
  }
34
34
 
35
35
  template void scan_device(const int*, int*, int, bool);
warp/native/sparse.cu CHANGED
@@ -456,7 +456,6 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
456
456
  size_t buff_size = 0;
457
457
  check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
458
458
  d_keys, nnz, 0, 64, stream));
459
- void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
460
459
  ScopedTemporary<> temp(context, buff_size);
461
460
  check_cuda(cub::DeviceRadixSort::SortPairs(
462
461
  temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
warp/native/temp_buffer.h CHANGED
@@ -10,13 +10,13 @@ template <typename T = char> struct ScopedTemporary
10
10
  {
11
11
 
12
12
  ScopedTemporary(void *context, size_t size)
13
- : m_context(context), m_buffer(static_cast<T*>(alloc_temp_device(m_context, size * sizeof(T))))
13
+ : m_context(context), m_buffer(static_cast<T*>(alloc_device(m_context, size * sizeof(T))))
14
14
  {
15
15
  }
16
16
 
17
17
  ~ScopedTemporary()
18
18
  {
19
- free_temp_device(m_context, m_buffer);
19
+ free_device(m_context, m_buffer);
20
20
  }
21
21
 
22
22
  T *buffer() const