warp-lang 1.4.2__py3-none-manylinux2014_x86_64.whl → 1.5.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (158) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1783 -2
  8. warp/codegen.py +177 -45
  9. warp/config.py +2 -2
  10. warp/context.py +321 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +2 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -5
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +600 -0
  82. warp/native/cuda_util.cpp +14 -0
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1857 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +137 -65
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +88 -15
  117. warp/stubs.py +569 -4
  118. warp/tape.py +12 -7
  119. warp/tests/assets/pixel.npy +0 -0
  120. warp/tests/aux_test_instancing_gc.py +18 -0
  121. warp/tests/test_array.py +39 -0
  122. warp/tests/test_codegen.py +81 -1
  123. warp/tests/test_codegen_instancing.py +30 -0
  124. warp/tests/test_collision.py +110 -0
  125. warp/tests/test_coloring.py +241 -0
  126. warp/tests/test_context.py +34 -0
  127. warp/tests/test_examples.py +18 -4
  128. warp/tests/test_fem.py +453 -113
  129. warp/tests/test_func.py +13 -0
  130. warp/tests/test_generics.py +52 -0
  131. warp/tests/test_iter.py +68 -0
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_mesh_query_point.py +1 -1
  134. warp/tests/test_module_hashing.py +23 -0
  135. warp/tests/test_paddle.py +27 -87
  136. warp/tests/test_print.py +56 -1
  137. warp/tests/test_spatial.py +1 -1
  138. warp/tests/test_tile.py +700 -0
  139. warp/tests/test_tile_mathdx.py +144 -0
  140. warp/tests/test_tile_mlp.py +383 -0
  141. warp/tests/test_tile_reduce.py +374 -0
  142. warp/tests/test_tile_shared_memory.py +190 -0
  143. warp/tests/test_vbd.py +12 -20
  144. warp/tests/test_volume.py +43 -0
  145. warp/tests/unittest_suites.py +19 -2
  146. warp/tests/unittest_utils.py +4 -0
  147. warp/types.py +338 -72
  148. warp/utils.py +22 -1
  149. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  150. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
  151. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  152. warp/fem/field/test.py +0 -180
  153. warp/fem/field/trial.py +0 -183
  154. warp/fem/space/collocated_function_space.py +0 -102
  155. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  156. warp/fem/space/trimesh_2d_function_space.py +0 -153
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -11,9 +11,16 @@
11
11
  #include "cuda_util.h"
12
12
  #include "error.h"
13
13
 
14
+ #include <cstdlib>
15
+ #include <fstream>
14
16
  #include <nvrtc.h>
15
17
  #include <nvPTXCompiler.h>
18
+ #if WP_ENABLE_MATHDX
19
+ #include <nvJitLink.h>
20
+ #include <libmathdx.h>
21
+ #endif
16
22
 
23
+ #include <array>
17
24
  #include <algorithm>
18
25
  #include <iterator>
19
26
  #include <list>
@@ -23,8 +30,39 @@
23
30
  #include <unordered_set>
24
31
  #include <vector>
25
32
 
33
+ #define check_any(result) (check_generic(result, __FILE__, __LINE__))
26
34
  #define check_nvrtc(code) (check_nvrtc_result(code, __FILE__, __LINE__))
27
35
  #define check_nvptx(code) (check_nvptx_result(code, __FILE__, __LINE__))
36
+ #define check_nvjitlink(handle, code) (check_nvjitlink_result(handle, code, __FILE__, __LINE__))
37
+ #define check_cufftdx(code) (check_cufftdx_result(code, __FILE__, __LINE__))
38
+ #define check_cublasdx(code) (check_cublasdx_result(code, __FILE__, __LINE__))
39
+ #define CHECK_ANY(code) \
40
+ { \
41
+ do { \
42
+ bool out = (check_any(code)); \
43
+ if(!out) { \
44
+ return out; \
45
+ } \
46
+ } while(0); \
47
+ }
48
+ #define CHECK_CUFFTDX(code) \
49
+ { \
50
+ do { \
51
+ bool out = (check_cufftdx(code)); \
52
+ if(!out) { \
53
+ return out; \
54
+ } \
55
+ } while(0); \
56
+ }
57
+ #define CHECK_CUBLASDX(code) \
58
+ { \
59
+ do { \
60
+ bool out = (check_cufftdx(code)); \
61
+ if(!out) { \
62
+ return out; \
63
+ } \
64
+ } while(0); \
65
+ }
28
66
 
29
67
  bool check_nvrtc_result(nvrtcResult result, const char* file, int line)
30
68
  {
@@ -74,6 +112,15 @@ bool check_nvptx_result(nvPTXCompileResult result, const char* file, int line)
74
112
  return false;
75
113
  }
76
114
 
115
+ bool check_generic(int result, const char* file, int line)
116
+ {
117
+ if (!result) {
118
+ fprintf(stderr, "Error %d on %s:%d\n", (int)result, file, line);
119
+ return false;
120
+ } else {
121
+ return true;
122
+ }
123
+ }
77
124
 
78
125
  struct DeviceInfo
79
126
  {
@@ -89,6 +136,7 @@ struct DeviceInfo
89
136
  int arch = 0;
90
137
  int is_uva = 0;
91
138
  int is_mempool_supported = 0;
139
+ int max_smem_bytes = 0;
92
140
  CUcontext primary_context = NULL;
93
141
  };
94
142
 
@@ -202,6 +250,7 @@ int cuda_init()
202
250
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
203
251
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
204
252
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
253
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].max_smem_bytes, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
205
254
  int major = 0;
206
255
  int minor = 0;
207
256
  check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
@@ -2520,11 +2569,57 @@ bool cuda_graph_destroy(void* context, void* graph_exec)
2520
2569
  return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
2521
2570
  }
2522
2571
 
2523
- size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path)
2572
+ bool write_file(const char* data, size_t size, std::string filename, const char* mode)
2573
+ {
2574
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
2575
+ if (print_debug)
2576
+ {
2577
+ printf("Writing %zu B to %s (%s)\n", size, filename.c_str(), mode);
2578
+ }
2579
+ FILE* file = fopen(filename.c_str(), mode);
2580
+ if (file)
2581
+ {
2582
+ if (fwrite(data, 1, size, file) != size) {
2583
+ fprintf(stderr, "Warp error: Failed to write to output file '%s'\n", filename.c_str());
2584
+ return false;
2585
+ }
2586
+ fclose(file);
2587
+ return true;
2588
+ }
2589
+ else
2590
+ {
2591
+ fprintf(stderr, "Warp error: Failed to open file '%s'\n", filename.c_str());
2592
+ return false;
2593
+ }
2594
+ }
2595
+
2596
+ #if WP_ENABLE_MATHDX
2597
+ bool check_nvjitlink_result(nvJitLinkHandle handle, nvJitLinkResult result, const char* file, int line)
2598
+ {
2599
+ if (result != NVJITLINK_SUCCESS) {
2600
+ fprintf(stderr, "nvJitLink error: %d on %s:%d\n", (int)result, file, line);
2601
+ size_t lsize;
2602
+ result = nvJitLinkGetErrorLogSize(handle, &lsize);
2603
+ if (result == NVJITLINK_SUCCESS && lsize > 0) {
2604
+ std::vector<char> log(lsize);
2605
+ result = nvJitLinkGetErrorLog(handle, log.data());
2606
+ if (result == NVJITLINK_SUCCESS) {
2607
+ fprintf(stderr, "%s\n", log.data());
2608
+ }
2609
+ }
2610
+ return false;
2611
+ } else {
2612
+ return true;
2613
+ }
2614
+ }
2615
+ #endif
2616
+
2617
+ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes)
2524
2618
  {
2525
2619
  // use file extension to determine whether to output PTX or CUBIN
2526
2620
  const char* output_ext = strrchr(output_path, '.');
2527
2621
  bool use_ptx = output_ext && strcmp(output_ext + 1, "ptx") == 0;
2622
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
2528
2623
 
2529
2624
  // check include dir path len (path + option)
2530
2625
  const int max_path = 4096 + 16;
@@ -2534,17 +2629,37 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2534
2629
  return size_t(-1);
2535
2630
  }
2536
2631
 
2632
+ if (print_debug)
2633
+ {
2634
+ // Not available in all nvJitLink versions
2635
+ // unsigned major = 0;
2636
+ // unsigned minor = 0;
2637
+ // nvJitLinkVersion(&major, &minor);
2638
+ // printf("nvJitLink version %d.%d\n", major, minor);
2639
+ int major = 0;
2640
+ int minor = 0;
2641
+ nvrtcVersion(&major, &minor);
2642
+ printf("NVRTC version %d.%d\n", major, minor);
2643
+ }
2644
+
2537
2645
  char include_opt[max_path];
2538
2646
  strcpy(include_opt, "--include-path=");
2539
2647
  strcat(include_opt, include_dir);
2540
2648
 
2541
2649
  const int max_arch = 128;
2542
2650
  char arch_opt[max_arch];
2651
+ char arch_opt_lto[max_arch];
2543
2652
 
2544
2653
  if (use_ptx)
2654
+ {
2545
2655
  snprintf(arch_opt, max_arch, "--gpu-architecture=compute_%d", arch);
2656
+ snprintf(arch_opt_lto, max_arch, "-arch=compute_%d", arch);
2657
+ }
2546
2658
  else
2659
+ {
2547
2660
  snprintf(arch_opt, max_arch, "--gpu-architecture=sm_%d", arch);
2661
+ snprintf(arch_opt_lto, max_arch, "-arch=sm_%d", arch);
2662
+ }
2548
2663
 
2549
2664
  std::vector<const char*> opts;
2550
2665
  opts.push_back(arch_opt);
@@ -2555,6 +2670,7 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2555
2670
  {
2556
2671
  opts.push_back("--define-macro=_DEBUG");
2557
2672
  opts.push_back("--generate-line-info");
2673
+
2558
2674
  // disabling since it causes issues with `Unresolved extern function 'cudaGetParameterBufferV2'
2559
2675
  //opts.push_back("--device-debug");
2560
2676
  }
@@ -2569,6 +2685,26 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2569
2685
  if (fast_math)
2570
2686
  opts.push_back("--use_fast_math");
2571
2687
 
2688
+ char include_cutlass[max_path];
2689
+ sprintf(include_cutlass, "--include-path=%s/cutlass/include", include_dir);
2690
+ opts.push_back(include_cutlass);
2691
+
2692
+ std::vector<std::string> cuda_include_opt;
2693
+ for(int i = 0; i < num_cuda_include_dirs; i++)
2694
+ {
2695
+ cuda_include_opt.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
2696
+ opts.push_back(cuda_include_opt.back().c_str());
2697
+ }
2698
+
2699
+ opts.push_back("--device-as-default-execution-space");
2700
+ opts.push_back("--extra-device-vectorization");
2701
+ opts.push_back("--restrict");
2702
+
2703
+ if (num_ltoirs > 0)
2704
+ {
2705
+ opts.push_back("-dlto");
2706
+ opts.push_back("--relocatable-device-code=true");
2707
+ }
2572
2708
 
2573
2709
  nvrtcProgram prog;
2574
2710
  nvrtcResult res;
@@ -2584,6 +2720,13 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2584
2720
  if (!check_nvrtc(res))
2585
2721
  return size_t(res);
2586
2722
 
2723
+ if (print_debug)
2724
+ {
2725
+ printf("NVRTC options:\n");
2726
+ for(auto o: opts) {
2727
+ printf("%s\n", o);
2728
+ }
2729
+ }
2587
2730
  res = nvrtcCompileProgram(prog, int(opts.size()), opts.data());
2588
2731
 
2589
2732
  if (!check_nvrtc(res) || verbose)
@@ -2613,7 +2756,17 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2613
2756
  nvrtcResult (*get_output_size)(nvrtcProgram, size_t*);
2614
2757
  nvrtcResult (*get_output_data)(nvrtcProgram, char*);
2615
2758
  const char* output_mode;
2616
- if (use_ptx)
2759
+ if(num_ltoirs > 0) {
2760
+ #if WP_ENABLE_MATHDX
2761
+ get_output_size = nvrtcGetLTOIRSize;
2762
+ get_output_data = nvrtcGetLTOIR;
2763
+ output_mode = "wb";
2764
+ #else
2765
+ fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
2766
+ return size_t(-1);
2767
+ #endif
2768
+ }
2769
+ else if (use_ptx)
2617
2770
  {
2618
2771
  get_output_size = nvrtcGetPTXSize;
2619
2772
  get_output_data = nvrtcGetPTX;
@@ -2635,19 +2788,78 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2635
2788
  res = get_output_data(prog, output.data());
2636
2789
  if (check_nvrtc(res))
2637
2790
  {
2638
- FILE* file = fopen(output_path, output_mode);
2639
- if (file)
2791
+
2792
+ // LTOIR case - need an extra step
2793
+ if (num_ltoirs > 0)
2640
2794
  {
2641
- if (fwrite(output.data(), 1, output_size, file) != output_size)
2795
+ #if WP_ENABLE_MATHDX
2796
+ nvJitLinkHandle handle;
2797
+ std::vector<const char *> lopts = {"-dlto", arch_opt_lto};
2798
+ if (use_ptx) {
2799
+ lopts.push_back("-ptx");
2800
+ }
2801
+ if (print_debug)
2802
+ {
2803
+ printf("nvJitLink options:\n");
2804
+ for(auto o: lopts) {
2805
+ printf("%s\n", o);
2806
+ }
2807
+ }
2808
+ if(!check_nvjitlink(handle, nvJitLinkCreate(&handle, lopts.size(), lopts.data())))
2809
+ {
2810
+ res = nvrtcResult(-1);
2811
+ }
2812
+ // Links
2813
+ if(std::getenv("WARP_DUMP_LTOIR"))
2814
+ {
2815
+ write_file(output.data(), output.size(), "nvrtc_output.ltoir", "wb");
2816
+ }
2817
+ if(!check_nvjitlink(handle, nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, output.data(), output.size(), "nvrtc_output"))) // NVRTC business
2642
2818
  {
2643
- fprintf(stderr, "Warp error: Failed to write output file '%s'\n", output_path);
2644
2819
  res = nvrtcResult(-1);
2645
2820
  }
2646
- fclose(file);
2821
+ for(size_t ltoidx = 0; ltoidx < num_ltoirs; ltoidx++)
2822
+ {
2823
+ if(std::getenv("WARP_DUMP_LTOIR"))
2824
+ {
2825
+ write_file(ltoirs[ltoidx], ltoir_sizes[ltoidx], std::string("lto_online_") + std::to_string(ltoidx) + ".ltoir", "wb");
2826
+ }
2827
+ if(!check_nvjitlink(handle, nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, ltoirs[ltoidx], ltoir_sizes[ltoidx], "lto_online"))) // External LTOIR
2828
+ {
2829
+ res = nvrtcResult(-1);
2830
+ }
2831
+ }
2832
+ if(!check_nvjitlink(handle, nvJitLinkComplete(handle)))
2833
+ {
2834
+ res = nvrtcResult(-1);
2835
+ }
2836
+ else
2837
+ {
2838
+ if(use_ptx)
2839
+ {
2840
+ size_t ptx_size = 0;
2841
+ check_nvjitlink(handle, nvJitLinkGetLinkedPtxSize(handle, &ptx_size));
2842
+ std::vector<char> ptx(ptx_size);
2843
+ check_nvjitlink(handle, nvJitLinkGetLinkedPtx(handle, ptx.data()));
2844
+ output = ptx;
2845
+ }
2846
+ else
2847
+ {
2848
+ size_t cubin_size = 0;
2849
+ check_nvjitlink(handle, nvJitLinkGetLinkedCubinSize(handle, &cubin_size));
2850
+ std::vector<char> cubin(cubin_size);
2851
+ check_nvjitlink(handle, nvJitLinkGetLinkedCubin(handle, cubin.data()));
2852
+ output = cubin;
2853
+ }
2854
+ }
2855
+ check_nvjitlink(handle, nvJitLinkDestroy(&handle));
2856
+ #else
2857
+ fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
2858
+ return size_t(-1);
2859
+ #endif
2647
2860
  }
2648
- else
2649
- {
2650
- fprintf(stderr, "Warp error: Failed to open output file '%s'\n", output_path);
2861
+
2862
+ if(!write_file(output.data(), output.size(), output_path, output_mode)) {
2651
2863
  res = nvrtcResult(-1);
2652
2864
  }
2653
2865
  }
@@ -2658,6 +2870,119 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2658
2870
  return res;
2659
2871
  }
2660
2872
 
2873
+ #if WP_ENABLE_MATHDX
2874
+ bool check_cufftdx_result(commonDxStatusType result, const char* file, int line)
2875
+ {
2876
+ if (result != commonDxStatusType::COMMONDX_SUCCESS) {
2877
+ fprintf(stderr, "libmathdx cuFFTDx error: %d on %s:%d\n", (int)result, file, line);
2878
+ return false;
2879
+ } else {
2880
+ return true;
2881
+ }
2882
+ }
2883
+
2884
+ bool check_cublasdx_result(commonDxStatusType result, const char* file, int line)
2885
+ {
2886
+ if (result != commonDxStatusType::COMMONDX_SUCCESS) {
2887
+ fprintf(stderr, "libmathdx cuBLASDx error: %d on %s:%d\n", (int)result, file, line);
2888
+ return false;
2889
+ } else {
2890
+ return true;
2891
+ }
2892
+ }
2893
+
2894
+ bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size)
2895
+ {
2896
+
2897
+ CHECK_ANY(ltoir_output_path != nullptr);
2898
+ CHECK_ANY(symbol_name != nullptr);
2899
+ CHECK_ANY(shared_memory_size != nullptr);
2900
+ // Includes currently unused
2901
+ CHECK_ANY(include_dirs == nullptr);
2902
+ CHECK_ANY(mathdx_include_dir == nullptr);
2903
+ CHECK_ANY(num_include_dirs == 0);
2904
+
2905
+ bool res = true;
2906
+ cufftdxHandle h;
2907
+ CHECK_CUFFTDX(cufftDxCreate(&h));
2908
+
2909
+ // CUFFTDX_API_BLOCK_LMEM means each thread starts with a subset of the data
2910
+ CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_API, cufftDxApi::CUFFTDX_API_BLOCK_LMEM));
2911
+ CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commonDxExecution::COMMONDX_EXECUTION_BLOCK));
2912
+ CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
2913
+ CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftDxDirection)direction));
2914
+ CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_PRECISION, (commonDxPrecision)precision));
2915
+ CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_SM, (long long)(arch * 10)));
2916
+ CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_ELEMENTS_PER_THREAD, (long long)(elements_per_thread)));
2917
+ CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_FFTS_PER_BLOCK, 1));
2918
+
2919
+ CHECK_CUFFTDX(cufftDxSetOptionStr(h, commonDxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
2920
+
2921
+ size_t lto_size = 0;
2922
+ CHECK_CUFFTDX(cufftDxGetLTOIRSize(h, &lto_size));
2923
+
2924
+ std::vector<char> lto(lto_size);
2925
+ CHECK_CUFFTDX(cufftDxGetLTOIR(h, lto.size(), lto.data()));
2926
+
2927
+ long long int smem = 0;
2928
+ CHECK_CUFFTDX(cufftDxGetTraitInt64(h, cufftDxTraitType::CUFFTDX_TRAIT_SHARED_MEMORY_SIZE, &smem));
2929
+ *shared_memory_size = (int)smem;
2930
+
2931
+ if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
2932
+ res = false;
2933
+ }
2934
+
2935
+ CHECK_CUFFTDX(cufftDxDestroy(h));
2936
+
2937
+ return res;
2938
+ }
2939
+
2940
+ bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads)
2941
+ {
2942
+
2943
+ CHECK_ANY(ltoir_output_path != nullptr);
2944
+ CHECK_ANY(symbol_name != nullptr);
2945
+ // Includes currently unused
2946
+ CHECK_ANY(include_dirs == nullptr);
2947
+ CHECK_ANY(mathdx_include_dir == nullptr);
2948
+ CHECK_ANY(num_include_dirs == 0);
2949
+
2950
+ bool res = true;
2951
+ cublasdxHandle h;
2952
+ CHECK_CUBLASDX(cublasDxCreate(&h));
2953
+
2954
+ CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasDxFunction::CUBLASDX_FUNCTION_MM));
2955
+ CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commonDxExecution::COMMONDX_EXECUTION_BLOCK));
2956
+ CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_API, cublasDxApi::CUBLASDX_API_BLOCK_SMEM));
2957
+ std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
2958
+ CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
2959
+ CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
2960
+ CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasDxType)type));
2961
+ std::array<long long int, 3> block_dim = {num_threads, 1, 1};
2962
+ CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
2963
+ std::array<long long int, 3> size = {M, N, K};
2964
+ CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
2965
+ std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
2966
+ CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
2967
+
2968
+ CHECK_CUBLASDX(cublasDxSetOptionStr(h, commonDxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
2969
+
2970
+ size_t lto_size = 0;
2971
+ CHECK_CUBLASDX(cublasDxGetLTOIRSize(h, &lto_size));
2972
+
2973
+ std::vector<char> lto(lto_size);
2974
+ CHECK_CUBLASDX(cublasDxGetLTOIR(h, lto.size(), lto.data()));
2975
+
2976
+ if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
2977
+ res = false;
2978
+ }
2979
+
2980
+ CHECK_CUBLASDX(cublasDxDestroy(h));
2981
+
2982
+ return res;
2983
+ }
2984
+ #endif
2985
+
2661
2986
  void* cuda_load_module(void* context, const char* path)
2662
2987
  {
2663
2988
  ContextGuard guard(context);
@@ -2784,6 +3109,29 @@ void cuda_unload_module(void* context, void* module)
2784
3109
  check_cu(cuModuleUnload_f((CUmodule)module));
2785
3110
  }
2786
3111
 
3112
+
3113
+ int cuda_get_max_shared_memory(void* context)
3114
+ {
3115
+ ContextInfo* info = get_context_info(context);
3116
+ if (!info)
3117
+ return -1;
3118
+
3119
+ int max_smem_bytes = info->device_info->max_smem_bytes;
3120
+ return max_smem_bytes;
3121
+ }
3122
+
3123
+ bool cuda_configure_kernel_shared_memory(void* kernel, int size)
3124
+ {
3125
+ int requested_smem_bytes = size;
3126
+
3127
+ // configure shared memory
3128
+ CUresult res = cuFuncSetAttribute_f((CUfunction)kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, requested_smem_bytes);
3129
+ if (res != CUDA_SUCCESS)
3130
+ return false;
3131
+
3132
+ return true;
3133
+ }
3134
+
2787
3135
  void* cuda_get_kernel(void* context, void* module, const char* name)
2788
3136
  {
2789
3137
  ContextGuard guard(context);
@@ -2796,15 +3144,21 @@ void* cuda_get_kernel(void* context, void* module, const char* name)
2796
3144
  }
2797
3145
 
2798
3146
  g_kernel_names[kernel] = name;
2799
-
2800
3147
  return kernel;
2801
3148
  }
2802
3149
 
2803
- size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args, void* stream)
3150
+ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream)
2804
3151
  {
2805
3152
  ContextGuard guard(context);
2806
3153
 
2807
- const int block_dim = 256;
3154
+ if (block_dim <= 0)
3155
+ {
3156
+ #if defined(_DEBUG)
3157
+ fprintf(stderr, "Warp warning: Launch got block_dim %d. Setting to 256.\n", dim, block_dim);
3158
+ #endif
3159
+ block_dim = 256;
3160
+ }
3161
+
2808
3162
  // CUDA specs up to compute capability 9.0 says the max x-dim grid is 2**31-1, so
2809
3163
  // grid_dim is fine as an int for the near future
2810
3164
  int grid_dim = (dim + block_dim - 1)/block_dim;
@@ -2835,7 +3189,8 @@ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_block
2835
3189
  (CUfunction)kernel,
2836
3190
  grid_dim, 1, 1,
2837
3191
  block_dim, 1, 1,
2838
- 0, static_cast<CUstream>(stream),
3192
+ shared_memory_bytes,
3193
+ static_cast<CUstream>(stream),
2839
3194
  args,
2840
3195
  0);
2841
3196
 
@@ -2940,7 +3295,6 @@ void cuda_timing_end(timing_result_t* results, int size)
2940
3295
  g_cuda_timing_state = parent_state;
2941
3296
  }
2942
3297
 
2943
-
2944
3298
  // impl. files
2945
3299
  #include "bvh.cu"
2946
3300
  #include "mesh.cu"
warp/native/warp.h CHANGED
@@ -34,6 +34,8 @@ extern "C"
34
34
  WP_API int is_cuda_compatibility_enabled();
35
35
  // whether Warp was compiled with CUTLASS support
36
36
  WP_API int is_cutlass_enabled();
37
+ // whether Warp was compiled with MathDx support
38
+ WP_API int is_mathdx_enabled();
37
39
  // whether Warp was compiled with debug support
38
40
  WP_API int is_debug_enabled();
39
41
 
@@ -315,12 +317,16 @@ extern "C"
315
317
  WP_API bool cuda_graph_launch(void* graph, void* stream);
316
318
  WP_API bool cuda_graph_destroy(void* context, void* graph);
317
319
 
318
- WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_file);
320
+ WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes);
321
+ WP_API bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size);
322
+ WP_API bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads);
319
323
 
320
324
  WP_API void* cuda_load_module(void* context, const char* ptx);
321
325
  WP_API void cuda_unload_module(void* context, void* module);
322
326
  WP_API void* cuda_get_kernel(void* context, void* module, const char* name);
323
- WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args, void* stream);
327
+ WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream);
328
+ WP_API int cuda_get_max_shared_memory(void* context);
329
+ WP_API bool cuda_configure_kernel_shared_memory(void* kernel, int size);
324
330
 
325
331
  WP_API void cuda_set_context_restore_policy(bool always_restore);
326
332
  WP_API int cuda_get_context_restore_policy();
@@ -336,4 +342,8 @@ extern "C"
336
342
  WP_API int cuda_timing_get_result_count();
337
343
  WP_API void cuda_timing_end(timing_result_t* results, int size);
338
344
 
345
+ // graph coloring
346
+ WP_API int graph_coloring(int num_nodes, wp::array_t<int> edges, int algorithm, wp::array_t<int> node_colors);
347
+ WP_API float balance_coloring(int num_nodes, wp::array_t<int> edges, int num_colors, float target_max_min_ratio, wp::array_t<int> node_colors);
348
+
339
349
  } // extern "C"
warp/optim/adam.py CHANGED
@@ -50,6 +50,26 @@ def adam_step_kernel_float(
50
50
  params[i] = params[i] - lr * mhat / (wp.sqrt(vhat) + eps)
51
51
 
52
52
 
53
+ @wp.kernel
54
+ def adam_step_kernel_half(
55
+ g: wp.array(dtype=wp.float16),
56
+ m: wp.array(dtype=float),
57
+ v: wp.array(dtype=float),
58
+ lr: float,
59
+ beta1: float,
60
+ beta2: float,
61
+ t: float,
62
+ eps: float,
63
+ params: wp.array(dtype=wp.float16),
64
+ ):
65
+ i = wp.tid()
66
+ m[i] = beta1 * m[i] + (1.0 - beta1) * float(g[i])
67
+ v[i] = beta2 * v[i] + (1.0 - beta2) * float(g[i]) * float(g[i])
68
+ mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
69
+ vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
70
+ params[i] = params[i] - wp.float16(lr * mhat / (wp.sqrt(vhat) + eps))
71
+
72
+
53
73
  class Adam:
54
74
  """An implementation of the Adam Optimizer
55
75
  It is designed to mimic Pytorch's version.
@@ -75,10 +95,20 @@ class Adam:
75
95
  self.v = [None] * len(params) # reset second moment
76
96
  for i in range(len(params)):
77
97
  param = params[i]
98
+
99
+ if param.dtype == wp.vec3:
100
+ dtype = wp.vec3
101
+ elif param.dtype == wp.float32:
102
+ dtype = wp.float32
103
+ elif param.dtype == wp.float16:
104
+ dtype = wp.float32 # we always use fp32 for moments, even if params are fp16
105
+ else:
106
+ raise RuntimeError(f"Unsupported dtype for Warp Adam optimizer: {param.dtype}")
107
+
78
108
  if self.m[i] is None or self.m[i].shape != param.shape or self.m[i].dtype != param.dtype:
79
- self.m[i] = wp.zeros_like(param)
109
+ self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
80
110
  if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype:
81
- self.v[i] = wp.zeros_like(param)
111
+ self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
82
112
 
83
113
  def reset_internal_state(self):
84
114
  for m_i in self.m:
@@ -98,8 +128,6 @@ class Adam:
98
128
  @staticmethod
99
129
  def step_detail(g, m, v, lr, beta1, beta2, t, eps, params):
100
130
  assert params.dtype == g.dtype
101
- assert params.dtype == m.dtype
102
- assert params.dtype == v.dtype
103
131
  assert params.shape == g.shape
104
132
  kernel_inputs = [g, m, v, lr, beta1, beta2, t, eps, params]
105
133
  if params.dtype == wp.types.float32:
@@ -109,6 +137,13 @@ class Adam:
109
137
  inputs=kernel_inputs,
110
138
  device=params.device,
111
139
  )
140
+ elif params.dtype == wp.types.float16:
141
+ wp.launch(
142
+ kernel=adam_step_kernel_half,
143
+ dim=len(params),
144
+ inputs=kernel_inputs,
145
+ device=params.device,
146
+ )
112
147
  elif params.dtype == wp.types.vec3:
113
148
  wp.launch(
114
149
  kernel=adam_step_kernel_vec3,