warp-lang 1.9.0__py3-none-macosx_10_13_universal2.whl → 1.9.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

warp/native/tile.h CHANGED
@@ -542,7 +542,7 @@ struct tile_register_t
542
542
 
543
543
  // define the += operator which is used during backward pass codegen
544
544
  // when returning a register tile from a user defined function
545
- inline CUDA_CALLABLE auto& operator += (tile_register_t<T, Layout>& rhs)
545
+ inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, Layout>& rhs)
546
546
  {
547
547
  grad_add(rhs);
548
548
  return *this;
@@ -658,7 +658,7 @@ struct tile_register_t
658
658
  data[i] += tile.data[i];
659
659
  }
660
660
 
661
- CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
661
+ inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
662
662
  {
663
663
  apply([&](int reg, auto c) {data[reg] += global.load_grad(c);});
664
664
  }
@@ -758,6 +758,7 @@ inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, boo
758
758
 
759
759
  // one entry per-thread so no need for synchronization
760
760
  smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
761
+ assert(smem_base[WP_TILE_THREAD_IDX] >= 0);
761
762
 
762
763
  #ifdef __CUDA_ARCH__
763
764
  extern __shared__ char dynamic_smem_base[];
@@ -905,6 +906,28 @@ struct tile_shared_t
905
906
  {
906
907
  }
907
908
 
909
+ // we delete the copy constructor because in the case the shared tile is owning,
910
+ // this leads to a double deallocation.
911
+ // this also forces one to handle copies explicitly
912
+ inline CUDA_CALLABLE tile_shared_t(const tile_shared_t& other) : data(other.data), grad(other.grad), initialized(other.initialized)
913
+ {
914
+ static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
915
+ }
916
+
917
+ // move constructor
918
+ inline CUDA_CALLABLE tile_shared_t(tile_shared_t&& other) : data(other.data), grad(other.grad), initialized(other.initialized)
919
+ {
920
+ other.data.ptr = nullptr;
921
+ other.grad.ptr = nullptr;
922
+ }
923
+
924
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
925
+ inline CUDA_CALLABLE tile_shared_t(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& other) : data(other.data.ptr), grad(other.grad.ptr), initialized(other.initialized)
926
+ {
927
+ static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
928
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
929
+ }
930
+
908
931
  // initialize from an existing tile's memory
909
932
  inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
910
933
  {
@@ -932,19 +955,47 @@ struct tile_shared_t
932
955
 
933
956
  // construct from another shared tile, this constructor
934
957
  // is invoked for reshape operations like `wp.tile_transpose()`
958
+ // or `wp::copy()`
935
959
  template <typename OtherT, typename OtherLayout, bool OtherOwner>
936
960
  inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& rhs)
937
961
  {
938
962
  // check dimensions are compatible
939
963
  static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
940
964
 
941
- // alias tile directly
942
- data.ptr = rhs.data.ptr;
943
- grad.ptr = rhs.grad.ptr;
944
- initialized = rhs.initialized;
965
+
966
+ if (Owner)
967
+ {
968
+ // if the tile owns the data we need to copy
969
+ assign(rhs);
970
+ }
971
+ else
972
+ {
973
+ // alias tile directly
974
+ data.ptr = rhs.data.ptr;
975
+ grad.ptr = rhs.grad.ptr;
976
+ initialized = rhs.initialized;
977
+ }
945
978
 
946
979
  return *this;
947
- }
980
+ }
981
+
982
+ inline CUDA_CALLABLE auto& operator=(const tile_shared_t& rhs)
983
+ {
984
+ if (Owner)
985
+ {
986
+ // if the tile owns the data we need to copy
987
+ assign(rhs);
988
+ }
989
+ else
990
+ {
991
+ // alias tile directly
992
+ data.ptr = rhs.data.ptr;
993
+ grad.ptr = rhs.grad.ptr;
994
+ initialized = rhs.initialized;
995
+ }
996
+
997
+ return *this;
998
+ }
948
999
 
949
1000
  // assign from a global tile (load)
950
1001
 
@@ -972,6 +1023,21 @@ struct tile_shared_t
972
1023
  return *this;
973
1024
  }
974
1025
 
1026
+ // define the += operator which is used during backward pass codegen
1027
+ // when returning a register tile from a user defined function
1028
+ template<typename OtherLayout>
1029
+ inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, OtherLayout>& rhs)
1030
+ {
1031
+ grad_add(rhs);
1032
+ return *this;
1033
+ }
1034
+
1035
+ inline CUDA_CALLABLE auto& operator += (const tile_shared_t<T, Layout>& rhs)
1036
+ {
1037
+ grad_add(rhs);
1038
+ return *this;
1039
+ }
1040
+
975
1041
  // in-place zero
976
1042
  inline CUDA_CALLABLE void zero()
977
1043
  {
@@ -1053,6 +1119,27 @@ struct tile_shared_t
1053
1119
  WP_TILE_SYNC();
1054
1120
  }
1055
1121
 
1122
+ // shared tile deep copy
1123
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
1124
+ inline CUDA_CALLABLE void assign(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& tile)
1125
+ {
1126
+ // check dimensions are compatible
1127
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
1128
+
1129
+ if (initialized)
1130
+ WP_TILE_SYNC();
1131
+
1132
+ WP_PRAGMA_UNROLL
1133
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1134
+ {
1135
+ auto c = Layout::coord_from_linear(i);
1136
+ data(c) = tile.data(c);
1137
+ }
1138
+
1139
+ initialized = true;
1140
+ WP_TILE_SYNC();
1141
+ }
1142
+
1056
1143
  // in-place gradient zero
1057
1144
  inline CUDA_CALLABLE void grad_zero()
1058
1145
  {
@@ -1092,8 +1179,21 @@ struct tile_shared_t
1092
1179
  WP_TILE_SYNC();
1093
1180
  }
1094
1181
 
1182
+ // accumulate gradients onto this tile from another shared tile
1183
+ inline CUDA_CALLABLE void grad_add(const tile_shared_t<T, Layout>& tile)
1184
+ {
1185
+ WP_PRAGMA_UNROLL
1186
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1187
+ {
1188
+ auto c = Layout::coord_from_linear(i);
1189
+ grad(c) += tile.grad(c);
1190
+ }
1191
+
1192
+ WP_TILE_SYNC();
1193
+ }
1194
+
1095
1195
  // accumulate gradient onto this tile from a global array
1096
- CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1196
+ inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1097
1197
  {
1098
1198
  WP_PRAGMA_UNROLL
1099
1199
  for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
@@ -1477,9 +1577,16 @@ void tile_register_t<T, L>::print() const
1477
1577
  // print entry points
1478
1578
  template <typename T, typename L>
1479
1579
  inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
1580
+
1581
+ template <typename T, typename L>
1582
+ inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
1583
+
1480
1584
  template <typename T, typename L, bool Owner>
1481
1585
  inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
1482
1586
 
1587
+ template <typename T, typename L, bool Owner>
1588
+ inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
1589
+
1483
1590
  template <typename T, typename L, bool O>
1484
1591
  inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
1485
1592
  {
@@ -1502,13 +1609,81 @@ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile&
1502
1609
  {
1503
1610
  }
1504
1611
 
1612
+ // select specialization for shared tiles
1613
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1614
+ inline CUDA_CALLABLE auto select(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
1615
+ {
1616
+ // The double NOT operator !! casts to bool without compiler warnings.
1617
+ return (!!cond) ? b.copy_to_register() : a;
1618
+ }
1505
1619
 
1506
- template <typename T, typename L>
1507
- inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
1508
- template <typename T, typename L, bool Owner>
1509
- inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
1620
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1621
+ inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
1622
+ {
1623
+ // The double NOT operator !! casts to bool without compiler warnings.
1624
+ return (!!cond) ? b : a.copy_to_register();
1625
+ }
1626
+
1627
+ template <typename C, typename T, typename L, bool Owner>
1628
+ inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
1629
+ {
1630
+ // The double NOT operator !! casts to bool without compiler warnings.
1631
+ return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
1632
+ }
1633
+
1634
+ template <typename C, typename T, typename L, bool LOwner, bool ROwner>
1635
+ inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
1636
+ {
1637
+ // The double NOT operator !! casts to bool without compiler warnings.
1638
+ return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
1639
+ }
1510
1640
 
1641
+ // adj_select same as in builtin.h
1511
1642
 
1643
+ // where specialization for register/shared tiles
1644
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1645
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
1646
+ {
1647
+ // The double NOT operator !! casts to bool without compiler warnings.
1648
+ return (!!cond) ? a : b.copy_to_register();
1649
+ }
1650
+
1651
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1652
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
1653
+ {
1654
+ // The double NOT operator !! casts to bool without compiler warnings.
1655
+ return (!!cond) ? a.copy_to_register() : b;
1656
+ }
1657
+
1658
+ template <typename C, typename T, typename L, bool Owner>
1659
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
1660
+ {
1661
+ // The double NOT operator !! casts to bool without compiler warnings.
1662
+ return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
1663
+ }
1664
+
1665
+ template <typename C, typename T, typename L, bool LOwner, bool ROwner>
1666
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
1667
+ {
1668
+ // The double NOT operator !! casts to bool without compiler warnings.
1669
+ return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
1670
+ }
1671
+
1672
+ // adj_where same as in builtin.h
1673
+
1674
+ // copy specialization for shared tiles, the lvalue this gets assigned to is owning, thus, this invokes the copy assign path
1675
+ template <typename T, typename L, bool Owner>
1676
+ inline CUDA_CALLABLE auto copy(const tile_shared_t<T, L, Owner>& t)
1677
+ {
1678
+ return tile_shared_t<T, L, false>(t.data.ptr, t.grad.ptr);
1679
+ }
1680
+
1681
+ template <typename T, typename L, bool Owner>
1682
+ inline CUDA_CALLABLE void adj_copy(const tile_shared_t<T, L, Owner>& src, tile_shared_t<T, L, Owner>& adj_src, tile_shared_t<T, L, Owner>& adj_dest)
1683
+ {
1684
+ adj_src += adj_dest;
1685
+ adj_dest.grad_zero();
1686
+ }
1512
1687
 
1513
1688
  // helpers to allocate shared tiles
1514
1689
  template <typename T, typename Shape, typename Strides, bool RequiresGrad>
@@ -3048,7 +3223,7 @@ template <typename Tile, typename AdjTile>
3048
3223
  inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
3049
3224
  {
3050
3225
  auto a = tile_transpose(adj_ret);
3051
- auto b = adj_t;
3226
+ auto& b = adj_t;
3052
3227
 
3053
3228
  adj_t.assign(tile_add(a,b));
3054
3229
  }
warp/native/vec.h CHANGED
@@ -343,17 +343,6 @@ inline CUDA_CALLABLE vec_t<Length, Type> add(vec_t<Length, Type> a, vec_t<Length
343
343
  return ret;
344
344
  }
345
345
 
346
- template<unsigned Length, typename Type>
347
- inline CUDA_CALLABLE vec_t<Length, Type> add(Type a, vec_t<Length, Type> b)
348
- {
349
- vec_t<Length, Type> ret;
350
- for( unsigned i=0; i < Length; ++i )
351
- {
352
- ret[i] = a + b[i];
353
- }
354
- return ret;
355
- }
356
-
357
346
  template<typename Type>
358
347
  inline CUDA_CALLABLE vec_t<2, Type> add(vec_t<2, Type> a, vec_t<2, Type> b)
359
348
  {
@@ -378,18 +367,6 @@ inline CUDA_CALLABLE vec_t<Length, Type> sub(vec_t<Length, Type> a, vec_t<Length
378
367
  return ret;
379
368
  }
380
369
 
381
- template<unsigned Length, typename Type>
382
- inline CUDA_CALLABLE vec_t<Length, Type> sub(Type a, vec_t<Length, Type> b)
383
- {
384
- vec_t<Length, Type> ret;
385
- for (unsigned i=0; i < Length; ++i)
386
- {
387
- ret[i] = Type(a - b[i]);
388
- }
389
-
390
- return ret;
391
- }
392
-
393
370
  template<typename Type>
394
371
  inline CUDA_CALLABLE vec_t<2, Type> sub(vec_t<2, Type> a, vec_t<2, Type> b)
395
372
  {
@@ -1303,21 +1280,6 @@ inline CUDA_CALLABLE void adj_add(vec_t<Length, Type> a, vec_t<Length, Type> b,
1303
1280
  adj_b += adj_ret;
1304
1281
  }
1305
1282
 
1306
- template<unsigned Length, typename Type>
1307
- inline CUDA_CALLABLE void adj_add(
1308
- Type a, vec_t<Length, Type> b,
1309
- Type& adj_a, vec_t<Length, Type>& adj_b,
1310
- const vec_t<Length, Type>& adj_ret
1311
- )
1312
- {
1313
- for (unsigned i = 0; i < Length; ++i)
1314
- {
1315
- adj_a += adj_ret.c[i];
1316
- }
1317
-
1318
- adj_b += adj_ret;
1319
- }
1320
-
1321
1283
  template<typename Type>
1322
1284
  inline CUDA_CALLABLE void adj_add(vec_t<2, Type> a, vec_t<2, Type> b, vec_t<2, Type>& adj_a, vec_t<2, Type>& adj_b, const vec_t<2, Type>& adj_ret)
1323
1285
  {
@@ -1345,21 +1307,6 @@ inline CUDA_CALLABLE void adj_sub(vec_t<Length, Type> a, vec_t<Length, Type> b,
1345
1307
  adj_b -= adj_ret;
1346
1308
  }
1347
1309
 
1348
- template<unsigned Length, typename Type>
1349
- inline CUDA_CALLABLE void adj_sub(
1350
- Type a, vec_t<Length, Type> b,
1351
- Type& adj_a, vec_t<Length, Type>& adj_b,
1352
- const vec_t<Length, Type>& adj_ret
1353
- )
1354
- {
1355
- for (unsigned i = 0; i < Length; ++i)
1356
- {
1357
- adj_a += adj_ret.c[i];
1358
- }
1359
-
1360
- adj_b -= adj_ret;
1361
- }
1362
-
1363
1310
  template<typename Type>
1364
1311
  inline CUDA_CALLABLE void adj_sub(vec_t<2, Type> a, vec_t<2, Type> b, vec_t<2, Type>& adj_a, vec_t<2, Type>& adj_b, const vec_t<2, Type>& adj_ret)
1365
1312
  {
warp/native/warp.cpp CHANGED
@@ -1078,9 +1078,9 @@ WP_API bool wp_cuda_graph_destroy(void* context, void* graph) { return false; }
1078
1078
  WP_API bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec) { return false; }
1079
1079
  WP_API bool wp_capture_debug_dot_print(void* graph, const char *path, uint32_t flags) { return false; }
1080
1080
 
1081
- WP_API bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret) { return false; }
1082
- WP_API bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret) { return false; }
1083
- WP_API bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle) { return false; }
1081
+ WP_API bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret) { return false; }
1082
+ WP_API bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret) { return false; }
1083
+ WP_API bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle) { return false; }
1084
1084
  WP_API bool wp_cuda_graph_pause_capture(void* context, void* stream, void** graph_ret) { return false; }
1085
1085
  WP_API bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph) { return false; }
1086
1086
  WP_API bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph) { return false; }
warp/native/warp.cu CHANGED
@@ -19,6 +19,7 @@
19
19
  #include "scan.h"
20
20
  #include "cuda_util.h"
21
21
  #include "error.h"
22
+ #include "sort.h"
22
23
 
23
24
  #include <cstdlib>
24
25
  #include <fstream>
@@ -2448,6 +2449,9 @@ void wp_cuda_stream_destroy(void* context, void* stream)
2448
2449
 
2449
2450
  wp_cuda_stream_unregister(context, stream);
2450
2451
 
2452
+ // release temporary radix sort buffer associated with this stream
2453
+ radix_sort_release(context, stream);
2454
+
2451
2455
  check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
2452
2456
  }
2453
2457
 
@@ -2811,11 +2815,12 @@ bool wp_cuda_graph_create_exec(void* context, void* stream, void* graph, void**
2811
2815
  // Support for conditional graph nodes available with CUDA 12.4+.
2812
2816
  #if CUDA_VERSION >= 12040
2813
2817
 
2814
- // CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
2815
- static std::map<int, void*> g_conditional_cubins;
2818
+ // CUBIN or PTX data for compiled conditional modules, loaded on demand, keyed on device architecture
2819
+ using ModuleKey = std::pair<int, bool>; // <arch, use_ptx>
2820
+ static std::map<ModuleKey, void*> g_conditional_modules;
2816
2821
 
2817
2822
  // Compile module with conditional helper kernels
2818
- static void* compile_conditional_module(int arch)
2823
+ static void* compile_conditional_module(int arch, bool use_ptx)
2819
2824
  {
2820
2825
  static const char* kernel_source = R"(
2821
2826
  typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
@@ -2844,8 +2849,9 @@ static void* compile_conditional_module(int arch)
2844
2849
  )";
2845
2850
 
2846
2851
  // avoid recompilation
2847
- auto it = g_conditional_cubins.find(arch);
2848
- if (it != g_conditional_cubins.end())
2852
+ ModuleKey key = {arch, use_ptx};
2853
+ auto it = g_conditional_modules.find(key);
2854
+ if (it != g_conditional_modules.end())
2849
2855
  return it->second;
2850
2856
 
2851
2857
  nvrtcProgram prog;
@@ -2853,11 +2859,23 @@ static void* compile_conditional_module(int arch)
2853
2859
  return NULL;
2854
2860
 
2855
2861
  char arch_opt[128];
2856
- snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
2862
+ if (use_ptx)
2863
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=compute_%d", arch);
2864
+ else
2865
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
2857
2866
 
2858
2867
  std::vector<const char*> opts;
2859
2868
  opts.push_back(arch_opt);
2860
2869
 
2870
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
2871
+ if (print_debug)
2872
+ {
2873
+ printf("NVRTC options (conditional module, arch=%d, use_ptx=%s):\n", arch, use_ptx ? "true" : "false");
2874
+ for(auto o: opts) {
2875
+ printf("%s\n", o);
2876
+ }
2877
+ }
2878
+
2861
2879
  if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
2862
2880
  {
2863
2881
  size_t log_size;
@@ -2874,23 +2892,37 @@ static void* compile_conditional_module(int arch)
2874
2892
  // get output
2875
2893
  char* output = NULL;
2876
2894
  size_t output_size = 0;
2877
- check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
2878
- if (output_size > 0)
2895
+
2896
+ if (use_ptx)
2897
+ {
2898
+ check_nvrtc(nvrtcGetPTXSize(prog, &output_size));
2899
+ if (output_size > 0)
2900
+ {
2901
+ output = new char[output_size];
2902
+ if (check_nvrtc(nvrtcGetPTX(prog, output)))
2903
+ g_conditional_modules[key] = output;
2904
+ }
2905
+ }
2906
+ else
2879
2907
  {
2880
- output = new char[output_size];
2881
- if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
2882
- g_conditional_cubins[arch] = output;
2908
+ check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
2909
+ if (output_size > 0)
2910
+ {
2911
+ output = new char[output_size];
2912
+ if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
2913
+ g_conditional_modules[key] = output;
2914
+ }
2883
2915
  }
2884
2916
 
2885
2917
  nvrtcDestroyProgram(&prog);
2886
2918
 
2887
- // return CUBIN data
2919
+ // return CUBIN or PTX data
2888
2920
  return output;
2889
2921
  }
2890
2922
 
2891
2923
 
2892
2924
  // Load module with conditional helper kernels
2893
- static CUmodule load_conditional_module(void* context)
2925
+ static CUmodule load_conditional_module(void* context, int arch, bool use_ptx)
2894
2926
  {
2895
2927
  ContextInfo* context_info = get_context_info(context);
2896
2928
  if (!context_info)
@@ -2900,17 +2932,15 @@ static CUmodule load_conditional_module(void* context)
2900
2932
  if (context_info->conditional_module)
2901
2933
  return context_info->conditional_module;
2902
2934
 
2903
- int arch = context_info->device_info->arch;
2904
-
2905
2935
  // compile if needed
2906
- void* compiled_module = compile_conditional_module(arch);
2936
+ void* compiled_module = compile_conditional_module(arch, use_ptx);
2907
2937
  if (!compiled_module)
2908
2938
  {
2909
2939
  fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
2910
2940
  return NULL;
2911
2941
  }
2912
2942
 
2913
- // load module
2943
+ // load module (handles both PTX and CUBIN data automatically)
2914
2944
  CUmodule module = NULL;
2915
2945
  if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
2916
2946
  {
@@ -2923,10 +2953,10 @@ static CUmodule load_conditional_module(void* context)
2923
2953
  return module;
2924
2954
  }
2925
2955
 
2926
- static CUfunction get_conditional_kernel(void* context, const char* name)
2956
+ static CUfunction get_conditional_kernel(void* context, int arch, bool use_ptx, const char* name)
2927
2957
  {
2928
2958
  // load module if needed
2929
- CUmodule module = load_conditional_module(context);
2959
+ CUmodule module = load_conditional_module(context, arch, use_ptx);
2930
2960
  if (!module)
2931
2961
  return NULL;
2932
2962
 
@@ -2976,7 +3006,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
2976
3006
  // https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
2977
3007
  // condition is a gpu pointer
2978
3008
  // if_graph_ret and else_graph_ret should be NULL if not needed
2979
- bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3009
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
2980
3010
  {
2981
3011
  bool has_if = if_graph_ret != NULL;
2982
3012
  bool has_else = else_graph_ret != NULL;
@@ -3019,9 +3049,9 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
3019
3049
  // (need to negate the condition if only the else branch is used)
3020
3050
  CUfunction kernel;
3021
3051
  if (has_if)
3022
- kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3052
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3023
3053
  else
3024
- kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
3054
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_else_handle_kernel");
3025
3055
 
3026
3056
  if (!kernel)
3027
3057
  {
@@ -3072,7 +3102,7 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
3072
3102
  check_cuda(cudaGraphConditionalHandleCreate(&if_handle, cuda_graph));
3073
3103
  check_cuda(cudaGraphConditionalHandleCreate(&else_handle, cuda_graph));
3074
3104
 
3075
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
3105
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_else_handles_kernel");
3076
3106
  if (!kernel)
3077
3107
  {
3078
3108
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3273,7 +3303,7 @@ bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_g
3273
3303
  return true;
3274
3304
  }
3275
3305
 
3276
- bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3306
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3277
3307
  {
3278
3308
  // if there's no body, it's a no-op
3279
3309
  if (!body_graph_ret)
@@ -3303,7 +3333,7 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
3303
3333
  return false;
3304
3334
 
3305
3335
  // launch a kernel to set the condition handle from condition pointer
3306
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3336
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3307
3337
  if (!kernel)
3308
3338
  {
3309
3339
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3339,14 +3369,14 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
3339
3369
  return true;
3340
3370
  }
3341
3371
 
3342
- bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3372
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3343
3373
  {
3344
3374
  ContextGuard guard(context);
3345
3375
 
3346
3376
  CUstream cuda_stream = static_cast<CUstream>(stream);
3347
3377
 
3348
3378
  // launch a kernel to set the condition handle from condition pointer
3349
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3379
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3350
3380
  if (!kernel)
3351
3381
  {
3352
3382
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3378,19 +3408,19 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
3378
3408
  return false;
3379
3409
  }
3380
3410
 
3381
- bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3411
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
3382
3412
  {
3383
3413
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3384
3414
  return false;
3385
3415
  }
3386
3416
 
3387
- bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3417
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3388
3418
  {
3389
3419
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3390
3420
  return false;
3391
3421
  }
3392
3422
 
3393
- bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3423
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3394
3424
  {
3395
3425
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3396
3426
  return false;
warp/native/warp.h CHANGED
@@ -314,9 +314,9 @@ extern "C"
314
314
  WP_API bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec);
315
315
  WP_API bool wp_capture_debug_dot_print(void* graph, const char *path, uint32_t flags);
316
316
 
317
- WP_API bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret);
318
- WP_API bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret);
319
- WP_API bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle);
317
+ WP_API bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret);
318
+ WP_API bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret);
319
+ WP_API bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle);
320
320
  WP_API bool wp_cuda_graph_pause_capture(void* context, void* stream, void** graph_ret);
321
321
  WP_API bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph);
322
322
  WP_API bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph);