warp-lang 1.4.2__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1783 -2
- warp/codegen.py +177 -45
- warp/config.py +2 -2
- warp/context.py +321 -73
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +2 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -5
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +600 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1857 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +137 -65
- warp/sim/graph_coloring.py +292 -0
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +88 -15
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +241 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +18 -4
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +13 -0
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_tile.py +700 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -0
- warp/types.py +338 -72
- warp/utils.py +22 -1
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu
CHANGED
|
@@ -11,9 +11,16 @@
|
|
|
11
11
|
#include "cuda_util.h"
|
|
12
12
|
#include "error.h"
|
|
13
13
|
|
|
14
|
+
#include <cstdlib>
|
|
15
|
+
#include <fstream>
|
|
14
16
|
#include <nvrtc.h>
|
|
15
17
|
#include <nvPTXCompiler.h>
|
|
18
|
+
#if WP_ENABLE_MATHDX
|
|
19
|
+
#include <nvJitLink.h>
|
|
20
|
+
#include <libmathdx.h>
|
|
21
|
+
#endif
|
|
16
22
|
|
|
23
|
+
#include <array>
|
|
17
24
|
#include <algorithm>
|
|
18
25
|
#include <iterator>
|
|
19
26
|
#include <list>
|
|
@@ -23,8 +30,39 @@
|
|
|
23
30
|
#include <unordered_set>
|
|
24
31
|
#include <vector>
|
|
25
32
|
|
|
33
|
+
#define check_any(result) (check_generic(result, __FILE__, __LINE__))
|
|
26
34
|
#define check_nvrtc(code) (check_nvrtc_result(code, __FILE__, __LINE__))
|
|
27
35
|
#define check_nvptx(code) (check_nvptx_result(code, __FILE__, __LINE__))
|
|
36
|
+
#define check_nvjitlink(handle, code) (check_nvjitlink_result(handle, code, __FILE__, __LINE__))
|
|
37
|
+
#define check_cufftdx(code) (check_cufftdx_result(code, __FILE__, __LINE__))
|
|
38
|
+
#define check_cublasdx(code) (check_cublasdx_result(code, __FILE__, __LINE__))
|
|
39
|
+
#define CHECK_ANY(code) \
|
|
40
|
+
{ \
|
|
41
|
+
do { \
|
|
42
|
+
bool out = (check_any(code)); \
|
|
43
|
+
if(!out) { \
|
|
44
|
+
return out; \
|
|
45
|
+
} \
|
|
46
|
+
} while(0); \
|
|
47
|
+
}
|
|
48
|
+
#define CHECK_CUFFTDX(code) \
|
|
49
|
+
{ \
|
|
50
|
+
do { \
|
|
51
|
+
bool out = (check_cufftdx(code)); \
|
|
52
|
+
if(!out) { \
|
|
53
|
+
return out; \
|
|
54
|
+
} \
|
|
55
|
+
} while(0); \
|
|
56
|
+
}
|
|
57
|
+
#define CHECK_CUBLASDX(code) \
|
|
58
|
+
{ \
|
|
59
|
+
do { \
|
|
60
|
+
bool out = (check_cufftdx(code)); \
|
|
61
|
+
if(!out) { \
|
|
62
|
+
return out; \
|
|
63
|
+
} \
|
|
64
|
+
} while(0); \
|
|
65
|
+
}
|
|
28
66
|
|
|
29
67
|
bool check_nvrtc_result(nvrtcResult result, const char* file, int line)
|
|
30
68
|
{
|
|
@@ -74,6 +112,15 @@ bool check_nvptx_result(nvPTXCompileResult result, const char* file, int line)
|
|
|
74
112
|
return false;
|
|
75
113
|
}
|
|
76
114
|
|
|
115
|
+
bool check_generic(int result, const char* file, int line)
|
|
116
|
+
{
|
|
117
|
+
if (!result) {
|
|
118
|
+
fprintf(stderr, "Error %d on %s:%d\n", (int)result, file, line);
|
|
119
|
+
return false;
|
|
120
|
+
} else {
|
|
121
|
+
return true;
|
|
122
|
+
}
|
|
123
|
+
}
|
|
77
124
|
|
|
78
125
|
struct DeviceInfo
|
|
79
126
|
{
|
|
@@ -89,6 +136,7 @@ struct DeviceInfo
|
|
|
89
136
|
int arch = 0;
|
|
90
137
|
int is_uva = 0;
|
|
91
138
|
int is_mempool_supported = 0;
|
|
139
|
+
int max_smem_bytes = 0;
|
|
92
140
|
CUcontext primary_context = NULL;
|
|
93
141
|
};
|
|
94
142
|
|
|
@@ -202,6 +250,7 @@ int cuda_init()
|
|
|
202
250
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
|
|
203
251
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
|
|
204
252
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
|
|
253
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].max_smem_bytes, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
|
|
205
254
|
int major = 0;
|
|
206
255
|
int minor = 0;
|
|
207
256
|
check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
|
|
@@ -2520,11 +2569,57 @@ bool cuda_graph_destroy(void* context, void* graph_exec)
|
|
|
2520
2569
|
return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
|
|
2521
2570
|
}
|
|
2522
2571
|
|
|
2523
|
-
|
|
2572
|
+
bool write_file(const char* data, size_t size, std::string filename, const char* mode)
|
|
2573
|
+
{
|
|
2574
|
+
const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
|
|
2575
|
+
if (print_debug)
|
|
2576
|
+
{
|
|
2577
|
+
printf("Writing %zu B to %s (%s)\n", size, filename.c_str(), mode);
|
|
2578
|
+
}
|
|
2579
|
+
FILE* file = fopen(filename.c_str(), mode);
|
|
2580
|
+
if (file)
|
|
2581
|
+
{
|
|
2582
|
+
if (fwrite(data, 1, size, file) != size) {
|
|
2583
|
+
fprintf(stderr, "Warp error: Failed to write to output file '%s'\n", filename.c_str());
|
|
2584
|
+
return false;
|
|
2585
|
+
}
|
|
2586
|
+
fclose(file);
|
|
2587
|
+
return true;
|
|
2588
|
+
}
|
|
2589
|
+
else
|
|
2590
|
+
{
|
|
2591
|
+
fprintf(stderr, "Warp error: Failed to open file '%s'\n", filename.c_str());
|
|
2592
|
+
return false;
|
|
2593
|
+
}
|
|
2594
|
+
}
|
|
2595
|
+
|
|
2596
|
+
#if WP_ENABLE_MATHDX
|
|
2597
|
+
bool check_nvjitlink_result(nvJitLinkHandle handle, nvJitLinkResult result, const char* file, int line)
|
|
2598
|
+
{
|
|
2599
|
+
if (result != NVJITLINK_SUCCESS) {
|
|
2600
|
+
fprintf(stderr, "nvJitLink error: %d on %s:%d\n", (int)result, file, line);
|
|
2601
|
+
size_t lsize;
|
|
2602
|
+
result = nvJitLinkGetErrorLogSize(handle, &lsize);
|
|
2603
|
+
if (result == NVJITLINK_SUCCESS && lsize > 0) {
|
|
2604
|
+
std::vector<char> log(lsize);
|
|
2605
|
+
result = nvJitLinkGetErrorLog(handle, log.data());
|
|
2606
|
+
if (result == NVJITLINK_SUCCESS) {
|
|
2607
|
+
fprintf(stderr, "%s\n", log.data());
|
|
2608
|
+
}
|
|
2609
|
+
}
|
|
2610
|
+
return false;
|
|
2611
|
+
} else {
|
|
2612
|
+
return true;
|
|
2613
|
+
}
|
|
2614
|
+
}
|
|
2615
|
+
#endif
|
|
2616
|
+
|
|
2617
|
+
size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes)
|
|
2524
2618
|
{
|
|
2525
2619
|
// use file extension to determine whether to output PTX or CUBIN
|
|
2526
2620
|
const char* output_ext = strrchr(output_path, '.');
|
|
2527
2621
|
bool use_ptx = output_ext && strcmp(output_ext + 1, "ptx") == 0;
|
|
2622
|
+
const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
|
|
2528
2623
|
|
|
2529
2624
|
// check include dir path len (path + option)
|
|
2530
2625
|
const int max_path = 4096 + 16;
|
|
@@ -2534,17 +2629,37 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
|
|
|
2534
2629
|
return size_t(-1);
|
|
2535
2630
|
}
|
|
2536
2631
|
|
|
2632
|
+
if (print_debug)
|
|
2633
|
+
{
|
|
2634
|
+
// Not available in all nvJitLink versions
|
|
2635
|
+
// unsigned major = 0;
|
|
2636
|
+
// unsigned minor = 0;
|
|
2637
|
+
// nvJitLinkVersion(&major, &minor);
|
|
2638
|
+
// printf("nvJitLink version %d.%d\n", major, minor);
|
|
2639
|
+
int major = 0;
|
|
2640
|
+
int minor = 0;
|
|
2641
|
+
nvrtcVersion(&major, &minor);
|
|
2642
|
+
printf("NVRTC version %d.%d\n", major, minor);
|
|
2643
|
+
}
|
|
2644
|
+
|
|
2537
2645
|
char include_opt[max_path];
|
|
2538
2646
|
strcpy(include_opt, "--include-path=");
|
|
2539
2647
|
strcat(include_opt, include_dir);
|
|
2540
2648
|
|
|
2541
2649
|
const int max_arch = 128;
|
|
2542
2650
|
char arch_opt[max_arch];
|
|
2651
|
+
char arch_opt_lto[max_arch];
|
|
2543
2652
|
|
|
2544
2653
|
if (use_ptx)
|
|
2654
|
+
{
|
|
2545
2655
|
snprintf(arch_opt, max_arch, "--gpu-architecture=compute_%d", arch);
|
|
2656
|
+
snprintf(arch_opt_lto, max_arch, "-arch=compute_%d", arch);
|
|
2657
|
+
}
|
|
2546
2658
|
else
|
|
2659
|
+
{
|
|
2547
2660
|
snprintf(arch_opt, max_arch, "--gpu-architecture=sm_%d", arch);
|
|
2661
|
+
snprintf(arch_opt_lto, max_arch, "-arch=sm_%d", arch);
|
|
2662
|
+
}
|
|
2548
2663
|
|
|
2549
2664
|
std::vector<const char*> opts;
|
|
2550
2665
|
opts.push_back(arch_opt);
|
|
@@ -2555,6 +2670,7 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
|
|
|
2555
2670
|
{
|
|
2556
2671
|
opts.push_back("--define-macro=_DEBUG");
|
|
2557
2672
|
opts.push_back("--generate-line-info");
|
|
2673
|
+
|
|
2558
2674
|
// disabling since it causes issues with `Unresolved extern function 'cudaGetParameterBufferV2'
|
|
2559
2675
|
//opts.push_back("--device-debug");
|
|
2560
2676
|
}
|
|
@@ -2569,6 +2685,26 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
|
|
|
2569
2685
|
if (fast_math)
|
|
2570
2686
|
opts.push_back("--use_fast_math");
|
|
2571
2687
|
|
|
2688
|
+
char include_cutlass[max_path];
|
|
2689
|
+
sprintf(include_cutlass, "--include-path=%s/cutlass/include", include_dir);
|
|
2690
|
+
opts.push_back(include_cutlass);
|
|
2691
|
+
|
|
2692
|
+
std::vector<std::string> cuda_include_opt;
|
|
2693
|
+
for(int i = 0; i < num_cuda_include_dirs; i++)
|
|
2694
|
+
{
|
|
2695
|
+
cuda_include_opt.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
|
|
2696
|
+
opts.push_back(cuda_include_opt.back().c_str());
|
|
2697
|
+
}
|
|
2698
|
+
|
|
2699
|
+
opts.push_back("--device-as-default-execution-space");
|
|
2700
|
+
opts.push_back("--extra-device-vectorization");
|
|
2701
|
+
opts.push_back("--restrict");
|
|
2702
|
+
|
|
2703
|
+
if (num_ltoirs > 0)
|
|
2704
|
+
{
|
|
2705
|
+
opts.push_back("-dlto");
|
|
2706
|
+
opts.push_back("--relocatable-device-code=true");
|
|
2707
|
+
}
|
|
2572
2708
|
|
|
2573
2709
|
nvrtcProgram prog;
|
|
2574
2710
|
nvrtcResult res;
|
|
@@ -2584,6 +2720,13 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
|
|
|
2584
2720
|
if (!check_nvrtc(res))
|
|
2585
2721
|
return size_t(res);
|
|
2586
2722
|
|
|
2723
|
+
if (print_debug)
|
|
2724
|
+
{
|
|
2725
|
+
printf("NVRTC options:\n");
|
|
2726
|
+
for(auto o: opts) {
|
|
2727
|
+
printf("%s\n", o);
|
|
2728
|
+
}
|
|
2729
|
+
}
|
|
2587
2730
|
res = nvrtcCompileProgram(prog, int(opts.size()), opts.data());
|
|
2588
2731
|
|
|
2589
2732
|
if (!check_nvrtc(res) || verbose)
|
|
@@ -2613,7 +2756,17 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
|
|
|
2613
2756
|
nvrtcResult (*get_output_size)(nvrtcProgram, size_t*);
|
|
2614
2757
|
nvrtcResult (*get_output_data)(nvrtcProgram, char*);
|
|
2615
2758
|
const char* output_mode;
|
|
2616
|
-
if
|
|
2759
|
+
if(num_ltoirs > 0) {
|
|
2760
|
+
#if WP_ENABLE_MATHDX
|
|
2761
|
+
get_output_size = nvrtcGetLTOIRSize;
|
|
2762
|
+
get_output_data = nvrtcGetLTOIR;
|
|
2763
|
+
output_mode = "wb";
|
|
2764
|
+
#else
|
|
2765
|
+
fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
|
|
2766
|
+
return size_t(-1);
|
|
2767
|
+
#endif
|
|
2768
|
+
}
|
|
2769
|
+
else if (use_ptx)
|
|
2617
2770
|
{
|
|
2618
2771
|
get_output_size = nvrtcGetPTXSize;
|
|
2619
2772
|
get_output_data = nvrtcGetPTX;
|
|
@@ -2635,19 +2788,78 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
|
|
|
2635
2788
|
res = get_output_data(prog, output.data());
|
|
2636
2789
|
if (check_nvrtc(res))
|
|
2637
2790
|
{
|
|
2638
|
-
|
|
2639
|
-
|
|
2791
|
+
|
|
2792
|
+
// LTOIR case - need an extra step
|
|
2793
|
+
if (num_ltoirs > 0)
|
|
2640
2794
|
{
|
|
2641
|
-
|
|
2795
|
+
#if WP_ENABLE_MATHDX
|
|
2796
|
+
nvJitLinkHandle handle;
|
|
2797
|
+
std::vector<const char *> lopts = {"-dlto", arch_opt_lto};
|
|
2798
|
+
if (use_ptx) {
|
|
2799
|
+
lopts.push_back("-ptx");
|
|
2800
|
+
}
|
|
2801
|
+
if (print_debug)
|
|
2802
|
+
{
|
|
2803
|
+
printf("nvJitLink options:\n");
|
|
2804
|
+
for(auto o: lopts) {
|
|
2805
|
+
printf("%s\n", o);
|
|
2806
|
+
}
|
|
2807
|
+
}
|
|
2808
|
+
if(!check_nvjitlink(handle, nvJitLinkCreate(&handle, lopts.size(), lopts.data())))
|
|
2809
|
+
{
|
|
2810
|
+
res = nvrtcResult(-1);
|
|
2811
|
+
}
|
|
2812
|
+
// Links
|
|
2813
|
+
if(std::getenv("WARP_DUMP_LTOIR"))
|
|
2814
|
+
{
|
|
2815
|
+
write_file(output.data(), output.size(), "nvrtc_output.ltoir", "wb");
|
|
2816
|
+
}
|
|
2817
|
+
if(!check_nvjitlink(handle, nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, output.data(), output.size(), "nvrtc_output"))) // NVRTC business
|
|
2642
2818
|
{
|
|
2643
|
-
fprintf(stderr, "Warp error: Failed to write output file '%s'\n", output_path);
|
|
2644
2819
|
res = nvrtcResult(-1);
|
|
2645
2820
|
}
|
|
2646
|
-
|
|
2821
|
+
for(size_t ltoidx = 0; ltoidx < num_ltoirs; ltoidx++)
|
|
2822
|
+
{
|
|
2823
|
+
if(std::getenv("WARP_DUMP_LTOIR"))
|
|
2824
|
+
{
|
|
2825
|
+
write_file(ltoirs[ltoidx], ltoir_sizes[ltoidx], std::string("lto_online_") + std::to_string(ltoidx) + ".ltoir", "wb");
|
|
2826
|
+
}
|
|
2827
|
+
if(!check_nvjitlink(handle, nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, ltoirs[ltoidx], ltoir_sizes[ltoidx], "lto_online"))) // External LTOIR
|
|
2828
|
+
{
|
|
2829
|
+
res = nvrtcResult(-1);
|
|
2830
|
+
}
|
|
2831
|
+
}
|
|
2832
|
+
if(!check_nvjitlink(handle, nvJitLinkComplete(handle)))
|
|
2833
|
+
{
|
|
2834
|
+
res = nvrtcResult(-1);
|
|
2835
|
+
}
|
|
2836
|
+
else
|
|
2837
|
+
{
|
|
2838
|
+
if(use_ptx)
|
|
2839
|
+
{
|
|
2840
|
+
size_t ptx_size = 0;
|
|
2841
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedPtxSize(handle, &ptx_size));
|
|
2842
|
+
std::vector<char> ptx(ptx_size);
|
|
2843
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedPtx(handle, ptx.data()));
|
|
2844
|
+
output = ptx;
|
|
2845
|
+
}
|
|
2846
|
+
else
|
|
2847
|
+
{
|
|
2848
|
+
size_t cubin_size = 0;
|
|
2849
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedCubinSize(handle, &cubin_size));
|
|
2850
|
+
std::vector<char> cubin(cubin_size);
|
|
2851
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedCubin(handle, cubin.data()));
|
|
2852
|
+
output = cubin;
|
|
2853
|
+
}
|
|
2854
|
+
}
|
|
2855
|
+
check_nvjitlink(handle, nvJitLinkDestroy(&handle));
|
|
2856
|
+
#else
|
|
2857
|
+
fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
|
|
2858
|
+
return size_t(-1);
|
|
2859
|
+
#endif
|
|
2647
2860
|
}
|
|
2648
|
-
|
|
2649
|
-
{
|
|
2650
|
-
fprintf(stderr, "Warp error: Failed to open output file '%s'\n", output_path);
|
|
2861
|
+
|
|
2862
|
+
if(!write_file(output.data(), output.size(), output_path, output_mode)) {
|
|
2651
2863
|
res = nvrtcResult(-1);
|
|
2652
2864
|
}
|
|
2653
2865
|
}
|
|
@@ -2658,6 +2870,119 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
|
|
|
2658
2870
|
return res;
|
|
2659
2871
|
}
|
|
2660
2872
|
|
|
2873
|
+
#if WP_ENABLE_MATHDX
|
|
2874
|
+
bool check_cufftdx_result(commonDxStatusType result, const char* file, int line)
|
|
2875
|
+
{
|
|
2876
|
+
if (result != commonDxStatusType::COMMONDX_SUCCESS) {
|
|
2877
|
+
fprintf(stderr, "libmathdx cuFFTDx error: %d on %s:%d\n", (int)result, file, line);
|
|
2878
|
+
return false;
|
|
2879
|
+
} else {
|
|
2880
|
+
return true;
|
|
2881
|
+
}
|
|
2882
|
+
}
|
|
2883
|
+
|
|
2884
|
+
bool check_cublasdx_result(commonDxStatusType result, const char* file, int line)
|
|
2885
|
+
{
|
|
2886
|
+
if (result != commonDxStatusType::COMMONDX_SUCCESS) {
|
|
2887
|
+
fprintf(stderr, "libmathdx cuBLASDx error: %d on %s:%d\n", (int)result, file, line);
|
|
2888
|
+
return false;
|
|
2889
|
+
} else {
|
|
2890
|
+
return true;
|
|
2891
|
+
}
|
|
2892
|
+
}
|
|
2893
|
+
|
|
2894
|
+
bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size)
|
|
2895
|
+
{
|
|
2896
|
+
|
|
2897
|
+
CHECK_ANY(ltoir_output_path != nullptr);
|
|
2898
|
+
CHECK_ANY(symbol_name != nullptr);
|
|
2899
|
+
CHECK_ANY(shared_memory_size != nullptr);
|
|
2900
|
+
// Includes currently unused
|
|
2901
|
+
CHECK_ANY(include_dirs == nullptr);
|
|
2902
|
+
CHECK_ANY(mathdx_include_dir == nullptr);
|
|
2903
|
+
CHECK_ANY(num_include_dirs == 0);
|
|
2904
|
+
|
|
2905
|
+
bool res = true;
|
|
2906
|
+
cufftdxHandle h;
|
|
2907
|
+
CHECK_CUFFTDX(cufftDxCreate(&h));
|
|
2908
|
+
|
|
2909
|
+
// CUFFTDX_API_BLOCK_LMEM means each thread starts with a subset of the data
|
|
2910
|
+
CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_API, cufftDxApi::CUFFTDX_API_BLOCK_LMEM));
|
|
2911
|
+
CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commonDxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
2912
|
+
CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
|
|
2913
|
+
CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftDxDirection)direction));
|
|
2914
|
+
CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_PRECISION, (commonDxPrecision)precision));
|
|
2915
|
+
CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
2916
|
+
CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_ELEMENTS_PER_THREAD, (long long)(elements_per_thread)));
|
|
2917
|
+
CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_FFTS_PER_BLOCK, 1));
|
|
2918
|
+
|
|
2919
|
+
CHECK_CUFFTDX(cufftDxSetOptionStr(h, commonDxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
2920
|
+
|
|
2921
|
+
size_t lto_size = 0;
|
|
2922
|
+
CHECK_CUFFTDX(cufftDxGetLTOIRSize(h, <o_size));
|
|
2923
|
+
|
|
2924
|
+
std::vector<char> lto(lto_size);
|
|
2925
|
+
CHECK_CUFFTDX(cufftDxGetLTOIR(h, lto.size(), lto.data()));
|
|
2926
|
+
|
|
2927
|
+
long long int smem = 0;
|
|
2928
|
+
CHECK_CUFFTDX(cufftDxGetTraitInt64(h, cufftDxTraitType::CUFFTDX_TRAIT_SHARED_MEMORY_SIZE, &smem));
|
|
2929
|
+
*shared_memory_size = (int)smem;
|
|
2930
|
+
|
|
2931
|
+
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
2932
|
+
res = false;
|
|
2933
|
+
}
|
|
2934
|
+
|
|
2935
|
+
CHECK_CUFFTDX(cufftDxDestroy(h));
|
|
2936
|
+
|
|
2937
|
+
return res;
|
|
2938
|
+
}
|
|
2939
|
+
|
|
2940
|
+
bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads)
|
|
2941
|
+
{
|
|
2942
|
+
|
|
2943
|
+
CHECK_ANY(ltoir_output_path != nullptr);
|
|
2944
|
+
CHECK_ANY(symbol_name != nullptr);
|
|
2945
|
+
// Includes currently unused
|
|
2946
|
+
CHECK_ANY(include_dirs == nullptr);
|
|
2947
|
+
CHECK_ANY(mathdx_include_dir == nullptr);
|
|
2948
|
+
CHECK_ANY(num_include_dirs == 0);
|
|
2949
|
+
|
|
2950
|
+
bool res = true;
|
|
2951
|
+
cublasdxHandle h;
|
|
2952
|
+
CHECK_CUBLASDX(cublasDxCreate(&h));
|
|
2953
|
+
|
|
2954
|
+
CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasDxFunction::CUBLASDX_FUNCTION_MM));
|
|
2955
|
+
CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commonDxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
2956
|
+
CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_API, cublasDxApi::CUBLASDX_API_BLOCK_SMEM));
|
|
2957
|
+
std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
|
|
2958
|
+
CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
|
|
2959
|
+
CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
2960
|
+
CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasDxType)type));
|
|
2961
|
+
std::array<long long int, 3> block_dim = {num_threads, 1, 1};
|
|
2962
|
+
CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
|
|
2963
|
+
std::array<long long int, 3> size = {M, N, K};
|
|
2964
|
+
CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
|
|
2965
|
+
std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
|
|
2966
|
+
CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
|
|
2967
|
+
|
|
2968
|
+
CHECK_CUBLASDX(cublasDxSetOptionStr(h, commonDxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
2969
|
+
|
|
2970
|
+
size_t lto_size = 0;
|
|
2971
|
+
CHECK_CUBLASDX(cublasDxGetLTOIRSize(h, <o_size));
|
|
2972
|
+
|
|
2973
|
+
std::vector<char> lto(lto_size);
|
|
2974
|
+
CHECK_CUBLASDX(cublasDxGetLTOIR(h, lto.size(), lto.data()));
|
|
2975
|
+
|
|
2976
|
+
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
2977
|
+
res = false;
|
|
2978
|
+
}
|
|
2979
|
+
|
|
2980
|
+
CHECK_CUBLASDX(cublasDxDestroy(h));
|
|
2981
|
+
|
|
2982
|
+
return res;
|
|
2983
|
+
}
|
|
2984
|
+
#endif
|
|
2985
|
+
|
|
2661
2986
|
void* cuda_load_module(void* context, const char* path)
|
|
2662
2987
|
{
|
|
2663
2988
|
ContextGuard guard(context);
|
|
@@ -2784,6 +3109,29 @@ void cuda_unload_module(void* context, void* module)
|
|
|
2784
3109
|
check_cu(cuModuleUnload_f((CUmodule)module));
|
|
2785
3110
|
}
|
|
2786
3111
|
|
|
3112
|
+
|
|
3113
|
+
int cuda_get_max_shared_memory(void* context)
|
|
3114
|
+
{
|
|
3115
|
+
ContextInfo* info = get_context_info(context);
|
|
3116
|
+
if (!info)
|
|
3117
|
+
return -1;
|
|
3118
|
+
|
|
3119
|
+
int max_smem_bytes = info->device_info->max_smem_bytes;
|
|
3120
|
+
return max_smem_bytes;
|
|
3121
|
+
}
|
|
3122
|
+
|
|
3123
|
+
bool cuda_configure_kernel_shared_memory(void* kernel, int size)
|
|
3124
|
+
{
|
|
3125
|
+
int requested_smem_bytes = size;
|
|
3126
|
+
|
|
3127
|
+
// configure shared memory
|
|
3128
|
+
CUresult res = cuFuncSetAttribute_f((CUfunction)kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, requested_smem_bytes);
|
|
3129
|
+
if (res != CUDA_SUCCESS)
|
|
3130
|
+
return false;
|
|
3131
|
+
|
|
3132
|
+
return true;
|
|
3133
|
+
}
|
|
3134
|
+
|
|
2787
3135
|
void* cuda_get_kernel(void* context, void* module, const char* name)
|
|
2788
3136
|
{
|
|
2789
3137
|
ContextGuard guard(context);
|
|
@@ -2796,15 +3144,21 @@ void* cuda_get_kernel(void* context, void* module, const char* name)
|
|
|
2796
3144
|
}
|
|
2797
3145
|
|
|
2798
3146
|
g_kernel_names[kernel] = name;
|
|
2799
|
-
|
|
2800
3147
|
return kernel;
|
|
2801
3148
|
}
|
|
2802
3149
|
|
|
2803
|
-
size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args, void* stream)
|
|
3150
|
+
size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream)
|
|
2804
3151
|
{
|
|
2805
3152
|
ContextGuard guard(context);
|
|
2806
3153
|
|
|
2807
|
-
|
|
3154
|
+
if (block_dim <= 0)
|
|
3155
|
+
{
|
|
3156
|
+
#if defined(_DEBUG)
|
|
3157
|
+
fprintf(stderr, "Warp warning: Launch got block_dim %d. Setting to 256.\n", dim, block_dim);
|
|
3158
|
+
#endif
|
|
3159
|
+
block_dim = 256;
|
|
3160
|
+
}
|
|
3161
|
+
|
|
2808
3162
|
// CUDA specs up to compute capability 9.0 says the max x-dim grid is 2**31-1, so
|
|
2809
3163
|
// grid_dim is fine as an int for the near future
|
|
2810
3164
|
int grid_dim = (dim + block_dim - 1)/block_dim;
|
|
@@ -2835,7 +3189,8 @@ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_block
|
|
|
2835
3189
|
(CUfunction)kernel,
|
|
2836
3190
|
grid_dim, 1, 1,
|
|
2837
3191
|
block_dim, 1, 1,
|
|
2838
|
-
|
|
3192
|
+
shared_memory_bytes,
|
|
3193
|
+
static_cast<CUstream>(stream),
|
|
2839
3194
|
args,
|
|
2840
3195
|
0);
|
|
2841
3196
|
|
|
@@ -2940,7 +3295,6 @@ void cuda_timing_end(timing_result_t* results, int size)
|
|
|
2940
3295
|
g_cuda_timing_state = parent_state;
|
|
2941
3296
|
}
|
|
2942
3297
|
|
|
2943
|
-
|
|
2944
3298
|
// impl. files
|
|
2945
3299
|
#include "bvh.cu"
|
|
2946
3300
|
#include "mesh.cu"
|
warp/native/warp.h
CHANGED
|
@@ -34,6 +34,8 @@ extern "C"
|
|
|
34
34
|
WP_API int is_cuda_compatibility_enabled();
|
|
35
35
|
// whether Warp was compiled with CUTLASS support
|
|
36
36
|
WP_API int is_cutlass_enabled();
|
|
37
|
+
// whether Warp was compiled with MathDx support
|
|
38
|
+
WP_API int is_mathdx_enabled();
|
|
37
39
|
// whether Warp was compiled with debug support
|
|
38
40
|
WP_API int is_debug_enabled();
|
|
39
41
|
|
|
@@ -315,12 +317,16 @@ extern "C"
|
|
|
315
317
|
WP_API bool cuda_graph_launch(void* graph, void* stream);
|
|
316
318
|
WP_API bool cuda_graph_destroy(void* context, void* graph);
|
|
317
319
|
|
|
318
|
-
WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, bool debug, bool verbose, bool verify_fp, bool fast_math, const char*
|
|
320
|
+
WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes);
|
|
321
|
+
WP_API bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size);
|
|
322
|
+
WP_API bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads);
|
|
319
323
|
|
|
320
324
|
WP_API void* cuda_load_module(void* context, const char* ptx);
|
|
321
325
|
WP_API void cuda_unload_module(void* context, void* module);
|
|
322
326
|
WP_API void* cuda_get_kernel(void* context, void* module, const char* name);
|
|
323
|
-
WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args, void* stream);
|
|
327
|
+
WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream);
|
|
328
|
+
WP_API int cuda_get_max_shared_memory(void* context);
|
|
329
|
+
WP_API bool cuda_configure_kernel_shared_memory(void* kernel, int size);
|
|
324
330
|
|
|
325
331
|
WP_API void cuda_set_context_restore_policy(bool always_restore);
|
|
326
332
|
WP_API int cuda_get_context_restore_policy();
|
|
@@ -336,4 +342,8 @@ extern "C"
|
|
|
336
342
|
WP_API int cuda_timing_get_result_count();
|
|
337
343
|
WP_API void cuda_timing_end(timing_result_t* results, int size);
|
|
338
344
|
|
|
345
|
+
// graph coloring
|
|
346
|
+
WP_API int graph_coloring(int num_nodes, wp::array_t<int> edges, int algorithm, wp::array_t<int> node_colors);
|
|
347
|
+
WP_API float balance_coloring(int num_nodes, wp::array_t<int> edges, int num_colors, float target_max_min_ratio, wp::array_t<int> node_colors);
|
|
348
|
+
|
|
339
349
|
} // extern "C"
|
warp/optim/adam.py
CHANGED
|
@@ -50,6 +50,26 @@ def adam_step_kernel_float(
|
|
|
50
50
|
params[i] = params[i] - lr * mhat / (wp.sqrt(vhat) + eps)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
@wp.kernel
|
|
54
|
+
def adam_step_kernel_half(
|
|
55
|
+
g: wp.array(dtype=wp.float16),
|
|
56
|
+
m: wp.array(dtype=float),
|
|
57
|
+
v: wp.array(dtype=float),
|
|
58
|
+
lr: float,
|
|
59
|
+
beta1: float,
|
|
60
|
+
beta2: float,
|
|
61
|
+
t: float,
|
|
62
|
+
eps: float,
|
|
63
|
+
params: wp.array(dtype=wp.float16),
|
|
64
|
+
):
|
|
65
|
+
i = wp.tid()
|
|
66
|
+
m[i] = beta1 * m[i] + (1.0 - beta1) * float(g[i])
|
|
67
|
+
v[i] = beta2 * v[i] + (1.0 - beta2) * float(g[i]) * float(g[i])
|
|
68
|
+
mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
|
|
69
|
+
vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
|
|
70
|
+
params[i] = params[i] - wp.float16(lr * mhat / (wp.sqrt(vhat) + eps))
|
|
71
|
+
|
|
72
|
+
|
|
53
73
|
class Adam:
|
|
54
74
|
"""An implementation of the Adam Optimizer
|
|
55
75
|
It is designed to mimic Pytorch's version.
|
|
@@ -75,10 +95,20 @@ class Adam:
|
|
|
75
95
|
self.v = [None] * len(params) # reset second moment
|
|
76
96
|
for i in range(len(params)):
|
|
77
97
|
param = params[i]
|
|
98
|
+
|
|
99
|
+
if param.dtype == wp.vec3:
|
|
100
|
+
dtype = wp.vec3
|
|
101
|
+
elif param.dtype == wp.float32:
|
|
102
|
+
dtype = wp.float32
|
|
103
|
+
elif param.dtype == wp.float16:
|
|
104
|
+
dtype = wp.float32 # we always use fp32 for moments, even if params are fp16
|
|
105
|
+
else:
|
|
106
|
+
raise RuntimeError(f"Unsupported dtype for Warp Adam optimizer: {param.dtype}")
|
|
107
|
+
|
|
78
108
|
if self.m[i] is None or self.m[i].shape != param.shape or self.m[i].dtype != param.dtype:
|
|
79
|
-
self.m[i] = wp.
|
|
109
|
+
self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
|
|
80
110
|
if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype:
|
|
81
|
-
self.v[i] = wp.
|
|
111
|
+
self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
|
|
82
112
|
|
|
83
113
|
def reset_internal_state(self):
|
|
84
114
|
for m_i in self.m:
|
|
@@ -98,8 +128,6 @@ class Adam:
|
|
|
98
128
|
@staticmethod
|
|
99
129
|
def step_detail(g, m, v, lr, beta1, beta2, t, eps, params):
|
|
100
130
|
assert params.dtype == g.dtype
|
|
101
|
-
assert params.dtype == m.dtype
|
|
102
|
-
assert params.dtype == v.dtype
|
|
103
131
|
assert params.shape == g.shape
|
|
104
132
|
kernel_inputs = [g, m, v, lr, beta1, beta2, t, eps, params]
|
|
105
133
|
if params.dtype == wp.types.float32:
|
|
@@ -109,6 +137,13 @@ class Adam:
|
|
|
109
137
|
inputs=kernel_inputs,
|
|
110
138
|
device=params.device,
|
|
111
139
|
)
|
|
140
|
+
elif params.dtype == wp.types.float16:
|
|
141
|
+
wp.launch(
|
|
142
|
+
kernel=adam_step_kernel_half,
|
|
143
|
+
dim=len(params),
|
|
144
|
+
inputs=kernel_inputs,
|
|
145
|
+
device=params.device,
|
|
146
|
+
)
|
|
112
147
|
elif params.dtype == wp.types.vec3:
|
|
113
148
|
wp.launch(
|
|
114
149
|
kernel=adam_step_kernel_vec3,
|