warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +7 -6
- warp/build_dll.py +70 -79
- warp/builtins.py +10 -6
- warp/codegen.py +51 -19
- warp/config.py +7 -8
- warp/constants.py +3 -0
- warp/context.py +948 -245
- warp/dlpack.py +198 -113
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usda +42 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usda +56 -0
- warp/examples/assets/torus.usda +105 -0
- warp/examples/benchmarks/benchmark_api.py +383 -0
- warp/examples/benchmarks/benchmark_cloth.py +279 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
- warp/examples/benchmarks/benchmark_launches.py +295 -0
- warp/examples/core/example_dem.py +221 -0
- warp/examples/core/example_fluid.py +267 -0
- warp/examples/core/example_graph_capture.py +129 -0
- warp/examples/core/example_marching_cubes.py +177 -0
- warp/examples/core/example_mesh.py +154 -0
- warp/examples/core/example_mesh_intersect.py +193 -0
- warp/examples/core/example_nvdb.py +169 -0
- warp/examples/core/example_raycast.py +89 -0
- warp/examples/core/example_raymarch.py +178 -0
- warp/examples/core/example_render_opengl.py +141 -0
- warp/examples/core/example_sph.py +389 -0
- warp/examples/core/example_torch.py +181 -0
- warp/examples/core/example_wave.py +249 -0
- warp/examples/fem/bsr_utils.py +380 -0
- warp/examples/fem/example_apic_fluid.py +391 -0
- warp/examples/fem/example_convection_diffusion.py +168 -0
- warp/examples/fem/example_convection_diffusion_dg.py +209 -0
- warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
- warp/examples/fem/example_deformed_geometry.py +159 -0
- warp/examples/fem/example_diffusion.py +173 -0
- warp/examples/fem/example_diffusion_3d.py +152 -0
- warp/examples/fem/example_diffusion_mgpu.py +214 -0
- warp/examples/fem/example_mixed_elasticity.py +222 -0
- warp/examples/fem/example_navier_stokes.py +243 -0
- warp/examples/fem/example_stokes.py +192 -0
- warp/examples/fem/example_stokes_transfer.py +249 -0
- warp/examples/fem/mesh_utils.py +109 -0
- warp/examples/fem/plot_utils.py +287 -0
- warp/examples/optim/example_bounce.py +248 -0
- warp/examples/optim/example_cloth_throw.py +210 -0
- warp/examples/optim/example_diffray.py +535 -0
- warp/examples/optim/example_drone.py +850 -0
- warp/examples/optim/example_inverse_kinematics.py +169 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
- warp/examples/optim/example_spring_cage.py +234 -0
- warp/examples/optim/example_trajectory.py +201 -0
- warp/examples/sim/example_cartpole.py +128 -0
- warp/examples/sim/example_cloth.py +184 -0
- warp/examples/sim/example_granular.py +113 -0
- warp/examples/sim/example_granular_collision_sdf.py +185 -0
- warp/examples/sim/example_jacobian_ik.py +213 -0
- warp/examples/sim/example_particle_chain.py +106 -0
- warp/examples/sim/example_quadruped.py +179 -0
- warp/examples/sim/example_rigid_chain.py +191 -0
- warp/examples/sim/example_rigid_contact.py +176 -0
- warp/examples/sim/example_rigid_force.py +126 -0
- warp/examples/sim/example_rigid_gyroscopic.py +97 -0
- warp/examples/sim/example_rigid_soft_contact.py +124 -0
- warp/examples/sim/example_soft_body.py +178 -0
- warp/fabric.py +29 -20
- warp/fem/cache.py +0 -1
- warp/fem/dirichlet.py +0 -2
- warp/fem/integrate.py +0 -1
- warp/jax.py +45 -0
- warp/jax_experimental.py +339 -0
- warp/native/builtin.h +12 -0
- warp/native/bvh.cu +18 -18
- warp/native/clang/clang.cpp +8 -3
- warp/native/cuda_util.cpp +94 -5
- warp/native/cuda_util.h +35 -6
- warp/native/cutlass_gemm.cpp +1 -1
- warp/native/cutlass_gemm.cu +4 -1
- warp/native/error.cpp +66 -0
- warp/native/error.h +27 -0
- warp/native/mesh.cu +2 -2
- warp/native/reduce.cu +4 -4
- warp/native/runlength_encode.cu +2 -2
- warp/native/scan.cu +2 -2
- warp/native/sparse.cu +0 -1
- warp/native/temp_buffer.h +2 -2
- warp/native/warp.cpp +95 -60
- warp/native/warp.cu +1053 -218
- warp/native/warp.h +49 -32
- warp/optim/linear.py +33 -16
- warp/render/render_opengl.py +202 -101
- warp/render/render_usd.py +82 -40
- warp/sim/__init__.py +13 -4
- warp/sim/articulation.py +4 -5
- warp/sim/collide.py +320 -175
- warp/sim/import_mjcf.py +25 -30
- warp/sim/import_urdf.py +94 -63
- warp/sim/import_usd.py +51 -36
- warp/sim/inertia.py +3 -2
- warp/sim/integrator.py +233 -0
- warp/sim/integrator_euler.py +447 -469
- warp/sim/integrator_featherstone.py +1991 -0
- warp/sim/integrator_xpbd.py +1420 -640
- warp/sim/model.py +765 -487
- warp/sim/particles.py +2 -1
- warp/sim/render.py +35 -13
- warp/sim/utils.py +222 -11
- warp/stubs.py +8 -0
- warp/tape.py +16 -1
- warp/tests/aux_test_grad_customs.py +23 -0
- warp/tests/test_array.py +190 -1
- warp/tests/test_async.py +656 -0
- warp/tests/test_bool.py +50 -0
- warp/tests/test_dlpack.py +164 -11
- warp/tests/test_examples.py +166 -74
- warp/tests/test_fem.py +8 -1
- warp/tests/test_generics.py +15 -5
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +172 -12
- warp/tests/test_jax.py +254 -0
- warp/tests/test_large.py +29 -6
- warp/tests/test_launch.py +25 -0
- warp/tests/test_linear_solvers.py +20 -3
- warp/tests/test_matmul.py +61 -16
- warp/tests/test_matmul_lite.py +13 -13
- warp/tests/test_mempool.py +186 -0
- warp/tests/test_multigpu.py +3 -0
- warp/tests/test_options.py +16 -2
- warp/tests/test_peer.py +137 -0
- warp/tests/test_print.py +3 -1
- warp/tests/test_quat.py +23 -0
- warp/tests/test_sim_kinematics.py +97 -0
- warp/tests/test_snippet.py +126 -3
- warp/tests/test_streams.py +108 -79
- warp/tests/test_torch.py +16 -8
- warp/tests/test_utils.py +32 -27
- warp/tests/test_verify_fp.py +65 -0
- warp/tests/test_volume.py +1 -1
- warp/tests/unittest_serial.py +2 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +14 -7
- warp/thirdparty/unittest_parallel.py +15 -3
- warp/torch.py +10 -8
- warp/types.py +363 -246
- warp/utils.py +143 -19
- warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
- warp_lang-1.0.0.dist-info/METADATA +394 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
- warp/sim/optimizer.py +0 -138
- warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
- warp_lang-0.11.0.dist-info/METADATA +0 -238
- /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/native/cuda_util.h
CHANGED
|
@@ -17,8 +17,10 @@
|
|
|
17
17
|
|
|
18
18
|
#include <stdio.h>
|
|
19
19
|
|
|
20
|
-
#
|
|
21
|
-
|
|
20
|
+
#include <vector>
|
|
21
|
+
|
|
22
|
+
#define check_cuda(code) (check_cuda_result(code, __FUNCTION__, __FILE__, __LINE__))
|
|
23
|
+
#define check_cu(code) (check_cu_result(code, __FUNCTION__, __FILE__, __LINE__))
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
#if defined(__CUDACC__)
|
|
@@ -55,6 +57,7 @@ CUresult cuDeviceGetUuid_f(CUuuid* uuid, CUdevice dev);
|
|
|
55
57
|
CUresult cuDevicePrimaryCtxRetain_f(CUcontext* ctx, CUdevice dev);
|
|
56
58
|
CUresult cuDevicePrimaryCtxRelease_f(CUdevice dev);
|
|
57
59
|
CUresult cuDeviceCanAccessPeer_f(int* can_access, CUdevice dev, CUdevice peer_dev);
|
|
60
|
+
CUresult cuMemGetInfo_f(size_t* free, size_t* total);
|
|
58
61
|
CUresult cuCtxGetCurrent_f(CUcontext* ctx);
|
|
59
62
|
CUresult cuCtxSetCurrent_f(CUcontext ctx);
|
|
60
63
|
CUresult cuCtxPushCurrent_f(CUcontext ctx);
|
|
@@ -64,18 +67,23 @@ CUresult cuCtxGetDevice_f(CUdevice* dev);
|
|
|
64
67
|
CUresult cuCtxCreate_f(CUcontext* ctx, unsigned int flags, CUdevice dev);
|
|
65
68
|
CUresult cuCtxDestroy_f(CUcontext ctx);
|
|
66
69
|
CUresult cuCtxEnablePeerAccess_f(CUcontext peer_ctx, unsigned int flags);
|
|
70
|
+
CUresult cuCtxDisablePeerAccess_f(CUcontext peer_ctx);
|
|
67
71
|
CUresult cuStreamCreate_f(CUstream* stream, unsigned int flags);
|
|
68
72
|
CUresult cuStreamDestroy_f(CUstream stream);
|
|
69
73
|
CUresult cuStreamSynchronize_f(CUstream stream);
|
|
70
74
|
CUresult cuStreamWaitEvent_f(CUstream stream, CUevent event, unsigned int flags);
|
|
75
|
+
CUresult cuStreamGetCaptureInfo_f(CUstream stream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out);
|
|
76
|
+
CUresult cuStreamUpdateCaptureDependencies_f(CUstream stream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags);
|
|
71
77
|
CUresult cuEventCreate_f(CUevent* event, unsigned int flags);
|
|
72
78
|
CUresult cuEventDestroy_f(CUevent event);
|
|
73
79
|
CUresult cuEventRecord_f(CUevent event, CUstream stream);
|
|
80
|
+
CUresult cuEventRecordWithFlags_f(CUevent event, CUstream stream, unsigned int flags);
|
|
74
81
|
CUresult cuModuleUnload_f(CUmodule hmod);
|
|
75
82
|
CUresult cuModuleLoadDataEx_f(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
|
|
76
83
|
CUresult cuModuleGetFunction_f(CUfunction *hfunc, CUmodule hmod, const char *name);
|
|
77
84
|
CUresult cuLaunchKernel_f(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
|
78
85
|
CUresult cuMemcpyPeerAsync_f(CUdeviceptr dst_ptr, CUcontext dst_ctx, CUdeviceptr src_ptr, CUcontext src_ctx, size_t n, CUstream stream);
|
|
86
|
+
CUresult cuPointerGetAttribute_f(void* data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
|
79
87
|
CUresult cuGraphicsMapResources_f(unsigned int count, CUgraphicsResource* resources, CUstream stream);
|
|
80
88
|
CUresult cuGraphicsUnmapResources_f(unsigned int count, CUgraphicsResource* resources, CUstream hStream);
|
|
81
89
|
CUresult cuGraphicsResourceGetMappedPointer_f(CUdeviceptr* pDevPtr, size_t* pSize, CUgraphicsResource resource);
|
|
@@ -86,13 +94,34 @@ CUresult cuGraphicsUnregisterResource_f(CUgraphicsResource resource);
|
|
|
86
94
|
bool init_cuda_driver();
|
|
87
95
|
bool is_cuda_driver_initialized();
|
|
88
96
|
|
|
89
|
-
bool check_cuda_result(cudaError_t code, const char* file, int line);
|
|
90
|
-
|
|
97
|
+
bool check_cuda_result(cudaError_t code, const char* func, const char* file, int line);
|
|
98
|
+
|
|
99
|
+
inline bool check_cuda_result(uint64_t code, const char* func, const char* file, int line)
|
|
100
|
+
{
|
|
101
|
+
return check_cuda_result(static_cast<cudaError_t>(code), func, file, line);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
bool check_cu_result(CUresult result, const char* func, const char* file, int line);
|
|
105
|
+
|
|
106
|
+
inline uint64_t get_capture_id(CUstream stream)
|
|
91
107
|
{
|
|
92
|
-
|
|
108
|
+
CUstreamCaptureStatus status;
|
|
109
|
+
uint64_t id = 0;
|
|
110
|
+
check_cu(cuStreamGetCaptureInfo_f(stream, &status, &id, NULL, NULL, NULL));
|
|
111
|
+
return id;
|
|
93
112
|
}
|
|
94
113
|
|
|
95
|
-
|
|
114
|
+
inline CUgraph get_capture_graph(CUstream stream)
|
|
115
|
+
{
|
|
116
|
+
CUstreamCaptureStatus status;
|
|
117
|
+
CUgraph graph = NULL;
|
|
118
|
+
check_cu(cuStreamGetCaptureInfo_f(stream, &status, NULL, &graph, NULL, NULL));
|
|
119
|
+
return graph;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
bool get_capture_dependencies(CUstream stream, std::vector<CUgraphNode>& dependencies_ret);
|
|
123
|
+
|
|
124
|
+
bool get_graph_leaf_nodes(cudaGraph_t graph, std::vector<cudaGraphNode_t>& leaf_nodes_ret);
|
|
96
125
|
|
|
97
126
|
|
|
98
127
|
//
|
warp/native/cutlass_gemm.cpp
CHANGED
warp/native/cutlass_gemm.cu
CHANGED
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
#include "builtin.h"
|
|
10
10
|
#include "temp_buffer.h"
|
|
11
|
+
#include "cuda_util.h"
|
|
11
12
|
|
|
12
13
|
#include "cutlass/cutlass.h"
|
|
13
14
|
#include "cutlass/gemm/device/gemm_universal.h"
|
|
@@ -226,7 +227,7 @@ extern "C" {
|
|
|
226
227
|
|
|
227
228
|
WP_API
|
|
228
229
|
bool cutlass_gemm(
|
|
229
|
-
int compute_capability,
|
|
230
|
+
void* context, int compute_capability,
|
|
230
231
|
int m, int n, int k,
|
|
231
232
|
const char* datatype_str,
|
|
232
233
|
const void* a, const void* b, const void* c, void* d,
|
|
@@ -237,6 +238,8 @@ bool cutlass_gemm(
|
|
|
237
238
|
|
|
238
239
|
std::string datatype(datatype_str);
|
|
239
240
|
|
|
241
|
+
ContextGuard guard(context);
|
|
242
|
+
|
|
240
243
|
// Specializations for using Tensor Cores and A/B RowMajor/ColumnMajor designations
|
|
241
244
|
if (compute_capability == 80) {
|
|
242
245
|
if (datatype == F64_STR) {
|
warp/native/error.cpp
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
/** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
* and proprietary rights in and to this software, related documentation
|
|
4
|
+
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
* distribution of this software and related documentation without an express
|
|
6
|
+
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#include <stdarg.h>
|
|
10
|
+
#include <stdio.h>
|
|
11
|
+
#include <string.h>
|
|
12
|
+
|
|
13
|
+
namespace wp
|
|
14
|
+
{
|
|
15
|
+
static char g_error_buffer[4096] = "";
|
|
16
|
+
static bool g_error_output_enabled = true;
|
|
17
|
+
static FILE* g_error_stream = stderr;
|
|
18
|
+
|
|
19
|
+
const char* get_error_string()
|
|
20
|
+
{
|
|
21
|
+
return g_error_buffer;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
void set_error_string(const char* fmt, ...)
|
|
25
|
+
{
|
|
26
|
+
va_list args;
|
|
27
|
+
va_start(args, fmt);
|
|
28
|
+
vsnprintf(g_error_buffer, sizeof(g_error_buffer), fmt, args);
|
|
29
|
+
if (g_error_output_enabled)
|
|
30
|
+
{
|
|
31
|
+
vfprintf(g_error_stream, fmt, args);
|
|
32
|
+
fputc('\n', g_error_stream);
|
|
33
|
+
fflush(g_error_stream);
|
|
34
|
+
}
|
|
35
|
+
va_end(args);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
void append_error_string(const char* fmt, ...)
|
|
39
|
+
{
|
|
40
|
+
size_t offset = strlen(g_error_buffer);
|
|
41
|
+
if (offset + 2 > sizeof(g_error_buffer))
|
|
42
|
+
return;
|
|
43
|
+
g_error_buffer[offset++] = '\n';
|
|
44
|
+
va_list args;
|
|
45
|
+
va_start(args, fmt);
|
|
46
|
+
vsnprintf(g_error_buffer + offset, sizeof(g_error_buffer) - offset, fmt, args);
|
|
47
|
+
if (g_error_output_enabled)
|
|
48
|
+
{
|
|
49
|
+
vfprintf(g_error_stream, fmt, args);
|
|
50
|
+
fputc('\n', g_error_stream);
|
|
51
|
+
fflush(g_error_stream);
|
|
52
|
+
}
|
|
53
|
+
va_end(args);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
void set_error_output_enabled(bool enable)
|
|
57
|
+
{
|
|
58
|
+
g_error_output_enabled = enable;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
bool is_error_output_enabled()
|
|
62
|
+
{
|
|
63
|
+
return g_error_output_enabled;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
} // end of namespace wp
|
warp/native/error.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
/** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
* and proprietary rights in and to this software, related documentation
|
|
4
|
+
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
* distribution of this software and related documentation without an express
|
|
6
|
+
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
namespace wp
|
|
12
|
+
{
|
|
13
|
+
// functions related to error reporting
|
|
14
|
+
|
|
15
|
+
// get error string from Python
|
|
16
|
+
const char* get_error_string();
|
|
17
|
+
|
|
18
|
+
// set error message for Python
|
|
19
|
+
// these functions also print the error message if error output is enabled
|
|
20
|
+
void set_error_string(const char* fmt, ...);
|
|
21
|
+
void append_error_string(const char* fmt, ...);
|
|
22
|
+
|
|
23
|
+
// allow disabling printing errors, which is handy during tests that expect failure
|
|
24
|
+
void set_error_output_enabled(bool enable);
|
|
25
|
+
bool is_error_output_enabled();
|
|
26
|
+
|
|
27
|
+
}
|
warp/native/mesh.cu
CHANGED
|
@@ -203,8 +203,8 @@ uint64_t mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::arr
|
|
|
203
203
|
// bvh_destroy_host(bvh_host);
|
|
204
204
|
|
|
205
205
|
// create lower upper arrays expected by GPU BVH builder
|
|
206
|
-
mesh.lowers = (wp::vec3*)
|
|
207
|
-
mesh.uppers = (wp::vec3*)
|
|
206
|
+
mesh.lowers = (wp::vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(wp::vec3)*num_tris);
|
|
207
|
+
mesh.uppers = (wp::vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(wp::vec3)*num_tris);
|
|
208
208
|
|
|
209
209
|
wp_launch_device(WP_CURRENT_CONTEXT, wp::compute_triangle_bounds, num_tris, (num_tris, points.data, indices.data, mesh.lowers, mesh.uppers));
|
|
210
210
|
|
warp/native/reduce.cu
CHANGED
|
@@ -110,7 +110,7 @@ template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int coun
|
|
|
110
110
|
|
|
111
111
|
size_t buff_size = 0;
|
|
112
112
|
check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, ptr_strided, ptr_out, count, stream));
|
|
113
|
-
void* temp_buffer =
|
|
113
|
+
void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
|
|
114
114
|
|
|
115
115
|
for (int k = 0; k < type_length; ++k)
|
|
116
116
|
{
|
|
@@ -118,7 +118,7 @@ template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int coun
|
|
|
118
118
|
check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, ptr_strided, ptr_out + k, count, stream));
|
|
119
119
|
}
|
|
120
120
|
|
|
121
|
-
|
|
121
|
+
free_device(WP_CURRENT_CONTEXT, temp_buffer);
|
|
122
122
|
}
|
|
123
123
|
|
|
124
124
|
template <typename T>
|
|
@@ -271,11 +271,11 @@ void array_inner_device(const ElemT *ptr_a, const ElemT *ptr_b, ScalarT *ptr_out
|
|
|
271
271
|
|
|
272
272
|
size_t buff_size = 0;
|
|
273
273
|
check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, inner_iterator, ptr_out, count, stream));
|
|
274
|
-
void* temp_buffer =
|
|
274
|
+
void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
|
|
275
275
|
|
|
276
276
|
check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, inner_iterator, ptr_out, count, stream));
|
|
277
277
|
|
|
278
|
-
|
|
278
|
+
free_device(WP_CURRENT_CONTEXT, temp_buffer);
|
|
279
279
|
}
|
|
280
280
|
|
|
281
281
|
template <typename T>
|
warp/native/runlength_encode.cu
CHANGED
|
@@ -21,13 +21,13 @@ void runlength_encode_device(int n,
|
|
|
21
21
|
nullptr, buff_size, values, run_values, run_lengths, run_count,
|
|
22
22
|
n, stream));
|
|
23
23
|
|
|
24
|
-
void* temp_buffer =
|
|
24
|
+
void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
|
|
25
25
|
|
|
26
26
|
check_cuda(cub::DeviceRunLengthEncode::Encode(
|
|
27
27
|
temp_buffer, buff_size, values, run_values, run_lengths, run_count,
|
|
28
28
|
n, stream));
|
|
29
29
|
|
|
30
|
-
|
|
30
|
+
free_device(WP_CURRENT_CONTEXT, temp_buffer);
|
|
31
31
|
}
|
|
32
32
|
|
|
33
33
|
void runlength_encode_int_device(
|
warp/native/scan.cu
CHANGED
|
@@ -20,7 +20,7 @@ void scan_device(const T* values_in, T* values_out, int n, bool inclusive)
|
|
|
20
20
|
check_cuda(cub::DeviceScan::ExclusiveSum(NULL, scan_temp_size, values_in, values_out, n));
|
|
21
21
|
}
|
|
22
22
|
|
|
23
|
-
void* temp_buffer =
|
|
23
|
+
void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, scan_temp_size);
|
|
24
24
|
|
|
25
25
|
// scan
|
|
26
26
|
if (inclusive) {
|
|
@@ -29,7 +29,7 @@ void scan_device(const T* values_in, T* values_out, int n, bool inclusive)
|
|
|
29
29
|
check_cuda(cub::DeviceScan::ExclusiveSum(temp_buffer, scan_temp_size, values_in, values_out, n, stream));
|
|
30
30
|
}
|
|
31
31
|
|
|
32
|
-
|
|
32
|
+
free_device(WP_CURRENT_CONTEXT, temp_buffer);
|
|
33
33
|
}
|
|
34
34
|
|
|
35
35
|
template void scan_device(const int*, int*, int, bool);
|
warp/native/sparse.cu
CHANGED
|
@@ -456,7 +456,6 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
|
|
|
456
456
|
size_t buff_size = 0;
|
|
457
457
|
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
|
|
458
458
|
d_keys, nnz, 0, 64, stream));
|
|
459
|
-
void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
|
|
460
459
|
ScopedTemporary<> temp(context, buff_size);
|
|
461
460
|
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
462
461
|
temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
warp/native/temp_buffer.h
CHANGED
|
@@ -10,13 +10,13 @@ template <typename T = char> struct ScopedTemporary
|
|
|
10
10
|
{
|
|
11
11
|
|
|
12
12
|
ScopedTemporary(void *context, size_t size)
|
|
13
|
-
: m_context(context), m_buffer(static_cast<T*>(
|
|
13
|
+
: m_context(context), m_buffer(static_cast<T*>(alloc_device(m_context, size * sizeof(T))))
|
|
14
14
|
{
|
|
15
15
|
}
|
|
16
16
|
|
|
17
17
|
~ScopedTemporary()
|
|
18
18
|
{
|
|
19
|
-
|
|
19
|
+
free_device(m_context, m_buffer);
|
|
20
20
|
}
|
|
21
21
|
|
|
22
22
|
T *buffer() const
|